pyMOTO 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyMOTO-1.3.0.dist-info → pyMOTO-1.5.0.dist-info}/METADATA +7 -8
- pyMOTO-1.5.0.dist-info/RECORD +29 -0
- {pyMOTO-1.3.0.dist-info → pyMOTO-1.5.0.dist-info}/WHEEL +1 -1
- pymoto/__init__.py +17 -11
- pymoto/common/domain.py +61 -5
- pymoto/common/dyadcarrier.py +87 -29
- pymoto/common/mma.py +142 -129
- pymoto/core_objects.py +129 -117
- pymoto/modules/aggregation.py +209 -0
- pymoto/modules/assembly.py +250 -10
- pymoto/modules/complex.py +3 -3
- pymoto/modules/filter.py +171 -24
- pymoto/modules/generic.py +12 -1
- pymoto/modules/io.py +85 -12
- pymoto/modules/linalg.py +92 -120
- pymoto/modules/scaling.py +5 -4
- pymoto/routines.py +34 -9
- pymoto/solvers/__init__.py +14 -0
- pymoto/solvers/auto_determine.py +108 -0
- pymoto/{common/solvers_dense.py → solvers/dense.py} +90 -70
- pymoto/solvers/iterative.py +361 -0
- pymoto/solvers/matrix_checks.py +60 -0
- pymoto/solvers/solvers.py +253 -0
- pymoto/{common/solvers_sparse.py → solvers/sparse.py} +42 -29
- pyMOTO-1.3.0.dist-info/RECORD +0 -24
- pymoto/common/solvers.py +0 -236
- {pyMOTO-1.3.0.dist-info → pyMOTO-1.5.0.dist-info}/LICENSE +0 -0
- {pyMOTO-1.3.0.dist-info → pyMOTO-1.5.0.dist-info}/top_level.txt +0 -0
- {pyMOTO-1.3.0.dist-info → pyMOTO-1.5.0.dist-info}/zip-safe +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: pyMOTO
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.5.0
|
4
4
|
Summary: A modular approach for topology optimization
|
5
5
|
Home-page: https://github.com/aatmdelissen/pyMOTO
|
6
6
|
Author: Arnoud Delissen
|
@@ -25,7 +25,7 @@ Requires-Dist: scikit-sparse ; extra == 'dev'
|
|
25
25
|
Requires-Dist: pypardiso ; extra == 'dev'
|
26
26
|
Requires-Dist: jax[cpu] ; extra == 'dev'
|
27
27
|
|
28
|
-
[](https://doi.org/10.5281/zenodo.8138859)
|
29
29
|
[](https://anaconda.org/aatmdelissen/pymoto)
|
30
30
|
[](https://pypi.org/project/pyMOTO/)
|
31
31
|
|
@@ -61,15 +61,14 @@ automatically calculated.
|
|
61
61
|
4. Run the example by typing `python ex_....py` in the console
|
62
62
|
|
63
63
|
|
64
|
-
|
64
|
+
For development, a local installation of `pyMOTO` can be done by first downloading/cloning the entire git repo, and then calling
|
65
65
|
`pip install -e .` in the `pyMOTO` folder (of course from within your virtual environment).
|
66
66
|
|
67
67
|
## Dependencies
|
68
|
-
* **
|
69
|
-
* **
|
70
|
-
* **
|
71
|
-
* **Matplotlib** - Plotting and visualisation
|
72
|
-
* (optional) **SAO** - Sequential approximated optimizers
|
68
|
+
* [**numpy**](https://numpy.org/doc/stable/) - Dense linear algebra and solvers
|
69
|
+
* [**scipy**](https://docs.scipy.org/doc/scipy/) - Sparse linear algebra and solvers
|
70
|
+
* [**sympy**](https://docs.sympy.org/latest/index.html) - Symbolic differentiation for `MathGeneral` module
|
71
|
+
* [**Matplotlib**](https://matplotlib.org/stable/) - Plotting and visualisation
|
73
72
|
* (optional) [**opt_einsum**](https://optimized-einsum.readthedocs.io/en/stable/install.html) - Optimized function for `EinSum` module
|
74
73
|
|
75
74
|
For fast linear solvers for sparse matrices:
|
@@ -0,0 +1,29 @@
|
|
1
|
+
pymoto/__init__.py,sha256=YLMAiO2PZHAC6nYWXVh03rhZnZkc_Rc2z7SGQj1T8I4,2058
|
2
|
+
pymoto/core_objects.py,sha256=88AOo041wrcRSPoLCRwBXUoYANGX-b2SAA0Nuf6sn2Y,25252
|
3
|
+
pymoto/routines.py,sha256=yjvcQDcWU47ZM6ZpZoX8VwrJoN9JDO_25DSa9oVcKec,15582
|
4
|
+
pymoto/utils.py,sha256=YJ-PNLJLc12Yx6TYCrEechS2aaBRx0o4mTM1soeeyz0,1122
|
5
|
+
pymoto/common/domain.py,sha256=-eFuYRLehQ17Ai-cV59f4I9FbEM-DJAj6kjjVfj31X0,18120
|
6
|
+
pymoto/common/dyadcarrier.py,sha256=VwMbqPr0NMDPfpsH0BwvXp8M1dmh8ijFDpF6yoyTmto,19394
|
7
|
+
pymoto/common/mma.py,sha256=Pof3clOHA8PG51TmUjs11dkjSP96kovZjsPv62tI2Ec,24055
|
8
|
+
pymoto/modules/aggregation.py,sha256=Oi17hIJ6dic4lOPw16zmjbdC72MjB6XK34H80bnbWAI,7580
|
9
|
+
pymoto/modules/assembly.py,sha256=quuR8QpB2w-O0zly-xS6PK6wZQMY6S5TWh15Y9wuh14,22974
|
10
|
+
pymoto/modules/autodiff.py,sha256=WAfoAOHBSozf7jbr9gQz9Vw4a_2G9wGJxLMMqUQP0Co,1684
|
11
|
+
pymoto/modules/complex.py,sha256=B_Obk-ABdV66lEudZ5s8o6qG9NsmYlBsX-PbWvbphhc,4429
|
12
|
+
pymoto/modules/filter.py,sha256=6X9FaQMWYZ_TpHVTFiEibzlmAwmSWbydYM93LFrJ0Wo,25490
|
13
|
+
pymoto/modules/generic.py,sha256=YzsGZ8J0oLCORt78Bf2p0v4GuqpWRI77NLoCk7gqidw,10666
|
14
|
+
pymoto/modules/io.py,sha256=LcFvJ-cPgg5ee-aag8kaxHw5RzQ-ggxOM5jk7PeJ1r8,13140
|
15
|
+
pymoto/modules/linalg.py,sha256=BNkih4nvvkYuQpm4bG5U38dAjkfD5EFi3MjANpBEPPI,21927
|
16
|
+
pymoto/modules/scaling.py,sha256=uq88HHW9rP16XLz7UGc3CNBBpY2Z1glo8yjYxZEnXUg,2327
|
17
|
+
pymoto/solvers/__init__.py,sha256=9JUeD2SgZbkYFullA7s7s6SuAVv0onqAqJ8hFvNOs2g,1033
|
18
|
+
pymoto/solvers/auto_determine.py,sha256=X8MEG7h6jLfAV1inpja45_-suG8qQFMfLMDfW2ryQqQ,5134
|
19
|
+
pymoto/solvers/dense.py,sha256=9fKPCwNxRKAEk5k1A7fdLrr9ngeVssGlw-sbjWCm4iU,11235
|
20
|
+
pymoto/solvers/iterative.py,sha256=CIxJHjGnCaIjXbtO2NxV60yeDpcCbSD6Bp0xR-7vOf0,12944
|
21
|
+
pymoto/solvers/matrix_checks.py,sha256=bbrfjpTSWWnuQW3xY0_CYE8yrh5gA9K5b1LzHEOFAxI,1663
|
22
|
+
pymoto/solvers/solvers.py,sha256=RwHjZYYlE3oA0U9k7ukla2gOdmq57rSSJQvHqjaM7JU,10626
|
23
|
+
pymoto/solvers/sparse.py,sha256=w8XBlFBIfOpNnfRdLWhLzzqtD8YVxMnDBuhIabFfQQc,16664
|
24
|
+
pyMOTO-1.5.0.dist-info/LICENSE,sha256=ZXMC2Txpzs-dBwz9Me4_1rQCSVl4P1B27MomNi43F30,1072
|
25
|
+
pyMOTO-1.5.0.dist-info/METADATA,sha256=hC38SdgeKEK5NkNDh-gwc4Gz2JOFSymJB-eAMXl7HX4,5006
|
26
|
+
pyMOTO-1.5.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
27
|
+
pyMOTO-1.5.0.dist-info/top_level.txt,sha256=EdvAUSmFMaiqhuEZW8jxANMiK-LdPtlmDWL6SfmCdUU,7
|
28
|
+
pyMOTO-1.5.0.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
29
|
+
pyMOTO-1.5.0.dist-info/RECORD,,
|
pymoto/__init__.py
CHANGED
@@ -1,25 +1,27 @@
|
|
1
|
-
__version__ = '1.
|
1
|
+
__version__ = '1.5.0'
|
2
2
|
|
3
3
|
from .common.domain import DomainDefinition
|
4
4
|
|
5
5
|
# Imports from common
|
6
6
|
from .common.dyadcarrier import DyadCarrier
|
7
7
|
from .common.mma import MMA
|
8
|
-
|
9
|
-
|
10
|
-
from .
|
8
|
+
|
9
|
+
# Import solvers
|
10
|
+
from . import solvers
|
11
11
|
|
12
12
|
# Modular inports
|
13
13
|
from .core_objects import Signal, Module, Network, make_signals
|
14
14
|
|
15
15
|
# Import modules
|
16
16
|
from .modules.assembly import AssembleGeneral, AssembleStiffness, AssembleMass, AssemblePoisson
|
17
|
+
from .modules.assembly import ElementOperation, Strain, Stress, ElementAverage, NodalOperation, ThermoMechanical
|
17
18
|
from .modules.autodiff import AutoMod
|
18
19
|
from .modules.complex import MakeComplex, RealPart, ImagPart, ComplexNorm
|
19
20
|
from .modules.filter import FilterConv, Filter, DensityFilter, OverhangFilter
|
20
21
|
from .modules.generic import MathGeneral, EinSum, ConcatSignal
|
21
|
-
from .modules.io import PlotDomain, PlotGraph, PlotIter, WriteToVTI
|
22
|
+
from .modules.io import FigModule, PlotDomain, PlotGraph, PlotIter, WriteToVTI, ScalarToFile
|
22
23
|
from .modules.linalg import Inverse, LinSolve, EigenSolve, SystemOfEquations, StaticCondensation
|
24
|
+
from .modules.aggregation import AggScaling, AggActiveSet, Aggregation, PNorm, SoftMinMax, KSFunction
|
23
25
|
from .modules.scaling import Scaling
|
24
26
|
|
25
27
|
# Further helper routines
|
@@ -28,21 +30,25 @@ from .routines import finite_difference, minimize_oc, minimize_mma
|
|
28
30
|
__all__ = [
|
29
31
|
'Signal', 'Module', 'Network', 'make_signals',
|
30
32
|
'finite_difference', 'minimize_oc', 'minimize_mma',
|
33
|
+
|
31
34
|
# Common
|
32
35
|
'MMA',
|
33
36
|
'DyadCarrier',
|
34
37
|
'DomainDefinition',
|
35
|
-
'
|
36
|
-
|
37
|
-
|
38
|
-
|
38
|
+
'solvers',
|
39
|
+
|
40
|
+
# Helpers
|
41
|
+
"AggScaling", "AggActiveSet",
|
42
|
+
|
39
43
|
# Modules
|
40
44
|
"MathGeneral", "EinSum", "ConcatSignal",
|
41
45
|
"Inverse", "LinSolve", "EigenSolve", "SystemOfEquations", "StaticCondensation",
|
42
46
|
"AssembleGeneral", "AssembleStiffness", "AssembleMass", "AssemblePoisson",
|
47
|
+
"ElementOperation", "Strain", "Stress", "ElementAverage", "NodalOperation", "ThermoMechanical",
|
43
48
|
"FilterConv", "Filter", "DensityFilter", "OverhangFilter",
|
44
|
-
"PlotDomain", "PlotGraph", "PlotIter", "WriteToVTI",
|
49
|
+
"FigModule", "PlotDomain", "PlotGraph", "PlotIter", "WriteToVTI", "ScalarToFile",
|
45
50
|
"MakeComplex", "RealPart", "ImagPart", "ComplexNorm",
|
46
51
|
"AutoMod",
|
47
|
-
"
|
52
|
+
"Aggregation", "PNorm", "SoftMinMax", "KSFunction",
|
53
|
+
"Scaling",
|
48
54
|
]
|
pymoto/common/domain.py
CHANGED
@@ -5,6 +5,31 @@ import struct
|
|
5
5
|
import warnings
|
6
6
|
from typing import Union
|
7
7
|
import numpy as np
|
8
|
+
from matplotlib.patches import PathPatch
|
9
|
+
from matplotlib.path import Path
|
10
|
+
|
11
|
+
|
12
|
+
def plot_deformed_element(ax, x, y, **kwargs):
|
13
|
+
codes, verts = zip(*[
|
14
|
+
(Path.MOVETO, [x[0], y[0]]),
|
15
|
+
(Path.LINETO, [x[1], y[1]]),
|
16
|
+
(Path.LINETO, [x[3], y[3]]),
|
17
|
+
(Path.LINETO, [x[2], y[2]]),
|
18
|
+
(Path.CLOSEPOLY, [x[0], y[0]])])
|
19
|
+
path = Path(verts, codes)
|
20
|
+
patch = PathPatch(path, **kwargs)
|
21
|
+
ax.add_artist(patch)
|
22
|
+
return patch
|
23
|
+
|
24
|
+
|
25
|
+
def get_path(x, y):
|
26
|
+
codes, verts = zip(*[
|
27
|
+
(Path.MOVETO, [x[0], y[0]]),
|
28
|
+
(Path.LINETO, [x[1], y[1]]),
|
29
|
+
(Path.LINETO, [x[3], y[3]]),
|
30
|
+
(Path.LINETO, [x[2], y[2]]),
|
31
|
+
(Path.CLOSEPOLY, [x[0], y[0]])])
|
32
|
+
return Path(verts, codes)
|
8
33
|
|
9
34
|
|
10
35
|
class DomainDefinition:
|
@@ -100,6 +125,14 @@ class DomainDefinition:
|
|
100
125
|
self.conn = np.zeros((self.nel, self.elemnodes), dtype=int)
|
101
126
|
self.conn[el, :] = self.get_elemconnectivity(elx, ely, elz)
|
102
127
|
|
128
|
+
# Helper for element slicing
|
129
|
+
eli, elj, elk = np.meshgrid(np.arange(self.nelx), np.arange(self.nely), np.arange(max(self.nelz, 1)), indexing='ij')
|
130
|
+
self.elements = self.get_elemnumber(eli, elj, elk)
|
131
|
+
|
132
|
+
# Helper for node slicing
|
133
|
+
ndi, ndj, ndk = np.meshgrid(np.arange(self.nelx+1), np.arange(self.nely+1), np.arange(self.nelz+1), indexing='ij')
|
134
|
+
self.nodes = self.get_nodenumber(ndi, ndj, ndk)
|
135
|
+
|
103
136
|
def get_elemnumber(self, eli: Union[int, np.ndarray], elj: Union[int, np.ndarray], elk: Union[int, np.ndarray] = 0):
|
104
137
|
""" Gets the element number(s) for element(s) with given Cartesian indices (i, j, k)
|
105
138
|
|
@@ -126,7 +159,7 @@ class DomainDefinition:
|
|
126
159
|
"""
|
127
160
|
return (nodk * (self.nely + 1) + nodj) * (self.nelx + 1) + nodi
|
128
161
|
|
129
|
-
def get_node_indices(self, nod_idx: Union[int, np.ndarray]):
|
162
|
+
def get_node_indices(self, nod_idx: Union[int, np.ndarray] = None):
|
130
163
|
""" Gets the Cartesian index (i, j, k) for given node number(s)
|
131
164
|
|
132
165
|
Args:
|
@@ -135,16 +168,18 @@ class DomainDefinition:
|
|
135
168
|
Returns:
|
136
169
|
i, j, k for requested node(s); k is only returned in 3D
|
137
170
|
"""
|
171
|
+
if nod_idx is None:
|
172
|
+
nod_idx = np.arange(self.nnodes)
|
138
173
|
nodi = nod_idx % (self.nelx + 1)
|
139
174
|
nodj = (nod_idx // (self.nelx + 1)) % (self.nely + 1)
|
140
175
|
if self.dim == 2:
|
141
|
-
return nodi, nodj
|
176
|
+
return np.stack([nodi, nodj], axis=0)
|
142
177
|
nodk = nod_idx // ((self.nelx + 1)*(self.nely + 1))
|
143
|
-
return nodi, nodj, nodk
|
178
|
+
return np.stack([nodi, nodj, nodk], axis=0)
|
144
179
|
|
145
|
-
def get_node_position(self, nod_idx: Union[int, np.ndarray]):
|
180
|
+
def get_node_position(self, nod_idx: Union[int, np.ndarray] = None):
|
146
181
|
ijk = self.get_node_indices(nod_idx)
|
147
|
-
return
|
182
|
+
return (self.element_size[:self.dim] * ijk.T).T
|
148
183
|
|
149
184
|
def get_elemconnectivity(self, i: Union[int, np.ndarray], j: Union[int, np.ndarray], k: Union[int, np.ndarray] = 0):
|
150
185
|
""" Get the connectivity for element identified with Cartesian indices (i, j, k)
|
@@ -230,6 +265,27 @@ class DomainDefinition:
|
|
230
265
|
dN_dx[i, :] *= np.array([n[i] for n in self.node_numbering]) # Flip +/- signs according to node position
|
231
266
|
return dN_dx
|
232
267
|
|
268
|
+
def plot(self, ax, deformation=None, scaling=None):
|
269
|
+
patches = []
|
270
|
+
for e in range(self.nel):
|
271
|
+
n = self.conn[e]
|
272
|
+
x, y = self.get_node_position(n)
|
273
|
+
u, v = deformation[n * 2], deformation[n * 2 + 1]
|
274
|
+
color = (1 - scaling[e], 1 - scaling[e], 1 - scaling[e]) if scaling is not None else 'grey'
|
275
|
+
patch = plot_deformed_element(ax, x + u, v + y, linewidth=0.1, color=color)
|
276
|
+
patches.append(patch)
|
277
|
+
return patches
|
278
|
+
|
279
|
+
def update_plot(self, patches, deformation=None, scaling=None):
|
280
|
+
for e in range(self.nel):
|
281
|
+
patch = patches[e]
|
282
|
+
n = self.conn[e]
|
283
|
+
x, y = self.get_node_position(n)
|
284
|
+
u, v = deformation[n * 2], deformation[n * 2 + 1]
|
285
|
+
color = (1 - scaling[e], 1 - scaling[e], 1 - scaling[e]) if scaling is not None else 'grey'
|
286
|
+
patch.set_color(color)
|
287
|
+
patch.set_path(self.get_path(x + u, y + v))
|
288
|
+
|
233
289
|
# flake8: noqa: C901
|
234
290
|
def write_to_vti(self, vectors: dict, filename="out.vti", scale=1.0, origin=(0.0, 0.0, 0.0)):
|
235
291
|
""" Write all given vectors to a Paraview (VTI) file
|
pymoto/common/dyadcarrier.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
|
-
from typing import Union, Iterable
|
1
|
+
from typing import Union, Iterable, List, Tuple
|
2
2
|
import warnings
|
3
3
|
import numpy as np
|
4
4
|
from numpy.typing import NDArray
|
5
|
-
from scipy.sparse import spmatrix
|
5
|
+
from scipy.sparse import spmatrix, coo_matrix
|
6
6
|
from ..utils import _parse_to_list
|
7
7
|
try: # Import fast optimized einsum
|
8
8
|
from opt_einsum import contract as einsum
|
@@ -31,25 +31,26 @@ def isnullslice(x):
|
|
31
31
|
|
32
32
|
|
33
33
|
class DyadCarrier(object):
|
34
|
-
""" Efficient storage for dyadic or rank-N matrix
|
34
|
+
r""" Efficient storage for dyadic or rank-N matrix
|
35
35
|
|
36
36
|
Stores only the vectors instead of creating a full rank-N matrix
|
37
37
|
:math:`\mathbf{A} = \sum_k^N \mathbf{u}_k\otimes\mathbf{v}_k`
|
38
38
|
or in index notation :math:`A_{ij} = \sum_k^N u_{ki} v_{kj}`. This saves a lot of memory for low :math:`N`.
|
39
39
|
|
40
|
-
Args:
|
41
|
-
u
|
42
|
-
v
|
40
|
+
Keyword Args:
|
41
|
+
u: List of vectors
|
42
|
+
v: List of vectors (if ``u`` is given and ``v`` not, a symmetric dyad is assumed with ``v = u``)
|
43
|
+
shape: Shape of the matrix
|
43
44
|
"""
|
44
45
|
|
45
46
|
__array_priority__ = 11.0 # For overriding numpy's ufuncs
|
46
47
|
ndim = 2 # Number of dimensions
|
47
48
|
|
48
|
-
def __init__(self, u: Iterable = None, v: Iterable = None):
|
49
|
+
def __init__(self, u: Iterable = None, v: Iterable = None, shape: Tuple[int, int] = (-1, -1)):
|
49
50
|
self.u = []
|
50
51
|
self.v = []
|
51
|
-
self.ulen =
|
52
|
-
self.vlen =
|
52
|
+
self.ulen = shape[0]
|
53
|
+
self.vlen = shape[1]
|
53
54
|
self.dtype = np.dtype('float64') # Standard data type
|
54
55
|
self.add_dyad(u, v)
|
55
56
|
|
@@ -66,8 +67,14 @@ class DyadCarrier(object):
|
|
66
67
|
else:
|
67
68
|
return self.ulen * self.vlen
|
68
69
|
|
70
|
+
@property
|
71
|
+
def n_dyads(self):
|
72
|
+
""" Number of dyads stored """
|
73
|
+
assert len(self.u) == len(self.v)
|
74
|
+
return len(self.u)
|
75
|
+
|
69
76
|
def add_dyad(self, u: Iterable, v: Iterable = None, fac: float = None):
|
70
|
-
""" Adds a list of vectors to the dyad carrier
|
77
|
+
r""" Adds a list of vectors to the dyad carrier
|
71
78
|
|
72
79
|
Checks for conforming sizes of `u` and `v`. The data inside the vectors are copied.
|
73
80
|
|
@@ -92,9 +99,7 @@ class DyadCarrier(object):
|
|
92
99
|
if len(ulist) != len(vlist):
|
93
100
|
raise TypeError("Number of vectors in u ({}) and v({}) should be equal".format(len(ulist), len(vlist)))
|
94
101
|
|
95
|
-
|
96
|
-
|
97
|
-
for i, ui, vi in zip(range(n), ulist, vlist):
|
102
|
+
for i, (ui, vi) in enumerate(zip(ulist, vlist)):
|
98
103
|
# Make sure they are numpy arrays
|
99
104
|
if not isinstance(ui, np.ndarray):
|
100
105
|
ui = np.array(ui)
|
@@ -141,14 +146,22 @@ class DyadCarrier(object):
|
|
141
146
|
|
142
147
|
def __getitem__(self, subscript):
|
143
148
|
assert len(subscript) == self.ndim, "Invalid number of slices, must be 2"
|
144
|
-
|
145
|
-
|
149
|
+
if self.shape[0] < 0 and self.shape[1] < 0:
|
150
|
+
return DyadCarrier()
|
151
|
+
|
152
|
+
usample = np.zeros(self.shape[0])[subscript[0]]
|
153
|
+
vsample = np.zeros(self.shape[1])[subscript[1]]
|
146
154
|
|
147
|
-
is_uni_slice = isscalarlike(
|
155
|
+
is_uni_slice = isscalarlike(usample) or isscalarlike(vsample)
|
148
156
|
is_np_slice = isinstance(subscript[0], np.ndarray) and isinstance(subscript[1], np.ndarray)
|
157
|
+
|
149
158
|
if is_np_slice and subscript[0].shape != subscript[1].shape:
|
150
159
|
raise IndexError(f"shape mismatch: indexing arrays could not be broadcast together "
|
151
160
|
f"with shapes {subscript[0].shape} {subscript[1].shape}")
|
161
|
+
|
162
|
+
usub = [ui[subscript[0]] for ui in self.u]
|
163
|
+
vsub = [vi[subscript[1]] for vi in self.v]
|
164
|
+
|
152
165
|
if is_uni_slice or is_np_slice:
|
153
166
|
res = 0
|
154
167
|
for (ui, vi) in zip(usub, vsub):
|
@@ -156,7 +169,7 @@ class DyadCarrier(object):
|
|
156
169
|
|
157
170
|
return res
|
158
171
|
else:
|
159
|
-
return DyadCarrier(usub, vsub)
|
172
|
+
return DyadCarrier(usub, vsub, shape=(np.size(usample), np.size(vsample)))
|
160
173
|
|
161
174
|
def __setitem__(self, subscript, value):
|
162
175
|
assert len(subscript) == self.ndim, "Invalid number of slices, must be 2"
|
@@ -171,10 +184,10 @@ class DyadCarrier(object):
|
|
171
184
|
vi[subscript[1]] = value
|
172
185
|
|
173
186
|
def __pos__(self):
|
174
|
-
return
|
187
|
+
return self.copy()
|
175
188
|
|
176
189
|
def __neg__(self):
|
177
|
-
return DyadCarrier([-uu for uu in self.u], self.v)
|
190
|
+
return DyadCarrier([-uu for uu in self.u], self.v, shape=self.shape)
|
178
191
|
|
179
192
|
def __iadd__(self, other):
|
180
193
|
self.add_dyad(other.u, other.v)
|
@@ -189,7 +202,7 @@ class DyadCarrier(object):
|
|
189
202
|
elif isdyad(other):
|
190
203
|
if other.shape != self.shape and (self.size > 0 and other.size > 0):
|
191
204
|
raise ValueError(f"Inconsistent shapes {self.shape} and {other.shape}")
|
192
|
-
return
|
205
|
+
return self.copy().__iadd__(other)
|
193
206
|
elif isdense(other):
|
194
207
|
other = np.broadcast_to(other, self.shape)
|
195
208
|
return other + self.todense()
|
@@ -219,28 +232,44 @@ class DyadCarrier(object):
|
|
219
232
|
return NotImplemented
|
220
233
|
|
221
234
|
def __rmul__(self, other): # other * self
|
222
|
-
return DyadCarrier([other*ui for ui in self.u], self.v)
|
235
|
+
return DyadCarrier([other*ui for ui in self.u], self.v, shape=self.shape)
|
223
236
|
|
224
237
|
def __mul__(self, other): # self * other
|
225
|
-
return DyadCarrier(self.u, [vi*other for vi in self.v])
|
238
|
+
return DyadCarrier(self.u, [vi*other for vi in self.v], shape=self.shape)
|
226
239
|
|
227
240
|
def copy(self):
|
228
241
|
""" Returns a deep copy of the DyadCarrier """
|
229
|
-
return DyadCarrier(self.u, self.v)
|
242
|
+
return DyadCarrier(self.u, self.v, shape=self.shape)
|
230
243
|
|
231
244
|
def conj(self):
|
232
245
|
""" Returns (a deep copied) complex conjugate of the DyadCarrier """
|
233
|
-
return DyadCarrier([u.conj() for u in self.u], [v.conj() for v in self.v])
|
246
|
+
return DyadCarrier([u.conj() for u in self.u], [v.conj() for v in self.v], shape=self.shape)
|
234
247
|
|
235
248
|
@property
|
236
249
|
def real(self):
|
237
250
|
""" Returns a deep copy of the real part of the DyadCarrier """
|
238
|
-
return DyadCarrier([*[u.real for u in self.u], *[-u.imag for u in self.u]], [*[v.real for v in self.v], *[v.imag for v in self.v]])
|
251
|
+
return DyadCarrier([*[u.real for u in self.u], *[-u.imag for u in self.u]], [*[v.real for v in self.v], *[v.imag for v in self.v]], shape=self.shape)
|
239
252
|
|
240
253
|
@property
|
241
254
|
def imag(self):
|
242
255
|
""" Returns a deep copy of the imaginary part of the DyadCarrier """
|
243
|
-
return DyadCarrier([*[u.real for u in self.u], *[u.imag for u in self.u]], [*[v.imag for v in self.v], *[v.real for v in self.v]])
|
256
|
+
return DyadCarrier([*[u.real for u in self.u], *[u.imag for u in self.u]], [*[v.imag for v in self.v], *[v.real for v in self.v]], shape=self.shape)
|
257
|
+
|
258
|
+
def min(self):
|
259
|
+
minval = 0.0
|
260
|
+
for u, v in zip(self.u, self.v):
|
261
|
+
minval += u.min() * v.min()
|
262
|
+
if len(self.u) >= 2:
|
263
|
+
warnings.warn("The minimum is an approximation")
|
264
|
+
return minval
|
265
|
+
|
266
|
+
def max(self):
|
267
|
+
maxval = 0.0
|
268
|
+
for u, v in zip(self.u, self.v):
|
269
|
+
maxval += u.max() * v.max()
|
270
|
+
if len(self.u) >= 2:
|
271
|
+
warnings.warn("The maximum is an approximation")
|
272
|
+
return maxval
|
244
273
|
|
245
274
|
# flake8: noqa: C901
|
246
275
|
def contract(self, mat: Union[NDArray, spmatrix] = None, rows: NDArray[int] = None, cols: NDArray[int] = None):
|
@@ -373,6 +402,31 @@ class DyadCarrier(object):
|
|
373
402
|
|
374
403
|
return val
|
375
404
|
|
405
|
+
def contract_multi(self, mats: List[spmatrix], dtype=None):
|
406
|
+
""" Faster version of contraction for a list of sparse matrices """
|
407
|
+
if dtype is None:
|
408
|
+
dtype = np.result_type(self.dtype, mats[0].dtype)
|
409
|
+
val = np.zeros(len(mats), dtype=dtype)
|
410
|
+
|
411
|
+
if len(self.u) == 0 or len(self.v) == 0:
|
412
|
+
return val
|
413
|
+
U = np.array(self.u).T
|
414
|
+
V = np.array(self.v).T
|
415
|
+
|
416
|
+
for i, m in enumerate(mats):
|
417
|
+
if m is None:
|
418
|
+
vali = 0.0
|
419
|
+
else:
|
420
|
+
try:
|
421
|
+
if not isinstance(m, coo_matrix):
|
422
|
+
warnings.warn("Inefficiency: Matrix must be converted to coo_matrix for contraction")
|
423
|
+
mat_coo = m.tocoo()
|
424
|
+
vali = np.einsum('ij,i,ij->', U[mat_coo.row, :], mat_coo.data, V[mat_coo.col, :])
|
425
|
+
except AttributeError:
|
426
|
+
vali = self.contract(m)
|
427
|
+
val[i] = vali
|
428
|
+
return val
|
429
|
+
|
376
430
|
def todense(self):
|
377
431
|
""" Returns a full (dense) matrix from the DyadCarrier matrix """
|
378
432
|
warning_size = 100e+6 # Bytes
|
@@ -387,6 +441,10 @@ class DyadCarrier(object):
|
|
387
441
|
|
388
442
|
return val
|
389
443
|
|
444
|
+
def toarray(self):
|
445
|
+
""" Convert to array, same as todense(). To be consistent with scipy.sparse """
|
446
|
+
return self.todense()
|
447
|
+
|
390
448
|
def iscomplex(self):
|
391
449
|
""" Check if the DyadCarrier is of complex type """
|
392
450
|
return np.iscomplexobj(np.array([], dtype=self.dtype))
|
@@ -417,7 +475,7 @@ class DyadCarrier(object):
|
|
417
475
|
|
418
476
|
def transpose(self):
|
419
477
|
""" Returns a deep copy of the transposed DyadCarrier matrix"""
|
420
|
-
return DyadCarrier(self.v, self.u)
|
478
|
+
return DyadCarrier(self.v, self.u, shape=(self.shape[1], self.shape[0]))
|
421
479
|
|
422
480
|
def dot(self, other):
|
423
481
|
""" Inner product """
|
@@ -445,9 +503,9 @@ class DyadCarrier(object):
|
|
445
503
|
if other.ndim == 1:
|
446
504
|
return self.__dot__(other)
|
447
505
|
|
448
|
-
return DyadCarrier(self.u, [vi@other for vi in self.v])
|
506
|
+
return DyadCarrier(self.u, [vi@other for vi in self.v], shape=(self.shape[0], other.shape[1]))
|
449
507
|
|
450
508
|
def __rmatmul__(self, other): # other @ self
|
451
509
|
if other.ndim == 1:
|
452
510
|
return self.__rdot__(other)
|
453
|
-
return DyadCarrier([other@ui for ui in self.u], self.v)
|
511
|
+
return DyadCarrier([other@ui for ui in self.u], self.v, shape=(other.shape[0], self.shape[1]))
|