xdrl.trainer_hooks.checkpoints#
Classes#
Periodically checkpoint policy weights for offline analysis. |
Module Contents#
- class xdrl.trainer_hooks.checkpoints.PolicyCheckpointHook(*, directory, interval, policy=None, policy_path=None, prefix='policy', destination='post_steps', meta=None)[source]#
Bases:
torchrl.trainers.trainers.TrainerHookBasePeriodically checkpoint policy weights for offline analysis.
- Parameters:
directory (str | pathlib.Path) – Directory where
.ptcheckpoint files are written.interval (int) – Number of hook calls between checkpoints. Non-positive values disable file writes while still allowing the hook to be registered.
policy (torch.nn.Module | None) – Policy module to checkpoint. If omitted,
policy_pathis resolved from the trainer during registration.policy_path (str | None) – Dotted path to the policy module on the trainer.
prefix (str) – Filename prefix used for checkpoint files.
destination (str) – TorchRL trainer operation where the hook is registered.
meta (dict[str, Any] | None) – Extra metadata persisted in every checkpoint payload.
- register(trainer, name='policy_checkpoint_hook')[source]#
Registers the hook in the trainer at a default location.
- Parameters:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
- Return type:
None
Note
To register the hook at another location than the default, use
register_op().