schubmult 2.0.4__py3-none-any.whl → 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. schubmult/__init__.py +96 -1
  2. schubmult/perm_lib.py +254 -819
  3. schubmult/poly_lib/__init__.py +31 -0
  4. schubmult/poly_lib/poly_lib.py +276 -0
  5. schubmult/poly_lib/schub_poly.py +148 -0
  6. schubmult/poly_lib/variables.py +204 -0
  7. schubmult/rings/__init__.py +18 -0
  8. schubmult/rings/_quantum_schubert_polynomial_ring.py +752 -0
  9. schubmult/rings/_schubert_polynomial_ring.py +1031 -0
  10. schubmult/rings/_tensor_schub_ring.py +128 -0
  11. schubmult/rings/_utils.py +55 -0
  12. schubmult/{sage_integration → sage}/__init__.py +4 -1
  13. schubmult/{sage_integration → sage}/_fast_double_schubert_polynomial_ring.py +67 -109
  14. schubmult/{sage_integration → sage}/_fast_schubert_polynomial_ring.py +33 -28
  15. schubmult/{sage_integration → sage}/_indexing.py +9 -5
  16. schubmult/schub_lib/__init__.py +51 -0
  17. schubmult/{schubmult_double/_funcs.py → schub_lib/double.py} +532 -596
  18. schubmult/{schubmult_q/_funcs.py → schub_lib/quantum.py} +54 -53
  19. schubmult/schub_lib/quantum_double.py +954 -0
  20. schubmult/schub_lib/schub_lib.py +659 -0
  21. schubmult/{schubmult_py/_funcs.py → schub_lib/single.py} +45 -35
  22. schubmult/schub_lib/tests/__init__.py +0 -0
  23. schubmult/schub_lib/tests/legacy_perm_lib.py +946 -0
  24. schubmult/schub_lib/tests/test_vs_old.py +109 -0
  25. schubmult/scripts/__init__.py +0 -0
  26. schubmult/scripts/schubmult_double.py +378 -0
  27. schubmult/scripts/schubmult_py.py +84 -0
  28. schubmult/scripts/schubmult_q.py +109 -0
  29. schubmult/scripts/schubmult_q_double.py +207 -0
  30. schubmult/utils/__init__.py +0 -0
  31. schubmult/{_base_argparse.py → utils/argparse.py} +29 -5
  32. schubmult/utils/logging.py +16 -0
  33. schubmult/utils/parsing.py +20 -0
  34. schubmult/utils/perm_utils.py +135 -0
  35. schubmult/utils/test_utils.py +65 -0
  36. schubmult-3.0.1.dist-info/METADATA +1234 -0
  37. schubmult-3.0.1.dist-info/RECORD +41 -0
  38. {schubmult-2.0.4.dist-info → schubmult-3.0.1.dist-info}/WHEEL +1 -1
  39. schubmult-3.0.1.dist-info/entry_points.txt +5 -0
  40. schubmult/_tests.py +0 -24
  41. schubmult/schubmult_double/__init__.py +0 -12
  42. schubmult/schubmult_double/__main__.py +0 -6
  43. schubmult/schubmult_double/_script.py +0 -474
  44. schubmult/schubmult_py/__init__.py +0 -12
  45. schubmult/schubmult_py/__main__.py +0 -6
  46. schubmult/schubmult_py/_script.py +0 -97
  47. schubmult/schubmult_q/__init__.py +0 -8
  48. schubmult/schubmult_q/__main__.py +0 -6
  49. schubmult/schubmult_q/_script.py +0 -166
  50. schubmult/schubmult_q_double/__init__.py +0 -10
  51. schubmult/schubmult_q_double/__main__.py +0 -6
  52. schubmult/schubmult_q_double/_funcs.py +0 -540
  53. schubmult/schubmult_q_double/_script.py +0 -396
  54. schubmult-2.0.4.dist-info/METADATA +0 -542
  55. schubmult-2.0.4.dist-info/RECORD +0 -30
  56. schubmult-2.0.4.dist-info/entry_points.txt +0 -5
  57. {schubmult-2.0.4.dist-info → schubmult-3.0.1.dist-info}/licenses/LICENSE +0 -0
  58. {schubmult-2.0.4.dist-info → schubmult-3.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,954 @@
1
+ # from ._vars import (
2
+ # var_y,
3
+ # var_x,
4
+ # var2,
5
+ # var3,
6
+ # q_var2,
7
+ # )
8
+ from functools import cache, cached_property
9
+
10
+ import numpy as np
11
+ from symengine import Add, Mul, Pow, expand, sympify
12
+
13
+ import schubmult.schub_lib.double as norm_yz
14
+ from schubmult.perm_lib import Permutation, code, inv, longest_element, medium_theta, permtrim, strict_theta, uncode
15
+ from schubmult.poly_lib.poly_lib import call_zvars, elem_sym_func_q, elem_sym_poly_q, q_vector
16
+ from schubmult.poly_lib.variables import CustomGeneratingSet, GeneratingSet, GeneratingSet_base
17
+ from schubmult.schub_lib.schub_lib import check_blocks, compute_vpathdicts, double_elem_sym_q, elem_sym_perms_q, elem_sym_perms_q_op, reduce_q_coeff
18
+ from schubmult.utils.logging import get_logger
19
+ from schubmult.utils.perm_utils import (
20
+ add_perm_dict,
21
+ count_less_than,
22
+ is_parabolic,
23
+ omega,
24
+ )
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ class _gvars:
30
+ @cached_property
31
+ def n(self):
32
+ return 100
33
+
34
+ # @cached_property
35
+ # def fvar(self):
36
+ # return 100
37
+
38
+ @cached_property
39
+ def var1(self):
40
+ return GeneratingSet("x")
41
+
42
+ @cached_property
43
+ def var2(self):
44
+ return GeneratingSet("y")
45
+
46
+ @cached_property
47
+ def var3(self):
48
+ return GeneratingSet("z")
49
+
50
+ @cached_property
51
+ def q_var(self):
52
+ return GeneratingSet("q")
53
+
54
+ @cached_property
55
+ def var_r(self):
56
+ return GeneratingSet("r")
57
+
58
+ @cached_property
59
+ def var_g1(self):
60
+ return GeneratingSet("y")
61
+
62
+ @cached_property
63
+ def var_g2(self):
64
+ return GeneratingSet("z")
65
+
66
+
67
+ _vars = _gvars()
68
+
69
+
70
+ # def E(p, k, varl=_vars.var2[1:], var_x=_vars.var1):
71
+ # return elem_sym_poly_q(p, k, var_x[1:], varl)
72
+
73
+
74
+ def single_variable(coeff_dict, varnum, var2=_vars.var2, q_var=_vars.q_var):
75
+ ret = {}
76
+ for u in coeff_dict:
77
+ if varnum - 1 < len(u):
78
+ ret[u] = ret.get(u, 0) + var2[u[varnum - 1]] * coeff_dict[u]
79
+ else:
80
+ ret[u] = ret.get(u, 0) + var2[varnum] * coeff_dict[u]
81
+ new_perms_k = elem_sym_perms_q(u, 1, varnum, q_var)
82
+ new_perms_km1 = []
83
+ if varnum > 1:
84
+ new_perms_km1 = elem_sym_perms_q(u, 1, varnum - 1, q_var)
85
+ for perm, udiff, mul_val in new_perms_k:
86
+ if udiff == 1:
87
+ ret[perm] = ret.get(perm, 0) + coeff_dict[u] * mul_val
88
+ for perm, udiff, mul_val in new_perms_km1:
89
+ if udiff == 1:
90
+ ret[perm] = ret.get(perm, 0) - coeff_dict[u] * mul_val
91
+ return ret
92
+
93
+
94
+ def mult_poly_q_double(coeff_dict, poly, var_x=_vars.var1, var_y=_vars.var2, q_var=_vars.q_var):
95
+ if not isinstance(var_x, GeneratingSet_base):
96
+ var_x = CustomGeneratingSet(var_x)
97
+ # logger.debug(f"{poly=} {list(var_x)=}")
98
+ if var_x.index(poly) != -1:
99
+ # logger.debug(f"yay {var_x.index(poly)=}")
100
+ return single_variable(coeff_dict, var_x.index(poly), var_y, q_var)
101
+ if isinstance(poly, Mul):
102
+ ret = coeff_dict
103
+ for a in poly.args:
104
+ ret = mult_poly_q_double(ret, a, var_x, var_y, q_var)
105
+ return ret
106
+ if isinstance(poly, Pow):
107
+ base = poly.args[0]
108
+ exponent = int(poly.args[1])
109
+ ret = coeff_dict
110
+ for i in range(int(exponent)):
111
+ ret = mult_poly_q_double(ret, base, var_x, var_y, q_var)
112
+ return ret
113
+ if isinstance(poly, Add):
114
+ ret = {}
115
+ for a in poly.args:
116
+ ret = add_perm_dict(ret, mult_poly_q_double(coeff_dict, a, var_x, var_y, q_var))
117
+ return ret
118
+ ret = {}
119
+ for perm in coeff_dict:
120
+ ret[perm] = poly * coeff_dict[perm]
121
+ return ret
122
+
123
+
124
+ def nil_hecke(perm_dict, v, n, var2=_vars.var2, var3=_vars.var3):
125
+ if v == Permutation([1, 2]):
126
+ return perm_dict
127
+ th = strict_theta(~v)
128
+ mu = uncode(th)
129
+ vmu = v * mu
130
+
131
+ ret_dict = {}
132
+ while th[-1] == 0:
133
+ th.pop()
134
+ thL = len(th)
135
+ vpathdicts = compute_vpathdicts(th, vmu, True)
136
+ for u, val in perm_dict.items():
137
+ vpathsums = {u: {Permutation([1, 2]): val}}
138
+ for index in range(thL):
139
+ mx_th = 0
140
+ for vp in vpathdicts[index]:
141
+ for v2, vdiff, s in vpathdicts[index][vp]:
142
+ mx_th = max(mx_th, th[index] - vdiff)
143
+ newpathsums = {}
144
+ for up in vpathsums:
145
+ newperms = elem_sym_perms_q_op(up, mx_th, th[index], n)
146
+ for up2, udiff, mul_val in newperms:
147
+ if up2 not in newpathsums:
148
+ newpathsums[up2] = {}
149
+ for v in vpathdicts[index]:
150
+ sumval = vpathsums[up].get(v, 0) * mul_val
151
+ if sumval == 0:
152
+ continue
153
+ for v2, vdiff, s in vpathdicts[index][v]:
154
+ newpathsums[up2][v2] = newpathsums[up2].get(
155
+ v2,
156
+ 0,
157
+ ) + s * sumval * elem_sym_func_q(
158
+ th[index],
159
+ index + 1,
160
+ up2,
161
+ up,
162
+ v,
163
+ v2,
164
+ udiff,
165
+ vdiff,
166
+ var2,
167
+ var3,
168
+ )
169
+ vpathsums = newpathsums
170
+ toget = vmu
171
+ ret_dict = add_perm_dict({ep: vpathsums[ep].get(toget, 0) for ep in vpathsums}, ret_dict)
172
+ return ret_dict
173
+
174
+
175
+ @cache
176
+ def schubmult_q_double_pair(perm1, perm2, var2=None, var3=None, q_var=None):
177
+ return schubmult_q_double_fast({perm1: 1}, perm2, var2, var3, q_var)
178
+
179
+
180
+ @cache
181
+ def schubmult_q_double_pair_generic(perm1, perm2):
182
+ return schubmult_q_double_fast({perm1: 1}, perm2, _vars.var_g1, _vars.var_g2, _vars.q_var)
183
+
184
+
185
+ @cache
186
+ def schubmult_q_generic_partial_posify(u2, v2):
187
+ # logger.debug("Line number")
188
+ return {w2: q_partial_posify_generic(val, u2, v2, w2) for w2, val in schubmult_q_double_pair_generic(u2, v2).items()}
189
+
190
+
191
+ def q_posify(u, v, w, val, var2, var3, q_var, msg):
192
+ # logger.debug(f"Line number {val=} {u=} {v=} {w=}")
193
+ try:
194
+ val2 = int(expand(val))
195
+ except Exception:
196
+ # logger.debug("Line number")
197
+ val2 = 0
198
+ q_dict = factor_out_q_keep_factored(val)
199
+ # logger.debug(f"{q_dict=}")
200
+ # logger.debug("Line number")
201
+ for q_part in q_dict:
202
+ try:
203
+ val2 += q_part * int(q_dict[q_part])
204
+ except Exception:
205
+ try:
206
+ # logger.debug("Line number")
207
+ if code(~v) == medium_theta(~v):
208
+ val2 += q_part * q_dict[q_part]
209
+ else:
210
+ q_part2 = q_part
211
+ qv = q_vector(q_part)
212
+ u2, v2, w2 = u, v, w
213
+ u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
214
+ while did_one:
215
+ u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
216
+ q_part2 = np.prod(
217
+ [q_var[i + 1] ** qv[i] for i in range(len(qv))],
218
+ )
219
+ if q_part2 == 1:
220
+ # reduced to classical coefficient
221
+ # logger.debug(f"{u=} {v=} {w=} {u2=} {v2=} {w2=} {q_part=} {q_dict[q_part]=}")
222
+ val2 += q_part * norm_yz.posify(
223
+ q_dict[q_part],
224
+ u2,
225
+ v2,
226
+ w2,
227
+ var2,
228
+ var3,
229
+ msg,
230
+ False,
231
+ )
232
+ else:
233
+ val2 += q_part * norm_yz.compute_positive_rep(
234
+ q_dict[q_part],
235
+ var2,
236
+ var3,
237
+ msg,
238
+ False,
239
+ )
240
+ if val2 is None:
241
+ raise Exception
242
+ except Exception:
243
+ import traceback
244
+
245
+ traceback.print_exc()
246
+ if expand(val - val2) != 0:
247
+ # logger.debug("Different")
248
+ raise Exception
249
+ return val2
250
+
251
+
252
+ # def q_posify(u, v, w, val, var2, var3, q_var, msg):
253
+ # if expand(val) != 0:
254
+ # try:
255
+ # int(val)
256
+ # except Exception:
257
+ # val2 = 0
258
+ # q_dict = factor_out_q_keep_factored(val)
259
+ # for q_part in q_dict:
260
+ # try:
261
+ # val2 += q_part * int(q_dict[q_part])
262
+ # except Exception:
263
+ # # if same:
264
+ # # to_add = q_part * expand(sympify(q_dict[q_part]).xreplace(subs_dict2))
265
+ # # val2 += to_add
266
+ # # else:
267
+ # try:
268
+ # if code(~v) == medium_theta(~v):
269
+ # val2 += q_part * q_dict[q_part]
270
+ # else:
271
+ # q_part2 = q_part
272
+ # qv = q_vector(q_part)
273
+ # u2, v2, w2 = u, v, w
274
+ # u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
275
+ # while did_one:
276
+ # u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
277
+ # q_part2 = np.prod(
278
+ # [q_var[i + 1] ** qv[i] for i in range(len(qv))],
279
+ # )
280
+ # if q_part2 == 1:
281
+ # # reduced to classical coefficient
282
+ # val2 += q_part * norm_yz.posify(
283
+ # q_dict[q_part],
284
+ # u2,
285
+ # v2,
286
+ # w2,
287
+ # var2,
288
+ # var3,
289
+ # msg,
290
+ # False,
291
+ # )
292
+ # else:
293
+ # val2 += q_part * norm_yz.compute_positive_rep(
294
+ # q_dict[q_part],
295
+ # var2,
296
+ # var3,
297
+ # msg,
298
+ # False,
299
+ # )
300
+ # except Exception as e:
301
+ # # print(f"Exception: {e}")
302
+ # import traceback
303
+
304
+ # traceback.print_exc()
305
+ # exit(1)
306
+ # if expand(val - val2) != 0:
307
+ # raise Exception
308
+ # val = val2
309
+ # return val
310
+ # return 0
311
+
312
+
313
+ def old_q_posify(u, v, w, val, var2, var3, q_var, msg):
314
+ val2 = 0
315
+ q_dict = factor_out_q_keep_factored(val)
316
+ for q_part in q_dict:
317
+ try:
318
+ val2 += q_part * int(q_dict[q_part])
319
+ except Exception:
320
+ try:
321
+ q_part2 = q_part
322
+ qv = q_vector(q_part)
323
+ u2, v2, w2 = u, v, w
324
+ u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
325
+ while did_one:
326
+ u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
327
+ q_part2 = np.prod(
328
+ [q_var[i + 1] ** qv[i] for i in range(len(qv))],
329
+ )
330
+ if q_part2 == 1:
331
+ # reduced to classical coefficient
332
+ val2 += q_part * norm_yz.posify(
333
+ q_dict[q_part],
334
+ u2,
335
+ v2,
336
+ w2,
337
+ var2,
338
+ var3,
339
+ msg,
340
+ False,
341
+ )
342
+ else:
343
+ val2 += q_part * norm_yz.compute_positive_rep(
344
+ q_dict[q_part],
345
+ var2,
346
+ var3,
347
+ msg,
348
+ False,
349
+ )
350
+ except Exception:
351
+ # print(f"Exception: {e}")
352
+ import traceback
353
+
354
+ traceback.print_exc()
355
+ exit(1)
356
+ if expand(val - val2) != 0:
357
+ raise Exception
358
+ return val2
359
+
360
+
361
+ def q_partial_posify_generic(val, u, v, w):
362
+ try:
363
+ val2 = int(expand(val))
364
+ except Exception:
365
+ val2 = 0
366
+ # logger.debug(f"{val=}")
367
+ q_dict = factor_out_q_keep_factored(val)
368
+ # logger.debug(f"{q_dict=}")
369
+ for q_part in q_dict:
370
+ try:
371
+ val2 += q_part * int(q_dict[q_part])
372
+ except Exception:
373
+ try:
374
+ if code(~v) == medium_theta(~v):
375
+ val2 += q_part * q_dict[q_part]
376
+ else:
377
+ q_part2 = q_part
378
+ qv = q_vector(q_part)
379
+ u2, v2, w2 = u, v, w
380
+ u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
381
+ while did_one:
382
+ u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
383
+ q_part2 = np.prod(
384
+ [_vars.q_var[i + 1] ** qv[i] for i in range(len(qv))],
385
+ )
386
+ if q_part2 == 1:
387
+ # reduced to classical coefficient
388
+ # logger.debug(f"{u=} {v=} {w=} {u2=} {v2=} {w2=} {q_part=} {q_dict[q_part]=}")
389
+ val2 += q_part * norm_yz.posify_generic_partial(
390
+ q_dict[q_part],
391
+ u2,
392
+ v2,
393
+ w2,
394
+ )
395
+ else:
396
+ val2 += q_part * q_dict[q_part]
397
+ except Exception as e:
398
+ logger.debug(f"Exception: {e}")
399
+
400
+ #import traceback
401
+
402
+ #traceback.print_exc()
403
+ if expand(val - val2) != 0:
404
+ raise Exception
405
+ return val2
406
+
407
+ def apply_peterson_woodward(coeff_dict, parabolic_index,q_var=_vars.q_var):
408
+ max_len = parabolic_index[-1] + 1
409
+ w_P = longest_element(parabolic_index)
410
+ w_P_prime = Permutation([1, 2])
411
+ coeff_dict_update = {}
412
+ for w_1 in coeff_dict.keys():
413
+ val = coeff_dict[w_1]
414
+ q_dict = factor_out_q_keep_factored(val)
415
+ for q_part in q_dict:
416
+ qv = q_vector(q_part)
417
+ w = w_1
418
+ good = True
419
+ parabolic_index2 = []
420
+ for i in range(len(parabolic_index)):
421
+ if omega(parabolic_index[i], qv) == 0:
422
+ parabolic_index2 += [parabolic_index[i]]
423
+ elif omega(parabolic_index[i], qv) != -1:
424
+ good = False
425
+ break
426
+ if not good:
427
+ continue
428
+ w_P_prime = longest_element(parabolic_index2)
429
+ if not check_blocks(qv, parabolic_index):
430
+ continue
431
+ w = (w * w_P_prime) * w_P
432
+ if not is_parabolic(w, parabolic_index):
433
+ continue
434
+
435
+ w = permtrim(w)
436
+ if len(w) > max_len:
437
+ continue
438
+ new_q_part = np.prod(
439
+ [q_var[index + 1 - count_less_than(parabolic_index, index + 1)] ** qv[index] for index in range(len(qv)) if index + 1 not in parabolic_index],
440
+ )
441
+ try:
442
+ new_q_part = int(new_q_part)
443
+ except Exception:
444
+ pass
445
+ q_val_part = q_dict[q_part]
446
+ coeff_dict_update[w] = coeff_dict_update.get(w, 0) + new_q_part * q_val_part
447
+ return coeff_dict_update
448
+
449
+
450
+ def elem_sym_func_q_q(k, i, u1, u2, v1, v2, udiff, vdiff, varl1, varl2, q_var=_vars.q_var):
451
+ newk = k - udiff
452
+ if newk < vdiff:
453
+ return 0
454
+ if newk == vdiff:
455
+ return 1
456
+ yvars = []
457
+ # mlen = max(len(u1), len(u2))
458
+ # u1 = [*u1] + [a + 1 for a in range(len(u1), mlen)]
459
+ # u2 = [*u2] + [a + 1 for a in range(len(u2), mlen)]
460
+ for j in range(k):
461
+ if u1[j] == u2[j]:
462
+ yvars += [varl1[u2[j]]]
463
+ # for j in range(len(u1), min(k, len(u2))):
464
+ # if u2[j] == j + 1:
465
+ # yvars += [varl1[u2[j]]]
466
+ # for j in range(len(u2), k):
467
+ # yvars += [varl1[j + 1]]
468
+ zvars = [varl2[a] for a in call_zvars(v1, v2, k, i)]
469
+ return elem_sym_poly_q(newk - vdiff, newk, yvars, zvars, q_var)
470
+
471
+
472
+ def schubpoly_quantum(v, var_x=_vars.var1, var_y=_vars.var2, q_var=_vars.q_var, coeff=1):
473
+ th = strict_theta(~v)
474
+ mu = uncode(th)
475
+ vmu = v * mu # permtrim(mulperm([*v], mu))
476
+ if len(th) == 0:
477
+ return coeff
478
+ while len(th) > 0 and th[-1] == 0:
479
+ th.pop()
480
+ vpathdicts = compute_vpathdicts(th, vmu)
481
+ vpathsums = {Permutation([1, 2]): {Permutation([1, 2]): coeff}}
482
+ inv_mu = inv(mu)
483
+ inv_vmu = inv(vmu)
484
+ inv_u = 0
485
+ ret_dict = {}
486
+ for index in range(len(th)):
487
+ mx_th = 0
488
+ for vp in vpathdicts[index]:
489
+ for v2, vdiff, s in vpathdicts[index][vp]:
490
+ mx_th = max(mx_th, th[index] - vdiff)
491
+ newpathsums = {}
492
+ for up in vpathsums:
493
+ inv_up = inv(up)
494
+ newperms = elem_sym_perms_q(
495
+ up,
496
+ min(mx_th, (inv_mu - (inv_up - inv_u)) - inv_vmu),
497
+ th[index],
498
+ q_var,
499
+ )
500
+ for up2, udiff, mul_val in newperms:
501
+ if up2 not in newpathsums:
502
+ newpathsums[up2] = {}
503
+ for v in vpathdicts[index]:
504
+ sumval = vpathsums[up].get(v, 0) * mul_val
505
+ if sumval == 0:
506
+ continue
507
+ for v2, vdiff, s in vpathdicts[index][v]:
508
+ newpathsums[up2][v2] = newpathsums[up2].get(
509
+ v2,
510
+ 0,
511
+ ) + s * sumval * elem_sym_func_q_q(
512
+ th[index],
513
+ index + 1,
514
+ up,
515
+ up2,
516
+ v,
517
+ v2,
518
+ udiff,
519
+ vdiff,
520
+ var_x,
521
+ var_y,
522
+ q_var,
523
+ )
524
+ vpathsums = newpathsums
525
+ toget = vmu
526
+ ret_dict = add_perm_dict({ep: vpathsums[ep].get(toget, 0) for ep in vpathsums}, ret_dict)
527
+ return ret_dict[Permutation([1, 2])]
528
+
529
+
530
+ def schubmult_q_double(perm_dict, v, var2=_vars.var2, var3=_vars.var3, q_var=_vars.q_var):
531
+ if v == Permutation([1, 2]):
532
+ return perm_dict
533
+ th = strict_theta(~v)
534
+ mu = uncode(th)
535
+ vmu = v * mu # permtrim(mulperm([*v], mu))
536
+ inv_vmu = inv(vmu)
537
+ inv_mu = inv(mu)
538
+ ret_dict = {}
539
+ if len(th) == 0:
540
+ return perm_dict
541
+ while th[-1] == 0:
542
+ th.pop()
543
+ thL = len(th)
544
+ vpathdicts = compute_vpathdicts(th, vmu, True)
545
+ for u, val in perm_dict.items():
546
+ inv_u = inv(u)
547
+ vpathsums = {u: {Permutation([1, 2]): val}}
548
+ for index in range(thL):
549
+ mx_th = 0
550
+ for vp in vpathdicts[index]:
551
+ for v2, vdiff, s in vpathdicts[index][vp]:
552
+ mx_th = max(mx_th, th[index] - vdiff)
553
+ newpathsums = {}
554
+ for up in vpathsums:
555
+ inv_up = inv(up)
556
+ newperms = elem_sym_perms_q(
557
+ up,
558
+ min(mx_th, (inv_mu - (inv_up - inv_u)) - inv_vmu),
559
+ th[index],
560
+ q_var,
561
+ )
562
+ for up2, udiff, mul_val in newperms:
563
+ if up2 not in newpathsums:
564
+ newpathsums[up2] = {}
565
+ for v in vpathdicts[index]:
566
+ sumval = vpathsums[up].get(v, 0) * mul_val
567
+ if sumval == 0:
568
+ continue
569
+ for v2, vdiff, s in vpathdicts[index][v]:
570
+ newpathsums[up2][v2] = newpathsums[up2].get(
571
+ v2,
572
+ 0,
573
+ ) + s * sumval * elem_sym_func_q(
574
+ th[index],
575
+ index + 1,
576
+ up,
577
+ up2,
578
+ v,
579
+ v2,
580
+ udiff,
581
+ vdiff,
582
+ var2,
583
+ var3,
584
+ )
585
+ vpathsums = newpathsums
586
+ toget = vmu
587
+ ret_dict = add_perm_dict({ep: vpathsums[ep].get(toget, 0) for ep in vpathsums}, ret_dict)
588
+ return ret_dict
589
+
590
+
591
+ def schubmult_q_double_fast(perm_dict, v, var2=_vars.var2, var3=_vars.var3, q_var=_vars.q_var):
592
+ if v == Permutation([1, 2]):
593
+ return perm_dict
594
+ th = medium_theta(~v)
595
+ if len(th) == 0:
596
+ return perm_dict
597
+ while th[-1] == 0:
598
+ th.pop()
599
+ mu = uncode(th)
600
+ vmu = v * mu # permtrim(mulperm([*v], mu))
601
+ inv_vmu = inv(vmu)
602
+ inv_mu = inv(mu)
603
+ ret_dict = {}
604
+
605
+ thL = len(th)
606
+ vpathdicts = compute_vpathdicts(th, vmu, True)
607
+ for u, val in perm_dict.items():
608
+ inv_u = inv(u)
609
+ vpathsums = {u: {Permutation([]): val}}
610
+ for index in range(thL):
611
+ if index > 0 and th[index - 1] == th[index]:
612
+ continue
613
+ mx_th = 0
614
+ for vp in vpathdicts[index]:
615
+ for v2, vdiff, s in vpathdicts[index][vp]:
616
+ mx_th = max(mx_th, th[index] - vdiff)
617
+ if index < len(th) - 1 and th[index] == th[index + 1]:
618
+ mx_th1 = 0
619
+ for vp in vpathdicts[index + 1]:
620
+ for v2, vdiff, s in vpathdicts[index + 1][vp]:
621
+ mx_th1 = max(mx_th1, th[index + 1] - vdiff)
622
+ newpathsums = {}
623
+ for up in vpathsums:
624
+ newpathsums0 = {}
625
+ inv_up = inv(up)
626
+ newperms = double_elem_sym_q(up, mx_th, mx_th1, th[index], q_var)
627
+ for v in vpathdicts[index]:
628
+ sumval = vpathsums[up].get(v, 0)
629
+ if sumval == 0:
630
+ continue
631
+ for v2, vdiff2, s2 in vpathdicts[index][v]:
632
+ for up1, udiff1, mul_val1 in newperms:
633
+ esim1 = (
634
+ elem_sym_func_q(
635
+ th[index],
636
+ index + 1,
637
+ up,
638
+ up1,
639
+ v,
640
+ v2,
641
+ udiff1,
642
+ vdiff2,
643
+ var2,
644
+ var3,
645
+ )
646
+ * mul_val1
647
+ * s2
648
+ )
649
+ mulfac = sumval * esim1
650
+ if (up1, udiff1, mul_val1) not in newpathsums0:
651
+ newpathsums0[(up1, udiff1, mul_val1)] = {}
652
+ # newpathsums0[(up1, udiff1, mul_val1
653
+ newpathsums0[(up1, udiff1, mul_val1)][v2] = newpathsums0[(up1, udiff1, mul_val1)].get(v2, 0) + mulfac
654
+
655
+ for up1, udiff1, mul_val1 in newpathsums0:
656
+ for v in vpathdicts[index + 1]:
657
+ sumval = newpathsums0[(up1, udiff1, mul_val1)].get(v, 0)
658
+ if sumval == 0:
659
+ continue
660
+ for v2, vdiff2, s2 in vpathdicts[index + 1][v]:
661
+ for up2, udiff2, mul_val2 in newperms[(up1, udiff1, mul_val1)]:
662
+ esim1 = (
663
+ elem_sym_func_q(
664
+ th[index + 1],
665
+ index + 2,
666
+ up1,
667
+ up2,
668
+ v,
669
+ v2,
670
+ udiff2,
671
+ vdiff2,
672
+ var2,
673
+ var3,
674
+ )
675
+ * mul_val2
676
+ * s2
677
+ )
678
+ mulfac = sumval * esim1
679
+ if up2 not in newpathsums:
680
+ newpathsums[up2] = {}
681
+ newpathsums[up2][v2] = newpathsums[up2].get(v2, 0) + mulfac
682
+ else:
683
+ newpathsums = {}
684
+ for up in vpathsums:
685
+ inv_up = inv(up)
686
+ newperms = elem_sym_perms_q(
687
+ up,
688
+ min(mx_th, (inv_mu - (inv_up - inv_u)) - inv_vmu),
689
+ th[index],
690
+ q_var,
691
+ )
692
+ for up2, udiff, mul_val in newperms:
693
+ if up2 not in newpathsums:
694
+ newpathsums[up2] = {}
695
+ for v in vpathdicts[index]:
696
+ sumval = vpathsums[up].get(v, 0) * mul_val
697
+ if sumval == 0:
698
+ continue
699
+ for v2, vdiff, s in vpathdicts[index][v]:
700
+ newpathsums[up2][v2] = newpathsums[up2].get(
701
+ v2,
702
+ 0,
703
+ ) + s * sumval * elem_sym_func_q(
704
+ th[index],
705
+ index + 1,
706
+ up,
707
+ up2,
708
+ v,
709
+ v2,
710
+ udiff,
711
+ vdiff,
712
+ var2,
713
+ var3,
714
+ )
715
+ vpathsums = newpathsums
716
+ toget = vmu
717
+ ret_dict = add_perm_dict({ep: vpathsums[ep].get(toget, 0) for ep in vpathsums}, ret_dict)
718
+ return ret_dict
719
+
720
+
721
+ # def schubmult_q_double_fast(perm_dict, v, var2=_vars.var2, var3=_vars.var3, q_var=_vars.q_var):
722
+ # if v == (1, 2):
723
+ # return perm_dict
724
+ # th = medium_theta(inverse(v))
725
+ # if len(th) == 0:
726
+ # return perm_dict
727
+ # while th[-1] == 0:
728
+ # th.pop()
729
+ # mu = permtrim(uncode(th))
730
+ # vmu = permtrim(mulperm([*v], mu))
731
+ # inv_vmu = inv(vmu)
732
+ # inv_mu = inv(mu)
733
+ # ret_dict = {}
734
+
735
+ # thL = len(th)
736
+ # vpathdicts = compute_vpathdicts(th, vmu, True)
737
+ # for u, val in perm_dict.items():
738
+ # inv_u = inv(u)
739
+ # vpathsums = {u: {(1, 2): val}}
740
+ # for index in range(thL):
741
+ # if index > 0 and th[index - 1] == th[index]:
742
+ # continue
743
+ # mx_th = 0
744
+ # for vp in vpathdicts[index]:
745
+ # for v2, vdiff, s in vpathdicts[index][vp]:
746
+ # mx_th = max(mx_th, th[index] - vdiff)
747
+ # if index < len(th) - 1 and th[index] == th[index + 1]:
748
+ # mx_th1 = 0
749
+ # for vp in vpathdicts[index + 1]:
750
+ # for v2, vdiff, s in vpathdicts[index + 1][vp]:
751
+ # mx_th1 = max(mx_th1, th[index + 1] - vdiff)
752
+ # newpathsums = {}
753
+ # for up in vpathsums:
754
+ # newpathsums0 = {}
755
+ # inv_up = inv(up)
756
+ # newperms = double_elem_sym_q(up, mx_th, mx_th1, th[index], q_var)
757
+ # for v in vpathdicts[index]:
758
+ # sumval = vpathsums[up].get(v, 0)
759
+ # if sumval == 0:
760
+ # continue
761
+ # for v2, vdiff2, s2 in vpathdicts[index][v]:
762
+ # for up1, udiff1, mul_val1 in newperms:
763
+ # esim1 = (
764
+ # elem_sym_func_q(
765
+ # th[index],
766
+ # index + 1,
767
+ # up,
768
+ # up1,
769
+ # v,
770
+ # v2,
771
+ # udiff1,
772
+ # vdiff2,
773
+ # var2,
774
+ # var3,
775
+ # )
776
+ # * mul_val1
777
+ # * s2
778
+ # )
779
+ # mulfac = sumval * esim1
780
+ # if (up1, udiff1, mul_val1) not in newpathsums0:
781
+ # newpathsums0[(up1, udiff1, mul_val1)] = {}
782
+ # # newpathsums0[(up1, udiff1, mul_val1
783
+ # newpathsums0[(up1, udiff1, mul_val1)][v2] = newpathsums0[(up1, udiff1, mul_val1)].get(v2, 0) + mulfac
784
+
785
+ # for up1, udiff1, mul_val1 in newpathsums0:
786
+ # for v in vpathdicts[index + 1]:
787
+ # sumval = newpathsums0[(up1, udiff1, mul_val1)].get(v, 0)
788
+ # if sumval == 0:
789
+ # continue
790
+ # for v2, vdiff2, s2 in vpathdicts[index + 1][v]:
791
+ # for up2, udiff2, mul_val2 in newperms[(up1, udiff1, mul_val1)]:
792
+ # esim1 = (
793
+ # elem_sym_func_q(
794
+ # th[index + 1],
795
+ # index + 2,
796
+ # up1,
797
+ # up2,
798
+ # v,
799
+ # v2,
800
+ # udiff2,
801
+ # vdiff2,
802
+ # var2,
803
+ # var3,
804
+ # )
805
+ # * mul_val2
806
+ # * s2
807
+ # )
808
+ # mulfac = sumval * esim1
809
+ # if up2 not in newpathsums:
810
+ # newpathsums[up2] = {}
811
+ # newpathsums[up2][v2] = newpathsums[up2].get(v2, 0) + mulfac
812
+ # else:
813
+ # newpathsums = {}
814
+ # for up in vpathsums:
815
+ # inv_up = inv(up)
816
+ # newperms = elem_sym_perms_q(
817
+ # up,
818
+ # min(mx_th, (inv_mu - (inv_up - inv_u)) - inv_vmu),
819
+ # th[index],
820
+ # q_var,
821
+ # )
822
+ # for up2, udiff, mul_val in newperms:
823
+ # if up2 not in newpathsums:
824
+ # newpathsums[up2] = {}
825
+ # for v in vpathdicts[index]:
826
+ # sumval = vpathsums[up].get(v, 0) * mul_val
827
+ # if sumval == 0:
828
+ # continue
829
+ # for v2, vdiff, s in vpathdicts[index][v]:
830
+ # newpathsums[up2][v2] = newpathsums[up2].get(
831
+ # v2,
832
+ # 0,
833
+ # ) + s * sumval * elem_sym_func_q(
834
+ # th[index],
835
+ # index + 1,
836
+ # up,
837
+ # up2,
838
+ # v,
839
+ # v2,
840
+ # udiff,
841
+ # vdiff,
842
+ # var2,
843
+ # var3,
844
+ # )
845
+ # vpathsums = newpathsums
846
+ # toget = tuple(vmu)
847
+ # ret_dict = add_perm_dict({ep: vpathsums[ep].get(toget, 0) for ep in vpathsums}, ret_dict)
848
+ # return ret_dict
849
+
850
+
851
+ def div_diff(v, w, var2=_vars.var2, var3=_vars.var3):
852
+ coeff_dict = {v: 1}
853
+ coeff_dict = norm_yz.schubmult_down(coeff_dict, w, var2, var3)
854
+ return coeff_dict.get(Permutation([1, 2]), 0)
855
+
856
+
857
+ def sum_q_dict(q_dict1, q_dict2):
858
+ ret = {**q_dict1}
859
+ for key in q_dict2:
860
+ ret[key] = ret.get(key, 0) + q_dict2[key]
861
+ return ret
862
+
863
+
864
+ def mul_q_dict(q_dict1, q_dict2):
865
+ ret = {}
866
+ for key1 in q_dict1:
867
+ for key2 in q_dict2:
868
+ key3 = key1 * key2
869
+ ret[key3] = ret.get(key3, 0) + q_dict1[key1] * q_dict2[key2]
870
+ return ret
871
+
872
+
873
+ def factor_out_q_keep_factored(poly, q_var=_vars.q_var):
874
+ ret = {}
875
+ # if str(poly).find("q") == -1:
876
+ # ret[1] = poly
877
+ # return ret
878
+ if not isinstance(q_var, GeneratingSet_base):
879
+ q_var = CustomGeneratingSet(q_var)
880
+ # logger.debug(f"{poly=}")
881
+ found_one = False
882
+ for s in sympify(poly).free_symbols:
883
+ if q_var.index(s) != -1:
884
+ found_one = True
885
+ # logger.debug("frobble bagel")
886
+
887
+ if not found_one:
888
+ ret[1] = poly
889
+ return ret
890
+ if q_var.index(poly) != -1:
891
+ # logger.debug("it might be poke")
892
+ # logger.debug(f"{poly=}")
893
+ ret[poly] = 1
894
+ return ret
895
+ if isinstance(poly, Add):
896
+ ag = poly.args
897
+ ret = factor_out_q_keep_factored(ag[0])
898
+ for i in range(1, len(ag)):
899
+ ret = sum_q_dict(ret, factor_out_q_keep_factored(ag[i]))
900
+ return ret
901
+ if isinstance(poly, Mul):
902
+ ag = poly.args
903
+ ret = factor_out_q_keep_factored(ag[0])
904
+ for i in range(1, len(ag)):
905
+ ret = mul_q_dict(ret, factor_out_q_keep_factored(ag[i]))
906
+ return ret
907
+ if isinstance(poly, Pow):
908
+ base = poly.args[0]
909
+ exponent = int(poly.args[1])
910
+
911
+ ret = factor_out_q_keep_factored(base)
912
+ ret0 = dict(ret)
913
+ for _ in range(exponent - 1):
914
+ ret = mul_q_dict(ret, ret0)
915
+ return ret
916
+ raise ValueError
917
+
918
+
919
+ def factor_out_q(poly):
920
+ coeff_dict = expand(poly).as_coefficients_dict()
921
+ ret = {}
922
+ for key in coeff_dict:
923
+ coeff = coeff_dict[key]
924
+ if coeff == 0:
925
+ continue
926
+ q_part = 1
927
+ yz_part = coeff
928
+ if isinstance(key, Mul):
929
+ for var_maybe_pow in key.args:
930
+ if isinstance(var_maybe_pow, Pow):
931
+ real_var = var_maybe_pow.args[0]
932
+ if real_var in _vars.q_var:
933
+ q_part *= var_maybe_pow
934
+ else:
935
+ yz_part *= var_maybe_pow
936
+ else:
937
+ real_var = var_maybe_pow
938
+ if real_var in _vars.q_var:
939
+ q_part *= var_maybe_pow
940
+ else:
941
+ yz_part *= var_maybe_pow
942
+ elif isinstance(key, Pow):
943
+ real_var = key.args[0]
944
+ if real_var in _vars.q_var:
945
+ q_part *= key
946
+ else:
947
+ yz_part *= key
948
+ elif key in _vars.q_var:
949
+ q_part *= key
950
+ else:
951
+ yz_part *= key
952
+
953
+ ret[q_part] = ret.get(q_part, 0) + yz_part
954
+ return ret