pyMOTO 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pyMOTO
3
- Version: 1.3.0
3
+ Version: 1.5.0
4
4
  Summary: A modular approach for topology optimization
5
5
  Home-page: https://github.com/aatmdelissen/pyMOTO
6
6
  Author: Arnoud Delissen
@@ -25,7 +25,7 @@ Requires-Dist: scikit-sparse ; extra == 'dev'
25
25
  Requires-Dist: pypardiso ; extra == 'dev'
26
26
  Requires-Dist: jax[cpu] ; extra == 'dev'
27
27
 
28
- [![10.5281/zenodo.7708738](https://zenodo.org/badge/DOI/10.5281/zenodo.7708738.svg)](https://doi.org/10.5281/zenodo.7708738)
28
+ [![10.5281/zenodo.8138859](https://zenodo.org/badge/DOI/10.5281/zenodo.8138859.svg)](https://doi.org/10.5281/zenodo.8138859)
29
29
  [![anaconda.org/aatmdelissen/pymoto](https://anaconda.org/aatmdelissen/pymoto/badges/version.svg)](https://anaconda.org/aatmdelissen/pymoto)
30
30
  [![pypi.org/project/pyMOTO](https://badge.fury.io/py/pyMOTO.svg)](https://pypi.org/project/pyMOTO/)
31
31
 
@@ -61,15 +61,14 @@ automatically calculated.
61
61
  4. Run the example by typing `python ex_....py` in the console
62
62
 
63
63
 
64
- A local installation for development in `pyMOTO` can be done by first downloading the entire git repo, and then calling
64
+ For development, a local installation of `pyMOTO` can be done by first downloading/cloning the entire git repo, and then calling
65
65
  `pip install -e .` in the `pyMOTO` folder (of course from within your virtual environment).
66
66
 
67
67
  ## Dependencies
68
- * **NumPy** - Dense linear algebra and solvers
69
- * **SciPy** - Sparse linear algebra and solvers
70
- * **SymPy** - Symbolic differentiation for `MathGeneral` module
71
- * **Matplotlib** - Plotting and visualisation
72
- * (optional) **SAO** - Sequential approximated optimizers
68
+ * [**numpy**](https://numpy.org/doc/stable/) - Dense linear algebra and solvers
69
+ * [**scipy**](https://docs.scipy.org/doc/scipy/) - Sparse linear algebra and solvers
70
+ * [**sympy**](https://docs.sympy.org/latest/index.html) - Symbolic differentiation for `MathGeneral` module
71
+ * [**Matplotlib**](https://matplotlib.org/stable/) - Plotting and visualisation
73
72
  * (optional) [**opt_einsum**](https://optimized-einsum.readthedocs.io/en/stable/install.html) - Optimized function for `EinSum` module
74
73
 
75
74
  For fast linear solvers for sparse matrices:
@@ -0,0 +1,29 @@
1
+ pymoto/__init__.py,sha256=YLMAiO2PZHAC6nYWXVh03rhZnZkc_Rc2z7SGQj1T8I4,2058
2
+ pymoto/core_objects.py,sha256=88AOo041wrcRSPoLCRwBXUoYANGX-b2SAA0Nuf6sn2Y,25252
3
+ pymoto/routines.py,sha256=yjvcQDcWU47ZM6ZpZoX8VwrJoN9JDO_25DSa9oVcKec,15582
4
+ pymoto/utils.py,sha256=YJ-PNLJLc12Yx6TYCrEechS2aaBRx0o4mTM1soeeyz0,1122
5
+ pymoto/common/domain.py,sha256=-eFuYRLehQ17Ai-cV59f4I9FbEM-DJAj6kjjVfj31X0,18120
6
+ pymoto/common/dyadcarrier.py,sha256=VwMbqPr0NMDPfpsH0BwvXp8M1dmh8ijFDpF6yoyTmto,19394
7
+ pymoto/common/mma.py,sha256=Pof3clOHA8PG51TmUjs11dkjSP96kovZjsPv62tI2Ec,24055
8
+ pymoto/modules/aggregation.py,sha256=Oi17hIJ6dic4lOPw16zmjbdC72MjB6XK34H80bnbWAI,7580
9
+ pymoto/modules/assembly.py,sha256=quuR8QpB2w-O0zly-xS6PK6wZQMY6S5TWh15Y9wuh14,22974
10
+ pymoto/modules/autodiff.py,sha256=WAfoAOHBSozf7jbr9gQz9Vw4a_2G9wGJxLMMqUQP0Co,1684
11
+ pymoto/modules/complex.py,sha256=B_Obk-ABdV66lEudZ5s8o6qG9NsmYlBsX-PbWvbphhc,4429
12
+ pymoto/modules/filter.py,sha256=6X9FaQMWYZ_TpHVTFiEibzlmAwmSWbydYM93LFrJ0Wo,25490
13
+ pymoto/modules/generic.py,sha256=YzsGZ8J0oLCORt78Bf2p0v4GuqpWRI77NLoCk7gqidw,10666
14
+ pymoto/modules/io.py,sha256=LcFvJ-cPgg5ee-aag8kaxHw5RzQ-ggxOM5jk7PeJ1r8,13140
15
+ pymoto/modules/linalg.py,sha256=BNkih4nvvkYuQpm4bG5U38dAjkfD5EFi3MjANpBEPPI,21927
16
+ pymoto/modules/scaling.py,sha256=uq88HHW9rP16XLz7UGc3CNBBpY2Z1glo8yjYxZEnXUg,2327
17
+ pymoto/solvers/__init__.py,sha256=9JUeD2SgZbkYFullA7s7s6SuAVv0onqAqJ8hFvNOs2g,1033
18
+ pymoto/solvers/auto_determine.py,sha256=X8MEG7h6jLfAV1inpja45_-suG8qQFMfLMDfW2ryQqQ,5134
19
+ pymoto/solvers/dense.py,sha256=9fKPCwNxRKAEk5k1A7fdLrr9ngeVssGlw-sbjWCm4iU,11235
20
+ pymoto/solvers/iterative.py,sha256=CIxJHjGnCaIjXbtO2NxV60yeDpcCbSD6Bp0xR-7vOf0,12944
21
+ pymoto/solvers/matrix_checks.py,sha256=bbrfjpTSWWnuQW3xY0_CYE8yrh5gA9K5b1LzHEOFAxI,1663
22
+ pymoto/solvers/solvers.py,sha256=RwHjZYYlE3oA0U9k7ukla2gOdmq57rSSJQvHqjaM7JU,10626
23
+ pymoto/solvers/sparse.py,sha256=w8XBlFBIfOpNnfRdLWhLzzqtD8YVxMnDBuhIabFfQQc,16664
24
+ pyMOTO-1.5.0.dist-info/LICENSE,sha256=ZXMC2Txpzs-dBwz9Me4_1rQCSVl4P1B27MomNi43F30,1072
25
+ pyMOTO-1.5.0.dist-info/METADATA,sha256=hC38SdgeKEK5NkNDh-gwc4Gz2JOFSymJB-eAMXl7HX4,5006
26
+ pyMOTO-1.5.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
27
+ pyMOTO-1.5.0.dist-info/top_level.txt,sha256=EdvAUSmFMaiqhuEZW8jxANMiK-LdPtlmDWL6SfmCdUU,7
28
+ pyMOTO-1.5.0.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
29
+ pyMOTO-1.5.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.42.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
pymoto/__init__.py CHANGED
@@ -1,25 +1,27 @@
1
- __version__ = '1.3.0'
1
+ __version__ = '1.5.0'
2
2
 
3
3
  from .common.domain import DomainDefinition
4
4
 
5
5
  # Imports from common
6
6
  from .common.dyadcarrier import DyadCarrier
7
7
  from .common.mma import MMA
8
- from .common.solvers import matrix_is_complex, matrix_is_diagonal, matrix_is_symmetric, matrix_is_hermitian, LinearSolver, LDAWrapper
9
- from .common.solvers_dense import SolverDiagonal, SolverDenseQR, SolverDenseLU, SolverDenseCholesky, SolverDenseLDL
10
- from .common.solvers_sparse import SolverSparsePardiso, SolverSparseLU, SolverSparseCholeskyScikit, SolverSparseCholeskyCVXOPT
8
+
9
+ # Import solvers
10
+ from . import solvers
11
11
 
12
12
  # Modular inports
13
13
  from .core_objects import Signal, Module, Network, make_signals
14
14
 
15
15
  # Import modules
16
16
  from .modules.assembly import AssembleGeneral, AssembleStiffness, AssembleMass, AssemblePoisson
17
+ from .modules.assembly import ElementOperation, Strain, Stress, ElementAverage, NodalOperation, ThermoMechanical
17
18
  from .modules.autodiff import AutoMod
18
19
  from .modules.complex import MakeComplex, RealPart, ImagPart, ComplexNorm
19
20
  from .modules.filter import FilterConv, Filter, DensityFilter, OverhangFilter
20
21
  from .modules.generic import MathGeneral, EinSum, ConcatSignal
21
- from .modules.io import PlotDomain, PlotGraph, PlotIter, WriteToVTI
22
+ from .modules.io import FigModule, PlotDomain, PlotGraph, PlotIter, WriteToVTI, ScalarToFile
22
23
  from .modules.linalg import Inverse, LinSolve, EigenSolve, SystemOfEquations, StaticCondensation
24
+ from .modules.aggregation import AggScaling, AggActiveSet, Aggregation, PNorm, SoftMinMax, KSFunction
23
25
  from .modules.scaling import Scaling
24
26
 
25
27
  # Further helper routines
@@ -28,21 +30,25 @@ from .routines import finite_difference, minimize_oc, minimize_mma
28
30
  __all__ = [
29
31
  'Signal', 'Module', 'Network', 'make_signals',
30
32
  'finite_difference', 'minimize_oc', 'minimize_mma',
33
+
31
34
  # Common
32
35
  'MMA',
33
36
  'DyadCarrier',
34
37
  'DomainDefinition',
35
- 'matrix_is_complex', 'matrix_is_diagonal', 'matrix_is_symmetric', 'matrix_is_hermitian',
36
- 'LinearSolver', 'LDAWrapper',
37
- 'SolverDiagonal', 'SolverDenseQR', 'SolverDenseLU', 'SolverDenseCholesky', 'SolverDenseLDL',
38
- 'SolverSparsePardiso', 'SolverSparseLU', 'SolverSparseCholeskyScikit', 'SolverSparseCholeskyCVXOPT',
38
+ 'solvers',
39
+
40
+ # Helpers
41
+ "AggScaling", "AggActiveSet",
42
+
39
43
  # Modules
40
44
  "MathGeneral", "EinSum", "ConcatSignal",
41
45
  "Inverse", "LinSolve", "EigenSolve", "SystemOfEquations", "StaticCondensation",
42
46
  "AssembleGeneral", "AssembleStiffness", "AssembleMass", "AssemblePoisson",
47
+ "ElementOperation", "Strain", "Stress", "ElementAverage", "NodalOperation", "ThermoMechanical",
43
48
  "FilterConv", "Filter", "DensityFilter", "OverhangFilter",
44
- "PlotDomain", "PlotGraph", "PlotIter", "WriteToVTI",
49
+ "FigModule", "PlotDomain", "PlotGraph", "PlotIter", "WriteToVTI", "ScalarToFile",
45
50
  "MakeComplex", "RealPart", "ImagPart", "ComplexNorm",
46
51
  "AutoMod",
47
- "Scaling"
52
+ "Aggregation", "PNorm", "SoftMinMax", "KSFunction",
53
+ "Scaling",
48
54
  ]
pymoto/common/domain.py CHANGED
@@ -5,6 +5,31 @@ import struct
5
5
  import warnings
6
6
  from typing import Union
7
7
  import numpy as np
8
+ from matplotlib.patches import PathPatch
9
+ from matplotlib.path import Path
10
+
11
+
12
+ def plot_deformed_element(ax, x, y, **kwargs):
13
+ codes, verts = zip(*[
14
+ (Path.MOVETO, [x[0], y[0]]),
15
+ (Path.LINETO, [x[1], y[1]]),
16
+ (Path.LINETO, [x[3], y[3]]),
17
+ (Path.LINETO, [x[2], y[2]]),
18
+ (Path.CLOSEPOLY, [x[0], y[0]])])
19
+ path = Path(verts, codes)
20
+ patch = PathPatch(path, **kwargs)
21
+ ax.add_artist(patch)
22
+ return patch
23
+
24
+
25
+ def get_path(x, y):
26
+ codes, verts = zip(*[
27
+ (Path.MOVETO, [x[0], y[0]]),
28
+ (Path.LINETO, [x[1], y[1]]),
29
+ (Path.LINETO, [x[3], y[3]]),
30
+ (Path.LINETO, [x[2], y[2]]),
31
+ (Path.CLOSEPOLY, [x[0], y[0]])])
32
+ return Path(verts, codes)
8
33
 
9
34
 
10
35
  class DomainDefinition:
@@ -100,6 +125,14 @@ class DomainDefinition:
100
125
  self.conn = np.zeros((self.nel, self.elemnodes), dtype=int)
101
126
  self.conn[el, :] = self.get_elemconnectivity(elx, ely, elz)
102
127
 
128
+ # Helper for element slicing
129
+ eli, elj, elk = np.meshgrid(np.arange(self.nelx), np.arange(self.nely), np.arange(max(self.nelz, 1)), indexing='ij')
130
+ self.elements = self.get_elemnumber(eli, elj, elk)
131
+
132
+ # Helper for node slicing
133
+ ndi, ndj, ndk = np.meshgrid(np.arange(self.nelx+1), np.arange(self.nely+1), np.arange(self.nelz+1), indexing='ij')
134
+ self.nodes = self.get_nodenumber(ndi, ndj, ndk)
135
+
103
136
  def get_elemnumber(self, eli: Union[int, np.ndarray], elj: Union[int, np.ndarray], elk: Union[int, np.ndarray] = 0):
104
137
  """ Gets the element number(s) for element(s) with given Cartesian indices (i, j, k)
105
138
 
@@ -126,7 +159,7 @@ class DomainDefinition:
126
159
  """
127
160
  return (nodk * (self.nely + 1) + nodj) * (self.nelx + 1) + nodi
128
161
 
129
- def get_node_indices(self, nod_idx: Union[int, np.ndarray]):
162
+ def get_node_indices(self, nod_idx: Union[int, np.ndarray] = None):
130
163
  """ Gets the Cartesian index (i, j, k) for given node number(s)
131
164
 
132
165
  Args:
@@ -135,16 +168,18 @@ class DomainDefinition:
135
168
  Returns:
136
169
  i, j, k for requested node(s); k is only returned in 3D
137
170
  """
171
+ if nod_idx is None:
172
+ nod_idx = np.arange(self.nnodes)
138
173
  nodi = nod_idx % (self.nelx + 1)
139
174
  nodj = (nod_idx // (self.nelx + 1)) % (self.nely + 1)
140
175
  if self.dim == 2:
141
- return nodi, nodj
176
+ return np.stack([nodi, nodj], axis=0)
142
177
  nodk = nod_idx // ((self.nelx + 1)*(self.nely + 1))
143
- return nodi, nodj, nodk
178
+ return np.stack([nodi, nodj, nodk], axis=0)
144
179
 
145
- def get_node_position(self, nod_idx: Union[int, np.ndarray]):
180
+ def get_node_position(self, nod_idx: Union[int, np.ndarray] = None):
146
181
  ijk = self.get_node_indices(nod_idx)
147
- return [idx * self.element_size[ii] for ii, idx in enumerate(ijk)]
182
+ return (self.element_size[:self.dim] * ijk.T).T
148
183
 
149
184
  def get_elemconnectivity(self, i: Union[int, np.ndarray], j: Union[int, np.ndarray], k: Union[int, np.ndarray] = 0):
150
185
  """ Get the connectivity for element identified with Cartesian indices (i, j, k)
@@ -230,6 +265,27 @@ class DomainDefinition:
230
265
  dN_dx[i, :] *= np.array([n[i] for n in self.node_numbering]) # Flip +/- signs according to node position
231
266
  return dN_dx
232
267
 
268
+ def plot(self, ax, deformation=None, scaling=None):
269
+ patches = []
270
+ for e in range(self.nel):
271
+ n = self.conn[e]
272
+ x, y = self.get_node_position(n)
273
+ u, v = deformation[n * 2], deformation[n * 2 + 1]
274
+ color = (1 - scaling[e], 1 - scaling[e], 1 - scaling[e]) if scaling is not None else 'grey'
275
+ patch = plot_deformed_element(ax, x + u, v + y, linewidth=0.1, color=color)
276
+ patches.append(patch)
277
+ return patches
278
+
279
+ def update_plot(self, patches, deformation=None, scaling=None):
280
+ for e in range(self.nel):
281
+ patch = patches[e]
282
+ n = self.conn[e]
283
+ x, y = self.get_node_position(n)
284
+ u, v = deformation[n * 2], deformation[n * 2 + 1]
285
+ color = (1 - scaling[e], 1 - scaling[e], 1 - scaling[e]) if scaling is not None else 'grey'
286
+ patch.set_color(color)
287
+ patch.set_path(self.get_path(x + u, y + v))
288
+
233
289
  # flake8: noqa: C901
234
290
  def write_to_vti(self, vectors: dict, filename="out.vti", scale=1.0, origin=(0.0, 0.0, 0.0)):
235
291
  """ Write all given vectors to a Paraview (VTI) file
@@ -1,8 +1,8 @@
1
- from typing import Union, Iterable
1
+ from typing import Union, Iterable, List, Tuple
2
2
  import warnings
3
3
  import numpy as np
4
4
  from numpy.typing import NDArray
5
- from scipy.sparse import spmatrix
5
+ from scipy.sparse import spmatrix, coo_matrix
6
6
  from ..utils import _parse_to_list
7
7
  try: # Import fast optimized einsum
8
8
  from opt_einsum import contract as einsum
@@ -31,25 +31,26 @@ def isnullslice(x):
31
31
 
32
32
 
33
33
  class DyadCarrier(object):
34
- """ Efficient storage for dyadic or rank-N matrix
34
+ r""" Efficient storage for dyadic or rank-N matrix
35
35
 
36
36
  Stores only the vectors instead of creating a full rank-N matrix
37
37
  :math:`\mathbf{A} = \sum_k^N \mathbf{u}_k\otimes\mathbf{v}_k`
38
38
  or in index notation :math:`A_{ij} = \sum_k^N u_{ki} v_{kj}`. This saves a lot of memory for low :math:`N`.
39
39
 
40
- Args:
41
- u : (optional) List of vectors
42
- v : (optional) List of vectors (if ``u`` is given and ``v`` not, a symmetric dyad is assumed with ``v = u``)
40
+ Keyword Args:
41
+ u: List of vectors
42
+ v: List of vectors (if ``u`` is given and ``v`` not, a symmetric dyad is assumed with ``v = u``)
43
+ shape: Shape of the matrix
43
44
  """
44
45
 
45
46
  __array_priority__ = 11.0 # For overriding numpy's ufuncs
46
47
  ndim = 2 # Number of dimensions
47
48
 
48
- def __init__(self, u: Iterable = None, v: Iterable = None):
49
+ def __init__(self, u: Iterable = None, v: Iterable = None, shape: Tuple[int, int] = (-1, -1)):
49
50
  self.u = []
50
51
  self.v = []
51
- self.ulen = -1
52
- self.vlen = -1
52
+ self.ulen = shape[0]
53
+ self.vlen = shape[1]
53
54
  self.dtype = np.dtype('float64') # Standard data type
54
55
  self.add_dyad(u, v)
55
56
 
@@ -66,8 +67,14 @@ class DyadCarrier(object):
66
67
  else:
67
68
  return self.ulen * self.vlen
68
69
 
70
+ @property
71
+ def n_dyads(self):
72
+ """ Number of dyads stored """
73
+ assert len(self.u) == len(self.v)
74
+ return len(self.u)
75
+
69
76
  def add_dyad(self, u: Iterable, v: Iterable = None, fac: float = None):
70
- """ Adds a list of vectors to the dyad carrier
77
+ r""" Adds a list of vectors to the dyad carrier
71
78
 
72
79
  Checks for conforming sizes of `u` and `v`. The data inside the vectors are copied.
73
80
 
@@ -92,9 +99,7 @@ class DyadCarrier(object):
92
99
  if len(ulist) != len(vlist):
93
100
  raise TypeError("Number of vectors in u ({}) and v({}) should be equal".format(len(ulist), len(vlist)))
94
101
 
95
- n = len(ulist)
96
-
97
- for i, ui, vi in zip(range(n), ulist, vlist):
102
+ for i, (ui, vi) in enumerate(zip(ulist, vlist)):
98
103
  # Make sure they are numpy arrays
99
104
  if not isinstance(ui, np.ndarray):
100
105
  ui = np.array(ui)
@@ -141,14 +146,22 @@ class DyadCarrier(object):
141
146
 
142
147
  def __getitem__(self, subscript):
143
148
  assert len(subscript) == self.ndim, "Invalid number of slices, must be 2"
144
- usub = [ui[subscript[0]] for ui in self.u]
145
- vsub = [vi[subscript[1]] for vi in self.v]
149
+ if self.shape[0] < 0 and self.shape[1] < 0:
150
+ return DyadCarrier()
151
+
152
+ usample = np.zeros(self.shape[0])[subscript[0]]
153
+ vsample = np.zeros(self.shape[1])[subscript[1]]
146
154
 
147
- is_uni_slice = isscalarlike(usub[0]) or isscalarlike(vsub[0])
155
+ is_uni_slice = isscalarlike(usample) or isscalarlike(vsample)
148
156
  is_np_slice = isinstance(subscript[0], np.ndarray) and isinstance(subscript[1], np.ndarray)
157
+
149
158
  if is_np_slice and subscript[0].shape != subscript[1].shape:
150
159
  raise IndexError(f"shape mismatch: indexing arrays could not be broadcast together "
151
160
  f"with shapes {subscript[0].shape} {subscript[1].shape}")
161
+
162
+ usub = [ui[subscript[0]] for ui in self.u]
163
+ vsub = [vi[subscript[1]] for vi in self.v]
164
+
152
165
  if is_uni_slice or is_np_slice:
153
166
  res = 0
154
167
  for (ui, vi) in zip(usub, vsub):
@@ -156,7 +169,7 @@ class DyadCarrier(object):
156
169
 
157
170
  return res
158
171
  else:
159
- return DyadCarrier(usub, vsub)
172
+ return DyadCarrier(usub, vsub, shape=(np.size(usample), np.size(vsample)))
160
173
 
161
174
  def __setitem__(self, subscript, value):
162
175
  assert len(subscript) == self.ndim, "Invalid number of slices, must be 2"
@@ -171,10 +184,10 @@ class DyadCarrier(object):
171
184
  vi[subscript[1]] = value
172
185
 
173
186
  def __pos__(self):
174
- return DyadCarrier(self.u, self.v)
187
+ return self.copy()
175
188
 
176
189
  def __neg__(self):
177
- return DyadCarrier([-uu for uu in self.u], self.v)
190
+ return DyadCarrier([-uu for uu in self.u], self.v, shape=self.shape)
178
191
 
179
192
  def __iadd__(self, other):
180
193
  self.add_dyad(other.u, other.v)
@@ -189,7 +202,7 @@ class DyadCarrier(object):
189
202
  elif isdyad(other):
190
203
  if other.shape != self.shape and (self.size > 0 and other.size > 0):
191
204
  raise ValueError(f"Inconsistent shapes {self.shape} and {other.shape}")
192
- return DyadCarrier(self.u, self.v).__iadd__(other)
205
+ return self.copy().__iadd__(other)
193
206
  elif isdense(other):
194
207
  other = np.broadcast_to(other, self.shape)
195
208
  return other + self.todense()
@@ -219,28 +232,44 @@ class DyadCarrier(object):
219
232
  return NotImplemented
220
233
 
221
234
  def __rmul__(self, other): # other * self
222
- return DyadCarrier([other*ui for ui in self.u], self.v)
235
+ return DyadCarrier([other*ui for ui in self.u], self.v, shape=self.shape)
223
236
 
224
237
  def __mul__(self, other): # self * other
225
- return DyadCarrier(self.u, [vi*other for vi in self.v])
238
+ return DyadCarrier(self.u, [vi*other for vi in self.v], shape=self.shape)
226
239
 
227
240
  def copy(self):
228
241
  """ Returns a deep copy of the DyadCarrier """
229
- return DyadCarrier(self.u, self.v)
242
+ return DyadCarrier(self.u, self.v, shape=self.shape)
230
243
 
231
244
  def conj(self):
232
245
  """ Returns (a deep copied) complex conjugate of the DyadCarrier """
233
- return DyadCarrier([u.conj() for u in self.u], [v.conj() for v in self.v])
246
+ return DyadCarrier([u.conj() for u in self.u], [v.conj() for v in self.v], shape=self.shape)
234
247
 
235
248
  @property
236
249
  def real(self):
237
250
  """ Returns a deep copy of the real part of the DyadCarrier """
238
- return DyadCarrier([*[u.real for u in self.u], *[-u.imag for u in self.u]], [*[v.real for v in self.v], *[v.imag for v in self.v]])
251
+ return DyadCarrier([*[u.real for u in self.u], *[-u.imag for u in self.u]], [*[v.real for v in self.v], *[v.imag for v in self.v]], shape=self.shape)
239
252
 
240
253
  @property
241
254
  def imag(self):
242
255
  """ Returns a deep copy of the imaginary part of the DyadCarrier """
243
- return DyadCarrier([*[u.real for u in self.u], *[u.imag for u in self.u]], [*[v.imag for v in self.v], *[v.real for v in self.v]])
256
+ return DyadCarrier([*[u.real for u in self.u], *[u.imag for u in self.u]], [*[v.imag for v in self.v], *[v.real for v in self.v]], shape=self.shape)
257
+
258
+ def min(self):
259
+ minval = 0.0
260
+ for u, v in zip(self.u, self.v):
261
+ minval += u.min() * v.min()
262
+ if len(self.u) >= 2:
263
+ warnings.warn("The minimum is an approximation")
264
+ return minval
265
+
266
+ def max(self):
267
+ maxval = 0.0
268
+ for u, v in zip(self.u, self.v):
269
+ maxval += u.max() * v.max()
270
+ if len(self.u) >= 2:
271
+ warnings.warn("The maximum is an approximation")
272
+ return maxval
244
273
 
245
274
  # flake8: noqa: C901
246
275
  def contract(self, mat: Union[NDArray, spmatrix] = None, rows: NDArray[int] = None, cols: NDArray[int] = None):
@@ -373,6 +402,31 @@ class DyadCarrier(object):
373
402
 
374
403
  return val
375
404
 
405
+ def contract_multi(self, mats: List[spmatrix], dtype=None):
406
+ """ Faster version of contraction for a list of sparse matrices """
407
+ if dtype is None:
408
+ dtype = np.result_type(self.dtype, mats[0].dtype)
409
+ val = np.zeros(len(mats), dtype=dtype)
410
+
411
+ if len(self.u) == 0 or len(self.v) == 0:
412
+ return val
413
+ U = np.array(self.u).T
414
+ V = np.array(self.v).T
415
+
416
+ for i, m in enumerate(mats):
417
+ if m is None:
418
+ vali = 0.0
419
+ else:
420
+ try:
421
+ if not isinstance(m, coo_matrix):
422
+ warnings.warn("Inefficiency: Matrix must be converted to coo_matrix for contraction")
423
+ mat_coo = m.tocoo()
424
+ vali = np.einsum('ij,i,ij->', U[mat_coo.row, :], mat_coo.data, V[mat_coo.col, :])
425
+ except AttributeError:
426
+ vali = self.contract(m)
427
+ val[i] = vali
428
+ return val
429
+
376
430
  def todense(self):
377
431
  """ Returns a full (dense) matrix from the DyadCarrier matrix """
378
432
  warning_size = 100e+6 # Bytes
@@ -387,6 +441,10 @@ class DyadCarrier(object):
387
441
 
388
442
  return val
389
443
 
444
+ def toarray(self):
445
+ """ Convert to array, same as todense(). To be consistent with scipy.sparse """
446
+ return self.todense()
447
+
390
448
  def iscomplex(self):
391
449
  """ Check if the DyadCarrier is of complex type """
392
450
  return np.iscomplexobj(np.array([], dtype=self.dtype))
@@ -417,7 +475,7 @@ class DyadCarrier(object):
417
475
 
418
476
  def transpose(self):
419
477
  """ Returns a deep copy of the transposed DyadCarrier matrix"""
420
- return DyadCarrier(self.v, self.u)
478
+ return DyadCarrier(self.v, self.u, shape=(self.shape[1], self.shape[0]))
421
479
 
422
480
  def dot(self, other):
423
481
  """ Inner product """
@@ -445,9 +503,9 @@ class DyadCarrier(object):
445
503
  if other.ndim == 1:
446
504
  return self.__dot__(other)
447
505
 
448
- return DyadCarrier(self.u, [vi@other for vi in self.v])
506
+ return DyadCarrier(self.u, [vi@other for vi in self.v], shape=(self.shape[0], other.shape[1]))
449
507
 
450
508
  def __rmatmul__(self, other): # other @ self
451
509
  if other.ndim == 1:
452
510
  return self.__rdot__(other)
453
- return DyadCarrier([other@ui for ui in self.u], self.v)
511
+ return DyadCarrier([other@ui for ui in self.u], self.v, shape=(other.shape[0], self.shape[1]))