jinns 0.6.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -48,6 +48,7 @@ class FisherKPP(PDENonStatio):
48
48
  """
49
49
  super().__init__(Tmax, eq_params_heterogeneity)
50
50
 
51
+ @PDENonStatio.evaluate_heterogeneous_parameters
51
52
  def evaluate(self, t, x, u, params):
52
53
  r"""
53
54
  Evaluate the dynamic loss at :math:`(t,x)`.
@@ -67,16 +68,12 @@ class FisherKPP(PDENonStatio):
67
68
  differential equation parameters and the neural network parameter
68
69
  """
69
70
  if isinstance(u, PINN):
70
- params["eq_params"] = self._eval_heterogeneous_parameters(
71
- params["eq_params"], t, x, self.eq_params_heterogeneity
72
- )
73
-
74
71
  # Note that the last dim of u is nec. 1
75
72
  u_ = lambda t, x: u(t, x, params)[0]
76
73
 
77
74
  du_dt = grad(u_, 0)(t, x)
78
75
 
79
- lap = _laplacian_rev(u, params, x, t)[..., None]
76
+ lap = _laplacian_rev(t, x, u, params)[..., None]
80
77
 
81
78
  return du_dt + self.Tmax * (
82
79
  -params["eq_params"]["D"] * lap
@@ -87,17 +84,12 @@ class FisherKPP(PDENonStatio):
87
84
  )
88
85
  )
89
86
  if isinstance(u, SPINN):
90
- x_grid = _get_grid(x)
91
- params["eq_params"] = self._eval_heterogeneous_parameters(
92
- params["eq_params"], t, x_grid, self.eq_params_heterogeneity
93
- )
94
-
95
87
  u_tx, du_dt = jax.jvp(
96
88
  lambda t: u(t, x, params),
97
89
  (t,),
98
90
  (jnp.ones_like(t),),
99
91
  )
100
- lap = _laplacian_fwd(u, params, x, t)[..., None]
92
+ lap = _laplacian_fwd(t, x, u, params)[..., None]
101
93
  return du_dt + self.Tmax * (
102
94
  -params["eq_params"]["D"] * lap
103
95
  - u_tx
@@ -160,12 +152,8 @@ class BurgerEquation(PDENonStatio):
160
152
  differential equation parameters and the neural network parameter
161
153
  """
162
154
  if isinstance(u, PINN):
163
- params["eq_params"] = self._eval_heterogeneous_parameters(
164
- params["eq_params"], t, x, self.eq_params_heterogeneity
165
- )
166
-
167
155
  # Note that the last dim of u is nec. 1
168
- u_ = lambda t, x: u(t, x, params)[0]
156
+ u_ = lambda t, x: jnp.squeeze(u(t, x, params)[u.slice_solution])
169
157
  du_dt = grad(u_, 0)
170
158
  du_dx = grad(u_, 1)
171
159
  d2u_dx2 = grad(
@@ -179,10 +167,6 @@ class BurgerEquation(PDENonStatio):
179
167
  )
180
168
 
181
169
  if isinstance(u, SPINN):
182
- x_grid = _get_grid(x)
183
- params["eq_params"] = self._eval_heterogeneous_parameters(
184
- params["eq_params"], t, x_grid, self.eq_params_heterogeneity
185
- )
186
170
  # d=2 JVP calls are expected since we have time and x
187
171
  # then with a batch of size B, we then have Bd JVP calls
188
172
  u_tx, du_dt = jax.jvp(
@@ -352,10 +336,6 @@ class FPENonStatioLoss2D(PDENonStatio):
352
336
  differential equation parameters and the neural network parameter
353
337
  """
354
338
  if isinstance(u, PINN):
355
- params["eq_params"] = self._eval_heterogeneous_parameters(
356
- params["eq_params"], t, x, self.eq_params_heterogeneity
357
- )
358
-
359
339
  # Note that the last dim of u is nec. 1
360
340
  u_ = lambda t, x: u(t, x, params)[0]
361
341
 
@@ -411,10 +391,6 @@ class FPENonStatioLoss2D(PDENonStatio):
411
391
 
412
392
  if isinstance(u, SPINN):
413
393
  x_grid = _get_grid(x)
414
- params["eq_params"] = self._eval_heterogeneous_parameters(
415
- params["eq_params"], t, x_grid, self.eq_params_heterogeneity
416
- )
417
-
418
394
  _, du_dt = jax.jvp(
419
395
  lambda t: u(t, x, params),
420
396
  (t,),
@@ -634,12 +610,12 @@ class MassConservation2DStatio(PDEStatio):
634
610
  if isinstance(u_dict[self.nn_key], PINN):
635
611
  u = u_dict[self.nn_key]
636
612
 
637
- return _div_rev(u, params, x)[..., None]
613
+ return _div_rev(None, x, u, params)[..., None]
638
614
 
639
615
  if isinstance(u_dict[self.nn_key], SPINN):
640
616
  u = u_dict[self.nn_key]
641
617
 
642
- return _div_fwd(u, params, x)[..., None]
618
+ return _div_fwd(None, x, u, params)[..., None]
643
619
  raise ValueError("u is not among the recognized types (PINN or SPINN)")
644
620
 
645
621
 
@@ -724,12 +700,12 @@ class NavierStokes2DStatio(PDEStatio):
724
700
  if isinstance(u_dict[self.u_key], PINN):
725
701
  u = u_dict[self.u_key]
726
702
 
727
- u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(u, u_params, x)
703
+ u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(None, x, u, u_params)
728
704
 
729
705
  p = lambda x: u_dict[self.p_key](x, p_params)
730
706
  jac_p = jacrev(p, 0)(x) # compute the gradient
731
707
 
732
- vec_laplacian_u = _vectorial_laplacian(u, u_params, x, u_vec_ndim=2)
708
+ vec_laplacian_u = _vectorial_laplacian(None, x, u, u_params, u_vec_ndim=2)
733
709
 
734
710
  # dynamic loss on x axis
735
711
  result_x = (
@@ -751,7 +727,7 @@ class NavierStokes2DStatio(PDEStatio):
751
727
  if isinstance(u_dict[self.u_key], SPINN):
752
728
  u = u_dict[self.u_key]
753
729
 
754
- u_dot_nabla_x_u = _u_dot_nabla_times_u_fwd(u, u_params, x)
730
+ u_dot_nabla_x_u = _u_dot_nabla_times_u_fwd(None, x, u, u_params)
755
731
 
756
732
  p = lambda x: u_dict[self.p_key](x, p_params)
757
733
 
@@ -761,7 +737,7 @@ class NavierStokes2DStatio(PDEStatio):
761
737
  _, dp_dy = jax.jvp(p, (x,), (tangent_vec_1,))
762
738
 
763
739
  vec_laplacian_u = jnp.moveaxis(
764
- _vectorial_laplacian(u, u_params, x, u_vec_ndim=2),
740
+ _vectorial_laplacian(None, x, u, u_params, u_vec_ndim=2),
765
741
  source=0,
766
742
  destination=-1,
767
743
  )
@@ -23,8 +23,8 @@ class DynamicLoss:
23
23
  Default None. A dict with the keys being the same as in eq_params
24
24
  and the value being either None (no heterogeneity) or a function
25
25
  which encodes for the spatio-temporal heterogeneity of the parameter.
26
- Such a function must be jittable and take three arguments `t`,
27
- `x` and `params["eq_params"]` even if one is not used. Therefore,
26
+ Such a function must be jittable and take four arguments `t`, `x`,
27
+ `u` and `params` even if one is not used. Therefore,
28
28
  one can introduce spatio-temporal covariates upon which a particular
29
29
  parameter can depend, e.g. in a GLM fashion. The effect of these
30
30
  covariables can themselves be estimated by being in `eq_params` too.
@@ -35,20 +35,24 @@ class DynamicLoss:
35
35
  self.Tmax = Tmax
36
36
  self.eq_params_heterogeneity = eq_params_heterogeneity
37
37
 
38
- def _eval_heterogeneous_parameters(
39
- self, eq_params, t, x, eq_params_heterogeneity=None
40
- ):
38
+ @staticmethod
39
+ def _eval_heterogeneous_parameters(t, x, u, params, eq_params_heterogeneity=None):
41
40
  eq_params_ = {}
42
41
  if eq_params_heterogeneity is None:
43
- return eq_params
44
- for k, p in eq_params.items():
42
+ return params["eq_params"]
43
+ for k, p in params["eq_params"].items():
45
44
  try:
46
45
  if eq_params_heterogeneity[k] is None:
47
46
  eq_params_[k] = p
48
47
  else:
49
- eq_params_[k] = eq_params_heterogeneity[k](
50
- t, x, eq_params # heterogeneity encoded through a function
51
- )
48
+ if t is None:
49
+ eq_params_[k] = eq_params_heterogeneity[k](
50
+ x, u, params # heterogeneity encoded through a function
51
+ )
52
+ else:
53
+ eq_params_[k] = eq_params_heterogeneity[k](
54
+ t, x, u, params # heterogeneity encoded through a function
55
+ )
52
56
  except KeyError:
53
57
  # we authorize missing eq_params_heterogeneity key
54
58
  # is its heterogeneity is None anyway
@@ -79,6 +83,31 @@ class ODE(DynamicLoss):
79
83
  """
80
84
  super().__init__(Tmax, eq_params_heterogeneity)
81
85
 
86
+ def eval_heterogeneous_parameters(self, t, u, params, eq_params_heterogeneity=None):
87
+ return super()._eval_heterogeneous_parameters(
88
+ t, None, u, params, eq_params_heterogeneity
89
+ )
90
+
91
+ @staticmethod
92
+ def evaluate_heterogeneous_parameters(evaluate):
93
+ """
94
+ Decorator which aims to decorate the evaluate methods of Dynamic losses
95
+ in order. It calls _eval_heterogeneous_parameters which applies the
96
+ user defined rules to obtain spatially / temporally heterogeneous
97
+ parameters
98
+ """
99
+
100
+ def wrapper(*args):
101
+ self, t, u, params = args
102
+ params["eq_params"] = self.eval_heterogeneous_parameters(
103
+ t, u, params, self.eq_params_heterogeneity
104
+ )
105
+ new_args = args[:-1] + (params,)
106
+ res = evaluate(*new_args)
107
+ return res
108
+
109
+ return wrapper
110
+
82
111
 
83
112
  class PDEStatio(DynamicLoss):
84
113
  r"""
@@ -99,6 +128,31 @@ class PDEStatio(DynamicLoss):
99
128
  """
100
129
  super().__init__(eq_params_heterogeneity=eq_params_heterogeneity)
101
130
 
131
+ def eval_heterogeneous_parameters(self, x, u, params, eq_params_heterogeneity=None):
132
+ return super()._eval_heterogeneous_parameters(
133
+ None, x, u, params, eq_params_heterogeneity
134
+ )
135
+
136
+ @staticmethod
137
+ def evaluate_heterogeneous_parameters(evaluate):
138
+ """
139
+ Decorator which aims to decorate the evaluate methods of Dynamic losses
140
+ in order. It calls _eval_heterogeneous_parameters which applies the
141
+ user defined rules to obtain spatially / temporally heterogeneous
142
+ parameters
143
+ """
144
+
145
+ def wrapper(*args):
146
+ self, x, u, params = args
147
+ params["eq_params"] = self.eval_heterogeneous_parameters(
148
+ x, u, params, self.eq_params_heterogeneity
149
+ )
150
+ new_args = args[:-1] + (params,)
151
+ res = evaluate(*new_args)
152
+ return res
153
+
154
+ return wrapper
155
+
102
156
 
103
157
  class PDENonStatio(DynamicLoss):
104
158
  r"""
@@ -122,3 +176,30 @@ class PDENonStatio(DynamicLoss):
122
176
  heterogeneity for no parameters.
123
177
  """
124
178
  super().__init__(Tmax, eq_params_heterogeneity)
179
+
180
+ def eval_heterogeneous_parameters(
181
+ self, t, x, u, params, eq_params_heterogeneity=None
182
+ ):
183
+ return super()._eval_heterogeneous_parameters(
184
+ t, x, u, params, eq_params_heterogeneity
185
+ )
186
+
187
+ @staticmethod
188
+ def evaluate_heterogeneous_parameters(evaluate):
189
+ """
190
+ Decorator which aims to decorate the evaluate methods of Dynamic losses
191
+ in order. It calls _eval_heterogeneous_parameters which applies the
192
+ user defined rules to obtain spatially / temporally heterogeneous
193
+ parameters
194
+ """
195
+
196
+ def wrapper(*args):
197
+ self, t, x, u, params = args
198
+ params["eq_params"] = self.eval_heterogeneous_parameters(
199
+ t, x, u, params, self.eq_params_heterogeneity
200
+ )
201
+ new_args = args[:-1] + (params,)
202
+ res = evaluate(*new_args)
203
+ return res
204
+
205
+ return wrapper
jinns/loss/_LossPDE.py CHANGED
@@ -254,6 +254,7 @@ class LossPDEStatio(LossPDEAbstract):
254
254
  derivative_keys=None,
255
255
  omega_boundary_fun=None,
256
256
  omega_boundary_condition=None,
257
+ omega_boundary_dim=None,
257
258
  norm_key=None,
258
259
  norm_borders=None,
259
260
  norm_samples=None,
@@ -310,6 +311,12 @@ class LossPDEStatio(LossPDEAbstract):
310
311
  enforce a particular boundary condition on this facet.
311
312
  The facet called "xmin", resp. "xmax" etc., in 2D,
312
313
  refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
314
+ omega_boundary_dim
315
+ Either None, or a jnp.s_ or a dict of jnp.s_ with keys following
316
+ the logic of omega_boundary_fun. It indicates which dimension(s) of
317
+ the PINN will be forced to match the boundary condition
318
+ Note that it must be a slice and not an integer (a preprocessing of the
319
+ user provided argument takes care of it)
313
320
  norm_key
314
321
  Jax random key to draw samples in for the Monte Carlo computation
315
322
  of the normalization constant. Default is None
@@ -345,7 +352,7 @@ class LossPDEStatio(LossPDEAbstract):
345
352
  raise ValueError(
346
353
  f"obs_batch must be a list of size 2. You gave {len(obs_batch)}"
347
354
  )
348
- if any(isinstance(b, jnp.ndarray) for b in obs_batch):
355
+ if not all(isinstance(b, jnp.ndarray) for b in obs_batch):
349
356
  raise ValueError("Every element of obs_batch should be a jnp.array.")
350
357
  n_obs = obs_batch[0].shape[0]
351
358
  if any(b.shape[0] != n_obs for b in obs_batch):
@@ -411,16 +418,49 @@ class LossPDEStatio(LossPDEAbstract):
411
418
 
412
419
  self.omega_boundary_fun = omega_boundary_fun
413
420
  self.omega_boundary_condition = omega_boundary_condition
421
+
422
+ self.omega_boundary_dim = omega_boundary_dim
423
+ if isinstance(self.omega_boundary_fun, dict):
424
+ if self.omega_boundary_dim is None:
425
+ self.omega_boundary_dim = {
426
+ k: jnp.s_[::] for k in self.omega_boundary_fun.keys()
427
+ }
428
+ if list(self.omega_boundary_dim.keys()) != list(
429
+ self.omega_boundary_fun.keys()
430
+ ):
431
+ raise ValueError(
432
+ "If omega_boundary_fun is a dict,"
433
+ " omega_boundary_dim should be a dict with the same keys"
434
+ )
435
+ for k, v in self.omega_boundary_dim.items():
436
+ if isinstance(v, int):
437
+ # rewrite it as a slice to ensure that axis does not disappear when
438
+ # indexing
439
+ self.omega_boundary_dim[k] = jnp.s_[v : v + 1]
440
+
441
+ else:
442
+ if self.omega_boundary_dim is None:
443
+ self.omega_boundary_dim = jnp.s_[::]
444
+ if isinstance(self.omega_boundary_dim, int):
445
+ # rewrite it as a slice to ensure that axis does not disappear when
446
+ # indexing
447
+ self.omega_boundary_dim = jnp.s_[
448
+ self.omega_boundary_dim : self.omega_boundary_dim + 1
449
+ ]
450
+ if not isinstance(self.omega_boundary_dim, slice):
451
+ raise ValueError("self.omega_boundary_dim must be a jnp.s_" " object")
452
+
414
453
  self.dynamic_loss = dynamic_loss
415
454
  self.obs_batch = obs_batch
416
455
 
417
- if sobolev_m is not None:
456
+ self.sobolev_m = sobolev_m
457
+ if self.sobolev_m is not None:
418
458
  self.sobolev_reg = _sobolev(
419
- self.u, sobolev_m
459
+ self.u, self.sobolev_m
420
460
  ) # we return a function, that way
421
461
  # the order of sobolev_m is static and the conditional in the recursive
422
462
  # function is properly set
423
- self.sobolev_m = sobolev_m
463
+ self.sobolev_m = self.sobolev_m
424
464
  else:
425
465
  self.sobolev_reg = None
426
466
 
@@ -512,7 +552,7 @@ class LossPDEStatio(LossPDEAbstract):
512
552
  if self.normalization_loss is not None:
513
553
  if isinstance(self.u, PINN):
514
554
  v_u = vmap(
515
- lambda x: self.u(x, params_),
555
+ lambda x: self.u(x, params_)[self.u.slice_solution],
516
556
  (0),
517
557
  0,
518
558
  )
@@ -560,6 +600,7 @@ class LossPDEStatio(LossPDEAbstract):
560
600
  self.u,
561
601
  params_,
562
602
  idx,
603
+ self.omega_boundary_dim[facet],
563
604
  )
564
605
  )
565
606
  else:
@@ -574,6 +615,7 @@ class LossPDEStatio(LossPDEAbstract):
574
615
  self.u,
575
616
  params_,
576
617
  facet,
618
+ self.omega_boundary_dim,
577
619
  )
578
620
  )
579
621
  else:
@@ -585,11 +627,11 @@ class LossPDEStatio(LossPDEAbstract):
585
627
  # TODO implement for SPINN
586
628
  if isinstance(self.u, PINN):
587
629
  v_u = vmap(
588
- lambda x: self.u(x, params_),
630
+ lambda x: self.u(x, params_)[self.u.slice_solution],
589
631
  0,
590
632
  0,
591
633
  )
592
- val = v_u(self.obs_batch[0][:, None])
634
+ val = v_u(self.obs_batch[0])
593
635
  mse_observation_loss = jnp.mean(
594
636
  self.loss_weights["observations"]
595
637
  * jnp.sum(
@@ -655,6 +697,7 @@ class LossPDEStatio(LossPDEAbstract):
655
697
  "derivative_keys": self.derivative_keys,
656
698
  "omega_boundary_fun": self.omega_boundary_fun,
657
699
  "omega_boundary_condition": self.omega_boundary_condition,
700
+ "omega_boundary_dim": self.omega_boundary_dim,
658
701
  "norm_borders": self.norm_borders,
659
702
  "sobolev_m": self.sobolev_m,
660
703
  }
@@ -670,6 +713,7 @@ class LossPDEStatio(LossPDEAbstract):
670
713
  aux_data["derivative_keys"],
671
714
  aux_data["omega_boundary_fun"],
672
715
  aux_data["omega_boundary_condition"],
716
+ aux_data["omega_boundary_dim"],
673
717
  norm_key,
674
718
  aux_data["norm_borders"],
675
719
  norm_samples,
@@ -706,6 +750,7 @@ class LossPDENonStatio(LossPDEStatio):
706
750
  derivative_keys=None,
707
751
  omega_boundary_fun=None,
708
752
  omega_boundary_condition=None,
753
+ omega_boundary_dim=None,
709
754
  initial_condition_fun=None,
710
755
  norm_key=None,
711
756
  norm_borders=None,
@@ -760,6 +805,12 @@ class LossPDENonStatio(LossPDEStatio):
760
805
  enforce a particular boundary condition on this facet.
761
806
  The facet called "xmin", resp. "xmax" etc., in 2D,
762
807
  refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
808
+ omega_boundary_dim
809
+ Either None, or a jnp.s_ or a dict of jnp.s_ with keys following
810
+ the logic of omega_boundary_fun. It indicates which dimension(s) of
811
+ the PINN will be forced to match the boundary condition
812
+ Note that it must be a slice and not an integer (a preprocessing of the
813
+ user provided argument takes care of it)
763
814
  initial_condition_fun
764
815
  A function representing the temporal initial condition. If None
765
816
  (default) then no initial condition is applied
@@ -815,6 +866,7 @@ class LossPDENonStatio(LossPDEStatio):
815
866
  derivative_keys,
816
867
  omega_boundary_fun,
817
868
  omega_boundary_condition,
869
+ omega_boundary_dim,
818
870
  norm_key,
819
871
  norm_borders,
820
872
  norm_samples,
@@ -978,6 +1030,7 @@ class LossPDENonStatio(LossPDEStatio):
978
1030
  self.u,
979
1031
  params_,
980
1032
  idx,
1033
+ self.omega_boundary_dim[facet],
981
1034
  )
982
1035
  )
983
1036
  else:
@@ -993,6 +1046,7 @@ class LossPDENonStatio(LossPDEStatio):
993
1046
  self.u,
994
1047
  params_,
995
1048
  facet,
1049
+ self.omega_boundary_dim,
996
1050
  )
997
1051
  )
998
1052
  else:
@@ -1036,7 +1090,7 @@ class LossPDENonStatio(LossPDEStatio):
1036
1090
  # TODO implement for SPINN
1037
1091
  if isinstance(self.u, PINN):
1038
1092
  v_u = vmap(
1039
- lambda t, x: self.u(t, x, params_),
1093
+ lambda t, x: self.u(t, x, params_)[self.u.slice_solution],
1040
1094
  (0, 0),
1041
1095
  0,
1042
1096
  )
@@ -1108,6 +1162,7 @@ class LossPDENonStatio(LossPDEStatio):
1108
1162
  "derivative_keys": self.derivative_keys,
1109
1163
  "omega_boundary_fun": self.omega_boundary_fun,
1110
1164
  "omega_boundary_condition": self.omega_boundary_condition,
1165
+ "omega_boundary_dim": self.omega_boundary_dim,
1111
1166
  "initial_condition_fun": self.initial_condition_fun,
1112
1167
  "norm_borders": self.norm_borders,
1113
1168
  "sobolev_m": self.sobolev_m,
@@ -1124,6 +1179,7 @@ class LossPDENonStatio(LossPDEStatio):
1124
1179
  aux_data["derivative_keys"],
1125
1180
  aux_data["omega_boundary_fun"],
1126
1181
  aux_data["omega_boundary_condition"],
1182
+ aux_data["omega_boundary_dim"],
1127
1183
  aux_data["initial_condition_fun"],
1128
1184
  norm_key,
1129
1185
  aux_data["norm_borders"],
@@ -1163,6 +1219,7 @@ class SystemLossPDE:
1163
1219
  derivative_keys_dict=None,
1164
1220
  omega_boundary_fun_dict=None,
1165
1221
  omega_boundary_condition_dict=None,
1222
+ omega_boundary_dim_dict=None,
1166
1223
  initial_condition_fun_dict=None,
1167
1224
  norm_key_dict=None,
1168
1225
  norm_borders_dict=None,
@@ -1199,15 +1256,17 @@ class SystemLossPDE:
1199
1256
  are not specified then the default behaviour for `derivative_keys`
1200
1257
  of LossODE is used
1201
1258
  omega_boundary_fun_dict
1202
- A dict of functions to be matched in the border condition, or a
1203
- dict of dict of functions (see doc for `omega_boundary_fun` in
1204
- LossPDEStatio or LossPDENonStatio).
1205
- Must share the keys of `u_dict`
1259
+ A dict of dict of functions (see doc for `omega_boundary_fun` in
1260
+ LossPDEStatio or LossPDENonStatio). Default is None.
1261
+ Must share the keys of `u_dict`.
1206
1262
  omega_boundary_condition_dict
1207
- A dict of either None (no condition), or a string defining the boundary
1208
- condition e.g. Dirichlet or Von Neumann, or a dict of dict of
1209
- strings (see doc for `omega_boundary_fun` in
1210
- LossPDEStatio or LossPDENonStatio).
1263
+ A dict of dict of strings (see doc for
1264
+ `omega_boundary_condition_dict` in
1265
+ LossPDEStatio or LossPDENonStatio). Default is None.
1266
+ Must share the keys of `u_dict`
1267
+ omega_boundary_dim_dict
1268
+ A dict of dict of slices (see doc for `omega_boundary_dim` in
1269
+ LossPDEStatio or LossPDENonStatio). Default is None.
1211
1270
  Must share the keys of `u_dict`
1212
1271
  initial_condition_fun_dict
1213
1272
  A dict of functions representing the temporal initial condition. If None
@@ -1265,6 +1324,10 @@ class SystemLossPDE:
1265
1324
  self.omega_boundary_condition_dict = {k: None for k in u_dict.keys()}
1266
1325
  else:
1267
1326
  self.omega_boundary_condition_dict = omega_boundary_condition_dict
1327
+ if omega_boundary_dim_dict is None:
1328
+ self.omega_boundary_dim_dict = {k: None for k in u_dict.keys()}
1329
+ else:
1330
+ self.omega_boundary_dim_dict = omega_boundary_dim_dict
1268
1331
  if initial_condition_fun_dict is None:
1269
1332
  self.initial_condition_fun_dict = {k: None for k in u_dict.keys()}
1270
1333
  else:
@@ -1309,6 +1372,7 @@ class SystemLossPDE:
1309
1372
  or u_dict.keys() != self.obs_batch_dict.keys()
1310
1373
  or u_dict.keys() != self.omega_boundary_fun_dict.keys()
1311
1374
  or u_dict.keys() != self.omega_boundary_condition_dict.keys()
1375
+ or u_dict.keys() != self.omega_boundary_dim_dict.keys()
1312
1376
  or u_dict.keys() != self.initial_condition_fun_dict.keys()
1313
1377
  or u_dict.keys() != self.norm_key_dict.keys()
1314
1378
  or u_dict.keys() != self.norm_borders_dict.keys()
@@ -1345,6 +1409,7 @@ class SystemLossPDE:
1345
1409
  derivative_keys=self.derivative_keys_dict[i],
1346
1410
  omega_boundary_fun=self.omega_boundary_fun_dict[i],
1347
1411
  omega_boundary_condition=self.omega_boundary_condition_dict[i],
1412
+ omega_boundary_dim=self.omega_boundary_dim_dict[i],
1348
1413
  norm_key=self.norm_key_dict[i],
1349
1414
  norm_borders=self.norm_borders_dict[i],
1350
1415
  norm_samples=self.norm_samples_dict[i],
@@ -1366,6 +1431,7 @@ class SystemLossPDE:
1366
1431
  derivative_keys=self.derivative_keys_dict[i],
1367
1432
  omega_boundary_fun=self.omega_boundary_fun_dict[i],
1368
1433
  omega_boundary_condition=self.omega_boundary_condition_dict[i],
1434
+ omega_boundary_dim=self.omega_boundary_dim_dict[i],
1369
1435
  initial_condition_fun=self.initial_condition_fun_dict[i],
1370
1436
  norm_key=self.norm_key_dict[i],
1371
1437
  norm_borders=self.norm_borders_dict[i],
@@ -11,7 +11,7 @@ from jinns.utils._spinn import SPINN
11
11
 
12
12
 
13
13
  def _compute_boundary_loss_statio(
14
- boundary_condition_type, f, border_batch, u, params, facet
14
+ boundary_condition_type, f, border_batch, u, params, facet, dim_to_apply
15
15
  ):
16
16
  r"""A generic function that will compute the mini-batch MSE of a
17
17
  boundary condition in the stationary case, given by:
@@ -44,6 +44,9 @@ def _compute_boundary_loss_statio(
44
44
  facet:
45
45
  An integer which represents the id of the facet which is currently
46
46
  considered (in the order provided wy the DataGenerator which is fixed)
47
+ dim_to_apply
48
+ A jnp.s_ object which indicates which dimension(s) of u will be forced
49
+ to match the boundary condition
47
50
 
48
51
  Returns
49
52
  -------
@@ -51,17 +54,24 @@ def _compute_boundary_loss_statio(
51
54
  the MSE computed on `border_batch`
52
55
  """
53
56
  if boundary_condition_type.lower() in "dirichlet":
54
- mse = boundary_dirichlet_statio(f, border_batch, u, params)
57
+ mse = boundary_dirichlet_statio(f, border_batch, u, params, dim_to_apply)
55
58
  elif any(
56
59
  boundary_condition_type.lower() in s
57
60
  for s in ["von neumann", "vn", "vonneumann"]
58
61
  ):
59
- mse = boundary_neumann_statio(f, border_batch, u, params, facet)
62
+ mse = boundary_neumann_statio(f, border_batch, u, params, facet, dim_to_apply)
60
63
  return mse
61
64
 
62
65
 
63
66
  def _compute_boundary_loss_nonstatio(
64
- boundary_condition_type, f, times_batch, border_batch, u, params, facet
67
+ boundary_condition_type,
68
+ f,
69
+ times_batch,
70
+ border_batch,
71
+ u,
72
+ params,
73
+ facet,
74
+ dim_to_apply,
65
75
  ):
66
76
  r"""A generic function that will compute the mini-batch MSE of a
67
77
  boundary condition in the non-stationary case, given by:
@@ -95,6 +105,8 @@ def _compute_boundary_loss_nonstatio(
95
105
  facet:
96
106
  An integer which represents the id of the facet which is currently
97
107
  considered (in the order provided wy the DataGenerator which is fixed)
108
+ dim_to_apply
109
+ A jnp.s_ object. The dimension of u on which to apply the boundary condition
98
110
 
99
111
  Returns
100
112
  -------
@@ -102,16 +114,20 @@ def _compute_boundary_loss_nonstatio(
102
114
  the MSE computed on `border_batch`
103
115
  """
104
116
  if boundary_condition_type.lower() in "dirichlet":
105
- mse = boundary_dirichlet_nonstatio(f, times_batch, border_batch, u, params)
117
+ mse = boundary_dirichlet_nonstatio(
118
+ f, times_batch, border_batch, u, params, dim_to_apply
119
+ )
106
120
  elif any(
107
121
  boundary_condition_type.lower() in s
108
122
  for s in ["von neumann", "vn", "vonneumann"]
109
123
  ):
110
- mse = boundary_neumann_nonstatio(f, times_batch, border_batch, u, params, facet)
124
+ mse = boundary_neumann_nonstatio(
125
+ f, times_batch, border_batch, u, params, facet, dim_to_apply
126
+ )
111
127
  return mse
112
128
 
113
129
 
114
- def boundary_dirichlet_statio(f, border_batch, u, params):
130
+ def boundary_dirichlet_statio(f, border_batch, u, params, dim_to_apply):
115
131
  r"""
116
132
  This omega boundary condition enforces a solution that is equal to f on
117
133
  border batch.
@@ -129,17 +145,19 @@ def boundary_dirichlet_statio(f, border_batch, u, params):
129
145
  Typically, it is a dictionary of
130
146
  dictionaries: `eq_params` and `nn_params``, respectively the
131
147
  differential equation parameters and the neural network parameter
148
+ dim_to_apply
149
+ A jnp.s_ object. The dimension of u on which to apply the boundary condition
132
150
  """
133
151
  if isinstance(u, PINN):
134
152
  v_u_boundary = vmap(
135
- lambda dx: u(dx, params) - f(dx),
153
+ lambda dx: u(dx, params)[dim_to_apply] - f(dx),
136
154
  (0),
137
155
  0,
138
156
  )
139
157
 
140
158
  mse_u_boundary = jnp.sum((v_u_boundary(border_batch)) ** 2, axis=-1)
141
159
  elif isinstance(u, SPINN):
142
- values = u(border_batch, params)
160
+ values = u(border_batch, params)[..., dim_to_apply]
143
161
  x_grid = _get_grid(border_batch)
144
162
  boundaries = _check_user_func_return(f(x_grid), values.shape)
145
163
  res = values - boundaries
@@ -150,7 +168,7 @@ def boundary_dirichlet_statio(f, border_batch, u, params):
150
168
  return mse_u_boundary
151
169
 
152
170
 
153
- def boundary_neumann_statio(f, border_batch, u, params, facet):
171
+ def boundary_neumann_statio(f, border_batch, u, params, facet, dim_to_apply):
154
172
  r"""
155
173
  This omega boundary condition enforces a solution where :math:`\nabla u\cdot
156
174
  n` is equal to `f` on omega borders. :math:`n` is the unitary
@@ -173,6 +191,8 @@ def boundary_neumann_statio(f, border_batch, u, params, facet):
173
191
  facet:
174
192
  An integer which represents the id of the facet which is currently
175
193
  considered (in the order provided wy the DataGenerator which is fixed)
194
+ dim_to_apply
195
+ A jnp.s_ object. The dimension of u on which to apply the boundary condition
176
196
  """
177
197
  # We resort to the shape of the border_batch to determine the dimension as
178
198
  # described in the border_batch function
@@ -186,7 +206,7 @@ def boundary_neumann_statio(f, border_batch, u, params, facet):
186
206
  n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
187
207
 
188
208
  if isinstance(u, PINN):
189
- u_ = lambda x, params: u(x, params)[0]
209
+ u_ = lambda x, params: jnp.squeeze(u(x, params)[dim_to_apply])
190
210
  v_neumann = vmap(
191
211
  lambda dx: jnp.dot(
192
212
  grad(u_, 0)(dx, params),
@@ -206,7 +226,7 @@ def boundary_neumann_statio(f, border_batch, u, params, facet):
206
226
  lambda x: u(
207
227
  x,
208
228
  params,
209
- ),
229
+ )[..., dim_to_apply],
210
230
  (border_batch,),
211
231
  (jnp.ones_like(border_batch),),
212
232
  )
@@ -246,7 +266,9 @@ def boundary_neumann_statio(f, border_batch, u, params, facet):
246
266
  return mse_u_boundary
247
267
 
248
268
 
249
- def boundary_dirichlet_nonstatio(f, times_batch, omega_border_batch, u, params):
269
+ def boundary_dirichlet_nonstatio(
270
+ f, times_batch, omega_border_batch, u, params, dim_to_apply
271
+ ):
250
272
  """
251
273
  This omega boundary condition enforces a solution that is equal to f
252
274
  at times_batch x omega borders
@@ -266,6 +288,8 @@ def boundary_dirichlet_nonstatio(f, times_batch, omega_border_batch, u, params):
266
288
  Typically, it is a dictionary of
267
289
  dictionaries: `eq_params` and `nn_params``, respectively the
268
290
  differential equation parameters and the neural network parameter
291
+ dim_to_apply
292
+ A jnp.s_ object. The dimension of u on which to apply the boundary condition
269
293
  """
270
294
  if isinstance(u, PINN):
271
295
  tile_omega_border_batch = jnp.tile(
@@ -280,7 +304,7 @@ def boundary_dirichlet_nonstatio(f, times_batch, omega_border_batch, u, params):
280
304
  t,
281
305
  dx,
282
306
  params,
283
- )
307
+ )[dim_to_apply]
284
308
  - f(t, dx),
285
309
  (0, 0),
286
310
  0,
@@ -304,7 +328,7 @@ def boundary_dirichlet_nonstatio(f, times_batch, omega_border_batch, u, params):
304
328
  # otherwise we require batches to have same shape and we do not need
305
329
  # this operation
306
330
 
307
- values = u(times_batch, tile_omega_border_batch, params)
331
+ values = u(times_batch, tile_omega_border_batch, params)[..., dim_to_apply]
308
332
  tx_grid = _get_grid(jnp.concatenate([times_batch, omega_border_batch], axis=-1))
309
333
  boundaries = _check_user_func_return(
310
334
  f(tx_grid[..., 0:1], tx_grid[..., 1:]), values.shape
@@ -314,7 +338,9 @@ def boundary_dirichlet_nonstatio(f, times_batch, omega_border_batch, u, params):
314
338
  return mse_u_boundary
315
339
 
316
340
 
317
- def boundary_neumann_nonstatio(f, times_batch, omega_border_batch, u, params, facet):
341
+ def boundary_neumann_nonstatio(
342
+ f, times_batch, omega_border_batch, u, params, facet, dim_to_apply
343
+ ):
318
344
  r"""
319
345
  This omega boundary condition enforces a solution where :math:`\nabla u\cdot
320
346
  n` is equal to `f` at time_batch x omega borders. :math:`n` is the unitary
@@ -338,6 +364,8 @@ def boundary_neumann_nonstatio(f, times_batch, omega_border_batch, u, params, fa
338
364
  facet:
339
365
  An integer which represents the id of the facet which is currently
340
366
  considered (in the order provided wy the DataGenerator which is fixed)
367
+ dim_to_apply
368
+ A jnp.s_ object. The dimension of u on which to apply the boundary condition
341
369
  """
342
370
  # We resort to the shape of the border_batch to determine the dimension as
343
371
  # described in the border_batch function
@@ -358,7 +386,7 @@ def boundary_neumann_nonstatio(f, times_batch, omega_border_batch, u, params, fa
358
386
  def rep_times(k):
359
387
  return jnp.repeat(times_batch, k, axis=0)
360
388
 
361
- u_ = lambda t, x, params: u(t, x, params)[0]
389
+ u_ = lambda t, x, params: jnp.squeeze(u(t, x, params)[dim_to_apply])
362
390
  v_neumann = vmap(
363
391
  lambda t, dx: jnp.dot(
364
392
  grad(u_, 1)(t, dx, params),
@@ -389,7 +417,7 @@ def boundary_neumann_nonstatio(f, times_batch, omega_border_batch, u, params, fa
389
417
  # high dim output at once
390
418
  if omega_border_batch.shape[0] == 1: # i.e. case 1D
391
419
  _, du_dx = jax.jvp(
392
- lambda x: u(times_batch, x, params),
420
+ lambda x: u(times_batch, x, params)[..., dim_to_apply],
393
421
  (omega_border_batch,),
394
422
  (jnp.ones_like(omega_border_batch),),
395
423
  )
@@ -402,12 +430,12 @@ def boundary_neumann_nonstatio(f, times_batch, omega_border_batch, u, params, fa
402
430
  jnp.array([0.0, 1.0])[None], omega_border_batch.shape[0], axis=0
403
431
  )
404
432
  _, du_dx1 = jax.jvp(
405
- lambda x: u(times_batch, x, params),
433
+ lambda x: u(times_batch, x, params)[..., dim_to_apply],
406
434
  (omega_border_batch,),
407
435
  (tangent_vec_0,),
408
436
  )
409
437
  _, du_dx2 = jax.jvp(
410
- lambda x: u(times_batch, x, params),
438
+ lambda x: u(times_batch, x, params)[..., dim_to_apply],
411
439
  (omega_border_batch,),
412
440
  (tangent_vec_1,),
413
441
  )
jinns/loss/_operators.py CHANGED
@@ -9,7 +9,7 @@ from jinns.utils._pinn import PINN
9
9
  from jinns.utils._spinn import SPINN
10
10
 
11
11
 
12
- def _div_rev(u, params, x, t=None):
12
+ def _div_rev(t, x, u, params):
13
13
  r"""
14
14
  Compute the divergence of a vector field :math:`\mathbf{u}`, i.e.,
15
15
  :math:`\nabla \cdot \mathbf{u}(\mathbf{x})` with :math:`\mathbf{u}` a vector
@@ -21,14 +21,14 @@ def _div_rev(u, params, x, t=None):
21
21
  if t is None:
22
22
  du_dxi = grad(lambda x, params: u(x, params)[i], 0)(x, params)[i]
23
23
  else:
24
- du_dxi = grad(lambda t, x, params: u(x, params)[i], 1)(x, params)[i]
24
+ du_dxi = grad(lambda t, x, params: u(t, x, params)[i], 1)(t, x, params)[i]
25
25
  return _, du_dxi
26
26
 
27
27
  _, accu = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[0]))
28
28
  return jnp.sum(accu)
29
29
 
30
30
 
31
- def _div_fwd(u, params, x, t=None):
31
+ def _div_fwd(t, x, u, params):
32
32
  r"""
33
33
  Compute the divergence of a **batched** vector field :math:`\mathbf{u}`, i.e.,
34
34
  :math:`\nabla \cdot \mathbf{u}(\mathbf{x})` with :math:`\mathbf{u}` a vector
@@ -55,7 +55,7 @@ def _div_fwd(u, params, x, t=None):
55
55
  return jnp.sum(accu, axis=0)
56
56
 
57
57
 
58
- def _laplacian_rev(u, params, x, t=None):
58
+ def _laplacian_rev(t, x, u, params):
59
59
  r"""
60
60
  Compute the Laplacian of a scalar field :math:`u` (from :math:`\mathbb{R}^d`
61
61
  to :math:`\mathbb{R}`) for :math:`\mathbf{x}` of arbitrary dimension, i.e.,
@@ -98,7 +98,7 @@ def _laplacian_rev(u, params, x, t=None):
98
98
  # return jnp.sum(trace_hessian)
99
99
 
100
100
 
101
- def _laplacian_fwd(u, params, x, t=None):
101
+ def _laplacian_fwd(t, x, u, params):
102
102
  r"""
103
103
  Compute the Laplacian of a **batched** scalar field :math:`u`
104
104
  (from :math:`\mathbb{R}^{b\times d}` to :math:`\mathbb{R}^{b\times b}`)
@@ -134,7 +134,7 @@ def _laplacian_fwd(u, params, x, t=None):
134
134
  return jnp.sum(trace_hessian, axis=0)
135
135
 
136
136
 
137
- def _vectorial_laplacian(u, params, x, t=None, u_vec_ndim=None):
137
+ def _vectorial_laplacian(t, x, u, params, u_vec_ndim=None):
138
138
  r"""
139
139
  Compute the vectorial Laplacian of a vector field :math:`\mathbf{u}` (from
140
140
  :math:`\mathbb{R}^d`
@@ -163,7 +163,7 @@ def _vectorial_laplacian(u, params, x, t=None, u_vec_ndim=None):
163
163
  uj = lambda x, params: jnp.expand_dims(u(x, params)[j], axis=-1)
164
164
  else:
165
165
  uj = lambda t, x, params: jnp.expand_dims(u(t, x, params)[j], axis=-1)
166
- lap_on_j = _laplacian_rev(uj, params, x, t)
166
+ lap_on_j = _laplacian_rev(t, x, uj, params)
167
167
  elif isinstance(u, SPINN):
168
168
  if t is None:
169
169
  uj = lambda x, params: jnp.expand_dims(u(x, params)[..., j], axis=-1)
@@ -171,7 +171,7 @@ def _vectorial_laplacian(u, params, x, t=None, u_vec_ndim=None):
171
171
  uj = lambda t, x, params: jnp.expand_dims(
172
172
  u(t, x, params)[..., j], axis=-1
173
173
  )
174
- lap_on_j = _laplacian_fwd(uj, params, x, t)
174
+ lap_on_j = _laplacian_fwd(t, x, uj, params)
175
175
 
176
176
  return _, lap_on_j
177
177
 
@@ -179,7 +179,7 @@ def _vectorial_laplacian(u, params, x, t=None, u_vec_ndim=None):
179
179
  return vec_lap
180
180
 
181
181
 
182
- def _u_dot_nabla_times_u_rev(u, params, x, t=None):
182
+ def _u_dot_nabla_times_u_rev(t, x, u, params):
183
183
  r"""
184
184
  Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
185
185
  :math:`\mathbf{x}` of arbitrary
@@ -224,7 +224,7 @@ def _u_dot_nabla_times_u_rev(u, params, x, t=None):
224
224
  raise NotImplementedError("x.ndim must be 2")
225
225
 
226
226
 
227
- def _u_dot_nabla_times_u_fwd(u, params, x, t=None):
227
+ def _u_dot_nabla_times_u_fwd(t, x, u, params):
228
228
  r"""
229
229
  Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
230
230
  :math:`\mathbf{x}` of arbitrary dimension **with a batch dimension**.
jinns/utils/_pinn.py CHANGED
@@ -64,16 +64,44 @@ class PINN:
64
64
  The function create_PINN has the role to population the `__call__` function
65
65
  """
66
66
 
67
- def __init__(self, key, eqx_list, output_slice=None):
67
+ def __init__(
68
+ self,
69
+ key,
70
+ eqx_list,
71
+ slice_solution,
72
+ eq_type,
73
+ input_transform,
74
+ output_transform,
75
+ output_slice=None,
76
+ ):
68
77
  _pinn = _MLP(key, eqx_list)
69
78
  self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
79
+ self.slice_solution = slice_solution
80
+ self.eq_type = eq_type
81
+ self.input_transform = input_transform
82
+ self.output_transform = output_transform
70
83
  self.output_slice = output_slice
71
84
 
72
85
  def init_params(self):
73
86
  return self.params
74
87
 
75
- def __call__(self, *args, **kwargs):
76
- return self.apply_fn(self, *args, **kwargs)
88
+ def __call__(self, *args):
89
+ if self.eq_type == "ODE":
90
+ (t, params) = args
91
+ t = t[None] # Add dimension which is lacking for the ODE batches
92
+ return self._eval_nn(
93
+ t, params, self.input_transform, self.output_transform
94
+ ).squeeze()
95
+ if self.eq_type == "statio_PDE":
96
+ (x, params) = args
97
+ return self._eval_nn(x, params, self.input_transform, self.output_transform)
98
+ if self.eq_type == "nonstatio_PDE":
99
+ (t, x, params) = args
100
+ t_x = jnp.concatenate([t, x], axis=-1)
101
+ return self._eval_nn(
102
+ t_x, params, self.input_transform, self.output_transform
103
+ )
104
+ raise ValueError("Wrong value for self.eq_type")
77
105
 
78
106
  def _eval_nn(self, inputs, params, input_transform, output_transform):
79
107
  """
@@ -82,7 +110,7 @@ class PINN:
82
110
  """
83
111
  try:
84
112
  model = eqx.combine(params["nn_params"], self.static)
85
- except: # give more flexibility
113
+ except (KeyError, TypeError) as e: # give more flexibility
86
114
  model = eqx.combine(params, self.static)
87
115
  res = output_transform(inputs, model(input_transform(inputs, params)).squeeze())
88
116
 
@@ -103,6 +131,7 @@ def create_PINN(
103
131
  input_transform=None,
104
132
  output_transform=None,
105
133
  shared_pinn_outputs=None,
134
+ slice_solution=None,
106
135
  ):
107
136
  """
108
137
  Utility function to create a standard PINN neural network with the equinox
@@ -152,6 +181,13 @@ def create_PINN(
152
181
  network. In this case we return a list of PINNs, one for each output in
153
182
  shared_pinn_outputs. This is useful to create PINNs that share the
154
183
  same network and same parameters. Default is None, we only return one PINN.
184
+ slice_solution
185
+ A jnp.s_ object which indicates which axis of the PINN output is
186
+ dedicated to the actual equation solution. Default None
187
+ means that slice_solution = the whole PINN output. This argument is useful
188
+ when the PINN is also used to output equation parameters for example
189
+ Note that it must be a slice and not an integer (a preprocessing of the
190
+ user provided argument takes care of it)
155
191
 
156
192
 
157
193
  Returns
@@ -182,6 +218,19 @@ def create_PINN(
182
218
  if eq_type != "ODE" and dim_x == 0:
183
219
  raise RuntimeError("Wrong parameter combination eq_type and dim_x")
184
220
 
221
+ try:
222
+ nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
223
+ # last layer
224
+ except IndexError:
225
+ nb_outputs_declared = eqx_list[-2][2]
226
+
227
+ if slice_solution is None:
228
+ slice_solution = jnp.s_[0:nb_outputs_declared]
229
+ if isinstance(slice_solution, int):
230
+ # rewrite it as a slice to ensure that axis does not disappear when
231
+ # indexing
232
+ slice_solution = jnp.s_[slice_solution : slice_solution + 1]
233
+
185
234
  if input_transform is None:
186
235
 
187
236
  def input_transform(_in, _params):
@@ -192,34 +241,19 @@ def create_PINN(
192
241
  def output_transform(_in_pinn, _out_pinn):
193
242
  return _out_pinn
194
243
 
195
- if eq_type == "ODE":
196
-
197
- def apply_fn(self, t, params):
198
- t = t[
199
- None
200
- ] # Note that we added a dimension to t which is lacking for the ODE batches
201
- return self._eval_nn(t, params, input_transform, output_transform).squeeze()
202
-
203
- elif eq_type == "statio_PDE":
204
- # Here we add an argument `x` which can be high dimensional
205
- def apply_fn(self, x, params):
206
- return self._eval_nn(x, params, input_transform, output_transform)
207
-
208
- elif eq_type == "nonstatio_PDE":
209
- # Here we add an argument `x` which can be high dimensional
210
- def apply_fn(self, t, x, params):
211
- t_x = jnp.concatenate([t, x], axis=-1)
212
- return self._eval_nn(t_x, params, input_transform, output_transform)
213
-
214
- else:
215
- raise RuntimeError("Wrong parameter value for eq_type")
216
-
217
244
  if shared_pinn_outputs is not None:
218
245
  pinns = []
219
246
  static = None
220
247
  for output_slice in shared_pinn_outputs:
221
- pinn = PINN(key, eqx_list, output_slice)
222
- pinn.apply_fn = apply_fn
248
+ pinn = PINN(
249
+ key,
250
+ eqx_list,
251
+ slice_solution,
252
+ eq_type,
253
+ input_transform,
254
+ output_transform,
255
+ output_slice,
256
+ )
223
257
  # all the pinns are in fact the same so we share the same static
224
258
  if static is None:
225
259
  static = pinn.static
@@ -227,6 +261,7 @@ def create_PINN(
227
261
  pinn.static = static
228
262
  pinns.append(pinn)
229
263
  return pinns
230
- pinn = PINN(key, eqx_list)
231
- pinn.apply_fn = apply_fn
264
+ pinn = PINN(
265
+ key, eqx_list, slice_solution, eq_type, input_transform, output_transform
266
+ )
232
267
  return pinn
jinns/utils/_spinn.py CHANGED
@@ -82,16 +82,35 @@ class SPINN:
82
82
  The function create_SPINN has the role to population the `__call__` function
83
83
  """
84
84
 
85
- def __init__(self, key, d, r, eqx_list, m=1):
85
+ def __init__(self, key, d, r, eqx_list, eq_type, m=1):
86
86
  self.d, self.r, self.m = d, r, m
87
87
  _spinn = _SPINN(key, d, r, eqx_list, m)
88
88
  self.params, self.static = eqx.partition(_spinn, eqx.is_inexact_array)
89
+ self.eq_type = eq_type
89
90
 
90
91
  def init_params(self):
91
92
  return self.params
92
93
 
93
- def __call__(self, *args, **kwargs):
94
- return self.apply_fn(self, *args, **kwargs)
94
+ def __call__(self, *args):
95
+ if self.eq_type == "statio_PDE":
96
+ (x, params) = args
97
+ try:
98
+ spinn = eqx.combine(params["nn_params"], self.static)
99
+ except (KeyError, TypeError) as e:
100
+ spinn = eqx.combine(params, self.static)
101
+ v_model = jax.vmap(spinn, (0))
102
+ res = v_model(t=None, x=x)
103
+ return self._eval_nn(res)
104
+ if self.eq_type == "nonstatio_PDE":
105
+ (t, x, params) = args
106
+ try:
107
+ spinn = eqx.combine(params["nn_params"], self.static)
108
+ except (KeyError, TypeError) as e:
109
+ spinn = eqx.combine(params, self.static)
110
+ v_model = jax.vmap(spinn, ((0, 0)))
111
+ res = v_model(t, x)
112
+ return self._eval_nn(res)
113
+ raise RuntimeError("Wrong parameter value for eq_type")
95
114
 
96
115
  def _eval_nn(self, res):
97
116
  """
@@ -209,32 +228,6 @@ def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
209
228
  "Too many dimensions, not enough letters available in jnp.einsum"
210
229
  )
211
230
 
212
- if eq_type == "statio_PDE":
213
-
214
- def apply_fn(self, x, params):
215
- try:
216
- spinn = eqx.combine(params["nn_params"], self.static)
217
- except: # give more flexibility
218
- spinn = eqx.combine(params, self.static)
219
- v_model = jax.vmap(spinn, (0))
220
- res = v_model(t=None, x=x)
221
- return self._eval_nn(res)
222
-
223
- elif eq_type == "nonstatio_PDE":
224
-
225
- def apply_fn(self, t, x, params):
226
- try:
227
- spinn = eqx.combine(params["nn_params"], self.static)
228
- except: # give more flexibility
229
- spinn = eqx.combine(params, self.static)
230
- v_model = jax.vmap(spinn, ((0, 0)))
231
- res = v_model(t, x)
232
- return self._eval_nn(res)
233
-
234
- else:
235
- raise RuntimeError("Wrong parameter value for eq_type")
236
-
237
- spinn = SPINN(key, d, r, eqx_list, m)
238
- spinn.apply_fn = apply_fn
231
+ spinn = SPINN(key, d, r, eqx_list, eq_type, m)
239
232
 
240
233
  return spinn
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.6.0
3
+ Version: 0.6.1
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -2,24 +2,24 @@ jinns/__init__.py,sha256=Nw5pdlmDhJwco3bXX3YttkeCF8czX_6m0poh8vu0lDQ,113
2
2
  jinns/data/_DataGenerators.py,sha256=nIuKtkX4V4ckfT4-g0bjlY7BLkgcok5JbI9OzJn73mA,44461
3
3
  jinns/data/__init__.py,sha256=S13J59Fxuph4uNJ542fP_Mj8U72ilhb5t_UQ-c1k3nY,232
4
4
  jinns/data/_display.py,sha256=NfINLJAGmQSPz30cWVaeQFpabzCXprp4RNH6Iycx-VU,7722
5
- jinns/loss/_DynamicLoss.py,sha256=_2ezagpJhnOKxH8--r-Dld4YN_9GAGtGbBO2ehr8uyM,28250
6
- jinns/loss/_DynamicLossAbstract.py,sha256=Sldyu3C8VKLUKhj7Hb4dU84dq0RFvwa3DuKFRD1mK7A,4819
5
+ jinns/loss/_DynamicLoss.py,sha256=JOn6rpVcM825QHFWBQBOGCrR9tvStDaZiaB6pxZ0fhE,27336
6
+ jinns/loss/_DynamicLossAbstract.py,sha256=NnvS8WKJZZZiGMfidxpGlkmWtINn_UDj3BNXe9l9O84,7787
7
7
  jinns/loss/_LossODE.py,sha256=nYWzUVO6vpZMYXZkeUpDRMZSowyGfGWUdwzBNciQkyo,20639
8
- jinns/loss/_LossPDE.py,sha256=RrZRhMEIHbQn9Fc1SjkJoGiaOHHFaQJdFBbx7V4Ug1E,68822
8
+ jinns/loss/_LossPDE.py,sha256=gPi75x4MTjNSxY0Vz7swnsG1serPOtIPeKYsV5uYkr4,72221
9
9
  jinns/loss/__init__.py,sha256=pFNYUxns-NPXBFdqrEVSiXkQLfCtKw-t2trlhvLzpYE,355
10
- jinns/loss/_boundary_conditions.py,sha256=nKxVSgq9GfmxDQlSilzFFq3uV0t2GjxvUirCtBZvcE0,15235
11
- jinns/loss/_operators.py,sha256=ye4JP8SVJ3Shr7nWCiIxa5KmgPMZcCcHduRgq8cm-J4,11098
10
+ jinns/loss/_boundary_conditions.py,sha256=K4cLYe5ecW2gJznv51nM3zOvE4UO-ZvVmVp3VIiuddM,16296
11
+ jinns/loss/_operators.py,sha256=zDGJqYqeYH7xd-4dtGX9PS-pf0uSOpUUXGo5SVjIJ4o,11069
12
12
  jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  jinns/solver/_rar.py,sha256=ZoklcHq5MFLUsIiWMUH_426MRxq7BUZkD86jcQYau0I,13918
14
14
  jinns/solver/_seq2seq.py,sha256=ihHnb6UpvShgEXOynUhsrPyt0wqNXIjeEL8HxvgRRHE,5985
15
15
  jinns/solver/_solve.py,sha256=-Ji16NJ9CMBcFDKLFpBibvQwow9ysAB4pVGwRRqoXRE,10138
16
16
  jinns/utils/__init__.py,sha256=bksGuq0mNoqciKIhEA8wOHkUppYQYJqOiHK2SEn5jac,227
17
17
  jinns/utils/_optim.py,sha256=Q8a_dMqb6rjdv1qGibgvl7cv8DENfXsthDJ56RziWJc,4441
18
- jinns/utils/_pinn.py,sha256=oxAuqTt3biGx66sOZAg_NrioaxqL9BAWI0Jj6lQfl7o,8142
19
- jinns/utils/_spinn.py,sha256=RB3erlaD2iQ1_SalsakIslBF-3Iizatb9NLG97SZ5es,7943
18
+ jinns/utils/_pinn.py,sha256=wEbCU4aDT8nd1sz72ZrgR8e-xlR-P0iczHk_hnnZQBo,9427
19
+ jinns/utils/_spinn.py,sha256=RZveR5R2HKMyQKuWQE_H6WKAfxABCkB6YGKyDZK051w,7889
20
20
  jinns/utils/_utils.py,sha256=b0-R6jtyMescaDfod0DuzeLMpm1C7gTwpU1MZEW2468,6060
21
- jinns-0.6.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
22
- jinns-0.6.0.dist-info/METADATA,sha256=gezTMTKxIyTGKWMKkhylm2mNxNiZY4w_NysNbyiizn4,2388
23
- jinns-0.6.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
24
- jinns-0.6.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
25
- jinns-0.6.0.dist-info/RECORD,,
21
+ jinns-0.6.1.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
22
+ jinns-0.6.1.dist-info/METADATA,sha256=yUHlhcVpsm8b1vHJSy5_TZjEUaaufI3niTvWZNXehks,2388
23
+ jinns-0.6.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
24
+ jinns-0.6.1.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
25
+ jinns-0.6.1.dist-info/RECORD,,
File without changes
File without changes