nshtrainer 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/__init__.py CHANGED
@@ -1,11 +1,11 @@
1
1
  from . import _experimental as _experimental
2
2
  from . import actsave as actsave
3
3
  from . import callbacks as callbacks
4
+ from . import config as config
4
5
  from . import lr_scheduler as lr_scheduler
5
6
  from . import model as model
6
7
  from . import nn as nn
7
8
  from . import optimizer as optimizer
8
- from . import snapshot as snapshot
9
9
  from . import typecheck as typecheck
10
10
  from ._snoop import snoop as snoop
11
11
  from .actsave import ActLoad as ActLoad
nshtrainer/_snoop.py CHANGED
@@ -1,216 +1 @@
1
- import contextlib
2
- from typing import Any, Protocol, cast
3
-
4
- from typing_extensions import TypeVar
5
-
6
- T = TypeVar("T", infer_variance=True)
7
-
8
-
9
- class SnoopConstructor(Protocol):
10
- def __call__(self, *args, **kwargs) -> contextlib.AbstractContextManager: ...
11
-
12
- def disable(self) -> contextlib.AbstractContextManager: ...
13
-
14
-
15
- try:
16
- import warnings
17
- from contextlib import nullcontext
18
-
19
- import lovely_numpy as lo
20
- import lovely_tensors as lt
21
- import numpy
22
- import pysnooper
23
- import pysnooper.utils
24
- import torch
25
- from pkg_resources import DistributionNotFound, get_distribution
26
-
27
- FLOATING_POINTS = set()
28
- for i in ["float", "double", "half", "complex128", "complex32", "complex64"]:
29
- if hasattr(torch, i): # older version of PyTorch do not have complex dtypes
30
- FLOATING_POINTS.add(getattr(torch, i))
31
-
32
- try:
33
- __version__ = get_distribution(__name__).version
34
- except DistributionNotFound:
35
- # package is not installed
36
- pass
37
-
38
- def default_format(x):
39
- try:
40
- formatted = str(lt.lovely(x))
41
- return formatted
42
- except BaseException:
43
- return str(x.shape)
44
-
45
- def default_numpy_format(x):
46
- return str(lo.lovely(x))
47
-
48
- class TorchSnooper(pysnooper.tracer.Tracer):
49
- def __init__(
50
- self,
51
- *args,
52
- tensor_format=default_format,
53
- numpy_format=default_numpy_format,
54
- **kwargs,
55
- ):
56
- self.orig_custom_repr = (
57
- kwargs["custom_repr"] if "custom_repr" in kwargs else ()
58
- )
59
- custom_repr = (lambda x: True, self.compute_repr)
60
- kwargs["custom_repr"] = (custom_repr,)
61
- super(TorchSnooper, self).__init__(*args, **kwargs)
62
- self.tensor_format = tensor_format
63
- self.numpy_format = numpy_format
64
-
65
- @staticmethod
66
- def is_return_types(x):
67
- return type(x).__module__ == "torch.return_types"
68
-
69
- def return_types_repr(self, x):
70
- if type(x).__name__ in {
71
- "max",
72
- "min",
73
- "median",
74
- "mode",
75
- "sort",
76
- "topk",
77
- "kthvalue",
78
- }:
79
- return (
80
- type(x).__name__
81
- + "(values="
82
- + self.tensor_format(x.values)
83
- + ", indices="
84
- + self.tensor_format(x.indices)
85
- + ")"
86
- )
87
- if type(x).__name__ == "svd":
88
- return (
89
- "svd(U="
90
- + self.tensor_format(x.U)
91
- + ", S="
92
- + self.tensor_format(x.S)
93
- + ", V="
94
- + self.tensor_format(x.V)
95
- + ")"
96
- )
97
- if type(x).__name__ == "slogdet":
98
- return (
99
- "slogdet(sign="
100
- + self.tensor_format(x.sign)
101
- + ", logabsdet="
102
- + self.tensor_format(x.logabsdet)
103
- + ")"
104
- )
105
- if type(x).__name__ == "qr":
106
- return (
107
- "qr(Q="
108
- + self.tensor_format(x.Q)
109
- + ", R="
110
- + self.tensor_format(x.R)
111
- + ")"
112
- )
113
- if type(x).__name__ == "solve":
114
- return (
115
- "solve(solution="
116
- + self.tensor_format(x.solution)
117
- + ", LU="
118
- + self.tensor_format(x.LU)
119
- + ")"
120
- )
121
- if type(x).__name__ == "geqrf":
122
- return (
123
- "geqrf(a="
124
- + self.tensor_format(x.a)
125
- + ", tau="
126
- + self.tensor_format(x.tau)
127
- + ")"
128
- )
129
- if type(x).__name__ in {"symeig", "eig"}:
130
- return (
131
- type(x).__name__
132
- + "(eigenvalues="
133
- + self.tensor_format(x.eigenvalues)
134
- + ", eigenvectors="
135
- + self.tensor_format(x.eigenvectors)
136
- + ")"
137
- )
138
- if type(x).__name__ == "triangular_solve":
139
- return (
140
- "triangular_solve(solution="
141
- + self.tensor_format(x.solution)
142
- + ", cloned_coefficient="
143
- + self.tensor_format(x.cloned_coefficient)
144
- + ")"
145
- )
146
- if type(x).__name__ == "gels":
147
- return (
148
- "gels(solution="
149
- + self.tensor_format(x.solution)
150
- + ", QR="
151
- + self.tensor_format(x.QR)
152
- + ")"
153
- )
154
- warnings.warn("Unknown return_types encountered, open a bug report!")
155
-
156
- def compute_repr(self, x):
157
- orig_repr_func = pysnooper.utils.get_repr_function(x, self.orig_custom_repr)
158
- if torch.is_tensor(x):
159
- return self.tensor_format(x)
160
- if isinstance(x, numpy.ndarray):
161
- return self.numpy_format(x)
162
- if self.is_return_types(x):
163
- return self.return_types_repr(x)
164
- if orig_repr_func is not repr:
165
- return orig_repr_func(x)
166
- if isinstance(x, (list, tuple)):
167
- content = ""
168
- for i in x:
169
- if content != "":
170
- content += ", "
171
- content += self.compute_repr(i)
172
- if isinstance(x, tuple) and len(x) == 1:
173
- content += ","
174
- if isinstance(x, tuple):
175
- return "(" + content + ")"
176
- return "[" + content + "]"
177
- if isinstance(x, dict):
178
- content = ""
179
- for k, v in x.items():
180
- if content != "":
181
- content += ", "
182
- content += self.compute_repr(k) + ": " + self.compute_repr(v)
183
- return "{" + content + "}"
184
- return repr(x)
185
-
186
- class _Snoop:
187
- disable = nullcontext
188
- __call__ = TorchSnooper
189
-
190
- snoop: SnoopConstructor = cast(Any, _Snoop())
191
-
192
- except ImportError:
193
- import warnings
194
- from contextlib import nullcontext
195
-
196
- from typing_extensions import override
197
-
198
- _has_warned = False
199
-
200
- class _snoop_cls(nullcontext):
201
- @classmethod
202
- def disable(cls):
203
- return nullcontext()
204
-
205
- @override
206
- def __enter__(self):
207
- global _has_warned
208
- if not _has_warned:
209
- warnings.warn(
210
- "snoop is not installed, please install it to enable snoop"
211
- )
212
- _has_warned = True
213
-
214
- return super().__enter__()
215
-
216
- snoop: SnoopConstructor = cast(Any, _snoop_cls)
1
+ from nshutils.snoop import * # type: ignore # noqa: F403
nshtrainer/config.py ADDED
@@ -0,0 +1,4 @@
1
+ from nshconfig import * # type: ignore # noqa: F403
2
+ from nshconfig import Config
3
+
4
+ TypedConfig = Config
nshtrainer/runner.py CHANGED
@@ -7,13 +7,13 @@ from typing_extensions import TypeVar, TypeVarTuple, Unpack, override
7
7
  from .model.config import BaseConfig
8
8
 
9
9
  TConfig = TypeVar("TConfig", bound=BaseConfig, infer_variance=True)
10
- TArguments = TypeVarTuple("TArguments")
10
+ TArguments = TypeVarTuple("TArguments", default=Unpack[tuple[()]])
11
11
  TReturn = TypeVar("TReturn", infer_variance=True)
12
12
 
13
13
 
14
14
  class Runner(
15
- _Runner[Unpack[tuple[TConfig, Unpack[TArguments]]], TReturn],
16
- Generic[TConfig, Unpack[TArguments], TReturn],
15
+ _Runner[TReturn, TConfig, Unpack[TArguments]],
16
+ Generic[TReturn, TConfig, Unpack[TArguments]],
17
17
  ):
18
18
  @override
19
19
  @classmethod
nshtrainer/typecheck.py CHANGED
@@ -1,145 +1 @@
1
- import os
2
- from collections.abc import Sequence
3
- from logging import getLogger
4
- from typing import Any
5
-
6
- import numpy as np
7
- import torch
8
- from jaxtyping import BFloat16 as BFloat16
9
- from jaxtyping import Bool as Bool
10
- from jaxtyping import Complex as Complex
11
- from jaxtyping import Complex64 as Complex64
12
- from jaxtyping import Complex128 as Complex128
13
- from jaxtyping import Float as Float
14
- from jaxtyping import Float16 as Float16
15
- from jaxtyping import Float32 as Float32
16
- from jaxtyping import Float64 as Float64
17
- from jaxtyping import Inexact as Inexact
18
- from jaxtyping import Int as Int
19
- from jaxtyping import Int4 as Int4
20
- from jaxtyping import Int8 as Int8
21
- from jaxtyping import Int16 as Int16
22
- from jaxtyping import Int32 as Int32
23
- from jaxtyping import Int64 as Int64
24
- from jaxtyping import Integer as Integer
25
- from jaxtyping import Key as Key
26
- from jaxtyping import Num as Num
27
- from jaxtyping import Real as Real
28
- from jaxtyping import Shaped as Shaped
29
- from jaxtyping import UInt as UInt
30
- from jaxtyping import UInt4 as UInt4
31
- from jaxtyping import UInt8 as UInt8
32
- from jaxtyping import UInt16 as UInt16
33
- from jaxtyping import UInt32 as UInt32
34
- from jaxtyping import UInt64 as UInt64
35
- from jaxtyping._storage import get_shape_memo, shape_str
36
- from torch import Tensor as Tensor
37
- from torch.nn.parameter import Parameter as Parameter
38
- from typing_extensions import TypeVar
39
-
40
- log = getLogger(__name__)
41
-
42
- DISABLE_ENV_KEY = "LL_DISABLE_TYPECHECKING"
43
-
44
-
45
- def typecheck_modules(modules: Sequence[str]):
46
- """
47
- Typecheck the given modules using `jaxtyping`.
48
-
49
- Args:
50
- modules: Modules to typecheck.
51
- """
52
- # If `DISABLE_ENV_KEY` is set and the environment variable is set, skip
53
- # typechecking.
54
- if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
55
- log.critical(
56
- f"Type checking is disabled due to the environment variable {DISABLE_ENV_KEY}."
57
- )
58
- return
59
-
60
- # Install the jaxtyping import hook for this module.
61
- from jaxtyping import install_import_hook
62
-
63
- install_import_hook(modules, "beartype.beartype")
64
-
65
- log.critical(f"Type checking the following modules: {modules}")
66
-
67
-
68
- def typecheck_this_module(additional_modules: Sequence[str] = ()):
69
- """
70
- Typecheck the calling module and any additional modules using `jaxtyping`.
71
-
72
- Args:
73
- additional_modules: Additional modules to typecheck.
74
- """
75
- # Get the calling module's name.
76
- # Here, we can just use beartype's internal implementation behind
77
- # `beartype_this_package`.
78
- from beartype._util.func.utilfuncframe import get_frame, get_frame_package_name
79
-
80
- # Get the calling module's name.
81
- assert get_frame is not None, "get_frame is None"
82
- frame = get_frame(1)
83
- assert frame is not None, "frame is None"
84
- calling_module_name = get_frame_package_name(frame)
85
- assert calling_module_name is not None, "calling_module_name is None"
86
-
87
- # Typecheck the calling module + any additional modules.
88
- typecheck_modules((calling_module_name, *additional_modules))
89
-
90
-
91
- def _make_error_str(input: Any, t: Any) -> str:
92
- error_components: list[str] = []
93
- error_components.append("Type checking error:")
94
- if hasattr(t, "__instancecheck_str__"):
95
- error_components.append(t.__instancecheck_str__(input))
96
- if torch.is_tensor(input):
97
- try:
98
- from lovely_tensors import lovely
99
-
100
- error_components.append(repr(lovely(input)))
101
- except BaseException:
102
- error_components.append(repr(input.shape))
103
- error_components.append(shape_str(get_shape_memo()))
104
-
105
- return "\n".join(error_components)
106
-
107
-
108
- T = TypeVar("T", torch.Tensor, np.ndarray, infer_variance=True)
109
-
110
- """
111
- Patch to jaxtyping:
112
-
113
- In `jaxtyping._import_hook`, we add:
114
- def _has_isinstance_or_tassert(func_def):
115
- for node in ast.walk(func_def):
116
- if isinstance(node, ast.Call):
117
- if isinstance(node.func, ast.Name) and node.func.id == "isinstance":
118
- return True
119
- elif isinstance(node.func, ast.Name) and node.func.id == "tassert":
120
- return True
121
- return False
122
-
123
- and we check this when adding the decorators.
124
- """
125
-
126
-
127
- def tassert(t: Any, input: T | tuple[T, ...]):
128
- """
129
- Typecheck the input against the given type.
130
-
131
- Args:
132
- t: Type to check against.
133
- input: Input to check.
134
- """
135
-
136
- # Ignore typechecking if the environment variable is set.
137
- if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
138
- return
139
-
140
- if isinstance(input, tuple):
141
- for i in input:
142
- assert isinstance(i, t), _make_error_str(i, t)
143
- return
144
- else:
145
- assert isinstance(input, t), _make_error_str(input, t)
1
+ from nshutils.typecheck import * # type: ignore # noqa: F403
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.1.1
3
+ Version: 0.3.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -9,15 +9,13 @@ Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.10
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
- Requires-Dist: beartype (>=0.18.5,<0.19.0)
13
- Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
14
12
  Requires-Dist: lightning
15
13
  Requires-Dist: lovely-numpy (>=0.2.13,<0.3.0)
16
14
  Requires-Dist: lovely-tensors (>=0.1.16,<0.2.0)
17
15
  Requires-Dist: nshconfig (>=0.2.0,<0.3.0)
18
- Requires-Dist: nshrunner (>=0.5.3,<0.6.0)
16
+ Requires-Dist: nshrunner (>=0.5.4,<0.6.0)
17
+ Requires-Dist: nshutils (>=0.2.0,<0.3.0)
19
18
  Requires-Dist: numpy
20
- Requires-Dist: pysnooper
21
19
  Requires-Dist: pytorch-lightning
22
20
  Requires-Dist: rich
23
21
  Requires-Dist: torch
@@ -1,9 +1,9 @@
1
- nshtrainer/__init__.py,sha256=OHbxLxVvFGW--ecuIGqkoylSVHFS4x4F1-oeuENH-Do,2212
1
+ nshtrainer/__init__.py,sha256=_r7kBmgGSLVfActlqQeupNolrmBu45xUuSS8odt3HL8,2208
2
2
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
3
3
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
4
4
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
5
5
  nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
6
- nshtrainer/_snoop.py,sha256=Rofv1Rd92E0LY40G3A-o9Hu0ZI73RR59wJD5l4Q3PDM,7022
6
+ nshtrainer/_snoop.py,sha256=2rEemPyMP3aIo2QgPzo_-AlT1oXGWYQipId4RQskMls,58
7
7
  nshtrainer/actsave/__init__.py,sha256=G1T-fELuGWkVqdhdyoePtj2dTOUtcIOW4VgsXv9JNTA,338
8
8
  nshtrainer/actsave/_callback.py,sha256=QoTa60F70f1RxB41VKixN9l5_htfFQxXDPHHSNFreuk,2770
9
9
  nshtrainer/actsave/_loader.py,sha256=fAhD32DrJa4onkYfcwc21YIeGEYzOSXCK_HVo9SZLgQ,4604
@@ -24,6 +24,7 @@ nshtrainer/callbacks/print_table.py,sha256=FcA-CBWwMf9c1NNRinvYpZC400RNQxuP28bJf
24
24
  nshtrainer/callbacks/throughput_monitor.py,sha256=YQLdpX3LGybIiD814yT9yCCVSEXRWf8WwsvVaN5aDBE,1848
25
25
  nshtrainer/callbacks/timer.py,sha256=sDXPPcdDKu5xnuK_bjr8plIq9MBuluNJ42Mt9LvPZzc,4610
26
26
  nshtrainer/callbacks/wandb_watch.py,sha256=pUpMsNxd03ex1rzOmFw2HzGOXjnQGaH84m8cc2dXo4g,2937
27
+ nshtrainer/config.py,sha256=IXOAl_JWFNX9kPTo_iw4Nc3qXqkKrbA6-ZrvTAjqu6A,104
27
28
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
28
29
  nshtrainer/data/balanced_batch_sampler.py,sha256=bcJBcQjh1hB1yKF_xSlT9AtEWv0BJjYc1CuH2BF-ea8,4392
29
30
  nshtrainer/data/transform.py,sha256=JeGxvytQly8hougrsdMmKG8gJ6qvFPDglJCO4Tp6STk,1795
@@ -47,18 +48,18 @@ nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,
47
48
  nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
48
49
  nshtrainer/nn/nonlinearity.py,sha256=owtU4kh4G98psD0axOJWVfBhm-OtJVgFM-TXSHmbNPU,3625
49
50
  nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
50
- nshtrainer/runner.py,sha256=af_EGnQTSvUgwnVhhytvY3V7o_Xg-xx-sLb8K2Szb1E,979
51
+ nshtrainer/runner.py,sha256=vyHr0EZ0PBOWZh09BtOOxio-FRQZFbVoL4cdBlI97vY,991
51
52
  nshtrainer/scripts/check_env.py,sha256=IMl6dSqsLYppI0XuCsVq8lK4bYqXwY9KHJkzsShz4Kg,806
52
53
  nshtrainer/scripts/find_packages.py,sha256=FbdlfmAefttFSMfaT0A46a-oHLP_ioaQKihwBfBeWeA,1467
53
54
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
54
55
  nshtrainer/trainer/signal_connector.py,sha256=QAoPM_C5JJOVQebcrJOimUUD3GHyoeZUqCEAvzZlT4U,8710
55
56
  nshtrainer/trainer/trainer.py,sha256=eYEYfY9v70MuorHcSf8nqM7f2CkmUHhpPcjCk4FJD7k,14034
56
- nshtrainer/typecheck.py,sha256=RGYHxDBcs97E6ayl6Olc43JBZXQolCtMxcLBniVCVBg,4688
57
+ nshtrainer/typecheck.py,sha256=ryV1Tzcf7hJ4I19H1oQVkikU9spmRk8jyIKQZ5UF7pQ,62
57
58
  nshtrainer/util/environment.py,sha256=_SEtiQ_s5bL5pllUlf96AOUv15kNvCPvocVC13S7mIk,4166
58
59
  nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
59
60
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
60
61
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
61
62
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
62
- nshtrainer-0.1.1.dist-info/METADATA,sha256=32iVLvdJh6OJQyD-_7NDO6IYqfHPSflDznYfYaCo8-c,882
63
- nshtrainer-0.1.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
64
- nshtrainer-0.1.1.dist-info/RECORD,,
63
+ nshtrainer-0.3.0.dist-info/METADATA,sha256=oME_P_Y7U4bavZunb-rF2m1R1w31vOS0Qh9l0Nfua68,812
64
+ nshtrainer-0.3.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
65
+ nshtrainer-0.3.0.dist-info/RECORD,,