dialogue_manager.dialogue_policy.a2c_dialogue_policy

Deep dialogue policy based on advantage actor-critic.

Module Contents

Classes

A2CDialoguePolicy

class dialogue_manager.dialogue_policy.a2c_dialogue_policy.A2CDialoguePolicy(input_size: int, hidden_size: int, output_size: int, possible_actions: List[Any], num_timesteps: int | None = None, n_envs: int = 1)

Bases: moviebot.dialogue_manager.dialogue_policy.neural_dialogue_policy.NeuralDialoguePolicy

forward(state: torch.Tensor) Tuple[torch.Tensor, torch.Tensor]

Forward pass.

Parameters:

state – A batched vector of dialogue states.

Returns:

The output of the actor and the critic.

select_action(state: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

Returns the selected action and its log probability.

Parameters:

state – Representation of dialogue state as a vector.

Returns:

The selected action id, its log probability, the state value, and the entropy.

get_losses(rewards: torch.Tensor, action_log_probs: torch.Tensor, value_preds: torch.Tensor, entropy: torch.Tensor, mask: torch.Tensor, gamma: float = 0.99, lam: float = 0.95, entropy_coef: float = 0.01) Tuple[torch.Tensor, torch.Tensor]

Computes the loss of a minibatch using the generalized advantage estimator.

Parameters:
  • rewards – The rewards.

  • action_log_probs – The log probabilities of the actions.

  • value_preds – The predicted values.

  • entropy – The entropy.

  • mask – The mask.

  • gamma – The discount factor. Defaults to 0.99.

  • lam – The GAE parameter (1 for Monte-Carlo sampling, 0 for normal TD-learning). Defaults to 0.95.

  • entropy_coef – The entropy coefficient. Defaults to 0.01.

Returns:

The critic and actor losses for the minibatch.

update_parameters(critic_loss: torch.Tensor, actor_loss: torch.Tensor) None

Updates the parameters of the policy.

Parameters:
  • critic_loss – The critic loss.

  • actor_loss – The actor loss.

save_policy(path: str) None

Saves the policy.

Parameters:

path – The path to save the policy to.

classmethod load_policy(path: str) A2CDialoguePolicy

Loads the policy.

Parameters:

path – The path to load the policy from.

Returns:

The loaded policy.