Source code for eta_utility.eta_x.agents.nsga2

from __future__ import annotations

import time
from collections import deque
from typing import TYPE_CHECKING

import numpy as np
from attrs import define
from gymnasium import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.utils import safe_mean
from stable_baselines3.common.vec_env import VecNormalize

from eta_utility import get_logger
from eta_utility.util_julia import check_julia_package

if check_julia_package():
    from julia import Main as Jl  # noqa: I900
    from julia import ju_extensions  # noqa: I900
    from julia.ju_extensions.Agents import Nsga2 as ju_NSGA2  # noqa: I900

if TYPE_CHECKING:
    import io
    import pathlib
    from typing import Any, Callable

    import torch as th
    from julia import _jlwrapper  # noqa: I900
    from stable_baselines3.common.callbacks import BaseCallback
    from stable_baselines3.common.policies import BasePolicy
    from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
    from stable_baselines3.common.vec_env import VecEnv

Jl.eval("using PyCall")
Jl.eval("import ju_extensions.Agents.Nsga2")
log = get_logger("eta_x.agents")


@define
class _VariableParameters:
    """VariableParameters define the minimum and maximum values as well as the data type of the variables."""

    #: Data type of the variable (can be 'int' or 'float').
    dtype: str
    #: Minimum value of the variable.
    minimum: float
    #: Maximum value of the variable.
    maximum: float

    @classmethod
    def from_space(cls, space: spaces.Space) -> list[_VariableParameters]:
        """Create _VariableParameters from a gymnasium space object.

        :param space: Gymnasium space description object.
        :return: List of _VariableParameters objects (one object for each variable).
        """
        if isinstance(space, spaces.Box):
            dtype = "int" if space.dtype in {np.int32, np.int16, np.int8, np.int64} else "float"
            return [cls(dtype, minimum=space.low[key], maximum=space.high[key]) for key, _ in enumerate(space.shape)]

        elif isinstance(space, spaces.MultiDiscrete):
            return [cls("int", minimum=0, maximum=int(dim)) for dim in space.nvec]

        elif isinstance(space, spaces.MultiBinary):
            return [cls(dtype="int", minimum=0, maximum=1) for _ in range(space.n)]  # type: ignore

        elif isinstance(space, spaces.Discrete):
            return [cls(dtype="int", minimum=0, maximum=int(space.n))]

        else:
            raise ValueError("Unknown type of space for variable parameters.")


[docs] class Nsga2(BaseAlgorithm): """The NSGA2 class implements the non-dominated sorting genetic algorithm 2 The agent can work with discrete event systems and with continous or mixed integer problems. Alternatively a mixture of the above may be specified. The action space can specify both events and variables using spaces.Dict in the form:: action_space= spaces.Dict({'events': spaces.Discrete(15), 'variables': spaces.MultiDiscrete([15]*3)}) This specifies 15 events and an additional 3 variables. The variables will be integers and have an upper boudary value of 15. Other spaces (except Tuple and Dict) can be defined for the variables. Events only takes the Discrete space as an input. When events is specified, a list will be returned with ordered values, that should achieve a near optimal reward. For variables the values will be adjusted to achieve the highest reward. Upper and lower boundaries as well as types will be infered from the space. .. note:: This agent does not use the observation space. Instead it only relies on rewards returned by the environment. Returned rewards can be tuples, if multi-objective optimization is required. Existing Environments do not have to be adjusted, however. The agent will also accept standard rewards and will ignore any observation spaces. .. note:: The number of environments must be equal to the population for this agent because it needs one environment for the evaluation of every solution. This allows for solutions to be evaluated in parallel. :param policy: Agent policy. Parameter is not used in this agent. :param env: Environment to be optimized. :param learning_rate: Reduction factor for the crossover and mutation rates (default: 1). :param verbose: Logging verbosity. :param population: Maximum number of parallel solutions (>= 2). :param mutations: Chance for mutations in existing solutions (between 0 and 1). :param crossovers: Chance for crossovers between solutions (between 0 and 1). :param n_generations: Number of generations to run the algorithm for. :param max_cross_len: Maximum number of genes (as a proportion of total elements) to cross over between solutions (between 0 and 1) (default 1). :param max_retries: Maximum number of tries to find new values before the algorithm fails and returns. Using the default should usually be fine (default: 10000). :param sense: Determine whether the algorithm looks for minimal ("minimize") or maximal ("maximize") rewards (default: "minimize") :param tensorboard_log: the log location for tensorboard (if None, no logging). :param seed: Seed for the pseudo random generators. :param _init_setup_model: Determine whether model should be initialized during setup """ def __init__( self, policy: type[BasePolicy], env: VecEnv, learning_rate: float | Schedule = 1.0, verbose: int = 2, *, population: int = 100, mutations: float = 0.05, crossovers: float = 0.1, n_generations: int = 100, max_cross_len: float = 1, max_retries: int = 100000, sense: str = "minimize", predict_learn_steps: int = 5, seed: int = 42, tensorboard_log: str | None = None, _init_setup_model: bool = True, **kwargs: Any, ) -> None: # Some types are incorrectly defined in the super class; this fixes it for this class and suppresses warnings self.start_time: float | None # type: ignore self.lr_schedule: Callable self.policy_class: type[BasePolicy] self.policy: BasePolicy # Set default values for superclass arguments super().__init__( policy=policy, env=env, learning_rate=learning_rate, tensorboard_log=tensorboard_log, verbose=verbose, seed=seed, support_multi_env=True, supported_action_spaces=( spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary, spaces.Dict, ), **kwargs, ) if self.observation_space is not None and self.action_space is not None: self.policy = self.policy_class(self.observation_space, self.action_space, **self.policy_kwargs) log.setLevel(int(verbose * 10)) ju_extensions.set_logger(log.level) if self.env is None: raise ValueError("The NSGA2 agent needs a specific environment to work correctly. Cannot use env = None.") if isinstance(self.env, VecNormalize): raise TypeError("The NSGA2 agent does not allow the use of normalized environments.") if sense not in {"minimize", "maximize"}: raise ValueError(f"The optimization sense must be one of 'minimize' or 'maximize', got {sense}.") #: Maximum number of parallel solutions (>= 2). self.population: int = population #: Chance for mutations in existing solutions (between 0 and 1). self.mutations: float = mutations #: Chance for crossovers between solutions (between 0 and 1). self.crossovers: float = crossovers #: Maximum number of genes (as a proportion of total elements) to cross over between solutions #: (between 0 and 1) (default 1). self.max_cross_len: float = max_cross_len #: Maximum number of tries to find new values before the algorithm fails and returns. #: Using the default should usually be fine (default: 10000). self.max_retries: int = max_retries #: Sense of the optimization (maximize or minimize). self.sense: str = sense #: Maximum number of generations to run for. self.n_generations: int = n_generations #: Maximum value of the reward (positive or negative infinity, depending on the optimization sense). self._max_value = np.inf if sense == "minimize" else -np.inf #: Parameters defining, how the events chromosome is generated. This is determined #: automatically from the action space. self.event_params: int = 0 #: Parameters defining how the variables chromosome is generated. This is determined #: automatically from the action space. self.variable_params: list[_VariableParameters] = [] #: Parent generation of solutions. self.generation_parent: _jlwrapper = [] #: Offspring generation of solutions. self.generation_offspr: _jlwrapper = [] #: Current learning rate of the algorithm. self._current_learning_rate: float = 1.0 #: List of solutions which have been seen before (avoids duplicate evaluation of equivalent solutions. self.seen_solutions: int = 0 #: Total number of retries needed during evolution to generate unique solutions. self.total_retries: int = 0 #: List of current minimal values for all parts of the reward self.current_minima: np.ndarray = np.full(1, self._max_value, dtype=np.float64, order="F") #: Buffer for actions self.ep_actions_buffer: deque = deque(maxlen=100) #: Buffer for rewards self.ep_reward_buffer: deque = deque(maxlen=100) #: Sorted sets of solutions self._fronts: deque = deque(maxlen=100) #: Number of solutions in each front self._front_lengths: deque = deque(maxlen=100) #: Buffer for training infos self.training_infos_buffer: dict = {} #: Number of learning steps for predict function self.predict_learn_steps: int = predict_learn_steps self._setup_lr_schedule() self._update_learning_rate() # Initialize and parametrize the julia functions. self.__jl_agent: _jlwrapper self._jl_Algorithm = Jl.eval( "pyfunctionret(" "Nsga2.Algorithm, Any, Int, Float64, Float64, Int, Int, Int, " "Nsga2.VariableParameters, Float64, String, UInt64" ")" ) self._jl_create_generation = Jl.eval("pyfunctionret(Nsga2.create_generation, Any, PyAny, Bool)") self._jl_create_offspring = Jl.eval("pyfunctionret(Nsga2.create_offspring, Any, PyAny, PyAny)") self._jl_initialize_rnd = Jl.eval("pyfunctionret(Nsga2.initialize_rnd!, Int, PyAny, PyAny)") self._jl_reinitialize_rnd = Jl.eval("pyfunctionret(Nsga2.initialize_rnd!, Int, PyAny, PyAny, Vector{Int})") self._jl_evolve = Jl.eval("pyfunctionret(Nsga2.evolve!, Int, PyAny, PyAny, PyAny, Float64)") self._jl_evaluate_solutions = Jl.eval( "pyfunctionret(" " Nsga2.evaluate!," " Tuple{Vector{Float64}, Int, Vector{Int}, Vector{Int}}," " PyAny," " PyAny," " PyAny," " PyAny" ")" ) self._jl_store_reward = Jl.eval("pyfunctionret(Nsga2.py_store_reward, nothing, PyAny, PyArray)") self._jl_get_actions = Jl.eval("pyfunctionret(Nsga2.py_actions, PyObject, PyAny)") self._jl_setup_generation = Jl.eval("pyfunctionret(Nsga2.load_generation, Any, PyArray, PyArray, Float64)") if _init_setup_model: self._setup_model() log.info( f"Agent initialized with parameters population:{self.population}, mutations: {self.mutations}, " f"crossovers: {self.crossovers}." ) self._check_learn_config() @property def last_evaluation_actions(self) -> np.ndarray | None: if len(self.ep_actions_buffer) >= 1: return self.ep_actions_buffer[-1] else: return None @property def last_evaluation_rewards(self) -> Any | None: if len(self.ep_reward_buffer) >= 1: return self.ep_reward_buffer[-1] else: return None @property def last_evaluation_fronts(self) -> list: fronts = [] beginning = 0 for frontend in self._front_lengths[-1]: fronts.append([s - 1 for s in self._fronts[-1][beginning:frontend]]) beginning = frontend return fronts def _event_and_variable_params(self) -> tuple[int, list[_VariableParameters]]: """Read event parameters and variable parameters from the action space. :return: Tuple of the events and variable configurations. """ event_params = 0 variable_space = None # If the type of the action space is spaces.Dict, it could contain events as well as variables. # Other spaces only contain variables. if isinstance(self.action_space, spaces.Dict): # Extract events space if "events" in self.action_space.spaces: if not isinstance(self.action_space.spaces["events"], spaces.Discrete): raise ValueError( f"Events must be specified as a discrete space. Received {type(self.action_space['events'])}." ) event_params = self.action_space.spaces["events"].n # type: ignore # Extract variables spaces. if "variables" in self.action_space.spaces: variable_space = self.action_space.spaces["variables"] else: variable_space = self.action_space # Set up the variable parameters by creating VariableParameters objects. variable_params = [] if variable_space is not None: variable_params = _VariableParameters.from_space(variable_space) log.debug( f"Successfully read action space information. " f"Length of events: {event_params}, length of variables: {len(variable_params)}." ) return event_params, variable_params def _setup_jl_agent(self) -> None: self.__jl_agent = self._jl_Algorithm( self.population, self.mutations, self.crossovers, self.max_cross_len, self.max_retries, self.event_params, self.variable_params, self._max_value, self.sense, self.seed, ) def _setup_model(self) -> None: """Set up the model by taking values from the supplied action space and initializing the first two parent generations. """ log.debug("Starting agent initialization.") # Set up learning rate and random seeding for all submodules. self._setup_lr_schedule() # Read the event and variable parameters from the action space. self.event_params, self.variable_params = self._event_and_variable_params() self._setup_jl_agent() self.set_random_seed(self.seed) log.debug("Successfully initialized NSGA 2 agent.") def _update_learning_rate(self, optmimizers: list[th.optim.Optimizer] | th.optim.Optimizer | None = None) -> None: """Update the learning rate as well as mutation and crossover rates. The mutation and crossover rates depend on the learning rate. Thus, a learning rate schedule will affect the crossover and mutation probabilities for each generation. :param optimizers: List of torch optimizers (not used by Nsga2). """ self._current_learning_rate = self.lr_schedule(self._current_progress_remaining) def _check_learn_config(self) -> None: # Check configuration of the algorithm for compatibility if self.population is None or self.population < 2: raise ValueError("The population size must be at least two.") if self.mutations is None or (not 0 <= self.mutations < 1): raise ValueError("The mutation rate must be between 0 and 1.") if self.crossovers is None or (not 0 <= self.crossovers < 0.5): raise ValueError( "The crossover rate must be between 0 and 0.5 (cannot cross more than half of population)." ) if not 0 <= self.max_cross_len <= 1: raise ValueError("The maximum crossover length must be between 0 and 1 (proportion of total length).")
[docs] def learn( self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, tb_log_name: str = "run", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> Nsga2: """ Return a trained model. The environment which the agent is training on should return an info dictionary when a solution is invalid. The info dictionary should contain a 'valid' key which is set to false in that case. If there are too many invalid solutions (more than half of the population), the agent will try to re-initialize these solutions until there is a sufficient number of valid solutions. :param total_timesteps: The total number of samples (env steps) to train on :param callback: callback(s) called at every step with state of the algorithm. :param log_interval: The number of timesteps before logging. :param tb_log_name: the name of the run for TensorBoard logging :param reset_num_timesteps: whether to reset the current timestep number (used in logging) :param progress_bar: Parameter to show progress bar, used by stable_baselines (currently unused!) :return: the trained model """ if self.n_generations is not None and total_timesteps > self.n_generations: total_timesteps = self.n_generations total_timesteps, callback = self._setup_learn( total_timesteps, callback, reset_num_timesteps, tb_log_name, progress_bar, ) # Reset training results infos self._set_training_infos(iteration=0) self._reset_training_infos() callback.on_training_start(locals(), globals()) log.info( f"Starting optimization for {total_timesteps} generations with parameters: " f"crossover rate: {self.crossovers}, mutation rate: {self.mutations}, population: {self.population}." ) # Initialize the parent generation in case it is empty (usually when the algorithm is first initialized) self._initialize_parent_generation_if_empty() # Train agent self._train(total_timesteps, callback, log_interval) return self
def _train( self, total_timesteps: int, callback: BaseCallback, log_interval: int = 1, ) -> None: """ Train the agent for the given number of timesteps. :param total_timesteps: The total number of samples (env steps) to train on :param callback: callback(s) called at every step with state of the algorithm. :param log_interval: The number of timesteps before logging. """ # Enter time step loop (each loop is one generation of solutions) while self.num_timesteps < total_timesteps: self.num_timesteps += 1 self._update_current_progress_remaining(self.num_timesteps, total_timesteps) # Display training infos if log_interval is not None and self.training_infos_buffer["iteration"] % log_interval == 0: self._display_training_infos() self._set_training_infos(iteration_time=time.time()) self._update_learning_rate() # Create empty offspring generation log.debug("Initializing offspring generation and performing evolution.") self.generation_offspr = self._jl_create_offspring(self.__jl_agent, self.generation_parent) self.training_infos_buffer["retries"] = self._jl_evolve( self.__jl_agent, self.generation_offspr, self.generation_parent, self._current_learning_rate ) self._set_training_infos(evolve_time=time.time()) self.total_retries += self.training_infos_buffer["retries"] log.debug("Evaluating offspring generation.") self.generation_offspr, self.training_infos_buffer["retries"] = self._evaluate(self.generation_offspr) self.total_retries += self.training_infos_buffer["retries"] self._set_training_infos(eval_time=time.time()) log.debug("Performing non-dominated sort with parent and offspring") new_generation_parent = self._jl_create_generation(self.__jl_agent, True) self.current_minima, self.seen_solutions, fronts, front_lengths = self._jl_evaluate_solutions( self.__jl_agent, self.generation_offspr, self.generation_parent, new_generation_parent ) self._fronts.append(fronts) self._front_lengths.append(front_lengths) self.ep_reward_buffer.append( np.vstack( ([sol.reward for sol in self.generation_offspr], [sol.reward for sol in self.generation_parent]) ) ) self.ep_actions_buffer.append( np.hstack((self._jl_get_actions(self.generation_offspr), self._jl_get_actions(self.generation_parent))) ) self.generation_parent = new_generation_parent self._set_training_infos(sorting_time=time.time()) log.debug(f"Successfully created and evaluated offspring generation with {self.population} solutions.") self.training_infos_buffer["iteration"] += 1 if not callback.on_step(): break callback.on_training_end() def _initialize_parent_generation_if_empty(self) -> None: """ Initialize parent generation if generation_parent is empty """ if Jl.length(self.generation_parent) == 0: log.debug("Initializing parent generation.") self.generation_parent = self._jl_create_generation(self.__jl_agent, False) self.training_infos_buffer["retries"] = self._jl_initialize_rnd(self.__jl_agent, self.generation_parent) # update training infos self._set_training_infos(evolve_time=time.time()) self.total_retries += self.training_infos_buffer["retries"] log.debug("Evaluating parent generation.") self.generation_parent, self.training_infos_buffer["invalid_sol"] = self._evaluate(self.generation_parent) # Update training infos self._set_training_infos(eval_time=time.time(), sorting_time=time.time()) # No sorting during first step. log.info(f"Successfully initialized first parent generation with {self.population} solutions.") def _display_training_infos(self) -> None: """ Display training infos. """ assert self.ep_info_buffer is not None, "Make sure that ep_info_buffer is exists before starting to learn." assert self.start_time is not None, "Make sure that start_time is set before starting to learn." self.logger.record("time/iterations", self.training_infos_buffer["iteration"], exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("general/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) self.logger.record("time/total", time.time() - self.start_time / 1000000000, exclude="tensorboard") self.logger.record("time/iteration", time.time() - self.training_infos_buffer["iteration_time"]) self.logger.record( "time/evolve", self.training_infos_buffer["evolve_time"] - self.training_infos_buffer["iteration_time"] ) self.logger.record( "time/evaluate", self.training_infos_buffer["eval_time"] - self.training_infos_buffer["evolve_time"] ) self.logger.record( "time/sort", self.training_infos_buffer["sorting_time"] - self.training_infos_buffer["eval_time"] ) self.logger.record("train/retries", self.training_infos_buffer["retries"]) self.logger.record("train/total_retries", self.total_retries) self.logger.record("train/learning_rate", self._current_learning_rate) self.logger.record("train/mutation_rate", self.mutations * self._current_learning_rate) self.logger.record("train/crossover_rate", self.crossovers * self._current_learning_rate) self.logger.record("train/seensolutions", self.seen_solutions) self.logger.record("evaluate/invalid", self.training_infos_buffer["invalid_sol"]) for idx, val in enumerate(self.current_minima): self.logger.record(f"evaluate/minimum_{idx}", val) self.logger.dump(step=self.num_timesteps) def _set_training_infos(self, **kwargs: Any) -> None: """Update the training infos buffer with the given values.""" self.training_infos_buffer.update(kwargs) def _reset_training_infos(self) -> None: """Reset training infos.""" self._set_training_infos( iteration_time=time.time(), evolve_time=time.time(), eval_time=time.time(), sorting_time=time.time(), retries=0, invalid_sol=0, ) def _evaluate(self, generation: _jlwrapper) -> tuple[_jlwrapper, int]: """Evaluate all solutions in the generation and store rewards :param generation: Sequence of solutions to evaluate :return: Sequence of evaluated solutions """ assert self.env is not None, "The agent needs to know the environment to evaluate solutions." rewards = np.array([]) retries = 0 infos: list[dict[str, Any]] = [] while retries < self.max_retries: _observations, rewards, terminated, truncated, infos = self.env.step( self._jl_get_actions(generation) ) # type: ignore dones = terminated | truncated self._update_info_buffer(infos, dones) # Ensure that there are always multiple rewards for every solution. if len(rewards.shape) == 1: rewards = np.reshape(rewards, (len(rewards), 1), order="F") solution_invalid = [] for idx, _ in enumerate(rewards): if "valid" in infos[idx] and infos[idx]["valid"] is False: rewards[idx] = np.full((len(rewards[idx]),), self._max_value, dtype=np.float64) solution_invalid.append(idx + 1) if len(solution_invalid) < self.population / 2: break else: retries += len(solution_invalid) retries += self._jl_reinitialize_rnd(self.__jl_agent, generation, solution_invalid) log.info( f"Randomized the generation again because " f"there were too many invalid solutions: {len(solution_invalid)}; retries: {retries}" ) assert rewards is not None self._jl_store_reward(generation, rewards) return generation, retries
[docs] def set_random_seed(self, seed: int | None = None) -> None: """ Set the seed of the pseudo-random generators (python, numpy, pytorch, gymnasium, julia) :param seed: Seed for the pseudo random generators. """ if seed is None: return ju_NSGA2.seed_b(self.__jl_agent, seed)
def _excluded_save_params(self) -> list[str]: """ Returns the names of the parameters that should be excluded from being saved by pickling. :return: List of parameters that should be excluded from being saved with pickle. """ excluded_params = super()._excluded_save_params() excluded_params.extend( [ "_Nsga2__jl_agent", "_jl_Algorithm", "_jl_create_generation", "_jl_create_offspring", "_jl_initialize_rnd", "_jl_reinitialize_rnd", "_jl_evolve", "_jl_evaluate_solutions", "_jl_store_reward", "_jl_get_actions", "_jl_setup_generation", "generation_parent", "generation_offspr", ] ) return excluded_params
[docs] @classmethod def load( cls: type[Nsga2], path: str | pathlib.Path | io.BufferedIOBase, env: GymEnv | None = None, device: th.device | str = "auto", custom_objects: dict[str, Any] | None = None, print_system_info: bool = False, force_reset: bool = True, **kwargs: Any, ) -> Nsga2: """ Load the model from a zip-file. Warning: ``load`` re-creates the model from scratch, it does not update it in-place! For an in-place load use ``set_parameters`` instead. :param path: path to the file (or a file-like) where to load the agent from :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment :param device: Device on which the code should run. :param custom_objects: Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. :param print_system_info: Whether to print system info from the saved model and the current system info (useful to debug loading issues) :param force_reset: Force call to ``reset()`` before training to avoid unexpected behavior. :param kwargs: extra arguments to change the model when loading :return: new model instance with loaded parameters """ model: Nsga2 = super().load(path, env, device, custom_objects, print_system_info, force_reset, **kwargs) log.setLevel(int(model.verbose * 10)) model._setup_jl_agent() model._load_generation() return model
def _load_generation(self) -> None: self.generation_offspr = self._jl_setup_generation( self.ep_actions_buffer[-1][: self.population]["events"], self.ep_actions_buffer[-1][: self.population]["variables"], self._max_value, ) self._evaluate(self.generation_offspr) self.generation_parent = self._jl_setup_generation( self.ep_actions_buffer[-1][self.population :]["events"], self.ep_actions_buffer[-1][self.population :]["variables"], self._max_value, ) self._evaluate(self.generation_parent)
[docs] def predict( self, observation: np.ndarray | dict[str, np.ndarray], state: tuple[np.ndarray, ...] | None = None, episode_start: np.ndarray | None = None, deterministic: bool = False, ) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]: """Predict function return actions from the best solution. :param observation: Observation from the environment. :param state: State from the environment. Not relevant here. :param episode_start: Whether the episode has just started. Not relevant here. :param deterministic: Whether to use deterministic actions. Not relevant here. :return: actions from the best solution """ # Reset training infos self._reset_training_infos() # Set crossover to zero ju_NSGA2.updateAlgorithmParameters_b(self.__jl_agent, 0.0) # Setup learning total_timesteps, callback = self._setup_learn( total_timesteps=self.predict_learn_steps, callback=None, reset_num_timesteps=False, tb_log_name="predict", progress_bar=False, ) # train from generation parent self._train(total_timesteps, callback) # select first solution of the first front best_solution = self.ep_actions_buffer[-1][self._fronts[-1][: self._front_lengths[-1][0]]][0] return best_solution, None # no states