brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,42 +17,153 @@
17
17
  A ``State``-based Transformation System for Program Compilation and Augmentation
18
18
  """
19
19
 
20
- __version__ = "0.1.9"
20
+ __version__ = "0.2.0"
21
+ __versio_info__ = (0, 2, 0)
21
22
 
22
- from . import augment
23
- from . import compile
24
23
  from . import environ
25
- from . import functional
26
24
  from . import graph
27
- from . import init
28
25
  from . import mixin
29
26
  from . import nn
30
- from . import optim
31
27
  from . import random
32
- from . import surrogate
33
28
  from . import transform
34
29
  from . import typing
35
30
  from . import util
31
+ from ._error import *
32
+ from ._error import __all__ as _error_all
36
33
  from ._state import *
37
34
  from ._state import __all__ as _state_all
38
35
 
36
+ # Create deprecated module proxies with scoped APIs
37
+ from ._deprecation import create_deprecated_module_proxy
38
+
39
+ # Augment module scope
40
+ _augment_apis = {
41
+ 'GradientTransform': 'brainstate.transform._autograd',
42
+ 'grad': 'brainstate.transform._autograd',
43
+ 'vector_grad': 'brainstate.transform._autograd',
44
+ 'hessian': 'brainstate.transform._autograd',
45
+ 'jacobian': 'brainstate.transform._autograd',
46
+ 'jacrev': 'brainstate.transform._autograd',
47
+ 'jacfwd': 'brainstate.transform._autograd',
48
+ 'abstract_init': 'brainstate.transform._eval_shape',
49
+ 'vmap': 'brainstate.transform._mapping',
50
+ 'pmap': 'brainstate.transform._mapping',
51
+ 'map': 'brainstate.transform._mapping',
52
+ 'vmap_new_states': 'brainstate.transform._mapping',
53
+ 'restore_rngs': 'brainstate.transform._random',
54
+ }
55
+
56
+ augment = create_deprecated_module_proxy(
57
+ deprecated_name='brainstate.augment',
58
+ replacement_module=transform,
59
+ replacement_name='brainstate.transform',
60
+ scoped_apis=_augment_apis
61
+ )
62
+
63
+ # Compile module scope
64
+ _compile_apis = {
65
+ 'checkpoint': 'brainstate.transform._ad_checkpoint',
66
+ 'remat': 'brainstate.transform._ad_checkpoint',
67
+ 'cond': 'brainstate.transform._conditions',
68
+ 'switch': 'brainstate.transform._conditions',
69
+ 'ifelse': 'brainstate.transform._conditions',
70
+ 'jit_error_if': 'brainstate.transform._error_if',
71
+ 'jit': 'brainstate.transform._jit',
72
+ 'scan': 'brainstate.transform._loop_collect_return',
73
+ 'checkpointed_scan': 'brainstate.transform._loop_collect_return',
74
+ 'for_loop': 'brainstate.transform._loop_collect_return',
75
+ 'checkpointed_for_loop': 'brainstate.transform._loop_collect_return',
76
+ 'while_loop': 'brainstate.transform._loop_no_collection',
77
+ 'bounded_while_loop': 'brainstate.transform._loop_no_collection',
78
+ 'StatefulFunction': 'brainstate.transform._make_jaxpr',
79
+ 'make_jaxpr': 'brainstate.transform._make_jaxpr',
80
+ 'ProgressBar': 'brainstate.transform._progress_bar',
81
+ }
82
+
83
+ compile = create_deprecated_module_proxy(
84
+ deprecated_name='brainstate.compile',
85
+ replacement_module=transform,
86
+ replacement_name='brainstate.transform',
87
+ scoped_apis=_compile_apis
88
+ )
89
+
90
+ # Functional module scope - use direct attribute access from nn module
91
+ _functional_apis = {
92
+ 'weight_standardization': 'brainstate.nn._normalizations',
93
+ 'clip_grad_norm': 'brainstate.nn._others',
94
+ 'tanh': 'brainstate.nn._activations',
95
+ 'relu': 'brainstate.nn._activations',
96
+ 'squareplus': 'brainstate.nn._activations',
97
+ 'softplus': 'brainstate.nn._activations',
98
+ 'soft_sign': 'brainstate.nn._activations',
99
+ 'sigmoid': 'brainstate.nn._activations',
100
+ 'silu': 'brainstate.nn._activations',
101
+ 'swish': 'brainstate.nn._activations',
102
+ 'log_sigmoid': 'brainstate.nn._activations',
103
+ 'elu': 'brainstate.nn._activations',
104
+ 'leaky_relu': 'brainstate.nn._activations',
105
+ 'hard_tanh': 'brainstate.nn._activations',
106
+ 'celu': 'brainstate.nn._activations',
107
+ 'selu': 'brainstate.nn._activations',
108
+ 'gelu': 'brainstate.nn._activations',
109
+ 'glu': 'brainstate.nn._activations',
110
+ 'logsumexp': 'brainstate.nn._activations',
111
+ 'log_softmax': 'brainstate.nn._activations',
112
+ 'softmax': 'brainstate.nn._activations',
113
+ 'standardize': 'brainstate.nn._activations',
114
+ 'relu6': 'brainstate.nn._activations',
115
+ 'hard_sigmoid': 'brainstate.nn._activations',
116
+ 'sparse_plus': 'brainstate.nn._activations',
117
+ 'hard_silu': 'brainstate.nn._activations',
118
+ 'hard_swish': 'brainstate.nn._activations',
119
+ 'hard_shrink': 'brainstate.nn._activations',
120
+ 'rrelu': 'brainstate.nn._activations',
121
+ 'mish': 'brainstate.nn._activations',
122
+ 'soft_shrink': 'brainstate.nn._activations',
123
+ 'prelu': 'brainstate.nn._activations',
124
+ 'softmin': 'brainstate.nn._activations',
125
+ 'one_hot': 'brainstate.nn._activations',
126
+ 'sparse_sigmoid': 'brainstate.nn._activations',
127
+ }
128
+
129
+ functional = create_deprecated_module_proxy(
130
+ deprecated_name='brainstate.functional',
131
+ replacement_module=nn,
132
+ replacement_name='brainstate.nn',
133
+ scoped_apis=_functional_apis
134
+ )
135
+
136
+
137
+ def __getattr__(name):
138
+ if name in ['surrogate', 'init', 'optim']:
139
+ import warnings
140
+ warnings.warn(
141
+ f"brainstate.{name} module is deprecated and will be removed in a future version. "
142
+ f"Please use braintools.{name} instead.",
143
+ DeprecationWarning,
144
+ stacklevel=2
145
+ )
146
+ import braintools
147
+ return getattr(braintools, name)
148
+ raise AttributeError(
149
+ f'module {__name__!r} has no attribute {name!r}'
150
+ )
151
+
152
+
39
153
  __all__ = [
40
- 'augment',
41
- 'compile',
42
154
  'environ',
43
- 'functional',
44
155
  'graph',
45
- 'init',
46
156
  'mixin',
47
157
  'nn',
48
- 'optim',
49
158
  'random',
50
- 'surrogate',
159
+ 'transform',
51
160
  'typing',
52
161
  'util',
53
- 'transform',
162
+ # Deprecated modules
163
+ 'augment',
164
+ 'compile',
165
+ 'functional',
54
166
  ]
55
- __all__ = __all__ + _state_all
56
-
57
- # ----------------------- #
58
- del _state_all
167
+ __all__ = __all__ + _state_all + _error_all
168
+ del _state_all, create_deprecated_module_proxy, _augment_apis, _compile_apis, _functional_apis
169
+ del _error_all
@@ -1,4 +1,4 @@
1
- # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,12 +15,40 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
+ """
19
+ Compatibility layer for JAX version differences.
20
+
21
+ This module provides a compatibility layer to handle differences between various
22
+ versions of JAX, ensuring that BrainState works correctly across different JAX
23
+ versions. It imports the appropriate modules and functions based on the detected
24
+ JAX version and provides fallback implementations when necessary.
25
+
26
+ Key Features:
27
+ - Version-aware imports for JAX core functionality
28
+ - Compatibility wrappers for changed APIs
29
+ - Fallback implementations for deprecated functions
30
+ - Type-safe utility functions
31
+
32
+ Examples:
33
+ Basic usage:
34
+
35
+ >>> from brainstate._compatible_import import safe_map, safe_zip
36
+ >>> result = safe_map(lambda x: x * 2, [1, 2, 3])
37
+ >>> pairs = safe_zip([1, 2, 3], ['a', 'b', 'c'])
38
+
39
+ Using JAX core types:
40
+
41
+ >>> from brainstate._compatible_import import Primitive, ClosedJaxpr
42
+ >>> # These imports work across different JAX versions
43
+ """
18
44
 
19
45
  from contextlib import contextmanager
20
46
  from functools import partial
21
47
  from typing import Iterable, Hashable, TypeVar, Callable
22
48
 
23
49
  import jax
50
+ from jax.core import get_aval, Tracer
51
+ from saiunit._compatible_import import wrap_init
24
52
 
25
53
  __all__ = [
26
54
  'ClosedJaxpr',
@@ -36,6 +64,12 @@ __all__ = [
36
64
  'wraps',
37
65
  'Device',
38
66
  'wrap_init',
67
+ 'Var',
68
+ 'JaxprEqn',
69
+ 'Jaxpr',
70
+ 'Literal',
71
+
72
+ 'make_iota', 'to_elt', 'BatchTracer', 'BatchTrace',
39
73
  ]
40
74
 
41
75
  T = TypeVar("T")
@@ -43,24 +77,45 @@ T1 = TypeVar("T1")
43
77
  T2 = TypeVar("T2")
44
78
  T3 = TypeVar("T3")
45
79
 
46
- from saiunit._compatible_import import wrap_init
47
-
48
- from jax.core import get_aval, Tracer
49
-
50
80
  if jax.__version_info__ < (0, 5, 0):
51
81
  from jax.lib.xla_client import Device
52
82
  else:
53
83
  from jax import Device
54
84
 
85
+ if jax.__version_info__ < (0, 7, 1):
86
+ from jax.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
87
+ else:
88
+ from jax._src.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
89
+
55
90
  if jax.__version_info__ < (0, 4, 38):
56
91
  from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
92
+ from jax.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
57
93
  else:
58
94
  from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
95
+ from jax.extend.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
59
96
  from jax.core import trace_ctx
60
97
 
61
98
 
62
99
  @contextmanager
63
100
  def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
101
+ """
102
+ Context manager to temporarily extend the JAX axis environment.
103
+
104
+ Extends the current JAX axis environment with new named axes for
105
+ vectorized computations, then restores the previous environment.
106
+
107
+ Args:
108
+ name_size_pairs: Iterable of (name, size) tuples specifying
109
+ the named axes to add to the environment.
110
+
111
+ Yields:
112
+ None: Context with extended axis environment.
113
+
114
+ Examples:
115
+ >>> with extend_axis_env_nd([('batch', 32), ('seq', 128)]):
116
+ ... # Code using vectorized operations with named axes
117
+ ... pass
118
+ """
64
119
  prev = trace_ctx.axis_env
65
120
  try:
66
121
  trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
@@ -73,6 +128,29 @@ if jax.__version_info__ < (0, 6, 0):
73
128
 
74
129
  else:
75
130
  def safe_map(f, *args):
131
+ """
132
+ Map a function over multiple sequences with length checking.
133
+
134
+ Applies a function to corresponding elements from multiple sequences,
135
+ ensuring all sequences have the same length.
136
+
137
+ Args:
138
+ f: Function to apply to elements from each sequence.
139
+ *args: Variable number of sequences to map over.
140
+
141
+ Returns:
142
+ list: Results of applying f to corresponding elements.
143
+
144
+ Raises:
145
+ AssertionError: If input sequences have different lengths.
146
+
147
+ Examples:
148
+ >>> safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5, 6])
149
+ [5, 7, 9]
150
+
151
+ >>> safe_map(str.upper, ['a', 'b', 'c'])
152
+ ['A', 'B', 'C']
153
+ """
76
154
  args = list(map(list, args))
77
155
  n = len(args[0])
78
156
  for arg in args[1:]:
@@ -81,6 +159,28 @@ else:
81
159
 
82
160
 
83
161
  def safe_zip(*args):
162
+ """
163
+ Zip multiple sequences with length checking.
164
+
165
+ Combines corresponding elements from multiple sequences into tuples,
166
+ ensuring all sequences have the same length.
167
+
168
+ Args:
169
+ *args: Variable number of sequences to zip together.
170
+
171
+ Returns:
172
+ list: List of tuples containing corresponding elements.
173
+
174
+ Raises:
175
+ AssertionError: If input sequences have different lengths.
176
+
177
+ Examples:
178
+ >>> safe_zip([1, 2, 3], ['a', 'b', 'c'])
179
+ [(1, 'a'), (2, 'b'), (3, 'c')]
180
+
181
+ >>> safe_zip([1, 2], [3, 4], [5, 6])
182
+ [(1, 3, 5), (2, 4, 6)]
183
+ """
84
184
  args = list(map(list, args))
85
185
  n = len(args[0])
86
186
  for arg in args[1:]:
@@ -89,7 +189,32 @@ else:
89
189
 
90
190
 
91
191
  def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
92
- """Unzip sequence of length-2 tuples into two tuples."""
192
+ """
193
+ Unzip sequence of length-2 tuples into two tuples.
194
+
195
+ Takes an iterable of 2-tuples and separates them into two tuples
196
+ containing the first and second elements respectively.
197
+
198
+ Args:
199
+ xys: Iterable of 2-tuples to unzip.
200
+
201
+ Returns:
202
+ tuple: A 2-tuple containing:
203
+ - Tuple of all first elements
204
+ - Tuple of all second elements
205
+
206
+ Examples:
207
+ >>> pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
208
+ >>> nums, letters = unzip2(pairs)
209
+ >>> nums
210
+ (1, 2, 3)
211
+ >>> letters
212
+ ('a', 'b', 'c')
213
+
214
+ Notes:
215
+ We deliberately don't use zip(*xys) because it is lazily evaluated,
216
+ is too permissive about inputs, and does not guarantee a length-2 output.
217
+ """
93
218
  # Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
94
219
  # is too permissive about inputs, and does not guarantee a length-2 output.
95
220
  xs: list[T1] = []
@@ -101,6 +226,30 @@ else:
101
226
 
102
227
 
103
228
  def fun_name(fun: Callable):
229
+ """
230
+ Extract the name of a function, handling special cases.
231
+
232
+ Attempts to get the name of a function, with special handling for
233
+ partial functions and fallback for unnamed functions.
234
+
235
+ Args:
236
+ fun: The function to get the name from.
237
+
238
+ Returns:
239
+ str: The function name, or "<unnamed function>" if no name available.
240
+
241
+ Examples:
242
+ >>> def my_function():
243
+ ... pass
244
+ >>> fun_name(my_function)
245
+ 'my_function'
246
+
247
+ >>> from functools import partial
248
+ >>> add = lambda x, y: x + y
249
+ >>> add_one = partial(add, 1)
250
+ >>> fun_name(add_one)
251
+ '<lambda>'
252
+ """
104
253
  name = getattr(fun, "__name__", None)
105
254
  if name is not None:
106
255
  return name
@@ -117,8 +266,34 @@ else:
117
266
  **kwargs,
118
267
  ) -> Callable[[T], T]:
119
268
  """
120
- Like functools.wraps, but with finer-grained control over the name and docstring
121
- of the resulting function.
269
+ Enhanced function wrapper with fine-grained control.
270
+
271
+ Like functools.wraps, but provides more control over the name and docstring
272
+ of the resulting function. Useful for creating custom decorators.
273
+
274
+ Args:
275
+ wrapped: The function being wrapped.
276
+ namestr: Optional format string for the wrapper function name.
277
+ Can use {fun} placeholder for the original function name.
278
+ docstr: Optional format string for the wrapper function docstring.
279
+ Can use {fun}, {doc}, and other kwargs as placeholders.
280
+ **kwargs: Additional keyword arguments for format string substitution.
281
+
282
+ Returns:
283
+ Callable: A decorator function that applies the wrapping.
284
+
285
+ Examples:
286
+ >>> def my_decorator(func):
287
+ ... @wraps(func, namestr="decorated_{fun}")
288
+ ... def wrapper(*args, **kwargs):
289
+ ... return func(*args, **kwargs)
290
+ ... return wrapper
291
+
292
+ >>> @my_decorator
293
+ ... def example():
294
+ ... pass
295
+ >>> example.__name__
296
+ 'decorated_example'
122
297
  """
123
298
 
124
299
  def wrapper(fun: T) -> T:
@@ -141,8 +316,25 @@ else:
141
316
 
142
317
 
143
318
  def to_concrete_aval(aval):
319
+ """
320
+ Convert an abstract value to its concrete representation.
321
+
322
+ Takes an abstract value and attempts to convert it to a concrete value,
323
+ handling JAX Tracer objects appropriately.
324
+
325
+ Args:
326
+ aval: The abstract value to convert.
327
+
328
+ Returns:
329
+ The concrete value representation, or the original aval if already concrete.
330
+
331
+ Examples:
332
+ >>> import jax.numpy as jnp
333
+ >>> arr = jnp.array([1, 2, 3])
334
+ >>> concrete = to_concrete_aval(arr)
335
+ # Returns the concrete array value
336
+ """
144
337
  aval = get_aval(aval)
145
338
  if isinstance(aval, Tracer):
146
339
  return aval.to_concrete_value()
147
340
  return aval
148
-