schubmult 2.0.4__py3-none-any.whl → 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. schubmult/__init__.py +96 -1
  2. schubmult/perm_lib.py +254 -819
  3. schubmult/poly_lib/__init__.py +31 -0
  4. schubmult/poly_lib/poly_lib.py +276 -0
  5. schubmult/poly_lib/schub_poly.py +148 -0
  6. schubmult/poly_lib/variables.py +204 -0
  7. schubmult/rings/__init__.py +18 -0
  8. schubmult/rings/_quantum_schubert_polynomial_ring.py +752 -0
  9. schubmult/rings/_schubert_polynomial_ring.py +1031 -0
  10. schubmult/rings/_tensor_schub_ring.py +128 -0
  11. schubmult/rings/_utils.py +55 -0
  12. schubmult/{sage_integration → sage}/__init__.py +4 -1
  13. schubmult/{sage_integration → sage}/_fast_double_schubert_polynomial_ring.py +67 -109
  14. schubmult/{sage_integration → sage}/_fast_schubert_polynomial_ring.py +33 -28
  15. schubmult/{sage_integration → sage}/_indexing.py +9 -5
  16. schubmult/schub_lib/__init__.py +51 -0
  17. schubmult/{schubmult_double/_funcs.py → schub_lib/double.py} +532 -596
  18. schubmult/{schubmult_q/_funcs.py → schub_lib/quantum.py} +54 -53
  19. schubmult/schub_lib/quantum_double.py +954 -0
  20. schubmult/schub_lib/schub_lib.py +659 -0
  21. schubmult/{schubmult_py/_funcs.py → schub_lib/single.py} +45 -35
  22. schubmult/schub_lib/tests/__init__.py +0 -0
  23. schubmult/schub_lib/tests/legacy_perm_lib.py +946 -0
  24. schubmult/schub_lib/tests/test_vs_old.py +109 -0
  25. schubmult/scripts/__init__.py +0 -0
  26. schubmult/scripts/schubmult_double.py +378 -0
  27. schubmult/scripts/schubmult_py.py +84 -0
  28. schubmult/scripts/schubmult_q.py +109 -0
  29. schubmult/scripts/schubmult_q_double.py +207 -0
  30. schubmult/utils/__init__.py +0 -0
  31. schubmult/{_base_argparse.py → utils/argparse.py} +29 -5
  32. schubmult/utils/logging.py +16 -0
  33. schubmult/utils/parsing.py +20 -0
  34. schubmult/utils/perm_utils.py +135 -0
  35. schubmult/utils/test_utils.py +65 -0
  36. schubmult-3.0.1.dist-info/METADATA +1234 -0
  37. schubmult-3.0.1.dist-info/RECORD +41 -0
  38. {schubmult-2.0.4.dist-info → schubmult-3.0.1.dist-info}/WHEEL +1 -1
  39. schubmult-3.0.1.dist-info/entry_points.txt +5 -0
  40. schubmult/_tests.py +0 -24
  41. schubmult/schubmult_double/__init__.py +0 -12
  42. schubmult/schubmult_double/__main__.py +0 -6
  43. schubmult/schubmult_double/_script.py +0 -474
  44. schubmult/schubmult_py/__init__.py +0 -12
  45. schubmult/schubmult_py/__main__.py +0 -6
  46. schubmult/schubmult_py/_script.py +0 -97
  47. schubmult/schubmult_q/__init__.py +0 -8
  48. schubmult/schubmult_q/__main__.py +0 -6
  49. schubmult/schubmult_q/_script.py +0 -166
  50. schubmult/schubmult_q_double/__init__.py +0 -10
  51. schubmult/schubmult_q_double/__main__.py +0 -6
  52. schubmult/schubmult_q_double/_funcs.py +0 -540
  53. schubmult/schubmult_q_double/_script.py +0 -396
  54. schubmult-2.0.4.dist-info/METADATA +0 -542
  55. schubmult-2.0.4.dist-info/RECORD +0 -30
  56. schubmult-2.0.4.dist-info/entry_points.txt +0 -5
  57. {schubmult-2.0.4.dist-info → schubmult-3.0.1.dist-info}/licenses/LICENSE +0 -0
  58. {schubmult-2.0.4.dist-info → schubmult-3.0.1.dist-info}/top_level.txt +0 -0
schubmult/perm_lib.py CHANGED
@@ -1,54 +1,218 @@
1
- from bisect import bisect_left
2
- from functools import cache
3
- from itertools import chain
1
+ import math
2
+ from functools import cache, cached_property
4
3
 
5
- import numpy as np
6
- from symengine import Mul, Pow, symarray, sympify
4
+ import sympy.combinatorics.permutations as spp
5
+ from symengine import sympify
6
+ from sympy import Basic, Tuple
7
+
8
+ import schubmult.utils.logging as lg
9
+ from schubmult.utils.perm_utils import cyclic_sort, permtrim_list, sg
10
+
11
+ # schubmult.poly_lib.variables import GeneratingSet
12
+
13
+ logger = lg.get_logger(__name__)
7
14
 
8
15
  zero = sympify(0)
9
16
  n = 100
10
17
 
11
- q_var = symarray("q", n)
18
+ # TODO: permutations act
19
+
20
+
21
+ class Permutation(Basic):
22
+ def __new__(cls, perm):
23
+ return Permutation.__xnew_cached__(cls, tuple(perm))
24
+
25
+ print_as_code = False
26
+
27
+ @staticmethod
28
+ @cache
29
+ def __xnew_cached__(_class, perm):
30
+ return Permutation.__xnew__(_class, perm)
31
+
32
+ @staticmethod
33
+ def __xnew__(_class, perm):
34
+ p = tuple(permtrim_list([*perm]))
35
+ s_perm = spp.Permutation._af_new([i - 1 for i in p])
36
+ obj = Basic.__new__(_class, Tuple(*perm))
37
+ obj._s_perm = s_perm
38
+ obj._perm = p
39
+ obj._hash_code = hash(p)
40
+ cd = s_perm.inversion_vector()
41
+ obj._unique_key = (len(p), sum([cd[i] * math.factorial(len(p) - 1 - i) for i in range(len(cd))]))
42
+ return obj
43
+
44
+ @classmethod
45
+ def sorting_perm(cls, itera):
46
+ L = [i + 1 for i in range(len(itera))]
47
+ L.sort(key=lambda i: itera[i - 1])
48
+ return Permutation(L)
49
+
50
+ def _latex(self, printer):
51
+ if Permutation.print_as_code:
52
+ return printer.doprint(trimcode(self))
53
+ return printer.doprint(list(self._perm))
54
+
55
+ # pattern is a list, not a permutation
56
+ def has_pattern(self, pattern):
57
+ if self == Permutation(pattern):
58
+ return True
59
+ if len(self._perm) <= len(Permutation(pattern)):
60
+ return False
61
+ expanded = list(self) + [i for i in range(len(self)+1,len(pattern)+1)]
62
+ for i in range(len(expanded)):
63
+ rmval = expanded[i]
64
+ perm2 = [*expanded[:i], *expanded[i+1:]]
65
+ perm2 = tuple([val-1 if val>rmval else val for val in perm2])
66
+ if Permutation(perm2).has_pattern(pattern):
67
+ return True
68
+ return False
69
+
70
+ def _sympystr(self, printer):
71
+ if Permutation.print_as_code:
72
+ return printer.doprint(trimcode(self))
73
+ return printer.doprint(self._perm)
74
+
75
+ def __call__(self, i):
76
+ """1-indexed"""
77
+ return self[i - 1]
78
+
79
+ def descents(self, zero_indexed=True):
80
+ if zero_indexed:
81
+ return self._s_perm.descents()
82
+ return {i + 1 for i in self._s_perm.descents()}
83
+
84
+ def get_cycles(self):
85
+ return self.get_cycles_cached()
86
+
87
+ @cache
88
+ def get_cycles_cached(self):
89
+ return [tuple(cyclic_sort([i + 1 for i in c])) for c in self._s_perm.cyclic_form]
90
+
91
+ @property
92
+ def code(self):
93
+ return list(self.cached_code())
94
+
95
+ @cache
96
+ def cached_code(self):
97
+ return self._s_perm.inversion_vector()
98
+
99
+ @cached_property
100
+ def inv(self):
101
+ return self._s_perm.inversions()
102
+
103
+ def swap(self, i, j):
104
+ new_perm = [*self._perm]
105
+ # print(f"SWAP {new_perm=}")
106
+ if i > j:
107
+ i, j = j, i
108
+ if j >= len(new_perm):
109
+ # print(f"SWAP {j}>={new_perm=}")
110
+ new_perm += list(range(len(new_perm) + 1, j + 2))
111
+ # print(f"SWAP extended {new_perm=}")
112
+ new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
113
+ # print(f"SWAP iddle {new_perm=}")
114
+ return Permutation(new_perm)
115
+
116
+ def __getitem__(self, i):
117
+ if isinstance(i, slice):
118
+ return [self[ii] for ii in range(i.start if i.start is not None else 0, i.stop if i.stop is not None else len(self))]
119
+ if i >= len(self._perm):
120
+ return i + 1
121
+ return self._perm[i]
122
+
123
+ def __setitem__(self, i, v):
124
+ raise NotImplementedError
125
+
126
+ def __hash__(self):
127
+ return self._hash_code
128
+
129
+ def __mul__(self, other):
130
+ new_sperm = other._s_perm * self._s_perm
131
+ new_perm = permtrim_list([new_sperm.array_form[i] + 1 for i in range(new_sperm.size)])
132
+ return Permutation(new_perm)
133
+
134
+ def __iter__(self):
135
+ yield from self._perm.__iter__()
136
+
137
+ def __getslice__(self, i, j):
138
+ return self._perm[i:j]
139
+
140
+ def __str__(self):
141
+ return str(self._perm)
142
+
143
+ def __add__(self, other):
144
+ if not isinstance(other, list):
145
+ raise NotImplementedError
146
+ permlist = [*self._perm, *other]
147
+ try:
148
+ return Permutation(permlist)
149
+ except Exception:
150
+ return permlist
151
+
152
+ def __radd__(self, other):
153
+ if not isinstance(other, list):
154
+ raise NotImplementedError
155
+ permlist = [*other, *self._perm]
156
+ try:
157
+ return Permutation(permlist)
158
+ except Exception:
159
+ return permlist
160
+
161
+ def __eq__(self, other):
162
+ if isinstance(other, Permutation):
163
+ # print(f"{other._perm= } {self._perm=} {type(self._perm)=}")
164
+ # return other._perm == self._perm
165
+ return other._unique_key == self._unique_key
166
+ if isinstance(other, list):
167
+ # print(f"{[*self._perm]= } {other=}")
168
+ return [*self._perm] == other
169
+ if isinstance(other, tuple):
170
+ # print(f"{self._perm=} {other=}")
171
+ return self._perm == other
172
+ return False
173
+
174
+ def __len__(self):
175
+ # print("REMOVE THIS")
176
+ return max(len(self._perm), 2)
177
+
178
+ def __invert__(self):
179
+ new_sperm = ~(self._s_perm)
180
+ new_perm = [new_sperm.array_form[i] + 1 for i in range(new_sperm.size)]
181
+ return Permutation(new_perm)
182
+
183
+ def __repr__(self):
184
+ return self.__str__()
185
+
186
+ def __lt__(self, other):
187
+ return tuple(self) < tuple(other)
188
+
12
189
 
190
+ def ensure_perms(func):
191
+ def wrapper(*args):
192
+ return func(*[Permutation(arg) if (isinstance(arg, list) or isinstance(arg, tuple)) else arg for arg in args])
13
193
 
14
- def getpermval(perm, index):
15
- if index < len(perm):
16
- return perm[index]
17
- return index + 1
194
+ return wrapper
18
195
 
19
196
 
197
+ @ensure_perms
20
198
  def inv(perm):
21
- L = len(perm)
22
- v = list(range(1, L + 1))
23
- ans = 0
24
- for i in range(L):
25
- itr = bisect_left(v, perm[i])
26
- ans += itr
27
- v = v[:itr] + v[itr + 1 :]
28
- return ans
199
+ return perm.inv
29
200
 
30
201
 
202
+ @ensure_perms
31
203
  def code(perm):
32
- L = len(perm)
33
- ret = []
34
- v = list(range(1, L + 1))
35
- for i in range(L - 1):
36
- itr = bisect_left(v, perm[i])
37
- ret += [itr]
38
- v = v[:itr] + v[itr + 1 :]
39
- return ret
204
+ return perm.code
40
205
 
41
206
 
207
+ @ensure_perms
42
208
  def mulperm(perm1, perm2):
43
- if len(perm1) < len(perm2):
44
- return [perm1[perm2[i] - 1] if perm2[i] <= len(perm1) else perm2[i] for i in range(len(perm2))]
45
- return [perm1[perm2[i] - 1] for i in range(len(perm2))] + perm1[len(perm2) :]
209
+ return perm1 * perm2
46
210
 
47
211
 
48
212
  def uncode(cd):
49
213
  cd2 = [*cd]
50
214
  if cd2 == []:
51
- return [1, 2]
215
+ return Permutation([])
52
216
  max_required = max([cd2[i] + i for i in range(len(cd2))])
53
217
  cd2 += [0 for i in range(len(cd2), max_required)]
54
218
  fullperm = [i + 1 for i in range(len(cd2) + 1)]
@@ -56,129 +220,19 @@ def uncode(cd):
56
220
  for i in range(len(cd2)):
57
221
  perm += [fullperm.pop(cd2[i])]
58
222
  perm += [fullperm[0]]
59
- return perm
60
-
61
-
62
- def reversecode(perm):
63
- ret = []
64
- for i in range(len(perm) - 1, 0, -1):
65
- ret = [0, *ret]
66
- for j in range(i, -1, -1):
67
- if perm[i] > perm[j]:
68
- ret[-1] += 1
69
- return ret
70
-
71
-
72
- def reverseuncode(cd):
73
- cd2 = list(cd)
74
- if cd2 == []:
75
- return [1, 2]
76
- # max_required = max([cd2[i]+i for i in range(len(cd2))])
77
- # cd2 += [0 for i in range(len(cd2),max_required)]
78
- fullperm = [i + 1 for i in range(len(cd2) + 1)]
79
- perm = []
80
- for i in range(len(cd2) - 1, 0, -1):
81
- perm = [fullperm[cd2[i]], *perm]
82
- fullperm.pop(cd2[i])
83
- perm += [fullperm[0]]
84
- return perm
223
+ return Permutation(perm)
85
224
 
86
225
 
226
+ @ensure_perms
87
227
  def inverse(perm):
88
- retperm = [0 for i in range(len(perm))]
89
- for i in range(len(perm)):
90
- retperm[perm[i] - 1] = i + 1
91
- return retperm
228
+ return ~perm
92
229
 
93
230
 
94
231
  def permtrim(perm):
95
- L = len(perm)
96
- while L > 2 and perm[-1] == L:
97
- L = perm.pop() - 1
98
- return perm
99
-
100
-
101
- def has_bruhat_descent(perm, i, j):
102
- if perm[i] < perm[j]:
103
- return False
104
- for p in range(i + 1, j):
105
- if perm[i] > perm[p] and perm[p] > perm[j]:
106
- return False
107
- return True
108
-
109
-
110
- def count_bruhat(perm, i, j):
111
- up_amount = 0
112
- if perm[i] < perm[j]:
113
- up_amount = 1
114
- else:
115
- up_amount = -1
116
- for k in range(i + 1, j):
117
- if perm[i] < perm[k] and perm[k] < perm[j]:
118
- up_amount += 2
119
- elif perm[i] > perm[k] and perm[k] > perm[j]:
120
- up_amount -= 2
121
- return up_amount
122
-
123
-
124
- def has_bruhat_ascent(perm, i, j):
125
- if perm[i] > perm[j]:
126
- return False
127
- for p in range(i + 1, j):
128
- if perm[i] < perm[p] and perm[p] < perm[j]:
129
- return False
130
- return True
131
-
132
-
133
- def elem_sym_perms(orig_perm, p, k):
134
- total_list = [(orig_perm, 0)]
135
- up_perm_list = [(orig_perm, 1000000000)]
136
- for pp in range(p):
137
- perm_list = []
138
- for up_perm, last in up_perm_list:
139
- up_perm2 = [*up_perm, len(up_perm) + 1]
140
- if len(up_perm2) < k + 1:
141
- up_perm2 += [i + 1 for i in range(len(up_perm2), k + 2)]
142
- pos_list = [i for i in range(k) if up_perm2[i] < last]
143
- for j in range(k, len(up_perm2)):
144
- if up_perm2[j] >= last:
145
- continue
146
- for i in pos_list:
147
- if has_bruhat_ascent(up_perm2, i, j):
148
- new_perm = [*up_perm2]
149
- new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
150
- if new_perm[-1] == len(new_perm):
151
- new_perm_add = tuple(new_perm[:-1])
152
- else:
153
- new_perm_add = tuple(new_perm)
154
- perm_list += [(new_perm_add, up_perm2[j])]
155
- total_list += [(new_perm_add, pp + 1)]
156
- up_perm_list = perm_list
157
- return total_list
158
-
159
-
160
- def elem_sym_perms_op(orig_perm, p, k):
161
- total_list = [(orig_perm, 0)]
162
- up_perm_list = [(orig_perm, k)]
163
- for pp in range(p):
164
- perm_list = []
165
- for up_perm, last in up_perm_list:
166
- up_perm2 = [*up_perm]
167
- if len(up_perm2) < k + 1:
168
- up_perm2 += [i + 1 for i in range(len(up_perm2), k + 2)]
169
- pos_list = [i for i in range(k) if getpermval(up_perm2, i) == getpermval(orig_perm, i)]
170
- for j in range(last, len(up_perm2)):
171
- for i in pos_list:
172
- if has_bruhat_descent(up_perm2, i, j):
173
- new_perm = [*up_perm2]
174
- new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
175
- new_perm_add = tuple(permtrim(new_perm))
176
- perm_list += [(new_perm_add, j)]
177
- total_list += [(new_perm_add, pp + 1)]
178
- up_perm_list = perm_list
179
- return total_list
232
+ return Permutation(perm)
180
233
 
181
234
 
235
+ @ensure_perms
182
236
  def strict_theta(u):
183
237
  ret = [*trimcode(u)]
184
238
  did_one = True
@@ -194,271 +248,20 @@ def strict_theta(u):
194
248
  return ret
195
249
 
196
250
 
197
- def elem_sym_perms_q(orig_perm, p, k, q_var=q_var):
198
- total_list = [(orig_perm, 0, 1)]
199
- up_perm_list = [(orig_perm, 1, 1000)]
200
- for pp in range(p):
201
- perm_list = []
202
- for up_perm, val, last_j in up_perm_list:
203
- up_perm2 = [*up_perm, len(up_perm) + 1]
204
- if len(up_perm2) < k + 1:
205
- up_perm2 += [i + 1 for i in range(len(up_perm2), k + 2)]
206
- pos_list = [i for i in range(k) if (i >= len(orig_perm) and up_perm2[i] == i + 1) or (i < len(orig_perm) and up_perm2[i] == orig_perm[i])]
207
- for j in range(min(len(up_perm2) - 1, last_j), k - 1, -1):
208
- for i in pos_list:
209
- ct = count_bruhat(up_perm2, i, j)
210
- # print(f"{up_perm2=} {ct=} {i=} {j=} {k=} {pp=}")
211
- if ct == 1 or ct == 2 * (i - j) + 1:
212
- new_perm = [*up_perm2]
213
- new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
214
- new_perm_add = tuple(permtrim(new_perm))
215
- new_val = val
216
- if ct < 0:
217
- new_val *= np.prod([q_var[index] for index in range(i + 1, j + 1)])
218
- perm_list += [(new_perm_add, new_val, j)]
219
- total_list += [(new_perm_add, pp + 1, new_val)]
220
- up_perm_list = perm_list
221
- return total_list
222
-
223
-
224
- def elem_sym_perms_q_op(orig_perm, p, k, n, q_var=q_var):
225
- total_list = [(orig_perm, 0, 1)]
226
- up_perm_list = [(orig_perm, 1, k)]
227
- for pp in range(p):
228
- perm_list = []
229
- for up_perm, val, last_j in up_perm_list:
230
- up_perm2 = [*up_perm]
231
- if len(up_perm) < n:
232
- up_perm2 += [i + 1 for i in range(len(up_perm2), n)]
233
- pos_list = [i for i in range(k) if (i >= len(orig_perm) and up_perm2[i] == i + 1) or (i < len(orig_perm) and up_perm2[i] == orig_perm[i])]
234
- for j in range(last_j, n):
235
- for i in pos_list:
236
- ct = count_bruhat(up_perm2, i, j)
237
- # print(f"{up_perm2=} {ct=} {i=} {j=} {k=} {pp=}")
238
- if ct == -1 or ct == 2 * (j - i) - 1:
239
- new_perm = [*up_perm2]
240
- new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
241
- new_perm_add = tuple(permtrim(new_perm))
242
- new_val = val
243
- if ct > 0:
244
- new_val *= np.prod([q_var[index] for index in range(i + 1, j + 1)])
245
- perm_list += [(new_perm_add, new_val, j)]
246
- total_list += [(new_perm_add, pp + 1, new_val)]
247
- up_perm_list = perm_list
248
- return total_list
249
-
250
-
251
- def q_vector(q_exp, q_var=q_var):
252
- qvar_list = q_var.tolist()
253
- ret = []
254
-
255
- if q_exp == 1:
256
- return ret
257
- if q_exp in q_var:
258
- i = qvar_list.index(q_exp)
259
- return [0 for j in range(i - 1)] + [1]
260
- if isinstance(q_exp, Pow):
261
- qv = q_exp.args[0]
262
- expon = int(q_exp.args[1])
263
- i = qvar_list.index(qv)
264
- return [0 for j in range(i - 1)] + [expon]
265
- if isinstance(q_exp, Mul):
266
- for a in q_exp.args:
267
- v1 = q_vector(a)
268
- v1 += [0 for i in range(len(v1), len(ret))]
269
- ret += [0 for i in range(len(ret), len(v1))]
270
- ret = [ret[i] + v1[i] for i in range(len(ret))]
271
- return ret
272
-
273
- return None
274
-
275
-
276
- def omega(i, qv):
277
- i = i - 1
278
- if len(qv) == 0 or i > len(qv):
279
- return 0
280
- if i == 0:
281
- if len(qv) == 1:
282
- return 2 * qv[0]
283
- return 2 * qv[0] - qv[1]
284
- if i == len(qv):
285
- return -qv[-1]
286
- if i == len(qv) - 1:
287
- return 2 * qv[-1] - qv[-2]
288
- return 2 * qv[i] - qv[i - 1] - qv[i + 1]
289
-
290
-
291
- def sg(i, w):
292
- if i >= len(w) - 1 or w[i] < w[i + 1]:
293
- return 0
294
- return 1
295
-
296
-
297
- def reduce_q_coeff(u, v, w, qv):
298
- for i in range(len(qv)):
299
- if sg(i, v) == 1 and sg(i, u) == 0 and sg(i, w) + omega(i + 1, qv) == 1:
300
- ret_v = [*v]
301
- ret_v[i], ret_v[i + 1] = ret_v[i + 1], ret_v[i]
302
- ret_w = [*w] + [j + 1 for j in range(len(w), i + 2)]
303
- ret_w[i], ret_w[i + 1] = ret_w[i + 1], ret_w[i]
304
- qv_ret = [*qv]
305
- if sg(i, w) == 0:
306
- qv_ret[i] -= 1
307
- return u, tuple(permtrim(ret_v)), tuple(permtrim(ret_w)), qv_ret, True
308
- if (sg(i, u) == 1 and sg(i, v) == 0 and sg(i, w) + omega(i + 1, qv) == 1) or (sg(i, u) == 1 and sg(i, v) == 1 and sg(i, w) + omega(i + 1, qv) == 2):
309
- ret_u = [*u]
310
- ret_u[i], ret_u[i + 1] = ret_u[i + 1], ret_u[i]
311
- ret_w = [*w] + [j + 1 for j in range(len(w), i + 2)]
312
- ret_w[i], ret_w[i + 1] = ret_w[i + 1], ret_w[i]
313
- qv_ret = [*qv]
314
- if sg(i, w) == 0:
315
- qv_ret[i] -= 1
316
- return tuple(permtrim(ret_u)), v, tuple(permtrim(ret_w)), qv_ret, True
317
- return u, v, w, qv, False
318
-
319
-
320
- def reduce_q_coeff_u_only(u, v, w, qv):
321
- for i in range(len(qv)):
322
- if (sg(i, u) == 1 and sg(i, v) == 0 and sg(i, w) + omega(i + 1, qv) == 1) or (sg(i, u) == 1 and sg(i, v) == 1 and sg(i, w) + omega(i + 1, qv) == 2):
323
- ret_u = [*u]
324
- ret_u[i], ret_u[i + 1] = ret_u[i + 1], ret_u[i]
325
- ret_w = [*w] + [j + 1 for j in range(len(w), i + 2)]
326
- ret_w[i], ret_w[i + 1] = ret_w[i + 1], ret_w[i]
327
- qv_ret = [*qv]
328
- if sg(i, w) == 0:
329
- qv_ret[i] -= 1
330
- return tuple(permtrim(ret_u)), v, tuple(permtrim(ret_w)), qv_ret, True
331
- return u, v, w, qv, False
332
-
333
-
334
251
  def longest_element(indices):
335
- perm = [1, 2]
252
+ perm = Permutation([1, 2])
336
253
  did_one = True
337
254
  while did_one:
338
255
  did_one = False
339
256
  for i in range(len(indices)):
340
257
  j = indices[i] - 1
341
258
  if sg(j, perm) == 0:
342
- if len(perm) < j + 2:
343
- perm = perm + list(range(len(perm) + 1, j + 3))
344
- perm[j], perm[j + 1] = perm[j + 1], perm[j]
259
+ perm = perm.swap(j, j + 1)
345
260
  did_one = True
346
261
  return permtrim(perm)
347
262
 
348
263
 
349
- def count_less_than(arr, val):
350
- ct = 0
351
- i = 0
352
- while i < len(arr) and arr[i] < val:
353
- i += 1
354
- ct += 1
355
- return ct
356
-
357
-
358
- def is_parabolic(w, parabolic_index):
359
- for i in parabolic_index:
360
- if sg(i - 1, w) == 1:
361
- return False
362
- return True
363
-
364
-
365
- def check_blocks(qv, parabolic_index):
366
- blocks = []
367
- cur_block = []
368
- last_val = -1
369
- for i in range(len(parabolic_index)):
370
- if last_val == -1 or last_val + 1 == parabolic_index[i]:
371
- last_val = parabolic_index[i]
372
- cur_block += [last_val]
373
- else:
374
- blocks += [cur_block]
375
- cur_block = []
376
- for block in blocks:
377
- for i in range(len(block)):
378
- for j in range(i, len(block)):
379
- val = 0
380
- for k in range(i, j + 1):
381
- val += omega(block[k], qv)
382
- if val != 0 and val != -1:
383
- return False
384
- return True
385
-
386
-
387
- # perms and inversion diff
388
- def kdown_perms(perm, monoperm, p, k):
389
- inv_m = inv(monoperm)
390
- inv_p = inv(perm)
391
- full_perm_list = []
392
-
393
- if inv(mulperm(list(perm), monoperm)) == inv_m - inv_p:
394
- full_perm_list += [(tuple(perm), 0, 1)]
395
-
396
- down_perm_list = [(perm, 1)]
397
- if len(perm) < k:
398
- return full_perm_list
399
- a2 = k - 1
400
- for pp in range(1, p + 1):
401
- down_perm_list2 = []
402
- for perm2, s in down_perm_list:
403
- L = len(perm2)
404
- if k > L:
405
- continue
406
- s2 = -s
407
- for b in chain(range(k - 1), range(k, L)):
408
- if perm2[b] != perm[b]:
409
- continue
410
- if b < a2:
411
- i, j = b, a2
412
- else:
413
- i, j, s2 = a2, b, s
414
- if has_bruhat_descent(perm2, i, j):
415
- new_perm = [*perm2]
416
- new_perm[a2], new_perm[b] = new_perm[b], new_perm[a2]
417
- permtrim(new_perm)
418
- down_perm_list2 += [(new_perm, s2)]
419
- if inv(mulperm(new_perm, monoperm)) == inv_m - inv_p + pp:
420
- full_perm_list += [(tuple(new_perm), pp, s2)]
421
- down_perm_list = down_perm_list2
422
- return full_perm_list
423
-
424
-
425
- def compute_vpathdicts(th, vmu, smpify=False):
426
- vpathdicts = [{} for index in range(len(th))]
427
- vpathdicts[-1][tuple(vmu)] = None
428
- thL = len(th)
429
-
430
- top = code(inverse(uncode(th)))
431
- for i in range(thL - 1, -1, -1):
432
- top2 = code(inverse(uncode(top)))
433
- while top2[-1] == 0:
434
- top2.pop()
435
- top2.pop()
436
- top = code(inverse(uncode(top2)))
437
- monoperm = uncode(top)
438
- if len(monoperm) < 2:
439
- monoperm = [1, 2]
440
- k = i + 1
441
- for last_perm in vpathdicts[i]:
442
- newperms = kdown_perms(last_perm, monoperm, th[i], k)
443
- vpathdicts[i][last_perm] = newperms
444
- if i > 0:
445
- for trip in newperms:
446
- vpathdicts[i - 1][trip[0]] = None
447
- vpathdicts2 = [{} for i in range(len(th))]
448
- for i in range(len(th)):
449
- for key, valueset in vpathdicts[i].items():
450
- for value in valueset:
451
- key2 = value[0]
452
- if key2 not in vpathdicts2[i]:
453
- vpathdicts2[i][key2] = set()
454
- v2 = value[2]
455
- if smpify:
456
- v2 = sympify(v2)
457
- vpathdicts2[i][key2].add((key, value[1], v2))
458
- # print(vpathdicts2)
459
- return vpathdicts2
460
-
461
-
264
+ @ensure_perms
462
265
  def theta(perm):
463
266
  cd = code(perm)
464
267
  for i in range(len(cd) - 1, 0, -1):
@@ -469,148 +272,30 @@ def theta(perm):
469
272
  return cd
470
273
 
471
274
 
472
- def add_perm_dict(d1, d2):
473
- for k, v in d2.items():
474
- d1[k] = d1.get(k, 0) + v
475
- return d1
476
-
477
-
478
- one = sympify(1)
479
-
480
-
481
- def elem_sym_poly_q(p, k, varl1, varl2, q_var=q_var):
482
- if p == 0 and k >= 0:
483
- return one
484
- if p < 0 or p > k:
485
- return zero
486
- return (
487
- (varl1[k - 1] - varl2[k - p]) * elem_sym_poly_q(p - 1, k - 1, varl1, varl2, q_var)
488
- + elem_sym_poly_q(p, k - 1, varl1, varl2, q_var)
489
- + q_var[k - 1] * elem_sym_poly_q(p - 2, k - 2, varl1, varl2, q_var)
490
- )
491
-
492
-
493
- def elem_sym_poly(p, k, varl1, varl2, xstart=0, ystart=0):
494
- if p > k:
495
- return zero
496
- if p == 0:
497
- return one
498
- if p == 1:
499
- res = varl1[xstart] - varl2[ystart]
500
- for i in range(1, k):
501
- res += varl1[xstart + i] - varl2[ystart + i]
502
- return res
503
- if p == k:
504
- res = (varl1[xstart] - varl2[ystart]) * (varl1[xstart + 1] - varl2[ystart])
505
- for i in range(2, k):
506
- res *= varl1[i + xstart] - varl2[ystart]
507
- return res
508
- mid = k // 2
509
- xsm = xstart + mid
510
- ysm = ystart + mid
511
- kmm = k - mid
512
- res = elem_sym_poly(p, mid, varl1, varl2, xstart, ystart) + elem_sym_poly(
513
- p,
514
- kmm,
515
- varl1,
516
- varl2,
517
- xsm,
518
- ysm,
519
- )
520
- for p2 in range(max(1, p - kmm), min(p, mid + 1)):
521
- res += elem_sym_poly(p2, mid, varl1, varl2, xstart, ystart) * elem_sym_poly(
522
- p - p2,
523
- kmm,
524
- varl1,
525
- varl2,
526
- xsm,
527
- ysm - p2,
528
- )
529
- return res
530
-
531
-
532
- @cache
533
- def call_zvars(v1, v2, k, i): # noqa: ARG001
534
- v3 = [*v2, *list(range(len(v2) + 1, i + 1))]
535
- return [v3[i - 1]] + [v3[j] for j in range(len(v1), len(v3)) if v3[j] != j + 1 and j != i - 1] + [v3[j] for j in range(len(v1)) if v1[j] != v3[j] and j != i - 1]
536
-
537
-
538
- def elem_sym_func(k, i, u1, u2, v1, v2, udiff, vdiff, varl1, varl2):
539
- newk = k - udiff
540
- if newk < vdiff:
541
- return zero
542
- if newk == vdiff:
543
- return one
544
- yvars = []
545
- for j in range(min(len(u1), k)):
546
- if u1[j] == u2[j]:
547
- yvars += [varl1[u2[j]]]
548
- for j in range(len(u1), min(k, len(u2))):
549
- if u2[j] == j + 1:
550
- yvars += [varl1[u2[j]]]
551
- for j in range(len(u2), k):
552
- yvars += [varl1[j + 1]]
553
- zvars = [varl2[i] for i in call_zvars(v1, v2, k, i)]
554
- return elem_sym_poly(newk - vdiff, newk, yvars, zvars)
555
-
556
-
557
- def elem_sym_func_q(k, i, u1, u2, v1, v2, udiff, vdiff, varl1, varl2):
558
- newk = k - udiff
559
- if newk < vdiff:
560
- return zero
561
- if newk == vdiff:
562
- return one
563
- yvars = []
564
- mlen = max(len(u1), len(u2))
565
- u1 = [*u1] + [a + 1 for a in range(len(u1), mlen)]
566
- u2 = [*u2] + [a + 1 for a in range(len(u2), mlen)]
567
- for j in range(min(len(u1), k)):
568
- if u1[j] == u2[j]:
569
- yvars += [varl1[u2[j]]]
570
- for j in range(len(u1), min(k, len(u2))):
571
- if u2[j] == j + 1:
572
- yvars += [varl1[u2[j]]]
573
- for j in range(len(u2), k):
574
- yvars += [varl1[j + 1]]
575
- zvars = [varl2[a] for a in call_zvars(v1, v2, k, i)]
576
- return elem_sym_poly(newk - vdiff, newk, yvars, zvars)
577
-
578
-
275
+ @ensure_perms
579
276
  def trimcode(perm):
580
- cd = code(perm)
277
+ cd = perm.code
581
278
  while len(cd) > 0 and cd[-1] == 0:
582
279
  cd.pop()
583
280
  return cd
584
281
 
585
282
 
586
- def p_trans(part):
587
- newpart = []
588
- if len(part) == 0 or part[0] == 0:
589
- return [0]
590
- for i in range(1, part[0] + 1):
591
- cnt = 0
592
- for j in range(len(part)):
593
- if part[j] >= i:
594
- cnt += 1
595
- if cnt == 0:
596
- break
597
- newpart += [cnt]
598
- return newpart
599
-
600
-
601
283
  def cycle(p, q):
602
- return list(range(1, p)) + [i + 1 for i in range(p, p + q)] + [p]
284
+ return Permutation(list(range(1, p)) + [i + 1 for i in range(p, p + q)] + [p])
603
285
 
604
286
 
287
+ @ensure_perms
605
288
  def phi1(u):
606
- c_star = code(inverse(u))
289
+ c_star = (~u).code
607
290
  c_star.pop(0)
608
- return inverse(uncode(c_star))
291
+ # print(f"{uncode(c_star)=}")
292
+ return ~(uncode(c_star))
609
293
 
610
294
 
295
+ @ensure_perms
611
296
  def one_dominates(u, w):
612
- c_star_u = code(inverse(u))
613
- c_star_w = code(inverse(w))
297
+ c_star_u = (~u).code
298
+ c_star_w = (~w).code
614
299
 
615
300
  a = c_star_u[0]
616
301
  b = c_star_w[0]
@@ -624,303 +309,16 @@ def one_dominates(u, w):
624
309
 
625
310
 
626
311
  def dominates(u, w):
627
- u2 = [*u]
628
- w2 = [*w]
629
- while u2 != [1, 2] and one_dominates(u2, w2):
312
+ u2 = u
313
+ w2 = w
314
+ while inv(u2) > 0 and one_dominates(u2, w2):
630
315
  u2 = phi1(u2)
631
316
  w2 = phi1(w2)
632
- if u2 == [1, 2]:
317
+ if inv(u2) == 0:
633
318
  return True
634
319
  return False
635
320
 
636
321
 
637
- def reduce_coeff(u, v, w):
638
- t_mu_u_t = theta(inverse(u))
639
- t_mu_v_t = theta(inverse(v))
640
-
641
- mu_u_inv = uncode(t_mu_u_t)
642
- mu_v_inv = uncode(t_mu_v_t)
643
-
644
- t_mu_u = p_trans(t_mu_u_t)
645
- t_mu_v = p_trans(t_mu_v_t)
646
-
647
- t_mu_u += [0 for i in range(len(t_mu_u), max(len(t_mu_u), len(t_mu_v)))]
648
- t_mu_v += [0 for i in range(len(t_mu_v), max(len(t_mu_u), len(t_mu_v)))]
649
-
650
- t_mu_uv = [t_mu_u[i] + t_mu_v[i] for i in range(len(t_mu_u))]
651
- t_mu_uv_t = p_trans(t_mu_uv)
652
-
653
- mu_uv_inv = uncode(t_mu_uv_t)
654
-
655
- if inv(mulperm(list(w), mu_uv_inv)) != inv(mu_uv_inv) - inv(w):
656
- return u, v, w
657
-
658
- umu = mulperm(list(u), mu_u_inv)
659
- vmu = mulperm(list(v), mu_v_inv)
660
- wmu = mulperm(list(w), mu_uv_inv)
661
-
662
- t_mu_w = theta(inverse(wmu))
663
-
664
- mu_w = uncode(t_mu_w)
665
-
666
- w_prime = mulperm(wmu, mu_w)
667
-
668
- if permtrim(list(w)) == permtrim(w_prime):
669
- return (permtrim(list(u)), permtrim(list(v)), permtrim(list(w)))
670
-
671
- A = []
672
- B = []
673
- indexA = 0
674
-
675
- while len(t_mu_u_t) > 0 and t_mu_u_t[-1] == 0:
676
- t_mu_u_t.pop()
677
-
678
- while len(t_mu_v_t) > 0 and t_mu_v_t[-1] == 0:
679
- t_mu_v_t.pop()
680
-
681
- while len(t_mu_uv_t) > 0 and t_mu_uv_t[-1] == 0:
682
- t_mu_uv_t.pop()
683
-
684
- for index in range(len(t_mu_uv_t)):
685
- if indexA < len(t_mu_u_t) and t_mu_uv_t[index] == t_mu_u_t[indexA]:
686
- A += [index]
687
- indexA += 1
688
- else:
689
- B += [index]
690
-
691
- mu_w_A = uncode(mu_A(code(mu_w), A))
692
- mu_w_B = uncode(mu_A(code(mu_w), B))
693
-
694
- return (
695
- permtrim(mulperm(umu, mu_w_A)),
696
- permtrim(mulperm(vmu, mu_w_B)),
697
- permtrim(w_prime),
698
- )
699
-
700
-
701
- def mu_A(mu, A):
702
- mu_t = p_trans(mu)
703
- mu_A_t = []
704
- for i in range(len(A)):
705
- if A[i] < len(mu_t):
706
- mu_A_t += [mu_t[A[i]]]
707
- return p_trans(mu_A_t)
708
-
709
-
710
- def reduce_descents(u, v, w):
711
- u2 = [*u]
712
- v2 = [*v]
713
- w2 = [*w]
714
- found_one = True
715
- while found_one:
716
- found_one = False
717
- if will_formula_work(u2, v2) or will_formula_work(v2, u2) or one_dominates(u2, w2) or is_reducible(v2) or inv(w2) - inv(u2) == 1:
718
- break
719
- for i in range(len(w2) - 2, -1, -1):
720
- if w2[i] > w2[i + 1] and i < len(v2) - 1 and v2[i] > v2[i + 1] and (i >= len(u2) - 1 or u2[i] < u2[i + 1]):
721
- w2[i], w2[i + 1] = w2[i + 1], w2[i]
722
- v2[i], v2[i + 1] = v2[i + 1], v2[i]
723
- found_one = True
724
- elif w2[i] > w2[i + 1] and i < len(u2) - 1 and u2[i] > u2[i + 1] and (i >= len(v2) - 1 or v2[i] < v2[i + 1]):
725
- w2[i], w2[i + 1] = w2[i + 1], w2[i]
726
- u2[i], u2[i + 1] = u2[i + 1], u2[i]
727
- found_one = True
728
- if found_one:
729
- break
730
- return permtrim(u2), permtrim(v2), permtrim(w2)
731
-
732
-
733
- def is_reducible(v):
734
- c03 = code(v)
735
- found0 = False
736
- good = True
737
- for i in range(len(c03)):
738
- if c03[i] == 0:
739
- found0 = True
740
- elif c03[i] != 0 and found0:
741
- good = False
742
- break
743
- return good
744
-
745
-
746
- def try_reduce_v(u, v, w):
747
- if is_reducible(v):
748
- return tuple(permtrim([*u])), tuple(permtrim([*v])), tuple(permtrim([*w]))
749
- u2 = [*u]
750
- v2 = [*v]
751
- w2 = [*w]
752
- cv = code(v2)
753
- for i in range(len(v2) - 2, -1, -1):
754
- if cv[i] == 0 and i < len(cv) - 1 and cv[i + 1] != 0:
755
- if i >= len(u2) - 1 or u2[i] < u2[i + 1]:
756
- v2[i], v2[i + 1] = v2[i + 1], v2[i]
757
- if i >= len(w2) - 1:
758
- w2 += list(range(len(w2) + 1, i + 3))
759
- w2[i + 1], w2[i] = w2[i], w2[i + 1]
760
- if is_reducible(v2):
761
- return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
762
- return try_reduce_v(u2, v2, w2)
763
- if i < len(w2) - 1 and w2[i] > w2[i + 1]:
764
- u2[i], u2[i + 1] = u2[i + 1], u2[i]
765
- v2[i], v2[i + 1] = v2[i + 1], v2[i]
766
- return try_reduce_v(u2, v2, w2)
767
- return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
768
- return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
769
-
770
-
771
- def try_reduce_u(u, v, w):
772
- if one_dominates(u, w):
773
- return u, v, w
774
- u2 = [*u]
775
- v2 = [*v]
776
- w2 = [*w]
777
- cu = code(u)
778
- for i in range(len(u2) - 2, -1, -1):
779
- if cu[i] == 0 and i < len(cu) - 1 and cu[i + 1] != 0:
780
- if i >= len(v2) - 1 or v2[i] < v2[i + 1]:
781
- u2[i], u2[i + 1] = u2[i + 1], u2[i]
782
- if i > len(w2) - 1:
783
- w2 += list(range(len(w2) + 1, i + 3))
784
- w2[i + 1], w2[i] = w2[i], w2[i + 1]
785
- if one_dominates(u, w):
786
- return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
787
- return try_reduce_u(u2, v2, w2)
788
- if i < len(w2) - 1 and w2[i] > w2[i + 1]:
789
- u2[i], u2[i + 1] = u2[i + 1], u2[i]
790
- v2[i], v2[i + 1] = v2[i + 1], v2[i]
791
- return try_reduce_u(u2, v2, w2)
792
- return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
793
- return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
794
-
795
-
796
- def divdiffable(v, u):
797
- inv_v = inv(v)
798
- inv_u = inv(u)
799
- perm2 = permtrim(mulperm(v, inverse(u)))
800
- if inv(perm2) != inv_v - inv_u:
801
- return []
802
- return perm2
803
-
804
-
805
- def will_formula_work(u, v):
806
- muv = uncode(theta(v))
807
- vn1muv = mulperm(inverse(v), muv)
808
- while True:
809
- found_one = False
810
- for i in range(len(vn1muv) - 1):
811
- if vn1muv[i] > vn1muv[i + 1]:
812
- found_one = True
813
- if i < len(u) - 1 and u[i] > u[i + 1]:
814
- return False
815
- vn1muv[i], vn1muv[i + 1] = vn1muv[i + 1], vn1muv[i]
816
- break
817
- if not found_one:
818
- return True
819
-
820
-
821
- def pull_out_var(vnum, v):
822
- vup = [*v, len(v) + 1]
823
- if vnum >= len(v):
824
- return [[[], v]]
825
- vpm_list = [(vup, 0)]
826
- ret_list = []
827
- for p in range(len(v) + 1 - vnum):
828
- vpm_list2 = []
829
- for vpm, b in vpm_list:
830
- if vpm[vnum - 1] == len(v) + 1:
831
- vpm2 = [*vpm]
832
- vpm2.pop(vnum - 1)
833
- vp = permtrim(vpm2)
834
- ret_list += [
835
- [
836
- [v[i] for i in range(vnum, len(v)) if ((i > len(vp) and v[i] == i) or (i <= len(vp) and v[i] == vp[i - 1]))],
837
- vp,
838
- ],
839
- ]
840
- for j in range(vnum, len(vup)):
841
- if vpm[j] <= b:
842
- continue
843
- for i in range(vnum):
844
- if has_bruhat_ascent(vpm, i, j):
845
- vpm[i], vpm[j] = vpm[j], vpm[i]
846
- vpm_list2 += [([*vpm], vpm[i])]
847
- vpm[i], vpm[j] = vpm[j], vpm[i]
848
- vpm_list = vpm_list2
849
- for vpm, b in vpm_list:
850
- if vpm[vnum - 1] == len(v) + 1:
851
- vpm2 = [*vpm]
852
- vpm2.pop(vnum - 1)
853
- vp = permtrim(vpm2)
854
- ret_list += [
855
- [
856
- [v[i] for i in range(vnum, len(v)) if ((i > len(vp) and v[i] == i) or (i <= len(vp) and v[i] == vp[i - 1]))],
857
- vp,
858
- ],
859
- ]
860
- return ret_list
861
-
862
-
863
- def get_cycles(perm):
864
- cycle_set = []
865
- done_vals = set()
866
- for i in range(len(perm)):
867
- p = i + 1
868
- if perm[i] == p:
869
- continue
870
- if p in done_vals:
871
- continue
872
- cycle = []
873
- m = -1
874
- max_index = -1
875
- while p not in done_vals:
876
- cycle += [p]
877
- done_vals.add(p)
878
- if p > m:
879
- m = p
880
- max_index = len(cycle) - 1
881
- p = perm[p - 1]
882
- cycle = tuple(cycle[max_index + 1 :] + cycle[: max_index + 1])
883
- cycle_set += [cycle]
884
- return cycle_set
885
-
886
-
887
- def double_elem_sym_q(u, p1, p2, k, q_var=q_var):
888
- ret_list = {}
889
- perms1 = elem_sym_perms_q(u, p1, k, q_var)
890
- iu = inverse(u)
891
- for perm1, udiff1, mul_val1 in perms1:
892
- perms2 = elem_sym_perms_q(perm1, p2, k, q_var)
893
- cycles1 = get_cycles(tuple(permtrim(mulperm(iu, [*perm1]))))
894
- cycles1_dict = {}
895
- for c in cycles1:
896
- if c[-1] not in cycles1_dict:
897
- cycles1_dict[c[-1]] = []
898
- cycles1_dict[c[-1]] += [set(c)]
899
- ip1 = inverse(perm1)
900
- for perm2, udiff2, mul_val2 in perms2:
901
- cycles2 = get_cycles(tuple(permtrim(mulperm(ip1, [*perm2]))))
902
- good = True
903
- for i in range(len(cycles2)):
904
- c2 = cycles2[i]
905
- if c2[-1] not in cycles1_dict:
906
- continue
907
- for c1_s in cycles1_dict[c2[-1]]:
908
- for a in range(len(c2) - 2, -1, -1):
909
- if c2[a] in c1_s:
910
- good = False
911
- break
912
- if not good:
913
- break
914
- if not good:
915
- break
916
-
917
- if good:
918
- if (perm1, udiff1, mul_val1) not in ret_list:
919
- ret_list[(perm1, udiff1, mul_val1)] = []
920
- ret_list[(perm1, udiff1, mul_val1)] += [(perm2, udiff2, mul_val2)]
921
- return ret_list
922
-
923
-
924
322
  def medium_theta(perm):
925
323
  cd = code(perm)
926
324
  found_one = True
@@ -932,8 +330,45 @@ def medium_theta(perm):
932
330
  cd[i], cd[i + 1] = cd[i + 1] + 1, cd[i]
933
331
  break
934
332
  if cd[i] == cd[i + 1] and cd[i] != 0 and i > 0 and cd[i - 1] <= cd[i] + 1:
935
- # if cd[i]==cd[i+1] and i>0 and cd[i-1]<=cd[i]+1:
936
333
  cd[i] += 1
937
334
  found_one = True
938
335
  break
939
336
  return cd
337
+
338
+
339
+ def split_perms(perms):
340
+ perms2 = [perms[0]]
341
+ for perm in perms[1:]:
342
+ cd = code(perm)
343
+ index = -1
344
+ not_zero = False
345
+ did = False
346
+ for i in range(len(cd)):
347
+ if cd[i] != 0:
348
+ not_zero = True
349
+ elif not_zero and cd[i] == 0:
350
+ not_zero = False
351
+ index = i
352
+ num_zeros_to_miss = 0
353
+ for j in range(index):
354
+ if cd[j] != 0:
355
+ num_zeros_to_miss = max(num_zeros_to_miss, cd[j] - (index - 1 - j))
356
+ num_zeros = 0
357
+ for j in range(index, len(cd)):
358
+ if cd[j] != 0:
359
+ break
360
+ num_zeros += 1
361
+ if num_zeros >= num_zeros_to_miss:
362
+ cd1 = cd[:index]
363
+ cd2 = [0 for i in range(index)] + cd[index:]
364
+ perms2 += [
365
+ uncode(cd1),
366
+ uncode(cd2),
367
+ ]
368
+ did = True
369
+ break
370
+ if not did:
371
+ perms2 += [perm]
372
+ return perms2
373
+
374
+ bad_classical_patterns = [Permutation([1,4,2,3]), Permutation([1,4,3,2]), Permutation([4,1,3,2]), Permutation([3,1,4,2])]