eventax 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eventax-0.1.0/PKG-INFO +39 -0
- eventax-0.1.0/README.md +16 -0
- eventax-0.1.0/pyproject.toml +4 -0
- eventax-0.1.0/setup.cfg +4 -0
- eventax-0.1.0/setup.py +27 -0
- eventax-0.1.0/src/eventax/__init__.py +0 -0
- eventax-0.1.0/src/eventax/adjoint.py +535 -0
- eventax-0.1.0/src/eventax/buffer.py +528 -0
- eventax-0.1.0/src/eventax/evnn.py +927 -0
- eventax-0.1.0/src/eventax/io.py +199 -0
- eventax-0.1.0/src/eventax/neuron_models/__init__.py +29 -0
- eventax-0.1.0/src/eventax/neuron_models/alif.py +157 -0
- eventax-0.1.0/src/eventax/neuron_models/amos_wrapper.py +93 -0
- eventax-0.1.0/src/eventax/neuron_models/base_model.py +93 -0
- eventax-0.1.0/src/eventax/neuron_models/egru.py +142 -0
- eventax-0.1.0/src/eventax/neuron_models/eif.py +190 -0
- eventax-0.1.0/src/eventax/neuron_models/helpers.py +21 -0
- eventax-0.1.0/src/eventax/neuron_models/initializations.py +175 -0
- eventax-0.1.0/src/eventax/neuron_models/izhikevich.py +137 -0
- eventax-0.1.0/src/eventax/neuron_models/lif.py +149 -0
- eventax-0.1.0/src/eventax/neuron_models/multi_model.py +222 -0
- eventax-0.1.0/src/eventax/neuron_models/pizhikevich.py +195 -0
- eventax-0.1.0/src/eventax/neuron_models/plif.py +193 -0
- eventax-0.1.0/src/eventax/neuron_models/pqif.py +78 -0
- eventax-0.1.0/src/eventax/neuron_models/qif.py +83 -0
- eventax-0.1.0/src/eventax/neuron_models/refractory_wrapper.py +128 -0
- eventax-0.1.0/src/eventax.egg-info/PKG-INFO +39 -0
- eventax-0.1.0/src/eventax.egg-info/SOURCES.txt +34 -0
- eventax-0.1.0/src/eventax.egg-info/dependency_links.txt +1 -0
- eventax-0.1.0/src/eventax.egg-info/requires.txt +3 -0
- eventax-0.1.0/src/eventax.egg-info/top_level.txt +1 -0
- eventax-0.1.0/tests/test_adjoint.py +86 -0
- eventax-0.1.0/tests/test_buffer.py +328 -0
- eventax-0.1.0/tests/test_evnn.py +386 -0
- eventax-0.1.0/tests/test_io.py +225 -0
- eventax-0.1.0/tests/test_neuron_models.py +102 -0
eventax-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: eventax
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A Diffrax-based framework for continuous-time spiking neural networks
|
|
5
|
+
Home-page: https://github.com/Efficient-Scalable-Machine-Learning/eventax
|
|
6
|
+
Author: Lukas König
|
|
7
|
+
Author-email: lukmkoenig@gmail.com
|
|
8
|
+
License: MIT
|
|
9
|
+
Requires-Python: >=3.11
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
Requires-Dist: jax>=0.9.1
|
|
12
|
+
Requires-Dist: diffrax>=0.7.2
|
|
13
|
+
Requires-Dist: equinox>=0.13.5
|
|
14
|
+
Dynamic: author
|
|
15
|
+
Dynamic: author-email
|
|
16
|
+
Dynamic: description
|
|
17
|
+
Dynamic: description-content-type
|
|
18
|
+
Dynamic: home-page
|
|
19
|
+
Dynamic: license
|
|
20
|
+
Dynamic: requires-dist
|
|
21
|
+
Dynamic: requires-python
|
|
22
|
+
Dynamic: summary
|
|
23
|
+
|
|
24
|
+
<p align="center">
|
|
25
|
+
<img src="./docs-site/docs/img/logo5.svg" alt="Eventpropjax" width="70%">
|
|
26
|
+
</p>
|
|
27
|
+
|
|
28
|
+
Eventax provides a [JAX](https://github.com/google/jax) implementation of the [EventProp algorithm](https://arxiv.org/abs/2009.08378) using [Diffrax](https://github.com/patrick-kidger/diffrax) and [Equinox](https://github.com/patrick-kidger/equinox) offering full autograd support, easy extension with custom neuron dynamics, and built-in delay training.
|
|
29
|
+
|
|
30
|
+
## Features
|
|
31
|
+
- Fully differentiable implementation via JAX and Diffrax
|
|
32
|
+
- Easy extension with custom neuron model dynamics + learnable parameters
|
|
33
|
+
- Support for (trainable) synnaptic delays.
|
|
34
|
+
- GPU/TPU compatibility through JAX
|
|
35
|
+
|
|
36
|
+
## 📦 Installation
|
|
37
|
+
```bash
|
|
38
|
+
pip install eventax
|
|
39
|
+
```
|
eventax-0.1.0/README.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
<p align="center">
|
|
2
|
+
<img src="./docs-site/docs/img/logo5.svg" alt="Eventpropjax" width="70%">
|
|
3
|
+
</p>
|
|
4
|
+
|
|
5
|
+
Eventax provides a [JAX](https://github.com/google/jax) implementation of the [EventProp algorithm](https://arxiv.org/abs/2009.08378) using [Diffrax](https://github.com/patrick-kidger/diffrax) and [Equinox](https://github.com/patrick-kidger/equinox) offering full autograd support, easy extension with custom neuron dynamics, and built-in delay training.
|
|
6
|
+
|
|
7
|
+
## Features
|
|
8
|
+
- Fully differentiable implementation via JAX and Diffrax
|
|
9
|
+
- Easy extension with custom neuron model dynamics + learnable parameters
|
|
10
|
+
- Support for (trainable) synnaptic delays.
|
|
11
|
+
- GPU/TPU compatibility through JAX
|
|
12
|
+
|
|
13
|
+
## 📦 Installation
|
|
14
|
+
```bash
|
|
15
|
+
pip install eventax
|
|
16
|
+
```
|
eventax-0.1.0/setup.cfg
ADDED
eventax-0.1.0/setup.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import os
|
|
3
|
+
from setuptools import setup, find_packages
|
|
4
|
+
here = os.path.abspath(os.path.dirname(__file__))
|
|
5
|
+
with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f:
|
|
6
|
+
long_description = f.read()
|
|
7
|
+
setup(
|
|
8
|
+
name="eventax",
|
|
9
|
+
version="0.1.0",
|
|
10
|
+
author="Lukas König",
|
|
11
|
+
author_email="lukmkoenig@gmail.com",
|
|
12
|
+
description=(
|
|
13
|
+
"A Diffrax-based framework for continuous-time spiking neural networks "
|
|
14
|
+
),
|
|
15
|
+
long_description=long_description,
|
|
16
|
+
long_description_content_type="text/markdown",
|
|
17
|
+
url="https://github.com/Efficient-Scalable-Machine-Learning/eventax",
|
|
18
|
+
license="MIT",
|
|
19
|
+
package_dir={"": "src"},
|
|
20
|
+
packages=find_packages(where="src"),
|
|
21
|
+
install_requires=[
|
|
22
|
+
"jax>=0.9.1",
|
|
23
|
+
"diffrax>=0.7.2",
|
|
24
|
+
"equinox>=0.13.5",
|
|
25
|
+
],
|
|
26
|
+
python_requires=">=3.11",
|
|
27
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,535 @@
|
|
|
1
|
+
import functools as ft
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, cast
|
|
5
|
+
|
|
6
|
+
import equinox as eqx
|
|
7
|
+
import equinox.internal as eqxi
|
|
8
|
+
import jax
|
|
9
|
+
import jax.lax as lax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
import jax.tree_util as jtu
|
|
12
|
+
from equinox.internal import ω
|
|
13
|
+
|
|
14
|
+
from diffrax._heuristics import is_sde, is_unsafe_sde
|
|
15
|
+
from diffrax._saveat import save_y, SaveAt, SubSaveAt
|
|
16
|
+
from diffrax._solver import (
|
|
17
|
+
AbstractItoSolver,
|
|
18
|
+
AbstractStratonovichSolver,
|
|
19
|
+
)
|
|
20
|
+
from diffrax._term import AbstractTerm, AdjointTerm
|
|
21
|
+
from diffrax._adjoint import AbstractAdjoint
|
|
22
|
+
|
|
23
|
+
ω = cast(Callable, ω)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _is_none(x):
|
|
27
|
+
return x is None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _is_subsaveat(x: Any) -> bool:
|
|
31
|
+
return isinstance(x, SubSaveAt)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _nondiff_solver_controller_state(
|
|
35
|
+
adjoint, init_state, passed_solver_state, passed_controller_state
|
|
36
|
+
):
|
|
37
|
+
if passed_solver_state:
|
|
38
|
+
name = (
|
|
39
|
+
f"When using `adjoint={adjoint.__class__.__name__}()`, then `solver_state`"
|
|
40
|
+
)
|
|
41
|
+
solver_fn = ft.partial(
|
|
42
|
+
eqxi.nondifferentiable,
|
|
43
|
+
name=name,
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
solver_fn = lax.stop_gradient
|
|
47
|
+
if passed_controller_state:
|
|
48
|
+
name = (
|
|
49
|
+
f"When using `adjoint={adjoint.__class__.__name__}()`, then "
|
|
50
|
+
"`controller_state`"
|
|
51
|
+
)
|
|
52
|
+
controller_fn = ft.partial(
|
|
53
|
+
eqxi.nondifferentiable,
|
|
54
|
+
name=name,
|
|
55
|
+
)
|
|
56
|
+
else:
|
|
57
|
+
controller_fn = lax.stop_gradient
|
|
58
|
+
init_state = eqx.tree_at(
|
|
59
|
+
lambda s: s.solver_state,
|
|
60
|
+
init_state,
|
|
61
|
+
replace_fn=solver_fn,
|
|
62
|
+
is_leaf=_is_none,
|
|
63
|
+
)
|
|
64
|
+
init_state = eqx.tree_at(
|
|
65
|
+
lambda s: s.controller_state,
|
|
66
|
+
init_state,
|
|
67
|
+
replace_fn=controller_fn,
|
|
68
|
+
is_leaf=_is_none,
|
|
69
|
+
)
|
|
70
|
+
return init_state
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _only_transpose_ys(final_state):
|
|
74
|
+
from diffrax._integrate import SaveState
|
|
75
|
+
|
|
76
|
+
def is_save_state(x): return isinstance(x, SaveState)
|
|
77
|
+
|
|
78
|
+
def get_ys(_final_state):
|
|
79
|
+
return [
|
|
80
|
+
s.ys
|
|
81
|
+
for s in jtu.tree_leaves(_final_state.save_state, is_leaf=is_save_state)
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
def get_ts(_final_state):
|
|
85
|
+
return [
|
|
86
|
+
s.ts
|
|
87
|
+
for s in jtu.tree_leaves(_final_state.save_state, is_leaf=is_save_state)
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
ys = get_ys(final_state)
|
|
91
|
+
ts = get_ts(final_state)
|
|
92
|
+
|
|
93
|
+
named_nondiff_entries = (
|
|
94
|
+
"y",
|
|
95
|
+
"tprev",
|
|
96
|
+
"tnext",
|
|
97
|
+
"solver_state",
|
|
98
|
+
"controller_state",
|
|
99
|
+
"dense_ts",
|
|
100
|
+
"dense_infos",
|
|
101
|
+
)
|
|
102
|
+
named_nondiff_values = tuple(
|
|
103
|
+
eqxi.nondifferentiable_backward(getattr(final_state, k), name=k, symbolic=False)
|
|
104
|
+
for k in named_nondiff_entries
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
final_state = eqxi.nondifferentiable_backward(final_state, symbolic=False)
|
|
108
|
+
|
|
109
|
+
def get_named_nondiff_entries(s): return tuple(
|
|
110
|
+
getattr(s, k) for k in named_nondiff_entries
|
|
111
|
+
)
|
|
112
|
+
final_state = eqx.tree_at(
|
|
113
|
+
get_named_nondiff_entries, final_state, named_nondiff_values, is_leaf=_is_none
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
final_state = eqx.tree_at(get_ys, final_state, ys)
|
|
117
|
+
final_state = eqx.tree_at(get_ts, final_state, ts)
|
|
118
|
+
return final_state
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
_inner_loop = jax.named_call(eqxi.while_loop, name="inner-loop")
|
|
122
|
+
_outer_loop = jax.named_call(eqxi.while_loop, name="outer-loop")
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@eqx.filter_custom_vjp
|
|
126
|
+
def _loop_backsolve(y__args__terms__t0__t1, *, self, throw, init_state, **kwargs):
|
|
127
|
+
del throw
|
|
128
|
+
y, args, terms, t0, t1 = y__args__terms__t0__t1
|
|
129
|
+
init_state = eqx.tree_at(lambda s: s.y, init_state, y)
|
|
130
|
+
del y
|
|
131
|
+
return self._loop(
|
|
132
|
+
args=args,
|
|
133
|
+
terms=terms,
|
|
134
|
+
init_state=init_state,
|
|
135
|
+
inner_while_loop=ft.partial(_inner_loop, kind="lax"),
|
|
136
|
+
outer_while_loop=ft.partial(_outer_loop, kind="lax"),
|
|
137
|
+
t0=t0,
|
|
138
|
+
t1=t1,
|
|
139
|
+
**kwargs,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@_loop_backsolve.def_fwd
|
|
144
|
+
def _loop_backsolve_fwd(perturbed, y__args__terms__t0__t1, **kwargs):
|
|
145
|
+
del perturbed
|
|
146
|
+
final_state, aux_stats = _loop_backsolve(y__args__terms__t0__t1, **kwargs)
|
|
147
|
+
# Note that `final_state.save_state` has type `PyTree[SaveState]`; here we are
|
|
148
|
+
# relying on the guard in `EventPropAdjoint` that it have trivial structure.
|
|
149
|
+
ts = final_state.save_state.ts
|
|
150
|
+
ys = final_state.save_state.ys
|
|
151
|
+
|
|
152
|
+
event_mask = final_state.event_mask
|
|
153
|
+
event_tprev = final_state.event_tprev
|
|
154
|
+
event_tnext = final_state.event_tnext
|
|
155
|
+
event_dense_info = final_state.event_dense_info
|
|
156
|
+
event_values = final_state.event_values
|
|
157
|
+
|
|
158
|
+
return (final_state, aux_stats), (ts, ys, event_mask, event_tprev, event_tnext, event_dense_info, event_values)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _materialise_none(y, grad_y):
|
|
162
|
+
if grad_y is None and eqx.is_inexact_array(y):
|
|
163
|
+
return jnp.zeros_like(y)
|
|
164
|
+
else:
|
|
165
|
+
return grad_y
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@_loop_backsolve.def_bwd
|
|
169
|
+
def _loop_backsolve_bwd(
|
|
170
|
+
residuals,
|
|
171
|
+
grad_final_state__aux_stats,
|
|
172
|
+
perturbed,
|
|
173
|
+
y__args__terms__t0__t1,
|
|
174
|
+
*,
|
|
175
|
+
self,
|
|
176
|
+
solver,
|
|
177
|
+
stepsize_controller,
|
|
178
|
+
event,
|
|
179
|
+
saveat,
|
|
180
|
+
dt0,
|
|
181
|
+
max_steps,
|
|
182
|
+
throw,
|
|
183
|
+
init_state,
|
|
184
|
+
progress_meter,
|
|
185
|
+
):
|
|
186
|
+
#
|
|
187
|
+
# Unpack our various arguments. Delete a lot of things just to make sure we're not
|
|
188
|
+
# using them later.
|
|
189
|
+
#
|
|
190
|
+
|
|
191
|
+
del perturbed, init_state, progress_meter
|
|
192
|
+
ts, ys, event_mask, event_tprev, event_tnext, event_dense_info, event_values = residuals
|
|
193
|
+
del residuals
|
|
194
|
+
grad_final_state, _ = grad_final_state__aux_stats
|
|
195
|
+
# Note that `grad_final_state.save_state` has type `PyTree[SaveState]`; here we are
|
|
196
|
+
# relying on the guard in `EventPropAdjoint` that it have trivial structure.
|
|
197
|
+
grad_ys = grad_final_state.save_state.ys
|
|
198
|
+
grad_ts = grad_final_state.save_state.ts
|
|
199
|
+
# We take the simple way out and don't try to handle symbolic zeros.
|
|
200
|
+
grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys)
|
|
201
|
+
grad_ts = jtu.tree_map(_materialise_none, ts, grad_ts)
|
|
202
|
+
del grad_final_state, grad_final_state__aux_stats
|
|
203
|
+
y, args, terms, t0, t1 = y__args__terms__t0__t1
|
|
204
|
+
del y__args__terms__t0__t1
|
|
205
|
+
diff_args = eqx.filter(args, eqx.is_inexact_array)
|
|
206
|
+
diff_terms = eqx.filter(terms, eqx.is_inexact_array)
|
|
207
|
+
zeros_like_y = jtu.tree_map(jnp.zeros_like, y)
|
|
208
|
+
zeros_like_diff_args = jtu.tree_map(jnp.zeros_like, diff_args)
|
|
209
|
+
zeros_like_diff_terms = jtu.tree_map(jnp.zeros_like, diff_terms)
|
|
210
|
+
|
|
211
|
+
# TODO: have this look inside MultiTerms? Need to think about the math. i.e.:
|
|
212
|
+
# is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm)
|
|
213
|
+
adjoint_terms = jtu.tree_map(
|
|
214
|
+
AdjointTerm, terms, is_leaf=lambda x: isinstance(x, AbstractTerm)
|
|
215
|
+
)
|
|
216
|
+
diffeqsolve = self._diffeqsolve
|
|
217
|
+
kwargs = dict(
|
|
218
|
+
args=args,
|
|
219
|
+
adjoint=self,
|
|
220
|
+
solver=solver,
|
|
221
|
+
stepsize_controller=stepsize_controller,
|
|
222
|
+
terms=adjoint_terms,
|
|
223
|
+
dt0=None if dt0 is None else -dt0,
|
|
224
|
+
max_steps=max_steps,
|
|
225
|
+
throw=throw,
|
|
226
|
+
)
|
|
227
|
+
kwargs.update(self.kwargs)
|
|
228
|
+
|
|
229
|
+
# Note that `saveat.subs` has type `PyTree[SubSaveAt]`. Here we use the assumption
|
|
230
|
+
# (checked in `EventPropAdjoint`) that it has trivial pytree structure.
|
|
231
|
+
saveat_t0 = saveat.subs.t0
|
|
232
|
+
|
|
233
|
+
if event is not None and event_mask is not None:
|
|
234
|
+
def _event_contribution():
|
|
235
|
+
"""Compute the adjoint jump from the event."""
|
|
236
|
+
# Get the time and state values of the event
|
|
237
|
+
t_event = ω(ts)[-1].ω
|
|
238
|
+
y_event = ω(ys)[-1].ω
|
|
239
|
+
|
|
240
|
+
dL_dt_star = ω(grad_ts)[-1].ω
|
|
241
|
+
lambda_plus = ω(grad_ys)[-1].ω # λ(t*⁺) = dL/dy(t*)
|
|
242
|
+
|
|
243
|
+
f_val = terms.vf(t_event, y_event, args)
|
|
244
|
+
|
|
245
|
+
def _eval_cond(_y, _t):
|
|
246
|
+
return event.cond_fn(
|
|
247
|
+
_t, _y, args,
|
|
248
|
+
terms=terms,
|
|
249
|
+
solver=solver,
|
|
250
|
+
t0=t0,
|
|
251
|
+
t1=t1,
|
|
252
|
+
dt0=dt0,
|
|
253
|
+
saveat=saveat,
|
|
254
|
+
stepsize_controller=stepsize_controller,
|
|
255
|
+
max_steps=max_steps,
|
|
256
|
+
)
|
|
257
|
+
_, vjp_fun = eqx.filter_vjp(_eval_cond, y_event, t_event)
|
|
258
|
+
dg_dy, dg_dt = vjp_fun(1.0)
|
|
259
|
+
|
|
260
|
+
# Compute ν = -(dL/dt* + λᵀ · f) / (∂g/∂t + ∂g/∂y · f)
|
|
261
|
+
dg_dy_dot_f = jtu.tree_reduce(
|
|
262
|
+
lambda a, b: a + b,
|
|
263
|
+
jtu.tree_map(lambda dy, f: jnp.sum(dy * f), dg_dy, f_val)
|
|
264
|
+
)
|
|
265
|
+
denominator = dg_dt + dg_dy_dot_f
|
|
266
|
+
|
|
267
|
+
lambda_dot_f = jtu.tree_reduce(
|
|
268
|
+
lambda a, b: a + b,
|
|
269
|
+
jtu.tree_map(lambda lp, f: jnp.sum(lp * f), lambda_plus, f_val)
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
nu = -(dL_dt_star + lambda_dot_f) / (denominator + 1e-12)
|
|
273
|
+
|
|
274
|
+
# Δλ = ν · ∂g/∂y
|
|
275
|
+
adjoint_jump = jtu.tree_map(lambda dy: nu * dy, dg_dy)
|
|
276
|
+
|
|
277
|
+
# λ(t*⁻) = λ(t*⁺) + Δλ
|
|
278
|
+
lambda_minus = (lambda_plus**ω + adjoint_jump**ω).ω
|
|
279
|
+
return lambda_minus
|
|
280
|
+
|
|
281
|
+
def _no_event_contribution():
|
|
282
|
+
"""No event occurred, just return the original gradient."""
|
|
283
|
+
return ω(grad_ys)[-1].ω
|
|
284
|
+
|
|
285
|
+
# Conditionally compute event contribution
|
|
286
|
+
adjusted_lambda = lax.cond(
|
|
287
|
+
event_mask,
|
|
288
|
+
_event_contribution,
|
|
289
|
+
_no_event_contribution,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
grad_ys = jtu.tree_map(
|
|
293
|
+
lambda g, new_val: g.at[-1].set(new_val),
|
|
294
|
+
grad_ys,
|
|
295
|
+
adjusted_lambda
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
del self, solver, stepsize_controller, adjoint_terms, dt0, max_steps, throw
|
|
299
|
+
del saveat
|
|
300
|
+
del diff_args, diff_terms
|
|
301
|
+
|
|
302
|
+
#
|
|
303
|
+
# Now run a scan backwards in time, diffeqsolve'ing between each pair of adjacent
|
|
304
|
+
# timestamps.
|
|
305
|
+
#
|
|
306
|
+
|
|
307
|
+
def _scan_fun(_state, _vals, first=False):
|
|
308
|
+
_t1, _t0, _y0, _grad_y0 = _vals
|
|
309
|
+
_a0, _solver_state, _controller_state = _state
|
|
310
|
+
_a_y0, _a_diff_args0, _a_diff_term0 = _a0
|
|
311
|
+
_a_y0 = (_a_y0**ω + _grad_y0**ω).ω
|
|
312
|
+
_aug0 = (_y0, _a_y0, _a_diff_args0, _a_diff_term0)
|
|
313
|
+
|
|
314
|
+
_sol = diffeqsolve(
|
|
315
|
+
t0=_t0,
|
|
316
|
+
t1=_t1,
|
|
317
|
+
y0=_aug0,
|
|
318
|
+
solver_state=_solver_state,
|
|
319
|
+
controller_state=_controller_state,
|
|
320
|
+
made_jump=not first, # Adding _grad_y0, above, is a jump.
|
|
321
|
+
saveat=SaveAt(t1=True, solver_state=True, controller_state=True),
|
|
322
|
+
**kwargs,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
def __get(__aug):
|
|
326
|
+
assert __aug.shape[0] == 1
|
|
327
|
+
return __aug[0]
|
|
328
|
+
|
|
329
|
+
_aug1 = ω(_sol.ys).call(__get).ω
|
|
330
|
+
_, _a_y1, _a_diff_args1, _a_diff_term1 = _aug1
|
|
331
|
+
_a1 = (_a_y1, _a_diff_args1, _a_diff_term1)
|
|
332
|
+
_solver_state = _sol.solver_state
|
|
333
|
+
_controller_state = _sol.controller_state
|
|
334
|
+
|
|
335
|
+
return (_a1, _solver_state, _controller_state), None
|
|
336
|
+
|
|
337
|
+
state = ((zeros_like_y, zeros_like_diff_args, zeros_like_diff_terms), None, None)
|
|
338
|
+
del zeros_like_y, zeros_like_diff_args, zeros_like_diff_terms
|
|
339
|
+
|
|
340
|
+
# We always start backpropagating from `ts[-1]`.
|
|
341
|
+
# We always finish backpropagating at `t0`.
|
|
342
|
+
#
|
|
343
|
+
# We may or may not have included `t0` in `ts`. (Depending on the value of
|
|
344
|
+
# SaveaAt(t0=...) on the forward pass.)
|
|
345
|
+
#
|
|
346
|
+
# For some of these options, we run _scan_fun once outside the loop to get access
|
|
347
|
+
# to solver_state etc. of the correct PyTree structure.
|
|
348
|
+
if saveat_t0:
|
|
349
|
+
if len(ts) > 2:
|
|
350
|
+
val0 = (ts[-2], ts[-1], ω(ys)[-1].ω, ω(grad_ys)[-1].ω)
|
|
351
|
+
state, _ = _scan_fun(state, val0, first=True)
|
|
352
|
+
vals = (
|
|
353
|
+
ts[:-2],
|
|
354
|
+
ts[1:-1],
|
|
355
|
+
ω(ys)[1:-1].ω,
|
|
356
|
+
ω(grad_ys)[1:-1].ω,
|
|
357
|
+
)
|
|
358
|
+
state, _ = lax.scan(_scan_fun, state, vals, reverse=True)
|
|
359
|
+
|
|
360
|
+
elif len(ts) == 1:
|
|
361
|
+
# nothing to do, diffeqsolve is the identity when merely SaveAt(t0=True).
|
|
362
|
+
pass
|
|
363
|
+
|
|
364
|
+
else:
|
|
365
|
+
assert len(ts) == 2
|
|
366
|
+
val = (ts[0], ts[1], ω(ys)[1].ω, ω(grad_ys)[1].ω)
|
|
367
|
+
state, _ = _scan_fun(state, val, first=True)
|
|
368
|
+
|
|
369
|
+
aug1, _, _ = state
|
|
370
|
+
a_y1, a_diff_args1, a_diff_terms1 = aug1
|
|
371
|
+
a_y1 = (ω(a_y1) + ω(grad_ys)[0]).ω
|
|
372
|
+
|
|
373
|
+
else:
|
|
374
|
+
if len(ts) > 1:
|
|
375
|
+
# TODO: fold this `_scan_fun` into the `lax.scan`. This will reduce compile
|
|
376
|
+
# time.
|
|
377
|
+
val0 = (ts[-2], ts[-1], ω(ys)[-1].ω, ω(grad_ys)[-1].ω)
|
|
378
|
+
state, _ = _scan_fun(state, val0, first=True)
|
|
379
|
+
vals = (
|
|
380
|
+
jnp.concatenate([t0[None], ts[:-2]]),
|
|
381
|
+
ts[:-1],
|
|
382
|
+
ω(ys)[:-1].ω,
|
|
383
|
+
ω(grad_ys)[:-1].ω,
|
|
384
|
+
)
|
|
385
|
+
state, _ = lax.scan(_scan_fun, state, vals, reverse=True)
|
|
386
|
+
|
|
387
|
+
else:
|
|
388
|
+
assert len(ts) == 1
|
|
389
|
+
val = (t0, ts[0], ω(ys)[0].ω, ω(grad_ys)[0].ω)
|
|
390
|
+
state, _ = _scan_fun(state, val, first=True)
|
|
391
|
+
|
|
392
|
+
aug1, _, _ = state
|
|
393
|
+
a_y1, a_diff_args1, a_diff_terms1 = aug1
|
|
394
|
+
|
|
395
|
+
# Boundary conditions
|
|
396
|
+
f_t1 = terms.vf(ts[-1], ω(ys)[-1].ω, args)
|
|
397
|
+
grad_t1 = jtu.tree_reduce(
|
|
398
|
+
lambda a, b: a + b,
|
|
399
|
+
jtu.tree_map(lambda a, f: jnp.sum(a * f), ω(grad_ys)[-1].ω, f_t1)
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
f_t0 = terms.vf(t0, y, args) # Vector field at t0
|
|
403
|
+
grad_t0 = -jtu.tree_reduce(
|
|
404
|
+
lambda a, b: a + b,
|
|
405
|
+
jtu.tree_map(lambda a, f: jnp.sum(a * f), a_y1, f_t0)
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
return a_y1, a_diff_args1, a_diff_terms1, grad_t0, grad_t1
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
class EventPropAdjoint(AbstractAdjoint):
|
|
412
|
+
"""Backpropagate through [`diffrax.diffeqsolve`][] by solving the continuous
|
|
413
|
+
adjoint equations backwards-in-time. This is also sometimes known as
|
|
414
|
+
"optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint
|
|
415
|
+
method".
|
|
416
|
+
|
|
417
|
+
This will compute gradients with respect to the `terms`, `y0`, `args`, `t0`, and `t1`
|
|
418
|
+
arguments passed to [`diffrax.diffeqsolve`][]. If you attempt to compute gradients with
|
|
419
|
+
respect to anything else (for example arguments passed via closure), then
|
|
420
|
+
a `CustomVJPException` will be raised by JAX. See also
|
|
421
|
+
[this FAQ](../../further_details/faq/#im-getting-a-customvjpexception)
|
|
422
|
+
entry.
|
|
423
|
+
|
|
424
|
+
!!! info
|
|
425
|
+
|
|
426
|
+
Using this method prevents computing forward-mode autoderivatives of
|
|
427
|
+
[`diffrax.diffeqsolve`][]. (That is to say, `jax.jvp` will not work.)
|
|
428
|
+
""" # noqa: E501
|
|
429
|
+
|
|
430
|
+
kwargs: dict[str, Any]
|
|
431
|
+
|
|
432
|
+
def __init__(self, **kwargs):
|
|
433
|
+
"""
|
|
434
|
+
**Arguments:**
|
|
435
|
+
|
|
436
|
+
- `**kwargs`: The arguments for the [`diffrax.diffeqsolve`][] operations that
|
|
437
|
+
are called on the backward pass. For example use
|
|
438
|
+
```python
|
|
439
|
+
EventPropAdjoint(solver=Dopri5())
|
|
440
|
+
```
|
|
441
|
+
to specify a particular solver to use on the backward pass.
|
|
442
|
+
"""
|
|
443
|
+
valid_keys = {
|
|
444
|
+
"dt0",
|
|
445
|
+
"solver",
|
|
446
|
+
"stepsize_controller",
|
|
447
|
+
"adjoint",
|
|
448
|
+
"max_steps",
|
|
449
|
+
"throw",
|
|
450
|
+
}
|
|
451
|
+
given_keys = set(kwargs.keys())
|
|
452
|
+
diff_keys = given_keys - valid_keys
|
|
453
|
+
if len(diff_keys) > 0:
|
|
454
|
+
raise ValueError(
|
|
455
|
+
"The following keyword argments are not valid for `EventPropAdjoint`: "
|
|
456
|
+
f"{diff_keys}"
|
|
457
|
+
)
|
|
458
|
+
self.kwargs = kwargs
|
|
459
|
+
|
|
460
|
+
def loop(
|
|
461
|
+
self,
|
|
462
|
+
*,
|
|
463
|
+
args,
|
|
464
|
+
terms,
|
|
465
|
+
solver,
|
|
466
|
+
saveat,
|
|
467
|
+
init_state,
|
|
468
|
+
passed_solver_state,
|
|
469
|
+
passed_controller_state,
|
|
470
|
+
event,
|
|
471
|
+
t0,
|
|
472
|
+
t1,
|
|
473
|
+
dt0,
|
|
474
|
+
**kwargs,
|
|
475
|
+
):
|
|
476
|
+
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
|
|
477
|
+
0
|
|
478
|
+
):
|
|
479
|
+
raise NotImplementedError(
|
|
480
|
+
"Cannot use `adjoint=EventPropAdjoint()` with `SaveAt(subs=...)`."
|
|
481
|
+
)
|
|
482
|
+
if saveat.dense or saveat.subs.steps:
|
|
483
|
+
raise NotImplementedError(
|
|
484
|
+
"Cannot use `adjoint=EventPropAdjoint()` with "
|
|
485
|
+
"`saveat=SaveAt(steps=True)` or saveat=SaveAt(dense=True)`."
|
|
486
|
+
)
|
|
487
|
+
if saveat.subs.fn is not save_y:
|
|
488
|
+
raise NotImplementedError(
|
|
489
|
+
"Cannot use `adjoint=EventPropAdjoint()` with `saveat=SaveAt(fn=...)`."
|
|
490
|
+
)
|
|
491
|
+
if is_unsafe_sde(terms):
|
|
492
|
+
raise ValueError(
|
|
493
|
+
"`adjoint=EventPropAdjoint()` does not support `UnsafeBrownianPath`. "
|
|
494
|
+
"Consider using `adjoint=DirectAdjoint()` instead."
|
|
495
|
+
)
|
|
496
|
+
if is_sde(terms):
|
|
497
|
+
if isinstance(solver, AbstractItoSolver):
|
|
498
|
+
raise NotImplementedError(
|
|
499
|
+
f"`{solver.__class__.__name__}` converges to the Itô solution. "
|
|
500
|
+
"However `EventPropAdjoint` currently only supports Stratonovich "
|
|
501
|
+
"SDEs."
|
|
502
|
+
)
|
|
503
|
+
elif not isinstance(solver, AbstractStratonovichSolver):
|
|
504
|
+
warnings.warn(
|
|
505
|
+
f"{solver.__class__.__name__} is not marked as converging to "
|
|
506
|
+
"either the Itô or the Stratonovich solution. Note that "
|
|
507
|
+
"`EventPropAdjoint` will only produce the correct solution for "
|
|
508
|
+
"Stratonovich SDEs."
|
|
509
|
+
)
|
|
510
|
+
if jtu.tree_structure(solver.term_structure) != jtu.tree_structure(0):
|
|
511
|
+
raise NotImplementedError(
|
|
512
|
+
"`diffrax.EventPropAdjoint` is only compatible with solvers that take "
|
|
513
|
+
"a single term."
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
y = init_state.y
|
|
517
|
+
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
|
|
518
|
+
# jax.debug.print("{x}", x=init_state)
|
|
519
|
+
init_state = jax.tree.map(lambda x: lax.stop_gradient(x) if eqx.is_array(x) else x, init_state)
|
|
520
|
+
init_state = _nondiff_solver_controller_state(
|
|
521
|
+
self, init_state, passed_solver_state, passed_controller_state
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
final_state, aux_stats = _loop_backsolve(
|
|
525
|
+
(y, args, terms, t0, t1),
|
|
526
|
+
self=self,
|
|
527
|
+
saveat=saveat,
|
|
528
|
+
init_state=init_state,
|
|
529
|
+
solver=solver,
|
|
530
|
+
event=event,
|
|
531
|
+
dt0=dt0,
|
|
532
|
+
**kwargs,
|
|
533
|
+
)
|
|
534
|
+
final_state = _only_transpose_ys(final_state)
|
|
535
|
+
return final_state, aux_stats
|