brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -13,5 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- def is_instance(x, cls):
17
- assert isinstance(x, cls), 'The input should be an instance of {}!'.format(cls)
16
+ # This module is going to be deleted in the future (near 2025-06).
brainstate/typing.py CHANGED
@@ -13,73 +13,96 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from __future__ import annotations
16
17
 
18
+ import builtins
17
19
  import functools as ft
20
+ import importlib
18
21
  import inspect
19
- import typing
20
- from typing import Sequence, Protocol, Union, Any, Generic, TypeVar, Tuple
21
22
 
22
- import brainunit as bu
23
+ import brainunit as u
23
24
  import jax
24
25
  import numpy as np
25
26
 
27
+ tp = importlib.import_module("typing")
28
+
26
29
  __all__ = [
27
- 'PyTree',
28
- 'Size',
29
- 'Axes',
30
- 'SeedOrKey',
31
- 'ArrayLike',
32
- 'DType',
33
- 'DTypeLike',
30
+ 'PathParts',
31
+ 'Predicate',
32
+ 'Filter',
33
+ 'PyTree',
34
+ 'Size',
35
+ 'Axes',
36
+ 'SeedOrKey',
37
+ 'ArrayLike',
38
+ 'DType',
39
+ 'DTypeLike',
40
+ 'Missing',
34
41
  ]
35
42
 
36
- _T = TypeVar("_T")
43
+ K = tp.TypeVar('K')
44
+
45
+
46
+ @tp.runtime_checkable
47
+ class Key(tp.Hashable, tp.Protocol):
48
+ def __lt__(self: K, value: K, /) -> bool:
49
+ ...
50
+
37
51
 
38
- _Annotation = TypeVar("_Annotation")
52
+ Ellipsis = builtins.ellipsis if tp.TYPE_CHECKING else tp.Any
39
53
 
54
+ PathParts = tp.Tuple[Key, ...]
55
+ Predicate = tp.Callable[[PathParts, tp.Any], bool]
56
+ FilterLiteral = tp.Union[type, str, Predicate, bool, Ellipsis, None]
57
+ Filter = tp.Union[FilterLiteral, tp.Tuple['Filter', ...], tp.List['Filter']]
40
58
 
41
- class _Array(Generic[_Annotation]):
42
- pass
59
+ _T = tp.TypeVar("_T")
60
+
61
+ _Annotation = tp.TypeVar("_Annotation")
62
+
63
+
64
+ class _Array(tp.Generic[_Annotation]):
65
+ pass
43
66
 
44
67
 
45
68
  _Array.__module__ = "builtins"
46
69
 
47
70
 
48
- def _item_to_str(item: Union[str, type, slice]) -> str:
49
- if isinstance(item, slice):
50
- if item.step is not None:
51
- raise NotImplementedError
52
- return _item_to_str(item.start) + ": " + _item_to_str(item.stop)
53
- elif item is ...:
54
- return "..."
55
- elif inspect.isclass(item):
56
- return item.__name__
57
- else:
58
- return repr(item)
71
+ def _item_to_str(item: tp.Union[str, type, slice]) -> str:
72
+ if isinstance(item, slice):
73
+ if item.step is not None:
74
+ raise NotImplementedError
75
+ return _item_to_str(item.start) + ": " + _item_to_str(item.stop)
76
+ elif item is ...:
77
+ return "..."
78
+ elif inspect.isclass(item):
79
+ return item.__name__
80
+ else:
81
+ return repr(item)
59
82
 
60
83
 
61
84
  def _maybe_tuple_to_str(
62
- item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
85
+ item: tp.Union[str, type, slice, tp.Tuple[tp.Union[str, type, slice], ...]]
63
86
  ) -> str:
64
- if isinstance(item, tuple):
65
- if len(item) == 0:
66
- # Explicit brackets
67
- return "()"
87
+ if isinstance(item, tuple):
88
+ if len(item) == 0:
89
+ # Explicit brackets
90
+ return "()"
91
+ else:
92
+ # No brackets
93
+ return ", ".join([_item_to_str(i) for i in item])
68
94
  else:
69
- # No brackets
70
- return ", ".join([_item_to_str(i) for i in item])
71
- else:
72
- return _item_to_str(item)
95
+ return _item_to_str(item)
73
96
 
74
97
 
75
98
  class Array:
76
- def __class_getitem__(cls, item):
77
- class X:
78
- pass
99
+ def __class_getitem__(cls, item):
100
+ class X:
101
+ pass
79
102
 
80
- X.__module__ = "builtins"
81
- X.__qualname__ = _maybe_tuple_to_str(item)
82
- return _Array[X]
103
+ X.__module__ = "builtins"
104
+ X.__qualname__ = _maybe_tuple_to_str(item)
105
+ return _Array[X]
83
106
 
84
107
 
85
108
  # Same __module__ trick here again. (So that we get the correct display when
@@ -89,8 +112,8 @@ class Array:
89
112
  Array.__module__ = "builtins"
90
113
 
91
114
 
92
- class _FakePyTree(Generic[_T]):
93
- pass
115
+ class _FakePyTree(tp.Generic[_T]):
116
+ pass
94
117
 
95
118
 
96
119
  _FakePyTree.__name__ = "PyTree"
@@ -99,84 +122,84 @@ _FakePyTree.__module__ = "builtins"
99
122
 
100
123
 
101
124
  class _MetaPyTree(type):
102
- def __call__(self, *args, **kwargs):
103
- raise RuntimeError("PyTree cannot be instantiated")
104
-
105
- # Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
106
- # the custom __instancecheck__ that we want.
107
- # We can't add that __instancecheck__ via subclassing, e.g.
108
- # type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms
109
- # isn't allowed.
110
- # Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
111
- # has __module__ "types", e.g. we get types.PyTree[int].
112
- @ft.lru_cache(maxsize=None)
113
- def __getitem__(cls, item):
114
- if isinstance(item, tuple):
115
- if len(item) == 2:
116
-
117
- class X(PyTree):
118
- leaftype = item[0]
119
- structure = item[1].strip()
120
-
121
- if not isinstance(X.structure, str):
122
- raise ValueError(
123
- "The structure annotation `struct` in "
124
- "`brainstate.typing.PyTree[leaftype, struct]` must be be a string, "
125
- f"e.g. `brainstate.typing.PyTree[leaftype, 'T']`. Got '{X.structure}'."
126
- )
127
- pieces = X.structure.split()
128
- if len(pieces) == 0:
129
- raise ValueError(
130
- "The string `struct` in `brainstate.typing.PyTree[leaftype, struct]` "
131
- "cannot be the empty string."
132
- )
133
- for piece_index, piece in enumerate(pieces):
134
- if (piece_index == 0) or (piece_index == len(pieces) - 1):
135
- if piece == "...":
136
- continue
137
- if not piece.isidentifier():
138
- raise ValueError(
139
- "The string `struct` in "
140
- "`brainstate.typing.PyTree[leaftype, struct]` must be be a "
141
- "whitespace-separated sequence of identifiers, e.g. "
142
- "`brainstate.typing.PyTree[leaftype, 'T']` or "
143
- "`brainstate.typing.PyTree[leaftype, 'foo bar']`.\n"
144
- "(Here, 'identifier' is used in the same sense as in "
145
- "regular Python, i.e. a valid variable name.)\n"
146
- f"Got piece '{piece}' in overall structure '{X.structure}'."
147
- )
148
- name = str(_FakePyTree[item[0]])[:-1] + ', "' + item[1].strip() + '"]'
149
- else:
150
- raise ValueError(
151
- "The subscript `foo` in `brainstate.typing.PyTree[foo]` must either be a "
152
- "leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and "
153
- "structure, e.g. `PyTree[int, 'T']`. Received a tuple of length "
154
- f"{len(item)}."
155
- )
156
- else:
157
- name = str(_FakePyTree[item])
158
-
159
- class X(PyTree):
160
- leaftype = item
161
- structure = None
162
-
163
- X.__name__ = name
164
- X.__qualname__ = name
165
- if getattr(typing, "GENERATING_DOCUMENTATION", False):
166
- X.__module__ = "builtins"
167
- else:
168
- X.__module__ = "brainstate.typing"
169
- return X
125
+ def __call__(self, *args, **kwargs):
126
+ raise RuntimeError("PyTree cannot be instantiated")
127
+
128
+ # Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
129
+ # the custom __instancecheck__ that we want.
130
+ # We can't add that __instancecheck__ via subclassing, e.g.
131
+ # type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms
132
+ # isn't allowed.
133
+ # Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
134
+ # has __module__ "types", e.g. we get types.PyTree[int].
135
+ @ft.lru_cache(maxsize=None)
136
+ def __getitem__(cls, item):
137
+ if isinstance(item, tuple):
138
+ if len(item) == 2:
139
+
140
+ class X(PyTree):
141
+ leaftype = item[0]
142
+ structure = item[1].strip()
143
+
144
+ if not isinstance(X.structure, str):
145
+ raise ValueError(
146
+ "The structure annotation `struct` in "
147
+ "`brainstate.typing.PyTree[leaftype, struct]` must be be a string, "
148
+ f"e.g. `brainstate.typing.PyTree[leaftype, 'T']`. Got '{X.structure}'."
149
+ )
150
+ pieces = X.structure.split()
151
+ if len(pieces) == 0:
152
+ raise ValueError(
153
+ "The string `struct` in `brainstate.typing.PyTree[leaftype, struct]` "
154
+ "cannot be the empty string."
155
+ )
156
+ for piece_index, piece in enumerate(pieces):
157
+ if (piece_index == 0) or (piece_index == len(pieces) - 1):
158
+ if piece == "...":
159
+ continue
160
+ if not piece.isidentifier():
161
+ raise ValueError(
162
+ "The string `struct` in "
163
+ "`brainstate.typing.PyTree[leaftype, struct]` must be be a "
164
+ "whitespace-separated sequence of identifiers, e.g. "
165
+ "`brainstate.typing.PyTree[leaftype, 'T']` or "
166
+ "`brainstate.typing.PyTree[leaftype, 'foo bar']`.\n"
167
+ "(Here, 'identifier' is used in the same sense as in "
168
+ "regular Python, i.e. a valid variable name.)\n"
169
+ f"Got piece '{piece}' in overall structure '{X.structure}'."
170
+ )
171
+ name = str(_FakePyTree[item[0]])[:-1] + ', "' + item[1].strip() + '"]'
172
+ else:
173
+ raise ValueError(
174
+ "The subscript `foo` in `brainstate.typing.PyTree[foo]` must either be a "
175
+ "leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and "
176
+ "structure, e.g. `PyTree[int, 'T']`. Received a tuple of length "
177
+ f"{len(item)}."
178
+ )
179
+ else:
180
+ name = str(_FakePyTree[item])
181
+
182
+ class X(PyTree):
183
+ leaftype = item
184
+ structure = None
185
+
186
+ X.__name__ = name
187
+ X.__qualname__ = name
188
+ if getattr(tp, "GENERATING_DOCUMENTATION", False):
189
+ X.__module__ = "builtins"
190
+ else:
191
+ X.__module__ = "brainstate.typing"
192
+ return X
170
193
 
171
194
 
172
195
  # Can't do `class PyTree(Generic[_T]): ...` because we need to override the
173
196
  # instancecheck for PyTree[foo], but subclassing
174
197
  # `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed.
175
198
  PyTree = _MetaPyTree("PyTree", (), {})
176
- if getattr(typing, "GENERATING_DOCUMENTATION", False):
177
- PyTree.__module__ = "builtins"
199
+ if getattr(tp, "GENERATING_DOCUMENTATION", False):
200
+ PyTree.__module__ = "builtins"
178
201
  else:
179
- PyTree.__module__ = "brainstate.typing"
202
+ PyTree.__module__ = "brainstate.typing"
180
203
  PyTree.__doc__ = """Represents a PyTree.
181
204
 
182
205
  Annotations of the following sorts are supported:
@@ -231,9 +254,9 @@ f. A structure can end with a `...`, to denote that the PyTree must be a prefix
231
254
  cases, all named pieces must already have been seen and their structures bound.
232
255
  """ # noqa: E501
233
256
 
234
- Size = Union[int, Sequence[int]]
235
- Axes = Union[int, Sequence[int]]
236
- SeedOrKey = Union[int, jax.Array, np.ndarray]
257
+ Size = tp.Union[int, tp.Sequence[int]]
258
+ Axes = tp.Union[int, tp.Sequence[int]]
259
+ SeedOrKey = tp.Union[int, jax.Array, np.ndarray]
237
260
 
238
261
  # --- Array --- #
239
262
 
@@ -241,12 +264,12 @@ SeedOrKey = Union[int, jax.Array, np.ndarray]
241
264
  # standard JAX array (i.e. not including future non-standard array types like
242
265
  # KeyArray and BInt). It's different than np.typing.ArrayLike in that it doesn't
243
266
  # accept arbitrary sequences, nor does it accept string data.
244
- ArrayLike = Union[
245
- jax.Array, # JAX array type
246
- np.ndarray, # NumPy array type
247
- np.bool_, np.number, # NumPy scalar types
248
- bool, int, float, complex, # Python scalar types
249
- bu.Quantity, # Quantity
267
+ ArrayLike = tp.Union[
268
+ jax.Array, # JAX array type
269
+ np.ndarray, # NumPy array type
270
+ np.bool_, np.number, # NumPy scalar types
271
+ bool, int, float, complex, # Python scalar types
272
+ u.Quantity, # Quantity
250
273
  ]
251
274
 
252
275
  # --- Dtype --- #
@@ -255,9 +278,9 @@ ArrayLike = Union[
255
278
  DType = np.dtype
256
279
 
257
280
 
258
- class SupportsDType(Protocol):
259
- @property
260
- def dtype(self) -> DType: ...
281
+ class SupportsDType(tp.Protocol):
282
+ @property
283
+ def dtype(self) -> DType: ...
261
284
 
262
285
 
263
286
  # DTypeLike is meant to annotate inputs to np.dtype that return
@@ -265,9 +288,13 @@ class SupportsDType(Protocol):
265
288
  # because JAX doesn't support objects or structured dtypes.
266
289
  # Unlike np.typing.DTypeLike, we exclude None, and instead require
267
290
  # explicit annotations when None is acceptable.
268
- DTypeLike = Union[
269
- str, # like 'float32', 'int32'
270
- type[Any], # like np.float32, np.int32, float, int
271
- np.dtype, # like np.dtype('float32'), np.dtype('int32')
272
- SupportsDType, # like jnp.float32, jnp.int32
291
+ DTypeLike = tp.Union[
292
+ str, # like 'float32', 'int32'
293
+ type[tp.Any], # like np.float32, np.int32, float, int
294
+ np.dtype, # like np.dtype('float32'), np.dtype('int32')
295
+ SupportsDType, # like jnp.float32, jnp.int32
273
296
  ]
297
+
298
+
299
+ class Missing:
300
+ pass
@@ -0,0 +1,52 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ._dict import *
17
+ from ._dict import __all__ as _mapping_all
18
+ from ._error import *
19
+ from ._error import __all__ as _error_all
20
+ from ._filter import *
21
+ from ._filter import __all__ as _filter_all
22
+ from ._others import *
23
+ from ._others import __all__ as _others_all
24
+ from ._pretty_repr import *
25
+ from ._pretty_repr import __all__ as _pretty_repr_all
26
+ from ._scaling import *
27
+ from ._scaling import __all__ as _mem_scale_all
28
+ from ._struct import *
29
+ from ._struct import __all__ as _struct_all
30
+ from ._visualization import *
31
+ from ._visualization import __all__ as _visualization_all
32
+
33
+ __all__ = (
34
+ _others_all
35
+ + _mem_scale_all
36
+ + _filter_all
37
+ + _pretty_repr_all
38
+ + _struct_all
39
+ + _error_all
40
+ + _mapping_all
41
+ + _visualization_all
42
+ )
43
+ del (
44
+ _others_all,
45
+ _mem_scale_all,
46
+ _filter_all,
47
+ _pretty_repr_all,
48
+ _struct_all,
49
+ _error_all,
50
+ _mapping_all,
51
+ _visualization_all,
52
+ )
@@ -0,0 +1,100 @@
1
+ # The file is adapted from the Flax library (https://github.com/google/flax).
2
+ # The credit should go to the Flax authors.
3
+ #
4
+ # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from __future__ import annotations
19
+
20
+ import dataclasses
21
+ from typing import Any, TypeVar, Protocol, Generic
22
+
23
+ import jax
24
+
25
+ __all__ = [
26
+ 'DelayedAccessor',
27
+ 'CallableProxy',
28
+ 'ApplyCaller',
29
+ ]
30
+
31
+ A = TypeVar('A', covariant=True) # type: ignore[not-supported-yet]
32
+
33
+
34
+ def _identity(x):
35
+ return x
36
+
37
+
38
+ @dataclasses.dataclass(frozen=True)
39
+ class GetItem:
40
+ key: Any
41
+
42
+
43
+ @dataclasses.dataclass(frozen=True)
44
+ class GetAttr:
45
+ name: str
46
+
47
+
48
+ @dataclasses.dataclass(frozen=True)
49
+ class DelayedAccessor:
50
+ actions: tuple[GetItem | GetAttr, ...] = ()
51
+
52
+ def __call__(self, x):
53
+ for action in self.actions:
54
+ if isinstance(action, GetItem):
55
+ x = x[action.key]
56
+ elif isinstance(action, GetAttr):
57
+ x = getattr(x, action.name)
58
+ return x
59
+
60
+ def __getattr__(self, name):
61
+ return DelayedAccessor(self.actions + (GetAttr(name),))
62
+
63
+ def __getitem__(self, key):
64
+ return DelayedAccessor(self.actions + (GetItem(key),))
65
+
66
+
67
+ jax.tree_util.register_static(DelayedAccessor)
68
+
69
+
70
+ class _AccessorCall(Protocol):
71
+ def __call__(self, accessor: DelayedAccessor, /, *args, **kwargs) -> Any:
72
+ ...
73
+
74
+
75
+ class CallableProxy:
76
+ def __init__(
77
+ self, fun: _AccessorCall, accessor: DelayedAccessor | None = None
78
+ ):
79
+ self._callable = fun
80
+ self._accessor = DelayedAccessor() if accessor is None else accessor
81
+
82
+ def __call__(self, *args, **kwargs):
83
+ return self._callable(self._accessor, *args, **kwargs)
84
+
85
+ def __getattr__(self, name) -> CallableProxy:
86
+ return CallableProxy(self._callable, getattr(self._accessor, name))
87
+
88
+ def __getitem__(self, key) -> CallableProxy:
89
+ return CallableProxy(self._callable, self._accessor[key])
90
+
91
+
92
+ class ApplyCaller(Protocol, Generic[A]):
93
+ def __getattr__(self, __name) -> ApplyCaller[A]:
94
+ ...
95
+
96
+ def __getitem__(self, __name) -> ApplyCaller[A]:
97
+ ...
98
+
99
+ def __call__(self, *args, **kwargs) -> tuple[Any, A]:
100
+ ...