nshutils 0.19.0__tar.gz → 0.20.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshutils
3
- Version: 0.19.0
3
+ Version: 0.20.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -11,14 +11,19 @@ Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
12
  Classifier: Programming Language :: Python :: 3.13
13
13
  Provides-Extra: extra
14
+ Provides-Extra: pprint
15
+ Provides-Extra: snoop
14
16
  Requires-Dist: beartype
15
17
  Requires-Dist: jaxtyping
16
- Requires-Dist: lovely-numpy ; extra == "extra"
17
- Requires-Dist: lovely-tensors ; extra == "extra"
18
+ Requires-Dist: lazy-loader
19
+ Requires-Dist: nshconfig
18
20
  Requires-Dist: numpy
19
21
  Requires-Dist: pysnooper ; extra == "extra"
20
- Requires-Dist: rich ; extra == "extra"
22
+ Requires-Dist: pysnooper ; extra == "snoop"
23
+ Requires-Dist: rich[jupyter] ; extra == "extra"
24
+ Requires-Dist: rich[jupyter] ; extra == "pprint"
21
25
  Requires-Dist: treescope ; extra == "extra"
26
+ Requires-Dist: treescope ; extra == "pprint"
22
27
  Requires-Dist: typing-extensions
23
28
  Requires-Dist: uuid7
24
29
  Project-URL: Homepage, https://github.com/nimashoghi/nshutils
@@ -1,15 +1,25 @@
1
1
  [project]
2
2
  name = "nshutils"
3
- version = "0.19.0"
3
+ version = "0.20.0"
4
4
  description = ""
5
5
  authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
6
6
  requires-python = ">=3.10,<4.0"
7
7
  readme = "README.md"
8
8
 
9
- dependencies = ["numpy", "jaxtyping", "typing-extensions", "beartype", "uuid7"]
9
+ dependencies = [
10
+ "lazy-loader",
11
+ "numpy",
12
+ "typing-extensions",
13
+ "jaxtyping",
14
+ "beartype",
15
+ "uuid7",
16
+ "nshconfig",
17
+ ]
10
18
 
11
19
  [project.optional-dependencies]
12
- extra = ["pysnooper", "lovely-numpy", "lovely-tensors", "rich", "treescope"]
20
+ snoop = ["pysnooper"]
21
+ pprint = ["rich[jupyter]", "treescope"]
22
+ extra = ["pysnooper", "rich[jupyter]", "treescope"]
13
23
 
14
24
  [project.urls]
15
25
  homepage = "https://github.com/nimashoghi/nshutils"
@@ -0,0 +1,20 @@
1
+ from __future__ import annotations
2
+
3
+ import lazy_loader as lazy
4
+
5
+ __getattr__, __dir__, __all__ = lazy.attach_stub(__name__, __file__)
6
+
7
+
8
+ try:
9
+ from importlib.metadata import PackageNotFoundError, version
10
+ except ImportError:
11
+ # For Python <3.8
12
+ from importlib_metadata import ( # pyright: ignore[reportMissingImports]
13
+ PackageNotFoundError,
14
+ version,
15
+ )
16
+
17
+ try:
18
+ __version__ = version(__name__)
19
+ except PackageNotFoundError:
20
+ __version__ = "unknown"
@@ -1,14 +1,12 @@
1
- from __future__ import annotations
2
-
3
1
  from . import actsave as actsave
2
+ from . import lovely as lovely
4
3
  from . import typecheck as typecheck
5
4
  from .actsave import ActLoad as ActLoad
6
5
  from .actsave import ActSave as ActSave
7
6
  from .collections import apply_to_collection as apply_to_collection
8
7
  from .display import display as display
9
8
  from .logging import init_python_logging as init_python_logging
10
- from .logging import lovely as lovely
11
- from .logging import pretty as pretty
9
+ from .logging import setup_logging as setup_logging
12
10
  from .snoop import snoop as snoop
13
11
  from .typecheck import tassert as tassert
14
12
  from .typecheck import typecheck_modules as typecheck_modules
@@ -12,7 +12,7 @@ from pathlib import Path
12
12
  from typing import TYPE_CHECKING, Generic, TypeAlias, cast, overload
13
13
 
14
14
  import numpy as np
15
- from typing_extensions import Never, ParamSpec, TypeVar, override
15
+ from typing_extensions import Never, ParamSpec, TypeAliasType, TypeVar, override
16
16
  from uuid_extensions import uuid7str
17
17
 
18
18
  from ..collections import apply_to_collection
@@ -21,21 +21,23 @@ if not TYPE_CHECKING:
21
21
  try:
22
22
  import torch # type: ignore
23
23
 
24
- Tensor: TypeAlias = torch.Tensor
24
+ Tensor = torch.Tensor
25
25
  except ImportError:
26
26
  torch = None
27
27
 
28
- Tensor: TypeAlias = Never
28
+ Tensor = Never
29
29
  else:
30
30
  import torch # type: ignore
31
31
 
32
- Tensor: TypeAlias = torch.Tensor
32
+ Tensor = torch.Tensor
33
33
 
34
34
 
35
35
  log = getLogger(__name__)
36
36
 
37
- Value: TypeAlias = int | float | complex | bool | str | np.ndarray | Tensor | None
38
- ValueOrLambda: TypeAlias = Value | Callable[..., Value]
37
+ Value = TypeAliasType(
38
+ "Value", int | float | complex | bool | str | np.ndarray | Tensor | None
39
+ )
40
+ ValueOrLambda = TypeAliasType("ValueOrLambda", Value | Callable[..., Value])
39
41
 
40
42
 
41
43
  def _torch_is_scripting() -> bool:
@@ -0,0 +1,75 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from .lovely._monkey_patch_all import Library
9
+
10
+
11
+ def setup_logging(
12
+ *,
13
+ lovely: bool | list[Library] = False,
14
+ treescope: bool = False,
15
+ treescope_autovisualize_arrays: bool = False,
16
+ rich: bool = False,
17
+ rich_tracebacks: bool = False,
18
+ log_level: int | str | None = logging.INFO,
19
+ log_save_dir: Path | None = None,
20
+ ):
21
+ if lovely:
22
+ from .lovely._monkey_patch_all import monkey_patch
23
+
24
+ monkey_patch("auto" if lovely is True else lovely)
25
+
26
+ if treescope:
27
+ try:
28
+ # Check if we're in a Jupyter environment
29
+ from IPython import get_ipython
30
+
31
+ if get_ipython() is not None:
32
+ import treescope as _treescope # type: ignore
33
+
34
+ _treescope.basic_interactive_setup(
35
+ autovisualize_arrays=treescope_autovisualize_arrays
36
+ )
37
+ else:
38
+ logging.info(
39
+ "Treescope setup is only supported in Jupyter notebooks. Skipping."
40
+ )
41
+ except ImportError:
42
+ logging.info(
43
+ "Failed to import `treescope` or `IPython`. Ignoring `treescope` registration"
44
+ )
45
+
46
+ log_handlers: list[logging.Handler] = []
47
+ if log_save_dir:
48
+ log_file = log_save_dir / "logging.log"
49
+ log_file.touch(exist_ok=True)
50
+ log_handlers.append(logging.FileHandler(log_file))
51
+
52
+ if rich:
53
+ try:
54
+ from rich.logging import RichHandler # type: ignore
55
+
56
+ log_handlers.append(RichHandler(rich_tracebacks=rich_tracebacks))
57
+ except ImportError:
58
+ logging.info(
59
+ "Failed to import rich. Falling back to default Python logging."
60
+ )
61
+
62
+ logging.basicConfig(
63
+ level=log_level,
64
+ format="%(message)s",
65
+ datefmt="[%X]",
66
+ handlers=log_handlers,
67
+ )
68
+ logging.info(
69
+ "Logging initialized. "
70
+ f"Lovely: {lovely}, Treescope: {treescope}, Rich: {rich}, "
71
+ f"Log level: {log_level}, Log save dir: {log_save_dir}"
72
+ )
73
+
74
+
75
+ init_python_logging = setup_logging
@@ -0,0 +1,10 @@
1
+ from __future__ import annotations
2
+
3
+ from ._monkey_patch_all import monkey_patch as monkey_patch
4
+ from .config import LovelyConfig as LovelyConfig
5
+ from .jax_ import jax_monkey_patch as jax_monkey_patch
6
+ from .jax_ import jax_repr as jax_repr
7
+ from .numpy_ import numpy_monkey_patch as numpy_monkey_patch
8
+ from .numpy_ import numpy_repr as numpy_repr
9
+ from .torch_ import torch_monkey_patch as torch_monkey_patch
10
+ from .torch_ import torch_repr as torch_repr
@@ -0,0 +1,155 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import importlib.util
5
+ import logging
6
+ from collections.abc import Callable, Iterator
7
+
8
+ from typing_extensions import ParamSpec, TypeAliasType, TypeVar
9
+
10
+ from ..util import ContextResource, resource_factory_contextmanager
11
+ from .utils import LovelyStats, format_tensor_stats
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+ TArray = TypeVar("TArray", infer_variance=True)
16
+ P = ParamSpec("P")
17
+
18
+ LovelyStatsFn = TypeAliasType(
19
+ "LovelyStatsFn",
20
+ Callable[[TArray], LovelyStats],
21
+ type_params=(TArray,),
22
+ )
23
+ LovelyReprFn = TypeAliasType(
24
+ "LovelyReprFn",
25
+ Callable[[TArray], str],
26
+ type_params=(TArray,),
27
+ )
28
+
29
+
30
+ def _find_missing_deps(dependencies: list[str]):
31
+ missing_deps: list[str] = []
32
+
33
+ for dep in dependencies:
34
+ if importlib.util.find_spec(dep) is not None:
35
+ continue
36
+
37
+ missing_deps.append(dep)
38
+
39
+ return missing_deps
40
+
41
+
42
+ def lovely_repr(dependencies: list[str]):
43
+ """
44
+ Decorator to create a lovely representation function for an array.
45
+
46
+ Args:
47
+ dependencies: List of dependencies to check before running the function.
48
+ If any dependency is not available, the function will not run.
49
+
50
+ Returns:
51
+ A decorator function that takes a function and returns a lovely representation function.
52
+
53
+ Example:
54
+ @lovely_repr(dependencies=["torch"])
55
+ def my_array_stats(array):
56
+ return {...}
57
+ """
58
+
59
+ def decorator_fn(array_stats_fn: LovelyStatsFn[TArray]) -> LovelyReprFn[TArray]:
60
+ """
61
+ Decorator to create a lovely representation function for an array.
62
+
63
+ Args:
64
+ array_stats_fn: A function that takes an array and returns its stats.
65
+
66
+ Returns:
67
+ A function that takes an array and returns its lovely representation.
68
+ """
69
+
70
+ @functools.wraps(array_stats_fn)
71
+ def wrapper(array: TArray) -> str:
72
+ if missing_deps := _find_missing_deps(dependencies):
73
+ log.warning(
74
+ f"Missing dependencies: {', '.join(missing_deps)}. "
75
+ "Skipping lovely representation."
76
+ )
77
+ return repr(array)
78
+
79
+ stats = array_stats_fn(array)
80
+ return format_tensor_stats(stats)
81
+
82
+ return wrapper
83
+
84
+ return decorator_fn
85
+
86
+
87
+ LovelyMonkeyPatchInputFn = TypeAliasType(
88
+ "LovelyMonkeyPatchInputFn",
89
+ Callable[P, Iterator[None]],
90
+ type_params=(P,),
91
+ )
92
+ LovelyMonkeyPatchFn = TypeAliasType(
93
+ "LovelyMonkeyPatchFn",
94
+ Callable[P, ContextResource[None]],
95
+ type_params=(P,),
96
+ )
97
+
98
+
99
+ def _nullcontext_generator():
100
+ """A generator that does nothing."""
101
+ yield
102
+
103
+
104
+ def _wrap_monkey_patch_fn(
105
+ monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
106
+ dependencies: list[str],
107
+ ) -> LovelyMonkeyPatchInputFn[P]:
108
+ @functools.wraps(monkey_patch_fn)
109
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[None]:
110
+ if missing_deps := _find_missing_deps(dependencies):
111
+ log.warning(
112
+ f"Missing dependencies: {', '.join(missing_deps)}. "
113
+ "Skipping monkey patch."
114
+ )
115
+ return _nullcontext_generator()
116
+
117
+ return monkey_patch_fn(*args, **kwargs)
118
+
119
+ return wrapper
120
+
121
+
122
+ def monkey_patch_contextmanager(dependencies: list[str]):
123
+ """
124
+ Decorator to create a monkey patch function for an array.
125
+
126
+ Args:
127
+ dependencies: List of dependencies to check before running the function.
128
+ If any dependency is not available, the function will not run.
129
+
130
+ Returns:
131
+ A decorator function that takes a function and returns a monkey patch function.
132
+
133
+ Example:
134
+ @monkey_patch_contextmanager(dependencies=["torch"])
135
+ def my_array_monkey_patch():
136
+ ...
137
+ """
138
+
139
+ def decorator_fn(
140
+ monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
141
+ ) -> LovelyMonkeyPatchFn[P]:
142
+ """
143
+ Decorator to create a monkey patch function for an array.
144
+
145
+ Args:
146
+ monkey_patch_fn: A function that applies the monkey patch.
147
+
148
+ Returns:
149
+ A function that applies the monkey patch.
150
+ """
151
+
152
+ wrapped_fn = _wrap_monkey_patch_fn(monkey_patch_fn, dependencies)
153
+ return resource_factory_contextmanager(wrapped_fn)
154
+
155
+ return decorator_fn
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import importlib.util
5
+ import logging
6
+ from typing import Literal
7
+
8
+ from typing_extensions import TypeAliasType, assert_never
9
+
10
+ from ..util import resource_factory_contextmanager
11
+
12
+ Library = TypeAliasType("Library", Literal["numpy", "torch", "jax"])
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ def _find_deps() -> list[Library]:
18
+ """
19
+ Find available libraries for monkey patching.
20
+ """
21
+ deps: list[Library] = []
22
+ if importlib.util.find_spec("torch") is not None:
23
+ deps.append("torch")
24
+ if importlib.util.find_spec("jax") is not None:
25
+ deps.append("jax")
26
+ if importlib.util.find_spec("numpy") is not None:
27
+ deps.append("numpy")
28
+ return deps
29
+
30
+
31
+ @resource_factory_contextmanager
32
+ def monkey_patch(libraries: list[Library] | Literal["auto"] = "auto"):
33
+ if libraries == "auto":
34
+ libraries = _find_deps()
35
+
36
+ if not libraries:
37
+ raise ValueError(
38
+ "No libraries found for monkey patching. "
39
+ "Please install numpy, torch, or jax."
40
+ )
41
+
42
+ with contextlib.ExitStack() as stack:
43
+ for library in libraries:
44
+ match library:
45
+ case "torch":
46
+ from .torch_ import torch_monkey_patch
47
+
48
+ stack.enter_context(torch_monkey_patch())
49
+ case "jax":
50
+ from .jax_ import jax_monkey_patch
51
+
52
+ stack.enter_context(jax_monkey_patch())
53
+ case "numpy":
54
+ from .numpy_ import numpy_monkey_patch
55
+
56
+ stack.enter_context(numpy_monkey_patch())
57
+ case _:
58
+ assert_never(library)
59
+
60
+ log.info(
61
+ f"Monkey patched libraries: {', '.join(libraries)}. "
62
+ "You can now use the lovely functions with these libraries."
63
+ )
64
+ yield
65
+ log.info(
66
+ f"Unmonkey patched libraries: {', '.join(libraries)}. "
67
+ "You can now use the lovely functions with these libraries."
68
+ )
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import ClassVar
4
+
5
+ import nshconfig as C
6
+ from typing_extensions import Self
7
+
8
+
9
+ class LovelyConfig(C.Config):
10
+ """
11
+ This class is used to manage the configuration of the Lovely library.
12
+ It inherits from the Config class in the nshconfig module.
13
+ """
14
+
15
+ precision: int = 3
16
+ """Number of digits after the decimal point."""
17
+
18
+ threshold_max: int = 3
19
+ """Absolute values larger than 10^3 use scientific notation."""
20
+
21
+ threshold_min: int = -4
22
+ """Absolute values smaller than 10^-4 use scientific notation."""
23
+
24
+ sci_mode: bool | None = None
25
+ """Force scientific notation (None=auto)."""
26
+
27
+ show_mem_above: int = 1024
28
+ """Show memory size if above this threshold (bytes)."""
29
+
30
+ color: bool = True
31
+ """Use ANSI colors in text."""
32
+
33
+ indent: int = 2
34
+ """Indent for nested representation."""
35
+
36
+ config_instance: ClassVar[Self | None] = None
37
+ """Singleton instance of the LovelyConfig class."""
38
+
39
+ @classmethod
40
+ def instance(cls) -> Self:
41
+ """
42
+ Get the singleton instance of the LovelyConfig class.
43
+ If it doesn't exist, create it.
44
+ """
45
+ if cls.config_instance is None:
46
+ cls.config_instance = cls()
47
+ return cls.config_instance
@@ -0,0 +1,89 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, cast
4
+
5
+ import numpy as np
6
+
7
+ from ._base import lovely_repr, monkey_patch_contextmanager
8
+ from .utils import LovelyStats, array_stats, patch_to
9
+
10
+ if TYPE_CHECKING:
11
+ import jax
12
+
13
+
14
+ def _type_name(array: jax.Array):
15
+ type_name = type(array).__name__.rsplit(".", 1)[-1]
16
+ return "array" if type_name == "ArrayImpl" else type_name
17
+
18
+
19
+ _DT_NAMES = {
20
+ "float16": "f16",
21
+ "float32": "f32",
22
+ "float64": "f64",
23
+ "uint8": "u8",
24
+ "uint16": "u16",
25
+ "uint32": "u32",
26
+ "uint64": "u64",
27
+ "int8": "i8",
28
+ "int16": "i16",
29
+ "int32": "i32",
30
+ "int64": "i64",
31
+ "bfloat16": "bf16",
32
+ "complex64": "c64",
33
+ "complex128": "c128",
34
+ }
35
+
36
+
37
+ def _dtype_str(array: jax.Array) -> str:
38
+ dtype_base = str(array.dtype).rsplit(".", 1)[-1]
39
+ dtype_base = _DT_NAMES.get(dtype_base, dtype_base)
40
+ return dtype_base
41
+
42
+
43
+ def _device(array: jax.Array) -> str:
44
+ from jaxlib.xla_extension import Device
45
+
46
+ if callable(device := array.device):
47
+ device = device()
48
+
49
+ device = cast(Device, device)
50
+ if device.platform == "cpu":
51
+ return "cpu"
52
+
53
+ return f"{device.platform}:{device.id}"
54
+
55
+
56
+ @lovely_repr(dependencies=["jax"])
57
+ def jax_repr(array: jax.Array) -> LovelyStats:
58
+ import jax.numpy as jnp
59
+
60
+ return {
61
+ # Basic attributes
62
+ "shape": array.shape,
63
+ "size": array.size,
64
+ "nbytes": array.nbytes,
65
+ "type_name": _type_name(array),
66
+ # Dtype
67
+ "dtype_str": _dtype_str(array),
68
+ "is_complex": jnp.iscomplexobj(array),
69
+ # Device
70
+ "device": _device(array),
71
+ # Depending of whether the tensor is complex or not, we will call the appropriate stats function
72
+ **array_stats(np.asarray(array)),
73
+ }
74
+
75
+
76
+ @monkey_patch_contextmanager(dependencies=["jax"])
77
+ def jax_monkey_patch():
78
+ from jax._src import array
79
+
80
+ prev_repr = array.ArrayImpl.__repr__
81
+ prev_str = array.ArrayImpl.__str__
82
+ try:
83
+ patch_to(array.ArrayImpl, "__repr__", jax_repr)
84
+ patch_to(array.ArrayImpl, "__str__", jax_repr)
85
+
86
+ yield
87
+ finally:
88
+ patch_to(array.ArrayImpl, "__repr__", prev_repr)
89
+ patch_to(array.ArrayImpl, "__str__", prev_str)
@@ -0,0 +1,72 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+
5
+ import numpy as np
6
+
7
+ from ._base import lovely_repr, monkey_patch_contextmanager
8
+ from .utils import LovelyStats, array_stats
9
+
10
+
11
+ def _type_name(array: np.ndarray):
12
+ return (
13
+ "array"
14
+ if type(array) is np.ndarray
15
+ else type(array).__name__.rsplit(".", 1)[-1]
16
+ )
17
+
18
+
19
+ _DT_NAMES = {
20
+ "float16": "f16",
21
+ "float32": "f32",
22
+ "float64": "", # Default dtype in numpy
23
+ "uint8": "u8",
24
+ "uint16": "u16",
25
+ "uint32": "u32",
26
+ "uint64": "u64",
27
+ "int8": "i8",
28
+ "int16": "i16",
29
+ "int32": "i32",
30
+ "int64": "i64",
31
+ "complex64": "c64",
32
+ "complex128": "c128",
33
+ }
34
+
35
+
36
+ def _dtype_str(array: np.ndarray) -> str:
37
+ dtype_base = str(array.dtype).rsplit(".", 1)[-1]
38
+ dtype_base = _DT_NAMES.get(dtype_base, dtype_base)
39
+ return dtype_base
40
+
41
+
42
+ @lovely_repr(dependencies=["numpy"])
43
+ def numpy_repr(array: np.ndarray) -> LovelyStats:
44
+ return {
45
+ # Basic attributes
46
+ "shape": array.shape,
47
+ "size": array.size,
48
+ "nbytes": array.nbytes,
49
+ "type_name": _type_name(array),
50
+ # Dtype
51
+ "dtype_str": _dtype_str(array),
52
+ "is_complex": np.iscomplexobj(array),
53
+ # Depending of whether the tensor is complex or not, we will call the appropriate stats function
54
+ **array_stats(array),
55
+ }
56
+
57
+
58
+ @monkey_patch_contextmanager(dependencies=["numpy"])
59
+ def numpy_monkey_patch():
60
+ try:
61
+ np.set_printoptions(override_repr=numpy_repr)
62
+ logging.info(
63
+ f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
64
+ f"{np.get_printoptions()=}"
65
+ )
66
+ yield
67
+ finally:
68
+ np.set_printoptions(override_repr=None)
69
+ logging.info(
70
+ f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
71
+ f"{np.get_printoptions()=}"
72
+ )
@@ -0,0 +1,99 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+
7
+ from ._base import lovely_repr, monkey_patch_contextmanager
8
+ from .utils import LovelyStats, array_stats, patch_to
9
+
10
+ if TYPE_CHECKING:
11
+ import torch
12
+
13
+
14
+ def _type_name(tensor: torch.Tensor):
15
+ import torch
16
+
17
+ return (
18
+ "tensor"
19
+ if type(tensor) is torch.Tensor
20
+ else type(tensor).__name__.split(".")[-1]
21
+ )
22
+
23
+
24
+ _DT_NAMES = {
25
+ "float32": "", # Default dtype
26
+ "float16": "f16",
27
+ "float64": "f64",
28
+ "bfloat16": "bf16",
29
+ "uint8": "u8",
30
+ "int8": "i8",
31
+ "int16": "i16",
32
+ "int32": "i32",
33
+ "int64": "i64",
34
+ "complex32": "c32",
35
+ "complex64": "c64",
36
+ "complex128": "c128",
37
+ }
38
+
39
+
40
+ def _dtype_str(tensor: torch.Tensor) -> str:
41
+ dtype_base = str(tensor.dtype).rsplit(".", 1)[-1]
42
+ dtype_base = _DT_NAMES.get(dtype_base, dtype_base)
43
+ return dtype_base
44
+
45
+
46
+ def _to_np(tensor: torch.Tensor) -> np.ndarray:
47
+ import torch
48
+
49
+ # Get tensor data as CPU NumPy array for analysis
50
+ t_cpu = tensor.detach().cpu()
51
+
52
+ # Convert bfloat16 to float32 for numpy compatibility
53
+ if tensor.dtype == torch.bfloat16:
54
+ t_cpu = t_cpu.to(torch.float32)
55
+
56
+ # Convert to NumPy
57
+ t_np = t_cpu.numpy()
58
+
59
+ return t_np
60
+
61
+
62
+ @lovely_repr(dependencies=["torch"])
63
+ def torch_repr(tensor: torch.Tensor) -> LovelyStats:
64
+ return {
65
+ # Basic attributes
66
+ "shape": tensor.shape,
67
+ "size": tensor.numel(),
68
+ "nbytes": tensor.element_size() * tensor.numel(),
69
+ "type_name": _type_name(tensor),
70
+ # Device
71
+ "device": str(tensor.device) if tensor.device else None,
72
+ "is_meta": device.type == "meta" if (device := tensor.device) else False,
73
+ # Grad
74
+ "requires_grad": tensor.requires_grad,
75
+ # Dtype
76
+ "dtype_str": _dtype_str(tensor),
77
+ "is_complex": tensor.is_complex(),
78
+ # Depending of whether the tensor is complex or not, we will call the appropriate stats function
79
+ **array_stats(_to_np(tensor)),
80
+ }
81
+
82
+
83
+ @monkey_patch_contextmanager(dependencies=["torch"])
84
+ def torch_monkey_patch():
85
+ import torch
86
+
87
+ original_repr = torch.Tensor.__repr__
88
+ original_str = torch.Tensor.__str__
89
+ original_parameter_repr = torch.nn.Parameter.__repr__
90
+ try:
91
+ patch_to(torch.Tensor, "__repr__", torch_repr)
92
+ patch_to(torch.Tensor, "__str__", torch_repr)
93
+ del torch.nn.Parameter.__repr__
94
+
95
+ yield
96
+ finally:
97
+ patch_to(torch.Tensor, "__repr__", original_repr)
98
+ patch_to(torch.Tensor, "__str__", original_str)
99
+ patch_to(torch.nn.Parameter, "__repr__", original_parameter_repr)
@@ -0,0 +1,345 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from collections import defaultdict
5
+ from collections.abc import Callable, Sequence
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from typing_extensions import TypedDict
10
+
11
+ from .config import LovelyConfig
12
+
13
+
14
+ class LovelyStats(TypedDict, total=False):
15
+ """Statistics for tensor representation."""
16
+
17
+ # Basic tensor information
18
+ shape: Sequence[int]
19
+ size: int
20
+ type_name: str
21
+ dtype_str: str
22
+ device: str | None
23
+ nbytes: int
24
+ requires_grad: bool
25
+ is_meta: bool
26
+
27
+ # Content flags
28
+ all_zeros: bool
29
+ has_nan: bool
30
+ has_pos_inf: bool
31
+ has_neg_inf: bool
32
+ is_complex: bool
33
+
34
+ # Numeric statistics
35
+ min: float | None
36
+ max: float | None
37
+ mean: float | None
38
+ std: float | None
39
+
40
+ # Complex number statistics
41
+ mag_min: float | None
42
+ mag_max: float | None
43
+ real_min: float | None
44
+ real_max: float | None
45
+ imag_min: float | None
46
+ imag_max: float | None
47
+
48
+ # Representation
49
+ values_str: str | None
50
+
51
+
52
+ # Formatting utilities
53
+ def sci_mode(f: float) -> bool:
54
+ """Determine if a float should be displayed in scientific notation."""
55
+ config = LovelyConfig.instance()
56
+ return (abs(f) < 10**config.threshold_min) or (abs(f) > 10**config.threshold_max)
57
+
58
+
59
+ def pretty_str(x: Any) -> str:
60
+ """Format a number or array for pretty display.
61
+
62
+ Works with scalars, numpy arrays, torch tensors, and jax arrays.
63
+ """
64
+ if isinstance(x, int):
65
+ return f"{x}"
66
+ elif isinstance(x, float):
67
+ if x == 0.0:
68
+ return "0."
69
+
70
+ sci = (
71
+ sci_mode(x)
72
+ if LovelyConfig.instance().sci_mode is None
73
+ else LovelyConfig.instance().sci_mode
74
+ )
75
+ fmt = f"{{:.{LovelyConfig.instance().precision}{'e' if sci else 'f'}}}"
76
+ return fmt.format(x)
77
+ elif isinstance(x, complex):
78
+ # Handle complex numbers
79
+ real_part = pretty_str(x.real)
80
+ imag_part = pretty_str(abs(x.imag))
81
+ sign = "+" if x.imag >= 0 else "-"
82
+ return f"{real_part}{sign}{imag_part}j"
83
+
84
+ # Handle array-like objects
85
+ try:
86
+ if hasattr(x, "ndim") and x.ndim == 0:
87
+ return pretty_str(x.item())
88
+ elif hasattr(x, "shape") and len(x.shape) > 0:
89
+ slices = [pretty_str(x[i]) for i in range(min(x.shape[0], 10))]
90
+ if x.shape[0] > 10:
91
+ slices.append("...")
92
+ return "[" + ", ".join(slices) + "]"
93
+ except:
94
+ pass
95
+
96
+ # Fallback
97
+ return str(x)
98
+
99
+
100
+ def sparse_join(items: list[str | None], sep: str = " ") -> str:
101
+ """Join non-empty strings with a separator."""
102
+ return sep.join([item for item in items if item])
103
+
104
+
105
+ def ansi_color(s: str, col: str, use_color: bool = True) -> str:
106
+ """Add ANSI color to a string if use_color is True."""
107
+ if not use_color:
108
+ return s
109
+
110
+ style = defaultdict(str)
111
+ style["grey"] = "\x1b[38;2;127;127;127m"
112
+ style["red"] = "\x1b[31m"
113
+ end_style = "\x1b[0m"
114
+
115
+ return style[col] + s + end_style
116
+
117
+
118
+ def bytes_to_human(num_bytes: int) -> str:
119
+ """Convert bytes to a human-readable string (b, Kb, Mb, Gb)."""
120
+ units = ["b", "Kb", "Mb", "Gb"]
121
+
122
+ value = num_bytes
123
+ matched_unit: str | None = None
124
+ for unit in units:
125
+ if value < 1024 / 10:
126
+ matched_unit = unit
127
+ break
128
+ value /= 1024.0
129
+
130
+ assert matched_unit is not None, "No matching unit found"
131
+
132
+ if value % 1 == 0 or value >= 10:
133
+ return f"{round(value)}{matched_unit}"
134
+ else:
135
+ return f"{value:.1f}{matched_unit}"
136
+
137
+
138
+ def in_debugger() -> bool:
139
+ """Returns True if running in a debugger."""
140
+ return getattr(sys, "gettrace", None) is not None and sys.gettrace() is not None
141
+
142
+
143
+ # Common tensor representation
144
+ def format_tensor_stats(tensor_stats: LovelyStats, color: bool | None = None) -> str:
145
+ """Format tensor stats into a pretty string representation."""
146
+ conf = LovelyConfig.instance()
147
+ if color is None:
148
+ color = conf.color
149
+ if in_debugger():
150
+ color = False
151
+
152
+ # Basic tensor info
153
+ shape_str = str(list(shape)) if (shape := tensor_stats.get("shape")) else None
154
+ type_str = sparse_join([tensor_stats.get("type_name"), shape_str], sep="")
155
+
156
+ # Calculate memory usage
157
+ numel = None
158
+ if (size := tensor_stats.get("size")) and (nbytes := tensor_stats.get("nbytes")):
159
+ shape = tensor_stats.get("shape", [])
160
+
161
+ if shape and max(shape) != size:
162
+ numel = f"n={size}"
163
+ if conf.show_mem_above <= nbytes:
164
+ numel = sparse_join([numel, f"({bytes_to_human(nbytes)})"])
165
+ elif conf.show_mem_above <= nbytes:
166
+ numel = bytes_to_human(nbytes)
167
+
168
+ # Handle empty tensors
169
+ if tensor_stats.get("size", 0) == 0:
170
+ common = ansi_color("empty", "grey", color)
171
+ # Handle all zeros
172
+ elif tensor_stats.get("all_zeros"):
173
+ common = ansi_color("all_zeros", "grey", color)
174
+ # Handle complex tensors
175
+ elif tensor_stats.get("is_complex"):
176
+ complex_info = []
177
+
178
+ # For magnitude stats
179
+ if (mag_min := tensor_stats.get("mag_min")) is not None and (
180
+ mag_max := tensor_stats.get("mag_max")
181
+ ) is not None:
182
+ complex_info.append(f"|z|∈[{pretty_str(mag_min)}, {pretty_str(mag_max)}]")
183
+
184
+ # For real part stats
185
+ if (real_min := tensor_stats.get("real_min")) is not None and (
186
+ real_max := tensor_stats.get("real_max")
187
+ ) is not None:
188
+ complex_info.append(f"Re∈[{pretty_str(real_min)}, {pretty_str(real_max)}]")
189
+
190
+ # For imaginary part stats
191
+ if (imag_min := tensor_stats.get("imag_min")) is not None and (
192
+ imag_max := tensor_stats.get("imag_max")
193
+ ) is not None:
194
+ complex_info.append(f"Im∈[{pretty_str(imag_min)}, {pretty_str(imag_max)}]")
195
+
196
+ common = sparse_join(complex_info)
197
+ # Handle normal tensors with stats
198
+ elif (min_val := tensor_stats.get("min")) is not None and (
199
+ max_val := tensor_stats.get("max")
200
+ ) is not None:
201
+ minmax = None
202
+ meanstd = None
203
+
204
+ if tensor_stats.get("size", 0) > 2:
205
+ minmax = f"x∈[{pretty_str(min_val)}, {pretty_str(max_val)}]"
206
+
207
+ if (
208
+ (mean := tensor_stats.get("mean")) is not None
209
+ and (std := tensor_stats.get("std")) is not None
210
+ and tensor_stats.get("size", 0) >= 2
211
+ ):
212
+ meanstd = f"μ={pretty_str(mean)} σ={pretty_str(std)}"
213
+
214
+ common = sparse_join([minmax, meanstd])
215
+ else:
216
+ common = None
217
+
218
+ # Handle warnings
219
+ warnings = []
220
+ if tensor_stats.get("has_nan"):
221
+ warnings.append(ansi_color("NaN!", "red", color))
222
+ if tensor_stats.get("has_pos_inf"):
223
+ warnings.append(ansi_color("+Inf!", "red", color))
224
+ if tensor_stats.get("has_neg_inf"):
225
+ warnings.append(ansi_color("-Inf!", "red", color))
226
+
227
+ attention = sparse_join(warnings)
228
+ common = sparse_join([common, attention])
229
+
230
+ # Other tensor attributes
231
+ dtype = tensor_stats.get("dtype_str", "")
232
+ device = tensor_stats.get("device")
233
+ grad = "grad" if tensor_stats.get("requires_grad") else None
234
+
235
+ # Format values for small tensors
236
+ vals = None
237
+ if (
238
+ 0 < tensor_stats.get("size", 0) <= 10
239
+ and tensor_stats.get("is_meta", False) is False
240
+ ):
241
+ vals = tensor_stats.get("values_str")
242
+
243
+ # Join all parts
244
+ result = sparse_join([type_str, dtype, numel, common, grad, device, vals])
245
+
246
+ return result
247
+
248
+
249
+ def real_stats(array: np.ndarray) -> LovelyStats:
250
+ stats: LovelyStats = {}
251
+
252
+ # Check for special values
253
+ stats["has_nan"] = bool(np.isnan(array).any())
254
+ stats["has_pos_inf"] = bool(np.isposinf(array).any())
255
+ stats["has_neg_inf"] = bool(np.isneginf(array).any())
256
+
257
+ # Only compute min/max/mean/std for good data
258
+ good_data = array[np.isfinite(array)]
259
+
260
+ if len(good_data) > 0:
261
+ stats["min"] = float(good_data.min())
262
+ stats["max"] = float(good_data.max())
263
+ stats["all_zeros"] = stats["min"] == 0 and stats["max"] == 0 and array.size > 1
264
+
265
+ if len(good_data) > 1:
266
+ stats["mean"] = float(good_data.mean())
267
+ stats["std"] = float(good_data.std())
268
+
269
+ # Get string representation of values for small tensors
270
+ if 0 < array.size <= 10:
271
+ stats["values_str"] = pretty_str(array)
272
+
273
+ return stats
274
+
275
+
276
+ def complex_stats(array: np.ndarray) -> LovelyStats:
277
+ stats: LovelyStats = {}
278
+
279
+ # Calculate magnitude (absolute value)
280
+ magnitude = np.abs(array)
281
+
282
+ # Check for special values in real or imaginary parts
283
+ stats["has_nan"] = bool(np.isnan(array.real).any() or np.isnan(array.imag).any())
284
+
285
+ # Get statistics for magnitude
286
+ good_mag = magnitude[np.isfinite(magnitude)]
287
+ if len(good_mag) > 0:
288
+ stats["mag_min"] = float(good_mag.min())
289
+ stats["mag_max"] = float(good_mag.max())
290
+ stats["all_zeros"] = (
291
+ stats["mag_min"] == 0 and stats["mag_max"] == 0 and array.size > 1
292
+ )
293
+
294
+ # Get statistics for real and imaginary parts
295
+ real_part = array.real
296
+ imag_part = array.imag
297
+
298
+ good_real = real_part[np.isfinite(real_part)]
299
+ good_imag = imag_part[np.isfinite(imag_part)]
300
+
301
+ if len(good_real := real_part[np.isfinite(real_part)]):
302
+ stats["real_min"] = float(good_real.min())
303
+ stats["real_max"] = float(good_real.max())
304
+
305
+ if len(good_imag := imag_part[np.isfinite(imag_part)]):
306
+ stats["imag_min"] = float(good_imag.min())
307
+ stats["imag_max"] = float(good_imag.max())
308
+
309
+ # Get string representation of values for small tensors
310
+ if 0 < array.size <= 10:
311
+ stats["values_str"] = pretty_str(array)
312
+
313
+ return stats
314
+
315
+
316
+ def array_stats(array: np.ndarray, ignore_empty: bool = True) -> LovelyStats:
317
+ """Compute all statistics for a given array.
318
+
319
+ Args:
320
+ array (np.ndarray): The input array.
321
+ ignore_empty (bool): If True, ignore empty arrays.
322
+
323
+ Returns:
324
+ LovelyStats: A dictionary containing the computed statistics.
325
+ """
326
+ if ignore_empty and array.size == 0:
327
+ return {}
328
+
329
+ if np.iscomplexobj(array):
330
+ return complex_stats(array)
331
+ else:
332
+ return real_stats(array)
333
+
334
+
335
+ def patch_to(
336
+ cls: type,
337
+ name: str,
338
+ func: Callable[[Any], Any],
339
+ as_property: bool = False,
340
+ ) -> None:
341
+ """Simple patch_to implementation to avoid fastcore dependency."""
342
+ if as_property:
343
+ setattr(cls, name, property(func))
344
+ else:
345
+ setattr(cls, name, func)
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  import os
4
5
  import sys
5
6
  from collections.abc import Sequence
6
- from logging import getLogger
7
7
  from types import FrameType as _FrameType
8
8
  from typing import Any
9
9
 
@@ -52,7 +52,8 @@ try:
52
52
  import jax # type: ignore
53
53
  except ImportError:
54
54
  jax = None
55
- log = getLogger(__name__)
55
+
56
+ log = logging.getLogger(__name__)
56
57
 
57
58
  DISABLE_ENV_KEY = "NSHUTILS_DISABLE_TYPECHECKING"
58
59
 
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import functools
5
+ from collections.abc import Callable, Iterator
6
+ from typing import Any, Generic
7
+
8
+ from typing_extensions import ParamSpec, TypeVar, override
9
+
10
+ R = TypeVar("R")
11
+ P = ParamSpec("P")
12
+
13
+
14
+ class ContextResource(contextlib.AbstractContextManager[R], Generic[R]):
15
+ """A class that provides both direct access to a resource and context management."""
16
+
17
+ def __init__(self, resource: R, cleanup_func: Callable[[R], Any]):
18
+ self.resource = resource
19
+ self._cleanup_func = cleanup_func
20
+
21
+ @override
22
+ def __enter__(self) -> R:
23
+ """When used as a context manager, return the wrapped resource."""
24
+ return self.resource
25
+
26
+ @override
27
+ def __exit__(self, *exc_info) -> None:
28
+ """Clean up the resource when exiting the context."""
29
+ self._cleanup_func(self.resource)
30
+
31
+ def close(self) -> None:
32
+ """Explicitly clean up the resource."""
33
+ self._cleanup_func(self.resource)
34
+
35
+
36
+ def resource_factory(
37
+ create_func: Callable[P, R], cleanup_func: Callable[[R], None]
38
+ ) -> Callable[P, ContextResource[R]]:
39
+ """
40
+ Create a factory function that returns a ContextResource.
41
+
42
+ Args:
43
+ create_func: Function that creates the resource
44
+ cleanup_func: Function that cleans up the resource
45
+
46
+ Returns:
47
+ A function that returns a ContextResource wrapping the created resource
48
+ """
49
+
50
+ @functools.wraps(create_func)
51
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextResource[R]:
52
+ resource = create_func(*args, **kwargs)
53
+ return ContextResource(resource, cleanup_func)
54
+
55
+ return wrapper
56
+
57
+
58
+ def resource_factory_from_context_fn(
59
+ context_func: Callable[P, contextlib.AbstractContextManager[R]],
60
+ ) -> Callable[P, ContextResource[R]]:
61
+ """
62
+ Create a factory function that returns a ContextResource.
63
+
64
+ Args:
65
+ context_func: Function that creates the resource
66
+
67
+ Returns:
68
+ A function that returns a ContextResource wrapping the created resource
69
+ """
70
+
71
+ @functools.wraps(context_func)
72
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextResource[R]:
73
+ context = context_func(*args, **kwargs)
74
+ resource = context.__enter__()
75
+ return ContextResource(resource, lambda _: context.__exit__(None, None, None))
76
+
77
+ return wrapper
78
+
79
+
80
+ def resource_factory_contextmanager(
81
+ context_func: Callable[P, Iterator[R]],
82
+ ) -> Callable[P, ContextResource[R]]:
83
+ """
84
+ Create a factory function that returns a ContextResource.
85
+
86
+ Args:
87
+ context_func: Generator function that creates the resource, yields it, and cleans up the resource when done.
88
+
89
+ Returns:
90
+ A function that returns a ContextResource wrapping the created resource
91
+ """
92
+ return resource_factory_from_context_fn(contextlib.contextmanager(context_func))
@@ -1,125 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import logging
4
- from pathlib import Path
5
-
6
-
7
- def init_python_logging(
8
- *,
9
- lovely_tensors: bool = False,
10
- lovely_numpy: bool = False,
11
- treescope: bool = False,
12
- treescope_autovisualize_arrays: bool = False,
13
- rich: bool = False,
14
- rich_tracebacks: bool = False,
15
- log_level: int | str | None = logging.INFO,
16
- log_save_dir: Path | None = None,
17
- ):
18
- if lovely_tensors:
19
- try:
20
- import lovely_tensors as _lovely_tensors # type: ignore
21
-
22
- _lovely_tensors.monkey_patch()
23
- except ImportError:
24
- logging.info(
25
- "Failed to import `lovely_tensors`. Ignoring pretty PyTorch tensor formatting"
26
- )
27
-
28
- if lovely_numpy:
29
- try:
30
- import lovely_numpy as _lovely_numpy # type: ignore
31
-
32
- _lovely_numpy.set_config(repr=_lovely_numpy.lovely)
33
- except ImportError:
34
- logging.info(
35
- "Failed to import `lovely_numpy`. Ignoring pretty numpy array formatting"
36
- )
37
-
38
- if treescope:
39
- try:
40
- # Check if we're in a Jupyter environment
41
- from IPython import get_ipython
42
-
43
- if get_ipython() is not None:
44
- import treescope as _treescope # type: ignore
45
-
46
- _treescope.basic_interactive_setup(
47
- autovisualize_arrays=treescope_autovisualize_arrays
48
- )
49
- else:
50
- logging.info(
51
- "Treescope setup is only supported in Jupyter notebooks. Skipping."
52
- )
53
- except ImportError:
54
- logging.info(
55
- "Failed to import `treescope` or `IPython`. Ignoring `treescope` registration"
56
- )
57
-
58
- log_handlers: list[logging.Handler] = []
59
- if log_save_dir:
60
- log_file = log_save_dir / "logging.log"
61
- log_file.touch(exist_ok=True)
62
- log_handlers.append(logging.FileHandler(log_file))
63
-
64
- if rich:
65
- try:
66
- from rich.logging import RichHandler # type: ignore
67
-
68
- log_handlers.append(RichHandler(rich_tracebacks=rich_tracebacks))
69
- except ImportError:
70
- logging.info(
71
- "Failed to import rich. Falling back to default Python logging."
72
- )
73
-
74
- logging.basicConfig(
75
- level=log_level,
76
- format="%(message)s",
77
- datefmt="[%X]",
78
- handlers=log_handlers,
79
- )
80
-
81
-
82
- def pretty(
83
- *,
84
- lovely_tensors: bool = True,
85
- lovely_numpy: bool = True,
86
- treescope: bool = False,
87
- treescope_autovisualize_arrays: bool = False,
88
- log_level: int | str | None = logging.INFO,
89
- log_save_dir: Path | None = None,
90
- rich_log_handler: bool = False,
91
- rich_tracebacks: bool = False,
92
- ):
93
- init_python_logging(
94
- lovely_tensors=lovely_tensors,
95
- lovely_numpy=lovely_numpy,
96
- treescope=treescope,
97
- treescope_autovisualize_arrays=treescope_autovisualize_arrays,
98
- rich=rich_log_handler,
99
- log_level=log_level,
100
- log_save_dir=log_save_dir,
101
- rich_tracebacks=rich_tracebacks,
102
- )
103
-
104
-
105
- def lovely(
106
- *,
107
- lovely_tensors: bool = True,
108
- lovely_numpy: bool = True,
109
- treescope: bool = False,
110
- treescope_autovisualize_arrays: bool = False,
111
- log_level: int | str | None = logging.INFO,
112
- log_save_dir: Path | None = None,
113
- rich_log_handler: bool = False,
114
- rich_tracebacks: bool = False,
115
- ):
116
- pretty(
117
- lovely_tensors=lovely_tensors,
118
- lovely_numpy=lovely_numpy,
119
- treescope=treescope,
120
- treescope_autovisualize_arrays=treescope_autovisualize_arrays,
121
- log_level=log_level,
122
- log_save_dir=log_save_dir,
123
- rich_log_handler=rich_log_handler,
124
- rich_tracebacks=rich_tracebacks,
125
- )
File without changes