nshutils 0.22.4__py3-none-any.whl → 0.22.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshutils/actsave/_saver.py +5 -5
- nshutils/lovely/_base.py +9 -6
- nshutils/lovely/jax_.py +6 -2
- nshutils/lovely/numpy_.py +6 -2
- nshutils/lovely/torch_.py +2 -2
- {nshutils-0.22.4.dist-info → nshutils-0.22.6.dist-info}/METADATA +1 -1
- {nshutils-0.22.4.dist-info → nshutils-0.22.6.dist-info}/RECORD +8 -8
- {nshutils-0.22.4.dist-info → nshutils-0.22.6.dist-info}/WHEEL +0 -0
nshutils/actsave/_saver.py
CHANGED
@@ -9,7 +9,7 @@ from dataclasses import dataclass
|
|
9
9
|
from functools import wraps
|
10
10
|
from logging import getLogger
|
11
11
|
from pathlib import Path
|
12
|
-
from typing import TYPE_CHECKING, Generic, Literal, Union, cast, overload
|
12
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast, overload
|
13
13
|
|
14
14
|
import numpy as np
|
15
15
|
from typing_extensions import Never, ParamSpec, TypeAliasType, TypeVar, override
|
@@ -36,8 +36,9 @@ else:
|
|
36
36
|
|
37
37
|
log = getLogger(__name__)
|
38
38
|
|
39
|
+
# Updated to include Any for arbitrary types
|
39
40
|
Value = TypeAliasType(
|
40
|
-
"Value", Union[int, float, complex, bool, str, np.ndarray, Tensor, None]
|
41
|
+
"Value", Union[int, float, complex, bool, str, np.ndarray, Tensor, Any, None]
|
41
42
|
)
|
42
43
|
ValueOrLambda = TypeAliasType("ValueOrLambda", Union[Value, Callable[..., Value]])
|
43
44
|
|
@@ -65,9 +66,8 @@ def _to_numpy(activation: Value) -> np.ndarray:
|
|
65
66
|
activation_ = activation_.float()
|
66
67
|
return activation_.cpu().numpy()
|
67
68
|
else:
|
68
|
-
|
69
|
-
|
70
|
-
return activation
|
69
|
+
# Handle arbitrary objects using numpy object dtype
|
70
|
+
return np.array(activation, dtype=object)
|
71
71
|
|
72
72
|
|
73
73
|
T = TypeVar("T", infer_variance=True)
|
nshutils/lovely/_base.py
CHANGED
@@ -17,7 +17,7 @@ P = ParamSpec("P")
|
|
17
17
|
|
18
18
|
LovelyStatsFn = TypeAliasType(
|
19
19
|
"LovelyStatsFn",
|
20
|
-
Callable[[TArray], LovelyStats],
|
20
|
+
Callable[[TArray], LovelyStats | None],
|
21
21
|
type_params=(TArray,),
|
22
22
|
)
|
23
23
|
LovelyReprFn = TypeAliasType(
|
@@ -39,14 +39,14 @@ def _find_missing_deps(dependencies: list[str]):
|
|
39
39
|
return missing_deps
|
40
40
|
|
41
41
|
|
42
|
-
def lovely_repr(dependencies: list[str]):
|
42
|
+
def lovely_repr(dependencies: list[str], fallback_repr: Callable[[TArray], str]):
|
43
43
|
"""
|
44
44
|
Decorator to create a lovely representation function for an array.
|
45
45
|
|
46
46
|
Args:
|
47
47
|
dependencies: List of dependencies to check before running the function.
|
48
48
|
If any dependency is not available, the function will not run.
|
49
|
-
|
49
|
+
fallback_repr: A function that takes an array and returns its fallback representation.
|
50
50
|
Returns:
|
51
51
|
A decorator function that takes a function and returns a lovely representation function.
|
52
52
|
|
@@ -61,7 +61,8 @@ def lovely_repr(dependencies: list[str]):
|
|
61
61
|
Decorator to create a lovely representation function for an array.
|
62
62
|
|
63
63
|
Args:
|
64
|
-
array_stats_fn: A function that takes an array and returns its stats
|
64
|
+
array_stats_fn: A function that takes an array and returns its stats,
|
65
|
+
or `None` if the array is not supported.
|
65
66
|
|
66
67
|
Returns:
|
67
68
|
A function that takes an array and returns its lovely representation.
|
@@ -74,9 +75,11 @@ def lovely_repr(dependencies: list[str]):
|
|
74
75
|
f"Missing dependencies: {', '.join(missing_deps)}. "
|
75
76
|
"Skipping lovely representation."
|
76
77
|
)
|
77
|
-
return
|
78
|
+
return fallback_repr(array)
|
79
|
+
|
80
|
+
if (stats := array_stats_fn(array)) is None:
|
81
|
+
return fallback_repr(array)
|
78
82
|
|
79
|
-
stats = array_stats_fn(array)
|
80
83
|
return format_tensor_stats(stats)
|
81
84
|
|
82
85
|
return wrapper
|
nshutils/lovely/jax_.py
CHANGED
@@ -53,10 +53,14 @@ def _device(array: jax.Array) -> str:
|
|
53
53
|
return f"{device.platform}:{device.id}"
|
54
54
|
|
55
55
|
|
56
|
-
@lovely_repr(dependencies=["jax"])
|
57
|
-
def jax_repr(array: jax.Array) -> LovelyStats:
|
56
|
+
@lovely_repr(dependencies=["jax"], fallback_repr=jax.Array.__repr__)
|
57
|
+
def jax_repr(array: jax.Array) -> LovelyStats | None:
|
58
58
|
import jax.numpy as jnp
|
59
59
|
|
60
|
+
# For dtypes like `object` or `str`, we let the fallback repr handle it
|
61
|
+
if not jnp.issubdtype(array.dtype, jnp.number):
|
62
|
+
return None
|
63
|
+
|
60
64
|
return {
|
61
65
|
# Basic attributes
|
62
66
|
"shape": array.shape,
|
nshutils/lovely/numpy_.py
CHANGED
@@ -51,8 +51,12 @@ def _dtype_str(array: np.ndarray) -> str:
|
|
51
51
|
return dtype_base
|
52
52
|
|
53
53
|
|
54
|
-
@lovely_repr(dependencies=["numpy"])
|
55
|
-
def numpy_repr(array: np.ndarray) -> LovelyStats:
|
54
|
+
@lovely_repr(dependencies=["numpy"], fallback_repr=np.array_repr)
|
55
|
+
def numpy_repr(array: np.ndarray) -> LovelyStats | None:
|
56
|
+
# For dtypes like `object` or `str`, we let the fallback repr handle it
|
57
|
+
if not np.issubdtype(array.dtype, np.number):
|
58
|
+
return None
|
59
|
+
|
56
60
|
return {
|
57
61
|
# Basic attributes
|
58
62
|
"shape": array.shape,
|
nshutils/lovely/torch_.py
CHANGED
@@ -59,8 +59,8 @@ def _to_np(tensor: torch.Tensor) -> np.ndarray:
|
|
59
59
|
return t_np
|
60
60
|
|
61
61
|
|
62
|
-
@lovely_repr(dependencies=["torch"])
|
63
|
-
def torch_repr(tensor: torch.Tensor) -> LovelyStats:
|
62
|
+
@lovely_repr(dependencies=["torch"], fallback_repr=torch.Tensor.__repr__)
|
63
|
+
def torch_repr(tensor: torch.Tensor) -> LovelyStats | None:
|
64
64
|
return {
|
65
65
|
# Basic attributes
|
66
66
|
"shape": tensor.shape,
|
@@ -2,21 +2,21 @@ nshutils/__init__.py,sha256=AFx1d5k34MyJ2kCHQL5vrZB8GDp2nYUaIUEjszSa25I,477
|
|
2
2
|
nshutils/__init__.pyi,sha256=ICbY2_XBAlXIVOGyK4PQpatmlUFHHc5-bqM4sfFZoAY,613
|
3
3
|
nshutils/actsave/__init__.py,sha256=hAVsog9d1g3_rQN1TRslrl6sK1PhCGbjy8PPUAmJI58,203
|
4
4
|
nshutils/actsave/_loader.py,sha256=mof3HezUNvLliz7macstX6ewXW05L0Mtv3zJyrbmImg,4640
|
5
|
-
nshutils/actsave/_saver.py,sha256=
|
5
|
+
nshutils/actsave/_saver.py,sha256=GyoJpeJIIG_Y2Hr3MTXylTM6YNyoInf2mvKs0foU7no,10528
|
6
6
|
nshutils/collections.py,sha256=QWGyANmo4Efq4XRNHDSTE9tRLStwEZHGwE0ATHR-Vqo,5233
|
7
7
|
nshutils/display.py,sha256=Ge63yllx7gi-MKL3mKQeQ5doql_nj56-o5aoTVmusDg,1473
|
8
8
|
nshutils/logging.py,sha256=78pv3-I_gmbKSf5_mYYBr6_H4GNBGErghAdhH9wfYIc,2205
|
9
9
|
nshutils/lovely/__init__.py,sha256=gbWMNs7xfK1CiNdkHvfH0KcyaGjdZ8_WUBGfaEUDN4I,451
|
10
|
-
nshutils/lovely/_base.py,sha256=
|
10
|
+
nshutils/lovely/_base.py,sha256=SH8MY6wbucB7us4pU1cekoU2glanutEB8vDrLDkLEL0,4544
|
11
11
|
nshutils/lovely/_monkey_patch_all.py,sha256=zgMupp2Wc_O9R3arl-BAIePpvQSi6TCeshGMaui-Cc8,1986
|
12
12
|
nshutils/lovely/config.py,sha256=lVNMuU1oUvsYlGN0Sn-m6iOLbJIchVnWDpyHm09nWo8,1224
|
13
|
-
nshutils/lovely/jax_.py,sha256=
|
14
|
-
nshutils/lovely/numpy_.py,sha256=
|
15
|
-
nshutils/lovely/torch_.py,sha256=
|
13
|
+
nshutils/lovely/jax_.py,sha256=J64CtP2mx_yPQ7ZiLKSrqEJfPDBBbx3j6o4zKc1f1eg,2404
|
14
|
+
nshutils/lovely/numpy_.py,sha256=iaTUQpamudyWd7pTBJyj_hcS9X8H53muCw2IYY9ZcUM,3345
|
15
|
+
nshutils/lovely/torch_.py,sha256=Jt0c7ysfrbgxORhL6fpbxCnfBDoYaTtTyc935yCJ7H0,2682
|
16
16
|
nshutils/lovely/utils.py,sha256=2ksT5YGVViFuWc8jSkwVCsABripJmyVJdEDDH7aab70,10459
|
17
17
|
nshutils/snoop.py,sha256=7d7_Q5sJmINL1J29wcnxEvpV95zvZYNoVn5frCq-rww,7393
|
18
18
|
nshutils/typecheck.py,sha256=Gi7xtfilN_UwZ1FTFqBVKDhcQzBEDonVxIv3bUj-uXY,5582
|
19
19
|
nshutils/util.py,sha256=tx-XiRbOrpafV3OkJDE5IVFtzn3kN7uSZ8FkMor0H5c,2845
|
20
|
-
nshutils-0.22.
|
21
|
-
nshutils-0.22.
|
22
|
-
nshutils-0.22.
|
20
|
+
nshutils-0.22.6.dist-info/METADATA,sha256=cclgVNUapsycUT65AJHSQXQpej_WEwwhRIl9vwrLMbY,4406
|
21
|
+
nshutils-0.22.6.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
22
|
+
nshutils-0.22.6.dist-info/RECORD,,
|
File without changes
|