jinns 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/data/_display.py CHANGED
@@ -13,6 +13,7 @@ def plot2d(
13
13
  title="",
14
14
  figsize=(7, 7),
15
15
  cmap="inferno",
16
+ spinn=False,
16
17
  ):
17
18
  """Generic function for plotting functions over rectangular 2-D domains
18
19
  :math:`\Omega`. It treats both the stationary case :math:`u(x)` or the
@@ -56,8 +57,24 @@ def plot2d(
56
57
 
57
58
  if times is None:
58
59
  # Statio case : expect a function of one argument fun(x)
59
- v_fun = vmap(fun, 0, 0)
60
- _plot_2D_statio(v_fun, mesh, plot=True, colorbar=True, cmap=cmap)
60
+ if not spinn:
61
+ v_fun = vmap(fun, 0, 0)
62
+ _plot_2D_statio(
63
+ v_fun, mesh, plot=True, colorbar=True, cmap=cmap, figsize=figsize
64
+ )
65
+ elif spinn:
66
+ values_grid = jnp.squeeze(
67
+ fun(jnp.stack([xy_data[0][..., None], xy_data[1][..., None]], axis=1))
68
+ )
69
+ _plot_2D_statio(
70
+ values_grid,
71
+ mesh,
72
+ plot=True,
73
+ colorbar=True,
74
+ cmap=cmap,
75
+ spinn=True,
76
+ figsize=figsize,
77
+ )
61
78
  plt.title(title)
62
79
 
63
80
  else:
@@ -81,17 +98,30 @@ def plot2d(
81
98
  )
82
99
 
83
100
  for idx, (t, ax) in enumerate(zip(times, grid)):
84
- v_fun_at_t = vmap(lambda x: fun(t=jnp.array([t]), x=x), 0, 0)
85
- t_slice, _ = _plot_2D_statio(
86
- v_fun_at_t, mesh, plot=False, colorbar=False, cmap=None
87
- )
101
+ if not spinn:
102
+ v_fun_at_t = vmap(lambda x: fun(t=jnp.array([t]), x=x), 0, 0)
103
+ t_slice, _ = _plot_2D_statio(
104
+ v_fun_at_t, mesh, plot=False, colorbar=False, cmap=None
105
+ )
106
+ elif spinn:
107
+ values_grid = jnp.squeeze(
108
+ fun(
109
+ t * jnp.ones((xy_data[0].shape[0], 1)),
110
+ jnp.stack(
111
+ [xy_data[0][..., None], xy_data[1][..., None]], axis=1
112
+ ),
113
+ )[0]
114
+ )
115
+ t_slice, _ = _plot_2D_statio(
116
+ values_grid, mesh, plot=False, colorbar=True, spinn=True
117
+ )
88
118
  im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
89
119
  ax.set_title(f"t = {times[idx] * Tmax}")
90
120
  ax.cax.colorbar(im)
91
121
 
92
122
 
93
123
  def _plot_2D_statio(
94
- v_fun, mesh, plot=True, colorbar=True, cmap="inferno", figsize=(7, 7)
124
+ v_fun, mesh, plot=True, colorbar=True, cmap="inferno", figsize=(7, 7), spinn=False
95
125
  ):
96
126
  """Function that plot the function u(x) with 2-D input x using pcolormesh()
97
127
 
@@ -114,8 +144,12 @@ def _plot_2D_statio(
114
144
  """
115
145
 
116
146
  x_grid, y_grid = mesh
117
- values = v_fun(jnp.vstack([x_grid.flatten(), y_grid.flatten()]).T)
118
- values_grid = values.reshape(x_grid.shape)
147
+ if not spinn:
148
+ values = v_fun(jnp.vstack([x_grid.flatten(), y_grid.flatten()]).T)
149
+ values_grid = values.reshape(x_grid.shape)
150
+ elif spinn:
151
+ # in this case v_fun is directly the values :)
152
+ values_grid = v_fun.T
119
153
 
120
154
  if plot:
121
155
  fig = plt.figure(figsize=figsize)
@@ -128,7 +162,13 @@ def _plot_2D_statio(
128
162
 
129
163
 
130
164
  def plot1d_slice(
131
- fun, xdata, time_slices=jnp.array([0]), Tmax=1, title="", figsize=(10, 10)
165
+ fun,
166
+ xdata,
167
+ time_slices=jnp.array([0]),
168
+ Tmax=1,
169
+ title="",
170
+ figsize=(10, 10),
171
+ spinn=False,
132
172
  ):
133
173
  """Function for plotting time slices of a function :math:`f(t_i, x)` where
134
174
  `t` is time (1-D) and x is 1-D
@@ -151,10 +191,16 @@ def plot1d_slice(
151
191
  """
152
192
  plt.figure(figsize=figsize)
153
193
  for t in time_slices:
154
- # fix t with partial : shape is (1,)
155
- v_u_tfixed = vmap(partial(fun, t=t * jnp.ones((1,))), 0, 0)
156
- # add an axis to xdata for the concatenate function in the neural net
157
- plt.plot(xdata, v_u_tfixed(x=xdata[:, None]), label=f"$t_i={t * Tmax}$")
194
+ if not spinn:
195
+ # fix t with partial : shape is (1,)
196
+ v_u_tfixed = vmap(partial(fun, t=t * jnp.ones((1,))), 0, 0)
197
+ # add an axis to xdata for the concatenate function in the neural net
198
+ values = v_u_tfixed(x=xdata[:, None])
199
+ elif spinn:
200
+ values = jnp.squeeze(
201
+ fun(t * jnp.ones((xdata.shape[0], 1)), xdata[..., None])[0]
202
+ )
203
+ plt.plot(xdata, values, label=f"$t_i={t * Tmax}$")
158
204
  plt.xlabel("x")
159
205
  plt.ylabel(r"$u(t_i, x)$")
160
206
  plt.legend()
@@ -162,7 +208,15 @@ def plot1d_slice(
162
208
 
163
209
 
164
210
  def plot1d_image(
165
- fun, xdata, times, Tmax=1, title="", figsize=(10, 10), colorbar=True, cmap="inferno"
211
+ fun,
212
+ xdata,
213
+ times,
214
+ Tmax=1,
215
+ title="",
216
+ figsize=(10, 10),
217
+ colorbar=True,
218
+ cmap="inferno",
219
+ spinn=False,
166
220
  ):
167
221
  """Function for plotting the 2-D image of a function :math:`f(t, x)` where
168
222
  `t` is time (1-D) and x is space (1-D).
@@ -186,12 +240,15 @@ def plot1d_image(
186
240
  """
187
241
 
188
242
  mesh = jnp.meshgrid(times, xdata) # cartesian product
189
- # the trick is to use _plot2Dstatio
190
- v_fun = vmap(lambda tx: fun(t=tx[0, None], x=tx[1, None]), 0, 0)
191
- t_grid, x_grid = mesh
192
- values_grid = v_fun(jnp.vstack([t_grid.flatten(), x_grid.flatten()]).T).reshape(
193
- t_grid.shape
194
- )
243
+ if not spinn:
244
+ # the trick is to use _plot2Dstatio
245
+ v_fun = vmap(lambda tx: fun(t=tx[0, None], x=tx[1, None]), 0, 0)
246
+ t_grid, x_grid = mesh
247
+ values_grid = v_fun(jnp.vstack([t_grid.flatten(), x_grid.flatten()]).T).reshape(
248
+ t_grid.shape
249
+ )
250
+ elif spinn:
251
+ values_grid = jnp.squeeze(fun((times[..., None]), xdata[..., None]).T)
195
252
  fig, ax = plt.subplots(1, 1, figsize=figsize)
196
253
  im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
197
254
  if colorbar: