jinns 0.8.6__py3-none-any.whl → 0.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/loss/_LossODE.py +6 -0
- jinns/loss/_LossPDE.py +18 -18
- jinns/solver/_solve.py +264 -121
- jinns/utils/_containers.py +57 -0
- jinns/validation/__init__.py +1 -0
- jinns/validation/_validation.py +214 -0
- {jinns-0.8.6.dist-info → jinns-0.8.7.dist-info}/METADATA +1 -1
- {jinns-0.8.6.dist-info → jinns-0.8.7.dist-info}/RECORD +11 -8
- {jinns-0.8.6.dist-info → jinns-0.8.7.dist-info}/LICENSE +0 -0
- {jinns-0.8.6.dist-info → jinns-0.8.7.dist-info}/WHEEL +0 -0
- {jinns-0.8.6.dist-info → jinns-0.8.7.dist-info}/top_level.txt +0 -0
jinns/loss/_LossODE.py
CHANGED
|
@@ -19,6 +19,8 @@ from jinns.loss._Losses import (
|
|
|
19
19
|
)
|
|
20
20
|
from jinns.utils._pinn import PINN
|
|
21
21
|
|
|
22
|
+
_LOSS_WEIGHT_KEYS_ODE = ["observations", "dyn_loss", "initial_condition"]
|
|
23
|
+
|
|
22
24
|
|
|
23
25
|
@register_pytree_node_class
|
|
24
26
|
class LossODE:
|
|
@@ -128,6 +130,10 @@ class LossODE:
|
|
|
128
130
|
if self.obs_slice is None:
|
|
129
131
|
self.obs_slice = jnp.s_[...]
|
|
130
132
|
|
|
133
|
+
for k in _LOSS_WEIGHT_KEYS_ODE:
|
|
134
|
+
if k not in self.loss_weights.keys():
|
|
135
|
+
self.loss_weights[k] = 0
|
|
136
|
+
|
|
131
137
|
def __call__(self, *args, **kwargs):
|
|
132
138
|
return self.evaluate(*args, **kwargs)
|
|
133
139
|
|
jinns/loss/_LossPDE.py
CHANGED
|
@@ -31,6 +31,16 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
|
31
31
|
"vonneumann",
|
|
32
32
|
]
|
|
33
33
|
|
|
34
|
+
_LOSS_WEIGHT_KEYS_PDESTATIO = [
|
|
35
|
+
"sobolev",
|
|
36
|
+
"observations",
|
|
37
|
+
"norm_loss",
|
|
38
|
+
"boundary_loss",
|
|
39
|
+
"dyn_loss",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
_LOSS_WEIGHT_KEYS_PDENONSTATIO = _LOSS_WEIGHT_KEYS_PDESTATIO + ["initial_condition"]
|
|
43
|
+
|
|
34
44
|
|
|
35
45
|
@register_pytree_node_class
|
|
36
46
|
class LossPDEAbstract:
|
|
@@ -269,8 +279,8 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
269
279
|
the PINN object
|
|
270
280
|
loss_weights
|
|
271
281
|
a dictionary with values used to ponderate each term in the loss
|
|
272
|
-
function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss
|
|
273
|
-
and `
|
|
282
|
+
function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`,
|
|
283
|
+
`observations` and `sobolev`.
|
|
274
284
|
Note that we can have jnp.arrays with the same dimension of
|
|
275
285
|
`u` which then ponderates each output of `u`
|
|
276
286
|
dynamic_loss
|
|
@@ -441,18 +451,12 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
441
451
|
) # we return a function, that way
|
|
442
452
|
# the order of sobolev_m is static and the conditional in the recursive
|
|
443
453
|
# function is properly set
|
|
444
|
-
self.sobolev_m = self.sobolev_m
|
|
445
454
|
else:
|
|
446
455
|
self.sobolev_reg = None
|
|
447
456
|
|
|
448
|
-
|
|
449
|
-
self.loss_weights
|
|
450
|
-
|
|
451
|
-
if self.omega_boundary_fun is None:
|
|
452
|
-
self.loss_weights["boundary_loss"] = 0
|
|
453
|
-
|
|
454
|
-
if self.sobolev_reg is None:
|
|
455
|
-
self.loss_weights["sobolev"] = 0
|
|
457
|
+
for k in _LOSS_WEIGHT_KEYS_PDESTATIO:
|
|
458
|
+
if k not in self.loss_weights.keys():
|
|
459
|
+
self.loss_weights[k] = 0
|
|
456
460
|
|
|
457
461
|
if (
|
|
458
462
|
isinstance(self.omega_boundary_fun, dict)
|
|
@@ -533,7 +537,6 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
533
537
|
)
|
|
534
538
|
else:
|
|
535
539
|
mse_norm_loss = jnp.array(0.0)
|
|
536
|
-
self.loss_weights["norm_loss"] = 0
|
|
537
540
|
|
|
538
541
|
# boundary part
|
|
539
542
|
params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
|
|
@@ -567,7 +570,6 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
567
570
|
)
|
|
568
571
|
else:
|
|
569
572
|
mse_observation_loss = jnp.array(0.0)
|
|
570
|
-
self.loss_weights["observations"] = 0
|
|
571
573
|
|
|
572
574
|
# Sobolev regularization
|
|
573
575
|
params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
|
|
@@ -582,7 +584,6 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
582
584
|
)
|
|
583
585
|
else:
|
|
584
586
|
mse_sobolev_loss = jnp.array(0.0)
|
|
585
|
-
self.loss_weights["sobolev"] = 0
|
|
586
587
|
|
|
587
588
|
# total loss
|
|
588
589
|
total_loss = (
|
|
@@ -785,8 +786,9 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
785
786
|
else:
|
|
786
787
|
self.sobolev_reg = None
|
|
787
788
|
|
|
788
|
-
|
|
789
|
-
self.loss_weights
|
|
789
|
+
for k in _LOSS_WEIGHT_KEYS_PDENONSTATIO:
|
|
790
|
+
if k not in self.loss_weights.keys():
|
|
791
|
+
self.loss_weights[k] = 0
|
|
790
792
|
|
|
791
793
|
def __call__(self, *args, **kwargs):
|
|
792
794
|
return self.evaluate(*args, **kwargs)
|
|
@@ -924,7 +926,6 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
924
926
|
)
|
|
925
927
|
else:
|
|
926
928
|
mse_observation_loss = jnp.array(0.0)
|
|
927
|
-
self.loss_weights["observations"] = 0
|
|
928
929
|
|
|
929
930
|
# Sobolev regularization
|
|
930
931
|
params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
|
|
@@ -939,7 +940,6 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
939
940
|
)
|
|
940
941
|
else:
|
|
941
942
|
mse_sobolev_loss = jnp.array(0.0)
|
|
942
|
-
self.loss_weights["sobolev"] = 0.0
|
|
943
943
|
|
|
944
944
|
# total loss
|
|
945
945
|
total_loss = (
|
jinns/solver/_solve.py
CHANGED
|
@@ -3,9 +3,9 @@ This modules implements the main `solve()` function of jinns which
|
|
|
3
3
|
handles the optimization process
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import copy
|
|
7
|
+
from functools import partial
|
|
6
8
|
import optax
|
|
7
|
-
from tqdm import tqdm
|
|
8
|
-
from jax_tqdm import scan_tqdm
|
|
9
9
|
import jax
|
|
10
10
|
from jax import jit
|
|
11
11
|
import jax.numpy as jnp
|
|
@@ -19,8 +19,32 @@ from jinns.data._DataGenerators import (
|
|
|
19
19
|
append_param_batch,
|
|
20
20
|
append_obs_batch,
|
|
21
21
|
)
|
|
22
|
+
from jinns.utils._containers import *
|
|
22
23
|
|
|
23
|
-
|
|
24
|
+
|
|
25
|
+
def check_batch_size(other_data, main_data, attr_name):
|
|
26
|
+
if (
|
|
27
|
+
(
|
|
28
|
+
isinstance(main_data, DataGeneratorODE)
|
|
29
|
+
and getattr(other_data, attr_name) != main_data.temporal_batch_size
|
|
30
|
+
)
|
|
31
|
+
or (
|
|
32
|
+
isinstance(main_data, CubicMeshPDEStatio)
|
|
33
|
+
and not isinstance(main_data, CubicMeshPDENonStatio)
|
|
34
|
+
and getattr(other_data, attr_name) != main_data.omega_batch_size
|
|
35
|
+
)
|
|
36
|
+
or (
|
|
37
|
+
isinstance(main_data, CubicMeshPDENonStatio)
|
|
38
|
+
and getattr(other_data, attr_name)
|
|
39
|
+
!= main_data.omega_batch_size * main_data.temporal_batch_size
|
|
40
|
+
)
|
|
41
|
+
):
|
|
42
|
+
raise ValueError(
|
|
43
|
+
"Optional other_data.param_batch_size must be"
|
|
44
|
+
" equal to main_data.temporal_batch_size or main_data.omega_batch_size or"
|
|
45
|
+
" the product of both dependeing on the type of the main"
|
|
46
|
+
" datagenerator"
|
|
47
|
+
)
|
|
24
48
|
|
|
25
49
|
|
|
26
50
|
def solve(
|
|
@@ -35,6 +59,7 @@ def solve(
|
|
|
35
59
|
tracked_params_key_list=None,
|
|
36
60
|
param_data=None,
|
|
37
61
|
obs_data=None,
|
|
62
|
+
validation=None,
|
|
38
63
|
obs_batch_sharding=None,
|
|
39
64
|
):
|
|
40
65
|
"""
|
|
@@ -88,6 +113,20 @@ def solve(
|
|
|
88
113
|
obs_data
|
|
89
114
|
Default None. A DataGeneratorObservations object which can be used to
|
|
90
115
|
sample minibatches of observations
|
|
116
|
+
validation
|
|
117
|
+
Default None. Otherwise, a callable ``eqx.Module`` which implements a
|
|
118
|
+
validation strategy. See documentation of :obj:`~jinns.validation.
|
|
119
|
+
_validation.AbstractValidationModule` for the general interface, and
|
|
120
|
+
:obj:`~jinns.validation._validation.ValidationLoss` for a practical
|
|
121
|
+
implementation of a vanilla validation stategy on a validation set of
|
|
122
|
+
collocation points.
|
|
123
|
+
|
|
124
|
+
**Note**: The ``__call__(self, params)`` method should have
|
|
125
|
+
the latter prescribed signature and return ``(validation [eqx.Module],
|
|
126
|
+
early_stop [bool], validation_criterion [Array])``. It is called every
|
|
127
|
+
``validation.call_every`` iteration. Users are free to design any
|
|
128
|
+
validation strategy of their choice, and to decide on the early
|
|
129
|
+
stopping criterion.
|
|
91
130
|
obs_batch_sharding
|
|
92
131
|
Default None. An optional sharding object to constraint the obs_batch.
|
|
93
132
|
Typically, a SingleDeviceSharding(gpu_device) when obs_data has been
|
|
@@ -114,39 +153,14 @@ def solve(
|
|
|
114
153
|
A dictionary. At each key an array of the values of the parameters
|
|
115
154
|
given in tracked_params_key_list is stored
|
|
116
155
|
"""
|
|
117
|
-
params = init_params
|
|
118
|
-
last_non_nan_params = init_params.copy()
|
|
119
|
-
|
|
120
156
|
if param_data is not None:
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
and obs_data.obs_batch_size != data.temporal_batch_size
|
|
126
|
-
)
|
|
127
|
-
or (
|
|
128
|
-
isinstance(data, CubicMeshPDEStatio)
|
|
129
|
-
and not isinstance(data, CubicMeshPDENonStatio)
|
|
130
|
-
and param_data.param_batch_size != data.omega_batch_size
|
|
131
|
-
and obs_data.obs_batch_size != data.omega_batch_size
|
|
132
|
-
)
|
|
133
|
-
or (
|
|
134
|
-
isinstance(data, CubicMeshPDENonStatio)
|
|
135
|
-
and param_data.param_batch_size
|
|
136
|
-
!= data.omega_batch_size * data.temporal_batch_size
|
|
137
|
-
and obs_data.obs_batch_size
|
|
138
|
-
!= data.omega_batch_size * data.temporal_batch_size
|
|
139
|
-
)
|
|
140
|
-
):
|
|
141
|
-
raise ValueError(
|
|
142
|
-
"Optional param_data.param_batch_size must be"
|
|
143
|
-
" equal to data.temporal_batch_size or data.omega_batch_size or"
|
|
144
|
-
" the product of both dependeing on the type of the main"
|
|
145
|
-
" datagenerator"
|
|
146
|
-
)
|
|
157
|
+
check_batch_size(param_data, data, "param_batch_size")
|
|
158
|
+
|
|
159
|
+
if obs_data is not None:
|
|
160
|
+
check_batch_size(obs_data, data, "obs_batch_size")
|
|
147
161
|
|
|
148
162
|
if opt_state is None:
|
|
149
|
-
opt_state = optimizer.init(
|
|
163
|
+
opt_state = optimizer.init(init_params)
|
|
150
164
|
|
|
151
165
|
# RAR sampling init (ouside scanned function to avoid dynamic slice error)
|
|
152
166
|
# If RAR is not used the _rar_step_*() are juste None and data is unchanged
|
|
@@ -160,7 +174,7 @@ def solve(
|
|
|
160
174
|
), "data.method must be uniform if using seq2seq learning !"
|
|
161
175
|
data, opt_state = initialize_seq2seq(loss, data, seq2seq, opt_state)
|
|
162
176
|
|
|
163
|
-
|
|
177
|
+
train_loss_values = jnp.zeros((n_iter))
|
|
164
178
|
# depending on obs_batch_sharding we will get the simple get_batch or the
|
|
165
179
|
# get_batch with device_put, the latter is not jittable
|
|
166
180
|
get_batch = get_get_batch(obs_batch_sharding)
|
|
@@ -168,68 +182,125 @@ def solve(
|
|
|
168
182
|
# initialize the dict for stored parameter values
|
|
169
183
|
# we need to get a loss_term to init stuff
|
|
170
184
|
batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
171
|
-
_, loss_terms = loss(
|
|
185
|
+
_, loss_terms = loss(init_params, batch_ini)
|
|
172
186
|
if tracked_params_key_list is None:
|
|
173
187
|
tracked_params_key_list = []
|
|
174
|
-
tracked_params = _tracked_parameters(
|
|
188
|
+
tracked_params = _tracked_parameters(init_params, tracked_params_key_list)
|
|
175
189
|
stored_params = jax.tree_util.tree_map(
|
|
176
190
|
lambda tracked_param, param: (
|
|
177
191
|
jnp.zeros((n_iter,) + param.shape) if tracked_param else None
|
|
178
192
|
),
|
|
179
193
|
tracked_params,
|
|
180
|
-
|
|
194
|
+
init_params,
|
|
181
195
|
)
|
|
182
196
|
|
|
183
197
|
# initialize the dict for stored loss values
|
|
184
198
|
stored_loss_terms = jax.tree_util.tree_map(
|
|
185
|
-
lambda
|
|
199
|
+
lambda _: jnp.zeros((n_iter)), loss_terms
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
train_data = DataGeneratorContainer(
|
|
203
|
+
data=data, param_data=param_data, obs_data=obs_data
|
|
204
|
+
)
|
|
205
|
+
optimization = OptimizationContainer(
|
|
206
|
+
params=init_params, last_non_nan_params=init_params.copy(), opt_state=opt_state
|
|
207
|
+
)
|
|
208
|
+
optimization_extra = OptimizationExtraContainer(
|
|
209
|
+
curr_seq=curr_seq,
|
|
210
|
+
seq2seq=seq2seq,
|
|
186
211
|
)
|
|
212
|
+
loss_container = LossContainer(
|
|
213
|
+
stored_loss_terms=stored_loss_terms,
|
|
214
|
+
train_loss_values=train_loss_values,
|
|
215
|
+
)
|
|
216
|
+
stored_objects = StoredObjectContainer(
|
|
217
|
+
stored_params=stored_params,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if validation is not None:
|
|
221
|
+
validation_crit_values = jnp.zeros(n_iter)
|
|
222
|
+
else:
|
|
223
|
+
validation_crit_values = None
|
|
224
|
+
|
|
225
|
+
break_fun = get_break_fun(n_iter)
|
|
187
226
|
|
|
227
|
+
iteration = 0
|
|
188
228
|
carry = (
|
|
189
|
-
|
|
190
|
-
init_params.copy(),
|
|
191
|
-
data,
|
|
192
|
-
curr_seq,
|
|
193
|
-
seq2seq,
|
|
194
|
-
stored_params,
|
|
195
|
-
stored_loss_terms,
|
|
229
|
+
iteration,
|
|
196
230
|
loss,
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
231
|
+
optimization,
|
|
232
|
+
optimization_extra,
|
|
233
|
+
train_data,
|
|
234
|
+
validation,
|
|
235
|
+
loss_container,
|
|
236
|
+
stored_objects,
|
|
237
|
+
validation_crit_values,
|
|
201
238
|
)
|
|
202
239
|
|
|
203
|
-
def one_iteration(carry
|
|
240
|
+
def one_iteration(carry):
|
|
204
241
|
(
|
|
205
|
-
|
|
206
|
-
last_non_nan_params,
|
|
207
|
-
data,
|
|
208
|
-
curr_seq,
|
|
209
|
-
seq2seq,
|
|
210
|
-
stored_params,
|
|
211
|
-
stored_loss_terms,
|
|
242
|
+
i,
|
|
212
243
|
loss,
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
244
|
+
optimization,
|
|
245
|
+
optimization_extra,
|
|
246
|
+
train_data,
|
|
247
|
+
validation,
|
|
248
|
+
loss_container,
|
|
249
|
+
stored_objects,
|
|
250
|
+
validation_crit_values,
|
|
217
251
|
) = carry
|
|
218
|
-
batch, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
219
252
|
|
|
253
|
+
batch, data, param_data, obs_data = get_batch(
|
|
254
|
+
train_data.data, train_data.param_data, train_data.obs_data
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Gradient step
|
|
220
258
|
(
|
|
221
259
|
loss,
|
|
222
|
-
|
|
260
|
+
train_loss_value,
|
|
223
261
|
loss_terms,
|
|
224
262
|
params,
|
|
225
263
|
opt_state,
|
|
226
264
|
last_non_nan_params,
|
|
227
265
|
) = gradient_step(
|
|
228
|
-
loss,
|
|
266
|
+
loss,
|
|
267
|
+
optimizer,
|
|
268
|
+
batch,
|
|
269
|
+
optimization.params,
|
|
270
|
+
optimization.opt_state,
|
|
271
|
+
optimization.last_non_nan_params,
|
|
229
272
|
)
|
|
230
273
|
|
|
231
|
-
# Print loss during optimization
|
|
232
|
-
print_fn(i,
|
|
274
|
+
# Print train loss value during optimization
|
|
275
|
+
print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
|
|
276
|
+
|
|
277
|
+
if validation is not None:
|
|
278
|
+
# there is a jax.lax.cond because we do not necesarily call the
|
|
279
|
+
# validation step every iteration
|
|
280
|
+
(
|
|
281
|
+
validation, # always return `validation` for in-place mutation
|
|
282
|
+
early_stopping,
|
|
283
|
+
validation_criterion,
|
|
284
|
+
) = jax.lax.cond(
|
|
285
|
+
i % validation.call_every == 0,
|
|
286
|
+
lambda operands: operands[0](*operands[1:]), # validation.__call__()
|
|
287
|
+
lambda operands: (
|
|
288
|
+
operands[0],
|
|
289
|
+
False,
|
|
290
|
+
validation_crit_values[i - 1],
|
|
291
|
+
),
|
|
292
|
+
(
|
|
293
|
+
validation, # validation must be in operands
|
|
294
|
+
params,
|
|
295
|
+
),
|
|
296
|
+
)
|
|
297
|
+
# Print validation loss value during optimization
|
|
298
|
+
print_fn(i, validation_criterion, print_loss_every, prefix="[validation] ")
|
|
299
|
+
validation_crit_values = validation_crit_values.at[i].set(
|
|
300
|
+
validation_criterion
|
|
301
|
+
)
|
|
302
|
+
else:
|
|
303
|
+
early_stopping = False
|
|
233
304
|
|
|
234
305
|
# Trigger RAR
|
|
235
306
|
loss, params, data = trigger_rar(
|
|
@@ -238,84 +309,98 @@ def solve(
|
|
|
238
309
|
|
|
239
310
|
# Trigger seq2seq
|
|
240
311
|
loss, params, data, opt_state, curr_seq, seq2seq = trigger_seq2seq(
|
|
241
|
-
i,
|
|
312
|
+
i,
|
|
313
|
+
loss,
|
|
314
|
+
params,
|
|
315
|
+
data,
|
|
316
|
+
opt_state,
|
|
317
|
+
optimization_extra.curr_seq,
|
|
318
|
+
optimization_extra.seq2seq,
|
|
242
319
|
)
|
|
243
320
|
|
|
244
321
|
# save loss value and selected parameters
|
|
245
|
-
stored_params, stored_loss_terms,
|
|
322
|
+
stored_params, stored_loss_terms, train_loss_values = store_loss_and_params(
|
|
246
323
|
i,
|
|
247
324
|
params,
|
|
248
|
-
stored_params,
|
|
249
|
-
stored_loss_terms,
|
|
250
|
-
|
|
251
|
-
|
|
325
|
+
stored_objects.stored_params,
|
|
326
|
+
loss_container.stored_loss_terms,
|
|
327
|
+
loss_container.train_loss_values,
|
|
328
|
+
train_loss_value,
|
|
252
329
|
loss_terms,
|
|
253
330
|
tracked_params,
|
|
254
331
|
)
|
|
332
|
+
i += 1
|
|
333
|
+
|
|
255
334
|
return (
|
|
256
|
-
|
|
257
|
-
last_non_nan_params,
|
|
258
|
-
data,
|
|
259
|
-
curr_seq,
|
|
260
|
-
seq2seq,
|
|
261
|
-
stored_params,
|
|
262
|
-
stored_loss_terms,
|
|
335
|
+
i,
|
|
263
336
|
loss,
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
337
|
+
OptimizationContainer(params, last_non_nan_params, opt_state),
|
|
338
|
+
OptimizationExtraContainer(curr_seq, seq2seq, early_stopping),
|
|
339
|
+
DataGeneratorContainer(data, param_data, obs_data),
|
|
340
|
+
validation,
|
|
341
|
+
LossContainer(stored_loss_terms, train_loss_values),
|
|
342
|
+
StoredObjectContainer(stored_params),
|
|
343
|
+
validation_crit_values,
|
|
344
|
+
)
|
|
269
345
|
|
|
270
|
-
# Main optimization loop. We use the
|
|
271
|
-
# if no mixing devices. Otherwise we use the
|
|
346
|
+
# Main optimization loop. We use the LAX while loop (fully jitted) version
|
|
347
|
+
# if no mixing devices. Otherwise we use the standard while loop. Here devices only
|
|
272
348
|
# concern obs_batch, but it could lead to more complex scheme in the future
|
|
273
349
|
if obs_batch_sharding is not None:
|
|
274
|
-
|
|
275
|
-
carry
|
|
350
|
+
while break_fun(carry):
|
|
351
|
+
carry = one_iteration(carry)
|
|
276
352
|
else:
|
|
277
|
-
carry
|
|
278
|
-
scan_tqdm(n_iter)(one_iteration),
|
|
279
|
-
carry,
|
|
280
|
-
jnp.arange(n_iter),
|
|
281
|
-
)
|
|
353
|
+
carry = jax.lax.while_loop(break_fun, one_iteration, carry)
|
|
282
354
|
|
|
283
355
|
(
|
|
284
|
-
|
|
285
|
-
last_non_nan_params,
|
|
286
|
-
data,
|
|
287
|
-
curr_seq,
|
|
288
|
-
seq2seq,
|
|
289
|
-
stored_params,
|
|
290
|
-
stored_loss_terms,
|
|
356
|
+
i,
|
|
291
357
|
loss,
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
358
|
+
optimization,
|
|
359
|
+
optimization_extra,
|
|
360
|
+
train_data,
|
|
361
|
+
validation,
|
|
362
|
+
loss_container,
|
|
363
|
+
stored_objects,
|
|
364
|
+
validation_crit_values,
|
|
296
365
|
) = carry
|
|
297
366
|
|
|
298
367
|
jax.debug.print(
|
|
299
|
-
"
|
|
300
|
-
i=
|
|
301
|
-
|
|
368
|
+
"Final iteration {i}: train loss value = {train_loss_val}",
|
|
369
|
+
i=i,
|
|
370
|
+
train_loss_val=loss_container.train_loss_values[i - 1],
|
|
302
371
|
)
|
|
372
|
+
if validation is not None:
|
|
373
|
+
jax.debug.print(
|
|
374
|
+
"validation loss value = {validation_loss_val}",
|
|
375
|
+
validation_loss_val=validation_crit_values[i - 1],
|
|
376
|
+
)
|
|
303
377
|
|
|
378
|
+
if validation is None:
|
|
379
|
+
return (
|
|
380
|
+
optimization.last_non_nan_params,
|
|
381
|
+
loss_container.train_loss_values,
|
|
382
|
+
loss_container.stored_loss_terms,
|
|
383
|
+
train_data.data,
|
|
384
|
+
loss,
|
|
385
|
+
optimization.opt_state,
|
|
386
|
+
stored_objects.stored_params,
|
|
387
|
+
)
|
|
304
388
|
return (
|
|
305
|
-
last_non_nan_params,
|
|
306
|
-
|
|
307
|
-
stored_loss_terms,
|
|
308
|
-
data,
|
|
389
|
+
optimization.last_non_nan_params,
|
|
390
|
+
loss_container.train_loss_values,
|
|
391
|
+
loss_container.stored_loss_terms,
|
|
392
|
+
train_data.data,
|
|
309
393
|
loss,
|
|
310
|
-
opt_state,
|
|
311
|
-
stored_params,
|
|
394
|
+
optimization.opt_state,
|
|
395
|
+
stored_objects.stored_params,
|
|
396
|
+
validation_crit_values,
|
|
312
397
|
)
|
|
313
398
|
|
|
314
399
|
|
|
315
400
|
@partial(jit, static_argnames=["optimizer"])
|
|
316
401
|
def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params):
|
|
317
402
|
"""
|
|
318
|
-
|
|
403
|
+
optimizer cannot be jit-ted.
|
|
319
404
|
"""
|
|
320
405
|
value_grad_loss = jax.value_and_grad(loss, has_aux=True)
|
|
321
406
|
(loss_val, loss_terms), grads = value_grad_loss(params, batch)
|
|
@@ -340,14 +425,14 @@ def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params
|
|
|
340
425
|
)
|
|
341
426
|
|
|
342
427
|
|
|
343
|
-
@jit
|
|
344
|
-
def print_fn(i, loss_val, print_loss_every):
|
|
428
|
+
@partial(jit, static_argnames=["prefix"])
|
|
429
|
+
def print_fn(i, loss_val, print_loss_every, prefix=""):
|
|
345
430
|
# note that if the following is not jitted in the main lor loop, it is
|
|
346
431
|
# super slow
|
|
347
432
|
_ = jax.lax.cond(
|
|
348
433
|
i % print_loss_every == 0,
|
|
349
434
|
lambda _: jax.debug.print(
|
|
350
|
-
"Iteration {i}: loss value = {loss_val}",
|
|
435
|
+
prefix + "Iteration {i}: loss value = {loss_val}",
|
|
351
436
|
i=i,
|
|
352
437
|
loss_val=loss_val,
|
|
353
438
|
),
|
|
@@ -362,8 +447,8 @@ def store_loss_and_params(
|
|
|
362
447
|
params,
|
|
363
448
|
stored_params,
|
|
364
449
|
stored_loss_terms,
|
|
365
|
-
|
|
366
|
-
|
|
450
|
+
train_loss_values,
|
|
451
|
+
train_loss_val,
|
|
367
452
|
loss_terms,
|
|
368
453
|
tracked_params,
|
|
369
454
|
):
|
|
@@ -384,8 +469,66 @@ def store_loss_and_params(
|
|
|
384
469
|
loss_terms,
|
|
385
470
|
)
|
|
386
471
|
|
|
387
|
-
|
|
388
|
-
return stored_params, stored_loss_terms,
|
|
472
|
+
train_loss_values = train_loss_values.at[i].set(train_loss_val)
|
|
473
|
+
return (stored_params, stored_loss_terms, train_loss_values)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def get_break_fun(n_iter):
|
|
477
|
+
"""
|
|
478
|
+
Wrapper to get the break_fun with appropriate `n_iter`
|
|
479
|
+
"""
|
|
480
|
+
|
|
481
|
+
@jit
|
|
482
|
+
def break_fun(carry):
|
|
483
|
+
"""
|
|
484
|
+
Function to break from the main optimization loop
|
|
485
|
+
We check several conditions
|
|
486
|
+
"""
|
|
487
|
+
|
|
488
|
+
def stop_while_loop(msg):
|
|
489
|
+
"""
|
|
490
|
+
Note that the message is wrapped in the jax.lax.cond because a
|
|
491
|
+
string is not a valid JAX type that can be fed into the operands
|
|
492
|
+
"""
|
|
493
|
+
jax.debug.print(f"Stopping main optimization loop, cause: {msg}")
|
|
494
|
+
return False
|
|
495
|
+
|
|
496
|
+
def continue_while_loop(_):
|
|
497
|
+
return True
|
|
498
|
+
|
|
499
|
+
(i, _, optimization, optimization_extra, _, _, _, _, _) = carry
|
|
500
|
+
|
|
501
|
+
# Condition 1
|
|
502
|
+
bool_max_iter = jax.lax.cond(
|
|
503
|
+
i >= n_iter,
|
|
504
|
+
lambda _: stop_while_loop("max iteration is reached"),
|
|
505
|
+
continue_while_loop,
|
|
506
|
+
None,
|
|
507
|
+
)
|
|
508
|
+
# Condition 2
|
|
509
|
+
bool_nan_in_params = jax.lax.cond(
|
|
510
|
+
_check_nan_in_pytree(optimization.params),
|
|
511
|
+
lambda _: stop_while_loop(
|
|
512
|
+
"NaN values in parameters " "(returning last non NaN values)"
|
|
513
|
+
),
|
|
514
|
+
continue_while_loop,
|
|
515
|
+
None,
|
|
516
|
+
)
|
|
517
|
+
# Condition 3
|
|
518
|
+
bool_early_stopping = jax.lax.cond(
|
|
519
|
+
optimization_extra.early_stopping,
|
|
520
|
+
lambda _: stop_while_loop("early stopping"),
|
|
521
|
+
continue_while_loop,
|
|
522
|
+
_,
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
# stop when one of the cond to continue is False
|
|
526
|
+
return jax.tree_util.tree_reduce(
|
|
527
|
+
lambda x, y: jnp.logical_and(jnp.array(x), jnp.array(y)),
|
|
528
|
+
(bool_max_iter, bool_nan_in_params, bool_early_stopping),
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
return break_fun
|
|
389
532
|
|
|
390
533
|
|
|
391
534
|
def get_get_batch(obs_batch_sharding):
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NamedTuples definition
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Union, NamedTuple
|
|
6
|
+
from jaxtyping import PyTree
|
|
7
|
+
from jax.typing import ArrayLike
|
|
8
|
+
import optax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
from jinns.loss._LossODE import LossODE, SystemLossODE
|
|
11
|
+
from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
|
|
12
|
+
from jinns.data._DataGenerators import (
|
|
13
|
+
DataGeneratorODE,
|
|
14
|
+
CubicMeshPDEStatio,
|
|
15
|
+
CubicMeshPDENonStatio,
|
|
16
|
+
DataGeneratorParameter,
|
|
17
|
+
DataGeneratorObservations,
|
|
18
|
+
DataGeneratorObservationsMultiPINNs,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DataGeneratorContainer(NamedTuple):
|
|
23
|
+
data: Union[DataGeneratorODE, CubicMeshPDEStatio, CubicMeshPDENonStatio]
|
|
24
|
+
param_data: Union[DataGeneratorParameter, None] = None
|
|
25
|
+
obs_data: Union[
|
|
26
|
+
DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
|
|
27
|
+
] = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ValidationContainer(NamedTuple):
|
|
31
|
+
loss: Union[
|
|
32
|
+
LossODE, SystemLossODE, LossPDEStatio, LossPDENonStatio, SystemLossPDE, None
|
|
33
|
+
]
|
|
34
|
+
data: DataGeneratorContainer
|
|
35
|
+
hyperparams: PyTree = None
|
|
36
|
+
loss_values: Union[ArrayLike, None] = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class OptimizationContainer(NamedTuple):
|
|
40
|
+
params: dict
|
|
41
|
+
last_non_nan_params: dict
|
|
42
|
+
opt_state: optax.OptState
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class OptimizationExtraContainer(NamedTuple):
|
|
46
|
+
curr_seq: int
|
|
47
|
+
seq2seq: Union[dict, None]
|
|
48
|
+
early_stopping: bool = False
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class LossContainer(NamedTuple):
|
|
52
|
+
stored_loss_terms: dict
|
|
53
|
+
train_loss_values: ArrayLike
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class StoredObjectContainer(NamedTuple):
|
|
57
|
+
stored_params: Union[list, None]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ._validation import AbstractValidationModule, ValidationLoss
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implements some validation functions and their associated hyperparameter
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import abc
|
|
7
|
+
from typing import Union
|
|
8
|
+
import equinox as eqx
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from jaxtyping import Array, Bool, PyTree, Int
|
|
12
|
+
import jinns
|
|
13
|
+
import jinns.data
|
|
14
|
+
from jinns.loss import LossODE, LossPDENonStatio, LossPDEStatio
|
|
15
|
+
from jinns.data._DataGenerators import (
|
|
16
|
+
DataGeneratorODE,
|
|
17
|
+
CubicMeshPDEStatio,
|
|
18
|
+
CubicMeshPDENonStatio,
|
|
19
|
+
DataGeneratorParameter,
|
|
20
|
+
DataGeneratorObservations,
|
|
21
|
+
DataGeneratorObservationsMultiPINNs,
|
|
22
|
+
append_obs_batch,
|
|
23
|
+
append_param_batch,
|
|
24
|
+
)
|
|
25
|
+
import jinns.loss
|
|
26
|
+
|
|
27
|
+
# Using eqx Module for the DataClass + Pytree inheritance
|
|
28
|
+
# Abstract class and abstract/final pattern is used
|
|
29
|
+
# see : https://docs.kidger.site/equinox/pattern/
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AbstractValidationModule(eqx.Module):
|
|
33
|
+
"""Abstract class representing interface for any validation module. It must
|
|
34
|
+
1. have a ``call_every`` attribute.
|
|
35
|
+
2. implement a ``__call__`` returning ``(AbstractValidationModule, Bool, Array)``
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
call_every: eqx.AbstractVar[Int] # Mandatory for all validation step,
|
|
39
|
+
# it tells that the validation step is performed every call_every
|
|
40
|
+
# iterations.
|
|
41
|
+
|
|
42
|
+
@abc.abstractmethod
|
|
43
|
+
def __call__(
|
|
44
|
+
self, params: PyTree
|
|
45
|
+
) -> tuple["AbstractValidationModule", Bool, Array]:
|
|
46
|
+
raise NotImplementedError
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ValidationLoss(AbstractValidationModule):
|
|
50
|
+
"""
|
|
51
|
+
Implementation of a vanilla validation module returning the PINN loss
|
|
52
|
+
on a validation set of collocation points. This can be used as a baseline
|
|
53
|
+
for more complicated validation strategy.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
loss: Union[callable, LossODE, LossPDEStatio, LossPDENonStatio] = eqx.field(
|
|
57
|
+
converter=copy.deepcopy
|
|
58
|
+
)
|
|
59
|
+
validation_data: Union[DataGeneratorODE, CubicMeshPDEStatio, CubicMeshPDENonStatio]
|
|
60
|
+
validation_param_data: Union[DataGeneratorParameter, None] = None
|
|
61
|
+
validation_obs_data: Union[
|
|
62
|
+
DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
|
|
63
|
+
] = None
|
|
64
|
+
call_every: Int = 250 # concrete typing
|
|
65
|
+
early_stopping: Bool = True # globally control if early stopping happens
|
|
66
|
+
|
|
67
|
+
patience: Union[Int] = 10
|
|
68
|
+
best_val_loss: Array = eqx.field(
|
|
69
|
+
converter=jnp.asarray, default_factory=lambda: jnp.array(jnp.inf)
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
counter: Array = eqx.field(
|
|
73
|
+
converter=jnp.asarray, default_factory=lambda: jnp.array(0.0)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def __call__(self, params) -> tuple["ValidationLoss", Bool, Array]:
|
|
77
|
+
# do in-place mutation
|
|
78
|
+
val_batch = self.validation_data.get_batch()
|
|
79
|
+
if self.validation_param_data is not None:
|
|
80
|
+
val_batch = append_param_batch(
|
|
81
|
+
val_batch, self.validation_param_data.get_batch()
|
|
82
|
+
)
|
|
83
|
+
if self.validation_obs_data is not None:
|
|
84
|
+
val_batch = append_obs_batch(
|
|
85
|
+
val_batch, self.validation_obs_data.get_batch()
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
validation_loss_value, _ = self.loss(params, val_batch)
|
|
89
|
+
(counter, best_val_loss) = jax.lax.cond(
|
|
90
|
+
validation_loss_value < self.best_val_loss,
|
|
91
|
+
lambda _: (jnp.array(0.0), validation_loss_value), # reset
|
|
92
|
+
lambda operands: (operands[0] + 1, operands[1]), # increment
|
|
93
|
+
(self.counter, self.best_val_loss),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# use eqx.tree_at to update attributes
|
|
97
|
+
# (https://github.com/patrick-kidger/equinox/issues/396)
|
|
98
|
+
new = eqx.tree_at(lambda t: t.counter, self, counter)
|
|
99
|
+
new = eqx.tree_at(lambda t: t.best_val_loss, new, best_val_loss)
|
|
100
|
+
|
|
101
|
+
bool_early_stopping = jax.lax.cond(
|
|
102
|
+
jnp.logical_and(
|
|
103
|
+
jnp.array(self.counter == self.patience),
|
|
104
|
+
jnp.array(self.early_stopping),
|
|
105
|
+
),
|
|
106
|
+
lambda _: True,
|
|
107
|
+
lambda _: False,
|
|
108
|
+
None,
|
|
109
|
+
)
|
|
110
|
+
# return `new` cause no in-place modification of the eqx.Module
|
|
111
|
+
return (new, bool_early_stopping, validation_loss_value)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
if __name__ == "__main__":
|
|
115
|
+
import jax
|
|
116
|
+
import jax.numpy as jnp
|
|
117
|
+
import jax.random as random
|
|
118
|
+
from jinns.loss import BurgerEquation
|
|
119
|
+
|
|
120
|
+
key = random.PRNGKey(1)
|
|
121
|
+
key, subkey = random.split(key)
|
|
122
|
+
|
|
123
|
+
n = 50
|
|
124
|
+
nb = 2 * 2 * 10
|
|
125
|
+
nt = 10
|
|
126
|
+
omega_batch_size = 10
|
|
127
|
+
omega_border_batch_size = 10
|
|
128
|
+
temporal_batch_size = 4
|
|
129
|
+
dim = 1
|
|
130
|
+
xmin = 0
|
|
131
|
+
xmax = 1
|
|
132
|
+
tmin, tmax = 0, 1
|
|
133
|
+
method = "uniform"
|
|
134
|
+
|
|
135
|
+
val_data = jinns.data.CubicMeshPDENonStatio(
|
|
136
|
+
subkey,
|
|
137
|
+
n,
|
|
138
|
+
nb,
|
|
139
|
+
nt,
|
|
140
|
+
omega_batch_size,
|
|
141
|
+
omega_border_batch_size,
|
|
142
|
+
temporal_batch_size,
|
|
143
|
+
dim,
|
|
144
|
+
(xmin,),
|
|
145
|
+
(xmax,),
|
|
146
|
+
tmin,
|
|
147
|
+
tmax,
|
|
148
|
+
method,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
eqx_list = [
|
|
152
|
+
[eqx.nn.Linear, 2, 50],
|
|
153
|
+
[jax.nn.tanh],
|
|
154
|
+
[eqx.nn.Linear, 50, 50],
|
|
155
|
+
[jax.nn.tanh],
|
|
156
|
+
[eqx.nn.Linear, 50, 50],
|
|
157
|
+
[jax.nn.tanh],
|
|
158
|
+
[eqx.nn.Linear, 50, 50],
|
|
159
|
+
[jax.nn.tanh],
|
|
160
|
+
[eqx.nn.Linear, 50, 50],
|
|
161
|
+
[jax.nn.tanh],
|
|
162
|
+
[eqx.nn.Linear, 50, 2],
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
key, subkey = random.split(key)
|
|
166
|
+
u = jinns.utils.create_PINN(
|
|
167
|
+
subkey, eqx_list, "nonstatio_PDE", 2, slice_solution=jnp.s_[:1]
|
|
168
|
+
)
|
|
169
|
+
init_nn_params = u.init_params()
|
|
170
|
+
|
|
171
|
+
dyn_loss = BurgerEquation()
|
|
172
|
+
loss_weights = {"dyn_loss": 1, "boundary_loss": 10, "observations": 10}
|
|
173
|
+
|
|
174
|
+
key, subkey = random.split(key)
|
|
175
|
+
loss = jinns.loss.LossPDENonStatio(
|
|
176
|
+
u=u,
|
|
177
|
+
loss_weights=loss_weights,
|
|
178
|
+
dynamic_loss=dyn_loss,
|
|
179
|
+
norm_key=subkey,
|
|
180
|
+
norm_borders=(-1, 1),
|
|
181
|
+
)
|
|
182
|
+
print(id(loss))
|
|
183
|
+
validation = ValidationLoss(
|
|
184
|
+
call_every=250,
|
|
185
|
+
early_stopping=True,
|
|
186
|
+
patience=1000,
|
|
187
|
+
loss=loss,
|
|
188
|
+
validation_data=val_data,
|
|
189
|
+
validation_param_data=None,
|
|
190
|
+
)
|
|
191
|
+
print(id(validation.loss) is not id(loss)) # should be True (deepcopy)
|
|
192
|
+
|
|
193
|
+
init_params = {"nn_params": init_nn_params, "eq_params": {"nu": 1.0}}
|
|
194
|
+
|
|
195
|
+
print(validation.loss is loss)
|
|
196
|
+
loss.evaluate(init_params, val_data.get_batch())
|
|
197
|
+
print(loss.norm_key)
|
|
198
|
+
print("Call validation once")
|
|
199
|
+
validation, _, _ = validation(init_params)
|
|
200
|
+
print(validation.loss is loss)
|
|
201
|
+
print(validation.loss.norm_key == loss.norm_key)
|
|
202
|
+
print("Crate new pytree from validation and call it once")
|
|
203
|
+
new_val = eqx.tree_at(lambda t: t.counter, validation, jnp.array(3.0))
|
|
204
|
+
print(validation.loss is new_val.loss) # FALSE
|
|
205
|
+
# test if attribute have been modified
|
|
206
|
+
new_val, _, _ = new_val(init_params)
|
|
207
|
+
print(f"{new_val.loss is loss=}")
|
|
208
|
+
print(f"{loss.norm_key=}")
|
|
209
|
+
print(f"{validation.loss.norm_key=}")
|
|
210
|
+
print(f"{new_val.loss.norm_key=}")
|
|
211
|
+
print(f"{new_val.loss.norm_key == loss.norm_key=}")
|
|
212
|
+
print(f"{new_val.loss.norm_key == validation.loss.norm_key=}")
|
|
213
|
+
print(new_val.counter)
|
|
214
|
+
print(validation.counter)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.7
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -6,8 +6,8 @@ jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awi
|
|
|
6
6
|
jinns/experimental/_diffrax_solver.py,sha256=sLT22byqh-6015_fhe1xtMWlFOYcCjzYKET4sLhA9R4,6818
|
|
7
7
|
jinns/loss/_DynamicLoss.py,sha256=L4CVmmF0rTPbHntgqsLLHlnrlQgLHsetUocpJm7ZYag,27461
|
|
8
8
|
jinns/loss/_DynamicLossAbstract.py,sha256=kTQlhLx7SBuH5dIDmYaE79sVHUZt1nUFa8LxPU5IHhM,8504
|
|
9
|
-
jinns/loss/_LossODE.py,sha256=
|
|
10
|
-
jinns/loss/_LossPDE.py,sha256=
|
|
9
|
+
jinns/loss/_LossODE.py,sha256=b9doBHoQwYvlgpqzrNO4dOaTN87LRvjHtHbz9bMoH7E,22119
|
|
10
|
+
jinns/loss/_LossPDE.py,sha256=purAEtc0e71kv9XnZUT-a7MrkDAkM_3tTI4xJPu6fH4,61629
|
|
11
11
|
jinns/loss/_Losses.py,sha256=XOL3MFiKEd3ndsc78Qnpi1vbgR0B2HaAWOGGW2meDM8,11190
|
|
12
12
|
jinns/loss/__init__.py,sha256=pFNYUxns-NPXBFdqrEVSiXkQLfCtKw-t2trlhvLzpYE,355
|
|
13
13
|
jinns/loss/_boundary_conditions.py,sha256=YfSnLZ25hXqQ5KWAuxOrWSKkf_oBqAc9GQV4z7MjWyQ,17434
|
|
@@ -15,8 +15,9 @@ jinns/loss/_operators.py,sha256=zDGJqYqeYH7xd-4dtGX9PS-pf0uSOpUUXGo5SVjIJ4o,1106
|
|
|
15
15
|
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
16
|
jinns/solver/_rar.py,sha256=K-0y1-ofOAo1n_Ea3QShSGCGKVYTwiaE_Bz9-DZMJm8,14525
|
|
17
17
|
jinns/solver/_seq2seq.py,sha256=FL-42hTgmVl7O3hHh1ccFVw2bT8bW82hvlDRz971Chk,5620
|
|
18
|
-
jinns/solver/_solve.py,sha256=
|
|
18
|
+
jinns/solver/_solve.py,sha256=mGi0zaT_fK_QpBjTxof5Ix4mmfmnPi66CNJ3GQFZuo4,19099
|
|
19
19
|
jinns/utils/__init__.py,sha256=44ms5UR6vMw3Nf6u4RCAzPFs4fom_YbBnH9mfne8m6k,313
|
|
20
|
+
jinns/utils/_containers.py,sha256=eYD277fO7X4EfX7PUFCCl69r3JBfh1sCfq8LkL5gd6o,1495
|
|
20
21
|
jinns/utils/_hyperpinn.py,sha256=93hbiATdp5W4l1cu9Oe6O2c45o-ZF_z2u6FzNLyjnm4,10878
|
|
21
22
|
jinns/utils/_optim.py,sha256=550kxH75TL30o1iKx1swJyP0KqyUPsJ7-imL1w65Qd0,4444
|
|
22
23
|
jinns/utils/_pinn.py,sha256=mhA4-3PazyQTbWIx9oLaNwL0QDe8ZIBhbiy5J3kwa4I,9471
|
|
@@ -24,8 +25,10 @@ jinns/utils/_save_load.py,sha256=qgZ23nUcB8-B5IZ2guuUWC4M7r5Lxd_Ms3staScdyJo,566
|
|
|
24
25
|
jinns/utils/_spinn.py,sha256=SzOUt1KHtB9QOpghpvitnXN-KEqXUXbvabC5k0TnKEo,7793
|
|
25
26
|
jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
|
|
26
27
|
jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
|
|
27
|
-
jinns
|
|
28
|
-
jinns
|
|
29
|
-
jinns-0.8.
|
|
30
|
-
jinns-0.8.
|
|
31
|
-
jinns-0.8.
|
|
28
|
+
jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
|
|
29
|
+
jinns/validation/_validation.py,sha256=KfetbzB0xTNdBcYLwFWjEtP63Tf9wJirlhgqLTJDyy4,6761
|
|
30
|
+
jinns-0.8.7.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
31
|
+
jinns-0.8.7.dist-info/METADATA,sha256=L0P7JvMGKrJHx9OjrtFsmNKEwdKA_RlufAbOBf5l10I,2482
|
|
32
|
+
jinns-0.8.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
33
|
+
jinns-0.8.7.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
34
|
+
jinns-0.8.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|