nshutils 0.19.1__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshutils/__init__.py +4 -13
- nshutils/__init__.pyi +13 -0
- nshutils/actsave/_saver.py +14 -10
- nshutils/logging.py +14 -64
- nshutils/lovely/__init__.py +10 -0
- nshutils/lovely/_base.py +155 -0
- nshutils/lovely/_monkey_patch_all.py +68 -0
- nshutils/lovely/config.py +47 -0
- nshutils/lovely/jax_.py +89 -0
- nshutils/lovely/numpy_.py +72 -0
- nshutils/lovely/torch_.py +99 -0
- nshutils/lovely/utils.py +345 -0
- nshutils/typecheck.py +3 -2
- nshutils/util.py +92 -0
- {nshutils-0.19.1.dist-info → nshutils-0.21.0.dist-info}/METADATA +12 -6
- nshutils-0.21.0.dist-info/RECORD +22 -0
- {nshutils-0.19.1.dist-info → nshutils-0.21.0.dist-info}/WHEEL +1 -1
- nshutils-0.19.1.dist-info/RECORD +0 -12
nshutils/__init__.py
CHANGED
@@ -1,18 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
from .collections import apply_to_collection as apply_to_collection
|
8
|
-
from .display import display as display
|
9
|
-
from .logging import init_python_logging as init_python_logging
|
10
|
-
from .logging import lovely as lovely
|
11
|
-
from .logging import pretty as pretty
|
12
|
-
from .snoop import snoop as snoop
|
13
|
-
from .typecheck import tassert as tassert
|
14
|
-
from .typecheck import typecheck_modules as typecheck_modules
|
15
|
-
from .typecheck import typecheck_this_module as typecheck_this_module
|
3
|
+
import lazy_loader as lazy
|
4
|
+
|
5
|
+
__getattr__, __dir__, __all__ = lazy.attach_stub(__name__, __file__)
|
6
|
+
|
16
7
|
|
17
8
|
try:
|
18
9
|
from importlib.metadata import PackageNotFoundError, version
|
nshutils/__init__.pyi
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
from . import actsave as actsave
|
2
|
+
from . import lovely as lovely
|
3
|
+
from . import typecheck as typecheck
|
4
|
+
from .actsave import ActLoad as ActLoad
|
5
|
+
from .actsave import ActSave as ActSave
|
6
|
+
from .collections import apply_to_collection as apply_to_collection
|
7
|
+
from .display import display as display
|
8
|
+
from .logging import init_python_logging as init_python_logging
|
9
|
+
from .logging import setup_logging as setup_logging
|
10
|
+
from .snoop import snoop as snoop
|
11
|
+
from .typecheck import tassert as tassert
|
12
|
+
from .typecheck import typecheck_modules as typecheck_modules
|
13
|
+
from .typecheck import typecheck_this_module as typecheck_this_module
|
nshutils/actsave/_saver.py
CHANGED
@@ -9,10 +9,10 @@ from dataclasses import dataclass
|
|
9
9
|
from functools import wraps
|
10
10
|
from logging import getLogger
|
11
11
|
from pathlib import Path
|
12
|
-
from typing import TYPE_CHECKING, Generic,
|
12
|
+
from typing import TYPE_CHECKING, Generic, Literal, cast, overload
|
13
13
|
|
14
14
|
import numpy as np
|
15
|
-
from typing_extensions import Never, ParamSpec, TypeVar, override
|
15
|
+
from typing_extensions import Never, ParamSpec, TypeAliasType, TypeVar, override
|
16
16
|
from uuid_extensions import uuid7str
|
17
17
|
|
18
18
|
from ..collections import apply_to_collection
|
@@ -21,25 +21,29 @@ if not TYPE_CHECKING:
|
|
21
21
|
try:
|
22
22
|
import torch # type: ignore
|
23
23
|
|
24
|
-
Tensor
|
24
|
+
Tensor = torch.Tensor
|
25
|
+
_torch_installed = True
|
25
26
|
except ImportError:
|
26
27
|
torch = None
|
28
|
+
_torch_installed = False
|
27
29
|
|
28
|
-
Tensor
|
30
|
+
Tensor = Never
|
29
31
|
else:
|
30
32
|
import torch # type: ignore
|
31
33
|
|
32
|
-
Tensor
|
33
|
-
|
34
|
+
Tensor = torch.Tensor
|
35
|
+
_torch_installed: Literal[True] = True
|
34
36
|
|
35
37
|
log = getLogger(__name__)
|
36
38
|
|
37
|
-
Value
|
38
|
-
|
39
|
+
Value = TypeAliasType(
|
40
|
+
"Value", int | float | complex | bool | str | np.ndarray | Tensor | None
|
41
|
+
)
|
42
|
+
ValueOrLambda = TypeAliasType("ValueOrLambda", Value | Callable[..., Value])
|
39
43
|
|
40
44
|
|
41
45
|
def _torch_is_scripting() -> bool:
|
42
|
-
if
|
46
|
+
if _torch_installed:
|
43
47
|
return False
|
44
48
|
|
45
49
|
return torch.jit.is_scripting()
|
@@ -54,7 +58,7 @@ def _to_numpy(activation: Value) -> np.ndarray:
|
|
54
58
|
return np.array(activation)
|
55
59
|
elif isinstance(activation, np.ndarray):
|
56
60
|
return activation
|
57
|
-
elif isinstance(activation, Tensor):
|
61
|
+
elif _torch_installed and isinstance(activation, Tensor):
|
58
62
|
activation_ = activation.detach()
|
59
63
|
if activation_.is_floating_point():
|
60
64
|
# NOTE: We need to convert to float32 because [b]float16 is not supported by numpy
|
nshutils/logging.py
CHANGED
@@ -2,12 +2,15 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from pathlib import Path
|
5
|
+
from typing import TYPE_CHECKING
|
5
6
|
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from .lovely._monkey_patch_all import Library
|
6
9
|
|
7
|
-
|
10
|
+
|
11
|
+
def setup_logging(
|
8
12
|
*,
|
9
|
-
|
10
|
-
lovely_numpy: bool = False,
|
13
|
+
lovely: bool | list[Library] = False,
|
11
14
|
treescope: bool = False,
|
12
15
|
treescope_autovisualize_arrays: bool = False,
|
13
16
|
rich: bool = False,
|
@@ -15,25 +18,10 @@ def init_python_logging(
|
|
15
18
|
log_level: int | str | None = logging.INFO,
|
16
19
|
log_save_dir: Path | None = None,
|
17
20
|
):
|
18
|
-
if
|
19
|
-
|
20
|
-
import lovely_tensors as _lovely_tensors # type: ignore
|
21
|
-
|
22
|
-
_lovely_tensors.monkey_patch()
|
23
|
-
except ImportError:
|
24
|
-
logging.info(
|
25
|
-
"Failed to import `lovely_tensors`. Ignoring pretty PyTorch tensor formatting"
|
26
|
-
)
|
27
|
-
|
28
|
-
if lovely_numpy:
|
29
|
-
try:
|
30
|
-
import lovely_numpy as _lovely_numpy # type: ignore
|
21
|
+
if lovely:
|
22
|
+
from .lovely._monkey_patch_all import monkey_patch
|
31
23
|
|
32
|
-
|
33
|
-
except ImportError:
|
34
|
-
logging.info(
|
35
|
-
"Failed to import `lovely_numpy`. Ignoring pretty numpy array formatting"
|
36
|
-
)
|
24
|
+
monkey_patch("auto" if lovely is True else lovely)
|
37
25
|
|
38
26
|
if treescope:
|
39
27
|
try:
|
@@ -77,49 +65,11 @@ def init_python_logging(
|
|
77
65
|
datefmt="[%X]",
|
78
66
|
handlers=log_handlers,
|
79
67
|
)
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
lovely_tensors: bool = True,
|
85
|
-
lovely_numpy: bool = True,
|
86
|
-
treescope: bool = False,
|
87
|
-
treescope_autovisualize_arrays: bool = False,
|
88
|
-
log_level: int | str | None = logging.INFO,
|
89
|
-
log_save_dir: Path | None = None,
|
90
|
-
rich_log_handler: bool = False,
|
91
|
-
rich_tracebacks: bool = False,
|
92
|
-
):
|
93
|
-
init_python_logging(
|
94
|
-
lovely_tensors=lovely_tensors,
|
95
|
-
lovely_numpy=lovely_numpy,
|
96
|
-
treescope=treescope,
|
97
|
-
treescope_autovisualize_arrays=treescope_autovisualize_arrays,
|
98
|
-
rich=rich_log_handler,
|
99
|
-
log_level=log_level,
|
100
|
-
log_save_dir=log_save_dir,
|
101
|
-
rich_tracebacks=rich_tracebacks,
|
68
|
+
logging.info(
|
69
|
+
"Logging initialized. "
|
70
|
+
f"Lovely: {lovely}, Treescope: {treescope}, Rich: {rich}, "
|
71
|
+
f"Log level: {log_level}, Log save dir: {log_save_dir}"
|
102
72
|
)
|
103
73
|
|
104
74
|
|
105
|
-
|
106
|
-
*,
|
107
|
-
lovely_tensors: bool = True,
|
108
|
-
lovely_numpy: bool = True,
|
109
|
-
treescope: bool = False,
|
110
|
-
treescope_autovisualize_arrays: bool = False,
|
111
|
-
log_level: int | str | None = logging.INFO,
|
112
|
-
log_save_dir: Path | None = None,
|
113
|
-
rich_log_handler: bool = False,
|
114
|
-
rich_tracebacks: bool = False,
|
115
|
-
):
|
116
|
-
pretty(
|
117
|
-
lovely_tensors=lovely_tensors,
|
118
|
-
lovely_numpy=lovely_numpy,
|
119
|
-
treescope=treescope,
|
120
|
-
treescope_autovisualize_arrays=treescope_autovisualize_arrays,
|
121
|
-
log_level=log_level,
|
122
|
-
log_save_dir=log_save_dir,
|
123
|
-
rich_log_handler=rich_log_handler,
|
124
|
-
rich_tracebacks=rich_tracebacks,
|
125
|
-
)
|
75
|
+
init_python_logging = setup_logging
|
@@ -0,0 +1,10 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from ._monkey_patch_all import monkey_patch as monkey_patch
|
4
|
+
from .config import LovelyConfig as LovelyConfig
|
5
|
+
from .jax_ import jax_monkey_patch as jax_monkey_patch
|
6
|
+
from .jax_ import jax_repr as jax_repr
|
7
|
+
from .numpy_ import numpy_monkey_patch as numpy_monkey_patch
|
8
|
+
from .numpy_ import numpy_repr as numpy_repr
|
9
|
+
from .torch_ import torch_monkey_patch as torch_monkey_patch
|
10
|
+
from .torch_ import torch_repr as torch_repr
|
nshutils/lovely/_base.py
ADDED
@@ -0,0 +1,155 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import functools
|
4
|
+
import importlib.util
|
5
|
+
import logging
|
6
|
+
from collections.abc import Callable, Iterator
|
7
|
+
|
8
|
+
from typing_extensions import ParamSpec, TypeAliasType, TypeVar
|
9
|
+
|
10
|
+
from ..util import ContextResource, resource_factory_contextmanager
|
11
|
+
from .utils import LovelyStats, format_tensor_stats
|
12
|
+
|
13
|
+
log = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
TArray = TypeVar("TArray", infer_variance=True)
|
16
|
+
P = ParamSpec("P")
|
17
|
+
|
18
|
+
LovelyStatsFn = TypeAliasType(
|
19
|
+
"LovelyStatsFn",
|
20
|
+
Callable[[TArray], LovelyStats],
|
21
|
+
type_params=(TArray,),
|
22
|
+
)
|
23
|
+
LovelyReprFn = TypeAliasType(
|
24
|
+
"LovelyReprFn",
|
25
|
+
Callable[[TArray], str],
|
26
|
+
type_params=(TArray,),
|
27
|
+
)
|
28
|
+
|
29
|
+
|
30
|
+
def _find_missing_deps(dependencies: list[str]):
|
31
|
+
missing_deps: list[str] = []
|
32
|
+
|
33
|
+
for dep in dependencies:
|
34
|
+
if importlib.util.find_spec(dep) is not None:
|
35
|
+
continue
|
36
|
+
|
37
|
+
missing_deps.append(dep)
|
38
|
+
|
39
|
+
return missing_deps
|
40
|
+
|
41
|
+
|
42
|
+
def lovely_repr(dependencies: list[str]):
|
43
|
+
"""
|
44
|
+
Decorator to create a lovely representation function for an array.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
dependencies: List of dependencies to check before running the function.
|
48
|
+
If any dependency is not available, the function will not run.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
A decorator function that takes a function and returns a lovely representation function.
|
52
|
+
|
53
|
+
Example:
|
54
|
+
@lovely_repr(dependencies=["torch"])
|
55
|
+
def my_array_stats(array):
|
56
|
+
return {...}
|
57
|
+
"""
|
58
|
+
|
59
|
+
def decorator_fn(array_stats_fn: LovelyStatsFn[TArray]) -> LovelyReprFn[TArray]:
|
60
|
+
"""
|
61
|
+
Decorator to create a lovely representation function for an array.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
array_stats_fn: A function that takes an array and returns its stats.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
A function that takes an array and returns its lovely representation.
|
68
|
+
"""
|
69
|
+
|
70
|
+
@functools.wraps(array_stats_fn)
|
71
|
+
def wrapper(array: TArray) -> str:
|
72
|
+
if missing_deps := _find_missing_deps(dependencies):
|
73
|
+
log.warning(
|
74
|
+
f"Missing dependencies: {', '.join(missing_deps)}. "
|
75
|
+
"Skipping lovely representation."
|
76
|
+
)
|
77
|
+
return repr(array)
|
78
|
+
|
79
|
+
stats = array_stats_fn(array)
|
80
|
+
return format_tensor_stats(stats)
|
81
|
+
|
82
|
+
return wrapper
|
83
|
+
|
84
|
+
return decorator_fn
|
85
|
+
|
86
|
+
|
87
|
+
LovelyMonkeyPatchInputFn = TypeAliasType(
|
88
|
+
"LovelyMonkeyPatchInputFn",
|
89
|
+
Callable[P, Iterator[None]],
|
90
|
+
type_params=(P,),
|
91
|
+
)
|
92
|
+
LovelyMonkeyPatchFn = TypeAliasType(
|
93
|
+
"LovelyMonkeyPatchFn",
|
94
|
+
Callable[P, ContextResource[None]],
|
95
|
+
type_params=(P,),
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
def _nullcontext_generator():
|
100
|
+
"""A generator that does nothing."""
|
101
|
+
yield
|
102
|
+
|
103
|
+
|
104
|
+
def _wrap_monkey_patch_fn(
|
105
|
+
monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
|
106
|
+
dependencies: list[str],
|
107
|
+
) -> LovelyMonkeyPatchInputFn[P]:
|
108
|
+
@functools.wraps(monkey_patch_fn)
|
109
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[None]:
|
110
|
+
if missing_deps := _find_missing_deps(dependencies):
|
111
|
+
log.warning(
|
112
|
+
f"Missing dependencies: {', '.join(missing_deps)}. "
|
113
|
+
"Skipping monkey patch."
|
114
|
+
)
|
115
|
+
return _nullcontext_generator()
|
116
|
+
|
117
|
+
return monkey_patch_fn(*args, **kwargs)
|
118
|
+
|
119
|
+
return wrapper
|
120
|
+
|
121
|
+
|
122
|
+
def monkey_patch_contextmanager(dependencies: list[str]):
|
123
|
+
"""
|
124
|
+
Decorator to create a monkey patch function for an array.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
dependencies: List of dependencies to check before running the function.
|
128
|
+
If any dependency is not available, the function will not run.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
A decorator function that takes a function and returns a monkey patch function.
|
132
|
+
|
133
|
+
Example:
|
134
|
+
@monkey_patch_contextmanager(dependencies=["torch"])
|
135
|
+
def my_array_monkey_patch():
|
136
|
+
...
|
137
|
+
"""
|
138
|
+
|
139
|
+
def decorator_fn(
|
140
|
+
monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
|
141
|
+
) -> LovelyMonkeyPatchFn[P]:
|
142
|
+
"""
|
143
|
+
Decorator to create a monkey patch function for an array.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
monkey_patch_fn: A function that applies the monkey patch.
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
A function that applies the monkey patch.
|
150
|
+
"""
|
151
|
+
|
152
|
+
wrapped_fn = _wrap_monkey_patch_fn(monkey_patch_fn, dependencies)
|
153
|
+
return resource_factory_contextmanager(wrapped_fn)
|
154
|
+
|
155
|
+
return decorator_fn
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
import importlib.util
|
5
|
+
import logging
|
6
|
+
from typing import Literal
|
7
|
+
|
8
|
+
from typing_extensions import TypeAliasType, assert_never
|
9
|
+
|
10
|
+
from ..util import resource_factory_contextmanager
|
11
|
+
|
12
|
+
Library = TypeAliasType("Library", Literal["numpy", "torch", "jax"])
|
13
|
+
|
14
|
+
log = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
def _find_deps() -> list[Library]:
|
18
|
+
"""
|
19
|
+
Find available libraries for monkey patching.
|
20
|
+
"""
|
21
|
+
deps: list[Library] = []
|
22
|
+
if importlib.util.find_spec("torch") is not None:
|
23
|
+
deps.append("torch")
|
24
|
+
if importlib.util.find_spec("jax") is not None:
|
25
|
+
deps.append("jax")
|
26
|
+
if importlib.util.find_spec("numpy") is not None:
|
27
|
+
deps.append("numpy")
|
28
|
+
return deps
|
29
|
+
|
30
|
+
|
31
|
+
@resource_factory_contextmanager
|
32
|
+
def monkey_patch(libraries: list[Library] | Literal["auto"] = "auto"):
|
33
|
+
if libraries == "auto":
|
34
|
+
libraries = _find_deps()
|
35
|
+
|
36
|
+
if not libraries:
|
37
|
+
raise ValueError(
|
38
|
+
"No libraries found for monkey patching. "
|
39
|
+
"Please install numpy, torch, or jax."
|
40
|
+
)
|
41
|
+
|
42
|
+
with contextlib.ExitStack() as stack:
|
43
|
+
for library in libraries:
|
44
|
+
match library:
|
45
|
+
case "torch":
|
46
|
+
from .torch_ import torch_monkey_patch
|
47
|
+
|
48
|
+
stack.enter_context(torch_monkey_patch())
|
49
|
+
case "jax":
|
50
|
+
from .jax_ import jax_monkey_patch
|
51
|
+
|
52
|
+
stack.enter_context(jax_monkey_patch())
|
53
|
+
case "numpy":
|
54
|
+
from .numpy_ import numpy_monkey_patch
|
55
|
+
|
56
|
+
stack.enter_context(numpy_monkey_patch())
|
57
|
+
case _:
|
58
|
+
assert_never(library)
|
59
|
+
|
60
|
+
log.info(
|
61
|
+
f"Monkey patched libraries: {', '.join(libraries)}. "
|
62
|
+
"You can now use the lovely functions with these libraries."
|
63
|
+
)
|
64
|
+
yield
|
65
|
+
log.info(
|
66
|
+
f"Unmonkey patched libraries: {', '.join(libraries)}. "
|
67
|
+
"You can now use the lovely functions with these libraries."
|
68
|
+
)
|
@@ -0,0 +1,47 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import ClassVar
|
4
|
+
|
5
|
+
import nshconfig as C
|
6
|
+
from typing_extensions import Self
|
7
|
+
|
8
|
+
|
9
|
+
class LovelyConfig(C.Config):
|
10
|
+
"""
|
11
|
+
This class is used to manage the configuration of the Lovely library.
|
12
|
+
It inherits from the Config class in the nshconfig module.
|
13
|
+
"""
|
14
|
+
|
15
|
+
precision: int = 3
|
16
|
+
"""Number of digits after the decimal point."""
|
17
|
+
|
18
|
+
threshold_max: int = 3
|
19
|
+
"""Absolute values larger than 10^3 use scientific notation."""
|
20
|
+
|
21
|
+
threshold_min: int = -4
|
22
|
+
"""Absolute values smaller than 10^-4 use scientific notation."""
|
23
|
+
|
24
|
+
sci_mode: bool | None = None
|
25
|
+
"""Force scientific notation (None=auto)."""
|
26
|
+
|
27
|
+
show_mem_above: int = 1024
|
28
|
+
"""Show memory size if above this threshold (bytes)."""
|
29
|
+
|
30
|
+
color: bool = True
|
31
|
+
"""Use ANSI colors in text."""
|
32
|
+
|
33
|
+
indent: int = 2
|
34
|
+
"""Indent for nested representation."""
|
35
|
+
|
36
|
+
config_instance: ClassVar[Self | None] = None
|
37
|
+
"""Singleton instance of the LovelyConfig class."""
|
38
|
+
|
39
|
+
@classmethod
|
40
|
+
def instance(cls) -> Self:
|
41
|
+
"""
|
42
|
+
Get the singleton instance of the LovelyConfig class.
|
43
|
+
If it doesn't exist, create it.
|
44
|
+
"""
|
45
|
+
if cls.config_instance is None:
|
46
|
+
cls.config_instance = cls()
|
47
|
+
return cls.config_instance
|
nshutils/lovely/jax_.py
ADDED
@@ -0,0 +1,89 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, cast
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ._base import lovely_repr, monkey_patch_contextmanager
|
8
|
+
from .utils import LovelyStats, array_stats, patch_to
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
import jax
|
12
|
+
|
13
|
+
|
14
|
+
def _type_name(array: jax.Array):
|
15
|
+
type_name = type(array).__name__.rsplit(".", 1)[-1]
|
16
|
+
return "array" if type_name == "ArrayImpl" else type_name
|
17
|
+
|
18
|
+
|
19
|
+
_DT_NAMES = {
|
20
|
+
"float16": "f16",
|
21
|
+
"float32": "f32",
|
22
|
+
"float64": "f64",
|
23
|
+
"uint8": "u8",
|
24
|
+
"uint16": "u16",
|
25
|
+
"uint32": "u32",
|
26
|
+
"uint64": "u64",
|
27
|
+
"int8": "i8",
|
28
|
+
"int16": "i16",
|
29
|
+
"int32": "i32",
|
30
|
+
"int64": "i64",
|
31
|
+
"bfloat16": "bf16",
|
32
|
+
"complex64": "c64",
|
33
|
+
"complex128": "c128",
|
34
|
+
}
|
35
|
+
|
36
|
+
|
37
|
+
def _dtype_str(array: jax.Array) -> str:
|
38
|
+
dtype_base = str(array.dtype).rsplit(".", 1)[-1]
|
39
|
+
dtype_base = _DT_NAMES.get(dtype_base, dtype_base)
|
40
|
+
return dtype_base
|
41
|
+
|
42
|
+
|
43
|
+
def _device(array: jax.Array) -> str:
|
44
|
+
from jaxlib.xla_extension import Device
|
45
|
+
|
46
|
+
if callable(device := array.device):
|
47
|
+
device = device()
|
48
|
+
|
49
|
+
device = cast(Device, device)
|
50
|
+
if device.platform == "cpu":
|
51
|
+
return "cpu"
|
52
|
+
|
53
|
+
return f"{device.platform}:{device.id}"
|
54
|
+
|
55
|
+
|
56
|
+
@lovely_repr(dependencies=["jax"])
|
57
|
+
def jax_repr(array: jax.Array) -> LovelyStats:
|
58
|
+
import jax.numpy as jnp
|
59
|
+
|
60
|
+
return {
|
61
|
+
# Basic attributes
|
62
|
+
"shape": array.shape,
|
63
|
+
"size": array.size,
|
64
|
+
"nbytes": array.nbytes,
|
65
|
+
"type_name": _type_name(array),
|
66
|
+
# Dtype
|
67
|
+
"dtype_str": _dtype_str(array),
|
68
|
+
"is_complex": jnp.iscomplexobj(array),
|
69
|
+
# Device
|
70
|
+
"device": _device(array),
|
71
|
+
# Depending of whether the tensor is complex or not, we will call the appropriate stats function
|
72
|
+
**array_stats(np.asarray(array)),
|
73
|
+
}
|
74
|
+
|
75
|
+
|
76
|
+
@monkey_patch_contextmanager(dependencies=["jax"])
|
77
|
+
def jax_monkey_patch():
|
78
|
+
from jax._src import array
|
79
|
+
|
80
|
+
prev_repr = array.ArrayImpl.__repr__
|
81
|
+
prev_str = array.ArrayImpl.__str__
|
82
|
+
try:
|
83
|
+
patch_to(array.ArrayImpl, "__repr__", jax_repr)
|
84
|
+
patch_to(array.ArrayImpl, "__str__", jax_repr)
|
85
|
+
|
86
|
+
yield
|
87
|
+
finally:
|
88
|
+
patch_to(array.ArrayImpl, "__repr__", prev_repr)
|
89
|
+
patch_to(array.ArrayImpl, "__str__", prev_str)
|
@@ -0,0 +1,72 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ._base import lovely_repr, monkey_patch_contextmanager
|
8
|
+
from .utils import LovelyStats, array_stats
|
9
|
+
|
10
|
+
|
11
|
+
def _type_name(array: np.ndarray):
|
12
|
+
return (
|
13
|
+
"array"
|
14
|
+
if type(array) is np.ndarray
|
15
|
+
else type(array).__name__.rsplit(".", 1)[-1]
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
_DT_NAMES = {
|
20
|
+
"float16": "f16",
|
21
|
+
"float32": "f32",
|
22
|
+
"float64": "", # Default dtype in numpy
|
23
|
+
"uint8": "u8",
|
24
|
+
"uint16": "u16",
|
25
|
+
"uint32": "u32",
|
26
|
+
"uint64": "u64",
|
27
|
+
"int8": "i8",
|
28
|
+
"int16": "i16",
|
29
|
+
"int32": "i32",
|
30
|
+
"int64": "i64",
|
31
|
+
"complex64": "c64",
|
32
|
+
"complex128": "c128",
|
33
|
+
}
|
34
|
+
|
35
|
+
|
36
|
+
def _dtype_str(array: np.ndarray) -> str:
|
37
|
+
dtype_base = str(array.dtype).rsplit(".", 1)[-1]
|
38
|
+
dtype_base = _DT_NAMES.get(dtype_base, dtype_base)
|
39
|
+
return dtype_base
|
40
|
+
|
41
|
+
|
42
|
+
@lovely_repr(dependencies=["numpy"])
|
43
|
+
def numpy_repr(array: np.ndarray) -> LovelyStats:
|
44
|
+
return {
|
45
|
+
# Basic attributes
|
46
|
+
"shape": array.shape,
|
47
|
+
"size": array.size,
|
48
|
+
"nbytes": array.nbytes,
|
49
|
+
"type_name": _type_name(array),
|
50
|
+
# Dtype
|
51
|
+
"dtype_str": _dtype_str(array),
|
52
|
+
"is_complex": np.iscomplexobj(array),
|
53
|
+
# Depending of whether the tensor is complex or not, we will call the appropriate stats function
|
54
|
+
**array_stats(array),
|
55
|
+
}
|
56
|
+
|
57
|
+
|
58
|
+
@monkey_patch_contextmanager(dependencies=["numpy"])
|
59
|
+
def numpy_monkey_patch():
|
60
|
+
try:
|
61
|
+
np.set_printoptions(override_repr=numpy_repr)
|
62
|
+
logging.info(
|
63
|
+
f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
|
64
|
+
f"{np.get_printoptions()=}"
|
65
|
+
)
|
66
|
+
yield
|
67
|
+
finally:
|
68
|
+
np.set_printoptions(override_repr=None)
|
69
|
+
logging.info(
|
70
|
+
f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
|
71
|
+
f"{np.get_printoptions()=}"
|
72
|
+
)
|
@@ -0,0 +1,99 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ._base import lovely_repr, monkey_patch_contextmanager
|
8
|
+
from .utils import LovelyStats, array_stats, patch_to
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
import torch
|
12
|
+
|
13
|
+
|
14
|
+
def _type_name(tensor: torch.Tensor):
|
15
|
+
import torch
|
16
|
+
|
17
|
+
return (
|
18
|
+
"tensor"
|
19
|
+
if type(tensor) is torch.Tensor
|
20
|
+
else type(tensor).__name__.split(".")[-1]
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
_DT_NAMES = {
|
25
|
+
"float32": "", # Default dtype
|
26
|
+
"float16": "f16",
|
27
|
+
"float64": "f64",
|
28
|
+
"bfloat16": "bf16",
|
29
|
+
"uint8": "u8",
|
30
|
+
"int8": "i8",
|
31
|
+
"int16": "i16",
|
32
|
+
"int32": "i32",
|
33
|
+
"int64": "i64",
|
34
|
+
"complex32": "c32",
|
35
|
+
"complex64": "c64",
|
36
|
+
"complex128": "c128",
|
37
|
+
}
|
38
|
+
|
39
|
+
|
40
|
+
def _dtype_str(tensor: torch.Tensor) -> str:
|
41
|
+
dtype_base = str(tensor.dtype).rsplit(".", 1)[-1]
|
42
|
+
dtype_base = _DT_NAMES.get(dtype_base, dtype_base)
|
43
|
+
return dtype_base
|
44
|
+
|
45
|
+
|
46
|
+
def _to_np(tensor: torch.Tensor) -> np.ndarray:
|
47
|
+
import torch
|
48
|
+
|
49
|
+
# Get tensor data as CPU NumPy array for analysis
|
50
|
+
t_cpu = tensor.detach().cpu()
|
51
|
+
|
52
|
+
# Convert bfloat16 to float32 for numpy compatibility
|
53
|
+
if tensor.dtype == torch.bfloat16:
|
54
|
+
t_cpu = t_cpu.to(torch.float32)
|
55
|
+
|
56
|
+
# Convert to NumPy
|
57
|
+
t_np = t_cpu.numpy()
|
58
|
+
|
59
|
+
return t_np
|
60
|
+
|
61
|
+
|
62
|
+
@lovely_repr(dependencies=["torch"])
|
63
|
+
def torch_repr(tensor: torch.Tensor) -> LovelyStats:
|
64
|
+
return {
|
65
|
+
# Basic attributes
|
66
|
+
"shape": tensor.shape,
|
67
|
+
"size": tensor.numel(),
|
68
|
+
"nbytes": tensor.element_size() * tensor.numel(),
|
69
|
+
"type_name": _type_name(tensor),
|
70
|
+
# Device
|
71
|
+
"device": str(tensor.device) if tensor.device else None,
|
72
|
+
"is_meta": device.type == "meta" if (device := tensor.device) else False,
|
73
|
+
# Grad
|
74
|
+
"requires_grad": tensor.requires_grad,
|
75
|
+
# Dtype
|
76
|
+
"dtype_str": _dtype_str(tensor),
|
77
|
+
"is_complex": tensor.is_complex(),
|
78
|
+
# Depending of whether the tensor is complex or not, we will call the appropriate stats function
|
79
|
+
**array_stats(_to_np(tensor)),
|
80
|
+
}
|
81
|
+
|
82
|
+
|
83
|
+
@monkey_patch_contextmanager(dependencies=["torch"])
|
84
|
+
def torch_monkey_patch():
|
85
|
+
import torch
|
86
|
+
|
87
|
+
original_repr = torch.Tensor.__repr__
|
88
|
+
original_str = torch.Tensor.__str__
|
89
|
+
original_parameter_repr = torch.nn.Parameter.__repr__
|
90
|
+
try:
|
91
|
+
patch_to(torch.Tensor, "__repr__", torch_repr)
|
92
|
+
patch_to(torch.Tensor, "__str__", torch_repr)
|
93
|
+
del torch.nn.Parameter.__repr__
|
94
|
+
|
95
|
+
yield
|
96
|
+
finally:
|
97
|
+
patch_to(torch.Tensor, "__repr__", original_repr)
|
98
|
+
patch_to(torch.Tensor, "__str__", original_str)
|
99
|
+
patch_to(torch.nn.Parameter, "__repr__", original_parameter_repr)
|
nshutils/lovely/utils.py
ADDED
@@ -0,0 +1,345 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import sys
|
4
|
+
from collections import defaultdict
|
5
|
+
from collections.abc import Callable, Sequence
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from typing_extensions import TypedDict
|
10
|
+
|
11
|
+
from .config import LovelyConfig
|
12
|
+
|
13
|
+
|
14
|
+
class LovelyStats(TypedDict, total=False):
|
15
|
+
"""Statistics for tensor representation."""
|
16
|
+
|
17
|
+
# Basic tensor information
|
18
|
+
shape: Sequence[int]
|
19
|
+
size: int
|
20
|
+
type_name: str
|
21
|
+
dtype_str: str
|
22
|
+
device: str | None
|
23
|
+
nbytes: int
|
24
|
+
requires_grad: bool
|
25
|
+
is_meta: bool
|
26
|
+
|
27
|
+
# Content flags
|
28
|
+
all_zeros: bool
|
29
|
+
has_nan: bool
|
30
|
+
has_pos_inf: bool
|
31
|
+
has_neg_inf: bool
|
32
|
+
is_complex: bool
|
33
|
+
|
34
|
+
# Numeric statistics
|
35
|
+
min: float | None
|
36
|
+
max: float | None
|
37
|
+
mean: float | None
|
38
|
+
std: float | None
|
39
|
+
|
40
|
+
# Complex number statistics
|
41
|
+
mag_min: float | None
|
42
|
+
mag_max: float | None
|
43
|
+
real_min: float | None
|
44
|
+
real_max: float | None
|
45
|
+
imag_min: float | None
|
46
|
+
imag_max: float | None
|
47
|
+
|
48
|
+
# Representation
|
49
|
+
values_str: str | None
|
50
|
+
|
51
|
+
|
52
|
+
# Formatting utilities
|
53
|
+
def sci_mode(f: float) -> bool:
|
54
|
+
"""Determine if a float should be displayed in scientific notation."""
|
55
|
+
config = LovelyConfig.instance()
|
56
|
+
return (abs(f) < 10**config.threshold_min) or (abs(f) > 10**config.threshold_max)
|
57
|
+
|
58
|
+
|
59
|
+
def pretty_str(x: Any) -> str:
|
60
|
+
"""Format a number or array for pretty display.
|
61
|
+
|
62
|
+
Works with scalars, numpy arrays, torch tensors, and jax arrays.
|
63
|
+
"""
|
64
|
+
if isinstance(x, int):
|
65
|
+
return f"{x}"
|
66
|
+
elif isinstance(x, float):
|
67
|
+
if x == 0.0:
|
68
|
+
return "0."
|
69
|
+
|
70
|
+
sci = (
|
71
|
+
sci_mode(x)
|
72
|
+
if LovelyConfig.instance().sci_mode is None
|
73
|
+
else LovelyConfig.instance().sci_mode
|
74
|
+
)
|
75
|
+
fmt = f"{{:.{LovelyConfig.instance().precision}{'e' if sci else 'f'}}}"
|
76
|
+
return fmt.format(x)
|
77
|
+
elif isinstance(x, complex):
|
78
|
+
# Handle complex numbers
|
79
|
+
real_part = pretty_str(x.real)
|
80
|
+
imag_part = pretty_str(abs(x.imag))
|
81
|
+
sign = "+" if x.imag >= 0 else "-"
|
82
|
+
return f"{real_part}{sign}{imag_part}j"
|
83
|
+
|
84
|
+
# Handle array-like objects
|
85
|
+
try:
|
86
|
+
if hasattr(x, "ndim") and x.ndim == 0:
|
87
|
+
return pretty_str(x.item())
|
88
|
+
elif hasattr(x, "shape") and len(x.shape) > 0:
|
89
|
+
slices = [pretty_str(x[i]) for i in range(min(x.shape[0], 10))]
|
90
|
+
if x.shape[0] > 10:
|
91
|
+
slices.append("...")
|
92
|
+
return "[" + ", ".join(slices) + "]"
|
93
|
+
except:
|
94
|
+
pass
|
95
|
+
|
96
|
+
# Fallback
|
97
|
+
return str(x)
|
98
|
+
|
99
|
+
|
100
|
+
def sparse_join(items: list[str | None], sep: str = " ") -> str:
|
101
|
+
"""Join non-empty strings with a separator."""
|
102
|
+
return sep.join([item for item in items if item])
|
103
|
+
|
104
|
+
|
105
|
+
def ansi_color(s: str, col: str, use_color: bool = True) -> str:
|
106
|
+
"""Add ANSI color to a string if use_color is True."""
|
107
|
+
if not use_color:
|
108
|
+
return s
|
109
|
+
|
110
|
+
style = defaultdict(str)
|
111
|
+
style["grey"] = "\x1b[38;2;127;127;127m"
|
112
|
+
style["red"] = "\x1b[31m"
|
113
|
+
end_style = "\x1b[0m"
|
114
|
+
|
115
|
+
return style[col] + s + end_style
|
116
|
+
|
117
|
+
|
118
|
+
def bytes_to_human(num_bytes: int) -> str:
|
119
|
+
"""Convert bytes to a human-readable string (b, Kb, Mb, Gb)."""
|
120
|
+
units = ["b", "Kb", "Mb", "Gb"]
|
121
|
+
|
122
|
+
value = num_bytes
|
123
|
+
matched_unit: str | None = None
|
124
|
+
for unit in units:
|
125
|
+
if value < 1024 / 10:
|
126
|
+
matched_unit = unit
|
127
|
+
break
|
128
|
+
value /= 1024.0
|
129
|
+
|
130
|
+
assert matched_unit is not None, "No matching unit found"
|
131
|
+
|
132
|
+
if value % 1 == 0 or value >= 10:
|
133
|
+
return f"{round(value)}{matched_unit}"
|
134
|
+
else:
|
135
|
+
return f"{value:.1f}{matched_unit}"
|
136
|
+
|
137
|
+
|
138
|
+
def in_debugger() -> bool:
|
139
|
+
"""Returns True if running in a debugger."""
|
140
|
+
return getattr(sys, "gettrace", None) is not None and sys.gettrace() is not None
|
141
|
+
|
142
|
+
|
143
|
+
# Common tensor representation
|
144
|
+
def format_tensor_stats(tensor_stats: LovelyStats, color: bool | None = None) -> str:
|
145
|
+
"""Format tensor stats into a pretty string representation."""
|
146
|
+
conf = LovelyConfig.instance()
|
147
|
+
if color is None:
|
148
|
+
color = conf.color
|
149
|
+
if in_debugger():
|
150
|
+
color = False
|
151
|
+
|
152
|
+
# Basic tensor info
|
153
|
+
shape_str = str(list(shape)) if (shape := tensor_stats.get("shape")) else None
|
154
|
+
type_str = sparse_join([tensor_stats.get("type_name"), shape_str], sep="")
|
155
|
+
|
156
|
+
# Calculate memory usage
|
157
|
+
numel = None
|
158
|
+
if (size := tensor_stats.get("size")) and (nbytes := tensor_stats.get("nbytes")):
|
159
|
+
shape = tensor_stats.get("shape", [])
|
160
|
+
|
161
|
+
if shape and max(shape) != size:
|
162
|
+
numel = f"n={size}"
|
163
|
+
if conf.show_mem_above <= nbytes:
|
164
|
+
numel = sparse_join([numel, f"({bytes_to_human(nbytes)})"])
|
165
|
+
elif conf.show_mem_above <= nbytes:
|
166
|
+
numel = bytes_to_human(nbytes)
|
167
|
+
|
168
|
+
# Handle empty tensors
|
169
|
+
if tensor_stats.get("size", 0) == 0:
|
170
|
+
common = ansi_color("empty", "grey", color)
|
171
|
+
# Handle all zeros
|
172
|
+
elif tensor_stats.get("all_zeros"):
|
173
|
+
common = ansi_color("all_zeros", "grey", color)
|
174
|
+
# Handle complex tensors
|
175
|
+
elif tensor_stats.get("is_complex"):
|
176
|
+
complex_info = []
|
177
|
+
|
178
|
+
# For magnitude stats
|
179
|
+
if (mag_min := tensor_stats.get("mag_min")) is not None and (
|
180
|
+
mag_max := tensor_stats.get("mag_max")
|
181
|
+
) is not None:
|
182
|
+
complex_info.append(f"|z|∈[{pretty_str(mag_min)}, {pretty_str(mag_max)}]")
|
183
|
+
|
184
|
+
# For real part stats
|
185
|
+
if (real_min := tensor_stats.get("real_min")) is not None and (
|
186
|
+
real_max := tensor_stats.get("real_max")
|
187
|
+
) is not None:
|
188
|
+
complex_info.append(f"Re∈[{pretty_str(real_min)}, {pretty_str(real_max)}]")
|
189
|
+
|
190
|
+
# For imaginary part stats
|
191
|
+
if (imag_min := tensor_stats.get("imag_min")) is not None and (
|
192
|
+
imag_max := tensor_stats.get("imag_max")
|
193
|
+
) is not None:
|
194
|
+
complex_info.append(f"Im∈[{pretty_str(imag_min)}, {pretty_str(imag_max)}]")
|
195
|
+
|
196
|
+
common = sparse_join(complex_info)
|
197
|
+
# Handle normal tensors with stats
|
198
|
+
elif (min_val := tensor_stats.get("min")) is not None and (
|
199
|
+
max_val := tensor_stats.get("max")
|
200
|
+
) is not None:
|
201
|
+
minmax = None
|
202
|
+
meanstd = None
|
203
|
+
|
204
|
+
if tensor_stats.get("size", 0) > 2:
|
205
|
+
minmax = f"x∈[{pretty_str(min_val)}, {pretty_str(max_val)}]"
|
206
|
+
|
207
|
+
if (
|
208
|
+
(mean := tensor_stats.get("mean")) is not None
|
209
|
+
and (std := tensor_stats.get("std")) is not None
|
210
|
+
and tensor_stats.get("size", 0) >= 2
|
211
|
+
):
|
212
|
+
meanstd = f"μ={pretty_str(mean)} σ={pretty_str(std)}"
|
213
|
+
|
214
|
+
common = sparse_join([minmax, meanstd])
|
215
|
+
else:
|
216
|
+
common = None
|
217
|
+
|
218
|
+
# Handle warnings
|
219
|
+
warnings = []
|
220
|
+
if tensor_stats.get("has_nan"):
|
221
|
+
warnings.append(ansi_color("NaN!", "red", color))
|
222
|
+
if tensor_stats.get("has_pos_inf"):
|
223
|
+
warnings.append(ansi_color("+Inf!", "red", color))
|
224
|
+
if tensor_stats.get("has_neg_inf"):
|
225
|
+
warnings.append(ansi_color("-Inf!", "red", color))
|
226
|
+
|
227
|
+
attention = sparse_join(warnings)
|
228
|
+
common = sparse_join([common, attention])
|
229
|
+
|
230
|
+
# Other tensor attributes
|
231
|
+
dtype = tensor_stats.get("dtype_str", "")
|
232
|
+
device = tensor_stats.get("device")
|
233
|
+
grad = "grad" if tensor_stats.get("requires_grad") else None
|
234
|
+
|
235
|
+
# Format values for small tensors
|
236
|
+
vals = None
|
237
|
+
if (
|
238
|
+
0 < tensor_stats.get("size", 0) <= 10
|
239
|
+
and tensor_stats.get("is_meta", False) is False
|
240
|
+
):
|
241
|
+
vals = tensor_stats.get("values_str")
|
242
|
+
|
243
|
+
# Join all parts
|
244
|
+
result = sparse_join([type_str, dtype, numel, common, grad, device, vals])
|
245
|
+
|
246
|
+
return result
|
247
|
+
|
248
|
+
|
249
|
+
def real_stats(array: np.ndarray) -> LovelyStats:
|
250
|
+
stats: LovelyStats = {}
|
251
|
+
|
252
|
+
# Check for special values
|
253
|
+
stats["has_nan"] = bool(np.isnan(array).any())
|
254
|
+
stats["has_pos_inf"] = bool(np.isposinf(array).any())
|
255
|
+
stats["has_neg_inf"] = bool(np.isneginf(array).any())
|
256
|
+
|
257
|
+
# Only compute min/max/mean/std for good data
|
258
|
+
good_data = array[np.isfinite(array)]
|
259
|
+
|
260
|
+
if len(good_data) > 0:
|
261
|
+
stats["min"] = float(good_data.min())
|
262
|
+
stats["max"] = float(good_data.max())
|
263
|
+
stats["all_zeros"] = stats["min"] == 0 and stats["max"] == 0 and array.size > 1
|
264
|
+
|
265
|
+
if len(good_data) > 1:
|
266
|
+
stats["mean"] = float(good_data.mean())
|
267
|
+
stats["std"] = float(good_data.std())
|
268
|
+
|
269
|
+
# Get string representation of values for small tensors
|
270
|
+
if 0 < array.size <= 10:
|
271
|
+
stats["values_str"] = pretty_str(array)
|
272
|
+
|
273
|
+
return stats
|
274
|
+
|
275
|
+
|
276
|
+
def complex_stats(array: np.ndarray) -> LovelyStats:
|
277
|
+
stats: LovelyStats = {}
|
278
|
+
|
279
|
+
# Calculate magnitude (absolute value)
|
280
|
+
magnitude = np.abs(array)
|
281
|
+
|
282
|
+
# Check for special values in real or imaginary parts
|
283
|
+
stats["has_nan"] = bool(np.isnan(array.real).any() or np.isnan(array.imag).any())
|
284
|
+
|
285
|
+
# Get statistics for magnitude
|
286
|
+
good_mag = magnitude[np.isfinite(magnitude)]
|
287
|
+
if len(good_mag) > 0:
|
288
|
+
stats["mag_min"] = float(good_mag.min())
|
289
|
+
stats["mag_max"] = float(good_mag.max())
|
290
|
+
stats["all_zeros"] = (
|
291
|
+
stats["mag_min"] == 0 and stats["mag_max"] == 0 and array.size > 1
|
292
|
+
)
|
293
|
+
|
294
|
+
# Get statistics for real and imaginary parts
|
295
|
+
real_part = array.real
|
296
|
+
imag_part = array.imag
|
297
|
+
|
298
|
+
good_real = real_part[np.isfinite(real_part)]
|
299
|
+
good_imag = imag_part[np.isfinite(imag_part)]
|
300
|
+
|
301
|
+
if len(good_real := real_part[np.isfinite(real_part)]):
|
302
|
+
stats["real_min"] = float(good_real.min())
|
303
|
+
stats["real_max"] = float(good_real.max())
|
304
|
+
|
305
|
+
if len(good_imag := imag_part[np.isfinite(imag_part)]):
|
306
|
+
stats["imag_min"] = float(good_imag.min())
|
307
|
+
stats["imag_max"] = float(good_imag.max())
|
308
|
+
|
309
|
+
# Get string representation of values for small tensors
|
310
|
+
if 0 < array.size <= 10:
|
311
|
+
stats["values_str"] = pretty_str(array)
|
312
|
+
|
313
|
+
return stats
|
314
|
+
|
315
|
+
|
316
|
+
def array_stats(array: np.ndarray, ignore_empty: bool = True) -> LovelyStats:
|
317
|
+
"""Compute all statistics for a given array.
|
318
|
+
|
319
|
+
Args:
|
320
|
+
array (np.ndarray): The input array.
|
321
|
+
ignore_empty (bool): If True, ignore empty arrays.
|
322
|
+
|
323
|
+
Returns:
|
324
|
+
LovelyStats: A dictionary containing the computed statistics.
|
325
|
+
"""
|
326
|
+
if ignore_empty and array.size == 0:
|
327
|
+
return {}
|
328
|
+
|
329
|
+
if np.iscomplexobj(array):
|
330
|
+
return complex_stats(array)
|
331
|
+
else:
|
332
|
+
return real_stats(array)
|
333
|
+
|
334
|
+
|
335
|
+
def patch_to(
|
336
|
+
cls: type,
|
337
|
+
name: str,
|
338
|
+
func: Callable[[Any], Any],
|
339
|
+
as_property: bool = False,
|
340
|
+
) -> None:
|
341
|
+
"""Simple patch_to implementation to avoid fastcore dependency."""
|
342
|
+
if as_property:
|
343
|
+
setattr(cls, name, property(func))
|
344
|
+
else:
|
345
|
+
setattr(cls, name, func)
|
nshutils/typecheck.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import logging
|
3
4
|
import os
|
4
5
|
import sys
|
5
6
|
from collections.abc import Sequence
|
6
|
-
from logging import getLogger
|
7
7
|
from types import FrameType as _FrameType
|
8
8
|
from typing import Any
|
9
9
|
|
@@ -52,7 +52,8 @@ try:
|
|
52
52
|
import jax # type: ignore
|
53
53
|
except ImportError:
|
54
54
|
jax = None
|
55
|
-
|
55
|
+
|
56
|
+
log = logging.getLogger(__name__)
|
56
57
|
|
57
58
|
DISABLE_ENV_KEY = "NSHUTILS_DISABLE_TYPECHECKING"
|
58
59
|
|
nshutils/util.py
ADDED
@@ -0,0 +1,92 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
import functools
|
5
|
+
from collections.abc import Callable, Iterator
|
6
|
+
from typing import Any, Generic
|
7
|
+
|
8
|
+
from typing_extensions import ParamSpec, TypeVar, override
|
9
|
+
|
10
|
+
R = TypeVar("R")
|
11
|
+
P = ParamSpec("P")
|
12
|
+
|
13
|
+
|
14
|
+
class ContextResource(contextlib.AbstractContextManager[R], Generic[R]):
|
15
|
+
"""A class that provides both direct access to a resource and context management."""
|
16
|
+
|
17
|
+
def __init__(self, resource: R, cleanup_func: Callable[[R], Any]):
|
18
|
+
self.resource = resource
|
19
|
+
self._cleanup_func = cleanup_func
|
20
|
+
|
21
|
+
@override
|
22
|
+
def __enter__(self) -> R:
|
23
|
+
"""When used as a context manager, return the wrapped resource."""
|
24
|
+
return self.resource
|
25
|
+
|
26
|
+
@override
|
27
|
+
def __exit__(self, *exc_info) -> None:
|
28
|
+
"""Clean up the resource when exiting the context."""
|
29
|
+
self._cleanup_func(self.resource)
|
30
|
+
|
31
|
+
def close(self) -> None:
|
32
|
+
"""Explicitly clean up the resource."""
|
33
|
+
self._cleanup_func(self.resource)
|
34
|
+
|
35
|
+
|
36
|
+
def resource_factory(
|
37
|
+
create_func: Callable[P, R], cleanup_func: Callable[[R], None]
|
38
|
+
) -> Callable[P, ContextResource[R]]:
|
39
|
+
"""
|
40
|
+
Create a factory function that returns a ContextResource.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
create_func: Function that creates the resource
|
44
|
+
cleanup_func: Function that cleans up the resource
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
A function that returns a ContextResource wrapping the created resource
|
48
|
+
"""
|
49
|
+
|
50
|
+
@functools.wraps(create_func)
|
51
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextResource[R]:
|
52
|
+
resource = create_func(*args, **kwargs)
|
53
|
+
return ContextResource(resource, cleanup_func)
|
54
|
+
|
55
|
+
return wrapper
|
56
|
+
|
57
|
+
|
58
|
+
def resource_factory_from_context_fn(
|
59
|
+
context_func: Callable[P, contextlib.AbstractContextManager[R]],
|
60
|
+
) -> Callable[P, ContextResource[R]]:
|
61
|
+
"""
|
62
|
+
Create a factory function that returns a ContextResource.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
context_func: Function that creates the resource
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
A function that returns a ContextResource wrapping the created resource
|
69
|
+
"""
|
70
|
+
|
71
|
+
@functools.wraps(context_func)
|
72
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextResource[R]:
|
73
|
+
context = context_func(*args, **kwargs)
|
74
|
+
resource = context.__enter__()
|
75
|
+
return ContextResource(resource, lambda _: context.__exit__(None, None, None))
|
76
|
+
|
77
|
+
return wrapper
|
78
|
+
|
79
|
+
|
80
|
+
def resource_factory_contextmanager(
|
81
|
+
context_func: Callable[P, Iterator[R]],
|
82
|
+
) -> Callable[P, ContextResource[R]]:
|
83
|
+
"""
|
84
|
+
Create a factory function that returns a ContextResource.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
context_func: Generator function that creates the resource, yields it, and cleans up the resource when done.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
A function that returns a ContextResource wrapping the created resource
|
91
|
+
"""
|
92
|
+
return resource_factory_from_context_fn(contextlib.contextmanager(context_func))
|
@@ -1,24 +1,30 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: nshutils
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.21.0
|
4
4
|
Summary:
|
5
5
|
Author: Nima Shoghi
|
6
6
|
Author-email: nimashoghi@gmail.com
|
7
|
-
Requires-Python: >=3.
|
7
|
+
Requires-Python: >=3.9,<4.0
|
8
8
|
Classifier: Programming Language :: Python :: 3
|
9
|
+
Classifier: Programming Language :: Python :: 3.9
|
9
10
|
Classifier: Programming Language :: Python :: 3.10
|
10
11
|
Classifier: Programming Language :: Python :: 3.11
|
11
12
|
Classifier: Programming Language :: Python :: 3.12
|
12
13
|
Classifier: Programming Language :: Python :: 3.13
|
13
14
|
Provides-Extra: extra
|
15
|
+
Provides-Extra: pprint
|
16
|
+
Provides-Extra: snoop
|
14
17
|
Requires-Dist: beartype
|
15
18
|
Requires-Dist: jaxtyping
|
16
|
-
Requires-Dist:
|
17
|
-
Requires-Dist:
|
19
|
+
Requires-Dist: lazy-loader
|
20
|
+
Requires-Dist: nshconfig
|
18
21
|
Requires-Dist: numpy
|
19
22
|
Requires-Dist: pysnooper ; extra == "extra"
|
20
|
-
Requires-Dist:
|
21
|
-
Requires-Dist:
|
23
|
+
Requires-Dist: pysnooper ; extra == "snoop"
|
24
|
+
Requires-Dist: rich[jupyter] ; extra == "extra"
|
25
|
+
Requires-Dist: rich[jupyter] ; extra == "pprint"
|
26
|
+
Requires-Dist: treescope ; (python_version >= "3.10") and (extra == "extra")
|
27
|
+
Requires-Dist: treescope ; (python_version >= "3.10") and (extra == "pprint")
|
22
28
|
Requires-Dist: typing-extensions
|
23
29
|
Requires-Dist: uuid7
|
24
30
|
Project-URL: Homepage, https://github.com/nimashoghi/nshutils
|
@@ -0,0 +1,22 @@
|
|
1
|
+
nshutils/__init__.py,sha256=AFx1d5k34MyJ2kCHQL5vrZB8GDp2nYUaIUEjszSa25I,477
|
2
|
+
nshutils/__init__.pyi,sha256=ICbY2_XBAlXIVOGyK4PQpatmlUFHHc5-bqM4sfFZoAY,613
|
3
|
+
nshutils/actsave/__init__.py,sha256=hAVsog9d1g3_rQN1TRslrl6sK1PhCGbjy8PPUAmJI58,203
|
4
|
+
nshutils/actsave/_loader.py,sha256=mof3HezUNvLliz7macstX6ewXW05L0Mtv3zJyrbmImg,4640
|
5
|
+
nshutils/actsave/_saver.py,sha256=LulgC_B7oYtsGjW_I_pnSLrf1k9lmMvHqRxcmfqqrjU,10441
|
6
|
+
nshutils/collections.py,sha256=QWGyANmo4Efq4XRNHDSTE9tRLStwEZHGwE0ATHR-Vqo,5233
|
7
|
+
nshutils/display.py,sha256=Ge63yllx7gi-MKL3mKQeQ5doql_nj56-o5aoTVmusDg,1473
|
8
|
+
nshutils/logging.py,sha256=78pv3-I_gmbKSf5_mYYBr6_H4GNBGErghAdhH9wfYIc,2205
|
9
|
+
nshutils/lovely/__init__.py,sha256=gbWMNs7xfK1CiNdkHvfH0KcyaGjdZ8_WUBGfaEUDN4I,451
|
10
|
+
nshutils/lovely/_base.py,sha256=c2XxNJlEdET2mP2gLzlYY1KHsEN4H9eDD_x8SptuBTA,4277
|
11
|
+
nshutils/lovely/_monkey_patch_all.py,sha256=WZsC6Xp5-Z2GBd6xyZZEsD4C2xNyZy0YBfjZzAy3m8M,2028
|
12
|
+
nshutils/lovely/config.py,sha256=jsUK_kfEvthL94qxpHxM-Xobdv67sWZpnH_ag4HLTNo,1274
|
13
|
+
nshutils/lovely/jax_.py,sha256=mPH-tSOzWE27ymupllBnO4O6avvJJxpF2a1-G4dN5Ow,2214
|
14
|
+
nshutils/lovely/numpy_.py,sha256=iHA4BCJIW9IU6DXKEfYbh9RA2xyeXeL0tyHdEja89Sw,1866
|
15
|
+
nshutils/lovely/torch_.py,sha256=3wnXLa-1xwuQVk1fM50mBqOVDv5wZHjfnBwzOnjcFjg,2638
|
16
|
+
nshutils/lovely/utils.py,sha256=2ksT5YGVViFuWc8jSkwVCsABripJmyVJdEDDH7aab70,10459
|
17
|
+
nshutils/snoop.py,sha256=7d7_Q5sJmINL1J29wcnxEvpV95zvZYNoVn5frCq-rww,7393
|
18
|
+
nshutils/typecheck.py,sha256=Gi7xtfilN_UwZ1FTFqBVKDhcQzBEDonVxIv3bUj-uXY,5582
|
19
|
+
nshutils/util.py,sha256=tx-XiRbOrpafV3OkJDE5IVFtzn3kN7uSZ8FkMor0H5c,2845
|
20
|
+
nshutils-0.21.0.dist-info/METADATA,sha256=Qa4bVvnzN2Ip9fbmuaToECZbRSKhgj-zmaNqTKmXXJ4,4431
|
21
|
+
nshutils-0.21.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
22
|
+
nshutils-0.21.0.dist-info/RECORD,,
|
nshutils-0.19.1.dist-info/RECORD
DELETED
@@ -1,12 +0,0 @@
|
|
1
|
-
nshutils/__init__.py,sha256=9RJO-Bt7uN-KHf7Yi9G7s_5AL8n-dwF2SClNMNQpiQE,985
|
2
|
-
nshutils/actsave/__init__.py,sha256=hAVsog9d1g3_rQN1TRslrl6sK1PhCGbjy8PPUAmJI58,203
|
3
|
-
nshutils/actsave/_loader.py,sha256=mof3HezUNvLliz7macstX6ewXW05L0Mtv3zJyrbmImg,4640
|
4
|
-
nshutils/actsave/_saver.py,sha256=Kor7PEk__noRDEAfCZ_I4vxql5UfWSIQIQ1OmW2RRTI,10290
|
5
|
-
nshutils/collections.py,sha256=QWGyANmo4Efq4XRNHDSTE9tRLStwEZHGwE0ATHR-Vqo,5233
|
6
|
-
nshutils/display.py,sha256=Ge63yllx7gi-MKL3mKQeQ5doql_nj56-o5aoTVmusDg,1473
|
7
|
-
nshutils/logging.py,sha256=-6IB0GTDDS8ue1H2tzkv_OLf4bZVN1ywL08TlDZWbtQ,3737
|
8
|
-
nshutils/snoop.py,sha256=7d7_Q5sJmINL1J29wcnxEvpV95zvZYNoVn5frCq-rww,7393
|
9
|
-
nshutils/typecheck.py,sha256=UOUYfa72wTmc-a7VQw52tKFb4U10xq1qcZuEWc2sAd8,5588
|
10
|
-
nshutils-0.19.1.dist-info/METADATA,sha256=nxPB1rd9Lv_CnL03BAMWGvBnRXYZ4VLsFTYpyAlhSAY,4168
|
11
|
-
nshutils-0.19.1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
12
|
-
nshutils-0.19.1.dist-info/RECORD,,
|