emerge 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emerge might be problematic. Click here for more details.

Files changed (54) hide show
  1. emerge/_emerge/bc.py +14 -20
  2. emerge/_emerge/const.py +5 -0
  3. emerge/_emerge/cs.py +2 -2
  4. emerge/_emerge/elements/femdata.py +14 -14
  5. emerge/_emerge/elements/index_interp.py +1 -1
  6. emerge/_emerge/elements/ned2_interp.py +1 -1
  7. emerge/_emerge/elements/nedelec2.py +4 -4
  8. emerge/_emerge/elements/nedleg2.py +10 -10
  9. emerge/_emerge/geo/horn.py +1 -1
  10. emerge/_emerge/geo/modeler.py +18 -19
  11. emerge/_emerge/geo/operations.py +13 -10
  12. emerge/_emerge/geo/pcb.py +180 -82
  13. emerge/_emerge/geo/pcb_tools/calculator.py +2 -2
  14. emerge/_emerge/geo/pcb_tools/macro.py +14 -13
  15. emerge/_emerge/geo/pmlbox.py +1 -1
  16. emerge/_emerge/geometry.py +47 -33
  17. emerge/_emerge/logsettings.py +15 -16
  18. emerge/_emerge/material.py +15 -11
  19. emerge/_emerge/mesh3d.py +81 -59
  20. emerge/_emerge/mesher.py +26 -21
  21. emerge/_emerge/mth/integrals.py +1 -1
  22. emerge/_emerge/mth/pairing.py +2 -2
  23. emerge/_emerge/periodic.py +34 -31
  24. emerge/_emerge/physics/microwave/adaptive_freq.py +15 -16
  25. emerge/_emerge/physics/microwave/assembly/assembler.py +120 -93
  26. emerge/_emerge/physics/microwave/assembly/curlcurl.py +1 -8
  27. emerge/_emerge/physics/microwave/assembly/generalized_eigen.py +43 -8
  28. emerge/_emerge/physics/microwave/assembly/robinbc.py +5 -5
  29. emerge/_emerge/physics/microwave/microwave_3d.py +71 -44
  30. emerge/_emerge/physics/microwave/microwave_bc.py +206 -117
  31. emerge/_emerge/physics/microwave/microwave_data.py +36 -38
  32. emerge/_emerge/physics/microwave/sc.py +26 -26
  33. emerge/_emerge/physics/microwave/simjob.py +20 -15
  34. emerge/_emerge/physics/microwave/sparam.py +12 -12
  35. emerge/_emerge/physics/microwave/touchstone.py +1 -1
  36. emerge/_emerge/plot/display.py +12 -6
  37. emerge/_emerge/plot/pyvista/display.py +44 -39
  38. emerge/_emerge/plot/pyvista/display_settings.py +1 -1
  39. emerge/_emerge/plot/simple_plots.py +15 -15
  40. emerge/_emerge/selection.py +35 -39
  41. emerge/_emerge/simmodel.py +41 -47
  42. emerge/_emerge/simulation_data.py +24 -15
  43. emerge/_emerge/solve_interfaces/cudss_interface.py +238 -0
  44. emerge/_emerge/solve_interfaces/pardiso_interface.py +24 -18
  45. emerge/_emerge/solver.py +314 -136
  46. emerge/cli.py +1 -1
  47. emerge/lib.py +245 -248
  48. {emerge-0.5.1.dist-info → emerge-0.5.3.dist-info}/METADATA +5 -1
  49. emerge-0.5.3.dist-info/RECORD +83 -0
  50. emerge/_emerge/plot/grapher.py +0 -93
  51. emerge-0.5.1.dist-info/RECORD +0 -82
  52. {emerge-0.5.1.dist-info → emerge-0.5.3.dist-info}/WHEEL +0 -0
  53. {emerge-0.5.1.dist-info → emerge-0.5.3.dist-info}/entry_points.txt +0 -0
  54. {emerge-0.5.1.dist-info → emerge-0.5.3.dist-info}/licenses/LICENSE +0 -0
@@ -27,12 +27,11 @@ from .plot.pyvista import PVDisplay
27
27
  from .dataset import SimulationDataset
28
28
  from .periodic import PeriodicCell
29
29
  from .bc import BoundaryCondition
30
- from typing import Literal, Type, Generator, Any
30
+ from typing import Literal, Generator, Any
31
31
  from loguru import logger
32
32
  import numpy as np
33
- import sys
34
- import gmsh
35
- import joblib
33
+ import gmsh # type: ignore
34
+ import joblib # type: ignore
36
35
  import os
37
36
  import inspect
38
37
  from pathlib import Path
@@ -52,11 +51,9 @@ Known problems/solutions:
52
51
  --------------------------
53
52
  """
54
53
 
55
-
56
54
  class SimulationError(Exception):
57
55
  pass
58
56
 
59
-
60
57
  ############################################################
61
58
  # BASE 3D SIMULATION MODEL #
62
59
  ############################################################
@@ -94,18 +91,13 @@ class Simulation3D:
94
91
 
95
92
  self.mesh: Mesh3D = Mesh3D(self.mesher)
96
93
  self.select: Selector = Selector()
97
- self.display: PVDisplay = None
98
- self.set_loglevel(loglevel)
99
94
 
100
95
  ## STATES
101
96
  self.__active: bool = False
102
97
  self._defined_geometries: bool = False
103
- self._cell: PeriodicCell = None
98
+ self._cell: PeriodicCell | None = None
104
99
 
105
- self.display = PVDisplay(self.mesh)
106
-
107
- if logfile:
108
- self.set_logfile()
100
+ self.display: PVDisplay = PVDisplay(self.mesh)
109
101
 
110
102
  self.save_file: bool = save_file
111
103
  self.load_file: bool = load_file
@@ -117,6 +109,10 @@ class Simulation3D:
117
109
 
118
110
  self._initialize_simulation()
119
111
 
112
+ self.set_loglevel(loglevel)
113
+ if logfile:
114
+ self.set_logfile()
115
+
120
116
  self._update_data()
121
117
 
122
118
 
@@ -203,7 +199,9 @@ class Simulation3D:
203
199
  if self.save_file:
204
200
  self.save()
205
201
  # Finalize GMSH
206
- gmsh.finalize()
202
+ if gmsh.isInitialized():
203
+ gmsh.finalize()
204
+
207
205
  logger.debug('GMSH Shut down successful')
208
206
  # set the state to active
209
207
  self.__active = False
@@ -214,28 +212,22 @@ class Simulation3D:
214
212
 
215
213
  def all_geometries(self) -> list[GeoObject]:
216
214
  """Returns all geometries stored in the simulation file."""
217
- return [obj for obj in self.sim.default.values() if isinstance(obj, GeoObject)]
215
+ return [obj for obj in self.data.sim.default.values() if isinstance(obj, GeoObject)]
218
216
 
219
217
  def all_bcs(self) -> list[BoundaryCondition]:
220
218
  """Returns all boundary condition objects stored in the simulation file"""
221
- return [obj for obj in self.sim.default.values() if isinstance(obj, BoundaryCondition)]
219
+ return [obj for obj in self.data.sim.default.values() if isinstance(obj, BoundaryCondition)]
222
220
 
223
221
  def _set_mesh(self, mesh: Mesh3D) -> None:
224
222
  """Set the current model mesh to a given mesh."""
225
223
  self.mesh = mesh
226
224
  self.mw.mesh = mesh
227
- self.mesher.mesh = mesh
228
225
  self.display._mesh = mesh
229
226
 
230
227
  ############################################################
231
228
  # PUBLIC FUNCTIONS #
232
229
  ############################################################
233
230
 
234
- @property
235
- def passed_geometries(self) -> list[GeoObject]:
236
- """"""
237
- return self.data.sim['geometries']
238
-
239
231
  def save(self) -> None:
240
232
  """Saves the current model in the provided project directory."""
241
233
  # Ensure directory exists
@@ -292,13 +284,15 @@ class Simulation3D:
292
284
  loglevel ('DEBUG','INFO','WARNING','ERROR'): The loglevel
293
285
  """
294
286
  LOG_CONTROLLER.set_std_loglevel(loglevel)
287
+ if loglevel not in ('TRACE','DEBUG'):
288
+ gmsh.option.setNumber("General.Terminal", 0)
295
289
 
296
290
  def set_logfile(self) -> None:
297
291
  """Adds a file output for the logger."""
298
292
  LOG_CONTROLLER.set_write_file(self.modelpath)
299
293
 
300
294
  def view(self,
301
- selections: list[Selection] = None,
295
+ selections: list[Selection] | None = None,
302
296
  use_gmsh: bool = False,
303
297
  volume_opacity: float = 0.1,
304
298
  surface_opacity: float = 1,
@@ -316,22 +310,15 @@ class Simulation3D:
316
310
  gmsh.model.occ.synchronize()
317
311
  gmsh.fltk.run()
318
312
  return
319
- try:
320
- for obj in self.data.sim['geometries']:
321
- if obj.dim==2:
322
- opacity=surface_opacity
323
- elif obj.dim==3:
324
- opacity=volume_opacity
325
- self.display.add_object(obj, show_edges=show_edges, opacity=opacity)
326
- if selections:
327
- [self.display.add_object(sel, color='red', opacity=0.7) for sel in selections]
328
- self.display.show()
329
- return
330
- except NotImplementedError as e:
331
- logger.warning('The provided BaseDisplay class does not support object display. Please make' \
332
- 'sure that this method is properly implemented.')
333
-
334
- def set_periodic_cell(self, cell: PeriodicCell, excluded_faces: list[FaceSelection] = None):
313
+ for geo in _GEOMANAGER.all_geometries():
314
+ self.display.add_object(geo)
315
+ if selections:
316
+ [self.display.add_object(sel, color='red', opacity=0.7) for sel in selections]
317
+ self.display.show()
318
+
319
+ return None
320
+
321
+ def set_periodic_cell(self, cell: PeriodicCell, excluded_faces: list[FaceSelection] | None = None):
335
322
  """Set the given periodic cell object as the simulations peridicity.
336
323
 
337
324
  Args:
@@ -341,18 +328,20 @@ class Simulation3D:
341
328
  self.mw.bc._cell = cell
342
329
  self._cell = cell
343
330
 
344
- def commit_geometry(self, *geometries: list[GeoObject]) -> None:
331
+ def commit_geometry(self, *geometries: GeoObject | list[GeoObject]) -> None:
345
332
  """Finalizes and locks the current geometry state of the simulation.
346
333
 
347
- The geometries may be provided (legacy behavior) but are automatically managed underwater.
334
+ The geometries may be provided (legacy behavior) but are automatically managed in the background.
348
335
 
349
336
  """
337
+ geometries_parsed: Any = None
350
338
  if not geometries:
351
- geometries = _GEOMANAGER.all_geometries()
339
+ geometries_parsed = _GEOMANAGER.all_geometries()
352
340
  else:
353
- geometries = unpack_lists(geometries + tuple([item for item in self.data.sim.default.values() if isinstance(item, GeoObject)]))
354
- self.data.sim['geometries'] = geometries
355
- self.mesher.submit_objects(geometries)
341
+ geometries_parsed = unpack_lists(geometries + tuple([item for item in self.data.sim.default.values() if isinstance(item, GeoObject)]))
342
+
343
+ self.data.sim['geometries'] = geometries_parsed
344
+ self.mesher.submit_objects(geometries_parsed)
356
345
  self._defined_geometries = True
357
346
  self.display._facetags = [dt[1] for dt in gmsh.model.get_entities(2)]
358
347
 
@@ -385,7 +374,12 @@ class Simulation3D:
385
374
  self.mesher.set_mesh_size(self.mw.get_discretizer(), self.mw.resolution)
386
375
 
387
376
  try:
377
+ gmsh.logger.start()
388
378
  gmsh.model.mesh.generate(3)
379
+ logs = gmsh.logger.get()
380
+ gmsh.logger.stop()
381
+ for log in logs:
382
+ logger.trace('[GMSH] '+log)
389
383
  except Exception:
390
384
  logger.error('GMSH Mesh error detected.')
391
385
  print(_GMSH_ERROR_TEXT)
@@ -438,9 +432,9 @@ class Simulation3D:
438
432
 
439
433
  logger.info(f'Iterating: {params}')
440
434
  if len(dims_flat)==1:
441
- yield dims_flat[0][i_iter]
435
+ yield (dims_flat[0][i_iter],)
442
436
  else:
443
- yield (dim[i_iter] for dim in dims_flat)
437
+ yield (dim[i_iter] for dim in dims_flat) # type: ignore
444
438
  self.mw.cache_matrices = True
445
439
 
446
440
  ############################################################
@@ -18,7 +18,7 @@
18
18
  from __future__ import annotations
19
19
  import numpy as np
20
20
  from loguru import logger
21
- from typing import TypeVar, Generic, Any, List, Union, Dict
21
+ from typing import TypeVar, Generic, Any, List, Union, Dict, Generator
22
22
  from collections import defaultdict
23
23
 
24
24
  T = TypeVar("T")
@@ -84,7 +84,7 @@ def generate_ndim(
84
84
  outer_data: dict[str, list[float]],
85
85
  inner_data: list[float],
86
86
  outer_labels: tuple[str, ...]
87
- ) -> np.ndarray:
87
+ ) -> tuple[np.ndarray,...]:
88
88
  """
89
89
  Generates an N-dimensional grid of values from flattened data, and returns each axis array plus the grid.
90
90
 
@@ -126,7 +126,7 @@ def generate_ndim(
126
126
  grid[tuple(idxs)] = values
127
127
 
128
128
  # Return each axis array followed by the grid
129
- return (*axes, grid)
129
+ return tuple(axes) + (grid,)
130
130
 
131
131
 
132
132
  class DataEntry:
@@ -142,7 +142,7 @@ class DataEntry:
142
142
 
143
143
  def values(self) -> list[Any]:
144
144
  """ Return all values stored in the DataEntry"""
145
- return self.data.values()
145
+ return list(self.data.values())
146
146
 
147
147
  def keys(self) -> list[str]:
148
148
  """ Return all names of data stored in the DataEntry"""
@@ -150,12 +150,15 @@ class DataEntry:
150
150
 
151
151
  def items(self) -> list[tuple[str, Any]]:
152
152
  """ Returns a list of all key: value pairs of the DataEntry."""
153
+ return list(self.data.items())
153
154
 
154
- def __eq__(self, other: dict[str, float]) -> bool:
155
+ def __eq__(self, other: Any) -> bool:
156
+ if not isinstance(other, dict):
157
+ return False
155
158
  allkeys = set(list(self.vars.keys()) + list(other.keys()))
156
159
  return all(self.vars[key]==other[key] for key in allkeys)
157
160
 
158
- def _dist(self, other: dict[str, float]) -> bool:
161
+ def _dist(self, other: dict[str, float]) -> float:
159
162
  return sum([(abs(self.vars.get(key,1e20)-other[key])/other[key]) for key in other.keys()])
160
163
 
161
164
  def __getitem__(self, key) -> Any:
@@ -184,6 +187,10 @@ class DataContainer:
184
187
  self.entries.append(entry)
185
188
  return entry
186
189
 
190
+ def iterate(self) -> Generator[tuple[dict[str, float], dict[str, Any]], None, None]:
191
+ for entry in self.entries:
192
+ yield entry.vars, entry.data
193
+
187
194
  @property
188
195
  def last(self) -> DataEntry:
189
196
  """Returns the last added entry"""
@@ -218,18 +225,18 @@ class DataContainer:
218
225
 
219
226
 
220
227
  class BaseDataset(Generic[T,M]):
221
- def __init__(self, datatype: T, matrixtype: M, scalar: bool):
228
+ def __init__(self, datatype: type[T], matrixtype: type[M], scalar: bool):
222
229
  self._datatype: type[T] = datatype
223
230
  self._matrixtype: type[M] = matrixtype
224
231
  self._variables: list[dict[str, float]] = []
225
232
  self._data_entries: list[T] = []
226
233
  self._scalar: bool = scalar
227
234
 
228
- self._gritted: bool = None
229
- self._axes: dict[str, np.ndarray] = None
230
- self._ax_ids: dict[str, int] = None
231
- self._ids: np.ndarray = None
232
- self._gridobj: M = None
235
+ self._gritted: bool | None = None
236
+ self._axes: dict[str, np.ndarray]| None = None
237
+ self._ax_ids: dict[str, int]| None = None
238
+ self._ids: np.ndarray| None = None
239
+ self._gridobj: M | None = None
233
240
 
234
241
  self._data: dict[str, Any] = dict()
235
242
 
@@ -238,11 +245,11 @@ class BaseDataset(Generic[T,M]):
238
245
 
239
246
  @property
240
247
  def _fields(self) -> list[str]:
241
- return self._datatype._fields
248
+ return self._datatype._fields # type: ignore
242
249
 
243
250
  @property
244
251
  def _copy(self) -> list[str]:
245
- return self._datatype._copy
252
+ return self._datatype._copy # type: ignore
246
253
 
247
254
  def store(self, key: str, value: Any) -> None:
248
255
  """Stores a variable with some value in the provided key.
@@ -335,7 +342,7 @@ class BaseDataset(Generic[T,M]):
335
342
  self._data_entries.append(new_entry)
336
343
  return new_entry
337
344
 
338
- def _grid_axes(self) -> None:
345
+ def _grid_axes(self) -> bool:
339
346
  """This method attepmts to create a gritted version of the scalar dataset
340
347
 
341
348
  Returns:
@@ -405,7 +412,9 @@ class BaseDataset(Generic[T,M]):
405
412
  """
406
413
  if self._gritted is None:
407
414
  self._grid_axes()
415
+
408
416
  if self._gritted is False:
409
417
  logger.error('The dataset cannot be cast to a structured grid.')
410
418
  raise ValueError('Data not in regular grid')
419
+
411
420
  return self._gridobj
@@ -0,0 +1,238 @@
1
+ # EMerge is an open source Python based FEM EM simulation module.
2
+ # Copyright (C) 2025 Robert Fennis.
3
+
4
+ # This program is free software; you can redistribute it and/or
5
+ # modify it under the terms of the GNU General Public License
6
+ # as published by the Free Software Foundation; either version 2
7
+ # of the License, or (at your option) any later version.
8
+
9
+ # This program is distributed in the hope that it will be useful,
10
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
11
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
+ # GNU General Public License for more details.
13
+
14
+ # You should have received a copy of the GNU General Public License
15
+ # along with this program; if not, see
16
+ # <https://www.gnu.org/licenses/>.
17
+
18
+ import cupy as cp # ty: ignore
19
+ import nvmath.bindings.cudss as cudss # ty: ignore
20
+ from nvmath import CudaDataType # ty: ignore
21
+
22
+ from scipy.sparse import csr_matrix
23
+ import numpy as np
24
+
25
+ from loguru import logger
26
+
27
+
28
+ ############################################################
29
+ # CONSTANTS #
30
+ ############################################################
31
+
32
+ ALG_NEST_DISS_METIS = cudss.AlgType.ALG_DEFAULT
33
+ ALG_COLAMD = cudss.AlgType.ALG_1
34
+ ALG_COLAMD_BLOCK_TRI = cudss.AlgType.ALG_2
35
+ ALG_AMD = cudss.AlgType.ALG_3
36
+
37
+ FLOAT64 = CudaDataType.CUDA_R_64F
38
+ FLOAT32 = CudaDataType.CUDA_R_32F
39
+ COMPLEX128 = CudaDataType.CUDA_C_64F
40
+ COMPLEX64 = CudaDataType.CUDA_C_32F
41
+ INT64 = CudaDataType.CUDA_R_64I
42
+ INT32 = CudaDataType.CUDA_R_32I
43
+
44
+ INDEX_BASE = cudss.IndexBase.ZERO
45
+
46
+ def _c_pointer(arry) -> int:
47
+ return int(arry.data.ptr)
48
+
49
+ ############################################################
50
+ # INTERFACE #
51
+ ############################################################
52
+
53
+ class CuDSSInterface:
54
+ def __init__(self):
55
+ self.A_cu = None
56
+ self.b_cu = None
57
+ self.x_cu = None
58
+ self.A_cobj = None
59
+ self.b_cobj = None
60
+ self.x_cobj = None
61
+ self.A_pattern = None
62
+
63
+ self._handle = cudss.create()
64
+ self._config = cudss.config_create()
65
+ self._data = cudss.data_create(self._handle)
66
+
67
+ self.MTYPE = cudss.MatrixType.SYMMETRIC
68
+ self.MVIEW = cudss.MatrixViewType.FULL
69
+ self.RALG = cudss.AlgType.ALG_DEFAULT
70
+ self.VTYPE = CudaDataType.CUDA_R_64F
71
+
72
+ self._INDPTR = None
73
+ self._ROW_START: int | None = None
74
+ self._ROW_END: int | None = None
75
+ self._IND = None
76
+ self._VAL = None
77
+ self._NNZ: int | None = None
78
+ self._COMP: bool = True
79
+ self._PRES: int = 2
80
+ self._COL_IDS = None
81
+
82
+ self._initialized = False
83
+
84
+ param = cudss.ConfigParam.REORDERING_ALG
85
+ dtype = cudss.get_config_param_dtype(int(param))
86
+ reorder_alg = np.array(self.RALG, dtype=dtype)
87
+
88
+ cudss.config_set(
89
+ self._config,
90
+ int(param),
91
+ reorder_alg.ctypes.data,
92
+ reorder_alg.nbytes
93
+ )
94
+
95
+ def set_algorithm(self, alg_type: cudss.AlgType):
96
+ self.RALG = alg_type
97
+
98
+ def init_type(self):
99
+ if self._PRES == 1:
100
+ if self._COMP:
101
+ self.c_dtype = cp.complex64
102
+ self.VTYPE = COMPLEX64
103
+ else:
104
+ self.c_dtype = cp.float32
105
+ self.VTYPE = FLOAT32
106
+ else:
107
+ if self._COMP:
108
+ self.c_dtype = cp.complex128
109
+ self.VTYPE = COMPLEX128
110
+ else:
111
+ self.c_dtype = cp.float64
112
+ self.VTYPE = FLOAT64
113
+
114
+ def submit_matrix(self, A: csr_matrix):
115
+ self.N = A.shape[0]
116
+
117
+ if np.iscomplexobj(A):
118
+ self._COMP = True
119
+ else:
120
+ self._COMP = False
121
+
122
+ self.init_type()
123
+
124
+ self.A_cu = cp.sparse.csr_matrix(A).astype(self.c_dtype)
125
+
126
+ self._INDPTR = cp.ascontiguousarray(self.A_cu.indptr.astype(cp.int32))
127
+ self._IND = cp.ascontiguousarray(self.A_cu.indices.astype(cp.int32))
128
+ self._VAL = cp.ascontiguousarray(self.A_cu.data)
129
+ self._NNZ = int(self._VAL.size)
130
+ self._ROW_START = self._INDPTR[:-1]
131
+ self._ROW_END = self._INDPTR[1:]
132
+ self._COL_IDS = self.A_cu.indices.astype(cp.int32)
133
+
134
+ def submit_vector(self, b: np.ndarray):
135
+ self.b_cu = cp.array(b).astype(self.c_dtype)
136
+
137
+ def create_solvec(self):
138
+ self.x_cu = cp.empty_like(self.b_cu)
139
+
140
+ def _update_dss_data(self):
141
+ cudss.matrix_set_values(self.A_cobj, _c_pointer(self._VAL))
142
+
143
+
144
+ self.b_cobj = cudss.matrix_create_dn(self.N, 1, self.N, _c_pointer(self.b_cu),
145
+ int(self.VTYPE), int(cudss.Layout.COL_MAJOR))
146
+ self.x_cobj = cudss.matrix_create_dn(self.N, 1, self.N, _c_pointer(self.x_cu),
147
+ int(self.VTYPE), int(cudss.Layout.COL_MAJOR))
148
+
149
+ def _create_dss_data(self):
150
+ self.A_cobj = cudss.matrix_create_csr(
151
+ self.N,self.N,self._NNZ,
152
+ _c_pointer(self._ROW_START),
153
+ _c_pointer(self._ROW_END),
154
+ _c_pointer(self._COL_IDS),
155
+ _c_pointer(self._VAL),
156
+ int(INT32),
157
+ int(self.VTYPE),
158
+ int(self.MTYPE),
159
+ int(self.MVIEW),
160
+ int(INDEX_BASE),
161
+ )
162
+
163
+ self.b_cobj = cudss.matrix_create_dn(self.N, 1, self.N, _c_pointer(self.b_cu),
164
+ int(self.VTYPE), int(cudss.Layout.COL_MAJOR))
165
+ self.x_cobj = cudss.matrix_create_dn(self.N, 1, self.N, _c_pointer(self.x_cu),
166
+ int(self.VTYPE), int(cudss.Layout.COL_MAJOR))
167
+
168
+ def from_symbolic(self, A: csr_matrix, b: np.ndarray) -> np.ndarray:
169
+ """Solves Ax=b starting from the symbolic factorization
170
+
171
+ Args:
172
+ A (csr_matrix): The input sparse matrix
173
+ b (np.ndarray): The solution vector b
174
+
175
+ Returns:
176
+ np.ndarray: The solved vector
177
+ """
178
+ self.submit_matrix(A)
179
+ self.submit_vector(b)
180
+ self.create_solvec()
181
+ self._create_dss_data()
182
+ self._symbolic()
183
+ self._numeric(False)
184
+ return self._solve()
185
+
186
+ def from_numeric(self, A: csr_matrix, b: np.ndarray) -> np.ndarray:
187
+ """Solves Ax=b starting from the Numeric factorization
188
+
189
+ Args:
190
+ A (csr_matrix): The input sparse matrix
191
+ b (np.ndarray): The solution vector b
192
+
193
+ Returns:
194
+ np.ndarray: The solved vector
195
+ """
196
+ self.submit_matrix(A)
197
+ self.submit_vector(b)
198
+ self.create_solvec()
199
+ self._update_dss_data()
200
+ self._numeric(True)
201
+ return self._solve()
202
+
203
+ def from_solve(self, b: np.ndarray) -> np.ndarray:
204
+ """Solves Ax=b only with a new b vector.
205
+
206
+ Args:
207
+ A (csr_matrix): The input sparse matrix
208
+ b (np.ndarray): The solution vector b
209
+
210
+ Returns:
211
+ np.ndarray: The solved vector
212
+ """
213
+ self.submit_vector(b)
214
+ self.create_solvec()
215
+ return self._solve()
216
+
217
+ def _symbolic(self):
218
+ logger.trace('Executing symbolic factorization')
219
+ cudss.execute(self._handle, cudss.Phase.ANALYSIS, self._config, self._data,
220
+ self.A_cobj, self.x_cobj, self.b_cobj)
221
+
222
+ def _numeric(self, refactorize: bool = False):
223
+ if refactorize:
224
+ logger.trace('Refactoring matrix')
225
+ phase = cudss.Phase.REFACTORIZATION
226
+ else:
227
+ phase = cudss.Phase.FACTORIZATION
228
+ logger.trace('Executing numerical factorization')
229
+ cudss.execute(self._handle, phase, self._config, self._data,
230
+ self.A_cobj, self.x_cobj, self.b_cobj)
231
+
232
+ def _solve(self) -> np.ndarray:
233
+ logger.trace('Solving matrix problem')
234
+ cudss.execute(self._handle, cudss.Phase.SOLVE, self._config, self._data,
235
+ self.A_cobj, self.x_cobj, self.b_cobj)
236
+ cp.cuda.runtime.deviceSynchronize()
237
+ x_host = cp.asnumpy(self.x_cu).ravel()
238
+ return x_host