eta_utility.eta_x.common.extractors module

class eta_utility.eta_x.common.extractors.CustomExtractor(observation_space: gymnasium.Space, *, net_arch: Sequence[Mapping[str, Any]], device: th.device | str = 'auto')[source]

Bases: BaseFeaturesExtractor

Advanced feature extractor which allows the definition of arbitrary network structures. Layers can be any of the layers defined in torch.nn. The net_arch parameter will be interpreted by the function eta_utility.eta_x.common.common.deserialize_net_arch().

Parameters:
  • observation_space – gymnasium space.

  • net_arch – The architecture of the Advanced Feature Extractor. See eta_utility.eta_x.common.deserialize_net_arch() for syntax.

  • device – Torch device for training.

forward(observations: Tensor) Tensor[source]

Perform a forward pass through the network.

Parameters:

observations – Observations to pass through network.

Returns:

Output of network.