nshtrainer 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,13 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from typing import Annotated, Literal
3
3
 
4
+ import nshconfig as C
4
5
  import torch
5
6
  import torch.nn as nn
6
7
  from typing_extensions import override
7
8
 
8
- from ..config import Field, TypedConfig
9
9
 
10
-
11
- class BaseNonlinearityConfig(TypedConfig, ABC):
10
+ class BaseNonlinearityConfig(C.Config, ABC):
12
11
  @abstractmethod
13
12
  def create_module(self) -> nn.Module:
14
13
  pass
@@ -153,5 +152,5 @@ NonlinearityConfig = Annotated[
153
152
  | SiLUNonlinearityConfig
154
153
  | MishNonlinearityConfig
155
154
  | SwiGLUNonlinearityConfig,
156
- Field(discriminator="name"),
155
+ C.Field(discriminator="name"),
157
156
  ]
nshtrainer/optimizer.py CHANGED
@@ -2,14 +2,13 @@ from abc import ABC, abstractmethod
2
2
  from collections.abc import Iterable
3
3
  from typing import Annotated, Any, Literal, TypeAlias
4
4
 
5
+ import nshconfig as C
5
6
  import torch.nn as nn
6
7
  from torch.optim import Optimizer
7
8
  from typing_extensions import override
8
9
 
9
- from .config import Field, TypedConfig
10
10
 
11
-
12
- class OptimizerConfigBase(TypedConfig, ABC):
11
+ class OptimizerConfigBase(C.Config, ABC):
13
12
  @abstractmethod
14
13
  def create_optimizer(
15
14
  self,
@@ -56,7 +55,4 @@ class AdamWConfig(OptimizerConfigBase):
56
55
  )
57
56
 
58
57
 
59
- OptimizerConfig: TypeAlias = Annotated[
60
- AdamWConfig,
61
- Field(discriminator="name"),
62
- ]
58
+ OptimizerConfig: TypeAlias = Annotated[AdamWConfig, C.Field(discriminator="name")]
nshtrainer/runner.py CHANGED
@@ -1,8 +1,8 @@
1
- from dataclasses import dataclass
2
1
  from typing import Generic
3
2
 
3
+ from nshrunner import RunInfo
4
4
  from nshrunner import Runner as _Runner
5
- from typing_extensions import Concatenate, TypeVar, TypeVarTuple, Unpack, override
5
+ from typing_extensions import TypeVar, TypeVarTuple, Unpack, override
6
6
 
7
7
  from .model.config import BaseConfig
8
8
 
@@ -11,11 +11,21 @@ TArguments = TypeVarTuple("TArguments")
11
11
  TReturn = TypeVar("TReturn", infer_variance=True)
12
12
 
13
13
 
14
- @dataclass(frozen=True)
15
14
  class Runner(
16
15
  _Runner[Unpack[tuple[TConfig, Unpack[TArguments]]], TReturn],
17
16
  Generic[TConfig, Unpack[TArguments], TReturn],
18
17
  ):
19
18
  @override
20
- def default_validate_fn():
21
- pass
19
+ @classmethod
20
+ def default_validate_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> None:
21
+ super().default_validate_fn(config, *args)
22
+
23
+ @override
24
+ @classmethod
25
+ def default_info_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> RunInfo:
26
+ run_info = super().default_info_fn(config, *args)
27
+ return {
28
+ **run_info,
29
+ "id": config.id,
30
+ "base_dir": config.directory.project_root,
31
+ }
@@ -25,14 +25,21 @@ _SIGNUM = int | signal.Signals
25
25
  _HANDLER: TypeAlias = Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
26
26
 
27
27
 
28
- class _SignalConnector(_LightningSignalConnector):
29
- def _auto_requeue_signals(self) -> list[signal.Signals]:
30
- from ..model.base import BaseConfig
28
+ def _resolve_requeue_signals():
29
+ signals: list[signal.Signals] = []
30
+
31
+ if timeout_signal_name := os.environ.get("NSHRUNNER_TIMEOUT_SIGNAL"):
32
+ signals.append(signal.Signals[timeout_signal_name])
31
33
 
32
- if not isinstance(config := self.trainer.lightning_module.hparams, BaseConfig):
33
- return []
34
+ if preempt_signal_name := os.environ.get("NSHRUNNER_PREEMPT_SIGNAL"):
35
+ signals.append(signal.Signals[preempt_signal_name])
34
36
 
35
- signals = config.runner.submit._resolved_auto_requeue_signals()
37
+ return signals
38
+
39
+
40
+ class _SignalConnector(_LightningSignalConnector):
41
+ def _auto_requeue_signals(self) -> list[signal.Signals]:
42
+ signals = _resolve_requeue_signals()
36
43
  signals_set = set(signals)
37
44
  valid_signals: set[signal.Signals] = signal.valid_signals()
38
45
  assert signals_set.issubset(
@@ -42,25 +49,29 @@ class _SignalConnector(_LightningSignalConnector):
42
49
 
43
50
  def _compose_and_register(
44
51
  self,
45
- signum: _SIGNUM,
52
+ signum: signal.Signals,
46
53
  handlers: list[_HANDLER],
47
54
  replace_existing: bool = False,
48
55
  ):
49
56
  if self._is_on_windows():
50
- log.info(f"Signal {signum} has no handlers or is not supported on Windows.")
57
+ log.info(
58
+ f"Signal {signum.name} has no handlers or is not supported on Windows."
59
+ )
51
60
  return
52
61
 
53
62
  if self._has_already_handler(signum):
54
63
  if not replace_existing:
55
64
  log.info(
56
- f"Signal {signum} already has a handler. Adding ours to the existing one."
65
+ f"Signal {signum.name} already has a handler. Adding ours to the existing one."
57
66
  )
58
67
  handlers.append(signal.getsignal(signum))
59
68
  else:
60
- log.info(f"Replacing existing handler for signal {signum} with ours.")
69
+ log.info(
70
+ f"Replacing existing handler for signal {signum.name} with ours."
71
+ )
61
72
 
62
73
  self._register_signal(signum, _HandlersCompose(handlers))
63
- log.info(f"Registered {len(handlers)} handlers for signal {signum}.")
74
+ log.info(f"Registered {len(handlers)} handlers for signal {signum.name}.")
64
75
 
65
76
  @override
66
77
  def register_signal_handlers(self) -> None:
@@ -31,7 +31,7 @@ log = logging.getLogger(__name__)
31
31
 
32
32
  def _is_bf16_supported_no_emulation():
33
33
  r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
34
- version = cast(Any, torch.version)
34
+ version = getattr(torch, "version")
35
35
 
36
36
  # Check for ROCm, if true return true, no ROCM_VERSION check required,
37
37
  # since it is supported on AMD GPU archs.
nshtrainer/typecheck.py CHANGED
@@ -82,6 +82,7 @@ def typecheck_this_module(additional_modules: Sequence[str] = ()):
82
82
  frame = get_frame(1)
83
83
  assert frame is not None, "frame is None"
84
84
  calling_module_name = get_frame_package_name(frame)
85
+ assert calling_module_name is not None, "calling_module_name is None"
85
86
 
86
87
  # Typecheck the calling module + any additional modules.
87
88
  typecheck_modules((calling_module_name, *additional_modules))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -9,10 +9,21 @@ Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.10
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
+ Requires-Dist: beartype (>=0.18.5,<0.19.0)
13
+ Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
14
+ Requires-Dist: lightning
15
+ Requires-Dist: lovely-numpy (>=0.2.13,<0.3.0)
16
+ Requires-Dist: lovely-tensors (>=0.1.16,<0.2.0)
12
17
  Requires-Dist: nshconfig (>=0.2.0,<0.3.0)
13
- Requires-Dist: nshrunner (>=0.1.0,<0.2.0)
18
+ Requires-Dist: nshrunner (>=0.5.3,<0.6.0)
19
+ Requires-Dist: numpy
20
+ Requires-Dist: pysnooper
21
+ Requires-Dist: pytorch-lightning
22
+ Requires-Dist: rich
14
23
  Requires-Dist: torch
24
+ Requires-Dist: torchmetrics
15
25
  Requires-Dist: typing-extensions
26
+ Requires-Dist: wrapt
16
27
  Description-Content-Type: text/markdown
17
28
 
18
29
 
@@ -1,22 +1,16 @@
1
- nshtrainer/__init__.py,sha256=o39TbnjwUYzE4POcncUiDx02Ey-Hzx8UGuwJDjMcKZU,2971
1
+ nshtrainer/__init__.py,sha256=OHbxLxVvFGW--ecuIGqkoylSVHFS4x4F1-oeuENH-Do,2212
2
2
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
3
3
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
4
4
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
5
5
  nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
6
6
  nshtrainer/_snoop.py,sha256=Rofv1Rd92E0LY40G3A-o9Hu0ZI73RR59wJD5l4Q3PDM,7022
7
- nshtrainer/_submit/print_environment_info.py,sha256=enbJGl_iHIlhKN8avzKnoZSb0zUQ_fUdnsQ8a_9tbYk,963
8
- nshtrainer/_submit/session/_output.py,sha256=CNGH5W6_XxAC5-TRvMAMxOHd3fjGpJhK-7RGTDyvMu4,245
9
- nshtrainer/_submit/session/_script.py,sha256=0AeBgBduDsoIEBrY9kebARiBUEGc50JAD9oE_IDiLnA,3775
10
- nshtrainer/_submit/session/lsf.py,sha256=p19EP6OhROZxqfRhzeTD7GDmfYaREIKMXMOI8G933FE,14307
11
- nshtrainer/_submit/session/slurm.py,sha256=JpAjQvck4LjGN8o8fOvIeMuFqrg1cioANoVsX5hU-3g,17594
12
- nshtrainer/_submit/session/unified.py,sha256=gfh-AtnMyFHzcQOUlhlAR__vaWDk1r9XCivz_t_lHKk,11695
13
7
  nshtrainer/actsave/__init__.py,sha256=G1T-fELuGWkVqdhdyoePtj2dTOUtcIOW4VgsXv9JNTA,338
14
8
  nshtrainer/actsave/_callback.py,sha256=QoTa60F70f1RxB41VKixN9l5_htfFQxXDPHHSNFreuk,2770
15
9
  nshtrainer/actsave/_loader.py,sha256=fAhD32DrJa4onkYfcwc21YIeGEYzOSXCK_HVo9SZLgQ,4604
16
10
  nshtrainer/actsave/_saver.py,sha256=0EHmQDhqVxQWRWWSyt03eP1K9ETiACMQYmsZkDMt6HY,9451
17
- nshtrainer/callbacks/__init__.py,sha256=ohE_MO_kX1o4SZwcipIXUA9m7XYcijEKJtGcoU8dTkY,1667
11
+ nshtrainer/callbacks/__init__.py,sha256=I6W33ityL9Ko8jjqHh3WH_8miV59SAe9LxInhoqX5XE,1665
18
12
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
19
- nshtrainer/callbacks/base.py,sha256=WESZz1VSTl1xSGVXBmxFqWwbLxXcJp97jpg9zrE0EsY,3560
13
+ nshtrainer/callbacks/base.py,sha256=LrcRUV02bZEKXRIRvhHT9qsvw_kwoWiAdQkVMyKc5NU,3542
20
14
  nshtrainer/callbacks/early_stopping.py,sha256=jriSU761wf_qTJ9Bos0D3h5aDvZHYpRqK62Ne8aWp5I,3768
21
15
  nshtrainer/callbacks/ema.py,sha256=zKCtvzZFo0ORlwNZHjaMk-sJoxrlTtFWOzR-yGy95W0,12134
22
16
  nshtrainer/callbacks/finite_checks.py,sha256=kX3TIJsxyqx0GuLJfYsqVgKU27zwjG9Z8324lyCFtwM,2087
@@ -30,17 +24,16 @@ nshtrainer/callbacks/print_table.py,sha256=FcA-CBWwMf9c1NNRinvYpZC400RNQxuP28bJf
30
24
  nshtrainer/callbacks/throughput_monitor.py,sha256=YQLdpX3LGybIiD814yT9yCCVSEXRWf8WwsvVaN5aDBE,1848
31
25
  nshtrainer/callbacks/timer.py,sha256=sDXPPcdDKu5xnuK_bjr8plIq9MBuluNJ42Mt9LvPZzc,4610
32
26
  nshtrainer/callbacks/wandb_watch.py,sha256=pUpMsNxd03ex1rzOmFw2HzGOXjnQGaH84m8cc2dXo4g,2937
33
- nshtrainer/config.py,sha256=0Fj5w-ry0BRl2_zJI6jwCnmMWE3p_eD8_Wn-NyFkTqU,10442
34
27
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
35
28
  nshtrainer/data/balanced_batch_sampler.py,sha256=bcJBcQjh1hB1yKF_xSlT9AtEWv0BJjYc1CuH2BF-ea8,4392
36
29
  nshtrainer/data/transform.py,sha256=JeGxvytQly8hougrsdMmKG8gJ6qvFPDglJCO4Tp6STk,1795
37
- nshtrainer/lr_scheduler/__init__.py,sha256=GNGmkcJD3jgCMk7pfaanAYrKz9957qkx6_Q0rssiHK0,738
38
- nshtrainer/lr_scheduler/_base.py,sha256=1tWMABevKZAuGhJN8Me2E9eqEyqoLtsG0bADPjED7a4,3752
39
- nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=VhsxZJ_Mw9zjkAGunFQ1KRub5_QM5NRqaEFWtmedFp8,5212
40
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=Ct-uLo8Q4t7lJ_HwoLRhNmudnCw4cSnblpBEg22aVTI,2691
41
- nshtrainer/model/__init__.py,sha256=PdvZkpAVkqvCLipGJvEHFU3WxnSMxYpvtuOkvLIenxg,2078
42
- nshtrainer/model/base.py,sha256=bhngGHxr0suQB9Ezi_3d5JgDWYqS_yPgGJZrGmc1TnI,23571
43
- nshtrainer/model/config.py,sha256=RMDdrbtvwm5vTFPxQ2x1hqiBIEEE-OAknhF6KTWfkkk,70293
30
+ nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
31
+ nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
32
+ nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9btIxMRWigUHUTlUYCSw,5221
33
+ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=2ZdlV0RUMwg2DClzqYHr8_EKT1jZBUlSD39e-XlCsC4,2764
34
+ nshtrainer/model/__init__.py,sha256=y32Hla-5whpzLL2BtCJpBakSp8o-1nQbpO0j_-xq_Po,1864
35
+ nshtrainer/model/base.py,sha256=EMkOtp4YWGPHM0HPSTLbx75T9vlYmXO4XyD725xU70w,21453
36
+ nshtrainer/model/config.py,sha256=6lATW6-Z1SIDgQ1IWrGBVQKTr8DhL5b_rFbJHQz0d5o,66796
44
37
  nshtrainer/model/modules/callback.py,sha256=JF59U9-CjJsAIspEhTJbVaGN0wGctZG7UquE3IS7R8A,6408
45
38
  nshtrainer/model/modules/debug.py,sha256=DTVty8cKnzj1GCULRyGx_sWTTsq9NLi30dzqjRTnuCU,1127
46
39
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -52,21 +45,20 @@ nshtrainer/nn/__init__.py,sha256=57LPaP3G-BBGD2eGxbBUABNgYl3s_oASwrtOSS4bzTs,133
52
45
  nshtrainer/nn/mlp.py,sha256=i-dHk0tomO_XlU6cKN4CC4HxTaYb-ukBCAgY1ySXl4I,3963
53
46
  nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
54
47
  nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
55
- nshtrainer/nn/nonlinearity.py,sha256=IhIR8NCTY3Np9dMDnUouERR8ZhWpK3S0hTbT0i8HezU,3645
56
- nshtrainer/optimizer.py,sha256=JiLNRtcfYxyhAab1Z1QcEzmrX9S_JyrBS67TXy12kXI,1557
57
- nshtrainer/runner.py,sha256=9HsYB58aasY9RVvya_gPECDs_MBhM1fl4cbM3iJYTDc,600
48
+ nshtrainer/nn/nonlinearity.py,sha256=owtU4kh4G98psD0axOJWVfBhm-OtJVgFM-TXSHmbNPU,3625
49
+ nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
50
+ nshtrainer/runner.py,sha256=af_EGnQTSvUgwnVhhytvY3V7o_Xg-xx-sLb8K2Szb1E,979
58
51
  nshtrainer/scripts/check_env.py,sha256=IMl6dSqsLYppI0XuCsVq8lK4bYqXwY9KHJkzsShz4Kg,806
59
52
  nshtrainer/scripts/find_packages.py,sha256=FbdlfmAefttFSMfaT0A46a-oHLP_ioaQKihwBfBeWeA,1467
60
53
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
61
- nshtrainer/trainer/signal_connector.py,sha256=aGg6kRiHiqtAdGlEvEvGLmOy7AvRHTSkXdTmZpRXbjU,8435
62
- nshtrainer/trainer/trainer.py,sha256=oi8KdHF1AdZ54KFbCFAEI7W-C7qRtRe-KtOjNwBuS3M,14033
63
- nshtrainer/typecheck.py,sha256=CFkmPIxCU24CHk_7_pykb-Y1PRNhpLgsVZw1zuuOS_U,4614
54
+ nshtrainer/trainer/signal_connector.py,sha256=QAoPM_C5JJOVQebcrJOimUUD3GHyoeZUqCEAvzZlT4U,8710
55
+ nshtrainer/trainer/trainer.py,sha256=eYEYfY9v70MuorHcSf8nqM7f2CkmUHhpPcjCk4FJD7k,14034
56
+ nshtrainer/typecheck.py,sha256=RGYHxDBcs97E6ayl6Olc43JBZXQolCtMxcLBniVCVBg,4688
64
57
  nshtrainer/util/environment.py,sha256=_SEtiQ_s5bL5pllUlf96AOUv15kNvCPvocVC13S7mIk,4166
65
58
  nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
66
- nshtrainer/util/singleton.py,sha256=nLhpuMZxl0zdNsnvS97o4ASUnKzCWYEKLzR_j9oP_xs,2208
67
59
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
68
60
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
69
61
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
70
- nshtrainer-0.1.0.dist-info/METADATA,sha256=3zdNPxyB-I6Gudq2gTaU0crdgmDCcGCp6Zudef0DtuM,529
71
- nshtrainer-0.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
72
- nshtrainer-0.1.0.dist-info/RECORD,,
62
+ nshtrainer-0.1.1.dist-info/METADATA,sha256=32iVLvdJh6OJQyD-_7NDO6IYqfHPSflDznYfYaCo8-c,882
63
+ nshtrainer-0.1.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
64
+ nshtrainer-0.1.1.dist-info/RECORD,,
@@ -1,31 +0,0 @@
1
- import logging
2
- import os
3
- import sys
4
-
5
-
6
- def print_environment_info(log: logging.Logger | None = None):
7
- if log is None:
8
- logging.basicConfig(level=logging.INFO)
9
- log = logging.getLogger(__name__)
10
-
11
- log_message_lines: list[str] = []
12
- log_message_lines.append("Python executable: " + sys.executable)
13
- log_message_lines.append("Python version: " + sys.version)
14
- log_message_lines.append("Python prefix: " + sys.prefix)
15
- log_message_lines.append("Python path:")
16
- for path in sys.path:
17
- log_message_lines.append(f" {path}")
18
-
19
- log_message_lines.append("Environment variables:")
20
- for key, value in os.environ.items():
21
- log_message_lines.append(f" {key}={value}")
22
-
23
- log_message_lines.append("Command line arguments:")
24
- for i, arg in enumerate(sys.argv):
25
- log_message_lines.append(f" {i}: {arg}")
26
-
27
- log.critical("\n".join(log_message_lines))
28
-
29
-
30
- if __name__ == "__main__":
31
- print_environment_info()
@@ -1,12 +0,0 @@
1
- from dataclasses import dataclass
2
- from pathlib import Path
3
-
4
-
5
- @dataclass(frozen=True)
6
- class SubmitOutput:
7
- command_parts: list[str]
8
- script_path: Path
9
-
10
- @property
11
- def command(self) -> str:
12
- return " ".join(self.command_parts)
@@ -1,109 +0,0 @@
1
- from collections.abc import Iterable, Mapping, Sequence
2
- from pathlib import Path
3
-
4
-
5
- def _create_launcher_script_file(
6
- script_path: Path,
7
- original_command: str | Iterable[str],
8
- environment: Mapping[str, str],
9
- setup_commands: Sequence[str],
10
- chmod: bool = True,
11
- prepend_command_with_exec: bool = True,
12
- command_prefix: str | None = None,
13
- # ^ If True, the original command will be prepended with 'exec' to replace the shell process
14
- # with the command. This is useful for ensuring that the command is the only process in the
15
- # process tree (e.g. for better signal handling).
16
- ):
17
- """
18
- Creates a helper bash script for running the given function.
19
-
20
- The core idea: The helper script is essentially one additional layer of indirection
21
- that allows us to encapsulates the environment setup and the actual function call
22
- in a single bash script (that does not require properly set up Python environment).
23
-
24
- In effect, this allows us to, for example:
25
- - Easily run the function in the correct environment
26
- (without having to deal with shell hooks)
27
- using `conda run -n myenv bash /path/to/helper.sh`.
28
- - Easily run the function in a Singularity container
29
- using `singularity exec my_container.sif bash /path/to/helper.sh`.
30
- """
31
- with script_path.open("w") as f:
32
- f.write("#!/bin/bash\n\n")
33
- f.write("set -e\n\n")
34
-
35
- if environment:
36
- for key, value in environment.items():
37
- f.write(f"export {key}={value}\n")
38
- f.write("\n")
39
-
40
- if setup_commands:
41
- for setup_command in setup_commands:
42
- f.write(f"{setup_command}\n")
43
- f.write("\n")
44
-
45
- if not isinstance(original_command, str):
46
- original_command = " ".join(original_command)
47
-
48
- if command_prefix:
49
- original_command = f"{command_prefix} {original_command}"
50
-
51
- if prepend_command_with_exec:
52
- original_command = f"exec {original_command}"
53
- f.write(f"{original_command}\n")
54
-
55
- if chmod:
56
- # Make the script executable
57
- script_path.chmod(0o755)
58
-
59
-
60
- def write_helper_script(
61
- base_dir: Path,
62
- command: str | Iterable[str],
63
- environment: Mapping[str, str],
64
- setup_commands: Sequence[str],
65
- chmod: bool = True,
66
- prepend_command_with_exec: bool = True,
67
- command_prefix: str | None = None,
68
- file_name: str = "helper.sh",
69
- ):
70
- """
71
- Creates a helper bash script for running the given function.
72
-
73
- The core idea: The helper script is essentially one additional layer of indirection
74
- that allows us to encapsulates the environment setup and the actual function call
75
- in a single bash script (that does not require properly set up Python environment).
76
-
77
- In effect, this allows us to, for example:
78
- - Easily run the function in the correct environment
79
- (without having to deal with shell hooks)
80
- using `conda run -n myenv bash /path/to/helper.sh`.
81
- - Easily run the function in a Singularity container
82
- using `singularity exec my_container.sif bash /path/to/helper.sh`.
83
- """
84
-
85
- out_path = base_dir / file_name
86
- _create_launcher_script_file(
87
- out_path,
88
- command,
89
- environment,
90
- setup_commands,
91
- chmod,
92
- prepend_command_with_exec,
93
- command_prefix,
94
- )
95
- return out_path
96
-
97
-
98
- DEFAULT_TEMPLATE = "bash {script}"
99
-
100
-
101
- def helper_script_to_command(script: Path, template: str | None) -> str:
102
- if not template:
103
- template = DEFAULT_TEMPLATE
104
-
105
- # Make sure the template has '{script}' in it
106
- if "{script}" not in template:
107
- raise ValueError(f"Template must contain '{{script}}'. Got: {template!r}")
108
-
109
- return template.format(script=str(script.absolute()))