brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -19,9 +19,9 @@ import brainevent
|
|
19
19
|
import brainunit as u
|
20
20
|
import jax
|
21
21
|
|
22
|
-
from brainstate import init
|
23
22
|
from brainstate._state import ParamState
|
24
23
|
from brainstate.typing import Size, ArrayLike
|
24
|
+
from . import init as init
|
25
25
|
from ._module import Module
|
26
26
|
|
27
27
|
__all__ = [
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -18,6 +18,7 @@ import jax
|
|
18
18
|
import jax.numpy as jnp
|
19
19
|
import pytest
|
20
20
|
|
21
|
+
import braintools
|
21
22
|
import brainstate
|
22
23
|
|
23
24
|
|
@@ -30,7 +31,7 @@ class TestEventLinear:
|
|
30
31
|
x = jnp.asarray(x, dtype=float)
|
31
32
|
m = brainstate.nn.EventLinear(
|
32
33
|
20, 40,
|
33
|
-
1.5 if homo_w else
|
34
|
+
1.5 if homo_w else braintools.init.KaimingUniform(),
|
34
35
|
float_as_event=bool_x
|
35
36
|
)
|
36
37
|
y = m(x)
|
@@ -42,7 +43,7 @@ class TestEventLinear:
|
|
42
43
|
n_in = 20
|
43
44
|
n_out = 30
|
44
45
|
x = brainstate.random.rand(n_in) < 0.3
|
45
|
-
fn = brainstate.nn.EventLinear(n_in, n_out,
|
46
|
+
fn = brainstate.nn.EventLinear(n_in, n_out, braintools.init.KaimingUniform())
|
46
47
|
|
47
48
|
with pytest.raises(TypeError):
|
48
49
|
print(jax.grad(lambda x: fn(x).sum())(x))
|
@@ -60,7 +61,7 @@ class TestEventLinear:
|
|
60
61
|
fn = brainstate.nn.EventLinear(
|
61
62
|
n_in,
|
62
63
|
n_out,
|
63
|
-
1.5 if homo_w else
|
64
|
+
1.5 if homo_w else braintools.init.KaimingUniform(),
|
64
65
|
float_as_event=bool_x
|
65
66
|
)
|
66
67
|
w = fn.weight.value
|
@@ -97,7 +98,7 @@ class TestEventLinear:
|
|
97
98
|
x = brainstate.random.rand(n_in)
|
98
99
|
|
99
100
|
fn = brainstate.nn.EventLinear(
|
100
|
-
n_in, n_out, 1.5 if homo_w else
|
101
|
+
n_in, n_out, 1.5 if homo_w else braintools.init.KaimingUniform(),
|
101
102
|
float_as_event=bool_x
|
102
103
|
)
|
103
104
|
w = fn.weight.value
|
brainstate/nn/_exp_euler.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -20,7 +20,7 @@ import brainunit as u
|
|
20
20
|
import jax.numpy as jnp
|
21
21
|
|
22
22
|
from brainstate import environ, random
|
23
|
-
from brainstate.
|
23
|
+
from brainstate.transform import vector_grad
|
24
24
|
|
25
25
|
__all__ = [
|
26
26
|
'exp_euler_step',
|
@@ -31,62 +31,224 @@ def exp_euler_step(
|
|
31
31
|
fn: Callable, *args, **kwargs
|
32
32
|
):
|
33
33
|
r"""
|
34
|
-
One-step Exponential Euler method for solving ODEs.
|
34
|
+
One-step Exponential Euler method for solving ODEs and SDEs.
|
35
|
+
|
36
|
+
The Exponential Euler method is a numerical integration scheme that provides improved
|
37
|
+
stability for stiff differential equations by exactly integrating the linear part of
|
38
|
+
the equation. For ODEs, it solves equations of the form:
|
39
|
+
|
40
|
+
.. math::
|
41
|
+
\frac{dx}{dt} = f(x, t)
|
42
|
+
|
43
|
+
For SDEs, it handles equations of the form:
|
44
|
+
|
45
|
+
.. math::
|
46
|
+
dx = f(x, t)dt + g(x, t)dW
|
47
|
+
|
48
|
+
where :math:`f(x, t)` is the drift term and :math:`g(x, t)` is the diffusion term.
|
49
|
+
|
50
|
+
The method linearizes the drift function around the current state and uses the
|
51
|
+
matrix exponential to integrate the linear part exactly, while treating the
|
52
|
+
remainder with standard Euler stepping.
|
53
|
+
|
54
|
+
Parameters
|
55
|
+
----------
|
56
|
+
fn : Callable
|
57
|
+
The drift function :math:`f(x, t)` to be integrated. This function should
|
58
|
+
take the state variable as the first argument, followed by optional time
|
59
|
+
and other arguments. It should return the derivative :math:`dx/dt`.
|
60
|
+
*args
|
61
|
+
Variable arguments. If the first argument is callable, it is treated as
|
62
|
+
the diffusion function for SDE integration. Otherwise, arguments are
|
63
|
+
passed to the drift function. The first non-callable argument should be
|
64
|
+
the state variable :math:`x`.
|
65
|
+
**kwargs
|
66
|
+
Additional keyword arguments passed to the drift and diffusion functions.
|
67
|
+
|
68
|
+
Returns
|
69
|
+
-------
|
70
|
+
x_next : ArrayLike
|
71
|
+
The state variable after one integration step of size ``dt``, where ``dt``
|
72
|
+
is obtained from the environment via ``environ.get('dt')``.
|
73
|
+
|
74
|
+
Raises
|
75
|
+
------
|
76
|
+
ValueError
|
77
|
+
If the input state variable dtype is not float16, bfloat16, float32, or float64.
|
78
|
+
ValueError
|
79
|
+
If drift and diffusion terms have incompatible units.
|
80
|
+
AssertionError
|
81
|
+
If ``fn`` is not callable or if no state variable is provided in ``*args``.
|
82
|
+
|
83
|
+
Notes
|
84
|
+
-----
|
85
|
+
**Unit Compatibility:**
|
86
|
+
|
87
|
+
- If the state variable :math:`x` has units :math:`[X]`, the drift function
|
88
|
+
:math:`f(x, t)` should return values with units :math:`[X]/[T]`, where
|
89
|
+
:math:`[T]` is the unit of time.
|
90
|
+
|
91
|
+
- If the state variable :math:`x` has units :math:`[X]`, the diffusion function
|
92
|
+
:math:`g(x, t)` should return values with units :math:`[X]/\sqrt{[T]}`.
|
93
|
+
|
94
|
+
**Algorithm:**
|
95
|
+
|
96
|
+
The method computes the Jacobian :math:`J = \frac{\partial f}{\partial x}` and
|
97
|
+
uses the exponential-related function :math:`\varphi(z) = (e^z - 1)/z` to update:
|
98
|
+
|
99
|
+
.. math::
|
100
|
+
x_{n+1} = x_n + dt \cdot \varphi(dt \cdot J) \cdot f(x_n, t_n)
|
101
|
+
|
102
|
+
For SDEs, a stochastic term is added:
|
103
|
+
|
104
|
+
.. math::
|
105
|
+
x_{n+1} = x_{n+1} + g(x_n, t_n) \sqrt{dt} \cdot \mathcal{N}(0, I)
|
35
106
|
|
36
107
|
Examples
|
37
108
|
--------
|
109
|
+
**ODE Integration:**
|
110
|
+
|
111
|
+
Simple exponential decay equation :math:`\frac{dx}{dt} = -x`:
|
38
112
|
|
39
|
-
|
40
|
-
... return -x
|
41
|
-
>>> x = 1.0
|
42
|
-
>>> exp_euler_step(fun, x, None)
|
113
|
+
.. code-block:: python
|
43
114
|
|
44
|
-
|
45
|
-
|
115
|
+
>>> import brainstate as bst
|
116
|
+
>>> import jax.numpy as jnp
|
117
|
+
>>>
|
118
|
+
>>> # Set time step in environment
|
119
|
+
>>> bst.environ.set(dt=0.01)
|
120
|
+
>>>
|
121
|
+
>>> # Define drift function
|
122
|
+
>>> def drift(x, t):
|
123
|
+
... return -x
|
124
|
+
>>>
|
125
|
+
>>> # Initial condition
|
126
|
+
>>> x0 = jnp.array(1.0)
|
127
|
+
>>>
|
128
|
+
>>> # Single integration step
|
129
|
+
>>> x1 = bst.nn.exp_euler_step(drift, x0, None)
|
130
|
+
>>> print(x1) # Should be close to exp(-0.01) ≈ 0.99
|
46
131
|
|
47
|
-
|
48
|
-
should have units of ( [X]/\sqrt{[T]} ).
|
132
|
+
**SDE Integration:**
|
49
133
|
|
50
|
-
|
51
|
-
fun: Callable. The function to be solved.
|
52
|
-
diffusion: Callable. The diffusion function.
|
53
|
-
*args: The input arguments.
|
54
|
-
drift: Callable. The drift function.
|
134
|
+
Ornstein-Uhlenbeck process :math:`dx = -\theta x dt + \sigma dW`:
|
55
135
|
|
56
|
-
|
57
|
-
|
136
|
+
.. code-block:: python
|
137
|
+
|
138
|
+
>>> import brainstate as bst
|
139
|
+
>>> import jax.numpy as jnp
|
140
|
+
>>>
|
141
|
+
>>> # Set time step
|
142
|
+
>>> bst.environ.set(dt=0.01)
|
143
|
+
>>>
|
144
|
+
>>> # Define drift and diffusion
|
145
|
+
>>> theta = 0.5
|
146
|
+
>>> sigma = 0.3
|
147
|
+
>>>
|
148
|
+
>>> def drift(x, t):
|
149
|
+
... return -theta * x
|
150
|
+
>>>
|
151
|
+
>>> def diffusion(x, t):
|
152
|
+
... return jnp.full_like(x, sigma)
|
153
|
+
>>>
|
154
|
+
>>> # Initial condition
|
155
|
+
>>> x0 = jnp.array(1.0)
|
156
|
+
>>>
|
157
|
+
>>> # Single SDE integration step
|
158
|
+
>>> x1 = bst.nn.exp_euler_step(drift, diffusion, x0, None)
|
159
|
+
|
160
|
+
**Multi-dimensional system:**
|
161
|
+
|
162
|
+
.. code-block:: python
|
163
|
+
|
164
|
+
>>> import brainstate as bst
|
165
|
+
>>> import jax.numpy as jnp
|
166
|
+
>>>
|
167
|
+
>>> bst.environ.set(dt=0.01)
|
168
|
+
>>>
|
169
|
+
>>> # Coupled oscillator system
|
170
|
+
>>> def drift(x, t):
|
171
|
+
... x1, x2 = x[0], x[1]
|
172
|
+
... return jnp.array([-x1 + x2, -x2 - x1])
|
173
|
+
>>>
|
174
|
+
>>> x0 = jnp.array([1.0, 0.0])
|
175
|
+
>>> x1 = bst.nn.exp_euler_step(drift, x0, None)
|
176
|
+
|
177
|
+
See Also
|
178
|
+
--------
|
179
|
+
brainstate.transform.vector_grad : Compute vector-Jacobian product used internally.
|
180
|
+
brainstate.environ.get : Retrieve environment variables like ``dt``.
|
181
|
+
|
182
|
+
References
|
183
|
+
----------
|
184
|
+
.. [1] Hochbruck, M., & Ostermann, A. (2010). Exponential integrators.
|
185
|
+
Acta Numerica, 19, 209-286.
|
186
|
+
.. [2] Cox, S. M., & Matthews, P. C. (2002). Exponential time differencing
|
187
|
+
for stiff systems. Journal of Computational Physics, 176(2), 430-455.
|
58
188
|
"""
|
59
|
-
|
189
|
+
# Validate inputs
|
190
|
+
assert callable(fn), 'The drift function should be callable.'
|
60
191
|
assert len(args) > 0, 'The input arguments should not be empty.'
|
192
|
+
|
193
|
+
# Parse arguments: check if first arg is diffusion function
|
194
|
+
diffusion = None
|
61
195
|
if callable(args[0]):
|
62
196
|
diffusion = args[0]
|
63
197
|
args = args[1:]
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
198
|
+
assert len(args) > 0, 'State variable is required after diffusion function.'
|
199
|
+
|
200
|
+
# Validate state variable dtype
|
201
|
+
state = u.math.asarray(args[0])
|
202
|
+
dtype = u.math.get_dtype(state)
|
203
|
+
if dtype not in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]:
|
68
204
|
raise ValueError(
|
69
|
-
f'
|
70
|
-
f'
|
205
|
+
f'State variable dtype must be float16, bfloat16, float32, or float64 '
|
206
|
+
f'for Exponential Euler method, but got {dtype}.'
|
71
207
|
)
|
72
208
|
|
73
|
-
#
|
74
|
-
dt = environ.
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
209
|
+
# Get time step from environment
|
210
|
+
dt = environ.get_dt()
|
211
|
+
|
212
|
+
# Compute drift term with Jacobian
|
213
|
+
# vector_grad returns (Jacobian, function_value)
|
214
|
+
jacobian, drift_value = vector_grad(fn, argnums=0, return_value=True)(*args, **kwargs)
|
215
|
+
|
216
|
+
# Convert Jacobian to proper units: [derivative_unit / state_unit] = [1/T]
|
217
|
+
jacobian_with_unit = u.Quantity(
|
218
|
+
u.get_mantissa(jacobian),
|
219
|
+
u.get_unit(drift_value) / u.get_unit(jacobian)
|
220
|
+
)
|
221
|
+
|
222
|
+
# Compute phi function: phi(z) = (exp(z) - 1) / z
|
223
|
+
# This is the exponential-related function for stability
|
224
|
+
phi = u.math.exprel(dt * jacobian_with_unit)
|
225
|
+
|
226
|
+
# Update state using exponential Euler scheme
|
227
|
+
x_next = state + dt * phi * drift_value
|
79
228
|
|
80
|
-
# diffusion
|
229
|
+
# Add diffusion term for SDE if provided
|
81
230
|
if diffusion is not None:
|
82
|
-
|
83
|
-
|
231
|
+
# Compute diffusion coefficient
|
232
|
+
diffusion_coef = diffusion(*args, **kwargs)
|
233
|
+
|
234
|
+
# Generate random noise and scale by sqrt(dt)
|
235
|
+
noise = random.randn_like(state)
|
236
|
+
diffusion_term = diffusion_coef * u.math.sqrt(dt) * noise
|
237
|
+
|
238
|
+
# Validate unit compatibility between drift and diffusion
|
239
|
+
if u.get_dim(x_next) != u.get_dim(diffusion_term):
|
84
240
|
drift_unit = u.get_unit(x_next)
|
85
241
|
time_unit = u.get_unit(dt)
|
242
|
+
expected_diffusion_unit = drift_unit / time_unit ** 0.5
|
243
|
+
actual_diffusion_unit = u.get_unit(diffusion_term)
|
86
244
|
raise ValueError(
|
87
|
-
f"
|
88
|
-
f"
|
89
|
-
f"
|
245
|
+
f"Unit mismatch between drift and diffusion terms. "
|
246
|
+
f"State has unit {u.get_unit(state)}, "
|
247
|
+
f"drift produces unit {drift_unit}, "
|
248
|
+
f"expected diffusion unit {expected_diffusion_unit}, "
|
249
|
+
f"but got {actual_diffusion_unit}."
|
90
250
|
)
|
91
|
-
|
251
|
+
|
252
|
+
x_next = x_next + diffusion_term
|
253
|
+
|
92
254
|
return x_next
|