diff --git a/src/tasks.py b/src/tasks.py index 29284e5..3006d90 100644 --- a/src/tasks.py +++ b/src/tasks.py @@ -42,8 +42,7 @@ def train(config: Dict[str, Any], """ seed_everything(seed, workers=True) - if not os.path.exists(out_dir): - os.makedirs(out_dir) + os.makedirs(out_dir, exist_ok=True) logger, data, trainer_kwargs, model, callbacks = config["log"], \ config["data"], \ @@ -84,8 +83,7 @@ def finetune(config: Dict[str, Any], """ seed_everything(seed, workers=True) - if not os.path.exists(out_dir): - os.makedirs(out_dir) + os.makedirs(out_dir, exist_ok=True) logger, data, trainer_kwargs, model, callbacks = config["log"], \ config["data"], \ @@ -142,8 +140,7 @@ def evaluate( Note: Evaluation must be run in a single run as resuming the trainer state is not supported for prediction. """ seed_everything(seed, workers=True) - if not os.path.exists(out_dir): - os.makedirs(out_dir) + os.makedirs(out_dir, exist_ok=True) logger, data, trainer_kwargs, model, callbacks = config["log"], \ config["data"], \