xdrl.trainer_hooks#
Submodules#
Classes#
Periodically checkpoint policy weights for offline analysis. |
|
Log collection metrics in the |
|
Track frame and iteration counters in the |
|
Compose deterministic and non-deterministic evaluation hooks. |
|
Run periodic evaluation and log metrics under |
|
Compose the default |
|
Optional shutdown hook for config-driven Weights & Biases cleanup. |
|
Flush pending W&B scalar rows emitted through TorchRL's scalar logger. |
Package Contents#
- class xdrl.trainer_hooks.PolicyCheckpointHook(*, directory, interval, policy=None, policy_path=None, prefix='policy', destination='post_steps', meta=None)[source]#
Bases:
torchrl.trainers.trainers.TrainerHookBasePeriodically checkpoint policy weights for offline analysis.
- Parameters:
directory (str | pathlib.Path) – Directory where
.ptcheckpoint files are written.interval (int) – Number of hook calls between checkpoints. Non-positive values disable file writes while still allowing the hook to be registered.
policy (torch.nn.Module | None) – Policy module to checkpoint. If omitted,
policy_pathis resolved from the trainer during registration.policy_path (str | None) – Dotted path to the policy module on the trainer.
prefix (str) – Filename prefix used for checkpoint files.
destination (str) – TorchRL trainer operation where the hook is registered.
meta (dict[str, Any] | None) – Extra metadata persisted in every checkpoint payload.
- policy = None#
- policy_path = None#
- directory#
- interval#
- prefix = 'policy'#
- destination = 'post_steps'#
- meta#
- num_calls = 0#
- last_checkpoint_path: pathlib.Path | None = None#
- register(trainer, name='policy_checkpoint_hook')[source]#
Registers the hook in the trainer at a default location.
- Parameters:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
- Return type:
None
Note
To register the hook at another location than the default, use
register_op().
- class xdrl.trainer_hooks.LoggingCollectionMetricsHook(group='agents', reward_key=None, done_key=('next', 'done'), episode_reward_key=None, episode_reward_weights=None, reduce_stats=None)[source]#
Bases:
torchrl.trainers.trainers.TrainerHookBaseLog collection metrics in the
collection/namespace.The hook reads reward, done, and optional episode-reward tensors from a collected TensorDict. Vector-valued rewards can be scalarized with explicit weights, which is useful for MO-Gymnasium and other multi-objective runs.
- Parameters:
group (str)
reward_key (tuple[str, Ellipsis] | None)
done_key (tuple[str, Ellipsis])
episode_reward_key (tuple[str, Ellipsis] | None)
episode_reward_weights (collections.abc.Sequence[float] | None)
reduce_stats (bool | None)
- group = 'agents'#
- reward_key = None#
- done_key = ('next', 'done')#
- episode_reward_key = None#
- episode_reward_weights#
- reduce_stats = None#
- __call__(batch)[source]#
- Parameters:
batch (tensordict.TensorDictBase)
- Return type:
dict[str, float]
- register(trainer, name='logging_collection_metrics')[source]#
Registers the hook in the trainer at a default location.
- Parameters:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
- Return type:
None
Note
To register the hook at another location than the default, use
register_op().
- class xdrl.trainer_hooks.LoggingCountersHook(frame_skip=1)[source]#
Bases:
torchrl.trainers.trainers.TrainerHookBaseTrack frame and iteration counters in the
counters/namespace.- Parameters:
frame_skip (int)
- frame_skip = 1#
- total_frames = 0#
- iteration = 0#
- register(trainer, name='logging_counters')[source]#
Registers the hook in the trainer at a default location.
- Parameters:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
- Return type:
None
Note
To register the hook at another location than the default, use
register_op().
- class xdrl.trainer_hooks.LoggingEvaluationHookSet(*, policy, environment, group, interval_frames, max_steps, deterministic, non_deterministic, render, video_fps, render_kwargs=None, reward_key=None, reduce_stats=None, logger=None)[source]#
Compose deterministic and non-deterministic evaluation hooks.
This wrapper is useful when the same environment and policy should be evaluated with both exploration settings on the same schedule.
- Parameters:
policy (torch.nn.Module)
group (str)
interval_frames (int)
max_steps (int)
deterministic (bool)
non_deterministic (bool)
render (bool)
video_fps (int)
render_kwargs (dict[str, Any] | None)
reward_key (tuple[str, Ellipsis] | None)
reduce_stats (bool | None)
logger (Any | None)
- hooks: list[LoggingEvaluationMetricsHook] = []#
- class xdrl.trainer_hooks.LoggingEvaluationMetricsHook(*, policy, environment, group, metric_subgroup, interval_frames, max_steps, deterministic, render, video_fps, render_kwargs=None, reward_key=None, reduce_stats=None, logger=None)[source]#
Bases:
torchrl.trainers.trainers.TrainerHookBaseRun periodic evaluation and log metrics under
eval/<subgroup>/.- Parameters:
policy (torch.nn.Module) – Policy module used during rollout.
environment – TorchRL environment with a
rolloutmethod.group (str) – Agent/group namespace used for reward keys.
metric_subgroup (str) – Evaluation label such as
"deterministic".interval_frames (int) – Collected-frame interval between evaluations.
max_steps (int) – Maximum rollout length.
deterministic (bool) – Whether to force deterministic exploration.
render (bool) – Whether to capture rendered frames and log a video.
video_fps (int) – Video frame rate passed to the logger.
render_kwargs (dict[str, Any] | None) – Optional keyword arguments passed to environment
render.reward_key (tuple[str, Ellipsis] | None) – TensorDict key for rollout rewards.
reduce_stats (bool | None) – Whether vector metrics are reduced to min/mean/max.
logger (Any | None) – Optional logger used before the hook is registered on a trainer.
- policy#
- environment#
- group#
- reward_key = None#
- reduce_stats = None#
- metric_subgroup#
- interval_frames#
- max_steps#
- deterministic#
- render#
- render_kwargs#
- video_fps#
- logger = None#
- trainer: torchrl.trainers.trainers.Trainer | None = None#
- _log_direct(metrics, step)[source]#
- Parameters:
metrics (dict[str, float])
step (int)
- Return type:
None
- __call__(_batch)[source]#
- Parameters:
_batch (tensordict.TensorDictBase)
- Return type:
dict[str, float]
- register(trainer, name='logging_evaluation_metrics')[source]#
Registers the hook in the trainer at a default location.
- Parameters:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
- Return type:
None
Note
To register the hook at another location than the default, use
register_op().
- class xdrl.trainer_hooks.LoggingHookSet(*, group, frame_skip, reward_key=None, done_key=('next', 'done'), episode_reward_key=None, episode_reward_weights=None, reduce_stats=None, eval_hook_set=None)[source]#
Bases:
torchrl.trainers.trainers.TrainerHookBaseCompose the default
xdrllogging hooks.The hook set registers collection metrics, training metrics, counters, progress metrics, timers, and optional evaluation hooks as a single object. This keeps Hydra configs concise while preserving independently testable hook components.
- Parameters:
group (str)
frame_skip (int)
reward_key (tuple[str, Ellipsis] | None)
done_key (tuple[str, Ellipsis])
episode_reward_key (tuple[str, Ellipsis] | None)
episode_reward_weights (collections.abc.Sequence[float] | None)
reduce_stats (bool | None)
eval_hook_set (LoggingEvaluationHookSet | None)
- group#
- collection_hook#
- training_hook#
- counters_hook#
- progress_hook#
- eval_hook_set = None#
- _iteration_start: float | None = None#
- _previous_iteration_end: float | None = None#
- _collection_time = 0.0#
- _total_time = 0.0#
- _timers_start(batch)[source]#
- Parameters:
batch (tensordict.TensorDictBase)
- Return type:
tensordict.TensorDictBase
- _timers_end(_batch)[source]#
- Parameters:
_batch (tensordict.TensorDictBase)
- Return type:
dict[str, float]
- register(trainer, name='logging_hooks')[source]#
Registers the hook in the trainer at a default location.
- Parameters:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
- Return type:
None
Note
To register the hook at another location than the default, use
register_op().
- class xdrl.trainer_hooks.WandbFinishHook(enabled=True)[source]#
Bases:
torchrl.trainers.trainers.TrainerHookBaseOptional shutdown hook for config-driven Weights & Biases cleanup.
The hook intentionally swallows import/runtime errors so offline or disabled W&B runs do not fail trainer shutdown.
- Parameters:
enabled (bool)
- enabled = True#
- register(trainer, name='wandb_finish')[source]#
Registers the hook in the trainer at a default location.
- Parameters:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
- Return type:
None
Note
To register the hook at another location than the default, use
register_op().
- class xdrl.trainer_hooks.WandbFlushHook(enabled=True)[source]#
Bases:
torchrl.trainers.trainers.TrainerHookBaseFlush pending W&B scalar rows emitted through TorchRL’s scalar logger.
TorchRL logs scalar metrics one by one, while its W&B logger defaults those calls to
commit=Falseso metrics for the same step can be grouped. This hook commits the pending row after each trainer iteration and before W&B is finished, which makes metrics appear during long-running jobs.- Parameters:
enabled (bool)
- enabled = True#
- trainer: torchrl.trainers.trainers.Trainer | None = None#
- _last_flushed_steps: tuple[tuple[str, int], Ellipsis] = ()#
- static _wandb_step_registry(logger)[source]#
- Parameters:
logger (Any)
- Return type:
tuple[tuple[str, int], Ellipsis]
- register(trainer, name='wandb_flush')[source]#
Registers the hook in the trainer at a default location.
- Parameters:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
- Return type:
None
Note
To register the hook at another location than the default, use
register_op().