resubmit 0.0.3__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: resubmit
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: Small wrapper around submitit to simplify cluster submissions
5
5
  Author: Amir Mehrpanah
6
6
  License: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "resubmit"
7
- version = "0.0.3"
7
+ version = "0.0.5"
8
8
  description = "Small wrapper around submitit to simplify cluster submissions"
9
9
  readme = "README.md"
10
10
  license = { text = "MIT" }
@@ -0,0 +1,197 @@
1
+ import re
2
+ from typing import Any, Dict, List, Tuple, Union, Optional, Iterable
3
+ import pandas as pd
4
+ from itertools import product
5
+ import logging
6
+
7
+
8
+ def _normalize_regex_spec(val: Any) -> Tuple[re.Pattern, bool]:
9
+ """Return (compiled_pattern, exclude_flag) for a given regex spec.
10
+
11
+ Raises ValueError for unsupported types.
12
+ """
13
+ if hasattr(val, "search") and callable(val.search):
14
+ return val, False
15
+ if isinstance(val, tuple) and len(val) >= 1:
16
+ pat = val[0]
17
+ exclude = bool(val[1]) if len(val) > 1 else False
18
+ return pat, exclude
19
+ if isinstance(val, dict):
20
+ pat = val["pattern"]
21
+ exclude = bool(val.get("exclude", False))
22
+ return pat, exclude
23
+ if isinstance(val, str):
24
+ if val.startswith("!re:"):
25
+ return re.compile(val[4:]), True
26
+ elif val.startswith("re:"):
27
+ return re.compile(val[3:]), False
28
+ raise ValueError(f"Unsupported regex spec: {val!r}")
29
+
30
+
31
+ def ensure_unique_combinations(
32
+ df: pd.DataFrame, cols: Union[str, List[str]], raise_on_conflict: bool = True
33
+ ) -> Tuple[bool, Optional[pd.DataFrame]]:
34
+ """Check that combinations of columns `cols` are unique across `df`.
35
+
36
+ Returns (is_unique, duplicates_df) where `duplicates_df` is None when unique.
37
+ If `raise_on_conflict` is True, raises `ValueError` when duplicates are found.
38
+ """
39
+ if isinstance(cols, str):
40
+ cols = [cols]
41
+ # Stringify to avoid dtype mismatch effects
42
+ key_series = df[cols].astype(str).agg("||".join, axis=1)
43
+ nunique = key_series.nunique()
44
+ if nunique == len(df):
45
+ return True, None
46
+
47
+ duplicates = df[key_series.duplicated(keep=False)]
48
+ if raise_on_conflict:
49
+ raise ValueError(
50
+ f"Found {len(duplicates)} rows with non-unique combinations for cols={cols}."
51
+ )
52
+ return False, duplicates
53
+
54
+
55
+ def create_jobs_dataframe(params: Dict[str, Any]) -> pd.DataFrame:
56
+ """Create a job DataFrame from a parameter map.
57
+
58
+ Rules:
59
+ - For parameters whose values are iterable (lists, tuples), we build the Cartesian
60
+ product across all such parameters.
61
+ - If a parameter value is callable, it is evaluated AFTER the initial DataFrame
62
+ is created; the callable is called as `col_values = fn(df)` and the result is
63
+ used as the column values (must be same length as `df`).
64
+ - If a parameter value is a regex spec (see `_is_regex_spec`), it is applied LAST
65
+ as a filter on the generated DataFrame. Regex specs can be used to include or
66
+ exclude rows based on the stringified value of that column.
67
+
68
+ Returns a filtered DataFrame with the applied callables and regex filters.
69
+ """
70
+ # Separate static values (used for product), callables and regex specs
71
+ static_items = {}
72
+ callables: Dict[str, Any] = {}
73
+ regex_specs: Dict[str, Any] = {}
74
+ unique_items: Dict[str, Any] = {}
75
+
76
+ for k, v in params.items():
77
+ # support explicit regex keys like 'name__regex' or 'name_regex' to filter 'name'
78
+ if k.endswith("__regex") or k.endswith("_regex"):
79
+ if k.endswith("__regex"):
80
+ base = k[: -len("__regex")]
81
+ else:
82
+ base = k[: -len("_regex")]
83
+ regex_specs[base] = v
84
+ elif k.endswith("__callable") or k.endswith("_callable"):
85
+ if k.endswith("__callable"):
86
+ base = k[: -len("__callable")]
87
+ else:
88
+ base = k[: -len("_callable")]
89
+ callables[base] = v
90
+ elif k.endswith("__unique") or k.endswith("_unique"):
91
+ if k.endswith("__unique"):
92
+ base = k[: -len("__unique")]
93
+ else:
94
+ base = k[: -len("_unique")]
95
+ unique_items[base] = v
96
+ continue
97
+ else:
98
+ static_items[k] = v
99
+
100
+ # If there are no static items, start from single-row DataFrame so callables
101
+ # can still compute columns.
102
+ if len(static_items) == 0:
103
+ df = pd.DataFrame([{}])
104
+ else:
105
+ df = pd.DataFrame(
106
+ list(product(*static_items.values())), columns=static_items.keys()
107
+ )
108
+
109
+ # Apply callables (they must accept the dataframe and return a list-like)
110
+ for k, fn in callables.items():
111
+ vals = fn(df)
112
+ if len(vals) != len(df):
113
+ raise ValueError(
114
+ f"Callable for param {k!r} returned length {len(vals)} != {len(df)}"
115
+ )
116
+ df[k] = vals
117
+
118
+ # Apply regex specs last as filters
119
+ if len(regex_specs) > 0:
120
+ mask = pd.Series([True] * len(df), index=df.index)
121
+ for k, spec in regex_specs.items():
122
+ pat, exclude = _normalize_regex_spec(spec)
123
+ col_str = df[k].astype(str)
124
+ matches = col_str.apply(lambda s: bool(pat.search(s)))
125
+ if exclude:
126
+ mask = mask & ~matches
127
+ else:
128
+ mask = mask & matches
129
+ df = df[mask].reset_index(drop=True)
130
+
131
+ # apply unique constraints
132
+ for k, unique_val in unique_items.items():
133
+ is_unique, duplicates = ensure_unique_combinations(
134
+ df,
135
+ k,
136
+ raise_on_conflict=unique_val,
137
+ )
138
+ if not is_unique:
139
+ logging.warning(f"Non-unique values found for column {k!r}:\n{duplicates}")
140
+
141
+ return df
142
+
143
+
144
+ def submit_jobs(
145
+ jobs_args: dict[Iterable],
146
+ func: Any,
147
+ *,
148
+ timeout_min: int,
149
+ cpus_per_task: int = 16,
150
+ mem_gb: int = 64,
151
+ num_gpus: int = 1,
152
+ folder: str = "logs/%j",
153
+ block: bool = False,
154
+ prompt: bool = True,
155
+ local_run: bool = False,
156
+ slurm_additional_parameters: Dict | None = None,
157
+ ) -> Any:
158
+ """
159
+ Submit jobs described by `jobs_args` where each entry is a dict of kwargs for `func`.
160
+ A dataframe is created from cartesian product of parameter lists, with support for callables and regex filtering.
161
+ 1. use `__unique' postfix in keys to enforce uniqueness.
162
+ 2. use `__callable' postfix in keys to define callables for column values.
163
+ 3. use `__regex' postfix in keys to define regex filters for columns.
164
+
165
+ Args:
166
+ jobs_args: dict of lists of job parameters.
167
+ func: Function to be submitted for each job.
168
+ timeout_min: Job timeout in minutes.
169
+ cpus_per_task: Number of CPUs per task.
170
+ mem_gb: Memory in GB.
171
+ num_gpus: Number of GPUs.
172
+ folder: Folder for logs.
173
+ block: Whether to block until jobs complete.
174
+ prompt: Whether to prompt for confirmation before submission.
175
+ local_run: If True, runs the function locally instead of submitting.
176
+ slurm_additional_parameters: Additional Slurm parameters as a dict. If not provided, defaults to {"gpus": num_gpus}.
177
+ Returns:
178
+ The result of `submit_jobs` from `.__submit`.
179
+ """
180
+
181
+ jobs_df = create_jobs_dataframe(jobs_args)
182
+ records = jobs_df.to_dict(orient="records")
183
+ from .__submit import _submit_jobs
184
+
185
+ return _submit_jobs(
186
+ records,
187
+ func,
188
+ timeout_min=timeout_min,
189
+ cpus_per_task=cpus_per_task,
190
+ mem_gb=mem_gb,
191
+ num_gpus=num_gpus,
192
+ folder=folder,
193
+ block=block,
194
+ prompt=prompt,
195
+ local_run=local_run,
196
+ slurm_additional_parameters=slurm_additional_parameters,
197
+ )
@@ -1,6 +1,6 @@
1
1
  """resubmit: small helpers around submitit for reproducible cluster submissions."""
2
2
 
3
- from .submit import submit_jobs
4
- from .debug import maybe_attach_debugger
3
+ from .__debug import maybe_attach_debugger
4
+ from .__bookkeeping import submit_jobs
5
5
 
6
6
  __all__ = ["submit_jobs", "maybe_attach_debugger"]
@@ -1,23 +1,21 @@
1
1
  """Core submission utilities wrapping submitit."""
2
+
2
3
  from typing import Any, Callable, Iterable, List, Optional, Dict
3
4
 
4
5
 
5
- def submit_jobs(
6
+ def _submit_jobs(
6
7
  jobs_args: Iterable[dict],
7
8
  func: Callable[[List[dict]], Any],
8
9
  *,
9
10
  timeout_min: int,
10
- cpus_per_task: int = 16,
11
- mem_gb: int = 64,
12
- num_gpus: int = 1,
13
- account: Optional[str] = None,
14
- folder: str = "logs/%j",
15
- block: bool = False,
16
- prompt: bool = True,
17
- local_run: bool = False,
11
+ cpus_per_task: int,
12
+ mem_gb: int,
13
+ num_gpus: int,
14
+ folder: str,
15
+ block: bool,
16
+ prompt: bool,
17
+ local_run: bool,
18
18
  slurm_additional_parameters: Optional[Dict] = None,
19
- constraint: Optional[str] = None,
20
- reservation: Optional[str] = None,
21
19
  ):
22
20
  """Submit jobs described by `jobs_args` where each entry is a dict of kwargs for `func`.
23
21
 
@@ -46,6 +44,7 @@ def submit_jobs(
46
44
  return
47
45
 
48
46
  import submitit
47
+
49
48
  print("submitting jobs")
50
49
  executor = submitit.AutoExecutor(folder=folder)
51
50
 
@@ -56,14 +55,6 @@ def submit_jobs(
56
55
  slurm_additional_parameters = dict(slurm_additional_parameters)
57
56
  slurm_additional_parameters.setdefault("gpus", num_gpus)
58
57
 
59
- # Allow explicit overrides similar to `account`.
60
- if account is not None:
61
- slurm_additional_parameters["account"] = account
62
- if reservation is not None:
63
- slurm_additional_parameters["reservation"] = reservation
64
- if constraint is not None:
65
- slurm_additional_parameters["constraint"] = constraint
66
-
67
58
  print("Slurm additional parameters:", slurm_additional_parameters)
68
59
 
69
60
  executor.update_parameters(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: resubmit
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: Small wrapper around submitit to simplify cluster submissions
5
5
  Author: Amir Mehrpanah
6
6
  License: MIT
@@ -1,12 +1,14 @@
1
1
  LICENSE
2
2
  README.md
3
3
  pyproject.toml
4
+ src/resubmit/__bookkeeping.py
5
+ src/resubmit/__debug.py
4
6
  src/resubmit/__init__.py
5
- src/resubmit/debug.py
6
- src/resubmit/submit.py
7
+ src/resubmit/__submit.py
7
8
  src/resubmit.egg-info/PKG-INFO
8
9
  src/resubmit.egg-info/SOURCES.txt
9
10
  src/resubmit.egg-info/dependency_links.txt
10
11
  src/resubmit.egg-info/requires.txt
11
12
  src/resubmit.egg-info/top_level.txt
13
+ tests/test_bookkeeping.py
12
14
  tests/test_resubmit.py
@@ -0,0 +1,45 @@
1
+ import re
2
+ import pandas as pd
3
+ from resubmit.__bookkeeping import create_jobs_dataframe, ensure_unique_combinations
4
+
5
+
6
+ def test_create_jobs_basic():
7
+ params = {"a": [1, 2], "b": [10]}
8
+ df = create_jobs_dataframe(params)
9
+ assert len(df) == 2
10
+ assert set(df.columns) == {"a", "b"}
11
+
12
+
13
+ def test_create_jobs_callable():
14
+ params = {"a": [1, 2], "b__callable": lambda df: df["a"] * 10}
15
+ df = create_jobs_dataframe(params)
16
+ assert list(df["b"]) == [10, 20]
17
+
18
+
19
+ def test_create_jobs_regex_include():
20
+ params = {"name": ["apple", "banana", "apricot"], "name__regex": re.compile(r"^a")}
21
+ df = create_jobs_dataframe(params)
22
+ assert set(df["name"]) == {"apple", "apricot"}
23
+
24
+
25
+ def test_create_jobs_regex_exclude():
26
+ params = {"name": ["apple", "banana", "apricot"], "name_regex": "!re:^a"}
27
+ df = create_jobs_dataframe(params)
28
+ assert set(df["name"]) == {"banana"}
29
+
30
+
31
+ def test_ensure_unique_combinations_raises():
32
+ df = pd.DataFrame({"a": [1, 1, 2], "b": [3, 3, 4]})
33
+ try:
34
+ ensure_unique_combinations(df, ["a", "b"], raise_on_conflict=True)
35
+ raised = False
36
+ except ValueError:
37
+ raised = True
38
+ assert raised
39
+
40
+
41
+ def test_ensure_unique_combinations_ok():
42
+ df = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
43
+ ok, dup = ensure_unique_combinations(df, ["a", "b"], raise_on_conflict=False)
44
+ assert ok
45
+ assert dup is None
@@ -0,0 +1,115 @@
1
+ import pytest
2
+ from resubmit import maybe_attach_debugger
3
+ from resubmit.__submit import _submit_jobs
4
+
5
+
6
+ def dummy_func(jobs):
7
+ # return a list of strings to show behavior
8
+ return [f"ok-{j['id']}" for j in jobs]
9
+
10
+
11
+ def test_submit_local_run():
12
+ jobs = [{"id": 1}, {"id": 2}]
13
+ res = _submit_jobs(
14
+ jobs,
15
+ dummy_func,
16
+ timeout_min=1,
17
+ local_run=True,
18
+ num_gpus=0,
19
+ cpus_per_task=1,
20
+ mem_gb=8,
21
+ folder="dummy/%j",
22
+ block=False,
23
+ prompt=False,
24
+ )
25
+ assert res == ["ok-1", "ok-2"]
26
+
27
+
28
+ def test_maybe_attach_debugger_noop():
29
+ # should not raise when port is None or 0
30
+ maybe_attach_debugger(None)
31
+ maybe_attach_debugger(0)
32
+
33
+
34
+ def test_slurm_parameters_optional(monkeypatch):
35
+ events = {}
36
+
37
+ class DummyExecutor:
38
+ def __init__(self, folder):
39
+ events["folder"] = folder
40
+
41
+ def update_parameters(self, **kwargs):
42
+ # capture the parameters passed to the executor
43
+ events["update"] = kwargs
44
+
45
+ def map_array(self, func, jobs_list):
46
+ return []
47
+
48
+ class DummyModule:
49
+ AutoExecutor = DummyExecutor
50
+
51
+ import sys
52
+
53
+ monkeypatch.setitem(sys.modules, "submitit", DummyModule)
54
+
55
+ jobs = [{"id": 1}]
56
+ # default: no constraint/reservation keys
57
+ _submit_jobs(
58
+ jobs,
59
+ dummy_func,
60
+ timeout_min=1,
61
+ local_run=False,
62
+ num_gpus=2,
63
+ prompt=False,
64
+ cpus_per_task=4,
65
+ mem_gb=16,
66
+ folder="logs/%j",
67
+ block=False,
68
+ )
69
+ slurm = events["update"]["slurm_additional_parameters"]
70
+ assert slurm["gpus"] == 2
71
+ assert "constraint" not in slurm
72
+ assert "reservation" not in slurm
73
+
74
+
75
+ def test_slurm_parameters_settable(monkeypatch):
76
+ events = {}
77
+
78
+ class DummyExecutor:
79
+ def __init__(self, folder):
80
+ events["folder"] = folder
81
+
82
+ def update_parameters(self, **kwargs):
83
+ events["update"] = kwargs
84
+
85
+ def map_array(self, func, jobs_list):
86
+ return []
87
+
88
+ class DummyModule:
89
+ AutoExecutor = DummyExecutor
90
+
91
+ import sys
92
+
93
+ monkeypatch.setitem(sys.modules, "submitit", DummyModule)
94
+
95
+ jobs = [{"id": 1}]
96
+ _submit_jobs(
97
+ jobs,
98
+ dummy_func,
99
+ timeout_min=1,
100
+ local_run=False,
101
+ prompt=False,
102
+ slurm_additional_parameters={
103
+ "constraint": "thin",
104
+ "reservation": "safe",
105
+ },
106
+ cpus_per_task=4,
107
+ mem_gb=16,
108
+ folder="logs/%j",
109
+ block=False,
110
+ num_gpus=1,
111
+ )
112
+ slurm = events["update"]["slurm_additional_parameters"]
113
+ assert slurm["constraint"] == "thin"
114
+ assert slurm["reservation"] == "safe"
115
+
@@ -1,116 +0,0 @@
1
- import pytest
2
- from resubmit import submit_jobs, maybe_attach_debugger
3
-
4
-
5
- def dummy_func(jobs):
6
- # return a list of strings to show behavior
7
- return [f"ok-{j['id']}" for j in jobs]
8
-
9
-
10
- def test_submit_local_run():
11
- jobs = [{"id": 1}, {"id": 2}]
12
- res = submit_jobs(jobs, dummy_func, timeout_min=1, local_run=True)
13
- assert res == ["ok-1", "ok-2"]
14
-
15
-
16
- def test_maybe_attach_debugger_noop():
17
- # should not raise when port is None or 0
18
- maybe_attach_debugger(None)
19
- maybe_attach_debugger(0)
20
-
21
-
22
- def test_slurm_parameters_optional(monkeypatch):
23
- events = {}
24
-
25
- class DummyExecutor:
26
- def __init__(self, folder):
27
- events['folder'] = folder
28
-
29
- def update_parameters(self, **kwargs):
30
- # capture the parameters passed to the executor
31
- events['update'] = kwargs
32
-
33
- def map_array(self, func, jobs_list):
34
- return []
35
-
36
- class DummyModule:
37
- AutoExecutor = DummyExecutor
38
-
39
- import sys
40
- monkeypatch.setitem(sys.modules, 'submitit', DummyModule)
41
-
42
- jobs = [{"id": 1}]
43
- # default: no constraint/reservation keys
44
- submit_jobs(jobs, dummy_func, timeout_min=1, local_run=False, num_gpus=2, prompt=False)
45
- slurm = events['update']['slurm_additional_parameters']
46
- assert slurm['gpus'] == 2
47
- assert 'constraint' not in slurm
48
- assert 'reservation' not in slurm
49
-
50
-
51
- def test_slurm_parameters_settable(monkeypatch):
52
- events = {}
53
-
54
- class DummyExecutor:
55
- def __init__(self, folder):
56
- events['folder'] = folder
57
-
58
- def update_parameters(self, **kwargs):
59
- events['update'] = kwargs
60
-
61
- def map_array(self, func, jobs_list):
62
- return []
63
-
64
- class DummyModule:
65
- AutoExecutor = DummyExecutor
66
-
67
- import sys
68
- monkeypatch.setitem(sys.modules, 'submitit', DummyModule)
69
-
70
- jobs = [{"id": 1}]
71
- submit_jobs(
72
- jobs,
73
- dummy_func,
74
- timeout_min=1,
75
- local_run=False,
76
- constraint='thin',
77
- reservation='safe',
78
- prompt=False,
79
- )
80
- slurm = events['update']['slurm_additional_parameters']
81
- assert slurm['constraint'] == 'thin'
82
- assert slurm['reservation'] == 'safe'
83
-
84
-
85
- def test_slurm_parameters_arg_precedence(monkeypatch):
86
- events = {}
87
-
88
- class DummyExecutor:
89
- def __init__(self, folder):
90
- events['folder'] = folder
91
-
92
- def update_parameters(self, **kwargs):
93
- events['update'] = kwargs
94
-
95
- def map_array(self, func, jobs_list):
96
- return []
97
-
98
- class DummyModule:
99
- AutoExecutor = DummyExecutor
100
-
101
- import sys
102
- monkeypatch.setitem(sys.modules, 'submitit', DummyModule)
103
-
104
- jobs = [{"id": 1}]
105
- # slurm_additional_parameters has constraint='foo' but explicit arg should override
106
- submit_jobs(
107
- jobs,
108
- dummy_func,
109
- timeout_min=1,
110
- local_run=False,
111
- slurm_additional_parameters={'constraint': 'foo'},
112
- constraint='bar',
113
- prompt=False,
114
- )
115
- slurm = events['update']['slurm_additional_parameters']
116
- assert slurm['constraint'] == 'bar'
File without changes
File without changes
File without changes