xdrl.trainer_hooks#

Submodules#

Classes#

PolicyCheckpointHook

Periodically checkpoint policy weights for offline analysis.

LoggingCollectionMetricsHook

Log collection metrics in the collection/ namespace.

LoggingCountersHook

Track frame and iteration counters in the counters/ namespace.

LoggingEvaluationHookSet

Compose deterministic and non-deterministic evaluation hooks.

LoggingEvaluationMetricsHook

Run periodic evaluation and log metrics under eval/<subgroup>/.

LoggingHookSet

Compose the default xdrl logging hooks.

WandbFinishHook

Optional shutdown hook for config-driven Weights & Biases cleanup.

WandbFlushHook

Flush pending W&B scalar rows emitted through TorchRL's scalar logger.

Package Contents#

class xdrl.trainer_hooks.PolicyCheckpointHook(*, directory, interval, policy=None, policy_path=None, prefix='policy', destination='post_steps', meta=None)[source]#

Bases: torchrl.trainers.trainers.TrainerHookBase

Periodically checkpoint policy weights for offline analysis.

Parameters:
  • directory (str | pathlib.Path) – Directory where .pt checkpoint files are written.

  • interval (int) – Number of hook calls between checkpoints. Non-positive values disable file writes while still allowing the hook to be registered.

  • policy (torch.nn.Module | None) – Policy module to checkpoint. If omitted, policy_path is resolved from the trainer during registration.

  • policy_path (str | None) – Dotted path to the policy module on the trainer.

  • prefix (str) – Filename prefix used for checkpoint files.

  • destination (str) – TorchRL trainer operation where the hook is registered.

  • meta (dict[str, Any] | None) – Extra metadata persisted in every checkpoint payload.

policy = None#
policy_path = None#
directory#
interval#
prefix = 'policy'#
destination = 'post_steps'#
meta#
num_calls = 0#
last_checkpoint_path: pathlib.Path | None = None#
__call__()[source]#
Return type:

None

register(trainer, name='policy_checkpoint_hook')[source]#

Registers the hook in the trainer at a default location.

Parameters:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

Return type:

None

Note

To register the hook at another location than the default, use register_op().

state_dict()[source]#
Return type:

dict[str, Any]

load_state_dict(state_dict)[source]#
Parameters:

state_dict (dict[str, Any])

Return type:

None

class xdrl.trainer_hooks.LoggingCollectionMetricsHook(group='agents', reward_key=None, done_key=('next', 'done'), episode_reward_key=None, episode_reward_weights=None, reduce_stats=None)[source]#

Bases: torchrl.trainers.trainers.TrainerHookBase

Log collection metrics in the collection/ namespace.

The hook reads reward, done, and optional episode-reward tensors from a collected TensorDict. Vector-valued rewards can be scalarized with explicit weights, which is useful for MO-Gymnasium and other multi-objective runs.

Parameters:
  • group (str)

  • reward_key (tuple[str, Ellipsis] | None)

  • done_key (tuple[str, Ellipsis])

  • episode_reward_key (tuple[str, Ellipsis] | None)

  • episode_reward_weights (collections.abc.Sequence[float] | None)

  • reduce_stats (bool | None)

group = 'agents'#
reward_key = None#
done_key = ('next', 'done')#
episode_reward_key = None#
episode_reward_weights#
reduce_stats = None#
__call__(batch)[source]#
Parameters:

batch (tensordict.TensorDictBase)

Return type:

dict[str, float]

register(trainer, name='logging_collection_metrics')[source]#

Registers the hook in the trainer at a default location.

Parameters:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

Return type:

None

Note

To register the hook at another location than the default, use register_op().

state_dict()[source]#
Return type:

dict[str, Any]

load_state_dict(state_dict)[source]#
Parameters:

state_dict (dict[str, Any])

Return type:

None

class xdrl.trainer_hooks.LoggingCountersHook(frame_skip=1)[source]#

Bases: torchrl.trainers.trainers.TrainerHookBase

Track frame and iteration counters in the counters/ namespace.

Parameters:

frame_skip (int)

frame_skip = 1#
total_frames = 0#
iteration = 0#
_current_frames(batch)[source]#
Parameters:

batch (tensordict.TensorDictBase)

Return type:

int

__call__(batch)[source]#
Parameters:

batch (tensordict.TensorDictBase)

Return type:

dict[str, int]

register(trainer, name='logging_counters')[source]#

Registers the hook in the trainer at a default location.

Parameters:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

Return type:

None

Note

To register the hook at another location than the default, use register_op().

state_dict()[source]#
Return type:

dict[str, Any]

load_state_dict(state_dict)[source]#
Parameters:

state_dict (dict[str, Any])

Return type:

None

class xdrl.trainer_hooks.LoggingEvaluationHookSet(*, policy, environment, group, interval_frames, max_steps, deterministic, non_deterministic, render, video_fps, render_kwargs=None, reward_key=None, reduce_stats=None, logger=None)[source]#

Compose deterministic and non-deterministic evaluation hooks.

This wrapper is useful when the same environment and policy should be evaluated with both exploration settings on the same schedule.

Parameters:
  • policy (torch.nn.Module)

  • group (str)

  • interval_frames (int)

  • max_steps (int)

  • deterministic (bool)

  • non_deterministic (bool)

  • render (bool)

  • video_fps (int)

  • render_kwargs (dict[str, Any] | None)

  • reward_key (tuple[str, Ellipsis] | None)

  • reduce_stats (bool | None)

  • logger (Any | None)

hooks: list[LoggingEvaluationMetricsHook] = []#
register(trainer, name='logging_evaluation_metrics')[source]#
Parameters:
  • trainer (torchrl.trainers.trainers.Trainer)

  • name (str)

Return type:

None

run(*, step)[source]#
Parameters:

step (int)

Return type:

dict[str, float]

close()[source]#
Return type:

None

class xdrl.trainer_hooks.LoggingEvaluationMetricsHook(*, policy, environment, group, metric_subgroup, interval_frames, max_steps, deterministic, render, video_fps, render_kwargs=None, reward_key=None, reduce_stats=None, logger=None)[source]#

Bases: torchrl.trainers.trainers.TrainerHookBase

Run periodic evaluation and log metrics under eval/<subgroup>/.

Parameters:
  • policy (torch.nn.Module) – Policy module used during rollout.

  • environment – TorchRL environment with a rollout method.

  • group (str) – Agent/group namespace used for reward keys.

  • metric_subgroup (str) – Evaluation label such as "deterministic".

  • interval_frames (int) – Collected-frame interval between evaluations.

  • max_steps (int) – Maximum rollout length.

  • deterministic (bool) – Whether to force deterministic exploration.

  • render (bool) – Whether to capture rendered frames and log a video.

  • video_fps (int) – Video frame rate passed to the logger.

  • render_kwargs (dict[str, Any] | None) – Optional keyword arguments passed to environment render.

  • reward_key (tuple[str, Ellipsis] | None) – TensorDict key for rollout rewards.

  • reduce_stats (bool | None) – Whether vector metrics are reduced to min/mean/max.

  • logger (Any | None) – Optional logger used before the hook is registered on a trainer.

policy#
environment#
group#
reward_key = None#
reduce_stats = None#
metric_subgroup#
interval_frames#
max_steps#
deterministic#
render#
render_kwargs#
video_fps#
logger = None#
trainer: torchrl.trainers.trainers.Trainer | None = None#
static _to_frame_array(frame)[source]#
Parameters:

frame (Any)

Return type:

numpy.ndarray

_extract_render_frame(output)[source]#
Parameters:

output (Any)

Return type:

numpy.ndarray

_renderable_candidates()[source]#
Return type:

list[Any]

_render_frame()[source]#
Return type:

numpy.ndarray

_evaluate_once(step)[source]#
Parameters:

step (int)

Return type:

dict[str, float]

_log_direct(metrics, step)[source]#
Parameters:
  • metrics (dict[str, float])

  • step (int)

Return type:

None

run(*, step)[source]#
Parameters:

step (int)

Return type:

dict[str, float]

__call__(_batch)[source]#
Parameters:

_batch (tensordict.TensorDictBase)

Return type:

dict[str, float]

register(trainer, name='logging_evaluation_metrics')[source]#

Registers the hook in the trainer at a default location.

Parameters:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

Return type:

None

Note

To register the hook at another location than the default, use register_op().

close()[source]#
Return type:

None

state_dict()[source]#
Return type:

dict[str, Any]

load_state_dict(state_dict)[source]#
Parameters:

state_dict (dict[str, Any])

Return type:

None

class xdrl.trainer_hooks.LoggingHookSet(*, group, frame_skip, reward_key=None, done_key=('next', 'done'), episode_reward_key=None, episode_reward_weights=None, reduce_stats=None, eval_hook_set=None)[source]#

Bases: torchrl.trainers.trainers.TrainerHookBase

Compose the default xdrl logging hooks.

The hook set registers collection metrics, training metrics, counters, progress metrics, timers, and optional evaluation hooks as a single object. This keeps Hydra configs concise while preserving independently testable hook components.

Parameters:
  • group (str)

  • frame_skip (int)

  • reward_key (tuple[str, Ellipsis] | None)

  • done_key (tuple[str, Ellipsis])

  • episode_reward_key (tuple[str, Ellipsis] | None)

  • episode_reward_weights (collections.abc.Sequence[float] | None)

  • reduce_stats (bool | None)

  • eval_hook_set (LoggingEvaluationHookSet | None)

group#
collection_hook#
training_hook#
counters_hook#
progress_hook#
eval_hook_set = None#
_iteration_start: float | None = None#
_previous_iteration_end: float | None = None#
_collection_time = 0.0#
_total_time = 0.0#
_timers_start(batch)[source]#
Parameters:

batch (tensordict.TensorDictBase)

Return type:

tensordict.TensorDictBase

_timers_end(_batch)[source]#
Parameters:

_batch (tensordict.TensorDictBase)

Return type:

dict[str, float]

register(trainer, name='logging_hooks')[source]#

Registers the hook in the trainer at a default location.

Parameters:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

Return type:

None

Note

To register the hook at another location than the default, use register_op().

run_pre_eval()[source]#
Return type:

dict[str, float]

close()[source]#
Return type:

None

state_dict()[source]#
Return type:

dict[str, Any]

load_state_dict(state_dict)[source]#
Parameters:

state_dict (dict[str, Any])

Return type:

None

class xdrl.trainer_hooks.WandbFinishHook(enabled=True)[source]#

Bases: torchrl.trainers.trainers.TrainerHookBase

Optional shutdown hook for config-driven Weights & Biases cleanup.

The hook intentionally swallows import/runtime errors so offline or disabled W&B runs do not fail trainer shutdown.

Parameters:

enabled (bool)

enabled = True#
__call__()[source]#
Return type:

None

register(trainer, name='wandb_finish')[source]#

Registers the hook in the trainer at a default location.

Parameters:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

Return type:

None

Note

To register the hook at another location than the default, use register_op().

state_dict()[source]#
Return type:

dict[str, Any]

load_state_dict(state_dict)[source]#
Parameters:

state_dict (dict[str, Any])

Return type:

None

class xdrl.trainer_hooks.WandbFlushHook(enabled=True)[source]#

Bases: torchrl.trainers.trainers.TrainerHookBase

Flush pending W&B scalar rows emitted through TorchRL’s scalar logger.

TorchRL logs scalar metrics one by one, while its W&B logger defaults those calls to commit=False so metrics for the same step can be grouped. This hook commits the pending row after each trainer iteration and before W&B is finished, which makes metrics appear during long-running jobs.

Parameters:

enabled (bool)

enabled = True#
trainer: torchrl.trainers.trainers.Trainer | None = None#
_last_flushed_steps: tuple[tuple[str, int], Ellipsis] = ()#
static _wandb_step_registry(logger)[source]#
Parameters:

logger (Any)

Return type:

tuple[tuple[str, int], Ellipsis]

static _wandb_experiment(logger)[source]#
Parameters:

logger (Any)

Return type:

Any | None

__call__(*_args, **_kwargs)[source]#
Parameters:
  • _args (Any)

  • _kwargs (Any)

Return type:

None

register(trainer, name='wandb_flush')[source]#

Registers the hook in the trainer at a default location.

Parameters:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

Return type:

None

Note

To register the hook at another location than the default, use register_op().

state_dict()[source]#
Return type:

dict[str, Any]

load_state_dict(state_dict)[source]#
Parameters:

state_dict (dict[str, Any])

Return type:

None