jinns 0.8.7__py3-none-any.whl → 0.8.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +1 -0
- jinns/data/_display.py +102 -13
- jinns/experimental/__init__.py +2 -0
- jinns/experimental/_sinuspinn.py +135 -0
- jinns/experimental/_spectralpinn.py +87 -0
- {jinns-0.8.7.dist-info → jinns-0.8.8.dist-info}/METADATA +1 -1
- {jinns-0.8.7.dist-info → jinns-0.8.8.dist-info}/RECORD +10 -8
- {jinns-0.8.7.dist-info → jinns-0.8.8.dist-info}/LICENSE +0 -0
- {jinns-0.8.7.dist-info → jinns-0.8.8.dist-info}/WHEEL +0 -0
- {jinns-0.8.7.dist-info → jinns-0.8.8.dist-info}/top_level.txt +0 -0
jinns/__init__.py
CHANGED
jinns/data/_display.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for plotting
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from functools import partial
|
|
6
|
+
import warnings
|
|
1
7
|
import matplotlib.pyplot as plt
|
|
2
8
|
import jax.numpy as jnp
|
|
3
9
|
from jax import vmap
|
|
4
10
|
from mpl_toolkits.axes_grid1 import ImageGrid
|
|
5
|
-
from functools import partial
|
|
6
11
|
|
|
7
12
|
|
|
8
13
|
def plot2d(
|
|
@@ -14,8 +19,10 @@ def plot2d(
|
|
|
14
19
|
figsize=(7, 7),
|
|
15
20
|
cmap="inferno",
|
|
16
21
|
spinn=False,
|
|
22
|
+
vmin_vmax=None,
|
|
23
|
+
ax_for_plot=None,
|
|
17
24
|
):
|
|
18
|
-
"""Generic function for plotting functions over rectangular 2-D domains
|
|
25
|
+
r"""Generic function for plotting functions over rectangular 2-D domains
|
|
19
26
|
:math:`\Omega`. It treats both the stationary case :math:`u(x)` or the
|
|
20
27
|
non-stationnary case :math:`u(t, x)`.
|
|
21
28
|
|
|
@@ -40,6 +47,12 @@ def plot2d(
|
|
|
40
47
|
_description_, by default (7, 7)
|
|
41
48
|
cmap : str, optional
|
|
42
49
|
_description_, by default "inferno"
|
|
50
|
+
vmin_vmax : tuple, optional
|
|
51
|
+
The colorbar minimum and maximum value. Defaults None.
|
|
52
|
+
ax_for_plot : Matplotlib axis, optional
|
|
53
|
+
If None, jinns triggers the plotting. Otherwise this argument
|
|
54
|
+
corresponds to the axis which will host the plot. Default is None.
|
|
55
|
+
NOTE: that this argument will have an effect only if times is None.
|
|
43
56
|
|
|
44
57
|
Raises
|
|
45
58
|
------
|
|
@@ -59,25 +72,49 @@ def plot2d(
|
|
|
59
72
|
# Statio case : expect a function of one argument fun(x)
|
|
60
73
|
if not spinn:
|
|
61
74
|
v_fun = vmap(fun, 0, 0)
|
|
62
|
-
_plot_2D_statio(
|
|
63
|
-
v_fun,
|
|
75
|
+
ret = _plot_2D_statio(
|
|
76
|
+
v_fun,
|
|
77
|
+
mesh,
|
|
78
|
+
plot=not ax_for_plot,
|
|
79
|
+
colorbar=True,
|
|
80
|
+
cmap=cmap,
|
|
81
|
+
figsize=figsize,
|
|
82
|
+
vmin_vmax=vmin_vmax,
|
|
64
83
|
)
|
|
65
84
|
elif spinn:
|
|
66
85
|
values_grid = jnp.squeeze(
|
|
67
86
|
fun(jnp.stack([xy_data[0][..., None], xy_data[1][..., None]], axis=1))
|
|
68
87
|
)
|
|
69
|
-
_plot_2D_statio(
|
|
88
|
+
ret = _plot_2D_statio(
|
|
70
89
|
values_grid,
|
|
71
90
|
mesh,
|
|
72
|
-
plot=
|
|
91
|
+
plot=not ax_for_plot,
|
|
73
92
|
colorbar=True,
|
|
74
93
|
cmap=cmap,
|
|
75
94
|
spinn=True,
|
|
76
95
|
figsize=figsize,
|
|
96
|
+
vmin_vmax=vmin_vmax,
|
|
77
97
|
)
|
|
78
|
-
|
|
98
|
+
if not ax_for_plot:
|
|
99
|
+
plt.title(title)
|
|
100
|
+
else:
|
|
101
|
+
if vmin_vmax is not None:
|
|
102
|
+
im = ax_for_plot.pcolormesh(
|
|
103
|
+
mesh[0],
|
|
104
|
+
mesh[1],
|
|
105
|
+
ret[0],
|
|
106
|
+
cmap=cmap,
|
|
107
|
+
vmin=vmin_vmax[0],
|
|
108
|
+
vmax=vmin_vmax[1],
|
|
109
|
+
)
|
|
110
|
+
else:
|
|
111
|
+
im = ax_for_plot.pcolormesh(mesh[0], mesh[1], ret[0], cmap=cmap)
|
|
112
|
+
ax_for_plot.set_title(title)
|
|
113
|
+
ax_for_plot.cax.colorbar(im, format="%0.2f")
|
|
79
114
|
|
|
80
115
|
else:
|
|
116
|
+
if ax_for_plot is not None:
|
|
117
|
+
warnings.warn("ax_for_plot is ignored. jinns will plot the figure")
|
|
81
118
|
if not isinstance(times, list):
|
|
82
119
|
try:
|
|
83
120
|
times = times.tolist()
|
|
@@ -101,7 +138,12 @@ def plot2d(
|
|
|
101
138
|
if not spinn:
|
|
102
139
|
v_fun_at_t = vmap(lambda x: fun(t=jnp.array([t]), x=x), 0, 0)
|
|
103
140
|
t_slice, _ = _plot_2D_statio(
|
|
104
|
-
v_fun_at_t,
|
|
141
|
+
v_fun_at_t,
|
|
142
|
+
mesh,
|
|
143
|
+
plot=False,
|
|
144
|
+
colorbar=False,
|
|
145
|
+
cmap=None,
|
|
146
|
+
vmin_vmax=vmin_vmax,
|
|
105
147
|
)
|
|
106
148
|
elif spinn:
|
|
107
149
|
values_grid = jnp.squeeze(
|
|
@@ -113,15 +155,37 @@ def plot2d(
|
|
|
113
155
|
)[0]
|
|
114
156
|
)
|
|
115
157
|
t_slice, _ = _plot_2D_statio(
|
|
116
|
-
values_grid,
|
|
158
|
+
values_grid,
|
|
159
|
+
mesh,
|
|
160
|
+
plot=False,
|
|
161
|
+
colorbar=True,
|
|
162
|
+
spinn=True,
|
|
163
|
+
vmin_vmax=vmin_vmax,
|
|
164
|
+
)
|
|
165
|
+
if vmin_vmax is not None:
|
|
166
|
+
im = ax.pcolormesh(
|
|
167
|
+
mesh[0],
|
|
168
|
+
mesh[1],
|
|
169
|
+
t_slice,
|
|
170
|
+
cmap=cmap,
|
|
171
|
+
vmin=vmin_vmax[0],
|
|
172
|
+
vmax=vmin_vmax[1],
|
|
117
173
|
)
|
|
118
|
-
|
|
174
|
+
else:
|
|
175
|
+
im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
|
|
119
176
|
ax.set_title(f"t = {times[idx] * Tmax:.2f}")
|
|
120
177
|
ax.cax.colorbar(im, format="%0.2f")
|
|
121
178
|
|
|
122
179
|
|
|
123
180
|
def _plot_2D_statio(
|
|
124
|
-
v_fun,
|
|
181
|
+
v_fun,
|
|
182
|
+
mesh,
|
|
183
|
+
plot=True,
|
|
184
|
+
colorbar=True,
|
|
185
|
+
cmap="inferno",
|
|
186
|
+
figsize=(7, 7),
|
|
187
|
+
spinn=False,
|
|
188
|
+
vmin_vmax=None,
|
|
125
189
|
):
|
|
126
190
|
"""Function that plot the function u(x) with 2-D input x using pcolormesh()
|
|
127
191
|
|
|
@@ -136,6 +200,8 @@ def _plot_2D_statio(
|
|
|
136
200
|
either show or return the plot, by default True
|
|
137
201
|
colorbar : bool, optional
|
|
138
202
|
add a colorbar, by default True
|
|
203
|
+
vmin_vmax: tuple, optional
|
|
204
|
+
The colorbar minimum and maximum value. Defaults None.
|
|
139
205
|
|
|
140
206
|
Returns
|
|
141
207
|
-------
|
|
@@ -153,7 +219,17 @@ def _plot_2D_statio(
|
|
|
153
219
|
|
|
154
220
|
if plot:
|
|
155
221
|
fig = plt.figure(figsize=figsize)
|
|
156
|
-
|
|
222
|
+
if vmin_vmax is not None:
|
|
223
|
+
im = plt.pcolormesh(
|
|
224
|
+
x_grid,
|
|
225
|
+
y_grid,
|
|
226
|
+
values_grid,
|
|
227
|
+
cmap=cmap,
|
|
228
|
+
vmin=vmin_vmax[0],
|
|
229
|
+
vmax=vmin_vmax[1],
|
|
230
|
+
)
|
|
231
|
+
else:
|
|
232
|
+
im = plt.pcolormesh(x_grid, y_grid, values_grid, cmap=cmap)
|
|
157
233
|
if colorbar:
|
|
158
234
|
fig.colorbar(im, format="%0.2f")
|
|
159
235
|
# don't plt.show() because it is done in plot2d()
|
|
@@ -217,6 +293,7 @@ def plot1d_image(
|
|
|
217
293
|
colorbar=True,
|
|
218
294
|
cmap="inferno",
|
|
219
295
|
spinn=False,
|
|
296
|
+
vmin_vmax=None,
|
|
220
297
|
):
|
|
221
298
|
"""Function for plotting the 2-D image of a function :math:`f(t, x)` where
|
|
222
299
|
`t` is time (1-D) and x is space (1-D).
|
|
@@ -237,6 +314,8 @@ def plot1d_image(
|
|
|
237
314
|
, by default ""
|
|
238
315
|
figsize : tuple, optional
|
|
239
316
|
, by default (10, 10)
|
|
317
|
+
vmin_vmax: tuple
|
|
318
|
+
The colorbar minimum and maximum value. Defaults None.
|
|
240
319
|
"""
|
|
241
320
|
|
|
242
321
|
mesh = jnp.meshgrid(times, xdata) # cartesian product
|
|
@@ -250,7 +329,17 @@ def plot1d_image(
|
|
|
250
329
|
elif spinn:
|
|
251
330
|
values_grid = jnp.squeeze(fun((times[..., None]), xdata[..., None]).T)
|
|
252
331
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
253
|
-
|
|
332
|
+
if vmin_vmax is not None:
|
|
333
|
+
im = ax.pcolormesh(
|
|
334
|
+
mesh[0] * Tmax,
|
|
335
|
+
mesh[1],
|
|
336
|
+
values_grid,
|
|
337
|
+
cmap=cmap,
|
|
338
|
+
vmin=vmin_vmax[0],
|
|
339
|
+
vmax=vmin_vmax[1],
|
|
340
|
+
)
|
|
341
|
+
else:
|
|
342
|
+
im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
|
|
254
343
|
if colorbar:
|
|
255
344
|
fig.colorbar(im, format="%0.2f")
|
|
256
345
|
ax.set_title(title)
|
jinns/experimental/__init__.py
CHANGED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import equinox as eqx
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jinns.utils._pinn import PINN
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
|
|
8
|
+
out, in_ = weight.shape
|
|
9
|
+
stddev = 1e-2
|
|
10
|
+
return stddev * jax.random.normal(key, shape=(out, in_))
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _SinusPINN(eqx.Module):
|
|
14
|
+
"""
|
|
15
|
+
A specific PINN whose layers are x_sin2x functions whose frequencies are
|
|
16
|
+
determined by an other network
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
layers_pinn: list
|
|
20
|
+
layers_aux_nn: list
|
|
21
|
+
|
|
22
|
+
def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
|
|
23
|
+
"""
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
key
|
|
27
|
+
A jax random key
|
|
28
|
+
list_layers_pinn
|
|
29
|
+
A list as eqx_list in jinns' PINN utility for the main PINN
|
|
30
|
+
list_layers_aux_nn
|
|
31
|
+
A list as eqx_list in jinns' PINN utility for the network which outputs
|
|
32
|
+
the PINN's activation frequencies
|
|
33
|
+
"""
|
|
34
|
+
self.layers_pinn = []
|
|
35
|
+
for l in list_layers_pinn:
|
|
36
|
+
if len(l) == 1:
|
|
37
|
+
self.layers_pinn.append(l[0])
|
|
38
|
+
else:
|
|
39
|
+
key, subkey = jax.random.split(key, 2)
|
|
40
|
+
self.layers_pinn.append(l[0](*l[1:], key=subkey))
|
|
41
|
+
self.layers_aux_nn = []
|
|
42
|
+
for idx, l in enumerate(list_layers_aux_nn):
|
|
43
|
+
if len(l) == 1:
|
|
44
|
+
self.layers_aux_nn.append(l[0])
|
|
45
|
+
else:
|
|
46
|
+
key, subkey = jax.random.split(key, 2)
|
|
47
|
+
linear_layer = l[0](*l[1:], key=subkey)
|
|
48
|
+
key, subkey = jax.random.split(key, 2)
|
|
49
|
+
linear_layer = eqx.tree_at(
|
|
50
|
+
lambda l: l.weight,
|
|
51
|
+
linear_layer,
|
|
52
|
+
almost_zero_init(linear_layer.weight, subkey),
|
|
53
|
+
)
|
|
54
|
+
if (idx == len(list_layers_aux_nn) - 1) or (
|
|
55
|
+
idx == len(list_layers_aux_nn) - 2
|
|
56
|
+
):
|
|
57
|
+
# for the last layer: almost 0 weights and 0.5 bias
|
|
58
|
+
linear_layer = eqx.tree_at(
|
|
59
|
+
lambda l: l.bias,
|
|
60
|
+
linear_layer,
|
|
61
|
+
0.5 * jnp.ones(linear_layer.bias.shape),
|
|
62
|
+
)
|
|
63
|
+
else:
|
|
64
|
+
# for the other previous layers:
|
|
65
|
+
# almost 0 weight and 0 bias
|
|
66
|
+
linear_layer = eqx.tree_at(
|
|
67
|
+
lambda l: l.bias,
|
|
68
|
+
linear_layer,
|
|
69
|
+
jnp.zeros(linear_layer.bias.shape),
|
|
70
|
+
)
|
|
71
|
+
self.layers_aux_nn.append(linear_layer)
|
|
72
|
+
|
|
73
|
+
## init to zero the frequency network except last biases
|
|
74
|
+
# key, subkey = jax.random.split(key, 2)
|
|
75
|
+
# _pinn = init_linear_weight(_pinn, almost_zero_init, subkey)
|
|
76
|
+
# key, subkey = jax.random.split(key, 2)
|
|
77
|
+
# _pinn = init_linear_bias(_pinn, zero_init, subkey)
|
|
78
|
+
# print(_pinn)
|
|
79
|
+
# print(jax.tree_util.tree_leaves(_pinn, is_leaf=lambda
|
|
80
|
+
# p:not isinstance(p,eqx.nn.Linear))[0].layers_aux_nn[-1].bias)
|
|
81
|
+
# _pinn = eqx.tree_at(lambda p:_pinn.layers_aux_nn[-1].bias, 0.5 *
|
|
82
|
+
# jnp.ones(_pinn.layers_aux_nn[-1].bias.shape))
|
|
83
|
+
# #, is_leaf=lambda
|
|
84
|
+
# #p:not isinstance(p, eqx.nn.Linear))
|
|
85
|
+
|
|
86
|
+
def __call__(self, x):
|
|
87
|
+
x_ = x.copy()
|
|
88
|
+
# forward pass in the network which determines the freq
|
|
89
|
+
for layer in self.layers_aux_nn:
|
|
90
|
+
x_ = layer(x_)
|
|
91
|
+
freq_list = jnp.clip(jnp.square(x_), a_min=1e-4, a_max=5)
|
|
92
|
+
x_ = x.copy()
|
|
93
|
+
# forward pass through the actual PINN
|
|
94
|
+
for idx, layer in enumerate(self.layers_pinn):
|
|
95
|
+
if idx % 2 == 0:
|
|
96
|
+
# Currently: every two layer we have an activation
|
|
97
|
+
# requiring a frequency
|
|
98
|
+
x_ = layer(x_)
|
|
99
|
+
else:
|
|
100
|
+
x_ = layer(x_, freq_list[(idx - 1) // 2])
|
|
101
|
+
return x_
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class sinusPINN(PINN):
|
|
105
|
+
"""
|
|
106
|
+
MUST inherit from PINN to pass all the checks
|
|
107
|
+
|
|
108
|
+
HOWEVER we dot not bother with reimplementing anything
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
|
|
112
|
+
super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
|
|
113
|
+
key, subkey = jax.random.split(key, 2)
|
|
114
|
+
_pinn = _SinusPINN(subkey, list_layers_pinn, list_layers_aux_nn)
|
|
115
|
+
|
|
116
|
+
self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
|
|
117
|
+
|
|
118
|
+
def init_params(self):
|
|
119
|
+
return self.params
|
|
120
|
+
|
|
121
|
+
def __call__(self, x, params):
|
|
122
|
+
try:
|
|
123
|
+
model = eqx.combine(params["nn_params"], self.static)
|
|
124
|
+
except (KeyError, TypeError) as e: # give more flexibility
|
|
125
|
+
model = eqx.combine(params, self.static)
|
|
126
|
+
res = model(x)
|
|
127
|
+
if not res.shape:
|
|
128
|
+
return jnp.expand_dims(res, axis=-1)
|
|
129
|
+
return res
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def create_sinusPINN(key, list_layers_pinn, list_layers_aux_nn):
|
|
133
|
+
""" """
|
|
134
|
+
u = sinusPINN(key, list_layers_pinn, list_layers_aux_nn)
|
|
135
|
+
return u
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import equinox as eqx
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jinns.utils._pinn import PINN
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
|
|
8
|
+
out, in_ = weight.shape
|
|
9
|
+
stddev = 1e-2
|
|
10
|
+
return stddev * jax.random.normal(key, shape=(out, in_))
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _SpectralPINN(eqx.Module):
|
|
14
|
+
"""
|
|
15
|
+
A specific PINN whose acrhitecture is similar to spectral method for simulation of a spatial field
|
|
16
|
+
(Chilès and Delfiner, 2012) - a single layer with cos() activation function and sum for last layer
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
layers_pinn: list
|
|
20
|
+
nbands: int
|
|
21
|
+
|
|
22
|
+
def __init__(self, key, list_layers_pinn, nbands):
|
|
23
|
+
"""
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
key
|
|
27
|
+
A jax random key
|
|
28
|
+
list_layers_pinn
|
|
29
|
+
A list as eqx_list in jinns' PINN utility for the main PINN
|
|
30
|
+
nbands
|
|
31
|
+
Number of spectral bands (i.e., neurones in the single layer of the PINN)
|
|
32
|
+
"""
|
|
33
|
+
self.nbands = nbands
|
|
34
|
+
self.layers_pinn = []
|
|
35
|
+
for l in list_layers_pinn:
|
|
36
|
+
if len(l) == 1:
|
|
37
|
+
self.layers_pinn.append(l[0])
|
|
38
|
+
else:
|
|
39
|
+
key, subkey = jax.random.split(key, 2)
|
|
40
|
+
self.layers_pinn.append(l[0](*l[1:], key=subkey))
|
|
41
|
+
|
|
42
|
+
def __call__(self, x):
|
|
43
|
+
# forward pass through the actual PINN
|
|
44
|
+
for layer in self.layers_pinn:
|
|
45
|
+
x = layer(x)
|
|
46
|
+
|
|
47
|
+
return jnp.sqrt(2 / self.nbands) * jnp.sum(x)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class spectralPINN(PINN):
|
|
51
|
+
"""
|
|
52
|
+
MUST inherit from PINN to pass all the checks
|
|
53
|
+
|
|
54
|
+
HOWEVER we dot not bother with reimplementing anything
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, key, list_layers_pinn, nbands):
|
|
58
|
+
super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
|
|
59
|
+
key, subkey = jax.random.split(key, 2)
|
|
60
|
+
_pinn = _SpectralPINN(subkey, list_layers_pinn, nbands)
|
|
61
|
+
|
|
62
|
+
self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
|
|
63
|
+
|
|
64
|
+
def init_params(self):
|
|
65
|
+
return self.params
|
|
66
|
+
|
|
67
|
+
def __call__(self, x, params):
|
|
68
|
+
try:
|
|
69
|
+
model = eqx.combine(params["nn_params"], self.static)
|
|
70
|
+
except (KeyError, TypeError) as e: # give more flexibility
|
|
71
|
+
model = eqx.combine(params, self.static)
|
|
72
|
+
# model = eqx.tree_at(lambda m:
|
|
73
|
+
# m.layers_pinn[0].bias,
|
|
74
|
+
# model,
|
|
75
|
+
# model.layers_pinn[0].bias % (2 *
|
|
76
|
+
# jnp.pi)
|
|
77
|
+
# )
|
|
78
|
+
res = model(x)
|
|
79
|
+
if not res.shape:
|
|
80
|
+
return jnp.expand_dims(res, axis=-1)
|
|
81
|
+
return res
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def create_spectralPINN(key, list_layers_pinn, nbands):
|
|
85
|
+
""" """
|
|
86
|
+
u = spectralPINN(key, list_layers_pinn, nbands)
|
|
87
|
+
return u
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.8
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -1,9 +1,11 @@
|
|
|
1
|
-
jinns/__init__.py,sha256=
|
|
1
|
+
jinns/__init__.py,sha256=T2XlmLbYqcXTumPJL00cJ80W98We5LH8Yg_Lss_exl4,139
|
|
2
2
|
jinns/data/_DataGenerators.py,sha256=N4-U4z3MG46UIzHCbKScv9Z7AN40w1wlLY_VsVNj2sI,62293
|
|
3
3
|
jinns/data/__init__.py,sha256=yBOmoavSD-cABp4XcjQY1zsEVO0mDyIhi2MJ5WNp0l8,326
|
|
4
|
-
jinns/data/_display.py,sha256=
|
|
5
|
-
jinns/experimental/__init__.py,sha256=
|
|
4
|
+
jinns/data/_display.py,sha256=vlqggDCgVMEwdGBtjVmZaTQORU6imSfDkssn2XCtITI,10392
|
|
5
|
+
jinns/experimental/__init__.py,sha256=qWbhC7Z8UgLWy0t-zU7RYze6v13-FngiCYXu-2bRVFQ,296
|
|
6
6
|
jinns/experimental/_diffrax_solver.py,sha256=sLT22byqh-6015_fhe1xtMWlFOYcCjzYKET4sLhA9R4,6818
|
|
7
|
+
jinns/experimental/_sinuspinn.py,sha256=hxSzscwMV2LayWOqenIlT1zqEVVrE5Y8CKf7bHX5XFQ,5016
|
|
8
|
+
jinns/experimental/_spectralpinn.py,sha256=-4795pa7AYtRNSE-ugan3gHh64mtu2VdrRG5AS_J9Eg,2654
|
|
7
9
|
jinns/loss/_DynamicLoss.py,sha256=L4CVmmF0rTPbHntgqsLLHlnrlQgLHsetUocpJm7ZYag,27461
|
|
8
10
|
jinns/loss/_DynamicLossAbstract.py,sha256=kTQlhLx7SBuH5dIDmYaE79sVHUZt1nUFa8LxPU5IHhM,8504
|
|
9
11
|
jinns/loss/_LossODE.py,sha256=b9doBHoQwYvlgpqzrNO4dOaTN87LRvjHtHbz9bMoH7E,22119
|
|
@@ -27,8 +29,8 @@ jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
|
|
|
27
29
|
jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
|
|
28
30
|
jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
|
|
29
31
|
jinns/validation/_validation.py,sha256=KfetbzB0xTNdBcYLwFWjEtP63Tf9wJirlhgqLTJDyy4,6761
|
|
30
|
-
jinns-0.8.
|
|
31
|
-
jinns-0.8.
|
|
32
|
-
jinns-0.8.
|
|
33
|
-
jinns-0.8.
|
|
34
|
-
jinns-0.8.
|
|
32
|
+
jinns-0.8.8.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
33
|
+
jinns-0.8.8.dist-info/METADATA,sha256=oTs2EJMu4Bwo2n9DLsAPSU5edpbgPtwhNXBuW8YjpOc,2482
|
|
34
|
+
jinns-0.8.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
35
|
+
jinns-0.8.8.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
36
|
+
jinns-0.8.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|