osiris-utils 1.1.10a0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. benchmarks/benchmark_hdf5_io.py +46 -0
  2. benchmarks/benchmark_load_all.py +54 -0
  3. docs/source/_static/Imagem1.png +0 -0
  4. docs/source/_static/custom.css +24 -8
  5. docs/source/api/decks.rst +48 -0
  6. docs/source/api/postprocess.rst +72 -8
  7. docs/source/api/sim_diag.rst +9 -7
  8. docs/source/api/utilities.rst +6 -6
  9. docs/source/conf.py +28 -8
  10. docs/source/examples/example_Derivatives.md +78 -0
  11. docs/source/examples/example_FFT.md +152 -0
  12. docs/source/examples/example_InputDeck.md +149 -0
  13. docs/source/examples/example_Simulation_Diagnostic.md +213 -0
  14. docs/source/examples/quick_start.md +51 -0
  15. docs/source/examples.rst +14 -0
  16. docs/source/index.rst +8 -0
  17. examples/edited-deck.1d +1 -1
  18. examples/example_Derivatives.ipynb +24 -36
  19. examples/example_FFT.ipynb +44 -23
  20. examples/example_InputDeck.ipynb +24 -277
  21. examples/example_Simulation_Diagnostic.ipynb +27 -17
  22. examples/quick_start.ipynb +17 -1
  23. osiris_utils/__init__.py +10 -6
  24. osiris_utils/cli/__init__.py +6 -0
  25. osiris_utils/cli/__main__.py +85 -0
  26. osiris_utils/cli/export.py +199 -0
  27. osiris_utils/cli/info.py +156 -0
  28. osiris_utils/cli/plot.py +189 -0
  29. osiris_utils/cli/validate.py +247 -0
  30. osiris_utils/data/__init__.py +15 -0
  31. osiris_utils/data/data.py +41 -171
  32. osiris_utils/data/diagnostic.py +285 -274
  33. osiris_utils/data/simulation.py +20 -13
  34. osiris_utils/decks/__init__.py +4 -0
  35. osiris_utils/decks/decks.py +83 -8
  36. osiris_utils/decks/species.py +12 -9
  37. osiris_utils/postprocessing/__init__.py +28 -0
  38. osiris_utils/postprocessing/derivative.py +317 -106
  39. osiris_utils/postprocessing/fft.py +135 -24
  40. osiris_utils/postprocessing/field_centering.py +28 -14
  41. osiris_utils/postprocessing/heatflux_correction.py +39 -18
  42. osiris_utils/postprocessing/mft.py +10 -2
  43. osiris_utils/postprocessing/postprocess.py +8 -5
  44. osiris_utils/postprocessing/pressure_correction.py +29 -17
  45. osiris_utils/utils.py +26 -17
  46. osiris_utils/vis/__init__.py +3 -0
  47. osiris_utils/vis/plot3d.py +148 -0
  48. {osiris_utils-1.1.10a0.dist-info → osiris_utils-1.2.1.dist-info}/METADATA +61 -7
  49. {osiris_utils-1.1.10a0.dist-info → osiris_utils-1.2.1.dist-info}/RECORD +53 -35
  50. {osiris_utils-1.1.10a0.dist-info → osiris_utils-1.2.1.dist-info}/WHEEL +1 -1
  51. osiris_utils-1.2.1.dist-info/entry_points.txt +2 -0
  52. {osiris_utils-1.1.10a0.dist-info → osiris_utils-1.2.1.dist-info}/top_level.txt +1 -0
  53. osiris_utils/postprocessing/mft_for_gridfile.py +0 -55
  54. {osiris_utils-1.1.10a0.dist-info → osiris_utils-1.2.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,3 +1,8 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Generator
4
+ from typing import Any
5
+
1
6
  import numpy as np
2
7
  import tqdm as tqdm
3
8
 
@@ -5,6 +10,8 @@ from ..data.diagnostic import Diagnostic
5
10
  from ..data.simulation import Simulation
6
11
  from .postprocess import PostProcess
7
12
 
13
+ __all__ = ["FFT_Simulation", "FFT_Diagnostic", "FFT_Species_Handler"]
14
+
8
15
 
9
16
  class FFT_Simulation(PostProcess):
10
17
  """
@@ -21,16 +28,16 @@ class FFT_Simulation(PostProcess):
21
28
 
22
29
  """
23
30
 
24
- def __init__(self, simulation, fft_axis):
31
+ def __init__(self, simulation: Simulation, fft_axis: int | list[int]) -> None:
25
32
  super().__init__("FFT")
26
33
  if not isinstance(simulation, Simulation):
27
- raise ValueError("Simulation must be a Simulation object.")
34
+ raise ValueError("simulation must be a Simulation-compatible object.")
28
35
  self._simulation = simulation
29
36
  self._fft_axis = fft_axis
30
- self._fft_computed = {}
31
- self._species_handler = {}
37
+ self._fft_computed: dict[Any, FFT_Diagnostic] = {}
38
+ self._species_handler: dict[Any, FFT_Species_Handler] = {}
32
39
 
33
- def __getitem__(self, key):
40
+ def __getitem__(self, key: Any) -> FFT_Species_Handler | FFT_Diagnostic:
34
41
  if key in self._simulation._species:
35
42
  if key not in self._species_handler:
36
43
  self._species_handler[key] = FFT_Species_Handler(self._simulation[key], self._fft_axis)
@@ -40,16 +47,16 @@ class FFT_Simulation(PostProcess):
40
47
  self._fft_computed[key] = FFT_Diagnostic(self._simulation[key], self._fft_axis)
41
48
  return self._fft_computed[key]
42
49
 
43
- def delete_all(self):
50
+ def delete_all(self) -> None:
44
51
  self._fft_computed = {}
45
52
 
46
- def delete(self, key):
53
+ def delete(self, key: Any) -> None:
47
54
  if key in self._fft_computed:
48
55
  del self._fft_computed[key]
49
56
  else:
50
57
  print(f"FFT {key} not found in simulation")
51
58
 
52
- def process(self, diagnostic):
59
+ def process(self, diagnostic: Diagnostic) -> FFT_Diagnostic:
53
60
  """Apply FFT to a diagnostic"""
54
61
  return FFT_Diagnostic(diagnostic, self._fft_axis)
55
62
 
@@ -77,7 +84,7 @@ class FFT_Diagnostic(Diagnostic):
77
84
 
78
85
  """
79
86
 
80
- def __init__(self, diagnostic, fft_axis):
87
+ def __init__(self, diagnostic: Diagnostic, fft_axis: int | list[int]) -> None:
81
88
  if hasattr(diagnostic, "_species"):
82
89
  super().__init__(
83
90
  simulation_folder=(diagnostic._simulation_folder if hasattr(diagnostic, "_simulation_folder") else None),
@@ -113,9 +120,11 @@ class FFT_Diagnostic(Diagnostic):
113
120
  if isinstance(self._dx, (int, float)):
114
121
  self._kmax = np.pi / (self._dx)
115
122
  else:
116
- self._kmax = np.pi / np.array([self._dx[ax - 1] for ax in self._fft_axis if ax != 0])
123
+ # Handle if fft_axis is int
124
+ axes = [self._fft_axis] if isinstance(self._fft_axis, int) else self._fft_axis
125
+ self._kmax = np.pi / np.array([self._dx[ax - 1] for ax in axes if ax != 0])
117
126
 
118
- def load_all(self):
127
+ def load_all(self) -> np.ndarray:
119
128
  if self._data is not None:
120
129
  print("Using cached data.")
121
130
  return self._data
@@ -165,7 +174,7 @@ class FFT_Diagnostic(Diagnostic):
165
174
  self._data = np.abs(result) ** 2
166
175
  return self._data
167
176
 
168
- def _data_generator(self, index):
177
+ def _data_generator(self, index: int) -> Generator[np.ndarray, None, None]:
169
178
  # Get the data for this index
170
179
  original_data = self._diag[index]
171
180
 
@@ -195,10 +204,10 @@ class FFT_Diagnostic(Diagnostic):
195
204
 
196
205
  yield np.abs(result_fft) ** 2
197
206
 
198
- def _get_window(self, length, axis):
207
+ def _get_window(self, length: int, axis: int) -> np.ndarray:
199
208
  return np.hanning(length)
200
209
 
201
- def _apply_window(self, data, window, axis):
210
+ def _apply_window(self, data: np.ndarray, window: np.ndarray, axis: int) -> np.ndarray:
202
211
  ndim = data.ndim
203
212
  window_shape = [1] * ndim
204
213
  window_shape[axis] = len(window)
@@ -207,7 +216,7 @@ class FFT_Diagnostic(Diagnostic):
207
216
 
208
217
  return data * reshaped_window
209
218
 
210
- def __getitem__(self, index):
219
+ def __getitem__(self, index: int | slice) -> np.ndarray:
211
220
  if self._all_loaded and self._data is not None:
212
221
  return self._data[index]
213
222
 
@@ -221,29 +230,131 @@ class FFT_Diagnostic(Diagnostic):
221
230
  else:
222
231
  raise ValueError("Invalid index type. Use int or slice.")
223
232
 
224
- def omega(self):
233
+ def omega(self) -> np.ndarray:
225
234
  """
226
- Get the angular frequency array for the FFT.
235
+ Get the angular frequency array for the FFT along the time dimension (axis 0).
236
+
237
+ Returns
238
+ -------
239
+ np.ndarray
240
+ Angular frequency array for the time axis.
227
241
  """
228
242
  if not self._all_loaded:
229
243
  raise ValueError("Load the data first using load_all() method.")
230
244
 
231
- omega = np.fft.fftfreq(self._data.shape[self._fft_axis], d=self._dx[self._fft_axis - 1])
232
- omega = np.fft.fftshift(omega)
233
- return omega
245
+ # If the FFT was computed along the time axis (0) return temporal frequencies
246
+ if isinstance(self._fft_axis, (list, tuple)):
247
+ if 0 in self._fft_axis:
248
+ dt = self._dt * self._ndump
249
+ omega = np.fft.fftfreq(self._data.shape[0], d=dt) * 2 * np.pi
250
+ return np.fft.fftshift(omega)
251
+ # If FFT was computed along spatial axes only and a single spatial axis
252
+ spatial_axes = [ax for ax in self._fft_axis if ax != 0]
253
+ if len(spatial_axes) == 1:
254
+ return self.k(spatial_axes[0])
255
+ # Multi-dimensional spatial FFT: return concatenated or dict of k arrays
256
+ return self.k()
257
+ else:
258
+ if self._fft_axis == 0:
259
+ dt = self._dt * self._ndump
260
+ omega = np.fft.fftfreq(self._data.shape[0], d=dt) * 2 * np.pi
261
+ return np.fft.fftshift(omega)
262
+ # Single spatial axis: return wavenumber array for that axis
263
+ return self.k(self._fft_axis)
264
+
265
+ def k(self, axis: int | None = None) -> np.ndarray | dict[int, np.ndarray]:
266
+ """
267
+ Get the wavenumber array for the FFT along spatial dimension(s).
268
+
269
+ Parameters
270
+ ----------
271
+ axis : int or None, optional
272
+ The spatial axis to compute wavenumber for (1, 2, or 3).
273
+ If None, returns wavenumbers for all spatial axes in fft_axis.
274
+
275
+ Returns
276
+ -------
277
+ np.ndarray or dict
278
+ If axis is specified: wavenumber array for that axis.
279
+ If axis is None: dictionary mapping axis -> wavenumber array.
280
+
281
+ Notes
282
+ -----
283
+ When load_all() is used, time axis is 0 and spatial axes are 1,2,3.
284
+ When accessing single timesteps, spatial axes are 0,1,2.
285
+ """
286
+ if self._data is None:
287
+ raise ValueError("Load the data first using load_all() or access via indexing.")
288
+
289
+ # Determine if we have the time dimension in the data
290
+ # If all_loaded is True, then axis 0 is time, spatial axes are 1,2,3
291
+ # If all_loaded is False, we're looking at a single timestep, spatial axes are 0,1,2
292
+ has_time_axis = self._all_loaded
293
+
294
+ # Determine which axes to compute k for
295
+ if axis is not None:
296
+ # Single axis specified
297
+ if axis == 0:
298
+ raise ValueError("axis must be a spatial dimension (1, 2, or 3), not 0 (time).")
299
+ if axis < 1 or axis > 3:
300
+ raise ValueError(f"axis must be 1, 2, or 3, got {axis}")
301
+
302
+ # Get dx for this axis
303
+ if isinstance(self._dx, (int, float)):
304
+ dx = self._dx
305
+ else:
306
+ dx = self._dx[axis - 1]
307
+
308
+ # Compute the actual data axis index
309
+ # If we have time axis, spatial axis N is at index N
310
+ # If no time axis (single timestep), spatial axis N is at index N-1
311
+ data_axis = axis if has_time_axis else axis - 1
312
+
313
+ # Compute wavenumber
314
+ k_array = np.fft.fftfreq(self._data.shape[data_axis], d=dx) * 2 * np.pi
315
+ k_array = np.fft.fftshift(k_array)
316
+ return k_array
317
+ else:
318
+ # axis is None: return k for all spatial axes in fft_axis
319
+ result = {}
320
+
321
+ if isinstance(self._fft_axis, (list, tuple)):
322
+ # Multi-axis FFT: return k for all spatial axes
323
+ spatial_axes = [ax for ax in self._fft_axis if ax != 0]
324
+ elif self._fft_axis == 0:
325
+ # Only time FFT, no spatial axes
326
+ raise ValueError("No spatial FFT axes to compute wavenumber for. fft_axis is 0 (time only).")
327
+ else:
328
+ # Single spatial axis
329
+ spatial_axes = [self._fft_axis]
330
+
331
+ for ax in spatial_axes:
332
+ if isinstance(self._dx, (int, float)):
333
+ dx = self._dx
334
+ else:
335
+ dx = self._dx[ax - 1]
336
+
337
+ # Compute the actual data axis index
338
+ data_axis = ax if has_time_axis else ax - 1
339
+
340
+ k_array = np.fft.fftfreq(self._data.shape[data_axis], d=dx) * 2 * np.pi
341
+ k_array = np.fft.fftshift(k_array)
342
+ result[ax] = k_array
343
+
344
+ return result
234
345
 
235
346
  @property
236
- def kmax(self):
347
+ def kmax(self) -> float | np.ndarray:
237
348
  return self._kmax
238
349
 
239
350
 
240
351
  class FFT_Species_Handler:
241
- def __init__(self, species_handler, fft_axis):
352
+ def __init__(self, species_handler: Any, fft_axis: int | list[int]) -> None:
242
353
  self._species_handler = species_handler
243
354
  self._fft_axis = fft_axis
244
- self._fft_computed = {}
355
+ self._fft_computed: dict[Any, FFT_Diagnostic] = {}
245
356
 
246
- def __getitem__(self, key):
357
+ def __getitem__(self, key: Any) -> FFT_Diagnostic:
247
358
  if key not in self._fft_computed:
248
359
  diag = self._species_handler[key]
249
360
  self._fft_computed[key] = FFT_Diagnostic(diag, self._fft_axis)
@@ -1,9 +1,15 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Generator
4
+
1
5
  import numpy as np
2
6
 
3
7
  from ..data.diagnostic import OSIRIS_FLD, Diagnostic
4
8
  from ..data.simulation import Simulation
5
9
  from .postprocess import PostProcess
6
10
 
11
+ __all__ = ["FieldCentering_Simulation", "FieldCentering_Diagnostic"]
12
+
7
13
 
8
14
  class FieldCentering_Simulation(PostProcess):
9
15
  """
@@ -22,47 +28,47 @@ class FieldCentering_Simulation(PostProcess):
22
28
  """
23
29
 
24
30
  def __init__(self, simulation: Simulation):
25
- super().__init__("FieldCentering Simulation")
26
31
  """
27
32
  Class to center the field in the simulation.
28
33
 
29
34
  Parameters
30
35
  ----------
31
- sim : Simulation
36
+ simulation : Simulation
32
37
  The simulation object.
33
- field : str
34
- The field to center.
35
38
  """
39
+ super().__init__("FieldCentering Simulation")
40
+
41
+ # Accept Simulation-compatible objects (Simulation or other PostProcess subclasses)
36
42
  if not isinstance(simulation, Simulation):
37
- raise ValueError("Simulation must be a Simulation object.")
43
+ raise ValueError("simulation must be a Simulation-compatible object.")
38
44
  self._simulation = simulation
39
45
 
40
46
  self._field_centered = {}
41
47
  # no need to create a species handler for field centering since fields are not species related
42
48
 
43
- def __getitem__(self, key):
49
+ def __getitem__(self, key: str) -> FieldCentering_Diagnostic:
44
50
  if key not in OSIRIS_FLD:
45
51
  raise ValueError(f"Does it make sense to center {key} field? Only {OSIRIS_FLD} are supported.")
46
52
  if key not in self._field_centered:
47
53
  self._field_centered[key] = FieldCentering_Diagnostic(self._simulation[key])
48
54
  return self._field_centered[key]
49
55
 
50
- def delete_all(self):
56
+ def delete_all(self) -> None:
51
57
  self._field_centered = {}
52
58
 
53
- def delete(self, key):
59
+ def delete(self, key: str) -> None:
54
60
  if key in self._field_centered:
55
61
  del self._field_centered[key]
56
62
  else:
57
63
  print(f"Field {key} not found in simulation")
58
64
 
59
- def process(self, diagnostic):
65
+ def process(self, diagnostic: Diagnostic) -> FieldCentering_Diagnostic:
60
66
  """Apply field centering to a diagnostic"""
61
67
  return FieldCentering_Diagnostic(diagnostic)
62
68
 
63
69
 
64
70
  class FieldCentering_Diagnostic(Diagnostic):
65
- def __init__(self, diagnostic):
71
+ def __init__(self, diagnostic: Diagnostic):
66
72
  """
67
73
  Class to center the field in the simulation. It converts fields from the Osiris yee mesh to the center of the cells.
68
74
  It only works for periodic boundaries.
@@ -104,10 +110,18 @@ class FieldCentering_Diagnostic(Diagnostic):
104
110
  self._original_name = diagnostic._name
105
111
  self._name = diagnostic._name + "_centered"
106
112
 
107
- self._data = None
113
+ self._data: np.ndarray | None = None
108
114
  self._all_loaded = False
109
115
 
110
- def load_all(self):
116
+ def load_all(self) -> np.ndarray:
117
+ """
118
+ Load all data and center the fields.
119
+
120
+ Returns
121
+ -------
122
+ data : np.ndarray
123
+ The centered field data.
124
+ """
111
125
  if self._data is not None:
112
126
  return self._data
113
127
 
@@ -217,7 +231,7 @@ class FieldCentering_Diagnostic(Diagnostic):
217
231
  self._all_loaded = True
218
232
  return self._data
219
233
 
220
- def __getitem__(self, index):
234
+ def __getitem__(self, index: int | slice) -> np.ndarray:
221
235
  """Get data at a specific index"""
222
236
  if self._all_loaded and self._data is not None:
223
237
  return self._data[index]
@@ -232,7 +246,7 @@ class FieldCentering_Diagnostic(Diagnostic):
232
246
  else:
233
247
  raise ValueError("Invalid index type. Use int or slice.")
234
248
 
235
- def _data_generator(self, index):
249
+ def _data_generator(self, index: int) -> Generator[np.ndarray, None, None]:
236
250
  if self._dim == 1:
237
251
  if self._original_name.lower() in [
238
252
  "b2",
@@ -1,3 +1,8 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Generator
4
+ from typing import Any
5
+
1
6
  import numpy as np
2
7
 
3
8
  from ..data.diagnostic import Diagnostic
@@ -7,9 +12,15 @@ from .pressure_correction import PressureCorrection_Simulation
7
12
 
8
13
  OSIRIS_H = ["q1", "q2", "q3"]
9
14
 
15
+ __all__ = [
16
+ "HeatfluxCorrection_Simulation",
17
+ "HeatfluxCorrection_Diagnostic",
18
+ "HeatfluxCorrection_Species_Handler",
19
+ ]
20
+
10
21
 
11
22
  class HeatfluxCorrection_Simulation(PostProcess):
12
- def __init__(self, simulation):
23
+ def __init__(self, simulation: Simulation):
13
24
  super().__init__("HeatfluxCorrection Simulation")
14
25
  """
15
26
  Class to correct pressure tensor components by subtracting Reynolds stress.
@@ -22,12 +33,12 @@ class HeatfluxCorrection_Simulation(PostProcess):
22
33
  The heatflux component to center.
23
34
  """
24
35
  if not isinstance(simulation, Simulation):
25
- raise ValueError("Simulation must be a Simulation object.")
36
+ raise ValueError("simulation must be a Simulation-compatible object.")
26
37
  self._simulation = simulation
27
- self._heatflux_corrected = {}
28
- self._species_handler = {}
38
+ self._heatflux_corrected: dict[str, HeatfluxCorrection_Diagnostic] = {}
39
+ self._species_handler: dict[str, HeatfluxCorrection_Species_Handler] = {}
29
40
 
30
- def __getitem__(self, key):
41
+ def __getitem__(self, key: str) -> HeatfluxCorrection_Species_Handler | HeatfluxCorrection_Diagnostic:
31
42
  if key in self._simulation._species:
32
43
  if key not in self._species_handler:
33
44
  self._species_handler[key] = HeatfluxCorrection_Species_Handler(self._simulation[key], self._simulation)
@@ -36,25 +47,35 @@ class HeatfluxCorrection_Simulation(PostProcess):
36
47
  raise ValueError(f"Invalid heatflux component {key}. Supported: {OSIRIS_H}.")
37
48
  if key not in self._heatflux_corrected:
38
49
  print("Weird that it got here - heatflux is always species dependent on OSIRIS")
39
- self._heatflux_corrected[key] = HeatfluxCorrection_Diagnostic(self._simulation[key], self._simulation)
50
+ # This part seems to lack some arguments for HeatfluxCorrection_Diagnostic,
51
+ # but keeping as is for structural consistency if reached
52
+ # self._heatflux_corrected[key] = HeatfluxCorrection_Diagnostic(self._simulation[key], self._simulation)
40
53
  return self._heatflux_corrected[key]
41
54
 
42
- def delete_all(self):
55
+ def delete_all(self) -> None:
43
56
  self._heatflux_corrected = {}
44
57
 
45
- def delete(self, key):
58
+ def delete(self, key: str) -> None:
46
59
  if key in self._heatflux_corrected:
47
60
  del self._heatflux_corrected[key]
48
61
  else:
49
62
  print(f"Heatflux {key} not found in simulation")
50
63
 
51
- def process(self, diagnostic):
64
+ def process(self, diagnostic: Diagnostic) -> HeatfluxCorrection_Diagnostic:
52
65
  """Apply heatflux correction to a diagnostic"""
66
+ # FIX: This is a bit of a hack, but it works for now
53
67
  return HeatfluxCorrection_Diagnostic(diagnostic, self._simulation)
54
68
 
55
69
 
56
70
  class HeatfluxCorrection_Diagnostic(Diagnostic):
57
- def __init__(self, diagnostic, vfl_i, Pjj_list, vfl_j_list, Pji_list):
71
+ def __init__(
72
+ self,
73
+ diagnostic: Diagnostic,
74
+ vfl_i: Diagnostic,
75
+ Pjj_list: list[Diagnostic],
76
+ vfl_j_list: list[Diagnostic],
77
+ Pji_list: list[Diagnostic],
78
+ ):
58
79
  """
59
80
  Class to correct the pressure in the simulation.
60
81
 
@@ -102,10 +123,10 @@ class HeatfluxCorrection_Diagnostic(Diagnostic):
102
123
  self._original_name = diagnostic._name
103
124
  self._name = diagnostic._name + "_corrected"
104
125
 
105
- self._data = None
126
+ self._data: np.ndarray | None = None
106
127
  self._all_loaded = False
107
128
 
108
- def load_all(self):
129
+ def load_all(self) -> np.ndarray:
109
130
  if self._data is not None:
110
131
  return self._data
111
132
 
@@ -131,12 +152,12 @@ class HeatfluxCorrection_Diagnostic(Diagnostic):
131
152
  # Sum over j: vfl_j * Pji
132
153
  vfl_dot_Pji = sum(vfl_j.data * Pji.data for vfl_j, Pji in zip(self._vfl_j_list, self._Pji_list, strict=False))
133
154
 
134
- self._data = 2 * q - 0.5 * vfl_i * trace_P - vfl_dot_Pji
155
+ self._data = 2 * q - vfl_i * trace_P - 2 * vfl_dot_Pji
135
156
  self._all_loaded = True
136
157
 
137
158
  return self._data
138
159
 
139
- def __getitem__(self, index):
160
+ def __getitem__(self, index: int | slice) -> np.ndarray:
140
161
  """Get data at a specific index"""
141
162
  if self._all_loaded and self._data is not None:
142
163
  return self._data[index]
@@ -151,7 +172,7 @@ class HeatfluxCorrection_Diagnostic(Diagnostic):
151
172
  else:
152
173
  raise ValueError("Invalid index type. Use int or slice.")
153
174
 
154
- def _data_generator(self, index):
175
+ def _data_generator(self, index: int) -> Generator[np.ndarray, None, None]:
155
176
  q = self._diag[index]
156
177
  vfl_i = self._vfl_i[index]
157
178
  trace_P = sum(Pjj[index] for Pjj in self._Pjj_list)
@@ -174,12 +195,12 @@ class HeatfluxCorrection_Species_Handler:
174
195
  The simulation object.
175
196
  """
176
197
 
177
- def __init__(self, species_handler, simulation):
198
+ def __init__(self, species_handler: Any, simulation: Simulation):
178
199
  self._species_handler = species_handler
179
200
  self._simulation = simulation
180
- self._heatflux_corrected = {}
201
+ self._heatflux_corrected: dict[str, HeatfluxCorrection_Diagnostic] = {}
181
202
 
182
- def __getitem__(self, key):
203
+ def __getitem__(self, key: str) -> HeatfluxCorrection_Diagnostic:
183
204
  if key not in self._heatflux_corrected:
184
205
  diag = self._species_handler[key]
185
206
 
@@ -1,9 +1,17 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
 
3
5
  from ..data.diagnostic import Diagnostic
4
6
  from ..data.simulation import Simulation
5
7
  from .postprocess import PostProcess
6
8
 
9
+ __all__ = [
10
+ "MFT_Simulation",
11
+ "MFT_Diagnostic",
12
+ "MFT_Species_Handler",
13
+ ]
14
+
7
15
 
8
16
  class MFT_Simulation(PostProcess):
9
17
  """
@@ -19,10 +27,10 @@ class MFT_Simulation(PostProcess):
19
27
 
20
28
  """
21
29
 
22
- def __init__(self, simulation, mft_axis=None):
30
+ def __init__(self, simulation: Simulation, mft_axis: int | None = None):
23
31
  super().__init__(f"MeanFieldTheory({mft_axis})")
24
32
  if not isinstance(simulation, Simulation):
25
- raise ValueError("Simulation must be a Simulation object.")
33
+ raise ValueError("simulation must be a Simulation-compatible object.")
26
34
  self._simulation = simulation
27
35
  self._mft_axis = mft_axis
28
36
  self._mft_computed = {}
@@ -1,7 +1,11 @@
1
- from ..data.diagnostic import Diagnostic
1
+ from __future__ import annotations
2
2
 
3
+ from ..data.simulation import Simulation
3
4
 
4
- class PostProcess(Diagnostic):
5
+ __all__ = ["PostProcess"]
6
+
7
+
8
+ class PostProcess(Simulation):
5
9
  """
6
10
  Base class for post-processing operations.
7
11
  Inherits from Diagnostic to ensure all operation overloads work.
@@ -14,14 +18,13 @@ class PostProcess(Diagnostic):
14
18
  The species to analyze.
15
19
  """
16
20
 
17
- def __init__(self, name, species=None):
21
+ def __init__(self, name: str, species: str = None):
18
22
  # Initialize with the same interface as Diagnostic
19
- super().__init__(species)
20
23
  self._name = name
21
24
  self._all_loaded = False
22
25
  self._data = None
23
26
 
24
- def process(self, diagnostic):
27
+ def process(self, simulation: Simulation) -> Simulation:
25
28
  """
26
29
  Apply the post-processing to a diagnostic.
27
30
  Must be implemented by subclasses.
@@ -1,3 +1,8 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Generator
4
+ from typing import Any
5
+
1
6
  import numpy as np
2
7
 
3
8
  from ..data.diagnostic import Diagnostic
@@ -6,9 +11,15 @@ from .postprocess import PostProcess
6
11
 
7
12
  OSIRIS_P = ["P11", "P12", "P13", "P21", "P22", "P23", "P31", "P32", "P33"]
8
13
 
14
+ __all__ = [
15
+ "PressureCorrection_Simulation",
16
+ "PressureCorrection_Diagnostic",
17
+ "PressureCorrection_Species_Handler",
18
+ ]
19
+
9
20
 
10
21
  class PressureCorrection_Simulation(PostProcess):
11
- def __init__(self, simulation):
22
+ def __init__(self, simulation: Simulation):
12
23
  super().__init__("PressureCorrection Simulation")
13
24
  """
14
25
  Class to correct pressure tensor components by subtracting Reynolds stress.
@@ -21,12 +32,12 @@ class PressureCorrection_Simulation(PostProcess):
21
32
  The pressure component to center.
22
33
  """
23
34
  if not isinstance(simulation, Simulation):
24
- raise ValueError("Simulation must be a Simulation object.")
35
+ raise ValueError("simulation must be a Simulation-compatible object.")
25
36
  self._simulation = simulation
26
- self._pressure_corrected = {}
27
- self._species_handler = {}
37
+ self._pressure_corrected: dict[str, PressureCorrection_Diagnostic] = {}
38
+ self._species_handler: dict[str, PressureCorrection_Species_Handler] = {}
28
39
 
29
- def __getitem__(self, key):
40
+ def __getitem__(self, key: str) -> PressureCorrection_Species_Handler | PressureCorrection_Diagnostic:
30
41
  if key in self._simulation._species:
31
42
  if key not in self._species_handler:
32
43
  self._species_handler[key] = PressureCorrection_Species_Handler(self._simulation[key])
@@ -35,25 +46,26 @@ class PressureCorrection_Simulation(PostProcess):
35
46
  raise ValueError(f"Invalid pressure component {key}. Supported: {OSIRIS_P}.")
36
47
  if key not in self._pressure_corrected:
37
48
  print("Weird that it got here - pressure is always species dependent on OSIRIS")
38
- self._pressure_corrected[key] = PressureCorrection_Diagnostic(self._simulation[key], self._simulation)
49
+ # self._pressure_corrected[key] = PressureCorrection_Diagnostic(self._simulation[key], self._simulation)
39
50
  return self._pressure_corrected[key]
40
51
 
41
- def delete_all(self):
52
+ def delete_all(self) -> None:
42
53
  self._pressure_corrected = {}
43
54
 
44
- def delete(self, key):
55
+ def delete(self, key: str) -> None:
45
56
  if key in self._pressure_corrected:
46
57
  del self._pressure_corrected[key]
47
58
  else:
48
59
  print(f"Pressure {key} not found in simulation")
49
60
 
50
- def process(self, diagnostic):
61
+ def process(self, diagnostic: Diagnostic) -> PressureCorrection_Diagnostic:
51
62
  """Apply pressure correction to a diagnostic"""
63
+ # FIX: This is a bit of a hack, but it works for now
52
64
  return PressureCorrection_Diagnostic(diagnostic, self._simulation)
53
65
 
54
66
 
55
67
  class PressureCorrection_Diagnostic(Diagnostic):
56
- def __init__(self, diagnostic, n, ufl_j, vfl_k):
68
+ def __init__(self, diagnostic: Diagnostic, n: Diagnostic, ufl_j: Diagnostic, vfl_k: Diagnostic):
57
69
  """
58
70
  Class to correct the pressure in the simulation.
59
71
 
@@ -100,10 +112,10 @@ class PressureCorrection_Diagnostic(Diagnostic):
100
112
  self._original_name = diagnostic._name
101
113
  self._name = diagnostic._name + "_corrected"
102
114
 
103
- self._data = None
115
+ self._data: np.ndarray | None = None
104
116
  self._all_loaded = False
105
117
 
106
- def load_all(self):
118
+ def load_all(self) -> np.ndarray:
107
119
  if self._data is not None:
108
120
  return self._data
109
121
 
@@ -130,7 +142,7 @@ class PressureCorrection_Diagnostic(Diagnostic):
130
142
 
131
143
  return self._data
132
144
 
133
- def __getitem__(self, index):
145
+ def __getitem__(self, index: int | slice) -> np.ndarray:
134
146
  """Get data at a specific index"""
135
147
  if self._all_loaded and self._data is not None:
136
148
  return self._data[index]
@@ -145,7 +157,7 @@ class PressureCorrection_Diagnostic(Diagnostic):
145
157
  else:
146
158
  raise ValueError("Invalid index type. Use int or slice.")
147
159
 
148
- def _data_generator(self, index):
160
+ def _data_generator(self, index: int) -> Generator[np.ndarray, None, None]:
149
161
  yield (self._diag[index] - self._n[index] * self._vfl_k[index] * self._ufl_j[index])
150
162
 
151
163
 
@@ -166,11 +178,11 @@ class PressureCorrection_Species_Handler:
166
178
  The axis to compute the derivative. Only used for 'xx', 'xt' and 'tx' types.
167
179
  """
168
180
 
169
- def __init__(self, species_handler):
181
+ def __init__(self, species_handler: Any):
170
182
  self._species_handler = species_handler
171
- self._pressure_corrected = {}
183
+ self._pressure_corrected: dict[str, PressureCorrection_Diagnostic] = {}
172
184
 
173
- def __getitem__(self, key):
185
+ def __getitem__(self, key: str) -> PressureCorrection_Diagnostic:
174
186
  if key not in self._pressure_corrected:
175
187
  diag = self._species_handler[key]
176
188