osiris-utils 1.1.2__py3-none-any.whl → 1.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,979 @@
1
+ """
2
+ The utilities on data.py are cool but not useful when you want to work with whole data of a simulation instead
3
+ of just a single file. This is what this file is for - deal with ''folders'' of data.
4
+
5
+ Took some inspiration from Diogo and Madox's work.
6
+
7
+ This would be awsome to compute time derivatives.
8
+ """
9
+ import numpy as np
10
+ import os
11
+
12
+ from .data import OsirisGridFile
13
+ import tqdm
14
+ import matplotlib.pyplot as plt
15
+ import warnings
16
+ from typing import Literal
17
+ from ..decks.decks import InputDeckIO, deval
18
+
19
+ def get_dimension_from_deck(deck: InputDeckIO) -> int:
20
+ for dim in range(1, 4):
21
+ try:
22
+ deck.get_param(section='grid', param=f'nx_p(1:{dim})')
23
+ return dim
24
+ except:
25
+ continue
26
+
27
+ raise Exception('Error parsing grid dimension')
28
+
29
+ OSIRIS_DENSITY = ["n"]
30
+ OSIRIS_SPECIE_REPORTS = ["charge", "q1", "q2", "q3", "j1", "j2", "j3"]
31
+ OSIRIS_SPECIE_REP_UDIST = [
32
+ "vfl1",
33
+ "vfl2",
34
+ "vfl3",
35
+ "ufl1",
36
+ "ufl2",
37
+ "ufl3",
38
+ "P11",
39
+ "P12",
40
+ "P13",
41
+ "P22",
42
+ "P23",
43
+ "P33",
44
+ "T11",
45
+ "T12",
46
+ "T13",
47
+ "T22",
48
+ "T23",
49
+ "T33",
50
+ ]
51
+ OSIRIS_FLD = ["e1", "e2", "e3", "b1", "b2", "b3", "part_e1", "part_e2", "epart_3", "part_b1", "part_b2", "part_b3", "ext_e1", "ext_e2", "ext_e3", "ext_b1", "ext_b2", "ext_b3"]
52
+ OSIRIS_PHA = ["p1x1", "p1x2", "p1x3", "p2x1", "p2x2", "p2x3", "p3x1", "p3x2", "p3x3", "gammax1", "gammax2", "gammax3"] # there may be more that I don't know
53
+ OSIRIS_ALL = OSIRIS_DENSITY + OSIRIS_SPECIE_REPORTS + OSIRIS_SPECIE_REP_UDIST + OSIRIS_FLD + OSIRIS_PHA
54
+
55
+ def which_quantities():
56
+ print("Available quantities:")
57
+ print(OSIRIS_ALL)
58
+
59
+
60
+ class Diagnostic:
61
+ """
62
+ Class to handle diagnostics. This is the "base" class of the code. Diagnostics can be loaded from OSIRIS output files, but are also created when performing operations with other diagnostics.
63
+ Post-processed quantities are also considered diagnostics. This way, we can perform operations with them as well.
64
+
65
+ Parameters
66
+ ----------
67
+ species : str
68
+ The species to handle the diagnostics.
69
+ simulation_folder : str
70
+ The path to the simulation folder. This is the path to the folder where the input deck is located.
71
+
72
+ Attributes
73
+ ----------
74
+ species : str
75
+ The species to handle the diagnostics.
76
+ dx : np.ndarray(float) or float
77
+ The grid spacing in each direction. If the dimension is 1, this is a float. If the dimension is 2 or 3, this is a np.ndarray.
78
+ nx : np.ndarray(int) or int
79
+ The number of grid points in each direction. If the dimension is 1, this is a int. If the dimension is 2 or 3, this is a np.ndarray.
80
+ x : np.ndarray
81
+ The grid points.
82
+ dt : float
83
+ The time step.
84
+ grid : np.ndarray
85
+ The grid boundaries.
86
+ axis : dict
87
+ The axis information. Each key is a direction and the value is a dictionary with the keys "name", "long_name", "units" and "plot_label".
88
+ units : str
89
+ The units of the diagnostic. This info may not be available for all diagnostics, ie, diagnostics resulting from operations and postprocessing.
90
+ name : str
91
+ The name of the diagnostic. This info may not be available for all diagnostics, ie, diagnostics resulting from operations and postprocessing.
92
+ label : str
93
+ The label of the diagnostic. This info may not be available for all diagnostics, ie, diagnostics resulting from operations and postprocessing.
94
+ dim : int
95
+ The dimension of the diagnostic.
96
+ ndump : int
97
+ The number of steps between dumps.
98
+ maxiter : int
99
+ The maximum number of iterations.
100
+ tunits : str
101
+ The time units.
102
+ path : str
103
+ The path to the diagnostic.
104
+ simulation_folder : str
105
+ The path to the simulation folder.
106
+ all_loaded : bool
107
+ If the data is already loaded into memory. This is useful to avoid loading the data multiple times.
108
+ data : np.ndarray
109
+ The diagnostic data. This is created only when the data is loaded into memory.
110
+
111
+ Methods
112
+ -------
113
+ get_quantity(quantity)
114
+ Get the data for a given quantity.
115
+ load_all()
116
+ Load all data into memory.
117
+ load(index)
118
+ Load data for a given index.
119
+ __getitem__(index)
120
+ Get data for a given index. Does not load the data into memory.
121
+ __iter__()
122
+ Iterate over the data. Does not load the data into memory.
123
+ __add__(other)
124
+ Add two diagnostics.
125
+ __sub__(other)
126
+ Subtract two diagnostics.
127
+ __mul__(other)
128
+ Multiply two diagnostics.
129
+ __truediv__(other)
130
+ Divide two diagnostics.
131
+ __pow__(other)
132
+ Power of a diagnostic.
133
+ plot_3d(idx, scale_type="default", boundaries=None)
134
+ Plot a 3D scatter plot of the diagnostic data.
135
+ time(index)
136
+ Get the time for a given index.
137
+
138
+ Examples
139
+ --------
140
+ >>> sim = Simulation("electrons", "path/to/simulation")
141
+ >>> sim.get_quantity("charge")
142
+ >>> sim.load_all()
143
+ >>> print(sim.data.shape)
144
+ (100, 100, 100)
145
+
146
+ >>> sim = Simulation("electrons", "path/to/simulation")
147
+ >>> sim.get_quantity("charge")
148
+ >>> sim[0]
149
+ array with the data for the first timestep
150
+ """
151
+ def __init__(self, simulation_folder=None, species=None, input_deck=None):
152
+ self._species = species if species else None
153
+
154
+ self._dx = None
155
+ self._nx = None
156
+ self._x = None
157
+ self._dt = None
158
+ self._grid = None
159
+ self._axis = None
160
+ self._units = None
161
+ self._name = None
162
+ self._label = None
163
+ self._dim = None
164
+ self._ndump = None
165
+ self._maxiter = None
166
+ self._tunits = None
167
+
168
+ if simulation_folder:
169
+ self._simulation_folder = simulation_folder
170
+ if not os.path.isdir(simulation_folder):
171
+ raise FileNotFoundError(f"Simulation folder {simulation_folder} not found.")
172
+ else:
173
+ self._simulation_folder = None
174
+
175
+ # load input deck if available
176
+ if input_deck:
177
+ self._input_deck = input_deck
178
+ else:
179
+ self._input_deck = None
180
+
181
+ self._all_loaded = False
182
+ self._quantity = None
183
+
184
+ def get_quantity(self, quantity):
185
+ """
186
+ Get the data for a given quantity.
187
+
188
+ Parameters
189
+ ----------
190
+ quantity : str
191
+ The quantity to get the data.
192
+ """
193
+ self._quantity = quantity
194
+
195
+ if self._quantity not in OSIRIS_ALL:
196
+ raise ValueError(f"Invalid quantity {self._quantity}. Use which_quantities() to see the available quantities.")
197
+ if self._quantity in OSIRIS_SPECIE_REP_UDIST:
198
+ if self._species is None:
199
+ raise ValueError("Species not set.")
200
+ self._get_moment(self._species.name, self._quantity)
201
+ elif self._quantity in OSIRIS_SPECIE_REPORTS:
202
+ if self._species is None:
203
+ raise ValueError("Species not set.")
204
+ self._get_density(self._species.name, self._quantity)
205
+ elif self._quantity in OSIRIS_FLD:
206
+ self._get_field(self._quantity)
207
+ elif self._quantity in OSIRIS_PHA:
208
+ if self._species is None:
209
+ raise ValueError("Species not set.")
210
+ self._get_phase_space(self._species.name, self._quantity)
211
+ elif self._quantity == "n":
212
+ if self._species is None:
213
+ raise ValueError("Species not set.")
214
+ self._get_density(self._species.name, "charge")
215
+ else:
216
+ raise ValueError(f"Invalid quantity {self._quantity}. Or it's not implemented yet (this may happen for phase space quantities).")
217
+
218
+ def _get_moment(self, species, moment):
219
+ if self._simulation_folder is None:
220
+ raise ValueError("Simulation folder not set. If you're using CustomDiagnostic, this method is not available.")
221
+ self._path = f"{self._simulation_folder}/MS/UDIST/{species}/{moment}/"
222
+ self._file_template = os.listdir(self._path)[0][:-9]
223
+ self._maxiter = len(os.listdir(self._path))
224
+ self._load_attributes(self._file_template, self._input_deck)
225
+
226
+ def _get_field(self, field):
227
+ if self._simulation_folder is None:
228
+ raise ValueError("Simulation folder not set. If you're using CustomDiagnostic, this method is not available.")
229
+ self._path = f"{self._simulation_folder}/MS/FLD/{field}/"
230
+ self._file_template = os.listdir(self._path)[0][:-9]
231
+ self._maxiter = len(os.listdir(self._path))
232
+ self._load_attributes(self._file_template, self._input_deck)
233
+
234
+ def _get_density(self, species, quantity):
235
+ if self._simulation_folder is None:
236
+ raise ValueError("Simulation folder not set. If you're using CustomDiagnostic, this method is not available.")
237
+ self._path = f"{self._simulation_folder}/MS/DENSITY/{species}/{quantity}/"
238
+ self._file_template = os.listdir(self._path)[0][:-9]
239
+ self._maxiter = len(os.listdir(self._path))
240
+ self._load_attributes(self._file_template, self._input_deck)
241
+
242
+ def _get_phase_space(self, species, type):
243
+ if self._simulation_folder is None:
244
+ raise ValueError("Simulation folder not set. If you're using CustomDiagnostic, this method is not available.")
245
+ self._path = f"{self._simulation_folder}/MS/PHA/{type}/{species}/"
246
+ self._file_template = os.listdir(self._path)[0][:-9]
247
+ self._maxiter = len(os.listdir(self._path))
248
+ self._load_attributes(self._file_template, self._input_deck)
249
+
250
+ def _load_attributes(self, file_template, input_deck): # this will be replaced by reading the input deck
251
+ # This can go wrong! NDUMP
252
+ # if input_deck is not None:
253
+ # self._dt = float(input_deck["time_step"][0]["dt"])
254
+ # self._ndump = int(input_deck["time_step"][0]["ndump"])
255
+ # self._dim = get_dimension_from_deck(input_deck)
256
+ # self._nx = np.array(list(map(int, input_deck["grid"][0][f"nx_p(1:{self._dim})"].split(','))))
257
+ # xmin = [deval(input_deck["space"][0][f"xmin(1:{self._dim})"].split(',')[i]) for i in range(self._dim)]
258
+ # xmax = [deval(input_deck["space"][0][f"xmax(1:{self._dim})"].split(',')[i]) for i in range(self._dim)]
259
+ # self._grid = np.array([[xmin[i], xmax[i]] for i in range(self._dim)])
260
+ # self._dx = (self._grid[:,1] - self._grid[:,0])/self._nx
261
+ # self._x = [np.arange(self._grid[i,0], self._grid[i,1], self._dx[i]) for i in range(self._dim)]
262
+
263
+ try:
264
+ path_file1 = os.path.join(self._path, file_template + "000001.h5")
265
+ dump1 = OsirisGridFile(path_file1)
266
+ self._dx = dump1.dx
267
+ self._nx = dump1.nx
268
+ self._x = dump1.x
269
+ self._dt = dump1.dt
270
+ self._grid = dump1.grid
271
+ self._axis = dump1.axis
272
+ self._units = dump1.units
273
+ self._name = dump1.name
274
+ self._label = dump1.label
275
+ self._dim = dump1.dim
276
+ self._ndump = dump1.iter
277
+ self._tunits = dump1.time[1]
278
+ except:
279
+ pass
280
+
281
+ def _data_generator(self, index):
282
+ if self._simulation_folder is None:
283
+ raise ValueError("Simulation folder not set.")
284
+ file = os.path.join(self._path, self._file_template + f"{index:06d}.h5")
285
+ data_object = OsirisGridFile(file)
286
+ yield data_object.data if self._quantity not in OSIRIS_DENSITY else self._species.rqm * data_object.data
287
+
288
+ def load_all(self):
289
+ """
290
+ Load all data into memory (all iterations).
291
+
292
+ Returns
293
+ -------
294
+ data : np.ndarray
295
+ The data for all iterations. Also stored in the attribute data.
296
+ """
297
+ # If data is already loaded, don't do anything
298
+ if self._all_loaded and self._data is not None:
299
+ print("Data already loaded.")
300
+ return self._data
301
+
302
+ # If this is a derived diagnostic without files
303
+ if self._simulation_folder is None:
304
+ # If it has a data generator but no direct files
305
+ try:
306
+ print("This appears to be a derived diagnostic. Loading data from generators...")
307
+ # Get the maximum size from the diagnostic attributes
308
+ if hasattr(self, '_maxiter') and self._maxiter is not None:
309
+ size = self._maxiter
310
+ else:
311
+ # Try to infer from a related diagnostic
312
+ if hasattr(self, '_diag') and hasattr(self._diag, '_maxiter'):
313
+ size = self._diag._maxiter
314
+ else:
315
+ # Default to a reasonable number if we can't determine
316
+ size = 100
317
+ print(f"Warning: Could not determine timestep count, using {size}.")
318
+
319
+ # Load data for all timesteps using the generator - this may take a while
320
+ self._data = np.stack([self[i] for i in tqdm.tqdm(range(size), desc="Loading data")])
321
+ self._all_loaded = True
322
+ return self._data
323
+
324
+ except Exception as e:
325
+ raise ValueError(f"Could not load derived diagnostic data: {str(e)}")
326
+
327
+ # Original implementation for file-based diagnostics
328
+ print("Loading all data from files. This may take a while.")
329
+ size = len(sorted(os.listdir(self._path)))
330
+ self._data = np.stack([self[i] for i in tqdm.tqdm(range(size), desc="Loading data")])
331
+ self._all_loaded = True
332
+ return self._data
333
+
334
+ def unload(self):
335
+ """
336
+ Unload data from memory. This is useful to free memory when the data is not needed anymore.
337
+ """
338
+ print("Unloading data from memory.")
339
+ if self._all_loaded == False:
340
+ print("Data is not loaded.")
341
+ return
342
+ self._data = None
343
+ self._all_loaded = False
344
+
345
+ def load(self, index):
346
+ """
347
+ Load data for a given index into memory. Not recommended. Use load_all for all data or access via generator or index for better performance.
348
+ """
349
+ self._data = next(self._data_generator(index))
350
+
351
+ def __getitem__(self, index):
352
+ # For derived diagnostics with cached data
353
+ if self._all_loaded and self._data is not None:
354
+ return self._data[index]
355
+
356
+ # For standard diagnostics with files
357
+ if isinstance(index, int):
358
+ if self._simulation_folder is not None and hasattr(self, '_data_generator'):
359
+ return next(self._data_generator(index))
360
+
361
+ # For derived diagnostics with custom generators
362
+ if hasattr(self, '_data_generator') and callable(self._data_generator):
363
+ return next(self._data_generator(index))
364
+
365
+ elif isinstance(index, slice):
366
+ start = 0 if index.start is None else index.start
367
+ step = 1 if index.step is None else index.step
368
+
369
+ if index.stop is None:
370
+ if hasattr(self, '_maxiter') and self._maxiter is not None:
371
+ stop = self._maxiter
372
+ elif self._simulation_folder is not None and hasattr(self, '_path'):
373
+ stop = len(sorted(os.listdir(self._path)))
374
+ else:
375
+ stop = 100 # Default if we can't determine
376
+ print(f"Warning: Could not determine iteration count for iteration, using {stop}.")
377
+ else:
378
+ stop = index.stop
379
+
380
+ indices = range(start, stop, step)
381
+ if self._simulation_folder is not None and hasattr(self, '_data_generator'):
382
+ return np.stack([next(self._data_generator(i)) for i in indices])
383
+ elif hasattr(self, '_data_generator') and callable(self._data_generator):
384
+ return np.stack([next(self._data_generator(i)) for i in indices])
385
+
386
+ # If we get here, we don't know how to get data for this index
387
+ raise ValueError(f"Cannot retrieve data for this diagnostic at index {index}. No data loaded and no generator available.")
388
+
389
+ def __iter__(self):
390
+ # If this is a file-based diagnostic
391
+ if self._simulation_folder is not None:
392
+ for i in range(len(sorted(os.listdir(self._path)))):
393
+ yield next(self._data_generator(i))
394
+
395
+ # If this is a derived diagnostic and data is already loaded
396
+ elif self._all_loaded and self._data is not None:
397
+ for i in range(self._data.shape[0]):
398
+ yield self._data[i]
399
+
400
+ # If this is a derived diagnostic with custom generator but no loaded data
401
+ elif hasattr(self, '_data_generator') and callable(self._data_generator):
402
+ # Determine how many iterations to go through
403
+ max_iter = self._maxiter
404
+ if max_iter is None:
405
+ if hasattr(self, '_diag') and hasattr(self._diag, '_maxiter'):
406
+ max_iter = self._diag._maxiter
407
+ else:
408
+ max_iter = 100 # Default if we can't determine
409
+ print(f"Warning: Could not determine iteration count for iteration, using {max_iter}.")
410
+
411
+ for i in range(max_iter):
412
+ yield next(self._data_generator(i))
413
+
414
+ # If we don't know how to handle this
415
+ else:
416
+ raise ValueError("Cannot iterate over this diagnostic. No data loaded and no generator available.")
417
+
418
+ def __add__(self, other):
419
+ if isinstance(other, (int, float, np.ndarray)):
420
+ result = Diagnostic(species=self._species)
421
+
422
+ for attr in ['_dx', '_nx', '_x', '_dt', '_grid', '_axis', '_dim', '_ndump', '_maxiter']:
423
+ if hasattr(self, attr):
424
+ setattr(result, attr, getattr(self, attr))
425
+
426
+ # Make sure _maxiter is set even for derived diagnostics
427
+ if not hasattr(result, '_maxiter') or result._maxiter is None:
428
+ if hasattr(self, '_maxiter') and self._maxiter is not None:
429
+ result._maxiter = self._maxiter
430
+
431
+ # result._name = self._name + " + " + str(other) if isinstance(other, (int, float)) else self._name + " + np.ndarray"
432
+
433
+ if self._all_loaded:
434
+ result._data = self._data + other
435
+ result._all_loaded = True
436
+ else:
437
+ def gen_scalar_add(original_gen, scalar):
438
+ for val in original_gen:
439
+ yield val + scalar
440
+
441
+ original_generator = self._data_generator
442
+ result._data_generator = lambda index: gen_scalar_add(original_generator(index), other)
443
+
444
+ return result
445
+
446
+ elif isinstance(other, Diagnostic):
447
+ result = Diagnostic(species=self._species)
448
+
449
+ for attr in ['_dx', '_nx', '_x', '_dt', '_grid', '_axis', '_dim', '_ndump', '_maxiter']:
450
+ if hasattr(self, attr):
451
+ setattr(result, attr, getattr(self, attr))
452
+
453
+ if not hasattr(result, '_maxiter') or result._maxiter is None:
454
+ if hasattr(self, '_maxiter') and self._maxiter is not None:
455
+ result._maxiter = self._maxiter
456
+
457
+ # result._name = self._name + " + " + str(other._name)
458
+
459
+ if self._all_loaded:
460
+ other.load_all()
461
+ result._data = self._data + other._data
462
+ result._all_loaded = True
463
+ else:
464
+ def gen_diag_add(original_gen1, original_gen2):
465
+ for val1, val2 in zip(original_gen1, original_gen2):
466
+ yield val1 + val2
467
+
468
+ original_generator = self._data_generator
469
+ other_generator = other._data_generator
470
+ result._data_generator = lambda index: gen_diag_add(original_generator(index), other_generator(index))
471
+
472
+ return result
473
+
474
+ def __sub__(self, other):
475
+ if isinstance(other, (int, float, np.ndarray)):
476
+ result = Diagnostic(species=self._species)
477
+
478
+ for attr in ['_dx', '_nx', '_x', '_dt', '_grid', '_axis', '_dim', '_ndump', '_maxiter']:
479
+ if hasattr(self, attr):
480
+ setattr(result, attr, getattr(self, attr))
481
+
482
+ if not hasattr(result, '_maxiter') or result._maxiter is None:
483
+ if hasattr(self, '_maxiter') and self._maxiter is not None:
484
+ result._maxiter = self._maxiter
485
+
486
+ # result._name = self._name + " - " + str(other) if isinstance(other, (int, float)) else self._name + " - np.ndarray"
487
+
488
+ if self._all_loaded:
489
+ result._data = self._data - other
490
+ result._all_loaded = True
491
+ else:
492
+ def gen_scalar_sub(original_gen, scalar):
493
+ for val in original_gen:
494
+ yield val - scalar
495
+
496
+ original_generator = self._data_generator
497
+ result._data_generator = lambda index: gen_scalar_sub(original_generator(index), other)
498
+
499
+ return result
500
+
501
+ elif isinstance(other, Diagnostic):
502
+
503
+
504
+ result = Diagnostic(species=self._species)
505
+
506
+ for attr in ['_dx', '_nx', '_x', '_dt', '_grid', '_axis', '_dim', '_ndump', '_maxiter']:
507
+ if hasattr(self, attr):
508
+ setattr(result, attr, getattr(self, attr))
509
+
510
+ if not hasattr(result, '_maxiter') or result._maxiter is None:
511
+ if hasattr(self, '_maxiter') and self._maxiter is not None:
512
+ result._maxiter = self._maxiter
513
+
514
+ # result._name = self._name + " - " + str(other._name)
515
+
516
+ if self._all_loaded:
517
+ other.load_all()
518
+ result._data = self._data - other._data
519
+ result._all_loaded = True
520
+ else:
521
+ def gen_diag_sub(original_gen1, original_gen2):
522
+ for val1, val2 in zip(original_gen1, original_gen2):
523
+ yield val1 - val2
524
+
525
+ original_generator = self._data_generator
526
+ other_generator = other._data_generator
527
+ result._data_generator = lambda index: gen_diag_sub(original_generator(index), other_generator(index))
528
+
529
+ return result
530
+
531
+ def __mul__(self, other):
532
+ if isinstance(other, (int, float, np.ndarray)):
533
+ result = Diagnostic(species=self._species)
534
+
535
+ for attr in ['_dx', '_nx', '_x', '_dt', '_grid', '_axis', '_dim', '_ndump', '_maxiter']:
536
+ if hasattr(self, attr):
537
+ setattr(result, attr, getattr(self, attr))
538
+
539
+ if not hasattr(result, '_maxiter') or result._maxiter is None:
540
+ if hasattr(self, '_maxiter') and self._maxiter is not None:
541
+ result._maxiter = self._maxiter
542
+
543
+ # result._name = self._name + " * " + str(other) if isinstance(other, (int, float)) else self._name + " * np.ndarray"
544
+
545
+ if self._all_loaded:
546
+ result._data = self._data * other
547
+ result._all_loaded = True
548
+ else:
549
+ def gen_scalar_mul(original_gen, scalar):
550
+ for val in original_gen:
551
+ yield val * scalar
552
+
553
+ original_generator = self._data_generator
554
+ result._data_generator = lambda index: gen_scalar_mul(original_generator(index), other)
555
+
556
+ return result
557
+
558
+ elif isinstance(other, Diagnostic):
559
+ result = Diagnostic(species=self._species)
560
+
561
+ for attr in ['_dx', '_nx', '_x', '_dt', '_grid', '_axis', '_dim', '_ndump', '_maxiter']:
562
+ if hasattr(self, attr):
563
+ setattr(result, attr, getattr(self, attr))
564
+
565
+ if not hasattr(result, '_maxiter') or result._maxiter is None:
566
+ if hasattr(self, '_maxiter') and self._maxiter is not None:
567
+ result._maxiter = self._maxiter
568
+
569
+ # result._name = self._name + " * " + str(other._name)
570
+
571
+ if self._all_loaded:
572
+ other.load_all()
573
+ result._data = self._data * other._data
574
+ result._all_loaded = True
575
+ else:
576
+ def gen_diag_mul(original_gen1, original_gen2):
577
+ for val1, val2 in zip(original_gen1, original_gen2):
578
+ yield val1 * val2
579
+
580
+ original_generator = self._data_generator
581
+ other_generator = other._data_generator
582
+ result._data_generator = lambda index: gen_diag_mul(original_generator(index), other_generator(index))
583
+
584
+ return result
585
+
586
+ def __truediv__(self, other):
587
+ if isinstance(other, (int, float, np.ndarray)):
588
+ result = Diagnostic(species=self._species)
589
+
590
+ for attr in ['_dx', '_nx', '_x', '_dt', '_grid', '_axis', '_dim', '_ndump', '_maxiter']:
591
+ if hasattr(self, attr):
592
+ setattr(result, attr, getattr(self, attr))
593
+
594
+ if not hasattr(result, '_maxiter') or result._maxiter is None:
595
+ if hasattr(self, '_maxiter') and self._maxiter is not None:
596
+ result._maxiter = self._maxiter
597
+
598
+ # result._name = self._name + " / " + str(other) if isinstance(other, (int, float)) else self._name + " / np.ndarray"
599
+
600
+ if self._all_loaded:
601
+ result._data = self._data / other
602
+ result._all_loaded = True
603
+ else:
604
+ def gen_scalar_div(original_gen, scalar):
605
+ for val in original_gen:
606
+ yield val / scalar
607
+
608
+ original_generator = self._data_generator
609
+ result._data_generator = lambda index: gen_scalar_div(original_generator(index), other)
610
+
611
+ return result
612
+
613
+ elif isinstance(other, Diagnostic):
614
+
615
+ result = Diagnostic(species=self._species)
616
+
617
+ for attr in ['_dx', '_nx', '_x', '_dt', '_grid', '_axis', '_dim', '_ndump', '_maxiter']:
618
+ if hasattr(self, attr):
619
+ setattr(result, attr, getattr(self, attr))
620
+
621
+ if not hasattr(result, '_maxiter') or result._maxiter is None:
622
+ if hasattr(self, '_maxiter') and self._maxiter is not None:
623
+ result._maxiter = self._maxiter
624
+
625
+ # result._name = self._name + " / " + str(other._name)
626
+
627
+ if self._all_loaded:
628
+ other.load_all()
629
+ result._data = self._data / other._data
630
+ result._all_loaded = True
631
+ else:
632
+ def gen_diag_div(original_gen1, original_gen2):
633
+ for val1, val2 in zip(original_gen1, original_gen2):
634
+ yield val1 / val2
635
+
636
+ original_generator = self._data_generator
637
+ other_generator = other._data_generator
638
+ result._data_generator = lambda index: gen_diag_div(original_generator(index), other_generator(index))
639
+
640
+ return result
641
+
642
+ def __pow__(self, other):
643
+ # power by scalar
644
+ if isinstance(other, (int, float)):
645
+ result = Diagnostic(species=self._species)
646
+
647
+ for attr in ['_dx', '_nx', '_x', '_dt', '_grid', '_axis', '_dim', '_ndump', '_maxiter']:
648
+ if hasattr(self, attr):
649
+ setattr(result, attr, getattr(self, attr))
650
+
651
+ if not hasattr(result, '_maxiter') or result._maxiter is None:
652
+ if hasattr(self, '_maxiter') and self._maxiter is not None:
653
+ result._maxiter = self._maxiter
654
+
655
+ # result._name = self._name + " ^(" + str(other) + ")"
656
+ # result._label = self._label + rf"$ ^{other}$"
657
+
658
+ if self._all_loaded:
659
+ result._data = self._data ** other
660
+ result._all_loaded = True
661
+ else:
662
+ def gen_scalar_pow(original_gen, scalar):
663
+ for val in original_gen:
664
+ yield val ** scalar
665
+
666
+ original_generator = self._data_generator
667
+ result._data_generator = lambda index: gen_scalar_pow(original_generator(index), other)
668
+
669
+ return result
670
+
671
+ # power by another diagnostic
672
+ elif isinstance(other, Diagnostic):
673
+ raise ValueError("Power by another diagnostic is not supported. Why would you do that?")
674
+
675
+ def __radd__(self, other):
676
+ return self + other
677
+
678
+ def __rsub__(self, other): # I don't know if this is correct because I'm not sure if the order of the subtraction is correct
679
+ return - self + other
680
+
681
+ def __rmul__(self, other):
682
+ return self * other
683
+
684
+ def __rtruediv__(self, other): # division is not commutative
685
+ if isinstance(other, (int, float, np.ndarray)):
686
+ result = Diagnostic(species=self._species)
687
+
688
+ for attr in ['_dx', '_nx', '_x', '_dt', '_grid', '_axis', '_dim', '_ndump', '_maxiter']:
689
+ if hasattr(self, attr):
690
+ setattr(result, attr, getattr(self, attr))
691
+
692
+ if not hasattr(result, '_maxiter') or result._maxiter is None:
693
+ if hasattr(self, '_maxiter') and self._maxiter is not None:
694
+ result._maxiter = self._maxiter
695
+
696
+ # result._name = str(other) + " / " + self._name if isinstance(other, (int, float)) else "np.ndarray / " + self._name
697
+
698
+ if self._all_loaded:
699
+ result._data = other / self._data
700
+ result._all_loaded = True
701
+ else:
702
+ def gen_scalar_rdiv(scalar, original_gen):
703
+ for val in original_gen:
704
+ yield scalar / val
705
+
706
+ original_generator = self._data_generator
707
+ result._data_generator = lambda index: gen_scalar_rdiv(other, original_generator(index))
708
+
709
+ return result
710
+
711
+ elif isinstance(other, Diagnostic):
712
+
713
+ result = Diagnostic(species=self._species)
714
+
715
+ for attr in ['_dx', '_nx', '_x', '_dt', '_grid', '_axis', '_dim', '_ndump', '_maxiter']:
716
+ if hasattr(self, attr):
717
+ setattr(result, attr, getattr(self, attr))
718
+
719
+ if not hasattr(result, '_maxiter') or result._maxiter is None:
720
+ if hasattr(self, '_maxiter') and self._maxiter is not None:
721
+ result._maxiter = self._maxiter
722
+
723
+ # result._name = str(other._name) + " / " + self._name
724
+
725
+ if self._all_loaded:
726
+ other.load_all()
727
+ result._data = other._data / self._data
728
+ result._all_loaded = True
729
+ else:
730
+ def gen_diag_div(original_gen1, original_gen2):
731
+ for val1, val2 in zip(original_gen1, original_gen2):
732
+ yield val2 / val1
733
+
734
+ original_generator = self._data_generator
735
+ other_generator = other._data_generator
736
+ result._data_generator = lambda index: gen_diag_div(original_generator(index), other_generator(index))
737
+
738
+ return result
739
+
740
+ def plot_3d(self, idx, scale_type: Literal["zero_centered", "pos", "neg", "default"] = "default", boundaries: np.ndarray = None):
741
+ """
742
+ Plots a 3D scatter plot of the diagnostic data (grid data).
743
+
744
+ Parameters
745
+ ----------
746
+ idx : int
747
+ Index of the data to plot.
748
+ scale_type : Literal["zero_centered", "pos", "neg", "default"], optional
749
+ Type of scaling for the colormap:
750
+ - "zero_centered": Center colormap around zero.
751
+ - "pos": Colormap for positive values.
752
+ - "neg": Colormap for negative values.
753
+ - "default": Standard colormap.
754
+ boundaries : np.ndarray, optional
755
+ Boundaries to plot part of the data. (3,2) If None, uses the default grid boundaries.
756
+
757
+ Returns
758
+ -------
759
+ fig : matplotlib.figure.Figure
760
+ The figure object containing the plot.
761
+ ax : matplotlib.axes._subplots.Axes3DSubplot
762
+ The 3D axes object of the plot.
763
+
764
+ Example
765
+ -------
766
+ sim = ou.Simulation("electrons", "path/to/simulation")
767
+ fig, ax = sim["b3"].plot_3d(55, scale_type="zero_centered", boundaries= [[0, 40], [0, 40], [0, 20]])
768
+ plt.show()
769
+ """
770
+
771
+
772
+ if self._dim != 3:
773
+ raise ValueError("This method is only available for 3D diagnostics.")
774
+
775
+ if boundaries is None:
776
+ boundaries = self._grid
777
+
778
+ if not isinstance(boundaries, np.ndarray):
779
+ try :
780
+ boundaries = np.array(boundaries)
781
+ except:
782
+ boundaries = self._grid
783
+ warnings.warn("boundaries cannot be accessed as a numpy array with shape (3, 2), using default instead")
784
+
785
+ if boundaries.shape != (3, 2):
786
+ warnings.warn("boundaries should have shape (3, 2), using default instead")
787
+ boundaries = self._grid
788
+
789
+ # Load data
790
+ if self._all_loaded:
791
+ data = self._data[idx]
792
+ else:
793
+ data = self[idx]
794
+
795
+ X, Y, Z = np.meshgrid(self._x[0], self._x[1], self._x[2], indexing="ij")
796
+
797
+ # Flatten arrays for scatter plot
798
+ X_flat, Y_flat, Z_flat, = X.ravel(), Y.ravel(), Z.ravel()
799
+ data_flat = data.ravel()
800
+
801
+ # Apply filter: Keep only chosen points
802
+ mask = (X_flat > boundaries[0][0]) & (X_flat < boundaries[0][1]) & (Y_flat > boundaries[1][0]) & (Y_flat < boundaries[1][1]) & (Z_flat > boundaries[2][0]) & (Z_flat < boundaries[2][1])
803
+ X_cut, Y_cut, Z_cut, data_cut = X_flat[mask], Y_flat[mask], Z_flat[mask], data_flat[mask]
804
+
805
+ if scale_type == "zero_centered":
806
+ # Center colormap around zero
807
+ cmap = "seismic"
808
+ vmax = np.max(np.abs(data_flat)) # Find max absolute value
809
+ vmin = -vmax
810
+ elif scale_type == "pos":
811
+ cmap = "plasma"
812
+ vmax = np.max(data_flat)
813
+ vmin = 0
814
+
815
+ elif scale_type == "neg":
816
+ cmap = "plasma"
817
+ vmax = 0
818
+ vmin = np.min(data_flat)
819
+ else:
820
+ cmap = "plasma"
821
+ vmax = np.max(data_flat)
822
+ vmin = np.min(data_flat)
823
+
824
+ norm = plt.Normalize(vmin=vmin, vmax=vmax)
825
+
826
+ # Plot
827
+ fig = plt.figure(figsize=(10, 7))
828
+ ax = fig.add_subplot(111, projection="3d")
829
+
830
+ # Scatter plot with seismic colormap
831
+ sc = ax.scatter(X_cut, Y_cut, Z_cut, c=data_cut, cmap=cmap, norm=norm, alpha=1)
832
+
833
+ # Set limits to maintain full background
834
+ ax.set_xlim(*self._grid[0])
835
+ ax.set_ylim(*self._grid[1])
836
+ ax.set_zlim(*self._grid[2])
837
+
838
+ # Colorbar
839
+ cbar = plt.colorbar(sc, ax=ax, shrink=0.6)
840
+
841
+ # Labels
842
+ # TODO try to use a latex label instaead of _name
843
+ cbar.set_label(r"${}$".format(self._name) + r"$\ [{}]$".format(self._units))
844
+ ax.set_title(r"$t={:.2f}$".format(self.time(idx)[0]) + r"$\ [{}]$".format(self.time(idx)[1]))
845
+ ax.set_xlabel(r"${}$".format(self.axis[0]["long_name"]) + r"$\ [{}]$".format(self.axis[0]["units"]))
846
+ ax.set_ylabel(r"${}$".format(self.axis[1]["long_name"]) + r"$\ [{}]$".format(self.axis[1]["units"]))
847
+ ax.set_zlabel(r"${}$".format(self.axis[2]["long_name"]) + r"$\ [{}]$".format(self.axis[2]["units"]))
848
+
849
+ return fig, ax
850
+
851
+ # Getters
852
+ @property
853
+ def data(self):
854
+ if self._data is None:
855
+ raise ValueError("Data not loaded into memory. Use get_* method with load_all=True or access via generator/index.")
856
+ return self._data
857
+
858
+ @property
859
+ def dx(self):
860
+ return self._dx
861
+
862
+ @property
863
+ def nx(self):
864
+ return self._nx
865
+
866
+ @property
867
+ def x(self):
868
+ return self._x
869
+
870
+ @property
871
+ def dt(self):
872
+ return self._dt
873
+
874
+ @property
875
+ def grid(self):
876
+ return self._grid
877
+
878
+ @property
879
+ def axis(self):
880
+ return self._axis
881
+
882
+ @property
883
+ def units(self):
884
+ return self._units
885
+
886
+ @property
887
+ def tunits(self):
888
+ return self._tunits
889
+
890
+ @property
891
+ def name(self):
892
+ return self._name
893
+
894
+ @property
895
+ def dim(self):
896
+ return self._dim
897
+
898
+ @property
899
+ def path(self):
900
+ return self
901
+
902
+ @property
903
+ def simulation_folder(self):
904
+ return self._simulation_folder
905
+
906
+ @property
907
+ def ndump(self):
908
+ return self._ndump
909
+
910
+ @property
911
+ def all_loaded(self):
912
+ return self._all_loaded
913
+
914
+ @property
915
+ def maxiter(self):
916
+ return self._maxiter
917
+
918
+ @property
919
+ def label(self):
920
+ return self._label
921
+
922
+ @property
923
+ def quantity(self):
924
+ return self._quantity
925
+
926
+ def time(self, index):
927
+ return [index * self._dt * self._ndump, self._tunits]
928
+
929
+ @dx.setter
930
+ def dx(self, value):
931
+ self._dx = value
932
+
933
+ @nx.setter
934
+ def nx(self, value):
935
+ self._nx = value
936
+
937
+ @x.setter
938
+ def x(self, value):
939
+ self._x = value
940
+
941
+ @dt.setter
942
+ def dt(self, value):
943
+ self._dt = value
944
+
945
+ @grid.setter
946
+ def grid(self, value):
947
+ self._grid = value
948
+
949
+ @axis.setter
950
+ def axis(self, value):
951
+ self._axis = value
952
+
953
+ @units.setter
954
+ def units(self, value):
955
+ self._units = value
956
+
957
+ @tunits.setter
958
+ def tunits(self, value):
959
+ self._tunits = value
960
+
961
+ @name.setter
962
+ def name(self, value):
963
+ self._name = value
964
+
965
+ @dim.setter
966
+ def dim(self, value):
967
+ self._dim = value
968
+
969
+ @ndump.setter
970
+ def ndump(self, value):
971
+ self._ndump = value
972
+
973
+ @data.setter
974
+ def data(self, value):
975
+ self._data = value
976
+
977
+ @quantity.setter
978
+ def quantity(self, key):
979
+ self._quantity = key