mxlpy 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/fit.py CHANGED
@@ -28,12 +28,16 @@ from mxlpy.types import (
28
28
 
29
29
  __all__ = [
30
30
  "InitialGuess",
31
+ "LossFn",
31
32
  "MinimizeFn",
33
+ "ProtocolResidualFn",
32
34
  "ResidualFn",
33
35
  "SteadyStateResidualFn",
34
36
  "TimeSeriesResidualFn",
37
+ "rmse",
35
38
  "steady_state",
36
39
  "time_course",
40
+ "time_course_over_protocol",
37
41
  ]
38
42
 
39
43
  if TYPE_CHECKING:
@@ -44,6 +48,21 @@ if TYPE_CHECKING:
44
48
  type InitialGuess = dict[str, float]
45
49
  type ResidualFn = Callable[[Array], float]
46
50
  type MinimizeFn = Callable[[ResidualFn, InitialGuess], dict[str, float]]
51
+ type LossFn = Callable[
52
+ [
53
+ pd.DataFrame | pd.Series,
54
+ pd.DataFrame | pd.Series,
55
+ ],
56
+ float,
57
+ ]
58
+
59
+
60
+ def rmse(
61
+ y_pred: pd.DataFrame | pd.Series,
62
+ y_true: pd.DataFrame | pd.Series,
63
+ ) -> float:
64
+ """Calculate root mean square error between model and data."""
65
+ return cast(float, np.sqrt(np.mean(np.square(y_pred - y_true))))
47
66
 
48
67
 
49
68
  class SteadyStateResidualFn(Protocol):
@@ -58,6 +77,7 @@ class SteadyStateResidualFn(Protocol):
58
77
  model: Model,
59
78
  y0: dict[str, float],
60
79
  integrator: IntegratorType,
80
+ loss_fn: LossFn,
61
81
  ) -> float:
62
82
  """Calculate residual error between model steady state and experimental data."""
63
83
  ...
@@ -75,6 +95,27 @@ class TimeSeriesResidualFn(Protocol):
75
95
  model: Model,
76
96
  y0: dict[str, float],
77
97
  integrator: IntegratorType,
98
+ loss_fn: LossFn,
99
+ ) -> float:
100
+ """Calculate residual error between model time course and experimental data."""
101
+ ...
102
+
103
+
104
+ class ProtocolResidualFn(Protocol):
105
+ """Protocol for time series residual functions."""
106
+
107
+ def __call__(
108
+ self,
109
+ par_values: Array,
110
+ # This will be filled out by partial
111
+ par_names: list[str],
112
+ data: pd.DataFrame,
113
+ model: Model,
114
+ y0: dict[str, float],
115
+ integrator: IntegratorType,
116
+ loss_fn: LossFn,
117
+ protocol: pd.DataFrame,
118
+ time_points_per_step: int = 10,
78
119
  ) -> float:
79
120
  """Calculate residual error between model time course and experimental data."""
80
121
  ...
@@ -109,6 +150,7 @@ def _steady_state_residual(
109
150
  model: Model,
110
151
  y0: dict[str, float] | None,
111
152
  integrator: IntegratorType,
153
+ loss_fn: LossFn,
112
154
  ) -> float:
113
155
  """Calculate residual error between model steady state and experimental data.
114
156
 
@@ -119,6 +161,7 @@ def _steady_state_residual(
119
161
  y0: Initial conditions
120
162
  par_names: Names of parameters being fit
121
163
  integrator: ODE integrator class to use
164
+ loss_fn: Loss function to use for residual calculation
122
165
 
123
166
  Returns:
124
167
  float: Root mean square error between model and data
@@ -143,9 +186,11 @@ def _steady_state_residual(
143
186
  )
144
187
  if res is None:
145
188
  return cast(float, np.inf)
146
- results_ss = res.get_combined()
147
- diff = data - results_ss.loc[:, data.index] # type: ignore
148
- return cast(float, np.sqrt(np.mean(np.square(diff))))
189
+
190
+ return loss_fn(
191
+ res.get_combined().loc[:, cast(list, data.index)],
192
+ data,
193
+ )
149
194
 
150
195
 
151
196
  def _time_course_residual(
@@ -156,6 +201,53 @@ def _time_course_residual(
156
201
  model: Model,
157
202
  y0: dict[str, float] | None,
158
203
  integrator: IntegratorType,
204
+ loss_fn: LossFn,
205
+ ) -> float:
206
+ """Calculate residual error between model time course and experimental data.
207
+
208
+ Args:
209
+ par_values: Parameter values to test
210
+ data: Experimental time course data
211
+ model: Model instance to simulate
212
+ y0: Initial conditions
213
+ par_names: Names of parameters being fit
214
+ integrator: ODE integrator class to use
215
+ loss_fn: Loss function to use for residual calculation
216
+
217
+ Returns:
218
+ float: Root mean square error between model and data
219
+
220
+ """
221
+ res = (
222
+ Simulator(
223
+ model.update_parameters(dict(zip(par_names, par_values, strict=True))),
224
+ y0=y0,
225
+ integrator=integrator,
226
+ )
227
+ .simulate_time_course(cast(list, data.index))
228
+ .get_result()
229
+ )
230
+ if res is None:
231
+ return cast(float, np.inf)
232
+ results_ss = res.get_combined()
233
+
234
+ return loss_fn(
235
+ results_ss.loc[:, cast(list, data.columns)],
236
+ data,
237
+ )
238
+
239
+
240
+ def _protocol_residual(
241
+ par_values: ArrayLike,
242
+ # This will be filled out by partial
243
+ par_names: list[str],
244
+ data: pd.DataFrame,
245
+ model: Model,
246
+ y0: dict[str, float] | None,
247
+ integrator: IntegratorType,
248
+ loss_fn: LossFn,
249
+ protocol: pd.DataFrame,
250
+ time_points_per_step: int = 10,
159
251
  ) -> float:
160
252
  """Calculate residual error between model time course and experimental data.
161
253
 
@@ -166,6 +258,9 @@ def _time_course_residual(
166
258
  y0: Initial conditions
167
259
  par_names: Names of parameters being fit
168
260
  integrator: ODE integrator class to use
261
+ loss_fn: Loss function to use for residual calculation
262
+ protocol: Experimental protocol
263
+ time_points_per_step: Number of time points per step in the protocol
169
264
 
170
265
  Returns:
171
266
  float: Root mean square error between model and data
@@ -177,14 +272,20 @@ def _time_course_residual(
177
272
  y0=y0,
178
273
  integrator=integrator,
179
274
  )
180
- .simulate_time_course(data.index) # type: ignore
275
+ .simulate_over_protocol(
276
+ protocol=protocol,
277
+ time_points_per_step=time_points_per_step,
278
+ )
181
279
  .get_result()
182
280
  )
183
281
  if res is None:
184
282
  return cast(float, np.inf)
185
283
  results_ss = res.get_combined()
186
- diff = data - results_ss.loc[:, data.columns] # type: ignore
187
- return cast(float, np.sqrt(np.mean(np.square(diff))))
284
+
285
+ return loss_fn(
286
+ results_ss.loc[:, cast(list, data.columns)],
287
+ data,
288
+ )
188
289
 
189
290
 
190
291
  def steady_state(
@@ -195,6 +296,7 @@ def steady_state(
195
296
  minimize_fn: MinimizeFn = _default_minimize_fn,
196
297
  residual_fn: SteadyStateResidualFn = _steady_state_residual,
197
298
  integrator: IntegratorType = DefaultIntegrator,
299
+ loss_fn: LossFn = rmse,
198
300
  ) -> dict[str, float]:
199
301
  """Fit model parameters to steady-state experimental data.
200
302
 
@@ -210,6 +312,7 @@ def steady_state(
210
312
  minimize_fn: Function to minimize fitting error
211
313
  residual_fn: Function to calculate fitting error
212
314
  integrator: ODE integrator class
315
+ loss_fn: Loss function to use for residual calculation
213
316
 
214
317
  Returns:
215
318
  dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
@@ -232,6 +335,7 @@ def steady_state(
232
335
  y0=y0,
233
336
  par_names=par_names,
234
337
  integrator=integrator,
338
+ loss_fn=loss_fn,
235
339
  ),
236
340
  )
237
341
  res = minimize_fn(fn, p0)
@@ -249,6 +353,62 @@ def time_course(
249
353
  minimize_fn: MinimizeFn = _default_minimize_fn,
250
354
  residual_fn: TimeSeriesResidualFn = _time_course_residual,
251
355
  integrator: IntegratorType = DefaultIntegrator,
356
+ loss_fn: LossFn = rmse,
357
+ ) -> dict[str, float]:
358
+ """Fit model parameters to time course of experimental data.
359
+
360
+ Examples:
361
+ >>> time_course(model, p0, data)
362
+ {'k1': 0.1, 'k2': 0.2}
363
+
364
+ Args:
365
+ model: Model instance to fit
366
+ data: Experimental time course data
367
+ p0: Initial parameter guesses as {parameter_name: value}
368
+ y0: Initial conditions as {species_name: value}
369
+ minimize_fn: Function to minimize fitting error
370
+ residual_fn: Function to calculate fitting error
371
+ integrator: ODE integrator class
372
+ loss_fn: Loss function to use for residual calculation
373
+
374
+ Returns:
375
+ dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
376
+
377
+ Note:
378
+ Uses L-BFGS-B optimization with bounds [1e-12, 1e6] for all parameters
379
+
380
+ """
381
+ par_names = list(p0.keys())
382
+ p_orig = model.parameters
383
+
384
+ fn = cast(
385
+ ResidualFn,
386
+ partial(
387
+ residual_fn,
388
+ data=data,
389
+ model=model,
390
+ y0=y0,
391
+ par_names=par_names,
392
+ integrator=integrator,
393
+ loss_fn=loss_fn,
394
+ ),
395
+ )
396
+ res = minimize_fn(fn, p0)
397
+ model.update_parameters(p_orig)
398
+ return res
399
+
400
+
401
+ def time_course_over_protocol(
402
+ model: Model,
403
+ p0: dict[str, float],
404
+ data: pd.DataFrame,
405
+ protocol: pd.DataFrame,
406
+ y0: dict[str, float] | None = None,
407
+ minimize_fn: MinimizeFn = _default_minimize_fn,
408
+ residual_fn: ProtocolResidualFn = _protocol_residual,
409
+ integrator: IntegratorType = DefaultIntegrator,
410
+ loss_fn: LossFn = rmse,
411
+ time_points_per_step: int = 10,
252
412
  ) -> dict[str, float]:
253
413
  """Fit model parameters to time course of experimental data.
254
414
 
@@ -258,12 +418,15 @@ def time_course(
258
418
 
259
419
  Args:
260
420
  model: Model instance to fit
261
- data: Experimental time course data as pandas DataFrame
262
421
  p0: Initial parameter guesses as {parameter_name: value}
422
+ data: Experimental time course data
423
+ protocol: Experimental protocol
263
424
  y0: Initial conditions as {species_name: value}
264
425
  minimize_fn: Function to minimize fitting error
265
426
  residual_fn: Function to calculate fitting error
266
427
  integrator: ODE integrator class
428
+ loss_fn: Loss function to use for residual calculation
429
+ time_points_per_step: Number of time points per step in the protocol
267
430
 
268
431
  Returns:
269
432
  dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
@@ -284,6 +447,9 @@ def time_course(
284
447
  y0=y0,
285
448
  par_names=par_names,
286
449
  integrator=integrator,
450
+ loss_fn=loss_fn,
451
+ protocol=protocol,
452
+ time_points_per_step=time_points_per_step,
287
453
  ),
288
454
  )
289
455
  res = minimize_fn(fn, p0)
mxlpy/identify.py CHANGED
@@ -19,6 +19,7 @@ def _mc_fit_time_course_worker(
19
19
  p0: pd.Series,
20
20
  model: Model,
21
21
  data: pd.DataFrame,
22
+ loss_fn: fit.LossFn,
22
23
  ) -> float:
23
24
  p_fit = fit.time_course(model=model, p0=p0.to_dict(), data=data)
24
25
  return fit._time_course_residual( # noqa: SLF001
@@ -28,6 +29,7 @@ def _mc_fit_time_course_worker(
28
29
  model=model,
29
30
  y0=None,
30
31
  integrator=fit.DefaultIntegrator,
32
+ loss_fn=loss_fn,
31
33
  )
32
34
 
33
35
 
@@ -37,6 +39,7 @@ def profile_likelihood(
37
39
  parameter_name: str,
38
40
  parameter_values: Array,
39
41
  n_random: int = 10,
42
+ loss_fn: fit.LossFn = fit.rmse,
40
43
  ) -> pd.Series:
41
44
  """Estimate the profile likelihood of model parameters given data.
42
45
 
@@ -46,6 +49,7 @@ def profile_likelihood(
46
49
  parameter_name: The name of the parameter to profile.
47
50
  parameter_values: The values of the parameter to profile.
48
51
  n_random: Number of Monte Carlo samples.
52
+ loss_fn: Loss function to use for fitting.
49
53
 
50
54
  """
51
55
  parameter_distributions = sample(
@@ -57,7 +61,9 @@ def profile_likelihood(
57
61
  for value in tqdm(parameter_values, desc=parameter_name):
58
62
  model.update_parameter(parameter_name, value)
59
63
  res[value] = parallelise(
60
- partial(_mc_fit_time_course_worker, model=model, data=data),
64
+ partial(
65
+ _mc_fit_time_course_worker, model=model, data=data, loss_fn=loss_fn
66
+ ),
61
67
  inputs=list(
62
68
  parameter_distributions.drop(columns=parameter_name).iterrows()
63
69
  ),
mxlpy/nn/_torch.py CHANGED
@@ -8,17 +8,77 @@ from __future__ import annotations
8
8
 
9
9
  from typing import TYPE_CHECKING, cast
10
10
 
11
+ import numpy as np
12
+ import pandas as pd
11
13
  import torch
14
+ import tqdm
12
15
  from torch import nn
16
+ from torch.utils.data import DataLoader, TensorDataset
17
+
18
+ type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
13
19
 
14
20
  if TYPE_CHECKING:
15
21
  from collections.abc import Callable
16
22
 
17
- __all__ = ["DefaultDevice", "LSTM", "MLP"]
23
+ from torch.optim.adam import Adam
24
+
25
+ from mxlpy.types import Array
26
+
27
+ __all__ = ["DefaultDevice", "LSTM", "LossFn", "MLP", "train"]
18
28
 
19
29
  DefaultDevice = torch.device("cpu")
20
30
 
21
31
 
32
+ def train(
33
+ aprox: nn.Module,
34
+ features: Array,
35
+ targets: Array,
36
+ epochs: int,
37
+ optimizer: Adam,
38
+ device: torch.device,
39
+ batch_size: int | None,
40
+ loss_fn: LossFn,
41
+ ) -> pd.Series:
42
+ """Train the neural network using mini-batch gradient descent.
43
+
44
+ Args:
45
+ aprox: Neural network model to train.
46
+ features: Input features as a tensor.
47
+ targets: Target values as a tensor.
48
+ epochs: Number of training epochs.
49
+ optimizer: Optimizer for training.
50
+ device: torch device
51
+ batch_size: Size of mini-batches for training.
52
+ loss_fn: Loss function
53
+
54
+ Returns:
55
+ pd.Series: Series containing the training loss history.
56
+
57
+ """
58
+ losses = {}
59
+
60
+ data = TensorDataset(
61
+ torch.tensor(features.astype(np.float32), dtype=torch.float32, device=device),
62
+ torch.tensor(targets.astype(np.float32), dtype=torch.float32, device=device),
63
+ )
64
+ data_loader = DataLoader(
65
+ data,
66
+ batch_size=len(features) if batch_size is None else batch_size,
67
+ shuffle=True,
68
+ )
69
+
70
+ for i in tqdm.trange(epochs):
71
+ epoch_loss = 0
72
+ for xb, yb in data_loader:
73
+ optimizer.zero_grad()
74
+ loss = loss_fn(aprox(xb), yb)
75
+ loss.backward()
76
+ optimizer.step()
77
+ epoch_loss += loss.item() * xb.size(0)
78
+ losses[i] = epoch_loss / len(data_loader.dataset) # type: ignore
79
+ return pd.Series(losses, dtype=float)
80
+
81
+
22
82
  class MLP(nn.Module):
23
83
  """Multilayer Perceptron (MLP) for surrogate modeling and neural posterior estimation.
24
84
 
mxlpy/npe/_torch.py CHANGED
@@ -18,11 +18,10 @@ from typing import TYPE_CHECKING, Self, cast
18
18
  import numpy as np
19
19
  import pandas as pd
20
20
  import torch
21
- import tqdm
22
21
  from torch import nn
23
22
  from torch.optim.adam import Adam
24
23
 
25
- from mxlpy.nn._torch import LSTM, MLP, DefaultDevice
24
+ from mxlpy.nn._torch import LSTM, MLP, DefaultDevice, train
26
25
  from mxlpy.parallel import Cache
27
26
  from mxlpy.types import AbstractEstimator
28
27
 
@@ -161,28 +160,16 @@ class TorchSteadyStateTrainer:
161
160
  batch_size: Size of mini-batches for training (None for full-batch)
162
161
 
163
162
  """
164
- features = torch.Tensor(self.features.to_numpy(), device=self.device)
165
- targets = torch.Tensor(self.targets.to_numpy(), device=self.device)
166
-
167
- if batch_size is None:
168
- losses = _train_full(
169
- approximator=self.approximator,
170
- features=features,
171
- targets=targets,
172
- epochs=epochs,
173
- optimizer=self.optimizer,
174
- loss_fn=self.loss_fn,
175
- )
176
- else:
177
- losses = _train_batched(
178
- approximator=self.approximator,
179
- features=features,
180
- targets=targets,
181
- epochs=epochs,
182
- optimizer=self.optimizer,
183
- batch_size=batch_size,
184
- loss_fn=self.loss_fn,
185
- )
163
+ losses = train(
164
+ aprox=self.approximator,
165
+ features=self.features.to_numpy(),
166
+ targets=self.targets.to_numpy(),
167
+ epochs=epochs,
168
+ optimizer=self.optimizer,
169
+ batch_size=batch_size,
170
+ loss_fn=self.loss_fn,
171
+ device=self.device,
172
+ )
186
173
 
187
174
  if len(self.losses) > 0:
188
175
  losses.index += self.losses[-1].index[-1]
@@ -260,37 +247,22 @@ class TorchTimeCourseTrainer:
260
247
  batch_size: Size of mini-batches for training (None for full-batch)
261
248
 
262
249
  """
263
- features = torch.Tensor(
264
- np.swapaxes(
250
+ losses = train(
251
+ aprox=self.approximator,
252
+ features=np.swapaxes(
265
253
  self.features.to_numpy().reshape(
266
254
  (len(self.targets), -1, len(self.features.columns))
267
255
  ),
268
256
  axis1=0,
269
257
  axis2=1,
270
258
  ),
259
+ targets=self.targets.to_numpy(),
260
+ epochs=epochs,
261
+ optimizer=self.optimizer,
262
+ batch_size=batch_size,
263
+ loss_fn=self.loss_fn,
271
264
  device=self.device,
272
265
  )
273
- targets = torch.Tensor(self.targets.to_numpy(), device=self.device)
274
-
275
- if batch_size is None:
276
- losses = _train_full(
277
- approximator=self.approximator,
278
- features=features,
279
- targets=targets,
280
- epochs=epochs,
281
- optimizer=self.optimizer,
282
- loss_fn=self.loss_fn,
283
- )
284
- else:
285
- losses = _train_batched(
286
- approximator=self.approximator,
287
- features=features,
288
- targets=targets,
289
- epochs=epochs,
290
- optimizer=self.optimizer,
291
- batch_size=batch_size,
292
- loss_fn=self.loss_fn,
293
- )
294
266
 
295
267
  if len(self.losses) > 0:
296
268
  losses.index += self.losses[-1].index[-1]
@@ -309,49 +281,6 @@ class TorchTimeCourseTrainer:
309
281
  )
310
282
 
311
283
 
312
- def _train_batched(
313
- approximator: nn.Module,
314
- features: torch.Tensor,
315
- targets: torch.Tensor,
316
- epochs: int,
317
- optimizer: Adam,
318
- batch_size: int,
319
- loss_fn: LossFn,
320
- ) -> pd.Series:
321
- losses = {}
322
- for epoch in tqdm.trange(epochs):
323
- permutation = torch.randperm(features.size()[0])
324
- epoch_loss = 0
325
- for i in range(0, features.size()[0], batch_size):
326
- optimizer.zero_grad()
327
- indices = permutation[i : i + batch_size]
328
- loss = loss_fn(approximator(features[indices]), targets[indices])
329
- loss.backward()
330
- optimizer.step()
331
- epoch_loss += loss.detach().numpy()
332
-
333
- losses[epoch] = epoch_loss / (features.size()[0] / batch_size)
334
- return pd.Series(losses, dtype=float)
335
-
336
-
337
- def _train_full(
338
- approximator: nn.Module,
339
- features: torch.Tensor,
340
- targets: torch.Tensor,
341
- epochs: int,
342
- optimizer: Adam,
343
- loss_fn: LossFn,
344
- ) -> pd.Series:
345
- losses = {}
346
- for i in tqdm.trange(epochs):
347
- optimizer.zero_grad()
348
- loss = loss_fn(approximator(features), targets)
349
- loss.backward()
350
- optimizer.step()
351
- losses[i] = loss.detach().numpy()
352
- return pd.Series(losses, dtype=float)
353
-
354
-
355
284
  def train_torch_steady_state(
356
285
  features: pd.DataFrame,
357
286
  targets: pd.DataFrame,