nshtrainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer/__init__.py +64 -0
  2. nshtrainer/_experimental/__init__.py +2 -0
  3. nshtrainer/_experimental/flops/__init__.py +48 -0
  4. nshtrainer/_experimental/flops/flop_counter.py +787 -0
  5. nshtrainer/_experimental/flops/module_tracker.py +140 -0
  6. nshtrainer/_snoop.py +216 -0
  7. nshtrainer/_submit/print_environment_info.py +31 -0
  8. nshtrainer/_submit/session/_output.py +12 -0
  9. nshtrainer/_submit/session/_script.py +109 -0
  10. nshtrainer/_submit/session/lsf.py +467 -0
  11. nshtrainer/_submit/session/slurm.py +573 -0
  12. nshtrainer/_submit/session/unified.py +350 -0
  13. nshtrainer/actsave/__init__.py +7 -0
  14. nshtrainer/actsave/_callback.py +75 -0
  15. nshtrainer/actsave/_loader.py +144 -0
  16. nshtrainer/actsave/_saver.py +337 -0
  17. nshtrainer/callbacks/__init__.py +35 -0
  18. nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  19. nshtrainer/callbacks/base.py +113 -0
  20. nshtrainer/callbacks/early_stopping.py +112 -0
  21. nshtrainer/callbacks/ema.py +383 -0
  22. nshtrainer/callbacks/finite_checks.py +75 -0
  23. nshtrainer/callbacks/gradient_skipping.py +103 -0
  24. nshtrainer/callbacks/interval.py +322 -0
  25. nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  26. nshtrainer/callbacks/log_epoch.py +35 -0
  27. nshtrainer/callbacks/norm_logging.py +187 -0
  28. nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  29. nshtrainer/callbacks/print_table.py +90 -0
  30. nshtrainer/callbacks/throughput_monitor.py +56 -0
  31. nshtrainer/callbacks/timer.py +157 -0
  32. nshtrainer/callbacks/wandb_watch.py +103 -0
  33. nshtrainer/config.py +289 -0
  34. nshtrainer/data/__init__.py +4 -0
  35. nshtrainer/data/balanced_batch_sampler.py +132 -0
  36. nshtrainer/data/transform.py +67 -0
  37. nshtrainer/lr_scheduler/__init__.py +18 -0
  38. nshtrainer/lr_scheduler/_base.py +101 -0
  39. nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  40. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  41. nshtrainer/model/__init__.py +44 -0
  42. nshtrainer/model/base.py +641 -0
  43. nshtrainer/model/config.py +2064 -0
  44. nshtrainer/model/modules/callback.py +157 -0
  45. nshtrainer/model/modules/debug.py +42 -0
  46. nshtrainer/model/modules/distributed.py +70 -0
  47. nshtrainer/model/modules/logger.py +170 -0
  48. nshtrainer/model/modules/profiler.py +24 -0
  49. nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  50. nshtrainer/model/modules/shared_parameters.py +72 -0
  51. nshtrainer/nn/__init__.py +19 -0
  52. nshtrainer/nn/mlp.py +106 -0
  53. nshtrainer/nn/module_dict.py +66 -0
  54. nshtrainer/nn/module_list.py +50 -0
  55. nshtrainer/nn/nonlinearity.py +157 -0
  56. nshtrainer/optimizer.py +62 -0
  57. nshtrainer/runner.py +21 -0
  58. nshtrainer/scripts/check_env.py +41 -0
  59. nshtrainer/scripts/find_packages.py +51 -0
  60. nshtrainer/trainer/__init__.py +1 -0
  61. nshtrainer/trainer/signal_connector.py +208 -0
  62. nshtrainer/trainer/trainer.py +340 -0
  63. nshtrainer/typecheck.py +144 -0
  64. nshtrainer/util/environment.py +119 -0
  65. nshtrainer/util/seed.py +11 -0
  66. nshtrainer/util/singleton.py +89 -0
  67. nshtrainer/util/slurm.py +49 -0
  68. nshtrainer/util/typed.py +2 -0
  69. nshtrainer/util/typing_utils.py +19 -0
  70. nshtrainer-0.1.0.dist-info/METADATA +18 -0
  71. nshtrainer-0.1.0.dist-info/RECORD +72 -0
  72. nshtrainer-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,140 @@
1
+ import weakref
2
+ from typing import Set
3
+
4
+ import torch
5
+ from torch.autograd.graph import register_multi_grad_hook
6
+ from torch.nn.modules.module import (
7
+ register_module_forward_hook,
8
+ register_module_forward_pre_hook,
9
+ )
10
+ from torch.utils._pytree import tree_flatten
11
+
12
+ __all__ = ["ModuleTracker"]
13
+
14
+
15
+ class ModuleTracker:
16
+ """
17
+ ``ModuleTracker`` is a context manager that tracks the nn.Module hierarchy during execution
18
+ so that other system can query which Module is currently being executed (or its backward is being
19
+ executed).
20
+
21
+ You can access the ``parents`` attribute on this context manager to get the set of all the
22
+ Modules currently being executed via their fqn (fully qualified name, also used as the key within
23
+ the state_dict).
24
+ You can access the ``is_bw`` attribute to know if you are currently running in backward or not.
25
+
26
+ Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag
27
+ will remain ``True`` after the forward until another Module is executed. If you need it to be
28
+ more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance
29
+ is possible but not done yet, please submit an issue requesting this if you need it.
30
+
31
+ Example usage
32
+
33
+ .. code-block:: python
34
+
35
+ mod = torch.nn.Linear(2, 2)
36
+
37
+ with ModuleTracker() as tracker:
38
+ # Access anything during the forward pass
39
+ def my_linear(m1, m2, bias):
40
+ print(f"Current modules: {tracker.parents}")
41
+ return torch.mm(m1, m2.t()) + bias
42
+ torch.nn.functional.linear = my_linear
43
+
44
+ mod(torch.rand(2, 2))
45
+
46
+ """
47
+
48
+ parents: Set[str]
49
+ """
50
+ A Set containing the fqn for each module currently running their forward
51
+ """
52
+
53
+ def __init__(self):
54
+ self.parents = {"Global"}
55
+ self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
56
+ self._seen_modules: weakref.WeakSet = weakref.WeakSet()
57
+ self._has_callback = False
58
+
59
+ def _maybe_set_engine_callback(self):
60
+ # This assumes no concurrent calls to backward
61
+ if self._has_callback:
62
+ return
63
+
64
+ def callback():
65
+ self.parents = {"Global"}
66
+ self._has_callback = False
67
+
68
+ torch.autograd.Variable._execution_engine.queue_callback(callback)
69
+ self._has_callback = True
70
+
71
+ @property
72
+ def is_bw(self):
73
+ """
74
+ A boolean marking if this is currently running during the backward pass or not
75
+ """
76
+ return torch._C._current_graph_task_id() != -1
77
+
78
+ def _get_mod_name(self, mod):
79
+ if mod not in self._known_modules:
80
+ self._known_modules[mod] = type(mod).__name__
81
+ mod_name = self._known_modules[mod]
82
+ if mod not in self._seen_modules:
83
+ for name, submod in mod.named_children():
84
+ self._known_modules[submod] = f"{mod_name}.{name}"
85
+ self._get_mod_name(submod)
86
+ self._seen_modules.add(mod)
87
+ return mod_name
88
+
89
+ def _get_append_fn(self, name, is_bw):
90
+ def fn(*args):
91
+ if is_bw:
92
+ self._maybe_set_engine_callback()
93
+ if name in self.parents:
94
+ print(
95
+ "The module hierarchy tracking seems to be messed up."
96
+ "Please file a bug to PyTorch."
97
+ )
98
+ self.parents.add(name)
99
+
100
+ return fn
101
+
102
+ def _get_pop_fn(self, name, is_bw):
103
+ def fn(*args):
104
+ if name in self.parents:
105
+ self.parents.remove(name)
106
+ elif not is_bw:
107
+ # Due to some input/output not requiring gradients, we cannot enforce
108
+ # proper nesting in backward
109
+ raise RuntimeError(
110
+ "The Module hierarchy tracking is wrong. Report a bug to PyTorch"
111
+ )
112
+
113
+ return fn
114
+
115
+ def _fw_pre_hook(self, mod, input):
116
+ name = self._get_mod_name(mod)
117
+ self._get_append_fn(name, False)()
118
+
119
+ args, _ = tree_flatten(input)
120
+ tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
121
+ if tensors:
122
+ register_multi_grad_hook(tensors, self._get_pop_fn(name, True))
123
+
124
+ def _fw_post_hook(self, mod, input, output):
125
+ name = self._get_mod_name(mod)
126
+ self._get_pop_fn(name, False)()
127
+
128
+ args, _ = tree_flatten(output)
129
+ tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
130
+ if tensors:
131
+ register_multi_grad_hook(tensors, self._get_append_fn(name, True))
132
+
133
+ def __enter__(self):
134
+ self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
135
+ self._fw_post_handle = register_module_forward_hook(self._fw_post_hook)
136
+ return self
137
+
138
+ def __exit__(self, *args):
139
+ self._fw_pre_handle.remove()
140
+ self._fw_post_handle.remove()
nshtrainer/_snoop.py ADDED
@@ -0,0 +1,216 @@
1
+ import contextlib
2
+ from typing import Any, Protocol, cast
3
+
4
+ from typing_extensions import TypeVar
5
+
6
+ T = TypeVar("T", infer_variance=True)
7
+
8
+
9
+ class SnoopConstructor(Protocol):
10
+ def __call__(self, *args, **kwargs) -> contextlib.AbstractContextManager: ...
11
+
12
+ def disable(self) -> contextlib.AbstractContextManager: ...
13
+
14
+
15
+ try:
16
+ import warnings
17
+ from contextlib import nullcontext
18
+
19
+ import lovely_numpy as lo
20
+ import lovely_tensors as lt
21
+ import numpy
22
+ import pysnooper
23
+ import pysnooper.utils
24
+ import torch
25
+ from pkg_resources import DistributionNotFound, get_distribution
26
+
27
+ FLOATING_POINTS = set()
28
+ for i in ["float", "double", "half", "complex128", "complex32", "complex64"]:
29
+ if hasattr(torch, i): # older version of PyTorch do not have complex dtypes
30
+ FLOATING_POINTS.add(getattr(torch, i))
31
+
32
+ try:
33
+ __version__ = get_distribution(__name__).version
34
+ except DistributionNotFound:
35
+ # package is not installed
36
+ pass
37
+
38
+ def default_format(x):
39
+ try:
40
+ formatted = str(lt.lovely(x))
41
+ return formatted
42
+ except BaseException:
43
+ return str(x.shape)
44
+
45
+ def default_numpy_format(x):
46
+ return str(lo.lovely(x))
47
+
48
+ class TorchSnooper(pysnooper.tracer.Tracer):
49
+ def __init__(
50
+ self,
51
+ *args,
52
+ tensor_format=default_format,
53
+ numpy_format=default_numpy_format,
54
+ **kwargs,
55
+ ):
56
+ self.orig_custom_repr = (
57
+ kwargs["custom_repr"] if "custom_repr" in kwargs else ()
58
+ )
59
+ custom_repr = (lambda x: True, self.compute_repr)
60
+ kwargs["custom_repr"] = (custom_repr,)
61
+ super(TorchSnooper, self).__init__(*args, **kwargs)
62
+ self.tensor_format = tensor_format
63
+ self.numpy_format = numpy_format
64
+
65
+ @staticmethod
66
+ def is_return_types(x):
67
+ return type(x).__module__ == "torch.return_types"
68
+
69
+ def return_types_repr(self, x):
70
+ if type(x).__name__ in {
71
+ "max",
72
+ "min",
73
+ "median",
74
+ "mode",
75
+ "sort",
76
+ "topk",
77
+ "kthvalue",
78
+ }:
79
+ return (
80
+ type(x).__name__
81
+ + "(values="
82
+ + self.tensor_format(x.values)
83
+ + ", indices="
84
+ + self.tensor_format(x.indices)
85
+ + ")"
86
+ )
87
+ if type(x).__name__ == "svd":
88
+ return (
89
+ "svd(U="
90
+ + self.tensor_format(x.U)
91
+ + ", S="
92
+ + self.tensor_format(x.S)
93
+ + ", V="
94
+ + self.tensor_format(x.V)
95
+ + ")"
96
+ )
97
+ if type(x).__name__ == "slogdet":
98
+ return (
99
+ "slogdet(sign="
100
+ + self.tensor_format(x.sign)
101
+ + ", logabsdet="
102
+ + self.tensor_format(x.logabsdet)
103
+ + ")"
104
+ )
105
+ if type(x).__name__ == "qr":
106
+ return (
107
+ "qr(Q="
108
+ + self.tensor_format(x.Q)
109
+ + ", R="
110
+ + self.tensor_format(x.R)
111
+ + ")"
112
+ )
113
+ if type(x).__name__ == "solve":
114
+ return (
115
+ "solve(solution="
116
+ + self.tensor_format(x.solution)
117
+ + ", LU="
118
+ + self.tensor_format(x.LU)
119
+ + ")"
120
+ )
121
+ if type(x).__name__ == "geqrf":
122
+ return (
123
+ "geqrf(a="
124
+ + self.tensor_format(x.a)
125
+ + ", tau="
126
+ + self.tensor_format(x.tau)
127
+ + ")"
128
+ )
129
+ if type(x).__name__ in {"symeig", "eig"}:
130
+ return (
131
+ type(x).__name__
132
+ + "(eigenvalues="
133
+ + self.tensor_format(x.eigenvalues)
134
+ + ", eigenvectors="
135
+ + self.tensor_format(x.eigenvectors)
136
+ + ")"
137
+ )
138
+ if type(x).__name__ == "triangular_solve":
139
+ return (
140
+ "triangular_solve(solution="
141
+ + self.tensor_format(x.solution)
142
+ + ", cloned_coefficient="
143
+ + self.tensor_format(x.cloned_coefficient)
144
+ + ")"
145
+ )
146
+ if type(x).__name__ == "gels":
147
+ return (
148
+ "gels(solution="
149
+ + self.tensor_format(x.solution)
150
+ + ", QR="
151
+ + self.tensor_format(x.QR)
152
+ + ")"
153
+ )
154
+ warnings.warn("Unknown return_types encountered, open a bug report!")
155
+
156
+ def compute_repr(self, x):
157
+ orig_repr_func = pysnooper.utils.get_repr_function(x, self.orig_custom_repr)
158
+ if torch.is_tensor(x):
159
+ return self.tensor_format(x)
160
+ if isinstance(x, numpy.ndarray):
161
+ return self.numpy_format(x)
162
+ if self.is_return_types(x):
163
+ return self.return_types_repr(x)
164
+ if orig_repr_func is not repr:
165
+ return orig_repr_func(x)
166
+ if isinstance(x, (list, tuple)):
167
+ content = ""
168
+ for i in x:
169
+ if content != "":
170
+ content += ", "
171
+ content += self.compute_repr(i)
172
+ if isinstance(x, tuple) and len(x) == 1:
173
+ content += ","
174
+ if isinstance(x, tuple):
175
+ return "(" + content + ")"
176
+ return "[" + content + "]"
177
+ if isinstance(x, dict):
178
+ content = ""
179
+ for k, v in x.items():
180
+ if content != "":
181
+ content += ", "
182
+ content += self.compute_repr(k) + ": " + self.compute_repr(v)
183
+ return "{" + content + "}"
184
+ return repr(x)
185
+
186
+ class _Snoop:
187
+ disable = nullcontext
188
+ __call__ = TorchSnooper
189
+
190
+ snoop: SnoopConstructor = cast(Any, _Snoop())
191
+
192
+ except ImportError:
193
+ import warnings
194
+ from contextlib import nullcontext
195
+
196
+ from typing_extensions import override
197
+
198
+ _has_warned = False
199
+
200
+ class _snoop_cls(nullcontext):
201
+ @classmethod
202
+ def disable(cls):
203
+ return nullcontext()
204
+
205
+ @override
206
+ def __enter__(self):
207
+ global _has_warned
208
+ if not _has_warned:
209
+ warnings.warn(
210
+ "snoop is not installed, please install it to enable snoop"
211
+ )
212
+ _has_warned = True
213
+
214
+ return super().__enter__()
215
+
216
+ snoop: SnoopConstructor = cast(Any, _snoop_cls)
@@ -0,0 +1,31 @@
1
+ import logging
2
+ import os
3
+ import sys
4
+
5
+
6
+ def print_environment_info(log: logging.Logger | None = None):
7
+ if log is None:
8
+ logging.basicConfig(level=logging.INFO)
9
+ log = logging.getLogger(__name__)
10
+
11
+ log_message_lines: list[str] = []
12
+ log_message_lines.append("Python executable: " + sys.executable)
13
+ log_message_lines.append("Python version: " + sys.version)
14
+ log_message_lines.append("Python prefix: " + sys.prefix)
15
+ log_message_lines.append("Python path:")
16
+ for path in sys.path:
17
+ log_message_lines.append(f" {path}")
18
+
19
+ log_message_lines.append("Environment variables:")
20
+ for key, value in os.environ.items():
21
+ log_message_lines.append(f" {key}={value}")
22
+
23
+ log_message_lines.append("Command line arguments:")
24
+ for i, arg in enumerate(sys.argv):
25
+ log_message_lines.append(f" {i}: {arg}")
26
+
27
+ log.critical("\n".join(log_message_lines))
28
+
29
+
30
+ if __name__ == "__main__":
31
+ print_environment_info()
@@ -0,0 +1,12 @@
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class SubmitOutput:
7
+ command_parts: list[str]
8
+ script_path: Path
9
+
10
+ @property
11
+ def command(self) -> str:
12
+ return " ".join(self.command_parts)
@@ -0,0 +1,109 @@
1
+ from collections.abc import Iterable, Mapping, Sequence
2
+ from pathlib import Path
3
+
4
+
5
+ def _create_launcher_script_file(
6
+ script_path: Path,
7
+ original_command: str | Iterable[str],
8
+ environment: Mapping[str, str],
9
+ setup_commands: Sequence[str],
10
+ chmod: bool = True,
11
+ prepend_command_with_exec: bool = True,
12
+ command_prefix: str | None = None,
13
+ # ^ If True, the original command will be prepended with 'exec' to replace the shell process
14
+ # with the command. This is useful for ensuring that the command is the only process in the
15
+ # process tree (e.g. for better signal handling).
16
+ ):
17
+ """
18
+ Creates a helper bash script for running the given function.
19
+
20
+ The core idea: The helper script is essentially one additional layer of indirection
21
+ that allows us to encapsulates the environment setup and the actual function call
22
+ in a single bash script (that does not require properly set up Python environment).
23
+
24
+ In effect, this allows us to, for example:
25
+ - Easily run the function in the correct environment
26
+ (without having to deal with shell hooks)
27
+ using `conda run -n myenv bash /path/to/helper.sh`.
28
+ - Easily run the function in a Singularity container
29
+ using `singularity exec my_container.sif bash /path/to/helper.sh`.
30
+ """
31
+ with script_path.open("w") as f:
32
+ f.write("#!/bin/bash\n\n")
33
+ f.write("set -e\n\n")
34
+
35
+ if environment:
36
+ for key, value in environment.items():
37
+ f.write(f"export {key}={value}\n")
38
+ f.write("\n")
39
+
40
+ if setup_commands:
41
+ for setup_command in setup_commands:
42
+ f.write(f"{setup_command}\n")
43
+ f.write("\n")
44
+
45
+ if not isinstance(original_command, str):
46
+ original_command = " ".join(original_command)
47
+
48
+ if command_prefix:
49
+ original_command = f"{command_prefix} {original_command}"
50
+
51
+ if prepend_command_with_exec:
52
+ original_command = f"exec {original_command}"
53
+ f.write(f"{original_command}\n")
54
+
55
+ if chmod:
56
+ # Make the script executable
57
+ script_path.chmod(0o755)
58
+
59
+
60
+ def write_helper_script(
61
+ base_dir: Path,
62
+ command: str | Iterable[str],
63
+ environment: Mapping[str, str],
64
+ setup_commands: Sequence[str],
65
+ chmod: bool = True,
66
+ prepend_command_with_exec: bool = True,
67
+ command_prefix: str | None = None,
68
+ file_name: str = "helper.sh",
69
+ ):
70
+ """
71
+ Creates a helper bash script for running the given function.
72
+
73
+ The core idea: The helper script is essentially one additional layer of indirection
74
+ that allows us to encapsulates the environment setup and the actual function call
75
+ in a single bash script (that does not require properly set up Python environment).
76
+
77
+ In effect, this allows us to, for example:
78
+ - Easily run the function in the correct environment
79
+ (without having to deal with shell hooks)
80
+ using `conda run -n myenv bash /path/to/helper.sh`.
81
+ - Easily run the function in a Singularity container
82
+ using `singularity exec my_container.sif bash /path/to/helper.sh`.
83
+ """
84
+
85
+ out_path = base_dir / file_name
86
+ _create_launcher_script_file(
87
+ out_path,
88
+ command,
89
+ environment,
90
+ setup_commands,
91
+ chmod,
92
+ prepend_command_with_exec,
93
+ command_prefix,
94
+ )
95
+ return out_path
96
+
97
+
98
+ DEFAULT_TEMPLATE = "bash {script}"
99
+
100
+
101
+ def helper_script_to_command(script: Path, template: str | None) -> str:
102
+ if not template:
103
+ template = DEFAULT_TEMPLATE
104
+
105
+ # Make sure the template has '{script}' in it
106
+ if "{script}" not in template:
107
+ raise ValueError(f"Template must contain '{{script}}'. Got: {template!r}")
108
+
109
+ return template.format(script=str(script.absolute()))