Source code for eta_utility.eta_x.common.extractors
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING
import torch as th
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.utils import get_device
from .common import deserialize_net_arch
if TYPE_CHECKING:
from typing import Any
import gymnasium
from logging import getLogger
log = getLogger(__name__)
[docs]
class CustomExtractor(BaseFeaturesExtractor):
"""
Advanced feature extractor which allows the definition of arbitrary network structures. Layers can be any
of the layers defined in `torch.nn <https://pytorch.org/docs/stable/nn.html>`_. The net_arch parameter will
be interpreted by the function :py:func:`eta_utility.eta_x.common.common.deserialize_net_arch`.
:param observation_space: gymnasium space.
:param net_arch: The architecture of the Advanced Feature Extractor. See
:py:func:`eta_utility.eta_x.common.deserialize_net_arch` for syntax.
:param device: Torch device for training.
"""
def __init__(
self,
observation_space: gymnasium.Space,
*,
net_arch: Sequence[Mapping[str, Any]],
device: th.device | str = "auto",
):
device = get_device(device)
network = deserialize_net_arch(net_arch, in_features=observation_space.shape[0], device=device) # type: ignore
# Check output dimension of the network
with th.no_grad():
output = network(th.as_tensor(observation_space.sample()[None]).float())
super().__init__(observation_space, output.shape[1])
self.network = network
[docs]
def forward(self, observations: th.Tensor) -> th.Tensor:
"""Perform a forward pass through the network.
:param observations: Observations to pass through network.
:return: Output of network.
"""
return self.network(observations)