schubmult 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. schubmult/__init__.py +1 -0
  2. schubmult/_base_argparse.py +174 -0
  3. schubmult/perm_lib.py +999 -0
  4. schubmult/sage_integration/__init__.py +25 -0
  5. schubmult/sage_integration/_fast_double_schubert_polynomial_ring.py +528 -0
  6. schubmult/sage_integration/_fast_schubert_polynomial_ring.py +356 -0
  7. schubmult/sage_integration/_indexing.py +44 -0
  8. schubmult/schubmult_double/__init__.py +18 -0
  9. schubmult/schubmult_double/__main__.py +5 -0
  10. schubmult/schubmult_double/_funcs.py +1590 -0
  11. schubmult/schubmult_double/_script.py +407 -0
  12. schubmult/schubmult_double/_vars.py +16 -0
  13. schubmult/schubmult_py/__init__.py +10 -0
  14. schubmult/schubmult_py/__main__.py +5 -0
  15. schubmult/schubmult_py/_funcs.py +111 -0
  16. schubmult/schubmult_py/_script.py +115 -0
  17. schubmult/schubmult_py/_vars.py +3 -0
  18. schubmult/schubmult_q/__init__.py +12 -0
  19. schubmult/schubmult_q/__main__.py +5 -0
  20. schubmult/schubmult_q/_funcs.py +304 -0
  21. schubmult/schubmult_q/_script.py +157 -0
  22. schubmult/schubmult_q/_vars.py +18 -0
  23. schubmult/schubmult_q_double/__init__.py +14 -0
  24. schubmult/schubmult_q_double/__main__.py +5 -0
  25. schubmult/schubmult_q_double/_funcs.py +507 -0
  26. schubmult/schubmult_q_double/_script.py +337 -0
  27. schubmult/schubmult_q_double/_vars.py +21 -0
  28. schubmult-2.0.0.dist-info/METADATA +455 -0
  29. schubmult-2.0.0.dist-info/RECORD +36 -0
  30. schubmult-2.0.0.dist-info/WHEEL +5 -0
  31. schubmult-2.0.0.dist-info/entry_points.txt +5 -0
  32. schubmult-2.0.0.dist-info/licenses/LICENSE +674 -0
  33. schubmult-2.0.0.dist-info/top_level.txt +2 -0
  34. tests/__init__.py +0 -0
  35. tests/test_fast_double_schubert.py +145 -0
  36. tests/test_fast_schubert.py +38 -0
@@ -0,0 +1,1590 @@
1
+ from bisect import bisect_left
2
+ from functools import cache
3
+ from cachetools import cached
4
+ from cachetools.keys import hashkey
5
+ from symengine import sympify, Add, Mul, Pow, expand, Integer
6
+ from schubmult.perm_lib import (
7
+ elem_sym_perms,
8
+ elem_sym_poly,
9
+ add_perm_dict,
10
+ dominates,
11
+ compute_vpathdicts,
12
+ inverse,
13
+ theta,
14
+ permtrim,
15
+ inv,
16
+ mulperm,
17
+ code,
18
+ uncode,
19
+ elem_sym_func,
20
+ elem_sym_perms_op,
21
+ divdiffable,
22
+ pull_out_var,
23
+ cycle,
24
+ will_formula_work,
25
+ one_dominates,
26
+ is_reducible,
27
+ reduce_coeff,
28
+ reduce_descents,
29
+ try_reduce_u,
30
+ try_reduce_v,
31
+ phi1,
32
+ zero,
33
+ )
34
+ import numpy as np
35
+ import pulp as pu
36
+ import sympy
37
+ import psutil
38
+ from sortedcontainers import SortedList
39
+ from ._vars import (
40
+ n,
41
+ var2,
42
+ var3,
43
+ var_x,
44
+ var_y,
45
+ )
46
+
47
+
48
+ def count_sorted(mn, tp):
49
+ index = bisect_left(mn, tp)
50
+ ct = 0
51
+ if mn[index] == tp:
52
+ while index < len(mn) and mn[index] == tp:
53
+ ct += 1
54
+ return ct
55
+
56
+
57
+ def E(p, k, varl=var_y[1:]):
58
+ return elem_sym_poly(p, k, var_x[1:], varl)
59
+
60
+
61
+ def single_variable(coeff_dict, varnum, var2=var2):
62
+ ret = {}
63
+ for u in coeff_dict:
64
+ if varnum - 1 < len(u):
65
+ ret[u] = ret.get(u, 0) + var2[u[varnum - 1]] * coeff_dict[u]
66
+ else:
67
+ ret[u] = ret.get(u, 0) + var2[varnum] * coeff_dict[u]
68
+ new_perms_k = elem_sym_perms(u, 1, varnum)
69
+ new_perms_km1 = []
70
+ if varnum > 1:
71
+ new_perms_km1 = elem_sym_perms(u, 1, varnum - 1)
72
+ for perm, udiff in new_perms_k:
73
+ if udiff == 1:
74
+ ret[perm] = ret.get(perm, 0) + coeff_dict[u]
75
+ for perm, udiff in new_perms_km1:
76
+ if udiff == 1:
77
+ ret[perm] = ret.get(perm, 0) - coeff_dict[u]
78
+ return ret
79
+
80
+
81
+ def single_variable_down(coeff_dict, varnum):
82
+ ret = {}
83
+ for u in coeff_dict:
84
+ if varnum - 1 < len(u):
85
+ ret[u] = ret.get(u, 0) + var2[u[varnum - 1]] * coeff_dict[u]
86
+ else:
87
+ ret[u] = ret.get(u, 0) + var2[varnum] * coeff_dict[u]
88
+ new_perms_k = elem_sym_perms_op(u, 1, varnum)
89
+ new_perms_km1 = []
90
+ if varnum > 1:
91
+ new_perms_km1 = elem_sym_perms_op(u, 1, varnum - 1)
92
+ for perm, udiff in new_perms_k:
93
+ if udiff == 1:
94
+ ret[perm] = ret.get(perm, 0) + coeff_dict[u]
95
+ for perm, udiff in new_perms_km1:
96
+ if udiff == 1:
97
+ ret[perm] = ret.get(perm, 0) - coeff_dict[u]
98
+ return ret
99
+
100
+
101
+ def mult_poly(coeff_dict, poly, var_x=var_x, var_y=var2):
102
+ if poly in var_x:
103
+ return single_variable(coeff_dict, var_x.index(poly), var_y)
104
+ elif isinstance(poly, Mul):
105
+ ret = coeff_dict
106
+ for a in poly.args:
107
+ ret = mult_poly(ret, a, var_x, var_y)
108
+ return ret
109
+ elif isinstance(poly, Pow):
110
+ base = poly.args[0]
111
+ exponent = int(poly.args[1])
112
+ ret = coeff_dict
113
+ for i in range(int(exponent)):
114
+ ret = mult_poly(ret, base, var_x, var_y)
115
+ return ret
116
+ elif isinstance(poly, Add):
117
+ ret = {}
118
+ for a in poly.args:
119
+ ret = add_perm_dict(ret, mult_poly(coeff_dict, a, var_x, var_y))
120
+ return ret
121
+ else:
122
+ ret = {}
123
+ for perm in coeff_dict:
124
+ ret[perm] = poly * coeff_dict[perm]
125
+ return ret
126
+
127
+
128
+ def mult_poly_down(coeff_dict, poly):
129
+ if poly in var_x:
130
+ return single_variable_down(coeff_dict, var_x.index(poly))
131
+ elif isinstance(poly, Mul):
132
+ ret = coeff_dict
133
+ for a in poly.args:
134
+ ret = mult_poly_down(ret, a)
135
+ return ret
136
+ elif isinstance(poly, Pow):
137
+ base = poly.args[0]
138
+ exponent = int(poly.args[1])
139
+ ret = coeff_dict
140
+ for i in range(int(exponent)):
141
+ ret = mult_poly_down(ret, base)
142
+ return ret
143
+ elif isinstance(poly, Add):
144
+ ret = {}
145
+ for a in poly.args:
146
+ ret = add_perm_dict(ret, mult_poly_down(coeff_dict, a))
147
+ return ret
148
+ else:
149
+ ret = {}
150
+ for perm in coeff_dict:
151
+ ret[perm] = poly * coeff_dict[perm]
152
+ return ret
153
+
154
+
155
+ def nilhecke_mult(coeff_dict1, coeff_dict2):
156
+ ret = {}
157
+ for w in coeff_dict2:
158
+ w1 = [*w]
159
+ inv_w1 = inv(w1)
160
+ poly = coeff_dict2[w]
161
+ did_mul = mult_poly_down(coeff_dict1, poly)
162
+ for v in did_mul:
163
+ v1 = [*v]
164
+ addperm = mulperm(v1, w1)
165
+ if inv(addperm) == inv(v1) + inv_w1:
166
+ toadd = tuple(permtrim(addperm))
167
+ ret[toadd] = ret.get(toadd, 0) + did_mul[v]
168
+ return ret
169
+
170
+
171
+ def forwardcoeff(u, v, perm, var2=var2, var3=var3):
172
+ th = theta(v)
173
+ muv = uncode(th)
174
+ vmun1 = mulperm(inverse([*v]), muv)
175
+
176
+ w = mulperm([*perm], vmun1)
177
+ if inv(w) == inv(vmun1) + inv(perm):
178
+ coeff_dict = schubmult_one(tuple(permtrim([*u])), tuple(muv), var2, var3)
179
+ return coeff_dict.get(tuple(permtrim(w)), 0)
180
+ return 0
181
+
182
+
183
+ def dualcoeff(u, v, perm, var2=var2, var3=var3):
184
+ if u == (1, 2):
185
+ vp = mulperm([*v], inverse(perm))
186
+ if inv(vp) == inv(v) - inv(perm):
187
+ val = schubpoly(vp, var2, var3)
188
+ else:
189
+ val = 0
190
+ else:
191
+ dpret = []
192
+ if dominates(u, perm):
193
+ dpret = dualpieri([*u], [*v], [*perm])
194
+ else:
195
+ th = theta(u)
196
+ muu = uncode(th)
197
+ umun1 = mulperm(inverse([*u]), muu)
198
+ w = mulperm([*perm], umun1)
199
+ if inv(w) == inv(umun1) + inv(perm):
200
+ dpret = dualpieri(muu, [*v], w)
201
+ ret = 0
202
+ for vlist, vp in dpret:
203
+ toadd = 1
204
+ for i in range(len(vlist)):
205
+ for j in range(len(vlist[i])):
206
+ toadd *= var2[i + 1] - var3[vlist[i][j]]
207
+ toadd *= schubpoly(vp, var2, var3, len(vlist) + 1)
208
+ ret += toadd
209
+ val = ret
210
+ return val
211
+
212
+
213
+ def dualpieri(mu, v, w):
214
+ lm = code(inverse(mu))
215
+ cn1w = code(inverse(w))
216
+ while len(lm) > 0 and lm[-1] == 0:
217
+ lm.pop()
218
+ while len(cn1w) > 0 and cn1w[-1] == 0:
219
+ cn1w.pop()
220
+ if len(cn1w) < len(lm):
221
+ return []
222
+ for i in range(len(lm)):
223
+ if lm[i] > cn1w[i]:
224
+ return []
225
+ c = [1, 2]
226
+ for i in range(len(lm), len(cn1w)):
227
+ c = mulperm(cycle(i - len(lm) + 1, cn1w[i]), c)
228
+ c = permtrim(c)
229
+ res = [[[], v]]
230
+ for i in range(len(lm)):
231
+ res2 = []
232
+ for vlist, vplist in res:
233
+ vp = vplist
234
+ vpl = divdiffable(vp, cycle(lm[i] + 1, cn1w[i] - lm[i]))
235
+ if vpl == []:
236
+ continue
237
+ vl = pull_out_var(lm[i] + 1, vpl)
238
+ for pw, vpl2 in vl:
239
+ res2 += [[vlist + [pw], vpl2]]
240
+ res = res2
241
+ if len(lm) == len(cn1w):
242
+ return res
243
+ else:
244
+ res2 = []
245
+ for vlist, vplist in res:
246
+ vp = vplist
247
+ vpl = divdiffable(vp, c)
248
+ if vpl == []:
249
+ continue
250
+ res2 += [[vlist, vpl]]
251
+ return res2
252
+
253
+
254
+ dimen = 0
255
+ monom_to_vec = {}
256
+
257
+
258
+ @cache
259
+ def schubmult_one(perm1, perm2, var2=var2, var3=var3):
260
+ return schubmult({perm1: 1}, perm2, var2, var3)
261
+
262
+
263
+ def schubmult(perm_dict, v, var2=var2, var3=var3):
264
+ vn1 = inverse(v)
265
+ th = theta(vn1)
266
+ if len(th) == 0:
267
+ return perm_dict
268
+ if th[0] == 0:
269
+ return perm_dict
270
+ mu = permtrim(uncode(th))
271
+ vmu = permtrim(mulperm([*v], mu))
272
+ inv_vmu = inv(vmu)
273
+ inv_mu = inv(mu)
274
+ ret_dict = {}
275
+ while th[-1] == 0:
276
+ th.pop()
277
+ thL = len(th)
278
+ vpathdicts = compute_vpathdicts(th, vmu, True)
279
+ for u, val in perm_dict.items():
280
+ inv_u = inv(u)
281
+ vpathsums = {u: {(1, 2): val}}
282
+ for index in range(thL):
283
+ mx_th = 0
284
+ for vp in vpathdicts[index]:
285
+ for v2, vdiff, s in vpathdicts[index][vp]:
286
+ if th[index] - vdiff > mx_th:
287
+ mx_th = th[index] - vdiff
288
+ newpathsums = {}
289
+ for up in vpathsums:
290
+ inv_up = inv(up)
291
+ newperms = elem_sym_perms(
292
+ up, min(mx_th, (inv_mu - (inv_up - inv_u)) - inv_vmu), th[index]
293
+ )
294
+ for up2, udiff in newperms:
295
+ if up2 not in newpathsums:
296
+ newpathsums[up2] = {}
297
+ for v in vpathdicts[index]:
298
+ sumval = vpathsums[up].get(v, zero)
299
+ if sumval == 0:
300
+ continue
301
+ for v2, vdiff, s in vpathdicts[index][v]:
302
+ newpathsums[up2][v2] = newpathsums[up2].get(
303
+ v2, zero
304
+ ) + s * sumval * elem_sym_func(
305
+ th[index],
306
+ index + 1,
307
+ up,
308
+ up2,
309
+ v,
310
+ v2,
311
+ udiff,
312
+ vdiff,
313
+ var2,
314
+ var3,
315
+ )
316
+ vpathsums = newpathsums
317
+ toget = tuple(vmu)
318
+ ret_dict = add_perm_dict({ep: vpathsums[ep].get(toget, 0) for ep in vpathsums}, ret_dict)
319
+ return ret_dict
320
+
321
+
322
+ def schubmult_down(perm_dict, v, var2=var2, var3=var3):
323
+ vn1 = inverse(v)
324
+ th = theta(vn1)
325
+ if th[0] == 0:
326
+ return perm_dict
327
+ mu = permtrim(uncode(th))
328
+ vmu = permtrim(mulperm([*v], mu))
329
+ ret_dict = {}
330
+
331
+ while th[-1] == 0:
332
+ th.pop()
333
+ thL = len(th)
334
+ vpathdicts = compute_vpathdicts(th, vmu, True)
335
+ for u, val in perm_dict.items():
336
+ vpathsums = {u: {(1, 2): val}}
337
+ for index in range(thL):
338
+ mx_th = 0
339
+ for vp in vpathdicts[index]:
340
+ for v2, vdiff, s in vpathdicts[index][vp]:
341
+ if th[index] - vdiff > mx_th:
342
+ mx_th = th[index] - vdiff
343
+ newpathsums = {}
344
+ for up in vpathsums:
345
+ newperms = elem_sym_perms_op(up, mx_th, th[index])
346
+ for up2, udiff in newperms:
347
+ if up2 not in newpathsums:
348
+ newpathsums[up2] = {}
349
+ for v in vpathdicts[index]:
350
+ sumval = vpathsums[up].get(v, zero)
351
+ if sumval == 0:
352
+ continue
353
+ for v2, vdiff, s in vpathdicts[index][v]:
354
+ newpathsums[up2][v2] = newpathsums[up2].get(
355
+ v2, zero
356
+ ) + s * sumval * elem_sym_func(
357
+ th[index],
358
+ index + 1,
359
+ up2,
360
+ up,
361
+ v,
362
+ v2,
363
+ udiff,
364
+ vdiff,
365
+ var2,
366
+ var3,
367
+ )
368
+ vpathsums = newpathsums
369
+ toget = tuple(vmu)
370
+ ret_dict = add_perm_dict({ep: vpathsums[ep].get(toget, 0) for ep in vpathsums}, ret_dict)
371
+ return ret_dict
372
+
373
+
374
+ def poly_to_vec(poly, vec0=None):
375
+ global dimen, monom_to_vec, base_vec
376
+ poly = expand(poly.xreplace({var3[1]: 0}))
377
+
378
+ dc = poly.as_coefficients_dict()
379
+
380
+ if vec0 is None:
381
+ init_basevec(dc)
382
+
383
+ vec = {}
384
+ for mn in dc:
385
+ cf = dc[mn]
386
+ if cf == 0:
387
+ continue
388
+ cf = abs(int(cf))
389
+ try:
390
+ index = monom_to_vec[mn]
391
+ except KeyError:
392
+ return None
393
+ if vec0 is not None and vec0[index] < cf:
394
+ return None
395
+ vec[index] = cf
396
+ return vec
397
+
398
+
399
+ def shiftsub(pol):
400
+ subs_dict = dict([(var2[i], var2[i + 1]) for i in range(99)])
401
+ return sympify(pol).subs(subs_dict)
402
+
403
+
404
+ def shiftsubz(pol):
405
+ subs_dict = dict([(var3[i], var3[i + 1]) for i in range(99)])
406
+ return sympify(pol).subs(subs_dict)
407
+
408
+
409
+ def init_basevec(dc):
410
+ global dimen, monom_to_vec, base_vec
411
+ monom_to_vec = {}
412
+ index = 0
413
+ for mn in dc:
414
+ if dc[mn] == 0:
415
+ continue
416
+ monom_to_vec[mn] = index
417
+ index += 1
418
+ dimen = index
419
+ base_vec = [0 for i in range(dimen)]
420
+
421
+
422
+ def split_flat_term(arg):
423
+ arg = expand(arg)
424
+ ys = []
425
+ zs = []
426
+ for arg2 in arg.args:
427
+ if str(arg2).find("y") != -1:
428
+ if isinstance(arg2, Mul):
429
+ for i in range(int(arg2.args[0])):
430
+ ys += [arg2.args[1]]
431
+ else:
432
+ ys += [arg2]
433
+ else:
434
+ if isinstance(arg2, Mul):
435
+ for i in range(abs(int(arg2.args[0]))):
436
+ zs += [-arg2.args[1]]
437
+ else:
438
+ zs += [arg2]
439
+ return ys, zs
440
+
441
+
442
+ def is_flat_term(term):
443
+ if isinstance(term, Integer) or isinstance(term, int):
444
+ return True
445
+ dc = expand(term).as_coefficients_dict()
446
+ for t in dc:
447
+ if str(t).count("y") + str(t).count("z") > 1 or str(t).find("**") != -1:
448
+ return False
449
+ return True
450
+
451
+
452
+ def flatten_factors(term, var2=var3, var3=var3):
453
+ found_one = False
454
+ if is_flat_term(term):
455
+ return term, False
456
+ elif isinstance(term, Pow):
457
+ if is_flat_term(term.args[0]) and len(term.args[0].args) > 2:
458
+ ys, zs = split_flat_term(term.args[0])
459
+ terms = [1]
460
+ for i in range(len(ys)):
461
+ terms2 = []
462
+ for j in range(len(term.args[1])):
463
+ for t in terms:
464
+ terms2 += [t * (ys[i] + zs[i])]
465
+ terms = terms2
466
+ return Add(*terms)
467
+ elif is_flat_term(term.args[0]):
468
+ return term, False
469
+ else:
470
+ return flatten_factors(term.args[0]) ** term.args[1], True
471
+ elif isinstance(term, Mul):
472
+ terms = [1]
473
+ for arg in term.args:
474
+ terms2 = []
475
+ if isinstance(arg, Add) and not is_flat_term(expand(arg)):
476
+ found_one = True
477
+ for term3 in terms:
478
+ for arg2 in arg.args:
479
+ flat, found = flatten_factors(arg2)
480
+ terms2 += [term3 * flat]
481
+ elif isinstance(arg, Add) and is_flat_term(arg) and len(arg.args) > 2:
482
+ found_one = True
483
+ ys, zs = split_flat_term(arg)
484
+ for term3 in terms:
485
+ for i in range(len(ys)):
486
+ terms2 += [term3 * (ys[i] + zs[i])]
487
+ else:
488
+ flat, found = flatten_factors(arg)
489
+ if found:
490
+ found_one = True
491
+ for term3 in terms:
492
+ terms2 += [term3 * flat]
493
+ terms = terms2
494
+ if len(terms) == 1:
495
+ term = terms[0]
496
+ else:
497
+ term = Add(*terms)
498
+ return term, found_one
499
+ elif isinstance(term, Add):
500
+ res = 0
501
+ for arg in term.args:
502
+ flat, found = flatten_factors(arg)
503
+ if found:
504
+ found_one = True
505
+ res += flat
506
+ return res, found_one
507
+
508
+
509
+ def fres(v):
510
+ for s in v.free_symbols:
511
+ return s
512
+
513
+
514
+ def split_mul(arg0, var2=var2, var3=var3):
515
+ monoms = SortedList()
516
+
517
+ var2s = {fres(var2[i]): i for i in range(len(var2))}
518
+ var3s = {fres(var3[i]): i for i in range(len(var3))}
519
+ # print(f"{type(arg0)=} {arg0=}")
520
+ if isinstance(arg0, Pow):
521
+ arg = arg0
522
+ arg2 = expand(arg.args[0])
523
+ yval = arg2.args[0]
524
+ zval = arg2.args[1]
525
+ if str(yval).find("z") != -1:
526
+ yval, zval = zval, yval
527
+ if str(zval).find("-") != -1:
528
+ zval = -zval
529
+ if str(yval).find("-") != -1:
530
+ yval = -yval
531
+ tup = (var2s[fres(yval)], var3s[fres(zval)])
532
+ for i in range(int(arg0.args[1])):
533
+ monoms += [tup]
534
+ else:
535
+ for arg in arg0.args:
536
+ if is_flat_term(arg):
537
+ if isinstance(arg, Integer) or isinstance(arg, int):
538
+ continue
539
+ arg = expand(arg)
540
+ if arg == 0:
541
+ break
542
+ yval = arg.args[0]
543
+ zval = arg.args[1]
544
+ if str(yval).find("z") != -1:
545
+ yval, zval = zval, yval
546
+ if str(zval).find("-") != -1:
547
+ zval = -zval
548
+ if str(yval).find("-") != -1:
549
+ yval = -yval
550
+ monoms += [(var2s[fres(yval)], var3s[fres(zval)])]
551
+ elif isinstance(arg, Pow):
552
+ arg2 = arg.args[0]
553
+ yval = arg2.args[0]
554
+ zval = arg2.args[1]
555
+ if str(yval).find("z") != -1:
556
+ yval, zval = zval, yval
557
+ if str(zval).find("-") != -1:
558
+ zval = -zval
559
+ if str(yval).find("-") != -1:
560
+ yval = -yval
561
+ tup = (var2s[fres(yval)], var3s[fres(zval)])
562
+ for i in range(int(arg.args[1])):
563
+ monoms += [tup]
564
+ return monoms
565
+
566
+
567
+ def split_monoms(pos_part, var2, var3):
568
+ arrs = SortedList()
569
+ if isinstance(pos_part, Add):
570
+ for arg0 in pos_part.args:
571
+ monoms = split_mul(arg0, var2, var3)
572
+ arrs += [monoms]
573
+ elif isinstance(pos_part, Mul) or isinstance(pos_part, Pow):
574
+ monoms = split_mul(pos_part, var2, var3)
575
+ arrs += [monoms]
576
+ else:
577
+ return [pos_part]
578
+ return arrs
579
+
580
+
581
+ def is_negative(term):
582
+ sign = 1
583
+ if isinstance(term, Integer) or isinstance(term, int):
584
+ return term < 0
585
+ elif isinstance(term, Mul):
586
+ for arg in term.args:
587
+ if isinstance(arg, Integer):
588
+ sign *= arg
589
+ elif isinstance(arg, Add):
590
+ if str(arg).find("-y") != -1:
591
+ sign *= -1
592
+ elif isinstance(arg, Pow):
593
+ mulsign = 1
594
+ if str(arg.args[0]).find("-y") != -1:
595
+ mulsign = -1
596
+ sign *= mulsign ** term.args[1]
597
+ elif isinstance(term, Pow):
598
+ mulsign = 1
599
+ if str(term.args[0]).find("-y") != -1:
600
+ mulsign = -1
601
+ sign *= mulsign ** term.args[1]
602
+ return sign < 0
603
+
604
+
605
+ def find_base_vectors(monom_list, monom_list_neg, var2, var3, depth):
606
+ size = 0
607
+ mn_fullcount = {}
608
+ # pairs_checked = set()
609
+ monom_list = set([tuple(mn) for mn in monom_list])
610
+ ct = 0
611
+ while ct < depth and size != len(monom_list):
612
+ size = len(monom_list)
613
+ # found = False
614
+ # for mn in mons2:
615
+ # if mn not in monom_list:
616
+ # found = True
617
+ # break
618
+ # if not found:
619
+ # print("Breaking")
620
+ # break
621
+
622
+ monom_list2 = set(monom_list)
623
+ additional_set2 = set()
624
+ for mn in monom_list:
625
+ # res = 1
626
+ # for tp in mn:
627
+ # res *= var2[tp[0]] - var3[tp[1]]
628
+ # if poly_to_vec(res,vec) is None:
629
+ # continue
630
+
631
+ mncount = mn_fullcount.get(mn, {})
632
+ if mncount == {}:
633
+ for tp in mn:
634
+ mncount[tp] = mncount.get(tp, 0) + 1
635
+ mn_fullcount[mn] = mncount
636
+ for mn2 in monom_list:
637
+ # if (mn,mn2) in pairs_checked:
638
+ # continue
639
+ mn2count = mn_fullcount.get(mn2, {})
640
+ if mn2count == {}:
641
+ for tp in mn2:
642
+ mn2count[tp] = mn2count.get(tp, 0) + 1
643
+ mn_fullcount[mn2] = mn2count
644
+ num_diff = 0
645
+ for tp in mncount:
646
+ pt = mn2count.get(tp, 0) - mncount[tp]
647
+ num_diff += abs(pt)
648
+ if num_diff > 1:
649
+ break
650
+ if num_diff == 1:
651
+ diff_term1 = None
652
+ diff_term2 = None
653
+ for tp in mn2count:
654
+ if mn2count[tp] > mncount.get(tp, 0):
655
+ diff_term2 = tp
656
+ break
657
+ for tp2 in mncount:
658
+ if mncount[tp2] > mn2count.get(tp2, 0):
659
+ diff_term1 = tp2
660
+ break
661
+ # print(f"{mn,mn2}")
662
+ if diff_term1 is None or diff_term2 is None:
663
+ print(f"{mn=} {mn2=}")
664
+ exit(1)
665
+ if diff_term2[1] == diff_term1[1]:
666
+ continue
667
+ new_term1 = (diff_term1[0], diff_term2[1])
668
+ new_term2 = (diff_term2[0], diff_term1[1])
669
+ # mn3 = [*mn]
670
+ # mn4 = list(mn2)
671
+ index = bisect_left(mn, diff_term1)
672
+ mn3 = list(mn[:index]) + list(mn[index + 1 :])
673
+ index = bisect_left(mn3, new_term1)
674
+ mn3_t = tuple(mn3[:index] + [new_term1] + mn3[index:])
675
+ index2 = bisect_left(mn2, diff_term2)
676
+ mn4 = list(mn2[:index2]) + list(mn2[index2 + 1 :])
677
+ index2 = bisect_left(mn4, new_term2)
678
+ mn4_t = tuple(mn4[:index2] + [new_term2] + mn4[index2:])
679
+ # res = 1
680
+ # for tp in mn3_t:
681
+ # res *= var2[tp[0]] - var3[tp[1]]
682
+ # if poly_to_vec(res,vec) is not None:
683
+ if mn3_t not in monom_list2:
684
+ additional_set2.add(mn3_t)
685
+ monom_list2.add(mn3_t)
686
+ # res = 1
687
+ # for tp in mn4_t:
688
+ # res *= var2[tp[0]] - var3[tp[1]]
689
+ ##
690
+ ## additional_set2.add(mn3_t)
691
+ # if poly_to_vec(res,vec) is not None:
692
+ if mn4_t not in monom_list2:
693
+ additional_set2.add(mn4_t)
694
+ monom_list2.add(mn4_t)
695
+ monom_list = monom_list2
696
+ ct += 1
697
+ ret = []
698
+ for mn in monom_list:
699
+ if len(mn) != len(set(mn)):
700
+ continue
701
+ res = 1
702
+ for tp in mn:
703
+ res *= var2[tp[0]] - var3[tp[1]]
704
+ ret += [res]
705
+ return ret, monom_list
706
+
707
+
708
+ def compute_positive_rep(val, var2=var2, var3=var3, msg=False, do_pos_neg=True):
709
+ notint = False
710
+ try:
711
+ int(expand(val))
712
+ val2 = expand(val)
713
+ except Exception:
714
+ notint = True
715
+ if notint:
716
+ frees = val.free_symbols
717
+ var2list = [*var2]
718
+ var3list = [*var3]
719
+
720
+ for i in range(len(var2list)):
721
+ symset = var2list[i].free_symbols
722
+ for sym in symset:
723
+ var2list[i] = sym
724
+
725
+ for i in range(len(var3list)):
726
+ symset = var3list[i].free_symbols
727
+ for sym in symset:
728
+ var3list[i] = sym
729
+
730
+ varsimp2 = [m for m in frees if m in var2list]
731
+ varsimp3 = [m for m in frees if m in var3list]
732
+ varsimp2.sort(key=lambda k: var2list.index(k))
733
+ varsimp3.sort(key=lambda k: var3list.index(k))
734
+
735
+ var22 = [sympy.sympify(m) for m in varsimp2]
736
+ var33 = [sympy.sympify(m) for m in varsimp3]
737
+ n1 = len(varsimp2)
738
+
739
+ for i in range(len(varsimp2)):
740
+ varsimp2[i] = var2[var2list.index(varsimp2[i])]
741
+ for i in range(len(varsimp3)):
742
+ varsimp3[i] = var3[var3list.index(varsimp3[i])]
743
+
744
+ base_vectors = []
745
+ base_monoms = []
746
+ vec = poly_to_vec(val, None)
747
+
748
+ if do_pos_neg:
749
+ smp = val
750
+ flat, found_one = flatten_factors(smp)
751
+ while found_one:
752
+ flat, found_one = flatten_factors(flat, varsimp2, varsimp3)
753
+ pos_part = 0
754
+ neg_part = 0
755
+ if isinstance(flat, Add) and not is_flat_term(flat):
756
+ for arg in flat.args:
757
+ if expand(arg) == 0:
758
+ continue
759
+ if not is_negative(arg):
760
+ pos_part += arg
761
+ else:
762
+ neg_part -= arg
763
+ if neg_part == 0:
764
+ # print("no neg")
765
+ return pos_part
766
+ depth = 1
767
+
768
+ mons = split_monoms(pos_part, varsimp2, varsimp3)
769
+ mons = set([tuple(mn) for mn in mons])
770
+ mons2 = split_monoms(neg_part, varsimp2, varsimp3)
771
+ mons2 = set([tuple(mn2) for mn2 in mons2])
772
+
773
+ # mons2 = split_monoms(neg_part)
774
+ # for mn in mons2:
775
+ # if mn not in mons:
776
+ # mons.add(mn)
777
+ # print(mons)
778
+ status = 0
779
+ size = len(mons)
780
+ while status != 1:
781
+ base_monoms, mons = find_base_vectors(mons, mons2, varsimp2, varsimp3, depth)
782
+ if len(mons) == size:
783
+ raise ValueError("Found counterexample")
784
+
785
+ size = len(mons)
786
+ base_vectors = []
787
+ bad = False
788
+ bad_vectors = []
789
+ for i in range(len(base_monoms)):
790
+ vec0 = poly_to_vec(base_monoms[i], vec)
791
+ if vec0 is not None:
792
+ base_vectors += [vec0]
793
+ else:
794
+ bad_vectors += [i]
795
+ for j in range(len(bad_vectors) - 1, -1, -1):
796
+ base_monoms.pop(bad_vectors[j])
797
+
798
+ vrs = [
799
+ pu.LpVariable(name=f"a{i}", lowBound=0, cat="Integer")
800
+ for i in range(len(base_vectors))
801
+ ]
802
+ lp_prob = pu.LpProblem("Problem", pu.LpMinimize)
803
+ lp_prob += int(0)
804
+ eqs = [*base_vec]
805
+ for j in range(len(base_vectors)):
806
+ for i in base_vectors[j]:
807
+ bvi = base_vectors[j][i]
808
+ if bvi == 1:
809
+ eqs[i] += vrs[j]
810
+ else:
811
+ eqs[i] += bvi * vrs[j]
812
+ for i in range(dimen):
813
+ try:
814
+ lp_prob += eqs[i] == vec[i]
815
+ except TypeError:
816
+ bad = True
817
+ break
818
+ if bad:
819
+ continue
820
+ try:
821
+ solver = pu.PULP_CBC_CMD(msg=msg)
822
+ status = lp_prob.solve(solver)
823
+ except KeyboardInterrupt:
824
+ current_process = psutil.Process()
825
+ children = current_process.children(recursive=True)
826
+ for child in children:
827
+ child_process = psutil.Process(child.pid)
828
+ child_process.terminate()
829
+ child_process.kill()
830
+ raise KeyboardInterrupt()
831
+ status = lp_prob.status
832
+ else:
833
+ val_poly = sympy.poly(expand(val), *var22, *var33)
834
+ vec = poly_to_vec(val)
835
+ mn = val_poly.monoms()
836
+ L1 = tuple([0 for i in range(n1)])
837
+ mn1L = []
838
+ lookup = {}
839
+ for mm0 in mn:
840
+ key = mm0[n1:]
841
+ if key not in lookup:
842
+ lookup[key] = []
843
+ mm0n1 = mm0[:n1]
844
+ st = set(mm0n1)
845
+ if len(st.intersection(set([0, 1]))) == len(st) and 1 in st:
846
+ lookup[key] += [mm0]
847
+ if mm0n1 == L1:
848
+ mn1L += [mm0]
849
+ for mn1 in mn1L:
850
+ comblistmn1 = [1]
851
+ for i in range(n1, len(mn1)):
852
+ if mn1[i] != 0:
853
+ arr = np.array(comblistmn1)
854
+ comblistmn12 = []
855
+ mn1_2 = tuple([*mn1[n1:i]] + [0] + [*mn1[i + 1 :]])
856
+ for mm0 in lookup[mn1_2]:
857
+ comblistmn12 += (
858
+ arr
859
+ * np.prod(
860
+ [
861
+ varsimp2[k] - varsimp3[i - n1]
862
+ for k in range(n1)
863
+ if mm0[k] == 1
864
+ ]
865
+ )
866
+ ).tolist()
867
+ comblistmn1 = comblistmn12
868
+ for i in range(len(comblistmn1)):
869
+ b1 = comblistmn1[i]
870
+ vec0 = poly_to_vec(b1, vec)
871
+ if vec0 is not None:
872
+ base_vectors += [vec0]
873
+ base_monoms += [b1]
874
+ vrs = [
875
+ pu.LpVariable(name=f"a{i}", lowBound=0, cat="Integer")
876
+ for i in range(len(base_vectors))
877
+ ]
878
+ lp_prob = pu.LpProblem("Problem", pu.LpMinimize)
879
+ lp_prob += int(0)
880
+ eqs = [*base_vec]
881
+ for j in range(len(base_vectors)):
882
+ for i in base_vectors[j]:
883
+ bvi = base_vectors[j][i]
884
+ if bvi == 1:
885
+ eqs[i] += vrs[j]
886
+ else:
887
+ eqs[i] += bvi * vrs[j]
888
+ for i in range(dimen):
889
+ lp_prob += eqs[i] == vec[i]
890
+ try:
891
+ solver = pu.PULP_CBC_CMD(msg=msg)
892
+ status = lp_prob.solve(solver)
893
+ except KeyboardInterrupt:
894
+ current_process = psutil.Process()
895
+ children = current_process.children(recursive=True)
896
+ for child in children:
897
+ child_process = psutil.Process(child.pid)
898
+ child_process.terminate()
899
+ child_process.kill()
900
+ raise KeyboardInterrupt()
901
+ # print(f"{pos_part=}")
902
+ # print(f"{neg_part=}")
903
+ # else:
904
+ # print(f"No dice {flat=}")
905
+ # exit(1)
906
+ # #val = pos_part - neg_part
907
+
908
+ # depth+=1
909
+ val2 = 0
910
+ for k in range(len(base_vectors)):
911
+ x = vrs[k].value()
912
+ b1 = base_monoms[k]
913
+ if x != 0 and x is not None:
914
+ val2 += int(x) * b1
915
+ return val2
916
+
917
+
918
+ def is_split_two(u, v, w):
919
+ if inv(w) - inv(u) != 2:
920
+ return False, []
921
+ diff_perm = mulperm(inverse([*u]), [*w])
922
+ identity = [i + 1 for i in range(len(diff_perm))]
923
+ cycles = []
924
+ for i in range(len(identity)):
925
+ if diff_perm[i] != identity[i]:
926
+ cycle0 = set()
927
+ cycle = set([i + 1])
928
+ last = i
929
+ while len(cycle0) != len(cycle):
930
+ cycle0 = cycle
931
+ last = diff_perm[last] - 1
932
+ cycle.add(last + 1)
933
+ if len(cycle) > 1 and cycle not in cycles:
934
+ cycles += [cycle]
935
+ if len(cycles) > 2:
936
+ break
937
+ if len(cycles) == 2:
938
+ return True, cycles
939
+ else:
940
+ return False, []
941
+
942
+
943
+ def is_coeff_irreducible(u, v, w):
944
+ return (
945
+ not will_formula_work(u, v)
946
+ and not will_formula_work(v, u)
947
+ and not one_dominates(u, w)
948
+ and not is_reducible(v)
949
+ and inv(w) - inv(u) > 1
950
+ and not is_split_two(u, v, w)[0]
951
+ and len([i for i in code(v) if i != 0]) > 1
952
+ )
953
+
954
+
955
+ def is_hook(cd):
956
+ started = False
957
+ done = False
958
+ found_zero_after = False
959
+ for i in range(len(cd)):
960
+ if (done or found_zero_after) and cd[i] != 0:
961
+ return False
962
+ if cd[i] == 1 and not started:
963
+ started = True
964
+ if cd[i] > 1:
965
+ done = True
966
+ if started and cd[i] == 0:
967
+ found_zero_after = True
968
+ if started or done:
969
+ return True
970
+ return False
971
+
972
+
973
+ def div_diff(i, poly):
974
+ return sympify(
975
+ sympy.div(sympy.sympify(poly - permy(poly, i)), sympy.sympify(var2[i] - var2[i + 1]))[0]
976
+ )
977
+
978
+
979
+ def skew_div_diff(u, w, poly):
980
+ d = -1
981
+ for i in range(len(w) - 1):
982
+ if w[i] > w[i + 1]:
983
+ d = i
984
+ break
985
+ d2 = -1
986
+ for i in range(len(u) - 1):
987
+ if u[i] > u[i + 1]:
988
+ d2 = i
989
+ break
990
+ if d == -1:
991
+ if d2 == -1:
992
+ return poly
993
+ return 0
994
+ w2 = [*w]
995
+ w2[d], w2[d + 1] = w2[d + 1], w2[d]
996
+ if d < len(u) - 1 and u[d] > u[d + 1]:
997
+ u2 = [*u]
998
+ u2[d], u2[d + 1] = u2[d + 1], u2[d]
999
+ return skew_div_diff(u2, w2, permy(poly, d + 1))
1000
+ else:
1001
+ return skew_div_diff(u, w2, div_diff(d + 1, poly))
1002
+
1003
+
1004
+ @cached(
1005
+ cache={},
1006
+ key=lambda val,
1007
+ u2,
1008
+ v2,
1009
+ w2,
1010
+ var2=var2,
1011
+ var3=var3,
1012
+ msg=False,
1013
+ do_pos_neg=True,
1014
+ sign_only=False: hashkey(u2, v2, w2, var2, var3, msg, do_pos_neg, sign_only),
1015
+ )
1016
+ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, sign_only=False):
1017
+ if inv(u2) + inv(v2) - inv(w2) == 0:
1018
+ return val
1019
+ cdv = code(v2)
1020
+ if set(cdv) == set([0, 1]) and do_pos_neg:
1021
+ return val
1022
+ # if is_hook(cdv):
1023
+ # print(f"Could've {cdv}")
1024
+ if not sign_only and expand(val) == 0:
1025
+ return 0
1026
+
1027
+ u, v, w = try_reduce_v(u2, v2, w2)
1028
+ if is_coeff_irreducible(u, v, w):
1029
+ u, v, w = try_reduce_u(u2, v2, w2)
1030
+ if is_coeff_irreducible(u, v, w):
1031
+ u, v, w = [*u2], [*v2], [*w2]
1032
+ if is_coeff_irreducible(u, v, w):
1033
+ w0 = [*w]
1034
+ u, v, w = reduce_descents(u, v, w)
1035
+ if is_coeff_irreducible(u, v, w):
1036
+ u, v, w = reduce_coeff(u, v, w)
1037
+ if is_coeff_irreducible(u, v, w):
1038
+ while is_coeff_irreducible(u, v, w) and tuple(permtrim(w0)) != tuple(
1039
+ permtrim([*w])
1040
+ ):
1041
+ w0 = w
1042
+ u, v, w = reduce_descents(u, v, w)
1043
+ if is_coeff_irreducible(u, v, w):
1044
+ u, v, w = reduce_coeff(u, v, w)
1045
+ u = tuple(u)
1046
+ v = tuple(v)
1047
+ w = tuple(w)
1048
+
1049
+ if w != w2 and sign_only:
1050
+ return 0
1051
+
1052
+ if is_coeff_irreducible(u, v, w):
1053
+ u3, v3, w3 = try_reduce_v(u, v, w)
1054
+ if not is_coeff_irreducible(u3, v3, w3):
1055
+ u, v, w = u3, v3, w3
1056
+ else:
1057
+ u3, v3, w3 = try_reduce_u(u, v, w)
1058
+ if not is_coeff_irreducible(u3, v3, w3):
1059
+ u, v, w = u3, v3, w3
1060
+ split_two_b, split_two = is_split_two(u, v, w)
1061
+
1062
+ if len([i for i in code(v) if i != 0]) == 1:
1063
+ if sign_only:
1064
+ return 0
1065
+ cv = code(v)
1066
+ for i in range(len(cv)):
1067
+ if cv[i] != 0:
1068
+ k = i + 1
1069
+ p = cv[i]
1070
+ break
1071
+ inv_u = inv(u)
1072
+ r = inv(w) - inv_u
1073
+ val = 0
1074
+ w2 = w
1075
+ hvarset = (
1076
+ [w2[i] for i in range(min(len(w2), k))]
1077
+ + [i + 1 for i in range(len(w2), k)]
1078
+ + [w2[b] for b in range(k, len(u)) if u[b] != w2[b]]
1079
+ + [w2[b] for b in range(len(u), len(w2))]
1080
+ )
1081
+ val = elem_sym_poly(
1082
+ p - r,
1083
+ k + p - 1,
1084
+ [-var3[i] for i in range(1, n)],
1085
+ [-var2[i] for i in hvarset],
1086
+ )
1087
+ elif will_formula_work(v, u) or dominates(u, w):
1088
+ if sign_only:
1089
+ return 0
1090
+ val = dualcoeff(u, v, w, var2, var3)
1091
+ elif inv(w) - inv(u) == 1:
1092
+ if sign_only:
1093
+ return 0
1094
+ a, b = -1, -1
1095
+ for i in range(len(w)):
1096
+ if a == -1 and u[i] != w[i]:
1097
+ a = i
1098
+ elif i >= len(u) and w[i] != i + 1:
1099
+ b = i
1100
+ elif b == -1 and u[i] != w[i]:
1101
+ b = i
1102
+ arr = [[[], v]]
1103
+ d = -1
1104
+ for i in range(len(v) - 1):
1105
+ if v[i] > v[i + 1]:
1106
+ d = i + 1
1107
+ for i in range(d):
1108
+ arr2 = []
1109
+ if i in [a, b]:
1110
+ continue
1111
+ i2 = 1
1112
+ if i > b:
1113
+ i2 += 2
1114
+ elif i > a:
1115
+ i2 += 1
1116
+ for vr, v2 in arr:
1117
+ dpret = pull_out_var(i2, [*v2])
1118
+ for v3r, v3 in dpret:
1119
+ arr2 += [[vr + [v3r], v3]]
1120
+ arr = arr2
1121
+ val = 0
1122
+ for L in arr:
1123
+ v3 = [*L[-1]]
1124
+ if v3[0] < v3[1]:
1125
+ continue
1126
+ else:
1127
+ v3[0], v3[1] = v3[1], v3[0]
1128
+ toadd = 1
1129
+ for i in range(d):
1130
+ if i in [a, b]:
1131
+ continue
1132
+ i2 = i
1133
+ if i > b:
1134
+ i2 = i - 2
1135
+ elif i > a:
1136
+ i2 = i - 1
1137
+ oaf = L[0][i2]
1138
+ if i >= len(w):
1139
+ yv = i + 1
1140
+ else:
1141
+ yv = w[i]
1142
+ for j in range(len(oaf)):
1143
+ toadd *= var2[yv] - var3[oaf[j]]
1144
+ toadd *= schubpoly(v3, [0, var2[w[a]], var2[w[b]]], var3)
1145
+ val += toadd
1146
+ elif split_two_b:
1147
+ if sign_only:
1148
+ return 0
1149
+ cycles = split_two
1150
+ a1, b1 = cycles[0]
1151
+ a2, b2 = cycles[1]
1152
+ a1 -= 1
1153
+ b1 -= 1
1154
+ a2 -= 1
1155
+ b2 -= 1
1156
+ spo = sorted([a1, b1, a2, b2])
1157
+ real_a1 = min(spo.index(a1), spo.index(b1))
1158
+ real_a2 = min(spo.index(a2), spo.index(b2))
1159
+ real_b1 = max(spo.index(a1), spo.index(b1))
1160
+ real_b2 = max(spo.index(a2), spo.index(b2))
1161
+
1162
+ good1 = False
1163
+ good2 = False
1164
+ if real_b1 - real_a1 == 1:
1165
+ good1 = True
1166
+ if real_b2 - real_a2 == 1:
1167
+ good2 = True
1168
+ a, b = -1, -1
1169
+ if good1 and not good2:
1170
+ a, b = min(a2, b2), max(a2, b2)
1171
+ if good2 and not good1:
1172
+ a, b = min(a1, b1), max(a1, b1)
1173
+ arr = [[[], v]]
1174
+ d = -1
1175
+ for i in range(len(v) - 1):
1176
+ if v[i] > v[i + 1]:
1177
+ d = i + 1
1178
+ for i in range(d):
1179
+ arr2 = []
1180
+
1181
+ if i in [a1, b1, a2, b2]:
1182
+ continue
1183
+ i2 = 1
1184
+ i2 += len([aa for aa in [a1, b1, a2, b2] if i > aa])
1185
+ for vr, v2 in arr:
1186
+ dpret = pull_out_var(i2, [*v2])
1187
+ for v3r, v3 in dpret:
1188
+ arr2 += [[vr + [(v3r, i + 1)], v3]]
1189
+ arr = arr2
1190
+ val = 0
1191
+
1192
+ if good1:
1193
+ arr2 = []
1194
+ for L in arr:
1195
+ v3 = [*L[-1]]
1196
+ if v3[real_a1] < v3[real_b1]:
1197
+ continue
1198
+ else:
1199
+ v3[real_a1], v3[real_b1] = v3[real_b1], v3[real_a1]
1200
+ arr2 += [[L[0], v3]]
1201
+ arr = arr2
1202
+ if not good2:
1203
+ for i in range(4):
1204
+ arr2 = []
1205
+
1206
+ if i in [real_a2, real_b2]:
1207
+ continue
1208
+ if i == real_a1:
1209
+ var_index = min(a1, b1) + 1
1210
+ elif i == real_b1:
1211
+ var_index = max(a1, b1) + 1
1212
+ i2 = 1
1213
+ i2 += len([aa for aa in [real_a2, real_b2] if i > aa])
1214
+ for vr, v2 in arr:
1215
+ dpret = pull_out_var(i2, [*v2])
1216
+ for v3r, v3 in dpret:
1217
+ arr2 += [[vr + [(v3r, var_index)], v3]]
1218
+ arr = arr2
1219
+ if good2:
1220
+ arr2 = []
1221
+ for L in arr:
1222
+ v3 = [*L[-1]]
1223
+ try:
1224
+ if v3[real_a2] < v3[real_b2]:
1225
+ continue
1226
+ else:
1227
+ v3[real_a2], v3[real_b2] = v3[real_b2], v3[real_a2]
1228
+ except IndexError:
1229
+ continue
1230
+ arr2 += [[L[0], v3]]
1231
+ arr = arr2
1232
+ if not good1:
1233
+ for i in range(4):
1234
+ arr2 = []
1235
+
1236
+ if i in [real_a1, real_b1]:
1237
+ continue
1238
+ i2 = 1
1239
+ i2 += len([aa for aa in [real_a1, real_b1] if i > aa])
1240
+ if i == real_a2:
1241
+ var_index = min(a2, b2) + 1
1242
+ elif i == real_b2:
1243
+ var_index = max(a2, b2) + 1
1244
+ for vr, v2 in arr:
1245
+ dpret = pull_out_var(i2, [*v2])
1246
+ for v3r, v3 in dpret:
1247
+ arr2 += [[vr + [(v3r, var_index)], v3]]
1248
+ arr = arr2
1249
+
1250
+ for L in arr:
1251
+ v3 = [*L[-1]]
1252
+ tomul = 1
1253
+ doschubpoly = True
1254
+ if (not good1 or not good2) and v3[0] < v3[1] and (good1 or good2):
1255
+ continue
1256
+ elif (good1 or good2) and (not good1 or not good2):
1257
+ v3[0], v3[1] = v3[1], v3[0]
1258
+ elif not good1 and not good2:
1259
+ doschubpoly = False
1260
+ if v3[0] < v3[1]:
1261
+ dual_u = uncode([2, 0])
1262
+ dual_w = [4, 2, 1, 3]
1263
+ coeff = permy(dualcoeff(dual_u, v3, dual_w, var2, var3), 2)
1264
+
1265
+ elif len(v3) < 3 or v3[1] < v3[2]:
1266
+ if len(v3) <= 3 or v3[2] < v3[3]:
1267
+ coeff = 0
1268
+ continue
1269
+ else:
1270
+ v3[0], v3[1] = v3[1], v3[0]
1271
+ v3[2], v3[3] = v3[3], v3[2]
1272
+ coeff = permy(schubpoly(v3, var2, var3), 2)
1273
+ elif len(v3) <= 3 or v3[2] < v3[3]:
1274
+ if len(v3) <= 3:
1275
+ v3 += [4]
1276
+ v3[2], v3[3] = v3[3], v3[2]
1277
+ coeff = permy(
1278
+ posify(
1279
+ schubmult_one((1, 3, 2), tuple(permtrim([*v3])), var2, var3).get(
1280
+ (2, 4, 3, 1), 0
1281
+ ),
1282
+ (1, 3, 2),
1283
+ tuple(permtrim([*v3])),
1284
+ (2, 4, 3, 1),
1285
+ var2,
1286
+ var3,
1287
+ msg,
1288
+ do_pos_neg,
1289
+ ),
1290
+ 2,
1291
+ )
1292
+ else:
1293
+ coeff = permy(
1294
+ schubmult_one((1, 3, 2), tuple(permtrim([*v3])), var2, var3).get(
1295
+ (2, 4, 1, 3), 0
1296
+ ),
1297
+ 2,
1298
+ )
1299
+ tomul = sympify(coeff)
1300
+ toadd = 1
1301
+ for i in range(len(L[0])):
1302
+ var_index = L[0][i][1]
1303
+ oaf = L[0][i][0]
1304
+ if var_index - 1 >= len(w):
1305
+ yv = var_index
1306
+ else:
1307
+ yv = w[var_index - 1]
1308
+ for j in range(len(oaf)):
1309
+ toadd *= var2[yv] - var3[oaf[j]]
1310
+ if (not good1 or not good2) and (good1 or good2):
1311
+ varo = [0, var2[w[a]], var2[w[b]]]
1312
+ else:
1313
+ varo = [0, *[var2[w[spo[k]]] for k in range(4)]]
1314
+ if doschubpoly:
1315
+ toadd *= schubpoly(v3, varo, var3)
1316
+ else:
1317
+ subs_dict3 = {var2[i]: varo[i] for i in range(len(varo))}
1318
+ toadd *= tomul.subs(subs_dict3)
1319
+ val += toadd
1320
+ elif will_formula_work(u, v):
1321
+ if sign_only:
1322
+ return 0
1323
+ val = forwardcoeff(u, v, w, var2, var3)
1324
+ # elif inv(w) - inv(u) == 2:
1325
+ # indices = []
1326
+ # for i in range(len(w)):
1327
+ # if i>=len(u) or u[i]!=w[i]:
1328
+ # indices += [i+1]
1329
+ # arr = [[[],v]]
1330
+ # d = -1
1331
+ # for i in range(len(v)-1):
1332
+ # if v[i]>v[i+1]:
1333
+ # d = i + 1
1334
+ # for i in range(d):
1335
+ # arr2 = []
1336
+ #
1337
+ # if i+1 in indices:
1338
+ # continue
1339
+ # i2 = 1
1340
+ # i2 += len([aa for aa in indices if i+1>aa])
1341
+ # for vr, v2 in arr:
1342
+ # dpret = pull_out_var(i2,[*v2])
1343
+ # for v3r, v3 in dpret:
1344
+ # arr2 += [[vr + [(v3r,i+1)],v3]]
1345
+ # arr = arr2
1346
+ # val = 0
1347
+ #
1348
+ # for L in arr:
1349
+ # v3 = [*L[-1]]
1350
+ # tomul = 1
1351
+ # pooly = skew_div_diff(u,w,schubpoly(v3,[0,*[var2[a] for a in indices]],var3))
1352
+ # coeff = compute_positive_rep(pooly,var2,var3,msg,False)
1353
+ # if coeff == -1:
1354
+ # return -1
1355
+ # tomul = sympify(coeff)
1356
+ # toadd = 1
1357
+ # for i in range(len(L[0])):
1358
+ # var_index = L[0][i][1]
1359
+ # oaf = L[0][i][0]
1360
+ # if var_index-1>=len(w):
1361
+ # yv = var_index
1362
+ # else:
1363
+ # yv = w[var_index-1]
1364
+ # for j in range(len(oaf)):
1365
+ # toadd*= var2[yv] - var3[oaf[j]]
1366
+ # toadd*=tomul#.subs(subs_dict3)
1367
+ # val += toadd
1368
+ else:
1369
+ c01 = code(u)
1370
+ c02 = code(w)
1371
+ c03 = code(v)
1372
+
1373
+ c1 = code(inverse(u))
1374
+ c2 = code(inverse(w))
1375
+
1376
+ if one_dominates(u, w):
1377
+ if sign_only:
1378
+ return 0
1379
+ while c1[0] != c2[0]:
1380
+ w = [*w]
1381
+ v = [*v]
1382
+ w[c2[0] - 1], w[c2[0]] = w[c2[0]], w[c2[0] - 1]
1383
+ v[c2[0] - 1], v[c2[0]] = v[c2[0]], v[c2[0] - 1]
1384
+ w = tuple(w)
1385
+ v = tuple(v)
1386
+ c2 = code(inverse(w))
1387
+ c03 = code(v)
1388
+ c01 = code(u)
1389
+ c02 = code(w)
1390
+
1391
+ if is_reducible(v):
1392
+ if sign_only:
1393
+ return 0
1394
+ newc = []
1395
+ elemc = []
1396
+ for i in range(len(c03)):
1397
+ if c03[i] > 0:
1398
+ newc += [c03[i] - 1]
1399
+ elemc += [1]
1400
+ else:
1401
+ break
1402
+ v3 = uncode(newc)
1403
+ coeff_dict = schubmult_one(
1404
+ tuple(permtrim([*u])), tuple(permtrim(uncode(elemc))), var2, var3
1405
+ )
1406
+ val = 0
1407
+ for new_w in coeff_dict:
1408
+ tomul = coeff_dict[new_w]
1409
+ newval = schubmult_one(new_w, tuple(permtrim(uncode(newc))), var2, var3).get(
1410
+ tuple(permtrim([*w])), 0
1411
+ )
1412
+ newval = posify(
1413
+ newval,
1414
+ new_w,
1415
+ tuple(permtrim(uncode(newc))),
1416
+ w,
1417
+ var2,
1418
+ var3,
1419
+ msg,
1420
+ do_pos_neg,
1421
+ )
1422
+ val += tomul * shiftsubz(newval)
1423
+ elif c01[0] == c02[0] and c01[0] != 0:
1424
+ if sign_only:
1425
+ return 0
1426
+ varl = c01[0]
1427
+ u3 = uncode([0] + c01[1:])
1428
+ w3 = uncode([0] + c02[1:])
1429
+ val = 0
1430
+ val = schubmult_one(tuple(permtrim(u3)), tuple(permtrim([*v])), var2, var3).get(
1431
+ tuple(permtrim(w3)), 0
1432
+ )
1433
+ val = posify(
1434
+ val,
1435
+ tuple(permtrim(u3)),
1436
+ tuple(permtrim([*v])),
1437
+ tuple(permtrim(w3)),
1438
+ var2,
1439
+ var3,
1440
+ msg,
1441
+ do_pos_neg,
1442
+ )
1443
+ for i in range(varl):
1444
+ val = permy(val, i + 1)
1445
+ elif c1[0] == c2[0]:
1446
+ if sign_only:
1447
+ return 0
1448
+ vp = pull_out_var(c1[0] + 1, [*v])
1449
+ u3 = tuple(permtrim(phi1(u)))
1450
+ w3 = tuple(permtrim(phi1(w)))
1451
+ val = 0
1452
+ for arr, v3 in vp:
1453
+ tomul = 1
1454
+ for i in range(len(arr)):
1455
+ tomul *= var2[1] - var3[arr[i]]
1456
+
1457
+ val2 = schubmult_one(tuple(permtrim(u3)), tuple(permtrim(v3)), var2, var3).get(
1458
+ tuple(permtrim(w3)), 0
1459
+ )
1460
+ val2 = posify(val2, u3, tuple(permtrim(v3)), w3, var2, var3, msg, do_pos_neg)
1461
+ val += tomul * shiftsub(val2)
1462
+ # elif inv(w)-inv(u)==2 and len(trimcode(u)) == len(trimcode(w)):
1463
+ # indices = []
1464
+ # for i in range(len(w)):
1465
+ # if i>=len(u) or u[i]!=w[i]:
1466
+ # indices += [i+1]
1467
+ # arr = [[[],v]]
1468
+ # d = -1
1469
+ # for i in range(len(v)-1):
1470
+ # if v[i]>v[i+1]:
1471
+ # d = i + 1
1472
+ # for i in range(d):
1473
+ # arr2 = []
1474
+ #
1475
+ # if i+1 in indices:
1476
+ # continue
1477
+ # i2 = 1
1478
+ # i2 += len([aa for aa in indices if i+1>aa])
1479
+ # for vr, v2 in arr:
1480
+ # dpret = pull_out_var(i2,[*v2])
1481
+ # for v3r, v3 in dpret:
1482
+ # arr2 += [[vr + [(v3r,i+1)],v3]]
1483
+ # arr = arr2
1484
+ # val = 0
1485
+ #
1486
+ # for L in arr:
1487
+ # v3 = [*L[-1]]
1488
+ # tomul = 1
1489
+ # toadd = 1
1490
+ # for i in range(len(L[0])):
1491
+ # var_index = L[0][i][1]
1492
+ # oaf = L[0][i][0]
1493
+ # if var_index-1>=len(w):
1494
+ # yv = var_index
1495
+ # else:
1496
+ # yv = w[var_index-1]
1497
+ # for j in range(len(oaf)):
1498
+ # toadd*= var2[yv] - var3[oaf[j]]
1499
+ # pooly = skew_div_diff(u,w,schubpoly(v3,[0,*[var2[a] for a in indices]],var3))
1500
+ # if toadd == 0:
1501
+ # continue
1502
+ # if pooly !=0:
1503
+ # coeff = compute_positive_rep(pooly,var2,var3,msg,False)
1504
+ # else:
1505
+ # coeff = 0
1506
+ # if coeff == -1:
1507
+ # return -1
1508
+ # tomul = sympify(coeff)
1509
+ # toadd*=tomul#.subs(subs_dict3)
1510
+ # val += toadd
1511
+ else:
1512
+ if not sign_only:
1513
+ if inv(u) + inv(v) - inv(w) == 1:
1514
+ val2 = compute_positive_rep(val, var2, var3, msg, False)
1515
+ else:
1516
+ val2 = compute_positive_rep(val, var2, var3, msg, do_pos_neg)
1517
+ if val2 is not None:
1518
+ val = val2
1519
+ else:
1520
+ # st = str(expand(val))
1521
+ # if st.find("-")!=-1:
1522
+ # return -1
1523
+ # else:
1524
+ # return val
1525
+ d = expand(val).as_coefficients_dict()
1526
+ for v in d.values():
1527
+ if v < 0:
1528
+ return -1
1529
+ return 1
1530
+ return val
1531
+
1532
+
1533
+ def split_perms(perms):
1534
+ perms2 = [perms[0]]
1535
+ for perm in perms[1:]:
1536
+ cd = code(perm)
1537
+ index = -1
1538
+ not_zero = False
1539
+ did = False
1540
+ for i in range(len(cd)):
1541
+ if cd[i] != 0:
1542
+ not_zero = True
1543
+ elif not_zero and cd[i] == 0:
1544
+ not_zero = False
1545
+ index = i
1546
+ num_zeros_to_miss = 0
1547
+ for j in range(index):
1548
+ if cd[j] != 0:
1549
+ num_zeros_to_miss = max(num_zeros_to_miss, cd[j] - (index - 1 - j))
1550
+ num_zeros = 0
1551
+ for j in range(index, len(cd)):
1552
+ if cd[j] != 0:
1553
+ break
1554
+ else:
1555
+ num_zeros += 1
1556
+ if num_zeros >= num_zeros_to_miss:
1557
+ cd1 = cd[:index]
1558
+ cd2 = [0 for i in range(index)] + cd[index:]
1559
+ perms2 += [
1560
+ tuple(permtrim(uncode(cd1))),
1561
+ tuple(permtrim(uncode(cd2))),
1562
+ ]
1563
+ did = True
1564
+ break
1565
+ if not did:
1566
+ perms2 += [perm]
1567
+ return perms2
1568
+
1569
+
1570
+ def schubpoly(v, var2=var2, var3=var3, start_var=1):
1571
+ n = 0
1572
+ for j in range(len(v) - 2, -1, -1):
1573
+ if v[j] > v[j + 1]:
1574
+ n = j + 1
1575
+ break
1576
+ if n == 0:
1577
+ return 1
1578
+ lst = pull_out_var(n, v)
1579
+ ret = 0
1580
+ for pw, vp in lst:
1581
+ tomul = 1
1582
+ for p in pw:
1583
+ tomul *= var2[start_var + n - 1] - var3[p]
1584
+ ret += tomul * schubpoly(vp, var2, var3, start_var)
1585
+ return ret
1586
+
1587
+
1588
+ def permy(val, i):
1589
+ subsdict = {var2[i]: var2[i + 1], var2[i + 1]: var2[i]}
1590
+ return sympify(val).subs(subsdict)