nshutils 0.19.1__tar.gz → 0.20.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshutils
3
- Version: 0.19.1
3
+ Version: 0.20.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -11,14 +11,19 @@ Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
12
  Classifier: Programming Language :: Python :: 3.13
13
13
  Provides-Extra: extra
14
+ Provides-Extra: pprint
15
+ Provides-Extra: snoop
14
16
  Requires-Dist: beartype
15
17
  Requires-Dist: jaxtyping
16
- Requires-Dist: lovely-numpy ; extra == "extra"
17
- Requires-Dist: lovely-tensors ; extra == "extra"
18
+ Requires-Dist: lazy-loader
19
+ Requires-Dist: nshconfig
18
20
  Requires-Dist: numpy
19
21
  Requires-Dist: pysnooper ; extra == "extra"
20
- Requires-Dist: rich ; extra == "extra"
22
+ Requires-Dist: pysnooper ; extra == "snoop"
23
+ Requires-Dist: rich[jupyter] ; extra == "extra"
24
+ Requires-Dist: rich[jupyter] ; extra == "pprint"
21
25
  Requires-Dist: treescope ; extra == "extra"
26
+ Requires-Dist: treescope ; extra == "pprint"
22
27
  Requires-Dist: typing-extensions
23
28
  Requires-Dist: uuid7
24
29
  Project-URL: Homepage, https://github.com/nimashoghi/nshutils
@@ -1,15 +1,25 @@
1
1
  [project]
2
2
  name = "nshutils"
3
- version = "0.19.1"
3
+ version = "0.20.0"
4
4
  description = ""
5
5
  authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
6
6
  requires-python = ">=3.10,<4.0"
7
7
  readme = "README.md"
8
8
 
9
- dependencies = ["numpy", "jaxtyping", "typing-extensions", "beartype", "uuid7"]
9
+ dependencies = [
10
+ "lazy-loader",
11
+ "numpy",
12
+ "typing-extensions",
13
+ "jaxtyping",
14
+ "beartype",
15
+ "uuid7",
16
+ "nshconfig",
17
+ ]
10
18
 
11
19
  [project.optional-dependencies]
12
- extra = ["pysnooper", "lovely-numpy", "lovely-tensors", "rich", "treescope"]
20
+ snoop = ["pysnooper"]
21
+ pprint = ["rich[jupyter]", "treescope"]
22
+ extra = ["pysnooper", "rich[jupyter]", "treescope"]
13
23
 
14
24
  [project.urls]
15
25
  homepage = "https://github.com/nimashoghi/nshutils"
@@ -0,0 +1,20 @@
1
+ from __future__ import annotations
2
+
3
+ import lazy_loader as lazy
4
+
5
+ __getattr__, __dir__, __all__ = lazy.attach_stub(__name__, __file__)
6
+
7
+
8
+ try:
9
+ from importlib.metadata import PackageNotFoundError, version
10
+ except ImportError:
11
+ # For Python <3.8
12
+ from importlib_metadata import ( # pyright: ignore[reportMissingImports]
13
+ PackageNotFoundError,
14
+ version,
15
+ )
16
+
17
+ try:
18
+ __version__ = version(__name__)
19
+ except PackageNotFoundError:
20
+ __version__ = "unknown"
@@ -1,29 +1,13 @@
1
- from __future__ import annotations
2
-
3
1
  from . import actsave as actsave
2
+ from . import lovely as lovely
4
3
  from . import typecheck as typecheck
5
4
  from .actsave import ActLoad as ActLoad
6
5
  from .actsave import ActSave as ActSave
7
6
  from .collections import apply_to_collection as apply_to_collection
8
7
  from .display import display as display
9
8
  from .logging import init_python_logging as init_python_logging
10
- from .logging import lovely as lovely
11
- from .logging import pretty as pretty
9
+ from .logging import setup_logging as setup_logging
12
10
  from .snoop import snoop as snoop
13
11
  from .typecheck import tassert as tassert
14
12
  from .typecheck import typecheck_modules as typecheck_modules
15
13
  from .typecheck import typecheck_this_module as typecheck_this_module
16
-
17
- try:
18
- from importlib.metadata import PackageNotFoundError, version
19
- except ImportError:
20
- # For Python <3.8
21
- from importlib_metadata import ( # pyright: ignore[reportMissingImports]
22
- PackageNotFoundError,
23
- version,
24
- )
25
-
26
- try:
27
- __version__ = version(__name__)
28
- except PackageNotFoundError:
29
- __version__ = "unknown"
@@ -12,7 +12,7 @@ from pathlib import Path
12
12
  from typing import TYPE_CHECKING, Generic, TypeAlias, cast, overload
13
13
 
14
14
  import numpy as np
15
- from typing_extensions import Never, ParamSpec, TypeVar, override
15
+ from typing_extensions import Never, ParamSpec, TypeAliasType, TypeVar, override
16
16
  from uuid_extensions import uuid7str
17
17
 
18
18
  from ..collections import apply_to_collection
@@ -21,21 +21,23 @@ if not TYPE_CHECKING:
21
21
  try:
22
22
  import torch # type: ignore
23
23
 
24
- Tensor: TypeAlias = torch.Tensor
24
+ Tensor = torch.Tensor
25
25
  except ImportError:
26
26
  torch = None
27
27
 
28
- Tensor: TypeAlias = Never
28
+ Tensor = Never
29
29
  else:
30
30
  import torch # type: ignore
31
31
 
32
- Tensor: TypeAlias = torch.Tensor
32
+ Tensor = torch.Tensor
33
33
 
34
34
 
35
35
  log = getLogger(__name__)
36
36
 
37
- Value: TypeAlias = int | float | complex | bool | str | np.ndarray | Tensor | None
38
- ValueOrLambda: TypeAlias = Value | Callable[..., Value]
37
+ Value = TypeAliasType(
38
+ "Value", int | float | complex | bool | str | np.ndarray | Tensor | None
39
+ )
40
+ ValueOrLambda = TypeAliasType("ValueOrLambda", Value | Callable[..., Value])
39
41
 
40
42
 
41
43
  def _torch_is_scripting() -> bool:
@@ -0,0 +1,75 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from .lovely._monkey_patch_all import Library
9
+
10
+
11
+ def setup_logging(
12
+ *,
13
+ lovely: bool | list[Library] = False,
14
+ treescope: bool = False,
15
+ treescope_autovisualize_arrays: bool = False,
16
+ rich: bool = False,
17
+ rich_tracebacks: bool = False,
18
+ log_level: int | str | None = logging.INFO,
19
+ log_save_dir: Path | None = None,
20
+ ):
21
+ if lovely:
22
+ from .lovely._monkey_patch_all import monkey_patch
23
+
24
+ monkey_patch("auto" if lovely is True else lovely)
25
+
26
+ if treescope:
27
+ try:
28
+ # Check if we're in a Jupyter environment
29
+ from IPython import get_ipython
30
+
31
+ if get_ipython() is not None:
32
+ import treescope as _treescope # type: ignore
33
+
34
+ _treescope.basic_interactive_setup(
35
+ autovisualize_arrays=treescope_autovisualize_arrays
36
+ )
37
+ else:
38
+ logging.info(
39
+ "Treescope setup is only supported in Jupyter notebooks. Skipping."
40
+ )
41
+ except ImportError:
42
+ logging.info(
43
+ "Failed to import `treescope` or `IPython`. Ignoring `treescope` registration"
44
+ )
45
+
46
+ log_handlers: list[logging.Handler] = []
47
+ if log_save_dir:
48
+ log_file = log_save_dir / "logging.log"
49
+ log_file.touch(exist_ok=True)
50
+ log_handlers.append(logging.FileHandler(log_file))
51
+
52
+ if rich:
53
+ try:
54
+ from rich.logging import RichHandler # type: ignore
55
+
56
+ log_handlers.append(RichHandler(rich_tracebacks=rich_tracebacks))
57
+ except ImportError:
58
+ logging.info(
59
+ "Failed to import rich. Falling back to default Python logging."
60
+ )
61
+
62
+ logging.basicConfig(
63
+ level=log_level,
64
+ format="%(message)s",
65
+ datefmt="[%X]",
66
+ handlers=log_handlers,
67
+ )
68
+ logging.info(
69
+ "Logging initialized. "
70
+ f"Lovely: {lovely}, Treescope: {treescope}, Rich: {rich}, "
71
+ f"Log level: {log_level}, Log save dir: {log_save_dir}"
72
+ )
73
+
74
+
75
+ init_python_logging = setup_logging
@@ -0,0 +1,10 @@
1
+ from __future__ import annotations
2
+
3
+ from ._monkey_patch_all import monkey_patch as monkey_patch
4
+ from .config import LovelyConfig as LovelyConfig
5
+ from .jax_ import jax_monkey_patch as jax_monkey_patch
6
+ from .jax_ import jax_repr as jax_repr
7
+ from .numpy_ import numpy_monkey_patch as numpy_monkey_patch
8
+ from .numpy_ import numpy_repr as numpy_repr
9
+ from .torch_ import torch_monkey_patch as torch_monkey_patch
10
+ from .torch_ import torch_repr as torch_repr
@@ -0,0 +1,155 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import importlib.util
5
+ import logging
6
+ from collections.abc import Callable, Iterator
7
+
8
+ from typing_extensions import ParamSpec, TypeAliasType, TypeVar
9
+
10
+ from ..util import ContextResource, resource_factory_contextmanager
11
+ from .utils import LovelyStats, format_tensor_stats
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+ TArray = TypeVar("TArray", infer_variance=True)
16
+ P = ParamSpec("P")
17
+
18
+ LovelyStatsFn = TypeAliasType(
19
+ "LovelyStatsFn",
20
+ Callable[[TArray], LovelyStats],
21
+ type_params=(TArray,),
22
+ )
23
+ LovelyReprFn = TypeAliasType(
24
+ "LovelyReprFn",
25
+ Callable[[TArray], str],
26
+ type_params=(TArray,),
27
+ )
28
+
29
+
30
+ def _find_missing_deps(dependencies: list[str]):
31
+ missing_deps: list[str] = []
32
+
33
+ for dep in dependencies:
34
+ if importlib.util.find_spec(dep) is not None:
35
+ continue
36
+
37
+ missing_deps.append(dep)
38
+
39
+ return missing_deps
40
+
41
+
42
+ def lovely_repr(dependencies: list[str]):
43
+ """
44
+ Decorator to create a lovely representation function for an array.
45
+
46
+ Args:
47
+ dependencies: List of dependencies to check before running the function.
48
+ If any dependency is not available, the function will not run.
49
+
50
+ Returns:
51
+ A decorator function that takes a function and returns a lovely representation function.
52
+
53
+ Example:
54
+ @lovely_repr(dependencies=["torch"])
55
+ def my_array_stats(array):
56
+ return {...}
57
+ """
58
+
59
+ def decorator_fn(array_stats_fn: LovelyStatsFn[TArray]) -> LovelyReprFn[TArray]:
60
+ """
61
+ Decorator to create a lovely representation function for an array.
62
+
63
+ Args:
64
+ array_stats_fn: A function that takes an array and returns its stats.
65
+
66
+ Returns:
67
+ A function that takes an array and returns its lovely representation.
68
+ """
69
+
70
+ @functools.wraps(array_stats_fn)
71
+ def wrapper(array: TArray) -> str:
72
+ if missing_deps := _find_missing_deps(dependencies):
73
+ log.warning(
74
+ f"Missing dependencies: {', '.join(missing_deps)}. "
75
+ "Skipping lovely representation."
76
+ )
77
+ return repr(array)
78
+
79
+ stats = array_stats_fn(array)
80
+ return format_tensor_stats(stats)
81
+
82
+ return wrapper
83
+
84
+ return decorator_fn
85
+
86
+
87
+ LovelyMonkeyPatchInputFn = TypeAliasType(
88
+ "LovelyMonkeyPatchInputFn",
89
+ Callable[P, Iterator[None]],
90
+ type_params=(P,),
91
+ )
92
+ LovelyMonkeyPatchFn = TypeAliasType(
93
+ "LovelyMonkeyPatchFn",
94
+ Callable[P, ContextResource[None]],
95
+ type_params=(P,),
96
+ )
97
+
98
+
99
+ def _nullcontext_generator():
100
+ """A generator that does nothing."""
101
+ yield
102
+
103
+
104
+ def _wrap_monkey_patch_fn(
105
+ monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
106
+ dependencies: list[str],
107
+ ) -> LovelyMonkeyPatchInputFn[P]:
108
+ @functools.wraps(monkey_patch_fn)
109
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[None]:
110
+ if missing_deps := _find_missing_deps(dependencies):
111
+ log.warning(
112
+ f"Missing dependencies: {', '.join(missing_deps)}. "
113
+ "Skipping monkey patch."
114
+ )
115
+ return _nullcontext_generator()
116
+
117
+ return monkey_patch_fn(*args, **kwargs)
118
+
119
+ return wrapper
120
+
121
+
122
+ def monkey_patch_contextmanager(dependencies: list[str]):
123
+ """
124
+ Decorator to create a monkey patch function for an array.
125
+
126
+ Args:
127
+ dependencies: List of dependencies to check before running the function.
128
+ If any dependency is not available, the function will not run.
129
+
130
+ Returns:
131
+ A decorator function that takes a function and returns a monkey patch function.
132
+
133
+ Example:
134
+ @monkey_patch_contextmanager(dependencies=["torch"])
135
+ def my_array_monkey_patch():
136
+ ...
137
+ """
138
+
139
+ def decorator_fn(
140
+ monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
141
+ ) -> LovelyMonkeyPatchFn[P]:
142
+ """
143
+ Decorator to create a monkey patch function for an array.
144
+
145
+ Args:
146
+ monkey_patch_fn: A function that applies the monkey patch.
147
+
148
+ Returns:
149
+ A function that applies the monkey patch.
150
+ """
151
+
152
+ wrapped_fn = _wrap_monkey_patch_fn(monkey_patch_fn, dependencies)
153
+ return resource_factory_contextmanager(wrapped_fn)
154
+
155
+ return decorator_fn
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import importlib.util
5
+ import logging
6
+ from typing import Literal
7
+
8
+ from typing_extensions import TypeAliasType, assert_never
9
+
10
+ from ..util import resource_factory_contextmanager
11
+
12
+ Library = TypeAliasType("Library", Literal["numpy", "torch", "jax"])
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ def _find_deps() -> list[Library]:
18
+ """
19
+ Find available libraries for monkey patching.
20
+ """
21
+ deps: list[Library] = []
22
+ if importlib.util.find_spec("torch") is not None:
23
+ deps.append("torch")
24
+ if importlib.util.find_spec("jax") is not None:
25
+ deps.append("jax")
26
+ if importlib.util.find_spec("numpy") is not None:
27
+ deps.append("numpy")
28
+ return deps
29
+
30
+
31
+ @resource_factory_contextmanager
32
+ def monkey_patch(libraries: list[Library] | Literal["auto"] = "auto"):
33
+ if libraries == "auto":
34
+ libraries = _find_deps()
35
+
36
+ if not libraries:
37
+ raise ValueError(
38
+ "No libraries found for monkey patching. "
39
+ "Please install numpy, torch, or jax."
40
+ )
41
+
42
+ with contextlib.ExitStack() as stack:
43
+ for library in libraries:
44
+ match library:
45
+ case "torch":
46
+ from .torch_ import torch_monkey_patch
47
+
48
+ stack.enter_context(torch_monkey_patch())
49
+ case "jax":
50
+ from .jax_ import jax_monkey_patch
51
+
52
+ stack.enter_context(jax_monkey_patch())
53
+ case "numpy":
54
+ from .numpy_ import numpy_monkey_patch
55
+
56
+ stack.enter_context(numpy_monkey_patch())
57
+ case _:
58
+ assert_never(library)
59
+
60
+ log.info(
61
+ f"Monkey patched libraries: {', '.join(libraries)}. "
62
+ "You can now use the lovely functions with these libraries."
63
+ )
64
+ yield
65
+ log.info(
66
+ f"Unmonkey patched libraries: {', '.join(libraries)}. "
67
+ "You can now use the lovely functions with these libraries."
68
+ )
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import ClassVar
4
+
5
+ import nshconfig as C
6
+ from typing_extensions import Self
7
+
8
+
9
+ class LovelyConfig(C.Config):
10
+ """
11
+ This class is used to manage the configuration of the Lovely library.
12
+ It inherits from the Config class in the nshconfig module.
13
+ """
14
+
15
+ precision: int = 3
16
+ """Number of digits after the decimal point."""
17
+
18
+ threshold_max: int = 3
19
+ """Absolute values larger than 10^3 use scientific notation."""
20
+
21
+ threshold_min: int = -4
22
+ """Absolute values smaller than 10^-4 use scientific notation."""
23
+
24
+ sci_mode: bool | None = None
25
+ """Force scientific notation (None=auto)."""
26
+
27
+ show_mem_above: int = 1024
28
+ """Show memory size if above this threshold (bytes)."""
29
+
30
+ color: bool = True
31
+ """Use ANSI colors in text."""
32
+
33
+ indent: int = 2
34
+ """Indent for nested representation."""
35
+
36
+ config_instance: ClassVar[Self | None] = None
37
+ """Singleton instance of the LovelyConfig class."""
38
+
39
+ @classmethod
40
+ def instance(cls) -> Self:
41
+ """
42
+ Get the singleton instance of the LovelyConfig class.
43
+ If it doesn't exist, create it.
44
+ """
45
+ if cls.config_instance is None:
46
+ cls.config_instance = cls()
47
+ return cls.config_instance
@@ -0,0 +1,89 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, cast
4
+
5
+ import numpy as np
6
+
7
+ from ._base import lovely_repr, monkey_patch_contextmanager
8
+ from .utils import LovelyStats, array_stats, patch_to
9
+
10
+ if TYPE_CHECKING:
11
+ import jax
12
+
13
+
14
+ def _type_name(array: jax.Array):
15
+ type_name = type(array).__name__.rsplit(".", 1)[-1]
16
+ return "array" if type_name == "ArrayImpl" else type_name
17
+
18
+
19
+ _DT_NAMES = {
20
+ "float16": "f16",
21
+ "float32": "f32",
22
+ "float64": "f64",
23
+ "uint8": "u8",
24
+ "uint16": "u16",
25
+ "uint32": "u32",
26
+ "uint64": "u64",
27
+ "int8": "i8",
28
+ "int16": "i16",
29
+ "int32": "i32",
30
+ "int64": "i64",
31
+ "bfloat16": "bf16",
32
+ "complex64": "c64",
33
+ "complex128": "c128",
34
+ }
35
+
36
+
37
+ def _dtype_str(array: jax.Array) -> str:
38
+ dtype_base = str(array.dtype).rsplit(".", 1)[-1]
39
+ dtype_base = _DT_NAMES.get(dtype_base, dtype_base)
40
+ return dtype_base
41
+
42
+
43
+ def _device(array: jax.Array) -> str:
44
+ from jaxlib.xla_extension import Device
45
+
46
+ if callable(device := array.device):
47
+ device = device()
48
+
49
+ device = cast(Device, device)
50
+ if device.platform == "cpu":
51
+ return "cpu"
52
+
53
+ return f"{device.platform}:{device.id}"
54
+
55
+
56
+ @lovely_repr(dependencies=["jax"])
57
+ def jax_repr(array: jax.Array) -> LovelyStats:
58
+ import jax.numpy as jnp
59
+
60
+ return {
61
+ # Basic attributes
62
+ "shape": array.shape,
63
+ "size": array.size,
64
+ "nbytes": array.nbytes,
65
+ "type_name": _type_name(array),
66
+ # Dtype
67
+ "dtype_str": _dtype_str(array),
68
+ "is_complex": jnp.iscomplexobj(array),
69
+ # Device
70
+ "device": _device(array),
71
+ # Depending of whether the tensor is complex or not, we will call the appropriate stats function
72
+ **array_stats(np.asarray(array)),
73
+ }
74
+
75
+
76
+ @monkey_patch_contextmanager(dependencies=["jax"])
77
+ def jax_monkey_patch():
78
+ from jax._src import array
79
+
80
+ prev_repr = array.ArrayImpl.__repr__
81
+ prev_str = array.ArrayImpl.__str__
82
+ try:
83
+ patch_to(array.ArrayImpl, "__repr__", jax_repr)
84
+ patch_to(array.ArrayImpl, "__str__", jax_repr)
85
+
86
+ yield
87
+ finally:
88
+ patch_to(array.ArrayImpl, "__repr__", prev_repr)
89
+ patch_to(array.ArrayImpl, "__str__", prev_str)
@@ -0,0 +1,72 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+
5
+ import numpy as np
6
+
7
+ from ._base import lovely_repr, monkey_patch_contextmanager
8
+ from .utils import LovelyStats, array_stats
9
+
10
+
11
+ def _type_name(array: np.ndarray):
12
+ return (
13
+ "array"
14
+ if type(array) is np.ndarray
15
+ else type(array).__name__.rsplit(".", 1)[-1]
16
+ )
17
+
18
+
19
+ _DT_NAMES = {
20
+ "float16": "f16",
21
+ "float32": "f32",
22
+ "float64": "", # Default dtype in numpy
23
+ "uint8": "u8",
24
+ "uint16": "u16",
25
+ "uint32": "u32",
26
+ "uint64": "u64",
27
+ "int8": "i8",
28
+ "int16": "i16",
29
+ "int32": "i32",
30
+ "int64": "i64",
31
+ "complex64": "c64",
32
+ "complex128": "c128",
33
+ }
34
+
35
+
36
+ def _dtype_str(array: np.ndarray) -> str:
37
+ dtype_base = str(array.dtype).rsplit(".", 1)[-1]
38
+ dtype_base = _DT_NAMES.get(dtype_base, dtype_base)
39
+ return dtype_base
40
+
41
+
42
+ @lovely_repr(dependencies=["numpy"])
43
+ def numpy_repr(array: np.ndarray) -> LovelyStats:
44
+ return {
45
+ # Basic attributes
46
+ "shape": array.shape,
47
+ "size": array.size,
48
+ "nbytes": array.nbytes,
49
+ "type_name": _type_name(array),
50
+ # Dtype
51
+ "dtype_str": _dtype_str(array),
52
+ "is_complex": np.iscomplexobj(array),
53
+ # Depending of whether the tensor is complex or not, we will call the appropriate stats function
54
+ **array_stats(array),
55
+ }
56
+
57
+
58
+ @monkey_patch_contextmanager(dependencies=["numpy"])
59
+ def numpy_monkey_patch():
60
+ try:
61
+ np.set_printoptions(override_repr=numpy_repr)
62
+ logging.info(
63
+ f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
64
+ f"{np.get_printoptions()=}"
65
+ )
66
+ yield
67
+ finally:
68
+ np.set_printoptions(override_repr=None)
69
+ logging.info(
70
+ f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
71
+ f"{np.get_printoptions()=}"
72
+ )