osiris-utils 1.1.3__py3-none-any.whl → 1.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,334 @@
1
+ from ..utils import *
2
+ from ..data.simulation import Simulation
3
+ from .postprocess import PostProcess
4
+ from ..data.diagnostic import Diagnostic
5
+
6
+ class MeanFieldTheory_Simulation(PostProcess):
7
+ """
8
+ Class to compute the mean field theory of a diagnostic. Works as a wrapper for the MFT_Diagnostic class.
9
+ Inherits from PostProcess to ensure all operation overloads work properly.
10
+
11
+ Parameters
12
+ ----------
13
+ simulation : Simulation
14
+ The simulation object.
15
+ mft_axis : int
16
+ The axis to compute the mean field theory.
17
+
18
+ """
19
+
20
+ def __init__(self, simulation, mft_axis=None):
21
+ super().__init__(f"MeanFieldTheory({mft_axis})")
22
+ if not isinstance(simulation, Simulation):
23
+ raise ValueError("Simulation must be a Simulation object.")
24
+ self._simulation = simulation
25
+ self._mft_axis = mft_axis
26
+ self._mft_computed = {}
27
+ self._species_handler = {}
28
+
29
+ def __getitem__(self, key):
30
+ if key in self._simulation._species:
31
+ if key not in self._species_handler:
32
+ self._species_handler[key] = MFT_Species_Handler(self._simulation[key], self._mft_axis)
33
+ return self._species_handler[key]
34
+ if key not in self._mft_computed:
35
+ self._mft_computed[key] = MFT_Diagnostic(self._simulation[key], self._mft_axis)
36
+ return self._mft_computed[key]
37
+
38
+ def delete_all(self):
39
+ self._mft_computed = {}
40
+
41
+ def delete(self, key):
42
+ if key in self._mft_computed:
43
+ del self._mft_computed[key]
44
+ else:
45
+ print(f"MeanFieldTheory {key} not found in simulation")
46
+
47
+ def process(self, diagnostic):
48
+ """Apply mean field theory to a diagnostic"""
49
+ return MFT_Diagnostic(diagnostic, self._mft_axis)
50
+
51
+ class MFT_Diagnostic(Diagnostic):
52
+ """
53
+ Class to compute mean field theory of a diagnostic.
54
+ Acts as a container for the average and fluctuation components.
55
+
56
+ Parameters
57
+ ----------
58
+ diagnostic : Diagnostic
59
+ The diagnostic object.
60
+ mft_axis : int
61
+ The axis to compute mean field theory along.
62
+
63
+
64
+ """
65
+
66
+ def __init__(self, diagnostic, mft_axis):
67
+ # Initialize using parent's __init__ with the same species
68
+ if hasattr(diagnostic, '_species'):
69
+ super().__init__(simulation_folder=diagnostic._simulation_folder if hasattr(diagnostic, '_simulation_folder') else None,
70
+ species=diagnostic._species)
71
+ else:
72
+ super().__init__(None)
73
+
74
+ self._name = f"MFT[{diagnostic._name}]"
75
+ self._diag = diagnostic
76
+ self._mft_axis = mft_axis
77
+ self._data = None
78
+ self._all_loaded = False
79
+
80
+ # Components that will be lazily created
81
+ self._components = {}
82
+
83
+ # Copy all relevant attributes from diagnostic
84
+ for attr in ['_dt', '_dx', '_ndump', '_axis', '_nx', '_x', '_grid', '_dim', '_maxiter', '_tunits', '_type']:
85
+ if hasattr(diagnostic, attr):
86
+ setattr(self, attr, getattr(diagnostic, attr))
87
+
88
+ def __getitem__(self, key):
89
+ """
90
+ Get a component of the mean field theory.
91
+
92
+ Parameters
93
+ ----------
94
+ key : str
95
+ Either "avg" for average or "delta" for fluctuations.
96
+
97
+ Returns
98
+ -------
99
+ Diagnostic
100
+ The requested component.
101
+ """
102
+ if key == "avg":
103
+ if "avg" not in self._components:
104
+ self._components["avg"] = MFT_Diagnostic_Average(self._diag, self._mft_axis)
105
+ return self._components["avg"]
106
+
107
+ elif key == "delta":
108
+ if "delta" not in self._components:
109
+ self._components["delta"] = MFT_Diagnostic_Fluctuations(self._diag, self._mft_axis)
110
+ return self._components["delta"]
111
+
112
+ else:
113
+ raise ValueError("Invalid MFT component. Use 'avg' or 'delta'.")
114
+
115
+ def load_all(self):
116
+ """Load both average and fluctuation components"""
117
+ # This will compute both components at once for efficiency
118
+ if "avg" not in self._components:
119
+ self._components["avg"] = MFT_Diagnostic_Average(self._diag, self._mft_axis)
120
+
121
+ if "delta" not in self._components:
122
+ self._components["delta"] = MFT_Diagnostic_Fluctuations(self._diag, self._mft_axis)
123
+
124
+ # Load both components
125
+ self._components["avg"].load_all()
126
+ self._components["delta"].load_all()
127
+
128
+ # Mark this container as loaded
129
+ self._all_loaded = True
130
+
131
+ return self._components
132
+
133
+ class MFT_Diagnostic_Average(Diagnostic):
134
+ """
135
+ Class to compute the average component of mean field theory.
136
+ Inherits from Diagnostic to ensure all operation overloads work properly.
137
+
138
+ Parameters
139
+ ----------
140
+ diagnostic : Diagnostic
141
+ The diagnostic object.
142
+ mft_axis : int
143
+ The axis to compute the mean field theory.
144
+
145
+ """
146
+
147
+ def __init__(self, diagnostic, mft_axis):
148
+ # Initialize with the same species as the diagnostic
149
+ if hasattr(diagnostic, '_species'):
150
+ super().__init__(simulation_folder=diagnostic._simulation_folder if hasattr(diagnostic, '_simulation_folder') else None,
151
+ species=diagnostic._species)
152
+ else:
153
+ super().__init__(None)
154
+
155
+ if mft_axis is None:
156
+ raise ValueError("Mean field theory axis must be specified.")
157
+
158
+ self.postprocess_name = f"MFT_AVG"
159
+
160
+ self._name = f"MFT_avg[{diagnostic._name}, {mft_axis}]"
161
+ self._diag = diagnostic
162
+ self._mft_axis = mft_axis
163
+ self._data = None
164
+ self._all_loaded = False
165
+
166
+ # Copy all relevant attributes from diagnostic
167
+ for attr in ['_dt', '_dx', '_ndump', '_axis', '_nx', '_x', '_grid', '_dim', '_maxiter', '_type']:
168
+ if hasattr(diagnostic, attr):
169
+ setattr(self, attr, getattr(diagnostic, attr))
170
+
171
+ def load_all(self):
172
+ """Load all data and compute the average"""
173
+ if self._data is not None:
174
+ print("Data already loaded")
175
+ return self._data
176
+
177
+ if not hasattr(self._diag, '_data') or self._diag._data is None:
178
+ self._diag.load_all()
179
+
180
+ if self._mft_axis is None:
181
+ raise ValueError("Mean field theory axis must be specified.")
182
+ else:
183
+ self._data = np.expand_dims(self._diag._data.mean(axis=self._mft_axis), axis=-1)
184
+
185
+ self._all_loaded = True
186
+ return self._data
187
+
188
+ def _data_generator(self, index):
189
+ """Generate average data for a specific index"""
190
+ if self._mft_axis is not None:
191
+ # Get the data for this index
192
+ data = self._diag[index]
193
+ # Compute the average (mean) along the specified axis
194
+ # Note: When accessing a slice, axis numbering is 0-based
195
+ avg = np.expand_dims(data.mean(axis=self._mft_axis-1), axis=-1)
196
+ yield avg
197
+ else:
198
+ raise ValueError("Invalid axis for mean field theory.")
199
+
200
+ def __getitem__(self, index):
201
+ """Get average at a specific index"""
202
+ if self._all_loaded and self._data is not None:
203
+ return self._data[index]
204
+
205
+ # Otherwise compute on-demand
206
+ if isinstance(index, int):
207
+ return next(self._data_generator(index))
208
+ elif isinstance(index, slice):
209
+ start = 0 if index.start is None else index.start
210
+ step = 1 if index.step is None else index.step
211
+ stop = self._diag._maxiter if index.stop is None else index.stop
212
+ return np.array([next(self._data_generator(i)) for i in range(start, stop, step)])
213
+ else:
214
+ raise ValueError("Invalid index type. Use int or slice.")
215
+
216
+ class MFT_Diagnostic_Fluctuations(Diagnostic):
217
+ """
218
+ Class to compute the fluctuation component of mean field theory.
219
+ Inherits from Diagnostic to ensure all operation overloads work properly.
220
+
221
+ Parameters
222
+ ----------
223
+ diagnostic : Diagnostic
224
+ The diagnostic object.
225
+ mft_axis : int
226
+ The axis to compute the mean field theory.
227
+
228
+ """
229
+
230
+ def __init__(self, diagnostic, mft_axis):
231
+ # Initialize with the same species as the diagnostic
232
+ if hasattr(diagnostic, '_species'):
233
+ super().__init__(simulation_folder=diagnostic._simulation_folder if hasattr(diagnostic, '_simulation_folder') else None,
234
+ species=diagnostic._species)
235
+ else:
236
+ super().__init__(None)
237
+
238
+ if mft_axis is None:
239
+ raise ValueError("Mean field theory axis must be specified.")
240
+
241
+ self.postprocess_name = f"MFT_FLT"
242
+
243
+ self._name = f"MFT_delta[{diagnostic._name}, {mft_axis}]"
244
+ self._diag = diagnostic
245
+ self._mft_axis = mft_axis
246
+ self._data = None
247
+ self._all_loaded = False
248
+
249
+ # Copy all relevant attributes from diagnostic
250
+ for attr in ['_dt', '_dx', '_ndump', '_axis', '_nx', '_x', '_grid', '_dim', '_maxiter', '_type']:
251
+ if hasattr(diagnostic, attr):
252
+ setattr(self, attr, getattr(diagnostic, attr))
253
+
254
+ def load_all(self):
255
+ """Load all data and compute the fluctuations"""
256
+ if self._data is not None:
257
+ print("Data already loaded")
258
+ return self._data
259
+
260
+ if not hasattr(self._diag, '_data') or self._diag._data is None:
261
+ self._diag.load_all()
262
+
263
+ if self._mft_axis is None:
264
+ raise ValueError("Mean field theory axis must be specified.")
265
+ else:
266
+ # Compute the average
267
+ avg = self._diag._data.mean(axis=self._mft_axis)
268
+ # Reshape avg for broadcasting
269
+ broadcast_shape = list(self._diag._data.shape)
270
+ broadcast_shape[self._mft_axis] = 1
271
+ avg_reshaped = avg.reshape(broadcast_shape)
272
+ # Compute the fluctuations
273
+ self._data = self._diag._data - avg_reshaped
274
+
275
+ self._all_loaded = True
276
+ return self._data
277
+
278
+ def _data_generator(self, index):
279
+ """Generate fluctuation data for a specific index"""
280
+ if self._mft_axis is not None:
281
+ # Get the data for this index
282
+ data = self._diag[index]
283
+ # Compute the average (mean) along the specified axis
284
+ # Note: When accessing a slice, axis numbering is 0-based
285
+ avg = data.mean(axis=self._mft_axis-1)
286
+ # Expand dimensions to enable broadcasting
287
+ avg_reshaped = np.expand_dims(avg, axis=self._mft_axis-1)
288
+ # Compute fluctuations
289
+ delta = data - avg_reshaped
290
+ yield delta
291
+ else:
292
+ raise ValueError("Invalid axis for mean field theory.")
293
+
294
+ def __getitem__(self, index):
295
+ """Get fluctuations at a specific index"""
296
+ if self._all_loaded and self._data is not None:
297
+ return self._data[index]
298
+
299
+ # Otherwise compute on-demand
300
+ if isinstance(index, int):
301
+ return next(self._data_generator(index))
302
+ elif isinstance(index, slice):
303
+ start = 0 if index.start is None else index.start
304
+ step = 1 if index.step is None else index.step
305
+ stop = self._diag._maxiter if index.stop is None else index.stop
306
+ return np.array([next(self._data_generator(i)) for i in range(start, stop, step)])
307
+ else:
308
+ raise ValueError("Invalid index type. Use int or slice.")
309
+
310
+ class MFT_Species_Handler:
311
+ """
312
+ Class to handle mean field theory for a species.
313
+ Acts as a wrapper for the MFT_Diagnostic class.
314
+
315
+ Not intended to be used directly, but through the MFT_Simulation class.
316
+
317
+ Parameters
318
+ ----------
319
+ species_handler : Species_Handler
320
+ The species handler object.
321
+ mft_axis : int
322
+ The axis to compute the mean field theory.
323
+ """
324
+
325
+ def __init__(self, species_handler, mft_axis):
326
+ self._species_handler = species_handler
327
+ self._mft_axis = mft_axis
328
+ self._mft_computed = {}
329
+
330
+ def __getitem__(self, key):
331
+ if key not in self._mft_computed:
332
+ diag = self._species_handler[key]
333
+ self._mft_computed[key] = MFT_Diagnostic(diag, self._mft_axis)
334
+ return self._mft_computed[key]
@@ -0,0 +1,52 @@
1
+ import numpy as np
2
+ from ..data.data import OsirisGridFile
3
+
4
+ # Deprecated
5
+ class MFT_Single(OsirisGridFile):
6
+ '''
7
+ Class to handle the mean field theory on data. Inherits from OsirisGridFile.
8
+
9
+ Parameters
10
+ ----------
11
+ source : str or OsirisGridFile
12
+ The filename or an OsirisGridFile object.
13
+ axis : int
14
+ The axis to average over.
15
+ '''
16
+ def __init__(self, source, axis=1):
17
+ if isinstance(source, OsirisGridFile):
18
+ self.__dict__.update(source.__dict__)
19
+ else:
20
+ super().__init__(source)
21
+ self._compute_mean_field(axis=axis)
22
+
23
+ def _compute_mean_field(self, axis=1):
24
+ self._average = np.expand_dims(np.mean(self.data, axis=axis), axis=axis)
25
+ self._fluctuations = self.data - self._average
26
+
27
+ def __array__(self):
28
+ return self.data
29
+
30
+ @property
31
+ def average(self):
32
+ return self._average
33
+
34
+ @property
35
+ def delta(self):
36
+ return self._fluctuations
37
+
38
+ def __str__(self):
39
+ return super().__str__() + f'\nAverage: {self.average.shape}\nDelta: {self.delta.shape}'
40
+
41
+ def derivative(self, field, axis=0):
42
+ '''
43
+ Compute the derivative of the average or the fluctuations.
44
+
45
+ Parameters
46
+ ----------
47
+ field : MeanFieldTheory.average or MeanFieldTheory.delta
48
+ The field to compute the derivative.
49
+ axis : int
50
+ The axis to compute the derivative.
51
+ '''
52
+ return np.gradient(field, self.dx[axis], axis=0)
@@ -0,0 +1,42 @@
1
+ from ..data.diagnostic import Diagnostic
2
+
3
+ class PostProcess(Diagnostic):
4
+ """
5
+ Base class for post-processing operations.
6
+ Inherits from Diagnostic to ensure all operation overloads work.
7
+
8
+ Parameters
9
+ ----------
10
+ name : str
11
+ Name of the post-processing operation.
12
+ species : str
13
+ The species to analyze.
14
+ """
15
+
16
+ def __init__(self, name, species=None):
17
+ # Initialize with the same interface as Diagnostic
18
+ super().__init__(species)
19
+ self._name = name
20
+ self._all_loaded = False
21
+ self._data = None
22
+
23
+ def process(self, diagnostic):
24
+ """
25
+ Apply the post-processing to a diagnostic.
26
+ Must be implemented by subclasses.
27
+
28
+ Parameters
29
+ ----------
30
+ diagnostic : Diagnostic
31
+ The diagnostic to process.
32
+
33
+ Returns
34
+ -------
35
+ Diagnostic or PostProcess
36
+ The processed diagnostic.
37
+ """
38
+ raise NotImplementedError("Subclasses must implement process method")
39
+
40
+
41
+ # PostProcessing_Simulation
42
+ # PostProcessing_Diagnostic
@@ -0,0 +1,171 @@
1
+ from ..utils import *
2
+ from ..data.simulation import Simulation
3
+ from .postprocess import PostProcess
4
+ from ..data.diagnostic import Diagnostic
5
+
6
+ OSIRIS_P = ["P11", "P12", "P13", "P21", "P22", "P23", "P31", "P32", "P33"]
7
+
8
+ class PressureCorrection_Simulation(PostProcess):
9
+ def __init__(self, simulation):
10
+ super().__init__(f"PressureCorrection Simulation")
11
+ """
12
+ Class to correct pressure tensor components by subtracting Reynolds stress.
13
+
14
+ Parameters
15
+ ----------
16
+ sim : Simulation
17
+ The simulation object.
18
+ pressure : str
19
+ The pressure component to center.
20
+ """
21
+ if not isinstance(simulation, Simulation):
22
+ raise ValueError("Simulation must be a Simulation object.")
23
+ self._simulation = simulation
24
+ self._pressure_corrected = {}
25
+ self._species_handler = {}
26
+
27
+ def __getitem__(self, key):
28
+ if key in self._simulation._species:
29
+ if key not in self._species_handler:
30
+ self._species_handler[key] = PressureCorrection_Species_Handler(self._simulation[key])
31
+ return self._species_handler[key]
32
+ if key not in OSIRIS_P:
33
+ raise ValueError(f"Invalid pressure component {key}. Supported: {OSIRIS_P}.")
34
+ if key not in self._pressure_corrected:
35
+ print("Weird that it got here - pressure is always species dependent on OSIRIS")
36
+ self._pressure_corrected[key] = PressureCorrection_Diagnostic(self._simulation[key], self._simulation)
37
+ return self._pressure_corrected[key]
38
+
39
+
40
+ def delete_all(self):
41
+ self._pressure_corrected = {}
42
+
43
+ def delete(self, key):
44
+ if key in self._pressure_corrected:
45
+ del self._pressure_corrected[key]
46
+ else:
47
+ print(f"Pressure {key} not found in simulation")
48
+
49
+ def process(self, diagnostic):
50
+ """Apply pressure correction to a diagnostic"""
51
+ return PressureCorrection_Diagnostic(diagnostic, self._simulation)
52
+
53
+ class PressureCorrection_Diagnostic(Diagnostic):
54
+ def __init__(self, diagnostic, n, ufl_j, vfl_k):
55
+
56
+ """
57
+ Class to correct the pressure in the simulation.
58
+
59
+ Parameters
60
+ ----------
61
+ diagnostic : Diagnostic
62
+ The diagnostic object.
63
+ """
64
+ if hasattr(diagnostic, '_species'):
65
+ super().__init__(simulation_folder=diagnostic._simulation_folder if hasattr(diagnostic, '_simulation_folder') else None,
66
+ species=diagnostic._species)
67
+ else:
68
+ super().__init__(None)
69
+
70
+ self.postprocess_name = "P_CORR"
71
+
72
+ if diagnostic._name not in OSIRIS_P:
73
+ raise ValueError(f"Invalid pressure component {diagnostic._name}. Supported: {OSIRIS_P}")
74
+
75
+ self._diag = diagnostic
76
+
77
+ # The density and velocities are now passed as arguments (so it can doesn't depend on the simulation)
78
+ self._n = n
79
+ self._ufl_j = ufl_j
80
+ self._vfl_k = vfl_k
81
+
82
+ for attr in ['_dt', '_dx', '_ndump', '_axis', '_nx', '_x', '_grid', '_dim', '_maxiter', '_type']:
83
+ if hasattr(diagnostic, attr):
84
+ setattr(self, attr, getattr(diagnostic, attr))
85
+
86
+ self._original_name = diagnostic._name
87
+ self._name = diagnostic._name + "_corrected"
88
+
89
+ self._data = None
90
+ self._all_loaded = False
91
+
92
+ def load_all(self):
93
+ if self._data is not None:
94
+ return self._data
95
+
96
+ if not hasattr(self._diag, '_data') or self._diag._data is None:
97
+ self._diag.load_all()
98
+
99
+ print(f"Loading {self._species._name} {self._original_name} diagnostic")
100
+ self._n.load_all()
101
+ self._ufl_j.load_all()
102
+ self._vfl_k.load_all()
103
+
104
+ # Then access the data
105
+ n = self._n.data
106
+ u = self._ufl_j.data
107
+ v = self._vfl_k.data
108
+
109
+ self._data = self._diag.data - n * v * u
110
+ self._all_loaded = True
111
+
112
+ # Unload the data to save memory
113
+ # self._n.unload()
114
+ # self._ufl_j.unload()
115
+ # self._vfl_k.unload()
116
+
117
+ return self._data
118
+
119
+ def __getitem__(self, index):
120
+ """Get data at a specific index"""
121
+ if self._all_loaded and self._data is not None:
122
+ return self._data[index]
123
+
124
+ if isinstance(index, int):
125
+ return next(self._data_generator(index))
126
+ elif isinstance(index, slice):
127
+ start = 0 if index.start is None else index.start
128
+ step = 1 if index.step is None else index.step
129
+ stop = self._diag._maxiter if index.stop is None else index.stop
130
+ return np.array([next(self._data_generator(i)) for i in range(start, stop, step)])
131
+ else:
132
+ raise ValueError("Invalid index type. Use int or slice.")
133
+
134
+ def _data_generator(self, index):
135
+ yield self._diag[index] - self._n[index] * self._vfl_k[index] * self._ufl_j[index]
136
+
137
+ class PressureCorrection_Species_Handler:
138
+ """
139
+ Class to handle pressure correction for a species.
140
+ Acts as a wrapper for the PressureCorrection_Diagnostic class.
141
+
142
+ Not intended to be used directly, but through the PressureCorrection_Simulation class.
143
+
144
+ Parameters
145
+ ----------
146
+ species_handler : Species_Handler
147
+ The species handler object.
148
+ type : str
149
+ The type of derivative to compute. Options are: 't', 'x1', 'x2', 'x3', 'xx', 'xt' and 'tx'.
150
+ axis : int or tuple
151
+ The axis to compute the derivative. Only used for 'xx', 'xt' and 'tx' types.
152
+ """
153
+ def __init__(self, species_handler):
154
+ self._species_handler = species_handler
155
+ self._pressure_corrected = {}
156
+
157
+ def __getitem__(self, key):
158
+ if key not in self._pressure_corrected:
159
+ diag = self._species_handler[key]
160
+
161
+ # Density and velocities alwayes depend on the species so this can be done here
162
+
163
+ n = self._species_handler["n"]
164
+ self._j, self._k = key[-2], key[-1]
165
+ try:
166
+ ufl = self._species_handler[f"ufl{self._j}"]
167
+ except:
168
+ ufl = self._species_handler[f"vfl{self._j}"]
169
+ vfl = self._species_handler[f"vfl{self._k}"]
170
+ self._pressure_corrected[key] = PressureCorrection_Diagnostic(diag, n, ufl, vfl)
171
+ return self._pressure_corrected[key]