flopscope 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. benchmarks/__init__.py +1 -0
  2. benchmarks/__main__.py +6 -0
  3. benchmarks/_baseline.py +171 -0
  4. benchmarks/_bitwise.py +231 -0
  5. benchmarks/_complex.py +176 -0
  6. benchmarks/_contractions.py +291 -0
  7. benchmarks/_fft.py +198 -0
  8. benchmarks/_impl_urls.py +139 -0
  9. benchmarks/_linalg.py +197 -0
  10. benchmarks/_linalg_delegates.py +407 -0
  11. benchmarks/_metadata.py +141 -0
  12. benchmarks/_misc.py +653 -0
  13. benchmarks/_perf.py +321 -0
  14. benchmarks/_perm_group_calibration.py +175 -0
  15. benchmarks/_pointwise.py +372 -0
  16. benchmarks/_polynomial.py +193 -0
  17. benchmarks/_random.py +209 -0
  18. benchmarks/_reductions.py +136 -0
  19. benchmarks/_sorting.py +289 -0
  20. benchmarks/_stats.py +137 -0
  21. benchmarks/_window.py +92 -0
  22. benchmarks/accumulation/__init__.py +0 -0
  23. benchmarks/accumulation/bench_cost_compute.py +138 -0
  24. benchmarks/dashboard.py +312 -0
  25. benchmarks/runner.py +636 -0
  26. flopscope/__init__.py +273 -0
  27. flopscope/_accumulation/__init__.py +13 -0
  28. flopscope/_accumulation/_bipartite.py +121 -0
  29. flopscope/_accumulation/_burnside.py +51 -0
  30. flopscope/_accumulation/_cache.py +146 -0
  31. flopscope/_accumulation/_components.py +153 -0
  32. flopscope/_accumulation/_cost.py +1414 -0
  33. flopscope/_accumulation/_cost_descriptions.py +63 -0
  34. flopscope/_accumulation/_detection.py +318 -0
  35. flopscope/_accumulation/_ladder.py +191 -0
  36. flopscope/_accumulation/_output_orbit.py +104 -0
  37. flopscope/_accumulation/_partition.py +290 -0
  38. flopscope/_accumulation/_path_info.py +211 -0
  39. flopscope/_accumulation/_public.py +169 -0
  40. flopscope/_accumulation/_reduction.py +310 -0
  41. flopscope/_accumulation/_regimes.py +303 -0
  42. flopscope/_accumulation/_shape.py +33 -0
  43. flopscope/_accumulation/_wreath.py +209 -0
  44. flopscope/_budget.py +1027 -0
  45. flopscope/_config.py +118 -0
  46. flopscope/_counting_ops.py +451 -0
  47. flopscope/_display.py +478 -0
  48. flopscope/_docstrings.py +59 -0
  49. flopscope/_dtypes.py +20 -0
  50. flopscope/_einsum.py +717 -0
  51. flopscope/_errstate.py +25 -0
  52. flopscope/_flops.py +282 -0
  53. flopscope/_free_ops.py +2654 -0
  54. flopscope/_ndarray.py +1126 -0
  55. flopscope/_opt_einsum/LICENSE +21 -0
  56. flopscope/_opt_einsum/NOTICE +59 -0
  57. flopscope/_opt_einsum/__init__.py +209 -0
  58. flopscope/_opt_einsum/_contract.py +1478 -0
  59. flopscope/_opt_einsum/_helpers.py +164 -0
  60. flopscope/_opt_einsum/_hsluv.py +273 -0
  61. flopscope/_opt_einsum/_path_random.py +462 -0
  62. flopscope/_opt_einsum/_paths.py +1653 -0
  63. flopscope/_opt_einsum/_subgraph_symmetry.py +544 -0
  64. flopscope/_opt_einsum/_symmetry.py +140 -0
  65. flopscope/_opt_einsum/_typing.py +37 -0
  66. flopscope/_perm_group.py +717 -0
  67. flopscope/_pointwise.py +2522 -0
  68. flopscope/_polynomial.py +278 -0
  69. flopscope/_registry.py +3216 -0
  70. flopscope/_sorting_ops.py +571 -0
  71. flopscope/_symmetric.py +812 -0
  72. flopscope/_symmetry_transport.py +510 -0
  73. flopscope/_symmetry_utils.py +669 -0
  74. flopscope/_type_info.py +12 -0
  75. flopscope/_unwrap.py +70 -0
  76. flopscope/_validation.py +83 -0
  77. flopscope/_version_check.py +46 -0
  78. flopscope/_weights.py +195 -0
  79. flopscope/_window.py +177 -0
  80. flopscope/accounting.py +565 -0
  81. flopscope/data/default_weights.json +462 -0
  82. flopscope/data/weights.csv +509 -0
  83. flopscope/errors.py +197 -0
  84. flopscope/numpy/__init__.py +878 -0
  85. flopscope/numpy/fft/__init__.py +55 -0
  86. flopscope/numpy/fft/_free.py +51 -0
  87. flopscope/numpy/fft/_transforms.py +695 -0
  88. flopscope/numpy/linalg/__init__.py +105 -0
  89. flopscope/numpy/linalg/_aliases.py +126 -0
  90. flopscope/numpy/linalg/_compound.py +161 -0
  91. flopscope/numpy/linalg/_decompositions.py +353 -0
  92. flopscope/numpy/linalg/_properties.py +533 -0
  93. flopscope/numpy/linalg/_solvers.py +444 -0
  94. flopscope/numpy/linalg/_svd.py +122 -0
  95. flopscope/numpy/random/__init__.py +684 -0
  96. flopscope/numpy/random/_cost_formulas.py +115 -0
  97. flopscope/numpy/random/_counted_classes.py +241 -0
  98. flopscope/numpy/testing/__init__.py +13 -0
  99. flopscope/numpy/typing/__init__.py +30 -0
  100. flopscope/py.typed +0 -0
  101. flopscope/stats/__init__.py +84 -0
  102. flopscope/stats/_base.py +77 -0
  103. flopscope/stats/_cauchy.py +146 -0
  104. flopscope/stats/_erf.py +190 -0
  105. flopscope/stats/_expon.py +146 -0
  106. flopscope/stats/_laplace.py +150 -0
  107. flopscope/stats/_logistic.py +148 -0
  108. flopscope/stats/_lognorm.py +160 -0
  109. flopscope/stats/_ndtri.py +133 -0
  110. flopscope/stats/_norm.py +149 -0
  111. flopscope/stats/_truncnorm.py +186 -0
  112. flopscope/stats/_uniform.py +141 -0
  113. flopscope-0.2.0.dist-info/METADATA +23 -0
  114. flopscope-0.2.0.dist-info/RECORD +115 -0
  115. flopscope-0.2.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,669 @@
1
+ """Helper primitives for exact tensor symmetry groups."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import math
7
+ from collections import OrderedDict
8
+ from collections.abc import Iterable, Mapping, Sequence
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+
13
+ from flopscope._perm_group import SymmetryGroup
14
+ from flopscope.errors import SymmetryError
15
+
16
+
17
+ def _normalize_axis_tuple(
18
+ axes: Iterable[Any],
19
+ *,
20
+ ndim: int | None = None,
21
+ what: str = "axes",
22
+ ) -> tuple[int, ...]:
23
+ norm_axes = tuple(axes)
24
+ if not norm_axes:
25
+ raise ValueError(f"{what} must be non-empty")
26
+ if not all(isinstance(axis, int) for axis in norm_axes):
27
+ raise TypeError(f"{what} must contain only integers")
28
+ if len(set(norm_axes)) != len(norm_axes):
29
+ raise ValueError(f"{what} contain duplicate entries")
30
+ if ndim is not None and any(axis < 0 or axis >= ndim for axis in norm_axes):
31
+ raise ValueError(f"{what} are out of range for ndim={ndim}")
32
+ return norm_axes
33
+
34
+
35
+ def normalize_symmetry_input(obj, *, ndim: int | None = None):
36
+ """Normalize supported symmetry shorthands to a single SymmetryGroup."""
37
+ if obj is None:
38
+ return None
39
+ if isinstance(obj, SymmetryGroup):
40
+ return validate_symmetry_group(obj, ndim=ndim)
41
+ if (
42
+ isinstance(obj, list)
43
+ and obj
44
+ and all(isinstance(group, SymmetryGroup) for group in obj)
45
+ ):
46
+ raise TypeError("symmetry must be a single SymmetryGroup, not a list of groups")
47
+ if isinstance(obj, (tuple, list)) and obj:
48
+ first = obj[0]
49
+ if isinstance(first, int):
50
+ axes = _normalize_axis_tuple(obj, ndim=ndim, what="symmetry axes")
51
+ return SymmetryGroup.symmetric(axes=axes)
52
+ if isinstance(first, (tuple, list)):
53
+ blocks = []
54
+ seen: set[int] = set()
55
+ for block in obj:
56
+ norm_block = _normalize_axis_tuple(
57
+ block, ndim=ndim, what="symmetry partition block"
58
+ )
59
+ overlap = seen & set(norm_block)
60
+ if overlap:
61
+ raise ValueError(
62
+ "symmetry partition blocks overlap on axes "
63
+ f"{tuple(sorted(overlap))}"
64
+ )
65
+ seen.update(norm_block)
66
+ blocks.append(norm_block)
67
+ return SymmetryGroup.young(blocks=tuple(blocks))
68
+ raise TypeError(
69
+ "symmetry must be a SymmetryGroup or an approved axis/partition shorthand"
70
+ )
71
+
72
+
73
+ def validate_symmetry_group(
74
+ group: SymmetryGroup,
75
+ *,
76
+ ndim: int | None = None,
77
+ shape: tuple[int, ...] | None = None,
78
+ ) -> SymmetryGroup:
79
+ """Validate tensor-facing properties of a symmetry group."""
80
+ if not isinstance(group, SymmetryGroup):
81
+ raise TypeError("symmetry must be a SymmetryGroup")
82
+ axes = group.axes
83
+ if axes is None:
84
+ if ndim is not None and group.degree > ndim:
85
+ raise ValueError(
86
+ f"SymmetryGroup degree {group.degree} exceeds tensor rank {ndim}"
87
+ )
88
+ return group
89
+ norm_axes = _normalize_axis_tuple(axes, ndim=ndim, what="SymmetryGroup axes")
90
+ if norm_axes != axes:
91
+ raise ValueError("SymmetryGroup axes must already be normalized")
92
+ if shape is not None:
93
+ for orbit in group.orbits():
94
+ sizes = {shape[axes[i]] for i in orbit}
95
+ if len(sizes) > 1:
96
+ raise SymmetryError(
97
+ axes=tuple(axes[i] for i in orbit), max_deviation=float("inf")
98
+ )
99
+ return group
100
+
101
+
102
+ def unique_elements_for_shape(
103
+ group: SymmetryGroup | None,
104
+ shape: tuple[int, ...],
105
+ ) -> int:
106
+ """Return the number of unique tensor elements implied by symmetry."""
107
+ if group is None:
108
+ return math.prod(shape)
109
+ return _unique_elements_for_shape_cached(group, tuple(shape))
110
+
111
+
112
+ @functools.cache
113
+ def _unique_elements_for_shape_cached(
114
+ group: SymmetryGroup,
115
+ shape: tuple[int, ...],
116
+ ) -> int:
117
+ validate_symmetry_group(group, ndim=len(shape), shape=shape)
118
+ axes = group.axes
119
+ if axes is None:
120
+ axes = tuple(range(group.degree))
121
+ size_dict = {local_idx: shape[axis] for local_idx, axis in enumerate(axes)}
122
+ result = group.burnside_unique_count(size_dict)
123
+ accounted = set(axes)
124
+ for axis, size in enumerate(shape):
125
+ if axis not in accounted:
126
+ result *= size
127
+ return result
128
+
129
+
130
+ def _build_from_kind(kind: tuple) -> SymmetryGroup | None:
131
+ """Construct an interned SymmetryGroup from a ``_known_kind`` tag.
132
+
133
+ Routes through the public factory matching the kind name. Returns
134
+ ``None`` for trivial (identity) kinds, since callers expect ``None``
135
+ to represent "no non-trivial symmetry."
136
+ """
137
+ name = kind[0]
138
+ if name == "identity":
139
+ return None
140
+ if name == "symmetric":
141
+ return SymmetryGroup.symmetric(axes=kind[1])
142
+ if name == "cyclic":
143
+ return SymmetryGroup.cyclic(axes=kind[1])
144
+ if name == "dihedral":
145
+ return SymmetryGroup.dihedral(axes=kind[1])
146
+ if name == "direct_product":
147
+ children = [_build_from_kind(child) for child in kind[1]]
148
+ non_trivial = [c for c in children if c is not None]
149
+ if not non_trivial:
150
+ return None
151
+ if len(non_trivial) == 1:
152
+ return non_trivial[0]
153
+ return SymmetryGroup.direct_product(*non_trivial)
154
+ raise AssertionError(f"unknown kind {kind!r}")
155
+
156
+
157
+ def embed_group(group: SymmetryGroup | None, ndim: int) -> SymmetryGroup | None:
158
+ """Embed a group acting on selected tensor axes into full rank ``ndim``."""
159
+ if group is None:
160
+ return None
161
+ validate_symmetry_group(group, ndim=ndim)
162
+ axes = group.axes
163
+ if axes is None:
164
+ axes = tuple(range(group.degree))
165
+ if axes == tuple(range(ndim)) and group.degree == ndim:
166
+ return group
167
+ generators = []
168
+ for generator in group.generators:
169
+ arr = list(range(ndim))
170
+ for local_idx, axis in enumerate(axes):
171
+ arr[axis] = axes[generator.array_form[local_idx]]
172
+ generators.append(arr)
173
+ if not generators:
174
+ generators.append(list(range(ndim)))
175
+ return SymmetryGroup.from_generators(generators, axes=tuple(range(ndim)))
176
+
177
+
178
+ def restrict_group_to_axes(
179
+ group: SymmetryGroup | None,
180
+ axes: Iterable[int],
181
+ ) -> SymmetryGroup | None:
182
+ """Restrict a group to a specific ordered subset of its tensor axes.
183
+
184
+ The helper composes :meth:`SymmetryGroup.setwise_stabilizer` and
185
+ :meth:`SymmetryGroup.restrict` so strict subsets of free-permuting groups
186
+ (e.g. ``symmetric(A)``) project cleanly to a sub-action. Provenance is
187
+ preserved only in the no-op case (``axes == group.axes``); strict-subset
188
+ results carry ``_known_kind=None`` — the "sub-symmetric is still symmetric"
189
+ rule lives in ``reduce_group``, not here.
190
+ """
191
+ if group is None:
192
+ return None
193
+ validate_symmetry_group(group)
194
+ group_axes = group.axes
195
+ if group_axes is None:
196
+ group_axes = tuple(range(group.degree))
197
+ wanted_axes = _normalize_axis_tuple(axes, what="restricted axes")
198
+ if wanted_axes == group_axes:
199
+ # No-op: kind passes through via the interned original.
200
+ return group
201
+ local_indices = []
202
+ for axis in wanted_axes:
203
+ if axis not in group_axes:
204
+ raise ValueError(
205
+ f"restricted axes {wanted_axes} are not a subset of {group_axes}"
206
+ )
207
+ local_indices.append(group_axes.index(axis))
208
+ if len(local_indices) < 2:
209
+ return None
210
+ kept = tuple(local_indices)
211
+ # First compute the setwise stabilizer so that restrict() only sees
212
+ # permutations that map the kept set to itself.
213
+ stabilized = group.setwise_stabilizer(set(kept))
214
+ restricted = stabilized.restrict(kept)
215
+ if restricted.order() <= 1:
216
+ return None
217
+ return restricted
218
+
219
+
220
+ def _remap_kind(kind: tuple | None, axis_map: Mapping[Any, Any]) -> tuple | None:
221
+ """Apply ``axis_map`` to the axes inside a ``_known_kind`` tag.
222
+
223
+ Returns ``None`` if any leaf axis is missing from the map (caller's
224
+ responsibility to ensure full coverage).
225
+ """
226
+ if kind is None:
227
+ return None
228
+ name = kind[0]
229
+ if name in ("identity", "symmetric", "cyclic", "dihedral"):
230
+ try:
231
+ return (name, tuple(axis_map[a] for a in kind[1]))
232
+ except KeyError:
233
+ return None
234
+ if name == "direct_product":
235
+ children = tuple(_remap_kind(child, axis_map) for child in kind[1])
236
+ if any(child is None for child in children):
237
+ return None
238
+ return ("direct_product", tuple(sorted(children, key=repr)))
239
+ return None
240
+
241
+
242
+ def _reduced_kind(
243
+ kind: tuple | None,
244
+ *,
245
+ reduced_axes: set[int],
246
+ axis_map: Mapping[Any, Any],
247
+ ) -> tuple | None:
248
+ """Compute the reduced kind tag for a known-kind group.
249
+
250
+ ``reduced_axes`` is the set of tensor axes being reduced over (in the
251
+ parent group's axis space). ``axis_map`` is the surviving-axes mapping
252
+ from old tensor axes to new tensor positions (post-reduction layout).
253
+
254
+ Returns ``None`` if the result is trivial or can't be expressed in
255
+ closed form.
256
+ """
257
+ if kind is None:
258
+ return None
259
+ name = kind[0]
260
+ if name == "identity":
261
+ kept = tuple(axis_map[a] for a in kind[1] if a not in reduced_axes)
262
+ if not kept:
263
+ return None
264
+ return ("identity", kept)
265
+ if name == "symmetric":
266
+ kept = tuple(axis_map[a] for a in kind[1] if a not in reduced_axes)
267
+ if len(kept) < 2:
268
+ return None
269
+ return ("symmetric", kept)
270
+ if name == "cyclic":
271
+ # Cyclic only survives if NONE of its axes are reduced (the cycle
272
+ # would otherwise no longer be a closed orbit on the kept axes).
273
+ if any(a in reduced_axes for a in kind[1]):
274
+ return None
275
+ kept = tuple(axis_map[a] for a in kind[1])
276
+ return ("cyclic", kept)
277
+ if name == "dihedral":
278
+ if any(a in reduced_axes for a in kind[1]):
279
+ return None
280
+ kept = tuple(axis_map[a] for a in kind[1])
281
+ return ("dihedral", kept)
282
+ if name == "direct_product":
283
+ new_children = []
284
+ for child in kind[1]:
285
+ new_child = _reduced_kind(
286
+ child, reduced_axes=reduced_axes, axis_map=axis_map
287
+ )
288
+ if new_child is not None:
289
+ new_children.append(new_child)
290
+ if not new_children:
291
+ return None
292
+ if len(new_children) == 1:
293
+ return new_children[0]
294
+ return ("direct_product", tuple(sorted(new_children, key=repr)))
295
+ return None
296
+
297
+
298
+ def remap_group_axes(
299
+ group: SymmetryGroup | None,
300
+ axis_map: Mapping[int, int],
301
+ ) -> SymmetryGroup | None:
302
+ """Rename tensor axes while preserving the group's local action."""
303
+ if group is None:
304
+ return None
305
+ validate_symmetry_group(group)
306
+ axes = group.axes
307
+ if axes is None:
308
+ axes = tuple(range(group.degree))
309
+ remapped_axes = []
310
+ for axis in axes:
311
+ if axis not in axis_map:
312
+ raise ValueError(f"missing remap for axis {axis}")
313
+ remapped_axes.append(axis_map[axis])
314
+ _normalize_axis_tuple(remapped_axes, what="remapped axes")
315
+ remapped = SymmetryGroup.from_generators(
316
+ group.generator_literals, # pyright: ignore[reportArgumentType]
317
+ axes=tuple(remapped_axes), # type: ignore[arg-type]
318
+ )
319
+ new_kind = _remap_kind(group._known_kind, axis_map)
320
+ if new_kind is not None:
321
+ remapped._known_kind = new_kind
322
+ remapped = SymmetryGroup._intern(remapped)
323
+ return remapped
324
+
325
+
326
+ def remap_group_for_expand_dims(
327
+ group: SymmetryGroup | None,
328
+ *,
329
+ ndim: int,
330
+ axis,
331
+ ) -> SymmetryGroup | None:
332
+ """Remap tensor-axis support after ``numpy.expand_dims`` axis insertion."""
333
+ if group is not None:
334
+ validate_symmetry_group(group, ndim=ndim)
335
+ probe_shape = tuple(range(2, 2 + ndim))
336
+ probe = np.empty(probe_shape)
337
+ expanded_shape = np.expand_dims(probe, axis=axis).shape
338
+ remapped = None
339
+ if group is not None:
340
+ axis_map = {
341
+ old_axis: expanded_shape.index(size)
342
+ for old_axis, size in enumerate(probe_shape)
343
+ }
344
+ remapped = remap_group_axes(group, axis_map)
345
+ inserted_axes = tuple(
346
+ axis_idx for axis_idx, size in enumerate(expanded_shape) if size == 1
347
+ )
348
+ inserted = inserted_axes_symmetry(inserted_axes)
349
+ return direct_product_groups(remapped, inserted)
350
+
351
+
352
+ def inserted_axes_symmetry(
353
+ inserted_positions: Sequence[int],
354
+ ) -> SymmetryGroup | None:
355
+ """Symmetry of N freshly-inserted size-1 axes at the given output positions.
356
+
357
+ Used by axis-inserting operations (``expand_dims``, ``__getitem__`` with
358
+ ``None``/``np.newaxis``). Returns ``None`` for fewer than 2 positions
359
+ (no non-trivial group). For 2+, returns
360
+ ``SymmetryGroup.symmetric(axes=tuple(inserted_positions))``.
361
+ """
362
+ if len(inserted_positions) < 2:
363
+ return None
364
+ return SymmetryGroup.symmetric(axes=tuple(inserted_positions))
365
+
366
+
367
+ def intersect_groups(
368
+ a: SymmetryGroup | None,
369
+ b: SymmetryGroup | None,
370
+ *,
371
+ ndim: int,
372
+ ) -> SymmetryGroup | None:
373
+ """Intersect two groups after embedding them into the same tensor rank."""
374
+ if a is None or b is None:
375
+ return None
376
+ # Easy known-kind case — same group ∩ itself = itself. Preserve
377
+ # provenance without enumeration. Skip trivial groups (order <= 1)
378
+ # so the existing "None means no symmetry" convention holds.
379
+ if a._known_kind is not None and a._known_kind == b._known_kind:
380
+ if a.order() <= 1:
381
+ return None
382
+ return a
383
+ if a.axes is not None and b.axes is not None and a.axes == b.axes:
384
+ common = sorted(
385
+ set(a.elements()) & set(b.elements()),
386
+ key=lambda perm: tuple(perm.array_form),
387
+ )
388
+ if len(common) <= 1:
389
+ return None
390
+ return SymmetryGroup(*common, axes=a.axes)
391
+ embedded_a = embed_group(a, ndim)
392
+ embedded_b = embed_group(b, ndim)
393
+ assert embedded_a is not None
394
+ assert embedded_b is not None
395
+ common = sorted(
396
+ set(embedded_a.elements()) & set(embedded_b.elements()),
397
+ key=lambda perm: tuple(perm.array_form),
398
+ )
399
+ if len(common) <= 1:
400
+ return None
401
+ return SymmetryGroup(*common, axes=tuple(range(ndim)))
402
+
403
+
404
+ def direct_product_groups(*groups: SymmetryGroup | None) -> SymmetryGroup | None:
405
+ """Compose disjoint groups, dropping trivial and absent factors."""
406
+ factors = []
407
+ for group in groups:
408
+ if group is None:
409
+ continue
410
+ validate_symmetry_group(group)
411
+ if group.order() > 1:
412
+ factors.append(group)
413
+ if not factors:
414
+ return None
415
+ if len(factors) == 1:
416
+ return factors[0]
417
+ product = SymmetryGroup.direct_product(*factors)
418
+ return product if product.order() > 1 else None
419
+
420
+
421
+ def setwise_stabilizer(
422
+ group: SymmetryGroup | None,
423
+ fixed_set: Iterable[int],
424
+ ) -> SymmetryGroup | None:
425
+ """Return the subgroup G' = {π ∈ G : π(fixed_set) = fixed_set}.
426
+
427
+ `fixed_set` is interpreted as tensor-axis indices; elements not in
428
+ ``group.axes`` are silently filtered out. Returns ``None`` if the
429
+ stabilizer is trivial (order ≤ 1).
430
+ """
431
+ if group is None:
432
+ return None
433
+ axes = group.axes
434
+ if axes is None:
435
+ axes = tuple(range(group.degree))
436
+ # Translate tensor-axis indices to internal degree indices.
437
+ internal = {axes.index(a) for a in fixed_set if a in axes}
438
+ result = group.setwise_stabilizer(internal)
439
+ return result if result.order() > 1 else None
440
+
441
+
442
+ def group_orbits_on_axes(
443
+ group: SymmetryGroup,
444
+ axes: Sequence[int],
445
+ ) -> list[set[int]]:
446
+ """Return the orbits of `group`'s action on the given tensor `axes`.
447
+
448
+ Axes not acted on by `group` are returned as singleton orbits. Output
449
+ order is deterministic (axes appear in their first-encounter order).
450
+ """
451
+ axis_list = list(axes)
452
+ group_axes = group.axes
453
+ if group_axes is None:
454
+ group_axes = tuple(range(group.degree))
455
+ # Map: tensor-axis -> set of tensor-axes reachable by any generator.
456
+ # For axes outside group_axes, the orbit is just itself.
457
+ seen: set[int] = set()
458
+ orbits: list[set[int]] = []
459
+ for a in axis_list:
460
+ if a in seen:
461
+ continue
462
+ if a not in group_axes:
463
+ orbits.append({a})
464
+ seen.add(a)
465
+ continue
466
+ orbit: set[int] = set()
467
+ frontier = {a}
468
+ while frontier:
469
+ x = frontier.pop()
470
+ if x in orbit:
471
+ continue
472
+ orbit.add(x)
473
+ local_x = group_axes.index(x)
474
+ for generator in group.generators:
475
+ local_y = generator.array_form[local_x]
476
+ y = group_axes[local_y]
477
+ if y not in orbit:
478
+ frontier.add(y)
479
+ orbits.append(orbit)
480
+ seen |= orbit
481
+ return orbits
482
+
483
+
484
+ def _normalize_reps_for_output(reps, *, output_ndim: int) -> tuple[int, ...]:
485
+ """Normalize `reps` arg to a tuple of length `output_ndim`.
486
+
487
+ Matches NumPy.tile's right-alignment rule: if `reps` is shorter, it's
488
+ prepended with 1s. If `reps` is a scalar, treat as `(reps,)`.
489
+ """
490
+ if isinstance(reps, int):
491
+ reps_tup = (reps,)
492
+ else:
493
+ reps_tup = tuple(reps)
494
+ if len(reps_tup) < output_ndim:
495
+ reps_tup = (1,) * (output_ndim - len(reps_tup)) + reps_tup
496
+ return reps_tup
497
+
498
+
499
+ def broadcast_group(
500
+ group: SymmetryGroup | None,
501
+ *,
502
+ input_shape: tuple[int, ...],
503
+ output_shape: tuple[int, ...],
504
+ ) -> SymmetryGroup | None:
505
+ """Broadcast a single input symmetry group onto an output shape."""
506
+ if len(input_shape) > len(output_shape):
507
+ raise ValueError("input rank cannot exceed output rank")
508
+
509
+ factors: list[SymmetryGroup] = []
510
+ offset = len(output_shape) - len(input_shape)
511
+
512
+ created_by_size: OrderedDict[int, list[int]] = OrderedDict()
513
+ for axis in range(offset):
514
+ created_by_size.setdefault(output_shape[axis], []).append(axis)
515
+ for block in created_by_size.values():
516
+ if len(block) >= 2:
517
+ factors.append(SymmetryGroup.symmetric(axes=tuple(block)))
518
+
519
+ if group is not None:
520
+ validate_symmetry_group(group, ndim=len(input_shape), shape=input_shape)
521
+ axes = group.axes
522
+ if axes is None:
523
+ axes = tuple(range(group.degree))
524
+ kept_local = []
525
+ for local_idx, axis in enumerate(axes):
526
+ out_axis = axis + offset
527
+ if input_shape[axis] == 1 and output_shape[out_axis] > 1:
528
+ continue
529
+ kept_local.append(local_idx)
530
+ if len(kept_local) >= 2:
531
+ restricted = (
532
+ group
533
+ if len(kept_local) == group.degree
534
+ else group.restrict(tuple(kept_local))
535
+ )
536
+ restricted_axes = (
537
+ restricted.axes
538
+ if restricted.axes is not None
539
+ else tuple(range(restricted.degree))
540
+ )
541
+ remapped = remap_group_axes(
542
+ restricted,
543
+ {
544
+ restricted_axes[new_local_idx]: axes[old_local_idx] + offset
545
+ for new_local_idx, old_local_idx in enumerate(kept_local)
546
+ },
547
+ )
548
+ if remapped is not None and remapped.order() > 1:
549
+ factors.append(remapped)
550
+
551
+ return direct_product_groups(*factors)
552
+
553
+
554
+ def reduce_group(
555
+ group: SymmetryGroup | None,
556
+ *,
557
+ ndim: int,
558
+ axis: int | tuple[int, ...] | None,
559
+ keepdims: bool = False,
560
+ ) -> SymmetryGroup | None:
561
+ """Propagate a single symmetry group through a reduction."""
562
+ if group is None or axis is None:
563
+ return None
564
+ validate_symmetry_group(group, ndim=ndim)
565
+ axes_set = {axis % ndim} if isinstance(axis, int) else {a % ndim for a in axis}
566
+ old_to_new: dict[int, int] = {}
567
+ if keepdims:
568
+ old_to_new = {dim: dim for dim in range(ndim)}
569
+ else:
570
+ new_idx = 0
571
+ for dim in range(ndim):
572
+ if dim not in axes_set:
573
+ old_to_new[dim] = new_idx
574
+ new_idx += 1
575
+
576
+ # Fast path for known-kind groups: compute the reduced kind directly
577
+ # and route through the appropriate factory. Avoids _dimino entirely.
578
+ if group._known_kind is not None:
579
+ reduced_kind = _reduced_kind(
580
+ group._known_kind, reduced_axes=axes_set, axis_map=old_to_new
581
+ )
582
+ if reduced_kind is not None:
583
+ return _build_from_kind(reduced_kind)
584
+ # Fall through to the generic path if the kind can't be reduced
585
+ # in closed form (e.g. partial reduction of a direct_product child
586
+ # whose own kind doesn't survive).
587
+
588
+ group_axes = group.axes
589
+ if group_axes is None:
590
+ group_axes = tuple(range(group.degree))
591
+ local_reduced = {
592
+ i for i, tensor_axis in enumerate(group_axes) if tensor_axis in axes_set
593
+ }
594
+ local_kept = [
595
+ i for i, tensor_axis in enumerate(group_axes) if tensor_axis not in axes_set
596
+ ]
597
+
598
+ if not local_reduced:
599
+ remapped = remap_group_axes(
600
+ group,
601
+ {tensor_axis: old_to_new[tensor_axis] for tensor_axis in group_axes},
602
+ )
603
+ return remapped if remapped is not None and remapped.order() > 1 else None
604
+ if not local_kept:
605
+ return None
606
+
607
+ stabilized = group.setwise_stabilizer(local_reduced)
608
+ restricted = stabilized.restrict(tuple(local_kept))
609
+ if restricted.order() <= 1:
610
+ return None
611
+ restricted_axes = (
612
+ restricted.axes
613
+ if restricted.axes is not None
614
+ else tuple(range(restricted.degree))
615
+ )
616
+ remapped = remap_group_axes(
617
+ restricted,
618
+ {
619
+ restricted_axes[new_local_idx]: old_to_new[group_axes[old_local_idx]]
620
+ for new_local_idx, old_local_idx in enumerate(local_kept)
621
+ },
622
+ )
623
+ return remapped if remapped is not None and remapped.order() > 1 else None
624
+
625
+
626
+ def wrap_with_symmetry(data, symmetry: SymmetryGroup | None):
627
+ """Wrap ndarray-like data with symmetry metadata when a group is present."""
628
+ array = np.asarray(data)
629
+ if symmetry is None:
630
+ return array
631
+ validate_symmetry_group(symmetry, ndim=array.ndim)
632
+ from flopscope._symmetric import SymmetricTensor
633
+
634
+ return SymmetricTensor(array, symmetry=symmetry)
635
+
636
+
637
+ def wrap_with_trusted_symmetry(data, symmetry: SymmetryGroup | None):
638
+ """Wrap data with already-proven symmetry metadata without re-validating.
639
+
640
+ This helper is for internal call sites only, where the symmetry was
641
+ generated or revalidated by trusted constructor logic. Avoiding the
642
+ redundant validation call keeps constructor hot paths fast while leaving
643
+ public/user-facing symmetry paths fully validated.
644
+ """
645
+ array = np.asarray(data)
646
+ if symmetry is None:
647
+ return array
648
+ from flopscope._symmetric import SymmetricTensor
649
+
650
+ return SymmetricTensor(array, symmetry=symmetry)
651
+
652
+
653
+ def wrap_with_inferred_symmetry(data, symmetry: SymmetryGroup | None):
654
+ """Wrap data with auto-inferred symmetry metadata.
655
+
656
+ Identical to :func:`wrap_with_trusted_symmetry` except the resulting
657
+ array carries ``_symmetry_inferred = True``. Read by
658
+ ``_prepare_symmetric_out`` to decide whether a non-symmetric ``out=``
659
+ write should silently downgrade the target (inferred) or raise
660
+ (explicit). Internal call sites only — never expose to user code.
661
+ """
662
+ array = np.asarray(data)
663
+ if symmetry is None:
664
+ return array
665
+ from flopscope._symmetric import SymmetricTensor
666
+
667
+ obj = SymmetricTensor(array, symmetry=symmetry)
668
+ obj._symmetry_inferred = True
669
+ return obj
@@ -0,0 +1,12 @@
1
+ """Dtype introspection utilities re-exported as free (0-FLOP) helpers.
2
+
3
+ `finfo` and `iinfo` are metadata queries that return info objects
4
+ (eps, max, min, bits, etc.). They perform no computation.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import numpy as _np
10
+
11
+ finfo = _np.finfo
12
+ iinfo = _np.iinfo