jinns 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/solver/_rar.py CHANGED
@@ -11,6 +11,41 @@ from jinns.loss._LossODE import LossODE, SystemLossODE
11
11
  from jinns.loss._DynamicLossAbstract import PDEStatio
12
12
 
13
13
  from functools import partial
14
+ from jinns.utils._hyperpinn import HYPERPINN
15
+ from jinns.utils._spinn import SPINN
16
+
17
+
18
+ def _proceed_to_rar(data, i):
19
+ """Utilility function with various check to ensure we can proceed with the rar_step.
20
+ Return True if yes, and False otherwise"""
21
+
22
+ # Overall checks (universal for any data generator)
23
+ check_list = [
24
+ # check if burn-in period has ended
25
+ data.rar_parameters["start_iter"] <= i,
26
+ # check if enough iterations since last points added
27
+ (data.rar_parameters["update_every"] - 1) == data.rar_iter_from_last_sampling,
28
+ ]
29
+
30
+ # Memory allocation checks (depends on the type of DataGenerator)
31
+ # check if we still have room to append new collocation points in the
32
+ # allocated jnp.array (can concern `data.p_times` or `p_omega`)
33
+ if isinstance(data, DataGeneratorODE) or isinstance(data, CubicMeshPDENonStatio):
34
+ check_list.append(
35
+ data.rar_parameters["selected_sample_size_times"]
36
+ <= jnp.count_nonzero(data.p_times == 0),
37
+ )
38
+
39
+ if isinstance(data, CubicMeshPDEStatio) or isinstance(data, CubicMeshPDENonStatio):
40
+ # for now the above check are redundants but there may be a time when
41
+ # we drop inheritence
42
+ check_list.append(
43
+ data.rar_parameters["selected_sample_size_omega"]
44
+ <= jnp.count_nonzero(data.p_omega == 0),
45
+ )
46
+
47
+ proceed = jnp.all(jnp.array(check_list))
48
+ return proceed
14
49
 
15
50
 
16
51
  @partial(jax.jit, static_argnames=["_rar_step_true", "_rar_step_false"])
@@ -22,21 +57,7 @@ def trigger_rar(i, loss, params, data, _rar_step_true, _rar_step_false):
22
57
  else:
23
58
  # update `data` according to rar scheme.
24
59
  data = jax.lax.cond(
25
- jnp.all(
26
- jnp.array(
27
- [
28
- # check if enough it since last points added
29
- data.rar_parameters["update_rate"]
30
- == data.rar_iter_from_last_sampling,
31
- # check if burn in period has ended
32
- data.rar_parameters["start_iter"] < i,
33
- # check if we still have room to append new
34
- # collocation points in the allocated jnp array
35
- data.rar_parameters["selected_sample_size"]
36
- <= jnp.count_nonzero(data.p == 0),
37
- ]
38
- )
39
- ),
60
+ _proceed_to_rar(data, i),
40
61
  _rar_step_true,
41
62
  _rar_step_false,
42
63
  (loss, params, data, i),
@@ -49,13 +70,37 @@ def init_rar(data):
49
70
  Separated from the main rar, because the initialization to get _true and
50
71
  _false cannot be jit-ted.
51
72
  """
73
+ # NOTE if a user misspell some entry of ``rar_parameters`` the error
74
+ # risks to be a bit obscure but it should be ok.
52
75
  if data.rar_parameters is None:
53
76
  _rar_step_true, _rar_step_false = None, None
54
77
  else:
55
- _rar_step_true, _rar_step_false = _rar_step_init(
56
- data.rar_parameters["sample_size"],
57
- data.rar_parameters["selected_sample_size"],
58
- )
78
+ if isinstance(data, DataGeneratorODE):
79
+ # In this case we only need rar parameters related to `times`
80
+ _rar_step_true, _rar_step_false = _rar_step_init(
81
+ data.rar_parameters["sample_size_times"],
82
+ data.rar_parameters["selected_sample_size_times"],
83
+ )
84
+ elif isinstance(data, CubicMeshPDENonStatio):
85
+ # In this case we need rar parameters related to both `times`
86
+ # and`omega`
87
+ _rar_step_true, _rar_step_false = _rar_step_init(
88
+ (
89
+ data.rar_parameters["sample_size_times"],
90
+ data.rar_parameters["sample_size_omega"],
91
+ ),
92
+ (
93
+ data.rar_parameters["selected_sample_size_times"],
94
+ data.rar_parameters["selected_sample_size_omega"],
95
+ ),
96
+ )
97
+ elif isinstance(data, CubicMeshPDEStatio):
98
+ # In this case we only need rar parameters related to `omega`
99
+ _rar_step_true, _rar_step_false = _rar_step_init(
100
+ data.rar_parameters["sample_size_omega"],
101
+ data.rar_parameters["selected_sample_size_omega"],
102
+ )
103
+
59
104
  data.rar_parameters["iter_from_last_sampling"] = 0
60
105
 
61
106
  return data, _rar_step_true, _rar_step_false
@@ -64,7 +109,7 @@ def init_rar(data):
64
109
  def _rar_step_init(sample_size, selected_sample_size):
65
110
  """
66
111
  This is a wrapper because the sampling size and
67
- selected_sample_size, must be treated static
112
+ selected_sample_size, must be treated as static
68
113
  in order to slice. So they must be set before jitting and not with the jitted
69
114
  dictionary values rar["test_points_nb"] and rar["added_points_nb"]
70
115
 
@@ -72,16 +117,10 @@ def _rar_step_init(sample_size, selected_sample_size):
72
117
  """
73
118
 
74
119
  def rar_step_true(operands):
75
- """
76
- Note: in all generality, we would need a stop gradient operator around
77
- these dynamic_loss evaluations that follow which produce weights for
78
- sampling. However, they appear through a argsort and sampling
79
- operations which definitly kills gradient flows
80
- """
81
120
  loss, params, data, i = operands
82
121
 
83
122
  if isinstance(data, DataGeneratorODE):
84
- s = data.sample_in_time_domain(sample_size)
123
+ new_omega_samples = data.sample_in_time_domain(sample_size)
85
124
 
86
125
  # We can have different types of Loss
87
126
  if isinstance(loss, LossODE):
@@ -90,7 +129,7 @@ def _rar_step_init(sample_size, selected_sample_size):
90
129
  (0),
91
130
  0,
92
131
  )
93
- dyn_on_s = v_dyn_loss(s)
132
+ dyn_on_s = v_dyn_loss(new_omega_samples)
94
133
  if dyn_on_s.ndim > 1:
95
134
  mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
96
135
  else:
@@ -106,7 +145,7 @@ def _rar_step_init(sample_size, selected_sample_size):
106
145
  (0),
107
146
  0,
108
147
  )
109
- dyn_on_s = v_dyn_loss(s)
148
+ dyn_on_s = v_dyn_loss(new_omega_samples)
110
149
  if dyn_on_s.ndim > 1:
111
150
  mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
112
151
  else:
@@ -118,9 +157,7 @@ def _rar_step_init(sample_size, selected_sample_size):
118
157
  (mse_on_s.shape[0] - selected_sample_size,),
119
158
  (selected_sample_size,),
120
159
  )
121
- higher_residual_points = s[higher_residual_idx]
122
-
123
- data.rar_iter_from_last_sampling = 0
160
+ higher_residual_points = new_omega_samples[higher_residual_idx]
124
161
 
125
162
  ## add the new points in times
126
163
  # start indices of update can be dynamic but the the shape (length)
@@ -135,7 +172,7 @@ def _rar_step_init(sample_size, selected_sample_size):
135
172
  ## points are non-zero
136
173
  new_proba = 1 / (data.nt_start + data.rar_iter_nb * selected_sample_size)
137
174
  # the next work because nt_start is static
138
- data.p = data.p.at[: data.nt_start].set(new_proba)
175
+ data.p_times = data.p_times.at[: data.nt_start].set(new_proba)
139
176
 
140
177
  # the next requires a fori_loop because the range is dynamic
141
178
  def update_slices(i, p):
@@ -147,16 +184,14 @@ def _rar_step_init(sample_size, selected_sample_size):
147
184
 
148
185
  data.rar_iter_nb += 1
149
186
 
150
- data.p = jax.lax.fori_loop(0, data.rar_iter_nb, update_slices, data.p)
151
-
152
- # NOTE must return data to be correctly updated because we cannot
153
- # have side effects in this function that will be jitted
154
- return data
187
+ data.p_times = jax.lax.fori_loop(
188
+ 0, data.rar_iter_nb, update_slices, data.p_times
189
+ )
155
190
 
156
191
  elif isinstance(data, CubicMeshPDEStatio) and not isinstance(
157
192
  data, CubicMeshPDENonStatio
158
193
  ):
159
- s = data.sample_in_omega_domain(sample_size)
194
+ new_omega_samples = data.sample_in_omega_domain(sample_size)
160
195
 
161
196
  # We can have different types of Loss
162
197
  if isinstance(loss, LossPDEStatio):
@@ -169,7 +204,7 @@ def _rar_step_init(sample_size, selected_sample_size):
169
204
  (0),
170
205
  0,
171
206
  )
172
- dyn_on_s = v_dyn_loss(s)
207
+ dyn_on_s = v_dyn_loss(new_omega_samples)
173
208
  if dyn_on_s.ndim > 1:
174
209
  mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
175
210
  else:
@@ -185,7 +220,7 @@ def _rar_step_init(sample_size, selected_sample_size):
185
220
  0,
186
221
  0,
187
222
  )
188
- dyn_on_s = v_dyn_loss(s)
223
+ dyn_on_s = v_dyn_loss(new_omega_samples)
189
224
  if dyn_on_s.ndim > 1:
190
225
  mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
191
226
  else:
@@ -197,12 +232,10 @@ def _rar_step_init(sample_size, selected_sample_size):
197
232
  (mse_on_s.shape[0] - selected_sample_size,),
198
233
  (selected_sample_size,),
199
234
  )
200
- higher_residual_points = s[higher_residual_idx]
235
+ higher_residual_points = new_omega_samples[higher_residual_idx]
201
236
 
202
- data.rar_iter_from_last_sampling = 0
203
-
204
- ## add the new points in times
205
- # start indices of update can be dynamic but the the shape (length)
237
+ ## add the new points in omega
238
+ # start indices of update can be dynamic but not the shape (length)
206
239
  # of the slice
207
240
  data.omega = jax.lax.dynamic_update_slice(
208
241
  data.omega,
@@ -214,7 +247,7 @@ def _rar_step_init(sample_size, selected_sample_size):
214
247
  ## points are non-zero
215
248
  new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
216
249
  # the next work because n_start is static
217
- data.p = data.p.at[: data.n_start].set(new_proba)
250
+ data.p_omega = data.p_omega.at[: data.n_start].set(new_proba)
218
251
 
219
252
  # the next requires a fori_loop because the range is dynamic
220
253
  def update_slices(i, p):
@@ -226,145 +259,169 @@ def _rar_step_init(sample_size, selected_sample_size):
226
259
 
227
260
  data.rar_iter_nb += 1
228
261
 
229
- data.p = jax.lax.fori_loop(0, data.rar_iter_nb, update_slices, data.p)
230
-
231
- # NOTE must return data to be correctly updated because we cannot
232
- # have side effects in this function that will be jitted
233
- return data
262
+ data.p_omega = jax.lax.fori_loop(
263
+ 0, data.rar_iter_nb, update_slices, data.p_omega
264
+ )
234
265
 
235
266
  elif isinstance(data, CubicMeshPDENonStatio):
236
- st = data.sample_in_time_domain(sample_size)
237
- sx = data.sample_in_omega_domain(sample_size)
238
-
239
- # According to the Loss type we have different syntax to call the
240
- # dynamic_loss evaluate function
241
- if isinstance(loss, LossPDEStatio) and not isinstance(
242
- loss, LossPDENonStatio
243
- ):
244
- # This case might not happen very often...
245
- v_dyn_loss = vmap(
246
- lambda x: loss.dynamic_loss.evaluate(
247
- x,
248
- loss.u,
249
- params,
250
- ),
251
- (0),
252
- 0,
253
- )
254
- dyn_on_s = v_dyn_loss(sx)
255
- if dyn_on_s.ndim > 1:
256
- mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
257
- else:
258
- mse_on_s = dyn_on_s**2
259
- elif isinstance(loss, LossPDENonStatio):
267
+ # NOTE in this case sample_size and selected_sample_size
268
+ # are tuples (times, omega) => we unpack them for clarity
269
+ selected_sample_size_times, selected_sample_size_omega = (
270
+ selected_sample_size
271
+ )
272
+ sample_size_times, sample_size_omega = sample_size
273
+
274
+ new_times_samples = data.sample_in_time_domain(sample_size_times)
275
+ new_omega_samples = data.sample_in_omega_domain(sample_size_omega)
276
+
277
+ if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
278
+ raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
279
+ else:
280
+ # do cartesian product on new points
281
+ tile_omega = jnp.tile(
282
+ new_omega_samples, reps=(sample_size_times, 1)
283
+ ) # it is tiled
284
+ repeat_times = jnp.repeat(new_times_samples, sample_size_omega, axis=0)[
285
+ ..., None
286
+ ] # it is repeated + add an axis
287
+
288
+ if isinstance(loss, LossPDENonStatio):
260
289
  v_dyn_loss = vmap(
261
290
  lambda t, x: loss.dynamic_loss.evaluate(t, x, loss.u, params),
262
291
  (0, 0),
263
292
  0,
264
293
  )
265
- dyn_on_s = v_dyn_loss(st[..., None], sx)
266
- if dyn_on_s.ndim > 1:
267
- mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
268
- else:
269
- mse_on_s = dyn_on_s**2
294
+ dyn_on_s = v_dyn_loss(repeat_times, tile_omega).reshape(
295
+ (sample_size_times, sample_size_omega)
296
+ )
297
+ mse_on_s = dyn_on_s**2
270
298
  elif isinstance(loss, SystemLossPDE):
271
- mse_on_s = 0
299
+ dyn_on_s = jnp.zeros((sample_size_times, sample_size_omega))
272
300
  for i in loss.dynamic_loss_dict.keys():
273
- if isinstance(loss.dynamic_loss_dict[i], PDEStatio):
274
- v_dyn_loss = vmap(
275
- lambda x: loss.dynamic_loss_dict[i].evaluate(
276
- x, loss.u_dict, params
277
- ),
278
- 0,
279
- 0,
280
- )
281
- dyn_on_s = v_dyn_loss(sx)
282
- if dyn_on_s.ndim > 1:
283
- mse_on_s += (
284
- jnp.linalg.norm(dyn_on_s, axis=-1) ** 2
285
- ).flatten()
286
- else:
287
- mse_on_s += dyn_on_s**2
288
- else:
289
- v_dyn_loss = vmap(
290
- lambda t, x: loss.dynamic_loss_dict[i].evaluate(
291
- t, x, loss.u_dict, params
292
- ),
293
- (0, 0),
294
- 0,
295
- )
296
- dyn_on_s = v_dyn_loss(st[..., None], sx)
297
- if dyn_on_s.ndim > 1:
298
- mse_on_s += (
299
- jnp.linalg.norm(dyn_on_s, axis=-1) ** 2
300
- ).flatten()
301
- else:
302
- mse_on_s += dyn_on_s**2
303
-
304
- ## Now that we have the residuals, select the m points
305
- # with higher dynamic loss (residuals)
306
- higher_residual_idx = jax.lax.dynamic_slice(
307
- jnp.argsort(mse_on_s),
308
- (mse_on_s.shape[0] - selected_sample_size,),
309
- (selected_sample_size,),
310
- )
311
- higher_residual_points_st = st[higher_residual_idx]
312
- higher_residual_points_sx = sx[higher_residual_idx]
301
+ v_dyn_loss = vmap(
302
+ lambda t, x: loss.dynamic_loss_dict[i].evaluate(
303
+ t, x, loss.u_dict, params
304
+ ),
305
+ (0, 0),
306
+ 0,
307
+ )
308
+ dyn_on_s += v_dyn_loss(repeat_times, tile_omega).reshape(
309
+ (sample_size_times, sample_size_omega)
310
+ )
313
311
 
314
- data.rar_iter_from_last_sampling = 0
312
+ mse_on_s = dyn_on_s**2
313
+ # -- Select the m points with highest average residuals on time and
314
+ # -- space (times in rows / omega in columns)
315
+ # mean_times = mse_on_s.mean(axis=1)
316
+ # mean_omega = mse_on_s.mean(axis=0)
317
+ # times_idx = jax.lax.dynamic_slice(
318
+ # jnp.argsort(mean_times),
319
+ # (mse_on_s.shape[0] - selected_sample_size_times,),
320
+ # (selected_sample_size_times,),
321
+ # )
322
+ # omega_idx = jax.lax.dynamic_slice(
323
+ # jnp.argsort(mean_omega),
324
+ # (mse_on_s.shape[1] - selected_sample_size_omega,),
325
+ # (selected_sample_size_omega,),
326
+ # )
327
+
328
+ # -- Select the m worst points (t, x) with highest residuals
329
+ n_select = max(selected_sample_size_times, selected_sample_size_omega)
330
+ _, idx = jax.lax.top_k(mse_on_s.flatten(), k=n_select)
331
+ arr_idx = jnp.unravel_index(idx, mse_on_s.shape)
332
+ times_idx = arr_idx[0][:selected_sample_size_times]
333
+ omega_idx = arr_idx[1][:selected_sample_size_omega]
334
+
335
+ higher_residual_points_times = new_times_samples[times_idx]
336
+ higher_residual_points_omega = new_omega_samples[omega_idx]
315
337
 
316
338
  ## add the new points in times
317
- # start indices of update can be dynamic but the the shape (length)
339
+ # start indices of update can be dynamic but not the shape (length)
318
340
  # of the slice
319
341
  data.times = jax.lax.dynamic_update_slice(
320
342
  data.times,
321
- higher_residual_points_st,
322
- (data.n_start + data.rar_iter_nb * selected_sample_size,),
343
+ higher_residual_points_times,
344
+ (data.n_start + data.rar_iter_nb * selected_sample_size_times,),
323
345
  )
324
346
 
325
347
  ## add the new points in omega
326
348
  data.omega = jax.lax.dynamic_update_slice(
327
349
  data.omega,
328
- higher_residual_points_sx,
350
+ higher_residual_points_omega,
329
351
  (
330
- data.n_start + data.rar_iter_nb * selected_sample_size,
352
+ data.n_start + data.rar_iter_nb * selected_sample_size_omega,
331
353
  data.dim,
332
354
  ),
333
355
  )
334
356
 
335
357
  ## rearrange probabilities so that the probabilities of the new
336
358
  ## points are non-zero
337
- new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
359
+ new_p_times = 1 / (
360
+ data.nt_start + data.rar_iter_nb * selected_sample_size_times
361
+ )
338
362
  # the next work because nt_start is static
339
- data.p = data.p.at[: data.n_start].set(new_proba)
363
+ data.p_times = data.p_times.at[: data.nt_start].set(new_p_times)
340
364
 
341
- # the next requires a fori_loop because the range is dynamic
342
- def update_slices(i, p):
343
- return jax.lax.dynamic_update_slice(
344
- p,
345
- 1 / new_proba * jnp.ones((selected_sample_size,)),
346
- ((data.n_start + i * selected_sample_size),),
347
- )
365
+ # same for p_omega (work because n_start is static)
366
+ new_p_omega = 1 / (
367
+ data.n_start + data.rar_iter_nb * selected_sample_size_omega
368
+ )
369
+ data.p_omega = data.p_omega.at[: data.n_start].set(new_p_omega)
370
+
371
+ # the part of data.p_* after n_start requires a fori_loop because
372
+ # the range is dynamic
373
+ def create_update_slices(new_val, selected_sample_size):
374
+ def update_slices(i, p):
375
+ new_p = jax.lax.dynamic_update_slice(
376
+ p,
377
+ new_val * jnp.ones((selected_sample_size,)),
378
+ ((data.n_start + i * selected_sample_size),),
379
+ )
380
+ return new_p
381
+
382
+ return update_slices
348
383
 
349
384
  data.rar_iter_nb += 1
350
385
 
351
- data.p = jax.lax.fori_loop(0, data.rar_iter_nb, update_slices, data.p)
386
+ ## update rest of p_times
387
+ update_slices_times = create_update_slices(
388
+ new_p_times, selected_sample_size_times
389
+ )
390
+ data.p_times = jax.lax.fori_loop(
391
+ 0,
392
+ data.rar_iter_nb,
393
+ update_slices_times,
394
+ data.p_times,
395
+ )
396
+ ## update rest of p_omega
397
+ update_slices_omega = create_update_slices(
398
+ new_p_omega, selected_sample_size_omega
399
+ )
400
+ data.p_omega = jax.lax.fori_loop(
401
+ 0,
402
+ data.rar_iter_nb,
403
+ update_slices_omega,
404
+ data.p_omega,
405
+ )
352
406
 
353
- # NOTE must return data to be correctly updated because we cannot
354
- # have side effects in this function that will be jitted
355
- return data
407
+ # update RAR parameters for all cases
408
+ data.rar_iter_from_last_sampling = 0
409
+
410
+ # NOTE must return data to be correctly updated because we cannot
411
+ # have side effects in this function that will be jitted
412
+ return data
356
413
 
357
414
  def rar_step_false(operands):
358
415
  _, _, data, i = operands
359
416
 
360
417
  # Add 1 only if we are after the burn in period
361
- data.rar_iter_from_last_sampling = jax.lax.cond(
362
- i < data.rar_parameters["start_iter"],
363
- lambda operand: 0,
364
- lambda operand: operand + 1,
365
- (data.rar_iter_from_last_sampling),
418
+ increment = jax.lax.cond(
419
+ i <= data.rar_parameters["start_iter"],
420
+ lambda: 0,
421
+ lambda: 1,
366
422
  )
367
423
 
424
+ data.rar_iter_from_last_sampling += increment
368
425
  return data
369
426
 
370
427
  return rar_step_true, rar_step_false
jinns/solver/_seq2seq.py CHANGED
@@ -88,7 +88,7 @@ def initialize_seq2seq(loss, data, seq2seq, opt_state):
88
88
  data.curr_omega_idx = 0
89
89
  data.generate_time_data()
90
90
  data._key, data.times, _ = _reset_batch_idx_and_permute(
91
- (data._key, data.times, data.curr_omega_idx, None, data.p)
91
+ (data._key, data.times, data.curr_omega_idx, None, data.p_times)
92
92
  )
93
93
  opt_state.hyperparams["learning_rate"] = seq2seq["learning_rate"][curr_seq]
94
94
 
@@ -145,7 +145,7 @@ def _update_seq2seq_SystemLossODE(operands):
145
145
  data.curr_omega_idx = 0
146
146
  data.generate_time_data()
147
147
  data._key, data.times, _ = _reset_batch_idx_and_permute(
148
- (data._key, data.times, data.curr_omega_idx, None, data.p)
148
+ (data._key, data.times, data.curr_omega_idx, None, data.p_times)
149
149
  )
150
150
 
151
151
  opt_state.hyperparams["learning_rate"] = seq2seq["learning_rate"][curr_seq]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.7
3
+ Version: 0.8.9
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -1,9 +1,11 @@
1
- jinns/__init__.py,sha256=Nw5pdlmDhJwco3bXX3YttkeCF8czX_6m0poh8vu0lDQ,113
2
- jinns/data/_DataGenerators.py,sha256=N4-U4z3MG46UIzHCbKScv9Z7AN40w1wlLY_VsVNj2sI,62293
1
+ jinns/__init__.py,sha256=T2XlmLbYqcXTumPJL00cJ80W98We5LH8Yg_Lss_exl4,139
2
+ jinns/data/_DataGenerators.py,sha256=_um-giHQ8mCILUOJHX231njHTHZp4S7EcGrUs7R1dUs,61829
3
3
  jinns/data/__init__.py,sha256=yBOmoavSD-cABp4XcjQY1zsEVO0mDyIhi2MJ5WNp0l8,326
4
- jinns/data/_display.py,sha256=6renz4H7kHktutmLY7HM6PmxYH7cBfGHpC7GQa1Fnlk,7778
5
- jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awiY,208
4
+ jinns/data/_display.py,sha256=vlqggDCgVMEwdGBtjVmZaTQORU6imSfDkssn2XCtITI,10392
5
+ jinns/experimental/__init__.py,sha256=qWbhC7Z8UgLWy0t-zU7RYze6v13-FngiCYXu-2bRVFQ,296
6
6
  jinns/experimental/_diffrax_solver.py,sha256=sLT22byqh-6015_fhe1xtMWlFOYcCjzYKET4sLhA9R4,6818
7
+ jinns/experimental/_sinuspinn.py,sha256=hxSzscwMV2LayWOqenIlT1zqEVVrE5Y8CKf7bHX5XFQ,5016
8
+ jinns/experimental/_spectralpinn.py,sha256=-4795pa7AYtRNSE-ugan3gHh64mtu2VdrRG5AS_J9Eg,2654
7
9
  jinns/loss/_DynamicLoss.py,sha256=L4CVmmF0rTPbHntgqsLLHlnrlQgLHsetUocpJm7ZYag,27461
8
10
  jinns/loss/_DynamicLossAbstract.py,sha256=kTQlhLx7SBuH5dIDmYaE79sVHUZt1nUFa8LxPU5IHhM,8504
9
11
  jinns/loss/_LossODE.py,sha256=b9doBHoQwYvlgpqzrNO4dOaTN87LRvjHtHbz9bMoH7E,22119
@@ -13,8 +15,8 @@ jinns/loss/__init__.py,sha256=pFNYUxns-NPXBFdqrEVSiXkQLfCtKw-t2trlhvLzpYE,355
13
15
  jinns/loss/_boundary_conditions.py,sha256=YfSnLZ25hXqQ5KWAuxOrWSKkf_oBqAc9GQV4z7MjWyQ,17434
14
16
  jinns/loss/_operators.py,sha256=zDGJqYqeYH7xd-4dtGX9PS-pf0uSOpUUXGo5SVjIJ4o,11069
15
17
  jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- jinns/solver/_rar.py,sha256=K-0y1-ofOAo1n_Ea3QShSGCGKVYTwiaE_Bz9-DZMJm8,14525
17
- jinns/solver/_seq2seq.py,sha256=FL-42hTgmVl7O3hHh1ccFVw2bT8bW82hvlDRz971Chk,5620
18
+ jinns/solver/_rar.py,sha256=IYP-jdbM0rbjBtxislrBYBuj49p9_QDOqejZKCHrKg8,17072
19
+ jinns/solver/_seq2seq.py,sha256=S6IPfsXpS_fbqIqAy01eUM7GBSBSkRzURan_J-iXXzI,5632
18
20
  jinns/solver/_solve.py,sha256=mGi0zaT_fK_QpBjTxof5Ix4mmfmnPi66CNJ3GQFZuo4,19099
19
21
  jinns/utils/__init__.py,sha256=44ms5UR6vMw3Nf6u4RCAzPFs4fom_YbBnH9mfne8m6k,313
20
22
  jinns/utils/_containers.py,sha256=eYD277fO7X4EfX7PUFCCl69r3JBfh1sCfq8LkL5gd6o,1495
@@ -27,8 +29,8 @@ jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
27
29
  jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
28
30
  jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
29
31
  jinns/validation/_validation.py,sha256=KfetbzB0xTNdBcYLwFWjEtP63Tf9wJirlhgqLTJDyy4,6761
30
- jinns-0.8.7.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
31
- jinns-0.8.7.dist-info/METADATA,sha256=L0P7JvMGKrJHx9OjrtFsmNKEwdKA_RlufAbOBf5l10I,2482
32
- jinns-0.8.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
33
- jinns-0.8.7.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
34
- jinns-0.8.7.dist-info/RECORD,,
32
+ jinns-0.8.9.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
33
+ jinns-0.8.9.dist-info/METADATA,sha256=na97ODHPafvEMKGvmUq6XszFrjQ9LP8L2FtCY2gZ8oI,2482
34
+ jinns-0.8.9.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
35
+ jinns-0.8.9.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
36
+ jinns-0.8.9.dist-info/RECORD,,
File without changes
File without changes