diff --git a/ml-agents/mlagents/trainers/directory_utils.py b/ml-agents/mlagents/trainers/directory_utils.py index 80379d81e99..5e979fab93b 100644 --- a/ml-agents/mlagents/trainers/directory_utils.py +++ b/ml-agents/mlagents/trainers/directory_utils.py @@ -1,11 +1,12 @@ import os +from typing import Optional from mlagents.trainers.exception import UnityTrainerException from mlagents.trainers.settings import TrainerSettings from mlagents.trainers.model_saver.torch_model_saver import DEFAULT_CHECKPOINT_NAME def validate_existing_directories( - output_path: str, resume: bool, force: bool, init_path: str = None + output_path: str, resume: bool, force: bool, init_path: Optional[str] = None ) -> None: """ Validates that if the run_id model exists, we do not overwrite it unless --force is specified. diff --git a/ml-agents/mlagents/trainers/trainer/trainer_factory.py b/ml-agents/mlagents/trainers/trainer/trainer_factory.py index 90f1aabef07..65cd610c55d 100644 --- a/ml-agents/mlagents/trainers/trainer/trainer_factory.py +++ b/ml-agents/mlagents/trainers/trainer/trainer_factory.py @@ -1,5 +1,5 @@ import os -from typing import Dict +from typing import Dict, Optional from mlagents_envs.logging_util import get_logger from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager @@ -25,7 +25,7 @@ def __init__( load_model: bool, seed: int, param_manager: EnvironmentParameterManager, - init_path: str = None, + init_path: Optional[str] = None, multi_gpu: bool = False, ): """