eventax 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. eventax-0.1.0/PKG-INFO +39 -0
  2. eventax-0.1.0/README.md +16 -0
  3. eventax-0.1.0/pyproject.toml +4 -0
  4. eventax-0.1.0/setup.cfg +4 -0
  5. eventax-0.1.0/setup.py +27 -0
  6. eventax-0.1.0/src/eventax/__init__.py +0 -0
  7. eventax-0.1.0/src/eventax/adjoint.py +535 -0
  8. eventax-0.1.0/src/eventax/buffer.py +528 -0
  9. eventax-0.1.0/src/eventax/evnn.py +927 -0
  10. eventax-0.1.0/src/eventax/io.py +199 -0
  11. eventax-0.1.0/src/eventax/neuron_models/__init__.py +29 -0
  12. eventax-0.1.0/src/eventax/neuron_models/alif.py +157 -0
  13. eventax-0.1.0/src/eventax/neuron_models/amos_wrapper.py +93 -0
  14. eventax-0.1.0/src/eventax/neuron_models/base_model.py +93 -0
  15. eventax-0.1.0/src/eventax/neuron_models/egru.py +142 -0
  16. eventax-0.1.0/src/eventax/neuron_models/eif.py +190 -0
  17. eventax-0.1.0/src/eventax/neuron_models/helpers.py +21 -0
  18. eventax-0.1.0/src/eventax/neuron_models/initializations.py +175 -0
  19. eventax-0.1.0/src/eventax/neuron_models/izhikevich.py +137 -0
  20. eventax-0.1.0/src/eventax/neuron_models/lif.py +149 -0
  21. eventax-0.1.0/src/eventax/neuron_models/multi_model.py +222 -0
  22. eventax-0.1.0/src/eventax/neuron_models/pizhikevich.py +195 -0
  23. eventax-0.1.0/src/eventax/neuron_models/plif.py +193 -0
  24. eventax-0.1.0/src/eventax/neuron_models/pqif.py +78 -0
  25. eventax-0.1.0/src/eventax/neuron_models/qif.py +83 -0
  26. eventax-0.1.0/src/eventax/neuron_models/refractory_wrapper.py +128 -0
  27. eventax-0.1.0/src/eventax.egg-info/PKG-INFO +39 -0
  28. eventax-0.1.0/src/eventax.egg-info/SOURCES.txt +34 -0
  29. eventax-0.1.0/src/eventax.egg-info/dependency_links.txt +1 -0
  30. eventax-0.1.0/src/eventax.egg-info/requires.txt +3 -0
  31. eventax-0.1.0/src/eventax.egg-info/top_level.txt +1 -0
  32. eventax-0.1.0/tests/test_adjoint.py +86 -0
  33. eventax-0.1.0/tests/test_buffer.py +328 -0
  34. eventax-0.1.0/tests/test_evnn.py +386 -0
  35. eventax-0.1.0/tests/test_io.py +225 -0
  36. eventax-0.1.0/tests/test_neuron_models.py +102 -0
eventax-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,39 @@
1
+ Metadata-Version: 2.4
2
+ Name: eventax
3
+ Version: 0.1.0
4
+ Summary: A Diffrax-based framework for continuous-time spiking neural networks
5
+ Home-page: https://github.com/Efficient-Scalable-Machine-Learning/eventax
6
+ Author: Lukas König
7
+ Author-email: lukmkoenig@gmail.com
8
+ License: MIT
9
+ Requires-Python: >=3.11
10
+ Description-Content-Type: text/markdown
11
+ Requires-Dist: jax>=0.9.1
12
+ Requires-Dist: diffrax>=0.7.2
13
+ Requires-Dist: equinox>=0.13.5
14
+ Dynamic: author
15
+ Dynamic: author-email
16
+ Dynamic: description
17
+ Dynamic: description-content-type
18
+ Dynamic: home-page
19
+ Dynamic: license
20
+ Dynamic: requires-dist
21
+ Dynamic: requires-python
22
+ Dynamic: summary
23
+
24
+ <p align="center">
25
+ <img src="./docs-site/docs/img/logo5.svg" alt="Eventpropjax" width="70%">
26
+ </p>
27
+
28
+ Eventax provides a [JAX](https://github.com/google/jax) implementation of the [EventProp algorithm](https://arxiv.org/abs/2009.08378) using [Diffrax](https://github.com/patrick-kidger/diffrax) and [Equinox](https://github.com/patrick-kidger/equinox) offering full autograd support, easy extension with custom neuron dynamics, and built-in delay training.
29
+
30
+ ## Features
31
+ - Fully differentiable implementation via JAX and Diffrax
32
+ - Easy extension with custom neuron model dynamics + learnable parameters
33
+ - Support for (trainable) synnaptic delays.
34
+ - GPU/TPU compatibility through JAX
35
+
36
+ ## 📦 Installation
37
+ ```bash
38
+ pip install eventax
39
+ ```
@@ -0,0 +1,16 @@
1
+ <p align="center">
2
+ <img src="./docs-site/docs/img/logo5.svg" alt="Eventpropjax" width="70%">
3
+ </p>
4
+
5
+ Eventax provides a [JAX](https://github.com/google/jax) implementation of the [EventProp algorithm](https://arxiv.org/abs/2009.08378) using [Diffrax](https://github.com/patrick-kidger/diffrax) and [Equinox](https://github.com/patrick-kidger/equinox) offering full autograd support, easy extension with custom neuron dynamics, and built-in delay training.
6
+
7
+ ## Features
8
+ - Fully differentiable implementation via JAX and Diffrax
9
+ - Easy extension with custom neuron model dynamics + learnable parameters
10
+ - Support for (trainable) synnaptic delays.
11
+ - GPU/TPU compatibility through JAX
12
+
13
+ ## 📦 Installation
14
+ ```bash
15
+ pip install eventax
16
+ ```
@@ -0,0 +1,4 @@
1
+ # pyproject.toml
2
+ [build-system]
3
+ requires = ["setuptools>=64", "wheel"]
4
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
eventax-0.1.0/setup.py ADDED
@@ -0,0 +1,27 @@
1
+ import io
2
+ import os
3
+ from setuptools import setup, find_packages
4
+ here = os.path.abspath(os.path.dirname(__file__))
5
+ with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f:
6
+ long_description = f.read()
7
+ setup(
8
+ name="eventax",
9
+ version="0.1.0",
10
+ author="Lukas König",
11
+ author_email="lukmkoenig@gmail.com",
12
+ description=(
13
+ "A Diffrax-based framework for continuous-time spiking neural networks "
14
+ ),
15
+ long_description=long_description,
16
+ long_description_content_type="text/markdown",
17
+ url="https://github.com/Efficient-Scalable-Machine-Learning/eventax",
18
+ license="MIT",
19
+ package_dir={"": "src"},
20
+ packages=find_packages(where="src"),
21
+ install_requires=[
22
+ "jax>=0.9.1",
23
+ "diffrax>=0.7.2",
24
+ "equinox>=0.13.5",
25
+ ],
26
+ python_requires=">=3.11",
27
+ )
File without changes
@@ -0,0 +1,535 @@
1
+ import functools as ft
2
+ import warnings
3
+ from collections.abc import Callable
4
+ from typing import Any, cast
5
+
6
+ import equinox as eqx
7
+ import equinox.internal as eqxi
8
+ import jax
9
+ import jax.lax as lax
10
+ import jax.numpy as jnp
11
+ import jax.tree_util as jtu
12
+ from equinox.internal import ω
13
+
14
+ from diffrax._heuristics import is_sde, is_unsafe_sde
15
+ from diffrax._saveat import save_y, SaveAt, SubSaveAt
16
+ from diffrax._solver import (
17
+ AbstractItoSolver,
18
+ AbstractStratonovichSolver,
19
+ )
20
+ from diffrax._term import AbstractTerm, AdjointTerm
21
+ from diffrax._adjoint import AbstractAdjoint
22
+
23
+ ω = cast(Callable, ω)
24
+
25
+
26
+ def _is_none(x):
27
+ return x is None
28
+
29
+
30
+ def _is_subsaveat(x: Any) -> bool:
31
+ return isinstance(x, SubSaveAt)
32
+
33
+
34
+ def _nondiff_solver_controller_state(
35
+ adjoint, init_state, passed_solver_state, passed_controller_state
36
+ ):
37
+ if passed_solver_state:
38
+ name = (
39
+ f"When using `adjoint={adjoint.__class__.__name__}()`, then `solver_state`"
40
+ )
41
+ solver_fn = ft.partial(
42
+ eqxi.nondifferentiable,
43
+ name=name,
44
+ )
45
+ else:
46
+ solver_fn = lax.stop_gradient
47
+ if passed_controller_state:
48
+ name = (
49
+ f"When using `adjoint={adjoint.__class__.__name__}()`, then "
50
+ "`controller_state`"
51
+ )
52
+ controller_fn = ft.partial(
53
+ eqxi.nondifferentiable,
54
+ name=name,
55
+ )
56
+ else:
57
+ controller_fn = lax.stop_gradient
58
+ init_state = eqx.tree_at(
59
+ lambda s: s.solver_state,
60
+ init_state,
61
+ replace_fn=solver_fn,
62
+ is_leaf=_is_none,
63
+ )
64
+ init_state = eqx.tree_at(
65
+ lambda s: s.controller_state,
66
+ init_state,
67
+ replace_fn=controller_fn,
68
+ is_leaf=_is_none,
69
+ )
70
+ return init_state
71
+
72
+
73
+ def _only_transpose_ys(final_state):
74
+ from diffrax._integrate import SaveState
75
+
76
+ def is_save_state(x): return isinstance(x, SaveState)
77
+
78
+ def get_ys(_final_state):
79
+ return [
80
+ s.ys
81
+ for s in jtu.tree_leaves(_final_state.save_state, is_leaf=is_save_state)
82
+ ]
83
+
84
+ def get_ts(_final_state):
85
+ return [
86
+ s.ts
87
+ for s in jtu.tree_leaves(_final_state.save_state, is_leaf=is_save_state)
88
+ ]
89
+
90
+ ys = get_ys(final_state)
91
+ ts = get_ts(final_state)
92
+
93
+ named_nondiff_entries = (
94
+ "y",
95
+ "tprev",
96
+ "tnext",
97
+ "solver_state",
98
+ "controller_state",
99
+ "dense_ts",
100
+ "dense_infos",
101
+ )
102
+ named_nondiff_values = tuple(
103
+ eqxi.nondifferentiable_backward(getattr(final_state, k), name=k, symbolic=False)
104
+ for k in named_nondiff_entries
105
+ )
106
+
107
+ final_state = eqxi.nondifferentiable_backward(final_state, symbolic=False)
108
+
109
+ def get_named_nondiff_entries(s): return tuple(
110
+ getattr(s, k) for k in named_nondiff_entries
111
+ )
112
+ final_state = eqx.tree_at(
113
+ get_named_nondiff_entries, final_state, named_nondiff_values, is_leaf=_is_none
114
+ )
115
+
116
+ final_state = eqx.tree_at(get_ys, final_state, ys)
117
+ final_state = eqx.tree_at(get_ts, final_state, ts)
118
+ return final_state
119
+
120
+
121
+ _inner_loop = jax.named_call(eqxi.while_loop, name="inner-loop")
122
+ _outer_loop = jax.named_call(eqxi.while_loop, name="outer-loop")
123
+
124
+
125
+ @eqx.filter_custom_vjp
126
+ def _loop_backsolve(y__args__terms__t0__t1, *, self, throw, init_state, **kwargs):
127
+ del throw
128
+ y, args, terms, t0, t1 = y__args__terms__t0__t1
129
+ init_state = eqx.tree_at(lambda s: s.y, init_state, y)
130
+ del y
131
+ return self._loop(
132
+ args=args,
133
+ terms=terms,
134
+ init_state=init_state,
135
+ inner_while_loop=ft.partial(_inner_loop, kind="lax"),
136
+ outer_while_loop=ft.partial(_outer_loop, kind="lax"),
137
+ t0=t0,
138
+ t1=t1,
139
+ **kwargs,
140
+ )
141
+
142
+
143
+ @_loop_backsolve.def_fwd
144
+ def _loop_backsolve_fwd(perturbed, y__args__terms__t0__t1, **kwargs):
145
+ del perturbed
146
+ final_state, aux_stats = _loop_backsolve(y__args__terms__t0__t1, **kwargs)
147
+ # Note that `final_state.save_state` has type `PyTree[SaveState]`; here we are
148
+ # relying on the guard in `EventPropAdjoint` that it have trivial structure.
149
+ ts = final_state.save_state.ts
150
+ ys = final_state.save_state.ys
151
+
152
+ event_mask = final_state.event_mask
153
+ event_tprev = final_state.event_tprev
154
+ event_tnext = final_state.event_tnext
155
+ event_dense_info = final_state.event_dense_info
156
+ event_values = final_state.event_values
157
+
158
+ return (final_state, aux_stats), (ts, ys, event_mask, event_tprev, event_tnext, event_dense_info, event_values)
159
+
160
+
161
+ def _materialise_none(y, grad_y):
162
+ if grad_y is None and eqx.is_inexact_array(y):
163
+ return jnp.zeros_like(y)
164
+ else:
165
+ return grad_y
166
+
167
+
168
+ @_loop_backsolve.def_bwd
169
+ def _loop_backsolve_bwd(
170
+ residuals,
171
+ grad_final_state__aux_stats,
172
+ perturbed,
173
+ y__args__terms__t0__t1,
174
+ *,
175
+ self,
176
+ solver,
177
+ stepsize_controller,
178
+ event,
179
+ saveat,
180
+ dt0,
181
+ max_steps,
182
+ throw,
183
+ init_state,
184
+ progress_meter,
185
+ ):
186
+ #
187
+ # Unpack our various arguments. Delete a lot of things just to make sure we're not
188
+ # using them later.
189
+ #
190
+
191
+ del perturbed, init_state, progress_meter
192
+ ts, ys, event_mask, event_tprev, event_tnext, event_dense_info, event_values = residuals
193
+ del residuals
194
+ grad_final_state, _ = grad_final_state__aux_stats
195
+ # Note that `grad_final_state.save_state` has type `PyTree[SaveState]`; here we are
196
+ # relying on the guard in `EventPropAdjoint` that it have trivial structure.
197
+ grad_ys = grad_final_state.save_state.ys
198
+ grad_ts = grad_final_state.save_state.ts
199
+ # We take the simple way out and don't try to handle symbolic zeros.
200
+ grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys)
201
+ grad_ts = jtu.tree_map(_materialise_none, ts, grad_ts)
202
+ del grad_final_state, grad_final_state__aux_stats
203
+ y, args, terms, t0, t1 = y__args__terms__t0__t1
204
+ del y__args__terms__t0__t1
205
+ diff_args = eqx.filter(args, eqx.is_inexact_array)
206
+ diff_terms = eqx.filter(terms, eqx.is_inexact_array)
207
+ zeros_like_y = jtu.tree_map(jnp.zeros_like, y)
208
+ zeros_like_diff_args = jtu.tree_map(jnp.zeros_like, diff_args)
209
+ zeros_like_diff_terms = jtu.tree_map(jnp.zeros_like, diff_terms)
210
+
211
+ # TODO: have this look inside MultiTerms? Need to think about the math. i.e.:
212
+ # is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm)
213
+ adjoint_terms = jtu.tree_map(
214
+ AdjointTerm, terms, is_leaf=lambda x: isinstance(x, AbstractTerm)
215
+ )
216
+ diffeqsolve = self._diffeqsolve
217
+ kwargs = dict(
218
+ args=args,
219
+ adjoint=self,
220
+ solver=solver,
221
+ stepsize_controller=stepsize_controller,
222
+ terms=adjoint_terms,
223
+ dt0=None if dt0 is None else -dt0,
224
+ max_steps=max_steps,
225
+ throw=throw,
226
+ )
227
+ kwargs.update(self.kwargs)
228
+
229
+ # Note that `saveat.subs` has type `PyTree[SubSaveAt]`. Here we use the assumption
230
+ # (checked in `EventPropAdjoint`) that it has trivial pytree structure.
231
+ saveat_t0 = saveat.subs.t0
232
+
233
+ if event is not None and event_mask is not None:
234
+ def _event_contribution():
235
+ """Compute the adjoint jump from the event."""
236
+ # Get the time and state values of the event
237
+ t_event = ω(ts)[-1].ω
238
+ y_event = ω(ys)[-1].ω
239
+
240
+ dL_dt_star = ω(grad_ts)[-1].ω
241
+ lambda_plus = ω(grad_ys)[-1].ω # λ(t*⁺) = dL/dy(t*)
242
+
243
+ f_val = terms.vf(t_event, y_event, args)
244
+
245
+ def _eval_cond(_y, _t):
246
+ return event.cond_fn(
247
+ _t, _y, args,
248
+ terms=terms,
249
+ solver=solver,
250
+ t0=t0,
251
+ t1=t1,
252
+ dt0=dt0,
253
+ saveat=saveat,
254
+ stepsize_controller=stepsize_controller,
255
+ max_steps=max_steps,
256
+ )
257
+ _, vjp_fun = eqx.filter_vjp(_eval_cond, y_event, t_event)
258
+ dg_dy, dg_dt = vjp_fun(1.0)
259
+
260
+ # Compute ν = -(dL/dt* + λᵀ · f) / (∂g/∂t + ∂g/∂y · f)
261
+ dg_dy_dot_f = jtu.tree_reduce(
262
+ lambda a, b: a + b,
263
+ jtu.tree_map(lambda dy, f: jnp.sum(dy * f), dg_dy, f_val)
264
+ )
265
+ denominator = dg_dt + dg_dy_dot_f
266
+
267
+ lambda_dot_f = jtu.tree_reduce(
268
+ lambda a, b: a + b,
269
+ jtu.tree_map(lambda lp, f: jnp.sum(lp * f), lambda_plus, f_val)
270
+ )
271
+
272
+ nu = -(dL_dt_star + lambda_dot_f) / (denominator + 1e-12)
273
+
274
+ # Δλ = ν · ∂g/∂y
275
+ adjoint_jump = jtu.tree_map(lambda dy: nu * dy, dg_dy)
276
+
277
+ # λ(t*⁻) = λ(t*⁺) + Δλ
278
+ lambda_minus = (lambda_plus**ω + adjoint_jump**ω).ω
279
+ return lambda_minus
280
+
281
+ def _no_event_contribution():
282
+ """No event occurred, just return the original gradient."""
283
+ return ω(grad_ys)[-1].ω
284
+
285
+ # Conditionally compute event contribution
286
+ adjusted_lambda = lax.cond(
287
+ event_mask,
288
+ _event_contribution,
289
+ _no_event_contribution,
290
+ )
291
+
292
+ grad_ys = jtu.tree_map(
293
+ lambda g, new_val: g.at[-1].set(new_val),
294
+ grad_ys,
295
+ adjusted_lambda
296
+ )
297
+
298
+ del self, solver, stepsize_controller, adjoint_terms, dt0, max_steps, throw
299
+ del saveat
300
+ del diff_args, diff_terms
301
+
302
+ #
303
+ # Now run a scan backwards in time, diffeqsolve'ing between each pair of adjacent
304
+ # timestamps.
305
+ #
306
+
307
+ def _scan_fun(_state, _vals, first=False):
308
+ _t1, _t0, _y0, _grad_y0 = _vals
309
+ _a0, _solver_state, _controller_state = _state
310
+ _a_y0, _a_diff_args0, _a_diff_term0 = _a0
311
+ _a_y0 = (_a_y0**ω + _grad_y0**ω).ω
312
+ _aug0 = (_y0, _a_y0, _a_diff_args0, _a_diff_term0)
313
+
314
+ _sol = diffeqsolve(
315
+ t0=_t0,
316
+ t1=_t1,
317
+ y0=_aug0,
318
+ solver_state=_solver_state,
319
+ controller_state=_controller_state,
320
+ made_jump=not first, # Adding _grad_y0, above, is a jump.
321
+ saveat=SaveAt(t1=True, solver_state=True, controller_state=True),
322
+ **kwargs,
323
+ )
324
+
325
+ def __get(__aug):
326
+ assert __aug.shape[0] == 1
327
+ return __aug[0]
328
+
329
+ _aug1 = ω(_sol.ys).call(__get).ω
330
+ _, _a_y1, _a_diff_args1, _a_diff_term1 = _aug1
331
+ _a1 = (_a_y1, _a_diff_args1, _a_diff_term1)
332
+ _solver_state = _sol.solver_state
333
+ _controller_state = _sol.controller_state
334
+
335
+ return (_a1, _solver_state, _controller_state), None
336
+
337
+ state = ((zeros_like_y, zeros_like_diff_args, zeros_like_diff_terms), None, None)
338
+ del zeros_like_y, zeros_like_diff_args, zeros_like_diff_terms
339
+
340
+ # We always start backpropagating from `ts[-1]`.
341
+ # We always finish backpropagating at `t0`.
342
+ #
343
+ # We may or may not have included `t0` in `ts`. (Depending on the value of
344
+ # SaveaAt(t0=...) on the forward pass.)
345
+ #
346
+ # For some of these options, we run _scan_fun once outside the loop to get access
347
+ # to solver_state etc. of the correct PyTree structure.
348
+ if saveat_t0:
349
+ if len(ts) > 2:
350
+ val0 = (ts[-2], ts[-1], ω(ys)[-1].ω, ω(grad_ys)[-1].ω)
351
+ state, _ = _scan_fun(state, val0, first=True)
352
+ vals = (
353
+ ts[:-2],
354
+ ts[1:-1],
355
+ ω(ys)[1:-1].ω,
356
+ ω(grad_ys)[1:-1].ω,
357
+ )
358
+ state, _ = lax.scan(_scan_fun, state, vals, reverse=True)
359
+
360
+ elif len(ts) == 1:
361
+ # nothing to do, diffeqsolve is the identity when merely SaveAt(t0=True).
362
+ pass
363
+
364
+ else:
365
+ assert len(ts) == 2
366
+ val = (ts[0], ts[1], ω(ys)[1].ω, ω(grad_ys)[1].ω)
367
+ state, _ = _scan_fun(state, val, first=True)
368
+
369
+ aug1, _, _ = state
370
+ a_y1, a_diff_args1, a_diff_terms1 = aug1
371
+ a_y1 = (ω(a_y1) + ω(grad_ys)[0]).ω
372
+
373
+ else:
374
+ if len(ts) > 1:
375
+ # TODO: fold this `_scan_fun` into the `lax.scan`. This will reduce compile
376
+ # time.
377
+ val0 = (ts[-2], ts[-1], ω(ys)[-1].ω, ω(grad_ys)[-1].ω)
378
+ state, _ = _scan_fun(state, val0, first=True)
379
+ vals = (
380
+ jnp.concatenate([t0[None], ts[:-2]]),
381
+ ts[:-1],
382
+ ω(ys)[:-1].ω,
383
+ ω(grad_ys)[:-1].ω,
384
+ )
385
+ state, _ = lax.scan(_scan_fun, state, vals, reverse=True)
386
+
387
+ else:
388
+ assert len(ts) == 1
389
+ val = (t0, ts[0], ω(ys)[0].ω, ω(grad_ys)[0].ω)
390
+ state, _ = _scan_fun(state, val, first=True)
391
+
392
+ aug1, _, _ = state
393
+ a_y1, a_diff_args1, a_diff_terms1 = aug1
394
+
395
+ # Boundary conditions
396
+ f_t1 = terms.vf(ts[-1], ω(ys)[-1].ω, args)
397
+ grad_t1 = jtu.tree_reduce(
398
+ lambda a, b: a + b,
399
+ jtu.tree_map(lambda a, f: jnp.sum(a * f), ω(grad_ys)[-1].ω, f_t1)
400
+ )
401
+
402
+ f_t0 = terms.vf(t0, y, args) # Vector field at t0
403
+ grad_t0 = -jtu.tree_reduce(
404
+ lambda a, b: a + b,
405
+ jtu.tree_map(lambda a, f: jnp.sum(a * f), a_y1, f_t0)
406
+ )
407
+
408
+ return a_y1, a_diff_args1, a_diff_terms1, grad_t0, grad_t1
409
+
410
+
411
+ class EventPropAdjoint(AbstractAdjoint):
412
+ """Backpropagate through [`diffrax.diffeqsolve`][] by solving the continuous
413
+ adjoint equations backwards-in-time. This is also sometimes known as
414
+ "optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint
415
+ method".
416
+
417
+ This will compute gradients with respect to the `terms`, `y0`, `args`, `t0`, and `t1`
418
+ arguments passed to [`diffrax.diffeqsolve`][]. If you attempt to compute gradients with
419
+ respect to anything else (for example arguments passed via closure), then
420
+ a `CustomVJPException` will be raised by JAX. See also
421
+ [this FAQ](../../further_details/faq/#im-getting-a-customvjpexception)
422
+ entry.
423
+
424
+ !!! info
425
+
426
+ Using this method prevents computing forward-mode autoderivatives of
427
+ [`diffrax.diffeqsolve`][]. (That is to say, `jax.jvp` will not work.)
428
+ """ # noqa: E501
429
+
430
+ kwargs: dict[str, Any]
431
+
432
+ def __init__(self, **kwargs):
433
+ """
434
+ **Arguments:**
435
+
436
+ - `**kwargs`: The arguments for the [`diffrax.diffeqsolve`][] operations that
437
+ are called on the backward pass. For example use
438
+ ```python
439
+ EventPropAdjoint(solver=Dopri5())
440
+ ```
441
+ to specify a particular solver to use on the backward pass.
442
+ """
443
+ valid_keys = {
444
+ "dt0",
445
+ "solver",
446
+ "stepsize_controller",
447
+ "adjoint",
448
+ "max_steps",
449
+ "throw",
450
+ }
451
+ given_keys = set(kwargs.keys())
452
+ diff_keys = given_keys - valid_keys
453
+ if len(diff_keys) > 0:
454
+ raise ValueError(
455
+ "The following keyword argments are not valid for `EventPropAdjoint`: "
456
+ f"{diff_keys}"
457
+ )
458
+ self.kwargs = kwargs
459
+
460
+ def loop(
461
+ self,
462
+ *,
463
+ args,
464
+ terms,
465
+ solver,
466
+ saveat,
467
+ init_state,
468
+ passed_solver_state,
469
+ passed_controller_state,
470
+ event,
471
+ t0,
472
+ t1,
473
+ dt0,
474
+ **kwargs,
475
+ ):
476
+ if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
477
+ 0
478
+ ):
479
+ raise NotImplementedError(
480
+ "Cannot use `adjoint=EventPropAdjoint()` with `SaveAt(subs=...)`."
481
+ )
482
+ if saveat.dense or saveat.subs.steps:
483
+ raise NotImplementedError(
484
+ "Cannot use `adjoint=EventPropAdjoint()` with "
485
+ "`saveat=SaveAt(steps=True)` or saveat=SaveAt(dense=True)`."
486
+ )
487
+ if saveat.subs.fn is not save_y:
488
+ raise NotImplementedError(
489
+ "Cannot use `adjoint=EventPropAdjoint()` with `saveat=SaveAt(fn=...)`."
490
+ )
491
+ if is_unsafe_sde(terms):
492
+ raise ValueError(
493
+ "`adjoint=EventPropAdjoint()` does not support `UnsafeBrownianPath`. "
494
+ "Consider using `adjoint=DirectAdjoint()` instead."
495
+ )
496
+ if is_sde(terms):
497
+ if isinstance(solver, AbstractItoSolver):
498
+ raise NotImplementedError(
499
+ f"`{solver.__class__.__name__}` converges to the Itô solution. "
500
+ "However `EventPropAdjoint` currently only supports Stratonovich "
501
+ "SDEs."
502
+ )
503
+ elif not isinstance(solver, AbstractStratonovichSolver):
504
+ warnings.warn(
505
+ f"{solver.__class__.__name__} is not marked as converging to "
506
+ "either the Itô or the Stratonovich solution. Note that "
507
+ "`EventPropAdjoint` will only produce the correct solution for "
508
+ "Stratonovich SDEs."
509
+ )
510
+ if jtu.tree_structure(solver.term_structure) != jtu.tree_structure(0):
511
+ raise NotImplementedError(
512
+ "`diffrax.EventPropAdjoint` is only compatible with solvers that take "
513
+ "a single term."
514
+ )
515
+
516
+ y = init_state.y
517
+ init_state = eqx.tree_at(lambda s: s.y, init_state, object())
518
+ # jax.debug.print("{x}", x=init_state)
519
+ init_state = jax.tree.map(lambda x: lax.stop_gradient(x) if eqx.is_array(x) else x, init_state)
520
+ init_state = _nondiff_solver_controller_state(
521
+ self, init_state, passed_solver_state, passed_controller_state
522
+ )
523
+
524
+ final_state, aux_stats = _loop_backsolve(
525
+ (y, args, terms, t0, t1),
526
+ self=self,
527
+ saveat=saveat,
528
+ init_state=init_state,
529
+ solver=solver,
530
+ event=event,
531
+ dt0=dt0,
532
+ **kwargs,
533
+ )
534
+ final_state = _only_transpose_ys(final_state)
535
+ return final_state, aux_stats