Skip to content
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
type=int,
default=500,
)
parser.add_argument("--n-models", help="Number of models for optimizing hyperparameters.", type=int, default=1)
parser.add_argument(
"-optimize", "--optimize-hyperparameters", action="store_true", default=False, help="Run hyperparameters search"
)
Expand Down Expand Up @@ -201,6 +202,7 @@
args.storage,
args.study_name,
args.n_trials,
args.n_models,
args.n_jobs,
args.sampler,
args.pruner,
Expand Down
116 changes: 65 additions & 51 deletions utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
storage: Optional[str] = None,
study_name: Optional[str] = None,
n_trials: int = 1,
n_models: int = 1,
n_jobs: int = 1,
sampler: str = "tpe",
pruner: str = "median",
Expand Down Expand Up @@ -133,10 +134,13 @@ def __init__(
self.no_optim_plots = no_optim_plots
# maximum number of trials for finding the best hyperparams
self.n_trials = n_trials
# number of parallel trained models, result is the median score
self.n_models = n_models
# number of parallel jobs when doing hyperparameter search
self.n_jobs = n_jobs
self.sampler = sampler
self.pruner = pruner
assert not (self.n_models > 1 and self.pruner != "none"), "Pruner is not currently supported for multiple models"
self.n_startup_trials = n_startup_trials
self.n_evaluations = n_evaluations
self.deterministic_eval = not self.is_atari(self.env_id)
Expand Down Expand Up @@ -649,15 +653,18 @@ def objective(self, trial: optuna.Trial) -> float:
if self.verbose >= 2:
trial_verbosity = self.verbose

model = ALGOS[self.algo](
env=env,
tensorboard_log=None,
# We do not seed the trial
seed=None,
verbose=trial_verbosity,
device=self.device,
**kwargs,
)
models = [
ALGOS[self.algo](
env=env,
tensorboard_log=None,
# We do not seed the trial
seed=None,
verbose=trial_verbosity if model_idx == 0 else 0,
device=self.device,
**kwargs,
)
for model_idx in range(self.n_models)
]

eval_env = self.create_envs(n_envs=self.n_eval_envs, eval_env=True)

Expand All @@ -668,51 +675,58 @@ def objective(self, trial: optuna.Trial) -> float:
path = None
if self.optimization_log_path is not None:
path = os.path.join(self.optimization_log_path, f"trial_{str(trial.number)}")
callbacks = get_callback_list({"callback": self.specified_callbacks})
eval_callback = TrialEvalCallback(
eval_env,
trial,
best_model_save_path=path,
log_path=path,
n_eval_episodes=self.n_eval_episodes,
eval_freq=optuna_eval_freq,
deterministic=self.deterministic_eval,
)
callbacks.append(eval_callback)

learn_kwargs = {}
# Special case for ARS
if self.algo == "ars" and self.n_envs > 1:
learn_kwargs["async_eval"] = AsyncEval(
[lambda: self.create_envs(n_envs=1, no_log=True) for _ in range(self.n_envs)], model.policy
rewards = np.zeros(self.n_models)
for model_idx, model in enumerate(models):
callbacks = get_callback_list({"callback": self.specified_callbacks})
eval_callback = TrialEvalCallback(
eval_env,
trial,
best_model_save_path=path,
log_path=path,
n_eval_episodes=self.n_eval_episodes,
eval_freq=optuna_eval_freq,
deterministic=self.deterministic_eval,
)
callbacks.append(eval_callback)

try:
model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs)
# Free memory
model.env.close()
eval_env.close()
except (AssertionError, ValueError) as e:
# Sometimes, random hyperparams can generate NaN
# Free memory
model.env.close()
eval_env.close()
# Prune hyperparams that generate NaNs
print(e)
print("============")
print("Sampled hyperparams:")
pprint(sampled_hyperparams)
raise optuna.exceptions.TrialPruned()
is_pruned = eval_callback.is_pruned
reward = eval_callback.last_mean_reward

del model.env, eval_env
del model

if is_pruned:
raise optuna.exceptions.TrialPruned()

return reward
learn_kwargs = {}
# Special case for ARS
if self.algo == "ars" and self.n_envs > 1:
learn_kwargs["async_eval"] = AsyncEval(
[lambda: self.create_envs(n_envs=1, no_log=True) for _ in range(self.n_envs)], model.policy
)

try:
model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs)
# Free memory
model.env.close()
except (AssertionError, ValueError) as e:
# Sometimes, random hyperparams can generate NaN
# Free memory
model.env.close()
eval_env.close()
# Prune hyperparams that generate NaNs
print(e)
print("============")
print("Sampled hyperparams:")
pprint(sampled_hyperparams)
raise optuna.exceptions.TrialPruned()
is_pruned = eval_callback.is_pruned
rewards[model_idx] = eval_callback.last_mean_reward

del model.env
del model

if is_pruned:
eval_env.close()
del eval_env
raise optuna.exceptions.TrialPruned()

eval_env.close()
del eval_env

return np.median(rewards)

def hyperparameters_optimization(self) -> None:

Expand Down