fargopy 0.4.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fargopy/fields.py CHANGED
@@ -6,8 +6,10 @@ import fargopy
6
6
  ###############################################################
7
7
  # Required packages
8
8
  ###############################################################
9
+ import os
9
10
  import numpy as np
10
11
  import re
12
+ import re
11
13
  import pandas as pd
12
14
 
13
15
  import matplotlib.pyplot as plt
@@ -21,7 +23,7 @@ from scipy.interpolate import LinearNDInterpolator
21
23
  from scipy.spatial import cKDTree
22
24
 
23
25
 
24
- from joblib import Parallel, delayed
26
+ from joblib import Parallel, delayed, parallel_config
25
27
 
26
28
 
27
29
  from ipywidgets import interact, FloatSlider, IntSlider
@@ -31,14 +33,17 @@ from IPython.display import HTML, Video
31
33
  from scipy.interpolate import griddata
32
34
  from scipy.integrate import solve_ivp
33
35
  from tqdm import tqdm
36
+ from pathlib import Path
37
+ import fargopy as fp
38
+ from scipy.ndimage import gaussian_filter
34
39
 
35
40
  ###############################################################
36
41
  # Constants
37
42
  ###############################################################
38
43
  # Map of coordinates into FARGO3D coordinates
39
- """This dictionary maps the coordinates regular names (r, phi, theta, etc.) of
40
- different coordinate systems into the FARGO3D x, y, z
41
- """
44
+ # This dictionary maps the coordinates regular names (r, phi, theta, etc.) of
45
+ # different coordinate systems into the FARGO3D x, y, z
46
+
42
47
  COORDS_MAP = dict(
43
48
  cartesian = dict(x='x',y='y',z='z'),
44
49
  cylindrical = dict(phi='x',r='y',z='z'),
@@ -49,21 +54,54 @@ COORDS_MAP = dict(
49
54
  # Classes
50
55
  ###############################################################
51
56
  class Field(fargopy.Fargobj):
52
- """Fields:
53
-
54
- Attributes:
55
- coordinates: type of coordinates (cartesian, cylindrical, spherical)
56
- data: numpy arrays with data of the field
57
-
58
- Methods:
59
- slice: get an slice of a field along a given spatial direction.
60
- Examples:
61
- >>> density.slice(r=0.5) # Take the closest slice to r = 0.5
62
- >>> density.slice(ir=20) # Take the slice through the 20 shell
63
- >>> density.slice(phi=30*RAD,interp='nearest') # Take a slice interpolating to the nearest
57
+ """Represents a simulation field (scalar or vector) with coordinate system and domain information.
58
+
59
+ The ``Field`` object encapsulates the data arrays and associated
60
+ coordinate meshes for a specific simulation variable. It supports
61
+ slicing, coordinate transformation and simple visualization
62
+ helpers.
63
+
64
+ Attributes
65
+ ----------
66
+ data : np.ndarray
67
+ Numpy array containing the physical data.
68
+ coordinates : str
69
+ Coordinate system type ('cartesian', 'cylindrical', 'spherical').
70
+ domains : object
71
+ Object containing domain-specific geometry (e.g., mesh arrays).
72
+ type : str
73
+ Field type ('scalar' or 'vector').
74
+
75
+ Examples
76
+ --------
77
+ Load a field from a simulation object (returns a Field instance
78
+ if interpolation is disabled or a FieldInterpolator otherwise):
79
+
80
+ >>> fp.Simulation.load_field(fields='gasdens', snapshot=0, interpolate=False)
81
+
82
+ Access data and mesh:
83
+
84
+ >>> rho = field.data
85
+ >>> xmesh = field.mesh.x
64
86
  """
65
87
 
66
88
  def __init__(self,data=None,coordinates='cartesian',domains=None,type='scalar',**kwargs):
89
+ """
90
+ Initialize a Field object.
91
+
92
+ Parameters
93
+ ----------
94
+ data : np.ndarray, optional
95
+ Field data array.
96
+ coordinates : str, optional
97
+ Coordinate system ('cartesian', 'cylindrical', 'spherical').
98
+ domains : object, optional
99
+ Domain information for each coordinate.
100
+ type : str, optional
101
+ Field type ('scalar' or 'vector').
102
+ **kwargs : dict
103
+ Additional keyword arguments.
104
+ """
67
105
  super().__init__(**kwargs)
68
106
  self.data = data
69
107
  self.coordinates = coordinates
@@ -71,9 +109,34 @@ class Field(fargopy.Fargobj):
71
109
  self.type = type
72
110
 
73
111
  def meshslice(self,slice=None,component=0,verbose=False):
74
- """Perform a slice on a field and produce as an output the
75
- corresponding field slice and the associated matrices of
76
- coordinates for plotting.
112
+ """Perform a slice on a field and produce the corresponding field slice and
113
+ associated coordinate matrices for plotting.
114
+
115
+ Parameters
116
+ ----------
117
+ slice : str
118
+ Slice definition string (e.g., 'z=0').
119
+ component : int, optional
120
+ Component index for vector fields (default: 0).
121
+ verbose : bool, optional
122
+ If True, print debug information.
123
+
124
+ Returns
125
+ -------
126
+ tuple
127
+ (sliced field, mesh dictionary with coordinates). The mesh dictionary
128
+ contains coordinate arrays (x, y, z, r, phi, theta) corresponding
129
+ to the slice.
130
+
131
+ Examples
132
+ --------
133
+ Slice a field at z=0:
134
+
135
+ >>> field_slice, mesh = field.meshslice(slice='z=0')
136
+
137
+ Plot the slice:
138
+
139
+ >>> plt.pcolormesh(mesh.x, mesh.y, field_slice)
77
140
  """
78
141
  # Analysis of the slice
79
142
  if slice is None:
@@ -83,7 +146,7 @@ class Field(fargopy.Fargobj):
83
146
  slice = slice.replace('deg','*fargopy.DEG')
84
147
 
85
148
  # Perform the slice
86
- slice_cmd = f"self.slice({slice},pattern=True,verbose={verbose})"
149
+ slice_cmd = f"self._slice({slice},pattern=True,verbose={verbose})"
87
150
  slice,pattern = eval(slice_cmd)
88
151
 
89
152
  # Create the mesh
@@ -122,36 +185,36 @@ class Field(fargopy.Fargobj):
122
185
 
123
186
  return slice,mesh
124
187
 
125
- def slice(self,verbose=False,pattern=False,**kwargs):
126
- """Extract an slice of a 3-dimensional FARGO3D field
127
-
128
- Parameters:
129
- quiet: boolean, default = False:
130
- If True extract the slice quietly.
131
- Else, print some control messages.
132
-
133
- pattern: boolean, default = False:
134
- If True return the pattern of the slice, eg. [:,:,:]
135
-
136
- ir, iphi, itheta, ix, iy, iz: string or integer:
137
- Index or range of indexes of the corresponding coordinate.
138
-
139
- r, phi, theta, x, y, z: float/list/tuple:
140
- Value for slicing. The slicing search for the closest
141
- value in the domain.
142
-
143
- Returns:
144
- slice: sliced field.
145
-
146
- Examples:
147
- # 0D: Get the value of the field in iphi = 0, itheta = -1 and close to r = 0.82
148
- >>> gasvz.slice(iphi=0,itheta=-1,r=0.82)
149
-
150
- # 1D: Get all values of the field in radial direction at iphi = 0, itheta = -1
151
- >>> gasvz.slice(iphi=0,itheta=-1)
152
-
153
- # 2D: Get all values of the field for values close to phi = 0
154
- >>> gasvz.slice(phi=0)
188
+ def _slice(self,verbose=False,pattern=False,**kwargs):
189
+ """
190
+ Extract a slice of a 3-dimensional FARGO3D field.
191
+
192
+ Parameters
193
+ ----------
194
+ verbose : bool, optional
195
+ If True, print debug information.
196
+ pattern : bool, optional
197
+ If True, return the pattern of the slice (e.g., [:,:,:]).
198
+ ir, iphi, itheta, ix, iy, iz : int or str, optional
199
+ Index or range of indexes for the corresponding coordinate.
200
+ r, phi, theta, x, y, z : float, list, or tuple, optional
201
+ Value or range for slicing. The closest value in the domain is used.
202
+
203
+ Returns
204
+ -------
205
+ np.ndarray or tuple
206
+ Sliced field, and optionally the pattern string if pattern=True.
207
+
208
+ Examples
209
+ --------
210
+ # 0D: Get the value of the field at iphi=0, itheta=-1, and close to r=0.82
211
+ >>> gasvz.slice(iphi=0, itheta=-1, r=0.82)
212
+
213
+ # 1D: Get all values in radial direction at iphi=0, itheta=-1
214
+ >>> gasvz.slice(iphi=0, itheta=-1)
215
+
216
+ # 2D: Get all values for values close to phi=0
217
+ >>> gasvz.slice(phi=0)
155
218
  """
156
219
  # By default slice
157
220
  ivar = dict(x=':',y=':',z=':')
@@ -218,6 +281,20 @@ class Field(fargopy.Fargobj):
218
281
  return slice
219
282
 
220
283
  def to_cartesian(self):
284
+ """
285
+ Convert the field to cartesian coordinates.
286
+
287
+ Returns
288
+ -------
289
+ Field or tuple of Field
290
+ The field in cartesian coordinates. For scalar fields, returns the field itself.
291
+ For vector fields, returns a tuple (vx, vy, vz).
292
+
293
+ Examples
294
+ --------
295
+ >>> v = sim.load_field('gasv', snapshot=0)
296
+ >>> vx, vy, vz = v.to_cartesian()
297
+ """
221
298
  if self.type == 'scalar':
222
299
  # Scalar fields are invariant under coordinate transformations
223
300
  return self
@@ -257,187 +334,580 @@ class Field(fargopy.Fargobj):
257
334
  Field(vz,coordinates=self.coordinates,domains=self.domains,type='scalar'))
258
335
 
259
336
  def get_size(self):
337
+ """
338
+ Return the size of the field data in megabytes (MB).
339
+
340
+ Returns
341
+ -------
342
+ float
343
+ Size in MB.
344
+ """
260
345
  return self.data.nbytes/1024**2
261
346
 
262
347
  def __str__(self):
348
+ """
349
+ String representation of the field data.
350
+
351
+ Returns
352
+ -------
353
+ str
354
+ """
263
355
  return str(self.data)
264
356
 
265
357
  def __repr__(self):
358
+ """
359
+ String representation of the field data.
360
+
361
+ Returns
362
+ -------
363
+ str
364
+ """
266
365
  return str(self.data)
267
366
 
268
367
 
368
+ # ###############################################################
369
+ # FieldInterpolator
370
+ # ###############################################################
371
+ # This class is used to load and interpolate fields from a FARGO3D simulation.
372
+ # It provides methods to load data, create meshes, and perform interpolation.
373
+ # It also handles 2D and 3D data loading based on the provided parameters.
374
+ #################################################################
375
+
376
+
269
377
  class FieldInterpolator:
270
- def __init__(self, sim):
378
+ """Loads and interpolates fields from a FARGO3D simulation.
379
+
380
+ The ``FieldInterpolator`` facilitates loading, slicing, and interpolating
381
+ simulation data across multiple snapshots and fields. It handles coordinate
382
+ transformations and dimensionality reduction based on slice definitions.
383
+
384
+ Attributes
385
+ ----------
386
+ sim : Simulation
387
+ The simulation object.
388
+ df : pd.DataFrame or None
389
+ DataFrame containing loaded field data organized by snapshot and time.
390
+ snapshot_time_table : pd.DataFrame or None
391
+ Table mapping snapshots to normalized time.
392
+ snapshot : list or None
393
+ List of loaded snapshots.
394
+ slice : str or None
395
+ Slice definition string used to load the data.
396
+ dim : int or None
397
+ Dimensionality of the loaded data (e.g., 2 for a 2D slice).
398
+
399
+ Examples
400
+ --------
401
+ Load multiple fields from a snapshot with interpolation enabled:
402
+
403
+ >>> data = sim.load_field(fields=['gasdens', 'gasv'], snapshot=4)
404
+
405
+ Load a specific slice:
406
+
407
+ >>> dens = sim.load_field(fields='gasdens', slice='r=[0.8,1.2],phi=[-25 deg,25 deg],theta=1.56', snapshot=4)
408
+ """
409
+
410
+ def __init__(self, sim, df=None):
411
+ """
412
+ Initialize a FieldInterpolator.
413
+
414
+ Parameters
415
+ ----------
416
+ sim : Simulation
417
+ The simulation object.
418
+ df : pd.DataFrame, optional
419
+ DataFrame with preloaded field data.
420
+ """
271
421
  self.sim = sim
272
422
  self.snapshot_time_table = None
423
+ self.snapshot = None
424
+ self.slice = None
425
+ self.dim=None
426
+ self.df = df
427
+ self._domain_limits = None
428
+ self._df_sorted_cache = None
429
+ self._slice_type = None
430
+ self._slice_ranges = None
431
+
432
+ def _reset_caches(self):
433
+ """Clear cached dataframe and slice metadata before loading or evaluating new data."""
434
+ self._df_sorted_cache = None
435
+ self._slice_type = None
436
+ self._slice_ranges = None
437
+
438
+ def _cache_domain_limits(self):
439
+ """Cache domain extrema for r, theta, and phi to avoid repeated property access."""
440
+ if self._domain_limits is not None:
441
+ return
442
+ dom = self.sim.domains
443
+ self._domain_limits = dict(
444
+ r=(dom.r.min(), dom.r.max()),
445
+ theta=(dom.theta.min(), dom.theta.max()),
446
+ phi=(dom.phi.min(), dom.phi.max())
447
+ )
448
+
449
+ def _detect_slice_type(self, slice_str):
450
+ """Return the canonical slice type ('theta', 'phi', 'r', or None) inferred from the user string."""
451
+ if not slice_str:
452
+ return None
453
+ txt = slice_str.replace(" ", "").lower()
454
+ m_theta = re.search(r"theta=([^\[\],]+)(?![\]])", txt)
455
+ m_phi = re.search(r"phi=([^\[\],]+)(?![\]])", txt)
456
+ if m_theta and not re.search(r"theta=\[", txt) and m_phi and not re.search(r"phi=\[", txt):
457
+ return "r"
458
+ if m_theta and not re.search(r"theta=\[", txt):
459
+ return "theta"
460
+ if m_phi and not re.search(r"phi=\[", txt):
461
+ return "phi"
462
+ return None
463
+
464
+ def _parse_slice_ranges(self, slice_str):
465
+ """Parse the slice expression into numeric (r, theta, phi, z) bounds expressed in radians when needed."""
466
+ ranges = {"r": None, "theta": None, "phi": None, "z": None}
467
+ if not slice_str:
468
+ return ranges
469
+ txt = slice_str.lower()
470
+
471
+ def _to_float(value):
472
+ value = value.strip()
473
+ match = re.match(r"(-?\d+(?:\.\d+)?)\s*deg", value)
474
+ return np.deg2rad(float(match.group(1))) if match else float(value)
475
+
476
+ range_pattern = re.compile(r"(r|theta|phi|z)=\[(.+?)\]")
477
+ value_pattern = re.compile(r"(r|theta|phi|z)=([^\[\],]+)")
478
+
479
+ for key, vals in range_pattern.findall(txt):
480
+ lo, hi = [_to_float(v) for v in vals.split(",")]
481
+ ranges[key] = (min(lo, hi), max(lo, hi))
482
+ for key, val in value_pattern.findall(txt):
483
+ if ranges.get(key) is None:
484
+ parsed = _to_float(val)
485
+ ranges[key] = (parsed, parsed)
486
+ return ranges
487
+
488
+ def _get_sorted_dataframe(self, dataframe):
489
+ """Return the dataframe sorted by normalized time, reusing a cached copy when possible."""
490
+ if self._df_sorted_cache and self._df_sorted_cache[0] is dataframe:
491
+ return self._df_sorted_cache[1]
492
+ df_sorted = dataframe.sort_values("time")
493
+ self._df_sorted_cache = (dataframe, df_sorted)
494
+ return df_sorted
495
+
496
+ def __getattr__(self, name):
497
+ """
498
+ Delegate attribute access to the internal DataFrame if present.
499
+
500
+ Parameters
501
+ ----------
502
+ name : str
503
+ Attribute name.
504
+
505
+ Returns
506
+ -------
507
+ object
508
+ Attribute from the DataFrame if available.
509
+
510
+ Raises
511
+ ------
512
+ AttributeError
513
+ If the attribute is not found.
514
+ """
515
+ df = object.__getattribute__(self, 'df')
516
+ if df is not None and hasattr(df, name):
517
+ return getattr(df, name)
518
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
519
+
520
+ def _run_parallel(self, tasks, backend='threading'):
521
+ tasks = list(tasks)
522
+ if not tasks:
523
+ return []
524
+ with parallel_config(n_jobs=-1, prefer='threads'):
525
+ return Parallel(backend=backend)(tasks)
526
+
527
+ def load_data(self, fields=None, slice=None, snapshots=1, cut=None, coords='cartesian'):
528
+ """Load one or multiple fields into a unified DataFrame.
529
+
530
+ This method loads simulation data for the specified fields and snapshots,
531
+ potentially applying a spatial slice or cut. The data is stored in
532
+ an internal DataFrame (`self.df`) for further processing or interpolation.
533
+
534
+ Parameters
535
+ ----------
536
+ fields : str or list of str
537
+ Name(s) of the fields to load (e.g., 'gasdens', 'gasv').
538
+ slice : str, optional
539
+ Slice definition (e.g., 'z=0', 'theta=1.57').
540
+ snapshots : int or list of int, optional
541
+ Snapshot number(s) to load. Can be a single integer or a range [start, end].
542
+ cut : list, optional
543
+ Geometric cut parameters (sphere or cylinder mask).
544
+ coords : str, optional
545
+ Coordinate system for vector transformation ('cartesian' by default).
546
+
547
+ Returns
548
+ -------
549
+ pd.DataFrame
550
+ The DataFrame containing the loaded data.
551
+
552
+ Examples
553
+ --------
554
+ Load density and velocity for snapshot 10:
555
+
556
+ >>> loader = fp.FieldInterpolator(sim)
557
+ >>> df = loader.load_data(fields=['gasdens', 'gasv'], snapshot=10)
558
+
559
+ Load a 2D slice at z=0 (midplane):
560
+
561
+ >>> df_slice = loader.load_data(fields='gasdens', slice='z=0', snapshot=10)
562
+ """
273
563
 
274
- def load_data(self, field=None, slice=None, snapshots=None):
275
- self.field = field
276
- self.slice=slice
277
-
278
- # Convert a single snapshot to a list
564
+ # -------------------------
565
+ # Validate arguments
566
+ # -------------------------
567
+ self._reset_caches()
568
+
569
+ if fields is None:
570
+ raise ValueError("You must specify at least one field.")
571
+ if isinstance(fields, str):
572
+ fields = [fields]
573
+
574
+ self.fields = fields
575
+ self.slice = slice
576
+
577
+ # Convert snapshot into list
279
578
  if isinstance(snapshots, int):
280
579
  snapshots = [snapshots]
580
+ self.snapshot = snapshots
581
+
582
+ # Detect dimensionality from the sliced data (if a slice is provided)
583
+ if slice is not None:
584
+ test_field = self.sim._load_field_raw('gasdens', snapshot=snapshots[0], field_type='scalar')
585
+ try:
586
+ data_slice, mesh = test_field.meshslice(slice=slice)
587
+ self.dim = len(np.array(data_slice).shape)
588
+ except Exception:
589
+ # Fallback: assume full 3D
590
+ self.dim = 3
591
+ else:
592
+ self.dim = 3
281
593
 
282
-
283
- # Handle the case where snapshots is a single value or a list with one value
594
+ # Snapshot & time arrays
284
595
  if len(snapshots) == 1:
285
-
286
596
  snaps = snapshots
287
- time_values = [0] # Single snapshot corresponds to a single time value
597
+ time_values = [0]
288
598
  else:
289
599
  snaps = np.arange(snapshots[0], snapshots[1] + 1)
290
600
  time_values = np.linspace(0, 1, len(snaps))
291
601
 
292
- # Guarda la tabla como DataFrame
602
+ # Store snapshot-time table
293
603
  self.snapshot_time_table = pd.DataFrame({
294
604
  "Snapshot": snaps,
295
605
  "Normalized_time": time_values
296
606
  })
297
607
 
298
- """
299
- Loads data in 2D or 3D depending on the provided parameters.
300
-
301
- Parameters:
302
- field (list of str, optional): List of fields to load (e.g., ["gasdens", "gasv"]).
303
- slice (str, optional): Slice definition, e.g., "phi=0", "theta=45", or "z=0,r=[0.8,1.2],phi=[-10 deg,10 deg]".
304
- snapshots (list or int, optional): List of snapshot indices or a single snapshot to load. Required for both 2D and 3D.
305
- Returns:
306
- pd.DataFrame: DataFrame containing the loaded data.
307
- """
308
- if field is None:
309
- raise ValueError("You must specify at least one field to load using the 'fields' parameter.")
310
-
311
- # Validate and parse the slice parameter
312
- slice_type = None
313
- if slice:
314
- slice = slice.lower() # Normalize to lowercase for consistency
315
- if "theta" in slice:
316
- slice_type = "theta"
317
- elif "phi" in slice:
318
- slice_type = "phi"
319
- else:
320
- raise ValueError("The 'slice' parameter must contain 'theta' or 'phi'.")
321
-
322
- if not isinstance(snapshots, (int, list, tuple)):
323
- raise ValueError("'snapshots' must be an integer, a list, or a tuple.")
324
-
325
- if isinstance(snapshots, (list, tuple)) and len(snapshots) == 2:
326
- if snapshots[0] > snapshots[1]:
327
- raise ValueError("The range in 'snapshots' is invalid. The first value must be less than or equal to the second.")
328
-
608
+ # Slice handling
329
609
  if not hasattr(self.sim, "domains") or self.sim.domains is None:
330
- raise ValueError("Simulation domains are not loaded. Ensure the simulation data is properly initialized.")
331
-
332
-
333
-
334
-
335
- if slice: # Load 2D data
336
- # Dynamically create DataFrame columns based on the fields
337
- columns = ['snapshot', 'time', 'var1_mesh', 'var2_mesh']
338
- if field == "gasdens":
339
- print(f'Loading 2D density data for slice: {slice}.')
340
- columns.append('gasdens_mesh')
341
- if field == "gasv":
342
- columns.append('gasv_mesh')
343
- print(f'Loading 2D gas velocity data for slice: {slice}.')
344
- if field == 'gasenergy':
345
- columns.append('gasenergy_mesh')
346
- print(f'Loading 2D gas energy data for slice {slice}')
347
- df_snapshots = pd.DataFrame(columns=columns)
610
+ raise ValueError("Simulation domains are not loaded.")
611
+ self._cache_domain_limits()
612
+ self._slice_type = self._detect_slice_type(slice)
613
+ self._slice_ranges = self._parse_slice_ranges(slice)
614
+
615
+ # -------------------------
616
+ # Helper for rotation (phi slices)
617
+ # -------------------------
618
+ def _rotation(X, Y, Z, phi0):
619
+ X_rot = X * np.cos(phi0) + Y * np.sin(phi0)
620
+ Y_rot = -X * np.sin(phi0) + Y * np.cos(phi0)
621
+ return X_rot, Y_rot, Z.copy()
622
+
623
+ # =====================================================================
624
+ # ======================== 2D CASE ================================
625
+ # =====================================================================
626
+ if self.dim < 3:
627
+
628
+ # Collect rows and build DataFrame once to avoid repeated concat
629
+ rows = []
348
630
 
349
631
  for i, snap in enumerate(snaps):
350
- row = {'snapshot': snap, 'time': time_values[i]}
351
-
352
- # Assign coordinates for all fields
353
- if field == 'gasdens':
354
- gasd = self.sim.load_field('gasdens', snapshot=snap, type='scalar')
355
- gasd_slice, mesh = gasd.meshslice(slice=slice)
356
- if slice_type == "phi":
357
- row["var1_mesh"], row["var2_mesh"] = getattr(mesh, "x"), getattr(mesh, "z")
358
- elif slice_type == "theta":
359
- row["var1_mesh"], row["var2_mesh"] = getattr(mesh, "x"), getattr(mesh, "y")
360
- row['gasdens_mesh'] = gasd_slice
361
-
362
- if field == "gasv":
363
- gasv = self.sim.load_field('gasv', snapshot=snap, type='vector')
364
- gasvx, gasvy, gasvz = gasv.to_cartesian()
365
- if slice_type == "phi":
366
- # Plane XZ: use vx and vz
367
- vel1_slice, mesh1 = getattr(gasvx, f'meshslice')(slice=slice)
368
- vel2_slice, mesh2 = getattr(gasvz, f'meshslice')(slice=slice)
369
- row["var1_mesh"], row["var2_mesh"] = getattr(mesh1, "x"), getattr(mesh1, "z")
370
- row['gasv_mesh'] = np.array([vel1_slice, vel2_slice])
371
- elif slice_type == "theta":
372
- # Plane XY: use vx and vy
373
- vel1_slice, mesh1 = getattr(gasvx, f'meshslice')(slice=slice)
374
- vel2_slice, mesh2 = getattr(gasvy, f'meshslice')(slice=slice)
375
- row["var1_mesh"], row["var2_mesh"] = getattr(mesh1, "x"), getattr(mesh1, "y")
376
- row['gasv_mesh'] = np.array([vel1_slice, vel2_slice])
377
-
378
- if field == "gasenergy":
379
- gasenergy = self.sim.load_field('gasenergy', snapshot=snap, type='scalar')
380
- gasenergy_slice, mesh = gasenergy.meshslice(slice=slice)
381
- row["gasenergy_mesh"] = gasenergy_slice
382
- if slice_type == "phi":
383
- row["var1_mesh"], row["var2_mesh"] = getattr(mesh, "x"), getattr(mesh, "z")
384
- elif slice_type == "theta":
385
- row["var1_mesh"], row["var2_mesh"] = getattr(mesh, "x"), getattr(mesh, "y")
386
-
387
- # Convert the row to a DataFrame and concatenate it
388
- row_df = pd.DataFrame([row])
389
- df_snapshots = pd.concat([df_snapshots, row_df], ignore_index=True)
390
632
 
633
+ row = {'snapshot': snap, 'time': time_values[i]}
634
+ coords_assigned = False # Only assign var1/var2/var3 once
635
+
636
+ # Loop over requested fields
637
+ for field in fields:
638
+
639
+ # -----------------
640
+ # GASDENS 2D
641
+ # -----------------
642
+ if field == 'gasdens':
643
+ gasd = self.sim._load_field_raw('gasdens', snapshot=snap, field_type='scalar')
644
+ data_slice, mesh = gasd.meshslice(slice=slice)
645
+
646
+ # assign coordinates only once
647
+ if not coords_assigned:
648
+ if coords == 'cartesian':
649
+ # rotate if phi is fixed
650
+ try:
651
+ if np.all(mesh.phi.ravel() == mesh.phi.ravel()[0]):
652
+ phi0 = mesh.phi.ravel()[0]
653
+ x_rot, y_rot, z_rot = _rotation(mesh.x, mesh.y, mesh.z, phi0)
654
+ row['var1_mesh'] = x_rot
655
+ row['var2_mesh'] = y_rot
656
+ row['var3_mesh'] = z_rot
657
+ else:
658
+ row['var1_mesh'] = mesh.x
659
+ row['var2_mesh'] = mesh.y
660
+ row['var3_mesh'] = mesh.z
661
+ except Exception:
662
+ # Fallback if mesh lacks phi
663
+ row['var1_mesh'] = mesh.x
664
+ row['var2_mesh'] = mesh.y
665
+ row['var3_mesh'] = mesh.z
666
+ else:
667
+ # original coordinate names as defined in simulation
668
+ vnames = getattr(self.sim.vars, 'VARIABLES', ['x', 'y', 'z'])
669
+ row['var1_mesh'] = getattr(mesh, vnames[0])
670
+ row['var2_mesh'] = getattr(mesh, vnames[1])
671
+ row['var3_mesh'] = getattr(mesh, vnames[2])
672
+ coords_assigned = True
673
+
674
+ row['gasdens_mesh'] = data_slice
675
+
676
+ # -----------------
677
+ # GASV 2D
678
+ # -----------------
679
+ if field == 'gasv':
680
+ gasv_raw = self.sim._load_field_raw('gasv', snapshot=snap, field_type='vector')
681
+ if coords == 'cartesian':
682
+ gasvx, gasvy, gasvz = gasv_raw.to_cartesian()
683
+ v1, mesh = gasvx.meshslice(slice=slice)
684
+ v2, mesh = gasvy.meshslice(slice=slice)
685
+ v3, mesh = gasvz.meshslice(slice=slice)
686
+
687
+ if not coords_assigned:
688
+ row['var1_mesh'] = mesh.x
689
+ row['var2_mesh'] = mesh.y
690
+ row['var3_mesh'] = mesh.z
691
+ coords_assigned = True
692
+
693
+ row['gasv_mesh'] = np.array([v1, v2, v3])
694
+ else:
695
+ v_slice, mesh = gasv_raw.meshslice(slice=slice)
696
+ v1, v2, v3 = v_slice[0], v_slice[1], v_slice[2]
697
+ if not coords_assigned:
698
+ vnames = getattr(self.sim.vars, 'VARIABLES', ['x', 'y', 'z'])
699
+ row['var1_mesh'] = getattr(mesh, vnames[0])
700
+ row['var2_mesh'] = getattr(mesh, vnames[1])
701
+ row['var3_mesh'] = getattr(mesh, vnames[2])
702
+ coords_assigned = True
703
+ row['gasv_mesh'] = np.array([v1, v2, v3])
704
+
705
+ # -----------------
706
+ # GASENERGY 2D
707
+ # -----------------
708
+ if field == 'gasenergy':
709
+ gasen = self.sim._load_field_raw('gasenergy', snapshot=snap, field_type='scalar')
710
+ data_slice, mesh = gasen.meshslice(slice=slice)
711
+
712
+ if not coords_assigned:
713
+ if coords == 'cartesian':
714
+ row['var1_mesh'] = mesh.x
715
+ row['var2_mesh'] = mesh.y
716
+ row['var3_mesh'] = mesh.z
717
+ else:
718
+ vnames = getattr(self.sim.vars, 'VARIABLES', ['x', 'y', 'z'])
719
+ row['var1_mesh'] = getattr(mesh, vnames[0])
720
+ row['var2_mesh'] = getattr(mesh, vnames[1])
721
+ row['var3_mesh'] = getattr(mesh, vnames[2])
722
+ coords_assigned = True
723
+
724
+ row['gasenergy_mesh'] = data_slice
725
+
726
+ # collect row dicts and build DataFrame once
727
+ rows.append(row)
728
+
729
+ df_snapshots = pd.DataFrame(rows)
391
730
  self.df = df_snapshots
392
731
  return df_snapshots
393
732
 
394
- elif slice is None: # Load 3D data
395
- # Generate 3D mesh
396
- theta, r, phi = np.meshgrid(self.sim.domains.theta, self.sim.domains.r, self.sim.domains.phi, indexing='ij')
397
- x, y, z = r * np.sin(theta) * np.cos(phi), r * np.sin(theta) * np.sin(phi), r * np.cos(theta)
398
-
399
- # Dynamically create DataFrame columns based on the fields
400
- columns = ['snapshot', 'time', 'var1_mesh', 'var2_mesh', 'var3_mesh']
401
- if field == "gasdens":
402
- print(f'Loading 3D density data ')
403
- columns.append('gasdens_mesh')
404
- if field == "gasv":
405
- columns.append('gasv_mesh')
406
- print(f'Loading 3D gas velocity data')
407
- if field == 'gasenergy':
408
- columns.append('gasenergy_mesh')
409
- print(f'Loading 3D gas energy data')
410
-
411
- df_snapshots = pd.DataFrame(columns=columns)
733
+ # =====================================================================
734
+ # ======================== 3D CASE ================================
735
+ # =====================================================================
736
+ if self.dim == 3:
737
+
738
+ # Build full mesh
739
+ theta, r, phi = np.meshgrid(
740
+ self.sim.domains.theta,
741
+ self.sim.domains.r,
742
+ self.sim.domains.phi,
743
+ indexing='ij'
744
+ )
745
+ x = r * np.sin(theta) * np.cos(phi)
746
+ y = r * np.sin(theta) * np.sin(phi)
747
+ z = r * np.cos(theta)
748
+
749
+ # Apply spherical or cylindrical mask (optional)
750
+ if cut is not None:
751
+ if len(cut) == 5:
752
+ xc, yc, zc, rc, hc = cut
753
+ r_xy = np.sqrt((x - xc)**2 + (y - yc)**2)
754
+ zmin, zmax = zc - hc/2, zc + hc/2
755
+ mask = (r_xy <= rc) & (z >= zmin) & (z <= zmax)
756
+ elif len(cut) == 4:
757
+ xc, yc, zc, rs = cut
758
+ r_sph = np.sqrt((x - xc)**2 + (y - yc)**2 + (z - zc)**2)
759
+ mask = r_sph <= rs
760
+ else:
761
+ raise ValueError("cut must have 4 (sphere) or 5 (cylinder) elements.")
762
+ else:
763
+ mask = None
764
+
765
+ # Collect rows and build DataFrame once to avoid repeated concat
766
+ rows = []
412
767
 
413
768
  for i, snap in enumerate(snaps):
414
- row = {'snapshot': snap, 'time': time_values[i]}
415
-
416
- # Assign coordinates for all fields
417
- if field == 'gasdens':
418
- gasd = self.sim.load_field('gasdens', snapshot=snap, type='scalar')
419
- row["var1_mesh"], row["var2_mesh"], row["var3_mesh"] = x, y, z
420
- row['gasdens_mesh'] = gasd.data
421
-
422
- if field == "gasv":
423
- gasv = self.sim.load_field('gasv', snapshot=snap, type='vector')
424
- gasvx, gasvy, gasvz = gasv.to_cartesian()
425
- row["var1_mesh"], row["var2_mesh"], row["var3_mesh"] = x, y, z
426
- row['gasv_mesh'] = np.array([gasvx.data, gasvy.data, gasvz.data])
427
-
428
- if field == "gasenergy":
429
- gasenergy = self.sim.load_field('gasenergy', snapshot=snap, type='scalar')
430
- row["gasenergy_mesh"] = gasenergy.data
431
- row["var1_mesh"], row["var2_mesh"], row["var3_mesh"] = x, y, z
432
-
433
- # Convert the row to a DataFrame and concatenate it
434
- row_df = pd.DataFrame([row])
435
- df_snapshots = pd.concat([df_snapshots, row_df], ignore_index=True)
436
769
 
770
+ row = {'snapshot': snap, 'time': time_values[i]}
771
+ coords_assigned = False
772
+
773
+ # Loop over requested fields
774
+ for field in fields:
775
+
776
+ # -----------------
777
+ # GASDENS 3D
778
+ # -----------------
779
+ if field == "gasdens":
780
+ gasd = self.sim._load_field_raw('gasdens', snapshot=snap, field_type='scalar')
781
+
782
+ if not coords_assigned:
783
+ if coords == 'cartesian':
784
+ if mask is not None:
785
+ row["var1_mesh"] = x[mask]
786
+ row["var2_mesh"] = y[mask]
787
+ row["var3_mesh"] = z[mask]
788
+ else:
789
+ row["var1_mesh"] = x
790
+ row["var2_mesh"] = y
791
+ row["var3_mesh"] = z
792
+ else:
793
+ # original coordinate variables order
794
+ v0, v1, v2 = self.sim.vars.VARIABLES
795
+ mapping = dict(r=r,phi=phi,theta=theta,x=x,y=y,z=z)
796
+ if mask is not None:
797
+ row["var1_mesh"] = mapping[v0][mask]
798
+ row["var2_mesh"] = mapping[v1][mask]
799
+ row["var3_mesh"] = mapping[v2][mask]
800
+ else:
801
+ row["var1_mesh"] = mapping[v0]
802
+ row["var2_mesh"] = mapping[v1]
803
+ row["var3_mesh"] = mapping[v2]
804
+ coords_assigned = True
805
+
806
+ row["gasdens_mesh"] = gasd.data[mask] if mask is not None else gasd.data
807
+ # -----------------
808
+ # GASV 3D
809
+ # -----------------
810
+ if field == "gasv":
811
+ gasv_raw = self.sim._load_field_raw('gasv', snapshot=snap, field_type='vector')
812
+ if coords == 'cartesian':
813
+ gasvx, gasvy, gasvz = gasv_raw.to_cartesian()
814
+
815
+ if not coords_assigned:
816
+ if mask is not None:
817
+ row["var1_mesh"] = x[mask]
818
+ row["var2_mesh"] = y[mask]
819
+ row["var3_mesh"] = z[mask]
820
+ else:
821
+ row["var1_mesh"] = x
822
+ row["var2_mesh"] = y
823
+ row["var3_mesh"] = z
824
+ coords_assigned = True
825
+
826
+ if mask is not None:
827
+ row["gasv_mesh"] = np.array([
828
+ gasvx.data[mask],
829
+ gasvy.data[mask],
830
+ gasvz.data[mask]
831
+ ])
832
+ else:
833
+ row["gasv_mesh"] = np.array([
834
+ gasvx.data,
835
+ gasvy.data,
836
+ gasvz.data
837
+ ])
838
+ else:
839
+ vdata = gasv_raw.data
840
+ if not coords_assigned:
841
+ v0, v1, v2 = self.sim.vars.VARIABLES
842
+ mapping = dict(r=r,phi=phi,theta=theta,x=x,y=y,z=z)
843
+ if mask is not None:
844
+ row["var1_mesh"] = mapping[v0][mask]
845
+ row["var2_mesh"] = mapping[v1][mask]
846
+ row["var3_mesh"] = mapping[v2][mask]
847
+ else:
848
+ row["var1_mesh"] = mapping[v0]
849
+ row["var2_mesh"] = mapping[v1]
850
+ row["var3_mesh"] = mapping[v2]
851
+ coords_assigned = True
852
+
853
+ if mask is not None:
854
+ row["gasv_mesh"] = np.array([vdata[0][mask], vdata[1][mask], vdata[2][mask]])
855
+ else:
856
+ row["gasv_mesh"] = np.array([vdata[0], vdata[1], vdata[2]])
857
+
858
+ # -----------------
859
+ # GASENERGY 3D
860
+ # -----------------
861
+ if field == "gasenergy":
862
+ gasen = self.sim._load_field_raw('gasenergy', snapshot=snap, field_type='scalar')
863
+
864
+ if not coords_assigned:
865
+ if coords == 'cartesian':
866
+ if mask is not None:
867
+ row["var1_mesh"] = x[mask]
868
+ row["var2_mesh"] = y[mask]
869
+ row["var3_mesh"] = z[mask]
870
+ else:
871
+ row["var1_mesh"] = x
872
+ row["var2_mesh"] = y
873
+ row["var3_mesh"] = z
874
+ else:
875
+ v0, v1, v2 = self.sim.vars.VARIABLES
876
+ mapping = dict(r=r,phi=phi,theta=theta,x=x,y=y,z=z)
877
+ if mask is not None:
878
+ row["var1_mesh"] = mapping[v0][mask]
879
+ row["var2_mesh"] = mapping[v1][mask]
880
+ row["var3_mesh"] = mapping[v2][mask]
881
+ else:
882
+ row["var1_mesh"] = mapping[v0]
883
+ row["var2_mesh"] = mapping[v1]
884
+ row["var3_mesh"] = mapping[v2]
885
+ coords_assigned = True
886
+
887
+ row["gasenergy_mesh"] = gasen.data[mask] if mask is not None else gasen.data
888
+
889
+ # collect row dicts and build DataFrame once
890
+ rows.append(row)
891
+
892
+ df_snapshots = pd.DataFrame(rows)
437
893
  self.df = df_snapshots
438
894
  return df_snapshots
895
+
439
896
 
440
897
  def times(self):
898
+ """
899
+ Return the snapshot time table mapping snapshots to normalized time.
900
+
901
+ Returns
902
+ -------
903
+ pd.DataFrame
904
+ DataFrame with columns 'Snapshot' and 'Normalized_time'.
905
+
906
+ Raises
907
+ ------
908
+ ValueError
909
+ If no data has been loaded.
910
+ """
441
911
  if self.snapshot_time_table is None:
442
912
  raise ValueError("No data loaded. Run load_data() first.")
443
913
  return self.snapshot_time_table
@@ -453,14 +923,21 @@ class FieldInterpolator:
453
923
  Create a mesh grid based on the slice definition provided by the user.
454
924
  If no slice is provided, create a full 3D mesh within the simulation domain.
455
925
 
456
- Parameters:
457
- slice (str, optional): The slice definition string (e.g., "r=[0.8,1.2],phi=0,theta=[0 deg,90 deg]").
458
- nr (int): Number of divisions in r.
459
- ntheta (int): Number of divisions in theta.
460
- nphi (int): Number of divisions in phi.
461
-
462
- Returns:
463
- tuple: Mesh grid (x, y, z) based on the slice definition or the full domain.
926
+ Parameters
927
+ ----------
928
+ slice : str, optional
929
+ The slice definition string (e.g., "r=[0.8,1.2],phi=0,theta=[0 deg,90 deg]").
930
+ nr : int
931
+ Number of divisions in r.
932
+ ntheta : int
933
+ Number of divisions in theta.
934
+ nphi : int
935
+ Number of divisions in phi.
936
+
937
+ Returns
938
+ -------
939
+ tuple
940
+ Mesh grid (x, y, z) based on the slice definition or the full domain.
464
941
  """
465
942
  import numpy as np
466
943
  import re
@@ -553,253 +1030,761 @@ class FieldInterpolator:
553
1030
  else:
554
1031
  raise ValueError("Slice definition must include either 'z', 'phi', or 'theta'.")
555
1032
 
1033
+
1034
+
1035
+ def _domain_mask(self, xi, slice=None):
1036
+ """
1037
+ Build a boolean mask that keeps only coordinates within the simulation domain and
1038
+ enforces any user-specified radial/angle limits for XY (theta) or XZ (phi) slices.
1039
+ """
1040
+ self._cache_domain_limits()
1041
+ slice = slice or self.slice
1042
+ slice_ranges = self._slice_ranges or self._parse_slice_ranges(slice)
1043
+ r_bounds = slice_ranges.get('r')
1044
+ theta_bounds = slice_ranges.get('theta')
1045
+ phi_bounds = slice_ranges.get('phi')
1046
+ r_min, r_max = self._domain_limits['r']
1047
+ theta_min, theta_max = self._domain_limits['theta']
1048
+ phi_min, phi_max = self._domain_limits['phi']
1049
+ eps = 1e-7
1050
+ xi = np.asarray(xi)
1051
+ ndim = xi.shape[1] if xi.ndim > 1 else 1
1052
+
1053
+ def _bounded(vals, bounds, default):
1054
+ if bounds is None:
1055
+ return vals >= default[0] - eps, vals <= default[1] + eps
1056
+ lo, hi = bounds
1057
+ return vals >= lo - eps, vals <= hi + eps
1058
+
1059
+ def _phi_in_bounds(phi_vals):
1060
+ if phi_bounds is None:
1061
+ return np.ones_like(phi_vals, dtype=bool)
1062
+ lo, hi = phi_bounds
1063
+ if lo <= hi:
1064
+ return (phi_vals >= lo - eps) & (phi_vals <= hi + eps)
1065
+ return (phi_vals >= lo - eps) | (phi_vals <= hi + eps)
1066
+
1067
+ if ndim == 2:
1068
+ # XY plane: theta fixed
1069
+ if slice is not None and 'theta' in slice:
1070
+ # XY plane: z = 0, theta fixed, filter by r and phi
1071
+ x, y = xi[:, 0], xi[:, 1]
1072
+ r = np.sqrt(x**2 + y**2)
1073
+ phi = np.arctan2(y, x)
1074
+ r_ge, r_le = _bounded(r, r_bounds, (r_min, r_max))
1075
+ mask = r_ge & r_le & _phi_in_bounds(phi)
1076
+ return mask
1077
+
1078
+ # XZ plane: phi fixed
1079
+ elif slice is not None and 'phi' in slice:
1080
+ # XZ plane: y = 0, phi fixed, filter by r and theta
1081
+ x, z = xi[:, 0], xi[:, 1]
1082
+ r = np.sqrt(x**2 + z**2)
1083
+ theta = np.arccos(z / np.clip(r, 1e-14, None))
1084
+ r_ge, r_le = _bounded(r, r_bounds, (r_min, r_max))
1085
+ if theta_bounds:
1086
+ lo, hi = theta_bounds
1087
+ theta_mask = (theta >= lo - eps) & (theta <= hi + eps)
1088
+ else:
1089
+ theta_mask = (
1090
+ ((theta > theta_min) | np.isclose(theta, theta_min, atol=eps)) &
1091
+ ((theta < theta_max) | np.isclose(theta, theta_max, atol=eps))
1092
+ )
1093
+ return r_ge & r_le & theta_mask
1094
+ else:
1095
+ # Default: treat as XY (theta fixed)
1096
+ x, y = xi[:, 0], xi[:, 1]
1097
+ z = np.zeros_like(x)
1098
+ r = np.sqrt(x**2 + y**2 + z**2)
1099
+ phi = np.arctan2(y, x)
1100
+ mask = (
1101
+ (r > r_min) &
1102
+ (r < r_max) )
1103
+ return mask
1104
+
1105
+ elif ndim == 1:
1106
+ # 1D input: could be r, theta, or phi
1107
+ xi_1d = np.asarray(xi).ravel()
1108
+
1109
+ # Decide which variable is the "free" one in the 1D cut.
1110
+ # Prefer explicit ranges (r=[...], theta=[...], phi=[...]); otherwise,
1111
+ # the free variable is the one that does NOT appear as a scalar in the slice.
1112
+ r_b = r_bounds
1113
+ th_b = theta_bounds
1114
+ ph_b = phi_bounds
1115
+
1116
+ def _is_range(b):
1117
+ return (b is not None) and (abs(b[1] - b[0]) > 1e-12)
1118
+
1119
+ if _is_range(r_b):
1120
+ free = 'r'
1121
+ elif _is_range(th_b):
1122
+ free = 'theta'
1123
+ elif _is_range(ph_b):
1124
+ free = 'phi'
1125
+ else:
1126
+ s = (slice or self.slice) or ""
1127
+ s_low = s.lower()
1128
+ has_r = re.search(r"\br\s*=", s_low) is not None
1129
+ has_th = re.search(r"\btheta\s*=", s_low) is not None
1130
+ has_ph = re.search(r"\bphi\s*=", s_low) is not None
1131
+ # the free variable is the one not present in the slice specification
1132
+ if not has_r:
1133
+ free = 'r'
1134
+ elif not has_th:
1135
+ free = 'theta'
1136
+ elif not has_ph:
1137
+ free = 'phi'
1138
+ else:
1139
+ free = 'r'
1140
+
1141
+ # Build mask depending on which variable is free
1142
+ if free == 'r':
1143
+ lo, hi = (r_b if r_b is not None else (r_min, r_max))
1144
+ mask = (xi_1d >= lo - eps) & (xi_1d <= hi + eps)
1145
+ return mask
1146
+
1147
+ if free == 'theta':
1148
+ lo, hi = (th_b if th_b is not None else (theta_min, theta_max))
1149
+ mask = (xi_1d >= lo - eps) & (xi_1d <= hi + eps)
1150
+ return mask
1151
+
1152
+ # free == 'phi'
1153
+ lo, hi = (ph_b if ph_b is not None else (phi_min, phi_max))
1154
+ if lo <= hi:
1155
+ mask = (xi_1d >= lo - eps) & (xi_1d <= hi + eps)
1156
+ else:
1157
+ # wrap-around range (e.g. [5.5, 0.5] in radians)
1158
+ mask = (xi_1d >= lo - eps) | (xi_1d <= hi + eps)
1159
+ return mask
1160
+
1161
+ if ndim==3:
1162
+ mask = np.ones(xi.shape[0],dtype=bool)
1163
+ return mask
1164
+
556
1165
  def evaluate(
557
- self, time, var1, var2=None, var3=None,
1166
+ self, time, var1, var2=None, var3=None, dataframe=None,
558
1167
  interpolator="griddata", method="linear",
559
- rbf_kwargs=None, griddata_kwargs=None, idw_kwargs=None
1168
+ rbf_kwargs=None, griddata_kwargs=None, idw_kwargs=None,
1169
+ sigma_smooth=None, field=None, reflect=False
560
1170
  ):
561
1171
  """
562
- Interpolates a field in 1D, 2D, or 3D using RBFInterpolator, griddata, LinearNDInterpolator, or IDW.
563
- Supports both grids and discrete points.
564
-
565
- Parameters:
566
- ...
567
- interpolator (str): Interpolation family, either "rbf", "griddata", "linearnd", or "idw". Default is "griddata".
568
- idw_kwargs (dict): Optional kwargs for IDW, e.g. {'power': 2, 'k': 8}
569
- ...
570
- """
1172
+ Evaluate the selected field at arbitrary spatial coordinates using
1173
+ multi-snapshot interpolation. Supports scalar and vector fields,
1174
+ 1D/2D/3D geometry, time interpolation, and several interpolation
1175
+ backends. Designed for unified DataFrames (gasdens + gasv + others).
571
1176
 
1177
+ Parameters
1178
+ ----------
1179
+ time : float or int
1180
+ Normalized time in [0,1] or snapshot index.
572
1181
 
573
- # --- Handle time input: explicit and robust: normalized time [0,1] or snapshot index ---
574
- if hasattr(self, "snapshot_time_table") and self.snapshot_time_table is not None:
575
- snaps = self.snapshot_time_table["Snapshot"].values
576
- min_snap, max_snap = snaps.min(), snaps.max()
577
- # If time is float in [0,1], treat as normalized time (directly)
578
- if isinstance(time, float) and 0 <= time <= 1:
579
- pass # Use as normalized time directly
580
- # If time is int or float > 1, treat as snapshot or fractional snapshot
581
- elif (isinstance(time, int) or (isinstance(time, float) and time > 1)):
582
- if time < min_snap or time > max_snap:
583
- raise ValueError(
584
- f"Selected snapshot (time={time}) is outside the loaded range [{min_snap}, {max_snap}]."
585
- )
586
- if isinstance(time, int) or np.isclose(time, np.round(time)):
587
- # Exact snapshot
588
- row = self.snapshot_time_table[self.snapshot_time_table["Snapshot"] == int(round(time))]
589
- if not row.empty:
590
- time = float(row["Normalized_time"].values[0])
591
- else:
592
- raise ValueError(f"Snapshot {int(round(time))} not found in snapshot_time_table.")
593
- else:
594
- # Fractional snapshot: interpolate between neighbors
595
- snap0 = int(np.floor(time))
596
- snap1 = int(np.ceil(time))
597
- if snap0 < min_snap or snap1 > max_snap:
598
- raise ValueError(
599
- f"Selected snapshot (time={time}) requires neighbors [{snap0}, {snap1}] outside the loaded range [{min_snap}, {max_snap}]."
600
- )
601
- row0 = self.snapshot_time_table[self.snapshot_time_table["Snapshot"] == snap0]
602
- row1 = self.snapshot_time_table[self.snapshot_time_table["Snapshot"] == snap1]
603
- if not row0.empty and not row1.empty:
604
- t0 = float(row0["Normalized_time"].values[0])
605
- t1 = float(row1["Normalized_time"].values[0])
606
- factor = (time - snap0) / (snap1 - snap0)
607
- time = (1 - factor) * t0 + factor * t1
608
- else:
609
- raise ValueError(f"Snapshots {snap0} or {snap1} not found in snapshot_time_table.")
610
- else:
611
- raise ValueError(
612
- f"Invalid time value: {time}. Must be a normalized time in [0,1] or a snapshot index in [{min_snap},{max_snap}]."
613
- )
614
- else:
615
- if isinstance(time, int):
616
- raise ValueError("snapshot_time_table not found. Did you call load_data()?")
617
-
618
- if interpolator not in ["rbf", "griddata", "linearnd","idw"]:
619
- raise ValueError("Invalid method. Choose either 'rbf', 'griddata', 'idw', or 'linearnd'.")
620
-
621
- # Automatically determine the field to interpolate
622
- if "gasdens_mesh" in self.df.columns:
623
- field_name = "gasdens_mesh"
624
- elif "gasenergy_mesh" in self.df.columns:
625
- field_name = "gasenergy_mesh"
626
- elif "gasv_mesh" in self.df.columns: # Velocity field
627
- field_name = "gasv_mesh"
628
- else:
629
- raise ValueError("No valid field found in the DataFrame for interpolation.")
1182
+ var1, var2, var3 : array-like or float
1183
+ Evaluation coordinates (x,y,z for 3D). Scalars are accepted.
630
1184
 
631
- # Sort the DataFrame by time
632
- df_sorted = self.df.sort_values("time")
633
- times = df_sorted["time"].values
634
- n_snaps = len(times)
1185
+ dataframe : pandas.DataFrame, optional
1186
+ Custom DataFrame. If omitted, self.df is used.
1187
+
1188
+ interpolator : {"griddata","rbf","linearnd","idw"}
1189
+ Backend used for spatial interpolation.
635
1190
 
636
- # Check if the input is a single point or a mesh
637
- is_scalar = np.isscalar(var1) and (var2 is None or np.isscalar(var2)) and (var3 is None or np.isscalar(var3))
638
- result_shape = () if is_scalar else var1.shape
1191
+ method : str
1192
+ Kernel/method used by backend (e.g., 'linear' for griddata).
639
1193
 
640
- if rbf_kwargs is None:
641
- rbf_kwargs = {}
642
- if griddata_kwargs is None:
643
- griddata_kwargs = {}
1194
+ sigma_smooth : float or None
1195
+ Optional Gaussian smoothing.
644
1196
 
1197
+ field : {"gasdens","gasv","gasenergy"} or None
1198
+ Field to evaluate. If None and DF has >1 field → explicit error.
645
1199
 
1200
+ Returns
1201
+ -------
1202
+ ndarray or float
1203
+ Interpolated value(s). Vector fields return shape (3,N) or (3,...).
1204
+ """
646
1205
 
647
- if idw_kwargs is None:
648
- idw_kwargs = {}
1206
+ # ===============================================================
1207
+ # Basic validation
1208
+ # ===============================================================
1209
+ if sigma_smooth is not None and sigma_smooth <= 0:
1210
+ raise ValueError("sigma_smooth must be None or positive.")
1211
+
1212
+ df = dataframe if dataframe is not None else self.df
1213
+ if df is None:
1214
+ raise ValueError("No dataframe available.")
1215
+
1216
+ # ===============================================================
1217
+ # FIELD SELECTION (safe and explicit)
1218
+ # ===============================================================
1219
+ field_map = {
1220
+ "gasdens": "gasdens_mesh",
1221
+ "gasv": "gasv_mesh",
1222
+ "gasenergy": "gasenergy_mesh"
1223
+ }
1224
+
1225
+ if field is not None:
1226
+ if field in field_map:
1227
+ field = field_map[field]
1228
+ if field not in df.columns:
1229
+ raise ValueError(
1230
+ f"Field '{field}' not in DF. Available: {list(df.columns)}"
1231
+ )
1232
+ else:
1233
+ # Autodetect only if exactly one exists
1234
+ candidates = [
1235
+ c for c in df.columns
1236
+ if c in ("gasdens_mesh","gasv_mesh","gasenergy_mesh")
1237
+ ]
1238
+ if len(candidates) != 1:
1239
+ raise ValueError(
1240
+ f"Multiple fields present {candidates}. "
1241
+ "Specify field='gasdens', 'gasv', or 'gasenergy'."
1242
+ )
1243
+ field = candidates[0]
649
1244
 
1245
+ # ===============================================================
1246
+ # Prepare snapshot ordering
1247
+ # ===============================================================
1248
+ df_sorted = self._get_sorted_dataframe(df)
1249
+ times = df_sorted["time"].values
1250
+ nsnaps = len(times)
1251
+
1252
+ # Detect scalar inputs
1253
+ is_scalar = (
1254
+ np.isscalar(var1)
1255
+ and (var2 is None or np.isscalar(var2))
1256
+ and (var3 is None or np.isscalar(var3))
1257
+ )
1258
+ result_shape = () if is_scalar else np.asarray(var1).shape
1259
+
1260
+ if np.isscalar(var1): var1 = np.array([var1])
1261
+ if np.isscalar(var2): var2 = np.array([var2])
1262
+ if np.isscalar(var3): var3 = np.array([var3])
1263
+
1264
+ # Convenience: allow calling evaluate(var1=..., var2=...) for
1265
+ # 2D XZ slices where the expected coordinates are (var1,var3).
1266
+ # If the slice type is not 'theta' (i.e. XZ) and the user passed
1267
+ # a value for var2 but left var3=None, treat var2 as var3.
1268
+ try:
1269
+ slice_type_tmp = self._slice_type or self._detect_slice_type(self.slice)
1270
+ except Exception:
1271
+ slice_type_tmp = None
1272
+ if self.dim == 2 and slice_type_tmp is not None and slice_type_tmp != 'theta':
1273
+ if var3 is None and var2 is not None:
1274
+ var3 = var2
1275
+ var2 = None
1276
+
1277
+ # ===============================================================
1278
+ # Smoothing helper
1279
+ # ===============================================================
1280
+ def _smooth(values):
1281
+ if sigma_smooth is None or np.isscalar(values):
1282
+ return values
1283
+
1284
+ arr = np.asarray(values)
1285
+ if arr.ndim == 0:
1286
+ return values
1287
+
1288
+ # Vector smoothing
1289
+ if field == "gasv_mesh" and arr.ndim >= 2:
1290
+ out = np.empty_like(arr)
1291
+ for k in range(arr.shape[0]):
1292
+ out[k] = gaussian_filter(arr[k], sigma=sigma_smooth)
1293
+ return out
1294
+
1295
+ return gaussian_filter(arr, sigma=sigma_smooth)
1296
+
1297
+ # ===============================================================
1298
+ # Interpolation backends
1299
+ # ===============================================================
650
1300
 
651
1301
  def idw_interp(coords, values, xi):
652
- # Forzar a 2D: (N, D) y (M, D)
653
1302
  coords = np.asarray(coords)
654
- xi = np.asarray(xi)
655
- if coords.ndim > 2:
656
- coords = coords.reshape(-1, coords.shape[-1])
657
- if xi.ndim > 2:
658
- xi = xi.reshape(-1, xi.shape[-1])
659
1303
  values = np.asarray(values).ravel()
660
- power = idw_kwargs.get('power', 2)
661
- k = idw_kwargs.get('k', 8)
1304
+ xi = np.asarray(xi)
1305
+
1306
+ mask = self._domain_mask(xi)
1307
+ if reflect:
1308
+ mask = np.ones(xi.shape[0], dtype=bool)
1309
+ out = np.zeros(xi.shape[0])
662
1310
  tree = cKDTree(coords)
663
- dists, idxs = tree.query(xi, k=k)
664
- dists = np.where(dists == 0, 1e-10, dists)
665
- weights = 1 / dists**power
666
- weights /= weights.sum(axis=1, keepdims=True)
667
- return np.sum(values[idxs] * weights, axis=1)
1311
+ k = idw_kwargs.get("k", 8)
1312
+ power = idw_kwargs.get("power", 2)
1313
+
1314
+ # If mask selects points, compute only there. Otherwise try for all xi.
1315
+ if np.any(mask):
1316
+ d, idxs = tree.query(xi[mask], k=k)
1317
+ d = np.where(d == 0, 1e-10, d)
1318
+ w = 1 / d**power
1319
+ w /= w.sum(axis=1, keepdims=True)
1320
+ out[mask] = np.sum(values[idxs] * w, axis=1)
1321
+ return out
1322
+
1323
+ # Fallback: compute for all xi
1324
+ d, idxs = tree.query(xi, k=k)
1325
+ d = np.where(d == 0, 1e-10, d)
1326
+ w = 1 / d**power
1327
+ w /= w.sum(axis=1, keepdims=True)
1328
+ out = np.sum(values[idxs] * w, axis=1)
1329
+ return out
1330
+
668
1331
 
669
1332
  def rbf_interp(coords, values, xi):
670
- # Check if epsilon is required for the selected kernel
671
- kernels_requiring_epsilon = ["gaussian", "multiquadric", "inverse_multiquadric", "inverse_quadratic"]
672
- if method in kernels_requiring_epsilon and "epsilon" not in rbf_kwargs:
673
- raise ValueError(f"Kernel '{method}' requires 'epsilon' in rbf_kwargs.")
674
- interpolator_obj = RBFInterpolator(
675
- coords, values.ravel(),
676
- kernel=method,
677
- **rbf_kwargs
678
- )
679
- return interpolator_obj(xi)
1333
+ coords = np.asarray(coords)
1334
+ values = np.asarray(values).ravel()
1335
+ xi = np.asarray(xi)
680
1336
 
1337
+ mask = self._domain_mask(xi)
1338
+ if reflect:
1339
+ mask = np.ones(xi.shape[0], dtype=bool)
1340
+ out = np.full(xi.shape[0], np.nan)
1341
+
1342
+ # Try interpolate where mask True
1343
+ try:
1344
+ obj = RBFInterpolator(coords, values, kernel=method, **(rbf_kwargs or {}))
1345
+ except Exception:
1346
+ return np.zeros(xi.shape[0])
1347
+
1348
+ if np.any(mask):
1349
+ out[mask] = obj(xi[mask])
1350
+ # attempt to leave other positions as NaN
1351
+ return np.where(np.isfinite(out), out, np.nan)
1352
+
1353
+ # Fallback: evaluate on all xi
1354
+ vals_all = obj(xi)
1355
+ return np.where(np.isfinite(vals_all), vals_all, np.nan)
1356
+
1357
+
681
1358
  def griddata_interp(coords, values, xi):
682
- return griddata(coords, values.ravel(), xi, method=method, **griddata_kwargs)
683
-
1359
+ # --- Apply domain mask: only interpolate inside the simulation domain ---
1360
+ mask = self._domain_mask(xi)
1361
+ if reflect:
1362
+ mask = np.ones(xi.shape[0], dtype=bool)
1363
+ out = np.full(xi.shape[0], np.nan)
1364
+
1365
+ # If mask has selected points, interpolate only there
1366
+ if np.any(mask):
1367
+ out[mask] = griddata(coords, values.ravel(), xi[mask], method=method)
1368
+ # leave outside as NaN -> caller can mask later
1369
+ return np.where(np.isfinite(out), out, np.nan)
1370
+
1371
+ # Fallback: try interpolate for all xi (useful when domain mask selection fails)
1372
+ try:
1373
+ vals_all = griddata(coords, values.ravel(), xi, method=method)
1374
+ return np.where(np.isfinite(vals_all), vals_all, np.nan)
1375
+ except Exception:
1376
+ return np.zeros(xi.shape[0])
1377
+
1378
+
684
1379
  def linearnd_interp(coords, values, xi):
685
- interp_obj = LinearNDInterpolator(coords, values.ravel())
686
- return interp_obj(xi)
687
-
688
- def interp(idx, field, component=None):
689
- if var2 is None and var3 is None: # 1D interpolation
690
- coord_x = np.array(df_sorted.iloc[idx]["var1_mesh"])
691
- if field == "gasv_mesh" and component is not None:
692
- data = np.array(df_sorted.iloc[idx][field])[component]
693
- else:
694
- data = np.array(df_sorted.iloc[idx][field])
695
- coords = coord_x.reshape(-1, 1)
696
- xi = var1.reshape(-1, 1) if not is_scalar else np.array([[var1]])
697
- if interpolator == "rbf":
698
- return rbf_interp(coords, data, xi)
699
- elif interpolator == "linearnd":
700
- return linearnd_interp(coords, data, xi)
701
- elif interpolator == "idw":
702
- return idw_interp(coords, data, xi)
703
- else:
704
- return griddata_interp(coords, data, xi)
705
-
706
- elif var3 is not None: # 3D interpolation
707
- coord_x = np.array(df_sorted.iloc[idx]["var1_mesh"])
708
- coord_y = np.array(df_sorted.iloc[idx]["var2_mesh"])
709
- coord_z = np.array(df_sorted.iloc[idx]["var3_mesh"])
710
- if field == "gasv_mesh" and component is not None:
711
- data = np.array(df_sorted.iloc[idx][field])[component]
712
- else:
713
- data = np.array(df_sorted.iloc[idx][field])
714
- coords = np.column_stack((coord_x.ravel(), coord_y.ravel(), coord_z.ravel()))
715
-
1380
+ coords = np.asarray(coords)
1381
+ values = np.asarray(values).ravel()
1382
+ xi = np.asarray(xi)
1383
+
1384
+ mask = self._domain_mask(xi)
1385
+ if reflect:
1386
+ mask = np.ones(xi.shape[0], dtype=bool)
1387
+ out = np.full(xi.shape[0], np.nan)
1388
+ obj = LinearNDInterpolator(coords, values)
1389
+
1390
+ if np.any(mask):
1391
+ out[mask] = obj(xi[mask])
1392
+ return np.where(np.isfinite(out), out, np.nan)
1393
+
1394
+ # Fallback: evaluate on all xi
1395
+ vals_all = obj(xi)
1396
+ return np.where(np.isfinite(vals_all), vals_all, np.zeros_like(vals_all))
1397
+
1398
+ # ===============================================================
1399
+ # Main interpolation kernel
1400
+ # ===============================================================
1401
+ slice_type = self._slice_type or self._detect_slice_type(self.slice)
1402
+
1403
+ def interp(idx, field_name, comp=None):
1404
+ row = df_sorted.iloc[idx]
1405
+
1406
+ cx = np.array(row["var1_mesh"])
1407
+ cy = np.array(row["var2_mesh"])
1408
+ cz = np.array(row["var3_mesh"])
1409
+
1410
+ # Build coordinate arrays
1411
+ if self.dim == 3:
1412
+ coords = np.column_stack((cx.ravel(), cy.ravel(), cz.ravel()))
716
1413
  xi = np.column_stack((var1.ravel(), var2.ravel(), var3.ravel()))
717
- if interpolator == "rbf":
718
- return rbf_interp(coords, data, xi)
719
- elif interpolator == "linearnd":
720
- return linearnd_interp(coords, data, xi)
721
- elif interpolator == "idw":
722
- return idw_interp(coords, data, xi)
1414
+ elif self.dim == 2:
1415
+ if slice_type == "theta":
1416
+ coords = np.column_stack((cx.ravel(), cy.ravel()))
1417
+ xi = np.column_stack((var1.ravel(), var2.ravel()))
723
1418
  else:
724
- return griddata_interp(coords, data, xi)
725
- else: # 2D interpolation
726
- coord1 = np.array(df_sorted.iloc[idx]["var1_mesh"])
727
- coord2 = np.array(df_sorted.iloc[idx]["var2_mesh"])
728
- if field == "gasv_mesh" and component is not None:
729
- data = np.array(df_sorted.iloc[idx][field])[component]
1419
+ coords = np.column_stack((cx.ravel(), cz.ravel()))
1420
+ xi = np.column_stack((var1.ravel(), var3.ravel()))
1421
+ elif self.dim == 1:
1422
+ coords = np.sqrt(cx**2 + cy**2 + cz**2)
1423
+ xi = np.asarray(var1)
1424
+
1425
+ # Select field
1426
+ data = row[field_name]
1427
+
1428
+ # -------------------------------------------
1429
+ # UNIVERSAL VECTOR COMPONENT SELECTOR
1430
+ # -------------------------------------------
1431
+ if isinstance(data, np.ndarray) and comp is not None:
1432
+ if data.ndim == 2 and data.shape[0] == 3:
1433
+ data = data[comp]
1434
+ elif data.ndim == 2 and data.shape[1] == 3:
1435
+ data = data[:, comp]
1436
+ elif data.ndim == 3 and data.shape[0] == 3:
1437
+ data = data[comp].ravel()
1438
+ elif data.ndim == 4 and data.shape[0] == 3:
1439
+ data = data[comp].ravel()
1440
+ elif data.ndim == 4 and data.shape[-1] == 3:
1441
+ data = data[..., comp].ravel()
730
1442
  else:
731
- data = np.array(df_sorted.iloc[idx][field])
732
- coords = np.column_stack((coord1.ravel(), coord2.ravel()))
733
- xi = np.column_stack((var1.ravel(), var2.ravel()))
734
- if interpolator == "rbf":
735
- return rbf_interp(coords, data, xi)
736
- elif interpolator == "linearnd":
737
- return linearnd_interp(coords, data, xi)
738
- elif interpolator == "idw":
739
- return idw_interp(coords, data, xi)
740
- else:
741
- return griddata_interp(coords, data, xi)
742
-
1443
+ raise ValueError(f"Cannot extract vector component from {data.shape}")
1444
+
1445
+ # -------------------------------------------------
1446
+ # Reflection augmentation
1447
+ # If `reflect=True` we augment the interpolation
1448
+ # dataset reflecting across the equatorial plane z=0
1449
+ # (i.e. z -> -z). For 2D XZ cuts (coords (x,z)) we flip z.
1450
+ # For vector components, the component normal to the
1451
+ # reflection plane (vz) changes sign.
1452
+ # -------------------------------------------------
1453
+ if reflect:
1454
+ try:
1455
+ # Normalize coords to shape (N, ndim)
1456
+ coords_arr = np.asarray(coords)
1457
+ ndim = coords_arr.shape[1] if coords_arr.ndim == 2 else 1
1458
+ coords_orig = coords_arr.reshape(-1, ndim)
1459
+ coords_ref = coords_orig.copy()
1460
+
1461
+ # Flip only the z coordinate (index -1 if 3D, index 1 if 2D XZ)
1462
+ if coords_orig.shape[1] == 3:
1463
+ coords_ref[:, 2] *= -1
1464
+ elif coords_orig.shape[1] == 2:
1465
+ # assume (x,z) layout for XZ cuts
1466
+ coords_ref[:, 1] *= -1
1467
+
1468
+ # Prepare data values (flattened)
1469
+ data_flat = np.asarray(data).ravel()
1470
+
1471
+ # For vector components, reflect sign for the
1472
+ # component perpendicular to the plane (vz -> -vz)
1473
+ if field_name == 'gasv_mesh' and comp is not None:
1474
+ # comp: 2->vz (flip), others unchanged
1475
+ if comp == 2:
1476
+ data_ref = -data_flat
1477
+ else:
1478
+ data_ref = data_flat.copy()
1479
+ else:
1480
+ # Scalars or already selected components
1481
+ data_ref = data_flat.copy()
1482
+
1483
+ # Augment coords and data for interpolation
1484
+ coords = np.vstack([coords_orig, coords_ref])
1485
+ data = np.concatenate([data_flat, data_ref])
1486
+ except Exception:
1487
+ # On error, fallback to original coords/data
1488
+ coords = np.asarray(coords)
1489
+ data = np.asarray(data)
1490
+
1491
+ # Dispatch backend
1492
+ if interpolator == "rbf":
1493
+ return rbf_interp(coords, data, xi)
1494
+ elif interpolator == "linearnd":
1495
+ return linearnd_interp(coords, data, xi)
1496
+ elif interpolator == "idw":
1497
+ return idw_interp(coords, data, xi)
1498
+ else:
1499
+ return griddata_interp(coords, data, xi)
1500
+
1501
+ # ===============================================================
1502
+ # TIME INTERPOLATION
1503
+ # ===============================================================
1504
+ if nsnaps == 1:
1505
+ if field == "gasv_mesh":
1506
+ vals = [interp(0, field, c) for c in range(3)]
1507
+ arr = np.array([v.item() if is_scalar else v.reshape(result_shape) for v in vals])
1508
+ return _smooth(arr)
1509
+ v = interp(0, field)
1510
+ return _smooth(v.item() if is_scalar else v.reshape(result_shape))
1511
+
1512
+ # Two snapshots
1513
+ if nsnaps == 2:
1514
+ i0, i1 = 0, 1
1515
+ t0, t1 = times[i0], times[i1]
1516
+ fac = (time - t0) / (t1 - t0) if abs(t1 - t0) > 1e-12 else 0
1517
+ fac = np.clip(fac, 0, 1)
1518
+
1519
+ def blend(c=None):
1520
+ v0 = interp(i0, field, c)
1521
+ v1 = interp(i1, field, c)
1522
+ return (1 - fac) * v0 + fac * v1
1523
+
1524
+ if field == "gasv_mesh":
1525
+ vals = [blend(c) for c in range(3)]
1526
+ arr = np.array([v.item() if is_scalar else v.reshape(result_shape) for v in vals])
1527
+ return _smooth(arr)
1528
+
1529
+ v = blend()
1530
+ return _smooth(v.item() if is_scalar else v.reshape(result_shape))
1531
+
1532
+ # Many snapshots
1533
+ i0 = np.searchsorted(times, time) - 1
1534
+ i0 = np.clip(i0, 0, nsnaps - 2)
1535
+ i1 = i0 + 1
1536
+ t0, t1 = times[i0], times[i1]
1537
+ fac = (time - t0) / (t1 - t0) if abs(t1 - t0) > 1e-12 else 0
1538
+ fac = np.clip(fac, 0, 1)
1539
+
1540
+ def blend(c=None):
1541
+ v0 = interp(i0, field, c)
1542
+ v1 = interp(i1, field, c)
1543
+ return (1 - fac) * v0 + fac * v1
1544
+
1545
+ if field == "gasv_mesh":
1546
+ vals = [blend(c) for c in range(3)]
1547
+ arr = np.array([v.item() if is_scalar else v.reshape(result_shape) for v in vals])
1548
+ return _smooth(arr)
1549
+
1550
+ v = blend()
1551
+ return _smooth(v.item() if is_scalar else v.reshape(result_shape))
1552
+
1553
+
1554
+ def plot(self, t=0, contour_levels=10, component='vz', smoothing_sigma=None, field=None):
1555
+ """
1556
+ Automatically determines the plane (XY, XZ, or 3D) and plots the field data.
1557
+
1558
+ Parameters
1559
+ ----------
1560
+ t : int, optional
1561
+ Index of the snapshot/time to plot.
1562
+ contour_levels : int, optional
1563
+ Number of contour levels for 2D plots.
1564
+ component : str, optional
1565
+ Component to plot for vector fields ('vx', 'vy', 'vz').
1566
+ field : str or None, optional
1567
+ Which field to plot when multiple fields were loaded (e.g. 'gasdens', 'gasv', 'gasenergy').
1568
+ If None and exactly one field is present, that field is plotted. If None and multiple
1569
+ candidate fields exist, a ValueError is raised instructing the user to pick one.
1570
+ """
743
1571
 
744
- # --- Case 1: only a snapshot ---
745
- if n_snaps == 1:
746
- def eval_single(component=None):
747
- return interp(0, field_name, component)
748
- if field_name == "gasv_mesh":
749
- components = 3 if var3 is not None else 2 if var2 is not None else 1
750
- results = Parallel(n_jobs=-1)(
751
- delayed(eval_single)(i) for i in range(components)
752
- )
753
- return np.array([res.item() if is_scalar else res.reshape(result_shape) for res in results])
1572
+ if self.df is None:
1573
+ raise ValueError("No data loaded. Run load_field() first.")
1574
+
1575
+ if component=='vz':
1576
+ comp = 2
1577
+ if component=='vy':
1578
+ comp = 1
1579
+ if component=='vx':
1580
+ comp = 0
1581
+ df_names = self.df.columns.tolist()
1582
+
1583
+ # Map short names to dataframe column names
1584
+ field_map = {
1585
+ 'gasdens': 'gasdens_mesh',
1586
+ 'gasv': 'gasv_mesh',
1587
+ 'gasenergy': 'gasenergy_mesh'
1588
+ }
1589
+
1590
+ # Detect candidate fields present in the DataFrame
1591
+ candidates = [c for c in df_names if c in ('gasdens_mesh', 'gasv_mesh', 'gasenergy_mesh')]
1592
+
1593
+ # Resolve user-requested field
1594
+ if field is not None:
1595
+ # allow short names or full column names
1596
+ if field in field_map:
1597
+ field_col = field_map[field]
754
1598
  else:
755
- # Trivial escalar case: parallelization over the single snapshot
756
- result = Parallel(n_jobs=-1)([delayed(eval_single)()])
757
- result = result[0]
758
- return result.item() if is_scalar else result.reshape(result_shape)
759
-
760
- # --- Case 2: Two snapshots, linear temporal interpolation ---
761
- elif n_snaps == 2:
762
- idx, idx_after = 0, 1
763
- t0, t1 = times[idx], times[idx_after]
764
- factor = (time - t0) / (t1 - t0) if abs(t1 - t0) > 1e-10 else 0
765
- factor = max(0, min(factor, 1))
766
- def temporal_interp(component=None):
767
- val0 = interp(idx, field_name, component)
768
- val1 = interp(idx_after, field_name, component)
769
- return (1 - factor) * val0 + factor * val1
770
- if field_name == "gasv_mesh":
771
- components = 3 if var3 is not None else 2 if var2 is not None else 1
772
- results = Parallel(n_jobs=-1)(
773
- delayed(temporal_interp)(i) for i in range(components)
774
- )
775
- return np.array([res.item() if is_scalar else res.reshape(result_shape) for res in results])
1599
+ field_col = field
1600
+ if field_col not in df_names:
1601
+ raise ValueError(f"Requested field '{field}' not present. Available: {candidates}")
1602
+ else:
1603
+ if len(candidates) == 1:
1604
+ field_col = candidates[0]
776
1605
  else:
777
- # Escalar: paralelización sobre ambos snapshots
778
- results = Parallel(n_jobs=2)(
779
- delayed(temporal_interp)() for _ in range(1)
1606
+ raise ValueError(
1607
+ f"Multiple fields present {candidates}. Specify which to plot using field='gasdens' or 'gasv'."
780
1608
  )
781
- result = results[0]
782
- return result.item() if is_scalar else result.reshape(result_shape)
783
1609
 
784
- # --- Case 3: More than two snapshots, spline temporal interpolation ---
1610
+ # Extract the mesh grids and field data after slicing
1611
+ var1 = self.df['var1_mesh'][t]
1612
+ var2 = self.df['var2_mesh'][t]
1613
+ var3 = self.df['var3_mesh'][t]
1614
+
1615
+ # Load the original field (before slicing) if needed elsewhere
1616
+ d3 = self.sim._load_field_raw('gasdens', snapshot=int(self.df['snapshot'][t]), field_type='scalar')
1617
+
1618
+ # Prepare field_data according to selected field
1619
+ raw_field = self.df[field_col][t]
1620
+ is_vector = (field_col == 'gasv_mesh')
1621
+ if is_vector:
1622
+ # choose component or compute magnitude if needed
1623
+ data_arr = np.asarray(raw_field)
1624
+ # handle common memory layouts: (3, ... ) or (..., 3)
1625
+ if data_arr.ndim >= 1 and data_arr.shape[0] == 3:
1626
+ # shape (3, N...) -> select component
1627
+ field_data = data_arr[comp]
1628
+ elif data_arr.ndim >= 1 and data_arr.shape[-1] == 3:
1629
+ field_data = data_arr[..., comp]
1630
+ else:
1631
+ # fallback: try to interpret as magnitude
1632
+ try:
1633
+ field_data = np.sqrt(np.sum(data_arr**2, axis=0))
1634
+ except Exception:
1635
+ field_data = data_arr
1636
+ # do not apply log to vector components
785
1637
  else:
786
- def eval_all_snaps(component=None):
787
- return Parallel(n_jobs=-1)(
788
- delayed(interp)(i, field_name, component) for i in range(n_snaps)
789
- )
790
- if field_name == "gasv_mesh":
791
- components = 3 if var3 is not None else 2 if var2 is not None else 1
792
- results = []
793
- for comp in range(components):
794
- values = eval_all_snaps(component=comp)
795
- values = np.stack([v if not is_scalar else np.array([v]) for v in values], axis=0)
796
- f = interp1d(times, values, axis=0, kind='linear', bounds_error=False, fill_value=np.nan)
797
- res = f(time)
798
- results.append(res.item() if is_scalar else res.reshape(result_shape))
799
- return np.array(results)
1638
+ # scalar field: apply safe log10 for plotting
1639
+ field_data = np.asarray(raw_field)
1640
+ # avoid log of non-positive numbers
1641
+ with np.errstate(divide='ignore'):
1642
+ field_data = np.log10(np.where(field_data <= 0, np.nan, field_data))
1643
+
1644
+ # Get the shapes of the resulting mesh grids after applying the slice
1645
+ sliced_shape = var1.shape # Assuming var1, var2, var3 have the same shape after slicing
1646
+
1647
+ # Detect fixed angles in slice string
1648
+ slice_str = self.slice if hasattr(self, 'slice') and self.slice is not None else ""
1649
+ # Fixed theta: e.g. 'theta=1.56' (not theta=[...])
1650
+ fixed_theta = re.search(r'theta\s*=\s*([^\[\],]+)', slice_str)
1651
+ fixed_phi = re.search(r'phi\s*=\s*([^\[\],]+)', slice_str)
1652
+
1653
+ if fixed_theta:
1654
+ plane = 'XY'
1655
+ elif fixed_phi:
1656
+ plane = 'XZ'
1657
+ else:
1658
+ plane = 'XY' # Default/fallback
1659
+ # Check the number of dimensions in the sliced shape
1660
+
1661
+
1662
+ if len(sliced_shape) == 3:
1663
+ var1_flat = var1.ravel()
1664
+ var2_flat = var2.ravel()
1665
+ var3_flat = var3.ravel()
1666
+ data = field_data.ravel()
1667
+
1668
+ fig = plt.figure(figsize=(8, 6))
1669
+ ax = fig.add_subplot(111, projection='3d')
1670
+ scatter = ax.scatter(var1_flat, var2_flat, var3_flat, c=data, cmap='Spectral_r', s=5)
1671
+ cbar = fig.colorbar(scatter, ax=ax)
1672
+ cbar.ax.tick_params(labelsize=12)
1673
+ if is_vector:
1674
+ cbar.set_label(rf"{component} [units]", size=14)
800
1675
  else:
801
- values = eval_all_snaps()
802
- values = np.stack([v if not is_scalar else np.array([v]) for v in values], axis=0)
803
- f = interp1d(times, values, axis=0, kind='linear', bounds_error=False, fill_value=np.nan)
804
- result = f(time)
805
- return result.item() if is_scalar else result.reshape(result_shape)
1676
+ cbar.set_label(rf"$\log_{{10}}(field)$", size=14)
1677
+ ax.set_xlabel("X",size=14)
1678
+ ax.set_ylabel("Y",size=14)
1679
+ ax.set_zlabel("Z",size=14)
1680
+
1681
+ fp.Plot.fargopy_mark(ax)
1682
+ plt.show()
1683
+
1684
+
1685
+ elif len(sliced_shape) == 2:
1686
+ # Optional smoothing to remove interpolation artefacts (triangular edges)
1687
+ if smoothing_sigma is not None:
1688
+ try:
1689
+ field_data = gaussian_filter(field_data, sigma=smoothing_sigma)
1690
+ except Exception:
1691
+ # If smoothing fails, fall back to original data
1692
+ field_data = field_data
1693
+ if plane == 'XY':
1694
+ fig, ax = plt.subplots(figsize=(8, 6))
1695
+ mesh = ax.pcolormesh(var1, var2, field_data, shading='auto', cmap='Spectral_r')
1696
+ cbar = fig.colorbar(mesh)
1697
+ cbar.ax.tick_params(labelsize=12)
1698
+ if is_vector:
1699
+ cbar.set_label(rf"{component} [units]", size=14)
1700
+ else:
1701
+ cbar.set_label(rf"$\log_{{10}}(field)$", size=14)
1702
+ ax.set_xlabel("X",size=14)
1703
+ ax.set_ylabel("Y",size=14)
1704
+ ax.tick_params(axis='both', which='major', labelsize=12)
1705
+
1706
+ fp.Plot.fargopy_mark(ax)
1707
+ plt.show()
1708
+ elif plane == 'XZ':
1709
+ fig, ax = plt.subplots(figsize=(8, 6))
1710
+ mesh = ax.pcolormesh(var1, var3, field_data, shading='auto', cmap='Spectral_r')
1711
+ cbar = fig.colorbar(mesh)
1712
+ cbar.ax.tick_params(labelsize=12)
1713
+ if is_vector:
1714
+ cbar.set_label(rf"{component} [units]", size=14)
1715
+ else:
1716
+ cbar.set_label(rf"$\log_{{10}}(field)$", size=14)
1717
+ ax.set_xlabel("X",size=14)
1718
+ ax.set_ylabel("Z",size=14)
1719
+ ax.tick_params(axis='both', which='major', labelsize=12)
1720
+
1721
+ fp.Plot.fargopy_mark(ax)
1722
+ plt.show()
1723
+ else:
1724
+ fig, ax = plt.subplots(figsize=(8, 6))
1725
+ mesh = ax.pcolormesh(var1, var2, field_data, shading='auto', cmap='Spectral_r')
1726
+ cbar = fig.colorbar(mesh)
1727
+ cbar.ax.tick_params(labelsize=12)
1728
+ cbar.set_label(rf"$\log_{{10}}(field)$", size=14)
1729
+ ax.set_xlabel("X",size=14)
1730
+ ax.set_ylabel("Y",size=14)
1731
+ ax.tick_params(axis='both', which='major', labelsize=12)
1732
+
1733
+ fp.Plot.fargopy_mark(ax)
1734
+ plt.show()
1735
+
1736
+ def cut_sector(self, xc, yc, zc, rc, hc, dataframe=None):
1737
+ """
1738
+ Filter the DataFrame to keep only data inside a cylinder of radius rc and height hc
1739
+ centered at (xc, yc, zc). Returns a new filtered DataFrame.
1740
+
1741
+ Parameters
1742
+ ----------
1743
+ xc, yc, zc : float
1744
+ Center coordinates of the cylinder.
1745
+ rc : float
1746
+ Cylinder radius.
1747
+ hc : float
1748
+ Cylinder height.
1749
+ dataframe : pd.DataFrame, optional
1750
+ DataFrame to filter (default: self.df).
1751
+
1752
+ Returns
1753
+ -------
1754
+ pd.DataFrame
1755
+ Filtered DataFrame with only points inside the cylinder.
1756
+ """
1757
+ if dataframe is None:
1758
+ if self.df is None:
1759
+ raise ValueError("No DataFrame loaded. Run load_data() first or pass a DataFrame.")
1760
+ dataframe = self.df
1761
+
1762
+ df = dataframe.copy()
1763
+ # Assume mesh columns are named 'var1_mesh', 'var2_mesh', 'var3_mesh'
1764
+ mask_list = []
1765
+ for idx, row in df.iterrows():
1766
+
1767
+ x = np.array(row['var1_mesh'])
1768
+ y = np.array(row['var2_mesh'])
1769
+ z = np.array(row['var3_mesh'])
1770
+ # Compute boolean mask selecting points inside the cylinder
1771
+ r_xy = np.sqrt((x - xc)**2 + (y - yc)**2)
1772
+ z_min = zc - hc/2
1773
+ z_max = zc + hc/2
1774
+ mask = (r_xy <= rc) & (z >= z_min) & (z <= z_max)
1775
+ # Si el campo es escalar
1776
+ filtered = {}
1777
+ filtered['snapshot'] = row['snapshot']
1778
+ filtered['time'] = row['time']
1779
+ filtered['var1_mesh'] = x[mask]
1780
+ filtered['var2_mesh'] = y[mask]
1781
+ filtered['var3_mesh'] = z[mask]
1782
+ # Filter the corresponding field columns
1783
+ for col in df.columns:
1784
+ if col.endswith('_mesh') and col not in ['var1_mesh', 'var2_mesh', 'var3_mesh']:
1785
+ data = np.array(row[col])
1786
+ filtered[col] = data[mask]
1787
+ mask_list.append(filtered)
1788
+ # Convert the list of dicts to a DataFrame
1789
+ filtered_df = pd.DataFrame(mask_list)
1790
+ return filtered_df