mxlpy 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/fit.py +173 -7
- mxlpy/identify.py +7 -1
- mxlpy/nn/_torch.py +61 -1
- mxlpy/npe/_torch.py +19 -90
- mxlpy/plot.py +194 -50
- mxlpy/surrogates/_torch.py +11 -101
- mxlpy/types.py +0 -3
- {mxlpy-0.17.0.dist-info → mxlpy-0.18.0.dist-info}/METADATA +7 -7
- {mxlpy-0.17.0.dist-info → mxlpy-0.18.0.dist-info}/RECORD +11 -11
- mxlpy-0.18.0.dist-info/licenses/LICENSE +21 -0
- mxlpy-0.17.0.dist-info/licenses/LICENSE +0 -674
- {mxlpy-0.17.0.dist-info → mxlpy-0.18.0.dist-info}/WHEEL +0 -0
mxlpy/fit.py
CHANGED
@@ -28,12 +28,16 @@ from mxlpy.types import (
|
|
28
28
|
|
29
29
|
__all__ = [
|
30
30
|
"InitialGuess",
|
31
|
+
"LossFn",
|
31
32
|
"MinimizeFn",
|
33
|
+
"ProtocolResidualFn",
|
32
34
|
"ResidualFn",
|
33
35
|
"SteadyStateResidualFn",
|
34
36
|
"TimeSeriesResidualFn",
|
37
|
+
"rmse",
|
35
38
|
"steady_state",
|
36
39
|
"time_course",
|
40
|
+
"time_course_over_protocol",
|
37
41
|
]
|
38
42
|
|
39
43
|
if TYPE_CHECKING:
|
@@ -44,6 +48,21 @@ if TYPE_CHECKING:
|
|
44
48
|
type InitialGuess = dict[str, float]
|
45
49
|
type ResidualFn = Callable[[Array], float]
|
46
50
|
type MinimizeFn = Callable[[ResidualFn, InitialGuess], dict[str, float]]
|
51
|
+
type LossFn = Callable[
|
52
|
+
[
|
53
|
+
pd.DataFrame | pd.Series,
|
54
|
+
pd.DataFrame | pd.Series,
|
55
|
+
],
|
56
|
+
float,
|
57
|
+
]
|
58
|
+
|
59
|
+
|
60
|
+
def rmse(
|
61
|
+
y_pred: pd.DataFrame | pd.Series,
|
62
|
+
y_true: pd.DataFrame | pd.Series,
|
63
|
+
) -> float:
|
64
|
+
"""Calculate root mean square error between model and data."""
|
65
|
+
return cast(float, np.sqrt(np.mean(np.square(y_pred - y_true))))
|
47
66
|
|
48
67
|
|
49
68
|
class SteadyStateResidualFn(Protocol):
|
@@ -58,6 +77,7 @@ class SteadyStateResidualFn(Protocol):
|
|
58
77
|
model: Model,
|
59
78
|
y0: dict[str, float],
|
60
79
|
integrator: IntegratorType,
|
80
|
+
loss_fn: LossFn,
|
61
81
|
) -> float:
|
62
82
|
"""Calculate residual error between model steady state and experimental data."""
|
63
83
|
...
|
@@ -75,6 +95,27 @@ class TimeSeriesResidualFn(Protocol):
|
|
75
95
|
model: Model,
|
76
96
|
y0: dict[str, float],
|
77
97
|
integrator: IntegratorType,
|
98
|
+
loss_fn: LossFn,
|
99
|
+
) -> float:
|
100
|
+
"""Calculate residual error between model time course and experimental data."""
|
101
|
+
...
|
102
|
+
|
103
|
+
|
104
|
+
class ProtocolResidualFn(Protocol):
|
105
|
+
"""Protocol for time series residual functions."""
|
106
|
+
|
107
|
+
def __call__(
|
108
|
+
self,
|
109
|
+
par_values: Array,
|
110
|
+
# This will be filled out by partial
|
111
|
+
par_names: list[str],
|
112
|
+
data: pd.DataFrame,
|
113
|
+
model: Model,
|
114
|
+
y0: dict[str, float],
|
115
|
+
integrator: IntegratorType,
|
116
|
+
loss_fn: LossFn,
|
117
|
+
protocol: pd.DataFrame,
|
118
|
+
time_points_per_step: int = 10,
|
78
119
|
) -> float:
|
79
120
|
"""Calculate residual error between model time course and experimental data."""
|
80
121
|
...
|
@@ -109,6 +150,7 @@ def _steady_state_residual(
|
|
109
150
|
model: Model,
|
110
151
|
y0: dict[str, float] | None,
|
111
152
|
integrator: IntegratorType,
|
153
|
+
loss_fn: LossFn,
|
112
154
|
) -> float:
|
113
155
|
"""Calculate residual error between model steady state and experimental data.
|
114
156
|
|
@@ -119,6 +161,7 @@ def _steady_state_residual(
|
|
119
161
|
y0: Initial conditions
|
120
162
|
par_names: Names of parameters being fit
|
121
163
|
integrator: ODE integrator class to use
|
164
|
+
loss_fn: Loss function to use for residual calculation
|
122
165
|
|
123
166
|
Returns:
|
124
167
|
float: Root mean square error between model and data
|
@@ -143,9 +186,11 @@ def _steady_state_residual(
|
|
143
186
|
)
|
144
187
|
if res is None:
|
145
188
|
return cast(float, np.inf)
|
146
|
-
|
147
|
-
|
148
|
-
|
189
|
+
|
190
|
+
return loss_fn(
|
191
|
+
res.get_combined().loc[:, cast(list, data.index)],
|
192
|
+
data,
|
193
|
+
)
|
149
194
|
|
150
195
|
|
151
196
|
def _time_course_residual(
|
@@ -156,6 +201,53 @@ def _time_course_residual(
|
|
156
201
|
model: Model,
|
157
202
|
y0: dict[str, float] | None,
|
158
203
|
integrator: IntegratorType,
|
204
|
+
loss_fn: LossFn,
|
205
|
+
) -> float:
|
206
|
+
"""Calculate residual error between model time course and experimental data.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
par_values: Parameter values to test
|
210
|
+
data: Experimental time course data
|
211
|
+
model: Model instance to simulate
|
212
|
+
y0: Initial conditions
|
213
|
+
par_names: Names of parameters being fit
|
214
|
+
integrator: ODE integrator class to use
|
215
|
+
loss_fn: Loss function to use for residual calculation
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
float: Root mean square error between model and data
|
219
|
+
|
220
|
+
"""
|
221
|
+
res = (
|
222
|
+
Simulator(
|
223
|
+
model.update_parameters(dict(zip(par_names, par_values, strict=True))),
|
224
|
+
y0=y0,
|
225
|
+
integrator=integrator,
|
226
|
+
)
|
227
|
+
.simulate_time_course(cast(list, data.index))
|
228
|
+
.get_result()
|
229
|
+
)
|
230
|
+
if res is None:
|
231
|
+
return cast(float, np.inf)
|
232
|
+
results_ss = res.get_combined()
|
233
|
+
|
234
|
+
return loss_fn(
|
235
|
+
results_ss.loc[:, cast(list, data.columns)],
|
236
|
+
data,
|
237
|
+
)
|
238
|
+
|
239
|
+
|
240
|
+
def _protocol_residual(
|
241
|
+
par_values: ArrayLike,
|
242
|
+
# This will be filled out by partial
|
243
|
+
par_names: list[str],
|
244
|
+
data: pd.DataFrame,
|
245
|
+
model: Model,
|
246
|
+
y0: dict[str, float] | None,
|
247
|
+
integrator: IntegratorType,
|
248
|
+
loss_fn: LossFn,
|
249
|
+
protocol: pd.DataFrame,
|
250
|
+
time_points_per_step: int = 10,
|
159
251
|
) -> float:
|
160
252
|
"""Calculate residual error between model time course and experimental data.
|
161
253
|
|
@@ -166,6 +258,9 @@ def _time_course_residual(
|
|
166
258
|
y0: Initial conditions
|
167
259
|
par_names: Names of parameters being fit
|
168
260
|
integrator: ODE integrator class to use
|
261
|
+
loss_fn: Loss function to use for residual calculation
|
262
|
+
protocol: Experimental protocol
|
263
|
+
time_points_per_step: Number of time points per step in the protocol
|
169
264
|
|
170
265
|
Returns:
|
171
266
|
float: Root mean square error between model and data
|
@@ -177,14 +272,20 @@ def _time_course_residual(
|
|
177
272
|
y0=y0,
|
178
273
|
integrator=integrator,
|
179
274
|
)
|
180
|
-
.
|
275
|
+
.simulate_over_protocol(
|
276
|
+
protocol=protocol,
|
277
|
+
time_points_per_step=time_points_per_step,
|
278
|
+
)
|
181
279
|
.get_result()
|
182
280
|
)
|
183
281
|
if res is None:
|
184
282
|
return cast(float, np.inf)
|
185
283
|
results_ss = res.get_combined()
|
186
|
-
|
187
|
-
return
|
284
|
+
|
285
|
+
return loss_fn(
|
286
|
+
results_ss.loc[:, cast(list, data.columns)],
|
287
|
+
data,
|
288
|
+
)
|
188
289
|
|
189
290
|
|
190
291
|
def steady_state(
|
@@ -195,6 +296,7 @@ def steady_state(
|
|
195
296
|
minimize_fn: MinimizeFn = _default_minimize_fn,
|
196
297
|
residual_fn: SteadyStateResidualFn = _steady_state_residual,
|
197
298
|
integrator: IntegratorType = DefaultIntegrator,
|
299
|
+
loss_fn: LossFn = rmse,
|
198
300
|
) -> dict[str, float]:
|
199
301
|
"""Fit model parameters to steady-state experimental data.
|
200
302
|
|
@@ -210,6 +312,7 @@ def steady_state(
|
|
210
312
|
minimize_fn: Function to minimize fitting error
|
211
313
|
residual_fn: Function to calculate fitting error
|
212
314
|
integrator: ODE integrator class
|
315
|
+
loss_fn: Loss function to use for residual calculation
|
213
316
|
|
214
317
|
Returns:
|
215
318
|
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
@@ -232,6 +335,7 @@ def steady_state(
|
|
232
335
|
y0=y0,
|
233
336
|
par_names=par_names,
|
234
337
|
integrator=integrator,
|
338
|
+
loss_fn=loss_fn,
|
235
339
|
),
|
236
340
|
)
|
237
341
|
res = minimize_fn(fn, p0)
|
@@ -249,6 +353,62 @@ def time_course(
|
|
249
353
|
minimize_fn: MinimizeFn = _default_minimize_fn,
|
250
354
|
residual_fn: TimeSeriesResidualFn = _time_course_residual,
|
251
355
|
integrator: IntegratorType = DefaultIntegrator,
|
356
|
+
loss_fn: LossFn = rmse,
|
357
|
+
) -> dict[str, float]:
|
358
|
+
"""Fit model parameters to time course of experimental data.
|
359
|
+
|
360
|
+
Examples:
|
361
|
+
>>> time_course(model, p0, data)
|
362
|
+
{'k1': 0.1, 'k2': 0.2}
|
363
|
+
|
364
|
+
Args:
|
365
|
+
model: Model instance to fit
|
366
|
+
data: Experimental time course data
|
367
|
+
p0: Initial parameter guesses as {parameter_name: value}
|
368
|
+
y0: Initial conditions as {species_name: value}
|
369
|
+
minimize_fn: Function to minimize fitting error
|
370
|
+
residual_fn: Function to calculate fitting error
|
371
|
+
integrator: ODE integrator class
|
372
|
+
loss_fn: Loss function to use for residual calculation
|
373
|
+
|
374
|
+
Returns:
|
375
|
+
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
376
|
+
|
377
|
+
Note:
|
378
|
+
Uses L-BFGS-B optimization with bounds [1e-12, 1e6] for all parameters
|
379
|
+
|
380
|
+
"""
|
381
|
+
par_names = list(p0.keys())
|
382
|
+
p_orig = model.parameters
|
383
|
+
|
384
|
+
fn = cast(
|
385
|
+
ResidualFn,
|
386
|
+
partial(
|
387
|
+
residual_fn,
|
388
|
+
data=data,
|
389
|
+
model=model,
|
390
|
+
y0=y0,
|
391
|
+
par_names=par_names,
|
392
|
+
integrator=integrator,
|
393
|
+
loss_fn=loss_fn,
|
394
|
+
),
|
395
|
+
)
|
396
|
+
res = minimize_fn(fn, p0)
|
397
|
+
model.update_parameters(p_orig)
|
398
|
+
return res
|
399
|
+
|
400
|
+
|
401
|
+
def time_course_over_protocol(
|
402
|
+
model: Model,
|
403
|
+
p0: dict[str, float],
|
404
|
+
data: pd.DataFrame,
|
405
|
+
protocol: pd.DataFrame,
|
406
|
+
y0: dict[str, float] | None = None,
|
407
|
+
minimize_fn: MinimizeFn = _default_minimize_fn,
|
408
|
+
residual_fn: ProtocolResidualFn = _protocol_residual,
|
409
|
+
integrator: IntegratorType = DefaultIntegrator,
|
410
|
+
loss_fn: LossFn = rmse,
|
411
|
+
time_points_per_step: int = 10,
|
252
412
|
) -> dict[str, float]:
|
253
413
|
"""Fit model parameters to time course of experimental data.
|
254
414
|
|
@@ -258,12 +418,15 @@ def time_course(
|
|
258
418
|
|
259
419
|
Args:
|
260
420
|
model: Model instance to fit
|
261
|
-
data: Experimental time course data as pandas DataFrame
|
262
421
|
p0: Initial parameter guesses as {parameter_name: value}
|
422
|
+
data: Experimental time course data
|
423
|
+
protocol: Experimental protocol
|
263
424
|
y0: Initial conditions as {species_name: value}
|
264
425
|
minimize_fn: Function to minimize fitting error
|
265
426
|
residual_fn: Function to calculate fitting error
|
266
427
|
integrator: ODE integrator class
|
428
|
+
loss_fn: Loss function to use for residual calculation
|
429
|
+
time_points_per_step: Number of time points per step in the protocol
|
267
430
|
|
268
431
|
Returns:
|
269
432
|
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
@@ -284,6 +447,9 @@ def time_course(
|
|
284
447
|
y0=y0,
|
285
448
|
par_names=par_names,
|
286
449
|
integrator=integrator,
|
450
|
+
loss_fn=loss_fn,
|
451
|
+
protocol=protocol,
|
452
|
+
time_points_per_step=time_points_per_step,
|
287
453
|
),
|
288
454
|
)
|
289
455
|
res = minimize_fn(fn, p0)
|
mxlpy/identify.py
CHANGED
@@ -19,6 +19,7 @@ def _mc_fit_time_course_worker(
|
|
19
19
|
p0: pd.Series,
|
20
20
|
model: Model,
|
21
21
|
data: pd.DataFrame,
|
22
|
+
loss_fn: fit.LossFn,
|
22
23
|
) -> float:
|
23
24
|
p_fit = fit.time_course(model=model, p0=p0.to_dict(), data=data)
|
24
25
|
return fit._time_course_residual( # noqa: SLF001
|
@@ -28,6 +29,7 @@ def _mc_fit_time_course_worker(
|
|
28
29
|
model=model,
|
29
30
|
y0=None,
|
30
31
|
integrator=fit.DefaultIntegrator,
|
32
|
+
loss_fn=loss_fn,
|
31
33
|
)
|
32
34
|
|
33
35
|
|
@@ -37,6 +39,7 @@ def profile_likelihood(
|
|
37
39
|
parameter_name: str,
|
38
40
|
parameter_values: Array,
|
39
41
|
n_random: int = 10,
|
42
|
+
loss_fn: fit.LossFn = fit.rmse,
|
40
43
|
) -> pd.Series:
|
41
44
|
"""Estimate the profile likelihood of model parameters given data.
|
42
45
|
|
@@ -46,6 +49,7 @@ def profile_likelihood(
|
|
46
49
|
parameter_name: The name of the parameter to profile.
|
47
50
|
parameter_values: The values of the parameter to profile.
|
48
51
|
n_random: Number of Monte Carlo samples.
|
52
|
+
loss_fn: Loss function to use for fitting.
|
49
53
|
|
50
54
|
"""
|
51
55
|
parameter_distributions = sample(
|
@@ -57,7 +61,9 @@ def profile_likelihood(
|
|
57
61
|
for value in tqdm(parameter_values, desc=parameter_name):
|
58
62
|
model.update_parameter(parameter_name, value)
|
59
63
|
res[value] = parallelise(
|
60
|
-
partial(
|
64
|
+
partial(
|
65
|
+
_mc_fit_time_course_worker, model=model, data=data, loss_fn=loss_fn
|
66
|
+
),
|
61
67
|
inputs=list(
|
62
68
|
parameter_distributions.drop(columns=parameter_name).iterrows()
|
63
69
|
),
|
mxlpy/nn/_torch.py
CHANGED
@@ -8,17 +8,77 @@ from __future__ import annotations
|
|
8
8
|
|
9
9
|
from typing import TYPE_CHECKING, cast
|
10
10
|
|
11
|
+
import numpy as np
|
12
|
+
import pandas as pd
|
11
13
|
import torch
|
14
|
+
import tqdm
|
12
15
|
from torch import nn
|
16
|
+
from torch.utils.data import DataLoader, TensorDataset
|
17
|
+
|
18
|
+
type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
13
19
|
|
14
20
|
if TYPE_CHECKING:
|
15
21
|
from collections.abc import Callable
|
16
22
|
|
17
|
-
|
23
|
+
from torch.optim.adam import Adam
|
24
|
+
|
25
|
+
from mxlpy.types import Array
|
26
|
+
|
27
|
+
__all__ = ["DefaultDevice", "LSTM", "LossFn", "MLP", "train"]
|
18
28
|
|
19
29
|
DefaultDevice = torch.device("cpu")
|
20
30
|
|
21
31
|
|
32
|
+
def train(
|
33
|
+
aprox: nn.Module,
|
34
|
+
features: Array,
|
35
|
+
targets: Array,
|
36
|
+
epochs: int,
|
37
|
+
optimizer: Adam,
|
38
|
+
device: torch.device,
|
39
|
+
batch_size: int | None,
|
40
|
+
loss_fn: LossFn,
|
41
|
+
) -> pd.Series:
|
42
|
+
"""Train the neural network using mini-batch gradient descent.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
aprox: Neural network model to train.
|
46
|
+
features: Input features as a tensor.
|
47
|
+
targets: Target values as a tensor.
|
48
|
+
epochs: Number of training epochs.
|
49
|
+
optimizer: Optimizer for training.
|
50
|
+
device: torch device
|
51
|
+
batch_size: Size of mini-batches for training.
|
52
|
+
loss_fn: Loss function
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
pd.Series: Series containing the training loss history.
|
56
|
+
|
57
|
+
"""
|
58
|
+
losses = {}
|
59
|
+
|
60
|
+
data = TensorDataset(
|
61
|
+
torch.tensor(features.astype(np.float32), dtype=torch.float32, device=device),
|
62
|
+
torch.tensor(targets.astype(np.float32), dtype=torch.float32, device=device),
|
63
|
+
)
|
64
|
+
data_loader = DataLoader(
|
65
|
+
data,
|
66
|
+
batch_size=len(features) if batch_size is None else batch_size,
|
67
|
+
shuffle=True,
|
68
|
+
)
|
69
|
+
|
70
|
+
for i in tqdm.trange(epochs):
|
71
|
+
epoch_loss = 0
|
72
|
+
for xb, yb in data_loader:
|
73
|
+
optimizer.zero_grad()
|
74
|
+
loss = loss_fn(aprox(xb), yb)
|
75
|
+
loss.backward()
|
76
|
+
optimizer.step()
|
77
|
+
epoch_loss += loss.item() * xb.size(0)
|
78
|
+
losses[i] = epoch_loss / len(data_loader.dataset) # type: ignore
|
79
|
+
return pd.Series(losses, dtype=float)
|
80
|
+
|
81
|
+
|
22
82
|
class MLP(nn.Module):
|
23
83
|
"""Multilayer Perceptron (MLP) for surrogate modeling and neural posterior estimation.
|
24
84
|
|
mxlpy/npe/_torch.py
CHANGED
@@ -18,11 +18,10 @@ from typing import TYPE_CHECKING, Self, cast
|
|
18
18
|
import numpy as np
|
19
19
|
import pandas as pd
|
20
20
|
import torch
|
21
|
-
import tqdm
|
22
21
|
from torch import nn
|
23
22
|
from torch.optim.adam import Adam
|
24
23
|
|
25
|
-
from mxlpy.nn._torch import LSTM, MLP, DefaultDevice
|
24
|
+
from mxlpy.nn._torch import LSTM, MLP, DefaultDevice, train
|
26
25
|
from mxlpy.parallel import Cache
|
27
26
|
from mxlpy.types import AbstractEstimator
|
28
27
|
|
@@ -161,28 +160,16 @@ class TorchSteadyStateTrainer:
|
|
161
160
|
batch_size: Size of mini-batches for training (None for full-batch)
|
162
161
|
|
163
162
|
"""
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
loss_fn=self.loss_fn,
|
175
|
-
)
|
176
|
-
else:
|
177
|
-
losses = _train_batched(
|
178
|
-
approximator=self.approximator,
|
179
|
-
features=features,
|
180
|
-
targets=targets,
|
181
|
-
epochs=epochs,
|
182
|
-
optimizer=self.optimizer,
|
183
|
-
batch_size=batch_size,
|
184
|
-
loss_fn=self.loss_fn,
|
185
|
-
)
|
163
|
+
losses = train(
|
164
|
+
aprox=self.approximator,
|
165
|
+
features=self.features.to_numpy(),
|
166
|
+
targets=self.targets.to_numpy(),
|
167
|
+
epochs=epochs,
|
168
|
+
optimizer=self.optimizer,
|
169
|
+
batch_size=batch_size,
|
170
|
+
loss_fn=self.loss_fn,
|
171
|
+
device=self.device,
|
172
|
+
)
|
186
173
|
|
187
174
|
if len(self.losses) > 0:
|
188
175
|
losses.index += self.losses[-1].index[-1]
|
@@ -260,37 +247,22 @@ class TorchTimeCourseTrainer:
|
|
260
247
|
batch_size: Size of mini-batches for training (None for full-batch)
|
261
248
|
|
262
249
|
"""
|
263
|
-
|
264
|
-
|
250
|
+
losses = train(
|
251
|
+
aprox=self.approximator,
|
252
|
+
features=np.swapaxes(
|
265
253
|
self.features.to_numpy().reshape(
|
266
254
|
(len(self.targets), -1, len(self.features.columns))
|
267
255
|
),
|
268
256
|
axis1=0,
|
269
257
|
axis2=1,
|
270
258
|
),
|
259
|
+
targets=self.targets.to_numpy(),
|
260
|
+
epochs=epochs,
|
261
|
+
optimizer=self.optimizer,
|
262
|
+
batch_size=batch_size,
|
263
|
+
loss_fn=self.loss_fn,
|
271
264
|
device=self.device,
|
272
265
|
)
|
273
|
-
targets = torch.Tensor(self.targets.to_numpy(), device=self.device)
|
274
|
-
|
275
|
-
if batch_size is None:
|
276
|
-
losses = _train_full(
|
277
|
-
approximator=self.approximator,
|
278
|
-
features=features,
|
279
|
-
targets=targets,
|
280
|
-
epochs=epochs,
|
281
|
-
optimizer=self.optimizer,
|
282
|
-
loss_fn=self.loss_fn,
|
283
|
-
)
|
284
|
-
else:
|
285
|
-
losses = _train_batched(
|
286
|
-
approximator=self.approximator,
|
287
|
-
features=features,
|
288
|
-
targets=targets,
|
289
|
-
epochs=epochs,
|
290
|
-
optimizer=self.optimizer,
|
291
|
-
batch_size=batch_size,
|
292
|
-
loss_fn=self.loss_fn,
|
293
|
-
)
|
294
266
|
|
295
267
|
if len(self.losses) > 0:
|
296
268
|
losses.index += self.losses[-1].index[-1]
|
@@ -309,49 +281,6 @@ class TorchTimeCourseTrainer:
|
|
309
281
|
)
|
310
282
|
|
311
283
|
|
312
|
-
def _train_batched(
|
313
|
-
approximator: nn.Module,
|
314
|
-
features: torch.Tensor,
|
315
|
-
targets: torch.Tensor,
|
316
|
-
epochs: int,
|
317
|
-
optimizer: Adam,
|
318
|
-
batch_size: int,
|
319
|
-
loss_fn: LossFn,
|
320
|
-
) -> pd.Series:
|
321
|
-
losses = {}
|
322
|
-
for epoch in tqdm.trange(epochs):
|
323
|
-
permutation = torch.randperm(features.size()[0])
|
324
|
-
epoch_loss = 0
|
325
|
-
for i in range(0, features.size()[0], batch_size):
|
326
|
-
optimizer.zero_grad()
|
327
|
-
indices = permutation[i : i + batch_size]
|
328
|
-
loss = loss_fn(approximator(features[indices]), targets[indices])
|
329
|
-
loss.backward()
|
330
|
-
optimizer.step()
|
331
|
-
epoch_loss += loss.detach().numpy()
|
332
|
-
|
333
|
-
losses[epoch] = epoch_loss / (features.size()[0] / batch_size)
|
334
|
-
return pd.Series(losses, dtype=float)
|
335
|
-
|
336
|
-
|
337
|
-
def _train_full(
|
338
|
-
approximator: nn.Module,
|
339
|
-
features: torch.Tensor,
|
340
|
-
targets: torch.Tensor,
|
341
|
-
epochs: int,
|
342
|
-
optimizer: Adam,
|
343
|
-
loss_fn: LossFn,
|
344
|
-
) -> pd.Series:
|
345
|
-
losses = {}
|
346
|
-
for i in tqdm.trange(epochs):
|
347
|
-
optimizer.zero_grad()
|
348
|
-
loss = loss_fn(approximator(features), targets)
|
349
|
-
loss.backward()
|
350
|
-
optimizer.step()
|
351
|
-
losses[i] = loss.detach().numpy()
|
352
|
-
return pd.Series(losses, dtype=float)
|
353
|
-
|
354
|
-
|
355
284
|
def train_torch_steady_state(
|
356
285
|
features: pd.DataFrame,
|
357
286
|
targets: pd.DataFrame,
|