osiris-utils 1.1.10__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. benchmarks/benchmark_hdf5_io.py +46 -0
  2. benchmarks/benchmark_load_all.py +54 -0
  3. docs/source/api/decks.rst +48 -0
  4. docs/source/api/postprocess.rst +66 -2
  5. docs/source/api/sim_diag.rst +1 -1
  6. docs/source/api/utilities.rst +1 -1
  7. docs/source/conf.py +2 -1
  8. docs/source/examples/example_Derivatives.md +78 -0
  9. docs/source/examples/example_FFT.md +152 -0
  10. docs/source/examples/example_InputDeck.md +148 -0
  11. docs/source/examples/example_Simulation_Diagnostic.md +213 -0
  12. docs/source/examples/quick_start.md +51 -0
  13. docs/source/examples.rst +14 -0
  14. docs/source/index.rst +8 -0
  15. examples/edited-deck.1d +1 -1
  16. examples/example_Derivatives.ipynb +24 -36
  17. examples/example_FFT.ipynb +44 -23
  18. examples/example_InputDeck.ipynb +24 -277
  19. examples/example_Simulation_Diagnostic.ipynb +27 -17
  20. examples/quick_start.ipynb +17 -1
  21. osiris_utils/__init__.py +10 -6
  22. osiris_utils/cli/__init__.py +6 -0
  23. osiris_utils/cli/__main__.py +85 -0
  24. osiris_utils/cli/export.py +199 -0
  25. osiris_utils/cli/info.py +156 -0
  26. osiris_utils/cli/plot.py +189 -0
  27. osiris_utils/cli/validate.py +247 -0
  28. osiris_utils/data/__init__.py +15 -0
  29. osiris_utils/data/data.py +41 -171
  30. osiris_utils/data/diagnostic.py +285 -274
  31. osiris_utils/data/simulation.py +20 -13
  32. osiris_utils/decks/__init__.py +4 -0
  33. osiris_utils/decks/decks.py +83 -8
  34. osiris_utils/decks/species.py +12 -9
  35. osiris_utils/postprocessing/__init__.py +28 -0
  36. osiris_utils/postprocessing/derivative.py +317 -106
  37. osiris_utils/postprocessing/fft.py +135 -24
  38. osiris_utils/postprocessing/field_centering.py +28 -14
  39. osiris_utils/postprocessing/heatflux_correction.py +39 -18
  40. osiris_utils/postprocessing/mft.py +10 -2
  41. osiris_utils/postprocessing/postprocess.py +8 -5
  42. osiris_utils/postprocessing/pressure_correction.py +29 -17
  43. osiris_utils/utils.py +26 -17
  44. osiris_utils/vis/__init__.py +3 -0
  45. osiris_utils/vis/plot3d.py +148 -0
  46. {osiris_utils-1.1.10.dist-info → osiris_utils-1.2.0.dist-info}/METADATA +55 -7
  47. {osiris_utils-1.1.10.dist-info → osiris_utils-1.2.0.dist-info}/RECORD +51 -34
  48. {osiris_utils-1.1.10.dist-info → osiris_utils-1.2.0.dist-info}/WHEEL +1 -1
  49. osiris_utils-1.2.0.dist-info/entry_points.txt +2 -0
  50. {osiris_utils-1.1.10.dist-info → osiris_utils-1.2.0.dist-info}/top_level.txt +1 -0
  51. osiris_utils/postprocessing/mft_for_gridfile.py +0 -55
  52. {osiris_utils-1.1.10.dist-info → osiris_utils-1.2.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,19 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Generator
4
+ from typing import Any
5
+
1
6
  import numpy as np
2
7
 
3
8
  from ..data.diagnostic import Diagnostic
4
9
  from ..data.simulation import Simulation
5
10
  from .postprocess import PostProcess
6
11
 
12
+ __all__ = ["Derivative_Diagnostic", "Derivative_Simulation", "Derivative_Species_Handler"]
13
+
7
14
 
8
15
  class Derivative_Simulation(PostProcess):
9
- """
10
- Class to compute the derivative of a diagnostic. Works as a wrapper for the Derivative_Diagnostic class.
16
+ """Class to compute the derivative of a diagnostic. Works as a wrapper for the Derivative_Diagnostic class.
11
17
  Inherits from PostProcess to ensure all operation overloads work properly.
12
18
 
19
+ This class can be initialized with either a Simulation object or another Derivative_Simulation object,
20
+ allowing for chaining derivatives (e.g., second derivative = Derivative_Simulation(Derivative_Simulation(...))).
21
+
13
22
  Parameters
14
23
  ----------
15
- simulation : Simulation
16
- The simulation object.
24
+ simulation : Simulation or Derivative_Simulation
25
+ The simulation object or another derivative simulation object.
17
26
  deriv_type : str
18
27
  The type of derivative to compute. Options are:
19
28
  - 't' for time derivative.
@@ -25,23 +34,42 @@ class Derivative_Simulation(PostProcess):
25
34
  - 'tx' for mixed derivative in one spatial axis and time.
26
35
  axis : int or tuple
27
36
  The axis to compute the derivative. Only used for 'xx', 'xt' and 'tx' types.
37
+ order : int
38
+ The order of the derivative. Currently only 2 and 4 are supported.
39
+ Order 2 uses central differences with edge_order=2 in numpy.gradient.
40
+ Order 4 uses a higher order finite difference scheme. For the edge points,
41
+ a lower order scheme is used to avoid going out of bounds.
28
42
 
29
43
  """
30
44
 
31
- def __init__(self, simulation, deriv_type, axis=None):
45
+ def __init__(
46
+ self,
47
+ simulation: Simulation | Derivative_Simulation,
48
+ deriv_type: str,
49
+ axis: int | tuple[int, int] | None = None,
50
+ order: int = 2,
51
+ ):
32
52
  super().__init__(f"Derivative({deriv_type})")
33
- if not isinstance(simulation, Simulation):
34
- raise ValueError("Simulation must be a Simulation object.")
53
+ # Accept both Simulation and Derivative_Simulation objects
54
+ if not isinstance(simulation, (Simulation, Derivative_Simulation)):
55
+ raise ValueError("simulation must be a Simulation or Derivative_Simulation object.")
35
56
  self._simulation = simulation
36
57
  self._deriv_type = deriv_type
37
58
  self._axis = axis
38
59
  self._derivatives_computed = {}
39
60
  self._species_handler = {}
61
+ self._order = order
40
62
 
41
- def __getitem__(self, key):
63
+ # Copy species list to make this class behave like a Simulation
64
+ if hasattr(simulation, "_species"):
65
+ self._species = simulation._species
66
+ else:
67
+ self._species = []
68
+
69
+ def __getitem__(self, key: Any) -> Derivative_Species_Handler | Derivative_Diagnostic:
42
70
  if key in self._simulation._species:
43
71
  if key not in self._species_handler:
44
- self._species_handler[key] = Derivative_Species_Handler(self._simulation[key], self._deriv_type, self._axis)
72
+ self._species_handler[key] = Derivative_Species_Handler(self._simulation[key], self._deriv_type, self._axis, self._order)
45
73
  return self._species_handler[key]
46
74
 
47
75
  if key not in self._derivatives_computed:
@@ -49,26 +77,64 @@ class Derivative_Simulation(PostProcess):
49
77
  diagnostic=self._simulation[key],
50
78
  deriv_type=self._deriv_type,
51
79
  axis=self._axis,
80
+ order=self._order,
52
81
  )
53
82
  return self._derivatives_computed[key]
54
83
 
55
- def delete_all(self):
84
+ def delete_all(self) -> None:
56
85
  self._derivatives_computed = {}
57
86
 
58
- def delete(self, key):
87
+ def delete(self, key: Any) -> None:
59
88
  if key in self._derivatives_computed:
60
89
  del self._derivatives_computed[key]
61
90
  else:
62
91
  print(f"Derivative {key} not found in simulation")
63
92
 
64
- def process(self, diagnostic):
93
+ def process(self, diagnostic: Diagnostic) -> Derivative_Diagnostic:
65
94
  """Apply derivative to a diagnostic"""
66
95
  return Derivative_Diagnostic(diagnostic, self._deriv_type, self._axis)
67
96
 
97
+ @property
98
+ def species(self) -> list:
99
+ """Return list of species, making this compatible with Simulation interface"""
100
+ return self._species
101
+
102
+ @property
103
+ def loaded_diagnostics(self) -> dict:
104
+ """Return loaded diagnostics, making this compatible with Simulation interface"""
105
+ return self._derivatives_computed
106
+
107
+ def add_diagnostic(self, diagnostic: Diagnostic, name: str | None = None) -> str:
108
+ """Add a custom diagnostic to the derivative simulation.
109
+
110
+ Parameters
111
+ ----------
112
+ diagnostic : Diagnostic
113
+ The diagnostic to add.
114
+ name : str, optional
115
+ The name to use as the key for accessing the diagnostic.
116
+ If None, an auto-generated name will be used.
117
+
118
+ Returns
119
+ -------
120
+ str
121
+ The name (key) used to store the diagnostic
122
+
123
+ """
124
+ if name is None:
125
+ i = 1
126
+ while f"custom_diag_{i}" in self._derivatives_computed:
127
+ i += 1
128
+ name = f"custom_diag_{i}"
129
+
130
+ if isinstance(diagnostic, Diagnostic):
131
+ self._derivatives_computed[name] = diagnostic
132
+ return name
133
+ raise ValueError("Only Diagnostic objects are supported")
134
+
68
135
 
69
136
  class Derivative_Diagnostic(Diagnostic):
70
- """
71
- Auxiliar class to compute the derivative of a diagnostic, for it to be similar in behavior to a Diagnostic object.
137
+ """Auxiliar class to compute the derivative of a diagnostic, for it to be similar in behavior to a Diagnostic object.
72
138
  Inherits directly from Diagnostic to ensure all operation overloads work properly.
73
139
 
74
140
  Parameters
@@ -89,7 +155,7 @@ class Derivative_Diagnostic(Diagnostic):
89
155
 
90
156
  """
91
157
 
92
- def __init__(self, diagnostic, deriv_type, axis=None):
158
+ def __init__(self, diagnostic: Diagnostic, deriv_type: str, axis: int | tuple[int, int] | None = None, order: int = 2) -> None:
93
159
  # Initialize using parent's __init__ with the same species
94
160
  if hasattr(diagnostic, "_species"):
95
161
  super().__init__(
@@ -107,6 +173,7 @@ class Derivative_Diagnostic(Diagnostic):
107
173
  self._axis = axis if axis is not None else diagnostic._axis
108
174
  self._data = None
109
175
  self._all_loaded = False
176
+ self._order = order
110
177
 
111
178
  # Copy all relevant attributes from diagnostic
112
179
  for attr in [
@@ -124,127 +191,269 @@ class Derivative_Diagnostic(Diagnostic):
124
191
  if hasattr(diagnostic, attr):
125
192
  setattr(self, attr, getattr(diagnostic, attr))
126
193
 
127
- def load_all(self):
194
+ @staticmethod
195
+ def _compute_fourth_order_spatial(data: np.ndarray, dx: float, axis: int) -> np.ndarray:
196
+ """Compute 4th-order spatial derivative along specified axis using vectorized operations.
197
+
198
+ Uses the 4th-order central difference stencil:
199
+ Interior: (-f[i+2] + 8*f[i+1] - 8*f[i-1] + f[i-2]) / (12*h)
200
+ Boundaries: 2nd-order forward/backward differences
201
+
202
+ Parameters
203
+ ----------
204
+ data : np.ndarray
205
+ Input data array
206
+ dx : float
207
+ Grid spacing
208
+ axis : int
209
+ Axis along which to compute derivative
210
+
211
+ Returns
212
+ -------
213
+ np.ndarray
214
+ Derivative of the input data
215
+
216
+ """
217
+ result = np.zeros_like(data)
218
+
219
+ # Build slice objects for vectorized operations
220
+ # This is much faster than looping through indices
221
+
222
+ # Central differences (vectorized)
223
+ # For each point i in [2, n-2), compute:
224
+ # (-f[i+2] + 8*f[i+1] - 8*f[i-1] + f[i-2]) / (12*h)
225
+
226
+ slices_center = [slice(None)] * data.ndim
227
+ slices_p2 = [slice(None)] * data.ndim
228
+ slices_p1 = [slice(None)] * data.ndim
229
+ slices_m1 = [slice(None)] * data.ndim
230
+ slices_m2 = [slice(None)] * data.ndim
231
+
232
+ # Target region: indices 2 to -2
233
+ slices_center[axis] = slice(2, -2)
234
+ # For the stencil, we need aligned slices
235
+ slices_p2[axis] = slice(4, None) # i+2: starts at 4, goes to end
236
+ slices_p1[axis] = slice(3, -1) # i+1: starts at 3, goes to -1
237
+ slices_m1[axis] = slice(1, -3) # i-1: starts at 1, goes to -3
238
+ slices_m2[axis] = slice(0, -4) # i-2: starts at 0, goes to -4
239
+
240
+ result[tuple(slices_center)] = (
241
+ -data[tuple(slices_p2)] + 8 * data[tuple(slices_p1)] - 8 * data[tuple(slices_m1)] + data[tuple(slices_m2)]
242
+ ) / (12 * dx)
243
+
244
+ # Boundary points using 2nd-order differences
245
+ # First point: forward difference
246
+ slices_0 = [slice(None)] * data.ndim
247
+ slices_1 = [slice(None)] * data.ndim
248
+ slices_2 = [slice(None)] * data.ndim
249
+ slices_0[axis] = 0
250
+ slices_1[axis] = 1
251
+ slices_2[axis] = 2
252
+ result[tuple(slices_0)] = (-3 * data[tuple(slices_0)] + 4 * data[tuple(slices_1)] - data[tuple(slices_2)]) / (2 * dx)
253
+
254
+ # Second point: central difference
255
+ result[tuple(slices_1)] = (data[tuple(slices_2)] - data[tuple(slices_0)]) / (2 * dx)
256
+
257
+ # Second-to-last point: central difference
258
+ slices_m2 = [slice(None)] * data.ndim
259
+ slices_m1 = [slice(None)] * data.ndim
260
+ slices_m3 = [slice(None)] * data.ndim
261
+ slices_m2[axis] = -2
262
+ slices_m1[axis] = -1
263
+ slices_m3[axis] = -3
264
+ result[tuple(slices_m2)] = (data[tuple(slices_m1)] - data[tuple(slices_m3)]) / (2 * dx)
265
+
266
+ # Last point: backward difference
267
+ result[tuple(slices_m1)] = (3 * data[tuple(slices_m1)] - 4 * data[tuple(slices_m2)] + data[tuple(slices_m3)]) / (2 * dx)
268
+
269
+ return result
270
+
271
+ def load_all(self) -> np.ndarray:
128
272
  """Load all data and compute the derivative"""
129
273
  if self._data is not None:
130
274
  print("Using cached derivative")
131
275
  return self._data
132
276
 
133
- if not hasattr(self._diag, "_data") or self._diag._data is None:
277
+ # Load diagnostic data if needed
278
+ if not self._diag._all_loaded:
134
279
  self._diag.load_all()
135
- self._data = self._diag._data
136
280
 
137
- if self._diag._all_loaded is True:
138
- print("Using cached data from diagnostic")
139
- self._data = self._diag._data
140
-
141
- if self._deriv_type == "t":
142
- result = np.gradient(self._data, self._diag._dt * self._diag._ndump, axis=0, edge_order=2)
143
-
144
- elif self._deriv_type == "x1":
145
- if self._dim == 1:
146
- result = np.gradient(self._data, self._diag._dx, axis=1, edge_order=2)
147
- else:
148
- result = np.gradient(self._data, self._diag._dx[0], axis=1, edge_order=2)
149
-
150
- elif self._deriv_type == "x2":
151
- result = np.gradient(self._data, self._diag._dx[1], axis=2, edge_order=2)
152
-
153
- elif self._deriv_type == "x3":
154
- result = np.gradient(self._data, self._diag._dx[2], axis=3, edge_order=2)
155
-
156
- elif self._deriv_type == "xx":
157
- if len(self._axis) != 2:
158
- raise ValueError("Axis must be a tuple with two elements.")
159
- result = np.gradient(
160
- np.gradient(
161
- self._data,
162
- self._diag._dx[self._axis[0] - 1],
163
- axis=self._axis[0],
281
+ # Use diagnostic data
282
+ print("Using cached data from diagnostic")
283
+ self._data = self._diag._data
284
+
285
+ if self._order == 2:
286
+ if self._deriv_type == "t":
287
+ result = np.gradient(self._data, self._diag._dt * self._diag._ndump, axis=0, edge_order=2)
288
+
289
+ elif self._deriv_type == "x1":
290
+ # Handle dx - extract scalar for 1D, use first element for multi-D
291
+ dx = self._diag._dx
292
+ if self._dim == 1 and isinstance(dx, (list, tuple, np.ndarray)):
293
+ dx = dx[0] if len(dx) >= 1 else dx
294
+ elif self._dim > 1:
295
+ dx = self._diag._dx[0]
296
+ result = np.gradient(self._data, dx, axis=1, edge_order=2)
297
+
298
+ elif self._deriv_type == "x2":
299
+ result = np.gradient(self._data, self._diag._dx[1], axis=2, edge_order=2)
300
+
301
+ elif self._deriv_type == "x3":
302
+ result = np.gradient(self._data, self._diag._dx[2], axis=3, edge_order=2)
303
+
304
+ elif self._deriv_type == "xx":
305
+ if len(self._axis) != 2:
306
+ raise ValueError("Axis must be a tuple with two elements.")
307
+ result = np.gradient(
308
+ np.gradient(
309
+ self._data,
310
+ self._diag._dx[self._axis[0] - 1],
311
+ axis=self._axis[0],
312
+ edge_order=2,
313
+ ),
314
+ self._diag._dx[self._axis[1] - 1],
315
+ axis=self._axis[1],
164
316
  edge_order=2,
165
- ),
166
- self._diag._dx[self._axis[1] - 1],
167
- axis=self._axis[1],
168
- edge_order=2,
169
- )
170
-
171
- elif self._deriv_type == "xt":
172
- if not isinstance(self._axis, int):
173
- raise ValueError("Axis must be an integer.")
174
- result = np.gradient(
175
- np.gradient(self._data, self._diag._dt, axis=0, edge_order=2),
176
- self._diag._dx[self._axis - 1],
177
- axis=self._axis[0],
178
- edge_order=2,
179
- )
317
+ )
180
318
 
181
- elif self._deriv_type == "tx":
182
- if not isinstance(self._axis, int):
183
- raise ValueError("Axis must be an integer.")
184
- result = np.gradient(
185
- np.gradient(
186
- self._data,
319
+ elif self._deriv_type == "xt":
320
+ if not isinstance(self._axis, int):
321
+ raise ValueError("Axis must be an integer.")
322
+ result = np.gradient(
323
+ np.gradient(self._data, self._diag._dt, axis=0, edge_order=2),
187
324
  self._diag._dx[self._axis - 1],
188
325
  axis=self._axis,
189
326
  edge_order=2,
190
- ),
191
- self._diag._dt,
192
- axis=0,
193
- edge_order=2,
194
- )
195
- else:
196
- raise ValueError("Invalid derivative type.")
327
+ )
328
+
329
+ elif self._deriv_type == "tx":
330
+ if not isinstance(self._axis, int):
331
+ raise ValueError("Axis must be an integer.")
332
+ result = np.gradient(
333
+ np.gradient(
334
+ self._data,
335
+ self._diag._dx[self._axis - 1],
336
+ axis=self._axis,
337
+ edge_order=2,
338
+ ),
339
+ self._diag._dt,
340
+ axis=0,
341
+ edge_order=2,
342
+ )
343
+ else:
344
+ raise ValueError("Invalid derivative type.")
345
+
346
+ elif self._order == 4:
347
+ if self._deriv_type in ["x1", "x2", "x3"]:
348
+ axis = {"x1": 1, "x2": 2, "x3": 3}[self._deriv_type]
349
+ # Extract dx as a scalar
350
+ if self._dim > 1:
351
+ dx = self._diag._dx[axis - 1]
352
+ else:
353
+ # For 1D, _dx might be a list with one element or a scalar
354
+ dx = self._diag._dx[0] if isinstance(self._diag._dx, (list, tuple, np.ndarray)) else self._diag._dx
355
+ # Ensure dx is a scalar float
356
+ if isinstance(dx, (list, tuple, np.ndarray)):
357
+ dx = float(dx) if np.isscalar(dx) else float(dx[0])
358
+ # Use vectorized helper function for massive speedup
359
+ result = self._compute_fourth_order_spatial(self._data, dx, axis)
360
+ else:
361
+ raise ValueError("Order 4 is only implemented for spatial derivatives 'x1', 'x2' and 'x3'.")
197
362
 
198
363
  # Store the result in the cache
199
364
  self._all_loaded = True
200
365
  self._data = result
201
366
  return self._data
202
367
 
203
- def _data_generator(self, index):
368
+ def _data_generator(self, index: int) -> Generator[np.ndarray, None, None]:
204
369
  """Generate data for a specific index on-demand"""
205
- if self._deriv_type == "x1":
206
- if self._dim == 1:
207
- yield np.gradient(self._diag[index], self._diag._dx, axis=0, edge_order=2)
370
+ if self._order == 2:
371
+ if self._deriv_type == "x1":
372
+ if self._dim == 1:
373
+ yield np.gradient(self._diag[index], self._diag._dx, axis=0, edge_order=2)
374
+ else:
375
+ yield np.gradient(self._diag[index], self._diag._dx[0], axis=0, edge_order=2)
376
+
377
+ elif self._deriv_type == "x2":
378
+ yield np.gradient(self._diag[index], self._diag._dx[1], axis=1, edge_order=2)
379
+
380
+ elif self._deriv_type == "x3":
381
+ yield np.gradient(self._diag[index], self._diag._dx[2], axis=2, edge_order=2)
382
+
383
+ elif self._deriv_type == "t":
384
+ if index == 0:
385
+ yield (-3 * self._diag[index] + 4 * self._diag[index + 1] - self._diag[index + 2]) / (
386
+ 2 * self._diag._dt * self._diag._ndump
387
+ )
388
+ elif index == self._diag._maxiter - 1:
389
+ yield (3 * self._diag[index] - 4 * self._diag[index - 1] + self._diag[index - 2]) / (
390
+ 2 * self._diag._dt * self._diag._ndump
391
+ )
392
+ else:
393
+ yield (self._diag[index + 1] - self._diag[index - 1]) / (2 * self._diag._dt * self._diag._ndump)
208
394
  else:
209
- yield np.gradient(self._diag[index], self._diag._dx[0], axis=0, edge_order=2)
210
-
211
- elif self._deriv_type == "x2":
212
- yield np.gradient(self._diag[index], self._diag._dx[1], axis=1, edge_order=2)
213
-
214
- elif self._deriv_type == "x3":
215
- yield np.gradient(self._diag[index], self._diag._dx[2], axis=2, edge_order=2)
216
-
217
- elif self._deriv_type == "t":
218
- if index == 0:
219
- yield (-3 * self._diag[index] + 4 * self._diag[index + 1] - self._diag[index + 2]) / (
220
- 2 * self._diag._dt * self._diag._ndump
221
- )
222
- elif index == self._diag._maxiter - 1:
223
- yield (3 * self._diag[index] - 4 * self._diag[index - 1] + self._diag[index - 2]) / (2 * self._diag._dt * self._diag._ndump)
395
+ raise ValueError("Invalid derivative type. Use 'x1', 'x2', 'x3' or 't'.")
396
+
397
+ elif self._order == 4:
398
+ if self._deriv_type in ["x1", "x2", "x3"]:
399
+ # Use vectorized helper function
400
+ data = self._diag[index]
401
+ axis_map = {"x1": 0, "x2": 1, "x3": 2}
402
+ axis = axis_map[self._deriv_type]
403
+
404
+ if self._deriv_type == "x1":
405
+ dx = self._diag._dx if self._dim == 1 else self._diag._dx[0]
406
+ elif self._deriv_type == "x2":
407
+ dx = self._diag._dx[1]
408
+ else: # x3
409
+ dx = self._diag._dx[2]
410
+
411
+ yield self._compute_fourth_order_spatial(data, dx, axis)
412
+
413
+ elif self._deriv_type == "t":
414
+ idx = index
415
+ # Fourth-order time derivative
416
+ if idx < 2:
417
+ # Forward difference for first two points
418
+ yield (-3 * self._diag[idx] + 4 * self._diag[idx + 1] - self._diag[idx + 2]) / (2 * self._diag._dt * self._diag._ndump)
419
+ elif idx >= self._diag._maxiter - 2:
420
+ # Backward difference for last two points
421
+ yield (3 * self._diag[idx] - 4 * self._diag[idx - 1] + self._diag[idx - 2]) / (2 * self._diag._dt * self._diag._ndump)
422
+ else:
423
+ # Fourth-order central: (-f[i+2] + 8*f[i+1] - 8*f[i-1] + f[i-2]) / (12*h)
424
+ yield (-self._diag[idx + 2] + 8 * self._diag[idx + 1] - 8 * self._diag[idx - 1] + self._diag[idx - 2]) / (
425
+ 12 * self._diag._dt * self._diag._ndump
426
+ )
224
427
  else:
225
- yield (self._diag[index + 1] - self._diag[index - 1]) / (2 * self._diag._dt * self._diag._ndump)
226
- else:
227
- raise ValueError("Invalid derivative type. Use 'x1', 'x2', 'x3' or 't'.")
428
+ raise ValueError("Invalid derivative type. Use 'x1', 'x2', 'x3' or 't'.")
228
429
 
229
- def __getitem__(self, index):
430
+ def __getitem__(self, index: int | slice) -> np.ndarray:
230
431
  """Get data at a specific index"""
231
432
  if self._all_loaded and self._data is not None:
232
433
  return self._data[index]
233
434
 
234
435
  if isinstance(index, int):
235
436
  return next(self._data_generator(index))
236
- elif isinstance(index, slice):
437
+ if isinstance(index, slice):
237
438
  start = 0 if index.start is None else index.start
238
439
  step = 1 if index.step is None else index.step
239
440
  stop = self._diag._maxiter if index.stop is None else index.stop
240
- return np.array([next(self._data_generator(i)) for i in range(start, stop, step)])
241
- else:
242
- raise ValueError("Invalid index type. Use int or slice.")
441
+
442
+ # Pre-allocate array for better performance
443
+ indices = range(start, stop, step)
444
+ if len(indices) > 0:
445
+ first_result = next(self._data_generator(indices[0]))
446
+ result = np.empty((len(indices),) + first_result.shape, dtype=first_result.dtype)
447
+ result[0] = first_result
448
+ for i, idx in enumerate(indices[1:], start=1):
449
+ result[i] = next(self._data_generator(idx))
450
+ return result
451
+ return np.array([])
452
+ raise ValueError("Invalid index type. Use int or slice.")
243
453
 
244
454
 
245
455
  class Derivative_Species_Handler:
246
- """
247
- Class to handle derivatives for a species.
456
+ """Class to handle derivatives for a species.
248
457
  Acts as a wrapper for the Derivative_Diagnostic class.
249
458
 
250
459
  Not intended to be used directly, but through the Derivative_Simulation class.
@@ -257,16 +466,18 @@ class Derivative_Species_Handler:
257
466
  The type of derivative to compute. Options are: 't', 'x1', 'x2', 'x3', 'xx', 'xt' and 'tx'.
258
467
  axis : int or tuple
259
468
  The axis to compute the derivative. Only used for 'xx', 'xt' and 'tx' types.
469
+
260
470
  """
261
471
 
262
- def __init__(self, species_handler, deriv_type, axis=None):
472
+ def __init__(self, species_handler: Any, deriv_type: str, axis: int | tuple[int, int] | None = None, order: int = 2) -> None:
263
473
  self._species_handler = species_handler
264
474
  self._deriv_type = deriv_type
265
475
  self._axis = axis
266
- self._derivatives_computed = {}
476
+ self._order = order
477
+ self._derivatives_computed: dict[Any, Derivative_Diagnostic] = {}
267
478
 
268
- def __getitem__(self, key):
479
+ def __getitem__(self, key: Any) -> Derivative_Diagnostic:
269
480
  if key not in self._derivatives_computed:
270
481
  diag = self._species_handler[key]
271
- self._derivatives_computed[key] = Derivative_Diagnostic(diag, self._deriv_type, self._axis)
482
+ self._derivatives_computed[key] = Derivative_Diagnostic(diag, self._deriv_type, self._axis, self._order)
272
483
  return self._derivatives_computed[key]