fargopy 0.3.15__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fargopy/fields.py CHANGED
@@ -6,16 +6,44 @@ import fargopy
6
6
  ###############################################################
7
7
  # Required packages
8
8
  ###############################################################
9
+ import os
9
10
  import numpy as np
10
11
  import re
12
+ import re
13
+ import pandas as pd
14
+
15
+ import matplotlib.pyplot as plt
16
+ import plotly.figure_factory as ff
17
+ from plotly.subplots import make_subplots
18
+ import plotly.graph_objects as go
19
+ from matplotlib.animation import FFMpegWriter
20
+ from scipy.interpolate import RBFInterpolator
21
+ from scipy.interpolate import interp1d
22
+ from scipy.interpolate import LinearNDInterpolator
23
+ from scipy.spatial import cKDTree
24
+
25
+
26
+ from joblib import Parallel, delayed, parallel_config
27
+
28
+
29
+ from ipywidgets import interact, FloatSlider, IntSlider
30
+ from celluloid import Camera
31
+ from IPython.display import HTML, Video
32
+
33
+ from scipy.interpolate import griddata
34
+ from scipy.integrate import solve_ivp
35
+ from tqdm import tqdm
36
+ from pathlib import Path
37
+ import fargopy as fp
38
+ from scipy.ndimage import gaussian_filter
11
39
 
12
40
  ###############################################################
13
41
  # Constants
14
42
  ###############################################################
15
43
  # Map of coordinates into FARGO3D coordinates
16
- """This dictionary maps the coordinates regular names (r, phi, theta, etc.) of
17
- different coordinate systems into the FARGO3D x, y, z
18
- """
44
+ # This dictionary maps the coordinates regular names (r, phi, theta, etc.) of
45
+ # different coordinate systems into the FARGO3D x, y, z
46
+
19
47
  COORDS_MAP = dict(
20
48
  cartesian = dict(x='x',y='y',z='z'),
21
49
  cylindrical = dict(phi='x',r='y',z='z'),
@@ -26,21 +54,54 @@ COORDS_MAP = dict(
26
54
  # Classes
27
55
  ###############################################################
28
56
  class Field(fargopy.Fargobj):
29
- """Fields:
30
-
31
- Attributes:
32
- coordinates: type of coordinates (cartesian, cylindrical, spherical)
33
- data: numpy arrays with data of the field
34
-
35
- Methods:
36
- slice: get an slice of a field along a given spatial direction.
37
- Examples:
38
- >>> density.slice(r=0.5) # Take the closest slice to r = 0.5
39
- >>> density.slice(ir=20) # Take the slice through the 20 shell
40
- >>> density.slice(phi=30*RAD,interp='nearest') # Take a slice interpolating to the nearest
57
+ """Represents a simulation field (scalar or vector) with coordinate system and domain information.
58
+
59
+ The ``Field`` object encapsulates the data arrays and associated
60
+ coordinate meshes for a specific simulation variable. It supports
61
+ slicing, coordinate transformation and simple visualization
62
+ helpers.
63
+
64
+ Attributes
65
+ ----------
66
+ data : np.ndarray
67
+ Numpy array containing the physical data.
68
+ coordinates : str
69
+ Coordinate system type ('cartesian', 'cylindrical', 'spherical').
70
+ domains : object
71
+ Object containing domain-specific geometry (e.g., mesh arrays).
72
+ type : str
73
+ Field type ('scalar' or 'vector').
74
+
75
+ Examples
76
+ --------
77
+ Load a field from a simulation object (returns a Field instance
78
+ if interpolation is disabled or a FieldInterpolator otherwise):
79
+
80
+ >>> fp.Simulation.load_field(fields='gasdens', snapshot=0, interpolate=False)
81
+
82
+ Access data and mesh:
83
+
84
+ >>> rho = field.data
85
+ >>> xmesh = field.mesh.x
41
86
  """
42
87
 
43
88
  def __init__(self,data=None,coordinates='cartesian',domains=None,type='scalar',**kwargs):
89
+ """
90
+ Initialize a Field object.
91
+
92
+ Parameters
93
+ ----------
94
+ data : np.ndarray, optional
95
+ Field data array.
96
+ coordinates : str, optional
97
+ Coordinate system ('cartesian', 'cylindrical', 'spherical').
98
+ domains : object, optional
99
+ Domain information for each coordinate.
100
+ type : str, optional
101
+ Field type ('scalar' or 'vector').
102
+ **kwargs : dict
103
+ Additional keyword arguments.
104
+ """
44
105
  super().__init__(**kwargs)
45
106
  self.data = data
46
107
  self.coordinates = coordinates
@@ -48,9 +109,34 @@ class Field(fargopy.Fargobj):
48
109
  self.type = type
49
110
 
50
111
  def meshslice(self,slice=None,component=0,verbose=False):
51
- """Perform a slice on a field and produce as an output the
52
- corresponding field slice and the associated matrices of
53
- coordinates for plotting.
112
+ """Perform a slice on a field and produce the corresponding field slice and
113
+ associated coordinate matrices for plotting.
114
+
115
+ Parameters
116
+ ----------
117
+ slice : str
118
+ Slice definition string (e.g., 'z=0').
119
+ component : int, optional
120
+ Component index for vector fields (default: 0).
121
+ verbose : bool, optional
122
+ If True, print debug information.
123
+
124
+ Returns
125
+ -------
126
+ tuple
127
+ (sliced field, mesh dictionary with coordinates). The mesh dictionary
128
+ contains coordinate arrays (x, y, z, r, phi, theta) corresponding
129
+ to the slice.
130
+
131
+ Examples
132
+ --------
133
+ Slice a field at z=0:
134
+
135
+ >>> field_slice, mesh = field.meshslice(slice='z=0')
136
+
137
+ Plot the slice:
138
+
139
+ >>> plt.pcolormesh(mesh.x, mesh.y, field_slice)
54
140
  """
55
141
  # Analysis of the slice
56
142
  if slice is None:
@@ -60,7 +146,7 @@ class Field(fargopy.Fargobj):
60
146
  slice = slice.replace('deg','*fargopy.DEG')
61
147
 
62
148
  # Perform the slice
63
- slice_cmd = f"self.slice({slice},pattern=True,verbose={verbose})"
149
+ slice_cmd = f"self._slice({slice},pattern=True,verbose={verbose})"
64
150
  slice,pattern = eval(slice_cmd)
65
151
 
66
152
  # Create the mesh
@@ -99,36 +185,36 @@ class Field(fargopy.Fargobj):
99
185
 
100
186
  return slice,mesh
101
187
 
102
- def slice(self,verbose=False,pattern=False,**kwargs):
103
- """Extract an slice of a 3-dimensional FARGO3D field
104
-
105
- Parameters:
106
- quiet: boolean, default = False:
107
- If True extract the slice quietly.
108
- Else, print some control messages.
109
-
110
- pattern: boolean, default = False:
111
- If True return the pattern of the slice, eg. [:,:,:]
112
-
113
- ir, iphi, itheta, ix, iy, iz: string or integer:
114
- Index or range of indexes of the corresponding coordinate.
188
+ def _slice(self,verbose=False,pattern=False,**kwargs):
189
+ """
190
+ Extract a slice of a 3-dimensional FARGO3D field.
115
191
 
116
- r, phi, theta, x, y, z: float/list/tuple:
117
- Value for slicing. The slicing search for the closest
118
- value in the domain.
192
+ Parameters
193
+ ----------
194
+ verbose : bool, optional
195
+ If True, print debug information.
196
+ pattern : bool, optional
197
+ If True, return the pattern of the slice (e.g., [:,:,:]).
198
+ ir, iphi, itheta, ix, iy, iz : int or str, optional
199
+ Index or range of indexes for the corresponding coordinate.
200
+ r, phi, theta, x, y, z : float, list, or tuple, optional
201
+ Value or range for slicing. The closest value in the domain is used.
119
202
 
120
- Returns:
121
- slice: sliced field.
203
+ Returns
204
+ -------
205
+ np.ndarray or tuple
206
+ Sliced field, and optionally the pattern string if pattern=True.
122
207
 
123
- Examples:
124
- # 0D: Get the value of the field in iphi = 0, itheta = -1 and close to r = 0.82
125
- >>> gasvz.slice(iphi=0,itheta=-1,r=0.82)
208
+ Examples
209
+ --------
210
+ # 0D: Get the value of the field at iphi=0, itheta=-1, and close to r=0.82
211
+ >>> gasvz.slice(iphi=0, itheta=-1, r=0.82)
126
212
 
127
- # 1D: Get all values of the field in radial direction at iphi = 0, itheta = -1
128
- >>> gasvz.slice(iphi=0,itheta=-1)
213
+ # 1D: Get all values in radial direction at iphi=0, itheta=-1
214
+ >>> gasvz.slice(iphi=0, itheta=-1)
129
215
 
130
- # 2D: Get all values of the field for values close to phi = 0
131
- >>> gasvz.slice(phi=0)
216
+ # 2D: Get all values for values close to phi=0
217
+ >>> gasvz.slice(phi=0)
132
218
  """
133
219
  # By default slice
134
220
  ivar = dict(x=':',y=':',z=':')
@@ -195,6 +281,20 @@ class Field(fargopy.Fargobj):
195
281
  return slice
196
282
 
197
283
  def to_cartesian(self):
284
+ """
285
+ Convert the field to cartesian coordinates.
286
+
287
+ Returns
288
+ -------
289
+ Field or tuple of Field
290
+ The field in cartesian coordinates. For scalar fields, returns the field itself.
291
+ For vector fields, returns a tuple (vx, vy, vz).
292
+
293
+ Examples
294
+ --------
295
+ >>> v = sim.load_field('gasv', snapshot=0)
296
+ >>> vx, vy, vz = v.to_cartesian()
297
+ """
198
298
  if self.type == 'scalar':
199
299
  # Scalar fields are invariant under coordinate transformations
200
300
  return self
@@ -234,11 +334,1457 @@ class Field(fargopy.Fargobj):
234
334
  Field(vz,coordinates=self.coordinates,domains=self.domains,type='scalar'))
235
335
 
236
336
  def get_size(self):
337
+ """
338
+ Return the size of the field data in megabytes (MB).
339
+
340
+ Returns
341
+ -------
342
+ float
343
+ Size in MB.
344
+ """
237
345
  return self.data.nbytes/1024**2
238
346
 
239
347
  def __str__(self):
348
+ """
349
+ String representation of the field data.
350
+
351
+ Returns
352
+ -------
353
+ str
354
+ """
240
355
  return str(self.data)
241
356
 
242
357
  def __repr__(self):
358
+ """
359
+ String representation of the field data.
360
+
361
+ Returns
362
+ -------
363
+ str
364
+ """
243
365
  return str(self.data)
244
366
 
367
+
368
+ # ###############################################################
369
+ # FieldInterpolator
370
+ # ###############################################################
371
+ # This class is used to load and interpolate fields from a FARGO3D simulation.
372
+ # It provides methods to load data, create meshes, and perform interpolation.
373
+ # It also handles 2D and 3D data loading based on the provided parameters.
374
+ #################################################################
375
+
376
+
377
+ class FieldInterpolator:
378
+ """Loads and interpolates fields from a FARGO3D simulation.
379
+
380
+ The ``FieldInterpolator`` facilitates loading, slicing, and interpolating
381
+ simulation data across multiple snapshots and fields. It handles coordinate
382
+ transformations and dimensionality reduction based on slice definitions.
383
+
384
+ Attributes
385
+ ----------
386
+ sim : Simulation
387
+ The simulation object.
388
+ df : pd.DataFrame or None
389
+ DataFrame containing loaded field data organized by snapshot and time.
390
+ snapshot_time_table : pd.DataFrame or None
391
+ Table mapping snapshots to normalized time.
392
+ snapshot : list or None
393
+ List of loaded snapshots.
394
+ slice : str or None
395
+ Slice definition string used to load the data.
396
+ dim : int or None
397
+ Dimensionality of the loaded data (e.g., 2 for a 2D slice).
398
+
399
+ Examples
400
+ --------
401
+ Load multiple fields from a snapshot with interpolation enabled:
402
+
403
+ >>> data = sim.load_field(fields=['gasdens', 'gasv'], snapshot=4)
404
+
405
+ Load a specific slice:
406
+
407
+ >>> dens = sim.load_field(fields='gasdens', slice='r=[0.8,1.2],phi=[-25 deg,25 deg],theta=1.56', snapshot=4)
408
+ """
409
+
410
+ def __init__(self, sim, df=None):
411
+ """
412
+ Initialize a FieldInterpolator.
413
+
414
+ Parameters
415
+ ----------
416
+ sim : Simulation
417
+ The simulation object.
418
+ df : pd.DataFrame, optional
419
+ DataFrame with preloaded field data.
420
+ """
421
+ self.sim = sim
422
+ self.snapshot_time_table = None
423
+ self.snapshot = None
424
+ self.slice = None
425
+ self.dim=None
426
+ self.df = df
427
+ self._domain_limits = None
428
+ self._df_sorted_cache = None
429
+ self._slice_type = None
430
+ self._slice_ranges = None
431
+
432
+ def _reset_caches(self):
433
+ """Clear cached dataframe and slice metadata before loading or evaluating new data."""
434
+ self._df_sorted_cache = None
435
+ self._slice_type = None
436
+ self._slice_ranges = None
437
+
438
+ def _cache_domain_limits(self):
439
+ """Cache domain extrema for r, theta, and phi to avoid repeated property access."""
440
+ if self._domain_limits is not None:
441
+ return
442
+ dom = self.sim.domains
443
+ self._domain_limits = dict(
444
+ r=(dom.r.min(), dom.r.max()),
445
+ theta=(dom.theta.min(), dom.theta.max()),
446
+ phi=(dom.phi.min(), dom.phi.max())
447
+ )
448
+
449
+ def _detect_slice_type(self, slice_str):
450
+ """Return the canonical slice type ('theta', 'phi', 'r', or None) inferred from the user string."""
451
+ if not slice_str:
452
+ return None
453
+ txt = slice_str.replace(" ", "").lower()
454
+ m_theta = re.search(r"theta=([^\[\],]+)(?![\]])", txt)
455
+ m_phi = re.search(r"phi=([^\[\],]+)(?![\]])", txt)
456
+ if m_theta and not re.search(r"theta=\[", txt) and m_phi and not re.search(r"phi=\[", txt):
457
+ return "r"
458
+ if m_theta and not re.search(r"theta=\[", txt):
459
+ return "theta"
460
+ if m_phi and not re.search(r"phi=\[", txt):
461
+ return "phi"
462
+ return None
463
+
464
+ def _parse_slice_ranges(self, slice_str):
465
+ """Parse the slice expression into numeric (r, theta, phi, z) bounds expressed in radians when needed."""
466
+ ranges = {"r": None, "theta": None, "phi": None, "z": None}
467
+ if not slice_str:
468
+ return ranges
469
+ txt = slice_str.lower()
470
+
471
+ def _to_float(value):
472
+ value = value.strip()
473
+ match = re.match(r"(-?\d+(?:\.\d+)?)\s*deg", value)
474
+ return np.deg2rad(float(match.group(1))) if match else float(value)
475
+
476
+ range_pattern = re.compile(r"(r|theta|phi|z)=\[(.+?)\]")
477
+ value_pattern = re.compile(r"(r|theta|phi|z)=([^\[\],]+)")
478
+
479
+ for key, vals in range_pattern.findall(txt):
480
+ lo, hi = [_to_float(v) for v in vals.split(",")]
481
+ ranges[key] = (min(lo, hi), max(lo, hi))
482
+ for key, val in value_pattern.findall(txt):
483
+ if ranges.get(key) is None:
484
+ parsed = _to_float(val)
485
+ ranges[key] = (parsed, parsed)
486
+ return ranges
487
+
488
+ def _get_sorted_dataframe(self, dataframe):
489
+ """Return the dataframe sorted by normalized time, reusing a cached copy when possible."""
490
+ if self._df_sorted_cache and self._df_sorted_cache[0] is dataframe:
491
+ return self._df_sorted_cache[1]
492
+ df_sorted = dataframe.sort_values("time")
493
+ self._df_sorted_cache = (dataframe, df_sorted)
494
+ return df_sorted
495
+
496
+ def __getattr__(self, name):
497
+ """
498
+ Delegate attribute access to the internal DataFrame if present.
499
+
500
+ Parameters
501
+ ----------
502
+ name : str
503
+ Attribute name.
504
+
505
+ Returns
506
+ -------
507
+ object
508
+ Attribute from the DataFrame if available.
509
+
510
+ Raises
511
+ ------
512
+ AttributeError
513
+ If the attribute is not found.
514
+ """
515
+ df = object.__getattribute__(self, 'df')
516
+ if df is not None and hasattr(df, name):
517
+ return getattr(df, name)
518
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
519
+
520
+ def _run_parallel(self, tasks, backend='threading'):
521
+ tasks = list(tasks)
522
+ if not tasks:
523
+ return []
524
+ with parallel_config(n_jobs=-1, prefer='threads'):
525
+ return Parallel(backend=backend)(tasks)
526
+
527
+ def load_data(self, fields=None, slice=None, snapshots=1, cut=None, coords='cartesian'):
528
+ """Load one or multiple fields into a unified DataFrame.
529
+
530
+ This method loads simulation data for the specified fields and snapshots,
531
+ potentially applying a spatial slice or cut. The data is stored in
532
+ an internal DataFrame (`self.df`) for further processing or interpolation.
533
+
534
+ Parameters
535
+ ----------
536
+ fields : str or list of str
537
+ Name(s) of the fields to load (e.g., 'gasdens', 'gasv').
538
+ slice : str, optional
539
+ Slice definition (e.g., 'z=0', 'theta=1.57').
540
+ snapshots : int or list of int, optional
541
+ Snapshot number(s) to load. Can be a single integer or a range [start, end].
542
+ cut : list, optional
543
+ Geometric cut parameters (sphere or cylinder mask).
544
+ coords : str, optional
545
+ Coordinate system for vector transformation ('cartesian' by default).
546
+
547
+ Returns
548
+ -------
549
+ pd.DataFrame
550
+ The DataFrame containing the loaded data.
551
+
552
+ Examples
553
+ --------
554
+ Load density and velocity for snapshot 10:
555
+
556
+ >>> loader = fp.FieldInterpolator(sim)
557
+ >>> df = loader.load_data(fields=['gasdens', 'gasv'], snapshot=10)
558
+
559
+ Load a 2D slice at z=0 (midplane):
560
+
561
+ >>> df_slice = loader.load_data(fields='gasdens', slice='z=0', snapshot=10)
562
+ """
563
+
564
+ # -------------------------
565
+ # Validate arguments
566
+ # -------------------------
567
+ self._reset_caches()
568
+
569
+ if fields is None:
570
+ raise ValueError("You must specify at least one field.")
571
+ if isinstance(fields, str):
572
+ fields = [fields]
573
+
574
+ self.fields = fields
575
+ self.slice = slice
576
+
577
+ # Convert snapshot into list
578
+ if isinstance(snapshots, int):
579
+ snapshots = [snapshots]
580
+ self.snapshot = snapshots
581
+
582
+ # Detect dimensionality from the sliced data (if a slice is provided)
583
+ if slice is not None:
584
+ test_field = self.sim._load_field_raw('gasdens', snapshot=snapshots[0], field_type='scalar')
585
+ try:
586
+ data_slice, mesh = test_field.meshslice(slice=slice)
587
+ self.dim = len(np.array(data_slice).shape)
588
+ except Exception:
589
+ # Fallback: assume full 3D
590
+ self.dim = 3
591
+ else:
592
+ self.dim = 3
593
+
594
+ # Snapshot & time arrays
595
+ if len(snapshots) == 1:
596
+ snaps = snapshots
597
+ time_values = [0]
598
+ else:
599
+ snaps = np.arange(snapshots[0], snapshots[1] + 1)
600
+ time_values = np.linspace(0, 1, len(snaps))
601
+
602
+ # Store snapshot-time table
603
+ self.snapshot_time_table = pd.DataFrame({
604
+ "Snapshot": snaps,
605
+ "Normalized_time": time_values
606
+ })
607
+
608
+ # Slice handling
609
+ if not hasattr(self.sim, "domains") or self.sim.domains is None:
610
+ raise ValueError("Simulation domains are not loaded.")
611
+ self._cache_domain_limits()
612
+ self._slice_type = self._detect_slice_type(slice)
613
+ self._slice_ranges = self._parse_slice_ranges(slice)
614
+
615
+ # -------------------------
616
+ # Helper for rotation (phi slices)
617
+ # -------------------------
618
+ def _rotation(X, Y, Z, phi0):
619
+ X_rot = X * np.cos(phi0) + Y * np.sin(phi0)
620
+ Y_rot = -X * np.sin(phi0) + Y * np.cos(phi0)
621
+ return X_rot, Y_rot, Z.copy()
622
+
623
+ # =====================================================================
624
+ # ======================== 2D CASE ================================
625
+ # =====================================================================
626
+ if self.dim < 3:
627
+
628
+ # Collect rows and build DataFrame once to avoid repeated concat
629
+ rows = []
630
+
631
+ for i, snap in enumerate(snaps):
632
+
633
+ row = {'snapshot': snap, 'time': time_values[i]}
634
+ coords_assigned = False # Only assign var1/var2/var3 once
635
+
636
+ # Loop over requested fields
637
+ for field in fields:
638
+
639
+ # -----------------
640
+ # GASDENS 2D
641
+ # -----------------
642
+ if field == 'gasdens':
643
+ gasd = self.sim._load_field_raw('gasdens', snapshot=snap, field_type='scalar')
644
+ data_slice, mesh = gasd.meshslice(slice=slice)
645
+
646
+ # assign coordinates only once
647
+ if not coords_assigned:
648
+ if coords == 'cartesian':
649
+ # rotate if phi is fixed
650
+ try:
651
+ if np.all(mesh.phi.ravel() == mesh.phi.ravel()[0]):
652
+ phi0 = mesh.phi.ravel()[0]
653
+ x_rot, y_rot, z_rot = _rotation(mesh.x, mesh.y, mesh.z, phi0)
654
+ row['var1_mesh'] = x_rot
655
+ row['var2_mesh'] = y_rot
656
+ row['var3_mesh'] = z_rot
657
+ else:
658
+ row['var1_mesh'] = mesh.x
659
+ row['var2_mesh'] = mesh.y
660
+ row['var3_mesh'] = mesh.z
661
+ except Exception:
662
+ # Fallback if mesh lacks phi
663
+ row['var1_mesh'] = mesh.x
664
+ row['var2_mesh'] = mesh.y
665
+ row['var3_mesh'] = mesh.z
666
+ else:
667
+ # original coordinate names as defined in simulation
668
+ vnames = getattr(self.sim.vars, 'VARIABLES', ['x', 'y', 'z'])
669
+ row['var1_mesh'] = getattr(mesh, vnames[0])
670
+ row['var2_mesh'] = getattr(mesh, vnames[1])
671
+ row['var3_mesh'] = getattr(mesh, vnames[2])
672
+ coords_assigned = True
673
+
674
+ row['gasdens_mesh'] = data_slice
675
+
676
+ # -----------------
677
+ # GASV 2D
678
+ # -----------------
679
+ if field == 'gasv':
680
+ gasv_raw = self.sim._load_field_raw('gasv', snapshot=snap, field_type='vector')
681
+ if coords == 'cartesian':
682
+ gasvx, gasvy, gasvz = gasv_raw.to_cartesian()
683
+ v1, mesh = gasvx.meshslice(slice=slice)
684
+ v2, mesh = gasvy.meshslice(slice=slice)
685
+ v3, mesh = gasvz.meshslice(slice=slice)
686
+
687
+ if not coords_assigned:
688
+ row['var1_mesh'] = mesh.x
689
+ row['var2_mesh'] = mesh.y
690
+ row['var3_mesh'] = mesh.z
691
+ coords_assigned = True
692
+
693
+ row['gasv_mesh'] = np.array([v1, v2, v3])
694
+ else:
695
+ v_slice, mesh = gasv_raw.meshslice(slice=slice)
696
+ v1, v2, v3 = v_slice[0], v_slice[1], v_slice[2]
697
+ if not coords_assigned:
698
+ vnames = getattr(self.sim.vars, 'VARIABLES', ['x', 'y', 'z'])
699
+ row['var1_mesh'] = getattr(mesh, vnames[0])
700
+ row['var2_mesh'] = getattr(mesh, vnames[1])
701
+ row['var3_mesh'] = getattr(mesh, vnames[2])
702
+ coords_assigned = True
703
+ row['gasv_mesh'] = np.array([v1, v2, v3])
704
+
705
+ # -----------------
706
+ # GASENERGY 2D
707
+ # -----------------
708
+ if field == 'gasenergy':
709
+ gasen = self.sim._load_field_raw('gasenergy', snapshot=snap, field_type='scalar')
710
+ data_slice, mesh = gasen.meshslice(slice=slice)
711
+
712
+ if not coords_assigned:
713
+ if coords == 'cartesian':
714
+ row['var1_mesh'] = mesh.x
715
+ row['var2_mesh'] = mesh.y
716
+ row['var3_mesh'] = mesh.z
717
+ else:
718
+ vnames = getattr(self.sim.vars, 'VARIABLES', ['x', 'y', 'z'])
719
+ row['var1_mesh'] = getattr(mesh, vnames[0])
720
+ row['var2_mesh'] = getattr(mesh, vnames[1])
721
+ row['var3_mesh'] = getattr(mesh, vnames[2])
722
+ coords_assigned = True
723
+
724
+ row['gasenergy_mesh'] = data_slice
725
+
726
+ # collect row dicts and build DataFrame once
727
+ rows.append(row)
728
+
729
+ df_snapshots = pd.DataFrame(rows)
730
+ self.df = df_snapshots
731
+ return df_snapshots
732
+
733
+ # =====================================================================
734
+ # ======================== 3D CASE ================================
735
+ # =====================================================================
736
+ if self.dim == 3:
737
+
738
+ # Build full mesh
739
+ theta, r, phi = np.meshgrid(
740
+ self.sim.domains.theta,
741
+ self.sim.domains.r,
742
+ self.sim.domains.phi,
743
+ indexing='ij'
744
+ )
745
+ x = r * np.sin(theta) * np.cos(phi)
746
+ y = r * np.sin(theta) * np.sin(phi)
747
+ z = r * np.cos(theta)
748
+
749
+ # Apply spherical or cylindrical mask (optional)
750
+ if cut is not None:
751
+ if len(cut) == 5:
752
+ xc, yc, zc, rc, hc = cut
753
+ r_xy = np.sqrt((x - xc)**2 + (y - yc)**2)
754
+ zmin, zmax = zc - hc/2, zc + hc/2
755
+ mask = (r_xy <= rc) & (z >= zmin) & (z <= zmax)
756
+ elif len(cut) == 4:
757
+ xc, yc, zc, rs = cut
758
+ r_sph = np.sqrt((x - xc)**2 + (y - yc)**2 + (z - zc)**2)
759
+ mask = r_sph <= rs
760
+ else:
761
+ raise ValueError("cut must have 4 (sphere) or 5 (cylinder) elements.")
762
+ else:
763
+ mask = None
764
+
765
+ # Collect rows and build DataFrame once to avoid repeated concat
766
+ rows = []
767
+
768
+ for i, snap in enumerate(snaps):
769
+
770
+ row = {'snapshot': snap, 'time': time_values[i]}
771
+ coords_assigned = False
772
+
773
+ # Loop over requested fields
774
+ for field in fields:
775
+
776
+ # -----------------
777
+ # GASDENS 3D
778
+ # -----------------
779
+ if field == "gasdens":
780
+ gasd = self.sim._load_field_raw('gasdens', snapshot=snap, field_type='scalar')
781
+
782
+ if not coords_assigned:
783
+ if coords == 'cartesian':
784
+ if mask is not None:
785
+ row["var1_mesh"] = x[mask]
786
+ row["var2_mesh"] = y[mask]
787
+ row["var3_mesh"] = z[mask]
788
+ else:
789
+ row["var1_mesh"] = x
790
+ row["var2_mesh"] = y
791
+ row["var3_mesh"] = z
792
+ else:
793
+ # original coordinate variables order
794
+ v0, v1, v2 = self.sim.vars.VARIABLES
795
+ mapping = dict(r=r,phi=phi,theta=theta,x=x,y=y,z=z)
796
+ if mask is not None:
797
+ row["var1_mesh"] = mapping[v0][mask]
798
+ row["var2_mesh"] = mapping[v1][mask]
799
+ row["var3_mesh"] = mapping[v2][mask]
800
+ else:
801
+ row["var1_mesh"] = mapping[v0]
802
+ row["var2_mesh"] = mapping[v1]
803
+ row["var3_mesh"] = mapping[v2]
804
+ coords_assigned = True
805
+
806
+ row["gasdens_mesh"] = gasd.data[mask] if mask is not None else gasd.data
807
+ # -----------------
808
+ # GASV 3D
809
+ # -----------------
810
+ if field == "gasv":
811
+ gasv_raw = self.sim._load_field_raw('gasv', snapshot=snap, field_type='vector')
812
+ if coords == 'cartesian':
813
+ gasvx, gasvy, gasvz = gasv_raw.to_cartesian()
814
+
815
+ if not coords_assigned:
816
+ if mask is not None:
817
+ row["var1_mesh"] = x[mask]
818
+ row["var2_mesh"] = y[mask]
819
+ row["var3_mesh"] = z[mask]
820
+ else:
821
+ row["var1_mesh"] = x
822
+ row["var2_mesh"] = y
823
+ row["var3_mesh"] = z
824
+ coords_assigned = True
825
+
826
+ if mask is not None:
827
+ row["gasv_mesh"] = np.array([
828
+ gasvx.data[mask],
829
+ gasvy.data[mask],
830
+ gasvz.data[mask]
831
+ ])
832
+ else:
833
+ row["gasv_mesh"] = np.array([
834
+ gasvx.data,
835
+ gasvy.data,
836
+ gasvz.data
837
+ ])
838
+ else:
839
+ vdata = gasv_raw.data
840
+ if not coords_assigned:
841
+ v0, v1, v2 = self.sim.vars.VARIABLES
842
+ mapping = dict(r=r,phi=phi,theta=theta,x=x,y=y,z=z)
843
+ if mask is not None:
844
+ row["var1_mesh"] = mapping[v0][mask]
845
+ row["var2_mesh"] = mapping[v1][mask]
846
+ row["var3_mesh"] = mapping[v2][mask]
847
+ else:
848
+ row["var1_mesh"] = mapping[v0]
849
+ row["var2_mesh"] = mapping[v1]
850
+ row["var3_mesh"] = mapping[v2]
851
+ coords_assigned = True
852
+
853
+ if mask is not None:
854
+ row["gasv_mesh"] = np.array([vdata[0][mask], vdata[1][mask], vdata[2][mask]])
855
+ else:
856
+ row["gasv_mesh"] = np.array([vdata[0], vdata[1], vdata[2]])
857
+
858
+ # -----------------
859
+ # GASENERGY 3D
860
+ # -----------------
861
+ if field == "gasenergy":
862
+ gasen = self.sim._load_field_raw('gasenergy', snapshot=snap, field_type='scalar')
863
+
864
+ if not coords_assigned:
865
+ if coords == 'cartesian':
866
+ if mask is not None:
867
+ row["var1_mesh"] = x[mask]
868
+ row["var2_mesh"] = y[mask]
869
+ row["var3_mesh"] = z[mask]
870
+ else:
871
+ row["var1_mesh"] = x
872
+ row["var2_mesh"] = y
873
+ row["var3_mesh"] = z
874
+ else:
875
+ v0, v1, v2 = self.sim.vars.VARIABLES
876
+ mapping = dict(r=r,phi=phi,theta=theta,x=x,y=y,z=z)
877
+ if mask is not None:
878
+ row["var1_mesh"] = mapping[v0][mask]
879
+ row["var2_mesh"] = mapping[v1][mask]
880
+ row["var3_mesh"] = mapping[v2][mask]
881
+ else:
882
+ row["var1_mesh"] = mapping[v0]
883
+ row["var2_mesh"] = mapping[v1]
884
+ row["var3_mesh"] = mapping[v2]
885
+ coords_assigned = True
886
+
887
+ row["gasenergy_mesh"] = gasen.data[mask] if mask is not None else gasen.data
888
+
889
+ # collect row dicts and build DataFrame once
890
+ rows.append(row)
891
+
892
+ df_snapshots = pd.DataFrame(rows)
893
+ self.df = df_snapshots
894
+ return df_snapshots
895
+
896
+
897
+ def times(self):
898
+ """
899
+ Return the snapshot time table mapping snapshots to normalized time.
900
+
901
+ Returns
902
+ -------
903
+ pd.DataFrame
904
+ DataFrame with columns 'Snapshot' and 'Normalized_time'.
905
+
906
+ Raises
907
+ ------
908
+ ValueError
909
+ If no data has been loaded.
910
+ """
911
+ if self.snapshot_time_table is None:
912
+ raise ValueError("No data loaded. Run load_data() first.")
913
+ return self.snapshot_time_table
914
+
915
+ def create_mesh(
916
+ self,
917
+ slice=None,
918
+ nr=50,
919
+ ntheta=50,
920
+ nphi=50
921
+ ):
922
+ """
923
+ Create a mesh grid based on the slice definition provided by the user.
924
+ If no slice is provided, create a full 3D mesh within the simulation domain.
925
+
926
+ Parameters
927
+ ----------
928
+ slice : str, optional
929
+ The slice definition string (e.g., "r=[0.8,1.2],phi=0,theta=[0 deg,90 deg]").
930
+ nr : int
931
+ Number of divisions in r.
932
+ ntheta : int
933
+ Number of divisions in theta.
934
+ nphi : int
935
+ Number of divisions in phi.
936
+
937
+ Returns
938
+ -------
939
+ tuple
940
+ Mesh grid (x, y, z) based on the slice definition or the full domain.
941
+ """
942
+ import numpy as np
943
+ import re
944
+
945
+ # If no slice is provided, create a full 3D mesh using the simulation domains
946
+ if not slice:
947
+ r = np.linspace(self.sim.domains.r.min(), self.sim.domains.r.max(), nr)
948
+ theta = np.linspace(self.sim.domains.theta.min(), self.sim.domains.theta.max(), ntheta)
949
+ phi = np.linspace(self.sim.domains.phi.min(), self.sim.domains.phi.max(), nphi)
950
+ theta_grid, r_grid, phi_grid = np.meshgrid(theta, r, phi, indexing='ij')
951
+ x = r_grid * np.sin(theta_grid) * np.cos(phi_grid)
952
+ y = r_grid * np.sin(theta_grid) * np.sin(phi_grid)
953
+ z = r_grid * np.cos(theta_grid)
954
+ return x, y, z
955
+
956
+ # Initialize default ranges
957
+ r_range = [self.sim.domains.r.min(), self.sim.domains.r.max()]
958
+ theta_range = [self.sim.domains.theta.min(), self.sim.domains.theta.max()]
959
+ phi_range = [self.sim.domains.phi.min(), self.sim.domains.phi.max()]
960
+ z_value = None
961
+
962
+ # Regular expressions to extract parameters
963
+ range_pattern = re.compile(r"(\w+)=\[(.+?)\]") # Matches ranges like r=[0.8,1.2]
964
+ value_pattern = re.compile(r"(\w+)=([-\d.]+)") # Matches single values like phi=0 or z=0
965
+ degree_pattern = re.compile(r"([-\d.]+) deg") # Matches angles in degrees like -25 deg
966
+
967
+ # Process ranges
968
+ for match in range_pattern.finditer(slice):
969
+ key, values = match.groups()
970
+ values = [float(degree_pattern.sub(lambda m: str(float(m.group(1)) * np.pi / 180), v.strip())) for v in values.split(',')]
971
+ if key == 'r':
972
+ r_range = values
973
+ elif key == 'phi':
974
+ phi_range = values
975
+ elif key == 'theta':
976
+ theta_range = values
977
+
978
+ # Process single values
979
+ for match in value_pattern.finditer(slice):
980
+ key, value = match.groups()
981
+ value = float(degree_pattern.sub(lambda m: str(float(m.group(1)) * np.pi / 180), value))
982
+ if key == 'z':
983
+ z_value = value
984
+ elif key == 'phi':
985
+ phi_range = [value, value]
986
+ elif key == 'theta':
987
+ theta_range = [value, value]
988
+
989
+ # 3D mesh: all ranges are intervals
990
+ if (phi_range[0] != phi_range[1]) and (theta_range[0] != theta_range[1]):
991
+ r = np.linspace(r_range[0], r_range[1], nr)
992
+ theta = np.linspace(theta_range[0], theta_range[1], ntheta)
993
+ phi = np.linspace(phi_range[0], phi_range[1], nphi)
994
+ theta_grid, r_grid, phi_grid = np.meshgrid(theta, r, phi, indexing='ij')
995
+ x = r_grid * np.sin(theta_grid) * np.cos(phi_grid)
996
+ y = r_grid * np.sin(theta_grid) * np.sin(phi_grid)
997
+ z = r_grid * np.cos(theta_grid)
998
+ return x, y, z
999
+
1000
+ # 2D mesh: one angle is fixed (slice)
1001
+ elif phi_range[0] == phi_range[1]: # Slice at constant phi (XZ plane)
1002
+ r = np.linspace(r_range[0], r_range[1], nr)
1003
+ theta = np.linspace(theta_range[0], theta_range[1], ntheta)
1004
+ phi = phi_range[0]
1005
+ theta_grid, r_grid = np.meshgrid(theta, r, indexing='ij')
1006
+ x = r_grid * np.sin(theta_grid) * np.cos(phi)
1007
+ y = r_grid * np.sin(theta_grid) * np.sin(phi)
1008
+ z = r_grid * np.cos(theta_grid)
1009
+ return x, y, z
1010
+
1011
+ elif theta_range[0] == theta_range[1]: # Slice at constant theta (XY plane)
1012
+ r = np.linspace(r_range[0], r_range[1], nr)
1013
+ phi = np.linspace(phi_range[0], phi_range[1], nphi)
1014
+ theta = theta_range[0]
1015
+ phi_grid, r_grid = np.meshgrid(phi, r, indexing='ij')
1016
+ x = r_grid * np.sin(theta) * np.cos(phi_grid)
1017
+ y = r_grid * np.sin(theta) * np.sin(phi_grid)
1018
+ z = r_grid * np.cos(theta)
1019
+ return x, y, z
1020
+
1021
+ elif z_value is not None: # Slice at constant z (XY plane in cartesian)
1022
+ r = np.linspace(r_range[0], r_range[1], nr)
1023
+ phi = np.linspace(phi_range[0], phi_range[1], nphi)
1024
+ r_grid, phi_grid = np.meshgrid(r, phi, indexing='ij')
1025
+ x = r_grid * np.cos(phi_grid)
1026
+ y = r_grid * np.sin(phi_grid)
1027
+ z = np.full_like(x, z_value)
1028
+ return x, y, z
1029
+
1030
+ else:
1031
+ raise ValueError("Slice definition must include either 'z', 'phi', or 'theta'.")
1032
+
1033
+
1034
+
1035
+ def _domain_mask(self, xi, slice=None):
1036
+ """
1037
+ Build a boolean mask that keeps only coordinates within the simulation domain and
1038
+ enforces any user-specified radial/angle limits for XY (theta) or XZ (phi) slices.
1039
+ """
1040
+ self._cache_domain_limits()
1041
+ slice = slice or self.slice
1042
+ slice_ranges = self._slice_ranges or self._parse_slice_ranges(slice)
1043
+ r_bounds = slice_ranges.get('r')
1044
+ theta_bounds = slice_ranges.get('theta')
1045
+ phi_bounds = slice_ranges.get('phi')
1046
+ r_min, r_max = self._domain_limits['r']
1047
+ theta_min, theta_max = self._domain_limits['theta']
1048
+ phi_min, phi_max = self._domain_limits['phi']
1049
+ eps = 1e-7
1050
+ xi = np.asarray(xi)
1051
+ ndim = xi.shape[1] if xi.ndim > 1 else 1
1052
+
1053
+ def _bounded(vals, bounds, default):
1054
+ if bounds is None:
1055
+ return vals >= default[0] - eps, vals <= default[1] + eps
1056
+ lo, hi = bounds
1057
+ return vals >= lo - eps, vals <= hi + eps
1058
+
1059
+ def _phi_in_bounds(phi_vals):
1060
+ if phi_bounds is None:
1061
+ return np.ones_like(phi_vals, dtype=bool)
1062
+ lo, hi = phi_bounds
1063
+ if lo <= hi:
1064
+ return (phi_vals >= lo - eps) & (phi_vals <= hi + eps)
1065
+ return (phi_vals >= lo - eps) | (phi_vals <= hi + eps)
1066
+
1067
+ if ndim == 2:
1068
+ # XY plane: theta fixed
1069
+ if slice is not None and 'theta' in slice:
1070
+ # XY plane: z = 0, theta fixed, filter by r and phi
1071
+ x, y = xi[:, 0], xi[:, 1]
1072
+ r = np.sqrt(x**2 + y**2)
1073
+ phi = np.arctan2(y, x)
1074
+ r_ge, r_le = _bounded(r, r_bounds, (r_min, r_max))
1075
+ mask = r_ge & r_le & _phi_in_bounds(phi)
1076
+ return mask
1077
+
1078
+ # XZ plane: phi fixed
1079
+ elif slice is not None and 'phi' in slice:
1080
+ # XZ plane: y = 0, phi fixed, filter by r and theta
1081
+ x, z = xi[:, 0], xi[:, 1]
1082
+ r = np.sqrt(x**2 + z**2)
1083
+ theta = np.arccos(z / np.clip(r, 1e-14, None))
1084
+ r_ge, r_le = _bounded(r, r_bounds, (r_min, r_max))
1085
+ if theta_bounds:
1086
+ lo, hi = theta_bounds
1087
+ theta_mask = (theta >= lo - eps) & (theta <= hi + eps)
1088
+ else:
1089
+ theta_mask = (
1090
+ ((theta > theta_min) | np.isclose(theta, theta_min, atol=eps)) &
1091
+ ((theta < theta_max) | np.isclose(theta, theta_max, atol=eps))
1092
+ )
1093
+ return r_ge & r_le & theta_mask
1094
+ else:
1095
+ # Default: treat as XY (theta fixed)
1096
+ x, y = xi[:, 0], xi[:, 1]
1097
+ z = np.zeros_like(x)
1098
+ r = np.sqrt(x**2 + y**2 + z**2)
1099
+ phi = np.arctan2(y, x)
1100
+ mask = (
1101
+ (r > r_min) &
1102
+ (r < r_max) )
1103
+ return mask
1104
+
1105
+ elif ndim == 1:
1106
+ # 1D input: could be r, theta, or phi
1107
+ xi_1d = np.asarray(xi).ravel()
1108
+
1109
+ # Decide which variable is the "free" one in the 1D cut.
1110
+ # Prefer explicit ranges (r=[...], theta=[...], phi=[...]); otherwise,
1111
+ # the free variable is the one that does NOT appear as a scalar in the slice.
1112
+ r_b = r_bounds
1113
+ th_b = theta_bounds
1114
+ ph_b = phi_bounds
1115
+
1116
+ def _is_range(b):
1117
+ return (b is not None) and (abs(b[1] - b[0]) > 1e-12)
1118
+
1119
+ if _is_range(r_b):
1120
+ free = 'r'
1121
+ elif _is_range(th_b):
1122
+ free = 'theta'
1123
+ elif _is_range(ph_b):
1124
+ free = 'phi'
1125
+ else:
1126
+ s = (slice or self.slice) or ""
1127
+ s_low = s.lower()
1128
+ has_r = re.search(r"\br\s*=", s_low) is not None
1129
+ has_th = re.search(r"\btheta\s*=", s_low) is not None
1130
+ has_ph = re.search(r"\bphi\s*=", s_low) is not None
1131
+ # the free variable is the one not present in the slice specification
1132
+ if not has_r:
1133
+ free = 'r'
1134
+ elif not has_th:
1135
+ free = 'theta'
1136
+ elif not has_ph:
1137
+ free = 'phi'
1138
+ else:
1139
+ free = 'r'
1140
+
1141
+ # Build mask depending on which variable is free
1142
+ if free == 'r':
1143
+ lo, hi = (r_b if r_b is not None else (r_min, r_max))
1144
+ mask = (xi_1d >= lo - eps) & (xi_1d <= hi + eps)
1145
+ return mask
1146
+
1147
+ if free == 'theta':
1148
+ lo, hi = (th_b if th_b is not None else (theta_min, theta_max))
1149
+ mask = (xi_1d >= lo - eps) & (xi_1d <= hi + eps)
1150
+ return mask
1151
+
1152
+ # free == 'phi'
1153
+ lo, hi = (ph_b if ph_b is not None else (phi_min, phi_max))
1154
+ if lo <= hi:
1155
+ mask = (xi_1d >= lo - eps) & (xi_1d <= hi + eps)
1156
+ else:
1157
+ # wrap-around range (e.g. [5.5, 0.5] in radians)
1158
+ mask = (xi_1d >= lo - eps) | (xi_1d <= hi + eps)
1159
+ return mask
1160
+
1161
+ if ndim==3:
1162
+ mask = np.ones(xi.shape[0],dtype=bool)
1163
+ return mask
1164
+
1165
+ def evaluate(
1166
+ self, time, var1, var2=None, var3=None, dataframe=None,
1167
+ interpolator="griddata", method="linear",
1168
+ rbf_kwargs=None, griddata_kwargs=None, idw_kwargs=None,
1169
+ sigma_smooth=None, field=None, reflect=False
1170
+ ):
1171
+ """
1172
+ Evaluate the selected field at arbitrary spatial coordinates using
1173
+ multi-snapshot interpolation. Supports scalar and vector fields,
1174
+ 1D/2D/3D geometry, time interpolation, and several interpolation
1175
+ backends. Designed for unified DataFrames (gasdens + gasv + others).
1176
+
1177
+ Parameters
1178
+ ----------
1179
+ time : float or int
1180
+ Normalized time in [0,1] or snapshot index.
1181
+
1182
+ var1, var2, var3 : array-like or float
1183
+ Evaluation coordinates (x,y,z for 3D). Scalars are accepted.
1184
+
1185
+ dataframe : pandas.DataFrame, optional
1186
+ Custom DataFrame. If omitted, self.df is used.
1187
+
1188
+ interpolator : {"griddata","rbf","linearnd","idw"}
1189
+ Backend used for spatial interpolation.
1190
+
1191
+ method : str
1192
+ Kernel/method used by backend (e.g., 'linear' for griddata).
1193
+
1194
+ sigma_smooth : float or None
1195
+ Optional Gaussian smoothing.
1196
+
1197
+ field : {"gasdens","gasv","gasenergy"} or None
1198
+ Field to evaluate. If None and DF has >1 field → explicit error.
1199
+
1200
+ Returns
1201
+ -------
1202
+ ndarray or float
1203
+ Interpolated value(s). Vector fields return shape (3,N) or (3,...).
1204
+ """
1205
+
1206
+ # ===============================================================
1207
+ # Basic validation
1208
+ # ===============================================================
1209
+ if sigma_smooth is not None and sigma_smooth <= 0:
1210
+ raise ValueError("sigma_smooth must be None or positive.")
1211
+
1212
+ df = dataframe if dataframe is not None else self.df
1213
+ if df is None:
1214
+ raise ValueError("No dataframe available.")
1215
+
1216
+ # ===============================================================
1217
+ # FIELD SELECTION (safe and explicit)
1218
+ # ===============================================================
1219
+ field_map = {
1220
+ "gasdens": "gasdens_mesh",
1221
+ "gasv": "gasv_mesh",
1222
+ "gasenergy": "gasenergy_mesh"
1223
+ }
1224
+
1225
+ if field is not None:
1226
+ if field in field_map:
1227
+ field = field_map[field]
1228
+ if field not in df.columns:
1229
+ raise ValueError(
1230
+ f"Field '{field}' not in DF. Available: {list(df.columns)}"
1231
+ )
1232
+ else:
1233
+ # Autodetect only if exactly one exists
1234
+ candidates = [
1235
+ c for c in df.columns
1236
+ if c in ("gasdens_mesh","gasv_mesh","gasenergy_mesh")
1237
+ ]
1238
+ if len(candidates) != 1:
1239
+ raise ValueError(
1240
+ f"Multiple fields present {candidates}. "
1241
+ "Specify field='gasdens', 'gasv', or 'gasenergy'."
1242
+ )
1243
+ field = candidates[0]
1244
+
1245
+ # ===============================================================
1246
+ # Prepare snapshot ordering
1247
+ # ===============================================================
1248
+ df_sorted = self._get_sorted_dataframe(df)
1249
+ times = df_sorted["time"].values
1250
+ nsnaps = len(times)
1251
+
1252
+ # Detect scalar inputs
1253
+ is_scalar = (
1254
+ np.isscalar(var1)
1255
+ and (var2 is None or np.isscalar(var2))
1256
+ and (var3 is None or np.isscalar(var3))
1257
+ )
1258
+ result_shape = () if is_scalar else np.asarray(var1).shape
1259
+
1260
+ if np.isscalar(var1): var1 = np.array([var1])
1261
+ if np.isscalar(var2): var2 = np.array([var2])
1262
+ if np.isscalar(var3): var3 = np.array([var3])
1263
+
1264
+ # Convenience: allow calling evaluate(var1=..., var2=...) for
1265
+ # 2D XZ slices where the expected coordinates are (var1,var3).
1266
+ # If the slice type is not 'theta' (i.e. XZ) and the user passed
1267
+ # a value for var2 but left var3=None, treat var2 as var3.
1268
+ try:
1269
+ slice_type_tmp = self._slice_type or self._detect_slice_type(self.slice)
1270
+ except Exception:
1271
+ slice_type_tmp = None
1272
+ if self.dim == 2 and slice_type_tmp is not None and slice_type_tmp != 'theta':
1273
+ if var3 is None and var2 is not None:
1274
+ var3 = var2
1275
+ var2 = None
1276
+
1277
+ # ===============================================================
1278
+ # Smoothing helper
1279
+ # ===============================================================
1280
+ def _smooth(values):
1281
+ if sigma_smooth is None or np.isscalar(values):
1282
+ return values
1283
+
1284
+ arr = np.asarray(values)
1285
+ if arr.ndim == 0:
1286
+ return values
1287
+
1288
+ # Vector smoothing
1289
+ if field == "gasv_mesh" and arr.ndim >= 2:
1290
+ out = np.empty_like(arr)
1291
+ for k in range(arr.shape[0]):
1292
+ out[k] = gaussian_filter(arr[k], sigma=sigma_smooth)
1293
+ return out
1294
+
1295
+ return gaussian_filter(arr, sigma=sigma_smooth)
1296
+
1297
+ # ===============================================================
1298
+ # Interpolation backends
1299
+ # ===============================================================
1300
+
1301
+ def idw_interp(coords, values, xi):
1302
+ coords = np.asarray(coords)
1303
+ values = np.asarray(values).ravel()
1304
+ xi = np.asarray(xi)
1305
+
1306
+ mask = self._domain_mask(xi)
1307
+ if reflect:
1308
+ mask = np.ones(xi.shape[0], dtype=bool)
1309
+ out = np.zeros(xi.shape[0])
1310
+ tree = cKDTree(coords)
1311
+ k = idw_kwargs.get("k", 8)
1312
+ power = idw_kwargs.get("power", 2)
1313
+
1314
+ # If mask selects points, compute only there. Otherwise try for all xi.
1315
+ if np.any(mask):
1316
+ d, idxs = tree.query(xi[mask], k=k)
1317
+ d = np.where(d == 0, 1e-10, d)
1318
+ w = 1 / d**power
1319
+ w /= w.sum(axis=1, keepdims=True)
1320
+ out[mask] = np.sum(values[idxs] * w, axis=1)
1321
+ return out
1322
+
1323
+ # Fallback: compute for all xi
1324
+ d, idxs = tree.query(xi, k=k)
1325
+ d = np.where(d == 0, 1e-10, d)
1326
+ w = 1 / d**power
1327
+ w /= w.sum(axis=1, keepdims=True)
1328
+ out = np.sum(values[idxs] * w, axis=1)
1329
+ return out
1330
+
1331
+
1332
+ def rbf_interp(coords, values, xi):
1333
+ coords = np.asarray(coords)
1334
+ values = np.asarray(values).ravel()
1335
+ xi = np.asarray(xi)
1336
+
1337
+ mask = self._domain_mask(xi)
1338
+ if reflect:
1339
+ mask = np.ones(xi.shape[0], dtype=bool)
1340
+ out = np.full(xi.shape[0], np.nan)
1341
+
1342
+ # Try interpolate where mask True
1343
+ try:
1344
+ obj = RBFInterpolator(coords, values, kernel=method, **(rbf_kwargs or {}))
1345
+ except Exception:
1346
+ return np.zeros(xi.shape[0])
1347
+
1348
+ if np.any(mask):
1349
+ out[mask] = obj(xi[mask])
1350
+ # attempt to leave other positions as NaN
1351
+ return np.where(np.isfinite(out), out, np.nan)
1352
+
1353
+ # Fallback: evaluate on all xi
1354
+ vals_all = obj(xi)
1355
+ return np.where(np.isfinite(vals_all), vals_all, np.nan)
1356
+
1357
+
1358
+ def griddata_interp(coords, values, xi):
1359
+ # --- Apply domain mask: only interpolate inside the simulation domain ---
1360
+ mask = self._domain_mask(xi)
1361
+ if reflect:
1362
+ mask = np.ones(xi.shape[0], dtype=bool)
1363
+ out = np.full(xi.shape[0], np.nan)
1364
+
1365
+ # If mask has selected points, interpolate only there
1366
+ if np.any(mask):
1367
+ out[mask] = griddata(coords, values.ravel(), xi[mask], method=method)
1368
+ # leave outside as NaN -> caller can mask later
1369
+ return np.where(np.isfinite(out), out, np.nan)
1370
+
1371
+ # Fallback: try interpolate for all xi (useful when domain mask selection fails)
1372
+ try:
1373
+ vals_all = griddata(coords, values.ravel(), xi, method=method)
1374
+ return np.where(np.isfinite(vals_all), vals_all, np.nan)
1375
+ except Exception:
1376
+ return np.zeros(xi.shape[0])
1377
+
1378
+
1379
+ def linearnd_interp(coords, values, xi):
1380
+ coords = np.asarray(coords)
1381
+ values = np.asarray(values).ravel()
1382
+ xi = np.asarray(xi)
1383
+
1384
+ mask = self._domain_mask(xi)
1385
+ if reflect:
1386
+ mask = np.ones(xi.shape[0], dtype=bool)
1387
+ out = np.full(xi.shape[0], np.nan)
1388
+ obj = LinearNDInterpolator(coords, values)
1389
+
1390
+ if np.any(mask):
1391
+ out[mask] = obj(xi[mask])
1392
+ return np.where(np.isfinite(out), out, np.nan)
1393
+
1394
+ # Fallback: evaluate on all xi
1395
+ vals_all = obj(xi)
1396
+ return np.where(np.isfinite(vals_all), vals_all, np.zeros_like(vals_all))
1397
+
1398
+ # ===============================================================
1399
+ # Main interpolation kernel
1400
+ # ===============================================================
1401
+ slice_type = self._slice_type or self._detect_slice_type(self.slice)
1402
+
1403
+ def interp(idx, field_name, comp=None):
1404
+ row = df_sorted.iloc[idx]
1405
+
1406
+ cx = np.array(row["var1_mesh"])
1407
+ cy = np.array(row["var2_mesh"])
1408
+ cz = np.array(row["var3_mesh"])
1409
+
1410
+ # Build coordinate arrays
1411
+ if self.dim == 3:
1412
+ coords = np.column_stack((cx.ravel(), cy.ravel(), cz.ravel()))
1413
+ xi = np.column_stack((var1.ravel(), var2.ravel(), var3.ravel()))
1414
+ elif self.dim == 2:
1415
+ if slice_type == "theta":
1416
+ coords = np.column_stack((cx.ravel(), cy.ravel()))
1417
+ xi = np.column_stack((var1.ravel(), var2.ravel()))
1418
+ else:
1419
+ coords = np.column_stack((cx.ravel(), cz.ravel()))
1420
+ xi = np.column_stack((var1.ravel(), var3.ravel()))
1421
+ elif self.dim == 1:
1422
+ coords = np.sqrt(cx**2 + cy**2 + cz**2)
1423
+ xi = np.asarray(var1)
1424
+
1425
+ # Select field
1426
+ data = row[field_name]
1427
+
1428
+ # -------------------------------------------
1429
+ # UNIVERSAL VECTOR COMPONENT SELECTOR
1430
+ # -------------------------------------------
1431
+ if isinstance(data, np.ndarray) and comp is not None:
1432
+ if data.ndim == 2 and data.shape[0] == 3:
1433
+ data = data[comp]
1434
+ elif data.ndim == 2 and data.shape[1] == 3:
1435
+ data = data[:, comp]
1436
+ elif data.ndim == 3 and data.shape[0] == 3:
1437
+ data = data[comp].ravel()
1438
+ elif data.ndim == 4 and data.shape[0] == 3:
1439
+ data = data[comp].ravel()
1440
+ elif data.ndim == 4 and data.shape[-1] == 3:
1441
+ data = data[..., comp].ravel()
1442
+ else:
1443
+ raise ValueError(f"Cannot extract vector component from {data.shape}")
1444
+
1445
+ # -------------------------------------------------
1446
+ # Reflection augmentation
1447
+ # If `reflect=True` we augment the interpolation
1448
+ # dataset reflecting across the equatorial plane z=0
1449
+ # (i.e. z -> -z). For 2D XZ cuts (coords (x,z)) we flip z.
1450
+ # For vector components, the component normal to the
1451
+ # reflection plane (vz) changes sign.
1452
+ # -------------------------------------------------
1453
+ if reflect:
1454
+ try:
1455
+ # Normalize coords to shape (N, ndim)
1456
+ coords_arr = np.asarray(coords)
1457
+ ndim = coords_arr.shape[1] if coords_arr.ndim == 2 else 1
1458
+ coords_orig = coords_arr.reshape(-1, ndim)
1459
+ coords_ref = coords_orig.copy()
1460
+
1461
+ # Flip only the z coordinate (index -1 if 3D, index 1 if 2D XZ)
1462
+ if coords_orig.shape[1] == 3:
1463
+ coords_ref[:, 2] *= -1
1464
+ elif coords_orig.shape[1] == 2:
1465
+ # assume (x,z) layout for XZ cuts
1466
+ coords_ref[:, 1] *= -1
1467
+
1468
+ # Prepare data values (flattened)
1469
+ data_flat = np.asarray(data).ravel()
1470
+
1471
+ # For vector components, reflect sign for the
1472
+ # component perpendicular to the plane (vz -> -vz)
1473
+ if field_name == 'gasv_mesh' and comp is not None:
1474
+ # comp: 2->vz (flip), others unchanged
1475
+ if comp == 2:
1476
+ data_ref = -data_flat
1477
+ else:
1478
+ data_ref = data_flat.copy()
1479
+ else:
1480
+ # Scalars or already selected components
1481
+ data_ref = data_flat.copy()
1482
+
1483
+ # Augment coords and data for interpolation
1484
+ coords = np.vstack([coords_orig, coords_ref])
1485
+ data = np.concatenate([data_flat, data_ref])
1486
+ except Exception:
1487
+ # On error, fallback to original coords/data
1488
+ coords = np.asarray(coords)
1489
+ data = np.asarray(data)
1490
+
1491
+ # Dispatch backend
1492
+ if interpolator == "rbf":
1493
+ return rbf_interp(coords, data, xi)
1494
+ elif interpolator == "linearnd":
1495
+ return linearnd_interp(coords, data, xi)
1496
+ elif interpolator == "idw":
1497
+ return idw_interp(coords, data, xi)
1498
+ else:
1499
+ return griddata_interp(coords, data, xi)
1500
+
1501
+ # ===============================================================
1502
+ # TIME INTERPOLATION
1503
+ # ===============================================================
1504
+ if nsnaps == 1:
1505
+ if field == "gasv_mesh":
1506
+ vals = [interp(0, field, c) for c in range(3)]
1507
+ arr = np.array([v.item() if is_scalar else v.reshape(result_shape) for v in vals])
1508
+ return _smooth(arr)
1509
+ v = interp(0, field)
1510
+ return _smooth(v.item() if is_scalar else v.reshape(result_shape))
1511
+
1512
+ # Two snapshots
1513
+ if nsnaps == 2:
1514
+ i0, i1 = 0, 1
1515
+ t0, t1 = times[i0], times[i1]
1516
+ fac = (time - t0) / (t1 - t0) if abs(t1 - t0) > 1e-12 else 0
1517
+ fac = np.clip(fac, 0, 1)
1518
+
1519
+ def blend(c=None):
1520
+ v0 = interp(i0, field, c)
1521
+ v1 = interp(i1, field, c)
1522
+ return (1 - fac) * v0 + fac * v1
1523
+
1524
+ if field == "gasv_mesh":
1525
+ vals = [blend(c) for c in range(3)]
1526
+ arr = np.array([v.item() if is_scalar else v.reshape(result_shape) for v in vals])
1527
+ return _smooth(arr)
1528
+
1529
+ v = blend()
1530
+ return _smooth(v.item() if is_scalar else v.reshape(result_shape))
1531
+
1532
+ # Many snapshots
1533
+ i0 = np.searchsorted(times, time) - 1
1534
+ i0 = np.clip(i0, 0, nsnaps - 2)
1535
+ i1 = i0 + 1
1536
+ t0, t1 = times[i0], times[i1]
1537
+ fac = (time - t0) / (t1 - t0) if abs(t1 - t0) > 1e-12 else 0
1538
+ fac = np.clip(fac, 0, 1)
1539
+
1540
+ def blend(c=None):
1541
+ v0 = interp(i0, field, c)
1542
+ v1 = interp(i1, field, c)
1543
+ return (1 - fac) * v0 + fac * v1
1544
+
1545
+ if field == "gasv_mesh":
1546
+ vals = [blend(c) for c in range(3)]
1547
+ arr = np.array([v.item() if is_scalar else v.reshape(result_shape) for v in vals])
1548
+ return _smooth(arr)
1549
+
1550
+ v = blend()
1551
+ return _smooth(v.item() if is_scalar else v.reshape(result_shape))
1552
+
1553
+
1554
+ def plot(self, t=0, contour_levels=10, component='vz', smoothing_sigma=None, field=None):
1555
+ """
1556
+ Automatically determines the plane (XY, XZ, or 3D) and plots the field data.
1557
+
1558
+ Parameters
1559
+ ----------
1560
+ t : int, optional
1561
+ Index of the snapshot/time to plot.
1562
+ contour_levels : int, optional
1563
+ Number of contour levels for 2D plots.
1564
+ component : str, optional
1565
+ Component to plot for vector fields ('vx', 'vy', 'vz').
1566
+ field : str or None, optional
1567
+ Which field to plot when multiple fields were loaded (e.g. 'gasdens', 'gasv', 'gasenergy').
1568
+ If None and exactly one field is present, that field is plotted. If None and multiple
1569
+ candidate fields exist, a ValueError is raised instructing the user to pick one.
1570
+ """
1571
+
1572
+ if self.df is None:
1573
+ raise ValueError("No data loaded. Run load_field() first.")
1574
+
1575
+ if component=='vz':
1576
+ comp = 2
1577
+ if component=='vy':
1578
+ comp = 1
1579
+ if component=='vx':
1580
+ comp = 0
1581
+ df_names = self.df.columns.tolist()
1582
+
1583
+ # Map short names to dataframe column names
1584
+ field_map = {
1585
+ 'gasdens': 'gasdens_mesh',
1586
+ 'gasv': 'gasv_mesh',
1587
+ 'gasenergy': 'gasenergy_mesh'
1588
+ }
1589
+
1590
+ # Detect candidate fields present in the DataFrame
1591
+ candidates = [c for c in df_names if c in ('gasdens_mesh', 'gasv_mesh', 'gasenergy_mesh')]
1592
+
1593
+ # Resolve user-requested field
1594
+ if field is not None:
1595
+ # allow short names or full column names
1596
+ if field in field_map:
1597
+ field_col = field_map[field]
1598
+ else:
1599
+ field_col = field
1600
+ if field_col not in df_names:
1601
+ raise ValueError(f"Requested field '{field}' not present. Available: {candidates}")
1602
+ else:
1603
+ if len(candidates) == 1:
1604
+ field_col = candidates[0]
1605
+ else:
1606
+ raise ValueError(
1607
+ f"Multiple fields present {candidates}. Specify which to plot using field='gasdens' or 'gasv'."
1608
+ )
1609
+
1610
+ # Extract the mesh grids and field data after slicing
1611
+ var1 = self.df['var1_mesh'][t]
1612
+ var2 = self.df['var2_mesh'][t]
1613
+ var3 = self.df['var3_mesh'][t]
1614
+
1615
+ # Load the original field (before slicing) if needed elsewhere
1616
+ d3 = self.sim._load_field_raw('gasdens', snapshot=int(self.df['snapshot'][t]), field_type='scalar')
1617
+
1618
+ # Prepare field_data according to selected field
1619
+ raw_field = self.df[field_col][t]
1620
+ is_vector = (field_col == 'gasv_mesh')
1621
+ if is_vector:
1622
+ # choose component or compute magnitude if needed
1623
+ data_arr = np.asarray(raw_field)
1624
+ # handle common memory layouts: (3, ... ) or (..., 3)
1625
+ if data_arr.ndim >= 1 and data_arr.shape[0] == 3:
1626
+ # shape (3, N...) -> select component
1627
+ field_data = data_arr[comp]
1628
+ elif data_arr.ndim >= 1 and data_arr.shape[-1] == 3:
1629
+ field_data = data_arr[..., comp]
1630
+ else:
1631
+ # fallback: try to interpret as magnitude
1632
+ try:
1633
+ field_data = np.sqrt(np.sum(data_arr**2, axis=0))
1634
+ except Exception:
1635
+ field_data = data_arr
1636
+ # do not apply log to vector components
1637
+ else:
1638
+ # scalar field: apply safe log10 for plotting
1639
+ field_data = np.asarray(raw_field)
1640
+ # avoid log of non-positive numbers
1641
+ with np.errstate(divide='ignore'):
1642
+ field_data = np.log10(np.where(field_data <= 0, np.nan, field_data))
1643
+
1644
+ # Get the shapes of the resulting mesh grids after applying the slice
1645
+ sliced_shape = var1.shape # Assuming var1, var2, var3 have the same shape after slicing
1646
+
1647
+ # Detect fixed angles in slice string
1648
+ slice_str = self.slice if hasattr(self, 'slice') and self.slice is not None else ""
1649
+ # Fixed theta: e.g. 'theta=1.56' (not theta=[...])
1650
+ fixed_theta = re.search(r'theta\s*=\s*([^\[\],]+)', slice_str)
1651
+ fixed_phi = re.search(r'phi\s*=\s*([^\[\],]+)', slice_str)
1652
+
1653
+ if fixed_theta:
1654
+ plane = 'XY'
1655
+ elif fixed_phi:
1656
+ plane = 'XZ'
1657
+ else:
1658
+ plane = 'XY' # Default/fallback
1659
+ # Check the number of dimensions in the sliced shape
1660
+
1661
+
1662
+ if len(sliced_shape) == 3:
1663
+ var1_flat = var1.ravel()
1664
+ var2_flat = var2.ravel()
1665
+ var3_flat = var3.ravel()
1666
+ data = field_data.ravel()
1667
+
1668
+ fig = plt.figure(figsize=(8, 6))
1669
+ ax = fig.add_subplot(111, projection='3d')
1670
+ scatter = ax.scatter(var1_flat, var2_flat, var3_flat, c=data, cmap='Spectral_r', s=5)
1671
+ cbar = fig.colorbar(scatter, ax=ax)
1672
+ cbar.ax.tick_params(labelsize=12)
1673
+ if is_vector:
1674
+ cbar.set_label(rf"{component} [units]", size=14)
1675
+ else:
1676
+ cbar.set_label(rf"$\log_{{10}}(field)$", size=14)
1677
+ ax.set_xlabel("X",size=14)
1678
+ ax.set_ylabel("Y",size=14)
1679
+ ax.set_zlabel("Z",size=14)
1680
+
1681
+ fp.Plot.fargopy_mark(ax)
1682
+ plt.show()
1683
+
1684
+
1685
+ elif len(sliced_shape) == 2:
1686
+ # Optional smoothing to remove interpolation artefacts (triangular edges)
1687
+ if smoothing_sigma is not None:
1688
+ try:
1689
+ field_data = gaussian_filter(field_data, sigma=smoothing_sigma)
1690
+ except Exception:
1691
+ # If smoothing fails, fall back to original data
1692
+ field_data = field_data
1693
+ if plane == 'XY':
1694
+ fig, ax = plt.subplots(figsize=(8, 6))
1695
+ mesh = ax.pcolormesh(var1, var2, field_data, shading='auto', cmap='Spectral_r')
1696
+ cbar = fig.colorbar(mesh)
1697
+ cbar.ax.tick_params(labelsize=12)
1698
+ if is_vector:
1699
+ cbar.set_label(rf"{component} [units]", size=14)
1700
+ else:
1701
+ cbar.set_label(rf"$\log_{{10}}(field)$", size=14)
1702
+ ax.set_xlabel("X",size=14)
1703
+ ax.set_ylabel("Y",size=14)
1704
+ ax.tick_params(axis='both', which='major', labelsize=12)
1705
+
1706
+ fp.Plot.fargopy_mark(ax)
1707
+ plt.show()
1708
+ elif plane == 'XZ':
1709
+ fig, ax = plt.subplots(figsize=(8, 6))
1710
+ mesh = ax.pcolormesh(var1, var3, field_data, shading='auto', cmap='Spectral_r')
1711
+ cbar = fig.colorbar(mesh)
1712
+ cbar.ax.tick_params(labelsize=12)
1713
+ if is_vector:
1714
+ cbar.set_label(rf"{component} [units]", size=14)
1715
+ else:
1716
+ cbar.set_label(rf"$\log_{{10}}(field)$", size=14)
1717
+ ax.set_xlabel("X",size=14)
1718
+ ax.set_ylabel("Z",size=14)
1719
+ ax.tick_params(axis='both', which='major', labelsize=12)
1720
+
1721
+ fp.Plot.fargopy_mark(ax)
1722
+ plt.show()
1723
+ else:
1724
+ fig, ax = plt.subplots(figsize=(8, 6))
1725
+ mesh = ax.pcolormesh(var1, var2, field_data, shading='auto', cmap='Spectral_r')
1726
+ cbar = fig.colorbar(mesh)
1727
+ cbar.ax.tick_params(labelsize=12)
1728
+ cbar.set_label(rf"$\log_{{10}}(field)$", size=14)
1729
+ ax.set_xlabel("X",size=14)
1730
+ ax.set_ylabel("Y",size=14)
1731
+ ax.tick_params(axis='both', which='major', labelsize=12)
1732
+
1733
+ fp.Plot.fargopy_mark(ax)
1734
+ plt.show()
1735
+
1736
+ def cut_sector(self, xc, yc, zc, rc, hc, dataframe=None):
1737
+ """
1738
+ Filter the DataFrame to keep only data inside a cylinder of radius rc and height hc
1739
+ centered at (xc, yc, zc). Returns a new filtered DataFrame.
1740
+
1741
+ Parameters
1742
+ ----------
1743
+ xc, yc, zc : float
1744
+ Center coordinates of the cylinder.
1745
+ rc : float
1746
+ Cylinder radius.
1747
+ hc : float
1748
+ Cylinder height.
1749
+ dataframe : pd.DataFrame, optional
1750
+ DataFrame to filter (default: self.df).
1751
+
1752
+ Returns
1753
+ -------
1754
+ pd.DataFrame
1755
+ Filtered DataFrame with only points inside the cylinder.
1756
+ """
1757
+ if dataframe is None:
1758
+ if self.df is None:
1759
+ raise ValueError("No DataFrame loaded. Run load_data() first or pass a DataFrame.")
1760
+ dataframe = self.df
1761
+
1762
+ df = dataframe.copy()
1763
+ # Assume mesh columns are named 'var1_mesh', 'var2_mesh', 'var3_mesh'
1764
+ mask_list = []
1765
+ for idx, row in df.iterrows():
1766
+
1767
+ x = np.array(row['var1_mesh'])
1768
+ y = np.array(row['var2_mesh'])
1769
+ z = np.array(row['var3_mesh'])
1770
+ # Compute boolean mask selecting points inside the cylinder
1771
+ r_xy = np.sqrt((x - xc)**2 + (y - yc)**2)
1772
+ z_min = zc - hc/2
1773
+ z_max = zc + hc/2
1774
+ mask = (r_xy <= rc) & (z >= z_min) & (z <= z_max)
1775
+ # Si el campo es escalar
1776
+ filtered = {}
1777
+ filtered['snapshot'] = row['snapshot']
1778
+ filtered['time'] = row['time']
1779
+ filtered['var1_mesh'] = x[mask]
1780
+ filtered['var2_mesh'] = y[mask]
1781
+ filtered['var3_mesh'] = z[mask]
1782
+ # Filter the corresponding field columns
1783
+ for col in df.columns:
1784
+ if col.endswith('_mesh') and col not in ['var1_mesh', 'var2_mesh', 'var3_mesh']:
1785
+ data = np.array(row[col])
1786
+ filtered[col] = data[mask]
1787
+ mask_list.append(filtered)
1788
+ # Convert the list of dicts to a DataFrame
1789
+ filtered_df = pd.DataFrame(mask_list)
1790
+ return filtered_df