flopscope 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. benchmarks/__init__.py +1 -0
  2. benchmarks/__main__.py +6 -0
  3. benchmarks/_baseline.py +171 -0
  4. benchmarks/_bitwise.py +231 -0
  5. benchmarks/_complex.py +176 -0
  6. benchmarks/_contractions.py +291 -0
  7. benchmarks/_fft.py +198 -0
  8. benchmarks/_impl_urls.py +139 -0
  9. benchmarks/_linalg.py +197 -0
  10. benchmarks/_linalg_delegates.py +407 -0
  11. benchmarks/_metadata.py +141 -0
  12. benchmarks/_misc.py +653 -0
  13. benchmarks/_perf.py +321 -0
  14. benchmarks/_perm_group_calibration.py +175 -0
  15. benchmarks/_pointwise.py +372 -0
  16. benchmarks/_polynomial.py +193 -0
  17. benchmarks/_random.py +209 -0
  18. benchmarks/_reductions.py +136 -0
  19. benchmarks/_sorting.py +289 -0
  20. benchmarks/_stats.py +137 -0
  21. benchmarks/_window.py +92 -0
  22. benchmarks/accumulation/__init__.py +0 -0
  23. benchmarks/accumulation/bench_cost_compute.py +138 -0
  24. benchmarks/dashboard.py +312 -0
  25. benchmarks/runner.py +636 -0
  26. flopscope/__init__.py +273 -0
  27. flopscope/_accumulation/__init__.py +13 -0
  28. flopscope/_accumulation/_bipartite.py +121 -0
  29. flopscope/_accumulation/_burnside.py +51 -0
  30. flopscope/_accumulation/_cache.py +146 -0
  31. flopscope/_accumulation/_components.py +153 -0
  32. flopscope/_accumulation/_cost.py +1414 -0
  33. flopscope/_accumulation/_cost_descriptions.py +63 -0
  34. flopscope/_accumulation/_detection.py +318 -0
  35. flopscope/_accumulation/_ladder.py +191 -0
  36. flopscope/_accumulation/_output_orbit.py +104 -0
  37. flopscope/_accumulation/_partition.py +290 -0
  38. flopscope/_accumulation/_path_info.py +211 -0
  39. flopscope/_accumulation/_public.py +169 -0
  40. flopscope/_accumulation/_reduction.py +310 -0
  41. flopscope/_accumulation/_regimes.py +303 -0
  42. flopscope/_accumulation/_shape.py +33 -0
  43. flopscope/_accumulation/_wreath.py +209 -0
  44. flopscope/_budget.py +1027 -0
  45. flopscope/_config.py +118 -0
  46. flopscope/_counting_ops.py +451 -0
  47. flopscope/_display.py +478 -0
  48. flopscope/_docstrings.py +59 -0
  49. flopscope/_dtypes.py +20 -0
  50. flopscope/_einsum.py +717 -0
  51. flopscope/_errstate.py +25 -0
  52. flopscope/_flops.py +282 -0
  53. flopscope/_free_ops.py +2654 -0
  54. flopscope/_ndarray.py +1126 -0
  55. flopscope/_opt_einsum/LICENSE +21 -0
  56. flopscope/_opt_einsum/NOTICE +59 -0
  57. flopscope/_opt_einsum/__init__.py +209 -0
  58. flopscope/_opt_einsum/_contract.py +1478 -0
  59. flopscope/_opt_einsum/_helpers.py +164 -0
  60. flopscope/_opt_einsum/_hsluv.py +273 -0
  61. flopscope/_opt_einsum/_path_random.py +462 -0
  62. flopscope/_opt_einsum/_paths.py +1653 -0
  63. flopscope/_opt_einsum/_subgraph_symmetry.py +544 -0
  64. flopscope/_opt_einsum/_symmetry.py +140 -0
  65. flopscope/_opt_einsum/_typing.py +37 -0
  66. flopscope/_perm_group.py +717 -0
  67. flopscope/_pointwise.py +2522 -0
  68. flopscope/_polynomial.py +278 -0
  69. flopscope/_registry.py +3216 -0
  70. flopscope/_sorting_ops.py +571 -0
  71. flopscope/_symmetric.py +812 -0
  72. flopscope/_symmetry_transport.py +510 -0
  73. flopscope/_symmetry_utils.py +669 -0
  74. flopscope/_type_info.py +12 -0
  75. flopscope/_unwrap.py +70 -0
  76. flopscope/_validation.py +83 -0
  77. flopscope/_version_check.py +46 -0
  78. flopscope/_weights.py +195 -0
  79. flopscope/_window.py +177 -0
  80. flopscope/accounting.py +565 -0
  81. flopscope/data/default_weights.json +462 -0
  82. flopscope/data/weights.csv +509 -0
  83. flopscope/errors.py +197 -0
  84. flopscope/numpy/__init__.py +878 -0
  85. flopscope/numpy/fft/__init__.py +55 -0
  86. flopscope/numpy/fft/_free.py +51 -0
  87. flopscope/numpy/fft/_transforms.py +695 -0
  88. flopscope/numpy/linalg/__init__.py +105 -0
  89. flopscope/numpy/linalg/_aliases.py +126 -0
  90. flopscope/numpy/linalg/_compound.py +161 -0
  91. flopscope/numpy/linalg/_decompositions.py +353 -0
  92. flopscope/numpy/linalg/_properties.py +533 -0
  93. flopscope/numpy/linalg/_solvers.py +444 -0
  94. flopscope/numpy/linalg/_svd.py +122 -0
  95. flopscope/numpy/random/__init__.py +684 -0
  96. flopscope/numpy/random/_cost_formulas.py +115 -0
  97. flopscope/numpy/random/_counted_classes.py +241 -0
  98. flopscope/numpy/testing/__init__.py +13 -0
  99. flopscope/numpy/typing/__init__.py +30 -0
  100. flopscope/py.typed +0 -0
  101. flopscope/stats/__init__.py +84 -0
  102. flopscope/stats/_base.py +77 -0
  103. flopscope/stats/_cauchy.py +146 -0
  104. flopscope/stats/_erf.py +190 -0
  105. flopscope/stats/_expon.py +146 -0
  106. flopscope/stats/_laplace.py +150 -0
  107. flopscope/stats/_logistic.py +148 -0
  108. flopscope/stats/_lognorm.py +160 -0
  109. flopscope/stats/_ndtri.py +133 -0
  110. flopscope/stats/_norm.py +149 -0
  111. flopscope/stats/_truncnorm.py +186 -0
  112. flopscope/stats/_uniform.py +141 -0
  113. flopscope-0.2.0.dist-info/METADATA +23 -0
  114. flopscope-0.2.0.dist-info/RECORD +115 -0
  115. flopscope-0.2.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,717 @@
1
+ """Symmetry groups for exact tensor symmetry representation.
2
+
3
+ Provides the public ``SymmetryGroup`` API plus private permutation helper
4
+ objects used internally for exact finite-group algorithms.
5
+
6
+ Core algorithms:
7
+ - Dimino's algorithm for group enumeration from generators
8
+ (Butler & McKay, Comm. in Algebra, 1983)
9
+ - Burnside's lemma for orbit counting (Burnside, 1897)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ import weakref
16
+ from functools import reduce
17
+ from typing import Any
18
+
19
+ __all__ = ["SymmetryGroup"]
20
+
21
+ _GROUP_INTERN: weakref.WeakValueDictionary[tuple, SymmetryGroup] = (
22
+ weakref.WeakValueDictionary()
23
+ )
24
+
25
+
26
+ class _Cycle:
27
+ """Composable private cycle builder."""
28
+
29
+ __slots__ = ("_mapping",)
30
+
31
+ def __init__(self, *cycle: int) -> None:
32
+ self._mapping: dict[int, int] = {}
33
+ if cycle:
34
+ for i in range(len(cycle)):
35
+ self._mapping[cycle[i]] = cycle[(i + 1) % len(cycle)]
36
+
37
+ def __call__(self, *cycle: int) -> _Cycle:
38
+ new = _Cycle()
39
+ new._mapping = dict(self._mapping)
40
+ if cycle:
41
+ new_cycle_map: dict[int, int] = {}
42
+ for i in range(len(cycle)):
43
+ new_cycle_map[cycle[i]] = cycle[(i + 1) % len(cycle)]
44
+ combined: dict[int, int] = {}
45
+ all_points = set(new._mapping) | set(new_cycle_map)
46
+ for x in all_points:
47
+ y = new._mapping.get(x, x)
48
+ z = new_cycle_map.get(y, y)
49
+ if z != x:
50
+ combined[x] = z
51
+ new._mapping = combined
52
+ return new
53
+
54
+ def list(self, size: int | None = None) -> list[int]:
55
+ if size is None:
56
+ size = max(self._mapping.keys(), default=-1) + 1
57
+ size = max(size, max(self._mapping.values(), default=-1) + 1)
58
+ arr = list(range(size))
59
+ for k, v in self._mapping.items():
60
+ if k < size:
61
+ arr[k] = v
62
+ return arr
63
+
64
+
65
+ class _Permutation:
66
+ """A private permutation on {0, 1, ..., n-1} in array form."""
67
+
68
+ __slots__ = ("_array_form",)
69
+
70
+ def __init__(
71
+ self,
72
+ array_form: list[int] | tuple[int, ...] | _Cycle,
73
+ size: int | None = None,
74
+ ) -> None:
75
+ if isinstance(array_form, _Cycle):
76
+ self._array_form = tuple(array_form.list(size))
77
+ elif array_form and isinstance(array_form[0], (list, tuple)):
78
+ c = _Cycle()
79
+ for cycle in array_form:
80
+ c = c(*cycle) # type: ignore[call-arg]
81
+ self._array_form = tuple(c.list(size))
82
+ else:
83
+ arr = list(array_form)
84
+ if size is not None and size > len(arr):
85
+ arr.extend(range(len(arr), size))
86
+ self._array_form = tuple(arr)
87
+
88
+ @property
89
+ def size(self) -> int:
90
+ return len(self._array_form)
91
+
92
+ @property
93
+ def array_form(self) -> list[int]:
94
+ return list(self._array_form)
95
+
96
+ @property
97
+ def is_identity(self) -> bool:
98
+ return all(self._array_form[i] == i for i in range(len(self._array_form)))
99
+
100
+ @classmethod
101
+ def identity(cls, size: int) -> _Permutation:
102
+ return cls(list(range(size)))
103
+
104
+ @classmethod
105
+ def from_cycle(cls, size: int, cycle: list[int]) -> _Permutation:
106
+ arr = list(range(size))
107
+ for i in range(len(cycle)):
108
+ arr[cycle[i]] = cycle[(i + 1) % len(cycle)]
109
+ return cls(arr)
110
+
111
+ def __mul__(self, other: _Permutation) -> _Permutation:
112
+ return _Permutation(
113
+ tuple(self._array_form[other._array_form[i]] for i in range(self.size))
114
+ )
115
+
116
+ def __invert__(self) -> _Permutation:
117
+ inv = [0] * self.size
118
+ for i, j in enumerate(self._array_form):
119
+ inv[j] = i
120
+ return _Permutation(inv)
121
+
122
+ def __eq__(self, other: object) -> bool:
123
+ if not isinstance(other, _Permutation):
124
+ return NotImplemented
125
+ return self._array_form == other._array_form
126
+
127
+ def __hash__(self) -> int:
128
+ return hash(self._array_form)
129
+
130
+ def __repr__(self) -> str:
131
+ return f"_Permutation({list(self._array_form)})"
132
+
133
+ @property
134
+ def cyclic_form(self) -> list[tuple[int, ...]]:
135
+ visited: set[int] = set()
136
+ cycles: list[tuple[int, ...]] = []
137
+ for i in range(self.size):
138
+ if i in visited or self._array_form[i] == i:
139
+ visited.add(i)
140
+ continue
141
+ cycle: list[int] = []
142
+ j = i
143
+ while j not in visited:
144
+ cycle.append(j)
145
+ visited.add(j)
146
+ j = self._array_form[j]
147
+ cycles.append(tuple(cycle))
148
+ return cycles
149
+
150
+ @property
151
+ def full_cyclic_form(self) -> list[tuple[int, ...]]:
152
+ visited: set[int] = set()
153
+ cycles: list[tuple[int, ...]] = []
154
+ for i in range(self.size):
155
+ if i in visited:
156
+ continue
157
+ cycle: list[int] = []
158
+ j = i
159
+ while j not in visited:
160
+ cycle.append(j)
161
+ visited.add(j)
162
+ j = self._array_form[j]
163
+ cycles.append(tuple(cycle))
164
+ return cycles
165
+
166
+ @property
167
+ def cycle_structure(self) -> dict[int, int]:
168
+ result: dict[int, int] = {}
169
+ for cycle in self.cyclic_form:
170
+ length = len(cycle)
171
+ result[length] = result.get(length, 0) + 1
172
+ return result
173
+
174
+ @property
175
+ def order(self) -> int:
176
+ lengths = [len(c) for c in self.full_cyclic_form]
177
+ if not lengths:
178
+ return 1
179
+ return reduce(lambda a, b: a * b // math.gcd(a, b), lengths)
180
+
181
+ def __call__(self, i: int) -> int:
182
+ return self._array_form[i]
183
+
184
+ def support(self) -> set[int]:
185
+ return {i for i in range(self.size) if self._array_form[i] != i}
186
+
187
+ def parity(self) -> int:
188
+ return sum(len(c) - 1 for c in self.cyclic_form) % 2
189
+
190
+ def signature(self) -> int:
191
+ return 1 if self.parity() == 0 else -1
192
+
193
+ def transpositions(self) -> list[tuple[int, int]]:
194
+ result: list[tuple[int, int]] = []
195
+ for cycle in self.cyclic_form:
196
+ for i in range(1, len(cycle)):
197
+ result.append((cycle[0], cycle[i]))
198
+ return result
199
+
200
+ def as_sympy(self):
201
+ try:
202
+ from sympy.combinatorics import Permutation as SPermutation
203
+ except ImportError:
204
+ raise ImportError(
205
+ "sympy is required for as_sympy(). Install with: pip install sympy"
206
+ ) from None
207
+ return SPermutation(self.array_form)
208
+
209
+ @classmethod
210
+ def from_sympy(cls, sp) -> _Permutation:
211
+ return cls(sp.array_form)
212
+
213
+
214
+ def _normalize_axes(axes: tuple[Any, ...] | list[Any]) -> tuple[Any, ...]:
215
+ norm_axes = tuple(axes)
216
+ if not norm_axes:
217
+ raise ValueError("axes must be non-empty")
218
+ if len(set(norm_axes)) != len(norm_axes):
219
+ raise ValueError("axes must be unique")
220
+ return norm_axes
221
+
222
+
223
+ def _normalize_generator_literal(
224
+ generator: list[int] | tuple[int, ...], *, degree: int
225
+ ) -> _Permutation:
226
+ arr = list(generator)
227
+ if len(arr) != degree:
228
+ raise ValueError(
229
+ f"Generator literal has degree {len(arr)}, expected degree {degree}"
230
+ )
231
+ if sorted(arr) != list(range(degree)):
232
+ raise ValueError(
233
+ f"Generator literal {arr} is not a bijection on range({degree})"
234
+ )
235
+ return _Permutation(arr)
236
+
237
+
238
+ def _closed_form_order(kind: tuple) -> int:
239
+ """Compute ``|G|`` for a known-kind structural fingerprint.
240
+
241
+ The tag layout is recursive: leaf kinds are ``(name, axes_tuple)``
242
+ where ``name`` is one of ``"identity" | "symmetric" | "cyclic" |
243
+ "dihedral"``; ``"direct_product"`` carries ``(name, children_tuple)``
244
+ where each child is itself a kind tuple.
245
+ """
246
+ name = kind[0]
247
+ if name == "identity":
248
+ return 1
249
+ if name == "symmetric":
250
+ return math.factorial(len(kind[1]))
251
+ if name == "cyclic":
252
+ return len(kind[1])
253
+ if name == "dihedral":
254
+ return 2 * len(kind[1])
255
+ if name == "direct_product":
256
+ return math.prod(_closed_form_order(child) for child in kind[1])
257
+ raise AssertionError(f"unknown kind {kind!r}")
258
+
259
+
260
+ class SymmetryGroup:
261
+ """A finite symmetry group defined by explicit generators."""
262
+
263
+ __slots__ = (
264
+ "__weakref__",
265
+ "_generators",
266
+ "_degree",
267
+ "_axes",
268
+ "_elements",
269
+ "_order",
270
+ "_labels",
271
+ "_canonical_action_cache",
272
+ "_known_kind",
273
+ )
274
+
275
+ def __init__(
276
+ self,
277
+ *generators: _Permutation,
278
+ axes: tuple[Any, ...] | None = None,
279
+ ) -> None:
280
+ if not generators:
281
+ raise ValueError(
282
+ "At least one generator required (use _Permutation.identity(n) for the trivial group)"
283
+ )
284
+ degrees = {g.size for g in generators}
285
+ if len(degrees) != 1:
286
+ raise ValueError(f"All generators must have the same size, got {degrees}")
287
+ self._generators = generators
288
+ self._degree = generators[0].size
289
+ if axes is not None:
290
+ axes = _normalize_axes(axes)
291
+ if len(axes) != self._degree:
292
+ raise ValueError(
293
+ f"axes has length {len(axes)}, expected {self._degree}"
294
+ )
295
+ self._axes = axes
296
+ self._elements: list[_Permutation] | None = None
297
+ self._order: int | None = None
298
+ self._known_kind: tuple | None = None
299
+ self._labels: tuple[str, ...] | None = None
300
+ self._canonical_action_cache: (
301
+ tuple[
302
+ tuple[str, ...] | None,
303
+ tuple[tuple[Any, ...], tuple[tuple[Any, ...], ...]],
304
+ ]
305
+ | None
306
+ ) = None
307
+
308
+ @classmethod
309
+ def from_generators(
310
+ cls,
311
+ generators: list[list[int] | tuple[int, ...]]
312
+ | tuple[list[int] | tuple[int, ...], ...],
313
+ *,
314
+ axes: tuple[Any, ...] | list[Any],
315
+ ) -> SymmetryGroup:
316
+ norm_axes = _normalize_axes(axes)
317
+ if not generators:
318
+ raise ValueError("At least one generator literal is required")
319
+ norm_generators = tuple(
320
+ _normalize_generator_literal(generator, degree=len(norm_axes))
321
+ for generator in generators
322
+ )
323
+ return cls(*norm_generators, axes=norm_axes)
324
+
325
+ @property
326
+ def degree(self) -> int:
327
+ return self._degree
328
+
329
+ @property
330
+ def generators(self) -> list[_Permutation]:
331
+ return list(self._generators)
332
+
333
+ @property
334
+ def generator_literals(self) -> list[list[int]]:
335
+ return [generator.array_form for generator in self._generators]
336
+
337
+ @property
338
+ def axes(self) -> tuple[Any, ...] | None:
339
+ return self._axes
340
+
341
+ def to_payload(self) -> dict[str, list[Any] | list[list[int]]]:
342
+ if self._axes is None:
343
+ raise ValueError("Cannot serialize a SymmetryGroup without axes")
344
+ return {"axes": list(self._axes), "generators": self.generator_literals}
345
+
346
+ @classmethod
347
+ def from_payload(cls, payload: dict[str, Any]) -> SymmetryGroup:
348
+ return cls.from_generators(payload["generators"], axes=tuple(payload["axes"]))
349
+
350
+ def elements(self) -> list[_Permutation]:
351
+ if self._elements is not None:
352
+ return self._elements
353
+ self._elements = _dimino(self._generators)
354
+ self._order = len(self._elements)
355
+ return self._elements
356
+
357
+ def order(self) -> int:
358
+ if self._order is not None:
359
+ return self._order
360
+ if self._known_kind is not None:
361
+ self._order = _closed_form_order(self._known_kind)
362
+ return self._order
363
+ self._order = len(self.elements())
364
+ return self._order
365
+
366
+ def is_symmetric(self) -> bool:
367
+ return self.order() == math.factorial(self._degree)
368
+
369
+ def orbits(self) -> list[frozenset[int]]:
370
+ parent = list(range(self._degree))
371
+
372
+ def find(x: int) -> int:
373
+ while parent[x] != x:
374
+ parent[x] = parent[parent[x]]
375
+ x = parent[x]
376
+ return x
377
+
378
+ def union(a: int, b: int) -> None:
379
+ ra, rb = find(a), find(b)
380
+ if ra != rb:
381
+ parent[ra] = rb
382
+
383
+ for g in self._generators:
384
+ for i in range(self._degree):
385
+ if g(i) != i:
386
+ union(i, g(i))
387
+
388
+ groups: dict[int, set[int]] = {}
389
+ for i in range(self._degree):
390
+ groups.setdefault(find(i), set()).add(i)
391
+ return [frozenset(s) for s in groups.values()]
392
+
393
+ def contains(self, perm: _Permutation) -> bool:
394
+ return perm in set(self.elements())
395
+
396
+ @property
397
+ def is_transitive(self) -> bool:
398
+ return len(self.orbits()) == 1
399
+
400
+ @property
401
+ def is_abelian(self) -> bool:
402
+ gens = self._generators
403
+ for i in range(len(gens)):
404
+ for j in range(i + 1, len(gens)):
405
+ if gens[i] * gens[j] != gens[j] * gens[i]:
406
+ return False
407
+ return True
408
+
409
+ @property
410
+ def identity(self) -> _Permutation:
411
+ return _Permutation.identity(self._degree)
412
+
413
+ def _semantic_domain(self) -> tuple[Any, ...]:
414
+ if self._labels is not None:
415
+ return tuple(self._labels)
416
+ if self._axes is not None:
417
+ return self._axes
418
+ return tuple(range(self._degree))
419
+
420
+ def _canonical_axis_action(
421
+ self,
422
+ ) -> tuple[tuple[Any, ...], tuple[tuple[Any, ...], ...]]:
423
+ cached = self._canonical_action_cache
424
+ if cached is not None and cached[0] is self._labels:
425
+ return cached[1]
426
+ domain = self._semantic_domain()
427
+ labelled_axes = tuple(sorted(domain, key=repr))
428
+ actions = []
429
+ for elem in self.elements():
430
+ mapping = {domain[i]: domain[j] for i, j in enumerate(elem.array_form)}
431
+ actions.append(tuple(mapping[axis] for axis in labelled_axes))
432
+ result = (labelled_axes, tuple(sorted(actions)))
433
+ self._canonical_action_cache = (self._labels, result)
434
+ return result
435
+
436
+ def __eq__(self, other: object) -> bool:
437
+ if self is other:
438
+ return True
439
+ if not isinstance(other, SymmetryGroup):
440
+ return NotImplemented
441
+ return self._canonical_axis_action() == other._canonical_axis_action()
442
+
443
+ def __hash__(self) -> int:
444
+ return hash(self._canonical_axis_action())
445
+
446
+ def equals(self, other: SymmetryGroup) -> bool:
447
+ return self == other
448
+
449
+ def orbit(self, alpha: int) -> frozenset[int]:
450
+ visited: set[int] = {alpha}
451
+ queue: list[int] = [alpha]
452
+ while queue:
453
+ point = queue.pop()
454
+ for g in self._generators:
455
+ image = g(point)
456
+ if image not in visited:
457
+ visited.add(image)
458
+ queue.append(image)
459
+ return frozenset(visited)
460
+
461
+ def pointwise_stabilizer(self, fixed: set[int]) -> SymmetryGroup:
462
+ if not fixed:
463
+ return SymmetryGroup(*self._generators, axes=self._axes)
464
+ surviving = [g for g in self.elements() if all(g(p) == p for p in fixed)]
465
+ if not surviving:
466
+ surviving = [_Permutation.identity(self._degree)]
467
+ return SymmetryGroup(*surviving, axes=self._axes)
468
+
469
+ def setwise_stabilizer(self, subset: set[int]) -> SymmetryGroup:
470
+ if not subset or subset == set(range(self._degree)):
471
+ return SymmetryGroup(*self._generators, axes=self._axes)
472
+ frozen = frozenset(subset)
473
+ surviving = [
474
+ g for g in self.elements() if frozenset(g(x) for x in frozen) == frozen
475
+ ]
476
+ if not surviving:
477
+ surviving = [_Permutation.identity(self._degree)]
478
+ return SymmetryGroup(*surviving, axes=self._axes)
479
+
480
+ def restrict(self, kept: tuple[int, ...]) -> SymmetryGroup:
481
+ new_degree = len(kept)
482
+ if new_degree == 0:
483
+ raise ValueError("kept must be non-empty")
484
+
485
+ old_to_new = {old: new for new, old in enumerate(kept)}
486
+ projected: set[_Permutation] = set()
487
+ for g in self.elements():
488
+ new_arr = [old_to_new[g(k)] for k in kept]
489
+ projected.add(_Permutation(new_arr))
490
+
491
+ new_axes: tuple[Any, ...] | None = None
492
+ if self._axes is not None:
493
+ new_axes = tuple(self._axes[k] for k in kept)
494
+
495
+ gens = list(projected) if projected else [_Permutation.identity(new_degree)]
496
+ return SymmetryGroup(*gens, axes=new_axes)
497
+
498
+ def burnside_unique_count(self, size_dict: dict[int, int]) -> int:
499
+ for orbit in self.orbits():
500
+ sizes = {size_dict[i] for i in orbit}
501
+ if len(sizes) != 1:
502
+ raise ValueError(
503
+ f"Positions {orbit} are in the same orbit but have different "
504
+ f"dimension sizes {sizes}; all must have the same dimension size"
505
+ )
506
+
507
+ total_fixed = 0
508
+ for g in self.elements():
509
+ fixed = 1
510
+ for cycle in g.full_cyclic_form:
511
+ fixed *= size_dict[cycle[0]]
512
+ total_fixed += fixed
513
+
514
+ count, remainder = divmod(total_fixed, self.order())
515
+ assert remainder == 0, (
516
+ f"Burnside sum {total_fixed} not divisible by |G|={self.order()}"
517
+ )
518
+ return count
519
+
520
+ @classmethod
521
+ def symmetric(cls, *, axes: tuple[Any, ...] | list[Any]) -> SymmetryGroup:
522
+ norm_axes = _normalize_axes(axes)
523
+ k = len(norm_axes)
524
+ if k == 1:
525
+ g = cls(_Permutation.identity(1), axes=norm_axes)
526
+ g._known_kind = ("identity", norm_axes)
527
+ return cls._intern(g)
528
+ gens = []
529
+ for i in range(k - 1):
530
+ arr = list(range(k))
531
+ arr[i], arr[i + 1] = arr[i + 1], arr[i]
532
+ gens.append(_Permutation(arr))
533
+ g = cls(*gens, axes=norm_axes)
534
+ g._known_kind = ("symmetric", norm_axes)
535
+ return cls._intern(g)
536
+
537
+ @classmethod
538
+ def cyclic(cls, *, axes: tuple[Any, ...] | list[Any]) -> SymmetryGroup:
539
+ norm_axes = _normalize_axes(axes)
540
+ k = len(norm_axes)
541
+ if k == 1:
542
+ g = cls(_Permutation.identity(1), axes=norm_axes)
543
+ g._known_kind = ("identity", norm_axes)
544
+ return cls._intern(g)
545
+ gen = _Permutation(list(range(1, k)) + [0])
546
+ g = cls(gen, axes=norm_axes)
547
+ g._known_kind = ("cyclic", norm_axes)
548
+ return cls._intern(g)
549
+
550
+ @classmethod
551
+ def dihedral(cls, *, axes: tuple[Any, ...] | list[Any]) -> SymmetryGroup:
552
+ norm_axes = _normalize_axes(axes)
553
+ k = len(norm_axes)
554
+ if k <= 2:
555
+ return cls.symmetric(axes=norm_axes)
556
+ rotation = _Permutation(list(range(1, k)) + [0])
557
+ reflection = _Permutation([0] + list(range(k - 1, 0, -1)))
558
+ g = cls(rotation, reflection, axes=norm_axes)
559
+ g._known_kind = ("dihedral", norm_axes)
560
+ return cls._intern(g)
561
+
562
+ @classmethod
563
+ def young(
564
+ cls,
565
+ blocks: list[tuple[Any, ...] | list[Any]]
566
+ | tuple[tuple[Any, ...] | list[Any], ...],
567
+ ) -> SymmetryGroup:
568
+ factors = [cls.symmetric(axes=tuple(block)) for block in blocks]
569
+ if not factors:
570
+ raise ValueError("young() requires at least one block")
571
+ return cls.direct_product(*factors)
572
+
573
+ @classmethod
574
+ def direct_product(cls, *groups: SymmetryGroup) -> SymmetryGroup:
575
+ if not groups:
576
+ raise ValueError("direct_product() requires at least one group")
577
+ supports = []
578
+ for group in groups:
579
+ if group.axes is None:
580
+ raise ValueError(
581
+ "SymmetryGroup.direct_product() requires axes on every factor"
582
+ )
583
+ supports.append(set(group.axes))
584
+ for i, support in enumerate(supports):
585
+ for other in supports[i + 1 :]:
586
+ if support & other:
587
+ raise ValueError(
588
+ "SymmetryGroup.direct_product() requires disjoint supports"
589
+ )
590
+
591
+ merged_axes: list[Any] = []
592
+ total_degree = sum(group.degree for group in groups)
593
+ generators: list[_Permutation] = []
594
+ offset = 0
595
+ for group in groups:
596
+ assert group.axes is not None
597
+ merged_axes.extend(group.axes)
598
+ for gen in group.generators:
599
+ arr = list(range(total_degree))
600
+ for i, j in enumerate(gen.array_form):
601
+ arr[offset + i] = offset + j
602
+ generators.append(_Permutation(arr))
603
+ offset += group.degree
604
+
605
+ if not generators:
606
+ generators.append(_Permutation.identity(total_degree))
607
+ g = cls(*generators, axes=tuple(merged_axes))
608
+ child_kinds = tuple(group._known_kind for group in groups)
609
+ if all(kind is not None for kind in child_kinds):
610
+ g._known_kind = ("direct_product", tuple(sorted(child_kinds, key=repr)))
611
+ return cls._intern(g)
612
+
613
+ @classmethod
614
+ def _intern(cls, group: SymmetryGroup) -> SymmetryGroup:
615
+ """Return the canonical instance for ``group``'s known kind.
616
+
617
+ Unknown-kind groups (``_known_kind is None``) are returned as-is
618
+ without registry interaction. Known-kind groups participate in
619
+ process-wide interning by ``_known_kind`` — the first construction
620
+ wins the registry slot; subsequent equivalent constructions return
621
+ the same Python object so caches (``_order``, ``_elements``,
622
+ ``_canonical_action_cache``) are shared across the equivalence
623
+ class.
624
+ """
625
+ if group._known_kind is None:
626
+ return group
627
+ existing = _GROUP_INTERN.get(group._known_kind)
628
+ if existing is not None:
629
+ return existing
630
+ _GROUP_INTERN[group._known_kind] = group
631
+ return group
632
+
633
+ def as_sympy(self):
634
+ try:
635
+ from sympy import combinatorics as _sympy_combinatorics
636
+ except ImportError:
637
+ raise ImportError(
638
+ "sympy is required for as_sympy(). Install with: pip install sympy"
639
+ ) from None
640
+ sympy_group_cls = _sympy_combinatorics.PermutationGroup
641
+ return sympy_group_cls(*[g.as_sympy() for g in self._generators])
642
+
643
+ @classmethod
644
+ def from_sympy(cls, spg, *, axes: tuple[Any, ...] | None = None) -> SymmetryGroup:
645
+ gens = [_Permutation.from_sympy(g) for g in spg.generators]
646
+ return cls(*gens, axes=axes)
647
+
648
+ def __repr__(self) -> str:
649
+ axes_str = f", axes={self._axes}" if self._axes is not None else ""
650
+ literals = ", ".join(repr(g.array_form) for g in self._generators)
651
+ return f"SymmetryGroup({literals}{axes_str})"
652
+
653
+
654
+ _CycleCompat = _Cycle
655
+ _PermutationCompat = _Permutation
656
+ _SymmetryGroupCompat = SymmetryGroup
657
+
658
+
659
+ class _DiminoBudgetExceeded(Exception):
660
+ """Raised when _dimino exceeds the configured dimino_budget.
661
+
662
+ Callers should catch this and fall back to a dense / non-symmetry-aware
663
+ cost, emitting CostFallbackWarning.
664
+ """
665
+
666
+ def __init__(self, seen_count: int, budget: int) -> None:
667
+ super().__init__(
668
+ f"Dimino enumeration exceeded budget: visited {seen_count} elements "
669
+ f"(budget={budget}). Group is likely too large to enumerate exactly."
670
+ )
671
+ self.seen_count = seen_count
672
+ self.budget = budget
673
+
674
+
675
+ def _dimino(generators: tuple[_Permutation, ...]) -> list[_Permutation]:
676
+ """Enumerate all group elements via Dimino's algorithm.
677
+
678
+ Consults the configured ``dimino_budget`` setting (default 50_000); if the
679
+ seen-set size exceeds the budget, raises :class:`_DiminoBudgetExceeded`
680
+ instead of running indefinitely. Callers should catch and fall back to a
681
+ dense (no-symmetry) cost via :class:`flopscope.errors.CostFallbackWarning`.
682
+ """
683
+ from flopscope._config import get_setting
684
+
685
+ budget = int(get_setting("dimino_budget")) # type: ignore[arg-type]
686
+ n = generators[0].size
687
+ identity = _Permutation.identity(n)
688
+ elements = [identity]
689
+ seen: set[_Permutation] = {identity}
690
+
691
+ for gen in generators:
692
+ if gen in seen:
693
+ continue
694
+ coset = [gen]
695
+ seen.add(gen)
696
+ new_elements = [gen]
697
+ while new_elements:
698
+ next_new: list[_Permutation] = []
699
+ for elem in new_elements:
700
+ for g in generators:
701
+ product = elem * g
702
+ if product not in seen:
703
+ seen.add(product)
704
+ next_new.append(product)
705
+ if len(seen) > budget:
706
+ raise _DiminoBudgetExceeded(len(seen), budget)
707
+ product_r = g * elem
708
+ if product_r not in seen:
709
+ seen.add(product_r)
710
+ next_new.append(product_r)
711
+ if len(seen) > budget:
712
+ raise _DiminoBudgetExceeded(len(seen), budget)
713
+ new_elements = next_new
714
+ coset.extend(next_new)
715
+ elements.extend(coset)
716
+
717
+ return elements