nshutils 0.19.1__tar.gz → 0.21.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,24 +1,30 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshutils
3
- Version: 0.19.1
3
+ Version: 0.21.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
7
- Requires-Python: >=3.10,<4.0
7
+ Requires-Python: >=3.9,<4.0
8
8
  Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.9
9
10
  Classifier: Programming Language :: Python :: 3.10
10
11
  Classifier: Programming Language :: Python :: 3.11
11
12
  Classifier: Programming Language :: Python :: 3.12
12
13
  Classifier: Programming Language :: Python :: 3.13
13
14
  Provides-Extra: extra
15
+ Provides-Extra: pprint
16
+ Provides-Extra: snoop
14
17
  Requires-Dist: beartype
15
18
  Requires-Dist: jaxtyping
16
- Requires-Dist: lovely-numpy ; extra == "extra"
17
- Requires-Dist: lovely-tensors ; extra == "extra"
19
+ Requires-Dist: lazy-loader
20
+ Requires-Dist: nshconfig
18
21
  Requires-Dist: numpy
19
22
  Requires-Dist: pysnooper ; extra == "extra"
20
- Requires-Dist: rich ; extra == "extra"
21
- Requires-Dist: treescope ; extra == "extra"
23
+ Requires-Dist: pysnooper ; extra == "snoop"
24
+ Requires-Dist: rich[jupyter] ; extra == "extra"
25
+ Requires-Dist: rich[jupyter] ; extra == "pprint"
26
+ Requires-Dist: treescope ; (python_version >= "3.10") and (extra == "extra")
27
+ Requires-Dist: treescope ; (python_version >= "3.10") and (extra == "pprint")
22
28
  Requires-Dist: typing-extensions
23
29
  Requires-Dist: uuid7
24
30
  Project-URL: Homepage, https://github.com/nimashoghi/nshutils
@@ -1,15 +1,25 @@
1
1
  [project]
2
2
  name = "nshutils"
3
- version = "0.19.1"
3
+ version = "0.21.0"
4
4
  description = ""
5
5
  authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
6
- requires-python = ">=3.10,<4.0"
6
+ requires-python = ">=3.9,<4.0"
7
7
  readme = "README.md"
8
8
 
9
- dependencies = ["numpy", "jaxtyping", "typing-extensions", "beartype", "uuid7"]
9
+ dependencies = [
10
+ "lazy-loader",
11
+ "numpy",
12
+ "typing-extensions",
13
+ "jaxtyping",
14
+ "beartype",
15
+ "uuid7",
16
+ "nshconfig",
17
+ ]
10
18
 
11
19
  [project.optional-dependencies]
12
- extra = ["pysnooper", "lovely-numpy", "lovely-tensors", "rich", "treescope"]
20
+ snoop = ["pysnooper"]
21
+ pprint = ["rich[jupyter]", "treescope; python_version >= '3.10'"]
22
+ extra = ["pysnooper", "rich[jupyter]", "treescope; python_version >= '3.10'"]
13
23
 
14
24
  [project.urls]
15
25
  homepage = "https://github.com/nimashoghi/nshutils"
@@ -0,0 +1,20 @@
1
+ from __future__ import annotations
2
+
3
+ import lazy_loader as lazy
4
+
5
+ __getattr__, __dir__, __all__ = lazy.attach_stub(__name__, __file__)
6
+
7
+
8
+ try:
9
+ from importlib.metadata import PackageNotFoundError, version
10
+ except ImportError:
11
+ # For Python <3.8
12
+ from importlib_metadata import ( # pyright: ignore[reportMissingImports]
13
+ PackageNotFoundError,
14
+ version,
15
+ )
16
+
17
+ try:
18
+ __version__ = version(__name__)
19
+ except PackageNotFoundError:
20
+ __version__ = "unknown"
@@ -1,29 +1,13 @@
1
- from __future__ import annotations
2
-
3
1
  from . import actsave as actsave
2
+ from . import lovely as lovely
4
3
  from . import typecheck as typecheck
5
4
  from .actsave import ActLoad as ActLoad
6
5
  from .actsave import ActSave as ActSave
7
6
  from .collections import apply_to_collection as apply_to_collection
8
7
  from .display import display as display
9
8
  from .logging import init_python_logging as init_python_logging
10
- from .logging import lovely as lovely
11
- from .logging import pretty as pretty
9
+ from .logging import setup_logging as setup_logging
12
10
  from .snoop import snoop as snoop
13
11
  from .typecheck import tassert as tassert
14
12
  from .typecheck import typecheck_modules as typecheck_modules
15
13
  from .typecheck import typecheck_this_module as typecheck_this_module
16
-
17
- try:
18
- from importlib.metadata import PackageNotFoundError, version
19
- except ImportError:
20
- # For Python <3.8
21
- from importlib_metadata import ( # pyright: ignore[reportMissingImports]
22
- PackageNotFoundError,
23
- version,
24
- )
25
-
26
- try:
27
- __version__ = version(__name__)
28
- except PackageNotFoundError:
29
- __version__ = "unknown"
@@ -9,10 +9,10 @@ from dataclasses import dataclass
9
9
  from functools import wraps
10
10
  from logging import getLogger
11
11
  from pathlib import Path
12
- from typing import TYPE_CHECKING, Generic, TypeAlias, cast, overload
12
+ from typing import TYPE_CHECKING, Generic, Literal, cast, overload
13
13
 
14
14
  import numpy as np
15
- from typing_extensions import Never, ParamSpec, TypeVar, override
15
+ from typing_extensions import Never, ParamSpec, TypeAliasType, TypeVar, override
16
16
  from uuid_extensions import uuid7str
17
17
 
18
18
  from ..collections import apply_to_collection
@@ -21,25 +21,29 @@ if not TYPE_CHECKING:
21
21
  try:
22
22
  import torch # type: ignore
23
23
 
24
- Tensor: TypeAlias = torch.Tensor
24
+ Tensor = torch.Tensor
25
+ _torch_installed = True
25
26
  except ImportError:
26
27
  torch = None
28
+ _torch_installed = False
27
29
 
28
- Tensor: TypeAlias = Never
30
+ Tensor = Never
29
31
  else:
30
32
  import torch # type: ignore
31
33
 
32
- Tensor: TypeAlias = torch.Tensor
33
-
34
+ Tensor = torch.Tensor
35
+ _torch_installed: Literal[True] = True
34
36
 
35
37
  log = getLogger(__name__)
36
38
 
37
- Value: TypeAlias = int | float | complex | bool | str | np.ndarray | Tensor | None
38
- ValueOrLambda: TypeAlias = Value | Callable[..., Value]
39
+ Value = TypeAliasType(
40
+ "Value", int | float | complex | bool | str | np.ndarray | Tensor | None
41
+ )
42
+ ValueOrLambda = TypeAliasType("ValueOrLambda", Value | Callable[..., Value])
39
43
 
40
44
 
41
45
  def _torch_is_scripting() -> bool:
42
- if torch is None:
46
+ if _torch_installed:
43
47
  return False
44
48
 
45
49
  return torch.jit.is_scripting()
@@ -54,7 +58,7 @@ def _to_numpy(activation: Value) -> np.ndarray:
54
58
  return np.array(activation)
55
59
  elif isinstance(activation, np.ndarray):
56
60
  return activation
57
- elif isinstance(activation, Tensor):
61
+ elif _torch_installed and isinstance(activation, Tensor):
58
62
  activation_ = activation.detach()
59
63
  if activation_.is_floating_point():
60
64
  # NOTE: We need to convert to float32 because [b]float16 is not supported by numpy
@@ -0,0 +1,75 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from .lovely._monkey_patch_all import Library
9
+
10
+
11
+ def setup_logging(
12
+ *,
13
+ lovely: bool | list[Library] = False,
14
+ treescope: bool = False,
15
+ treescope_autovisualize_arrays: bool = False,
16
+ rich: bool = False,
17
+ rich_tracebacks: bool = False,
18
+ log_level: int | str | None = logging.INFO,
19
+ log_save_dir: Path | None = None,
20
+ ):
21
+ if lovely:
22
+ from .lovely._monkey_patch_all import monkey_patch
23
+
24
+ monkey_patch("auto" if lovely is True else lovely)
25
+
26
+ if treescope:
27
+ try:
28
+ # Check if we're in a Jupyter environment
29
+ from IPython import get_ipython
30
+
31
+ if get_ipython() is not None:
32
+ import treescope as _treescope # type: ignore
33
+
34
+ _treescope.basic_interactive_setup(
35
+ autovisualize_arrays=treescope_autovisualize_arrays
36
+ )
37
+ else:
38
+ logging.info(
39
+ "Treescope setup is only supported in Jupyter notebooks. Skipping."
40
+ )
41
+ except ImportError:
42
+ logging.info(
43
+ "Failed to import `treescope` or `IPython`. Ignoring `treescope` registration"
44
+ )
45
+
46
+ log_handlers: list[logging.Handler] = []
47
+ if log_save_dir:
48
+ log_file = log_save_dir / "logging.log"
49
+ log_file.touch(exist_ok=True)
50
+ log_handlers.append(logging.FileHandler(log_file))
51
+
52
+ if rich:
53
+ try:
54
+ from rich.logging import RichHandler # type: ignore
55
+
56
+ log_handlers.append(RichHandler(rich_tracebacks=rich_tracebacks))
57
+ except ImportError:
58
+ logging.info(
59
+ "Failed to import rich. Falling back to default Python logging."
60
+ )
61
+
62
+ logging.basicConfig(
63
+ level=log_level,
64
+ format="%(message)s",
65
+ datefmt="[%X]",
66
+ handlers=log_handlers,
67
+ )
68
+ logging.info(
69
+ "Logging initialized. "
70
+ f"Lovely: {lovely}, Treescope: {treescope}, Rich: {rich}, "
71
+ f"Log level: {log_level}, Log save dir: {log_save_dir}"
72
+ )
73
+
74
+
75
+ init_python_logging = setup_logging
@@ -0,0 +1,10 @@
1
+ from __future__ import annotations
2
+
3
+ from ._monkey_patch_all import monkey_patch as monkey_patch
4
+ from .config import LovelyConfig as LovelyConfig
5
+ from .jax_ import jax_monkey_patch as jax_monkey_patch
6
+ from .jax_ import jax_repr as jax_repr
7
+ from .numpy_ import numpy_monkey_patch as numpy_monkey_patch
8
+ from .numpy_ import numpy_repr as numpy_repr
9
+ from .torch_ import torch_monkey_patch as torch_monkey_patch
10
+ from .torch_ import torch_repr as torch_repr
@@ -0,0 +1,155 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import importlib.util
5
+ import logging
6
+ from collections.abc import Callable, Iterator
7
+
8
+ from typing_extensions import ParamSpec, TypeAliasType, TypeVar
9
+
10
+ from ..util import ContextResource, resource_factory_contextmanager
11
+ from .utils import LovelyStats, format_tensor_stats
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+ TArray = TypeVar("TArray", infer_variance=True)
16
+ P = ParamSpec("P")
17
+
18
+ LovelyStatsFn = TypeAliasType(
19
+ "LovelyStatsFn",
20
+ Callable[[TArray], LovelyStats],
21
+ type_params=(TArray,),
22
+ )
23
+ LovelyReprFn = TypeAliasType(
24
+ "LovelyReprFn",
25
+ Callable[[TArray], str],
26
+ type_params=(TArray,),
27
+ )
28
+
29
+
30
+ def _find_missing_deps(dependencies: list[str]):
31
+ missing_deps: list[str] = []
32
+
33
+ for dep in dependencies:
34
+ if importlib.util.find_spec(dep) is not None:
35
+ continue
36
+
37
+ missing_deps.append(dep)
38
+
39
+ return missing_deps
40
+
41
+
42
+ def lovely_repr(dependencies: list[str]):
43
+ """
44
+ Decorator to create a lovely representation function for an array.
45
+
46
+ Args:
47
+ dependencies: List of dependencies to check before running the function.
48
+ If any dependency is not available, the function will not run.
49
+
50
+ Returns:
51
+ A decorator function that takes a function and returns a lovely representation function.
52
+
53
+ Example:
54
+ @lovely_repr(dependencies=["torch"])
55
+ def my_array_stats(array):
56
+ return {...}
57
+ """
58
+
59
+ def decorator_fn(array_stats_fn: LovelyStatsFn[TArray]) -> LovelyReprFn[TArray]:
60
+ """
61
+ Decorator to create a lovely representation function for an array.
62
+
63
+ Args:
64
+ array_stats_fn: A function that takes an array and returns its stats.
65
+
66
+ Returns:
67
+ A function that takes an array and returns its lovely representation.
68
+ """
69
+
70
+ @functools.wraps(array_stats_fn)
71
+ def wrapper(array: TArray) -> str:
72
+ if missing_deps := _find_missing_deps(dependencies):
73
+ log.warning(
74
+ f"Missing dependencies: {', '.join(missing_deps)}. "
75
+ "Skipping lovely representation."
76
+ )
77
+ return repr(array)
78
+
79
+ stats = array_stats_fn(array)
80
+ return format_tensor_stats(stats)
81
+
82
+ return wrapper
83
+
84
+ return decorator_fn
85
+
86
+
87
+ LovelyMonkeyPatchInputFn = TypeAliasType(
88
+ "LovelyMonkeyPatchInputFn",
89
+ Callable[P, Iterator[None]],
90
+ type_params=(P,),
91
+ )
92
+ LovelyMonkeyPatchFn = TypeAliasType(
93
+ "LovelyMonkeyPatchFn",
94
+ Callable[P, ContextResource[None]],
95
+ type_params=(P,),
96
+ )
97
+
98
+
99
+ def _nullcontext_generator():
100
+ """A generator that does nothing."""
101
+ yield
102
+
103
+
104
+ def _wrap_monkey_patch_fn(
105
+ monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
106
+ dependencies: list[str],
107
+ ) -> LovelyMonkeyPatchInputFn[P]:
108
+ @functools.wraps(monkey_patch_fn)
109
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[None]:
110
+ if missing_deps := _find_missing_deps(dependencies):
111
+ log.warning(
112
+ f"Missing dependencies: {', '.join(missing_deps)}. "
113
+ "Skipping monkey patch."
114
+ )
115
+ return _nullcontext_generator()
116
+
117
+ return monkey_patch_fn(*args, **kwargs)
118
+
119
+ return wrapper
120
+
121
+
122
+ def monkey_patch_contextmanager(dependencies: list[str]):
123
+ """
124
+ Decorator to create a monkey patch function for an array.
125
+
126
+ Args:
127
+ dependencies: List of dependencies to check before running the function.
128
+ If any dependency is not available, the function will not run.
129
+
130
+ Returns:
131
+ A decorator function that takes a function and returns a monkey patch function.
132
+
133
+ Example:
134
+ @monkey_patch_contextmanager(dependencies=["torch"])
135
+ def my_array_monkey_patch():
136
+ ...
137
+ """
138
+
139
+ def decorator_fn(
140
+ monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
141
+ ) -> LovelyMonkeyPatchFn[P]:
142
+ """
143
+ Decorator to create a monkey patch function for an array.
144
+
145
+ Args:
146
+ monkey_patch_fn: A function that applies the monkey patch.
147
+
148
+ Returns:
149
+ A function that applies the monkey patch.
150
+ """
151
+
152
+ wrapped_fn = _wrap_monkey_patch_fn(monkey_patch_fn, dependencies)
153
+ return resource_factory_contextmanager(wrapped_fn)
154
+
155
+ return decorator_fn
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import importlib.util
5
+ import logging
6
+ from typing import Literal
7
+
8
+ from typing_extensions import TypeAliasType, assert_never
9
+
10
+ from ..util import resource_factory_contextmanager
11
+
12
+ Library = TypeAliasType("Library", Literal["numpy", "torch", "jax"])
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ def _find_deps() -> list[Library]:
18
+ """
19
+ Find available libraries for monkey patching.
20
+ """
21
+ deps: list[Library] = []
22
+ if importlib.util.find_spec("torch") is not None:
23
+ deps.append("torch")
24
+ if importlib.util.find_spec("jax") is not None:
25
+ deps.append("jax")
26
+ if importlib.util.find_spec("numpy") is not None:
27
+ deps.append("numpy")
28
+ return deps
29
+
30
+
31
+ @resource_factory_contextmanager
32
+ def monkey_patch(libraries: list[Library] | Literal["auto"] = "auto"):
33
+ if libraries == "auto":
34
+ libraries = _find_deps()
35
+
36
+ if not libraries:
37
+ raise ValueError(
38
+ "No libraries found for monkey patching. "
39
+ "Please install numpy, torch, or jax."
40
+ )
41
+
42
+ with contextlib.ExitStack() as stack:
43
+ for library in libraries:
44
+ match library:
45
+ case "torch":
46
+ from .torch_ import torch_monkey_patch
47
+
48
+ stack.enter_context(torch_monkey_patch())
49
+ case "jax":
50
+ from .jax_ import jax_monkey_patch
51
+
52
+ stack.enter_context(jax_monkey_patch())
53
+ case "numpy":
54
+ from .numpy_ import numpy_monkey_patch
55
+
56
+ stack.enter_context(numpy_monkey_patch())
57
+ case _:
58
+ assert_never(library)
59
+
60
+ log.info(
61
+ f"Monkey patched libraries: {', '.join(libraries)}. "
62
+ "You can now use the lovely functions with these libraries."
63
+ )
64
+ yield
65
+ log.info(
66
+ f"Unmonkey patched libraries: {', '.join(libraries)}. "
67
+ "You can now use the lovely functions with these libraries."
68
+ )
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import ClassVar
4
+
5
+ import nshconfig as C
6
+ from typing_extensions import Self
7
+
8
+
9
+ class LovelyConfig(C.Config):
10
+ """
11
+ This class is used to manage the configuration of the Lovely library.
12
+ It inherits from the Config class in the nshconfig module.
13
+ """
14
+
15
+ precision: int = 3
16
+ """Number of digits after the decimal point."""
17
+
18
+ threshold_max: int = 3
19
+ """Absolute values larger than 10^3 use scientific notation."""
20
+
21
+ threshold_min: int = -4
22
+ """Absolute values smaller than 10^-4 use scientific notation."""
23
+
24
+ sci_mode: bool | None = None
25
+ """Force scientific notation (None=auto)."""
26
+
27
+ show_mem_above: int = 1024
28
+ """Show memory size if above this threshold (bytes)."""
29
+
30
+ color: bool = True
31
+ """Use ANSI colors in text."""
32
+
33
+ indent: int = 2
34
+ """Indent for nested representation."""
35
+
36
+ config_instance: ClassVar[Self | None] = None
37
+ """Singleton instance of the LovelyConfig class."""
38
+
39
+ @classmethod
40
+ def instance(cls) -> Self:
41
+ """
42
+ Get the singleton instance of the LovelyConfig class.
43
+ If it doesn't exist, create it.
44
+ """
45
+ if cls.config_instance is None:
46
+ cls.config_instance = cls()
47
+ return cls.config_instance
@@ -0,0 +1,89 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, cast
4
+
5
+ import numpy as np
6
+
7
+ from ._base import lovely_repr, monkey_patch_contextmanager
8
+ from .utils import LovelyStats, array_stats, patch_to
9
+
10
+ if TYPE_CHECKING:
11
+ import jax
12
+
13
+
14
+ def _type_name(array: jax.Array):
15
+ type_name = type(array).__name__.rsplit(".", 1)[-1]
16
+ return "array" if type_name == "ArrayImpl" else type_name
17
+
18
+
19
+ _DT_NAMES = {
20
+ "float16": "f16",
21
+ "float32": "f32",
22
+ "float64": "f64",
23
+ "uint8": "u8",
24
+ "uint16": "u16",
25
+ "uint32": "u32",
26
+ "uint64": "u64",
27
+ "int8": "i8",
28
+ "int16": "i16",
29
+ "int32": "i32",
30
+ "int64": "i64",
31
+ "bfloat16": "bf16",
32
+ "complex64": "c64",
33
+ "complex128": "c128",
34
+ }
35
+
36
+
37
+ def _dtype_str(array: jax.Array) -> str:
38
+ dtype_base = str(array.dtype).rsplit(".", 1)[-1]
39
+ dtype_base = _DT_NAMES.get(dtype_base, dtype_base)
40
+ return dtype_base
41
+
42
+
43
+ def _device(array: jax.Array) -> str:
44
+ from jaxlib.xla_extension import Device
45
+
46
+ if callable(device := array.device):
47
+ device = device()
48
+
49
+ device = cast(Device, device)
50
+ if device.platform == "cpu":
51
+ return "cpu"
52
+
53
+ return f"{device.platform}:{device.id}"
54
+
55
+
56
+ @lovely_repr(dependencies=["jax"])
57
+ def jax_repr(array: jax.Array) -> LovelyStats:
58
+ import jax.numpy as jnp
59
+
60
+ return {
61
+ # Basic attributes
62
+ "shape": array.shape,
63
+ "size": array.size,
64
+ "nbytes": array.nbytes,
65
+ "type_name": _type_name(array),
66
+ # Dtype
67
+ "dtype_str": _dtype_str(array),
68
+ "is_complex": jnp.iscomplexobj(array),
69
+ # Device
70
+ "device": _device(array),
71
+ # Depending of whether the tensor is complex or not, we will call the appropriate stats function
72
+ **array_stats(np.asarray(array)),
73
+ }
74
+
75
+
76
+ @monkey_patch_contextmanager(dependencies=["jax"])
77
+ def jax_monkey_patch():
78
+ from jax._src import array
79
+
80
+ prev_repr = array.ArrayImpl.__repr__
81
+ prev_str = array.ArrayImpl.__str__
82
+ try:
83
+ patch_to(array.ArrayImpl, "__repr__", jax_repr)
84
+ patch_to(array.ArrayImpl, "__str__", jax_repr)
85
+
86
+ yield
87
+ finally:
88
+ patch_to(array.ArrayImpl, "__repr__", prev_repr)
89
+ patch_to(array.ArrayImpl, "__str__", prev_str)