emerge 0.6.7__py3-none-any.whl → 0.6.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emerge might be problematic. Click here for more details.

Files changed (33) hide show
  1. emerge/__init__.py +2 -2
  2. emerge/_emerge/_cache_check.py +1 -1
  3. emerge/_emerge/elements/femdata.py +3 -2
  4. emerge/_emerge/elements/index_interp.py +1 -2
  5. emerge/_emerge/elements/ned2_interp.py +16 -16
  6. emerge/_emerge/elements/nedelec2.py +17 -6
  7. emerge/_emerge/elements/nedleg2.py +21 -9
  8. emerge/_emerge/geo/__init__.py +1 -1
  9. emerge/_emerge/geo/horn.py +0 -1
  10. emerge/_emerge/geo/modeler.py +1 -1
  11. emerge/_emerge/geo/operations.py +13 -0
  12. emerge/_emerge/geo/pcb_tools/calculator.py +2 -3
  13. emerge/_emerge/geo/pmlbox.py +35 -11
  14. emerge/_emerge/geometry.py +13 -6
  15. emerge/_emerge/material.py +334 -82
  16. emerge/_emerge/mesh3d.py +14 -8
  17. emerge/_emerge/physics/microwave/assembly/assembler.py +43 -20
  18. emerge/_emerge/physics/microwave/microwave_3d.py +57 -44
  19. emerge/_emerge/physics/microwave/microwave_bc.py +26 -24
  20. emerge/_emerge/physics/microwave/microwave_data.py +90 -7
  21. emerge/_emerge/plot/pyvista/display.py +53 -15
  22. emerge/_emerge/plot/pyvista/display_settings.py +4 -1
  23. emerge/_emerge/plot/simple_plots.py +42 -26
  24. emerge/_emerge/projects/_load_base.txt +1 -2
  25. emerge/_emerge/selection.py +4 -0
  26. emerge/_emerge/simmodel.py +21 -9
  27. emerge/_emerge/solver.py +45 -18
  28. emerge/lib.py +256 -250
  29. {emerge-0.6.7.dist-info → emerge-0.6.9.dist-info}/METADATA +2 -1
  30. {emerge-0.6.7.dist-info → emerge-0.6.9.dist-info}/RECORD +33 -33
  31. {emerge-0.6.7.dist-info → emerge-0.6.9.dist-info}/licenses/LICENSE +2 -2
  32. {emerge-0.6.7.dist-info → emerge-0.6.9.dist-info}/WHEEL +0 -0
  33. {emerge-0.6.7.dist-info → emerge-0.6.9.dist-info}/entry_points.txt +0 -0
@@ -284,15 +284,16 @@ class PVDisplay(BaseDisplay):
284
284
  self._ruler.min_length = max(1e-3, min(self._mesh.edge_lengths))
285
285
  self._update_camera()
286
286
  self._add_aux_items()
287
+ # self._plot.renderer.enable_depth_peeling(20, 0.8)
288
+ # self._plot.enable_anti_aliasing(self.set.anti_aliassing)
287
289
  if self._do_animate:
288
290
  self._wire_close_events()
289
291
  self.add_text('Press Q to close!',color='red', position='upper_left')
290
292
  self._plot.show(auto_close=False, interactive_update=True, before_close_callback=self._close_callback)
291
293
  self._animate()
292
-
293
-
294
294
  else:
295
295
  self._plot.show()
296
+
296
297
  self._reset()
297
298
 
298
299
  def set_mesh(self, mesh: Mesh3D):
@@ -315,7 +316,6 @@ class PVDisplay(BaseDisplay):
315
316
  """The private callback function that stops the animation.
316
317
  """
317
318
  self._stop = True
318
- print('CLOSE!')
319
319
 
320
320
  def _animate(self) -> None:
321
321
  """Private function that starts the animation loop.
@@ -441,8 +441,20 @@ class PVDisplay(BaseDisplay):
441
441
  opacity = obj.opacity
442
442
  line_width = 0.5
443
443
  color = obj.color_rgb
444
+ metal = obj._metal
444
445
  style='surface'
445
446
 
447
+ # Default render settings
448
+ metallic = 0.05
449
+ roughness = 0.5
450
+ pbr = False
451
+
452
+ if metal:
453
+ pbr = True
454
+ metallic = 0.8
455
+ roughness = 0.3
456
+
457
+ # Default keyword arguments when plotting Mesh mode.
446
458
  if mesh is True:
447
459
  show_edges = True
448
460
  opacity = 0.7
@@ -450,13 +462,28 @@ class PVDisplay(BaseDisplay):
450
462
  style='wireframe'
451
463
  color=next(C_CYCLE)
452
464
 
453
- kwargs = setdefault(kwargs, color=color, opacity=opacity, line_width=line_width, show_edges=show_edges, pickable=True, style=style)
465
+ # Defining the default keyword arguments for PyVista
466
+ kwargs = setdefault(kwargs, color=color,
467
+ opacity=opacity,
468
+ metallic=metallic,
469
+ pbr=pbr,
470
+ roughness=roughness,
471
+ line_width=line_width,
472
+ show_edges=show_edges,
473
+ pickable=True,
474
+ style=style)
454
475
  mesh_obj = self.mesh(obj)
455
476
 
456
477
  if mesh is True and volume_mesh is True:
457
478
  mesh_obj = mesh_obj.extract_all_edges()
458
-
459
479
  actor = self._plot.add_mesh(mesh_obj, *args, **kwargs)
480
+
481
+ # Push 3D Geometries back to avoid Z-fighting with 2D geometries.
482
+ if obj.dim==3:
483
+ mapper = actor.GetMapper()
484
+ mapper.SetResolveCoincidentTopology(1)
485
+ mapper.SetRelativeCoincidentTopologyPolygonOffsetParameters(1,1)
486
+
460
487
  self._plot.add_mesh(self._volume_edges(_select(obj)), color='#000000', line_width=2, show_edges=True)
461
488
 
462
489
  def add_scatter(self, xs: np.ndarray, ys: np.ndarray, zs: np.ndarray):
@@ -567,6 +594,7 @@ class PVDisplay(BaseDisplay):
567
594
  grid = pv.StructuredGrid(x,y,z)
568
595
  field_flat = field.flatten(order='F')
569
596
 
597
+
570
598
  if scale=='log':
571
599
  T = lambda x: np.log10(np.abs(x))
572
600
  elif scale=='symlog':
@@ -582,25 +610,28 @@ class PVDisplay(BaseDisplay):
582
610
  self._ctr += 1
583
611
  grid[name] = static_field
584
612
 
613
+ grid_no_nan = grid.threshold(scalars=name)
614
+
615
+ # Determine color limits
585
616
  if clim is None:
586
- fmin = np.min(static_field)
587
- fmax = np.max(static_field)
617
+ fmin = np.nanmin(static_field)
618
+ fmax = np.nanmax(static_field)
588
619
  clim = (fmin, fmax)
589
-
590
620
  if symmetrize:
591
- lim = max(abs(clim[0]),abs(clim[1]))
621
+ lim = max(abs(clim[0]), abs(clim[1]))
592
622
  clim = (-lim, lim)
593
623
 
594
624
  kwargs = setdefault(kwargs, cmap=cmap, clim=clim, opacity=opacity, pickable=False, multi_colors=True)
595
- actor = self._plot.add_mesh(grid, scalars=name, **kwargs)
625
+ actor = self._plot.add_mesh(grid_no_nan, scalars=name, **kwargs)
626
+
596
627
 
597
628
  if self._do_animate:
598
629
  def on_update(obj: _AnimObject, phi: complex):
599
- field = obj.T(np.real(obj.field*phi))
600
- obj.grid[name] = field
601
- self._objs.append(_AnimObject(field_flat, T, grid, actor, on_update)) # type: ignore
602
-
603
-
630
+ field_anim = obj.T(np.real(obj.field * phi))
631
+ obj.grid[name] = field_anim
632
+ self._objs.append(_AnimObject(field_flat, T, grid_no_nan, actor, on_update))
633
+
634
+
604
635
  def add_title(self, title: str) -> None:
605
636
  """Adds a title
606
637
 
@@ -652,6 +683,11 @@ class PVDisplay(BaseDisplay):
652
683
  dx = dx.flatten().real
653
684
  dy = dy.flatten().real
654
685
  dz = dz.flatten().real
686
+
687
+ ids = np.invert(np.isnan(dx))
688
+
689
+ x, y, z, dx, dy, dz = x[ids], y[ids], z[ids], dx[ids], dy[ids], dz[ids]
690
+
655
691
  dmin = _min_distance(x,y,z)
656
692
 
657
693
  dmax = np.max(_norm(dx,dy,dz))
@@ -667,8 +703,10 @@ class PVDisplay(BaseDisplay):
667
703
  kwargs = dict()
668
704
  if color is not None:
669
705
  kwargs['color'] = color
706
+
670
707
  pl = self._plot.add_arrows(Coo, Vec, scalars=None, clim=None, cmap=None, **kwargs)
671
708
 
709
+
672
710
  def add_contour(self,
673
711
  X: np.ndarray,
674
712
  Y: np.ndarray,
@@ -1,3 +1,4 @@
1
+ from typing import Literal
1
2
 
2
3
  class PVDisplaySettings:
3
4
 
@@ -22,4 +23,6 @@ class PVDisplaySettings:
22
23
  self.background_bottom: str = "#c0d2e8"
23
24
  self.background_top: str = "#ffffff"
24
25
  self.grid_line_color: str = "#8e8e8e"
25
- self.z_boost: float = 1e-6
26
+ self.z_boost: float = 0#1e-9
27
+ self.depth_peeling: bool = True
28
+ self.anti_aliassing: Literal["msaa","ssaa",'fxaa'] = "msaa"
@@ -406,21 +406,21 @@ and sparse frequency annotations (e.g., labeled by frequency).
406
406
  plt.tight_layout()
407
407
  plt.show()
408
408
 
409
- def plot_sp(f: np.ndarray, S: list[np.ndarray] | np.ndarray,
410
- dblim=[-40, 5],
411
- xunit="GHz",
412
- levelindicator: int | float | None = None,
413
- noise_floor=-150,
414
- fill_areas: list[tuple] | None = None,
415
- spec_area: list[tuple[float,...]] | None = None,
416
- unwrap_phase=False,
417
- logx: bool = False,
418
- labels: list[str] | None = None,
419
- linestyles: list[str] | None = None,
420
- colorcycle: list[int] | None = None,
421
- filename: str | None = None,
422
- show_plot: bool = True,
423
- figdata: tuple | None = None) -> tuple[plt.Figure, plt.Axes, plt.Axes]:
409
+ def plot_sp(f: np.ndarray | list[np.ndarray], S: list[np.ndarray] | np.ndarray,
410
+ dblim=[-40, 5],
411
+ xunit="GHz",
412
+ levelindicator: int | float | None = None,
413
+ noise_floor=-150,
414
+ fill_areas: list[tuple] | None = None,
415
+ spec_area: list[tuple[float,...]] | None = None,
416
+ unwrap_phase=False,
417
+ logx: bool = False,
418
+ labels: list[str] | None = None,
419
+ linestyles: list[str] | None = None,
420
+ colorcycle: list[int] | None = None,
421
+ filename: str | None = None,
422
+ show_plot: bool = True,
423
+ figdata: tuple | None = None) -> tuple[plt.Figure, plt.Axes, plt.Axes]:
424
424
  """Plot S-parameters in dB and phase
425
425
 
426
426
  Args:
@@ -444,7 +444,12 @@ def plot_sp(f: np.ndarray, S: list[np.ndarray] | np.ndarray,
444
444
  Ss = [S]
445
445
  else:
446
446
  Ss = S
447
-
447
+
448
+ if not isinstance(f, list):
449
+ fs = [f for _ in Ss]
450
+ else:
451
+ fs = f
452
+
448
453
  if linestyles is None:
449
454
  linestyles = ['-' for _ in S]
450
455
 
@@ -452,7 +457,8 @@ def plot_sp(f: np.ndarray, S: list[np.ndarray] | np.ndarray,
452
457
  colorcycle = [i for i, S in enumerate(S)]
453
458
 
454
459
  unitdivider: dict[str, float] = {"MHz": 1e6, "GHz": 1e9, "kHz": 1e3}
455
- fnew = f / unitdivider[xunit]
460
+
461
+ fs = [f / unitdivider[xunit] for f in fs]
456
462
 
457
463
  if figdata is None:
458
464
  # Create two subplots: one for magnitude and one for phase
@@ -463,10 +469,10 @@ def plot_sp(f: np.ndarray, S: list[np.ndarray] | np.ndarray,
463
469
  minphase, maxphase = -180, 180
464
470
 
465
471
  maxy = 0
466
- for s, ls, cid in zip(Ss, linestyles, colorcycle):
472
+ for f, s, ls, cid in zip(fs, Ss, linestyles, colorcycle):
467
473
  # Calculate and plot magnitude in dB
468
474
  SdB = 20 * np.log10(np.abs(s) + 10**(noise_floor/20) * np.random.rand(*s.shape) + 10**((noise_floor-30)/20))
469
- ax_mag.plot(fnew, SdB, label="Magnitude (dB)", linestyle=ls, color=EMERGE_COLORS[cid % len(EMERGE_COLORS)])
475
+ ax_mag.plot(f, SdB, label="Magnitude (dB)", linestyle=ls, color=EMERGE_COLORS[cid % len(EMERGE_COLORS)])
470
476
  if np.max(SdB) > maxy:
471
477
  maxy = np.max(SdB)
472
478
  # Calculate and plot phase in degrees
@@ -475,12 +481,12 @@ def plot_sp(f: np.ndarray, S: list[np.ndarray] | np.ndarray,
475
481
  phase = np.unwrap(phase, period=360)
476
482
  minphase = min(np.min(phase), minphase)
477
483
  maxphase = max(np.max(phase), maxphase)
478
- ax_phase.plot(fnew, phase, label="Phase (degrees)", linestyle=ls, color=EMERGE_COLORS[cid % len(EMERGE_COLORS)])
484
+ ax_phase.plot(f, phase, label="Phase (degrees)", linestyle=ls, color=EMERGE_COLORS[cid % len(EMERGE_COLORS)])
479
485
 
480
486
  # Annotate level indicators if specified
481
487
  if isinstance(levelindicator, (int, float)) and levelindicator is not None:
482
488
  lvl = levelindicator
483
- fcross = hintersections(fnew, SdB, lvl)
489
+ fcross = hintersections(f, SdB, lvl)
484
490
  for fs in fcross:
485
491
  ax_mag.annotate(
486
492
  f"{str(fs)[:4]}{xunit}",
@@ -500,16 +506,18 @@ def plot_sp(f: np.ndarray, S: list[np.ndarray] | np.ndarray,
500
506
  f2 = fmax / unitdivider[xunit]
501
507
  ax_mag.fill_between([f1, f2], vmin,vmax, color='red', alpha=0.2)
502
508
  # Configure magnitude plot (ax_mag)
509
+ fmin = min([min(f) for f in fs])
510
+ fmax = max([max(f) for f in fs])
503
511
  ax_mag.set_ylabel("Magnitude (dB)")
504
512
  ax_mag.set_xlabel(f"Frequency ({xunit})")
505
- ax_mag.axis([min(fnew), max(fnew), dblim[0], max(maxy*1.1,dblim[1])]) # type: ignore
513
+ ax_mag.axis([fmin, fmax, dblim[0], max(maxy*1.1,dblim[1])]) # type: ignore
506
514
  ax_mag.axhline(y=0, color="k", linewidth=1)
507
515
  ax_mag.xaxis.set_minor_locator(tck.AutoMinorLocator(2))
508
516
  ax_mag.yaxis.set_minor_locator(tck.AutoMinorLocator(2))
509
517
  # Configure phase plot (ax_phase)
510
518
  ax_phase.set_ylabel("Phase (degrees)")
511
519
  ax_phase.set_xlabel(f"Frequency ({xunit})")
512
- ax_phase.axis([min(fnew), max(fnew), minphase, maxphase]) # type: ignore
520
+ ax_phase.axis([fmin, fmax, minphase, maxphase]) # type: ignore
513
521
  ax_phase.xaxis.set_minor_locator(tck.AutoMinorLocator(2))
514
522
  ax_phase.yaxis.set_minor_locator(tck.AutoMinorLocator(2))
515
523
  if logx:
@@ -533,7 +541,7 @@ def plot_ff(
533
541
  dB: bool = False,
534
542
  labels: Optional[List[str]] = None,
535
543
  xlabel: str = "Theta (rad)",
536
- ylabel: str = "|E|",
544
+ ylabel: str = "",
537
545
  linestyles: Union[str, List[str]] = "-",
538
546
  linewidth: float = 2.0,
539
547
  markers: Optional[Union[str, List[Optional[str]]]] = None,
@@ -612,6 +620,8 @@ def plot_ff(
612
620
  def plot_ff_polar(
613
621
  theta: np.ndarray,
614
622
  E: Union[np.ndarray, Sequence[np.ndarray]],
623
+ dB: bool = False,
624
+ dBfloor: float = -30,
615
625
  labels: Optional[List[str]] = None,
616
626
  linestyles: Union[str, List[str]] = "-",
617
627
  linewidth: float = 2.0,
@@ -649,6 +659,8 @@ def plot_ff_polar(
649
659
  E_list = list(E)
650
660
  n_series = len(E_list)
651
661
 
662
+ if dB:
663
+ E_list = [20*np.log10(np.clip(np.abs(e), a_min=10**(dBfloor/20), a_max = 1e9)) for e in E_list]
652
664
  # Style broadcasting
653
665
  def _broadcast(param, default):
654
666
  if isinstance(param, list):
@@ -665,11 +677,15 @@ def plot_ff_polar(
665
677
  ax.set_theta_zero_location(zero_location) # type: ignore
666
678
  ax.set_theta_direction(-1 if clockwise else 1) # type: ignore
667
679
  ax.set_rlabel_position(rlabel_angle) # type: ignore
680
+ ymin = min([min(E) for E in E_list])
681
+ ymax = max([max(E) for E in E_list])
682
+ yrange = ymax-ymin
668
683
 
684
+ ax.set_ylim(ymin-0.05*yrange, ymax+0.05*yrange)
669
685
  for i, Ei in enumerate(E_list):
670
- mag = np.abs(Ei)
686
+
671
687
  ax.plot(
672
- theta, mag,
688
+ theta, Ei,
673
689
  linestyle=linestyles[i],
674
690
  linewidth=linewidth,
675
691
  marker=markers[i],
@@ -17,8 +17,7 @@ with em.Simulation("myfile", load_file=True) as m:
17
17
  S11 = data.S(1,1)
18
18
  S21 = data.S(2,1)
19
19
  plt.plot_sp(f/1e9, [S11, S21])
20
-
21
- m.set_mesh(m.data.mw.field[0].mesh)
20
+
22
21
  m.display.add_object(m['box'])
23
22
  m.display.add_surf(*m.data.mw.field[0].cutplane(1*mm, z=5*mm).scalar('Ez','real'))
24
23
  m.display.show()
@@ -205,6 +205,10 @@ class Selection:
205
205
  def centers(self) -> list[tuple[float, float, float],]:
206
206
  return [gmsh.model.occ.get_center_of_mass(self.dim, tag) for tag in self.tags]
207
207
 
208
+ @property
209
+ def _metal(self) -> bool:
210
+ return False
211
+
208
212
  @property
209
213
  def opacity(self) -> float:
210
214
  return 0.6
@@ -31,7 +31,7 @@ from typing import Literal, Generator, Any
31
31
  from loguru import logger
32
32
  import numpy as np
33
33
  import gmsh # type: ignore
34
- import joblib # type: ignore
34
+ import cloudpickle
35
35
  import os
36
36
  import inspect
37
37
  from pathlib import Path
@@ -181,7 +181,10 @@ class Simulation:
181
181
 
182
182
  # Restier the Exit GMSH function on proper program abortion
183
183
  register(self._exit_gmsh)
184
-
184
+ else:
185
+ gmsh.finalize()
186
+ gmsh.initialize()
187
+
185
188
  # Create a new GMSH model or load it
186
189
  if not self.load_file:
187
190
  gmsh.model.add(self.modelname)
@@ -245,13 +248,20 @@ class Simulation:
245
248
  vM, vm, vp = [float(x) for x in version.split('.')]
246
249
  cM, cm, cp = [float(x) for x in __version__.split('.')]
247
250
  if vM != cM:
248
- raise VersionError(f"You are running a script designed for version {version} with a possibly incompatible version of EMerge {__version__}")
251
+ raise VersionError(f"You are running a script designed for version {version} with a possibly incompatible version of EMerge {__version__}. \n You can upgrade your version of emerge with: pip --upgrade emerge")
249
252
  if vm != cm:
250
- raise VersionError(f"You are running a script designed for version {version} with a possibly incompatible version of EMerge {__version__}")
253
+ raise VersionError(f"You are running a script designed for version {version} with a possibly incompatible version of EMerge {__version__}. \n You can upgrade your version of emerge with: pip --upgrade emerge")
251
254
  if vp != cp:
252
- logger.warning(f"You are running a script designed for version {version} with a possibly incompatible version of EMerge {__version__}")
255
+ logger.warning("You are running a script designed for a different version of EMerge.")
256
+ logger.warning(f"The script version: {version}")
257
+ logger.warning(f"EMerge version: {__version__}")
258
+ logger.warning("Usually EMerge works without a problem but Errors may occur.")
259
+ logger.warning("You can upgrade your version of emerge with: pip --upgrade emerge")
253
260
  logger.warning("You may suppress this error by removing the call to .check_version().")
254
- input('Press enter to proceed...')
261
+ logger.warning("Press Ctrl+C to abort.")
262
+ ans = input('Press enter to proceed or [Q] to quit:')
263
+ if ans.lower().strip()=='q':
264
+ quit()
255
265
 
256
266
  def save(self) -> None:
257
267
  """Saves the current model in the provided project directory."""
@@ -276,7 +286,8 @@ class Simulation:
276
286
  # Pack and save data
277
287
  dataset = dict(simdata=self.data, mesh=self.mesh)
278
288
  data_path = self.modelpath / 'simdata.emerge'
279
- joblib.dump(dataset, str(data_path))
289
+ with open(str(data_path), "wb") as f_out:
290
+ cloudpickle.dump(dataset, f_out)
280
291
  logger.info(f"Saved simulation data to: {data_path}")
281
292
 
282
293
  def load(self) -> None:
@@ -297,7 +308,8 @@ class Simulation:
297
308
  #self.mesh.update([])
298
309
 
299
310
  # Load data
300
- datapack = joblib.load(str(data_path))
311
+ with open(str(data_path), "rb") as f_in:
312
+ datapack= cloudpickle.load(f_in)
301
313
  self.data = datapack['simdata']
302
314
  self._set_mesh(datapack['mesh'])
303
315
  logger.info(f"Loaded simulation data from: {data_path}")
@@ -458,7 +470,7 @@ class Simulation:
458
470
 
459
471
  logger.info(f'Iterating: {params}')
460
472
  if len(dims_flat)==1:
461
- yield (dims_flat[0][i_iter],)
473
+ yield dims_flat[0][i_iter]
462
474
  else:
463
475
  yield (dim[i_iter] for dim in dims_flat) # type: ignore
464
476
  self.mw.cache_matrices = True
emerge/_emerge/solver.py CHANGED
@@ -289,10 +289,14 @@ class Solver:
289
289
 
290
290
  def __init__(self):
291
291
  self.own_preconditioner: bool = False
292
+ self.initialized: bool = False
292
293
 
293
294
  def __str__(self) -> str:
294
295
  return f'{self.__class__.__name__}'
295
296
 
297
+ def initialize(self) -> None:
298
+ return None
299
+
296
300
  def duplicate(self) -> Solver:
297
301
  return self.__class__()
298
302
 
@@ -324,6 +328,9 @@ class EigSolver:
324
328
  def __init__(self):
325
329
  self.own_preconditioner: bool = False
326
330
 
331
+ def initialize(self) -> None:
332
+ return None
333
+
327
334
  def __str__(self) -> str:
328
335
  return f'{self.__class__.__name__}'
329
336
 
@@ -513,7 +520,19 @@ class SolverUMFPACK(Solver):
513
520
  super().__init__()
514
521
  self.A: np.ndarray = None
515
522
  self.b: np.ndarray = None
516
- self.umfpack: um.UmfpackContext = um.UmfpackContext('zl')
523
+
524
+ self.umfpack: um.UmfpackContext | None = None
525
+
526
+ # SETTINGS
527
+ self._pivoting_threshold: float = 0.001
528
+
529
+ self.fact_symb: bool = False
530
+ self.initalized: bool = False
531
+
532
+ def initialize(self):
533
+ if self.initalized:
534
+ return
535
+ self.umfpack = um.UmfpackContext('zl')
517
536
  self.umfpack.control[um.UMFPACK_PRL] = 0 # ty: ignore
518
537
  self.umfpack.control[um.UMFPACK_IRSTEP] = 2 # ty: ignore
519
538
  self.umfpack.control[um.UMFPACK_STRATEGY] = um.UMFPACK_STRATEGY_SYMMETRIC # ty: ignore
@@ -522,21 +541,16 @@ class SolverUMFPACK(Solver):
522
541
  self.umfpack.control[um.UMFPACK_SYM_PIVOT_TOLERANCE] = 0.001 # ty: ignore
523
542
  self.umfpack.control[um.UMFPACK_BLOCK_SIZE] = 64 # ty: ignore
524
543
  self.umfpack.control[um.UMFPACK_FIXQ] = -1 # ty: ignore
525
-
526
- # SETTINGS
527
- self._pivoting_threshold: float = 0.001
528
-
529
- self.fact_symb: bool = False
530
-
544
+ self.initalized = True
531
545
  def reset(self) -> None:
532
546
  self.fact_symb = False
533
547
 
534
- def set_options(self,
535
- pivoting_threshold: float | None = None) -> None:
536
- if pivoting_threshold is not None:
537
- self.umfpack.control[um.UMFPACK_PIVOT_TOLERANCE] = pivoting_threshold # ty: ignore
538
- self.umfpack.control[um.UMFPACK_SYM_PIVOT_TOLERANCE] = pivoting_threshold # ty: ignore
539
- self._pivoting_threshold = pivoting_threshold
548
+ def set_options(self, pivoting_threshold: float | None = None) -> None:
549
+ self.initialize()
550
+ if pivoting_threshold is not None:
551
+ self.umfpack.control[um.UMFPACK_PIVOT_TOLERANCE] = pivoting_threshold # ty: ignore
552
+ self.umfpack.control[um.UMFPACK_SYM_PIVOT_TOLERANCE] = pivoting_threshold # ty: ignore
553
+ self._pivoting_threshold = pivoting_threshold
540
554
 
541
555
  def duplicate(self) -> Solver:
542
556
  new_solver = self.__class__()
@@ -568,11 +582,17 @@ class SolverPardiso(Solver):
568
582
 
569
583
  def __init__(self):
570
584
  super().__init__()
571
- self.solver: PardisoInterface = PardisoInterface()
585
+ self.solver: PardisoInterface | None = None
572
586
  self.fact_symb: bool = False
573
587
  self.A: np.ndarray = None
574
588
  self.b: np.ndarray = None
575
589
 
590
+ def initialize(self) -> None:
591
+ if self.initialized:
592
+ return
593
+ self.solver = PardisoInterface()
594
+ self.initialized = True
595
+
576
596
  def solve(self, A, b, precon, reuse_factorization: bool = False, id: int = -1) -> tuple[np.ndarray, SolveReport]:
577
597
  logger.info(f'[ID={id}] Calling Pardiso Solver')
578
598
  if self.fact_symb is False:
@@ -594,11 +614,18 @@ class SolverPardiso(Solver):
594
614
  class CuDSSSolver(Solver):
595
615
  real_only = False
596
616
  def __init__(self):
597
- self._cudss = CuDSSInterface()
617
+ super().__init__()
618
+ self._cudss: CuDSSInterface | None = None
598
619
  self.fact_symb: bool = False
599
620
  self.fact_numb: bool = False
621
+
622
+ def initialize(self) -> None:
623
+ if self.initialized:
624
+ return
625
+ self._cudss = CuDSSInterface()
600
626
  self._cudss._PRES = 2
601
-
627
+ self.initialized = True
628
+
602
629
  def reset(self) -> None:
603
630
  self.fact_symb = False
604
631
  self.fact_numb = False
@@ -1085,7 +1112,7 @@ class SolveRoutine:
1085
1112
  np.ndarray: The resultant solution.
1086
1113
  """
1087
1114
  solver: Solver = self._get_solver(A, b)
1088
-
1115
+ solver.initialize()
1089
1116
  NF = A.shape[0]
1090
1117
  NS = solve_ids.shape[0]
1091
1118
 
@@ -1178,7 +1205,7 @@ class SolveRoutine:
1178
1205
  SolveReport: The solution report
1179
1206
  """
1180
1207
  solver = self._get_eig_solver_bma(A, B, direct=direct)
1181
-
1208
+ solver.initialize()
1182
1209
  NF = A.shape[0]
1183
1210
  NS = solve_ids.shape[0]
1184
1211