xdrl.trainer_hooks.checkpoints#

Classes#

PolicyCheckpointHook

Periodically checkpoint policy weights for offline analysis.

Module Contents#

class xdrl.trainer_hooks.checkpoints.PolicyCheckpointHook(*, directory, interval, policy=None, policy_path=None, prefix='policy', destination='post_steps', meta=None)[source]#

Bases: torchrl.trainers.trainers.TrainerHookBase

Periodically checkpoint policy weights for offline analysis.

Parameters:
  • directory (str | pathlib.Path) – Directory where .pt checkpoint files are written.

  • interval (int) – Number of hook calls between checkpoints. Non-positive values disable file writes while still allowing the hook to be registered.

  • policy (torch.nn.Module | None) – Policy module to checkpoint. If omitted, policy_path is resolved from the trainer during registration.

  • policy_path (str | None) – Dotted path to the policy module on the trainer.

  • prefix (str) – Filename prefix used for checkpoint files.

  • destination (str) – TorchRL trainer operation where the hook is registered.

  • meta (dict[str, Any] | None) – Extra metadata persisted in every checkpoint payload.

policy = None[source]#
policy_path = None[source]#
directory[source]#
interval[source]#
prefix = 'policy'[source]#
destination = 'post_steps'[source]#
meta[source]#
num_calls = 0[source]#
last_checkpoint_path: pathlib.Path | None = None[source]#
__call__()[source]#
Return type:

None

register(trainer, name='policy_checkpoint_hook')[source]#

Registers the hook in the trainer at a default location.

Parameters:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

Return type:

None

Note

To register the hook at another location than the default, use register_op().

state_dict()[source]#
Return type:

dict[str, Any]

load_state_dict(state_dict)[source]#
Parameters:

state_dict (dict[str, Any])

Return type:

None