fargopy 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fargopy/__init__.py CHANGED
@@ -351,7 +351,8 @@ from fargopy.sys import *
351
351
  from fargopy.fields import *
352
352
  from fargopy.simulation import *
353
353
  from fargopy.plot import *
354
- #from fargopy.Fsimulation
354
+ #from fargopy.fsimulation import *
355
+ from fargopy.flux import *
355
356
 
356
357
  # Showing version
357
358
  print(f"Running FARGOpy version {version}")
fargopy/fields.py CHANGED
@@ -8,6 +8,29 @@ import fargopy
8
8
  ###############################################################
9
9
  import numpy as np
10
10
  import re
11
+ import pandas as pd
12
+
13
+ import matplotlib.pyplot as plt
14
+ import plotly.figure_factory as ff
15
+ from plotly.subplots import make_subplots
16
+ import plotly.graph_objects as go
17
+ from matplotlib.animation import FFMpegWriter
18
+ from scipy.interpolate import RBFInterpolator
19
+ from scipy.interpolate import interp1d
20
+ from scipy.interpolate import LinearNDInterpolator
21
+ from scipy.spatial import cKDTree
22
+
23
+
24
+ from joblib import Parallel, delayed
25
+
26
+
27
+ from ipywidgets import interact, FloatSlider, IntSlider
28
+ from celluloid import Camera
29
+ from IPython.display import HTML, Video
30
+
31
+ from scipy.interpolate import griddata
32
+ from scipy.integrate import solve_ivp
33
+ from tqdm import tqdm
11
34
 
12
35
  ###############################################################
13
36
  # Constants
@@ -242,3 +265,541 @@ class Field(fargopy.Fargobj):
242
265
  def __repr__(self):
243
266
  return str(self.data)
244
267
 
268
+
269
+ class FieldInterpolator:
270
+ def __init__(self, sim):
271
+ self.sim = sim
272
+ self.snapshot_time_table = None
273
+
274
+ def load_data(self, field=None, slice=None, snapshots=None):
275
+ self.field = field
276
+ self.slice=slice
277
+
278
+ # Convert a single snapshot to a list
279
+ if isinstance(snapshots, int):
280
+ snapshots = [snapshots]
281
+
282
+
283
+ # Handle the case where snapshots is a single value or a list with one value
284
+ if len(snapshots) == 1:
285
+
286
+ snaps = snapshots
287
+ time_values = [0] # Single snapshot corresponds to a single time value
288
+ else:
289
+ snaps = np.arange(snapshots[0], snapshots[1] + 1)
290
+ time_values = np.linspace(0, 1, len(snaps))
291
+
292
+ # Guarda la tabla como DataFrame
293
+ self.snapshot_time_table = pd.DataFrame({
294
+ "Snapshot": snaps,
295
+ "Normalized_time": time_values
296
+ })
297
+
298
+ """
299
+ Loads data in 2D or 3D depending on the provided parameters.
300
+
301
+ Parameters:
302
+ field (list of str, optional): List of fields to load (e.g., ["gasdens", "gasv"]).
303
+ slice (str, optional): Slice definition, e.g., "phi=0", "theta=45", or "z=0,r=[0.8,1.2],phi=[-10 deg,10 deg]".
304
+ snapshots (list or int, optional): List of snapshot indices or a single snapshot to load. Required for both 2D and 3D.
305
+ Returns:
306
+ pd.DataFrame: DataFrame containing the loaded data.
307
+ """
308
+ if field is None:
309
+ raise ValueError("You must specify at least one field to load using the 'fields' parameter.")
310
+
311
+ # Validate and parse the slice parameter
312
+ slice_type = None
313
+ if slice:
314
+ slice = slice.lower() # Normalize to lowercase for consistency
315
+ if "theta" in slice:
316
+ slice_type = "theta"
317
+ elif "phi" in slice:
318
+ slice_type = "phi"
319
+ else:
320
+ raise ValueError("The 'slice' parameter must contain 'theta' or 'phi'.")
321
+
322
+ if not isinstance(snapshots, (int, list, tuple)):
323
+ raise ValueError("'snapshots' must be an integer, a list, or a tuple.")
324
+
325
+ if isinstance(snapshots, (list, tuple)) and len(snapshots) == 2:
326
+ if snapshots[0] > snapshots[1]:
327
+ raise ValueError("The range in 'snapshots' is invalid. The first value must be less than or equal to the second.")
328
+
329
+ if not hasattr(self.sim, "domains") or self.sim.domains is None:
330
+ raise ValueError("Simulation domains are not loaded. Ensure the simulation data is properly initialized.")
331
+
332
+
333
+
334
+
335
+ if slice: # Load 2D data
336
+ # Dynamically create DataFrame columns based on the fields
337
+ columns = ['snapshot', 'time', 'var1_mesh', 'var2_mesh']
338
+ if field == "gasdens":
339
+ print(f'Loading 2D density data for slice: {slice}.')
340
+ columns.append('gasdens_mesh')
341
+ if field == "gasv":
342
+ columns.append('gasv_mesh')
343
+ print(f'Loading 2D gas velocity data for slice: {slice}.')
344
+ if field == 'gasenergy':
345
+ columns.append('gasenergy_mesh')
346
+ print(f'Loading 2D gas energy data for slice {slice}')
347
+ df_snapshots = pd.DataFrame(columns=columns)
348
+
349
+ for i, snap in enumerate(snaps):
350
+ row = {'snapshot': snap, 'time': time_values[i]}
351
+
352
+ # Assign coordinates for all fields
353
+ if field == 'gasdens':
354
+ gasd = self.sim.load_field('gasdens', snapshot=snap, type='scalar')
355
+ gasd_slice, mesh = gasd.meshslice(slice=slice)
356
+ if slice_type == "phi":
357
+ row["var1_mesh"], row["var2_mesh"] = getattr(mesh, "x"), getattr(mesh, "z")
358
+ elif slice_type == "theta":
359
+ row["var1_mesh"], row["var2_mesh"] = getattr(mesh, "x"), getattr(mesh, "y")
360
+ row['gasdens_mesh'] = gasd_slice
361
+
362
+ if field == "gasv":
363
+ gasv = self.sim.load_field('gasv', snapshot=snap, type='vector')
364
+ gasvx, gasvy, gasvz = gasv.to_cartesian()
365
+ if slice_type == "phi":
366
+ # Plane XZ: use vx and vz
367
+ vel1_slice, mesh1 = getattr(gasvx, f'meshslice')(slice=slice)
368
+ vel2_slice, mesh2 = getattr(gasvz, f'meshslice')(slice=slice)
369
+ row["var1_mesh"], row["var2_mesh"] = getattr(mesh1, "x"), getattr(mesh1, "z")
370
+ row['gasv_mesh'] = np.array([vel1_slice, vel2_slice])
371
+ elif slice_type == "theta":
372
+ # Plane XY: use vx and vy
373
+ vel1_slice, mesh1 = getattr(gasvx, f'meshslice')(slice=slice)
374
+ vel2_slice, mesh2 = getattr(gasvy, f'meshslice')(slice=slice)
375
+ row["var1_mesh"], row["var2_mesh"] = getattr(mesh1, "x"), getattr(mesh1, "y")
376
+ row['gasv_mesh'] = np.array([vel1_slice, vel2_slice])
377
+
378
+ if field == "gasenergy":
379
+ gasenergy = self.sim.load_field('gasenergy', snapshot=snap, type='scalar')
380
+ gasenergy_slice, mesh = gasenergy.meshslice(slice=slice)
381
+ row["gasenergy_mesh"] = gasenergy_slice
382
+ if slice_type == "phi":
383
+ row["var1_mesh"], row["var2_mesh"] = getattr(mesh, "x"), getattr(mesh, "z")
384
+ elif slice_type == "theta":
385
+ row["var1_mesh"], row["var2_mesh"] = getattr(mesh, "x"), getattr(mesh, "y")
386
+
387
+ # Convert the row to a DataFrame and concatenate it
388
+ row_df = pd.DataFrame([row])
389
+ df_snapshots = pd.concat([df_snapshots, row_df], ignore_index=True)
390
+
391
+ self.df = df_snapshots
392
+ return df_snapshots
393
+
394
+ elif slice is None: # Load 3D data
395
+ # Generate 3D mesh
396
+ theta, r, phi = np.meshgrid(self.sim.domains.theta, self.sim.domains.r, self.sim.domains.phi, indexing='ij')
397
+ x, y, z = r * np.sin(theta) * np.cos(phi), r * np.sin(theta) * np.sin(phi), r * np.cos(theta)
398
+
399
+ # Dynamically create DataFrame columns based on the fields
400
+ columns = ['snapshot', 'time', 'var1_mesh', 'var2_mesh', 'var3_mesh']
401
+ if field == "gasdens":
402
+ print(f'Loading 3D density data ')
403
+ columns.append('gasdens_mesh')
404
+ if field == "gasv":
405
+ columns.append('gasv_mesh')
406
+ print(f'Loading 3D gas velocity data')
407
+ if field == 'gasenergy':
408
+ columns.append('gasenergy_mesh')
409
+ print(f'Loading 3D gas energy data')
410
+
411
+ df_snapshots = pd.DataFrame(columns=columns)
412
+
413
+ for i, snap in enumerate(snaps):
414
+ row = {'snapshot': snap, 'time': time_values[i]}
415
+
416
+ # Assign coordinates for all fields
417
+ if field == 'gasdens':
418
+ gasd = self.sim.load_field('gasdens', snapshot=snap, type='scalar')
419
+ row["var1_mesh"], row["var2_mesh"], row["var3_mesh"] = x, y, z
420
+ row['gasdens_mesh'] = gasd.data
421
+
422
+ if field == "gasv":
423
+ gasv = self.sim.load_field('gasv', snapshot=snap, type='vector')
424
+ gasvx, gasvy, gasvz = gasv.to_cartesian()
425
+ row["var1_mesh"], row["var2_mesh"], row["var3_mesh"] = x, y, z
426
+ row['gasv_mesh'] = np.array([gasvx.data, gasvy.data, gasvz.data])
427
+
428
+ if field == "gasenergy":
429
+ gasenergy = self.sim.load_field('gasenergy', snapshot=snap, type='scalar')
430
+ row["gasenergy_mesh"] = gasenergy.data
431
+ row["var1_mesh"], row["var2_mesh"], row["var3_mesh"] = x, y, z
432
+
433
+ # Convert the row to a DataFrame and concatenate it
434
+ row_df = pd.DataFrame([row])
435
+ df_snapshots = pd.concat([df_snapshots, row_df], ignore_index=True)
436
+
437
+ self.df = df_snapshots
438
+ return df_snapshots
439
+
440
+ def times(self):
441
+ if self.snapshot_time_table is None:
442
+ raise ValueError("No data loaded. Run load_data() first.")
443
+ return self.snapshot_time_table
444
+
445
+ def create_mesh(
446
+ self,
447
+ slice=None,
448
+ nr=50,
449
+ ntheta=50,
450
+ nphi=50
451
+ ):
452
+ """
453
+ Create a mesh grid based on the slice definition provided by the user.
454
+ If no slice is provided, create a full 3D mesh within the simulation domain.
455
+
456
+ Parameters:
457
+ slice (str, optional): The slice definition string (e.g., "r=[0.8,1.2],phi=0,theta=[0 deg,90 deg]").
458
+ nr (int): Number of divisions in r.
459
+ ntheta (int): Number of divisions in theta.
460
+ nphi (int): Number of divisions in phi.
461
+
462
+ Returns:
463
+ tuple: Mesh grid (x, y, z) based on the slice definition or the full domain.
464
+ """
465
+ import numpy as np
466
+ import re
467
+
468
+ # If no slice is provided, create a full 3D mesh using the simulation domains
469
+ if not slice:
470
+ r = np.linspace(self.sim.domains.r.min(), self.sim.domains.r.max(), nr)
471
+ theta = np.linspace(self.sim.domains.theta.min(), self.sim.domains.theta.max(), ntheta)
472
+ phi = np.linspace(self.sim.domains.phi.min(), self.sim.domains.phi.max(), nphi)
473
+ theta_grid, r_grid, phi_grid = np.meshgrid(theta, r, phi, indexing='ij')
474
+ x = r_grid * np.sin(theta_grid) * np.cos(phi_grid)
475
+ y = r_grid * np.sin(theta_grid) * np.sin(phi_grid)
476
+ z = r_grid * np.cos(theta_grid)
477
+ return x, y, z
478
+
479
+ # Initialize default ranges
480
+ r_range = [self.sim.domains.r.min(), self.sim.domains.r.max()]
481
+ theta_range = [self.sim.domains.theta.min(), self.sim.domains.theta.max()]
482
+ phi_range = [self.sim.domains.phi.min(), self.sim.domains.phi.max()]
483
+ z_value = None
484
+
485
+ # Regular expressions to extract parameters
486
+ range_pattern = re.compile(r"(\w+)=\[(.+?)\]") # Matches ranges like r=[0.8,1.2]
487
+ value_pattern = re.compile(r"(\w+)=([-\d.]+)") # Matches single values like phi=0 or z=0
488
+ degree_pattern = re.compile(r"([-\d.]+) deg") # Matches angles in degrees like -25 deg
489
+
490
+ # Process ranges
491
+ for match in range_pattern.finditer(slice):
492
+ key, values = match.groups()
493
+ values = [float(degree_pattern.sub(lambda m: str(float(m.group(1)) * np.pi / 180), v.strip())) for v in values.split(',')]
494
+ if key == 'r':
495
+ r_range = values
496
+ elif key == 'phi':
497
+ phi_range = values
498
+ elif key == 'theta':
499
+ theta_range = values
500
+
501
+ # Process single values
502
+ for match in value_pattern.finditer(slice):
503
+ key, value = match.groups()
504
+ value = float(degree_pattern.sub(lambda m: str(float(m.group(1)) * np.pi / 180), value))
505
+ if key == 'z':
506
+ z_value = value
507
+ elif key == 'phi':
508
+ phi_range = [value, value]
509
+ elif key == 'theta':
510
+ theta_range = [value, value]
511
+
512
+ # 3D mesh: all ranges are intervals
513
+ if (phi_range[0] != phi_range[1]) and (theta_range[0] != theta_range[1]):
514
+ r = np.linspace(r_range[0], r_range[1], nr)
515
+ theta = np.linspace(theta_range[0], theta_range[1], ntheta)
516
+ phi = np.linspace(phi_range[0], phi_range[1], nphi)
517
+ theta_grid, r_grid, phi_grid = np.meshgrid(theta, r, phi, indexing='ij')
518
+ x = r_grid * np.sin(theta_grid) * np.cos(phi_grid)
519
+ y = r_grid * np.sin(theta_grid) * np.sin(phi_grid)
520
+ z = r_grid * np.cos(theta_grid)
521
+ return x, y, z
522
+
523
+ # 2D mesh: one angle is fixed (slice)
524
+ elif phi_range[0] == phi_range[1]: # Slice at constant phi (XZ plane)
525
+ r = np.linspace(r_range[0], r_range[1], nr)
526
+ theta = np.linspace(theta_range[0], theta_range[1], ntheta)
527
+ phi = phi_range[0]
528
+ theta_grid, r_grid = np.meshgrid(theta, r, indexing='ij')
529
+ x = r_grid * np.sin(theta_grid) * np.cos(phi)
530
+ y = r_grid * np.sin(theta_grid) * np.sin(phi)
531
+ z = r_grid * np.cos(theta_grid)
532
+ return x, y, z
533
+
534
+ elif theta_range[0] == theta_range[1]: # Slice at constant theta (XY plane)
535
+ r = np.linspace(r_range[0], r_range[1], nr)
536
+ phi = np.linspace(phi_range[0], phi_range[1], nphi)
537
+ theta = theta_range[0]
538
+ phi_grid, r_grid = np.meshgrid(phi, r, indexing='ij')
539
+ x = r_grid * np.sin(theta) * np.cos(phi_grid)
540
+ y = r_grid * np.sin(theta) * np.sin(phi_grid)
541
+ z = r_grid * np.cos(theta)
542
+ return x, y, z
543
+
544
+ elif z_value is not None: # Slice at constant z (XY plane in cartesian)
545
+ r = np.linspace(r_range[0], r_range[1], nr)
546
+ phi = np.linspace(phi_range[0], phi_range[1], nphi)
547
+ r_grid, phi_grid = np.meshgrid(r, phi, indexing='ij')
548
+ x = r_grid * np.cos(phi_grid)
549
+ y = r_grid * np.sin(phi_grid)
550
+ z = np.full_like(x, z_value)
551
+ return x, y, z
552
+
553
+ else:
554
+ raise ValueError("Slice definition must include either 'z', 'phi', or 'theta'.")
555
+
556
+ def evaluate(
557
+ self, time, var1, var2=None, var3=None,
558
+ interpolator="griddata", method="linear",
559
+ rbf_kwargs=None, griddata_kwargs=None, idw_kwargs=None
560
+ ):
561
+ """
562
+ Interpolates a field in 1D, 2D, or 3D using RBFInterpolator, griddata, LinearNDInterpolator, or IDW.
563
+ Supports both grids and discrete points.
564
+
565
+ Parameters:
566
+ ...
567
+ interpolator (str): Interpolation family, either "rbf", "griddata", "linearnd", or "idw". Default is "griddata".
568
+ idw_kwargs (dict): Optional kwargs for IDW, e.g. {'power': 2, 'k': 8}
569
+ ...
570
+ """
571
+
572
+
573
+ # --- Handle time input: explicit and robust: normalized time [0,1] or snapshot index ---
574
+ if hasattr(self, "snapshot_time_table") and self.snapshot_time_table is not None:
575
+ snaps = self.snapshot_time_table["Snapshot"].values
576
+ min_snap, max_snap = snaps.min(), snaps.max()
577
+ # If time is float in [0,1], treat as normalized time (directly)
578
+ if isinstance(time, float) and 0 <= time <= 1:
579
+ pass # Use as normalized time directly
580
+ # If time is int or float > 1, treat as snapshot or fractional snapshot
581
+ elif (isinstance(time, int) or (isinstance(time, float) and time > 1)):
582
+ if time < min_snap or time > max_snap:
583
+ raise ValueError(
584
+ f"Selected snapshot (time={time}) is outside the loaded range [{min_snap}, {max_snap}]."
585
+ )
586
+ if isinstance(time, int) or np.isclose(time, np.round(time)):
587
+ # Exact snapshot
588
+ row = self.snapshot_time_table[self.snapshot_time_table["Snapshot"] == int(round(time))]
589
+ if not row.empty:
590
+ time = float(row["Normalized_time"].values[0])
591
+ else:
592
+ raise ValueError(f"Snapshot {int(round(time))} not found in snapshot_time_table.")
593
+ else:
594
+ # Fractional snapshot: interpolate between neighbors
595
+ snap0 = int(np.floor(time))
596
+ snap1 = int(np.ceil(time))
597
+ if snap0 < min_snap or snap1 > max_snap:
598
+ raise ValueError(
599
+ f"Selected snapshot (time={time}) requires neighbors [{snap0}, {snap1}] outside the loaded range [{min_snap}, {max_snap}]."
600
+ )
601
+ row0 = self.snapshot_time_table[self.snapshot_time_table["Snapshot"] == snap0]
602
+ row1 = self.snapshot_time_table[self.snapshot_time_table["Snapshot"] == snap1]
603
+ if not row0.empty and not row1.empty:
604
+ t0 = float(row0["Normalized_time"].values[0])
605
+ t1 = float(row1["Normalized_time"].values[0])
606
+ factor = (time - snap0) / (snap1 - snap0)
607
+ time = (1 - factor) * t0 + factor * t1
608
+ else:
609
+ raise ValueError(f"Snapshots {snap0} or {snap1} not found in snapshot_time_table.")
610
+ else:
611
+ raise ValueError(
612
+ f"Invalid time value: {time}. Must be a normalized time in [0,1] or a snapshot index in [{min_snap},{max_snap}]."
613
+ )
614
+ else:
615
+ if isinstance(time, int):
616
+ raise ValueError("snapshot_time_table not found. Did you call load_data()?")
617
+
618
+ if interpolator not in ["rbf", "griddata", "linearnd","idw"]:
619
+ raise ValueError("Invalid method. Choose either 'rbf', 'griddata', 'idw', or 'linearnd'.")
620
+
621
+ # Automatically determine the field to interpolate
622
+ if "gasdens_mesh" in self.df.columns:
623
+ field_name = "gasdens_mesh"
624
+ elif "gasenergy_mesh" in self.df.columns:
625
+ field_name = "gasenergy_mesh"
626
+ elif "gasv_mesh" in self.df.columns: # Velocity field
627
+ field_name = "gasv_mesh"
628
+ else:
629
+ raise ValueError("No valid field found in the DataFrame for interpolation.")
630
+
631
+ # Sort the DataFrame by time
632
+ df_sorted = self.df.sort_values("time")
633
+ times = df_sorted["time"].values
634
+ n_snaps = len(times)
635
+
636
+ # Check if the input is a single point or a mesh
637
+ is_scalar = np.isscalar(var1) and (var2 is None or np.isscalar(var2)) and (var3 is None or np.isscalar(var3))
638
+ result_shape = () if is_scalar else var1.shape
639
+
640
+ if rbf_kwargs is None:
641
+ rbf_kwargs = {}
642
+ if griddata_kwargs is None:
643
+ griddata_kwargs = {}
644
+
645
+
646
+
647
+ if idw_kwargs is None:
648
+ idw_kwargs = {}
649
+
650
+
651
+ def idw_interp(coords, values, xi):
652
+ # Forzar a 2D: (N, D) y (M, D)
653
+ coords = np.asarray(coords)
654
+ xi = np.asarray(xi)
655
+ if coords.ndim > 2:
656
+ coords = coords.reshape(-1, coords.shape[-1])
657
+ if xi.ndim > 2:
658
+ xi = xi.reshape(-1, xi.shape[-1])
659
+ values = np.asarray(values).ravel()
660
+ power = idw_kwargs.get('power', 2)
661
+ k = idw_kwargs.get('k', 8)
662
+ tree = cKDTree(coords)
663
+ dists, idxs = tree.query(xi, k=k)
664
+ dists = np.where(dists == 0, 1e-10, dists)
665
+ weights = 1 / dists**power
666
+ weights /= weights.sum(axis=1, keepdims=True)
667
+ return np.sum(values[idxs] * weights, axis=1)
668
+
669
+ def rbf_interp(coords, values, xi):
670
+ # Check if epsilon is required for the selected kernel
671
+ kernels_requiring_epsilon = ["gaussian", "multiquadric", "inverse_multiquadric", "inverse_quadratic"]
672
+ if method in kernels_requiring_epsilon and "epsilon" not in rbf_kwargs:
673
+ raise ValueError(f"Kernel '{method}' requires 'epsilon' in rbf_kwargs.")
674
+ interpolator_obj = RBFInterpolator(
675
+ coords, values.ravel(),
676
+ kernel=method,
677
+ **rbf_kwargs
678
+ )
679
+ return interpolator_obj(xi)
680
+
681
+ def griddata_interp(coords, values, xi):
682
+ return griddata(coords, values.ravel(), xi, method=method, **griddata_kwargs)
683
+
684
+ def linearnd_interp(coords, values, xi):
685
+ interp_obj = LinearNDInterpolator(coords, values.ravel())
686
+ return interp_obj(xi)
687
+
688
+ def interp(idx, field, component=None):
689
+ if var2 is None and var3 is None: # 1D interpolation
690
+ coord_x = np.array(df_sorted.iloc[idx]["var1_mesh"])
691
+ if field == "gasv_mesh" and component is not None:
692
+ data = np.array(df_sorted.iloc[idx][field])[component]
693
+ else:
694
+ data = np.array(df_sorted.iloc[idx][field])
695
+ coords = coord_x.reshape(-1, 1)
696
+ xi = var1.reshape(-1, 1) if not is_scalar else np.array([[var1]])
697
+ if interpolator == "rbf":
698
+ return rbf_interp(coords, data, xi)
699
+ elif interpolator == "linearnd":
700
+ return linearnd_interp(coords, data, xi)
701
+ elif interpolator == "idw":
702
+ return idw_interp(coords, data, xi)
703
+ else:
704
+ return griddata_interp(coords, data, xi)
705
+
706
+ elif var3 is not None: # 3D interpolation
707
+ coord_x = np.array(df_sorted.iloc[idx]["var1_mesh"])
708
+ coord_y = np.array(df_sorted.iloc[idx]["var2_mesh"])
709
+ coord_z = np.array(df_sorted.iloc[idx]["var3_mesh"])
710
+ if field == "gasv_mesh" and component is not None:
711
+ data = np.array(df_sorted.iloc[idx][field])[component]
712
+ else:
713
+ data = np.array(df_sorted.iloc[idx][field])
714
+ coords = np.column_stack((coord_x.ravel(), coord_y.ravel(), coord_z.ravel()))
715
+
716
+ xi = np.column_stack((var1.ravel(), var2.ravel(), var3.ravel()))
717
+ if interpolator == "rbf":
718
+ return rbf_interp(coords, data, xi)
719
+ elif interpolator == "linearnd":
720
+ return linearnd_interp(coords, data, xi)
721
+ elif interpolator == "idw":
722
+ return idw_interp(coords, data, xi)
723
+ else:
724
+ return griddata_interp(coords, data, xi)
725
+ else: # 2D interpolation
726
+ coord1 = np.array(df_sorted.iloc[idx]["var1_mesh"])
727
+ coord2 = np.array(df_sorted.iloc[idx]["var2_mesh"])
728
+ if field == "gasv_mesh" and component is not None:
729
+ data = np.array(df_sorted.iloc[idx][field])[component]
730
+ else:
731
+ data = np.array(df_sorted.iloc[idx][field])
732
+ coords = np.column_stack((coord1.ravel(), coord2.ravel()))
733
+ xi = np.column_stack((var1.ravel(), var2.ravel()))
734
+ if interpolator == "rbf":
735
+ return rbf_interp(coords, data, xi)
736
+ elif interpolator == "linearnd":
737
+ return linearnd_interp(coords, data, xi)
738
+ elif interpolator == "idw":
739
+ return idw_interp(coords, data, xi)
740
+ else:
741
+ return griddata_interp(coords, data, xi)
742
+
743
+
744
+ # --- Case 1: only a snapshot ---
745
+ if n_snaps == 1:
746
+ def eval_single(component=None):
747
+ return interp(0, field_name, component)
748
+ if field_name == "gasv_mesh":
749
+ components = 3 if var3 is not None else 2 if var2 is not None else 1
750
+ results = Parallel(n_jobs=-1)(
751
+ delayed(eval_single)(i) for i in range(components)
752
+ )
753
+ return np.array([res.item() if is_scalar else res.reshape(result_shape) for res in results])
754
+ else:
755
+ # Trivial escalar case: parallelization over the single snapshot
756
+ result = Parallel(n_jobs=-1)([delayed(eval_single)()])
757
+ result = result[0]
758
+ return result.item() if is_scalar else result.reshape(result_shape)
759
+
760
+ # --- Case 2: Two snapshots, linear temporal interpolation ---
761
+ elif n_snaps == 2:
762
+ idx, idx_after = 0, 1
763
+ t0, t1 = times[idx], times[idx_after]
764
+ factor = (time - t0) / (t1 - t0) if abs(t1 - t0) > 1e-10 else 0
765
+ factor = max(0, min(factor, 1))
766
+ def temporal_interp(component=None):
767
+ val0 = interp(idx, field_name, component)
768
+ val1 = interp(idx_after, field_name, component)
769
+ return (1 - factor) * val0 + factor * val1
770
+ if field_name == "gasv_mesh":
771
+ components = 3 if var3 is not None else 2 if var2 is not None else 1
772
+ results = Parallel(n_jobs=-1)(
773
+ delayed(temporal_interp)(i) for i in range(components)
774
+ )
775
+ return np.array([res.item() if is_scalar else res.reshape(result_shape) for res in results])
776
+ else:
777
+ # Escalar: paralelización sobre ambos snapshots
778
+ results = Parallel(n_jobs=2)(
779
+ delayed(temporal_interp)() for _ in range(1)
780
+ )
781
+ result = results[0]
782
+ return result.item() if is_scalar else result.reshape(result_shape)
783
+
784
+ # --- Case 3: More than two snapshots, spline temporal interpolation ---
785
+ else:
786
+ def eval_all_snaps(component=None):
787
+ return Parallel(n_jobs=-1)(
788
+ delayed(interp)(i, field_name, component) for i in range(n_snaps)
789
+ )
790
+ if field_name == "gasv_mesh":
791
+ components = 3 if var3 is not None else 2 if var2 is not None else 1
792
+ results = []
793
+ for comp in range(components):
794
+ values = eval_all_snaps(component=comp)
795
+ values = np.stack([v if not is_scalar else np.array([v]) for v in values], axis=0)
796
+ f = interp1d(times, values, axis=0, kind='linear', bounds_error=False, fill_value=np.nan)
797
+ res = f(time)
798
+ results.append(res.item() if is_scalar else res.reshape(result_shape))
799
+ return np.array(results)
800
+ else:
801
+ values = eval_all_snaps()
802
+ values = np.stack([v if not is_scalar else np.array([v]) for v in values], axis=0)
803
+ f = interp1d(times, values, axis=0, kind='linear', bounds_error=False, fill_value=np.nan)
804
+ result = f(time)
805
+ return result.item() if is_scalar else result.reshape(result_shape)