nntool 1.3.0__cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nntool might be problematic. Click here for more details.

Files changed (76) hide show
  1. nntool/__init__.py +2 -0
  2. nntool/__pycache__/__init__.cpython-39.opt-1.pyc +0 -0
  3. nntool/__pycache__/__init__.cpython-39.pyc +0 -0
  4. nntool/experiment/__init__.py +6 -0
  5. nntool/experiment/__pycache__/__init__.cpython-39.opt-1.pyc +0 -0
  6. nntool/experiment/__pycache__/__init__.cpython-39.pyc +0 -0
  7. nntool/experiment/__pycache__/config.cpython-39.opt-1.pyc +0 -0
  8. nntool/experiment/__pycache__/config.cpython-39.pyc +0 -0
  9. nntool/experiment/__pycache__/utils.cpython-39.opt-1.pyc +0 -0
  10. nntool/experiment/__pycache__/utils.cpython-39.pyc +0 -0
  11. nntool/experiment/config.py +112 -0
  12. nntool/experiment/utils.py +63 -0
  13. nntool/parser/__init__.py +1 -0
  14. nntool/parser/__pycache__/__init__.cpython-39.opt-1.pyc +0 -0
  15. nntool/parser/__pycache__/__init__.cpython-39.pyc +0 -0
  16. nntool/parser/__pycache__/parse.cpython-39.opt-1.pyc +0 -0
  17. nntool/parser/__pycache__/parse.cpython-39.pyc +0 -0
  18. nntool/parser/parse.py +22 -0
  19. nntool/plot/__init__.py +6 -0
  20. nntool/plot/__pycache__/__init__.cpython-39.opt-1.pyc +0 -0
  21. nntool/plot/__pycache__/__init__.cpython-39.pyc +0 -0
  22. nntool/plot/__pycache__/context.cpython-39.opt-1.pyc +0 -0
  23. nntool/plot/__pycache__/context.cpython-39.pyc +0 -0
  24. nntool/plot/context.py +48 -0
  25. nntool/plot/csrc/__compile__.cpython-39-x86_64-linux-gnu.so +0 -0
  26. nntool/plot/csrc/__init__.py +3 -0
  27. nntool/plot/csrc/__pycache__/__init__.cpython-39.opt-1.pyc +0 -0
  28. nntool/plot/csrc/__pycache__/__init__.cpython-39.pyc +0 -0
  29. nntool/slurm/__init__.py +9 -0
  30. nntool/slurm/__pycache__/__init__.cpython-39.opt-1.pyc +0 -0
  31. nntool/slurm/__pycache__/__init__.cpython-39.pyc +0 -0
  32. nntool/slurm/__pycache__/config.cpython-39.opt-1.pyc +0 -0
  33. nntool/slurm/__pycache__/config.cpython-39.pyc +0 -0
  34. nntool/slurm/__pycache__/function.cpython-39.opt-1.pyc +0 -0
  35. nntool/slurm/__pycache__/function.cpython-39.pyc +0 -0
  36. nntool/slurm/__pycache__/task.cpython-39.opt-1.pyc +0 -0
  37. nntool/slurm/__pycache__/task.cpython-39.pyc +0 -0
  38. nntool/slurm/__pycache__/wrap.cpython-39.opt-1.pyc +0 -0
  39. nntool/slurm/__pycache__/wrap.cpython-39.pyc +0 -0
  40. nntool/slurm/accelerator/__init__.py +0 -0
  41. nntool/slurm/accelerator/__pycache__/__init__.cpython-39.opt-1.pyc +0 -0
  42. nntool/slurm/accelerator/__pycache__/__init__.cpython-39.pyc +0 -0
  43. nntool/slurm/accelerator/__pycache__/utils.cpython-39.opt-1.pyc +0 -0
  44. nntool/slurm/accelerator/__pycache__/utils.cpython-39.pyc +0 -0
  45. nntool/slurm/accelerator/utils.py +39 -0
  46. nntool/slurm/config.py +182 -0
  47. nntool/slurm/csrc/__compile__.cpython-39-x86_64-linux-gnu.so +0 -0
  48. nntool/slurm/csrc/__init__.py +5 -0
  49. nntool/slurm/csrc/__pycache__/__init__.cpython-39.opt-1.pyc +0 -0
  50. nntool/slurm/csrc/__pycache__/__init__.cpython-39.pyc +0 -0
  51. nntool/slurm/function.py +173 -0
  52. nntool/slurm/task.py +231 -0
  53. nntool/slurm/wrap.py +210 -0
  54. nntool/train/__init__.py +25 -0
  55. nntool/train/__pycache__/__init__.cpython-39.opt-1.pyc +0 -0
  56. nntool/train/__pycache__/__init__.cpython-39.pyc +0 -0
  57. nntool/train/__pycache__/trainer.cpython-39.opt-1.pyc +0 -0
  58. nntool/train/__pycache__/trainer.cpython-39.pyc +0 -0
  59. nntool/train/trainer.py +92 -0
  60. nntool/utils/__init__.py +6 -0
  61. nntool/utils/__pycache__/__init__.cpython-39.opt-1.pyc +0 -0
  62. nntool/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  63. nntool/wandb/__init__.py +1 -0
  64. nntool/wandb/__pycache__/__init__.cpython-39.opt-1.pyc +0 -0
  65. nntool/wandb/__pycache__/__init__.cpython-39.pyc +0 -0
  66. nntool/wandb/__pycache__/config.cpython-39.opt-1.pyc +0 -0
  67. nntool/wandb/__pycache__/config.cpython-39.pyc +0 -0
  68. nntool/wandb/config.py +118 -0
  69. nntool-1.3.0.dist-info/._SOURCES.txt +0 -0
  70. nntool-1.3.0.dist-info/._dependency_links.txt +0 -0
  71. nntool-1.3.0.dist-info/._requires.txt +0 -0
  72. nntool-1.3.0.dist-info/._top_level.txt +0 -0
  73. nntool-1.3.0.dist-info/METADATA +25 -0
  74. nntool-1.3.0.dist-info/RECORD +76 -0
  75. nntool-1.3.0.dist-info/WHEEL +6 -0
  76. nntool-1.3.0.dist-info/top_level.txt +1 -0
nntool/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ def test_import():
2
+ print(f"nntool located at {__file__} is imported!")
@@ -0,0 +1,6 @@
1
+ from .config import BaseExperimentConfig
2
+ from .utils import (
3
+ get_current_time,
4
+ get_output_path,
5
+ read_toml_file,
6
+ )
@@ -0,0 +1,112 @@
1
+ import os
2
+
3
+ from typing import Any, Dict
4
+ from pathlib import Path
5
+ from dataclasses import dataclass
6
+ from .utils import get_output_path, read_toml_file
7
+
8
+
9
+ @dataclass
10
+ class BaseExperimentConfig:
11
+ """
12
+ Configuration class for setting up an experiment.
13
+
14
+ :param config_name: The name of the configuration.
15
+ :param output_folder: The folder path where the outputs will be saved.
16
+ :param experiment_name_key: Key for experiment name in the environment variable, default is 'EXP_NAME'.
17
+ :param env_toml_path: Path to the `env.toml` file, default is 'env.toml'.
18
+ :param append_date_to_path: If True, the current date and time will be appended to the output path, default is True.
19
+ :param existing_output_path_ok: If True, the existing output path is ok to be reused, default is False.
20
+ """
21
+
22
+ # config name
23
+ config_name: str
24
+
25
+ # the output folder for the outputs
26
+ output_folder: str
27
+
28
+ # key for experiment name in the environment variable
29
+ experiment_name_key: str = "EXP_NAME"
30
+
31
+ # the path to the env.toml file
32
+ env_toml_path: str = "env.toml"
33
+
34
+ # append date time to the output path
35
+ append_date_to_path: bool = True
36
+
37
+ # exisiting output path is ok
38
+ existing_output_path_ok: bool = False
39
+
40
+ def __post_init__(self):
41
+ # annotations
42
+ self.experiment_name: str
43
+ self.project_path: str
44
+ self.output_path: str
45
+ self.current_time: str
46
+ self.env_toml: Dict[str, Any] = self.__prepare_env_toml_dict()
47
+
48
+ self.experiment_name = self.__prepare_experiment_name()
49
+ self.project_path, self.output_path, self.current_time = (
50
+ self.__prepare_experiment_paths()
51
+ )
52
+
53
+ # custom post update for the derived class
54
+ self.set_up_stateful_fields()
55
+
56
+ def __prepare_env_toml_dict(self):
57
+ env_toml_path = Path(self.env_toml_path)
58
+ if not env_toml_path.exists():
59
+ raise FileNotFoundError(f"{env_toml_path} does not exist")
60
+
61
+ config = read_toml_file(env_toml_path)
62
+ return config
63
+
64
+ def __prepare_experiment_name(self):
65
+ return os.environ.get(self.experiment_name_key, "default")
66
+
67
+ def __prepare_experiment_paths(self):
68
+ project_path = self.env_toml["project"]["path"]
69
+
70
+ output_path, current_time = get_output_path(
71
+ output_path=os.path.join(
72
+ self.output_folder, self.config_name, self.experiment_name
73
+ ),
74
+ append_date=self.append_date_to_path,
75
+ cache_into_env=False,
76
+ )
77
+ output_path = f"{project_path}/{output_path}"
78
+ return project_path, output_path, current_time
79
+
80
+ def get_output_path(self) -> str:
81
+ """Return the output path prepared for the experiment.
82
+
83
+ :return: output path for the experiment
84
+ """
85
+ return self.output_path
86
+
87
+ def get_current_time(self) -> str:
88
+ """Return the current time for the experiment.
89
+
90
+ :return: current time for the experiment
91
+ """
92
+ return self.current_time
93
+
94
+ def set_up_stateful_fields(self):
95
+ """
96
+ Post configuration steps for stateful fields such as `output_path` in the derived class.
97
+ This method should be overridden in the derived class.
98
+ """
99
+ pass
100
+
101
+ def start(self):
102
+ """
103
+ Start the experimen. This will
104
+ - cache `NNTOOL_OUTPUT_PATH` and `NNTOOL_OUTPUT_PATH_DATE` into environment variables, which means the later launched processes would inherit these variables.
105
+ - create the output path if it does not exist.
106
+ """
107
+ os.environ["NNTOOL_OUTPUT_PATH"] = self.get_output_path()
108
+ os.environ["NNTOOL_OUTPUT_PATH_DATE"] = self.get_current_time()
109
+
110
+ # create the output path
111
+ output_path = Path(self.get_output_path())
112
+ output_path.mkdir(parents=True, exist_ok=self.existing_output_path_ok)
@@ -0,0 +1,63 @@
1
+ import os
2
+ import datetime
3
+ import tomli
4
+
5
+
6
+ def get_current_time() -> str:
7
+ """get current time in this format: MMDDYYYY/HHMMSS
8
+
9
+ :return: time in the format MMDDYYYY/HHMMSS
10
+ """
11
+ # Get the current time
12
+ current_time = datetime.datetime.now()
13
+
14
+ # Format the time (MDY/HMS)
15
+ formatted_time = current_time.strftime("%m%d%Y/%H%M%S")
16
+
17
+ return formatted_time
18
+
19
+
20
+ def read_toml_file(file_path: str) -> dict:
21
+ """Read a toml file and return the content as a dictionary
22
+
23
+ :param file_path: path to the toml file
24
+ :return: content of the toml file as a dictionary
25
+ """
26
+ with open(file_path, "rb") as f:
27
+ content = tomli.load(f)
28
+
29
+ return content
30
+
31
+
32
+ def get_output_path(
33
+ output_path: str = "./",
34
+ append_date: bool = True,
35
+ cache_into_env: bool = True,
36
+ ) -> tuple[str, str]:
37
+ """Get output path based on environment variable OUTPUT_PATH and NNTOOL_OUTPUT_PATH.
38
+ The output path is appended with the current time if append_date is True (e.g. /OUTPUT_PATH/xxx/MMDDYYYY/HHMMSS).
39
+
40
+ :param append_date: append a children folder with the date time, defaults to True
41
+ :param cache_into_env: whether cache the newly created path into env, defaults to True
42
+ :return: (output path, current time)
43
+ """
44
+ if "OUTPUT_PATH" in os.environ:
45
+ output_path = os.environ["OUTPUT_PATH"]
46
+ current_time = "" if not append_date else get_current_time()
47
+ elif "NNTOOL_OUTPUT_PATH" in os.environ:
48
+ # reuse the NNTOOL_OUTPUT_PATH if it is set
49
+ output_path = os.environ["NNTOOL_OUTPUT_PATH"]
50
+ current_time = "" if not append_date else os.environ["NNTOOL_OUTPUT_PATH_DATE"]
51
+ else:
52
+ current_time = get_current_time()
53
+ if append_date:
54
+ output_path = os.path.join(output_path, current_time)
55
+ print(
56
+ f"OUTPUT_PATH is not found in environment variables. NNTOOL_OUTPUT_PATH is set using path: {output_path}"
57
+ )
58
+
59
+ if cache_into_env:
60
+ os.environ["NNTOOL_OUTPUT_PATH"] = output_path
61
+ os.environ["NNTOOL_OUTPUT_PATH_DATE"] = current_time
62
+
63
+ return output_path, current_time
@@ -0,0 +1 @@
1
+ from .parse import parse_from_cli
nntool/parser/parse.py ADDED
@@ -0,0 +1,22 @@
1
+ import tyro
2
+
3
+ from typing import Union, Callable, Type, Any
4
+
5
+
6
+ def parse_from_cli(
7
+ ArgsType: Type[Any],
8
+ parser: Union[str, Callable] = "tyro",
9
+ *args,
10
+ **kwargs,
11
+ ) -> Any:
12
+ # parse with built-in parser or custom parser function
13
+ parser_fn = None
14
+ if isinstance(parser, str):
15
+ if parser == "tyro":
16
+ parser_fn = tyro.cli
17
+ else:
18
+ raise ValueError(f"Parser `{parser}` is not supported.")
19
+ else:
20
+ parser_fn = parser
21
+ args: ArgsType = parser_fn(ArgsType, *args, **kwargs)
22
+ return args
@@ -0,0 +1,6 @@
1
+ import cythonpackage
2
+
3
+ cythonpackage.init(__name__)
4
+
5
+ from .csrc.latexify import *
6
+ from .context import latexify_plot, enable_latexify
nntool/plot/context.py ADDED
@@ -0,0 +1,48 @@
1
+ import os
2
+ import matplotlib
3
+ import seaborn as sns
4
+
5
+ from typing import Union
6
+ from dataclasses import dataclass
7
+ from .csrc.latexify import SIZE_SMALL, latexify, savefig
8
+
9
+
10
+ @dataclass
11
+ class latexify_plot:
12
+ enable: bool = True
13
+ width_scale_factor: float = 1
14
+ height_scale_factor: float = 1
15
+ fig_width: Union[float, None] = None
16
+ fig_height: Union[float, None] = None
17
+ font_size: int = SIZE_SMALL
18
+
19
+ def __post_init__(self):
20
+ self.legend_size = 7 if self.enable else None
21
+
22
+ def __enter__(self):
23
+ if self.enable:
24
+ os.environ["LATEXIFY"] = "1"
25
+ latexify(
26
+ width_scale_factor=self.width_scale_factor,
27
+ height_scale_factor=self.height_scale_factor,
28
+ fig_width=self.fig_width,
29
+ fig_height=self.fig_height,
30
+ font_size=self.font_size,
31
+ )
32
+ return self
33
+
34
+ def __exit__(self, *args):
35
+ if self.enable:
36
+ os.environ.pop("LATEXIFY")
37
+ matplotlib.rcParams.update(matplotlib.rcParamsDefault)
38
+
39
+ def savefig(
40
+ self, filename, despine: bool = True, fig_dir: str = "tests/plot", **kwargs
41
+ ):
42
+ if despine:
43
+ sns.despine()
44
+ savefig(filename, fig_dir=fig_dir, **kwargs)
45
+
46
+
47
+ # This is for backward compatibility
48
+ enable_latexify = latexify_plot
@@ -0,0 +1,3 @@
1
+ import cythonpackage
2
+
3
+ cythonpackage.init(__name__)
@@ -0,0 +1,9 @@
1
+ from .config import SlurmConfig, SlurmArgs
2
+ from .wrap import (
3
+ slurm_function,
4
+ slurm_fn,
5
+ slurm_launcher,
6
+ slurm_distributed_launcher,
7
+ )
8
+ from .function import SlurmFunction
9
+ from .task import PyTorchDistributedTask
File without changes
@@ -0,0 +1,39 @@
1
+ import subprocess
2
+
3
+
4
+ def nvidia_smi_gpu_memory_stats() -> dict:
5
+ """
6
+ Parse the nvidia-smi output and extract the memory used stats.
7
+ """
8
+ out_dict = {}
9
+ try:
10
+ sp = subprocess.Popen(
11
+ ["nvidia-smi", "--query-gpu=index,memory.used", "--format=csv,noheader"],
12
+ stdout=subprocess.PIPE,
13
+ stderr=subprocess.PIPE,
14
+ close_fds=True,
15
+ )
16
+ out_str = sp.communicate()
17
+ out_list = out_str[0].decode("utf-8").split("\n")
18
+ out_dict = {}
19
+ for item in out_list:
20
+ if " MiB" in item:
21
+ gpu_idx, mem_used = item.split(",")
22
+ gpu_key = f"gpu_{gpu_idx}_mem_used_gb"
23
+ out_dict[gpu_key] = int(mem_used.strip().split(" ")[0]) / 1024
24
+ except FileNotFoundError:
25
+ raise Exception(
26
+ "Failed to find the 'nvidia-smi' executable for printing GPU stats"
27
+ )
28
+ except subprocess.CalledProcessError as e:
29
+ raise Exception(f"nvidia-smi returned non zero error code: {e.returncode}")
30
+
31
+ return out_dict
32
+
33
+
34
+ def nvidia_smi_gpu_memory_stats_str() -> str:
35
+ """
36
+ Parse the nvidia-smi output and extract the memory used stats.
37
+ """
38
+ stats = nvidia_smi_gpu_memory_stats()
39
+ return ", ".join([f"{k}: {v:.4f}" for k, v in stats.items()])
nntool/slurm/config.py ADDED
@@ -0,0 +1,182 @@
1
+ import os
2
+ import sys
3
+ from dataclasses import dataclass, field, replace
4
+ from typing import Literal, Dict, Optional
5
+
6
+
7
+ @dataclass
8
+ class SlurmConfig:
9
+ """
10
+ Configuration class for SLURM job submission and execution.
11
+
12
+ :param mode: Running mode for the job. Options include:
13
+ "debug" (default), "exec", "local", or "slurm".
14
+
15
+ :param job_name: The name of the SLURM job. Default is 'JOB_NAME'.
16
+
17
+ :param partition: The name of the SLURM partition to use. Default is 'PARTITION_NAME'.
18
+
19
+ :param output_parent_path: The parent directory name for saving slurm folder. Default is './'.
20
+
21
+ :param output_folder: The folder name where SLURM output files will be stored. Default is 'slurm'.
22
+
23
+ :param node_list: A string specifying the nodes to use. Leave blank to use all available nodes. Default is an empty string.
24
+
25
+ :param node_list_exclude: A string specifying the nodes to exclude. Leave blank to use all nodes in the node list. Default is an empty string.
26
+
27
+ :param num_of_node: The number of nodes to request. Default is 1.
28
+
29
+ :param tasks_per_node: The number of tasks to run per node. Default is 1.
30
+
31
+ :param gpus_per_task: The number of GPUs to request per task. Default is 0.
32
+
33
+ :param cpus_per_task: The number of CPUs to request per task. Default is 1.
34
+
35
+ :param gpus_per_node: The number of GPUs to request per node. If this is set, `gpus_per_task` will be ignored. Default is None.
36
+
37
+ :param mem: The amount of memory to request. Leave blank to use the default memory configuration of the node. Default is an empty string.
38
+
39
+ :param timeout_min: The time limit for the job in minutes. Default is `sys.maxsize` for effectively no limit.
40
+
41
+ :param pack_code: Whether to pack the codebase before submission. Default is False.
42
+
43
+ :param use_packed_code: Whether to use the packed code for execution. Default is False.
44
+
45
+ :param code_root: The root directory of the codebase. Default is the current directory (``.``).
46
+
47
+ :param code_file_suffixes: A list of file extensions for code files to be included when packing. Default includes ``.py``, ``.sh``, ``.yaml``, and ``.toml``.
48
+
49
+ :param exclude_code_folders: A list of folder names relative to `code_root` that will be excluded from packing. Default excludes 'wandb', 'outputs', and 'datasets'.
50
+
51
+ :param use_distributed_env: Whether to use a distributed environment for the job. Default is False.
52
+
53
+ :param processes_per_task: The number of processes to run per task. This value is not used by SLURM but is relevant for distributed environments. Default is 1.
54
+
55
+ :param distributed_launch_command: The command to launch distributed environment setup, using environment variables like ``{num_processes}``, ``{num_machines}``, ``{machine_rank}``, ``{main_process_ip}``, ``{main_process_port}``. Default is an empty string.
56
+
57
+ :param extra_params_kwargs: Additional parameters for the SLURM job as a dictionary of key-value pairs. Default is an empty dictionary.
58
+
59
+ :param extra_submit_kwargs: Additional submit parameters for the SLURM job as a dictionary of key-value pairs. Default is an empty dictionary.
60
+
61
+ :param extra_task_kwargs: Additional task parameters for the SLURM job as a dictionary of key-value pairs. Default is an empty dictionary.
62
+ """
63
+
64
+ # running mode
65
+ mode: Literal["debug", "exec", "local", "slurm"] = "debug"
66
+
67
+ # slurm job name
68
+ job_name: str = "JOB_NAME"
69
+
70
+ # slurm partition name
71
+ partition: str = "PARTITION_NAME"
72
+
73
+ # slurm output parent path
74
+ output_parent_path: str = "./"
75
+
76
+ # slurm output folder name
77
+ output_folder: str = "slurm"
78
+
79
+ # node list string (leave blank to use all nodes)
80
+ node_list: str = ""
81
+
82
+ # node list string to be excluded (leave blank to use all nodes in the node list)
83
+ node_list_exclude: str = ""
84
+
85
+ # number of nodes to request
86
+ num_of_node: int = 1
87
+
88
+ # tasks per node
89
+ tasks_per_node: int = 1
90
+
91
+ # number of gpus per task to request
92
+ gpus_per_task: int = 0
93
+
94
+ # number of cpus per task to request
95
+ cpus_per_task: int = 1
96
+
97
+ # number of gpus per node to request (if this is set, gpus_per_task will be ignored)
98
+ gpus_per_node: Optional[int] = None
99
+
100
+ # memory to request (leave black to use default memory configurations in the node)
101
+ mem: str = ""
102
+
103
+ # time out min
104
+ timeout_min: int = sys.maxsize
105
+
106
+ # whether to pack code
107
+ pack_code: bool = False
108
+
109
+ # use packed code to run
110
+ use_packed_code: bool = False
111
+
112
+ # code root
113
+ code_root: str = "."
114
+
115
+ # code file extensions
116
+ code_file_suffixes: list[str] = field(
117
+ default_factory=lambda: [".py", ".sh", ".yaml", ".toml"]
118
+ )
119
+
120
+ # exclude folders (relative to the code root)
121
+ exclude_code_folders: list[str] = field(
122
+ default_factory=lambda: ["wandb", "outputs", "datasets"]
123
+ )
124
+
125
+ # whether to use distributed environment
126
+ use_distributed_env: bool = False
127
+
128
+ # distributed enviroment task
129
+ distributed_env_task: Literal["torch"] = "torch"
130
+
131
+ # processes per task (this value is not used by slurm, but in the distributed environment)
132
+ processes_per_task: int = 1
133
+
134
+ # distributed launch command (this will be called after the distributed enviroment is set up)
135
+ # the following environment variables are available:
136
+ # num_processes: int
137
+ # num_machines: int
138
+ # machine_rank: int
139
+ # main_process_ip: str
140
+ # main_process_port: int
141
+ # use braces to access the environment variables, e.g. {num_processes}
142
+ distributed_launch_command: str = ""
143
+
144
+ # whether distributed_launch_command includes the entry point
145
+ distributed_launch_command_with_entry_point: bool = True
146
+
147
+ # extra slurm job parameters
148
+ extra_params_kwargs: Dict[str, str] = field(default_factory=dict)
149
+
150
+ # extra slurm submit parameters
151
+ extra_submit_kwargs: Dict[str, str] = field(default_factory=dict)
152
+
153
+ # extra slurm task parameters
154
+ extra_task_kwargs: Dict[str, str] = field(default_factory=dict)
155
+
156
+ def __post_init__(self):
157
+ # normalize the output folder
158
+ output_folder_suffix = ""
159
+ if self.mode != "slurm":
160
+ output_folder_suffix = f"_{self.mode}"
161
+ if self.output_folder.endswith("slurm"):
162
+ self.output_folder = f"{self.output_folder}{output_folder_suffix}"
163
+ else:
164
+ self.output_folder = os.path.join(
165
+ self.output_folder, f"slurm{output_folder_suffix}"
166
+ )
167
+
168
+ # output path
169
+ self.output_path: str = os.path.join(
170
+ self.output_parent_path, self.output_folder
171
+ )
172
+
173
+ def set_output_path(self, output_parent_path: str) -> "SlurmConfig":
174
+ """Set output path and date for the slurm job."""
175
+ new_config = replace(
176
+ self,
177
+ output_parent_path=output_parent_path,
178
+ )
179
+ return new_config
180
+
181
+
182
+ SlurmArgs = SlurmConfig
@@ -0,0 +1,5 @@
1
+ import cythonpackage
2
+
3
+ cythonpackage.init(__name__)
4
+
5
+ from ._slurm import SlurmFunction as _SlurmFunction