openscvx 0.1.3__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

@@ -17,7 +17,30 @@ def dVdt(
17
17
  n_u: int,
18
18
  N: int,
19
19
  dis_type: str,
20
+ **params
20
21
  ) -> jnp.ndarray:
22
+ """Compute the time derivative of the augmented state vector.
23
+
24
+ This function computes the time derivative of the augmented state vector V,
25
+ which includes the state, state transition matrix, and control influence matrix.
26
+
27
+ Args:
28
+ tau (float): Current normalized time in [0,1].
29
+ V (jnp.ndarray): Augmented state vector.
30
+ u_cur (np.ndarray): Control input at current node.
31
+ u_next (np.ndarray): Control input at next node.
32
+ state_dot (callable): Function computing state derivatives.
33
+ A (callable): Function computing state Jacobian.
34
+ B (callable): Function computing control Jacobian.
35
+ n_x (int): Number of states.
36
+ n_u (int): Number of controls.
37
+ N (int): Number of nodes in trajectory.
38
+ dis_type (str): Discretization type ("ZOH" or "FOH").
39
+ **params: Additional parameters passed to state_dot, A, and B.
40
+
41
+ Returns:
42
+ jnp.ndarray: Time derivative of augmented state vector.
43
+ """
21
44
  # Define the nodes
22
45
  nodes = jnp.arange(0, N-1)
23
46
 
@@ -52,15 +75,15 @@ def dVdt(
52
75
  u = u[: x.shape[0]]
53
76
 
54
77
  # Compute the nonlinear propagation term
55
- f = state_dot(x, u[:, :-1], nodes)
78
+ f = state_dot(x, u[:, :-1], nodes, *params.items())
56
79
  F = s[:, None] * f
57
80
 
58
81
  # Evaluate the State Jacobian
59
- dfdx = A(x, u[:, :-1], nodes)
82
+ dfdx = A(x, u[:, :-1], nodes, *params.items())
60
83
  sdfdx = s[:, None, None] * dfdx
61
84
 
62
85
  # Evaluate the Control Jacobian
63
- dfdu_veh = B(x, u[:, :-1], nodes)
86
+ dfdu_veh = B(x, u[:, :-1], nodes, *params.items())
64
87
  dfdu = dfdu.at[:, :, :-1].set(s[:, None, None] * dfdu_veh)
65
88
  dfdu = dfdu.at[:, :, -1].set(f)
66
89
 
@@ -76,7 +99,8 @@ def dVdt(
76
99
  dVdt = dVdt.at[:, i3:i4].set((jnp.matmul(sdfdx, V[:, i3:i4].reshape(-1, n_x, n_u)) + dfdu * beta).reshape(-1, n_x * n_u))
77
100
  dVdt = dVdt.at[:, i4:i5].set((jnp.matmul(sdfdx, V[:, i4:i5].reshape(-1, n_x)[..., None]).squeeze(-1) + z).reshape(-1, n_x))
78
101
  # fmt: on
79
- return dVdt.flatten()
102
+
103
+ return dVdt.reshape(-1)
80
104
 
81
105
 
82
106
  def calculate_discretization(
@@ -94,8 +118,38 @@ def calculate_discretization(
94
118
  rtol,
95
119
  atol,
96
120
  dis_type: str,
121
+ **kwargs
97
122
  ):
98
-
123
+ """Calculate the discretized system matrices.
124
+
125
+ This function computes the discretized system matrices (A_bar, B_bar, C_bar)
126
+ and defect vector (z_bar) using numerical integration.
127
+
128
+ Args:
129
+ x: State trajectory.
130
+ u: Control trajectory.
131
+ state_dot (callable): Function computing state derivatives.
132
+ A (callable): Function computing state Jacobian.
133
+ B (callable): Function computing control Jacobian.
134
+ n_x (int): Number of states.
135
+ n_u (int): Number of controls.
136
+ N (int): Number of nodes in trajectory.
137
+ custom_integrator (bool): Whether to use custom RK45 integrator.
138
+ debug (bool): Whether to use debug mode.
139
+ solver (str): Name of the solver to use.
140
+ rtol (float): Relative tolerance for integration.
141
+ atol (float): Absolute tolerance for integration.
142
+ dis_type (str): Discretization type ("ZOH" or "FOH").
143
+ **kwargs: Additional parameters passed to state_dot, A, and B.
144
+
145
+ Returns:
146
+ tuple: (A_bar, B_bar, C_bar, z_bar, Vmulti) where:
147
+ - A_bar: Discretized state transition matrix
148
+ - B_bar: Discretized control influence matrix
149
+ - C_bar: Discretized control influence matrix for next node
150
+ - z_bar: Defect vector
151
+ - Vmulti: Full augmented state trajectory
152
+ """
99
153
  # Define indices for slicing the augmented state vector
100
154
  i0 = 0
101
155
  i1 = n_x
@@ -104,67 +158,91 @@ def calculate_discretization(
104
158
  i4 = i3 + n_x * n_u
105
159
  i5 = i4 + n_x
106
160
 
107
- # initial augmented state
161
+ # Initial augmented state
108
162
  V0 = jnp.zeros((N - 1, i5))
109
163
  V0 = V0.at[:, :n_x].set(x[:-1].astype(float))
110
- V0 = V0.at[:, n_x : n_x + n_x * n_x].set(
164
+ V0 = V0.at[:, n_x:n_x + n_x * n_x].set(
111
165
  jnp.eye(n_x).reshape(1, -1).repeat(N - 1, axis=0)
112
166
  )
113
167
 
114
- # choose integrator
168
+ # Choose integrator
169
+ integrator_args = dict(
170
+ u_cur=u[:-1].astype(float),
171
+ u_next=u[1:].astype(float),
172
+ state_dot=state_dot,
173
+ A=A,
174
+ B=B,
175
+ n_x=n_x,
176
+ n_u=n_u,
177
+ N=N,
178
+ dis_type=dis_type,
179
+ **kwargs # <-- adds parameter values with names
180
+ )
181
+
182
+ # Define dVdt wrapper using named arguments
183
+ def dVdt_wrapped(t, y):
184
+ return dVdt(t, y, **integrator_args)
185
+
186
+ # Choose integrator
115
187
  if custom_integrator:
116
- # fmt: off
117
188
  sol = solve_ivp_rk45(
118
- lambda t,y,*a: dVdt(t, y, *a),
119
- 1.0/(N-1),
189
+ dVdt_wrapped,
190
+ 1.0 / (N - 1),
120
191
  V0.reshape(-1),
121
- args=(u[:-1].astype(float), u[1:].astype(float),
122
- state_dot, A, B, n_x, n_u, N, dis_type),
192
+ args=(),
123
193
  is_not_compiled=debug,
124
194
  )
125
- # fmt: on
126
195
  else:
127
- # fmt: off
128
196
  sol = solve_ivp_diffrax(
129
- lambda t,y,*a: dVdt(t, y, *a),
130
- 1.0/(N-1),
197
+ dVdt_wrapped,
198
+ 1.0 / (N - 1),
131
199
  V0.reshape(-1),
132
- args=(u[:-1].astype(float), u[1:].astype(float),
133
- state_dot, A, B, n_x, n_u, N, dis_type),
134
200
  solver_name=solver,
135
201
  rtol=rtol,
136
202
  atol=atol,
203
+ args=(),
137
204
  extra_kwargs=None,
138
205
  )
139
- # fmt: on
140
206
 
141
207
  Vend = sol[-1].T.reshape(-1, i5)
142
208
  Vmulti = sol.T
143
209
 
144
- # fmt: off
145
- A_bar = Vend[:, i1:i2].reshape(N-1, n_x, n_x).transpose(1,2,0).reshape(n_x*n_x, -1, order='F').T
146
- B_bar = Vend[:, i2:i3].reshape(N-1, n_x, n_u).transpose(1,2,0).reshape(n_x*n_u, -1, order='F').T
147
- C_bar = Vend[:, i3:i4].reshape(N-1, n_x, n_u).transpose(1,2,0).reshape(n_x*n_u, -1, order='F').T
210
+ A_bar = Vend[:, i1:i2].reshape(N - 1, n_x, n_x).transpose(1, 2, 0).reshape(n_x * n_x, -1, order='F').T
211
+ B_bar = Vend[:, i2:i3].reshape(N - 1, n_x, n_u).transpose(1, 2, 0).reshape(n_x * n_u, -1, order='F').T
212
+ C_bar = Vend[:, i3:i4].reshape(N - 1, n_x, n_u).transpose(1, 2, 0).reshape(n_x * n_u, -1, order='F').T
148
213
  z_bar = Vend[:, i4:i5]
149
- # fmt: on
150
214
 
151
215
  return A_bar, B_bar, C_bar, z_bar, Vmulti
152
216
 
153
217
 
154
- def get_discretization_solver(dyn: Dynamics, params):
155
- return lambda x, u: calculate_discretization(
218
+ def get_discretization_solver(dyn: Dynamics, settings, param_map):
219
+ """Create a discretization solver function.
220
+
221
+ This function creates a solver that computes the discretized system matrices
222
+ using the specified dynamics and settings.
223
+
224
+ Args:
225
+ dyn (Dynamics): System dynamics object.
226
+ settings: Configuration settings for discretization.
227
+ param_map (dict): Mapping of parameter names to values.
228
+
229
+ Returns:
230
+ callable: A function that computes the discretized system matrices.
231
+ """
232
+ return lambda x, u, *params: calculate_discretization(
156
233
  x=x,
157
234
  u=u,
158
235
  state_dot=dyn.f,
159
236
  A=dyn.A,
160
237
  B=dyn.B,
161
- n_x=params.sim.n_states,
162
- n_u=params.sim.n_controls,
163
- N=params.scp.n,
164
- custom_integrator=params.dis.custom_integrator,
165
- debug=params.dev.debug,
166
- solver=params.dis.solver,
167
- rtol=params.dis.rtol,
168
- atol=params.dis.atol,
169
- dis_type=params.dis.dis_type,
170
- )
238
+ n_x=settings.sim.n_states,
239
+ n_u=settings.sim.n_controls,
240
+ N=settings.scp.n,
241
+ custom_integrator=settings.dis.custom_integrator,
242
+ debug=settings.dev.debug,
243
+ solver=settings.dis.solver,
244
+ rtol=settings.dis.rtol,
245
+ atol=settings.dis.atol,
246
+ dis_type=settings.dis.dis_type,
247
+ **dict(zip(param_map.keys(), params)) # <--- Named keyword args
248
+ )
openscvx/dynamics.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import Callable, Optional
2
+ from typing import Callable, Optional, Union
3
3
  import functools
4
4
 
5
5
  import jax.numpy as jnp
@@ -7,25 +7,165 @@ import jax.numpy as jnp
7
7
 
8
8
  @dataclass
9
9
  class Dynamics:
10
+ """
11
+ Dataclass to hold a system dynamics function and (optionally) its gradients.
12
+ This class is intended to be instantiated using the `dynamics` decorator wrapped around a function defining the system dynamics.
13
+ Both the dynamics and optional gradients should be composed of `jax` primitives to enable efficient computation.
14
+
15
+ Usage examples:
16
+
17
+ ```python
18
+ @dynamics
19
+ def f(x_, u_):
20
+ return x_ + u_
21
+ # f is now a Dynamics object
22
+ ```
23
+
24
+ ```python
25
+ @dynamics(A=grad_f_x, B=grad_f_u)
26
+ def f(x_, u_):
27
+ return x_ + u_
28
+ ```
29
+
30
+ Or, if a more lambda-function-style is desired, the function can be directly wrapped:
31
+
32
+ ```python
33
+ dyn = dynamics(lambda x_, u_: x_ + u_)
34
+ ```
35
+
36
+ ---
37
+ **Using Parameters in Dynamics**
38
+
39
+ You can use symbolic `Parameter` objects in your dynamics function to represent tunable or environment-dependent values. **The argument names for parameters must match the parameter name with an underscore suffix** (e.g., `I_sp_` for a parameter named `I_sp`). This is required for the parameter mapping to work correctly.
40
+
41
+ Example (3DoF rocket landing):
42
+
43
+ ```python
44
+ from openscvx.backend.parameter import Parameter
45
+ import jax.numpy as jnp
46
+
47
+ I_sp = Parameter("I_sp")
48
+ g = Parameter("g")
49
+ theta = Parameter("theta")
50
+
51
+ @dynamics
52
+ def rocket_dynamics(x_, u_, I_sp_, g_, theta_):
53
+ m = x_[6]
54
+ T = u_
55
+ r_dot = x_[3:6]
56
+ g_vec = jnp.array([0, 0, g_])
57
+ v_dot = T/m - g_vec
58
+ m_dot = -jnp.linalg.norm(T) / (I_sp_ * 9.807 * jnp.cos(theta_))
59
+ t_dot = 1
60
+ return jnp.hstack([r_dot, v_dot, m_dot, t_dot])
61
+
62
+ # Set parameter values before solving
63
+ I_sp.value = 225
64
+ g.value = 3.7114
65
+ theta.value = 27 * jnp.pi / 180
66
+ ```
67
+
68
+ ---
69
+ **Using Parameters in Nodal Constraints**
70
+
71
+ You can also use symbolic `Parameter` objects in nodal constraints. As with dynamics, the argument names for parameters in the constraint function must match the parameter name with an underscore suffix (e.g., `g_` for a parameter named `g`).
72
+
73
+ Example:
74
+
75
+ ```python
76
+ from openscvx.backend.parameter import Parameter
77
+ from openscvx.constraints import nodal
78
+ import jax.numpy as jnp
79
+
80
+ g = Parameter("g")
81
+ g.value = 3.7114
82
+
83
+ @nodal
84
+ def terminal_velocity_constraint(x_, u_, g_):
85
+ # Enforce a terminal velocity constraint using the gravity parameter
86
+ return x_[5] + g_ * x_[7] # e.g., vz + g * t <= 0 at final node
87
+ ```
88
+
89
+ When building your problem, collect all parameters with `Parameter.get_all()` and pass them to your problem setup.
90
+
91
+ Args:
92
+ f (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
93
+ Function defining the continuous time nonlinear system dynamics as x_dot = f(x, u, ...params).
94
+ - x: 1D array (state at a single node), shape (n_x,)
95
+ - u: 1D array (control at a single node), shape (n_u,)
96
+ - Additional parameters: passed as keyword arguments with names matching the parameter name plus an underscore (e.g., g_ for Parameter('g')).
97
+ If you want to use parameters, include them as extra arguments with the underscore naming convention.
98
+ If you use vectorized integration or batch evaluation, x and u may be 2D arrays (N, n_x) and (N, n_u).
99
+ A (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
100
+ Jacobian of `f` w.r.t. `x`. If not specified, will be calculated using `jax.jacfwd`.
101
+ B (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
102
+ Jacobian of `f` w.r.t. `u`. If not specified, will be calculated using `jax.jacfwd`.
103
+
104
+ Returns:
105
+ Dynamics: A dataclass bundling the system dynamics function and Jacobians.
106
+ """
107
+
10
108
  f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
11
109
  A: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
12
110
  B: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
13
111
 
112
+
14
113
  def dynamics(
15
- _func=None,
114
+ _func: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
16
115
  *,
17
- A: Optional[Callable] = None,
18
- B: Optional[Callable] = None,):
19
- """Decorator to mark a function as defining the system dynamics.
116
+ A: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
117
+ B: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
118
+ ) -> Union[Callable, Dynamics]:
119
+ """
120
+ Decorator that wraps a function defining the system dynamics as a `Dynamics` object.
121
+ You may optionally specify the system gradients w.r.t. `x`, `u` if desired, if not specified they will be calculated using `jax.jacfwd`.
122
+ Note: the dynamics as well as the optional gradients should be composed of `jax` primitives to enable efficient computation.
123
+
124
+ This decorator may be used with or without arguments:
20
125
 
21
- Use as:
22
- @dynamics(A=my_grad_f_x, B=my_grad_f_u)')
23
- def my_dynamics(x,u): ...
126
+ ```
127
+ @dynamics
128
+ def f(x, u): ...
129
+ ```
130
+
131
+ or
132
+
133
+ ```
134
+ @dynamics(A=grad_f_x, B=grad_f_u)
135
+ def f(x, u): ...
136
+ ```
137
+
138
+ or, if a more lambda-function-style is desired, the function can be direclty wrapped
139
+
140
+ ```
141
+ dyn = dynamics(f(x,u))
142
+ dyn_lambda = dynamics(lambda x, u: ...)
143
+ ```
144
+
145
+ Args:
146
+ _func (callable, optional): The function to wrap. Populated
147
+ when using @dynamics with no extra args.
148
+ A (callable, optional): Jacobian of f wrt state x. Computed
149
+ via jax.jacfwd if not provided.
150
+ B (callable, optional): Jacobian of f wrt input u. Computed
151
+ via jax.jacfwd if not provided.
152
+
153
+ Returns:
154
+ Union[Callable, Dynamics]
155
+ A decorator if called without a function, or a `Dynamics` dataclass bundling system dynamics function
156
+ and Jacobians when applied to a function.
157
+
158
+ Examples:
159
+ >>> @dynamics
160
+ ... def f(x, u):
161
+ ... return x + u
162
+ >>> isinstance(f, Dynamics)
163
+ True
24
164
  """
25
165
 
26
166
  def decorator(f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
27
- # wrap so name, doc, signature stay on f
28
- wrapped = functools.wraps(f)(f)
167
+ # Had to unwrap to ensure arguments are visible downstream. Originally wrapped so name, doc, signature stay on f
168
+ wrapped = f
29
169
  return Dynamics(
30
170
  f=wrapped,
31
171
  A=A,
@@ -38,4 +178,3 @@ def dynamics(
38
178
  # if called as dynamics(func), we immediately decorate
39
179
  else:
40
180
  return decorator(_func)
41
-
openscvx/integrators.py CHANGED
@@ -1,7 +1,52 @@
1
+ import os
2
+ os.environ["EQX_ON_ERROR"] = "nan"
3
+
4
+
5
+ from typing import Callable, Any, Optional
6
+
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import diffrax as dfx
11
+ from typing import Callable, Any
12
+
1
13
  import jax
2
14
  import jax.numpy as jnp
3
15
  import diffrax as dfx
4
16
 
17
+ from diffrax._global_interpolation import DenseInterpolation
18
+ from jax import tree_util
19
+
20
+ # Safely check if DenseInterpolation is already registered
21
+ try:
22
+ # Attempt to flatten a dummy DenseInterpolation instance
23
+ # Provide dummy arguments to create a valid instance
24
+ dummy_instance = DenseInterpolation(
25
+ ts=jnp.array([]),
26
+ ts_size=0,
27
+ infos=None,
28
+ interpolation_cls=None,
29
+ direction=None,
30
+ t0_if_trivial=0.0,
31
+ y0_if_trivial=jnp.array([]),
32
+ )
33
+ tree_util.tree_flatten(dummy_instance)
34
+ except ValueError:
35
+ # Register DenseInterpolation as a PyTree node if not already registered
36
+ def dense_interpolation_flatten(obj):
37
+ # Flatten the internal data of DenseInterpolation
38
+ return (obj._data,), None
39
+
40
+ def dense_interpolation_unflatten(aux_data, children):
41
+ # Reconstruct DenseInterpolation from its flattened data
42
+ return DenseInterpolation(*children)
43
+
44
+ tree_util.register_pytree_node(
45
+ DenseInterpolation,
46
+ dense_interpolation_flatten,
47
+ dense_interpolation_unflatten,
48
+ )
49
+
5
50
  SOLVER_MAP = {
6
51
  "Tsit5": dfx.Tsit5,
7
52
  "Euler": dfx.Euler,
@@ -19,7 +64,30 @@ SOLVER_MAP = {
19
64
  }
20
65
 
21
66
  # fmt: off
22
- def rk45_step(f, t, y, h, *args):
67
+ def rk45_step(
68
+ f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
69
+ t: jnp.ndarray,
70
+ y: jnp.ndarray,
71
+ h: float,
72
+ *args
73
+ ) -> jnp.ndarray:
74
+ """
75
+ Perform a single RK45 (Runge-Kutta-Fehlberg) integration step.
76
+
77
+ This implements the classic Dorman-Prince coefficients for an
78
+ explicit 4(5) method, returning the fourth-order estimate.
79
+
80
+ Args:
81
+ f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]):
82
+ ODE right-hand side; signature f(t, y, *args) -> dy/dt.
83
+ t (jnp.ndarray): Current time.
84
+ y (jnp.ndarray): Current state vector.
85
+ h (float): Step size.
86
+ *args: Additional arguments passed to `f`.
87
+
88
+ Returns:
89
+ jnp.ndarray: Next state estimate at t + h.
90
+ """
23
91
  k1 = f(t, y, *args)
24
92
  k2 = f(t + h/4, y + h*k1/4, *args)
25
93
  k3 = f(t + 3*h/8, y + 3*h*k1/32 + 9*h*k2/32, *args)
@@ -31,14 +99,31 @@ def rk45_step(f, t, y, h, *args):
31
99
 
32
100
 
33
101
  def solve_ivp_rk45(
34
- f,
102
+ f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
35
103
  tau_final: float,
36
- y_0,
104
+ y_0: jnp.ndarray,
37
105
  args,
38
106
  tau_0: float = 0.0,
39
107
  num_substeps: int = 50,
40
108
  is_not_compiled: bool = False,
41
109
  ):
110
+ """
111
+ Solve an initial-value ODE problem using fixed-step RK45 integration.
112
+
113
+ Args:
114
+ f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]):
115
+ ODE right-hand side; signature f(t, y, *args) -> dy/dt.
116
+ tau_final (float): Final integration time.
117
+ y_0 (jnp.ndarray): Initial state at tau_0.
118
+ args (tuple): Extra arguments to pass to `f`.
119
+ tau_0 (float, optional): Initial time. Defaults to 0.0.
120
+ num_substeps (int, optional): Number of output time points. Defaults to 50.
121
+ is_not_compiled (bool, optional): If True, use Python loop instead of
122
+ JAX `lax.fori_loop`. Defaults to False.
123
+
124
+ Returns:
125
+ jnp.ndarray: Array of shape (num_substeps, state_dim) with solution at each time.
126
+ """
42
127
  substeps = jnp.linspace(tau_0, tau_final, num_substeps)
43
128
 
44
129
  h = (tau_final - tau_0) / (len(substeps) - 1)
@@ -65,17 +150,40 @@ def solve_ivp_rk45(
65
150
 
66
151
 
67
152
  def solve_ivp_diffrax(
68
- f,
69
- tau_final,
70
- y_0,
153
+ f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
154
+ tau_final: float,
155
+ y_0: jnp.ndarray,
71
156
  args,
72
157
  tau_0: float = 0.0,
73
158
  num_substeps: int = 50,
74
- solver_name="Dopri8",
159
+ solver_name: str = "Dopri8",
75
160
  rtol: float = 1e-3,
76
161
  atol: float = 1e-6,
77
162
  extra_kwargs=None,
78
163
  ):
164
+ """
165
+ Solve an initial-value ODE problem using a Diffrax adaptive solver.
166
+
167
+ Args:
168
+ f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]): ODE right-hand side; signature f(t, y, *args) -> dy/dt.
169
+ tau_final (float): Final integration time.
170
+ y_0 (jnp.ndarray): Initial state at tau_0.
171
+ args (tuple): Extra arguments to pass to `f` in the solver term.
172
+ tau_0 (float, optional): Initial time. Defaults to 0.0.
173
+ num_substeps (int, optional): Number of save points between tau_0 and tau_final.
174
+ Defaults to 50.
175
+ solver_name (str, optional): Key into SOLVER_MAP for the Diffrax solver class.
176
+ Defaults to "Dopri8".
177
+ rtol (float, optional): Relative tolerance for adaptive stepping. Defaults to 1e-3.
178
+ atol (float, optional): Absolute tolerance for adaptive stepping. Defaults to 1e-6.
179
+ extra_kwargs (dict, optional): Additional keyword arguments forwarded to `diffeqsolve`.
180
+
181
+ Returns:
182
+ jnp.ndarray: Solution states at the requested save points, shape (num_substeps, state_dim).
183
+
184
+ Raises:
185
+ ValueError: If `solver_name` is not in SOLVER_MAP.
186
+ """
79
187
  substeps = jnp.linspace(tau_0, tau_final, num_substeps)
80
188
 
81
189
  solver_class = SOLVER_MAP.get(solver_name)
@@ -95,26 +203,31 @@ def solve_ivp_diffrax(
95
203
  args=args,
96
204
  stepsize_controller=stepsize_controller,
97
205
  saveat=dfx.SaveAt(ts=substeps),
206
+ progress_meter=dfx.NoProgressMeter(),
98
207
  **(extra_kwargs or {}),
99
208
  )
100
209
 
101
210
  return solution.ys
102
211
 
103
212
 
104
- # TODO: (norrisg) this function is basically identical to `solve_ivp_diffrax`, could combine, but requires returning solution and getting `.ys` wherever the `solve_ivp_diffrax` is called
105
213
  def solve_ivp_diffrax_prop(
106
- f,
107
- tau_final,
108
- y_0,
214
+ f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
215
+ tau_final: float,
216
+ y_0: jnp.ndarray,
109
217
  args,
110
218
  tau_0: float = 0.0,
111
219
  num_substeps: int = 50,
112
- solver_name="Dopri8",
220
+ solver_name: str = "Dopri8",
113
221
  rtol: float = 1e-3,
114
222
  atol: float = 1e-6,
115
223
  extra_kwargs=None,
224
+ save_time: jnp.ndarray = None,
225
+ mask: jnp.ndarray = None,
116
226
  ):
117
- substeps = jnp.linspace(tau_0, tau_final, num_substeps)
227
+ if save_time is None:
228
+ raise ValueError("save_time must be provided for export compatibility.")
229
+ if mask is None:
230
+ mask = jnp.ones_like(save_time, dtype=bool)
118
231
 
119
232
  solver_class = SOLVER_MAP.get(solver_name)
120
233
  if solver_class is None:
@@ -123,17 +236,23 @@ def solve_ivp_diffrax_prop(
123
236
 
124
237
  term = dfx.ODETerm(lambda t, y, args: f(t, y, *args))
125
238
  stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol)
239
+
126
240
  solution = dfx.diffeqsolve(
127
241
  term,
128
242
  solver=solver,
129
243
  t0=tau_0,
130
244
  t1=tau_final,
131
- dt0=(tau_final - tau_0) / (len(substeps) - 1),
245
+ dt0=(tau_final - tau_0) / 1,
132
246
  y0=y_0,
133
247
  args=args,
134
248
  stepsize_controller=stepsize_controller,
135
- saveat=dfx.SaveAt(dense=True, ts=substeps),
249
+ saveat=dfx.SaveAt(dense=True),
136
250
  **(extra_kwargs or {}),
137
251
  )
138
252
 
139
- return solution
253
+ # Evaluate all save_time points (static size), then mask them
254
+ all_evals = jax.vmap(solution.evaluate)(save_time) # shape: (MAX_TAU_LEN, n_states)
255
+ masked_array = jnp.where(mask[:, None], all_evals, jnp.zeros_like(all_evals))
256
+ # shape: (variable_len, n_states)
257
+
258
+ return masked_array