nshtrainer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +64 -0
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_experimental/flops/__init__.py +48 -0
- nshtrainer/_experimental/flops/flop_counter.py +787 -0
- nshtrainer/_experimental/flops/module_tracker.py +140 -0
- nshtrainer/_snoop.py +216 -0
- nshtrainer/_submit/print_environment_info.py +31 -0
- nshtrainer/_submit/session/_output.py +12 -0
- nshtrainer/_submit/session/_script.py +109 -0
- nshtrainer/_submit/session/lsf.py +467 -0
- nshtrainer/_submit/session/slurm.py +573 -0
- nshtrainer/_submit/session/unified.py +350 -0
- nshtrainer/actsave/__init__.py +7 -0
- nshtrainer/actsave/_callback.py +75 -0
- nshtrainer/actsave/_loader.py +144 -0
- nshtrainer/actsave/_saver.py +337 -0
- nshtrainer/callbacks/__init__.py +35 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
- nshtrainer/callbacks/base.py +113 -0
- nshtrainer/callbacks/early_stopping.py +112 -0
- nshtrainer/callbacks/ema.py +383 -0
- nshtrainer/callbacks/finite_checks.py +75 -0
- nshtrainer/callbacks/gradient_skipping.py +103 -0
- nshtrainer/callbacks/interval.py +322 -0
- nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
- nshtrainer/callbacks/log_epoch.py +35 -0
- nshtrainer/callbacks/norm_logging.py +187 -0
- nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
- nshtrainer/callbacks/print_table.py +90 -0
- nshtrainer/callbacks/throughput_monitor.py +56 -0
- nshtrainer/callbacks/timer.py +157 -0
- nshtrainer/callbacks/wandb_watch.py +103 -0
- nshtrainer/config.py +289 -0
- nshtrainer/data/__init__.py +4 -0
- nshtrainer/data/balanced_batch_sampler.py +132 -0
- nshtrainer/data/transform.py +67 -0
- nshtrainer/lr_scheduler/__init__.py +18 -0
- nshtrainer/lr_scheduler/_base.py +101 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
- nshtrainer/model/__init__.py +44 -0
- nshtrainer/model/base.py +641 -0
- nshtrainer/model/config.py +2064 -0
- nshtrainer/model/modules/callback.py +157 -0
- nshtrainer/model/modules/debug.py +42 -0
- nshtrainer/model/modules/distributed.py +70 -0
- nshtrainer/model/modules/logger.py +170 -0
- nshtrainer/model/modules/profiler.py +24 -0
- nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
- nshtrainer/model/modules/shared_parameters.py +72 -0
- nshtrainer/nn/__init__.py +19 -0
- nshtrainer/nn/mlp.py +106 -0
- nshtrainer/nn/module_dict.py +66 -0
- nshtrainer/nn/module_list.py +50 -0
- nshtrainer/nn/nonlinearity.py +157 -0
- nshtrainer/optimizer.py +62 -0
- nshtrainer/runner.py +21 -0
- nshtrainer/scripts/check_env.py +41 -0
- nshtrainer/scripts/find_packages.py +51 -0
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/signal_connector.py +208 -0
- nshtrainer/trainer/trainer.py +340 -0
- nshtrainer/typecheck.py +144 -0
- nshtrainer/util/environment.py +119 -0
- nshtrainer/util/seed.py +11 -0
- nshtrainer/util/singleton.py +89 -0
- nshtrainer/util/slurm.py +49 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +19 -0
- nshtrainer-0.1.0.dist-info/METADATA +18 -0
- nshtrainer-0.1.0.dist-info/RECORD +72 -0
- nshtrainer-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import weakref
|
|
2
|
+
from typing import Set
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch.autograd.graph import register_multi_grad_hook
|
|
6
|
+
from torch.nn.modules.module import (
|
|
7
|
+
register_module_forward_hook,
|
|
8
|
+
register_module_forward_pre_hook,
|
|
9
|
+
)
|
|
10
|
+
from torch.utils._pytree import tree_flatten
|
|
11
|
+
|
|
12
|
+
__all__ = ["ModuleTracker"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ModuleTracker:
|
|
16
|
+
"""
|
|
17
|
+
``ModuleTracker`` is a context manager that tracks the nn.Module hierarchy during execution
|
|
18
|
+
so that other system can query which Module is currently being executed (or its backward is being
|
|
19
|
+
executed).
|
|
20
|
+
|
|
21
|
+
You can access the ``parents`` attribute on this context manager to get the set of all the
|
|
22
|
+
Modules currently being executed via their fqn (fully qualified name, also used as the key within
|
|
23
|
+
the state_dict).
|
|
24
|
+
You can access the ``is_bw`` attribute to know if you are currently running in backward or not.
|
|
25
|
+
|
|
26
|
+
Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag
|
|
27
|
+
will remain ``True`` after the forward until another Module is executed. If you need it to be
|
|
28
|
+
more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance
|
|
29
|
+
is possible but not done yet, please submit an issue requesting this if you need it.
|
|
30
|
+
|
|
31
|
+
Example usage
|
|
32
|
+
|
|
33
|
+
.. code-block:: python
|
|
34
|
+
|
|
35
|
+
mod = torch.nn.Linear(2, 2)
|
|
36
|
+
|
|
37
|
+
with ModuleTracker() as tracker:
|
|
38
|
+
# Access anything during the forward pass
|
|
39
|
+
def my_linear(m1, m2, bias):
|
|
40
|
+
print(f"Current modules: {tracker.parents}")
|
|
41
|
+
return torch.mm(m1, m2.t()) + bias
|
|
42
|
+
torch.nn.functional.linear = my_linear
|
|
43
|
+
|
|
44
|
+
mod(torch.rand(2, 2))
|
|
45
|
+
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
parents: Set[str]
|
|
49
|
+
"""
|
|
50
|
+
A Set containing the fqn for each module currently running their forward
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self):
|
|
54
|
+
self.parents = {"Global"}
|
|
55
|
+
self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
|
56
|
+
self._seen_modules: weakref.WeakSet = weakref.WeakSet()
|
|
57
|
+
self._has_callback = False
|
|
58
|
+
|
|
59
|
+
def _maybe_set_engine_callback(self):
|
|
60
|
+
# This assumes no concurrent calls to backward
|
|
61
|
+
if self._has_callback:
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
def callback():
|
|
65
|
+
self.parents = {"Global"}
|
|
66
|
+
self._has_callback = False
|
|
67
|
+
|
|
68
|
+
torch.autograd.Variable._execution_engine.queue_callback(callback)
|
|
69
|
+
self._has_callback = True
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def is_bw(self):
|
|
73
|
+
"""
|
|
74
|
+
A boolean marking if this is currently running during the backward pass or not
|
|
75
|
+
"""
|
|
76
|
+
return torch._C._current_graph_task_id() != -1
|
|
77
|
+
|
|
78
|
+
def _get_mod_name(self, mod):
|
|
79
|
+
if mod not in self._known_modules:
|
|
80
|
+
self._known_modules[mod] = type(mod).__name__
|
|
81
|
+
mod_name = self._known_modules[mod]
|
|
82
|
+
if mod not in self._seen_modules:
|
|
83
|
+
for name, submod in mod.named_children():
|
|
84
|
+
self._known_modules[submod] = f"{mod_name}.{name}"
|
|
85
|
+
self._get_mod_name(submod)
|
|
86
|
+
self._seen_modules.add(mod)
|
|
87
|
+
return mod_name
|
|
88
|
+
|
|
89
|
+
def _get_append_fn(self, name, is_bw):
|
|
90
|
+
def fn(*args):
|
|
91
|
+
if is_bw:
|
|
92
|
+
self._maybe_set_engine_callback()
|
|
93
|
+
if name in self.parents:
|
|
94
|
+
print(
|
|
95
|
+
"The module hierarchy tracking seems to be messed up."
|
|
96
|
+
"Please file a bug to PyTorch."
|
|
97
|
+
)
|
|
98
|
+
self.parents.add(name)
|
|
99
|
+
|
|
100
|
+
return fn
|
|
101
|
+
|
|
102
|
+
def _get_pop_fn(self, name, is_bw):
|
|
103
|
+
def fn(*args):
|
|
104
|
+
if name in self.parents:
|
|
105
|
+
self.parents.remove(name)
|
|
106
|
+
elif not is_bw:
|
|
107
|
+
# Due to some input/output not requiring gradients, we cannot enforce
|
|
108
|
+
# proper nesting in backward
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
"The Module hierarchy tracking is wrong. Report a bug to PyTorch"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return fn
|
|
114
|
+
|
|
115
|
+
def _fw_pre_hook(self, mod, input):
|
|
116
|
+
name = self._get_mod_name(mod)
|
|
117
|
+
self._get_append_fn(name, False)()
|
|
118
|
+
|
|
119
|
+
args, _ = tree_flatten(input)
|
|
120
|
+
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
|
|
121
|
+
if tensors:
|
|
122
|
+
register_multi_grad_hook(tensors, self._get_pop_fn(name, True))
|
|
123
|
+
|
|
124
|
+
def _fw_post_hook(self, mod, input, output):
|
|
125
|
+
name = self._get_mod_name(mod)
|
|
126
|
+
self._get_pop_fn(name, False)()
|
|
127
|
+
|
|
128
|
+
args, _ = tree_flatten(output)
|
|
129
|
+
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
|
|
130
|
+
if tensors:
|
|
131
|
+
register_multi_grad_hook(tensors, self._get_append_fn(name, True))
|
|
132
|
+
|
|
133
|
+
def __enter__(self):
|
|
134
|
+
self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
|
|
135
|
+
self._fw_post_handle = register_module_forward_hook(self._fw_post_hook)
|
|
136
|
+
return self
|
|
137
|
+
|
|
138
|
+
def __exit__(self, *args):
|
|
139
|
+
self._fw_pre_handle.remove()
|
|
140
|
+
self._fw_post_handle.remove()
|
nshtrainer/_snoop.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
from typing import Any, Protocol, cast
|
|
3
|
+
|
|
4
|
+
from typing_extensions import TypeVar
|
|
5
|
+
|
|
6
|
+
T = TypeVar("T", infer_variance=True)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SnoopConstructor(Protocol):
|
|
10
|
+
def __call__(self, *args, **kwargs) -> contextlib.AbstractContextManager: ...
|
|
11
|
+
|
|
12
|
+
def disable(self) -> contextlib.AbstractContextManager: ...
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import warnings
|
|
17
|
+
from contextlib import nullcontext
|
|
18
|
+
|
|
19
|
+
import lovely_numpy as lo
|
|
20
|
+
import lovely_tensors as lt
|
|
21
|
+
import numpy
|
|
22
|
+
import pysnooper
|
|
23
|
+
import pysnooper.utils
|
|
24
|
+
import torch
|
|
25
|
+
from pkg_resources import DistributionNotFound, get_distribution
|
|
26
|
+
|
|
27
|
+
FLOATING_POINTS = set()
|
|
28
|
+
for i in ["float", "double", "half", "complex128", "complex32", "complex64"]:
|
|
29
|
+
if hasattr(torch, i): # older version of PyTorch do not have complex dtypes
|
|
30
|
+
FLOATING_POINTS.add(getattr(torch, i))
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
__version__ = get_distribution(__name__).version
|
|
34
|
+
except DistributionNotFound:
|
|
35
|
+
# package is not installed
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def default_format(x):
|
|
39
|
+
try:
|
|
40
|
+
formatted = str(lt.lovely(x))
|
|
41
|
+
return formatted
|
|
42
|
+
except BaseException:
|
|
43
|
+
return str(x.shape)
|
|
44
|
+
|
|
45
|
+
def default_numpy_format(x):
|
|
46
|
+
return str(lo.lovely(x))
|
|
47
|
+
|
|
48
|
+
class TorchSnooper(pysnooper.tracer.Tracer):
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
*args,
|
|
52
|
+
tensor_format=default_format,
|
|
53
|
+
numpy_format=default_numpy_format,
|
|
54
|
+
**kwargs,
|
|
55
|
+
):
|
|
56
|
+
self.orig_custom_repr = (
|
|
57
|
+
kwargs["custom_repr"] if "custom_repr" in kwargs else ()
|
|
58
|
+
)
|
|
59
|
+
custom_repr = (lambda x: True, self.compute_repr)
|
|
60
|
+
kwargs["custom_repr"] = (custom_repr,)
|
|
61
|
+
super(TorchSnooper, self).__init__(*args, **kwargs)
|
|
62
|
+
self.tensor_format = tensor_format
|
|
63
|
+
self.numpy_format = numpy_format
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def is_return_types(x):
|
|
67
|
+
return type(x).__module__ == "torch.return_types"
|
|
68
|
+
|
|
69
|
+
def return_types_repr(self, x):
|
|
70
|
+
if type(x).__name__ in {
|
|
71
|
+
"max",
|
|
72
|
+
"min",
|
|
73
|
+
"median",
|
|
74
|
+
"mode",
|
|
75
|
+
"sort",
|
|
76
|
+
"topk",
|
|
77
|
+
"kthvalue",
|
|
78
|
+
}:
|
|
79
|
+
return (
|
|
80
|
+
type(x).__name__
|
|
81
|
+
+ "(values="
|
|
82
|
+
+ self.tensor_format(x.values)
|
|
83
|
+
+ ", indices="
|
|
84
|
+
+ self.tensor_format(x.indices)
|
|
85
|
+
+ ")"
|
|
86
|
+
)
|
|
87
|
+
if type(x).__name__ == "svd":
|
|
88
|
+
return (
|
|
89
|
+
"svd(U="
|
|
90
|
+
+ self.tensor_format(x.U)
|
|
91
|
+
+ ", S="
|
|
92
|
+
+ self.tensor_format(x.S)
|
|
93
|
+
+ ", V="
|
|
94
|
+
+ self.tensor_format(x.V)
|
|
95
|
+
+ ")"
|
|
96
|
+
)
|
|
97
|
+
if type(x).__name__ == "slogdet":
|
|
98
|
+
return (
|
|
99
|
+
"slogdet(sign="
|
|
100
|
+
+ self.tensor_format(x.sign)
|
|
101
|
+
+ ", logabsdet="
|
|
102
|
+
+ self.tensor_format(x.logabsdet)
|
|
103
|
+
+ ")"
|
|
104
|
+
)
|
|
105
|
+
if type(x).__name__ == "qr":
|
|
106
|
+
return (
|
|
107
|
+
"qr(Q="
|
|
108
|
+
+ self.tensor_format(x.Q)
|
|
109
|
+
+ ", R="
|
|
110
|
+
+ self.tensor_format(x.R)
|
|
111
|
+
+ ")"
|
|
112
|
+
)
|
|
113
|
+
if type(x).__name__ == "solve":
|
|
114
|
+
return (
|
|
115
|
+
"solve(solution="
|
|
116
|
+
+ self.tensor_format(x.solution)
|
|
117
|
+
+ ", LU="
|
|
118
|
+
+ self.tensor_format(x.LU)
|
|
119
|
+
+ ")"
|
|
120
|
+
)
|
|
121
|
+
if type(x).__name__ == "geqrf":
|
|
122
|
+
return (
|
|
123
|
+
"geqrf(a="
|
|
124
|
+
+ self.tensor_format(x.a)
|
|
125
|
+
+ ", tau="
|
|
126
|
+
+ self.tensor_format(x.tau)
|
|
127
|
+
+ ")"
|
|
128
|
+
)
|
|
129
|
+
if type(x).__name__ in {"symeig", "eig"}:
|
|
130
|
+
return (
|
|
131
|
+
type(x).__name__
|
|
132
|
+
+ "(eigenvalues="
|
|
133
|
+
+ self.tensor_format(x.eigenvalues)
|
|
134
|
+
+ ", eigenvectors="
|
|
135
|
+
+ self.tensor_format(x.eigenvectors)
|
|
136
|
+
+ ")"
|
|
137
|
+
)
|
|
138
|
+
if type(x).__name__ == "triangular_solve":
|
|
139
|
+
return (
|
|
140
|
+
"triangular_solve(solution="
|
|
141
|
+
+ self.tensor_format(x.solution)
|
|
142
|
+
+ ", cloned_coefficient="
|
|
143
|
+
+ self.tensor_format(x.cloned_coefficient)
|
|
144
|
+
+ ")"
|
|
145
|
+
)
|
|
146
|
+
if type(x).__name__ == "gels":
|
|
147
|
+
return (
|
|
148
|
+
"gels(solution="
|
|
149
|
+
+ self.tensor_format(x.solution)
|
|
150
|
+
+ ", QR="
|
|
151
|
+
+ self.tensor_format(x.QR)
|
|
152
|
+
+ ")"
|
|
153
|
+
)
|
|
154
|
+
warnings.warn("Unknown return_types encountered, open a bug report!")
|
|
155
|
+
|
|
156
|
+
def compute_repr(self, x):
|
|
157
|
+
orig_repr_func = pysnooper.utils.get_repr_function(x, self.orig_custom_repr)
|
|
158
|
+
if torch.is_tensor(x):
|
|
159
|
+
return self.tensor_format(x)
|
|
160
|
+
if isinstance(x, numpy.ndarray):
|
|
161
|
+
return self.numpy_format(x)
|
|
162
|
+
if self.is_return_types(x):
|
|
163
|
+
return self.return_types_repr(x)
|
|
164
|
+
if orig_repr_func is not repr:
|
|
165
|
+
return orig_repr_func(x)
|
|
166
|
+
if isinstance(x, (list, tuple)):
|
|
167
|
+
content = ""
|
|
168
|
+
for i in x:
|
|
169
|
+
if content != "":
|
|
170
|
+
content += ", "
|
|
171
|
+
content += self.compute_repr(i)
|
|
172
|
+
if isinstance(x, tuple) and len(x) == 1:
|
|
173
|
+
content += ","
|
|
174
|
+
if isinstance(x, tuple):
|
|
175
|
+
return "(" + content + ")"
|
|
176
|
+
return "[" + content + "]"
|
|
177
|
+
if isinstance(x, dict):
|
|
178
|
+
content = ""
|
|
179
|
+
for k, v in x.items():
|
|
180
|
+
if content != "":
|
|
181
|
+
content += ", "
|
|
182
|
+
content += self.compute_repr(k) + ": " + self.compute_repr(v)
|
|
183
|
+
return "{" + content + "}"
|
|
184
|
+
return repr(x)
|
|
185
|
+
|
|
186
|
+
class _Snoop:
|
|
187
|
+
disable = nullcontext
|
|
188
|
+
__call__ = TorchSnooper
|
|
189
|
+
|
|
190
|
+
snoop: SnoopConstructor = cast(Any, _Snoop())
|
|
191
|
+
|
|
192
|
+
except ImportError:
|
|
193
|
+
import warnings
|
|
194
|
+
from contextlib import nullcontext
|
|
195
|
+
|
|
196
|
+
from typing_extensions import override
|
|
197
|
+
|
|
198
|
+
_has_warned = False
|
|
199
|
+
|
|
200
|
+
class _snoop_cls(nullcontext):
|
|
201
|
+
@classmethod
|
|
202
|
+
def disable(cls):
|
|
203
|
+
return nullcontext()
|
|
204
|
+
|
|
205
|
+
@override
|
|
206
|
+
def __enter__(self):
|
|
207
|
+
global _has_warned
|
|
208
|
+
if not _has_warned:
|
|
209
|
+
warnings.warn(
|
|
210
|
+
"snoop is not installed, please install it to enable snoop"
|
|
211
|
+
)
|
|
212
|
+
_has_warned = True
|
|
213
|
+
|
|
214
|
+
return super().__enter__()
|
|
215
|
+
|
|
216
|
+
snoop: SnoopConstructor = cast(Any, _snoop_cls)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def print_environment_info(log: logging.Logger | None = None):
|
|
7
|
+
if log is None:
|
|
8
|
+
logging.basicConfig(level=logging.INFO)
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
log_message_lines: list[str] = []
|
|
12
|
+
log_message_lines.append("Python executable: " + sys.executable)
|
|
13
|
+
log_message_lines.append("Python version: " + sys.version)
|
|
14
|
+
log_message_lines.append("Python prefix: " + sys.prefix)
|
|
15
|
+
log_message_lines.append("Python path:")
|
|
16
|
+
for path in sys.path:
|
|
17
|
+
log_message_lines.append(f" {path}")
|
|
18
|
+
|
|
19
|
+
log_message_lines.append("Environment variables:")
|
|
20
|
+
for key, value in os.environ.items():
|
|
21
|
+
log_message_lines.append(f" {key}={value}")
|
|
22
|
+
|
|
23
|
+
log_message_lines.append("Command line arguments:")
|
|
24
|
+
for i, arg in enumerate(sys.argv):
|
|
25
|
+
log_message_lines.append(f" {i}: {arg}")
|
|
26
|
+
|
|
27
|
+
log.critical("\n".join(log_message_lines))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
if __name__ == "__main__":
|
|
31
|
+
print_environment_info()
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _create_launcher_script_file(
|
|
6
|
+
script_path: Path,
|
|
7
|
+
original_command: str | Iterable[str],
|
|
8
|
+
environment: Mapping[str, str],
|
|
9
|
+
setup_commands: Sequence[str],
|
|
10
|
+
chmod: bool = True,
|
|
11
|
+
prepend_command_with_exec: bool = True,
|
|
12
|
+
command_prefix: str | None = None,
|
|
13
|
+
# ^ If True, the original command will be prepended with 'exec' to replace the shell process
|
|
14
|
+
# with the command. This is useful for ensuring that the command is the only process in the
|
|
15
|
+
# process tree (e.g. for better signal handling).
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Creates a helper bash script for running the given function.
|
|
19
|
+
|
|
20
|
+
The core idea: The helper script is essentially one additional layer of indirection
|
|
21
|
+
that allows us to encapsulates the environment setup and the actual function call
|
|
22
|
+
in a single bash script (that does not require properly set up Python environment).
|
|
23
|
+
|
|
24
|
+
In effect, this allows us to, for example:
|
|
25
|
+
- Easily run the function in the correct environment
|
|
26
|
+
(without having to deal with shell hooks)
|
|
27
|
+
using `conda run -n myenv bash /path/to/helper.sh`.
|
|
28
|
+
- Easily run the function in a Singularity container
|
|
29
|
+
using `singularity exec my_container.sif bash /path/to/helper.sh`.
|
|
30
|
+
"""
|
|
31
|
+
with script_path.open("w") as f:
|
|
32
|
+
f.write("#!/bin/bash\n\n")
|
|
33
|
+
f.write("set -e\n\n")
|
|
34
|
+
|
|
35
|
+
if environment:
|
|
36
|
+
for key, value in environment.items():
|
|
37
|
+
f.write(f"export {key}={value}\n")
|
|
38
|
+
f.write("\n")
|
|
39
|
+
|
|
40
|
+
if setup_commands:
|
|
41
|
+
for setup_command in setup_commands:
|
|
42
|
+
f.write(f"{setup_command}\n")
|
|
43
|
+
f.write("\n")
|
|
44
|
+
|
|
45
|
+
if not isinstance(original_command, str):
|
|
46
|
+
original_command = " ".join(original_command)
|
|
47
|
+
|
|
48
|
+
if command_prefix:
|
|
49
|
+
original_command = f"{command_prefix} {original_command}"
|
|
50
|
+
|
|
51
|
+
if prepend_command_with_exec:
|
|
52
|
+
original_command = f"exec {original_command}"
|
|
53
|
+
f.write(f"{original_command}\n")
|
|
54
|
+
|
|
55
|
+
if chmod:
|
|
56
|
+
# Make the script executable
|
|
57
|
+
script_path.chmod(0o755)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def write_helper_script(
|
|
61
|
+
base_dir: Path,
|
|
62
|
+
command: str | Iterable[str],
|
|
63
|
+
environment: Mapping[str, str],
|
|
64
|
+
setup_commands: Sequence[str],
|
|
65
|
+
chmod: bool = True,
|
|
66
|
+
prepend_command_with_exec: bool = True,
|
|
67
|
+
command_prefix: str | None = None,
|
|
68
|
+
file_name: str = "helper.sh",
|
|
69
|
+
):
|
|
70
|
+
"""
|
|
71
|
+
Creates a helper bash script for running the given function.
|
|
72
|
+
|
|
73
|
+
The core idea: The helper script is essentially one additional layer of indirection
|
|
74
|
+
that allows us to encapsulates the environment setup and the actual function call
|
|
75
|
+
in a single bash script (that does not require properly set up Python environment).
|
|
76
|
+
|
|
77
|
+
In effect, this allows us to, for example:
|
|
78
|
+
- Easily run the function in the correct environment
|
|
79
|
+
(without having to deal with shell hooks)
|
|
80
|
+
using `conda run -n myenv bash /path/to/helper.sh`.
|
|
81
|
+
- Easily run the function in a Singularity container
|
|
82
|
+
using `singularity exec my_container.sif bash /path/to/helper.sh`.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
out_path = base_dir / file_name
|
|
86
|
+
_create_launcher_script_file(
|
|
87
|
+
out_path,
|
|
88
|
+
command,
|
|
89
|
+
environment,
|
|
90
|
+
setup_commands,
|
|
91
|
+
chmod,
|
|
92
|
+
prepend_command_with_exec,
|
|
93
|
+
command_prefix,
|
|
94
|
+
)
|
|
95
|
+
return out_path
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
DEFAULT_TEMPLATE = "bash {script}"
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def helper_script_to_command(script: Path, template: str | None) -> str:
|
|
102
|
+
if not template:
|
|
103
|
+
template = DEFAULT_TEMPLATE
|
|
104
|
+
|
|
105
|
+
# Make sure the template has '{script}' in it
|
|
106
|
+
if "{script}" not in template:
|
|
107
|
+
raise ValueError(f"Template must contain '{{script}}'. Got: {template!r}")
|
|
108
|
+
|
|
109
|
+
return template.format(script=str(script.absolute()))
|