schubmult 2.0.4__py3-none-any.whl → 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. schubmult/__init__.py +96 -1
  2. schubmult/perm_lib.py +254 -819
  3. schubmult/poly_lib/__init__.py +31 -0
  4. schubmult/poly_lib/poly_lib.py +276 -0
  5. schubmult/poly_lib/schub_poly.py +148 -0
  6. schubmult/poly_lib/variables.py +204 -0
  7. schubmult/rings/__init__.py +18 -0
  8. schubmult/rings/_quantum_schubert_polynomial_ring.py +752 -0
  9. schubmult/rings/_schubert_polynomial_ring.py +1031 -0
  10. schubmult/rings/_tensor_schub_ring.py +128 -0
  11. schubmult/rings/_utils.py +55 -0
  12. schubmult/{sage_integration → sage}/__init__.py +4 -1
  13. schubmult/{sage_integration → sage}/_fast_double_schubert_polynomial_ring.py +67 -109
  14. schubmult/{sage_integration → sage}/_fast_schubert_polynomial_ring.py +33 -28
  15. schubmult/{sage_integration → sage}/_indexing.py +9 -5
  16. schubmult/schub_lib/__init__.py +51 -0
  17. schubmult/{schubmult_double/_funcs.py → schub_lib/double.py} +532 -596
  18. schubmult/{schubmult_q/_funcs.py → schub_lib/quantum.py} +54 -53
  19. schubmult/schub_lib/quantum_double.py +954 -0
  20. schubmult/schub_lib/schub_lib.py +659 -0
  21. schubmult/{schubmult_py/_funcs.py → schub_lib/single.py} +45 -35
  22. schubmult/schub_lib/tests/__init__.py +0 -0
  23. schubmult/schub_lib/tests/legacy_perm_lib.py +946 -0
  24. schubmult/schub_lib/tests/test_vs_old.py +109 -0
  25. schubmult/scripts/__init__.py +0 -0
  26. schubmult/scripts/schubmult_double.py +378 -0
  27. schubmult/scripts/schubmult_py.py +84 -0
  28. schubmult/scripts/schubmult_q.py +109 -0
  29. schubmult/scripts/schubmult_q_double.py +207 -0
  30. schubmult/utils/__init__.py +0 -0
  31. schubmult/{_base_argparse.py → utils/argparse.py} +29 -5
  32. schubmult/utils/logging.py +16 -0
  33. schubmult/utils/parsing.py +20 -0
  34. schubmult/utils/perm_utils.py +135 -0
  35. schubmult/utils/test_utils.py +65 -0
  36. schubmult-3.0.1.dist-info/METADATA +1234 -0
  37. schubmult-3.0.1.dist-info/RECORD +41 -0
  38. {schubmult-2.0.4.dist-info → schubmult-3.0.1.dist-info}/WHEEL +1 -1
  39. schubmult-3.0.1.dist-info/entry_points.txt +5 -0
  40. schubmult/_tests.py +0 -24
  41. schubmult/schubmult_double/__init__.py +0 -12
  42. schubmult/schubmult_double/__main__.py +0 -6
  43. schubmult/schubmult_double/_script.py +0 -474
  44. schubmult/schubmult_py/__init__.py +0 -12
  45. schubmult/schubmult_py/__main__.py +0 -6
  46. schubmult/schubmult_py/_script.py +0 -97
  47. schubmult/schubmult_q/__init__.py +0 -8
  48. schubmult/schubmult_q/__main__.py +0 -6
  49. schubmult/schubmult_q/_script.py +0 -166
  50. schubmult/schubmult_q_double/__init__.py +0 -10
  51. schubmult/schubmult_q_double/__main__.py +0 -6
  52. schubmult/schubmult_q_double/_funcs.py +0 -540
  53. schubmult/schubmult_q_double/_script.py +0 -396
  54. schubmult-2.0.4.dist-info/METADATA +0 -542
  55. schubmult-2.0.4.dist-info/RECORD +0 -30
  56. schubmult-2.0.4.dist-info/entry_points.txt +0 -5
  57. {schubmult-2.0.4.dist-info → schubmult-3.0.1.dist-info}/licenses/LICENSE +0 -0
  58. {schubmult-2.0.4.dist-info → schubmult-3.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,946 @@
1
+ from bisect import bisect_left
2
+ from functools import cache
3
+ from itertools import chain
4
+
5
+ import numpy as np
6
+ from symengine import Mul, Pow, symarray, sympify
7
+
8
+ zero = sympify(0)
9
+ n = 100
10
+
11
+ q_var = symarray("q", n)
12
+
13
+
14
+ def getpermval(perm, index):
15
+ if index < len(perm):
16
+ return perm[index]
17
+ return index + 1
18
+
19
+
20
+ def inv(perm):
21
+ L = len(perm)
22
+ v = list(range(1, L + 1))
23
+ ans = 0
24
+ for i in range(L):
25
+ itr = bisect_left(v, perm[i])
26
+ ans += itr
27
+ v = v[:itr] + v[itr + 1 :]
28
+ return ans
29
+
30
+
31
+ def code(perm):
32
+ L = len(perm)
33
+ ret = []
34
+ v = list(range(1, L + 1))
35
+ for i in range(L - 1):
36
+ itr = bisect_left(v, perm[i])
37
+ ret += [itr]
38
+ v = v[:itr] + v[itr + 1 :]
39
+ return ret
40
+
41
+
42
+ def mulperm(perm1, perm2):
43
+ if len(perm1) < len(perm2):
44
+ return [perm1[perm2[i] - 1] if perm2[i] <= len(perm1) else perm2[i] for i in range(len(perm2))]
45
+ return [perm1[perm2[i] - 1] for i in range(len(perm2))] + perm1[len(perm2) :]
46
+
47
+
48
+ def uncode(cd):
49
+ cd2 = [*cd]
50
+ if cd2 == []:
51
+ return [1, 2]
52
+ max_required = max([cd2[i] + i for i in range(len(cd2))])
53
+ cd2 += [0 for i in range(len(cd2), max_required)]
54
+ fullperm = [i + 1 for i in range(len(cd2) + 1)]
55
+ perm = []
56
+ for i in range(len(cd2)):
57
+ perm += [fullperm.pop(cd2[i])]
58
+ perm += [fullperm[0]]
59
+ return perm
60
+
61
+
62
+ def reversecode(perm):
63
+ ret = []
64
+ for i in range(len(perm) - 1, 0, -1):
65
+ ret = [0, *ret]
66
+ for j in range(i, -1, -1):
67
+ if perm[i] > perm[j]:
68
+ ret[-1] += 1
69
+ return ret
70
+
71
+
72
+ def reverseuncode(cd):
73
+ cd2 = list(cd)
74
+ if cd2 == []:
75
+ return [1, 2]
76
+ # max_required = max([cd2[i]+i for i in range(len(cd2))])
77
+ # cd2 += [0 for i in range(len(cd2),max_required)]
78
+ fullperm = [i + 1 for i in range(len(cd2) + 1)]
79
+ perm = []
80
+ for i in range(len(cd2) - 1, 0, -1):
81
+ perm = [fullperm[cd2[i]], *perm]
82
+ fullperm.pop(cd2[i])
83
+ perm += [fullperm[0]]
84
+ return perm
85
+
86
+
87
+ def inverse(perm):
88
+ retperm = [0 for i in range(len(perm))]
89
+ for i in range(len(perm)):
90
+ retperm[perm[i] - 1] = i + 1
91
+ return retperm
92
+
93
+
94
+ def permtrim(perm):
95
+ L = len(perm)
96
+ while L > 2 and perm[-1] == L:
97
+ L = perm.pop() - 1
98
+ return perm
99
+
100
+
101
+ def has_bruhat_descent(perm, i, j):
102
+ if perm[i] < perm[j]:
103
+ return False
104
+ for p in range(i + 1, j):
105
+ if perm[i] > perm[p] and perm[p] > perm[j]:
106
+ return False
107
+ return True
108
+
109
+
110
+ def count_bruhat(perm, i, j):
111
+ up_amount = 0
112
+ if perm[i] < perm[j]:
113
+ up_amount = 1
114
+ else:
115
+ up_amount = -1
116
+ for k in range(i + 1, j):
117
+ if perm[i] < perm[k] and perm[k] < perm[j]:
118
+ up_amount += 2
119
+ elif perm[i] > perm[k] and perm[k] > perm[j]:
120
+ up_amount -= 2
121
+ return up_amount
122
+
123
+
124
+ def has_bruhat_ascent(perm, i, j):
125
+ if perm[i] > perm[j]:
126
+ return False
127
+ for p in range(i + 1, j):
128
+ if perm[i] < perm[p] and perm[p] < perm[j]:
129
+ return False
130
+ return True
131
+
132
+
133
+ def elem_sym_perms(orig_perm, p, k):
134
+ total_list = [(orig_perm, 0)]
135
+ up_perm_list = [(orig_perm, 1000000000)]
136
+ for pp in range(p):
137
+ perm_list = []
138
+ for up_perm, last in up_perm_list:
139
+ up_perm2 = [*up_perm, len(up_perm) + 1]
140
+ if len(up_perm2) < k + 1:
141
+ up_perm2 += [i + 1 for i in range(len(up_perm2), k + 2)]
142
+ pos_list = [i for i in range(k) if up_perm2[i] < last]
143
+ for j in range(k, len(up_perm2)):
144
+ if up_perm2[j] >= last:
145
+ continue
146
+ for i in pos_list:
147
+ if has_bruhat_ascent(up_perm2, i, j):
148
+ new_perm = [*up_perm2]
149
+ new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
150
+ if new_perm[-1] == len(new_perm):
151
+ new_perm_add = tuple(new_perm[:-1])
152
+ else:
153
+ new_perm_add = tuple(new_perm)
154
+ perm_list += [(new_perm_add, up_perm2[j])]
155
+ total_list += [(new_perm_add, pp + 1)]
156
+ up_perm_list = perm_list
157
+ return total_list
158
+
159
+
160
+ def elem_sym_perms_op(orig_perm, p, k):
161
+ total_list = [(orig_perm, 0)]
162
+ up_perm_list = [(orig_perm, k)]
163
+ for pp in range(p):
164
+ perm_list = []
165
+ for up_perm, last in up_perm_list:
166
+ up_perm2 = [*up_perm]
167
+ if len(up_perm2) < k + 1:
168
+ up_perm2 += [i + 1 for i in range(len(up_perm2), k + 2)]
169
+ pos_list = [i for i in range(k) if getpermval(up_perm2, i) == getpermval(orig_perm, i)]
170
+ for j in range(last, len(up_perm2)):
171
+ for i in pos_list:
172
+ if has_bruhat_descent(up_perm2, i, j):
173
+ new_perm = [*up_perm2]
174
+ new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
175
+ new_perm_add = tuple(permtrim(new_perm))
176
+ perm_list += [(new_perm_add, j)]
177
+ total_list += [(new_perm_add, pp + 1)]
178
+ up_perm_list = perm_list
179
+ return total_list
180
+
181
+
182
+ def strict_theta(u):
183
+ ret = [*trimcode(u)]
184
+ did_one = True
185
+ while did_one:
186
+ did_one = False
187
+ for i in range(len(ret) - 2, -1, -1):
188
+ if ret[i + 1] != 0 and ret[i] <= ret[i + 1]:
189
+ ret[i], ret[i + 1] = ret[i + 1] + 1, ret[i]
190
+ did_one = True
191
+ break
192
+ while len(ret) > 0 and ret[-1] == 0:
193
+ ret.pop()
194
+ return ret
195
+
196
+
197
+ def elem_sym_perms_q(orig_perm, p, k, q_var=q_var):
198
+ total_list = [(orig_perm, 0, 1)]
199
+ up_perm_list = [(orig_perm, 1, 1000)]
200
+ for pp in range(p):
201
+ perm_list = []
202
+ for up_perm, val, last_j in up_perm_list:
203
+ up_perm2 = [*up_perm, len(up_perm) + 1]
204
+ if len(up_perm2) < k + 1:
205
+ up_perm2 += [i + 1 for i in range(len(up_perm2), k + 2)]
206
+ pos_list = [i for i in range(k) if (i >= len(orig_perm) and up_perm2[i] == i + 1) or (i < len(orig_perm) and up_perm2[i] == orig_perm[i])]
207
+ for j in range(min(len(up_perm2) - 1, last_j), k - 1, -1):
208
+ for i in pos_list:
209
+ ct = count_bruhat(up_perm2, i, j)
210
+ # print(f"{up_perm2=} {ct=} {i=} {j=} {k=} {pp=}")
211
+ if ct == 1 or ct == 2 * (i - j) + 1:
212
+ new_perm = [*up_perm2]
213
+ new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
214
+ new_perm_add = tuple(permtrim(new_perm))
215
+ new_val = val
216
+ if ct < 0:
217
+ new_val *= np.prod([q_var[index] for index in range(i + 1, j + 1)])
218
+ perm_list += [(new_perm_add, new_val, j)]
219
+ total_list += [(new_perm_add, pp + 1, new_val)]
220
+ up_perm_list = perm_list
221
+ return total_list
222
+
223
+
224
+ def elem_sym_perms_q_op(orig_perm, p, k, n, q_var=q_var):
225
+ total_list = [(orig_perm, 0, 1)]
226
+ up_perm_list = [(orig_perm, 1, k)]
227
+ for pp in range(p):
228
+ perm_list = []
229
+ for up_perm, val, last_j in up_perm_list:
230
+ up_perm2 = [*up_perm]
231
+ if len(up_perm) < n:
232
+ up_perm2 += [i + 1 for i in range(len(up_perm2), n)]
233
+ pos_list = [i for i in range(k) if (i >= len(orig_perm) and up_perm2[i] == i + 1) or (i < len(orig_perm) and up_perm2[i] == orig_perm[i])]
234
+ for j in range(last_j, n):
235
+ for i in pos_list:
236
+ ct = count_bruhat(up_perm2, i, j)
237
+ # print(f"{up_perm2=} {ct=} {i=} {j=} {k=} {pp=}")
238
+ if ct == -1 or ct == 2 * (j - i) - 1:
239
+ new_perm = [*up_perm2]
240
+ new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
241
+ new_perm_add = tuple(permtrim(new_perm))
242
+ new_val = val
243
+ if ct > 0:
244
+ new_val *= np.prod([q_var[index] for index in range(i + 1, j + 1)])
245
+ perm_list += [(new_perm_add, new_val, j)]
246
+ total_list += [(new_perm_add, pp + 1, new_val)]
247
+ up_perm_list = perm_list
248
+ return total_list
249
+
250
+
251
+ def q_vector(q_exp, q_var=q_var):
252
+ qvar_list = q_var.tolist()
253
+ ret = []
254
+
255
+ if q_exp == 1:
256
+ return ret
257
+ if q_exp in q_var:
258
+ i = qvar_list.index(q_exp)
259
+ return [0 for j in range(i - 1)] + [1]
260
+ if isinstance(q_exp, Pow):
261
+ qv = q_exp.args[0]
262
+ expon = int(q_exp.args[1])
263
+ i = qvar_list.index(qv)
264
+ return [0 for j in range(i - 1)] + [expon]
265
+ if isinstance(q_exp, Mul):
266
+ for a in q_exp.args:
267
+ v1 = q_vector(a)
268
+ v1 += [0 for i in range(len(v1), len(ret))]
269
+ ret += [0 for i in range(len(ret), len(v1))]
270
+ ret = [ret[i] + v1[i] for i in range(len(ret))]
271
+ return ret
272
+
273
+ return None
274
+
275
+
276
+ def omega(i, qv):
277
+ i = i - 1
278
+ if len(qv) == 0 or i > len(qv):
279
+ return 0
280
+ if i == 0:
281
+ if len(qv) == 1:
282
+ return 2 * qv[0]
283
+ return 2 * qv[0] - qv[1]
284
+ if i == len(qv):
285
+ return -qv[-1]
286
+ if i == len(qv) - 1:
287
+ return 2 * qv[-1] - qv[-2]
288
+ return 2 * qv[i] - qv[i - 1] - qv[i + 1]
289
+
290
+
291
+ def sg(i, w):
292
+ if i >= len(w) - 1 or w[i] < w[i + 1]:
293
+ return 0
294
+ return 1
295
+
296
+
297
+ def reduce_q_coeff(u, v, w, qv):
298
+ for i in range(len(qv)):
299
+ if sg(i, v) == 1 and sg(i, u) == 0 and sg(i, w) + omega(i + 1, qv) == 1:
300
+ ret_v = [*v]
301
+ ret_v[i], ret_v[i + 1] = ret_v[i + 1], ret_v[i]
302
+ ret_w = [*w] + [j + 1 for j in range(len(w), i + 2)]
303
+ ret_w[i], ret_w[i + 1] = ret_w[i + 1], ret_w[i]
304
+ qv_ret = [*qv]
305
+ if sg(i, w) == 0:
306
+ qv_ret[i] -= 1
307
+ return u, tuple(permtrim(ret_v)), tuple(permtrim(ret_w)), qv_ret, True
308
+ if (sg(i, u) == 1 and sg(i, v) == 0 and sg(i, w) + omega(i + 1, qv) == 1) or (sg(i, u) == 1 and sg(i, v) == 1 and sg(i, w) + omega(i + 1, qv) == 2):
309
+ ret_u = [*u]
310
+ ret_u[i], ret_u[i + 1] = ret_u[i + 1], ret_u[i]
311
+ ret_w = [*w] + [j + 1 for j in range(len(w), i + 2)]
312
+ ret_w[i], ret_w[i + 1] = ret_w[i + 1], ret_w[i]
313
+ qv_ret = [*qv]
314
+ if sg(i, w) == 0:
315
+ qv_ret[i] -= 1
316
+ return tuple(permtrim(ret_u)), v, tuple(permtrim(ret_w)), qv_ret, True
317
+ return u, v, w, qv, False
318
+
319
+
320
+ def reduce_q_coeff_u_only(u, v, w, qv):
321
+ for i in range(len(qv)):
322
+ if (sg(i, u) == 1 and sg(i, v) == 0 and sg(i, w) + omega(i + 1, qv) == 1) or (sg(i, u) == 1 and sg(i, v) == 1 and sg(i, w) + omega(i + 1, qv) == 2):
323
+ ret_u = [*u]
324
+ ret_u[i], ret_u[i + 1] = ret_u[i + 1], ret_u[i]
325
+ ret_w = [*w] + [j + 1 for j in range(len(w), i + 2)]
326
+ ret_w[i], ret_w[i + 1] = ret_w[i + 1], ret_w[i]
327
+ qv_ret = [*qv]
328
+ if sg(i, w) == 0:
329
+ qv_ret[i] -= 1
330
+ return tuple(permtrim(ret_u)), v, tuple(permtrim(ret_w)), qv_ret, True
331
+ return u, v, w, qv, False
332
+
333
+
334
+ def longest_element(indices):
335
+ perm = [1, 2]
336
+ did_one = True
337
+ while did_one:
338
+ did_one = False
339
+ for i in range(len(indices)):
340
+ j = indices[i] - 1
341
+ if sg(j, perm) == 0:
342
+ if len(perm) < j + 2:
343
+ perm = perm + list(range(len(perm) + 1, j + 3))
344
+ perm[j], perm[j + 1] = perm[j + 1], perm[j]
345
+ did_one = True
346
+ return permtrim(perm)
347
+
348
+
349
+ def count_less_than(arr, val):
350
+ ct = 0
351
+ i = 0
352
+ while i < len(arr) and arr[i] < val:
353
+ i += 1
354
+ ct += 1
355
+ return ct
356
+
357
+
358
+ def is_parabolic(w, parabolic_index):
359
+ for i in parabolic_index:
360
+ if sg(i - 1, w) == 1:
361
+ return False
362
+ return True
363
+
364
+
365
+ def check_blocks(qv, parabolic_index):
366
+ blocks = []
367
+ cur_block = []
368
+ last_val = -1
369
+ for i in range(len(parabolic_index)):
370
+ if last_val == -1 or last_val + 1 == parabolic_index[i]:
371
+ last_val = parabolic_index[i]
372
+ cur_block += [last_val]
373
+ else:
374
+ blocks += [cur_block]
375
+ cur_block = []
376
+ for block in blocks:
377
+ for i in range(len(block)):
378
+ for j in range(i, len(block)):
379
+ val = 0
380
+ for k in range(i, j + 1):
381
+ val += omega(block[k], qv)
382
+ if val != 0 and val != -1:
383
+ return False
384
+ return True
385
+
386
+
387
+ # perms and inversion diff
388
+ def kdown_perms(perm, monoperm, p, k):
389
+ inv_m = inv(monoperm)
390
+ inv_p = inv(perm)
391
+ full_perm_list = []
392
+ print(f"LEGACY {perm=} {monoperm=} {inv_m=} {inv_p=} {mulperm(list(perm),monoperm)}")
393
+ if inv(mulperm(list(perm), monoperm)) == inv_m - inv_p:
394
+ full_perm_list += [(tuple(perm), 0, 1)]
395
+
396
+ down_perm_list = [(perm, 1)]
397
+ if len(perm) < k:
398
+ return full_perm_list
399
+ a2 = k - 1
400
+ for pp in range(1, p + 1):
401
+ down_perm_list2 = []
402
+ for perm2, s in down_perm_list:
403
+ L = len(perm2)
404
+ print(f"LEGACY {perm2=} {L=}")
405
+ if k > L:
406
+ continue
407
+ s2 = -s
408
+ for b in chain(range(k - 1), range(k, L)):
409
+ if perm2[b] != perm[b]:
410
+ continue
411
+ if b < a2:
412
+ i, j = b, a2
413
+ else:
414
+ i, j, s2 = a2, b, s
415
+ print(f"LEGACY {perm2=} {i=} {j=}")
416
+ if has_bruhat_descent(perm2, i, j):
417
+ print(f"LEGACY YEAH BABY {perm2=} {i=} {j=}")
418
+ new_perm = [*perm2]
419
+ new_perm[a2], new_perm[b] = new_perm[b], new_perm[a2]
420
+ permtrim(new_perm)
421
+ print(f"LEGACY {new_perm=}")
422
+ down_perm_list2 += [(new_perm, s2)]
423
+ if inv(mulperm(new_perm, monoperm)) == inv_m - inv_p + pp:
424
+ full_perm_list += [(tuple(new_perm), pp, s2)]
425
+ else:
426
+ print(f"LEGACY NO BABY {perm2=} {i=} {j=}")
427
+ down_perm_list = down_perm_list2
428
+ return full_perm_list
429
+
430
+
431
+ def compute_vpathdicts(th, vmu, smpify=False):
432
+ vpathdicts = [{} for index in range(len(th))]
433
+ vpathdicts[-1][tuple(vmu)] = None
434
+ thL = len(th)
435
+
436
+ top = code(inverse(uncode(th)))
437
+ for i in range(thL - 1, -1, -1):
438
+ top2 = code(inverse(uncode(top)))
439
+ while top2[-1] == 0:
440
+ top2.pop()
441
+ top2.pop()
442
+ top = code(inverse(uncode(top2)))
443
+ monoperm = uncode(top)
444
+ if len(monoperm) < 2:
445
+ monoperm = [1, 2]
446
+ k = i + 1
447
+ for last_perm in vpathdicts[i]:
448
+ newperms = kdown_perms(last_perm, monoperm, th[i], k)
449
+ vpathdicts[i][last_perm] = newperms
450
+ if i > 0:
451
+ for trip in newperms:
452
+ vpathdicts[i - 1][trip[0]] = None
453
+ vpathdicts2 = [{} for i in range(len(th))]
454
+ for i in range(len(th)):
455
+ for key, valueset in vpathdicts[i].items():
456
+ for value in valueset:
457
+ key2 = value[0]
458
+ if key2 not in vpathdicts2[i]:
459
+ vpathdicts2[i][key2] = set()
460
+ v2 = value[2]
461
+ if smpify:
462
+ v2 = sympify(v2)
463
+ vpathdicts2[i][key2].add((key, value[1], v2))
464
+ # print(vpathdicts2)
465
+ return vpathdicts2
466
+
467
+
468
+ def theta(perm):
469
+ cd = code(perm)
470
+ for i in range(len(cd) - 1, 0, -1):
471
+ for j in range(i - 1, -1, -1):
472
+ if cd[j] < cd[i]:
473
+ cd[i] += 1
474
+ cd.sort(reverse=True)
475
+ return cd
476
+
477
+
478
+ def add_perm_dict(d1, d2):
479
+ for k, v in d2.items():
480
+ d1[k] = d1.get(k, 0) + v
481
+ return d1
482
+
483
+
484
+ one = sympify(1)
485
+
486
+
487
+ def elem_sym_poly_q(p, k, varl1, varl2, q_var=q_var):
488
+ if p == 0 and k >= 0:
489
+ return one
490
+ if p < 0 or p > k:
491
+ return zero
492
+ return (
493
+ (varl1[k - 1] - varl2[k - p]) * elem_sym_poly_q(p - 1, k - 1, varl1, varl2, q_var)
494
+ + elem_sym_poly_q(p, k - 1, varl1, varl2, q_var)
495
+ + q_var[k - 1] * elem_sym_poly_q(p - 2, k - 2, varl1, varl2, q_var)
496
+ )
497
+
498
+
499
+ def elem_sym_poly(p, k, varl1, varl2, xstart=0, ystart=0):
500
+ if p > k:
501
+ return zero
502
+ if p == 0:
503
+ return one
504
+ if p == 1:
505
+ res = varl1[xstart] - varl2[ystart]
506
+ for i in range(1, k):
507
+ res += varl1[xstart + i] - varl2[ystart + i]
508
+ return res
509
+ if p == k:
510
+ res = (varl1[xstart] - varl2[ystart]) * (varl1[xstart + 1] - varl2[ystart])
511
+ for i in range(2, k):
512
+ res *= varl1[i + xstart] - varl2[ystart]
513
+ return res
514
+ mid = k // 2
515
+ xsm = xstart + mid
516
+ ysm = ystart + mid
517
+ kmm = k - mid
518
+ res = elem_sym_poly(p, mid, varl1, varl2, xstart, ystart) + elem_sym_poly(
519
+ p,
520
+ kmm,
521
+ varl1,
522
+ varl2,
523
+ xsm,
524
+ ysm,
525
+ )
526
+ for p2 in range(max(1, p - kmm), min(p, mid + 1)):
527
+ res += elem_sym_poly(p2, mid, varl1, varl2, xstart, ystart) * elem_sym_poly(
528
+ p - p2,
529
+ kmm,
530
+ varl1,
531
+ varl2,
532
+ xsm,
533
+ ysm - p2,
534
+ )
535
+ return res
536
+
537
+
538
+ @cache
539
+ def call_zvars(v1, v2, k, i): # noqa: ARG001
540
+ v3 = [*v2, *list(range(len(v2) + 1, i + 1))]
541
+ return [v3[i - 1]] + [v3[j] for j in range(len(v1), len(v3)) if v3[j] != j + 1 and j != i - 1] + [v3[j] for j in range(len(v1)) if v1[j] != v3[j] and j != i - 1]
542
+
543
+
544
+ def elem_sym_func(k, i, u1, u2, v1, v2, udiff, vdiff, varl1, varl2):
545
+ newk = k - udiff
546
+ if newk < vdiff:
547
+ return zero
548
+ if newk == vdiff:
549
+ return one
550
+ yvars = []
551
+ for j in range(min(len(u1), k)):
552
+ if u1[j] == u2[j]:
553
+ yvars += [varl1[u2[j]]]
554
+ for j in range(len(u1), min(k, len(u2))):
555
+ if u2[j] == j + 1:
556
+ yvars += [varl1[u2[j]]]
557
+ for j in range(len(u2), k):
558
+ yvars += [varl1[j + 1]]
559
+ zvars = [varl2[i] for i in call_zvars(v1, v2, k, i)]
560
+ return elem_sym_poly(newk - vdiff, newk, yvars, zvars)
561
+
562
+
563
+ def elem_sym_func_q(k, i, u1, u2, v1, v2, udiff, vdiff, varl1, varl2):
564
+ newk = k - udiff
565
+ if newk < vdiff:
566
+ return zero
567
+ if newk == vdiff:
568
+ return one
569
+ yvars = []
570
+ mlen = max(len(u1), len(u2))
571
+ u1 = [*u1] + [a + 1 for a in range(len(u1), mlen)]
572
+ u2 = [*u2] + [a + 1 for a in range(len(u2), mlen)]
573
+ for j in range(min(len(u1), k)):
574
+ if u1[j] == u2[j]:
575
+ yvars += [varl1[u2[j]]]
576
+ for j in range(len(u1), min(k, len(u2))):
577
+ if u2[j] == j + 1:
578
+ yvars += [varl1[u2[j]]]
579
+ for j in range(len(u2), k):
580
+ yvars += [varl1[j + 1]]
581
+ zvars = [varl2[a] for a in call_zvars(v1, v2, k, i)]
582
+ return elem_sym_poly(newk - vdiff, newk, yvars, zvars)
583
+
584
+
585
+ def trimcode(perm):
586
+ cd = code(perm)
587
+ while len(cd) > 0 and cd[-1] == 0:
588
+ cd.pop()
589
+ return cd
590
+
591
+
592
+ def p_trans(part):
593
+ newpart = []
594
+ if len(part) == 0 or part[0] == 0:
595
+ return [0]
596
+ for i in range(1, part[0] + 1):
597
+ cnt = 0
598
+ for j in range(len(part)):
599
+ if part[j] >= i:
600
+ cnt += 1
601
+ if cnt == 0:
602
+ break
603
+ newpart += [cnt]
604
+ return newpart
605
+
606
+
607
+ def cycle(p, q):
608
+ return list(range(1, p)) + [i + 1 for i in range(p, p + q)] + [p]
609
+
610
+
611
+ def phi1(u):
612
+ c_star = code(inverse(u))
613
+ c_star.pop(0)
614
+ return inverse(uncode(c_star))
615
+
616
+
617
+ def one_dominates(u, w):
618
+ c_star_u = code(inverse(u))
619
+ c_star_w = code(inverse(w))
620
+
621
+ a = c_star_u[0]
622
+ b = c_star_w[0]
623
+
624
+ for i in range(a, b):
625
+ if i >= len(u) - 1:
626
+ return True
627
+ if u[i] > u[i + 1]:
628
+ return False
629
+ return True
630
+
631
+
632
+ def dominates(u, w):
633
+ u2 = [*u]
634
+ w2 = [*w]
635
+ while u2 != [1, 2] and one_dominates(u2, w2):
636
+ u2 = phi1(u2)
637
+ w2 = phi1(w2)
638
+ if u2 == [1, 2]:
639
+ return True
640
+ return False
641
+
642
+
643
+ def reduce_coeff(u, v, w):
644
+ t_mu_u_t = theta(inverse(u))
645
+ t_mu_v_t = theta(inverse(v))
646
+
647
+ mu_u_inv = uncode(t_mu_u_t)
648
+ mu_v_inv = uncode(t_mu_v_t)
649
+
650
+ t_mu_u = p_trans(t_mu_u_t)
651
+ t_mu_v = p_trans(t_mu_v_t)
652
+
653
+ t_mu_u += [0 for i in range(len(t_mu_u), max(len(t_mu_u), len(t_mu_v)))]
654
+ t_mu_v += [0 for i in range(len(t_mu_v), max(len(t_mu_u), len(t_mu_v)))]
655
+
656
+ t_mu_uv = [t_mu_u[i] + t_mu_v[i] for i in range(len(t_mu_u))]
657
+ t_mu_uv_t = p_trans(t_mu_uv)
658
+
659
+ mu_uv_inv = uncode(t_mu_uv_t)
660
+
661
+ if inv(mulperm(list(w), mu_uv_inv)) != inv(mu_uv_inv) - inv(w):
662
+ return u, v, w
663
+
664
+ umu = mulperm(list(u), mu_u_inv)
665
+ vmu = mulperm(list(v), mu_v_inv)
666
+ wmu = mulperm(list(w), mu_uv_inv)
667
+
668
+ t_mu_w = theta(inverse(wmu))
669
+
670
+ mu_w = uncode(t_mu_w)
671
+
672
+ w_prime = mulperm(wmu, mu_w)
673
+
674
+ if permtrim(list(w)) == permtrim(w_prime):
675
+ return (permtrim(list(u)), permtrim(list(v)), permtrim(list(w)))
676
+
677
+ A = []
678
+ B = []
679
+ indexA = 0
680
+
681
+ while len(t_mu_u_t) > 0 and t_mu_u_t[-1] == 0:
682
+ t_mu_u_t.pop()
683
+
684
+ while len(t_mu_v_t) > 0 and t_mu_v_t[-1] == 0:
685
+ t_mu_v_t.pop()
686
+
687
+ while len(t_mu_uv_t) > 0 and t_mu_uv_t[-1] == 0:
688
+ t_mu_uv_t.pop()
689
+
690
+ for index in range(len(t_mu_uv_t)):
691
+ if indexA < len(t_mu_u_t) and t_mu_uv_t[index] == t_mu_u_t[indexA]:
692
+ A += [index]
693
+ indexA += 1
694
+ else:
695
+ B += [index]
696
+
697
+ mu_w_A = uncode(mu_A(code(mu_w), A))
698
+ mu_w_B = uncode(mu_A(code(mu_w), B))
699
+
700
+ return (
701
+ permtrim(mulperm(umu, mu_w_A)),
702
+ permtrim(mulperm(vmu, mu_w_B)),
703
+ permtrim(w_prime),
704
+ )
705
+
706
+
707
+ def mu_A(mu, A):
708
+ mu_t = p_trans(mu)
709
+ mu_A_t = []
710
+ for i in range(len(A)):
711
+ if A[i] < len(mu_t):
712
+ mu_A_t += [mu_t[A[i]]]
713
+ return p_trans(mu_A_t)
714
+
715
+
716
+ def reduce_descents(u, v, w):
717
+ u2 = [*u]
718
+ v2 = [*v]
719
+ w2 = [*w]
720
+ found_one = True
721
+ while found_one:
722
+ found_one = False
723
+ if will_formula_work(u2, v2) or will_formula_work(v2, u2) or one_dominates(u2, w2) or is_reducible(v2) or inv(w2) - inv(u2) == 1:
724
+ break
725
+ for i in range(len(w2) - 2, -1, -1):
726
+ if w2[i] > w2[i + 1] and i < len(v2) - 1 and v2[i] > v2[i + 1] and (i >= len(u2) - 1 or u2[i] < u2[i + 1]):
727
+ w2[i], w2[i + 1] = w2[i + 1], w2[i]
728
+ v2[i], v2[i + 1] = v2[i + 1], v2[i]
729
+ found_one = True
730
+ elif w2[i] > w2[i + 1] and i < len(u2) - 1 and u2[i] > u2[i + 1] and (i >= len(v2) - 1 or v2[i] < v2[i + 1]):
731
+ w2[i], w2[i + 1] = w2[i + 1], w2[i]
732
+ u2[i], u2[i + 1] = u2[i + 1], u2[i]
733
+ found_one = True
734
+ if found_one:
735
+ break
736
+ return permtrim(u2), permtrim(v2), permtrim(w2)
737
+
738
+
739
+ def is_reducible(v):
740
+ c03 = code(v)
741
+ found0 = False
742
+ good = True
743
+ for i in range(len(c03)):
744
+ if c03[i] == 0:
745
+ found0 = True
746
+ elif c03[i] != 0 and found0:
747
+ good = False
748
+ break
749
+ return good
750
+
751
+
752
+ def try_reduce_v(u, v, w):
753
+ if is_reducible(v):
754
+ return tuple(permtrim([*u])), tuple(permtrim([*v])), tuple(permtrim([*w]))
755
+ u2 = [*u]
756
+ v2 = [*v]
757
+ w2 = [*w]
758
+ cv = code(v2)
759
+ for i in range(len(v2) - 2, -1, -1):
760
+ if cv[i] == 0 and i < len(cv) - 1 and cv[i + 1] != 0:
761
+ if i >= len(u2) - 1 or u2[i] < u2[i + 1]:
762
+ v2[i], v2[i + 1] = v2[i + 1], v2[i]
763
+ if i >= len(w2) - 1:
764
+ w2 += list(range(len(w2) + 1, i + 3))
765
+ w2[i + 1], w2[i] = w2[i], w2[i + 1]
766
+ if is_reducible(v2):
767
+ return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
768
+ return try_reduce_v(u2, v2, w2)
769
+ if i < len(w2) - 1 and w2[i] > w2[i + 1]:
770
+ u2[i], u2[i + 1] = u2[i + 1], u2[i]
771
+ v2[i], v2[i + 1] = v2[i + 1], v2[i]
772
+ return try_reduce_v(u2, v2, w2)
773
+ return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
774
+ return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
775
+
776
+
777
+ def try_reduce_u(u, v, w):
778
+ if one_dominates(u, w):
779
+ return u, v, w
780
+ u2 = [*u]
781
+ v2 = [*v]
782
+ w2 = [*w]
783
+ cu = code(u)
784
+ for i in range(len(u2) - 2, -1, -1):
785
+ if cu[i] == 0 and i < len(cu) - 1 and cu[i + 1] != 0:
786
+ if i >= len(v2) - 1 or v2[i] < v2[i + 1]:
787
+ u2[i], u2[i + 1] = u2[i + 1], u2[i]
788
+ if i > len(w2) - 1:
789
+ w2 += list(range(len(w2) + 1, i + 3))
790
+ w2[i + 1], w2[i] = w2[i], w2[i + 1]
791
+ if one_dominates(u, w):
792
+ return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
793
+ return try_reduce_u(u2, v2, w2)
794
+ if i < len(w2) - 1 and w2[i] > w2[i + 1]:
795
+ u2[i], u2[i + 1] = u2[i + 1], u2[i]
796
+ v2[i], v2[i + 1] = v2[i + 1], v2[i]
797
+ return try_reduce_u(u2, v2, w2)
798
+ return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
799
+ return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
800
+
801
+
802
+ def divdiffable(v, u):
803
+ inv_v = inv(v)
804
+ inv_u = inv(u)
805
+ perm2 = permtrim(mulperm(v, inverse(u)))
806
+ if inv(perm2) != inv_v - inv_u:
807
+ return []
808
+ return perm2
809
+
810
+
811
+ def will_formula_work(u, v):
812
+ muv = uncode(theta(v))
813
+ vn1muv = mulperm(inverse(v), muv)
814
+ while True:
815
+ found_one = False
816
+ for i in range(len(vn1muv) - 1):
817
+ if vn1muv[i] > vn1muv[i + 1]:
818
+ found_one = True
819
+ if i < len(u) - 1 and u[i] > u[i + 1]:
820
+ return False
821
+ vn1muv[i], vn1muv[i + 1] = vn1muv[i + 1], vn1muv[i]
822
+ break
823
+ if not found_one:
824
+ return True
825
+
826
+
827
+ def pull_out_var(vnum, v):
828
+ vup = [*v, len(v) + 1]
829
+ if vnum >= len(v):
830
+ return [[[], v]]
831
+ vpm_list = [(vup, 0)]
832
+ ret_list = []
833
+ for p in range(len(v) + 1 - vnum):
834
+ vpm_list2 = []
835
+ for vpm, b in vpm_list:
836
+ if vpm[vnum - 1] == len(v) + 1:
837
+ vpm2 = [*vpm]
838
+ vpm2.pop(vnum - 1)
839
+ vp = permtrim(vpm2)
840
+ ret_list += [
841
+ [
842
+ [v[i] for i in range(vnum, len(v)) if ((i > len(vp) and v[i] == i) or (i <= len(vp) and v[i] == vp[i - 1]))],
843
+ vp,
844
+ ],
845
+ ]
846
+ for j in range(vnum, len(vup)):
847
+ if vpm[j] <= b:
848
+ continue
849
+ for i in range(vnum):
850
+ if has_bruhat_ascent(vpm, i, j):
851
+ vpm[i], vpm[j] = vpm[j], vpm[i]
852
+ vpm_list2 += [([*vpm], vpm[i])]
853
+ vpm[i], vpm[j] = vpm[j], vpm[i]
854
+ vpm_list = vpm_list2
855
+ for vpm, b in vpm_list:
856
+ if vpm[vnum - 1] == len(v) + 1:
857
+ vpm2 = [*vpm]
858
+ vpm2.pop(vnum - 1)
859
+ vp = permtrim(vpm2)
860
+ ret_list += [
861
+ [
862
+ [v[i] for i in range(vnum, len(v)) if ((i > len(vp) and v[i] == i) or (i <= len(vp) and v[i] == vp[i - 1]))],
863
+ vp,
864
+ ],
865
+ ]
866
+ return ret_list
867
+
868
+
869
+ def get_cycles(perm):
870
+ cycle_set = []
871
+ done_vals = set()
872
+ for i in range(len(perm)):
873
+ p = i + 1
874
+ if perm[i] == p:
875
+ continue
876
+ if p in done_vals:
877
+ continue
878
+ cycle = []
879
+ m = -1
880
+ max_index = -1
881
+ while p not in done_vals:
882
+ cycle += [p]
883
+ done_vals.add(p)
884
+ if p > m:
885
+ m = p
886
+ max_index = len(cycle) - 1
887
+ p = perm[p - 1]
888
+ cycle = tuple(cycle[max_index + 1 :] + cycle[: max_index + 1])
889
+ cycle_set += [cycle]
890
+ return cycle_set
891
+
892
+
893
+ def double_elem_sym_q(u, p1, p2, k, q_var=q_var):
894
+ ret_list = {}
895
+ perms1 = elem_sym_perms_q(u, p1, k, q_var)
896
+ iu = inverse(u)
897
+ for perm1, udiff1, mul_val1 in perms1:
898
+ perms2 = elem_sym_perms_q(perm1, p2, k, q_var)
899
+ cycles1 = get_cycles(tuple(permtrim(mulperm(iu, [*perm1]))))
900
+ cycles1_dict = {}
901
+ for c in cycles1:
902
+ if c[-1] not in cycles1_dict:
903
+ cycles1_dict[c[-1]] = []
904
+ cycles1_dict[c[-1]] += [set(c)]
905
+ ip1 = inverse(perm1)
906
+ for perm2, udiff2, mul_val2 in perms2:
907
+ cycles2 = get_cycles(tuple(permtrim(mulperm(ip1, [*perm2]))))
908
+ good = True
909
+ for i in range(len(cycles2)):
910
+ c2 = cycles2[i]
911
+ if c2[-1] not in cycles1_dict:
912
+ continue
913
+ for c1_s in cycles1_dict[c2[-1]]:
914
+ for a in range(len(c2) - 2, -1, -1):
915
+ if c2[a] in c1_s:
916
+ good = False
917
+ break
918
+ if not good:
919
+ break
920
+ if not good:
921
+ break
922
+
923
+ if good:
924
+ # print(f"{(perm1, udiff1, mul_val1)=}")
925
+ if (perm1, udiff1, mul_val1) not in ret_list:
926
+ ret_list[(perm1, udiff1, mul_val1)] = []
927
+ ret_list[(perm1, udiff1, mul_val1)] += [(perm2, udiff2, mul_val2)]
928
+ return ret_list
929
+
930
+
931
+ def medium_theta(perm):
932
+ cd = code(perm)
933
+ found_one = True
934
+ while found_one:
935
+ found_one = False
936
+ for i in range(len(cd) - 1):
937
+ if cd[i] < cd[i + 1]:
938
+ found_one = True
939
+ cd[i], cd[i + 1] = cd[i + 1] + 1, cd[i]
940
+ break
941
+ if cd[i] == cd[i + 1] and cd[i] != 0 and i > 0 and cd[i - 1] <= cd[i] + 1:
942
+ # if cd[i]==cd[i+1] and i>0 and cd[i-1]<=cd[i]+1:
943
+ cd[i] += 1
944
+ found_one = True
945
+ break
946
+ return cd