osiris-utils 1.1.2__py3-none-any.whl → 1.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,266 @@
1
+ import sys
2
+ import os
3
+ from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QPushButton,
4
+ QFileDialog, QMessageBox, QComboBox, QHBoxLayout,
5
+ QVBoxLayout, QLabel, QLineEdit, QFrame, QDoubleSpinBox)
6
+ from PySide6.QtCore import Qt
7
+ from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.colors import LogNorm
10
+ import numpy as np
11
+ from ..data.data import OsirisGridFile # Update import as needed
12
+ from ..utils import integrate, transverse_average # Update import as needed
13
+
14
+ class LAVA_Qt(QMainWindow):
15
+ def __init__(self):
16
+ super().__init__()
17
+ self.setWindowTitle('LAVA (LabAstro Visualization Assistant) - OSIRIS Data Grid Viewer')
18
+ self.setGeometry(100, 100, 1000, 600)
19
+
20
+ # Initialize data
21
+ self.data_info = None
22
+ self.dims = 0
23
+ self.current_ax = None
24
+ self.current_folder = None
25
+
26
+ # Main widget and layout
27
+ self.main_widget = QWidget()
28
+ self.setCentralWidget(self.main_widget)
29
+ self.main_layout = QVBoxLayout(self.main_widget)
30
+
31
+ # Create UI elements
32
+ self.create_controls()
33
+ self.create_labels_section()
34
+ self.create_plot_area()
35
+
36
+ def create_controls(self):
37
+ # Control buttons frame
38
+ control_frame = QWidget()
39
+ control_layout = QHBoxLayout(control_frame)
40
+
41
+ # Buttons
42
+ self.browse_btn = QPushButton('Browse Folder')
43
+ self.browse_btn.clicked.connect(self.load_folder)
44
+ self.save_btn = QPushButton('Save Plot')
45
+ self.save_btn.clicked.connect(self.save_plot)
46
+
47
+ # File selector
48
+ self.file_selector = QComboBox()
49
+ self.file_selector.setPlaceholderText('Select file...')
50
+ self.file_selector.currentIndexChanged.connect(self.file_selection_changed)
51
+ self.file_selector.view().setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
52
+ self.file_selector.setSizeAdjustPolicy(QComboBox.AdjustToContents)
53
+
54
+ # Plot type combo box
55
+ self.plot_combo = QComboBox()
56
+ self.plot_combo.addItem('Select Plot Type')
57
+ self.plot_combo.currentTextChanged.connect(self.plot_data)
58
+
59
+ control_layout.addWidget(self.browse_btn)
60
+ control_layout.addWidget(self.save_btn)
61
+ control_layout.addWidget(QLabel('Files:'))
62
+ control_layout.addWidget(self.file_selector)
63
+ control_layout.addWidget(QLabel('Plot Type:'))
64
+ control_layout.addWidget(self.plot_combo)
65
+ self.main_layout.addWidget(control_frame)
66
+
67
+ def create_labels_section(self):
68
+ # Labels frame
69
+ labels_frame = QWidget()
70
+ labels_layout = QHBoxLayout(labels_frame)
71
+
72
+ # Title and labels
73
+ self.title_edit = QLineEdit()
74
+ self.xlabel_edit = QLineEdit()
75
+ self.ylabel_edit = QLineEdit()
76
+
77
+ # Connect text changes
78
+ self.title_edit.textChanged.connect(self.update_plot_labels)
79
+ self.xlabel_edit.textChanged.connect(self.update_plot_labels)
80
+ self.ylabel_edit.textChanged.connect(self.update_plot_labels)
81
+
82
+
83
+ labels_layout.addWidget(QLabel('Title:'))
84
+ labels_layout.addWidget(self.title_edit)
85
+ labels_layout.addWidget(QLabel('X Label:'))
86
+ labels_layout.addWidget(self.xlabel_edit)
87
+ labels_layout.addWidget(QLabel('Y Label:'))
88
+ labels_layout.addWidget(self.ylabel_edit)
89
+
90
+ # define the size of the labels frame
91
+ self.main_layout.addWidget(labels_frame)
92
+
93
+
94
+ def load_folder(self):
95
+ folder_dialog = QFileDialog()
96
+ folderpath = folder_dialog.getExistingDirectory(
97
+ self, 'Select Folder with HDF5 Files'
98
+ )
99
+
100
+ if not folderpath:
101
+ return
102
+
103
+ try:
104
+ self.current_folder = folderpath
105
+ self.file_selector.clear()
106
+
107
+ # Find all .h5 files
108
+ h5_files = [f for f in os.listdir(folderpath) if f.endswith('.h5')]
109
+ # all the files end with xxxxxx.h5 so we can use this to order them by the number
110
+ def sort_key(filename):
111
+ try:
112
+ # Split filename into parts and get the numeric portion
113
+ base = os.path.splitext(filename)[0] # Remove .h5
114
+ numeric_part = base.split('-')[-1] # Get last part after -
115
+ return int(numeric_part)
116
+ except (IndexError, ValueError):
117
+ return 0 # Fallback for malformed filenames
118
+
119
+ h5_files.sort(key=sort_key)
120
+
121
+ if not h5_files:
122
+ raise ValueError('No HDF5 files found in selected folder')
123
+
124
+ self.file_selector.addItems(h5_files)
125
+ self.file_selector.setCurrentIndex(0)
126
+
127
+ except Exception as e:
128
+ QMessageBox.critical(self, 'Error', str(e))
129
+
130
+ def file_selection_changed(self, index):
131
+ '''Handle file selection change in the combo box'''
132
+ if index >= 0 and self.current_folder:
133
+ filename = self.file_selector.itemText(index)
134
+ self.process_file(filename)
135
+
136
+ def process_file(self, filename):
137
+ try:
138
+ filepath = os.path.join(self.current_folder, filename)
139
+ gridfile = OsirisGridFile(filepath)
140
+ self.dims = len(gridfile.axis)
141
+ self.type = gridfile.type
142
+
143
+ if self.type == 'grid':
144
+ if self.dims == 1:
145
+ x = np.arange(gridfile.grid[0], gridfile.grid[1], gridfile.dx)
146
+ self.xlabel_edit.setText(r'$%s$ [$%s$]' % (gridfile.axis[0]['long_name'], gridfile.axis[0]['units']))
147
+ self.ylabel_edit.setText(r'$%s$ [$%s$]' % (gridfile.label, gridfile.units))
148
+ self.data_info = (x, gridfile.data)
149
+ elif self.dims == 2:
150
+ x = np.arange(gridfile.grid[0][0], gridfile.grid[0][1], gridfile.dx[0])
151
+ y = np.arange(gridfile.grid[1][0], gridfile.grid[1][1], gridfile.dx[1])
152
+ self.xlabel_edit.setText(r'$%s$ [$%s$]' % (gridfile.axis[0]['long_name'], gridfile.axis[0]['units']))
153
+ self.ylabel_edit.setText(r'$%s$ [$%s$]' % (gridfile.axis[1]['long_name'], gridfile.axis[1]['units']))
154
+ self.data_info = (x, y, gridfile.data)
155
+ elif self.dims == 3:
156
+ raise ValueError('3D not supported yet')
157
+ else:
158
+ raise ValueError('Unsupported dimensionality')
159
+
160
+ self.title_edit.setText(r'$%s$ [$%s$]' %( gridfile.label, gridfile.units))
161
+ self.update_plot_menu()
162
+ self.plot_data()
163
+
164
+ else:
165
+ QMessageBox.information(self, 'Info', f'{self.type} data not supported yet')
166
+
167
+ except Exception as e:
168
+ QMessageBox.critical(self, 'Error', str(e))
169
+
170
+ def create_plot_area(self):
171
+ # Matplotlib figure and canvas
172
+ self.figure = plt.figure(figsize=(8, 6))
173
+ self.canvas = FigureCanvas(self.figure)
174
+ self.main_layout.addWidget(self.canvas)
175
+
176
+ def update_plot_labels(self):
177
+ if self.current_ax:
178
+ self.current_ax.set_xlabel(self.xlabel_edit.text())
179
+ self.current_ax.set_ylabel(self.ylabel_edit.text())
180
+ self.figure.suptitle(self.title_edit.text())
181
+ self.canvas.draw()
182
+
183
+ def plot_data(self):
184
+ self.figure.clear()
185
+ if self.dims == 1:
186
+ self.plot_1d()
187
+ elif self.dims == 2:
188
+ self.plot_2d()
189
+ self.update_plot_labels()
190
+ self.canvas.draw()
191
+
192
+ def plot_1d(self):
193
+ x, data = self.data_info
194
+ self.current_ax = self.figure.add_subplot(111)
195
+ plot_type = self.plot_combo.currentText()
196
+
197
+ if 'Line' in plot_type:
198
+ self.current_ax.plot(x, data)
199
+ elif 'Scatter' in plot_type:
200
+ self.current_ax.scatter(x, data)
201
+
202
+ self.current_ax.set_xlabel(self.xlabel_edit.text())
203
+ self.current_ax.set_ylabel(self.ylabel_edit.text())
204
+ self.figure.suptitle(self.title_edit.text())
205
+
206
+ def plot_2d(self):
207
+ x, y, data = self.data_info
208
+ self.current_ax = self.figure.add_subplot(111)
209
+ plot_type = self.plot_combo.currentText()
210
+
211
+ if 'Quantity' in plot_type:
212
+ img = self.current_ax.imshow(data.T, extent=(x[0], x[-1], y[0], y[-1]), origin='lower', aspect='auto')
213
+ self.figure.colorbar(img)
214
+ elif 'Integral' in plot_type:
215
+ avg = integrate(transverse_average(data), x[-1]/len(x))
216
+ self.current_ax.plot(x, avg)
217
+ elif 'Transverse' in plot_type:
218
+ avg = transverse_average(data)
219
+ self.current_ax.plot(x, avg)
220
+ elif 'Phase' in plot_type:
221
+ img = self.current_ax.imshow(np.abs(-data.T), extent=(x[0], x[-1], y[0], y[-1]), origin='lower', aspect='auto', norm=LogNorm())
222
+ self.figure.colorbar(img)
223
+
224
+ self.current_ax.set_xlabel(self.xlabel_edit.text())
225
+ self.current_ax.set_ylabel(self.ylabel_edit.text())
226
+ self.figure.suptitle(self.title_edit.text())
227
+
228
+ def update_plot_menu(self):
229
+
230
+ # Save current plot type before clearing
231
+ current_plot_type = self.plot_combo.currentText()
232
+ self.plot_combo.clear()
233
+
234
+ # Determine items based on dimensions
235
+ if self.dims == 1:
236
+ items = ['Line Plot', 'Scatter Plot']
237
+ elif self.dims == 2:
238
+ items = ['Quantity Plot', 'T. Average Integral', 'Transverse Average', 'Phase Space']
239
+ else:
240
+ items = []
241
+
242
+ self.plot_combo.addItems(items)
243
+
244
+ # Restore previous selection if possible
245
+ if current_plot_type in items:
246
+ self.plot_combo.setCurrentText(current_plot_type)
247
+ else:
248
+ self.plot_combo.setCurrentIndex(0 if items else -1)
249
+
250
+ def save_plot(self):
251
+ file_dialog = QFileDialog()
252
+ filepath, _ = file_dialog.getSaveFileName(
253
+ self, 'Save Plot', '', 'PNG Files (*.png);;PDF Files (*.pdf)'
254
+ )
255
+
256
+ if filepath:
257
+ self.figure.savefig(filepath, dpi=800, bbox_inches='tight')
258
+
259
+ def LAVA():
260
+ app = QApplication(sys.argv)
261
+ window = LAVA_Qt()
262
+ window.show()
263
+ sys.exit(app.exec())
264
+
265
+ if __name__ == '__main__':
266
+ LAVA()
File without changes
@@ -0,0 +1,243 @@
1
+ from ..utils import *
2
+ from ..data.simulation import Simulation
3
+ from .postprocess import PostProcess
4
+ from ..data.diagnostic import Diagnostic
5
+
6
+ class Derivative_Simulation(PostProcess):
7
+ """
8
+ Class to compute the derivative of a diagnostic. Works as a wrapper for the Derivative_Diagnostic class.
9
+ Inherits from PostProcess to ensure all operation overloads work properly.
10
+
11
+ Parameters
12
+ ----------
13
+ simulation : Simulation
14
+ The simulation object.
15
+ type : str
16
+ The type of derivative to compute. Options are:
17
+ - 't' for time derivative.
18
+ - 'x1' for first spatial derivative.
19
+ - 'x2' for second spatial derivative.
20
+ - 'x3' for third spatial derivative.
21
+ - 'xx' for second spatial derivative in two axis.
22
+ - 'xt' for mixed derivative in time and one spatial axis.
23
+ - 'tx' for mixed derivative in one spatial axis and time.
24
+ axis : int or tuple
25
+ The axis to compute the derivative. Only used for 'xx', 'xt' and 'tx' types.
26
+
27
+ Example
28
+ -------
29
+ >>> sim = Simulation('electrons', 'path/to/simulation')
30
+ >>> derivative = Derivative(sim, 'x1')
31
+ >>> deriv_e1_wrt_x1 = derivative['e1']
32
+ """
33
+
34
+ def __init__(self, simulation, type, axis=None):
35
+ super().__init__(f"Derivative({type})")
36
+ if not isinstance(simulation, Simulation):
37
+ raise ValueError("Simulation must be a Simulation object.")
38
+ self._simulation = simulation
39
+ self._type = type
40
+ self._axis = axis
41
+ self._derivatives_computed = {}
42
+ self._species_handler = {}
43
+
44
+ def __getitem__(self, key):
45
+ if key in self._simulation._species:
46
+ if key not in self._species_handler:
47
+ self._species_handler[key] = Derivative_Species_Handler(self._simulation[key], self._type, self._axis)
48
+ return self._species_handler[key]
49
+
50
+ if key not in self._derivatives_computed:
51
+ self._derivatives_computed[key] = Derivative_Diagnostic(diagnostic=self._simulation[key],
52
+ type=self._type, axis=self._axis)
53
+ return self._derivatives_computed[key]
54
+
55
+ def delete_all(self):
56
+ self._derivatives_computed = {}
57
+
58
+ def delete(self, key):
59
+ if key in self._derivatives_computed:
60
+ del self._derivatives_computed[key]
61
+ else:
62
+ print(f"Derivative {key} not found in simulation")
63
+
64
+ def process(self, diagnostic):
65
+ """Apply derivative to a diagnostic"""
66
+ return Derivative_Diagnostic(diagnostic, self._type, self._axis)
67
+
68
+
69
+ class Derivative_Diagnostic(Diagnostic):
70
+ """
71
+ Auxiliar class to compute the derivative of a diagnostic, for it to be similar in behavior to a Diagnostic object.
72
+ Inherits directly from Diagnostic to ensure all operation overloads work properly.
73
+
74
+ Parameters
75
+ ----------
76
+ diagnostic : Diagnostic
77
+ The diagnostic object.
78
+ type : str
79
+ The type of derivative to compute. Options are: 't', 'x1', 'x2', 'x3', 'xx', 'xt' and 'tx'.
80
+ axis : int or tuple
81
+ The axis to compute the derivative. Only used for 'xx', 'xt' and 'tx' types
82
+
83
+ Methods
84
+ -------
85
+ load_all()
86
+ Load all the data and compute the derivative.
87
+ __getitem__(index)
88
+ Get data at a specific index.
89
+
90
+ Example
91
+ -------
92
+ >>> sim = Simulation('electrons', 'path/to/simulation')
93
+ >>> diag = sim['e1']
94
+ >>> derivative = Derivative_Diagnostic(diag, 'x1')
95
+ """
96
+
97
+ def __init__(self, diagnostic, type, axis=None):
98
+ # Initialize using parent's __init__ with the same species
99
+ if hasattr(diagnostic, '_species'):
100
+ super().__init__(simulation_folder=diagnostic._simulation_folder if hasattr(diagnostic, '_simulation_folder') else None,
101
+ species=diagnostic._species)
102
+ else:
103
+ super().__init__(None)
104
+
105
+ # self._name = f"D[{diagnostic._name}, {type}]"
106
+ self._diag = diagnostic
107
+ self._type = type
108
+ self._axis = axis if axis is not None else diagnostic._axis
109
+ self._data = None
110
+ self._all_loaded = False
111
+
112
+ # Copy all relevant attributes from diagnostic
113
+ for attr in ['_dt', '_dx', '_ndump', '_axis', '_nx', '_x', '_grid', '_dim', '_maxiter']:
114
+ if hasattr(diagnostic, attr):
115
+ setattr(self, attr, getattr(diagnostic, attr))
116
+
117
+ def load_metadata(self):
118
+ """Copy metadata from original diagnostic to ensure consistency"""
119
+ self._dt = self._diag._dt
120
+ self._dx = self._diag._dx
121
+ self._ndump = self._diag._ndump
122
+ self._axis = self._diag._axis
123
+ self._nx = self._diag._nx
124
+ self._x = self._diag._x
125
+ self._grid = self._diag._grid
126
+ self._dim = self._diag._dim
127
+ self._maxiter = self._diag._maxiter
128
+
129
+ def load_all(self):
130
+ """Load all data and compute the derivative"""
131
+ if self._data is not None:
132
+ print("Using cached derivative")
133
+ return self._data
134
+
135
+ if not hasattr(self._diag, '_data') or self._diag._data is None:
136
+ self._diag.load_all()
137
+
138
+ if self._type == "t":
139
+ result = np.gradient(self._diag._data, self._diag._dt * self._diag._ndump, axis=0, edge_order=2)
140
+
141
+ elif self._type == "x1":
142
+ if self._dim == 1:
143
+ result = np.gradient(self._diag._data, self._diag._dx, axis=1, edge_order=2)
144
+ else:
145
+ result = np.gradient(self._diag._data, self._diag._dx[0], axis=1, edge_order=2)
146
+
147
+ elif self._type == "x2":
148
+ result = np.gradient(self._diag._data, self._diag._dx[0], axis=2, edge_order=2)
149
+
150
+ elif self._type == "x3":
151
+ result = np.gradient(self._diag._data, self._diag._dx[0], axis=3, edge_order=2)
152
+
153
+ elif self._type == "xx":
154
+ if len(self._axis) != 2:
155
+ raise ValueError("Axis must be a tuple with two elements.")
156
+ result = np.gradient(np.gradient(self._diag._data, self._diag._dx[self._axis[0]-1], axis=self._axis[0], edge_order=2),
157
+ self._diag._dx[self._axis[1]-1], axis=self._axis[1], edge_order=2)
158
+
159
+ elif self._type == "xt":
160
+ if not isinstance(self._axis, int):
161
+ raise ValueError("Axis must be an integer.")
162
+ result = np.gradient(np.gradient(self._diag._data, self._diag._dt, axis=0, edge_order=2),
163
+ self._diag._dx[self._axis-1], axis=self._axis[0], edge_order=2)
164
+
165
+ elif self._type == "tx":
166
+ if not isinstance(self._axis, int):
167
+ raise ValueError("Axis must be an integer.")
168
+ result = np.gradient(np.gradient(self._diag._data, self._diag._dx[self._axis-1], axis=self._axis, edge_order=2),
169
+ self._diag._dt, axis=0, edge_order=2)
170
+ else:
171
+ raise ValueError("Invalid derivative type.")
172
+
173
+ # Store the result in the cache
174
+ self._all_loaded = True
175
+ self._data = result
176
+ return self._data
177
+
178
+ def _data_generator(self, index):
179
+ """Generate data for a specific index on-demand"""
180
+ if self._type == "x1":
181
+ if self._dim == 1:
182
+ yield np.gradient(self._diag[index], self._diag._dx, axis=0, edge_order=2)
183
+ else:
184
+ yield np.gradient(self._diag[index], self._diag._dx[0], axis=0, edge_order=2)
185
+
186
+ elif self._type == "x2":
187
+ yield np.gradient(self._diag[index], self._diag._dx[1], axis=1, edge_order=2)
188
+
189
+ elif self._type == "x3":
190
+ yield np.gradient(self._diag[index], self._diag._dx[2], axis=2, edge_order=2)
191
+
192
+ elif self._type == "t":
193
+ if index == 0:
194
+ yield (-3 * self._diag[index] + 4 * self._diag[index + 1] - self._diag[index + 2]) / (2 * self._diag._dt * self._diag._ndump)
195
+ elif index == self._diag._maxiter - 1:
196
+ yield (3 * self._diag[index] - 4 * self._diag[index - 1] + self._diag[index - 2]) / (2 * self._diag._dt * self._diag._ndump)
197
+ else:
198
+ yield (self._diag[index + 1] - self._diag[index - 1]) / (2 * self._diag._dt * self._diag._ndump)
199
+ else:
200
+ raise ValueError("Invalid derivative type. Use 'x1', 'x2', 'x3' or 't'.")
201
+
202
+ def __getitem__(self, index):
203
+ """Get data at a specific index"""
204
+ if self._all_loaded and self._data is not None:
205
+ return self._data[index]
206
+
207
+ if isinstance(index, int):
208
+ return next(self._data_generator(index))
209
+ elif isinstance(index, slice):
210
+ start = 0 if index.start is None else index.start
211
+ step = 1 if index.step is None else index.step
212
+ stop = self._diag._maxiter if index.stop is None else index.stop
213
+ return np.array([next(self._data_generator(i)) for i in range(start, stop, step)])
214
+ else:
215
+ raise ValueError("Invalid index type. Use int or slice.")
216
+
217
+ class Derivative_Species_Handler:
218
+ """
219
+ Class to handle derivatives for a species.
220
+ Acts as a wrapper for the Derivative_Diagnostic class.
221
+
222
+ Not intended to be used directly, but through the Derivative_Simulation class.
223
+
224
+ Parameters
225
+ ----------
226
+ species_handler : Species_Handler
227
+ The species handler object.
228
+ type : str
229
+ The type of derivative to compute. Options are: 't', 'x1', 'x2', 'x3', 'xx', 'xt' and 'tx'.
230
+ axis : int or tuple
231
+ The axis to compute the derivative. Only used for 'xx', 'xt' and 'tx' types.
232
+ """
233
+ def __init__(self, species_handler, type, axis=None):
234
+ self._species_handler = species_handler
235
+ self._type = type
236
+ self._axis = axis
237
+ self._derivatives_computed = {}
238
+
239
+ def __getitem__(self, key):
240
+ if key not in self._derivatives_computed:
241
+ diag = self._species_handler[key]
242
+ self._derivatives_computed[key] = Derivative_Diagnostic(diag, self._type, self._axis)
243
+ return self._derivatives_computed[key]