jaxsim 0.2.dev188__py3-none-any.whl → 0.2.dev364__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +88 -72
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/collision.py +14 -0
  26. jaxsim/parsers/descriptions/link.py +13 -2
  27. jaxsim/parsers/kinematic_graph.py +5 -0
  28. jaxsim/parsers/rod/utils.py +7 -8
  29. jaxsim/rbda/__init__.py +7 -0
  30. jaxsim/rbda/aba.py +295 -0
  31. jaxsim/rbda/collidable_points.py +142 -0
  32. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  33. jaxsim/rbda/forward_kinematics.py +113 -0
  34. jaxsim/rbda/jacobian.py +201 -0
  35. jaxsim/rbda/rnea.py +237 -0
  36. jaxsim/rbda/soft_contacts.py +296 -0
  37. jaxsim/rbda/utils.py +152 -0
  38. jaxsim/terrain/__init__.py +2 -0
  39. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  40. jaxsim/utils/__init__.py +1 -4
  41. jaxsim/utils/hashless.py +18 -0
  42. jaxsim/utils/jaxsim_dataclass.py +281 -30
  43. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
  44. jaxsim-0.2.dev364.dist-info/RECORD +64 -0
  45. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
  46. jaxsim/high_level/__init__.py +0 -2
  47. jaxsim/high_level/common.py +0 -11
  48. jaxsim/high_level/joint.py +0 -148
  49. jaxsim/high_level/link.py +0 -259
  50. jaxsim/high_level/model.py +0 -1686
  51. jaxsim/math/conv.py +0 -114
  52. jaxsim/math/joint.py +0 -102
  53. jaxsim/math/plucker.py +0 -100
  54. jaxsim/physics/__init__.py +0 -12
  55. jaxsim/physics/algos/__init__.py +0 -0
  56. jaxsim/physics/algos/aba.py +0 -254
  57. jaxsim/physics/algos/aba_motors.py +0 -284
  58. jaxsim/physics/algos/forward_kinematics.py +0 -79
  59. jaxsim/physics/algos/jacobian.py +0 -98
  60. jaxsim/physics/algos/rnea.py +0 -180
  61. jaxsim/physics/algos/rnea_motors.py +0 -196
  62. jaxsim/physics/algos/soft_contacts.py +0 -523
  63. jaxsim/physics/algos/utils.py +0 -69
  64. jaxsim/physics/model/__init__.py +0 -0
  65. jaxsim/physics/model/ground_contact.py +0 -55
  66. jaxsim/physics/model/physics_model.py +0 -388
  67. jaxsim/physics/model/physics_model_state.py +0 -283
  68. jaxsim/simulation/__init__.py +0 -4
  69. jaxsim/simulation/integrators.py +0 -393
  70. jaxsim/simulation/ode.py +0 -290
  71. jaxsim/simulation/ode_data.py +0 -96
  72. jaxsim/simulation/ode_integration.py +0 -62
  73. jaxsim/simulation/simulator.py +0 -543
  74. jaxsim/simulation/simulator_callbacks.py +0 -79
  75. jaxsim/simulation/utils.py +0 -15
  76. jaxsim/sixd/__init__.py +0 -2
  77. jaxsim/utils/oop.py +0 -536
  78. jaxsim/utils/vmappable.py +0 -117
  79. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  80. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
  81. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/api/references.py CHANGED
@@ -7,10 +7,11 @@ import jax.numpy as jnp
7
7
  import jax_dataclasses
8
8
 
9
9
  import jaxsim.api as js
10
- import jaxsim.physics.model.physics_model_state
11
10
  import jaxsim.typing as jtp
12
- from jaxsim import VelRepr
13
- from jaxsim.simulation.ode_data import ODEInput
11
+ from jaxsim.utils.tracing import not_tracing
12
+
13
+ from .common import VelRepr
14
+ from .ode_data import ODEInput
14
15
 
15
16
  try:
16
17
  from typing import Self
@@ -95,7 +96,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
95
96
 
96
97
  # Create a zero references object.
97
98
  references = JaxSimModelReferences(
98
- input=ODEInput.zero(physics_model=model.physics_model),
99
+ input=ODEInput.zero(model=model),
99
100
  velocity_representation=velocity_representation,
100
101
  )
101
102
 
@@ -132,7 +133,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
132
133
  valid = True
133
134
 
134
135
  if model is not None:
135
- valid = valid and self.input.valid(physics_model=model.physics_model)
136
+ valid = valid and self.input.valid(model=model)
136
137
 
137
138
  return valid
138
139
 
@@ -188,7 +189,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
188
189
 
189
190
  # If we have the model, we can extract the link names, if not provided.
190
191
  link_names = link_names if link_names is not None else model.link_names()
191
- link_idxs = jaxsim.api.link.names_to_idxs(link_names=link_names, model=model)
192
+ link_idxs = js.link.names_to_idxs(link_names=link_names, model=model)
192
193
 
193
194
  # In inertial-fixed representation, we already have the link forces.
194
195
  if self.velocity_representation is VelRepr.Inertial:
@@ -198,7 +199,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
198
199
  msg = "Missing model data to use a representation different from {}"
199
200
  raise ValueError(msg.format(VelRepr.Inertial.name))
200
201
 
201
- if not data.valid(model=model):
202
+ if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model):
202
203
  raise ValueError("The provided data is not valid for the model")
203
204
 
204
205
  # Helper function to convert a single 6D force to the active representation.
@@ -252,7 +253,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
252
253
 
253
254
  return self.input.physics_model.tau
254
255
 
255
- if not self.valid(model=model):
256
+ if not_tracing(self.input.physics_model.tau) and not self.valid(model=model):
256
257
  msg = "The actuation object is not compatible with the provided model"
257
258
  raise ValueError(msg)
258
259
 
@@ -303,7 +304,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
303
304
  if model is None:
304
305
  return replace(forces=forces)
305
306
 
306
- if not self.valid(model=model):
307
+ if not_tracing(forces) and not self.valid(model=model):
307
308
  msg = "The references object is not compatible with the provided model"
308
309
  raise ValueError(msg)
309
310
 
@@ -379,7 +380,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
379
380
 
380
381
  # If we have the model, we can extract the link names if not provided.
381
382
  link_names = link_names if link_names is not None else model.link_names()
382
- link_idxs = jaxsim.api.link.names_to_idxs(link_names=link_names, model=model)
383
+ link_idxs = js.link.names_to_idxs(link_names=link_names, model=model)
383
384
 
384
385
  # Compute the bias depending on whether we either set or add the link forces.
385
386
  W_f0_L = (
@@ -401,7 +402,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
401
402
  msg = "Missing model data to use a representation different from {}"
402
403
  raise ValueError(msg.format(VelRepr.Inertial.name))
403
404
 
404
- if not data.valid(model=model):
405
+ if not_tracing(forces) and not data.valid(model=model):
405
406
  raise ValueError("The provided data is not valid for the model")
406
407
 
407
408
  # Helper function to convert a single 6D force to the inertial representation.
@@ -1,2 +1,2 @@
1
- from . import fixed_step
2
- from .common import Integrator, Time, TimeStep
1
+ from . import fixed_step, variable_step
2
+ from .common import Integrator, SystemDynamics, Time, TimeStep
@@ -5,9 +5,12 @@ from typing import Any, ClassVar, Generic, Protocol, Type, TypeVar
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  import jax_dataclasses
8
+ import jaxlie
8
9
  from jax_dataclasses import Static
9
10
 
11
+ import jaxsim.api as js
10
12
  import jaxsim.typing as jtp
13
+ from jaxsim.math import Quaternion
11
14
  from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability
12
15
 
13
16
  try:
@@ -46,6 +49,9 @@ class SystemDynamics(Protocol[State, StateDerivative]):
46
49
  @jax_dataclasses.pytree_dataclass
47
50
  class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
48
51
 
52
+ AfterInitKey: ClassVar[str] = "after_init"
53
+ InitializingKey: ClassVar[str] = "initializing"
54
+
49
55
  AuxDictDynamicsKey: ClassVar[str] = "aux_dict_dynamics"
50
56
 
51
57
  dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
@@ -58,7 +64,10 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
58
64
 
59
65
  @classmethod
60
66
  def build(
61
- cls: Type[Self], *, dynamics: SystemDynamics[State, StateDerivative], **kwargs
67
+ cls: Type[Self],
68
+ *,
69
+ dynamics: SystemDynamics[State, StateDerivative],
70
+ **kwargs,
62
71
  ) -> Self:
63
72
  """
64
73
  Build the integrator object.
@@ -102,9 +111,9 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
102
111
  with integrator.mutable_context(mutability=Mutability.MUTABLE):
103
112
  xf = integrator(x0, t0, dt, **kwargs)
104
113
 
105
- assert Integrator.AuxDictDynamicsKey in integrator.params
106
-
107
- return xf, integrator.params
114
+ return xf, integrator.params | {
115
+ Integrator.AfterInitKey: jnp.array(False).astype(bool)
116
+ }
108
117
 
109
118
  @abc.abstractmethod
110
119
  def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
@@ -116,7 +125,7 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
116
125
  t0: Time,
117
126
  dt: TimeStep,
118
127
  *,
119
- key: jax.Array | None = None,
128
+ include_dynamics_aux_dict: bool = False,
120
129
  **kwargs,
121
130
  ) -> dict[str, Any]:
122
131
  """
@@ -126,7 +135,6 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
126
135
  x0: The initial state of the system.
127
136
  t0: The initial time of the system.
128
137
  dt: The time step of the integration.
129
- key: An optional random key to initialize the integrator.
130
138
 
131
139
  Returns:
132
140
  The auxiliary dictionary of the integrator.
@@ -141,17 +149,43 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
141
149
  the first step will be wrong.
142
150
  """
143
151
 
144
- _, aux_dict_dynamics = self.dynamics(x0, t0)
145
-
146
152
  with self.editable(validate=False) as integrator:
153
+
154
+ # Initialize the integrator parameters.
155
+ # For initialization purpose, the integrators can check if the
156
+ # `Integrator.InitializingKey` is present in their parameters.
157
+ # The AfterInitKey is used in the first step after initialization.
158
+ integrator.params = {
159
+ Integrator.InitializingKey: jnp.array(True),
160
+ Integrator.AfterInitKey: jnp.array(False),
161
+ }
162
+
163
+ # Run a dummy call of the integrator.
164
+ # It is used only to get the params so that we know the structure
165
+ # of the corresponding pytree.
147
166
  _ = integrator(x0, t0, dt, **kwargs)
148
- aux_dict_step = integrator.params
149
167
 
150
- if Integrator.AuxDictDynamicsKey in aux_dict_dynamics:
151
- msg = "You cannot create a key '{}' in the __call__ method."
152
- raise KeyError(msg.format(Integrator.AuxDictDynamicsKey))
168
+ # Remove the injected key.
169
+ _ = integrator.params.pop(Integrator.InitializingKey)
153
170
 
154
- return {Integrator.AuxDictDynamicsKey: aux_dict_dynamics} | aux_dict_step
171
+ # Make sure that all leafs of the dictionary are JAX arrays.
172
+ # Also, since these are dummy parameters, set them all to zero.
173
+ params_after_init = jax.tree_util.tree_map(
174
+ lambda l: jnp.zeros_like(l), integrator.params
175
+ )
176
+
177
+ # Mark the next step as first step after initialization.
178
+ params_after_init = params_after_init | {
179
+ Integrator.AfterInitKey: jnp.array(True)
180
+ }
181
+
182
+ # Store the zero parameters in the integrator.
183
+ # When the integrator is stepped, this is used to check if the passed
184
+ # parameters are valid.
185
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
186
+ self.params = params_after_init
187
+
188
+ return params_after_init
155
189
 
156
190
 
157
191
  @jax_dataclasses.pytree_dataclass
@@ -209,20 +243,10 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
209
243
  The integrator object.
210
244
  """
211
245
 
212
- # Adjust the shape of the tableau coefficients.
213
- c = jnp.atleast_1d(cls.c.squeeze())
214
- b = jnp.atleast_2d(jnp.vstack(cls.b.squeeze()))
215
- A = jnp.atleast_2d(cls.A.squeeze())
216
-
217
246
  # Check validity of the Butcher tableau.
218
- if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
247
+ if not ExplicitRungeKutta.butcher_tableau_is_valid(A=cls.A, b=cls.b, c=cls.c):
219
248
  raise ValueError("The Butcher tableau of this class is not valid.")
220
249
 
221
- # Store the adjusted shapes of the tableau coefficients.
222
- cls.c = c
223
- cls.b = b
224
- cls.A = A
225
-
226
250
  # Check that b.T has enough rows based on the configured index of the solution.
227
251
  if cls.row_index_of_solution >= cls.b.T.shape[0]:
228
252
  msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."
@@ -506,3 +530,65 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
506
530
  # possibly intermediate kᵢ derivative).
507
531
  # Note that if multiple rows match (it should not), we return the first match.
508
532
  return True, int(jnp.where(rows_of_A_with_fsal == True)[0].tolist()[0])
533
+
534
+
535
+ class ExplicitRungeKuttaSO3Mixin:
536
+ """
537
+ Mixin class to apply over explicit RK integrators defined on
538
+ `PyTreeType = ODEState` to integrate the quaternion on SO(3).
539
+ """
540
+
541
+ @classmethod
542
+ def integrate_rk_stage(
543
+ cls, x0: js.ode_data.ODEState, t0: Time, dt: TimeStep, k: js.ode_data.ODEState
544
+ ) -> js.ode_data.ODEState:
545
+
546
+ op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
547
+ xf: js.ode_data.ODEState = jax.tree_util.tree_map(op, x0, k)
548
+
549
+ W_Q_B_t0 = x0.physics_model.base_quaternion
550
+ W_ω_WB_t0 = x0.physics_model.base_angular_velocity
551
+
552
+ return xf.replace(
553
+ physics_model=xf.physics_model.replace(
554
+ base_quaternion=Quaternion.integration(
555
+ quaternion=W_Q_B_t0,
556
+ dt=dt,
557
+ omega=W_ω_WB_t0,
558
+ omega_in_body_fixed=False,
559
+ ),
560
+ )
561
+ )
562
+
563
+ @classmethod
564
+ def post_process_state(
565
+ cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
566
+ ) -> js.ode_data.ODEState:
567
+
568
+ # Indices to convert quaternions between serializations.
569
+ to_xyzw = jnp.array([1, 2, 3, 0])
570
+ to_wxyz = jnp.array([3, 0, 1, 2])
571
+
572
+ # Get the initial quaternion.
573
+ W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
574
+ xyzw=x0.physics_model.base_quaternion[to_xyzw]
575
+ )
576
+
577
+ # Get the final angular velocity.
578
+ # This is already computed by averaging the kᵢ in RK-based schemes.
579
+ # Therefore, by using the ω at tf, we obtain a RK scheme operating
580
+ # on the SO(3) manifold.
581
+ W_ω_WB_tf = xf.physics_model.base_angular_velocity
582
+
583
+ # Integrate the quaternion on SO(3).
584
+ # Note that we left-multiply with the exponential map since the angular
585
+ # velocity is expressed in the inertial frame.
586
+ W_Q_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_Q_B_t0
587
+
588
+ # Replace the quaternion in the final state.
589
+ return xf.replace(
590
+ physics_model=xf.physics_model.replace(
591
+ base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
592
+ ),
593
+ validate=True,
594
+ )
@@ -3,14 +3,12 @@ from typing import ClassVar, Generic
3
3
  import jax
4
4
  import jax.numpy as jnp
5
5
  import jax_dataclasses
6
- import jaxlie
7
6
 
8
- from jaxsim.simulation.ode_data import ODEState
7
+ import jaxsim.api as js
9
8
 
10
- from .common import ExplicitRungeKutta, PyTreeType, Time, TimeStep
11
-
12
- ODEStateDerivative = ODEState
9
+ from .common import ExplicitRungeKutta, ExplicitRungeKuttaSO3Mixin, PyTreeType
13
10
 
11
+ ODEStateDerivative = js.ode_data.ODEState
14
12
 
15
13
  # =====================================================
16
14
  # Explicit Runge-Kutta integrators operating on PyTrees
@@ -20,37 +18,23 @@ ODEStateDerivative = ODEState
20
18
  @jax_dataclasses.pytree_dataclass
21
19
  class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
22
20
 
23
- A: ClassVar[jax.typing.ArrayLike] = jnp.array(
24
- [
25
- [0],
26
- ]
27
- ).astype(float)
21
+ A: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(0).astype(float)
28
22
 
29
- b: ClassVar[jax.typing.ArrayLike] = (
30
- jnp.array(
31
- [
32
- [1],
33
- ]
34
- )
35
- .astype(float)
36
- .transpose()
37
- )
23
+ b: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(1).astype(float).transpose()
38
24
 
39
- c: ClassVar[jax.typing.ArrayLike] = jnp.array(
40
- [0],
41
- ).astype(float)
25
+ c: ClassVar[jax.typing.ArrayLike] = jnp.atleast_1d(0).astype(float)
42
26
 
43
27
  row_index_of_solution: ClassVar[int] = 0
44
28
  order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,)
45
29
 
46
30
 
47
31
  @jax_dataclasses.pytree_dataclass
48
- class Heun(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
32
+ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
49
33
 
50
34
  A: ClassVar[jax.typing.ArrayLike] = jnp.array(
51
35
  [
52
36
  [0, 0],
53
- [1 / 2, 0],
37
+ [1, 0],
54
38
  ]
55
39
  ).astype(float)
56
40
 
@@ -103,56 +87,16 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
103
87
  # ===============================================================================
104
88
 
105
89
 
106
- class ExplicitRungeKuttaSO3Mixin:
107
- """
108
- Mixin class to apply over explicit RK integrators defined on
109
- `PyTreeType = ODEState` to integrate the quaternion on SO(3).
110
- """
111
-
112
- @classmethod
113
- def post_process_state(
114
- cls, x0: ODEState, t0: Time, xf: ODEState, dt: TimeStep
115
- ) -> ODEState:
116
-
117
- # Indices to convert quaternions between serializations.
118
- to_xyzw = jnp.array([1, 2, 3, 0])
119
- to_wxyz = jnp.array([3, 0, 1, 2])
120
-
121
- # Get the initial quaternion.
122
- W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
123
- xyzw=x0.physics_model.base_quaternion[to_xyzw]
124
- )
125
-
126
- # Get the final angular velocity.
127
- # This is already computed by averaging the kᵢ in RK-based schemes.
128
- # Therefore, by using the ω at tf, we obtain a RK scheme operating
129
- # on the SO(3) manifold.
130
- W_ω_WB_tf = xf.physics_model.base_angular_velocity
131
-
132
- # Integrate the quaternion on SO(3).
133
- # Note that we left-multiply with the exponential map since the angular
134
- # velocity is expressed in the inertial frame.
135
- W_Q_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_Q_B_t0
136
-
137
- # Replace the quaternion in the final state.
138
- return xf.replace(
139
- physics_model=xf.physics_model.replace(
140
- base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
141
- ),
142
- validate=True,
143
- )
144
-
145
-
146
90
  @jax_dataclasses.pytree_dataclass
147
- class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]):
91
+ class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]):
148
92
  pass
149
93
 
150
94
 
151
95
  @jax_dataclasses.pytree_dataclass
152
- class HeunSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]):
96
+ class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]):
153
97
  pass
154
98
 
155
99
 
156
100
  @jax_dataclasses.pytree_dataclass
157
- class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[ODEState]):
101
+ class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]):
158
102
  pass