flopscope 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +1 -0
- benchmarks/__main__.py +6 -0
- benchmarks/_baseline.py +171 -0
- benchmarks/_bitwise.py +231 -0
- benchmarks/_complex.py +176 -0
- benchmarks/_contractions.py +291 -0
- benchmarks/_fft.py +198 -0
- benchmarks/_impl_urls.py +139 -0
- benchmarks/_linalg.py +197 -0
- benchmarks/_linalg_delegates.py +407 -0
- benchmarks/_metadata.py +141 -0
- benchmarks/_misc.py +653 -0
- benchmarks/_perf.py +321 -0
- benchmarks/_perm_group_calibration.py +175 -0
- benchmarks/_pointwise.py +372 -0
- benchmarks/_polynomial.py +193 -0
- benchmarks/_random.py +209 -0
- benchmarks/_reductions.py +136 -0
- benchmarks/_sorting.py +289 -0
- benchmarks/_stats.py +137 -0
- benchmarks/_window.py +92 -0
- benchmarks/accumulation/__init__.py +0 -0
- benchmarks/accumulation/bench_cost_compute.py +138 -0
- benchmarks/dashboard.py +312 -0
- benchmarks/runner.py +636 -0
- flopscope/__init__.py +273 -0
- flopscope/_accumulation/__init__.py +13 -0
- flopscope/_accumulation/_bipartite.py +121 -0
- flopscope/_accumulation/_burnside.py +51 -0
- flopscope/_accumulation/_cache.py +146 -0
- flopscope/_accumulation/_components.py +153 -0
- flopscope/_accumulation/_cost.py +1414 -0
- flopscope/_accumulation/_cost_descriptions.py +63 -0
- flopscope/_accumulation/_detection.py +318 -0
- flopscope/_accumulation/_ladder.py +191 -0
- flopscope/_accumulation/_output_orbit.py +104 -0
- flopscope/_accumulation/_partition.py +290 -0
- flopscope/_accumulation/_path_info.py +211 -0
- flopscope/_accumulation/_public.py +169 -0
- flopscope/_accumulation/_reduction.py +310 -0
- flopscope/_accumulation/_regimes.py +303 -0
- flopscope/_accumulation/_shape.py +33 -0
- flopscope/_accumulation/_wreath.py +209 -0
- flopscope/_budget.py +1027 -0
- flopscope/_config.py +118 -0
- flopscope/_counting_ops.py +451 -0
- flopscope/_display.py +478 -0
- flopscope/_docstrings.py +59 -0
- flopscope/_dtypes.py +20 -0
- flopscope/_einsum.py +717 -0
- flopscope/_errstate.py +25 -0
- flopscope/_flops.py +282 -0
- flopscope/_free_ops.py +2654 -0
- flopscope/_ndarray.py +1126 -0
- flopscope/_opt_einsum/LICENSE +21 -0
- flopscope/_opt_einsum/NOTICE +59 -0
- flopscope/_opt_einsum/__init__.py +209 -0
- flopscope/_opt_einsum/_contract.py +1478 -0
- flopscope/_opt_einsum/_helpers.py +164 -0
- flopscope/_opt_einsum/_hsluv.py +273 -0
- flopscope/_opt_einsum/_path_random.py +462 -0
- flopscope/_opt_einsum/_paths.py +1653 -0
- flopscope/_opt_einsum/_subgraph_symmetry.py +544 -0
- flopscope/_opt_einsum/_symmetry.py +140 -0
- flopscope/_opt_einsum/_typing.py +37 -0
- flopscope/_perm_group.py +717 -0
- flopscope/_pointwise.py +2522 -0
- flopscope/_polynomial.py +278 -0
- flopscope/_registry.py +3216 -0
- flopscope/_sorting_ops.py +571 -0
- flopscope/_symmetric.py +812 -0
- flopscope/_symmetry_transport.py +510 -0
- flopscope/_symmetry_utils.py +669 -0
- flopscope/_type_info.py +12 -0
- flopscope/_unwrap.py +70 -0
- flopscope/_validation.py +83 -0
- flopscope/_version_check.py +46 -0
- flopscope/_weights.py +195 -0
- flopscope/_window.py +177 -0
- flopscope/accounting.py +565 -0
- flopscope/data/default_weights.json +462 -0
- flopscope/data/weights.csv +509 -0
- flopscope/errors.py +197 -0
- flopscope/numpy/__init__.py +878 -0
- flopscope/numpy/fft/__init__.py +55 -0
- flopscope/numpy/fft/_free.py +51 -0
- flopscope/numpy/fft/_transforms.py +695 -0
- flopscope/numpy/linalg/__init__.py +105 -0
- flopscope/numpy/linalg/_aliases.py +126 -0
- flopscope/numpy/linalg/_compound.py +161 -0
- flopscope/numpy/linalg/_decompositions.py +353 -0
- flopscope/numpy/linalg/_properties.py +533 -0
- flopscope/numpy/linalg/_solvers.py +444 -0
- flopscope/numpy/linalg/_svd.py +122 -0
- flopscope/numpy/random/__init__.py +684 -0
- flopscope/numpy/random/_cost_formulas.py +115 -0
- flopscope/numpy/random/_counted_classes.py +241 -0
- flopscope/numpy/testing/__init__.py +13 -0
- flopscope/numpy/typing/__init__.py +30 -0
- flopscope/py.typed +0 -0
- flopscope/stats/__init__.py +84 -0
- flopscope/stats/_base.py +77 -0
- flopscope/stats/_cauchy.py +146 -0
- flopscope/stats/_erf.py +190 -0
- flopscope/stats/_expon.py +146 -0
- flopscope/stats/_laplace.py +150 -0
- flopscope/stats/_logistic.py +148 -0
- flopscope/stats/_lognorm.py +160 -0
- flopscope/stats/_ndtri.py +133 -0
- flopscope/stats/_norm.py +149 -0
- flopscope/stats/_truncnorm.py +186 -0
- flopscope/stats/_uniform.py +141 -0
- flopscope-0.2.0.dist-info/METADATA +23 -0
- flopscope-0.2.0.dist-info/RECORD +115 -0
- flopscope-0.2.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,669 @@
|
|
|
1
|
+
"""Helper primitives for exact tensor symmetry groups."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
import math
|
|
7
|
+
from collections import OrderedDict
|
|
8
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from flopscope._perm_group import SymmetryGroup
|
|
14
|
+
from flopscope.errors import SymmetryError
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _normalize_axis_tuple(
|
|
18
|
+
axes: Iterable[Any],
|
|
19
|
+
*,
|
|
20
|
+
ndim: int | None = None,
|
|
21
|
+
what: str = "axes",
|
|
22
|
+
) -> tuple[int, ...]:
|
|
23
|
+
norm_axes = tuple(axes)
|
|
24
|
+
if not norm_axes:
|
|
25
|
+
raise ValueError(f"{what} must be non-empty")
|
|
26
|
+
if not all(isinstance(axis, int) for axis in norm_axes):
|
|
27
|
+
raise TypeError(f"{what} must contain only integers")
|
|
28
|
+
if len(set(norm_axes)) != len(norm_axes):
|
|
29
|
+
raise ValueError(f"{what} contain duplicate entries")
|
|
30
|
+
if ndim is not None and any(axis < 0 or axis >= ndim for axis in norm_axes):
|
|
31
|
+
raise ValueError(f"{what} are out of range for ndim={ndim}")
|
|
32
|
+
return norm_axes
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def normalize_symmetry_input(obj, *, ndim: int | None = None):
|
|
36
|
+
"""Normalize supported symmetry shorthands to a single SymmetryGroup."""
|
|
37
|
+
if obj is None:
|
|
38
|
+
return None
|
|
39
|
+
if isinstance(obj, SymmetryGroup):
|
|
40
|
+
return validate_symmetry_group(obj, ndim=ndim)
|
|
41
|
+
if (
|
|
42
|
+
isinstance(obj, list)
|
|
43
|
+
and obj
|
|
44
|
+
and all(isinstance(group, SymmetryGroup) for group in obj)
|
|
45
|
+
):
|
|
46
|
+
raise TypeError("symmetry must be a single SymmetryGroup, not a list of groups")
|
|
47
|
+
if isinstance(obj, (tuple, list)) and obj:
|
|
48
|
+
first = obj[0]
|
|
49
|
+
if isinstance(first, int):
|
|
50
|
+
axes = _normalize_axis_tuple(obj, ndim=ndim, what="symmetry axes")
|
|
51
|
+
return SymmetryGroup.symmetric(axes=axes)
|
|
52
|
+
if isinstance(first, (tuple, list)):
|
|
53
|
+
blocks = []
|
|
54
|
+
seen: set[int] = set()
|
|
55
|
+
for block in obj:
|
|
56
|
+
norm_block = _normalize_axis_tuple(
|
|
57
|
+
block, ndim=ndim, what="symmetry partition block"
|
|
58
|
+
)
|
|
59
|
+
overlap = seen & set(norm_block)
|
|
60
|
+
if overlap:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
"symmetry partition blocks overlap on axes "
|
|
63
|
+
f"{tuple(sorted(overlap))}"
|
|
64
|
+
)
|
|
65
|
+
seen.update(norm_block)
|
|
66
|
+
blocks.append(norm_block)
|
|
67
|
+
return SymmetryGroup.young(blocks=tuple(blocks))
|
|
68
|
+
raise TypeError(
|
|
69
|
+
"symmetry must be a SymmetryGroup or an approved axis/partition shorthand"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def validate_symmetry_group(
|
|
74
|
+
group: SymmetryGroup,
|
|
75
|
+
*,
|
|
76
|
+
ndim: int | None = None,
|
|
77
|
+
shape: tuple[int, ...] | None = None,
|
|
78
|
+
) -> SymmetryGroup:
|
|
79
|
+
"""Validate tensor-facing properties of a symmetry group."""
|
|
80
|
+
if not isinstance(group, SymmetryGroup):
|
|
81
|
+
raise TypeError("symmetry must be a SymmetryGroup")
|
|
82
|
+
axes = group.axes
|
|
83
|
+
if axes is None:
|
|
84
|
+
if ndim is not None and group.degree > ndim:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"SymmetryGroup degree {group.degree} exceeds tensor rank {ndim}"
|
|
87
|
+
)
|
|
88
|
+
return group
|
|
89
|
+
norm_axes = _normalize_axis_tuple(axes, ndim=ndim, what="SymmetryGroup axes")
|
|
90
|
+
if norm_axes != axes:
|
|
91
|
+
raise ValueError("SymmetryGroup axes must already be normalized")
|
|
92
|
+
if shape is not None:
|
|
93
|
+
for orbit in group.orbits():
|
|
94
|
+
sizes = {shape[axes[i]] for i in orbit}
|
|
95
|
+
if len(sizes) > 1:
|
|
96
|
+
raise SymmetryError(
|
|
97
|
+
axes=tuple(axes[i] for i in orbit), max_deviation=float("inf")
|
|
98
|
+
)
|
|
99
|
+
return group
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def unique_elements_for_shape(
|
|
103
|
+
group: SymmetryGroup | None,
|
|
104
|
+
shape: tuple[int, ...],
|
|
105
|
+
) -> int:
|
|
106
|
+
"""Return the number of unique tensor elements implied by symmetry."""
|
|
107
|
+
if group is None:
|
|
108
|
+
return math.prod(shape)
|
|
109
|
+
return _unique_elements_for_shape_cached(group, tuple(shape))
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@functools.cache
|
|
113
|
+
def _unique_elements_for_shape_cached(
|
|
114
|
+
group: SymmetryGroup,
|
|
115
|
+
shape: tuple[int, ...],
|
|
116
|
+
) -> int:
|
|
117
|
+
validate_symmetry_group(group, ndim=len(shape), shape=shape)
|
|
118
|
+
axes = group.axes
|
|
119
|
+
if axes is None:
|
|
120
|
+
axes = tuple(range(group.degree))
|
|
121
|
+
size_dict = {local_idx: shape[axis] for local_idx, axis in enumerate(axes)}
|
|
122
|
+
result = group.burnside_unique_count(size_dict)
|
|
123
|
+
accounted = set(axes)
|
|
124
|
+
for axis, size in enumerate(shape):
|
|
125
|
+
if axis not in accounted:
|
|
126
|
+
result *= size
|
|
127
|
+
return result
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _build_from_kind(kind: tuple) -> SymmetryGroup | None:
|
|
131
|
+
"""Construct an interned SymmetryGroup from a ``_known_kind`` tag.
|
|
132
|
+
|
|
133
|
+
Routes through the public factory matching the kind name. Returns
|
|
134
|
+
``None`` for trivial (identity) kinds, since callers expect ``None``
|
|
135
|
+
to represent "no non-trivial symmetry."
|
|
136
|
+
"""
|
|
137
|
+
name = kind[0]
|
|
138
|
+
if name == "identity":
|
|
139
|
+
return None
|
|
140
|
+
if name == "symmetric":
|
|
141
|
+
return SymmetryGroup.symmetric(axes=kind[1])
|
|
142
|
+
if name == "cyclic":
|
|
143
|
+
return SymmetryGroup.cyclic(axes=kind[1])
|
|
144
|
+
if name == "dihedral":
|
|
145
|
+
return SymmetryGroup.dihedral(axes=kind[1])
|
|
146
|
+
if name == "direct_product":
|
|
147
|
+
children = [_build_from_kind(child) for child in kind[1]]
|
|
148
|
+
non_trivial = [c for c in children if c is not None]
|
|
149
|
+
if not non_trivial:
|
|
150
|
+
return None
|
|
151
|
+
if len(non_trivial) == 1:
|
|
152
|
+
return non_trivial[0]
|
|
153
|
+
return SymmetryGroup.direct_product(*non_trivial)
|
|
154
|
+
raise AssertionError(f"unknown kind {kind!r}")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def embed_group(group: SymmetryGroup | None, ndim: int) -> SymmetryGroup | None:
|
|
158
|
+
"""Embed a group acting on selected tensor axes into full rank ``ndim``."""
|
|
159
|
+
if group is None:
|
|
160
|
+
return None
|
|
161
|
+
validate_symmetry_group(group, ndim=ndim)
|
|
162
|
+
axes = group.axes
|
|
163
|
+
if axes is None:
|
|
164
|
+
axes = tuple(range(group.degree))
|
|
165
|
+
if axes == tuple(range(ndim)) and group.degree == ndim:
|
|
166
|
+
return group
|
|
167
|
+
generators = []
|
|
168
|
+
for generator in group.generators:
|
|
169
|
+
arr = list(range(ndim))
|
|
170
|
+
for local_idx, axis in enumerate(axes):
|
|
171
|
+
arr[axis] = axes[generator.array_form[local_idx]]
|
|
172
|
+
generators.append(arr)
|
|
173
|
+
if not generators:
|
|
174
|
+
generators.append(list(range(ndim)))
|
|
175
|
+
return SymmetryGroup.from_generators(generators, axes=tuple(range(ndim)))
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def restrict_group_to_axes(
|
|
179
|
+
group: SymmetryGroup | None,
|
|
180
|
+
axes: Iterable[int],
|
|
181
|
+
) -> SymmetryGroup | None:
|
|
182
|
+
"""Restrict a group to a specific ordered subset of its tensor axes.
|
|
183
|
+
|
|
184
|
+
The helper composes :meth:`SymmetryGroup.setwise_stabilizer` and
|
|
185
|
+
:meth:`SymmetryGroup.restrict` so strict subsets of free-permuting groups
|
|
186
|
+
(e.g. ``symmetric(A)``) project cleanly to a sub-action. Provenance is
|
|
187
|
+
preserved only in the no-op case (``axes == group.axes``); strict-subset
|
|
188
|
+
results carry ``_known_kind=None`` — the "sub-symmetric is still symmetric"
|
|
189
|
+
rule lives in ``reduce_group``, not here.
|
|
190
|
+
"""
|
|
191
|
+
if group is None:
|
|
192
|
+
return None
|
|
193
|
+
validate_symmetry_group(group)
|
|
194
|
+
group_axes = group.axes
|
|
195
|
+
if group_axes is None:
|
|
196
|
+
group_axes = tuple(range(group.degree))
|
|
197
|
+
wanted_axes = _normalize_axis_tuple(axes, what="restricted axes")
|
|
198
|
+
if wanted_axes == group_axes:
|
|
199
|
+
# No-op: kind passes through via the interned original.
|
|
200
|
+
return group
|
|
201
|
+
local_indices = []
|
|
202
|
+
for axis in wanted_axes:
|
|
203
|
+
if axis not in group_axes:
|
|
204
|
+
raise ValueError(
|
|
205
|
+
f"restricted axes {wanted_axes} are not a subset of {group_axes}"
|
|
206
|
+
)
|
|
207
|
+
local_indices.append(group_axes.index(axis))
|
|
208
|
+
if len(local_indices) < 2:
|
|
209
|
+
return None
|
|
210
|
+
kept = tuple(local_indices)
|
|
211
|
+
# First compute the setwise stabilizer so that restrict() only sees
|
|
212
|
+
# permutations that map the kept set to itself.
|
|
213
|
+
stabilized = group.setwise_stabilizer(set(kept))
|
|
214
|
+
restricted = stabilized.restrict(kept)
|
|
215
|
+
if restricted.order() <= 1:
|
|
216
|
+
return None
|
|
217
|
+
return restricted
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _remap_kind(kind: tuple | None, axis_map: Mapping[Any, Any]) -> tuple | None:
|
|
221
|
+
"""Apply ``axis_map`` to the axes inside a ``_known_kind`` tag.
|
|
222
|
+
|
|
223
|
+
Returns ``None`` if any leaf axis is missing from the map (caller's
|
|
224
|
+
responsibility to ensure full coverage).
|
|
225
|
+
"""
|
|
226
|
+
if kind is None:
|
|
227
|
+
return None
|
|
228
|
+
name = kind[0]
|
|
229
|
+
if name in ("identity", "symmetric", "cyclic", "dihedral"):
|
|
230
|
+
try:
|
|
231
|
+
return (name, tuple(axis_map[a] for a in kind[1]))
|
|
232
|
+
except KeyError:
|
|
233
|
+
return None
|
|
234
|
+
if name == "direct_product":
|
|
235
|
+
children = tuple(_remap_kind(child, axis_map) for child in kind[1])
|
|
236
|
+
if any(child is None for child in children):
|
|
237
|
+
return None
|
|
238
|
+
return ("direct_product", tuple(sorted(children, key=repr)))
|
|
239
|
+
return None
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _reduced_kind(
|
|
243
|
+
kind: tuple | None,
|
|
244
|
+
*,
|
|
245
|
+
reduced_axes: set[int],
|
|
246
|
+
axis_map: Mapping[Any, Any],
|
|
247
|
+
) -> tuple | None:
|
|
248
|
+
"""Compute the reduced kind tag for a known-kind group.
|
|
249
|
+
|
|
250
|
+
``reduced_axes`` is the set of tensor axes being reduced over (in the
|
|
251
|
+
parent group's axis space). ``axis_map`` is the surviving-axes mapping
|
|
252
|
+
from old tensor axes to new tensor positions (post-reduction layout).
|
|
253
|
+
|
|
254
|
+
Returns ``None`` if the result is trivial or can't be expressed in
|
|
255
|
+
closed form.
|
|
256
|
+
"""
|
|
257
|
+
if kind is None:
|
|
258
|
+
return None
|
|
259
|
+
name = kind[0]
|
|
260
|
+
if name == "identity":
|
|
261
|
+
kept = tuple(axis_map[a] for a in kind[1] if a not in reduced_axes)
|
|
262
|
+
if not kept:
|
|
263
|
+
return None
|
|
264
|
+
return ("identity", kept)
|
|
265
|
+
if name == "symmetric":
|
|
266
|
+
kept = tuple(axis_map[a] for a in kind[1] if a not in reduced_axes)
|
|
267
|
+
if len(kept) < 2:
|
|
268
|
+
return None
|
|
269
|
+
return ("symmetric", kept)
|
|
270
|
+
if name == "cyclic":
|
|
271
|
+
# Cyclic only survives if NONE of its axes are reduced (the cycle
|
|
272
|
+
# would otherwise no longer be a closed orbit on the kept axes).
|
|
273
|
+
if any(a in reduced_axes for a in kind[1]):
|
|
274
|
+
return None
|
|
275
|
+
kept = tuple(axis_map[a] for a in kind[1])
|
|
276
|
+
return ("cyclic", kept)
|
|
277
|
+
if name == "dihedral":
|
|
278
|
+
if any(a in reduced_axes for a in kind[1]):
|
|
279
|
+
return None
|
|
280
|
+
kept = tuple(axis_map[a] for a in kind[1])
|
|
281
|
+
return ("dihedral", kept)
|
|
282
|
+
if name == "direct_product":
|
|
283
|
+
new_children = []
|
|
284
|
+
for child in kind[1]:
|
|
285
|
+
new_child = _reduced_kind(
|
|
286
|
+
child, reduced_axes=reduced_axes, axis_map=axis_map
|
|
287
|
+
)
|
|
288
|
+
if new_child is not None:
|
|
289
|
+
new_children.append(new_child)
|
|
290
|
+
if not new_children:
|
|
291
|
+
return None
|
|
292
|
+
if len(new_children) == 1:
|
|
293
|
+
return new_children[0]
|
|
294
|
+
return ("direct_product", tuple(sorted(new_children, key=repr)))
|
|
295
|
+
return None
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def remap_group_axes(
|
|
299
|
+
group: SymmetryGroup | None,
|
|
300
|
+
axis_map: Mapping[int, int],
|
|
301
|
+
) -> SymmetryGroup | None:
|
|
302
|
+
"""Rename tensor axes while preserving the group's local action."""
|
|
303
|
+
if group is None:
|
|
304
|
+
return None
|
|
305
|
+
validate_symmetry_group(group)
|
|
306
|
+
axes = group.axes
|
|
307
|
+
if axes is None:
|
|
308
|
+
axes = tuple(range(group.degree))
|
|
309
|
+
remapped_axes = []
|
|
310
|
+
for axis in axes:
|
|
311
|
+
if axis not in axis_map:
|
|
312
|
+
raise ValueError(f"missing remap for axis {axis}")
|
|
313
|
+
remapped_axes.append(axis_map[axis])
|
|
314
|
+
_normalize_axis_tuple(remapped_axes, what="remapped axes")
|
|
315
|
+
remapped = SymmetryGroup.from_generators(
|
|
316
|
+
group.generator_literals, # pyright: ignore[reportArgumentType]
|
|
317
|
+
axes=tuple(remapped_axes), # type: ignore[arg-type]
|
|
318
|
+
)
|
|
319
|
+
new_kind = _remap_kind(group._known_kind, axis_map)
|
|
320
|
+
if new_kind is not None:
|
|
321
|
+
remapped._known_kind = new_kind
|
|
322
|
+
remapped = SymmetryGroup._intern(remapped)
|
|
323
|
+
return remapped
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def remap_group_for_expand_dims(
|
|
327
|
+
group: SymmetryGroup | None,
|
|
328
|
+
*,
|
|
329
|
+
ndim: int,
|
|
330
|
+
axis,
|
|
331
|
+
) -> SymmetryGroup | None:
|
|
332
|
+
"""Remap tensor-axis support after ``numpy.expand_dims`` axis insertion."""
|
|
333
|
+
if group is not None:
|
|
334
|
+
validate_symmetry_group(group, ndim=ndim)
|
|
335
|
+
probe_shape = tuple(range(2, 2 + ndim))
|
|
336
|
+
probe = np.empty(probe_shape)
|
|
337
|
+
expanded_shape = np.expand_dims(probe, axis=axis).shape
|
|
338
|
+
remapped = None
|
|
339
|
+
if group is not None:
|
|
340
|
+
axis_map = {
|
|
341
|
+
old_axis: expanded_shape.index(size)
|
|
342
|
+
for old_axis, size in enumerate(probe_shape)
|
|
343
|
+
}
|
|
344
|
+
remapped = remap_group_axes(group, axis_map)
|
|
345
|
+
inserted_axes = tuple(
|
|
346
|
+
axis_idx for axis_idx, size in enumerate(expanded_shape) if size == 1
|
|
347
|
+
)
|
|
348
|
+
inserted = inserted_axes_symmetry(inserted_axes)
|
|
349
|
+
return direct_product_groups(remapped, inserted)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def inserted_axes_symmetry(
|
|
353
|
+
inserted_positions: Sequence[int],
|
|
354
|
+
) -> SymmetryGroup | None:
|
|
355
|
+
"""Symmetry of N freshly-inserted size-1 axes at the given output positions.
|
|
356
|
+
|
|
357
|
+
Used by axis-inserting operations (``expand_dims``, ``__getitem__`` with
|
|
358
|
+
``None``/``np.newaxis``). Returns ``None`` for fewer than 2 positions
|
|
359
|
+
(no non-trivial group). For 2+, returns
|
|
360
|
+
``SymmetryGroup.symmetric(axes=tuple(inserted_positions))``.
|
|
361
|
+
"""
|
|
362
|
+
if len(inserted_positions) < 2:
|
|
363
|
+
return None
|
|
364
|
+
return SymmetryGroup.symmetric(axes=tuple(inserted_positions))
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def intersect_groups(
|
|
368
|
+
a: SymmetryGroup | None,
|
|
369
|
+
b: SymmetryGroup | None,
|
|
370
|
+
*,
|
|
371
|
+
ndim: int,
|
|
372
|
+
) -> SymmetryGroup | None:
|
|
373
|
+
"""Intersect two groups after embedding them into the same tensor rank."""
|
|
374
|
+
if a is None or b is None:
|
|
375
|
+
return None
|
|
376
|
+
# Easy known-kind case — same group ∩ itself = itself. Preserve
|
|
377
|
+
# provenance without enumeration. Skip trivial groups (order <= 1)
|
|
378
|
+
# so the existing "None means no symmetry" convention holds.
|
|
379
|
+
if a._known_kind is not None and a._known_kind == b._known_kind:
|
|
380
|
+
if a.order() <= 1:
|
|
381
|
+
return None
|
|
382
|
+
return a
|
|
383
|
+
if a.axes is not None and b.axes is not None and a.axes == b.axes:
|
|
384
|
+
common = sorted(
|
|
385
|
+
set(a.elements()) & set(b.elements()),
|
|
386
|
+
key=lambda perm: tuple(perm.array_form),
|
|
387
|
+
)
|
|
388
|
+
if len(common) <= 1:
|
|
389
|
+
return None
|
|
390
|
+
return SymmetryGroup(*common, axes=a.axes)
|
|
391
|
+
embedded_a = embed_group(a, ndim)
|
|
392
|
+
embedded_b = embed_group(b, ndim)
|
|
393
|
+
assert embedded_a is not None
|
|
394
|
+
assert embedded_b is not None
|
|
395
|
+
common = sorted(
|
|
396
|
+
set(embedded_a.elements()) & set(embedded_b.elements()),
|
|
397
|
+
key=lambda perm: tuple(perm.array_form),
|
|
398
|
+
)
|
|
399
|
+
if len(common) <= 1:
|
|
400
|
+
return None
|
|
401
|
+
return SymmetryGroup(*common, axes=tuple(range(ndim)))
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def direct_product_groups(*groups: SymmetryGroup | None) -> SymmetryGroup | None:
|
|
405
|
+
"""Compose disjoint groups, dropping trivial and absent factors."""
|
|
406
|
+
factors = []
|
|
407
|
+
for group in groups:
|
|
408
|
+
if group is None:
|
|
409
|
+
continue
|
|
410
|
+
validate_symmetry_group(group)
|
|
411
|
+
if group.order() > 1:
|
|
412
|
+
factors.append(group)
|
|
413
|
+
if not factors:
|
|
414
|
+
return None
|
|
415
|
+
if len(factors) == 1:
|
|
416
|
+
return factors[0]
|
|
417
|
+
product = SymmetryGroup.direct_product(*factors)
|
|
418
|
+
return product if product.order() > 1 else None
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def setwise_stabilizer(
|
|
422
|
+
group: SymmetryGroup | None,
|
|
423
|
+
fixed_set: Iterable[int],
|
|
424
|
+
) -> SymmetryGroup | None:
|
|
425
|
+
"""Return the subgroup G' = {π ∈ G : π(fixed_set) = fixed_set}.
|
|
426
|
+
|
|
427
|
+
`fixed_set` is interpreted as tensor-axis indices; elements not in
|
|
428
|
+
``group.axes`` are silently filtered out. Returns ``None`` if the
|
|
429
|
+
stabilizer is trivial (order ≤ 1).
|
|
430
|
+
"""
|
|
431
|
+
if group is None:
|
|
432
|
+
return None
|
|
433
|
+
axes = group.axes
|
|
434
|
+
if axes is None:
|
|
435
|
+
axes = tuple(range(group.degree))
|
|
436
|
+
# Translate tensor-axis indices to internal degree indices.
|
|
437
|
+
internal = {axes.index(a) for a in fixed_set if a in axes}
|
|
438
|
+
result = group.setwise_stabilizer(internal)
|
|
439
|
+
return result if result.order() > 1 else None
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def group_orbits_on_axes(
|
|
443
|
+
group: SymmetryGroup,
|
|
444
|
+
axes: Sequence[int],
|
|
445
|
+
) -> list[set[int]]:
|
|
446
|
+
"""Return the orbits of `group`'s action on the given tensor `axes`.
|
|
447
|
+
|
|
448
|
+
Axes not acted on by `group` are returned as singleton orbits. Output
|
|
449
|
+
order is deterministic (axes appear in their first-encounter order).
|
|
450
|
+
"""
|
|
451
|
+
axis_list = list(axes)
|
|
452
|
+
group_axes = group.axes
|
|
453
|
+
if group_axes is None:
|
|
454
|
+
group_axes = tuple(range(group.degree))
|
|
455
|
+
# Map: tensor-axis -> set of tensor-axes reachable by any generator.
|
|
456
|
+
# For axes outside group_axes, the orbit is just itself.
|
|
457
|
+
seen: set[int] = set()
|
|
458
|
+
orbits: list[set[int]] = []
|
|
459
|
+
for a in axis_list:
|
|
460
|
+
if a in seen:
|
|
461
|
+
continue
|
|
462
|
+
if a not in group_axes:
|
|
463
|
+
orbits.append({a})
|
|
464
|
+
seen.add(a)
|
|
465
|
+
continue
|
|
466
|
+
orbit: set[int] = set()
|
|
467
|
+
frontier = {a}
|
|
468
|
+
while frontier:
|
|
469
|
+
x = frontier.pop()
|
|
470
|
+
if x in orbit:
|
|
471
|
+
continue
|
|
472
|
+
orbit.add(x)
|
|
473
|
+
local_x = group_axes.index(x)
|
|
474
|
+
for generator in group.generators:
|
|
475
|
+
local_y = generator.array_form[local_x]
|
|
476
|
+
y = group_axes[local_y]
|
|
477
|
+
if y not in orbit:
|
|
478
|
+
frontier.add(y)
|
|
479
|
+
orbits.append(orbit)
|
|
480
|
+
seen |= orbit
|
|
481
|
+
return orbits
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _normalize_reps_for_output(reps, *, output_ndim: int) -> tuple[int, ...]:
|
|
485
|
+
"""Normalize `reps` arg to a tuple of length `output_ndim`.
|
|
486
|
+
|
|
487
|
+
Matches NumPy.tile's right-alignment rule: if `reps` is shorter, it's
|
|
488
|
+
prepended with 1s. If `reps` is a scalar, treat as `(reps,)`.
|
|
489
|
+
"""
|
|
490
|
+
if isinstance(reps, int):
|
|
491
|
+
reps_tup = (reps,)
|
|
492
|
+
else:
|
|
493
|
+
reps_tup = tuple(reps)
|
|
494
|
+
if len(reps_tup) < output_ndim:
|
|
495
|
+
reps_tup = (1,) * (output_ndim - len(reps_tup)) + reps_tup
|
|
496
|
+
return reps_tup
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def broadcast_group(
|
|
500
|
+
group: SymmetryGroup | None,
|
|
501
|
+
*,
|
|
502
|
+
input_shape: tuple[int, ...],
|
|
503
|
+
output_shape: tuple[int, ...],
|
|
504
|
+
) -> SymmetryGroup | None:
|
|
505
|
+
"""Broadcast a single input symmetry group onto an output shape."""
|
|
506
|
+
if len(input_shape) > len(output_shape):
|
|
507
|
+
raise ValueError("input rank cannot exceed output rank")
|
|
508
|
+
|
|
509
|
+
factors: list[SymmetryGroup] = []
|
|
510
|
+
offset = len(output_shape) - len(input_shape)
|
|
511
|
+
|
|
512
|
+
created_by_size: OrderedDict[int, list[int]] = OrderedDict()
|
|
513
|
+
for axis in range(offset):
|
|
514
|
+
created_by_size.setdefault(output_shape[axis], []).append(axis)
|
|
515
|
+
for block in created_by_size.values():
|
|
516
|
+
if len(block) >= 2:
|
|
517
|
+
factors.append(SymmetryGroup.symmetric(axes=tuple(block)))
|
|
518
|
+
|
|
519
|
+
if group is not None:
|
|
520
|
+
validate_symmetry_group(group, ndim=len(input_shape), shape=input_shape)
|
|
521
|
+
axes = group.axes
|
|
522
|
+
if axes is None:
|
|
523
|
+
axes = tuple(range(group.degree))
|
|
524
|
+
kept_local = []
|
|
525
|
+
for local_idx, axis in enumerate(axes):
|
|
526
|
+
out_axis = axis + offset
|
|
527
|
+
if input_shape[axis] == 1 and output_shape[out_axis] > 1:
|
|
528
|
+
continue
|
|
529
|
+
kept_local.append(local_idx)
|
|
530
|
+
if len(kept_local) >= 2:
|
|
531
|
+
restricted = (
|
|
532
|
+
group
|
|
533
|
+
if len(kept_local) == group.degree
|
|
534
|
+
else group.restrict(tuple(kept_local))
|
|
535
|
+
)
|
|
536
|
+
restricted_axes = (
|
|
537
|
+
restricted.axes
|
|
538
|
+
if restricted.axes is not None
|
|
539
|
+
else tuple(range(restricted.degree))
|
|
540
|
+
)
|
|
541
|
+
remapped = remap_group_axes(
|
|
542
|
+
restricted,
|
|
543
|
+
{
|
|
544
|
+
restricted_axes[new_local_idx]: axes[old_local_idx] + offset
|
|
545
|
+
for new_local_idx, old_local_idx in enumerate(kept_local)
|
|
546
|
+
},
|
|
547
|
+
)
|
|
548
|
+
if remapped is not None and remapped.order() > 1:
|
|
549
|
+
factors.append(remapped)
|
|
550
|
+
|
|
551
|
+
return direct_product_groups(*factors)
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def reduce_group(
|
|
555
|
+
group: SymmetryGroup | None,
|
|
556
|
+
*,
|
|
557
|
+
ndim: int,
|
|
558
|
+
axis: int | tuple[int, ...] | None,
|
|
559
|
+
keepdims: bool = False,
|
|
560
|
+
) -> SymmetryGroup | None:
|
|
561
|
+
"""Propagate a single symmetry group through a reduction."""
|
|
562
|
+
if group is None or axis is None:
|
|
563
|
+
return None
|
|
564
|
+
validate_symmetry_group(group, ndim=ndim)
|
|
565
|
+
axes_set = {axis % ndim} if isinstance(axis, int) else {a % ndim for a in axis}
|
|
566
|
+
old_to_new: dict[int, int] = {}
|
|
567
|
+
if keepdims:
|
|
568
|
+
old_to_new = {dim: dim for dim in range(ndim)}
|
|
569
|
+
else:
|
|
570
|
+
new_idx = 0
|
|
571
|
+
for dim in range(ndim):
|
|
572
|
+
if dim not in axes_set:
|
|
573
|
+
old_to_new[dim] = new_idx
|
|
574
|
+
new_idx += 1
|
|
575
|
+
|
|
576
|
+
# Fast path for known-kind groups: compute the reduced kind directly
|
|
577
|
+
# and route through the appropriate factory. Avoids _dimino entirely.
|
|
578
|
+
if group._known_kind is not None:
|
|
579
|
+
reduced_kind = _reduced_kind(
|
|
580
|
+
group._known_kind, reduced_axes=axes_set, axis_map=old_to_new
|
|
581
|
+
)
|
|
582
|
+
if reduced_kind is not None:
|
|
583
|
+
return _build_from_kind(reduced_kind)
|
|
584
|
+
# Fall through to the generic path if the kind can't be reduced
|
|
585
|
+
# in closed form (e.g. partial reduction of a direct_product child
|
|
586
|
+
# whose own kind doesn't survive).
|
|
587
|
+
|
|
588
|
+
group_axes = group.axes
|
|
589
|
+
if group_axes is None:
|
|
590
|
+
group_axes = tuple(range(group.degree))
|
|
591
|
+
local_reduced = {
|
|
592
|
+
i for i, tensor_axis in enumerate(group_axes) if tensor_axis in axes_set
|
|
593
|
+
}
|
|
594
|
+
local_kept = [
|
|
595
|
+
i for i, tensor_axis in enumerate(group_axes) if tensor_axis not in axes_set
|
|
596
|
+
]
|
|
597
|
+
|
|
598
|
+
if not local_reduced:
|
|
599
|
+
remapped = remap_group_axes(
|
|
600
|
+
group,
|
|
601
|
+
{tensor_axis: old_to_new[tensor_axis] for tensor_axis in group_axes},
|
|
602
|
+
)
|
|
603
|
+
return remapped if remapped is not None and remapped.order() > 1 else None
|
|
604
|
+
if not local_kept:
|
|
605
|
+
return None
|
|
606
|
+
|
|
607
|
+
stabilized = group.setwise_stabilizer(local_reduced)
|
|
608
|
+
restricted = stabilized.restrict(tuple(local_kept))
|
|
609
|
+
if restricted.order() <= 1:
|
|
610
|
+
return None
|
|
611
|
+
restricted_axes = (
|
|
612
|
+
restricted.axes
|
|
613
|
+
if restricted.axes is not None
|
|
614
|
+
else tuple(range(restricted.degree))
|
|
615
|
+
)
|
|
616
|
+
remapped = remap_group_axes(
|
|
617
|
+
restricted,
|
|
618
|
+
{
|
|
619
|
+
restricted_axes[new_local_idx]: old_to_new[group_axes[old_local_idx]]
|
|
620
|
+
for new_local_idx, old_local_idx in enumerate(local_kept)
|
|
621
|
+
},
|
|
622
|
+
)
|
|
623
|
+
return remapped if remapped is not None and remapped.order() > 1 else None
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
def wrap_with_symmetry(data, symmetry: SymmetryGroup | None):
|
|
627
|
+
"""Wrap ndarray-like data with symmetry metadata when a group is present."""
|
|
628
|
+
array = np.asarray(data)
|
|
629
|
+
if symmetry is None:
|
|
630
|
+
return array
|
|
631
|
+
validate_symmetry_group(symmetry, ndim=array.ndim)
|
|
632
|
+
from flopscope._symmetric import SymmetricTensor
|
|
633
|
+
|
|
634
|
+
return SymmetricTensor(array, symmetry=symmetry)
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def wrap_with_trusted_symmetry(data, symmetry: SymmetryGroup | None):
|
|
638
|
+
"""Wrap data with already-proven symmetry metadata without re-validating.
|
|
639
|
+
|
|
640
|
+
This helper is for internal call sites only, where the symmetry was
|
|
641
|
+
generated or revalidated by trusted constructor logic. Avoiding the
|
|
642
|
+
redundant validation call keeps constructor hot paths fast while leaving
|
|
643
|
+
public/user-facing symmetry paths fully validated.
|
|
644
|
+
"""
|
|
645
|
+
array = np.asarray(data)
|
|
646
|
+
if symmetry is None:
|
|
647
|
+
return array
|
|
648
|
+
from flopscope._symmetric import SymmetricTensor
|
|
649
|
+
|
|
650
|
+
return SymmetricTensor(array, symmetry=symmetry)
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
def wrap_with_inferred_symmetry(data, symmetry: SymmetryGroup | None):
|
|
654
|
+
"""Wrap data with auto-inferred symmetry metadata.
|
|
655
|
+
|
|
656
|
+
Identical to :func:`wrap_with_trusted_symmetry` except the resulting
|
|
657
|
+
array carries ``_symmetry_inferred = True``. Read by
|
|
658
|
+
``_prepare_symmetric_out`` to decide whether a non-symmetric ``out=``
|
|
659
|
+
write should silently downgrade the target (inferred) or raise
|
|
660
|
+
(explicit). Internal call sites only — never expose to user code.
|
|
661
|
+
"""
|
|
662
|
+
array = np.asarray(data)
|
|
663
|
+
if symmetry is None:
|
|
664
|
+
return array
|
|
665
|
+
from flopscope._symmetric import SymmetricTensor
|
|
666
|
+
|
|
667
|
+
obj = SymmetricTensor(array, symmetry=symmetry)
|
|
668
|
+
obj._symmetry_inferred = True
|
|
669
|
+
return obj
|
flopscope/_type_info.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Dtype introspection utilities re-exported as free (0-FLOP) helpers.
|
|
2
|
+
|
|
3
|
+
`finfo` and `iinfo` are metadata queries that return info objects
|
|
4
|
+
(eps, max, min, bits, etc.). They perform no computation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import numpy as _np
|
|
10
|
+
|
|
11
|
+
finfo = _np.finfo
|
|
12
|
+
iinfo = _np.iinfo
|