nshutils 0.19.1__tar.gz → 0.20.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshutils-0.19.1 → nshutils-0.20.0}/PKG-INFO +9 -4
- {nshutils-0.19.1 → nshutils-0.20.0}/pyproject.toml +13 -3
- nshutils-0.20.0/src/nshutils/__init__.py +20 -0
- nshutils-0.19.1/src/nshutils/__init__.py → nshutils-0.20.0/src/nshutils/__init__.pyi +2 -18
- {nshutils-0.19.1 → nshutils-0.20.0}/src/nshutils/actsave/_saver.py +8 -6
- nshutils-0.20.0/src/nshutils/logging.py +75 -0
- nshutils-0.20.0/src/nshutils/lovely/__init__.py +10 -0
- nshutils-0.20.0/src/nshutils/lovely/_base.py +155 -0
- nshutils-0.20.0/src/nshutils/lovely/_monkey_patch_all.py +68 -0
- nshutils-0.20.0/src/nshutils/lovely/config.py +47 -0
- nshutils-0.20.0/src/nshutils/lovely/jax_.py +89 -0
- nshutils-0.20.0/src/nshutils/lovely/numpy_.py +72 -0
- nshutils-0.20.0/src/nshutils/lovely/torch_.py +99 -0
- nshutils-0.20.0/src/nshutils/lovely/utils.py +345 -0
- {nshutils-0.19.1 → nshutils-0.20.0}/src/nshutils/typecheck.py +3 -2
- nshutils-0.20.0/src/nshutils/util.py +92 -0
- nshutils-0.19.1/src/nshutils/logging.py +0 -125
- {nshutils-0.19.1 → nshutils-0.20.0}/README.md +0 -0
- {nshutils-0.19.1 → nshutils-0.20.0}/src/nshutils/actsave/__init__.py +0 -0
- {nshutils-0.19.1 → nshutils-0.20.0}/src/nshutils/actsave/_loader.py +0 -0
- {nshutils-0.19.1 → nshutils-0.20.0}/src/nshutils/collections.py +0 -0
- {nshutils-0.19.1 → nshutils-0.20.0}/src/nshutils/display.py +0 -0
- {nshutils-0.19.1 → nshutils-0.20.0}/src/nshutils/snoop.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: nshutils
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.20.0
|
4
4
|
Summary:
|
5
5
|
Author: Nima Shoghi
|
6
6
|
Author-email: nimashoghi@gmail.com
|
@@ -11,14 +11,19 @@ Classifier: Programming Language :: Python :: 3.11
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.12
|
12
12
|
Classifier: Programming Language :: Python :: 3.13
|
13
13
|
Provides-Extra: extra
|
14
|
+
Provides-Extra: pprint
|
15
|
+
Provides-Extra: snoop
|
14
16
|
Requires-Dist: beartype
|
15
17
|
Requires-Dist: jaxtyping
|
16
|
-
Requires-Dist:
|
17
|
-
Requires-Dist:
|
18
|
+
Requires-Dist: lazy-loader
|
19
|
+
Requires-Dist: nshconfig
|
18
20
|
Requires-Dist: numpy
|
19
21
|
Requires-Dist: pysnooper ; extra == "extra"
|
20
|
-
Requires-Dist:
|
22
|
+
Requires-Dist: pysnooper ; extra == "snoop"
|
23
|
+
Requires-Dist: rich[jupyter] ; extra == "extra"
|
24
|
+
Requires-Dist: rich[jupyter] ; extra == "pprint"
|
21
25
|
Requires-Dist: treescope ; extra == "extra"
|
26
|
+
Requires-Dist: treescope ; extra == "pprint"
|
22
27
|
Requires-Dist: typing-extensions
|
23
28
|
Requires-Dist: uuid7
|
24
29
|
Project-URL: Homepage, https://github.com/nimashoghi/nshutils
|
@@ -1,15 +1,25 @@
|
|
1
1
|
[project]
|
2
2
|
name = "nshutils"
|
3
|
-
version = "0.
|
3
|
+
version = "0.20.0"
|
4
4
|
description = ""
|
5
5
|
authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
|
6
6
|
requires-python = ">=3.10,<4.0"
|
7
7
|
readme = "README.md"
|
8
8
|
|
9
|
-
dependencies = [
|
9
|
+
dependencies = [
|
10
|
+
"lazy-loader",
|
11
|
+
"numpy",
|
12
|
+
"typing-extensions",
|
13
|
+
"jaxtyping",
|
14
|
+
"beartype",
|
15
|
+
"uuid7",
|
16
|
+
"nshconfig",
|
17
|
+
]
|
10
18
|
|
11
19
|
[project.optional-dependencies]
|
12
|
-
|
20
|
+
snoop = ["pysnooper"]
|
21
|
+
pprint = ["rich[jupyter]", "treescope"]
|
22
|
+
extra = ["pysnooper", "rich[jupyter]", "treescope"]
|
13
23
|
|
14
24
|
[project.urls]
|
15
25
|
homepage = "https://github.com/nimashoghi/nshutils"
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import lazy_loader as lazy
|
4
|
+
|
5
|
+
__getattr__, __dir__, __all__ = lazy.attach_stub(__name__, __file__)
|
6
|
+
|
7
|
+
|
8
|
+
try:
|
9
|
+
from importlib.metadata import PackageNotFoundError, version
|
10
|
+
except ImportError:
|
11
|
+
# For Python <3.8
|
12
|
+
from importlib_metadata import ( # pyright: ignore[reportMissingImports]
|
13
|
+
PackageNotFoundError,
|
14
|
+
version,
|
15
|
+
)
|
16
|
+
|
17
|
+
try:
|
18
|
+
__version__ = version(__name__)
|
19
|
+
except PackageNotFoundError:
|
20
|
+
__version__ = "unknown"
|
@@ -1,29 +1,13 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
1
|
from . import actsave as actsave
|
2
|
+
from . import lovely as lovely
|
4
3
|
from . import typecheck as typecheck
|
5
4
|
from .actsave import ActLoad as ActLoad
|
6
5
|
from .actsave import ActSave as ActSave
|
7
6
|
from .collections import apply_to_collection as apply_to_collection
|
8
7
|
from .display import display as display
|
9
8
|
from .logging import init_python_logging as init_python_logging
|
10
|
-
from .logging import
|
11
|
-
from .logging import pretty as pretty
|
9
|
+
from .logging import setup_logging as setup_logging
|
12
10
|
from .snoop import snoop as snoop
|
13
11
|
from .typecheck import tassert as tassert
|
14
12
|
from .typecheck import typecheck_modules as typecheck_modules
|
15
13
|
from .typecheck import typecheck_this_module as typecheck_this_module
|
16
|
-
|
17
|
-
try:
|
18
|
-
from importlib.metadata import PackageNotFoundError, version
|
19
|
-
except ImportError:
|
20
|
-
# For Python <3.8
|
21
|
-
from importlib_metadata import ( # pyright: ignore[reportMissingImports]
|
22
|
-
PackageNotFoundError,
|
23
|
-
version,
|
24
|
-
)
|
25
|
-
|
26
|
-
try:
|
27
|
-
__version__ = version(__name__)
|
28
|
-
except PackageNotFoundError:
|
29
|
-
__version__ = "unknown"
|
@@ -12,7 +12,7 @@ from pathlib import Path
|
|
12
12
|
from typing import TYPE_CHECKING, Generic, TypeAlias, cast, overload
|
13
13
|
|
14
14
|
import numpy as np
|
15
|
-
from typing_extensions import Never, ParamSpec, TypeVar, override
|
15
|
+
from typing_extensions import Never, ParamSpec, TypeAliasType, TypeVar, override
|
16
16
|
from uuid_extensions import uuid7str
|
17
17
|
|
18
18
|
from ..collections import apply_to_collection
|
@@ -21,21 +21,23 @@ if not TYPE_CHECKING:
|
|
21
21
|
try:
|
22
22
|
import torch # type: ignore
|
23
23
|
|
24
|
-
Tensor
|
24
|
+
Tensor = torch.Tensor
|
25
25
|
except ImportError:
|
26
26
|
torch = None
|
27
27
|
|
28
|
-
Tensor
|
28
|
+
Tensor = Never
|
29
29
|
else:
|
30
30
|
import torch # type: ignore
|
31
31
|
|
32
|
-
Tensor
|
32
|
+
Tensor = torch.Tensor
|
33
33
|
|
34
34
|
|
35
35
|
log = getLogger(__name__)
|
36
36
|
|
37
|
-
Value
|
38
|
-
|
37
|
+
Value = TypeAliasType(
|
38
|
+
"Value", int | float | complex | bool | str | np.ndarray | Tensor | None
|
39
|
+
)
|
40
|
+
ValueOrLambda = TypeAliasType("ValueOrLambda", Value | Callable[..., Value])
|
39
41
|
|
40
42
|
|
41
43
|
def _torch_is_scripting() -> bool:
|
@@ -0,0 +1,75 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from .lovely._monkey_patch_all import Library
|
9
|
+
|
10
|
+
|
11
|
+
def setup_logging(
|
12
|
+
*,
|
13
|
+
lovely: bool | list[Library] = False,
|
14
|
+
treescope: bool = False,
|
15
|
+
treescope_autovisualize_arrays: bool = False,
|
16
|
+
rich: bool = False,
|
17
|
+
rich_tracebacks: bool = False,
|
18
|
+
log_level: int | str | None = logging.INFO,
|
19
|
+
log_save_dir: Path | None = None,
|
20
|
+
):
|
21
|
+
if lovely:
|
22
|
+
from .lovely._monkey_patch_all import monkey_patch
|
23
|
+
|
24
|
+
monkey_patch("auto" if lovely is True else lovely)
|
25
|
+
|
26
|
+
if treescope:
|
27
|
+
try:
|
28
|
+
# Check if we're in a Jupyter environment
|
29
|
+
from IPython import get_ipython
|
30
|
+
|
31
|
+
if get_ipython() is not None:
|
32
|
+
import treescope as _treescope # type: ignore
|
33
|
+
|
34
|
+
_treescope.basic_interactive_setup(
|
35
|
+
autovisualize_arrays=treescope_autovisualize_arrays
|
36
|
+
)
|
37
|
+
else:
|
38
|
+
logging.info(
|
39
|
+
"Treescope setup is only supported in Jupyter notebooks. Skipping."
|
40
|
+
)
|
41
|
+
except ImportError:
|
42
|
+
logging.info(
|
43
|
+
"Failed to import `treescope` or `IPython`. Ignoring `treescope` registration"
|
44
|
+
)
|
45
|
+
|
46
|
+
log_handlers: list[logging.Handler] = []
|
47
|
+
if log_save_dir:
|
48
|
+
log_file = log_save_dir / "logging.log"
|
49
|
+
log_file.touch(exist_ok=True)
|
50
|
+
log_handlers.append(logging.FileHandler(log_file))
|
51
|
+
|
52
|
+
if rich:
|
53
|
+
try:
|
54
|
+
from rich.logging import RichHandler # type: ignore
|
55
|
+
|
56
|
+
log_handlers.append(RichHandler(rich_tracebacks=rich_tracebacks))
|
57
|
+
except ImportError:
|
58
|
+
logging.info(
|
59
|
+
"Failed to import rich. Falling back to default Python logging."
|
60
|
+
)
|
61
|
+
|
62
|
+
logging.basicConfig(
|
63
|
+
level=log_level,
|
64
|
+
format="%(message)s",
|
65
|
+
datefmt="[%X]",
|
66
|
+
handlers=log_handlers,
|
67
|
+
)
|
68
|
+
logging.info(
|
69
|
+
"Logging initialized. "
|
70
|
+
f"Lovely: {lovely}, Treescope: {treescope}, Rich: {rich}, "
|
71
|
+
f"Log level: {log_level}, Log save dir: {log_save_dir}"
|
72
|
+
)
|
73
|
+
|
74
|
+
|
75
|
+
init_python_logging = setup_logging
|
@@ -0,0 +1,10 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from ._monkey_patch_all import monkey_patch as monkey_patch
|
4
|
+
from .config import LovelyConfig as LovelyConfig
|
5
|
+
from .jax_ import jax_monkey_patch as jax_monkey_patch
|
6
|
+
from .jax_ import jax_repr as jax_repr
|
7
|
+
from .numpy_ import numpy_monkey_patch as numpy_monkey_patch
|
8
|
+
from .numpy_ import numpy_repr as numpy_repr
|
9
|
+
from .torch_ import torch_monkey_patch as torch_monkey_patch
|
10
|
+
from .torch_ import torch_repr as torch_repr
|
@@ -0,0 +1,155 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import functools
|
4
|
+
import importlib.util
|
5
|
+
import logging
|
6
|
+
from collections.abc import Callable, Iterator
|
7
|
+
|
8
|
+
from typing_extensions import ParamSpec, TypeAliasType, TypeVar
|
9
|
+
|
10
|
+
from ..util import ContextResource, resource_factory_contextmanager
|
11
|
+
from .utils import LovelyStats, format_tensor_stats
|
12
|
+
|
13
|
+
log = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
TArray = TypeVar("TArray", infer_variance=True)
|
16
|
+
P = ParamSpec("P")
|
17
|
+
|
18
|
+
LovelyStatsFn = TypeAliasType(
|
19
|
+
"LovelyStatsFn",
|
20
|
+
Callable[[TArray], LovelyStats],
|
21
|
+
type_params=(TArray,),
|
22
|
+
)
|
23
|
+
LovelyReprFn = TypeAliasType(
|
24
|
+
"LovelyReprFn",
|
25
|
+
Callable[[TArray], str],
|
26
|
+
type_params=(TArray,),
|
27
|
+
)
|
28
|
+
|
29
|
+
|
30
|
+
def _find_missing_deps(dependencies: list[str]):
|
31
|
+
missing_deps: list[str] = []
|
32
|
+
|
33
|
+
for dep in dependencies:
|
34
|
+
if importlib.util.find_spec(dep) is not None:
|
35
|
+
continue
|
36
|
+
|
37
|
+
missing_deps.append(dep)
|
38
|
+
|
39
|
+
return missing_deps
|
40
|
+
|
41
|
+
|
42
|
+
def lovely_repr(dependencies: list[str]):
|
43
|
+
"""
|
44
|
+
Decorator to create a lovely representation function for an array.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
dependencies: List of dependencies to check before running the function.
|
48
|
+
If any dependency is not available, the function will not run.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
A decorator function that takes a function and returns a lovely representation function.
|
52
|
+
|
53
|
+
Example:
|
54
|
+
@lovely_repr(dependencies=["torch"])
|
55
|
+
def my_array_stats(array):
|
56
|
+
return {...}
|
57
|
+
"""
|
58
|
+
|
59
|
+
def decorator_fn(array_stats_fn: LovelyStatsFn[TArray]) -> LovelyReprFn[TArray]:
|
60
|
+
"""
|
61
|
+
Decorator to create a lovely representation function for an array.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
array_stats_fn: A function that takes an array and returns its stats.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
A function that takes an array and returns its lovely representation.
|
68
|
+
"""
|
69
|
+
|
70
|
+
@functools.wraps(array_stats_fn)
|
71
|
+
def wrapper(array: TArray) -> str:
|
72
|
+
if missing_deps := _find_missing_deps(dependencies):
|
73
|
+
log.warning(
|
74
|
+
f"Missing dependencies: {', '.join(missing_deps)}. "
|
75
|
+
"Skipping lovely representation."
|
76
|
+
)
|
77
|
+
return repr(array)
|
78
|
+
|
79
|
+
stats = array_stats_fn(array)
|
80
|
+
return format_tensor_stats(stats)
|
81
|
+
|
82
|
+
return wrapper
|
83
|
+
|
84
|
+
return decorator_fn
|
85
|
+
|
86
|
+
|
87
|
+
LovelyMonkeyPatchInputFn = TypeAliasType(
|
88
|
+
"LovelyMonkeyPatchInputFn",
|
89
|
+
Callable[P, Iterator[None]],
|
90
|
+
type_params=(P,),
|
91
|
+
)
|
92
|
+
LovelyMonkeyPatchFn = TypeAliasType(
|
93
|
+
"LovelyMonkeyPatchFn",
|
94
|
+
Callable[P, ContextResource[None]],
|
95
|
+
type_params=(P,),
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
def _nullcontext_generator():
|
100
|
+
"""A generator that does nothing."""
|
101
|
+
yield
|
102
|
+
|
103
|
+
|
104
|
+
def _wrap_monkey_patch_fn(
|
105
|
+
monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
|
106
|
+
dependencies: list[str],
|
107
|
+
) -> LovelyMonkeyPatchInputFn[P]:
|
108
|
+
@functools.wraps(monkey_patch_fn)
|
109
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[None]:
|
110
|
+
if missing_deps := _find_missing_deps(dependencies):
|
111
|
+
log.warning(
|
112
|
+
f"Missing dependencies: {', '.join(missing_deps)}. "
|
113
|
+
"Skipping monkey patch."
|
114
|
+
)
|
115
|
+
return _nullcontext_generator()
|
116
|
+
|
117
|
+
return monkey_patch_fn(*args, **kwargs)
|
118
|
+
|
119
|
+
return wrapper
|
120
|
+
|
121
|
+
|
122
|
+
def monkey_patch_contextmanager(dependencies: list[str]):
|
123
|
+
"""
|
124
|
+
Decorator to create a monkey patch function for an array.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
dependencies: List of dependencies to check before running the function.
|
128
|
+
If any dependency is not available, the function will not run.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
A decorator function that takes a function and returns a monkey patch function.
|
132
|
+
|
133
|
+
Example:
|
134
|
+
@monkey_patch_contextmanager(dependencies=["torch"])
|
135
|
+
def my_array_monkey_patch():
|
136
|
+
...
|
137
|
+
"""
|
138
|
+
|
139
|
+
def decorator_fn(
|
140
|
+
monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
|
141
|
+
) -> LovelyMonkeyPatchFn[P]:
|
142
|
+
"""
|
143
|
+
Decorator to create a monkey patch function for an array.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
monkey_patch_fn: A function that applies the monkey patch.
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
A function that applies the monkey patch.
|
150
|
+
"""
|
151
|
+
|
152
|
+
wrapped_fn = _wrap_monkey_patch_fn(monkey_patch_fn, dependencies)
|
153
|
+
return resource_factory_contextmanager(wrapped_fn)
|
154
|
+
|
155
|
+
return decorator_fn
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
import importlib.util
|
5
|
+
import logging
|
6
|
+
from typing import Literal
|
7
|
+
|
8
|
+
from typing_extensions import TypeAliasType, assert_never
|
9
|
+
|
10
|
+
from ..util import resource_factory_contextmanager
|
11
|
+
|
12
|
+
Library = TypeAliasType("Library", Literal["numpy", "torch", "jax"])
|
13
|
+
|
14
|
+
log = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
def _find_deps() -> list[Library]:
|
18
|
+
"""
|
19
|
+
Find available libraries for monkey patching.
|
20
|
+
"""
|
21
|
+
deps: list[Library] = []
|
22
|
+
if importlib.util.find_spec("torch") is not None:
|
23
|
+
deps.append("torch")
|
24
|
+
if importlib.util.find_spec("jax") is not None:
|
25
|
+
deps.append("jax")
|
26
|
+
if importlib.util.find_spec("numpy") is not None:
|
27
|
+
deps.append("numpy")
|
28
|
+
return deps
|
29
|
+
|
30
|
+
|
31
|
+
@resource_factory_contextmanager
|
32
|
+
def monkey_patch(libraries: list[Library] | Literal["auto"] = "auto"):
|
33
|
+
if libraries == "auto":
|
34
|
+
libraries = _find_deps()
|
35
|
+
|
36
|
+
if not libraries:
|
37
|
+
raise ValueError(
|
38
|
+
"No libraries found for monkey patching. "
|
39
|
+
"Please install numpy, torch, or jax."
|
40
|
+
)
|
41
|
+
|
42
|
+
with contextlib.ExitStack() as stack:
|
43
|
+
for library in libraries:
|
44
|
+
match library:
|
45
|
+
case "torch":
|
46
|
+
from .torch_ import torch_monkey_patch
|
47
|
+
|
48
|
+
stack.enter_context(torch_monkey_patch())
|
49
|
+
case "jax":
|
50
|
+
from .jax_ import jax_monkey_patch
|
51
|
+
|
52
|
+
stack.enter_context(jax_monkey_patch())
|
53
|
+
case "numpy":
|
54
|
+
from .numpy_ import numpy_monkey_patch
|
55
|
+
|
56
|
+
stack.enter_context(numpy_monkey_patch())
|
57
|
+
case _:
|
58
|
+
assert_never(library)
|
59
|
+
|
60
|
+
log.info(
|
61
|
+
f"Monkey patched libraries: {', '.join(libraries)}. "
|
62
|
+
"You can now use the lovely functions with these libraries."
|
63
|
+
)
|
64
|
+
yield
|
65
|
+
log.info(
|
66
|
+
f"Unmonkey patched libraries: {', '.join(libraries)}. "
|
67
|
+
"You can now use the lovely functions with these libraries."
|
68
|
+
)
|
@@ -0,0 +1,47 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import ClassVar
|
4
|
+
|
5
|
+
import nshconfig as C
|
6
|
+
from typing_extensions import Self
|
7
|
+
|
8
|
+
|
9
|
+
class LovelyConfig(C.Config):
|
10
|
+
"""
|
11
|
+
This class is used to manage the configuration of the Lovely library.
|
12
|
+
It inherits from the Config class in the nshconfig module.
|
13
|
+
"""
|
14
|
+
|
15
|
+
precision: int = 3
|
16
|
+
"""Number of digits after the decimal point."""
|
17
|
+
|
18
|
+
threshold_max: int = 3
|
19
|
+
"""Absolute values larger than 10^3 use scientific notation."""
|
20
|
+
|
21
|
+
threshold_min: int = -4
|
22
|
+
"""Absolute values smaller than 10^-4 use scientific notation."""
|
23
|
+
|
24
|
+
sci_mode: bool | None = None
|
25
|
+
"""Force scientific notation (None=auto)."""
|
26
|
+
|
27
|
+
show_mem_above: int = 1024
|
28
|
+
"""Show memory size if above this threshold (bytes)."""
|
29
|
+
|
30
|
+
color: bool = True
|
31
|
+
"""Use ANSI colors in text."""
|
32
|
+
|
33
|
+
indent: int = 2
|
34
|
+
"""Indent for nested representation."""
|
35
|
+
|
36
|
+
config_instance: ClassVar[Self | None] = None
|
37
|
+
"""Singleton instance of the LovelyConfig class."""
|
38
|
+
|
39
|
+
@classmethod
|
40
|
+
def instance(cls) -> Self:
|
41
|
+
"""
|
42
|
+
Get the singleton instance of the LovelyConfig class.
|
43
|
+
If it doesn't exist, create it.
|
44
|
+
"""
|
45
|
+
if cls.config_instance is None:
|
46
|
+
cls.config_instance = cls()
|
47
|
+
return cls.config_instance
|
@@ -0,0 +1,89 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, cast
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ._base import lovely_repr, monkey_patch_contextmanager
|
8
|
+
from .utils import LovelyStats, array_stats, patch_to
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
import jax
|
12
|
+
|
13
|
+
|
14
|
+
def _type_name(array: jax.Array):
|
15
|
+
type_name = type(array).__name__.rsplit(".", 1)[-1]
|
16
|
+
return "array" if type_name == "ArrayImpl" else type_name
|
17
|
+
|
18
|
+
|
19
|
+
_DT_NAMES = {
|
20
|
+
"float16": "f16",
|
21
|
+
"float32": "f32",
|
22
|
+
"float64": "f64",
|
23
|
+
"uint8": "u8",
|
24
|
+
"uint16": "u16",
|
25
|
+
"uint32": "u32",
|
26
|
+
"uint64": "u64",
|
27
|
+
"int8": "i8",
|
28
|
+
"int16": "i16",
|
29
|
+
"int32": "i32",
|
30
|
+
"int64": "i64",
|
31
|
+
"bfloat16": "bf16",
|
32
|
+
"complex64": "c64",
|
33
|
+
"complex128": "c128",
|
34
|
+
}
|
35
|
+
|
36
|
+
|
37
|
+
def _dtype_str(array: jax.Array) -> str:
|
38
|
+
dtype_base = str(array.dtype).rsplit(".", 1)[-1]
|
39
|
+
dtype_base = _DT_NAMES.get(dtype_base, dtype_base)
|
40
|
+
return dtype_base
|
41
|
+
|
42
|
+
|
43
|
+
def _device(array: jax.Array) -> str:
|
44
|
+
from jaxlib.xla_extension import Device
|
45
|
+
|
46
|
+
if callable(device := array.device):
|
47
|
+
device = device()
|
48
|
+
|
49
|
+
device = cast(Device, device)
|
50
|
+
if device.platform == "cpu":
|
51
|
+
return "cpu"
|
52
|
+
|
53
|
+
return f"{device.platform}:{device.id}"
|
54
|
+
|
55
|
+
|
56
|
+
@lovely_repr(dependencies=["jax"])
|
57
|
+
def jax_repr(array: jax.Array) -> LovelyStats:
|
58
|
+
import jax.numpy as jnp
|
59
|
+
|
60
|
+
return {
|
61
|
+
# Basic attributes
|
62
|
+
"shape": array.shape,
|
63
|
+
"size": array.size,
|
64
|
+
"nbytes": array.nbytes,
|
65
|
+
"type_name": _type_name(array),
|
66
|
+
# Dtype
|
67
|
+
"dtype_str": _dtype_str(array),
|
68
|
+
"is_complex": jnp.iscomplexobj(array),
|
69
|
+
# Device
|
70
|
+
"device": _device(array),
|
71
|
+
# Depending of whether the tensor is complex or not, we will call the appropriate stats function
|
72
|
+
**array_stats(np.asarray(array)),
|
73
|
+
}
|
74
|
+
|
75
|
+
|
76
|
+
@monkey_patch_contextmanager(dependencies=["jax"])
|
77
|
+
def jax_monkey_patch():
|
78
|
+
from jax._src import array
|
79
|
+
|
80
|
+
prev_repr = array.ArrayImpl.__repr__
|
81
|
+
prev_str = array.ArrayImpl.__str__
|
82
|
+
try:
|
83
|
+
patch_to(array.ArrayImpl, "__repr__", jax_repr)
|
84
|
+
patch_to(array.ArrayImpl, "__str__", jax_repr)
|
85
|
+
|
86
|
+
yield
|
87
|
+
finally:
|
88
|
+
patch_to(array.ArrayImpl, "__repr__", prev_repr)
|
89
|
+
patch_to(array.ArrayImpl, "__str__", prev_str)
|
@@ -0,0 +1,72 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ._base import lovely_repr, monkey_patch_contextmanager
|
8
|
+
from .utils import LovelyStats, array_stats
|
9
|
+
|
10
|
+
|
11
|
+
def _type_name(array: np.ndarray):
|
12
|
+
return (
|
13
|
+
"array"
|
14
|
+
if type(array) is np.ndarray
|
15
|
+
else type(array).__name__.rsplit(".", 1)[-1]
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
_DT_NAMES = {
|
20
|
+
"float16": "f16",
|
21
|
+
"float32": "f32",
|
22
|
+
"float64": "", # Default dtype in numpy
|
23
|
+
"uint8": "u8",
|
24
|
+
"uint16": "u16",
|
25
|
+
"uint32": "u32",
|
26
|
+
"uint64": "u64",
|
27
|
+
"int8": "i8",
|
28
|
+
"int16": "i16",
|
29
|
+
"int32": "i32",
|
30
|
+
"int64": "i64",
|
31
|
+
"complex64": "c64",
|
32
|
+
"complex128": "c128",
|
33
|
+
}
|
34
|
+
|
35
|
+
|
36
|
+
def _dtype_str(array: np.ndarray) -> str:
|
37
|
+
dtype_base = str(array.dtype).rsplit(".", 1)[-1]
|
38
|
+
dtype_base = _DT_NAMES.get(dtype_base, dtype_base)
|
39
|
+
return dtype_base
|
40
|
+
|
41
|
+
|
42
|
+
@lovely_repr(dependencies=["numpy"])
|
43
|
+
def numpy_repr(array: np.ndarray) -> LovelyStats:
|
44
|
+
return {
|
45
|
+
# Basic attributes
|
46
|
+
"shape": array.shape,
|
47
|
+
"size": array.size,
|
48
|
+
"nbytes": array.nbytes,
|
49
|
+
"type_name": _type_name(array),
|
50
|
+
# Dtype
|
51
|
+
"dtype_str": _dtype_str(array),
|
52
|
+
"is_complex": np.iscomplexobj(array),
|
53
|
+
# Depending of whether the tensor is complex or not, we will call the appropriate stats function
|
54
|
+
**array_stats(array),
|
55
|
+
}
|
56
|
+
|
57
|
+
|
58
|
+
@monkey_patch_contextmanager(dependencies=["numpy"])
|
59
|
+
def numpy_monkey_patch():
|
60
|
+
try:
|
61
|
+
np.set_printoptions(override_repr=numpy_repr)
|
62
|
+
logging.info(
|
63
|
+
f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
|
64
|
+
f"{np.get_printoptions()=}"
|
65
|
+
)
|
66
|
+
yield
|
67
|
+
finally:
|
68
|
+
np.set_printoptions(override_repr=None)
|
69
|
+
logging.info(
|
70
|
+
f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
|
71
|
+
f"{np.get_printoptions()=}"
|
72
|
+
)
|