from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import torch
from tensordict import TensorDict
from tensordict import TensorDictBase
from torchrl.trainers.algorithms.configs.common import ConfigBase, _normalize_hydra_key
from torchrl.objectives.value import GAE
from torchrl.trainers.trainers import LogValidationReward, Trainer, TrainerHookBase, _resolve_module
from xdrl.trainer_hooks import LoggingHookSet, PolicyCheckpointHook
[docs]
class ReducedLogValidationReward(LogValidationReward):
"""Validation reward hook that emits scalar-friendly metrics.
TorchRL's validation hook can return tensor metrics. This subclass keeps the
upstream rollout behavior but reduces multi-valued tensors before logging and
optionally resolves the evaluation policy from a trainer attribute path.
"""
def __init__(
self,
*,
enabled: bool = True,
pre_eval: bool = False,
policy_path: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
[docs]
self.enabled = bool(enabled)
[docs]
self.pre_eval = bool(pre_eval)
[docs]
self.policy_path = policy_path
[docs]
self._trainer: Trainer | None = None
[docs]
def _run_pre_eval(self) -> None:
if self._trainer is None:
return
metrics = self(TensorDict({}, batch_size=[]))
if metrics:
self._trainer._log(**metrics)
@torch.inference_mode()
[docs]
def __call__(self, batch: TensorDictBase) -> dict[str, Any]:
if not self.enabled:
return {}
out = super().__call__(batch)
if out is None:
return {}
reduced = {}
for key, value in out.items():
if isinstance(value, torch.Tensor) and value.numel() > 1:
reduced[key] = value.float().mean()
else:
reduced[key] = value
return reduced
[docs]
def register(self, trainer: Trainer, name: str = "validation_reward") -> None:
self._trainer = trainer
trainer.register_module(name, self)
if not self.enabled:
return
if self.policy_exploration is None and self.policy_path is not None:
self.policy_exploration = _resolve_module(trainer, self.policy_path)
if self.pre_eval:
trainer.register_op("setup", self._run_pre_eval)
trainer.register_op("post_steps_log", self)
[docs]
def state_dict(self) -> dict[str, Any]:
state = {}
if self.environment is not None:
state.update(super().state_dict())
state.update(
{
"enabled": self.enabled,
"pre_eval": self.pre_eval,
"policy_path": self.policy_path,
}
)
return state
[docs]
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.enabled = bool(state_dict.get("enabled", self.enabled))
self.pre_eval = bool(state_dict.get("pre_eval", self.pre_eval))
self.policy_path = state_dict.get("policy_path", self.policy_path)
if self.environment is not None and "recorder_state_dict" in state_dict:
super().load_state_dict(state_dict)
[docs]
class GAEHook(torch.nn.Module, TrainerHookBase):
"""Compute generalized advantage estimation from a trainer hook.
The hook resolves the value network from the trainer during registration and
applies TorchRL's ``GAE`` module at ``pre_epoch``. This is useful when the
trainer's built-in GAE path is disabled or when key names need to be driven
from Hydra.
"""
def __init__(
self,
*,
gamma: float,
lmbda: float,
value_network_path: str = "loss_module.critic_network",
reward_key: Any = "reward",
done_key: Any = "done",
terminated_key: Any = "terminated",
advantage_key: Any = "advantage",
value_target_key: Any = "value_target",
value_key: Any = "state_value",
average_gae: bool = True,
) -> None:
super().__init__()
[docs]
self.gamma = float(gamma)
[docs]
self.lmbda = float(lmbda)
[docs]
self.value_network_path = value_network_path
[docs]
self.reward_key = _normalize_hydra_key(reward_key)
[docs]
self.done_key = _normalize_hydra_key(done_key)
[docs]
self.terminated_key = _normalize_hydra_key(terminated_key)
[docs]
self.advantage_key = _normalize_hydra_key(advantage_key)
[docs]
self.value_target_key = _normalize_hydra_key(value_target_key)
[docs]
self.value_key = _normalize_hydra_key(value_key)
[docs]
self.average_gae = bool(average_gae)
[docs]
self.gae: GAE | None = None
[docs]
def __call__(self, batch: TensorDictBase) -> TensorDictBase:
if self.gae is None:
msg = "GAEHook must be registered before it is called."
raise RuntimeError(msg)
with torch.no_grad():
return self.gae(batch)
[docs]
def register(self, trainer: Trainer, name: str = "gae_hook") -> None:
value_network = _resolve_module(trainer, self.value_network_path)
self.gae = GAE(
gamma=self.gamma,
lmbda=self.lmbda,
value_network=value_network,
average_gae=self.average_gae,
)
self.gae.set_keys(
reward=self.reward_key,
done=self.done_key,
terminated=self.terminated_key,
advantage=self.advantage_key,
value_target=self.value_target_key,
value=self.value_key,
)
trainer.register_module(name, self)
trainer.register_op("pre_epoch", self)
[docs]
def _make_logging_hook_set(
*,
group: str,
frame_skip: int,
reward_key: Any | None = None,
done_key: Any = ("next", "done"),
episode_reward_key: Any | None = None,
episode_reward_weights: list[float] | None = None,
reduce_stats: bool | None = None,
eval_hook_set: Any = None,
) -> LoggingHookSet:
normalized_reward_key = None if reward_key is None else _normalize_hydra_key(reward_key)
normalized_done_key = _normalize_hydra_key(done_key)
normalized_episode_reward_key = None if episode_reward_key is None else _normalize_hydra_key(episode_reward_key)
return LoggingHookSet(
group=group,
frame_skip=frame_skip,
reward_key=normalized_reward_key,
done_key=normalized_done_key,
episode_reward_key=normalized_episode_reward_key,
episode_reward_weights=episode_reward_weights,
reduce_stats=reduce_stats,
eval_hook_set=eval_hook_set,
)
[docs]
def _make_policy_checkpoint_hook(
*,
directory: str,
interval: int,
policy: Any = None,
policy_path: str | None = None,
prefix: str = "policy",
destination: str = "post_steps",
meta: dict[str, Any] | None = None,
) -> PolicyCheckpointHook:
return PolicyCheckpointHook(
directory=directory,
interval=interval,
policy=policy,
policy_path=policy_path,
prefix=prefix,
destination=destination,
meta=meta,
)
[docs]
def _make_log_validation_reward_hook(
*,
environment: Any,
interval_frames: int,
frames_per_batch: int,
record_frames: int,
policy_exploration: Any = None,
policy_path: str | None = None,
enabled: bool = True,
pre_eval: bool = False,
frame_skip: int = 1,
exploration_type: str = "DETERMINISTIC",
log_keys: list[Any] | None = None,
out_keys: dict[Any, str] | None = None,
log_pbar: bool = False,
) -> LogValidationReward:
from torchrl.envs import ExplorationType
record_interval = max(1, int(interval_frames) // max(1, int(frames_per_batch) * int(frame_skip)))
exploration = getattr(ExplorationType, exploration_type.upper())
normalized_log_keys = None
if log_keys is not None:
normalized_log_keys = [_normalize_hydra_key(key) for key in log_keys]
normalized_out_keys = None
if out_keys is not None:
normalized_out_keys = {_normalize_hydra_key(key): value for key, value in out_keys.items()}
if normalized_log_keys is None:
normalized_log_keys = [("next", "agents", "reward")]
if normalized_out_keys is None:
normalized_out_keys = {}
for key in normalized_log_keys:
if key == ("next", "reward"):
normalized_out_keys[key] = "r_evaluation"
elif isinstance(key, tuple):
normalized_out_keys[key] = "validation/" + "/".join(str(item) for item in key)
else:
normalized_out_keys[key] = f"validation/{key}"
return ReducedLogValidationReward(
enabled=enabled,
pre_eval=pre_eval,
policy_path=policy_path,
record_interval=record_interval,
record_frames=record_frames,
frame_skip=frame_skip,
policy_exploration=policy_exploration,
environment=environment,
exploration_type=exploration,
log_keys=normalized_log_keys,
out_keys=normalized_out_keys,
log_pbar=log_pbar,
)
@dataclass
[docs]
class LoggingHookSetConfig(ConfigBase):
"""Hydra config for ``xdrl.trainer_hooks.LoggingHookSet``."""
[docs]
reward_key: Any | None = None
[docs]
done_key: Any = ("next", "done")
[docs]
episode_reward_key: Any | None = None
[docs]
episode_reward_weights: list[float] | None = None
[docs]
reduce_stats: bool | None = None
[docs]
eval_hook_set: Any = None
[docs]
_target_: str = "xdrl.configs.hooks._make_logging_hook_set"
[docs]
def __post_init__(self) -> None:
pass
@dataclass
[docs]
class PolicyCheckpointHookConfig(ConfigBase):
"""Hydra config for periodic policy checkpointing."""
[docs]
policy_path: str | None = None
[docs]
directory: str = "checkpoints/policy"
[docs]
destination: str = "post_steps"
[docs]
_target_: str = "xdrl.configs.hooks._make_policy_checkpoint_hook"
[docs]
def __post_init__(self) -> None:
pass
@dataclass
[docs]
class GAEHookConfig(ConfigBase):
"""Hydra config for explicit GAE computation as a trainer hook."""
[docs]
value_network_path: str = "loss_module.critic_network"
[docs]
reward_key: Any = "reward"
[docs]
terminated_key: Any = "terminated"
[docs]
advantage_key: Any = "advantage"
[docs]
value_target_key: Any = "value_target"
[docs]
value_key: Any = "state_value"
[docs]
average_gae: bool = True
[docs]
_target_: str = "xdrl.configs.hooks.GAEHook"
[docs]
def __post_init__(self) -> None:
pass
@dataclass
[docs]
class LogValidationRewardHookConfig(ConfigBase):
"""Hydra config for scalar-friendly TorchRL validation reward logging."""
[docs]
policy_exploration: Any = None
[docs]
policy_path: str | None = None
[docs]
environment: Any = None
[docs]
interval_frames: int = 100_000
[docs]
frames_per_batch: int = 1
[docs]
record_frames: int = 100
[docs]
exploration_type: str = "DETERMINISTIC"
[docs]
log_keys: list[Any] | None = field(default_factory=lambda: [["next", "agents", "reward"]])
[docs]
out_keys: dict[Any, str] | None = None
[docs]
_target_: str = "xdrl.configs.hooks._make_log_validation_reward_hook"
[docs]
def __post_init__(self) -> None:
pass
@dataclass
[docs]
class WandbFinishHookConfig(ConfigBase):
"""Hydra config for optional W&B shutdown cleanup."""
[docs]
_target_: str = "xdrl.trainer_hooks.logging.WandbFinishHook"
[docs]
def __post_init__(self) -> None:
pass
@dataclass
[docs]
class WandbFlushHookConfig(ConfigBase):
"""Hydra config for flushing pending Weights & Biases scalar rows."""
[docs]
_target_: str = "xdrl.trainer_hooks.logging.WandbFlushHook"
[docs]
def __post_init__(self) -> None:
pass