brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,11 +13,18 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Optional, Callable, Union
16
+ from functools import lru_cache
17
+ from typing import Callable, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import jax.tree_util as jtu
23
+ from jax import core as jax_core
17
24
 
18
- from brainstate import init
19
25
  from brainstate._state import ParamState
20
- from brainstate.typing import ArrayLike
26
+ from brainstate.typing import ArrayLike, Size
27
+ from . import init as init
21
28
  from ._module import Module
22
29
 
23
30
  __all__ = [
@@ -25,34 +32,377 @@ __all__ = [
25
32
  ]
26
33
 
27
34
 
35
+ def _normalize_embedding_size(size: Size) -> Tuple[int, ...]:
36
+ """Convert ``Size`` specifications to a validated tuple of integers."""
37
+ size_array = np.asarray(size)
38
+ if size_array.size == 0:
39
+ raise ValueError('embedding_size must contain at least one dimension.')
40
+ flat = size_array.reshape(-1)
41
+ normalized = tuple(int(dim) for dim in flat)
42
+ if any(dim < 0 for dim in normalized):
43
+ raise ValueError('embedding_size must not contain negative values.')
44
+ return normalized
45
+
46
+
47
+ @lru_cache(maxsize=None)
48
+ def _embedding_lookup_fn(
49
+ padding_idx: Optional[int],
50
+ scale_grad_by_freq: bool,
51
+ ):
52
+ """Return a lookup function with a custom VJP implementing embedding semantics."""
53
+
54
+ @jax.custom_vjp
55
+ def _lookup(weight: jax.Array, indices: jax.Array) -> jax.Array:
56
+ indices = jnp.asarray(indices)
57
+ return weight[indices]
58
+
59
+ def _lookup_fwd(weight: jax.Array, indices: jax.Array):
60
+ indices = jnp.asarray(indices)
61
+ return weight[indices], (indices, weight.shape)
62
+
63
+ def _lookup_bwd(residual, grad_output: jax.Array):
64
+ indices, weight_shape = residual
65
+ grad_output = jnp.asarray(grad_output)
66
+ flat_idx = jnp.ravel(indices)
67
+ if flat_idx.size == 0:
68
+ return jnp.zeros(weight_shape, dtype=grad_output.dtype), None
69
+
70
+ flat_idx = jnp.asarray(flat_idx, dtype=jnp.int32)
71
+ grad_flat = grad_output.reshape((flat_idx.shape[0],) + weight_shape[1:])
72
+
73
+ if scale_grad_by_freq:
74
+ counts = jnp.bincount(flat_idx, length=weight_shape[0])
75
+ counts = counts.astype(grad_flat.dtype)
76
+ counts = jnp.where(counts == 0, 1.0, counts)
77
+ scale = counts[flat_idx]
78
+ grad_flat = grad_flat / scale.reshape((flat_idx.shape[0],) + (1,) * (grad_flat.ndim - 1))
79
+
80
+ if padding_idx is not None:
81
+ pad_value = jnp.asarray(padding_idx, dtype=flat_idx.dtype)
82
+ mask = flat_idx != pad_value
83
+ broadcast_shape = (flat_idx.shape[0],) + (1,) * (grad_flat.ndim - 1)
84
+ grad_flat = grad_flat * mask.reshape(broadcast_shape).astype(grad_flat.dtype)
85
+
86
+ grad_weight = jnp.zeros(weight_shape, dtype=grad_output.dtype)
87
+ grad_weight = grad_weight.at[flat_idx].add(grad_flat)
88
+ return grad_weight, None
89
+
90
+ _lookup.defvjp(_lookup_fwd, _lookup_bwd)
91
+ return _lookup
92
+
93
+
94
+ def _contains_tracer(tree) -> bool:
95
+ """Return True if the pytree contains any JAX tracer values."""
96
+ return any(isinstance(leaf, jax_core.Tracer) for leaf in jtu.tree_leaves(tree))
97
+
98
+
99
+
28
100
  class Embedding(Module):
29
101
  r"""
30
102
  A simple lookup table that stores embeddings of a fixed size.
31
103
 
32
- Args:
33
- num_embeddings: Size of embedding dictionary. Must be non-negative.
34
- embedding_size: Size of each embedding vector. Must be non-negative.
35
- embedding_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
104
+ This module is commonly used to store word embeddings and retrieve them using indices.
105
+ The input to the module is a list of indices, and the output is the corresponding
106
+ word embeddings.
107
+
108
+ Parameters
109
+ ----------
110
+ num_embeddings : int
111
+ Size of embedding dictionary. Must be non-negative.
112
+ embedding_size : Size
113
+ Size of each embedding vector. Can be an int or a sequence of ints, and must
114
+ contain only non-negative values.
115
+ embedding_init : Callable or ArrayLike, optional
116
+ The initializer for the embedding lookup table, of shape
117
+ ``(num_embeddings, embedding_size)``. Default is ``LecunUniform()``.
118
+ padding_idx : int, optional
119
+ If specified, the entries at ``padding_idx`` do not contribute to the gradient;
120
+ therefore, the embedding vector at ``padding_idx`` is not updated during training,
121
+ i.e., it remains as a fixed "pad". For a newly constructed Embedding, the embedding
122
+ vector at ``padding_idx`` will default to all zeros. Default is ``None``.
123
+ max_norm : float, optional
124
+ If given, each embedding vector with norm larger than ``max_norm`` is renormalized
125
+ to have norm ``max_norm``. Default is ``None``.
126
+ norm_type : float, optional
127
+ The p of the p-norm to compute for the ``max_norm`` option. Default is ``2.0``.
128
+ scale_grad_by_freq : bool, optional
129
+ If given, this scales gradients by the inverse frequency of the words in
130
+ the mini-batch. Default is ``False``.
131
+ name : str, optional
132
+ The name of the module.
133
+ param_type : type, optional
134
+ The parameter state type to use. Default is ``ParamState``.
135
+
136
+ Attributes
137
+ ----------
138
+ num_embeddings : int
139
+ Size of the embedding dictionary.
140
+ embedding_size : tuple[int, ...]
141
+ Size of each embedding vector.
142
+ out_size : tuple[int, ...]
143
+ Output size, same as ``embedding_size``.
144
+ weight : ParamState
145
+ The learnable weights of the module of shape
146
+ ``(num_embeddings, *embedding_size)``.
147
+ padding_idx : int or None
148
+ Index of the padding token.
149
+ max_norm : float or None
150
+ Maximum norm for embedding vectors.
151
+ norm_type : float
152
+ Type of p-norm to compute for max_norm.
153
+ scale_grad_by_freq : bool
154
+ Whether to scale gradients by frequency.
155
+ freeze : bool
156
+ Whether the embedding weights are frozen.
157
+
158
+ Examples
159
+ --------
160
+ Create an embedding layer with 10 words and 3-dimensional embeddings:
161
+
162
+ .. code-block:: python
163
+
164
+ >>> import brainstate as bst
165
+ >>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3)
166
+ >>> embedding.weight.value.shape
167
+ (10, 3)
168
+
169
+ Retrieve embeddings for specific indices:
170
+
171
+ .. code-block:: python
172
+
173
+ >>> import jax.numpy as jnp
174
+ >>> indices = jnp.array([1, 3, 5])
175
+ >>> output = embedding(indices)
176
+ >>> output.shape
177
+ (3, 3)
178
+
179
+ Use with a batch of sequences:
180
+
181
+ .. code-block:: python
182
+
183
+ >>> # Batch of 2 sequences, each with 4 tokens
184
+ >>> batch_indices = jnp.array([[1, 2, 3, 4],
185
+ ... [5, 6, 7, 8]])
186
+ >>> output = embedding(batch_indices)
187
+ >>> output.shape
188
+ (2, 4, 3)
189
+
190
+ Use padding_idx to keep padding embeddings fixed:
191
+
192
+ .. code-block:: python
193
+
194
+ >>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3, padding_idx=0)
195
+ >>> # The embedding at index 0 will remain zeros and not be updated during training
196
+ >>> indices = jnp.array([0, 2, 0, 5])
197
+ >>> output = embedding(indices)
198
+ >>> output[0] # Will be zeros
199
+ Array([0., 0., 0.], dtype=float32)
200
+
201
+ Use max_norm to constrain embedding norms:
202
+
203
+ .. code-block:: python
204
+
205
+ >>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3, max_norm=1.0)
206
+ >>> # All embeddings accessed in a forward pass are renormalized to have norm <= 1.0
207
+
208
+ Load pretrained embeddings:
209
+
210
+ .. code-block:: python
211
+
212
+ >>> import brainstate
213
+ >>> import jax.numpy as jnp
214
+ >>> pretrained = jnp.array([[1.0, 2.0, 3.0],
215
+ ... [4.0, 5.0, 6.0],
216
+ ... [7.0, 8.0, 9.0]])
217
+ >>> embedding = bst.nn.Embedding.from_pretrained(pretrained, param_type=brainstate.FakeState)
218
+ >>> embedding.weight.value.shape
219
+ (3, 3)
36
220
  """
37
221
 
38
222
  def __init__(
39
223
  self,
40
224
  num_embeddings: int,
41
- embedding_size: int,
225
+ embedding_size: Size,
42
226
  embedding_init: Union[Callable, ArrayLike] = init.LecunUniform(),
227
+ padding_idx: Optional[int] = None,
228
+ max_norm: Optional[float] = None,
229
+ norm_type: float = 2.0,
230
+ scale_grad_by_freq: bool = False,
231
+ freeze: bool = False,
43
232
  name: Optional[str] = None,
233
+ param_type: type = ParamState,
44
234
  ):
45
235
  super().__init__(name=name)
46
- if num_embeddings < 0:
47
- raise ValueError("num_embeddings must not be negative.")
48
- if embedding_size < 0:
49
- raise ValueError("embedding_size must not be negative.")
50
- self.num_embeddings = num_embeddings
51
- self.embedding_size = embedding_size
52
- self.out_size = (embedding_size,)
53
236
 
54
- weight = init.param(embedding_init, (self.num_embeddings, self.embedding_size))
55
- self.weight = ParamState(weight)
237
+ self.num_embeddings = int(num_embeddings)
238
+ if self.num_embeddings < 0:
239
+ raise ValueError('num_embeddings must not be negative.')
240
+
241
+ embedding_size_tuple = _normalize_embedding_size(embedding_size)
242
+ self.embedding_size = embedding_size_tuple
243
+ self.out_size = embedding_size_tuple
244
+
245
+ if padding_idx is not None:
246
+ padding_idx = int(padding_idx)
247
+ if padding_idx < 0 or padding_idx >= self.num_embeddings:
248
+ raise ValueError(f'padding_idx must be within [0, {self.num_embeddings}).')
249
+ self.padding_idx = padding_idx
250
+
251
+ if max_norm is not None and max_norm <= 0:
252
+ raise ValueError('max_norm must be positive when provided.')
253
+ self.max_norm = max_norm
254
+ self.norm_type = norm_type
255
+ self.scale_grad_by_freq = bool(scale_grad_by_freq)
256
+ self.freeze = bool(freeze)
257
+
258
+ weight_shape = (self.num_embeddings, *self.out_size)
259
+ weight = init.param(embedding_init, weight_shape)
260
+
261
+ if self.padding_idx is not None:
262
+ weight = weight.at[self.padding_idx].set(0)
263
+
264
+ self.weight = param_type(weight)
265
+ self._lookup = _embedding_lookup_fn(self.padding_idx, self.scale_grad_by_freq)
266
+
267
+ @classmethod
268
+ def from_pretrained(
269
+ cls,
270
+ embeddings: ArrayLike,
271
+ padding_idx: Optional[int] = None,
272
+ max_norm: Optional[float] = None,
273
+ norm_type: float = 2.0,
274
+ scale_grad_by_freq: bool = False,
275
+ freeze: bool = True,
276
+ name: Optional[str] = None,
277
+ param_type: type = ParamState,
278
+ ):
279
+ r"""
280
+ Create an Embedding instance from given 2-dimensional array.
281
+
282
+ Parameters
283
+ ----------
284
+ embeddings : ArrayLike
285
+ Array containing weights for the Embedding. First dimension is passed to
286
+ Embedding as ``num_embeddings``, remaining dimensions as ``embedding_size``.
287
+ padding_idx : int, optional
288
+ If specified, the entries at ``padding_idx`` do not contribute to the gradient.
289
+ Default is ``None``.
290
+ max_norm : float, optional
291
+ See module initialization documentation. Default is ``None``.
292
+ norm_type : float, optional
293
+ See module initialization documentation. Default is ``2.0``.
294
+ scale_grad_by_freq : bool, optional
295
+ See module initialization documentation. Default is ``False``.
296
+ freeze : bool, optional
297
+ If ``True``, embeddings are frozen (no gradients). Default is ``True``.
298
+ name : str, optional
299
+ The name of the module.
300
+
301
+ Returns
302
+ -------
303
+ Embedding
304
+ An Embedding module with pretrained weights.
305
+
306
+ Examples
307
+ --------
308
+ Load pretrained word embeddings:
309
+
310
+ .. code-block:: python
311
+
312
+ >>> import jax.numpy as jnp
313
+ >>> import brainstate as bst
314
+ >>> pretrained = jnp.array([[1.0, 2.0, 3.0],
315
+ ... [4.0, 5.0, 6.0],
316
+ ... [7.0, 8.0, 9.0]])
317
+ >>> embedding = bst.nn.Embedding.from_pretrained(pretrained)
318
+ >>> embedding.weight.value.shape
319
+ (3, 3)
320
+ >>> indices = jnp.array([1])
321
+ >>> embedding(indices)
322
+ Array([[4., 5., 6.]], dtype=float32)
323
+ """
324
+ embeddings = jnp.asarray(embeddings)
325
+ if embeddings.ndim < 2:
326
+ raise ValueError('embeddings must be at least 2-dimensional')
327
+
328
+ num_embeddings = embeddings.shape[0]
329
+ embedding_size = embeddings.shape[1:]
330
+
331
+ instance = cls(
332
+ num_embeddings=num_embeddings,
333
+ embedding_size=embedding_size,
334
+ embedding_init=embeddings,
335
+ padding_idx=padding_idx,
336
+ max_norm=max_norm,
337
+ norm_type=norm_type,
338
+ scale_grad_by_freq=scale_grad_by_freq,
339
+ freeze=freeze,
340
+ name=name,
341
+ param_type=param_type,
342
+ )
343
+
344
+ instance.weight = param_type(jnp.asarray(embeddings))
345
+ return instance
56
346
 
57
347
  def update(self, indices: ArrayLike):
58
- return self.weight.value[indices]
348
+ """
349
+ Retrieve embeddings for the given indices.
350
+
351
+ Parameters
352
+ ----------
353
+ indices : ArrayLike
354
+ Indices to retrieve embeddings for. Can be any shape.
355
+
356
+ Returns
357
+ -------
358
+ ArrayLike
359
+ Embeddings corresponding to the indices, with shape
360
+ ``(*indices.shape, *embedding_size)``.
361
+ """
362
+ indices = jnp.asarray(indices)
363
+ if not jnp.issubdtype(indices.dtype, jnp.integer):
364
+ raise TypeError('Embedding indices must be integers.')
365
+
366
+ weight_value = self.weight.value
367
+ effective_weight = weight_value
368
+
369
+ if self.max_norm is not None:
370
+ renormed_weight = self._apply_max_norm(weight_value, indices)
371
+ effective_weight = weight_value + jax.lax.stop_gradient(renormed_weight - weight_value)
372
+ if not _contains_tracer(renormed_weight):
373
+ self.weight.value = renormed_weight
374
+
375
+ if self.freeze:
376
+ effective_weight = jax.lax.stop_gradient(effective_weight)
377
+
378
+ embeddings = self._lookup(effective_weight, indices)
379
+ return embeddings
380
+
381
+ def _apply_max_norm(self, weight: jax.Array, indices: jax.Array) -> jax.Array:
382
+ """Apply max_norm constraint to the embedding weights for the given indices."""
383
+ flat_idx = jnp.ravel(indices)
384
+ if flat_idx.size == 0:
385
+ return weight
386
+
387
+ flat_idx = jnp.asarray(flat_idx, dtype=jnp.int32)
388
+ if self.padding_idx is not None:
389
+ pad_value = jnp.asarray(self.padding_idx, dtype=flat_idx.dtype)
390
+ flat_idx = flat_idx[flat_idx != pad_value]
391
+
392
+ if flat_idx.size == 0:
393
+ return weight
394
+
395
+ rows = weight[flat_idx]
396
+ rows_flat = rows.reshape((rows.shape[0], -1))
397
+ row_dtype = rows_flat.dtype
398
+
399
+ norms = jnp.linalg.norm(rows_flat, ord=self.norm_type, axis=1, keepdims=True)
400
+ max_norm = jnp.asarray(self.max_norm, dtype=row_dtype)
401
+ eps = jnp.asarray(1e-8, dtype=row_dtype)
402
+ one = jnp.asarray(1.0, dtype=row_dtype)
403
+ scale = jnp.minimum(one, max_norm / (norms + eps))
404
+ rows_scaled = (rows_flat * scale).reshape(rows.shape)
405
+
406
+ return weight.at[flat_idx].set(rows_scaled)
407
+
408
+
@@ -0,0 +1,156 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ import unittest
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+
23
+ import brainstate as bs
24
+
25
+
26
+ class TestEmbedding(unittest.TestCase):
27
+ """Comprehensive tests for the Embedding module."""
28
+
29
+ def setUp(self):
30
+ settings = bs.environ.all()
31
+ self._prev_fit = settings.get('fit', None)
32
+ bs.environ.set(fit=True)
33
+
34
+ def tearDown(self):
35
+ if self._prev_fit is None:
36
+ bs.environ.pop('fit', None)
37
+ else:
38
+ bs.environ.set(fit=self._prev_fit)
39
+
40
+ def test_padding_idx_zero_gradient(self):
41
+ embedding = bs.nn.Embedding(num_embeddings=4, embedding_size=3, padding_idx=0)
42
+ lookup = embedding._lookup
43
+ weight = jnp.arange(12.0, dtype=jnp.float32).reshape(4, 3)
44
+ indices = jnp.array([0, 1, 0, 2], dtype=jnp.int32)
45
+
46
+ def loss_fn(w):
47
+ return jnp.sum(lookup(w, indices))
48
+
49
+ grad = jax.grad(loss_fn)(weight)
50
+ self.assertTrue(jnp.allclose(grad[0], 0.0))
51
+ self.assertFalse(jnp.allclose(grad[1], 0.0))
52
+
53
+ def test_scale_grad_by_freq(self):
54
+ base = bs.nn.Embedding(num_embeddings=5, embedding_size=2)
55
+ scaled = bs.nn.Embedding(num_embeddings=5, embedding_size=2, scale_grad_by_freq=True)
56
+ base_lookup = base._lookup
57
+ scaled_lookup = scaled._lookup
58
+
59
+ weight = jnp.arange(10.0, dtype=jnp.float32).reshape(5, 2)
60
+ indices = jnp.array([1, 1, 2], dtype=jnp.int32)
61
+
62
+ def loss_base(w):
63
+ return jnp.sum(base_lookup(w, indices))
64
+
65
+ def loss_scaled(w):
66
+ return jnp.sum(scaled_lookup(w, indices))
67
+
68
+ base_grad = jax.grad(loss_base)(weight)
69
+ scaled_grad = jax.grad(loss_scaled)(weight)
70
+
71
+ counts = jnp.bincount(indices, length=weight.shape[0])
72
+ ones = jnp.ones((indices.shape[0], weight.shape[1]), dtype=weight.dtype)
73
+ expected_base = jnp.zeros_like(weight).at[indices].add(ones)
74
+ expected_scaled = jnp.where(counts[:, None] > 0, jnp.ones_like(weight), 0.0)
75
+
76
+ self.assertTrue(jnp.allclose(base_grad, expected_base))
77
+ self.assertTrue(jnp.allclose(scaled_grad, expected_scaled))
78
+
79
+ def test_lookup_grad_jit_consistent(self):
80
+ embedding = bs.nn.Embedding(num_embeddings=4, embedding_size=2)
81
+ lookup = embedding._lookup
82
+ weight = jnp.arange(8.0, dtype=jnp.float32).reshape(4, 2)
83
+ indices = jnp.array([0, 1, 1, 3], dtype=jnp.int32)
84
+
85
+ def loss_fn(w):
86
+ return jnp.sum(lookup(w, indices))
87
+
88
+ grad_eager = jax.grad(loss_fn)(weight)
89
+ grad_jitted = jax.grad(jax.jit(loss_fn))(weight)
90
+
91
+ expected = jnp.zeros_like(weight).at[indices].add(jnp.ones((indices.shape[0], weight.shape[1]), dtype=weight.dtype))
92
+
93
+ self.assertTrue(jnp.allclose(grad_eager, grad_jitted))
94
+ self.assertTrue(jnp.allclose(grad_eager, expected))
95
+
96
+ def test_jit_forward_with_max_norm(self):
97
+ embedding = bs.nn.Embedding(num_embeddings=3, embedding_size=3, max_norm=0.5)
98
+ heavy = jnp.array([[0.0, 0.0, 0.0], [1.0, 2.0, 2.0], [0.3, -0.4, 0.5]], dtype=jnp.float32)
99
+ embedding.weight.value = heavy
100
+ indices = jnp.array([1, 2, 1], dtype=jnp.int32)
101
+
102
+ compiled = jax.jit(lambda ids: embedding(ids))
103
+ out = compiled(indices)
104
+ self.assertEqual(out.shape, (3, 3))
105
+ output_norms = jnp.linalg.norm(out, axis=-1)
106
+ self.assertTrue(jnp.all(output_norms <= 0.5 + 1e-6))
107
+ # Weight remains unclipped during JIT execution but must be usable without tracer leaks
108
+ self.assertGreater(float(jnp.linalg.norm(embedding.weight.value[1])), 0.5)
109
+
110
+ def test_freeze_disables_gradients(self):
111
+ embedding = bs.nn.Embedding(num_embeddings=4, embedding_size=2, freeze=True)
112
+ indices = jnp.array([1, 2, 3], dtype=jnp.int32)
113
+
114
+ def loss_fn(weight):
115
+ embedding.weight.value = weight
116
+ return jnp.sum(embedding(indices))
117
+
118
+ grad = jax.grad(loss_fn)(embedding.weight.value)
119
+ self.assertTrue(jnp.allclose(grad, 0.0))
120
+
121
+ def test_from_pretrained_defaults_to_freeze(self):
122
+ pretrained = jnp.arange(12.0, dtype=jnp.float32).reshape(4, 3)
123
+ embedding = bs.nn.Embedding.from_pretrained(pretrained)
124
+ self.assertTrue(embedding.freeze)
125
+
126
+ def loss_fn(weight):
127
+ embedding.weight.value = weight
128
+ return jnp.sum(embedding(jnp.array([1, 2], dtype=jnp.int32)))
129
+
130
+ grad = jax.grad(loss_fn)(embedding.weight.value)
131
+ self.assertTrue(jnp.allclose(grad, 0.0))
132
+
133
+ def test_from_pretrained_unfrozen_gradients(self):
134
+ pretrained = jnp.arange(6.0, dtype=jnp.float32).reshape(2, 3)
135
+ embedding = bs.nn.Embedding.from_pretrained(pretrained, freeze=False)
136
+
137
+ def loss_fn(weight):
138
+ embedding.weight.value = weight
139
+ return jnp.sum(embedding(jnp.array([0, 1], dtype=jnp.int32)))
140
+
141
+ grad = jax.grad(loss_fn)(embedding.weight.value)
142
+ self.assertFalse(jnp.allclose(grad, 0.0))
143
+
144
+ def test_max_norm_renormalizes_weights(self):
145
+ embedding = bs.nn.Embedding(num_embeddings=3, embedding_size=3, max_norm=1.0, norm_type=2.0)
146
+ custom_weight = jnp.array([[0.0, 0.0, 0.0], [3.0, 4.0, 0.0], [0.5, 0.5, 0.5]], dtype=jnp.float32)
147
+ embedding.weight.value = custom_weight
148
+ _ = embedding(jnp.array([1, 2], dtype=jnp.int32))
149
+
150
+ row_norm = jnp.linalg.norm(embedding.weight.value[1])
151
+ self.assertLessEqual(float(row_norm), 1.0 + 1e-6)
152
+ self.assertTrue(jnp.allclose(embedding.weight.value[0], custom_weight[0]))
153
+
154
+
155
+ if __name__ == '__main__':
156
+ unittest.main()
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -22,10 +22,11 @@ import jax
22
22
  import jax.numpy as jnp
23
23
  import numpy as np
24
24
 
25
- from brainstate import random, augment, environ, init
25
+ from brainstate import random, transform, environ
26
26
  from brainstate._state import ParamState, FakeState
27
- from brainstate.compile import for_loop
27
+ from brainstate.transform import for_loop
28
28
  from brainstate.typing import Size, ArrayLike
29
+ from . import init as init
29
30
  from ._module import Module
30
31
 
31
32
  __all__ = [
@@ -45,7 +46,7 @@ def init_indices_without_replace(
45
46
  rng = random.default_rng(seed)
46
47
 
47
48
  if method == 'vmap':
48
- @augment.vmap(axis_size=n_pre)
49
+ @transform.vmap(axis_size=n_pre)
49
50
  def rand_indices():
50
51
  return rng.choice(n_post, size=(conn_num,), replace=False)
51
52
 
@@ -176,15 +177,14 @@ class FixedNumConn(Module):
176
177
  conn_weight = u.math.asarray(init.param(conn_weight, (), allow_none=False))
177
178
  self.weight = FakeState(conn_weight)
178
179
 
179
- def update(self, x: jax.Array) -> Union[jax.Array, u.Quantity]:
180
+ def update(self, x) -> Union[jax.Array, u.Quantity]:
180
181
  if self.conn_num >= 1:
181
182
  csr = self.conn.with_data(self.weight.value)
182
183
  return x @ csr
183
184
  else:
184
185
  weight = self.weight.value
185
186
  r = u.math.zeros(x.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
186
- r = u.maybe_decimal(u.Quantity(r, unit=u.get_unit(weight)))
187
- return u.math.asarray(r, dtype=environ.dftype())
187
+ return u.maybe_decimal(u.Quantity(r, unit=u.get_unit(weight), dtype=environ.dftype()))
188
188
 
189
189
 
190
190
  class EventFixedNumConn(FixedNumConn):
@@ -225,15 +225,9 @@ class EventFixedNumConn(FixedNumConn):
225
225
  __module__ = 'brainstate.nn'
226
226
 
227
227
  def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
228
- if self.conn_num >= 1:
229
- csr = self.conn.with_data(self.weight.value)
230
- return brainevent.EventArray(spk) @ csr
231
- else:
232
- weight = self.weight.value
233
- unit = u.get_unit(weight)
234
- r = jnp.zeros(spk.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
235
- r = u.maybe_decimal(u.Quantity(r, unit=unit))
236
- return u.math.asarray(r, dtype=environ.dftype())
228
+ return super().update(
229
+ brainevent.EventArray(spk)
230
+ )
237
231
 
238
232
 
239
233
  EventFixedProb = EventFixedNumConn
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ import jax.numpy as jnp
19
19
  import pytest
20
20
 
21
21
  import brainstate
22
+ import braintools
22
23
 
23
24
 
24
25
  class TestFixedProbCSR:
@@ -30,14 +31,14 @@ class TestFixedProbCSR:
30
31
  y = m(x)
31
32
  print(y)
32
33
 
33
- m2 = brainstate.nn.EventFixedProb(20, 40, 0.1, brainstate.init.KaimingUniform(), seed=123)
34
+ m2 = brainstate.nn.EventFixedProb(20, 40, 0.1, braintools.init.KaimingUniform(), seed=123)
34
35
  print(m2(x))
35
36
 
36
37
  def test_grad_bool(self):
37
38
  n_in = 20
38
39
  n_out = 30
39
40
  x = jax.numpy.asarray(brainstate.random.rand(n_in) < 0.3, dtype=float)
40
- fn = brainstate.nn.EventFixedProb(n_in, n_out, 0.1, brainstate.init.KaimingUniform(), seed=123)
41
+ fn = brainstate.nn.EventFixedProb(n_in, n_out, 0.1, braintools.init.KaimingUniform(), seed=123)
41
42
 
42
43
  def f(x):
43
44
  return fn(x).sum()
@@ -53,7 +54,7 @@ class TestFixedProbCSR:
53
54
  if homo_w:
54
55
  fn = brainstate.nn.EventFixedProb(n_in, n_out, 0.1, 1.5, seed=123)
55
56
  else:
56
- fn = brainstate.nn.EventFixedProb(n_in, n_out, 0.1, brainstate.init.KaimingUniform(), seed=123)
57
+ fn = brainstate.nn.EventFixedProb(n_in, n_out, 0.1, braintools.init.KaimingUniform(), seed=123)
57
58
  w = fn.weight.value
58
59
 
59
60
  def f(x, w):
@@ -85,7 +86,7 @@ class TestFixedProbCSR:
85
86
  x = jax.numpy.asarray(brainstate.random.rand(n_in) < 0.3, dtype=float)
86
87
 
87
88
  fn = brainstate.nn.EventFixedProb(
88
- n_in, n_out, 0.1, 1.5 if homo_w else brainstate.init.KaimingUniform(),
89
+ n_in, n_out, 0.1, 1.5 if homo_w else braintools.init.KaimingUniform(),
89
90
  seed=123,
90
91
  )
91
92
  w = fn.weight.value