brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,408 +1,408 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from functools import lru_cache
17
- from typing import Callable, Optional, Tuple, Union
18
-
19
- import numpy as np
20
- import jax
21
- import jax.numpy as jnp
22
- import jax.tree_util as jtu
23
- from jax import core as jax_core
24
-
25
- from brainstate._state import ParamState
26
- from brainstate.typing import ArrayLike, Size
27
- from . import init as init
28
- from ._module import Module
29
-
30
- __all__ = [
31
- 'Embedding',
32
- ]
33
-
34
-
35
- def _normalize_embedding_size(size: Size) -> Tuple[int, ...]:
36
- """Convert ``Size`` specifications to a validated tuple of integers."""
37
- size_array = np.asarray(size)
38
- if size_array.size == 0:
39
- raise ValueError('embedding_size must contain at least one dimension.')
40
- flat = size_array.reshape(-1)
41
- normalized = tuple(int(dim) for dim in flat)
42
- if any(dim < 0 for dim in normalized):
43
- raise ValueError('embedding_size must not contain negative values.')
44
- return normalized
45
-
46
-
47
- @lru_cache(maxsize=None)
48
- def _embedding_lookup_fn(
49
- padding_idx: Optional[int],
50
- scale_grad_by_freq: bool,
51
- ):
52
- """Return a lookup function with a custom VJP implementing embedding semantics."""
53
-
54
- @jax.custom_vjp
55
- def _lookup(weight: jax.Array, indices: jax.Array) -> jax.Array:
56
- indices = jnp.asarray(indices)
57
- return weight[indices]
58
-
59
- def _lookup_fwd(weight: jax.Array, indices: jax.Array):
60
- indices = jnp.asarray(indices)
61
- return weight[indices], (indices, weight.shape)
62
-
63
- def _lookup_bwd(residual, grad_output: jax.Array):
64
- indices, weight_shape = residual
65
- grad_output = jnp.asarray(grad_output)
66
- flat_idx = jnp.ravel(indices)
67
- if flat_idx.size == 0:
68
- return jnp.zeros(weight_shape, dtype=grad_output.dtype), None
69
-
70
- flat_idx = jnp.asarray(flat_idx, dtype=jnp.int32)
71
- grad_flat = grad_output.reshape((flat_idx.shape[0],) + weight_shape[1:])
72
-
73
- if scale_grad_by_freq:
74
- counts = jnp.bincount(flat_idx, length=weight_shape[0])
75
- counts = counts.astype(grad_flat.dtype)
76
- counts = jnp.where(counts == 0, 1.0, counts)
77
- scale = counts[flat_idx]
78
- grad_flat = grad_flat / scale.reshape((flat_idx.shape[0],) + (1,) * (grad_flat.ndim - 1))
79
-
80
- if padding_idx is not None:
81
- pad_value = jnp.asarray(padding_idx, dtype=flat_idx.dtype)
82
- mask = flat_idx != pad_value
83
- broadcast_shape = (flat_idx.shape[0],) + (1,) * (grad_flat.ndim - 1)
84
- grad_flat = grad_flat * mask.reshape(broadcast_shape).astype(grad_flat.dtype)
85
-
86
- grad_weight = jnp.zeros(weight_shape, dtype=grad_output.dtype)
87
- grad_weight = grad_weight.at[flat_idx].add(grad_flat)
88
- return grad_weight, None
89
-
90
- _lookup.defvjp(_lookup_fwd, _lookup_bwd)
91
- return _lookup
92
-
93
-
94
- def _contains_tracer(tree) -> bool:
95
- """Return True if the pytree contains any JAX tracer values."""
96
- return any(isinstance(leaf, jax_core.Tracer) for leaf in jtu.tree_leaves(tree))
97
-
98
-
99
-
100
- class Embedding(Module):
101
- r"""
102
- A simple lookup table that stores embeddings of a fixed size.
103
-
104
- This module is commonly used to store word embeddings and retrieve them using indices.
105
- The input to the module is a list of indices, and the output is the corresponding
106
- word embeddings.
107
-
108
- Parameters
109
- ----------
110
- num_embeddings : int
111
- Size of embedding dictionary. Must be non-negative.
112
- embedding_size : Size
113
- Size of each embedding vector. Can be an int or a sequence of ints, and must
114
- contain only non-negative values.
115
- embedding_init : Callable or ArrayLike, optional
116
- The initializer for the embedding lookup table, of shape
117
- ``(num_embeddings, embedding_size)``. Default is ``LecunUniform()``.
118
- padding_idx : int, optional
119
- If specified, the entries at ``padding_idx`` do not contribute to the gradient;
120
- therefore, the embedding vector at ``padding_idx`` is not updated during training,
121
- i.e., it remains as a fixed "pad". For a newly constructed Embedding, the embedding
122
- vector at ``padding_idx`` will default to all zeros. Default is ``None``.
123
- max_norm : float, optional
124
- If given, each embedding vector with norm larger than ``max_norm`` is renormalized
125
- to have norm ``max_norm``. Default is ``None``.
126
- norm_type : float, optional
127
- The p of the p-norm to compute for the ``max_norm`` option. Default is ``2.0``.
128
- scale_grad_by_freq : bool, optional
129
- If given, this scales gradients by the inverse frequency of the words in
130
- the mini-batch. Default is ``False``.
131
- name : str, optional
132
- The name of the module.
133
- param_type : type, optional
134
- The parameter state type to use. Default is ``ParamState``.
135
-
136
- Attributes
137
- ----------
138
- num_embeddings : int
139
- Size of the embedding dictionary.
140
- embedding_size : tuple[int, ...]
141
- Size of each embedding vector.
142
- out_size : tuple[int, ...]
143
- Output size, same as ``embedding_size``.
144
- weight : ParamState
145
- The learnable weights of the module of shape
146
- ``(num_embeddings, *embedding_size)``.
147
- padding_idx : int or None
148
- Index of the padding token.
149
- max_norm : float or None
150
- Maximum norm for embedding vectors.
151
- norm_type : float
152
- Type of p-norm to compute for max_norm.
153
- scale_grad_by_freq : bool
154
- Whether to scale gradients by frequency.
155
- freeze : bool
156
- Whether the embedding weights are frozen.
157
-
158
- Examples
159
- --------
160
- Create an embedding layer with 10 words and 3-dimensional embeddings:
161
-
162
- .. code-block:: python
163
-
164
- >>> import brainstate as bst
165
- >>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3)
166
- >>> embedding.weight.value.shape
167
- (10, 3)
168
-
169
- Retrieve embeddings for specific indices:
170
-
171
- .. code-block:: python
172
-
173
- >>> import jax.numpy as jnp
174
- >>> indices = jnp.array([1, 3, 5])
175
- >>> output = embedding(indices)
176
- >>> output.shape
177
- (3, 3)
178
-
179
- Use with a batch of sequences:
180
-
181
- .. code-block:: python
182
-
183
- >>> # Batch of 2 sequences, each with 4 tokens
184
- >>> batch_indices = jnp.array([[1, 2, 3, 4],
185
- ... [5, 6, 7, 8]])
186
- >>> output = embedding(batch_indices)
187
- >>> output.shape
188
- (2, 4, 3)
189
-
190
- Use padding_idx to keep padding embeddings fixed:
191
-
192
- .. code-block:: python
193
-
194
- >>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3, padding_idx=0)
195
- >>> # The embedding at index 0 will remain zeros and not be updated during training
196
- >>> indices = jnp.array([0, 2, 0, 5])
197
- >>> output = embedding(indices)
198
- >>> output[0] # Will be zeros
199
- Array([0., 0., 0.], dtype=float32)
200
-
201
- Use max_norm to constrain embedding norms:
202
-
203
- .. code-block:: python
204
-
205
- >>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3, max_norm=1.0)
206
- >>> # All embeddings accessed in a forward pass are renormalized to have norm <= 1.0
207
-
208
- Load pretrained embeddings:
209
-
210
- .. code-block:: python
211
-
212
- >>> import brainstate
213
- >>> import jax.numpy as jnp
214
- >>> pretrained = jnp.array([[1.0, 2.0, 3.0],
215
- ... [4.0, 5.0, 6.0],
216
- ... [7.0, 8.0, 9.0]])
217
- >>> embedding = bst.nn.Embedding.from_pretrained(pretrained, param_type=brainstate.FakeState)
218
- >>> embedding.weight.value.shape
219
- (3, 3)
220
- """
221
-
222
- def __init__(
223
- self,
224
- num_embeddings: int,
225
- embedding_size: Size,
226
- embedding_init: Union[Callable, ArrayLike] = init.LecunUniform(),
227
- padding_idx: Optional[int] = None,
228
- max_norm: Optional[float] = None,
229
- norm_type: float = 2.0,
230
- scale_grad_by_freq: bool = False,
231
- freeze: bool = False,
232
- name: Optional[str] = None,
233
- param_type: type = ParamState,
234
- ):
235
- super().__init__(name=name)
236
-
237
- self.num_embeddings = int(num_embeddings)
238
- if self.num_embeddings < 0:
239
- raise ValueError('num_embeddings must not be negative.')
240
-
241
- embedding_size_tuple = _normalize_embedding_size(embedding_size)
242
- self.embedding_size = embedding_size_tuple
243
- self.out_size = embedding_size_tuple
244
-
245
- if padding_idx is not None:
246
- padding_idx = int(padding_idx)
247
- if padding_idx < 0 or padding_idx >= self.num_embeddings:
248
- raise ValueError(f'padding_idx must be within [0, {self.num_embeddings}).')
249
- self.padding_idx = padding_idx
250
-
251
- if max_norm is not None and max_norm <= 0:
252
- raise ValueError('max_norm must be positive when provided.')
253
- self.max_norm = max_norm
254
- self.norm_type = norm_type
255
- self.scale_grad_by_freq = bool(scale_grad_by_freq)
256
- self.freeze = bool(freeze)
257
-
258
- weight_shape = (self.num_embeddings, *self.out_size)
259
- weight = init.param(embedding_init, weight_shape)
260
-
261
- if self.padding_idx is not None:
262
- weight = weight.at[self.padding_idx].set(0)
263
-
264
- self.weight = param_type(weight)
265
- self._lookup = _embedding_lookup_fn(self.padding_idx, self.scale_grad_by_freq)
266
-
267
- @classmethod
268
- def from_pretrained(
269
- cls,
270
- embeddings: ArrayLike,
271
- padding_idx: Optional[int] = None,
272
- max_norm: Optional[float] = None,
273
- norm_type: float = 2.0,
274
- scale_grad_by_freq: bool = False,
275
- freeze: bool = True,
276
- name: Optional[str] = None,
277
- param_type: type = ParamState,
278
- ):
279
- r"""
280
- Create an Embedding instance from given 2-dimensional array.
281
-
282
- Parameters
283
- ----------
284
- embeddings : ArrayLike
285
- Array containing weights for the Embedding. First dimension is passed to
286
- Embedding as ``num_embeddings``, remaining dimensions as ``embedding_size``.
287
- padding_idx : int, optional
288
- If specified, the entries at ``padding_idx`` do not contribute to the gradient.
289
- Default is ``None``.
290
- max_norm : float, optional
291
- See module initialization documentation. Default is ``None``.
292
- norm_type : float, optional
293
- See module initialization documentation. Default is ``2.0``.
294
- scale_grad_by_freq : bool, optional
295
- See module initialization documentation. Default is ``False``.
296
- freeze : bool, optional
297
- If ``True``, embeddings are frozen (no gradients). Default is ``True``.
298
- name : str, optional
299
- The name of the module.
300
-
301
- Returns
302
- -------
303
- Embedding
304
- An Embedding module with pretrained weights.
305
-
306
- Examples
307
- --------
308
- Load pretrained word embeddings:
309
-
310
- .. code-block:: python
311
-
312
- >>> import jax.numpy as jnp
313
- >>> import brainstate as bst
314
- >>> pretrained = jnp.array([[1.0, 2.0, 3.0],
315
- ... [4.0, 5.0, 6.0],
316
- ... [7.0, 8.0, 9.0]])
317
- >>> embedding = bst.nn.Embedding.from_pretrained(pretrained)
318
- >>> embedding.weight.value.shape
319
- (3, 3)
320
- >>> indices = jnp.array([1])
321
- >>> embedding(indices)
322
- Array([[4., 5., 6.]], dtype=float32)
323
- """
324
- embeddings = jnp.asarray(embeddings)
325
- if embeddings.ndim < 2:
326
- raise ValueError('embeddings must be at least 2-dimensional')
327
-
328
- num_embeddings = embeddings.shape[0]
329
- embedding_size = embeddings.shape[1:]
330
-
331
- instance = cls(
332
- num_embeddings=num_embeddings,
333
- embedding_size=embedding_size,
334
- embedding_init=embeddings,
335
- padding_idx=padding_idx,
336
- max_norm=max_norm,
337
- norm_type=norm_type,
338
- scale_grad_by_freq=scale_grad_by_freq,
339
- freeze=freeze,
340
- name=name,
341
- param_type=param_type,
342
- )
343
-
344
- instance.weight = param_type(jnp.asarray(embeddings))
345
- return instance
346
-
347
- def update(self, indices: ArrayLike):
348
- """
349
- Retrieve embeddings for the given indices.
350
-
351
- Parameters
352
- ----------
353
- indices : ArrayLike
354
- Indices to retrieve embeddings for. Can be any shape.
355
-
356
- Returns
357
- -------
358
- ArrayLike
359
- Embeddings corresponding to the indices, with shape
360
- ``(*indices.shape, *embedding_size)``.
361
- """
362
- indices = jnp.asarray(indices)
363
- if not jnp.issubdtype(indices.dtype, jnp.integer):
364
- raise TypeError('Embedding indices must be integers.')
365
-
366
- weight_value = self.weight.value
367
- effective_weight = weight_value
368
-
369
- if self.max_norm is not None:
370
- renormed_weight = self._apply_max_norm(weight_value, indices)
371
- effective_weight = weight_value + jax.lax.stop_gradient(renormed_weight - weight_value)
372
- if not _contains_tracer(renormed_weight):
373
- self.weight.value = renormed_weight
374
-
375
- if self.freeze:
376
- effective_weight = jax.lax.stop_gradient(effective_weight)
377
-
378
- embeddings = self._lookup(effective_weight, indices)
379
- return embeddings
380
-
381
- def _apply_max_norm(self, weight: jax.Array, indices: jax.Array) -> jax.Array:
382
- """Apply max_norm constraint to the embedding weights for the given indices."""
383
- flat_idx = jnp.ravel(indices)
384
- if flat_idx.size == 0:
385
- return weight
386
-
387
- flat_idx = jnp.asarray(flat_idx, dtype=jnp.int32)
388
- if self.padding_idx is not None:
389
- pad_value = jnp.asarray(self.padding_idx, dtype=flat_idx.dtype)
390
- flat_idx = flat_idx[flat_idx != pad_value]
391
-
392
- if flat_idx.size == 0:
393
- return weight
394
-
395
- rows = weight[flat_idx]
396
- rows_flat = rows.reshape((rows.shape[0], -1))
397
- row_dtype = rows_flat.dtype
398
-
399
- norms = jnp.linalg.norm(rows_flat, ord=self.norm_type, axis=1, keepdims=True)
400
- max_norm = jnp.asarray(self.max_norm, dtype=row_dtype)
401
- eps = jnp.asarray(1e-8, dtype=row_dtype)
402
- one = jnp.asarray(1.0, dtype=row_dtype)
403
- scale = jnp.minimum(one, max_norm / (norms + eps))
404
- rows_scaled = (rows_flat * scale).reshape(rows.shape)
405
-
406
- return weight.at[flat_idx].set(rows_scaled)
407
-
408
-
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from functools import lru_cache
17
+ from typing import Callable, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import jax.tree_util as jtu
23
+ from jax import core as jax_core
24
+
25
+ from brainstate._state import ParamState
26
+ from brainstate.typing import ArrayLike, Size
27
+ from . import init as init
28
+ from ._module import Module
29
+
30
+ __all__ = [
31
+ 'Embedding',
32
+ ]
33
+
34
+
35
+ def _normalize_embedding_size(size: Size) -> Tuple[int, ...]:
36
+ """Convert ``Size`` specifications to a validated tuple of integers."""
37
+ size_array = np.asarray(size)
38
+ if size_array.size == 0:
39
+ raise ValueError('embedding_size must contain at least one dimension.')
40
+ flat = size_array.reshape(-1)
41
+ normalized = tuple(int(dim) for dim in flat)
42
+ if any(dim < 0 for dim in normalized):
43
+ raise ValueError('embedding_size must not contain negative values.')
44
+ return normalized
45
+
46
+
47
+ @lru_cache(maxsize=None)
48
+ def _embedding_lookup_fn(
49
+ padding_idx: Optional[int],
50
+ scale_grad_by_freq: bool,
51
+ ):
52
+ """Return a lookup function with a custom VJP implementing embedding semantics."""
53
+
54
+ @jax.custom_vjp
55
+ def _lookup(weight: jax.Array, indices: jax.Array) -> jax.Array:
56
+ indices = jnp.asarray(indices)
57
+ return weight[indices]
58
+
59
+ def _lookup_fwd(weight: jax.Array, indices: jax.Array):
60
+ indices = jnp.asarray(indices)
61
+ return weight[indices], (indices, weight.shape)
62
+
63
+ def _lookup_bwd(residual, grad_output: jax.Array):
64
+ indices, weight_shape = residual
65
+ grad_output = jnp.asarray(grad_output)
66
+ flat_idx = jnp.ravel(indices)
67
+ if flat_idx.size == 0:
68
+ return jnp.zeros(weight_shape, dtype=grad_output.dtype), None
69
+
70
+ flat_idx = jnp.asarray(flat_idx, dtype=jnp.int32)
71
+ grad_flat = grad_output.reshape((flat_idx.shape[0],) + weight_shape[1:])
72
+
73
+ if scale_grad_by_freq:
74
+ counts = jnp.bincount(flat_idx, length=weight_shape[0])
75
+ counts = counts.astype(grad_flat.dtype)
76
+ counts = jnp.where(counts == 0, 1.0, counts)
77
+ scale = counts[flat_idx]
78
+ grad_flat = grad_flat / scale.reshape((flat_idx.shape[0],) + (1,) * (grad_flat.ndim - 1))
79
+
80
+ if padding_idx is not None:
81
+ pad_value = jnp.asarray(padding_idx, dtype=flat_idx.dtype)
82
+ mask = flat_idx != pad_value
83
+ broadcast_shape = (flat_idx.shape[0],) + (1,) * (grad_flat.ndim - 1)
84
+ grad_flat = grad_flat * mask.reshape(broadcast_shape).astype(grad_flat.dtype)
85
+
86
+ grad_weight = jnp.zeros(weight_shape, dtype=grad_output.dtype)
87
+ grad_weight = grad_weight.at[flat_idx].add(grad_flat)
88
+ return grad_weight, None
89
+
90
+ _lookup.defvjp(_lookup_fwd, _lookup_bwd)
91
+ return _lookup
92
+
93
+
94
+ def _contains_tracer(tree) -> bool:
95
+ """Return True if the pytree contains any JAX tracer values."""
96
+ return any(isinstance(leaf, jax_core.Tracer) for leaf in jtu.tree_leaves(tree))
97
+
98
+
99
+
100
+ class Embedding(Module):
101
+ r"""
102
+ A simple lookup table that stores embeddings of a fixed size.
103
+
104
+ This module is commonly used to store word embeddings and retrieve them using indices.
105
+ The input to the module is a list of indices, and the output is the corresponding
106
+ word embeddings.
107
+
108
+ Parameters
109
+ ----------
110
+ num_embeddings : int
111
+ Size of embedding dictionary. Must be non-negative.
112
+ embedding_size : Size
113
+ Size of each embedding vector. Can be an int or a sequence of ints, and must
114
+ contain only non-negative values.
115
+ embedding_init : Callable or ArrayLike, optional
116
+ The initializer for the embedding lookup table, of shape
117
+ ``(num_embeddings, embedding_size)``. Default is ``LecunUniform()``.
118
+ padding_idx : int, optional
119
+ If specified, the entries at ``padding_idx`` do not contribute to the gradient;
120
+ therefore, the embedding vector at ``padding_idx`` is not updated during training,
121
+ i.e., it remains as a fixed "pad". For a newly constructed Embedding, the embedding
122
+ vector at ``padding_idx`` will default to all zeros. Default is ``None``.
123
+ max_norm : float, optional
124
+ If given, each embedding vector with norm larger than ``max_norm`` is renormalized
125
+ to have norm ``max_norm``. Default is ``None``.
126
+ norm_type : float, optional
127
+ The p of the p-norm to compute for the ``max_norm`` option. Default is ``2.0``.
128
+ scale_grad_by_freq : bool, optional
129
+ If given, this scales gradients by the inverse frequency of the words in
130
+ the mini-batch. Default is ``False``.
131
+ name : str, optional
132
+ The name of the module.
133
+ param_type : type, optional
134
+ The parameter state type to use. Default is ``ParamState``.
135
+
136
+ Attributes
137
+ ----------
138
+ num_embeddings : int
139
+ Size of the embedding dictionary.
140
+ embedding_size : tuple[int, ...]
141
+ Size of each embedding vector.
142
+ out_size : tuple[int, ...]
143
+ Output size, same as ``embedding_size``.
144
+ weight : ParamState
145
+ The learnable weights of the module of shape
146
+ ``(num_embeddings, *embedding_size)``.
147
+ padding_idx : int or None
148
+ Index of the padding token.
149
+ max_norm : float or None
150
+ Maximum norm for embedding vectors.
151
+ norm_type : float
152
+ Type of p-norm to compute for max_norm.
153
+ scale_grad_by_freq : bool
154
+ Whether to scale gradients by frequency.
155
+ freeze : bool
156
+ Whether the embedding weights are frozen.
157
+
158
+ Examples
159
+ --------
160
+ Create an embedding layer with 10 words and 3-dimensional embeddings:
161
+
162
+ .. code-block:: python
163
+
164
+ >>> import brainstate as brainstate
165
+ >>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3)
166
+ >>> embedding.weight.value.shape
167
+ (10, 3)
168
+
169
+ Retrieve embeddings for specific indices:
170
+
171
+ .. code-block:: python
172
+
173
+ >>> import jax.numpy as jnp
174
+ >>> indices = jnp.array([1, 3, 5])
175
+ >>> output = embedding(indices)
176
+ >>> output.shape
177
+ (3, 3)
178
+
179
+ Use with a batch of sequences:
180
+
181
+ .. code-block:: python
182
+
183
+ >>> # Batch of 2 sequences, each with 4 tokens
184
+ >>> batch_indices = jnp.array([[1, 2, 3, 4],
185
+ ... [5, 6, 7, 8]])
186
+ >>> output = embedding(batch_indices)
187
+ >>> output.shape
188
+ (2, 4, 3)
189
+
190
+ Use padding_idx to keep padding embeddings fixed:
191
+
192
+ .. code-block:: python
193
+
194
+ >>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3, padding_idx=0)
195
+ >>> # The embedding at index 0 will remain zeros and not be updated during training
196
+ >>> indices = jnp.array([0, 2, 0, 5])
197
+ >>> output = embedding(indices)
198
+ >>> output[0] # Will be zeros
199
+ Array([0., 0., 0.], dtype=float32)
200
+
201
+ Use max_norm to constrain embedding norms:
202
+
203
+ .. code-block:: python
204
+
205
+ >>> embedding = brainstate.nn.Embedding(num_embeddings=10, embedding_size=3, max_norm=1.0)
206
+ >>> # All embeddings accessed in a forward pass are renormalized to have norm <= 1.0
207
+
208
+ Load pretrained embeddings:
209
+
210
+ .. code-block:: python
211
+
212
+ >>> import brainstate
213
+ >>> import jax.numpy as jnp
214
+ >>> pretrained = jnp.array([[1.0, 2.0, 3.0],
215
+ ... [4.0, 5.0, 6.0],
216
+ ... [7.0, 8.0, 9.0]])
217
+ >>> embedding = brainstate.nn.Embedding.from_pretrained(pretrained, param_type=brainstate.FakeState)
218
+ >>> embedding.weight.value.shape
219
+ (3, 3)
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ num_embeddings: int,
225
+ embedding_size: Size,
226
+ embedding_init: Union[Callable, ArrayLike] = init.LecunUniform(),
227
+ padding_idx: Optional[int] = None,
228
+ max_norm: Optional[float] = None,
229
+ norm_type: float = 2.0,
230
+ scale_grad_by_freq: bool = False,
231
+ freeze: bool = False,
232
+ name: Optional[str] = None,
233
+ param_type: type = ParamState,
234
+ ):
235
+ super().__init__(name=name)
236
+
237
+ self.num_embeddings = int(num_embeddings)
238
+ if self.num_embeddings < 0:
239
+ raise ValueError('num_embeddings must not be negative.')
240
+
241
+ embedding_size_tuple = _normalize_embedding_size(embedding_size)
242
+ self.embedding_size = embedding_size_tuple
243
+ self.out_size = embedding_size_tuple
244
+
245
+ if padding_idx is not None:
246
+ padding_idx = int(padding_idx)
247
+ if padding_idx < 0 or padding_idx >= self.num_embeddings:
248
+ raise ValueError(f'padding_idx must be within [0, {self.num_embeddings}).')
249
+ self.padding_idx = padding_idx
250
+
251
+ if max_norm is not None and max_norm <= 0:
252
+ raise ValueError('max_norm must be positive when provided.')
253
+ self.max_norm = max_norm
254
+ self.norm_type = norm_type
255
+ self.scale_grad_by_freq = bool(scale_grad_by_freq)
256
+ self.freeze = bool(freeze)
257
+
258
+ weight_shape = (self.num_embeddings, *self.out_size)
259
+ weight = init.param(embedding_init, weight_shape)
260
+
261
+ if self.padding_idx is not None:
262
+ weight = weight.at[self.padding_idx].set(0)
263
+
264
+ self.weight = param_type(weight)
265
+ self._lookup = _embedding_lookup_fn(self.padding_idx, self.scale_grad_by_freq)
266
+
267
+ @classmethod
268
+ def from_pretrained(
269
+ cls,
270
+ embeddings: ArrayLike,
271
+ padding_idx: Optional[int] = None,
272
+ max_norm: Optional[float] = None,
273
+ norm_type: float = 2.0,
274
+ scale_grad_by_freq: bool = False,
275
+ freeze: bool = True,
276
+ name: Optional[str] = None,
277
+ param_type: type = ParamState,
278
+ ):
279
+ r"""
280
+ Create an Embedding instance from given 2-dimensional array.
281
+
282
+ Parameters
283
+ ----------
284
+ embeddings : ArrayLike
285
+ Array containing weights for the Embedding. First dimension is passed to
286
+ Embedding as ``num_embeddings``, remaining dimensions as ``embedding_size``.
287
+ padding_idx : int, optional
288
+ If specified, the entries at ``padding_idx`` do not contribute to the gradient.
289
+ Default is ``None``.
290
+ max_norm : float, optional
291
+ See module initialization documentation. Default is ``None``.
292
+ norm_type : float, optional
293
+ See module initialization documentation. Default is ``2.0``.
294
+ scale_grad_by_freq : bool, optional
295
+ See module initialization documentation. Default is ``False``.
296
+ freeze : bool, optional
297
+ If ``True``, embeddings are frozen (no gradients). Default is ``True``.
298
+ name : str, optional
299
+ The name of the module.
300
+
301
+ Returns
302
+ -------
303
+ Embedding
304
+ An Embedding module with pretrained weights.
305
+
306
+ Examples
307
+ --------
308
+ Load pretrained word embeddings:
309
+
310
+ .. code-block:: python
311
+
312
+ >>> import jax.numpy as jnp
313
+ >>> import brainstate as brainstate
314
+ >>> pretrained = jnp.array([[1.0, 2.0, 3.0],
315
+ ... [4.0, 5.0, 6.0],
316
+ ... [7.0, 8.0, 9.0]])
317
+ >>> embedding = brainstate.nn.Embedding.from_pretrained(pretrained)
318
+ >>> embedding.weight.value.shape
319
+ (3, 3)
320
+ >>> indices = jnp.array([1])
321
+ >>> embedding(indices)
322
+ Array([[4., 5., 6.]], dtype=float32)
323
+ """
324
+ embeddings = jnp.asarray(embeddings)
325
+ if embeddings.ndim < 2:
326
+ raise ValueError('embeddings must be at least 2-dimensional')
327
+
328
+ num_embeddings = embeddings.shape[0]
329
+ embedding_size = embeddings.shape[1:]
330
+
331
+ instance = cls(
332
+ num_embeddings=num_embeddings,
333
+ embedding_size=embedding_size,
334
+ embedding_init=embeddings,
335
+ padding_idx=padding_idx,
336
+ max_norm=max_norm,
337
+ norm_type=norm_type,
338
+ scale_grad_by_freq=scale_grad_by_freq,
339
+ freeze=freeze,
340
+ name=name,
341
+ param_type=param_type,
342
+ )
343
+
344
+ instance.weight = param_type(jnp.asarray(embeddings))
345
+ return instance
346
+
347
+ def update(self, indices: ArrayLike):
348
+ """
349
+ Retrieve embeddings for the given indices.
350
+
351
+ Parameters
352
+ ----------
353
+ indices : ArrayLike
354
+ Indices to retrieve embeddings for. Can be any shape.
355
+
356
+ Returns
357
+ -------
358
+ ArrayLike
359
+ Embeddings corresponding to the indices, with shape
360
+ ``(*indices.shape, *embedding_size)``.
361
+ """
362
+ indices = jnp.asarray(indices)
363
+ if not jnp.issubdtype(indices.dtype, jnp.integer):
364
+ raise TypeError('Embedding indices must be integers.')
365
+
366
+ weight_value = self.weight.value
367
+ effective_weight = weight_value
368
+
369
+ if self.max_norm is not None:
370
+ renormed_weight = self._apply_max_norm(weight_value, indices)
371
+ effective_weight = weight_value + jax.lax.stop_gradient(renormed_weight - weight_value)
372
+ if not _contains_tracer(renormed_weight):
373
+ self.weight.value = renormed_weight
374
+
375
+ if self.freeze:
376
+ effective_weight = jax.lax.stop_gradient(effective_weight)
377
+
378
+ embeddings = self._lookup(effective_weight, indices)
379
+ return embeddings
380
+
381
+ def _apply_max_norm(self, weight: jax.Array, indices: jax.Array) -> jax.Array:
382
+ """Apply max_norm constraint to the embedding weights for the given indices."""
383
+ flat_idx = jnp.ravel(indices)
384
+ if flat_idx.size == 0:
385
+ return weight
386
+
387
+ flat_idx = jnp.asarray(flat_idx, dtype=jnp.int32)
388
+ if self.padding_idx is not None:
389
+ pad_value = jnp.asarray(self.padding_idx, dtype=flat_idx.dtype)
390
+ flat_idx = flat_idx[flat_idx != pad_value]
391
+
392
+ if flat_idx.size == 0:
393
+ return weight
394
+
395
+ rows = weight[flat_idx]
396
+ rows_flat = rows.reshape((rows.shape[0], -1))
397
+ row_dtype = rows_flat.dtype
398
+
399
+ norms = jnp.linalg.norm(rows_flat, ord=self.norm_type, axis=1, keepdims=True)
400
+ max_norm = jnp.asarray(self.max_norm, dtype=row_dtype)
401
+ eps = jnp.asarray(1e-8, dtype=row_dtype)
402
+ one = jnp.asarray(1.0, dtype=row_dtype)
403
+ scale = jnp.minimum(one, max_norm / (norms + eps))
404
+ rows_scaled = (rows_flat * scale).reshape(rows.shape)
405
+
406
+ return weight.at[flat_idx].set(rows_scaled)
407
+
408
+