dialogue_manager.dialogue_policy.dqn_dialogue_policy

Deep dialogue policy based on Q network.

Module Contents

Classes

DQNDialoguePolicy

class dialogue_manager.dialogue_policy.dqn_dialogue_policy.DQNDialoguePolicy(input_size: int, hidden_size: int, output_size: int, possible_actions: List[Any])

Bases: moviebot.dialogue_manager.dialogue_policy.neural_dialogue_policy.NeuralDialoguePolicy

forward(state: torch.Tensor) torch.Tensor

Forward pass of the policy.

Parameters:

state – State or batch of states.

Returns:

Next action(s) probabilities.

select_action(state: torch.Tensor) Tuple[int, Any]

Selects an action based on the current state.

Parameters:

state – The current state.

Returns:

The id of selected action and the action.

save_policy(path: str) None

Saves the policy to a file.

Parameters:

path – The path to save the policy to.

classmethod load_policy(path: str) DQNDialoguePolicy

Loads the policy from a file.

Parameters:

path – The path to load the policy from.

Returns:

The loaded policy.