schubmult 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. schubmult/__init__.py +1 -0
  2. schubmult/_base_argparse.py +174 -0
  3. schubmult/perm_lib.py +999 -0
  4. schubmult/sage_integration/__init__.py +25 -0
  5. schubmult/sage_integration/_fast_double_schubert_polynomial_ring.py +528 -0
  6. schubmult/sage_integration/_fast_schubert_polynomial_ring.py +356 -0
  7. schubmult/sage_integration/_indexing.py +44 -0
  8. schubmult/schubmult_double/__init__.py +18 -0
  9. schubmult/schubmult_double/__main__.py +5 -0
  10. schubmult/schubmult_double/_funcs.py +1590 -0
  11. schubmult/schubmult_double/_script.py +407 -0
  12. schubmult/schubmult_double/_vars.py +16 -0
  13. schubmult/schubmult_py/__init__.py +10 -0
  14. schubmult/schubmult_py/__main__.py +5 -0
  15. schubmult/schubmult_py/_funcs.py +111 -0
  16. schubmult/schubmult_py/_script.py +115 -0
  17. schubmult/schubmult_py/_vars.py +3 -0
  18. schubmult/schubmult_q/__init__.py +12 -0
  19. schubmult/schubmult_q/__main__.py +5 -0
  20. schubmult/schubmult_q/_funcs.py +304 -0
  21. schubmult/schubmult_q/_script.py +157 -0
  22. schubmult/schubmult_q/_vars.py +18 -0
  23. schubmult/schubmult_q_double/__init__.py +14 -0
  24. schubmult/schubmult_q_double/__main__.py +5 -0
  25. schubmult/schubmult_q_double/_funcs.py +507 -0
  26. schubmult/schubmult_q_double/_script.py +337 -0
  27. schubmult/schubmult_q_double/_vars.py +21 -0
  28. schubmult-2.0.0.dist-info/METADATA +455 -0
  29. schubmult-2.0.0.dist-info/RECORD +36 -0
  30. schubmult-2.0.0.dist-info/WHEEL +5 -0
  31. schubmult-2.0.0.dist-info/entry_points.txt +5 -0
  32. schubmult-2.0.0.dist-info/licenses/LICENSE +674 -0
  33. schubmult-2.0.0.dist-info/top_level.txt +2 -0
  34. tests/__init__.py +0 -0
  35. tests/test_fast_double_schubert.py +145 -0
  36. tests/test_fast_schubert.py +38 -0
schubmult/perm_lib.py ADDED
@@ -0,0 +1,999 @@
1
+ from symengine import sympify, Mul, Pow, symarray
2
+ from functools import cache
3
+ from itertools import chain
4
+ from bisect import bisect_left
5
+ import numpy as np
6
+
7
+ zero = sympify(0)
8
+ n = 100
9
+
10
+ q_var = symarray("q", n)
11
+
12
+
13
+ def getpermval(perm, index):
14
+ if index < len(perm):
15
+ return perm[index]
16
+ return index + 1
17
+
18
+
19
+ def inv(perm):
20
+ L = len(perm)
21
+ v = [i for i in range(1, L + 1)]
22
+ ans = 0
23
+ for i in range(L):
24
+ itr = bisect_left(v, perm[i])
25
+ ans += itr
26
+ v = v[:itr] + v[itr + 1 :]
27
+ return ans
28
+
29
+
30
+ def code(perm):
31
+ L = len(perm)
32
+ ret = []
33
+ v = [i for i in range(1, L + 1)]
34
+ for i in range(L - 1):
35
+ itr = bisect_left(v, perm[i])
36
+ ret += [itr]
37
+ v = v[:itr] + v[itr + 1 :]
38
+ return ret
39
+
40
+
41
+ def mulperm(perm1, perm2):
42
+ if len(perm1) < len(perm2):
43
+ return [
44
+ perm1[perm2[i] - 1] if perm2[i] <= len(perm1) else perm2[i] for i in range(len(perm2))
45
+ ]
46
+ else:
47
+ return [perm1[perm2[i] - 1] for i in range(len(perm2))] + perm1[len(perm2) :]
48
+
49
+
50
+ def uncode(cd):
51
+ cd2 = [*cd]
52
+ if cd2 == []:
53
+ return [1, 2]
54
+ max_required = max([cd2[i] + i for i in range(len(cd2))])
55
+ cd2 += [0 for i in range(len(cd2), max_required)]
56
+ fullperm = [i + 1 for i in range(len(cd2) + 1)]
57
+ perm = []
58
+ for i in range(len(cd2)):
59
+ perm += [fullperm.pop(cd2[i])]
60
+ perm += [fullperm[0]]
61
+ return perm
62
+
63
+
64
+ def reversecode(perm):
65
+ ret = []
66
+ for i in range(len(perm) - 1, 0, -1):
67
+ ret = [0] + ret
68
+ for j in range(i, -1, -1):
69
+ if perm[i] > perm[j]:
70
+ ret[-1] += 1
71
+ return ret
72
+
73
+
74
+ def reverseuncode(cd):
75
+ cd2 = list(cd)
76
+ if cd2 == []:
77
+ return [1, 2]
78
+ # max_required = max([cd2[i]+i for i in range(len(cd2))])
79
+ # cd2 += [0 for i in range(len(cd2),max_required)]
80
+ fullperm = [i + 1 for i in range(len(cd2) + 1)]
81
+ perm = []
82
+ for i in range(len(cd2) - 1, 0, -1):
83
+ perm = [fullperm[cd2[i]]] + perm
84
+ fullperm.pop(cd2[i])
85
+ perm += [fullperm[0]]
86
+ return perm
87
+
88
+
89
+ def inverse(perm):
90
+ retperm = [0 for i in range(len(perm))]
91
+ for i in range(len(perm)):
92
+ retperm[perm[i] - 1] = i + 1
93
+ return retperm
94
+
95
+
96
+ def permtrim(perm):
97
+ L = len(perm)
98
+ while L > 2 and perm[-1] == L:
99
+ L = perm.pop() - 1
100
+ return perm
101
+
102
+
103
+ def has_bruhat_descent(perm, i, j):
104
+ if perm[i] < perm[j]:
105
+ return False
106
+ for p in range(i + 1, j):
107
+ if perm[i] > perm[p] and perm[p] > perm[j]:
108
+ return False
109
+ return True
110
+
111
+
112
+ def count_bruhat(perm, i, j):
113
+ up_amount = 0
114
+ if perm[i] < perm[j]:
115
+ up_amount = 1
116
+ else:
117
+ up_amount = -1
118
+ for k in range(i + 1, j):
119
+ if perm[i] < perm[k] and perm[k] < perm[j]:
120
+ up_amount += 2
121
+ elif perm[i] > perm[k] and perm[k] > perm[j]:
122
+ up_amount -= 2
123
+ return up_amount
124
+
125
+
126
+ def has_bruhat_ascent(perm, i, j):
127
+ if perm[i] > perm[j]:
128
+ return False
129
+ for p in range(i + 1, j):
130
+ if perm[i] < perm[p] and perm[p] < perm[j]:
131
+ return False
132
+ return True
133
+
134
+
135
+ def elem_sym_perms(orig_perm, p, k):
136
+ total_list = [(orig_perm, 0)]
137
+ up_perm_list = [(orig_perm, 1000000000)]
138
+ for pp in range(p):
139
+ perm_list = []
140
+ for up_perm, last in up_perm_list:
141
+ up_perm2 = [*up_perm, len(up_perm) + 1]
142
+ if len(up_perm2) < k + 1:
143
+ up_perm2 += [i + 1 for i in range(len(up_perm2), k + 2)]
144
+ pos_list = [i for i in range(k) if up_perm2[i] < last]
145
+ for j in range(k, len(up_perm2)):
146
+ if up_perm2[j] >= last:
147
+ continue
148
+ for i in pos_list:
149
+ if has_bruhat_ascent(up_perm2, i, j):
150
+ new_perm = [*up_perm2]
151
+ new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
152
+ if new_perm[-1] == len(new_perm):
153
+ new_perm_add = tuple(new_perm[:-1])
154
+ else:
155
+ new_perm_add = tuple(new_perm)
156
+ perm_list += [(new_perm_add, up_perm2[j])]
157
+ total_list += [(new_perm_add, pp + 1)]
158
+ up_perm_list = perm_list
159
+ return total_list
160
+
161
+
162
+ def elem_sym_perms_op(orig_perm, p, k):
163
+ total_list = [(orig_perm, 0)]
164
+ up_perm_list = [(orig_perm, k)]
165
+ for pp in range(p):
166
+ perm_list = []
167
+ for up_perm, last in up_perm_list:
168
+ up_perm2 = [*up_perm]
169
+ if len(up_perm2) < k + 1:
170
+ up_perm2 += [i + 1 for i in range(len(up_perm2), k + 2)]
171
+ pos_list = [i for i in range(k) if getpermval(up_perm2, i) == getpermval(orig_perm, i)]
172
+ for j in range(last, len(up_perm2)):
173
+ for i in pos_list:
174
+ if has_bruhat_descent(up_perm2, i, j):
175
+ new_perm = [*up_perm2]
176
+ new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
177
+ new_perm_add = tuple(permtrim(new_perm))
178
+ perm_list += [(new_perm_add, j)]
179
+ total_list += [(new_perm_add, pp + 1)]
180
+ up_perm_list = perm_list
181
+ return total_list
182
+
183
+
184
+ def strict_theta(u):
185
+ ret = [*trimcode(u)]
186
+ did_one = True
187
+ while did_one:
188
+ did_one = False
189
+ for i in range(len(ret) - 2, -1, -1):
190
+ if ret[i + 1] != 0 and ret[i] <= ret[i + 1]:
191
+ ret[i], ret[i + 1] = ret[i + 1] + 1, ret[i]
192
+ did_one = True
193
+ break
194
+ while len(ret) > 0 and ret[-1] == 0:
195
+ ret.pop()
196
+ return ret
197
+
198
+
199
+ def elem_sym_perms_q(orig_perm, p, k, q_var=q_var):
200
+ total_list = [(orig_perm, 0, 1)]
201
+ up_perm_list = [(orig_perm, 1, 1000)]
202
+ for pp in range(p):
203
+ perm_list = []
204
+ for up_perm, val, last_j in up_perm_list:
205
+ up_perm2 = [*up_perm, len(up_perm) + 1]
206
+ if len(up_perm2) < k + 1:
207
+ up_perm2 += [i + 1 for i in range(len(up_perm2), k + 2)]
208
+ pos_list = [
209
+ i
210
+ for i in range(k)
211
+ if (i >= len(orig_perm) and up_perm2[i] == i + 1)
212
+ or (i < len(orig_perm) and up_perm2[i] == orig_perm[i])
213
+ ]
214
+ for j in range(min(len(up_perm2) - 1, last_j), k - 1, -1):
215
+ for i in pos_list:
216
+ ct = count_bruhat(up_perm2, i, j)
217
+ # print(f"{up_perm2=} {ct=} {i=} {j=} {k=} {pp=}")
218
+ if ct == 1 or ct == 2 * (i - j) + 1:
219
+ new_perm = [*up_perm2]
220
+ new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
221
+ new_perm_add = tuple(permtrim(new_perm))
222
+ new_val = val
223
+ if ct < 0:
224
+ new_val *= np.prod([q_var[index] for index in range(i + 1, j + 1)])
225
+ perm_list += [(new_perm_add, new_val, j)]
226
+ total_list += [(new_perm_add, pp + 1, new_val)]
227
+ up_perm_list = perm_list
228
+ return total_list
229
+
230
+
231
+ def elem_sym_perms_q_op(orig_perm, p, k, n, q_var=q_var):
232
+ total_list = [(orig_perm, 0, 1)]
233
+ up_perm_list = [(orig_perm, 1, k)]
234
+ for pp in range(p):
235
+ perm_list = []
236
+ for up_perm, val, last_j in up_perm_list:
237
+ up_perm2 = [*up_perm]
238
+ if len(up_perm) < n:
239
+ up_perm2 += [i + 1 for i in range(len(up_perm2), n)]
240
+ pos_list = [
241
+ i
242
+ for i in range(k)
243
+ if (i >= len(orig_perm) and up_perm2[i] == i + 1)
244
+ or (i < len(orig_perm) and up_perm2[i] == orig_perm[i])
245
+ ]
246
+ for j in range(last_j, n):
247
+ for i in pos_list:
248
+ ct = count_bruhat(up_perm2, i, j)
249
+ # print(f"{up_perm2=} {ct=} {i=} {j=} {k=} {pp=}")
250
+ if ct == -1 or ct == 2 * (j - i) - 1:
251
+ new_perm = [*up_perm2]
252
+ new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
253
+ new_perm_add = tuple(permtrim(new_perm))
254
+ new_val = val
255
+ if ct > 0:
256
+ new_val *= np.prod([q_var[index] for index in range(i + 1, j + 1)])
257
+ perm_list += [(new_perm_add, new_val, j)]
258
+ total_list += [(new_perm_add, pp + 1, new_val)]
259
+ up_perm_list = perm_list
260
+ return total_list
261
+
262
+
263
+ def q_vector(q_exp, q_var=q_var):
264
+ qvar_list = q_var.tolist()
265
+ ret = []
266
+
267
+ if q_exp == 1:
268
+ return ret
269
+ if q_exp in q_var:
270
+ i = qvar_list.index(q_exp)
271
+ ret = [0 for j in range(i - 1)] + [1]
272
+ return ret
273
+ if isinstance(q_exp, Pow):
274
+ qv = q_exp.args[0]
275
+ expon = int(q_exp.args[1])
276
+ i = qvar_list.index(qv)
277
+ ret = [0 for j in range(i - 1)] + [expon]
278
+ return ret
279
+ if isinstance(q_exp, Mul):
280
+ for a in q_exp.args:
281
+ v1 = q_vector(a)
282
+ v1 += [0 for i in range(len(v1), len(ret))]
283
+ ret += [0 for i in range(len(ret), len(v1))]
284
+ ret = [ret[i] + v1[i] for i in range(len(ret))]
285
+ return ret
286
+
287
+ return None
288
+
289
+
290
+ def omega(i, qv):
291
+ i = i - 1
292
+ if len(qv) == 0 or i > len(qv):
293
+ return 0
294
+ if i == 0:
295
+ if len(qv) == 1:
296
+ return 2 * qv[0]
297
+ return 2 * qv[0] - qv[1]
298
+ if i == len(qv):
299
+ return -qv[-1]
300
+ if i == len(qv) - 1:
301
+ return 2 * qv[-1] - qv[-2]
302
+ return 2 * qv[i] - qv[i - 1] - qv[i + 1]
303
+
304
+
305
+ def sg(i, w):
306
+ if i >= len(w) - 1 or w[i] < w[i + 1]:
307
+ return 0
308
+ return 1
309
+
310
+
311
+ def reduce_q_coeff(u, v, w, qv):
312
+ for i in range(len(qv)):
313
+ if sg(i, v) == 1 and sg(i, u) == 0 and sg(i, w) + omega(i + 1, qv) == 1:
314
+ ret_v = [*v]
315
+ ret_v[i], ret_v[i + 1] = ret_v[i + 1], ret_v[i]
316
+ ret_w = [*w] + [j + 1 for j in range(len(w), i + 2)]
317
+ ret_w[i], ret_w[i + 1] = ret_w[i + 1], ret_w[i]
318
+ qv_ret = [*qv]
319
+ if sg(i, w) == 0:
320
+ qv_ret[i] -= 1
321
+ return u, tuple(permtrim(ret_v)), tuple(permtrim(ret_w)), qv_ret, True
322
+ elif sg(i, u) == 1 and sg(i, v) == 0 and sg(i, w) + omega(i + 1, qv) == 1:
323
+ ret_u = [*u]
324
+ ret_u[i], ret_u[i + 1] = ret_u[i + 1], ret_u[i]
325
+ ret_w = [*w] + [j + 1 for j in range(len(w), i + 2)]
326
+ ret_w[i], ret_w[i + 1] = ret_w[i + 1], ret_w[i]
327
+ qv_ret = [*qv]
328
+ if sg(i, w) == 0:
329
+ qv_ret[i] -= 1
330
+ return tuple(permtrim(ret_u)), v, tuple(permtrim(ret_w)), qv_ret, True
331
+ elif sg(i, u) == 1 and sg(i, v) == 1 and sg(i, w) + omega(i + 1, qv) == 2:
332
+ ret_u = [*u]
333
+ ret_u[i], ret_u[i + 1] = ret_u[i + 1], ret_u[i]
334
+ ret_w = [*w] + [j + 1 for j in range(len(w), i + 2)]
335
+ ret_w[i], ret_w[i + 1] = ret_w[i + 1], ret_w[i]
336
+ qv_ret = [*qv]
337
+ if sg(i, w) == 0:
338
+ qv_ret[i] -= 1
339
+ return tuple(permtrim(ret_u)), v, tuple(permtrim(ret_w)), qv_ret, True
340
+ return u, v, w, qv, False
341
+
342
+
343
+ def reduce_q_coeff_u_only(u, v, w, qv):
344
+ for i in range(len(qv)):
345
+ if sg(i, u) == 1 and sg(i, v) == 0 and sg(i, w) + omega(i + 1, qv) == 1:
346
+ ret_u = [*u]
347
+ ret_u[i], ret_u[i + 1] = ret_u[i + 1], ret_u[i]
348
+ ret_w = [*w] + [j + 1 for j in range(len(w), i + 2)]
349
+ ret_w[i], ret_w[i + 1] = ret_w[i + 1], ret_w[i]
350
+ qv_ret = [*qv]
351
+ if sg(i, w) == 0:
352
+ qv_ret[i] -= 1
353
+ return tuple(permtrim(ret_u)), v, tuple(permtrim(ret_w)), qv_ret, True
354
+ elif sg(i, u) == 1 and sg(i, v) == 1 and sg(i, w) + omega(i + 1, qv) == 2:
355
+ ret_u = [*u]
356
+ ret_u[i], ret_u[i + 1] = ret_u[i + 1], ret_u[i]
357
+ ret_w = [*w] + [j + 1 for j in range(len(w), i + 2)]
358
+ ret_w[i], ret_w[i + 1] = ret_w[i + 1], ret_w[i]
359
+ qv_ret = [*qv]
360
+ if sg(i, w) == 0:
361
+ qv_ret[i] -= 1
362
+ return tuple(permtrim(ret_u)), v, tuple(permtrim(ret_w)), qv_ret, True
363
+ return u, v, w, qv, False
364
+
365
+
366
+ def longest_element(indices):
367
+ perm = [1, 2]
368
+ did_one = True
369
+ while did_one:
370
+ did_one = False
371
+ for i in range(len(indices)):
372
+ j = indices[i] - 1
373
+ if sg(j, perm) == 0:
374
+ if len(perm) < j + 2:
375
+ perm = perm + [index for index in range(len(perm) + 1, j + 3)]
376
+ perm[j], perm[j + 1] = perm[j + 1], perm[j]
377
+ did_one = True
378
+ return permtrim(perm)
379
+
380
+
381
+ def count_less_than(arr, val):
382
+ ct = 0
383
+ i = 0
384
+ while i < len(arr) and arr[i] < val:
385
+ i += 1
386
+ ct += 1
387
+ return ct
388
+
389
+
390
+ def is_parabolic(w, parabolic_index):
391
+ for i in parabolic_index:
392
+ if sg(i - 1, w) == 1:
393
+ return False
394
+ return True
395
+
396
+
397
+ def check_blocks(qv, parabolic_index):
398
+ blocks = []
399
+ cur_block = []
400
+ last_val = -1
401
+ for i in range(len(parabolic_index)):
402
+ if last_val == -1 or last_val + 1 == parabolic_index[i]:
403
+ last_val = parabolic_index[i]
404
+ cur_block += [last_val]
405
+ else:
406
+ blocks += [cur_block]
407
+ cur_block = []
408
+ for block in blocks:
409
+ for i in range(len(block)):
410
+ for j in range(i, len(block)):
411
+ val = 0
412
+ for k in range(i, j + 1):
413
+ val += omega(block[k], qv)
414
+ if val != 0 and val != -1:
415
+ return False
416
+ return True
417
+
418
+
419
+ # perms and inversion diff
420
+ def kdown_perms(perm, monoperm, p, k):
421
+ inv_m = inv(monoperm)
422
+ inv_p = inv(perm)
423
+ full_perm_list = []
424
+
425
+ if inv(mulperm(list(perm), monoperm)) == inv_m - inv_p:
426
+ full_perm_list += [(tuple(perm), 0, 1)]
427
+
428
+ down_perm_list = [(perm, 1)]
429
+ if len(perm) < k:
430
+ return full_perm_list
431
+ a2 = k - 1
432
+ for pp in range(1, p + 1):
433
+ down_perm_list2 = []
434
+ for perm2, s in down_perm_list:
435
+ L = len(perm2)
436
+ if L < k:
437
+ continue
438
+ s2 = -s
439
+ for b in chain(range(k - 1), range(k, L)):
440
+ if perm2[b] != perm[b]:
441
+ continue
442
+ if b < a2:
443
+ i, j = b, a2
444
+ else:
445
+ i, j, s2 = a2, b, s
446
+ if has_bruhat_descent(perm2, i, j):
447
+ new_perm = [*perm2]
448
+ new_perm[a2], new_perm[b] = new_perm[b], new_perm[a2]
449
+ permtrim(new_perm)
450
+ down_perm_list2 += [(new_perm, s2)]
451
+ if inv(mulperm(new_perm, monoperm)) == inv_m - inv_p + pp:
452
+ full_perm_list += [(tuple(new_perm), pp, s2)]
453
+ down_perm_list = down_perm_list2
454
+ return full_perm_list
455
+
456
+
457
+ def compute_vpathdicts(th, vmu, smpify=False):
458
+ vpathdicts = [{} for index in range(len(th))]
459
+ vpathdicts[-1][tuple(vmu)] = None
460
+ thL = len(th)
461
+
462
+ top = code(inverse(uncode(th)))
463
+ for i in range(thL - 1, -1, -1):
464
+ top2 = code(inverse(uncode(top)))
465
+ while top2[-1] == 0:
466
+ top2.pop()
467
+ top2.pop()
468
+ top = code(inverse(uncode(top2)))
469
+ monoperm = uncode(top)
470
+ if len(monoperm) < 2:
471
+ monoperm = [1, 2]
472
+ k = i + 1
473
+ for last_perm in vpathdicts[i]:
474
+ newperms = kdown_perms(last_perm, monoperm, th[i], k)
475
+ vpathdicts[i][last_perm] = newperms
476
+ if i > 0:
477
+ for trip in newperms:
478
+ vpathdicts[i - 1][trip[0]] = None
479
+ vpathdicts2 = [{} for i in range(len(th))]
480
+ for i in range(len(th)):
481
+ for key, valueset in vpathdicts[i].items():
482
+ for value in valueset:
483
+ key2 = value[0]
484
+ if key2 not in vpathdicts2[i]:
485
+ vpathdicts2[i][key2] = set()
486
+ v2 = value[2]
487
+ if smpify:
488
+ v2 = sympify(v2)
489
+ vpathdicts2[i][key2].add((key, value[1], v2))
490
+ # print(vpathdicts2)
491
+ return vpathdicts2
492
+
493
+
494
+ def theta(perm):
495
+ cd = code(perm)
496
+ for i in range(len(cd) - 1, 0, -1):
497
+ for j in range(i - 1, -1, -1):
498
+ if cd[j] < cd[i]:
499
+ cd[i] += 1
500
+ cd.sort(reverse=True)
501
+ return cd
502
+
503
+
504
+ def add_perm_dict(d1, d2):
505
+ for k, v in d2.items():
506
+ d1[k] = d1.get(k, 0) + v
507
+ return d1
508
+
509
+
510
+ one = sympify(1)
511
+
512
+
513
+ def elem_sym_poly_q(p, k, varl1, varl2, q_var=q_var):
514
+ if p == 0 and k >= 0:
515
+ return one
516
+ if p < 0 or p > k:
517
+ return zero
518
+ return (
519
+ (varl1[k - 1] - varl2[k - p]) * elem_sym_poly_q(p - 1, k - 1, varl1, varl2, q_var)
520
+ + elem_sym_poly_q(p, k - 1, varl1, varl2, q_var)
521
+ + q_var[k - 1] * elem_sym_poly_q(p - 2, k - 2, varl1, varl2, q_var)
522
+ )
523
+
524
+
525
+ def elem_sym_poly(p, k, varl1, varl2, xstart=0, ystart=0):
526
+ global zero, one
527
+ if p > k:
528
+ return zero
529
+ if p == 0:
530
+ return one
531
+ if p == 1:
532
+ res = varl1[xstart] - varl2[ystart]
533
+ for i in range(1, k):
534
+ res += varl1[xstart + i] - varl2[ystart + i]
535
+ return res
536
+ if p == k:
537
+ res = (varl1[xstart] - varl2[ystart]) * (varl1[xstart + 1] - varl2[ystart])
538
+ for i in range(2, k):
539
+ res *= varl1[i + xstart] - varl2[ystart]
540
+ return res
541
+ mid = k // 2
542
+ xsm = xstart + mid
543
+ ysm = ystart + mid
544
+ kmm = k - mid
545
+ res = elem_sym_poly(p, mid, varl1, varl2, xstart, ystart) + elem_sym_poly(
546
+ p, kmm, varl1, varl2, xsm, ysm
547
+ )
548
+ for p2 in range(max(1, p - kmm), min(p, mid + 1)):
549
+ res += elem_sym_poly(p2, mid, varl1, varl2, xstart, ystart) * elem_sym_poly(
550
+ p - p2, kmm, varl1, varl2, xsm, ysm - p2
551
+ )
552
+ return res
553
+
554
+
555
+ @cache
556
+ def call_zvars(v1, v2, k, i):
557
+ v3 = [*v2] + [j for j in range(len(v2) + 1, i + 1)]
558
+ zvars = (
559
+ [v3[i - 1]]
560
+ + [v3[j] for j in range(len(v1), len(v3)) if v3[j] != j + 1 and j != i - 1]
561
+ + [v3[j] for j in range(len(v1)) if v1[j] != v3[j] and j != i - 1]
562
+ )
563
+ return zvars
564
+
565
+
566
+ def elem_sym_func(k, i, u1, u2, v1, v2, udiff, vdiff, varl1, varl2):
567
+ global zero, one
568
+ newk = k - udiff
569
+ if newk < vdiff:
570
+ return zero
571
+ if newk == vdiff:
572
+ return one
573
+ yvars = []
574
+ for j in range(min(len(u1), k)):
575
+ if u1[j] == u2[j]:
576
+ yvars += [varl1[u2[j]]]
577
+ for j in range(len(u1), min(k, len(u2))):
578
+ if u2[j] == j + 1:
579
+ yvars += [varl1[u2[j]]]
580
+ for j in range(len(u2), k):
581
+ yvars += [varl1[j + 1]]
582
+ zvars = [varl2[i] for i in call_zvars(v1, v2, k, i)]
583
+ return elem_sym_poly(newk - vdiff, newk, yvars, zvars)
584
+
585
+
586
+ def elem_sym_func_q(k, i, u1, u2, v1, v2, udiff, vdiff, varl1, varl2):
587
+ global zero, one
588
+ newk = k - udiff
589
+ if newk < vdiff:
590
+ return zero
591
+ if newk == vdiff:
592
+ return one
593
+ yvars = []
594
+ mlen = max(len(u1), len(u2))
595
+ u1 = [*u1] + [a + 1 for a in range(len(u1), mlen)]
596
+ u2 = [*u2] + [a + 1 for a in range(len(u2), mlen)]
597
+ for j in range(min(len(u1), k)):
598
+ if u1[j] == u2[j]:
599
+ yvars += [varl1[u2[j]]]
600
+ for j in range(len(u1), min(k, len(u2))):
601
+ if u2[j] == j + 1:
602
+ yvars += [varl1[u2[j]]]
603
+ for j in range(len(u2), k):
604
+ yvars += [varl1[j + 1]]
605
+ zvars = [varl2[a] for a in call_zvars(v1, v2, k, i)]
606
+ return elem_sym_poly(newk - vdiff, newk, yvars, zvars)
607
+
608
+
609
+ def trimcode(perm):
610
+ cd = code(perm)
611
+ while len(cd) > 0 and cd[-1] == 0:
612
+ cd.pop()
613
+ return cd
614
+
615
+
616
+ def p_trans(part):
617
+ newpart = []
618
+ if len(part) == 0 or part[0] == 0:
619
+ return [0]
620
+ for i in range(1, part[0] + 1):
621
+ cnt = 0
622
+ for j in range(len(part)):
623
+ if part[j] >= i:
624
+ cnt += 1
625
+ if cnt == 0:
626
+ break
627
+ newpart += [cnt]
628
+ return newpart
629
+
630
+
631
+ def cycle(p, q):
632
+ return [i for i in range(1, p)] + [i + 1 for i in range(p, p + q)] + [p]
633
+
634
+
635
+ def phi1(u):
636
+ c_star = code(inverse(u))
637
+ c_star.pop(0)
638
+ phi_u = inverse(uncode(c_star))
639
+ return phi_u
640
+
641
+
642
+ def one_dominates(u, w):
643
+ c_star_u = code(inverse(u))
644
+ c_star_w = code(inverse(w))
645
+
646
+ a = c_star_u[0]
647
+ b = c_star_w[0]
648
+
649
+ for i in range(a, b):
650
+ if i >= len(u) - 1:
651
+ return True
652
+ if u[i] > u[i + 1]:
653
+ return False
654
+ return True
655
+
656
+
657
+ def dominates(u, w):
658
+ u2 = [*u]
659
+ w2 = [*w]
660
+ while u2 != [1, 2] and one_dominates(u2, w2):
661
+ u2 = phi1(u2)
662
+ w2 = phi1(w2)
663
+ if u2 == [1, 2]:
664
+ return True
665
+ return False
666
+
667
+
668
+ def reduce_coeff(u, v, w):
669
+ t_mu_u_t = theta(inverse(u))
670
+ t_mu_v_t = theta(inverse(v))
671
+
672
+ mu_u_inv = uncode(t_mu_u_t)
673
+ mu_v_inv = uncode(t_mu_v_t)
674
+
675
+ t_mu_u = p_trans(t_mu_u_t)
676
+ t_mu_v = p_trans(t_mu_v_t)
677
+
678
+ t_mu_u += [0 for i in range(len(t_mu_u), max(len(t_mu_u), len(t_mu_v)))]
679
+ t_mu_v += [0 for i in range(len(t_mu_v), max(len(t_mu_u), len(t_mu_v)))]
680
+
681
+ t_mu_uv = [t_mu_u[i] + t_mu_v[i] for i in range(len(t_mu_u))]
682
+ t_mu_uv_t = p_trans(t_mu_uv)
683
+
684
+ mu_uv_inv = uncode(t_mu_uv_t)
685
+
686
+ if inv(mulperm(list(w), mu_uv_inv)) != inv(mu_uv_inv) - inv(w):
687
+ return u, v, w
688
+
689
+ umu = mulperm(list(u), mu_u_inv)
690
+ vmu = mulperm(list(v), mu_v_inv)
691
+ wmu = mulperm(list(w), mu_uv_inv)
692
+
693
+ t_mu_w = theta(inverse(wmu))
694
+
695
+ mu_w = uncode(t_mu_w)
696
+
697
+ w_prime = mulperm(wmu, mu_w)
698
+
699
+ if permtrim(list(w)) == permtrim(w_prime):
700
+ return (permtrim(list(u)), permtrim(list(v)), permtrim(list(w)))
701
+
702
+ A = []
703
+ B = []
704
+ indexA = 0
705
+
706
+ while len(t_mu_u_t) > 0 and t_mu_u_t[-1] == 0:
707
+ t_mu_u_t.pop()
708
+
709
+ while len(t_mu_v_t) > 0 and t_mu_v_t[-1] == 0:
710
+ t_mu_v_t.pop()
711
+
712
+ while len(t_mu_uv_t) > 0 and t_mu_uv_t[-1] == 0:
713
+ t_mu_uv_t.pop()
714
+
715
+ for index in range(len(t_mu_uv_t)):
716
+ if indexA < len(t_mu_u_t) and t_mu_uv_t[index] == t_mu_u_t[indexA]:
717
+ A += [index]
718
+ indexA += 1
719
+ else:
720
+ B += [index]
721
+
722
+ mu_w_A = uncode(mu_A(code(mu_w), A))
723
+ mu_w_B = uncode(mu_A(code(mu_w), B))
724
+
725
+ return (
726
+ permtrim(mulperm(umu, mu_w_A)),
727
+ permtrim(mulperm(vmu, mu_w_B)),
728
+ permtrim(w_prime),
729
+ )
730
+
731
+
732
+ def mu_A(mu, A):
733
+ mu_t = p_trans(mu)
734
+ mu_A_t = []
735
+ for i in range(len(A)):
736
+ if A[i] < len(mu_t):
737
+ mu_A_t += [mu_t[A[i]]]
738
+ return p_trans(mu_A_t)
739
+
740
+
741
+ def reduce_descents(u, v, w):
742
+ u2 = [*u]
743
+ v2 = [*v]
744
+ w2 = [*w]
745
+ found_one = True
746
+ while found_one:
747
+ found_one = False
748
+ if (
749
+ will_formula_work(u2, v2)
750
+ or will_formula_work(v2, u2)
751
+ or one_dominates(u2, w2)
752
+ or is_reducible(v2)
753
+ or inv(w2) - inv(u2) == 1
754
+ ):
755
+ break
756
+ for i in range(len(w2) - 2, -1, -1):
757
+ if (
758
+ w2[i] > w2[i + 1]
759
+ and i < len(v2) - 1
760
+ and v2[i] > v2[i + 1]
761
+ and (i >= len(u2) - 1 or u2[i] < u2[i + 1])
762
+ ):
763
+ w2[i], w2[i + 1] = w2[i + 1], w2[i]
764
+ v2[i], v2[i + 1] = v2[i + 1], v2[i]
765
+ found_one = True
766
+ elif (
767
+ w2[i] > w2[i + 1]
768
+ and i < len(u2) - 1
769
+ and u2[i] > u2[i + 1]
770
+ and (i >= len(v2) - 1 or v2[i] < v2[i + 1])
771
+ ):
772
+ w2[i], w2[i + 1] = w2[i + 1], w2[i]
773
+ u2[i], u2[i + 1] = u2[i + 1], u2[i]
774
+ found_one = True
775
+ if found_one:
776
+ break
777
+ return permtrim(u2), permtrim(v2), permtrim(w2)
778
+
779
+
780
+ def is_reducible(v):
781
+ c03 = code(v)
782
+ found0 = False
783
+ good = True
784
+ for i in range(len(c03)):
785
+ if c03[i] == 0:
786
+ found0 = True
787
+ elif c03[i] != 0 and found0:
788
+ good = False
789
+ break
790
+ return good
791
+
792
+
793
+ def try_reduce_v(u, v, w):
794
+ if is_reducible(v):
795
+ return tuple(permtrim([*u])), tuple(permtrim([*v])), tuple(permtrim([*w]))
796
+ u2 = [*u]
797
+ v2 = [*v]
798
+ w2 = [*w]
799
+ cv = code(v2)
800
+ for i in range(len(v2) - 2, -1, -1):
801
+ if cv[i] == 0 and i < len(cv) - 1 and cv[i + 1] != 0:
802
+ if i >= len(u2) - 1 or u2[i] < u2[i + 1]:
803
+ v2[i], v2[i + 1] = v2[i + 1], v2[i]
804
+ if i >= len(w2) - 1:
805
+ w2 += [j for j in range(len(w2) + 1, i + 3)]
806
+ w2[i + 1], w2[i] = w2[i], w2[i + 1]
807
+ if is_reducible(v2):
808
+ return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
809
+ else:
810
+ return try_reduce_v(u2, v2, w2)
811
+ elif i < len(w2) - 1 and w2[i] > w2[i + 1]:
812
+ u2[i], u2[i + 1] = u2[i + 1], u2[i]
813
+ v2[i], v2[i + 1] = v2[i + 1], v2[i]
814
+ return try_reduce_v(u2, v2, w2)
815
+ else:
816
+ return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
817
+ return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
818
+
819
+
820
+ def try_reduce_u(u, v, w):
821
+ if one_dominates(u, w):
822
+ return u, v, w
823
+ u2 = [*u]
824
+ v2 = [*v]
825
+ w2 = [*w]
826
+ cu = code(u)
827
+ for i in range(len(u2) - 2, -1, -1):
828
+ if cu[i] == 0 and i < len(cu) - 1 and cu[i + 1] != 0:
829
+ if i >= len(v2) - 1 or v2[i] < v2[i + 1]:
830
+ u2[i], u2[i + 1] = u2[i + 1], u2[i]
831
+ if i > len(w2) - 1:
832
+ w2 += [j for j in range(len(w2) + 1, i + 3)]
833
+ w2[i + 1], w2[i] = w2[i], w2[i + 1]
834
+ if one_dominates(u, w):
835
+ return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
836
+ else:
837
+ return try_reduce_u(u2, v2, w2)
838
+ elif i < len(w2) - 1 and w2[i] > w2[i + 1]:
839
+ u2[i], u2[i + 1] = u2[i + 1], u2[i]
840
+ v2[i], v2[i + 1] = v2[i + 1], v2[i]
841
+ return try_reduce_u(u2, v2, w2)
842
+ else:
843
+ return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
844
+ return tuple(permtrim(u2)), tuple(permtrim(v2)), tuple(permtrim(w2))
845
+
846
+
847
+ def divdiffable(v, u):
848
+ inv_v = inv(v)
849
+ inv_u = inv(u)
850
+ perm2 = permtrim(mulperm(v, inverse(u)))
851
+ if inv(perm2) != inv_v - inv_u:
852
+ return []
853
+ return perm2
854
+
855
+
856
+ def will_formula_work(u, v):
857
+ muv = uncode(theta(v))
858
+ vn1muv = mulperm(inverse(v), muv)
859
+ while True:
860
+ found_one = False
861
+ for i in range(len(vn1muv) - 1):
862
+ if vn1muv[i] > vn1muv[i + 1]:
863
+ found_one = True
864
+ if i < len(u) - 1 and u[i] > u[i + 1]:
865
+ return False
866
+ else:
867
+ vn1muv[i], vn1muv[i + 1] = vn1muv[i + 1], vn1muv[i]
868
+ break
869
+ if not found_one:
870
+ return True
871
+
872
+
873
+ def pull_out_var(vnum, v):
874
+ vup = v + [len(v) + 1]
875
+ if vnum >= len(v):
876
+ return [[[], v]]
877
+ vpm_list = [(vup, 0)]
878
+ ret_list = []
879
+ for p in range(len(v) + 1 - vnum):
880
+ vpm_list2 = []
881
+ for vpm, b in vpm_list:
882
+ if vpm[vnum - 1] == len(v) + 1:
883
+ vpm2 = [*vpm]
884
+ vpm2.pop(vnum - 1)
885
+ vp = permtrim(vpm2)
886
+ ret_list += [
887
+ [
888
+ [
889
+ v[i]
890
+ for i in range(vnum, len(v))
891
+ if ((i > len(vp) and v[i] == i) or (i <= len(vp) and v[i] == vp[i - 1]))
892
+ ],
893
+ vp,
894
+ ]
895
+ ]
896
+ for j in range(vnum, len(vup)):
897
+ if vpm[j] <= b:
898
+ continue
899
+ for i in range(vnum):
900
+ if has_bruhat_ascent(vpm, i, j):
901
+ vpm[i], vpm[j] = vpm[j], vpm[i]
902
+ vpm_list2 += [([*vpm], vpm[i])]
903
+ vpm[i], vpm[j] = vpm[j], vpm[i]
904
+ vpm_list = vpm_list2
905
+ for vpm, b in vpm_list:
906
+ if vpm[vnum - 1] == len(v) + 1:
907
+ vpm2 = [*vpm]
908
+ vpm2.pop(vnum - 1)
909
+ vp = permtrim(vpm2)
910
+ ret_list += [
911
+ [
912
+ [
913
+ v[i]
914
+ for i in range(vnum, len(v))
915
+ if ((i > len(vp) and v[i] == i) or (i <= len(vp) and v[i] == vp[i - 1]))
916
+ ],
917
+ vp,
918
+ ]
919
+ ]
920
+ return ret_list
921
+
922
+
923
+ def get_cycles(perm):
924
+ cycle_set = []
925
+ done_vals = set()
926
+ for i in range(len(perm)):
927
+ p = i + 1
928
+ if perm[i] == p:
929
+ continue
930
+ if p in done_vals:
931
+ continue
932
+ cycle = []
933
+ m = -1
934
+ max_index = -1
935
+ while p not in done_vals:
936
+ cycle += [p]
937
+ done_vals.add(p)
938
+ if p > m:
939
+ m = p
940
+ max_index = len(cycle) - 1
941
+ p = perm[p - 1]
942
+ cycle = tuple(cycle[max_index + 1 :] + cycle[: max_index + 1])
943
+ cycle_set += [cycle]
944
+ return cycle_set
945
+
946
+
947
+ def double_elem_sym_q(u, p1, p2, k, q_var=q_var):
948
+ ret_list = {}
949
+ perms1 = elem_sym_perms_q(u, p1, k, q_var)
950
+ iu = inverse(u)
951
+ for perm1, udiff1, mul_val1 in perms1:
952
+ perms2 = elem_sym_perms_q(perm1, p2, k, q_var)
953
+ cycles1 = get_cycles(tuple(permtrim(mulperm(iu, [*perm1]))))
954
+ cycles1_dict = {}
955
+ for c in cycles1:
956
+ if c[-1] not in cycles1_dict:
957
+ cycles1_dict[c[-1]] = []
958
+ cycles1_dict[c[-1]] += [set(c)]
959
+ ip1 = inverse(perm1)
960
+ for perm2, udiff2, mul_val2 in perms2:
961
+ cycles2 = get_cycles(tuple(permtrim(mulperm(ip1, [*perm2]))))
962
+ good = True
963
+ for i in range(len(cycles2)):
964
+ c2 = cycles2[i]
965
+ if c2[-1] not in cycles1_dict:
966
+ continue
967
+ for c1_s in cycles1_dict[c2[-1]]:
968
+ for a in range(len(c2) - 2, -1, -1):
969
+ if c2[a] in c1_s:
970
+ good = False
971
+ break
972
+ if not good:
973
+ break
974
+ if not good:
975
+ break
976
+
977
+ if good:
978
+ if (perm1, udiff1, mul_val1) not in ret_list:
979
+ ret_list[(perm1, udiff1, mul_val1)] = []
980
+ ret_list[(perm1, udiff1, mul_val1)] += [(perm2, udiff2, mul_val2)]
981
+ return ret_list
982
+
983
+
984
+ def medium_theta(perm):
985
+ cd = code(perm)
986
+ found_one = True
987
+ while found_one:
988
+ found_one = False
989
+ for i in range(len(cd) - 1):
990
+ if cd[i] < cd[i + 1]:
991
+ found_one = True
992
+ cd[i], cd[i + 1] = cd[i + 1] + 1, cd[i]
993
+ break
994
+ if cd[i] == cd[i + 1] and cd[i] != 0 and i > 0 and cd[i - 1] <= cd[i] + 1:
995
+ # if cd[i]==cd[i+1] and i>0 and cd[i-1]<=cd[i]+1:
996
+ cd[i] += 1
997
+ found_one = True
998
+ break
999
+ return cd