From a6d9aa4e7baa4837a191d780a30adee03bfd750e Mon Sep 17 00:00:00 2001 From: Dante Camarena Date: Fri, 8 Apr 2022 13:16:01 -0400 Subject: [PATCH 1/2] Fix string typing on trainer_factory.py --- ml-agents/mlagents/trainers/trainer/trainer_factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml-agents/mlagents/trainers/trainer/trainer_factory.py b/ml-agents/mlagents/trainers/trainer/trainer_factory.py index 90f1aabef07..65cd610c55d 100644 --- a/ml-agents/mlagents/trainers/trainer/trainer_factory.py +++ b/ml-agents/mlagents/trainers/trainer/trainer_factory.py @@ -1,5 +1,5 @@ import os -from typing import Dict +from typing import Dict, Optional from mlagents_envs.logging_util import get_logger from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager @@ -25,7 +25,7 @@ def __init__( load_model: bool, seed: int, param_manager: EnvironmentParameterManager, - init_path: str = None, + init_path: Optional[str] = None, multi_gpu: bool = False, ): """ From adf05b5d3640a71377852bf505aac8b22202af7b Mon Sep 17 00:00:00 2001 From: Dante Camarena Date: Fri, 8 Apr 2022 13:17:42 -0400 Subject: [PATCH 2/2] Fix typing on directory_utils.py --- ml-agents/mlagents/trainers/directory_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/directory_utils.py b/ml-agents/mlagents/trainers/directory_utils.py index 80379d81e99..5e979fab93b 100644 --- a/ml-agents/mlagents/trainers/directory_utils.py +++ b/ml-agents/mlagents/trainers/directory_utils.py @@ -1,11 +1,12 @@ import os +from typing import Optional from mlagents.trainers.exception import UnityTrainerException from mlagents.trainers.settings import TrainerSettings from mlagents.trainers.model_saver.torch_model_saver import DEFAULT_CHECKPOINT_NAME def validate_existing_directories( - output_path: str, resume: bool, force: bool, init_path: str = None + output_path: str, resume: bool, force: bool, init_path: Optional[str] = None ) -> None: """ Validates that if the run_id model exists, we do not overwrite it unless --force is specified.