eta_utility.eta_x.agents.rule_based module
- class eta_utility.eta_x.agents.rule_based.RuleBased(policy: type[BasePolicy], env: VecEnv, verbose: int = 4, _init_setup_model: bool = True, **kwargs: Any)[source]
Bases:
BaseAlgorithm
,ABC
The rule based agent base class provides the facilities to easily build a complete rule based agent. To achieve this, only the control_rules function must be implemented. It should take an observation from the environment as input and provide actions as an output.
- Parameters:
policy – Agent policy. Parameter is not used in this agent and can be set to NoPolicy.
env – Environment to be controlled.
verbose – Logging verbosity.
kwargs – Additional arguments as specified in stable_baselins3.common.base_class.
- abstract control_rules(observation: ndarray) ndarray [source]
This function is abstract and should be used to implement control rules which determine actions from the received observations.
- Parameters:
observation – Observations as provided by a single, non vectorized environment.
- Returns:
Action values, as determined by the control rules.
- predict(observation: np.ndarray | dict[str, np.ndarray], state: tuple[np.ndarray, ...] | None = None, episode_start: np.ndarray | None = None, deterministic: bool = True) tuple[np.ndarray, tuple[np.ndarray, ...] | None] [source]
Perform controller operations and return actions. It will take care of vectorization of environments. This will call the control_rules method which should implement the control rules for a single environment.
- Parameters:
observation – the input observation.
state – The last states (not used here).
episode_start – The last masks (not used here).
deterministic – Whether to return deterministic actions. This agent always returns deterministic actions.
- Returns:
Tuple of the model’s action and the next state (state is typically None in this agent).
- classmethod load(path: str | pathlib.Path | io.BufferedIOBase, env: GymEnv | None = None, device: th.device | str = 'auto', custom_objects: dict[str, Any] | None = None, print_system_info: bool = False, force_reset: bool = True, _init_setup_model: bool = False, **kwargs: Any) RuleBased [source]
Load the model from a zip-file. Warning:
load
re-creates the model from scratch, it does not update it in-place!- Parameters:
path – path to the file (or a file-like) where to load the agent from.
env – the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment.
device – Device on which the code should run..
custom_objects – Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in
keras.models.load_model
. Useful when you have an object in file that can not be deserialized.print_system_info – Whether to print system info from the saved model and the current system info (useful to debug loading issues)
force_reset – Force a call to
reset()
before training to avoid unexpected behavior. See https://github.com/DLR-RM/stable-baselines3/issues/597kwargs – extra arguments to change the model when loading.
- get_parameter_list() None [source]
Getting tensorflow parameters is not implemented for the rule based agent.
- learn(total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 100, tb_log_name: str = 'run', reset_num_timesteps: bool = True, progress_bar: bool = False) RuleBased [source]
Return a trained model. Learning is not implemented for the rule based agent.
- Parameters:
total_timesteps – The total number of samples (env steps) to train on.
callback – Callback(s) called at every step with state of the algorithm.
log_interval – The number of timesteps before logging.
tb_log_name – The name of the run for TensorBoard logging.
reset_num_timesteps – Whether or not to reset the current timestep number (used in logging).
progress_bar – Display a progress bar using tqdm and rich.
- Returns:
The trained model.