brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -25,7 +25,80 @@ import pytest
|
|
25
25
|
import brainstate
|
26
26
|
|
27
27
|
|
28
|
+
class TestRandomExamples(unittest.TestCase):
|
29
|
+
"""Test cases that demonstrate usage examples from docstrings."""
|
30
|
+
|
31
|
+
def test_rand_examples(self):
|
32
|
+
"""Test examples from rand function docstring."""
|
33
|
+
# Generate random values in a 3x2 array
|
34
|
+
arr = brainstate.random.rand(3, 2)
|
35
|
+
self.assertEqual(arr.shape, (3, 2))
|
36
|
+
self.assertTrue((arr >= 0).all() and (arr < 1).all())
|
37
|
+
|
38
|
+
def test_randint_examples(self):
|
39
|
+
"""Test examples from randint function docstring."""
|
40
|
+
# Generate 10 random integers from 0 to 1 (exclusive)
|
41
|
+
arr = brainstate.random.randint(2, size=10)
|
42
|
+
self.assertEqual(arr.shape, (10,))
|
43
|
+
self.assertTrue((arr >= 0).all() and (arr < 2).all())
|
44
|
+
|
45
|
+
# Generate a 2x4 array of integers from 0 to 4 (exclusive)
|
46
|
+
arr = brainstate.random.randint(5, size=(2, 4))
|
47
|
+
self.assertEqual(arr.shape, (2, 4))
|
48
|
+
self.assertTrue((arr >= 0).all() and (arr < 5).all())
|
49
|
+
|
50
|
+
# Generate integers with different upper bounds using broadcasting
|
51
|
+
arr = brainstate.random.randint(1, [3, 5, 10])
|
52
|
+
self.assertEqual(arr.shape, (3,))
|
53
|
+
|
54
|
+
# Generate integers with different lower bounds
|
55
|
+
arr = brainstate.random.randint([1, 5, 7], 10)
|
56
|
+
self.assertEqual(arr.shape, (3,))
|
57
|
+
self.assertTrue((arr >= jnp.array([1, 5, 7])).all())
|
58
|
+
|
59
|
+
def test_randn_examples(self):
|
60
|
+
"""Test examples from randn function docstring."""
|
61
|
+
# Generate standard normal distributed values
|
62
|
+
arr = brainstate.random.randn(3, 2)
|
63
|
+
self.assertEqual(arr.shape, (3, 2))
|
64
|
+
|
65
|
+
def test_choice_examples(self):
|
66
|
+
"""Test examples from choice function docstring."""
|
67
|
+
# Choose from range
|
68
|
+
result = brainstate.random.choice(5)
|
69
|
+
self.assertTrue(0 <= result < 5)
|
70
|
+
|
71
|
+
# Choose multiple with probabilities
|
72
|
+
arr = brainstate.random.choice(5, 3, p=[0.1, 0.4, 0.2, 0.0, 0.3])
|
73
|
+
self.assertEqual(arr.shape, (3,))
|
74
|
+
self.assertTrue((arr >= 0).all() and (arr < 5).all())
|
75
|
+
|
76
|
+
def test_normal_examples(self):
|
77
|
+
"""Test examples from normal function docstring."""
|
78
|
+
# Standard normal
|
79
|
+
result = brainstate.random.normal()
|
80
|
+
self.assertEqual(result.shape, ())
|
81
|
+
|
82
|
+
# With different parameters
|
83
|
+
arr = brainstate.random.normal(loc=0.0, scale=1.0, size=(2, 3))
|
84
|
+
self.assertEqual(arr.shape, (2, 3))
|
85
|
+
|
86
|
+
def test_uniform_examples(self):
|
87
|
+
"""Test examples from uniform function docstring."""
|
88
|
+
# Standard uniform
|
89
|
+
result = brainstate.random.uniform()
|
90
|
+
self.assertEqual(result.shape, ())
|
91
|
+
self.assertTrue(0.0 <= result < 1.0)
|
92
|
+
|
93
|
+
# With custom range
|
94
|
+
arr = brainstate.random.uniform(low=2.0, high=5.0, size=(3, 2))
|
95
|
+
self.assertEqual(arr.shape, (3, 2))
|
96
|
+
self.assertTrue((arr >= 2.0).all() and (arr < 5.0).all())
|
97
|
+
|
98
|
+
|
28
99
|
class TestRandom(unittest.TestCase):
|
100
|
+
def setUp(self):
|
101
|
+
brainstate.environ.set(precision=32)
|
29
102
|
|
30
103
|
def test_rand(self):
|
31
104
|
brainstate.random.seed()
|