osiris-utils 1.1.2__py3-none-any.whl → 1.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,240 @@
1
+ from ..utils import *
2
+ from ..data.simulation import Simulation
3
+ from .postprocess import PostProcess
4
+ from ..data.diagnostic import Diagnostic
5
+ import numpy as np
6
+ import tqdm as tqdm
7
+
8
+
9
+ class FastFourierTransform_Simulation(PostProcess):
10
+ """
11
+ Class to handle the Fast Fourier Transform on data. Works as a wrapper for the FFT_Diagnostic class.
12
+ Inherits from PostProcess to ensure all operation overloads work properly.
13
+
14
+ Parameters
15
+ ----------
16
+
17
+ simulation : Simulation
18
+ The simulation object.
19
+ axis : int
20
+ The axis to compute the FFT.
21
+
22
+ Example
23
+ -------
24
+ >>> sim = Simulation('electrons', 'path/to/simulation')
25
+ >>> fft = FastFourierTransform(sim, 1)
26
+ >>> fft_e1 = fft['e1']
27
+ """
28
+ def __init__(self, simulation, fft_axis):
29
+ super().__init__("FFT")
30
+ if not isinstance(simulation, Simulation):
31
+ raise ValueError("Simulation must be a Simulation object.")
32
+ self._simulation = simulation
33
+ self._fft_axis = fft_axis
34
+ self._fft_computed = {}
35
+ self._species_handler = {}
36
+
37
+ def __getitem__(self, key):
38
+ if key in self._simulation._species:
39
+ if key not in self._species_handler:
40
+ self._species_handler[key] = FFT_Species_Handler(self._simulation[key], self._fft_axis)
41
+ return self._species_handler[key]
42
+
43
+ if key not in self._fft_computed:
44
+ self._fft_computed[key] = FFT_Diagnostic(self._simulation[key], self._fft_axis)
45
+ return self._fft_computed[key]
46
+
47
+ def delete_all(self):
48
+ self._fft_computed = {}
49
+
50
+ def delete(self, key):
51
+ if key in self._fft_computed:
52
+ del self._fft_computed[key]
53
+ else:
54
+ print(f"FFT {key} not found in simulation")
55
+
56
+ def process(self, diagnostic):
57
+ """Apply FFT to a diagnostic"""
58
+ return FFT_Diagnostic(diagnostic, self._fft_axis)
59
+
60
+
61
+ class FFT_Diagnostic(Diagnostic):
62
+ """
63
+ Auxiliar class to compute the FFT of a diagnostic, for it to be similar in behavior to a Diagnostic object.
64
+ Inherits directly from Diagnostic to ensure all operation overloads work properly.
65
+
66
+ Parameters
67
+ ----------
68
+ diagnostic : Diagnostic
69
+ The diagnostic to compute the FFT.
70
+ axis : int
71
+ The axis to compute the FFT.
72
+
73
+ Methods
74
+ -------
75
+ load_all()
76
+ Load all the data and compute the FFT.
77
+ omega()
78
+ Get the angular frequency array for the FFT.
79
+ __getitem__(index)
80
+ Get data at a specific index.
81
+
82
+ Example
83
+ -------
84
+ >>> sim = Simulation('electrons', 'path/to/simulation')
85
+ >>> diag = sim['e1']
86
+ >>> fft = FFT_Diagnostic(diag, 1)
87
+ """
88
+ def __init__(self, diagnostic, fft_axis):
89
+ if hasattr(diagnostic, '_species'):
90
+ super().__init__(simulation_folder=diagnostic._simulation_folder if hasattr(diagnostic, '_simulation_folder') else None,
91
+ species=diagnostic._species)
92
+ else:
93
+ super().__init__(None)
94
+
95
+ self._name = f"FFT[{diagnostic._name}, {fft_axis}]"
96
+ self._diag = diagnostic
97
+ self._fft_axis = fft_axis
98
+ self._data = None
99
+ self._all_loaded = False
100
+
101
+ # Copy all relevant attributes from diagnostic
102
+ for attr in ['_dt', '_dx', '_ndump', '_axis', '_nx', '_x', '_grid', '_dim', '_maxiter']:
103
+ if hasattr(diagnostic, attr):
104
+ setattr(self, attr, getattr(diagnostic, attr))
105
+
106
+ if isinstance(self._dx, (int, float)):
107
+ self._kmax = np.pi / (self._dx)
108
+ else:
109
+ self._kmax = np.pi / np.array([self._dx[ax-1] for ax in self._fft_axis if ax != 0])
110
+
111
+ def load_all(self):
112
+ if self._data is not None:
113
+ print("Using cached derivative")
114
+ return self._data
115
+
116
+ if not hasattr(self._diag, '_data') or self._diag._data is None:
117
+ self._diag.load_all()
118
+ self._diag._data = np.nan_to_num(self._diag._data)
119
+
120
+ # Apply appropriate windows based on which axes we're transforming
121
+ if isinstance(self._fft_axis, (list, tuple)):
122
+ # Multiple axes FFT
123
+ result = self._diag._data.copy()
124
+
125
+ for axis in self._fft_axis:
126
+ if axis == 0: # Time axis
127
+ window = np.hanning(result.shape[0]).reshape(-1, *([1] * (result.ndim - 1)))
128
+ result = result * window
129
+ else: # Spatial axis
130
+ window = self._get_window(result.shape[axis], axis)
131
+ result = self._apply_window(result, window, axis)
132
+
133
+ with tqdm.tqdm(total=1, desc="FFT calculation") as pbar:
134
+ data_fft = np.fft.fftn(result, axes=self._fft_axis)
135
+ pbar.update(0.5)
136
+ result = np.fft.fftshift(data_fft, axes=self._fft_axis)
137
+ pbar.update(0.5)
138
+
139
+ else:
140
+ if self._fft_axis == 0:
141
+ hanning_window = np.hanning(self._diag._data.shape[0]).reshape(-1, *([1] * (self._diag._data.ndim - 1)))
142
+ data_windowed = hanning_window * self._diag._data
143
+ else:
144
+ window = self._get_window(self._diag._data.shape[self._fft_axis], self._fft_axis)
145
+ data_windowed = self._apply_window(self._diag._data, window, self._fft_axis)
146
+
147
+ with tqdm.tqdm(total=1, desc="FFT calculation") as pbar:
148
+ data_fft = np.fft.fft(data_windowed, axis=self._fft_axis)
149
+ pbar.update(0.5)
150
+ result = np.fft.fftshift(data_fft, axes=self._fft_axis)
151
+ pbar.update(0.5)
152
+
153
+ self.omega_max = np.pi / self._dt / self._ndump
154
+
155
+ self._all_loaded = True
156
+ self._data = np.abs(result)**2
157
+ return self._data
158
+
159
+ def _data_generator(self, index):
160
+ # Get the data for this index
161
+ original_data = self._diag[index]
162
+
163
+ if self._fft_axis == 0:
164
+ raise ValueError("Cannot generate FFT along time axis for a single timestep. Use load_all() instead.")
165
+
166
+ # For spatial FFT, we can apply a spatial window if desired
167
+ if isinstance(self._fft_axis, (list, tuple)):
168
+ result = original_data
169
+ for axis in self._fft_axis:
170
+ if axis != 0: # Skip time axis
171
+ # Apply window along this spatial dimension
172
+ window = self._get_window(original_data.shape[axis-1], axis-1)
173
+ result = self._apply_window(result, window, axis-1)
174
+
175
+ # Compute FFT
176
+ result_fft = np.fft.fftn(result, axes=[ax-1 for ax in self._fft_axis if ax != 0])
177
+ result_fft = np.fft.fftshift(result_fft, axes=[ax-1 for ax in self._fft_axis if ax != 0])
178
+
179
+ else:
180
+ if self._fft_axis > 0: # Spatial axis
181
+ window = self._get_window(original_data.shape[self._fft_axis-1], self._fft_axis-1)
182
+ windowed_data = self._apply_window(original_data, window, self._fft_axis-1)
183
+
184
+ result_fft = np.fft.fft(windowed_data, axis=self._fft_axis-1)
185
+ result_fft = np.fft.fftshift(result_fft, axes=self._fft_axis-1)
186
+
187
+ yield np.abs(result_fft)**2
188
+
189
+ def _get_window(self, length, axis):
190
+ return np.hanning(length)
191
+
192
+ def _apply_window(self, data, window, axis):
193
+ ndim = data.ndim
194
+ window_shape = [1] * ndim
195
+ window_shape[axis] = len(window)
196
+
197
+ reshaped_window = window.reshape(window_shape)
198
+
199
+ return data * reshaped_window
200
+
201
+ def __getitem__(self, index):
202
+ if self._all_loaded and self._data is not None:
203
+ return self._data[index]
204
+
205
+ if isinstance(index, int):
206
+ return next(self._data_generator(index))
207
+ elif isinstance(index, slice):
208
+ start = 0 if index.start is None else index.start
209
+ step = 1 if index.step is None else index.step
210
+ stop = self._diag._maxiter if index.stop is None else index.stop
211
+ return np.array([next(self._data_generator(i)) for i in range(start, stop, step)])
212
+ else:
213
+ raise ValueError("Invalid index type. Use int or slice.")
214
+
215
+ def omega(self):
216
+ """
217
+ Get the angular frequency array for the FFT.
218
+ """
219
+ if not self._all_loaded:
220
+ raise ValueError("Load the data first using load_all() method.")
221
+
222
+ omega = np.fft.fftfreq(self._data.shape[self._fft_axis], d=self._dx[self._fft_axis-1])
223
+ omega = np.fft.fftshift(omega)
224
+ return omega
225
+
226
+ @property
227
+ def kmax(self):
228
+ return self._kmax
229
+
230
+ class FFT_Species_Handler:
231
+ def __init__(self, species_handler, fft_axis):
232
+ self._species_handler = species_handler
233
+ self._fft_axis = fft_axis
234
+ self._fft_computed = {}
235
+
236
+ def __getitem__(self, key):
237
+ if key not in self._fft_computed:
238
+ diag = self._species_handler[key]
239
+ self._fft_computed[key] = FFT_Diagnostic(diag, self._fft_axis)
240
+ return self._fft_computed[key]
@@ -0,0 +1,348 @@
1
+ from ..utils import *
2
+ from ..data.simulation import Simulation
3
+ from .postprocess import PostProcess
4
+ from ..data.diagnostic import Diagnostic
5
+
6
+ class MeanFieldTheory_Simulation(PostProcess):
7
+ """
8
+ Class to compute the mean field theory of a diagnostic. Works as a wrapper for the MFT_Diagnostic class.
9
+ Inherits from PostProcess to ensure all operation overloads work properly.
10
+
11
+ Parameters
12
+ ----------
13
+ simulation : Simulation
14
+ The simulation object.
15
+ mft_axis : int
16
+ The axis to compute the mean field theory.
17
+
18
+ Example
19
+ -------
20
+ >>> sim = Simulation('electrons', 'path/to/simulation')
21
+ >>> mft = MeanFieldTheory(sim, 1)
22
+ >>> mft_e1 = mft['e1']
23
+ """
24
+
25
+ def __init__(self, simulation, mft_axis=None):
26
+ super().__init__(f"MeanFieldTheory({mft_axis})")
27
+ if not isinstance(simulation, Simulation):
28
+ raise ValueError("Simulation must be a Simulation object.")
29
+ self._simulation = simulation
30
+ self._mft_axis = mft_axis
31
+ self._mft_computed = {}
32
+ self._species_handler = {}
33
+
34
+ def __getitem__(self, key):
35
+ if key in self._simulation._species:
36
+ if key not in self._species_handler:
37
+ self._species_handler[key] = MFT_Species_Handler(self._simulation[key], self._mft_axis)
38
+ return self._species_handler[key]
39
+ if key not in self._mft_computed:
40
+ self._mft_computed[key] = MFT_Diagnostic(self._simulation[key], self._mft_axis)
41
+ return self._mft_computed[key]
42
+
43
+ def delete_all(self):
44
+ self._mft_computed = {}
45
+
46
+ def delete(self, key):
47
+ if key in self._mft_computed:
48
+ del self._mft_computed[key]
49
+ else:
50
+ print(f"MeanFieldTheory {key} not found in simulation")
51
+
52
+ def process(self, diagnostic):
53
+ """Apply mean field theory to a diagnostic"""
54
+ return MFT_Diagnostic(diagnostic, self._mft_axis)
55
+
56
+ class MFT_Diagnostic(Diagnostic):
57
+ """
58
+ Class to compute mean field theory of a diagnostic.
59
+ Acts as a container for the average and fluctuation components.
60
+
61
+ Parameters
62
+ ----------
63
+ diagnostic : Diagnostic
64
+ The diagnostic object.
65
+ mft_axis : int
66
+ The axis to compute mean field theory along.
67
+
68
+ Example
69
+ -------
70
+ >>> sim = Simulation('electrons', 'path/to/simulation')
71
+ >>> diag = sim['e1']
72
+ >>> mft = MFT_Diagnostic(diag, 1)
73
+ >>> avg = mft['avg']
74
+ >>> delta = mft['delta']
75
+ """
76
+
77
+ def __init__(self, diagnostic, mft_axis):
78
+ # Initialize using parent's __init__ with the same species
79
+ if hasattr(diagnostic, '_species'):
80
+ super().__init__(species=diagnostic._species)
81
+ else:
82
+ super().__init__(None)
83
+
84
+ self._name = f"MFT[{diagnostic._name}]"
85
+ self._diag = diagnostic
86
+ self._mft_axis = mft_axis
87
+ self._data = None
88
+ self._all_loaded = False
89
+
90
+ # Components that will be lazily created
91
+ self._components = {}
92
+
93
+ # Copy all relevant attributes from diagnostic
94
+ for attr in ['_dt', '_dx', '_ndump', '_axis', '_nx', '_x', '_grid', '_dim', '_maxiter']:
95
+ if hasattr(diagnostic, attr):
96
+ setattr(self, attr, getattr(diagnostic, attr))
97
+
98
+ def __getitem__(self, key):
99
+ """
100
+ Get a component of the mean field theory.
101
+
102
+ Parameters
103
+ ----------
104
+ key : str
105
+ Either "avg" for average or "delta" for fluctuations.
106
+
107
+ Returns
108
+ -------
109
+ Diagnostic
110
+ The requested component.
111
+ """
112
+ if key == "avg":
113
+ if "avg" not in self._components:
114
+ self._components["avg"] = MFT_Diagnostic_Average(self._diag, self._mft_axis)
115
+ return self._components["avg"]
116
+
117
+ elif key == "delta":
118
+ if "delta" not in self._components:
119
+ self._components["delta"] = MFT_Diagnostic_Fluctuations(self._diag, self._mft_axis)
120
+ return self._components["delta"]
121
+
122
+ else:
123
+ raise ValueError("Invalid MFT component. Use 'avg' or 'delta'.")
124
+
125
+ def load_all(self):
126
+ """Load both average and fluctuation components"""
127
+ # This will compute both components at once for efficiency
128
+ if "avg" not in self._components:
129
+ self._components["avg"] = MFT_Diagnostic_Average(self._diag, self._mft_axis)
130
+
131
+ if "delta" not in self._components:
132
+ self._components["delta"] = MFT_Diagnostic_Fluctuations(self._diag, self._mft_axis)
133
+
134
+ # Load both components
135
+ self._components["avg"].load_all()
136
+ self._components["delta"].load_all()
137
+
138
+ # Mark this container as loaded
139
+ self._all_loaded = True
140
+
141
+ return self._components
142
+
143
+ class MFT_Diagnostic_Average(Diagnostic):
144
+ """
145
+ Class to compute the average component of mean field theory.
146
+ Inherits from Diagnostic to ensure all operation overloads work properly.
147
+
148
+ Parameters
149
+ ----------
150
+ diagnostic : Diagnostic
151
+ The diagnostic object.
152
+ mft_axis : int
153
+ The axis to compute the mean field theory.
154
+
155
+ Example
156
+ -------
157
+ >>> sim = Simulation('electrons', 'path/to/simulation')
158
+ >>> diag = sim['e1']
159
+ >>> avg = MFT_Diagnostic_Average(diag, 1)
160
+ """
161
+
162
+ def __init__(self, diagnostic, mft_axis):
163
+ # Initialize with the same species as the diagnostic
164
+ if hasattr(diagnostic, '_species'):
165
+ super().__init__(species=diagnostic._species)
166
+ else:
167
+ super().__init__(None)
168
+
169
+ if mft_axis is None:
170
+ raise ValueError("Mean field theory axis must be specified.")
171
+
172
+ self._name = f"MFT_avg[{diagnostic._name}, {mft_axis}]"
173
+ self._diag = diagnostic
174
+ self._mft_axis = mft_axis
175
+ self._data = None
176
+ self._all_loaded = False
177
+
178
+ # Copy all relevant attributes from diagnostic
179
+ for attr in ['_dt', '_dx', '_ndump', '_axis', '_nx', '_x', '_grid', '_dim', '_maxiter']:
180
+ if hasattr(diagnostic, attr):
181
+ setattr(self, attr, getattr(diagnostic, attr))
182
+
183
+ def load_all(self):
184
+ """Load all data and compute the average"""
185
+ if self._data is not None:
186
+ print("Data already loaded")
187
+ return self._data
188
+
189
+ if not hasattr(self._diag, '_data') or self._diag._data is None:
190
+ self._diag.load_all()
191
+
192
+ if self._mft_axis is None:
193
+ raise ValueError("Mean field theory axis must be specified.")
194
+ else:
195
+ self._data = np.expand_dims(self._diag._data.mean(axis=self._mft_axis), axis=-1)
196
+
197
+ self._all_loaded = True
198
+ return self._data
199
+
200
+ def _data_generator(self, index):
201
+ """Generate average data for a specific index"""
202
+ if self._mft_axis is not None:
203
+ # Get the data for this index
204
+ data = self._diag[index]
205
+ # Compute the average (mean) along the specified axis
206
+ # Note: When accessing a slice, axis numbering is 0-based
207
+ avg = np.expand_dims(data.mean(axis=self._mft_axis-1), axis=-1)
208
+ yield avg
209
+ else:
210
+ raise ValueError("Invalid axis for mean field theory.")
211
+
212
+ def __getitem__(self, index):
213
+ """Get average at a specific index"""
214
+ if self._all_loaded and self._data is not None:
215
+ return self._data[index]
216
+
217
+ # Otherwise compute on-demand
218
+ if isinstance(index, int):
219
+ return next(self._data_generator(index))
220
+ elif isinstance(index, slice):
221
+ start = 0 if index.start is None else index.start
222
+ step = 1 if index.step is None else index.step
223
+ stop = self._diag._maxiter if index.stop is None else index.stop
224
+ return np.array([next(self._data_generator(i)) for i in range(start, stop, step)])
225
+ else:
226
+ raise ValueError("Invalid index type. Use int or slice.")
227
+
228
+ class MFT_Diagnostic_Fluctuations(Diagnostic):
229
+ """
230
+ Class to compute the fluctuation component of mean field theory.
231
+ Inherits from Diagnostic to ensure all operation overloads work properly.
232
+
233
+ Parameters
234
+ ----------
235
+ diagnostic : Diagnostic
236
+ The diagnostic object.
237
+ mft_axis : int
238
+ The axis to compute the mean field theory.
239
+
240
+ Example
241
+ -------
242
+ >>> sim = Simulation('electrons', 'path/to/simulation')
243
+ >>> diag = sim['e1']
244
+ >>> delta = MFT_Diagnostic_Fluctuations(diag, 1)
245
+ """
246
+
247
+ def __init__(self, diagnostic, mft_axis):
248
+ # Initialize with the same species as the diagnostic
249
+ if hasattr(diagnostic, '_species'):
250
+ super().__init__(species=diagnostic._species)
251
+ else:
252
+ super().__init__(None)
253
+
254
+ if mft_axis is None:
255
+ raise ValueError("Mean field theory axis must be specified.")
256
+
257
+ self._name = f"MFT_delta[{diagnostic._name}, {mft_axis}]"
258
+ self._diag = diagnostic
259
+ self._mft_axis = mft_axis
260
+ self._data = None
261
+ self._all_loaded = False
262
+
263
+ # Copy all relevant attributes from diagnostic
264
+ for attr in ['_dt', '_dx', '_ndump', '_axis', '_nx', '_x', '_grid', '_dim', '_maxiter']:
265
+ if hasattr(diagnostic, attr):
266
+ setattr(self, attr, getattr(diagnostic, attr))
267
+
268
+ def load_all(self):
269
+ """Load all data and compute the fluctuations"""
270
+ if self._data is not None:
271
+ print("Data already loaded")
272
+ return self._data
273
+
274
+ if not hasattr(self._diag, '_data') or self._diag._data is None:
275
+ self._diag.load_all()
276
+
277
+ if self._mft_axis is None:
278
+ raise ValueError("Mean field theory axis must be specified.")
279
+ else:
280
+ # Compute the average
281
+ avg = self._diag._data.mean(axis=self._mft_axis)
282
+ # Reshape avg for broadcasting
283
+ broadcast_shape = list(self._diag._data.shape)
284
+ broadcast_shape[self._mft_axis] = 1
285
+ avg_reshaped = avg.reshape(broadcast_shape)
286
+ # Compute the fluctuations
287
+ self._data = self._diag._data - avg_reshaped
288
+
289
+ self._all_loaded = True
290
+ return self._data
291
+
292
+ def _data_generator(self, index):
293
+ """Generate fluctuation data for a specific index"""
294
+ if self._mft_axis is not None:
295
+ # Get the data for this index
296
+ data = self._diag[index]
297
+ # Compute the average (mean) along the specified axis
298
+ # Note: When accessing a slice, axis numbering is 0-based
299
+ avg = data.mean(axis=self._mft_axis-1)
300
+ # Expand dimensions to enable broadcasting
301
+ avg_reshaped = np.expand_dims(avg, axis=self._mft_axis-1)
302
+ # Compute fluctuations
303
+ delta = data - avg_reshaped
304
+ yield delta
305
+ else:
306
+ raise ValueError("Invalid axis for mean field theory.")
307
+
308
+ def __getitem__(self, index):
309
+ """Get fluctuations at a specific index"""
310
+ if self._all_loaded and self._data is not None:
311
+ return self._data[index]
312
+
313
+ # Otherwise compute on-demand
314
+ if isinstance(index, int):
315
+ return next(self._data_generator(index))
316
+ elif isinstance(index, slice):
317
+ start = 0 if index.start is None else index.start
318
+ step = 1 if index.step is None else index.step
319
+ stop = self._diag._maxiter if index.stop is None else index.stop
320
+ return np.array([next(self._data_generator(i)) for i in range(start, stop, step)])
321
+ else:
322
+ raise ValueError("Invalid index type. Use int or slice.")
323
+
324
+ class MFT_Species_Handler:
325
+ """
326
+ Class to handle mean field theory for a species.
327
+ Acts as a wrapper for the MFT_Diagnostic class.
328
+
329
+ Not intended to be used directly, but through the MFT_Simulation class.
330
+
331
+ Parameters
332
+ ----------
333
+ species_handler : Species_Handler
334
+ The species handler object.
335
+ mft_axis : int
336
+ The axis to compute the mean field theory.
337
+ """
338
+
339
+ def __init__(self, species_handler, mft_axis):
340
+ self._species_handler = species_handler
341
+ self._mft_axis = mft_axis
342
+ self._mft_computed = {}
343
+
344
+ def __getitem__(self, key):
345
+ if key not in self._mft_computed:
346
+ diag = self._species_handler[key]
347
+ self._mft_computed[key] = MFT_Diagnostic(diag, self._mft_axis)
348
+ return self._mft_computed[key]
@@ -0,0 +1,52 @@
1
+ import numpy as np
2
+ from ..data.data import OsirisGridFile
3
+
4
+ # Deprecated
5
+ class MFT_Single(OsirisGridFile):
6
+ '''
7
+ Class to handle the mean field theory on data. Inherits from OsirisGridFile.
8
+
9
+ Parameters
10
+ ----------
11
+ source : str or OsirisGridFile
12
+ The filename or an OsirisGridFile object.
13
+ axis : int
14
+ The axis to average over.
15
+ '''
16
+ def __init__(self, source, axis=1):
17
+ if isinstance(source, OsirisGridFile):
18
+ self.__dict__.update(source.__dict__)
19
+ else:
20
+ super().__init__(source)
21
+ self._compute_mean_field(axis=axis)
22
+
23
+ def _compute_mean_field(self, axis=1):
24
+ self._average = np.expand_dims(np.mean(self.data, axis=axis), axis=axis)
25
+ self._fluctuations = self.data - self._average
26
+
27
+ def __array__(self):
28
+ return self.data
29
+
30
+ @property
31
+ def average(self):
32
+ return self._average
33
+
34
+ @property
35
+ def delta(self):
36
+ return self._fluctuations
37
+
38
+ def __str__(self):
39
+ return super().__str__() + f'\nAverage: {self.average.shape}\nDelta: {self.delta.shape}'
40
+
41
+ def derivative(self, field, axis=0):
42
+ '''
43
+ Compute the derivative of the average or the fluctuations.
44
+
45
+ Parameters
46
+ ----------
47
+ field : MeanFieldTheory.average or MeanFieldTheory.delta
48
+ The field to compute the derivative.
49
+ axis : int
50
+ The axis to compute the derivative.
51
+ '''
52
+ return np.gradient(field, self.dx[axis], axis=0)
@@ -0,0 +1,42 @@
1
+ from ..data.diagnostic import Diagnostic
2
+
3
+ class PostProcess(Diagnostic):
4
+ """
5
+ Base class for post-processing operations.
6
+ Inherits from Diagnostic to ensure all operation overloads work.
7
+
8
+ Parameters
9
+ ----------
10
+ name : str
11
+ Name of the post-processing operation.
12
+ species : str
13
+ The species to analyze.
14
+ """
15
+
16
+ def __init__(self, name, species=None):
17
+ # Initialize with the same interface as Diagnostic
18
+ super().__init__(species)
19
+ self._name = name
20
+ self._all_loaded = False
21
+ self._data = None
22
+
23
+ def process(self, diagnostic):
24
+ """
25
+ Apply the post-processing to a diagnostic.
26
+ Must be implemented by subclasses.
27
+
28
+ Parameters
29
+ ----------
30
+ diagnostic : Diagnostic
31
+ The diagnostic to process.
32
+
33
+ Returns
34
+ -------
35
+ Diagnostic or PostProcess
36
+ The processed diagnostic.
37
+ """
38
+ raise NotImplementedError("Subclasses must implement process method")
39
+
40
+
41
+ # PostProcessing_Simulation
42
+ # PostProcessing_Diagnostic