xdrl.configs.hooks#

Classes#

ReducedLogValidationReward

Validation reward hook that emits scalar-friendly metrics.

GAEHook

Compute generalized advantage estimation from a trainer hook.

LoggingHookSetConfig

Hydra config for xdrl.trainer_hooks.LoggingHookSet.

PolicyCheckpointHookConfig

Hydra config for periodic policy checkpointing.

GAEHookConfig

Hydra config for explicit GAE computation as a trainer hook.

LogValidationRewardHookConfig

Hydra config for scalar-friendly TorchRL validation reward logging.

WandbFinishHookConfig

Hydra config for optional W&B shutdown cleanup.

WandbFlushHookConfig

Hydra config for flushing pending Weights & Biases scalar rows.

Functions#

_make_logging_hook_set(*, group, frame_skip[, ...])

_make_policy_checkpoint_hook(*, directory, interval[, ...])

_make_log_validation_reward_hook(*, environment, ...)

Module Contents#

class xdrl.configs.hooks.ReducedLogValidationReward(*, enabled=True, pre_eval=False, policy_path=None, **kwargs)[source]#

Bases: torchrl.trainers.trainers.LogValidationReward

Validation reward hook that emits scalar-friendly metrics.

TorchRL’s validation hook can return tensor metrics. This subclass keeps the upstream rollout behavior but reduces multi-valued tensors before logging and optionally resolves the evaluation policy from a trainer attribute path.

Parameters:
  • enabled (bool)

  • pre_eval (bool)

  • policy_path (str | None)

enabled = True[source]#
pre_eval = False[source]#
policy_path = None[source]#
_trainer: torchrl.trainers.trainers.Trainer | None = None[source]#
_run_pre_eval()[source]#
Return type:

None

__call__(batch)[source]#
Parameters:

batch (tensordict.TensorDictBase)

Return type:

dict[str, Any]

register(trainer, name='validation_reward')[source]#

Registers the hook in the trainer at a default location.

Parameters:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

Return type:

None

Note

To register the hook at another location than the default, use register_op().

state_dict()[source]#
Return type:

dict[str, Any]

load_state_dict(state_dict)[source]#
Parameters:

state_dict (dict[str, Any])

Return type:

None

class xdrl.configs.hooks.GAEHook(*, gamma, lmbda, value_network_path='loss_module.critic_network', reward_key='reward', done_key='done', terminated_key='terminated', advantage_key='advantage', value_target_key='value_target', value_key='state_value', average_gae=True)[source]#

Bases: torch.nn.Module, torchrl.trainers.trainers.TrainerHookBase

Compute generalized advantage estimation from a trainer hook.

The hook resolves the value network from the trainer during registration and applies TorchRL’s GAE module at pre_epoch. This is useful when the trainer’s built-in GAE path is disabled or when key names need to be driven from Hydra.

Parameters:
  • gamma (float)

  • lmbda (float)

  • value_network_path (str)

  • reward_key (Any)

  • done_key (Any)

  • terminated_key (Any)

  • advantage_key (Any)

  • value_target_key (Any)

  • value_key (Any)

  • average_gae (bool)

gamma[source]#
lmbda[source]#
value_network_path = 'loss_module.critic_network'[source]#
reward_key = 'r'[source]#
done_key = 'd'[source]#
terminated_key = 't'[source]#
advantage_key = 'a'[source]#
value_target_key = 'v'[source]#
value_key = 's'[source]#
average_gae = True[source]#
gae: torchrl.objectives.value.GAE | None = None[source]#
__call__(batch)[source]#
Parameters:

batch (tensordict.TensorDictBase)

Return type:

tensordict.TensorDictBase

register(trainer, name='gae_hook')[source]#

Registers the hook in the trainer at a default location.

Parameters:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

Return type:

None

Note

To register the hook at another location than the default, use register_op().

xdrl.configs.hooks._make_logging_hook_set(*, group, frame_skip, reward_key=None, done_key=('next', 'done'), episode_reward_key=None, episode_reward_weights=None, reduce_stats=None, eval_hook_set=None)[source]#
Parameters:
  • group (str)

  • frame_skip (int)

  • reward_key (Any | None)

  • done_key (Any)

  • episode_reward_key (Any | None)

  • episode_reward_weights (list[float] | None)

  • reduce_stats (bool | None)

  • eval_hook_set (Any)

Return type:

xdrl.trainer_hooks.LoggingHookSet

xdrl.configs.hooks._make_policy_checkpoint_hook(*, directory, interval, policy=None, policy_path=None, prefix='policy', destination='post_steps', meta=None)[source]#
Parameters:
  • directory (str)

  • interval (int)

  • policy (Any)

  • policy_path (str | None)

  • prefix (str)

  • destination (str)

  • meta (dict[str, Any] | None)

Return type:

xdrl.trainer_hooks.PolicyCheckpointHook

xdrl.configs.hooks._make_log_validation_reward_hook(*, environment, interval_frames, frames_per_batch, record_frames, policy_exploration=None, policy_path=None, enabled=True, pre_eval=False, frame_skip=1, exploration_type='DETERMINISTIC', log_keys=None, out_keys=None, log_pbar=False)[source]#
Parameters:
  • environment (Any)

  • interval_frames (int)

  • frames_per_batch (int)

  • record_frames (int)

  • policy_exploration (Any)

  • policy_path (str | None)

  • enabled (bool)

  • pre_eval (bool)

  • frame_skip (int)

  • exploration_type (str)

  • log_keys (list[Any] | None)

  • out_keys (dict[Any, str] | None)

  • log_pbar (bool)

Return type:

torchrl.trainers.trainers.LogValidationReward

class xdrl.configs.hooks.LoggingHookSetConfig[source]#

Bases: torchrl.trainers.algorithms.configs.common.ConfigBase

Hydra config for xdrl.trainer_hooks.LoggingHookSet.

group: str = 'agents'[source]#
frame_skip: int = 1[source]#
reward_key: Any | None = None[source]#
done_key: Any = ('next', 'done')[source]#
episode_reward_key: Any | None = None[source]#
episode_reward_weights: list[float] | None = None[source]#
reduce_stats: bool | None = None[source]#
eval_hook_set: Any = None[source]#
_target_: str = 'xdrl.configs.hooks._make_logging_hook_set'[source]#
__post_init__()[source]#

Post-initialization hook for configuration classes.

Return type:

None

class xdrl.configs.hooks.PolicyCheckpointHookConfig[source]#

Bases: torchrl.trainers.algorithms.configs.common.ConfigBase

Hydra config for periodic policy checkpointing.

policy: Any = None[source]#
policy_path: str | None = None[source]#
directory: str = 'checkpoints/policy'[source]#
interval: int = 0[source]#
prefix: str = 'policy'[source]#
destination: str = 'post_steps'[source]#
meta: dict[str, Any] | None[source]#
_target_: str = 'xdrl.configs.hooks._make_policy_checkpoint_hook'[source]#
__post_init__()[source]#

Post-initialization hook for configuration classes.

Return type:

None

class xdrl.configs.hooks.GAEHookConfig[source]#

Bases: torchrl.trainers.algorithms.configs.common.ConfigBase

Hydra config for explicit GAE computation as a trainer hook.

gamma: float = 0.99[source]#
lmbda: float = 0.95[source]#
value_network_path: str = 'loss_module.critic_network'[source]#
reward_key: Any = 'reward'[source]#
done_key: Any = 'done'[source]#
terminated_key: Any = 'terminated'[source]#
advantage_key: Any = 'advantage'[source]#
value_target_key: Any = 'value_target'[source]#
value_key: Any = 'state_value'[source]#
average_gae: bool = True[source]#
_target_: str = 'xdrl.configs.hooks.GAEHook'[source]#
__post_init__()[source]#

Post-initialization hook for configuration classes.

Return type:

None

class xdrl.configs.hooks.LogValidationRewardHookConfig[source]#

Bases: torchrl.trainers.algorithms.configs.common.ConfigBase

Hydra config for scalar-friendly TorchRL validation reward logging.

policy_exploration: Any = None[source]#
policy_path: str | None = None[source]#
environment: Any = None[source]#
enabled: bool = True[source]#
pre_eval: bool = False[source]#
interval_frames: int = 100000[source]#
frames_per_batch: int = 1[source]#
record_frames: int = 100[source]#
frame_skip: int = 1[source]#
exploration_type: str = 'DETERMINISTIC'[source]#
log_keys: list[Any] | None = [['next', 'agents', 'reward']][source]#
out_keys: dict[Any, str] | None = None[source]#
log_pbar: bool = False[source]#
_target_: str = 'xdrl.configs.hooks._make_log_validation_reward_hook'[source]#
__post_init__()[source]#

Post-initialization hook for configuration classes.

Return type:

None

class xdrl.configs.hooks.WandbFinishHookConfig[source]#

Bases: torchrl.trainers.algorithms.configs.common.ConfigBase

Hydra config for optional W&B shutdown cleanup.

enabled: bool = True[source]#
_target_: str = 'xdrl.trainer_hooks.logging.WandbFinishHook'[source]#
__post_init__()[source]#

Post-initialization hook for configuration classes.

Return type:

None

class xdrl.configs.hooks.WandbFlushHookConfig[source]#

Bases: torchrl.trainers.algorithms.configs.common.ConfigBase

Hydra config for flushing pending Weights & Biases scalar rows.

enabled: bool = True[source]#
_target_: str = 'xdrl.trainer_hooks.logging.WandbFlushHook'[source]#
__post_init__()[source]#

Post-initialization hook for configuration classes.

Return type:

None