osiris-utils 1.1.10__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_hdf5_io.py +46 -0
- benchmarks/benchmark_load_all.py +54 -0
- docs/source/api/decks.rst +48 -0
- docs/source/api/postprocess.rst +66 -2
- docs/source/api/sim_diag.rst +1 -1
- docs/source/api/utilities.rst +1 -1
- docs/source/conf.py +2 -1
- docs/source/examples/example_Derivatives.md +78 -0
- docs/source/examples/example_FFT.md +152 -0
- docs/source/examples/example_InputDeck.md +148 -0
- docs/source/examples/example_Simulation_Diagnostic.md +213 -0
- docs/source/examples/quick_start.md +51 -0
- docs/source/examples.rst +14 -0
- docs/source/index.rst +8 -0
- examples/edited-deck.1d +1 -1
- examples/example_Derivatives.ipynb +24 -36
- examples/example_FFT.ipynb +44 -23
- examples/example_InputDeck.ipynb +24 -277
- examples/example_Simulation_Diagnostic.ipynb +27 -17
- examples/quick_start.ipynb +17 -1
- osiris_utils/__init__.py +10 -6
- osiris_utils/cli/__init__.py +6 -0
- osiris_utils/cli/__main__.py +85 -0
- osiris_utils/cli/export.py +199 -0
- osiris_utils/cli/info.py +156 -0
- osiris_utils/cli/plot.py +189 -0
- osiris_utils/cli/validate.py +247 -0
- osiris_utils/data/__init__.py +15 -0
- osiris_utils/data/data.py +41 -171
- osiris_utils/data/diagnostic.py +285 -274
- osiris_utils/data/simulation.py +20 -13
- osiris_utils/decks/__init__.py +4 -0
- osiris_utils/decks/decks.py +83 -8
- osiris_utils/decks/species.py +12 -9
- osiris_utils/postprocessing/__init__.py +28 -0
- osiris_utils/postprocessing/derivative.py +317 -106
- osiris_utils/postprocessing/fft.py +135 -24
- osiris_utils/postprocessing/field_centering.py +28 -14
- osiris_utils/postprocessing/heatflux_correction.py +39 -18
- osiris_utils/postprocessing/mft.py +10 -2
- osiris_utils/postprocessing/postprocess.py +8 -5
- osiris_utils/postprocessing/pressure_correction.py +29 -17
- osiris_utils/utils.py +26 -17
- osiris_utils/vis/__init__.py +3 -0
- osiris_utils/vis/plot3d.py +148 -0
- {osiris_utils-1.1.10.dist-info → osiris_utils-1.2.0.dist-info}/METADATA +55 -7
- {osiris_utils-1.1.10.dist-info → osiris_utils-1.2.0.dist-info}/RECORD +51 -34
- {osiris_utils-1.1.10.dist-info → osiris_utils-1.2.0.dist-info}/WHEEL +1 -1
- osiris_utils-1.2.0.dist-info/entry_points.txt +2 -0
- {osiris_utils-1.1.10.dist-info → osiris_utils-1.2.0.dist-info}/top_level.txt +1 -0
- osiris_utils/postprocessing/mft_for_gridfile.py +0 -55
- {osiris_utils-1.1.10.dist-info → osiris_utils-1.2.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Generator
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
1
6
|
import numpy as np
|
|
2
7
|
import tqdm as tqdm
|
|
3
8
|
|
|
@@ -5,6 +10,8 @@ from ..data.diagnostic import Diagnostic
|
|
|
5
10
|
from ..data.simulation import Simulation
|
|
6
11
|
from .postprocess import PostProcess
|
|
7
12
|
|
|
13
|
+
__all__ = ["FFT_Simulation", "FFT_Diagnostic", "FFT_Species_Handler"]
|
|
14
|
+
|
|
8
15
|
|
|
9
16
|
class FFT_Simulation(PostProcess):
|
|
10
17
|
"""
|
|
@@ -21,16 +28,16 @@ class FFT_Simulation(PostProcess):
|
|
|
21
28
|
|
|
22
29
|
"""
|
|
23
30
|
|
|
24
|
-
def __init__(self, simulation, fft_axis):
|
|
31
|
+
def __init__(self, simulation: Simulation, fft_axis: int | list[int]) -> None:
|
|
25
32
|
super().__init__("FFT")
|
|
26
33
|
if not isinstance(simulation, Simulation):
|
|
27
|
-
raise ValueError("
|
|
34
|
+
raise ValueError("simulation must be a Simulation-compatible object.")
|
|
28
35
|
self._simulation = simulation
|
|
29
36
|
self._fft_axis = fft_axis
|
|
30
|
-
self._fft_computed = {}
|
|
31
|
-
self._species_handler = {}
|
|
37
|
+
self._fft_computed: dict[Any, FFT_Diagnostic] = {}
|
|
38
|
+
self._species_handler: dict[Any, FFT_Species_Handler] = {}
|
|
32
39
|
|
|
33
|
-
def __getitem__(self, key):
|
|
40
|
+
def __getitem__(self, key: Any) -> FFT_Species_Handler | FFT_Diagnostic:
|
|
34
41
|
if key in self._simulation._species:
|
|
35
42
|
if key not in self._species_handler:
|
|
36
43
|
self._species_handler[key] = FFT_Species_Handler(self._simulation[key], self._fft_axis)
|
|
@@ -40,16 +47,16 @@ class FFT_Simulation(PostProcess):
|
|
|
40
47
|
self._fft_computed[key] = FFT_Diagnostic(self._simulation[key], self._fft_axis)
|
|
41
48
|
return self._fft_computed[key]
|
|
42
49
|
|
|
43
|
-
def delete_all(self):
|
|
50
|
+
def delete_all(self) -> None:
|
|
44
51
|
self._fft_computed = {}
|
|
45
52
|
|
|
46
|
-
def delete(self, key):
|
|
53
|
+
def delete(self, key: Any) -> None:
|
|
47
54
|
if key in self._fft_computed:
|
|
48
55
|
del self._fft_computed[key]
|
|
49
56
|
else:
|
|
50
57
|
print(f"FFT {key} not found in simulation")
|
|
51
58
|
|
|
52
|
-
def process(self, diagnostic):
|
|
59
|
+
def process(self, diagnostic: Diagnostic) -> FFT_Diagnostic:
|
|
53
60
|
"""Apply FFT to a diagnostic"""
|
|
54
61
|
return FFT_Diagnostic(diagnostic, self._fft_axis)
|
|
55
62
|
|
|
@@ -77,7 +84,7 @@ class FFT_Diagnostic(Diagnostic):
|
|
|
77
84
|
|
|
78
85
|
"""
|
|
79
86
|
|
|
80
|
-
def __init__(self, diagnostic, fft_axis):
|
|
87
|
+
def __init__(self, diagnostic: Diagnostic, fft_axis: int | list[int]) -> None:
|
|
81
88
|
if hasattr(diagnostic, "_species"):
|
|
82
89
|
super().__init__(
|
|
83
90
|
simulation_folder=(diagnostic._simulation_folder if hasattr(diagnostic, "_simulation_folder") else None),
|
|
@@ -113,9 +120,11 @@ class FFT_Diagnostic(Diagnostic):
|
|
|
113
120
|
if isinstance(self._dx, (int, float)):
|
|
114
121
|
self._kmax = np.pi / (self._dx)
|
|
115
122
|
else:
|
|
116
|
-
|
|
123
|
+
# Handle if fft_axis is int
|
|
124
|
+
axes = [self._fft_axis] if isinstance(self._fft_axis, int) else self._fft_axis
|
|
125
|
+
self._kmax = np.pi / np.array([self._dx[ax - 1] for ax in axes if ax != 0])
|
|
117
126
|
|
|
118
|
-
def load_all(self):
|
|
127
|
+
def load_all(self) -> np.ndarray:
|
|
119
128
|
if self._data is not None:
|
|
120
129
|
print("Using cached data.")
|
|
121
130
|
return self._data
|
|
@@ -165,7 +174,7 @@ class FFT_Diagnostic(Diagnostic):
|
|
|
165
174
|
self._data = np.abs(result) ** 2
|
|
166
175
|
return self._data
|
|
167
176
|
|
|
168
|
-
def _data_generator(self, index):
|
|
177
|
+
def _data_generator(self, index: int) -> Generator[np.ndarray, None, None]:
|
|
169
178
|
# Get the data for this index
|
|
170
179
|
original_data = self._diag[index]
|
|
171
180
|
|
|
@@ -195,10 +204,10 @@ class FFT_Diagnostic(Diagnostic):
|
|
|
195
204
|
|
|
196
205
|
yield np.abs(result_fft) ** 2
|
|
197
206
|
|
|
198
|
-
def _get_window(self, length, axis):
|
|
207
|
+
def _get_window(self, length: int, axis: int) -> np.ndarray:
|
|
199
208
|
return np.hanning(length)
|
|
200
209
|
|
|
201
|
-
def _apply_window(self, data, window, axis):
|
|
210
|
+
def _apply_window(self, data: np.ndarray, window: np.ndarray, axis: int) -> np.ndarray:
|
|
202
211
|
ndim = data.ndim
|
|
203
212
|
window_shape = [1] * ndim
|
|
204
213
|
window_shape[axis] = len(window)
|
|
@@ -207,7 +216,7 @@ class FFT_Diagnostic(Diagnostic):
|
|
|
207
216
|
|
|
208
217
|
return data * reshaped_window
|
|
209
218
|
|
|
210
|
-
def __getitem__(self, index):
|
|
219
|
+
def __getitem__(self, index: int | slice) -> np.ndarray:
|
|
211
220
|
if self._all_loaded and self._data is not None:
|
|
212
221
|
return self._data[index]
|
|
213
222
|
|
|
@@ -221,29 +230,131 @@ class FFT_Diagnostic(Diagnostic):
|
|
|
221
230
|
else:
|
|
222
231
|
raise ValueError("Invalid index type. Use int or slice.")
|
|
223
232
|
|
|
224
|
-
def omega(self):
|
|
233
|
+
def omega(self) -> np.ndarray:
|
|
225
234
|
"""
|
|
226
|
-
Get the angular frequency array for the FFT.
|
|
235
|
+
Get the angular frequency array for the FFT along the time dimension (axis 0).
|
|
236
|
+
|
|
237
|
+
Returns
|
|
238
|
+
-------
|
|
239
|
+
np.ndarray
|
|
240
|
+
Angular frequency array for the time axis.
|
|
227
241
|
"""
|
|
228
242
|
if not self._all_loaded:
|
|
229
243
|
raise ValueError("Load the data first using load_all() method.")
|
|
230
244
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
245
|
+
# If the FFT was computed along the time axis (0) return temporal frequencies
|
|
246
|
+
if isinstance(self._fft_axis, (list, tuple)):
|
|
247
|
+
if 0 in self._fft_axis:
|
|
248
|
+
dt = self._dt * self._ndump
|
|
249
|
+
omega = np.fft.fftfreq(self._data.shape[0], d=dt) * 2 * np.pi
|
|
250
|
+
return np.fft.fftshift(omega)
|
|
251
|
+
# If FFT was computed along spatial axes only and a single spatial axis
|
|
252
|
+
spatial_axes = [ax for ax in self._fft_axis if ax != 0]
|
|
253
|
+
if len(spatial_axes) == 1:
|
|
254
|
+
return self.k(spatial_axes[0])
|
|
255
|
+
# Multi-dimensional spatial FFT: return concatenated or dict of k arrays
|
|
256
|
+
return self.k()
|
|
257
|
+
else:
|
|
258
|
+
if self._fft_axis == 0:
|
|
259
|
+
dt = self._dt * self._ndump
|
|
260
|
+
omega = np.fft.fftfreq(self._data.shape[0], d=dt) * 2 * np.pi
|
|
261
|
+
return np.fft.fftshift(omega)
|
|
262
|
+
# Single spatial axis: return wavenumber array for that axis
|
|
263
|
+
return self.k(self._fft_axis)
|
|
264
|
+
|
|
265
|
+
def k(self, axis: int | None = None) -> np.ndarray | dict[int, np.ndarray]:
|
|
266
|
+
"""
|
|
267
|
+
Get the wavenumber array for the FFT along spatial dimension(s).
|
|
268
|
+
|
|
269
|
+
Parameters
|
|
270
|
+
----------
|
|
271
|
+
axis : int or None, optional
|
|
272
|
+
The spatial axis to compute wavenumber for (1, 2, or 3).
|
|
273
|
+
If None, returns wavenumbers for all spatial axes in fft_axis.
|
|
274
|
+
|
|
275
|
+
Returns
|
|
276
|
+
-------
|
|
277
|
+
np.ndarray or dict
|
|
278
|
+
If axis is specified: wavenumber array for that axis.
|
|
279
|
+
If axis is None: dictionary mapping axis -> wavenumber array.
|
|
280
|
+
|
|
281
|
+
Notes
|
|
282
|
+
-----
|
|
283
|
+
When load_all() is used, time axis is 0 and spatial axes are 1,2,3.
|
|
284
|
+
When accessing single timesteps, spatial axes are 0,1,2.
|
|
285
|
+
"""
|
|
286
|
+
if self._data is None:
|
|
287
|
+
raise ValueError("Load the data first using load_all() or access via indexing.")
|
|
288
|
+
|
|
289
|
+
# Determine if we have the time dimension in the data
|
|
290
|
+
# If all_loaded is True, then axis 0 is time, spatial axes are 1,2,3
|
|
291
|
+
# If all_loaded is False, we're looking at a single timestep, spatial axes are 0,1,2
|
|
292
|
+
has_time_axis = self._all_loaded
|
|
293
|
+
|
|
294
|
+
# Determine which axes to compute k for
|
|
295
|
+
if axis is not None:
|
|
296
|
+
# Single axis specified
|
|
297
|
+
if axis == 0:
|
|
298
|
+
raise ValueError("axis must be a spatial dimension (1, 2, or 3), not 0 (time).")
|
|
299
|
+
if axis < 1 or axis > 3:
|
|
300
|
+
raise ValueError(f"axis must be 1, 2, or 3, got {axis}")
|
|
301
|
+
|
|
302
|
+
# Get dx for this axis
|
|
303
|
+
if isinstance(self._dx, (int, float)):
|
|
304
|
+
dx = self._dx
|
|
305
|
+
else:
|
|
306
|
+
dx = self._dx[axis - 1]
|
|
307
|
+
|
|
308
|
+
# Compute the actual data axis index
|
|
309
|
+
# If we have time axis, spatial axis N is at index N
|
|
310
|
+
# If no time axis (single timestep), spatial axis N is at index N-1
|
|
311
|
+
data_axis = axis if has_time_axis else axis - 1
|
|
312
|
+
|
|
313
|
+
# Compute wavenumber
|
|
314
|
+
k_array = np.fft.fftfreq(self._data.shape[data_axis], d=dx) * 2 * np.pi
|
|
315
|
+
k_array = np.fft.fftshift(k_array)
|
|
316
|
+
return k_array
|
|
317
|
+
else:
|
|
318
|
+
# axis is None: return k for all spatial axes in fft_axis
|
|
319
|
+
result = {}
|
|
320
|
+
|
|
321
|
+
if isinstance(self._fft_axis, (list, tuple)):
|
|
322
|
+
# Multi-axis FFT: return k for all spatial axes
|
|
323
|
+
spatial_axes = [ax for ax in self._fft_axis if ax != 0]
|
|
324
|
+
elif self._fft_axis == 0:
|
|
325
|
+
# Only time FFT, no spatial axes
|
|
326
|
+
raise ValueError("No spatial FFT axes to compute wavenumber for. fft_axis is 0 (time only).")
|
|
327
|
+
else:
|
|
328
|
+
# Single spatial axis
|
|
329
|
+
spatial_axes = [self._fft_axis]
|
|
330
|
+
|
|
331
|
+
for ax in spatial_axes:
|
|
332
|
+
if isinstance(self._dx, (int, float)):
|
|
333
|
+
dx = self._dx
|
|
334
|
+
else:
|
|
335
|
+
dx = self._dx[ax - 1]
|
|
336
|
+
|
|
337
|
+
# Compute the actual data axis index
|
|
338
|
+
data_axis = ax if has_time_axis else ax - 1
|
|
339
|
+
|
|
340
|
+
k_array = np.fft.fftfreq(self._data.shape[data_axis], d=dx) * 2 * np.pi
|
|
341
|
+
k_array = np.fft.fftshift(k_array)
|
|
342
|
+
result[ax] = k_array
|
|
343
|
+
|
|
344
|
+
return result
|
|
234
345
|
|
|
235
346
|
@property
|
|
236
|
-
def kmax(self):
|
|
347
|
+
def kmax(self) -> float | np.ndarray:
|
|
237
348
|
return self._kmax
|
|
238
349
|
|
|
239
350
|
|
|
240
351
|
class FFT_Species_Handler:
|
|
241
|
-
def __init__(self, species_handler, fft_axis):
|
|
352
|
+
def __init__(self, species_handler: Any, fft_axis: int | list[int]) -> None:
|
|
242
353
|
self._species_handler = species_handler
|
|
243
354
|
self._fft_axis = fft_axis
|
|
244
|
-
self._fft_computed = {}
|
|
355
|
+
self._fft_computed: dict[Any, FFT_Diagnostic] = {}
|
|
245
356
|
|
|
246
|
-
def __getitem__(self, key):
|
|
357
|
+
def __getitem__(self, key: Any) -> FFT_Diagnostic:
|
|
247
358
|
if key not in self._fft_computed:
|
|
248
359
|
diag = self._species_handler[key]
|
|
249
360
|
self._fft_computed[key] = FFT_Diagnostic(diag, self._fft_axis)
|
|
@@ -1,9 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Generator
|
|
4
|
+
|
|
1
5
|
import numpy as np
|
|
2
6
|
|
|
3
7
|
from ..data.diagnostic import OSIRIS_FLD, Diagnostic
|
|
4
8
|
from ..data.simulation import Simulation
|
|
5
9
|
from .postprocess import PostProcess
|
|
6
10
|
|
|
11
|
+
__all__ = ["FieldCentering_Simulation", "FieldCentering_Diagnostic"]
|
|
12
|
+
|
|
7
13
|
|
|
8
14
|
class FieldCentering_Simulation(PostProcess):
|
|
9
15
|
"""
|
|
@@ -22,47 +28,47 @@ class FieldCentering_Simulation(PostProcess):
|
|
|
22
28
|
"""
|
|
23
29
|
|
|
24
30
|
def __init__(self, simulation: Simulation):
|
|
25
|
-
super().__init__("FieldCentering Simulation")
|
|
26
31
|
"""
|
|
27
32
|
Class to center the field in the simulation.
|
|
28
33
|
|
|
29
34
|
Parameters
|
|
30
35
|
----------
|
|
31
|
-
|
|
36
|
+
simulation : Simulation
|
|
32
37
|
The simulation object.
|
|
33
|
-
field : str
|
|
34
|
-
The field to center.
|
|
35
38
|
"""
|
|
39
|
+
super().__init__("FieldCentering Simulation")
|
|
40
|
+
|
|
41
|
+
# Accept Simulation-compatible objects (Simulation or other PostProcess subclasses)
|
|
36
42
|
if not isinstance(simulation, Simulation):
|
|
37
|
-
raise ValueError("
|
|
43
|
+
raise ValueError("simulation must be a Simulation-compatible object.")
|
|
38
44
|
self._simulation = simulation
|
|
39
45
|
|
|
40
46
|
self._field_centered = {}
|
|
41
47
|
# no need to create a species handler for field centering since fields are not species related
|
|
42
48
|
|
|
43
|
-
def __getitem__(self, key):
|
|
49
|
+
def __getitem__(self, key: str) -> FieldCentering_Diagnostic:
|
|
44
50
|
if key not in OSIRIS_FLD:
|
|
45
51
|
raise ValueError(f"Does it make sense to center {key} field? Only {OSIRIS_FLD} are supported.")
|
|
46
52
|
if key not in self._field_centered:
|
|
47
53
|
self._field_centered[key] = FieldCentering_Diagnostic(self._simulation[key])
|
|
48
54
|
return self._field_centered[key]
|
|
49
55
|
|
|
50
|
-
def delete_all(self):
|
|
56
|
+
def delete_all(self) -> None:
|
|
51
57
|
self._field_centered = {}
|
|
52
58
|
|
|
53
|
-
def delete(self, key):
|
|
59
|
+
def delete(self, key: str) -> None:
|
|
54
60
|
if key in self._field_centered:
|
|
55
61
|
del self._field_centered[key]
|
|
56
62
|
else:
|
|
57
63
|
print(f"Field {key} not found in simulation")
|
|
58
64
|
|
|
59
|
-
def process(self, diagnostic):
|
|
65
|
+
def process(self, diagnostic: Diagnostic) -> FieldCentering_Diagnostic:
|
|
60
66
|
"""Apply field centering to a diagnostic"""
|
|
61
67
|
return FieldCentering_Diagnostic(diagnostic)
|
|
62
68
|
|
|
63
69
|
|
|
64
70
|
class FieldCentering_Diagnostic(Diagnostic):
|
|
65
|
-
def __init__(self, diagnostic):
|
|
71
|
+
def __init__(self, diagnostic: Diagnostic):
|
|
66
72
|
"""
|
|
67
73
|
Class to center the field in the simulation. It converts fields from the Osiris yee mesh to the center of the cells.
|
|
68
74
|
It only works for periodic boundaries.
|
|
@@ -104,10 +110,18 @@ class FieldCentering_Diagnostic(Diagnostic):
|
|
|
104
110
|
self._original_name = diagnostic._name
|
|
105
111
|
self._name = diagnostic._name + "_centered"
|
|
106
112
|
|
|
107
|
-
self._data = None
|
|
113
|
+
self._data: np.ndarray | None = None
|
|
108
114
|
self._all_loaded = False
|
|
109
115
|
|
|
110
|
-
def load_all(self):
|
|
116
|
+
def load_all(self) -> np.ndarray:
|
|
117
|
+
"""
|
|
118
|
+
Load all data and center the fields.
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
data : np.ndarray
|
|
123
|
+
The centered field data.
|
|
124
|
+
"""
|
|
111
125
|
if self._data is not None:
|
|
112
126
|
return self._data
|
|
113
127
|
|
|
@@ -217,7 +231,7 @@ class FieldCentering_Diagnostic(Diagnostic):
|
|
|
217
231
|
self._all_loaded = True
|
|
218
232
|
return self._data
|
|
219
233
|
|
|
220
|
-
def __getitem__(self, index):
|
|
234
|
+
def __getitem__(self, index: int | slice) -> np.ndarray:
|
|
221
235
|
"""Get data at a specific index"""
|
|
222
236
|
if self._all_loaded and self._data is not None:
|
|
223
237
|
return self._data[index]
|
|
@@ -232,7 +246,7 @@ class FieldCentering_Diagnostic(Diagnostic):
|
|
|
232
246
|
else:
|
|
233
247
|
raise ValueError("Invalid index type. Use int or slice.")
|
|
234
248
|
|
|
235
|
-
def _data_generator(self, index):
|
|
249
|
+
def _data_generator(self, index: int) -> Generator[np.ndarray, None, None]:
|
|
236
250
|
if self._dim == 1:
|
|
237
251
|
if self._original_name.lower() in [
|
|
238
252
|
"b2",
|
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Generator
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
1
6
|
import numpy as np
|
|
2
7
|
|
|
3
8
|
from ..data.diagnostic import Diagnostic
|
|
@@ -7,9 +12,15 @@ from .pressure_correction import PressureCorrection_Simulation
|
|
|
7
12
|
|
|
8
13
|
OSIRIS_H = ["q1", "q2", "q3"]
|
|
9
14
|
|
|
15
|
+
__all__ = [
|
|
16
|
+
"HeatfluxCorrection_Simulation",
|
|
17
|
+
"HeatfluxCorrection_Diagnostic",
|
|
18
|
+
"HeatfluxCorrection_Species_Handler",
|
|
19
|
+
]
|
|
20
|
+
|
|
10
21
|
|
|
11
22
|
class HeatfluxCorrection_Simulation(PostProcess):
|
|
12
|
-
def __init__(self, simulation):
|
|
23
|
+
def __init__(self, simulation: Simulation):
|
|
13
24
|
super().__init__("HeatfluxCorrection Simulation")
|
|
14
25
|
"""
|
|
15
26
|
Class to correct pressure tensor components by subtracting Reynolds stress.
|
|
@@ -22,12 +33,12 @@ class HeatfluxCorrection_Simulation(PostProcess):
|
|
|
22
33
|
The heatflux component to center.
|
|
23
34
|
"""
|
|
24
35
|
if not isinstance(simulation, Simulation):
|
|
25
|
-
raise ValueError("
|
|
36
|
+
raise ValueError("simulation must be a Simulation-compatible object.")
|
|
26
37
|
self._simulation = simulation
|
|
27
|
-
self._heatflux_corrected = {}
|
|
28
|
-
self._species_handler = {}
|
|
38
|
+
self._heatflux_corrected: dict[str, HeatfluxCorrection_Diagnostic] = {}
|
|
39
|
+
self._species_handler: dict[str, HeatfluxCorrection_Species_Handler] = {}
|
|
29
40
|
|
|
30
|
-
def __getitem__(self, key):
|
|
41
|
+
def __getitem__(self, key: str) -> HeatfluxCorrection_Species_Handler | HeatfluxCorrection_Diagnostic:
|
|
31
42
|
if key in self._simulation._species:
|
|
32
43
|
if key not in self._species_handler:
|
|
33
44
|
self._species_handler[key] = HeatfluxCorrection_Species_Handler(self._simulation[key], self._simulation)
|
|
@@ -36,25 +47,35 @@ class HeatfluxCorrection_Simulation(PostProcess):
|
|
|
36
47
|
raise ValueError(f"Invalid heatflux component {key}. Supported: {OSIRIS_H}.")
|
|
37
48
|
if key not in self._heatflux_corrected:
|
|
38
49
|
print("Weird that it got here - heatflux is always species dependent on OSIRIS")
|
|
39
|
-
|
|
50
|
+
# This part seems to lack some arguments for HeatfluxCorrection_Diagnostic,
|
|
51
|
+
# but keeping as is for structural consistency if reached
|
|
52
|
+
# self._heatflux_corrected[key] = HeatfluxCorrection_Diagnostic(self._simulation[key], self._simulation)
|
|
40
53
|
return self._heatflux_corrected[key]
|
|
41
54
|
|
|
42
|
-
def delete_all(self):
|
|
55
|
+
def delete_all(self) -> None:
|
|
43
56
|
self._heatflux_corrected = {}
|
|
44
57
|
|
|
45
|
-
def delete(self, key):
|
|
58
|
+
def delete(self, key: str) -> None:
|
|
46
59
|
if key in self._heatflux_corrected:
|
|
47
60
|
del self._heatflux_corrected[key]
|
|
48
61
|
else:
|
|
49
62
|
print(f"Heatflux {key} not found in simulation")
|
|
50
63
|
|
|
51
|
-
def process(self, diagnostic):
|
|
64
|
+
def process(self, diagnostic: Diagnostic) -> HeatfluxCorrection_Diagnostic:
|
|
52
65
|
"""Apply heatflux correction to a diagnostic"""
|
|
66
|
+
# FIX: This is a bit of a hack, but it works for now
|
|
53
67
|
return HeatfluxCorrection_Diagnostic(diagnostic, self._simulation)
|
|
54
68
|
|
|
55
69
|
|
|
56
70
|
class HeatfluxCorrection_Diagnostic(Diagnostic):
|
|
57
|
-
def __init__(
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
diagnostic: Diagnostic,
|
|
74
|
+
vfl_i: Diagnostic,
|
|
75
|
+
Pjj_list: list[Diagnostic],
|
|
76
|
+
vfl_j_list: list[Diagnostic],
|
|
77
|
+
Pji_list: list[Diagnostic],
|
|
78
|
+
):
|
|
58
79
|
"""
|
|
59
80
|
Class to correct the pressure in the simulation.
|
|
60
81
|
|
|
@@ -102,10 +123,10 @@ class HeatfluxCorrection_Diagnostic(Diagnostic):
|
|
|
102
123
|
self._original_name = diagnostic._name
|
|
103
124
|
self._name = diagnostic._name + "_corrected"
|
|
104
125
|
|
|
105
|
-
self._data = None
|
|
126
|
+
self._data: np.ndarray | None = None
|
|
106
127
|
self._all_loaded = False
|
|
107
128
|
|
|
108
|
-
def load_all(self):
|
|
129
|
+
def load_all(self) -> np.ndarray:
|
|
109
130
|
if self._data is not None:
|
|
110
131
|
return self._data
|
|
111
132
|
|
|
@@ -131,12 +152,12 @@ class HeatfluxCorrection_Diagnostic(Diagnostic):
|
|
|
131
152
|
# Sum over j: vfl_j * Pji
|
|
132
153
|
vfl_dot_Pji = sum(vfl_j.data * Pji.data for vfl_j, Pji in zip(self._vfl_j_list, self._Pji_list, strict=False))
|
|
133
154
|
|
|
134
|
-
self._data = 2 * q -
|
|
155
|
+
self._data = 2 * q - vfl_i * trace_P - 2 * vfl_dot_Pji
|
|
135
156
|
self._all_loaded = True
|
|
136
157
|
|
|
137
158
|
return self._data
|
|
138
159
|
|
|
139
|
-
def __getitem__(self, index):
|
|
160
|
+
def __getitem__(self, index: int | slice) -> np.ndarray:
|
|
140
161
|
"""Get data at a specific index"""
|
|
141
162
|
if self._all_loaded and self._data is not None:
|
|
142
163
|
return self._data[index]
|
|
@@ -151,7 +172,7 @@ class HeatfluxCorrection_Diagnostic(Diagnostic):
|
|
|
151
172
|
else:
|
|
152
173
|
raise ValueError("Invalid index type. Use int or slice.")
|
|
153
174
|
|
|
154
|
-
def _data_generator(self, index):
|
|
175
|
+
def _data_generator(self, index: int) -> Generator[np.ndarray, None, None]:
|
|
155
176
|
q = self._diag[index]
|
|
156
177
|
vfl_i = self._vfl_i[index]
|
|
157
178
|
trace_P = sum(Pjj[index] for Pjj in self._Pjj_list)
|
|
@@ -174,12 +195,12 @@ class HeatfluxCorrection_Species_Handler:
|
|
|
174
195
|
The simulation object.
|
|
175
196
|
"""
|
|
176
197
|
|
|
177
|
-
def __init__(self, species_handler, simulation):
|
|
198
|
+
def __init__(self, species_handler: Any, simulation: Simulation):
|
|
178
199
|
self._species_handler = species_handler
|
|
179
200
|
self._simulation = simulation
|
|
180
|
-
self._heatflux_corrected = {}
|
|
201
|
+
self._heatflux_corrected: dict[str, HeatfluxCorrection_Diagnostic] = {}
|
|
181
202
|
|
|
182
|
-
def __getitem__(self, key):
|
|
203
|
+
def __getitem__(self, key: str) -> HeatfluxCorrection_Diagnostic:
|
|
183
204
|
if key not in self._heatflux_corrected:
|
|
184
205
|
diag = self._species_handler[key]
|
|
185
206
|
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
|
|
3
5
|
from ..data.diagnostic import Diagnostic
|
|
4
6
|
from ..data.simulation import Simulation
|
|
5
7
|
from .postprocess import PostProcess
|
|
6
8
|
|
|
9
|
+
__all__ = [
|
|
10
|
+
"MFT_Simulation",
|
|
11
|
+
"MFT_Diagnostic",
|
|
12
|
+
"MFT_Species_Handler",
|
|
13
|
+
]
|
|
14
|
+
|
|
7
15
|
|
|
8
16
|
class MFT_Simulation(PostProcess):
|
|
9
17
|
"""
|
|
@@ -19,10 +27,10 @@ class MFT_Simulation(PostProcess):
|
|
|
19
27
|
|
|
20
28
|
"""
|
|
21
29
|
|
|
22
|
-
def __init__(self, simulation, mft_axis=None):
|
|
30
|
+
def __init__(self, simulation: Simulation, mft_axis: int | None = None):
|
|
23
31
|
super().__init__(f"MeanFieldTheory({mft_axis})")
|
|
24
32
|
if not isinstance(simulation, Simulation):
|
|
25
|
-
raise ValueError("
|
|
33
|
+
raise ValueError("simulation must be a Simulation-compatible object.")
|
|
26
34
|
self._simulation = simulation
|
|
27
35
|
self._mft_axis = mft_axis
|
|
28
36
|
self._mft_computed = {}
|
|
@@ -1,7 +1,11 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from ..data.simulation import Simulation
|
|
3
4
|
|
|
4
|
-
|
|
5
|
+
__all__ = ["PostProcess"]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PostProcess(Simulation):
|
|
5
9
|
"""
|
|
6
10
|
Base class for post-processing operations.
|
|
7
11
|
Inherits from Diagnostic to ensure all operation overloads work.
|
|
@@ -14,14 +18,13 @@ class PostProcess(Diagnostic):
|
|
|
14
18
|
The species to analyze.
|
|
15
19
|
"""
|
|
16
20
|
|
|
17
|
-
def __init__(self, name, species=None):
|
|
21
|
+
def __init__(self, name: str, species: str = None):
|
|
18
22
|
# Initialize with the same interface as Diagnostic
|
|
19
|
-
super().__init__(species)
|
|
20
23
|
self._name = name
|
|
21
24
|
self._all_loaded = False
|
|
22
25
|
self._data = None
|
|
23
26
|
|
|
24
|
-
def process(self,
|
|
27
|
+
def process(self, simulation: Simulation) -> Simulation:
|
|
25
28
|
"""
|
|
26
29
|
Apply the post-processing to a diagnostic.
|
|
27
30
|
Must be implemented by subclasses.
|
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Generator
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
1
6
|
import numpy as np
|
|
2
7
|
|
|
3
8
|
from ..data.diagnostic import Diagnostic
|
|
@@ -6,9 +11,15 @@ from .postprocess import PostProcess
|
|
|
6
11
|
|
|
7
12
|
OSIRIS_P = ["P11", "P12", "P13", "P21", "P22", "P23", "P31", "P32", "P33"]
|
|
8
13
|
|
|
14
|
+
__all__ = [
|
|
15
|
+
"PressureCorrection_Simulation",
|
|
16
|
+
"PressureCorrection_Diagnostic",
|
|
17
|
+
"PressureCorrection_Species_Handler",
|
|
18
|
+
]
|
|
19
|
+
|
|
9
20
|
|
|
10
21
|
class PressureCorrection_Simulation(PostProcess):
|
|
11
|
-
def __init__(self, simulation):
|
|
22
|
+
def __init__(self, simulation: Simulation):
|
|
12
23
|
super().__init__("PressureCorrection Simulation")
|
|
13
24
|
"""
|
|
14
25
|
Class to correct pressure tensor components by subtracting Reynolds stress.
|
|
@@ -21,12 +32,12 @@ class PressureCorrection_Simulation(PostProcess):
|
|
|
21
32
|
The pressure component to center.
|
|
22
33
|
"""
|
|
23
34
|
if not isinstance(simulation, Simulation):
|
|
24
|
-
raise ValueError("
|
|
35
|
+
raise ValueError("simulation must be a Simulation-compatible object.")
|
|
25
36
|
self._simulation = simulation
|
|
26
|
-
self._pressure_corrected = {}
|
|
27
|
-
self._species_handler = {}
|
|
37
|
+
self._pressure_corrected: dict[str, PressureCorrection_Diagnostic] = {}
|
|
38
|
+
self._species_handler: dict[str, PressureCorrection_Species_Handler] = {}
|
|
28
39
|
|
|
29
|
-
def __getitem__(self, key):
|
|
40
|
+
def __getitem__(self, key: str) -> PressureCorrection_Species_Handler | PressureCorrection_Diagnostic:
|
|
30
41
|
if key in self._simulation._species:
|
|
31
42
|
if key not in self._species_handler:
|
|
32
43
|
self._species_handler[key] = PressureCorrection_Species_Handler(self._simulation[key])
|
|
@@ -35,25 +46,26 @@ class PressureCorrection_Simulation(PostProcess):
|
|
|
35
46
|
raise ValueError(f"Invalid pressure component {key}. Supported: {OSIRIS_P}.")
|
|
36
47
|
if key not in self._pressure_corrected:
|
|
37
48
|
print("Weird that it got here - pressure is always species dependent on OSIRIS")
|
|
38
|
-
self._pressure_corrected[key] = PressureCorrection_Diagnostic(self._simulation[key], self._simulation)
|
|
49
|
+
# self._pressure_corrected[key] = PressureCorrection_Diagnostic(self._simulation[key], self._simulation)
|
|
39
50
|
return self._pressure_corrected[key]
|
|
40
51
|
|
|
41
|
-
def delete_all(self):
|
|
52
|
+
def delete_all(self) -> None:
|
|
42
53
|
self._pressure_corrected = {}
|
|
43
54
|
|
|
44
|
-
def delete(self, key):
|
|
55
|
+
def delete(self, key: str) -> None:
|
|
45
56
|
if key in self._pressure_corrected:
|
|
46
57
|
del self._pressure_corrected[key]
|
|
47
58
|
else:
|
|
48
59
|
print(f"Pressure {key} not found in simulation")
|
|
49
60
|
|
|
50
|
-
def process(self, diagnostic):
|
|
61
|
+
def process(self, diagnostic: Diagnostic) -> PressureCorrection_Diagnostic:
|
|
51
62
|
"""Apply pressure correction to a diagnostic"""
|
|
63
|
+
# FIX: This is a bit of a hack, but it works for now
|
|
52
64
|
return PressureCorrection_Diagnostic(diagnostic, self._simulation)
|
|
53
65
|
|
|
54
66
|
|
|
55
67
|
class PressureCorrection_Diagnostic(Diagnostic):
|
|
56
|
-
def __init__(self, diagnostic, n, ufl_j, vfl_k):
|
|
68
|
+
def __init__(self, diagnostic: Diagnostic, n: Diagnostic, ufl_j: Diagnostic, vfl_k: Diagnostic):
|
|
57
69
|
"""
|
|
58
70
|
Class to correct the pressure in the simulation.
|
|
59
71
|
|
|
@@ -100,10 +112,10 @@ class PressureCorrection_Diagnostic(Diagnostic):
|
|
|
100
112
|
self._original_name = diagnostic._name
|
|
101
113
|
self._name = diagnostic._name + "_corrected"
|
|
102
114
|
|
|
103
|
-
self._data = None
|
|
115
|
+
self._data: np.ndarray | None = None
|
|
104
116
|
self._all_loaded = False
|
|
105
117
|
|
|
106
|
-
def load_all(self):
|
|
118
|
+
def load_all(self) -> np.ndarray:
|
|
107
119
|
if self._data is not None:
|
|
108
120
|
return self._data
|
|
109
121
|
|
|
@@ -130,7 +142,7 @@ class PressureCorrection_Diagnostic(Diagnostic):
|
|
|
130
142
|
|
|
131
143
|
return self._data
|
|
132
144
|
|
|
133
|
-
def __getitem__(self, index):
|
|
145
|
+
def __getitem__(self, index: int | slice) -> np.ndarray:
|
|
134
146
|
"""Get data at a specific index"""
|
|
135
147
|
if self._all_loaded and self._data is not None:
|
|
136
148
|
return self._data[index]
|
|
@@ -145,7 +157,7 @@ class PressureCorrection_Diagnostic(Diagnostic):
|
|
|
145
157
|
else:
|
|
146
158
|
raise ValueError("Invalid index type. Use int or slice.")
|
|
147
159
|
|
|
148
|
-
def _data_generator(self, index):
|
|
160
|
+
def _data_generator(self, index: int) -> Generator[np.ndarray, None, None]:
|
|
149
161
|
yield (self._diag[index] - self._n[index] * self._vfl_k[index] * self._ufl_j[index])
|
|
150
162
|
|
|
151
163
|
|
|
@@ -166,11 +178,11 @@ class PressureCorrection_Species_Handler:
|
|
|
166
178
|
The axis to compute the derivative. Only used for 'xx', 'xt' and 'tx' types.
|
|
167
179
|
"""
|
|
168
180
|
|
|
169
|
-
def __init__(self, species_handler):
|
|
181
|
+
def __init__(self, species_handler: Any):
|
|
170
182
|
self._species_handler = species_handler
|
|
171
|
-
self._pressure_corrected = {}
|
|
183
|
+
self._pressure_corrected: dict[str, PressureCorrection_Diagnostic] = {}
|
|
172
184
|
|
|
173
|
-
def __getitem__(self, key):
|
|
185
|
+
def __getitem__(self, key: str) -> PressureCorrection_Diagnostic:
|
|
174
186
|
if key not in self._pressure_corrected:
|
|
175
187
|
diag = self._species_handler[key]
|
|
176
188
|
|