xdrl.configs.hooks#
Classes#
Validation reward hook that emits scalar-friendly metrics. |
|
Compute generalized advantage estimation from a trainer hook. |
|
Hydra config for |
|
Hydra config for periodic policy checkpointing. |
|
Hydra config for explicit GAE computation as a trainer hook. |
|
Hydra config for scalar-friendly TorchRL validation reward logging. |
|
Hydra config for optional W&B shutdown cleanup. |
|
Hydra config for flushing pending Weights & Biases scalar rows. |
Functions#
|
|
|
|
|
Module Contents#
- class xdrl.configs.hooks.ReducedLogValidationReward(*, enabled=True, pre_eval=False, policy_path=None, **kwargs)[source]#
Bases:
torchrl.trainers.trainers.LogValidationRewardValidation reward hook that emits scalar-friendly metrics.
TorchRL’s validation hook can return tensor metrics. This subclass keeps the upstream rollout behavior but reduces multi-valued tensors before logging and optionally resolves the evaluation policy from a trainer attribute path.
- Parameters:
enabled (bool)
pre_eval (bool)
policy_path (str | None)
- register(trainer, name='validation_reward')[source]#
Registers the hook in the trainer at a default location.
- Parameters:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
- Return type:
None
Note
To register the hook at another location than the default, use
register_op().
- class xdrl.configs.hooks.GAEHook(*, gamma, lmbda, value_network_path='loss_module.critic_network', reward_key='reward', done_key='done', terminated_key='terminated', advantage_key='advantage', value_target_key='value_target', value_key='state_value', average_gae=True)[source]#
Bases:
torch.nn.Module,torchrl.trainers.trainers.TrainerHookBaseCompute generalized advantage estimation from a trainer hook.
The hook resolves the value network from the trainer during registration and applies TorchRL’s
GAEmodule atpre_epoch. This is useful when the trainer’s built-in GAE path is disabled or when key names need to be driven from Hydra.- Parameters:
gamma (float)
lmbda (float)
value_network_path (str)
reward_key (Any)
done_key (Any)
terminated_key (Any)
advantage_key (Any)
value_target_key (Any)
value_key (Any)
average_gae (bool)
- __call__(batch)[source]#
- Parameters:
batch (tensordict.TensorDictBase)
- Return type:
tensordict.TensorDictBase
- register(trainer, name='gae_hook')[source]#
Registers the hook in the trainer at a default location.
- Parameters:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
- Return type:
None
Note
To register the hook at another location than the default, use
register_op().
- xdrl.configs.hooks._make_logging_hook_set(*, group, frame_skip, reward_key=None, done_key=('next', 'done'), episode_reward_key=None, episode_reward_weights=None, reduce_stats=None, eval_hook_set=None)[source]#
- Parameters:
group (str)
frame_skip (int)
reward_key (Any | None)
done_key (Any)
episode_reward_key (Any | None)
episode_reward_weights (list[float] | None)
reduce_stats (bool | None)
eval_hook_set (Any)
- Return type:
- xdrl.configs.hooks._make_policy_checkpoint_hook(*, directory, interval, policy=None, policy_path=None, prefix='policy', destination='post_steps', meta=None)[source]#
- Parameters:
directory (str)
interval (int)
policy (Any)
policy_path (str | None)
prefix (str)
destination (str)
meta (dict[str, Any] | None)
- Return type:
- xdrl.configs.hooks._make_log_validation_reward_hook(*, environment, interval_frames, frames_per_batch, record_frames, policy_exploration=None, policy_path=None, enabled=True, pre_eval=False, frame_skip=1, exploration_type='DETERMINISTIC', log_keys=None, out_keys=None, log_pbar=False)[source]#
- Parameters:
environment (Any)
interval_frames (int)
frames_per_batch (int)
record_frames (int)
policy_exploration (Any)
policy_path (str | None)
enabled (bool)
pre_eval (bool)
frame_skip (int)
exploration_type (str)
log_keys (list[Any] | None)
out_keys (dict[Any, str] | None)
log_pbar (bool)
- Return type:
torchrl.trainers.trainers.LogValidationReward
- class xdrl.configs.hooks.LoggingHookSetConfig[source]#
Bases:
torchrl.trainers.algorithms.configs.common.ConfigBaseHydra config for
xdrl.trainer_hooks.LoggingHookSet.
- class xdrl.configs.hooks.PolicyCheckpointHookConfig[source]#
Bases:
torchrl.trainers.algorithms.configs.common.ConfigBaseHydra config for periodic policy checkpointing.
- class xdrl.configs.hooks.GAEHookConfig[source]#
Bases:
torchrl.trainers.algorithms.configs.common.ConfigBaseHydra config for explicit GAE computation as a trainer hook.
- class xdrl.configs.hooks.LogValidationRewardHookConfig[source]#
Bases:
torchrl.trainers.algorithms.configs.common.ConfigBaseHydra config for scalar-friendly TorchRL validation reward logging.
- class xdrl.configs.hooks.WandbFinishHookConfig[source]#
Bases:
torchrl.trainers.algorithms.configs.common.ConfigBaseHydra config for optional W&B shutdown cleanup.