osiris-utils 1.1.2__py3-none-any.whl → 1.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- osiris_utils/__init__.py +22 -9
- osiris_utils/data/__init__.py +0 -0
- osiris_utils/data/data.py +418 -0
- osiris_utils/data/diagnostic.py +979 -0
- osiris_utils/data/simulation.py +203 -0
- osiris_utils/decks/__init__.py +0 -0
- osiris_utils/decks/decks.py +288 -0
- osiris_utils/decks/species.py +55 -0
- osiris_utils/gui/__init__.py +0 -0
- osiris_utils/gui/gui.py +266 -0
- osiris_utils/postprocessing/__init__.py +0 -0
- osiris_utils/postprocessing/derivative.py +243 -0
- osiris_utils/postprocessing/fft.py +240 -0
- osiris_utils/postprocessing/mft.py +348 -0
- osiris_utils/postprocessing/mft_for_gridfile.py +52 -0
- osiris_utils/postprocessing/postprocess.py +42 -0
- osiris_utils/utils.py +1 -40
- {osiris_utils-1.1.2.dist-info → osiris_utils-1.1.4.dist-info}/METADATA +20 -2
- osiris_utils-1.1.4.dist-info/RECORD +22 -0
- {osiris_utils-1.1.2.dist-info → osiris_utils-1.1.4.dist-info}/WHEEL +1 -1
- osiris_utils-1.1.2.dist-info/RECORD +0 -7
- {osiris_utils-1.1.2.dist-info → osiris_utils-1.1.4.dist-info}/licenses/LICENSE.txt +0 -0
- {osiris_utils-1.1.2.dist-info → osiris_utils-1.1.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
from ..utils import *
|
|
2
|
+
from ..data.simulation import Simulation
|
|
3
|
+
from .postprocess import PostProcess
|
|
4
|
+
from ..data.diagnostic import Diagnostic
|
|
5
|
+
import numpy as np
|
|
6
|
+
import tqdm as tqdm
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FastFourierTransform_Simulation(PostProcess):
|
|
10
|
+
"""
|
|
11
|
+
Class to handle the Fast Fourier Transform on data. Works as a wrapper for the FFT_Diagnostic class.
|
|
12
|
+
Inherits from PostProcess to ensure all operation overloads work properly.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
|
|
17
|
+
simulation : Simulation
|
|
18
|
+
The simulation object.
|
|
19
|
+
axis : int
|
|
20
|
+
The axis to compute the FFT.
|
|
21
|
+
|
|
22
|
+
Example
|
|
23
|
+
-------
|
|
24
|
+
>>> sim = Simulation('electrons', 'path/to/simulation')
|
|
25
|
+
>>> fft = FastFourierTransform(sim, 1)
|
|
26
|
+
>>> fft_e1 = fft['e1']
|
|
27
|
+
"""
|
|
28
|
+
def __init__(self, simulation, fft_axis):
|
|
29
|
+
super().__init__("FFT")
|
|
30
|
+
if not isinstance(simulation, Simulation):
|
|
31
|
+
raise ValueError("Simulation must be a Simulation object.")
|
|
32
|
+
self._simulation = simulation
|
|
33
|
+
self._fft_axis = fft_axis
|
|
34
|
+
self._fft_computed = {}
|
|
35
|
+
self._species_handler = {}
|
|
36
|
+
|
|
37
|
+
def __getitem__(self, key):
|
|
38
|
+
if key in self._simulation._species:
|
|
39
|
+
if key not in self._species_handler:
|
|
40
|
+
self._species_handler[key] = FFT_Species_Handler(self._simulation[key], self._fft_axis)
|
|
41
|
+
return self._species_handler[key]
|
|
42
|
+
|
|
43
|
+
if key not in self._fft_computed:
|
|
44
|
+
self._fft_computed[key] = FFT_Diagnostic(self._simulation[key], self._fft_axis)
|
|
45
|
+
return self._fft_computed[key]
|
|
46
|
+
|
|
47
|
+
def delete_all(self):
|
|
48
|
+
self._fft_computed = {}
|
|
49
|
+
|
|
50
|
+
def delete(self, key):
|
|
51
|
+
if key in self._fft_computed:
|
|
52
|
+
del self._fft_computed[key]
|
|
53
|
+
else:
|
|
54
|
+
print(f"FFT {key} not found in simulation")
|
|
55
|
+
|
|
56
|
+
def process(self, diagnostic):
|
|
57
|
+
"""Apply FFT to a diagnostic"""
|
|
58
|
+
return FFT_Diagnostic(diagnostic, self._fft_axis)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class FFT_Diagnostic(Diagnostic):
|
|
62
|
+
"""
|
|
63
|
+
Auxiliar class to compute the FFT of a diagnostic, for it to be similar in behavior to a Diagnostic object.
|
|
64
|
+
Inherits directly from Diagnostic to ensure all operation overloads work properly.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
diagnostic : Diagnostic
|
|
69
|
+
The diagnostic to compute the FFT.
|
|
70
|
+
axis : int
|
|
71
|
+
The axis to compute the FFT.
|
|
72
|
+
|
|
73
|
+
Methods
|
|
74
|
+
-------
|
|
75
|
+
load_all()
|
|
76
|
+
Load all the data and compute the FFT.
|
|
77
|
+
omega()
|
|
78
|
+
Get the angular frequency array for the FFT.
|
|
79
|
+
__getitem__(index)
|
|
80
|
+
Get data at a specific index.
|
|
81
|
+
|
|
82
|
+
Example
|
|
83
|
+
-------
|
|
84
|
+
>>> sim = Simulation('electrons', 'path/to/simulation')
|
|
85
|
+
>>> diag = sim['e1']
|
|
86
|
+
>>> fft = FFT_Diagnostic(diag, 1)
|
|
87
|
+
"""
|
|
88
|
+
def __init__(self, diagnostic, fft_axis):
|
|
89
|
+
if hasattr(diagnostic, '_species'):
|
|
90
|
+
super().__init__(simulation_folder=diagnostic._simulation_folder if hasattr(diagnostic, '_simulation_folder') else None,
|
|
91
|
+
species=diagnostic._species)
|
|
92
|
+
else:
|
|
93
|
+
super().__init__(None)
|
|
94
|
+
|
|
95
|
+
self._name = f"FFT[{diagnostic._name}, {fft_axis}]"
|
|
96
|
+
self._diag = diagnostic
|
|
97
|
+
self._fft_axis = fft_axis
|
|
98
|
+
self._data = None
|
|
99
|
+
self._all_loaded = False
|
|
100
|
+
|
|
101
|
+
# Copy all relevant attributes from diagnostic
|
|
102
|
+
for attr in ['_dt', '_dx', '_ndump', '_axis', '_nx', '_x', '_grid', '_dim', '_maxiter']:
|
|
103
|
+
if hasattr(diagnostic, attr):
|
|
104
|
+
setattr(self, attr, getattr(diagnostic, attr))
|
|
105
|
+
|
|
106
|
+
if isinstance(self._dx, (int, float)):
|
|
107
|
+
self._kmax = np.pi / (self._dx)
|
|
108
|
+
else:
|
|
109
|
+
self._kmax = np.pi / np.array([self._dx[ax-1] for ax in self._fft_axis if ax != 0])
|
|
110
|
+
|
|
111
|
+
def load_all(self):
|
|
112
|
+
if self._data is not None:
|
|
113
|
+
print("Using cached derivative")
|
|
114
|
+
return self._data
|
|
115
|
+
|
|
116
|
+
if not hasattr(self._diag, '_data') or self._diag._data is None:
|
|
117
|
+
self._diag.load_all()
|
|
118
|
+
self._diag._data = np.nan_to_num(self._diag._data)
|
|
119
|
+
|
|
120
|
+
# Apply appropriate windows based on which axes we're transforming
|
|
121
|
+
if isinstance(self._fft_axis, (list, tuple)):
|
|
122
|
+
# Multiple axes FFT
|
|
123
|
+
result = self._diag._data.copy()
|
|
124
|
+
|
|
125
|
+
for axis in self._fft_axis:
|
|
126
|
+
if axis == 0: # Time axis
|
|
127
|
+
window = np.hanning(result.shape[0]).reshape(-1, *([1] * (result.ndim - 1)))
|
|
128
|
+
result = result * window
|
|
129
|
+
else: # Spatial axis
|
|
130
|
+
window = self._get_window(result.shape[axis], axis)
|
|
131
|
+
result = self._apply_window(result, window, axis)
|
|
132
|
+
|
|
133
|
+
with tqdm.tqdm(total=1, desc="FFT calculation") as pbar:
|
|
134
|
+
data_fft = np.fft.fftn(result, axes=self._fft_axis)
|
|
135
|
+
pbar.update(0.5)
|
|
136
|
+
result = np.fft.fftshift(data_fft, axes=self._fft_axis)
|
|
137
|
+
pbar.update(0.5)
|
|
138
|
+
|
|
139
|
+
else:
|
|
140
|
+
if self._fft_axis == 0:
|
|
141
|
+
hanning_window = np.hanning(self._diag._data.shape[0]).reshape(-1, *([1] * (self._diag._data.ndim - 1)))
|
|
142
|
+
data_windowed = hanning_window * self._diag._data
|
|
143
|
+
else:
|
|
144
|
+
window = self._get_window(self._diag._data.shape[self._fft_axis], self._fft_axis)
|
|
145
|
+
data_windowed = self._apply_window(self._diag._data, window, self._fft_axis)
|
|
146
|
+
|
|
147
|
+
with tqdm.tqdm(total=1, desc="FFT calculation") as pbar:
|
|
148
|
+
data_fft = np.fft.fft(data_windowed, axis=self._fft_axis)
|
|
149
|
+
pbar.update(0.5)
|
|
150
|
+
result = np.fft.fftshift(data_fft, axes=self._fft_axis)
|
|
151
|
+
pbar.update(0.5)
|
|
152
|
+
|
|
153
|
+
self.omega_max = np.pi / self._dt / self._ndump
|
|
154
|
+
|
|
155
|
+
self._all_loaded = True
|
|
156
|
+
self._data = np.abs(result)**2
|
|
157
|
+
return self._data
|
|
158
|
+
|
|
159
|
+
def _data_generator(self, index):
|
|
160
|
+
# Get the data for this index
|
|
161
|
+
original_data = self._diag[index]
|
|
162
|
+
|
|
163
|
+
if self._fft_axis == 0:
|
|
164
|
+
raise ValueError("Cannot generate FFT along time axis for a single timestep. Use load_all() instead.")
|
|
165
|
+
|
|
166
|
+
# For spatial FFT, we can apply a spatial window if desired
|
|
167
|
+
if isinstance(self._fft_axis, (list, tuple)):
|
|
168
|
+
result = original_data
|
|
169
|
+
for axis in self._fft_axis:
|
|
170
|
+
if axis != 0: # Skip time axis
|
|
171
|
+
# Apply window along this spatial dimension
|
|
172
|
+
window = self._get_window(original_data.shape[axis-1], axis-1)
|
|
173
|
+
result = self._apply_window(result, window, axis-1)
|
|
174
|
+
|
|
175
|
+
# Compute FFT
|
|
176
|
+
result_fft = np.fft.fftn(result, axes=[ax-1 for ax in self._fft_axis if ax != 0])
|
|
177
|
+
result_fft = np.fft.fftshift(result_fft, axes=[ax-1 for ax in self._fft_axis if ax != 0])
|
|
178
|
+
|
|
179
|
+
else:
|
|
180
|
+
if self._fft_axis > 0: # Spatial axis
|
|
181
|
+
window = self._get_window(original_data.shape[self._fft_axis-1], self._fft_axis-1)
|
|
182
|
+
windowed_data = self._apply_window(original_data, window, self._fft_axis-1)
|
|
183
|
+
|
|
184
|
+
result_fft = np.fft.fft(windowed_data, axis=self._fft_axis-1)
|
|
185
|
+
result_fft = np.fft.fftshift(result_fft, axes=self._fft_axis-1)
|
|
186
|
+
|
|
187
|
+
yield np.abs(result_fft)**2
|
|
188
|
+
|
|
189
|
+
def _get_window(self, length, axis):
|
|
190
|
+
return np.hanning(length)
|
|
191
|
+
|
|
192
|
+
def _apply_window(self, data, window, axis):
|
|
193
|
+
ndim = data.ndim
|
|
194
|
+
window_shape = [1] * ndim
|
|
195
|
+
window_shape[axis] = len(window)
|
|
196
|
+
|
|
197
|
+
reshaped_window = window.reshape(window_shape)
|
|
198
|
+
|
|
199
|
+
return data * reshaped_window
|
|
200
|
+
|
|
201
|
+
def __getitem__(self, index):
|
|
202
|
+
if self._all_loaded and self._data is not None:
|
|
203
|
+
return self._data[index]
|
|
204
|
+
|
|
205
|
+
if isinstance(index, int):
|
|
206
|
+
return next(self._data_generator(index))
|
|
207
|
+
elif isinstance(index, slice):
|
|
208
|
+
start = 0 if index.start is None else index.start
|
|
209
|
+
step = 1 if index.step is None else index.step
|
|
210
|
+
stop = self._diag._maxiter if index.stop is None else index.stop
|
|
211
|
+
return np.array([next(self._data_generator(i)) for i in range(start, stop, step)])
|
|
212
|
+
else:
|
|
213
|
+
raise ValueError("Invalid index type. Use int or slice.")
|
|
214
|
+
|
|
215
|
+
def omega(self):
|
|
216
|
+
"""
|
|
217
|
+
Get the angular frequency array for the FFT.
|
|
218
|
+
"""
|
|
219
|
+
if not self._all_loaded:
|
|
220
|
+
raise ValueError("Load the data first using load_all() method.")
|
|
221
|
+
|
|
222
|
+
omega = np.fft.fftfreq(self._data.shape[self._fft_axis], d=self._dx[self._fft_axis-1])
|
|
223
|
+
omega = np.fft.fftshift(omega)
|
|
224
|
+
return omega
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def kmax(self):
|
|
228
|
+
return self._kmax
|
|
229
|
+
|
|
230
|
+
class FFT_Species_Handler:
|
|
231
|
+
def __init__(self, species_handler, fft_axis):
|
|
232
|
+
self._species_handler = species_handler
|
|
233
|
+
self._fft_axis = fft_axis
|
|
234
|
+
self._fft_computed = {}
|
|
235
|
+
|
|
236
|
+
def __getitem__(self, key):
|
|
237
|
+
if key not in self._fft_computed:
|
|
238
|
+
diag = self._species_handler[key]
|
|
239
|
+
self._fft_computed[key] = FFT_Diagnostic(diag, self._fft_axis)
|
|
240
|
+
return self._fft_computed[key]
|
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
from ..utils import *
|
|
2
|
+
from ..data.simulation import Simulation
|
|
3
|
+
from .postprocess import PostProcess
|
|
4
|
+
from ..data.diagnostic import Diagnostic
|
|
5
|
+
|
|
6
|
+
class MeanFieldTheory_Simulation(PostProcess):
|
|
7
|
+
"""
|
|
8
|
+
Class to compute the mean field theory of a diagnostic. Works as a wrapper for the MFT_Diagnostic class.
|
|
9
|
+
Inherits from PostProcess to ensure all operation overloads work properly.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
simulation : Simulation
|
|
14
|
+
The simulation object.
|
|
15
|
+
mft_axis : int
|
|
16
|
+
The axis to compute the mean field theory.
|
|
17
|
+
|
|
18
|
+
Example
|
|
19
|
+
-------
|
|
20
|
+
>>> sim = Simulation('electrons', 'path/to/simulation')
|
|
21
|
+
>>> mft = MeanFieldTheory(sim, 1)
|
|
22
|
+
>>> mft_e1 = mft['e1']
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, simulation, mft_axis=None):
|
|
26
|
+
super().__init__(f"MeanFieldTheory({mft_axis})")
|
|
27
|
+
if not isinstance(simulation, Simulation):
|
|
28
|
+
raise ValueError("Simulation must be a Simulation object.")
|
|
29
|
+
self._simulation = simulation
|
|
30
|
+
self._mft_axis = mft_axis
|
|
31
|
+
self._mft_computed = {}
|
|
32
|
+
self._species_handler = {}
|
|
33
|
+
|
|
34
|
+
def __getitem__(self, key):
|
|
35
|
+
if key in self._simulation._species:
|
|
36
|
+
if key not in self._species_handler:
|
|
37
|
+
self._species_handler[key] = MFT_Species_Handler(self._simulation[key], self._mft_axis)
|
|
38
|
+
return self._species_handler[key]
|
|
39
|
+
if key not in self._mft_computed:
|
|
40
|
+
self._mft_computed[key] = MFT_Diagnostic(self._simulation[key], self._mft_axis)
|
|
41
|
+
return self._mft_computed[key]
|
|
42
|
+
|
|
43
|
+
def delete_all(self):
|
|
44
|
+
self._mft_computed = {}
|
|
45
|
+
|
|
46
|
+
def delete(self, key):
|
|
47
|
+
if key in self._mft_computed:
|
|
48
|
+
del self._mft_computed[key]
|
|
49
|
+
else:
|
|
50
|
+
print(f"MeanFieldTheory {key} not found in simulation")
|
|
51
|
+
|
|
52
|
+
def process(self, diagnostic):
|
|
53
|
+
"""Apply mean field theory to a diagnostic"""
|
|
54
|
+
return MFT_Diagnostic(diagnostic, self._mft_axis)
|
|
55
|
+
|
|
56
|
+
class MFT_Diagnostic(Diagnostic):
|
|
57
|
+
"""
|
|
58
|
+
Class to compute mean field theory of a diagnostic.
|
|
59
|
+
Acts as a container for the average and fluctuation components.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
diagnostic : Diagnostic
|
|
64
|
+
The diagnostic object.
|
|
65
|
+
mft_axis : int
|
|
66
|
+
The axis to compute mean field theory along.
|
|
67
|
+
|
|
68
|
+
Example
|
|
69
|
+
-------
|
|
70
|
+
>>> sim = Simulation('electrons', 'path/to/simulation')
|
|
71
|
+
>>> diag = sim['e1']
|
|
72
|
+
>>> mft = MFT_Diagnostic(diag, 1)
|
|
73
|
+
>>> avg = mft['avg']
|
|
74
|
+
>>> delta = mft['delta']
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(self, diagnostic, mft_axis):
|
|
78
|
+
# Initialize using parent's __init__ with the same species
|
|
79
|
+
if hasattr(diagnostic, '_species'):
|
|
80
|
+
super().__init__(species=diagnostic._species)
|
|
81
|
+
else:
|
|
82
|
+
super().__init__(None)
|
|
83
|
+
|
|
84
|
+
self._name = f"MFT[{diagnostic._name}]"
|
|
85
|
+
self._diag = diagnostic
|
|
86
|
+
self._mft_axis = mft_axis
|
|
87
|
+
self._data = None
|
|
88
|
+
self._all_loaded = False
|
|
89
|
+
|
|
90
|
+
# Components that will be lazily created
|
|
91
|
+
self._components = {}
|
|
92
|
+
|
|
93
|
+
# Copy all relevant attributes from diagnostic
|
|
94
|
+
for attr in ['_dt', '_dx', '_ndump', '_axis', '_nx', '_x', '_grid', '_dim', '_maxiter']:
|
|
95
|
+
if hasattr(diagnostic, attr):
|
|
96
|
+
setattr(self, attr, getattr(diagnostic, attr))
|
|
97
|
+
|
|
98
|
+
def __getitem__(self, key):
|
|
99
|
+
"""
|
|
100
|
+
Get a component of the mean field theory.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
key : str
|
|
105
|
+
Either "avg" for average or "delta" for fluctuations.
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
Diagnostic
|
|
110
|
+
The requested component.
|
|
111
|
+
"""
|
|
112
|
+
if key == "avg":
|
|
113
|
+
if "avg" not in self._components:
|
|
114
|
+
self._components["avg"] = MFT_Diagnostic_Average(self._diag, self._mft_axis)
|
|
115
|
+
return self._components["avg"]
|
|
116
|
+
|
|
117
|
+
elif key == "delta":
|
|
118
|
+
if "delta" not in self._components:
|
|
119
|
+
self._components["delta"] = MFT_Diagnostic_Fluctuations(self._diag, self._mft_axis)
|
|
120
|
+
return self._components["delta"]
|
|
121
|
+
|
|
122
|
+
else:
|
|
123
|
+
raise ValueError("Invalid MFT component. Use 'avg' or 'delta'.")
|
|
124
|
+
|
|
125
|
+
def load_all(self):
|
|
126
|
+
"""Load both average and fluctuation components"""
|
|
127
|
+
# This will compute both components at once for efficiency
|
|
128
|
+
if "avg" not in self._components:
|
|
129
|
+
self._components["avg"] = MFT_Diagnostic_Average(self._diag, self._mft_axis)
|
|
130
|
+
|
|
131
|
+
if "delta" not in self._components:
|
|
132
|
+
self._components["delta"] = MFT_Diagnostic_Fluctuations(self._diag, self._mft_axis)
|
|
133
|
+
|
|
134
|
+
# Load both components
|
|
135
|
+
self._components["avg"].load_all()
|
|
136
|
+
self._components["delta"].load_all()
|
|
137
|
+
|
|
138
|
+
# Mark this container as loaded
|
|
139
|
+
self._all_loaded = True
|
|
140
|
+
|
|
141
|
+
return self._components
|
|
142
|
+
|
|
143
|
+
class MFT_Diagnostic_Average(Diagnostic):
|
|
144
|
+
"""
|
|
145
|
+
Class to compute the average component of mean field theory.
|
|
146
|
+
Inherits from Diagnostic to ensure all operation overloads work properly.
|
|
147
|
+
|
|
148
|
+
Parameters
|
|
149
|
+
----------
|
|
150
|
+
diagnostic : Diagnostic
|
|
151
|
+
The diagnostic object.
|
|
152
|
+
mft_axis : int
|
|
153
|
+
The axis to compute the mean field theory.
|
|
154
|
+
|
|
155
|
+
Example
|
|
156
|
+
-------
|
|
157
|
+
>>> sim = Simulation('electrons', 'path/to/simulation')
|
|
158
|
+
>>> diag = sim['e1']
|
|
159
|
+
>>> avg = MFT_Diagnostic_Average(diag, 1)
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
def __init__(self, diagnostic, mft_axis):
|
|
163
|
+
# Initialize with the same species as the diagnostic
|
|
164
|
+
if hasattr(diagnostic, '_species'):
|
|
165
|
+
super().__init__(species=diagnostic._species)
|
|
166
|
+
else:
|
|
167
|
+
super().__init__(None)
|
|
168
|
+
|
|
169
|
+
if mft_axis is None:
|
|
170
|
+
raise ValueError("Mean field theory axis must be specified.")
|
|
171
|
+
|
|
172
|
+
self._name = f"MFT_avg[{diagnostic._name}, {mft_axis}]"
|
|
173
|
+
self._diag = diagnostic
|
|
174
|
+
self._mft_axis = mft_axis
|
|
175
|
+
self._data = None
|
|
176
|
+
self._all_loaded = False
|
|
177
|
+
|
|
178
|
+
# Copy all relevant attributes from diagnostic
|
|
179
|
+
for attr in ['_dt', '_dx', '_ndump', '_axis', '_nx', '_x', '_grid', '_dim', '_maxiter']:
|
|
180
|
+
if hasattr(diagnostic, attr):
|
|
181
|
+
setattr(self, attr, getattr(diagnostic, attr))
|
|
182
|
+
|
|
183
|
+
def load_all(self):
|
|
184
|
+
"""Load all data and compute the average"""
|
|
185
|
+
if self._data is not None:
|
|
186
|
+
print("Data already loaded")
|
|
187
|
+
return self._data
|
|
188
|
+
|
|
189
|
+
if not hasattr(self._diag, '_data') or self._diag._data is None:
|
|
190
|
+
self._diag.load_all()
|
|
191
|
+
|
|
192
|
+
if self._mft_axis is None:
|
|
193
|
+
raise ValueError("Mean field theory axis must be specified.")
|
|
194
|
+
else:
|
|
195
|
+
self._data = np.expand_dims(self._diag._data.mean(axis=self._mft_axis), axis=-1)
|
|
196
|
+
|
|
197
|
+
self._all_loaded = True
|
|
198
|
+
return self._data
|
|
199
|
+
|
|
200
|
+
def _data_generator(self, index):
|
|
201
|
+
"""Generate average data for a specific index"""
|
|
202
|
+
if self._mft_axis is not None:
|
|
203
|
+
# Get the data for this index
|
|
204
|
+
data = self._diag[index]
|
|
205
|
+
# Compute the average (mean) along the specified axis
|
|
206
|
+
# Note: When accessing a slice, axis numbering is 0-based
|
|
207
|
+
avg = np.expand_dims(data.mean(axis=self._mft_axis-1), axis=-1)
|
|
208
|
+
yield avg
|
|
209
|
+
else:
|
|
210
|
+
raise ValueError("Invalid axis for mean field theory.")
|
|
211
|
+
|
|
212
|
+
def __getitem__(self, index):
|
|
213
|
+
"""Get average at a specific index"""
|
|
214
|
+
if self._all_loaded and self._data is not None:
|
|
215
|
+
return self._data[index]
|
|
216
|
+
|
|
217
|
+
# Otherwise compute on-demand
|
|
218
|
+
if isinstance(index, int):
|
|
219
|
+
return next(self._data_generator(index))
|
|
220
|
+
elif isinstance(index, slice):
|
|
221
|
+
start = 0 if index.start is None else index.start
|
|
222
|
+
step = 1 if index.step is None else index.step
|
|
223
|
+
stop = self._diag._maxiter if index.stop is None else index.stop
|
|
224
|
+
return np.array([next(self._data_generator(i)) for i in range(start, stop, step)])
|
|
225
|
+
else:
|
|
226
|
+
raise ValueError("Invalid index type. Use int or slice.")
|
|
227
|
+
|
|
228
|
+
class MFT_Diagnostic_Fluctuations(Diagnostic):
|
|
229
|
+
"""
|
|
230
|
+
Class to compute the fluctuation component of mean field theory.
|
|
231
|
+
Inherits from Diagnostic to ensure all operation overloads work properly.
|
|
232
|
+
|
|
233
|
+
Parameters
|
|
234
|
+
----------
|
|
235
|
+
diagnostic : Diagnostic
|
|
236
|
+
The diagnostic object.
|
|
237
|
+
mft_axis : int
|
|
238
|
+
The axis to compute the mean field theory.
|
|
239
|
+
|
|
240
|
+
Example
|
|
241
|
+
-------
|
|
242
|
+
>>> sim = Simulation('electrons', 'path/to/simulation')
|
|
243
|
+
>>> diag = sim['e1']
|
|
244
|
+
>>> delta = MFT_Diagnostic_Fluctuations(diag, 1)
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
def __init__(self, diagnostic, mft_axis):
|
|
248
|
+
# Initialize with the same species as the diagnostic
|
|
249
|
+
if hasattr(diagnostic, '_species'):
|
|
250
|
+
super().__init__(species=diagnostic._species)
|
|
251
|
+
else:
|
|
252
|
+
super().__init__(None)
|
|
253
|
+
|
|
254
|
+
if mft_axis is None:
|
|
255
|
+
raise ValueError("Mean field theory axis must be specified.")
|
|
256
|
+
|
|
257
|
+
self._name = f"MFT_delta[{diagnostic._name}, {mft_axis}]"
|
|
258
|
+
self._diag = diagnostic
|
|
259
|
+
self._mft_axis = mft_axis
|
|
260
|
+
self._data = None
|
|
261
|
+
self._all_loaded = False
|
|
262
|
+
|
|
263
|
+
# Copy all relevant attributes from diagnostic
|
|
264
|
+
for attr in ['_dt', '_dx', '_ndump', '_axis', '_nx', '_x', '_grid', '_dim', '_maxiter']:
|
|
265
|
+
if hasattr(diagnostic, attr):
|
|
266
|
+
setattr(self, attr, getattr(diagnostic, attr))
|
|
267
|
+
|
|
268
|
+
def load_all(self):
|
|
269
|
+
"""Load all data and compute the fluctuations"""
|
|
270
|
+
if self._data is not None:
|
|
271
|
+
print("Data already loaded")
|
|
272
|
+
return self._data
|
|
273
|
+
|
|
274
|
+
if not hasattr(self._diag, '_data') or self._diag._data is None:
|
|
275
|
+
self._diag.load_all()
|
|
276
|
+
|
|
277
|
+
if self._mft_axis is None:
|
|
278
|
+
raise ValueError("Mean field theory axis must be specified.")
|
|
279
|
+
else:
|
|
280
|
+
# Compute the average
|
|
281
|
+
avg = self._diag._data.mean(axis=self._mft_axis)
|
|
282
|
+
# Reshape avg for broadcasting
|
|
283
|
+
broadcast_shape = list(self._diag._data.shape)
|
|
284
|
+
broadcast_shape[self._mft_axis] = 1
|
|
285
|
+
avg_reshaped = avg.reshape(broadcast_shape)
|
|
286
|
+
# Compute the fluctuations
|
|
287
|
+
self._data = self._diag._data - avg_reshaped
|
|
288
|
+
|
|
289
|
+
self._all_loaded = True
|
|
290
|
+
return self._data
|
|
291
|
+
|
|
292
|
+
def _data_generator(self, index):
|
|
293
|
+
"""Generate fluctuation data for a specific index"""
|
|
294
|
+
if self._mft_axis is not None:
|
|
295
|
+
# Get the data for this index
|
|
296
|
+
data = self._diag[index]
|
|
297
|
+
# Compute the average (mean) along the specified axis
|
|
298
|
+
# Note: When accessing a slice, axis numbering is 0-based
|
|
299
|
+
avg = data.mean(axis=self._mft_axis-1)
|
|
300
|
+
# Expand dimensions to enable broadcasting
|
|
301
|
+
avg_reshaped = np.expand_dims(avg, axis=self._mft_axis-1)
|
|
302
|
+
# Compute fluctuations
|
|
303
|
+
delta = data - avg_reshaped
|
|
304
|
+
yield delta
|
|
305
|
+
else:
|
|
306
|
+
raise ValueError("Invalid axis for mean field theory.")
|
|
307
|
+
|
|
308
|
+
def __getitem__(self, index):
|
|
309
|
+
"""Get fluctuations at a specific index"""
|
|
310
|
+
if self._all_loaded and self._data is not None:
|
|
311
|
+
return self._data[index]
|
|
312
|
+
|
|
313
|
+
# Otherwise compute on-demand
|
|
314
|
+
if isinstance(index, int):
|
|
315
|
+
return next(self._data_generator(index))
|
|
316
|
+
elif isinstance(index, slice):
|
|
317
|
+
start = 0 if index.start is None else index.start
|
|
318
|
+
step = 1 if index.step is None else index.step
|
|
319
|
+
stop = self._diag._maxiter if index.stop is None else index.stop
|
|
320
|
+
return np.array([next(self._data_generator(i)) for i in range(start, stop, step)])
|
|
321
|
+
else:
|
|
322
|
+
raise ValueError("Invalid index type. Use int or slice.")
|
|
323
|
+
|
|
324
|
+
class MFT_Species_Handler:
|
|
325
|
+
"""
|
|
326
|
+
Class to handle mean field theory for a species.
|
|
327
|
+
Acts as a wrapper for the MFT_Diagnostic class.
|
|
328
|
+
|
|
329
|
+
Not intended to be used directly, but through the MFT_Simulation class.
|
|
330
|
+
|
|
331
|
+
Parameters
|
|
332
|
+
----------
|
|
333
|
+
species_handler : Species_Handler
|
|
334
|
+
The species handler object.
|
|
335
|
+
mft_axis : int
|
|
336
|
+
The axis to compute the mean field theory.
|
|
337
|
+
"""
|
|
338
|
+
|
|
339
|
+
def __init__(self, species_handler, mft_axis):
|
|
340
|
+
self._species_handler = species_handler
|
|
341
|
+
self._mft_axis = mft_axis
|
|
342
|
+
self._mft_computed = {}
|
|
343
|
+
|
|
344
|
+
def __getitem__(self, key):
|
|
345
|
+
if key not in self._mft_computed:
|
|
346
|
+
diag = self._species_handler[key]
|
|
347
|
+
self._mft_computed[key] = MFT_Diagnostic(diag, self._mft_axis)
|
|
348
|
+
return self._mft_computed[key]
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from ..data.data import OsirisGridFile
|
|
3
|
+
|
|
4
|
+
# Deprecated
|
|
5
|
+
class MFT_Single(OsirisGridFile):
|
|
6
|
+
'''
|
|
7
|
+
Class to handle the mean field theory on data. Inherits from OsirisGridFile.
|
|
8
|
+
|
|
9
|
+
Parameters
|
|
10
|
+
----------
|
|
11
|
+
source : str or OsirisGridFile
|
|
12
|
+
The filename or an OsirisGridFile object.
|
|
13
|
+
axis : int
|
|
14
|
+
The axis to average over.
|
|
15
|
+
'''
|
|
16
|
+
def __init__(self, source, axis=1):
|
|
17
|
+
if isinstance(source, OsirisGridFile):
|
|
18
|
+
self.__dict__.update(source.__dict__)
|
|
19
|
+
else:
|
|
20
|
+
super().__init__(source)
|
|
21
|
+
self._compute_mean_field(axis=axis)
|
|
22
|
+
|
|
23
|
+
def _compute_mean_field(self, axis=1):
|
|
24
|
+
self._average = np.expand_dims(np.mean(self.data, axis=axis), axis=axis)
|
|
25
|
+
self._fluctuations = self.data - self._average
|
|
26
|
+
|
|
27
|
+
def __array__(self):
|
|
28
|
+
return self.data
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def average(self):
|
|
32
|
+
return self._average
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def delta(self):
|
|
36
|
+
return self._fluctuations
|
|
37
|
+
|
|
38
|
+
def __str__(self):
|
|
39
|
+
return super().__str__() + f'\nAverage: {self.average.shape}\nDelta: {self.delta.shape}'
|
|
40
|
+
|
|
41
|
+
def derivative(self, field, axis=0):
|
|
42
|
+
'''
|
|
43
|
+
Compute the derivative of the average or the fluctuations.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
field : MeanFieldTheory.average or MeanFieldTheory.delta
|
|
48
|
+
The field to compute the derivative.
|
|
49
|
+
axis : int
|
|
50
|
+
The axis to compute the derivative.
|
|
51
|
+
'''
|
|
52
|
+
return np.gradient(field, self.dx[axis], axis=0)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from ..data.diagnostic import Diagnostic
|
|
2
|
+
|
|
3
|
+
class PostProcess(Diagnostic):
|
|
4
|
+
"""
|
|
5
|
+
Base class for post-processing operations.
|
|
6
|
+
Inherits from Diagnostic to ensure all operation overloads work.
|
|
7
|
+
|
|
8
|
+
Parameters
|
|
9
|
+
----------
|
|
10
|
+
name : str
|
|
11
|
+
Name of the post-processing operation.
|
|
12
|
+
species : str
|
|
13
|
+
The species to analyze.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, name, species=None):
|
|
17
|
+
# Initialize with the same interface as Diagnostic
|
|
18
|
+
super().__init__(species)
|
|
19
|
+
self._name = name
|
|
20
|
+
self._all_loaded = False
|
|
21
|
+
self._data = None
|
|
22
|
+
|
|
23
|
+
def process(self, diagnostic):
|
|
24
|
+
"""
|
|
25
|
+
Apply the post-processing to a diagnostic.
|
|
26
|
+
Must be implemented by subclasses.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
diagnostic : Diagnostic
|
|
31
|
+
The diagnostic to process.
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
Diagnostic or PostProcess
|
|
36
|
+
The processed diagnostic.
|
|
37
|
+
"""
|
|
38
|
+
raise NotImplementedError("Subclasses must implement process method")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# PostProcessing_Simulation
|
|
42
|
+
# PostProcessing_Diagnostic
|