nshutils 0.22.4__tar.gz → 0.22.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {nshutils-0.22.4 → nshutils-0.22.6}/PKG-INFO +1 -1
  2. {nshutils-0.22.4 → nshutils-0.22.6}/pyproject.toml +1 -1
  3. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/actsave/_saver.py +5 -5
  4. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/lovely/_base.py +9 -6
  5. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/lovely/jax_.py +6 -2
  6. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/lovely/numpy_.py +6 -2
  7. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/lovely/torch_.py +2 -2
  8. {nshutils-0.22.4 → nshutils-0.22.6}/README.md +0 -0
  9. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/__init__.py +0 -0
  10. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/__init__.pyi +0 -0
  11. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/actsave/__init__.py +0 -0
  12. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/actsave/_loader.py +0 -0
  13. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/collections.py +0 -0
  14. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/display.py +0 -0
  15. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/logging.py +0 -0
  16. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/lovely/__init__.py +0 -0
  17. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/lovely/_monkey_patch_all.py +0 -0
  18. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/lovely/config.py +0 -0
  19. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/lovely/utils.py +0 -0
  20. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/snoop.py +0 -0
  21. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/typecheck.py +0 -0
  22. {nshutils-0.22.4 → nshutils-0.22.6}/src/nshutils/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshutils
3
- Version: 0.22.4
3
+ Version: 0.22.6
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "nshutils"
3
- version = "0.22.4"
3
+ version = "0.22.6"
4
4
  description = ""
5
5
  authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
6
6
  requires-python = ">=3.9,<4.0"
@@ -9,7 +9,7 @@ from dataclasses import dataclass
9
9
  from functools import wraps
10
10
  from logging import getLogger
11
11
  from pathlib import Path
12
- from typing import TYPE_CHECKING, Generic, Literal, Union, cast, overload
12
+ from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast, overload
13
13
 
14
14
  import numpy as np
15
15
  from typing_extensions import Never, ParamSpec, TypeAliasType, TypeVar, override
@@ -36,8 +36,9 @@ else:
36
36
 
37
37
  log = getLogger(__name__)
38
38
 
39
+ # Updated to include Any for arbitrary types
39
40
  Value = TypeAliasType(
40
- "Value", Union[int, float, complex, bool, str, np.ndarray, Tensor, None]
41
+ "Value", Union[int, float, complex, bool, str, np.ndarray, Tensor, Any, None]
41
42
  )
42
43
  ValueOrLambda = TypeAliasType("ValueOrLambda", Union[Value, Callable[..., Value]])
43
44
 
@@ -65,9 +66,8 @@ def _to_numpy(activation: Value) -> np.ndarray:
65
66
  activation_ = activation_.float()
66
67
  return activation_.cpu().numpy()
67
68
  else:
68
- log.warning(f"Unrecognized activation type {type(activation)}")
69
-
70
- return activation
69
+ # Handle arbitrary objects using numpy object dtype
70
+ return np.array(activation, dtype=object)
71
71
 
72
72
 
73
73
  T = TypeVar("T", infer_variance=True)
@@ -17,7 +17,7 @@ P = ParamSpec("P")
17
17
 
18
18
  LovelyStatsFn = TypeAliasType(
19
19
  "LovelyStatsFn",
20
- Callable[[TArray], LovelyStats],
20
+ Callable[[TArray], LovelyStats | None],
21
21
  type_params=(TArray,),
22
22
  )
23
23
  LovelyReprFn = TypeAliasType(
@@ -39,14 +39,14 @@ def _find_missing_deps(dependencies: list[str]):
39
39
  return missing_deps
40
40
 
41
41
 
42
- def lovely_repr(dependencies: list[str]):
42
+ def lovely_repr(dependencies: list[str], fallback_repr: Callable[[TArray], str]):
43
43
  """
44
44
  Decorator to create a lovely representation function for an array.
45
45
 
46
46
  Args:
47
47
  dependencies: List of dependencies to check before running the function.
48
48
  If any dependency is not available, the function will not run.
49
-
49
+ fallback_repr: A function that takes an array and returns its fallback representation.
50
50
  Returns:
51
51
  A decorator function that takes a function and returns a lovely representation function.
52
52
 
@@ -61,7 +61,8 @@ def lovely_repr(dependencies: list[str]):
61
61
  Decorator to create a lovely representation function for an array.
62
62
 
63
63
  Args:
64
- array_stats_fn: A function that takes an array and returns its stats.
64
+ array_stats_fn: A function that takes an array and returns its stats,
65
+ or `None` if the array is not supported.
65
66
 
66
67
  Returns:
67
68
  A function that takes an array and returns its lovely representation.
@@ -74,9 +75,11 @@ def lovely_repr(dependencies: list[str]):
74
75
  f"Missing dependencies: {', '.join(missing_deps)}. "
75
76
  "Skipping lovely representation."
76
77
  )
77
- return repr(array)
78
+ return fallback_repr(array)
79
+
80
+ if (stats := array_stats_fn(array)) is None:
81
+ return fallback_repr(array)
78
82
 
79
- stats = array_stats_fn(array)
80
83
  return format_tensor_stats(stats)
81
84
 
82
85
  return wrapper
@@ -53,10 +53,14 @@ def _device(array: jax.Array) -> str:
53
53
  return f"{device.platform}:{device.id}"
54
54
 
55
55
 
56
- @lovely_repr(dependencies=["jax"])
57
- def jax_repr(array: jax.Array) -> LovelyStats:
56
+ @lovely_repr(dependencies=["jax"], fallback_repr=jax.Array.__repr__)
57
+ def jax_repr(array: jax.Array) -> LovelyStats | None:
58
58
  import jax.numpy as jnp
59
59
 
60
+ # For dtypes like `object` or `str`, we let the fallback repr handle it
61
+ if not jnp.issubdtype(array.dtype, jnp.number):
62
+ return None
63
+
60
64
  return {
61
65
  # Basic attributes
62
66
  "shape": array.shape,
@@ -51,8 +51,12 @@ def _dtype_str(array: np.ndarray) -> str:
51
51
  return dtype_base
52
52
 
53
53
 
54
- @lovely_repr(dependencies=["numpy"])
55
- def numpy_repr(array: np.ndarray) -> LovelyStats:
54
+ @lovely_repr(dependencies=["numpy"], fallback_repr=np.array_repr)
55
+ def numpy_repr(array: np.ndarray) -> LovelyStats | None:
56
+ # For dtypes like `object` or `str`, we let the fallback repr handle it
57
+ if not np.issubdtype(array.dtype, np.number):
58
+ return None
59
+
56
60
  return {
57
61
  # Basic attributes
58
62
  "shape": array.shape,
@@ -59,8 +59,8 @@ def _to_np(tensor: torch.Tensor) -> np.ndarray:
59
59
  return t_np
60
60
 
61
61
 
62
- @lovely_repr(dependencies=["torch"])
63
- def torch_repr(tensor: torch.Tensor) -> LovelyStats:
62
+ @lovely_repr(dependencies=["torch"], fallback_repr=torch.Tensor.__repr__)
63
+ def torch_repr(tensor: torch.Tensor) -> LovelyStats | None:
64
64
  return {
65
65
  # Basic attributes
66
66
  "shape": tensor.shape,
File without changes