dialogue_manager.dialogue_policy.a2c_dialogue_policy¶
Deep dialogue policy based on advantage actor-critic.
Module Contents¶
Classes¶
- class dialogue_manager.dialogue_policy.a2c_dialogue_policy.A2CDialoguePolicy(input_size: int, hidden_size: int, output_size: int, possible_actions: List[Any], num_timesteps: int | None = None, n_envs: int = 1)¶
Bases:
moviebot.dialogue_manager.dialogue_policy.neural_dialogue_policy.NeuralDialoguePolicy- forward(state: torch.Tensor) Tuple[torch.Tensor, torch.Tensor]¶
Forward pass.
- Parameters:
state – A batched vector of dialogue states.
- Returns:
The output of the actor and the critic.
- select_action(state: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]¶
Returns the selected action and its log probability.
- Parameters:
state – Representation of dialogue state as a vector.
- Returns:
The selected action id, its log probability, the state value, and the entropy.
- get_losses(rewards: torch.Tensor, action_log_probs: torch.Tensor, value_preds: torch.Tensor, entropy: torch.Tensor, mask: torch.Tensor, gamma: float = 0.99, lam: float = 0.95, entropy_coef: float = 0.01) Tuple[torch.Tensor, torch.Tensor]¶
Computes the loss of a minibatch using the generalized advantage estimator.
- Parameters:
rewards – The rewards.
action_log_probs – The log probabilities of the actions.
value_preds – The predicted values.
entropy – The entropy.
mask – The mask.
gamma – The discount factor. Defaults to 0.99.
lam – The GAE parameter (1 for Monte-Carlo sampling, 0 for normal TD-learning). Defaults to 0.95.
entropy_coef – The entropy coefficient. Defaults to 0.01.
- Returns:
The critic and actor losses for the minibatch.
- update_parameters(critic_loss: torch.Tensor, actor_loss: torch.Tensor) None¶
Updates the parameters of the policy.
- Parameters:
critic_loss – The critic loss.
actor_loss – The actor loss.
- save_policy(path: str) None¶
Saves the policy.
- Parameters:
path – The path to save the policy to.
- classmethod load_policy(path: str) A2CDialoguePolicy¶
Loads the policy.
- Parameters:
path – The path to load the policy from.
- Returns:
The loaded policy.