schubmult 2.0.2__py3-none-any.whl → 2.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. schubmult/__init__.py +1 -1
  2. schubmult/_base_argparse.py +42 -8
  3. schubmult/_tests.py +24 -0
  4. schubmult/perm_lib.py +52 -112
  5. schubmult/sage_integration/__init__.py +13 -13
  6. schubmult/sage_integration/_fast_double_schubert_polynomial_ring.py +139 -118
  7. schubmult/sage_integration/_fast_schubert_polynomial_ring.py +88 -49
  8. schubmult/sage_integration/_indexing.py +35 -32
  9. schubmult/schubmult_double/__init__.py +6 -12
  10. schubmult/schubmult_double/__main__.py +2 -1
  11. schubmult/schubmult_double/_funcs.py +245 -281
  12. schubmult/schubmult_double/_script.py +128 -70
  13. schubmult/schubmult_py/__init__.py +5 -3
  14. schubmult/schubmult_py/__main__.py +2 -1
  15. schubmult/schubmult_py/_funcs.py +68 -23
  16. schubmult/schubmult_py/_script.py +40 -58
  17. schubmult/schubmult_q/__init__.py +3 -7
  18. schubmult/schubmult_q/__main__.py +2 -1
  19. schubmult/schubmult_q/_funcs.py +41 -60
  20. schubmult/schubmult_q/_script.py +39 -30
  21. schubmult/schubmult_q_double/__init__.py +5 -11
  22. schubmult/schubmult_q_double/__main__.py +2 -1
  23. schubmult/schubmult_q_double/_funcs.py +99 -66
  24. schubmult/schubmult_q_double/_script.py +209 -150
  25. schubmult-2.0.4.dist-info/METADATA +542 -0
  26. schubmult-2.0.4.dist-info/RECORD +30 -0
  27. {schubmult-2.0.2.dist-info → schubmult-2.0.4.dist-info}/WHEEL +1 -1
  28. schubmult-2.0.4.dist-info/entry_points.txt +5 -0
  29. {schubmult-2.0.2.dist-info → schubmult-2.0.4.dist-info}/top_level.txt +0 -1
  30. schubmult/schubmult_double/_vars.py +0 -18
  31. schubmult/schubmult_py/_vars.py +0 -3
  32. schubmult/schubmult_q/_vars.py +0 -18
  33. schubmult/schubmult_q_double/_vars.py +0 -21
  34. schubmult-2.0.2.dist-info/METADATA +0 -455
  35. schubmult-2.0.2.dist-info/RECORD +0 -36
  36. schubmult-2.0.2.dist-info/entry_points.txt +0 -5
  37. tests/__init__.py +0 -0
  38. tests/test_fast_double_schubert.py +0 -145
  39. tests/test_fast_schubert.py +0 -38
  40. {schubmult-2.0.2.dist-info → schubmult-2.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,48 +1,81 @@
1
1
  from bisect import bisect_left
2
- from functools import cache
2
+ from functools import cache, cached_property
3
+
4
+ import numpy as np
5
+ import psutil
6
+ import pulp as pu
7
+ import sympy
3
8
  from cachetools import cached
4
9
  from cachetools.keys import hashkey
5
- from symengine import sympify, Add, Mul, Pow, expand, Integer
10
+ from sortedcontainers import SortedList
11
+ from symengine import Add, Integer, Mul, Pow, expand, symarray, sympify
12
+
6
13
  from schubmult.perm_lib import (
7
- elem_sym_perms,
8
- elem_sym_poly,
9
14
  add_perm_dict,
10
- dominates,
11
- compute_vpathdicts,
12
- inverse,
13
- theta,
14
- permtrim,
15
- inv,
16
- mulperm,
17
15
  code,
18
- uncode,
16
+ compute_vpathdicts,
17
+ cycle,
18
+ divdiffable,
19
+ dominates,
19
20
  elem_sym_func,
21
+ elem_sym_perms,
20
22
  elem_sym_perms_op,
21
- divdiffable,
22
- pull_out_var,
23
- cycle,
24
- will_formula_work,
25
- one_dominates,
23
+ elem_sym_poly,
24
+ inv,
25
+ inverse,
26
26
  is_reducible,
27
+ mulperm,
28
+ one_dominates,
29
+ permtrim,
30
+ phi1,
31
+ pull_out_var,
27
32
  reduce_coeff,
28
33
  reduce_descents,
34
+ theta,
29
35
  try_reduce_u,
30
36
  try_reduce_v,
31
- phi1,
37
+ uncode,
38
+ will_formula_work,
32
39
  zero,
33
40
  )
34
- import numpy as np
35
- import pulp as pu
36
- import sympy
37
- import psutil
38
- from sortedcontainers import SortedList
39
- from ._vars import (
40
- n,
41
- var2,
42
- var3,
43
- var_x,
44
- var_y,
45
- )
41
+
42
+ # NO GLOBAL VARS
43
+ # from ._vars import (
44
+ # n,
45
+ # var2,
46
+ # var3,
47
+ # _vars.var1,
48
+ # var_y,
49
+ # )
50
+
51
+
52
+ class _gvars:
53
+ @cached_property
54
+ def n(self):
55
+ return 100
56
+
57
+ # @cached_property
58
+ # def fvar(self):
59
+ # return 100
60
+
61
+ @cached_property
62
+ def var1(self):
63
+ return tuple(symarray("x", self.n).tolist())
64
+
65
+ @cached_property
66
+ def var2(self):
67
+ return tuple(symarray("y", self.n).tolist())
68
+
69
+ @cached_property
70
+ def var3(self):
71
+ return tuple(symarray("z", self.n).tolist())
72
+
73
+ @cached_property
74
+ def var_r(self):
75
+ return symarray("r", 100)
76
+
77
+
78
+ _vars = _gvars()
46
79
 
47
80
 
48
81
  def count_sorted(mn, tp):
@@ -54,11 +87,11 @@ def count_sorted(mn, tp):
54
87
  return ct
55
88
 
56
89
 
57
- def E(p, k, varl=var_y[1:]):
58
- return elem_sym_poly(p, k, var_x[1:], varl)
90
+ def E(p, k, varl=_vars.var2[1:]):
91
+ return elem_sym_poly(p, k, _vars.var1[1:], varl)
59
92
 
60
93
 
61
- def single_variable(coeff_dict, varnum, var2=var2):
94
+ def single_variable(coeff_dict, varnum, var2=None):
62
95
  ret = {}
63
96
  for u in coeff_dict:
64
97
  if varnum - 1 < len(u):
@@ -78,7 +111,7 @@ def single_variable(coeff_dict, varnum, var2=var2):
78
111
  return ret
79
112
 
80
113
 
81
- def single_variable_down(coeff_dict, varnum):
114
+ def single_variable_down(coeff_dict, varnum, var2=_vars.var2):
82
115
  ret = {}
83
116
  for u in coeff_dict:
84
117
  if varnum - 1 < len(u):
@@ -98,58 +131,56 @@ def single_variable_down(coeff_dict, varnum):
98
131
  return ret
99
132
 
100
133
 
101
- def mult_poly(coeff_dict, poly, var_x=var_x, var_y=var2):
134
+ def mult_poly(coeff_dict, poly, var_x=_vars.var1, var_y=_vars.var2):
102
135
  if poly in var_x:
103
136
  return single_variable(coeff_dict, var_x.index(poly), var_y)
104
- elif isinstance(poly, Mul):
137
+ if isinstance(poly, Mul):
105
138
  ret = coeff_dict
106
139
  for a in poly.args:
107
140
  ret = mult_poly(ret, a, var_x, var_y)
108
141
  return ret
109
- elif isinstance(poly, Pow):
142
+ if isinstance(poly, Pow):
110
143
  base = poly.args[0]
111
144
  exponent = int(poly.args[1])
112
145
  ret = coeff_dict
113
146
  for i in range(int(exponent)):
114
147
  ret = mult_poly(ret, base, var_x, var_y)
115
148
  return ret
116
- elif isinstance(poly, Add):
149
+ if isinstance(poly, Add):
117
150
  ret = {}
118
151
  for a in poly.args:
119
152
  ret = add_perm_dict(ret, mult_poly(coeff_dict, a, var_x, var_y))
120
153
  return ret
121
- else:
122
- ret = {}
123
- for perm in coeff_dict:
124
- ret[perm] = poly * coeff_dict[perm]
125
- return ret
154
+ ret = {}
155
+ for perm in coeff_dict:
156
+ ret[perm] = poly * coeff_dict[perm]
157
+ return ret
126
158
 
127
159
 
128
160
  def mult_poly_down(coeff_dict, poly):
129
- if poly in var_x:
130
- return single_variable_down(coeff_dict, var_x.index(poly))
131
- elif isinstance(poly, Mul):
161
+ if poly in _vars.var1:
162
+ return single_variable_down(coeff_dict, _vars.var1.index(poly))
163
+ if isinstance(poly, Mul):
132
164
  ret = coeff_dict
133
165
  for a in poly.args:
134
166
  ret = mult_poly_down(ret, a)
135
167
  return ret
136
- elif isinstance(poly, Pow):
168
+ if isinstance(poly, Pow):
137
169
  base = poly.args[0]
138
170
  exponent = int(poly.args[1])
139
171
  ret = coeff_dict
140
172
  for i in range(int(exponent)):
141
173
  ret = mult_poly_down(ret, base)
142
174
  return ret
143
- elif isinstance(poly, Add):
175
+ if isinstance(poly, Add):
144
176
  ret = {}
145
177
  for a in poly.args:
146
178
  ret = add_perm_dict(ret, mult_poly_down(coeff_dict, a))
147
179
  return ret
148
- else:
149
- ret = {}
150
- for perm in coeff_dict:
151
- ret[perm] = poly * coeff_dict[perm]
152
- return ret
180
+ ret = {}
181
+ for perm in coeff_dict:
182
+ ret[perm] = poly * coeff_dict[perm]
183
+ return ret
153
184
 
154
185
 
155
186
  def nilhecke_mult(coeff_dict1, coeff_dict2):
@@ -168,7 +199,7 @@ def nilhecke_mult(coeff_dict1, coeff_dict2):
168
199
  return ret
169
200
 
170
201
 
171
- def forwardcoeff(u, v, perm, var2=var2, var3=var3):
202
+ def forwardcoeff(u, v, perm, var2=None, var3=None):
172
203
  th = theta(v)
173
204
  muv = uncode(th)
174
205
  vmun1 = mulperm(inverse([*v]), muv)
@@ -180,7 +211,7 @@ def forwardcoeff(u, v, perm, var2=var2, var3=var3):
180
211
  return 0
181
212
 
182
213
 
183
- def dualcoeff(u, v, perm, var2=var2, var3=var3):
214
+ def dualcoeff(u, v, perm, var2=None, var3=None):
184
215
  if u == (1, 2):
185
216
  vp = mulperm([*v], inverse(perm))
186
217
  if inv(vp) == inv(v) - inv(perm):
@@ -236,19 +267,18 @@ def dualpieri(mu, v, w):
236
267
  continue
237
268
  vl = pull_out_var(lm[i] + 1, vpl)
238
269
  for pw, vpl2 in vl:
239
- res2 += [[vlist + [pw], vpl2]]
270
+ res2 += [[[*vlist, pw], vpl2]]
240
271
  res = res2
241
272
  if len(lm) == len(cn1w):
242
273
  return res
243
- else:
244
- res2 = []
245
- for vlist, vplist in res:
246
- vp = vplist
247
- vpl = divdiffable(vp, c)
248
- if vpl == []:
249
- continue
250
- res2 += [[vlist, vpl]]
251
- return res2
274
+ res2 = []
275
+ for vlist, vplist in res:
276
+ vp = vplist
277
+ vpl = divdiffable(vp, c)
278
+ if vpl == []:
279
+ continue
280
+ res2 += [[vlist, vpl]]
281
+ return res2
252
282
 
253
283
 
254
284
  dimen = 0
@@ -256,11 +286,11 @@ monom_to_vec = {}
256
286
 
257
287
 
258
288
  @cache
259
- def schubmult_one(perm1, perm2, var2=var2, var3=var3):
289
+ def schubmult_one(perm1, perm2, var2=None, var3=None):
260
290
  return schubmult({perm1: 1}, perm2, var2, var3)
261
291
 
262
292
 
263
- def schubmult(perm_dict, v, var2=var2, var3=var3):
293
+ def schubmult(perm_dict, v, var2=None, var3=None):
264
294
  vn1 = inverse(v)
265
295
  th = theta(vn1)
266
296
  if len(th) == 0:
@@ -283,13 +313,14 @@ def schubmult(perm_dict, v, var2=var2, var3=var3):
283
313
  mx_th = 0
284
314
  for vp in vpathdicts[index]:
285
315
  for v2, vdiff, s in vpathdicts[index][vp]:
286
- if th[index] - vdiff > mx_th:
287
- mx_th = th[index] - vdiff
316
+ mx_th = max(mx_th, th[index] - vdiff)
288
317
  newpathsums = {}
289
318
  for up in vpathsums:
290
319
  inv_up = inv(up)
291
320
  newperms = elem_sym_perms(
292
- up, min(mx_th, (inv_mu - (inv_up - inv_u)) - inv_vmu), th[index]
321
+ up,
322
+ min(mx_th, (inv_mu - (inv_up - inv_u)) - inv_vmu),
323
+ th[index],
293
324
  )
294
325
  for up2, udiff in newperms:
295
326
  if up2 not in newpathsums:
@@ -300,7 +331,8 @@ def schubmult(perm_dict, v, var2=var2, var3=var3):
300
331
  continue
301
332
  for v2, vdiff, s in vpathdicts[index][v]:
302
333
  newpathsums[up2][v2] = newpathsums[up2].get(
303
- v2, zero
334
+ v2,
335
+ zero,
304
336
  ) + s * sumval * elem_sym_func(
305
337
  th[index],
306
338
  index + 1,
@@ -319,7 +351,7 @@ def schubmult(perm_dict, v, var2=var2, var3=var3):
319
351
  return ret_dict
320
352
 
321
353
 
322
- def schubmult_down(perm_dict, v, var2=var2, var3=var3):
354
+ def schubmult_down(perm_dict, v, var2=None, var3=None):
323
355
  vn1 = inverse(v)
324
356
  th = theta(vn1)
325
357
  if th[0] == 0:
@@ -338,8 +370,7 @@ def schubmult_down(perm_dict, v, var2=var2, var3=var3):
338
370
  mx_th = 0
339
371
  for vp in vpathdicts[index]:
340
372
  for v2, vdiff, s in vpathdicts[index][vp]:
341
- if th[index] - vdiff > mx_th:
342
- mx_th = th[index] - vdiff
373
+ mx_th = max(mx_th, th[index] - vdiff)
343
374
  newpathsums = {}
344
375
  for up in vpathsums:
345
376
  newperms = elem_sym_perms_op(up, mx_th, th[index])
@@ -352,7 +383,8 @@ def schubmult_down(perm_dict, v, var2=var2, var3=var3):
352
383
  continue
353
384
  for v2, vdiff, s in vpathdicts[index][v]:
354
385
  newpathsums[up2][v2] = newpathsums[up2].get(
355
- v2, zero
386
+ v2,
387
+ zero,
356
388
  ) + s * sumval * elem_sym_func(
357
389
  th[index],
358
390
  index + 1,
@@ -371,8 +403,7 @@ def schubmult_down(perm_dict, v, var2=var2, var3=var3):
371
403
  return ret_dict
372
404
 
373
405
 
374
- def poly_to_vec(poly, vec0=None):
375
- global dimen, monom_to_vec, base_vec
406
+ def poly_to_vec(poly, vec0=None, var3=_vars.var3):
376
407
  poly = expand(poly.xreplace({var3[1]: 0}))
377
408
 
378
409
  dc = poly.as_coefficients_dict()
@@ -396,18 +427,18 @@ def poly_to_vec(poly, vec0=None):
396
427
  return vec
397
428
 
398
429
 
399
- def shiftsub(pol):
400
- subs_dict = dict([(var2[i], var2[i + 1]) for i in range(99)])
430
+ def shiftsub(pol, var2=_vars.var2):
431
+ subs_dict = {var2[i]: var2[i + 1] for i in range(99)}
401
432
  return sympify(pol).subs(subs_dict)
402
433
 
403
434
 
404
- def shiftsubz(pol):
405
- subs_dict = dict([(var3[i], var3[i + 1]) for i in range(99)])
435
+ def shiftsubz(pol, var3=_vars.var3):
436
+ subs_dict = {var3[i]: var3[i + 1] for i in range(99)}
406
437
  return sympify(pol).subs(subs_dict)
407
438
 
408
439
 
409
440
  def init_basevec(dc):
410
- global dimen, monom_to_vec, base_vec
441
+ global dimen, monom_to_vec, base_vec # noqa: PLW0603
411
442
  monom_to_vec = {}
412
443
  index = 0
413
444
  for mn in dc:
@@ -430,12 +461,11 @@ def split_flat_term(arg):
430
461
  ys += [arg2.args[1]]
431
462
  else:
432
463
  ys += [arg2]
464
+ elif isinstance(arg2, Mul):
465
+ for i in range(abs(int(arg2.args[0]))):
466
+ zs += [-arg2.args[1]]
433
467
  else:
434
- if isinstance(arg2, Mul):
435
- for i in range(abs(int(arg2.args[0]))):
436
- zs += [-arg2.args[1]]
437
- else:
438
- zs += [arg2]
468
+ zs += [arg2]
439
469
  return ys, zs
440
470
 
441
471
 
@@ -449,11 +479,11 @@ def is_flat_term(term):
449
479
  return True
450
480
 
451
481
 
452
- def flatten_factors(term, var2=var3, var3=var3):
482
+ def flatten_factors(term):
453
483
  found_one = False
454
484
  if is_flat_term(term):
455
485
  return term, False
456
- elif isinstance(term, Pow):
486
+ if isinstance(term, Pow):
457
487
  if is_flat_term(term.args[0]) and len(term.args[0].args) > 2:
458
488
  ys, zs = split_flat_term(term.args[0])
459
489
  terms = [1]
@@ -464,11 +494,10 @@ def flatten_factors(term, var2=var3, var3=var3):
464
494
  terms2 += [t * (ys[i] + zs[i])]
465
495
  terms = terms2
466
496
  return Add(*terms)
467
- elif is_flat_term(term.args[0]):
497
+ if is_flat_term(term.args[0]):
468
498
  return term, False
469
- else:
470
- return flatten_factors(term.args[0]) ** term.args[1], True
471
- elif isinstance(term, Mul):
499
+ return flatten_factors(term.args[0]) ** term.args[1], True
500
+ if isinstance(term, Mul):
472
501
  terms = [1]
473
502
  for arg in term.args:
474
503
  terms2 = []
@@ -496,7 +525,7 @@ def flatten_factors(term, var2=var3, var3=var3):
496
525
  else:
497
526
  term = Add(*terms)
498
527
  return term, found_one
499
- elif isinstance(term, Add):
528
+ if isinstance(term, Add):
500
529
  res = 0
501
530
  for arg in term.args:
502
531
  flat, found = flatten_factors(arg)
@@ -504,14 +533,16 @@ def flatten_factors(term, var2=var3, var3=var3):
504
533
  found_one = True
505
534
  res += flat
506
535
  return res, found_one
536
+ return None
507
537
 
508
538
 
509
539
  def fres(v):
510
540
  for s in v.free_symbols:
511
541
  return s
542
+ return None
512
543
 
513
544
 
514
- def split_mul(arg0, var2=var2, var3=var3):
545
+ def split_mul(arg0, var2=None, var3=None):
515
546
  monoms = SortedList()
516
547
 
517
548
  var2s = {fres(var2[i]): i for i in range(len(var2))}
@@ -582,7 +613,7 @@ def is_negative(term):
582
613
  sign = 1
583
614
  if isinstance(term, Integer) or isinstance(term, int):
584
615
  return term < 0
585
- elif isinstance(term, Mul):
616
+ if isinstance(term, Mul):
586
617
  for arg in term.args:
587
618
  if isinstance(arg, Integer):
588
619
  sign *= arg
@@ -602,11 +633,11 @@ def is_negative(term):
602
633
  return sign < 0
603
634
 
604
635
 
605
- def find_base_vectors(monom_list, monom_list_neg, var2, var3, depth):
636
+ def find_base_vectors(monom_list, var2, var3, depth):
606
637
  size = 0
607
638
  mn_fullcount = {}
608
639
  # pairs_checked = set()
609
- monom_list = set([tuple(mn) for mn in monom_list])
640
+ monom_list = {tuple(mn) for mn in monom_list}
610
641
  ct = 0
611
642
  while ct < depth and size != len(monom_list):
612
643
  size = len(monom_list)
@@ -705,7 +736,7 @@ def find_base_vectors(monom_list, monom_list_neg, var2, var3, depth):
705
736
  return ret, monom_list
706
737
 
707
738
 
708
- def compute_positive_rep(val, var2=var2, var3=var3, msg=False, do_pos_neg=True):
739
+ def compute_positive_rep(val, var2=None, var3=None, msg=False, do_pos_neg=True):
709
740
  notint = False
710
741
  try:
711
742
  int(expand(val))
@@ -766,9 +797,9 @@ def compute_positive_rep(val, var2=var2, var3=var3, msg=False, do_pos_neg=True):
766
797
  depth = 1
767
798
 
768
799
  mons = split_monoms(pos_part, varsimp2, varsimp3)
769
- mons = set([tuple(mn) for mn in mons])
800
+ mons = {tuple(mn) for mn in mons}
770
801
  mons2 = split_monoms(neg_part, varsimp2, varsimp3)
771
- mons2 = set([tuple(mn2) for mn2 in mons2])
802
+ mons2 = {tuple(mn2) for mn2 in mons2}
772
803
 
773
804
  # mons2 = split_monoms(neg_part)
774
805
  # for mn in mons2:
@@ -795,12 +826,9 @@ def compute_positive_rep(val, var2=var2, var3=var3, msg=False, do_pos_neg=True):
795
826
  for j in range(len(bad_vectors) - 1, -1, -1):
796
827
  base_monoms.pop(bad_vectors[j])
797
828
 
798
- vrs = [
799
- pu.LpVariable(name=f"a{i}", lowBound=0, cat="Integer")
800
- for i in range(len(base_vectors))
801
- ]
829
+ vrs = [pu.LpVariable(name=f"a{i}", lowBound=0, cat="Integer") for i in range(len(base_vectors))]
802
830
  lp_prob = pu.LpProblem("Problem", pu.LpMinimize)
803
- lp_prob += int(0)
831
+ lp_prob += 0
804
832
  eqs = [*base_vec]
805
833
  for j in range(len(base_vectors)):
806
834
  for i in base_vectors[j]:
@@ -842,7 +870,7 @@ def compute_positive_rep(val, var2=var2, var3=var3, msg=False, do_pos_neg=True):
842
870
  lookup[key] = []
843
871
  mm0n1 = mm0[:n1]
844
872
  st = set(mm0n1)
845
- if len(st.intersection(set([0, 1]))) == len(st) and 1 in st:
873
+ if len(st.intersection({0, 1})) == len(st) and 1 in st:
846
874
  lookup[key] += [mm0]
847
875
  if mm0n1 == L1:
848
876
  mn1L += [mm0]
@@ -852,16 +880,12 @@ def compute_positive_rep(val, var2=var2, var3=var3, msg=False, do_pos_neg=True):
852
880
  if mn1[i] != 0:
853
881
  arr = np.array(comblistmn1)
854
882
  comblistmn12 = []
855
- mn1_2 = tuple([*mn1[n1:i]] + [0] + [*mn1[i + 1 :]])
883
+ mn1_2 = (*mn1[n1:i], 0, *mn1[i + 1 :])
856
884
  for mm0 in lookup[mn1_2]:
857
885
  comblistmn12 += (
858
886
  arr
859
887
  * np.prod(
860
- [
861
- varsimp2[k] - varsimp3[i - n1]
862
- for k in range(n1)
863
- if mm0[k] == 1
864
- ]
888
+ [varsimp2[k] - varsimp3[i - n1] for k in range(n1) if mm0[k] == 1],
865
889
  )
866
890
  ).tolist()
867
891
  comblistmn1 = comblistmn12
@@ -871,12 +895,9 @@ def compute_positive_rep(val, var2=var2, var3=var3, msg=False, do_pos_neg=True):
871
895
  if vec0 is not None:
872
896
  base_vectors += [vec0]
873
897
  base_monoms += [b1]
874
- vrs = [
875
- pu.LpVariable(name=f"a{i}", lowBound=0, cat="Integer")
876
- for i in range(len(base_vectors))
877
- ]
898
+ vrs = [pu.LpVariable(name=f"a{i}", lowBound=0, cat="Integer") for i in range(len(base_vectors))]
878
899
  lp_prob = pu.LpProblem("Problem", pu.LpMinimize)
879
- lp_prob += int(0)
900
+ lp_prob += 0
880
901
  eqs = [*base_vec]
881
902
  for j in range(len(base_vectors)):
882
903
  for i in base_vectors[j]:
@@ -915,7 +936,7 @@ def compute_positive_rep(val, var2=var2, var3=var3, msg=False, do_pos_neg=True):
915
936
  return val2
916
937
 
917
938
 
918
- def is_split_two(u, v, w):
939
+ def is_split_two(u, v, w): # noqa: ARG001
919
940
  if inv(w) - inv(u) != 2:
920
941
  return False, []
921
942
  diff_perm = mulperm(inverse([*u]), [*w])
@@ -924,7 +945,7 @@ def is_split_two(u, v, w):
924
945
  for i in range(len(identity)):
925
946
  if diff_perm[i] != identity[i]:
926
947
  cycle0 = set()
927
- cycle = set([i + 1])
948
+ cycle = {i + 1}
928
949
  last = i
929
950
  while len(cycle0) != len(cycle):
930
951
  cycle0 = cycle
@@ -936,8 +957,7 @@ def is_split_two(u, v, w):
936
957
  break
937
958
  if len(cycles) == 2:
938
959
  return True, cycles
939
- else:
940
- return False, []
960
+ return False, []
941
961
 
942
962
 
943
963
  def is_coeff_irreducible(u, v, w):
@@ -970,9 +990,9 @@ def is_hook(cd):
970
990
  return False
971
991
 
972
992
 
973
- def div_diff(i, poly):
993
+ def div_diff(i, poly, var2=_vars.var2):
974
994
  return sympify(
975
- sympy.div(sympy.sympify(poly - permy(poly, i)), sympy.sympify(var2[i] - var2[i + 1]))[0]
995
+ sympy.div(sympy.sympify(poly - permy(poly, i)), sympy.sympify(var2[i] - var2[i + 1]))[0],
976
996
  )
977
997
 
978
998
 
@@ -997,30 +1017,31 @@ def skew_div_diff(u, w, poly):
997
1017
  u2 = [*u]
998
1018
  u2[d], u2[d + 1] = u2[d + 1], u2[d]
999
1019
  return skew_div_diff(u2, w2, permy(poly, d + 1))
1000
- else:
1001
- return skew_div_diff(u, w2, div_diff(d + 1, poly))
1020
+ return skew_div_diff(u, w2, div_diff(d + 1, poly))
1002
1021
 
1003
1022
 
1004
1023
  @cached(
1005
1024
  cache={},
1006
- key=lambda val,
1025
+ key=lambda val, u2, v2, w2, var2=None, var3=None, msg=False, do_pos_neg=True, sign_only=False: hashkey(u2, v2, w2, var2, var3, msg, do_pos_neg, sign_only),
1026
+ )
1027
+ def posify(
1028
+ val,
1007
1029
  u2,
1008
1030
  v2,
1009
1031
  w2,
1010
- var2=var2,
1011
- var3=var3,
1032
+ var2=None,
1033
+ var3=None,
1012
1034
  msg=False,
1013
1035
  do_pos_neg=True,
1014
- sign_only=False: hashkey(u2, v2, w2, var2, var3, msg, do_pos_neg, sign_only),
1015
- )
1016
- def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, sign_only=False):
1036
+ sign_only=False,
1037
+ n=_vars.n,
1038
+ ):
1017
1039
  if inv(u2) + inv(v2) - inv(w2) == 0:
1018
1040
  return val
1019
1041
  cdv = code(v2)
1020
- if set(cdv) == set([0, 1]) and do_pos_neg:
1042
+ if set(cdv) == {0, 1} and do_pos_neg:
1021
1043
  return val
1022
- # if is_hook(cdv):
1023
- # print(f"Could've {cdv}")
1044
+
1024
1045
  if not sign_only and expand(val) == 0:
1025
1046
  return 0
1026
1047
 
@@ -1036,7 +1057,7 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1036
1057
  u, v, w = reduce_coeff(u, v, w)
1037
1058
  if is_coeff_irreducible(u, v, w):
1038
1059
  while is_coeff_irreducible(u, v, w) and tuple(permtrim(w0)) != tuple(
1039
- permtrim([*w])
1060
+ permtrim([*w]),
1040
1061
  ):
1041
1062
  w0 = w
1042
1063
  u, v, w = reduce_descents(u, v, w)
@@ -1072,12 +1093,7 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1072
1093
  r = inv(w) - inv_u
1073
1094
  val = 0
1074
1095
  w2 = w
1075
- hvarset = (
1076
- [w2[i] for i in range(min(len(w2), k))]
1077
- + [i + 1 for i in range(len(w2), k)]
1078
- + [w2[b] for b in range(k, len(u)) if u[b] != w2[b]]
1079
- + [w2[b] for b in range(len(u), len(w2))]
1080
- )
1096
+ hvarset = [w2[i] for i in range(min(len(w2), k))] + [i + 1 for i in range(len(w2), k)] + [w2[b] for b in range(k, len(u)) if u[b] != w2[b]] + [w2[b] for b in range(len(u), len(w2))]
1081
1097
  val = elem_sym_poly(
1082
1098
  p - r,
1083
1099
  k + p - 1,
@@ -1095,9 +1111,7 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1095
1111
  for i in range(len(w)):
1096
1112
  if a == -1 and u[i] != w[i]:
1097
1113
  a = i
1098
- elif i >= len(u) and w[i] != i + 1:
1099
- b = i
1100
- elif b == -1 and u[i] != w[i]:
1114
+ elif (i >= len(u) and w[i] != i + 1) or (b == -1 and u[i] != w[i]):
1101
1115
  b = i
1102
1116
  arr = [[[], v]]
1103
1117
  d = -1
@@ -1116,15 +1130,14 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1116
1130
  for vr, v2 in arr:
1117
1131
  dpret = pull_out_var(i2, [*v2])
1118
1132
  for v3r, v3 in dpret:
1119
- arr2 += [[vr + [v3r], v3]]
1133
+ arr2 += [[[*vr, v3r], v3]]
1120
1134
  arr = arr2
1121
1135
  val = 0
1122
1136
  for L in arr:
1123
1137
  v3 = [*L[-1]]
1124
1138
  if v3[0] < v3[1]:
1125
1139
  continue
1126
- else:
1127
- v3[0], v3[1] = v3[1], v3[0]
1140
+ v3[0], v3[1] = v3[1], v3[0]
1128
1141
  toadd = 1
1129
1142
  for i in range(d):
1130
1143
  if i in [a, b]:
@@ -1185,7 +1198,7 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1185
1198
  for vr, v2 in arr:
1186
1199
  dpret = pull_out_var(i2, [*v2])
1187
1200
  for v3r, v3 in dpret:
1188
- arr2 += [[vr + [(v3r, i + 1)], v3]]
1201
+ arr2 += [[[*vr, (v3r, i + 1)], v3]]
1189
1202
  arr = arr2
1190
1203
  val = 0
1191
1204
 
@@ -1195,8 +1208,7 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1195
1208
  v3 = [*L[-1]]
1196
1209
  if v3[real_a1] < v3[real_b1]:
1197
1210
  continue
1198
- else:
1199
- v3[real_a1], v3[real_b1] = v3[real_b1], v3[real_a1]
1211
+ v3[real_a1], v3[real_b1] = v3[real_b1], v3[real_a1]
1200
1212
  arr2 += [[L[0], v3]]
1201
1213
  arr = arr2
1202
1214
  if not good2:
@@ -1214,7 +1226,7 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1214
1226
  for vr, v2 in arr:
1215
1227
  dpret = pull_out_var(i2, [*v2])
1216
1228
  for v3r, v3 in dpret:
1217
- arr2 += [[vr + [(v3r, var_index)], v3]]
1229
+ arr2 += [[[*vr, (v3r, var_index)], v3]]
1218
1230
  arr = arr2
1219
1231
  if good2:
1220
1232
  arr2 = []
@@ -1223,8 +1235,7 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1223
1235
  try:
1224
1236
  if v3[real_a2] < v3[real_b2]:
1225
1237
  continue
1226
- else:
1227
- v3[real_a2], v3[real_b2] = v3[real_b2], v3[real_a2]
1238
+ v3[real_a2], v3[real_b2] = v3[real_b2], v3[real_a2]
1228
1239
  except IndexError:
1229
1240
  continue
1230
1241
  arr2 += [[L[0], v3]]
@@ -1244,7 +1255,7 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1244
1255
  for vr, v2 in arr:
1245
1256
  dpret = pull_out_var(i2, [*v2])
1246
1257
  for v3r, v3 in dpret:
1247
- arr2 += [[vr + [(v3r, var_index)], v3]]
1258
+ arr2 += [[[*vr, (v3r, var_index)], v3]]
1248
1259
  arr = arr2
1249
1260
 
1250
1261
  for L in arr:
@@ -1253,7 +1264,7 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1253
1264
  doschubpoly = True
1254
1265
  if (not good1 or not good2) and v3[0] < v3[1] and (good1 or good2):
1255
1266
  continue
1256
- elif (good1 or good2) and (not good1 or not good2):
1267
+ if (good1 or good2) and (not good1 or not good2):
1257
1268
  v3[0], v3[1] = v3[1], v3[0]
1258
1269
  elif not good1 and not good2:
1259
1270
  doschubpoly = False
@@ -1266,10 +1277,9 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1266
1277
  if len(v3) <= 3 or v3[2] < v3[3]:
1267
1278
  coeff = 0
1268
1279
  continue
1269
- else:
1270
- v3[0], v3[1] = v3[1], v3[0]
1271
- v3[2], v3[3] = v3[3], v3[2]
1272
- coeff = permy(schubpoly(v3, var2, var3), 2)
1280
+ v3[0], v3[1] = v3[1], v3[0]
1281
+ v3[2], v3[3] = v3[3], v3[2]
1282
+ coeff = permy(schubpoly(v3, var2, var3), 2)
1273
1283
  elif len(v3) <= 3 or v3[2] < v3[3]:
1274
1284
  if len(v3) <= 3:
1275
1285
  v3 += [4]
@@ -1277,7 +1287,8 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1277
1287
  coeff = permy(
1278
1288
  posify(
1279
1289
  schubmult_one((1, 3, 2), tuple(permtrim([*v3])), var2, var3).get(
1280
- (2, 4, 3, 1), 0
1290
+ (2, 4, 3, 1),
1291
+ 0,
1281
1292
  ),
1282
1293
  (1, 3, 2),
1283
1294
  tuple(permtrim([*v3])),
@@ -1292,7 +1303,8 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1292
1303
  else:
1293
1304
  coeff = permy(
1294
1305
  schubmult_one((1, 3, 2), tuple(permtrim([*v3])), var2, var3).get(
1295
- (2, 4, 1, 3), 0
1306
+ (2, 4, 1, 3),
1307
+ 0,
1296
1308
  ),
1297
1309
  2,
1298
1310
  )
@@ -1321,50 +1333,6 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1321
1333
  if sign_only:
1322
1334
  return 0
1323
1335
  val = forwardcoeff(u, v, w, var2, var3)
1324
- # elif inv(w) - inv(u) == 2:
1325
- # indices = []
1326
- # for i in range(len(w)):
1327
- # if i>=len(u) or u[i]!=w[i]:
1328
- # indices += [i+1]
1329
- # arr = [[[],v]]
1330
- # d = -1
1331
- # for i in range(len(v)-1):
1332
- # if v[i]>v[i+1]:
1333
- # d = i + 1
1334
- # for i in range(d):
1335
- # arr2 = []
1336
- #
1337
- # if i+1 in indices:
1338
- # continue
1339
- # i2 = 1
1340
- # i2 += len([aa for aa in indices if i+1>aa])
1341
- # for vr, v2 in arr:
1342
- # dpret = pull_out_var(i2,[*v2])
1343
- # for v3r, v3 in dpret:
1344
- # arr2 += [[vr + [(v3r,i+1)],v3]]
1345
- # arr = arr2
1346
- # val = 0
1347
- #
1348
- # for L in arr:
1349
- # v3 = [*L[-1]]
1350
- # tomul = 1
1351
- # pooly = skew_div_diff(u,w,schubpoly(v3,[0,*[var2[a] for a in indices]],var3))
1352
- # coeff = compute_positive_rep(pooly,var2,var3,msg,False)
1353
- # if coeff == -1:
1354
- # return -1
1355
- # tomul = sympify(coeff)
1356
- # toadd = 1
1357
- # for i in range(len(L[0])):
1358
- # var_index = L[0][i][1]
1359
- # oaf = L[0][i][0]
1360
- # if var_index-1>=len(w):
1361
- # yv = var_index
1362
- # else:
1363
- # yv = w[var_index-1]
1364
- # for j in range(len(oaf)):
1365
- # toadd*= var2[yv] - var3[oaf[j]]
1366
- # toadd*=tomul#.subs(subs_dict3)
1367
- # val += toadd
1368
1336
  else:
1369
1337
  c01 = code(u)
1370
1338
  c02 = code(w)
@@ -1401,13 +1369,17 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1401
1369
  break
1402
1370
  v3 = uncode(newc)
1403
1371
  coeff_dict = schubmult_one(
1404
- tuple(permtrim([*u])), tuple(permtrim(uncode(elemc))), var2, var3
1372
+ tuple(permtrim([*u])),
1373
+ tuple(permtrim(uncode(elemc))),
1374
+ var2,
1375
+ var3,
1405
1376
  )
1406
1377
  val = 0
1407
1378
  for new_w in coeff_dict:
1408
1379
  tomul = coeff_dict[new_w]
1409
1380
  newval = schubmult_one(new_w, tuple(permtrim(uncode(newc))), var2, var3).get(
1410
- tuple(permtrim([*w])), 0
1381
+ tuple(permtrim([*w])),
1382
+ 0,
1411
1383
  )
1412
1384
  newval = posify(
1413
1385
  newval,
@@ -1428,7 +1400,8 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1428
1400
  w3 = uncode([0] + c02[1:])
1429
1401
  val = 0
1430
1402
  val = schubmult_one(tuple(permtrim(u3)), tuple(permtrim([*v])), var2, var3).get(
1431
- tuple(permtrim(w3)), 0
1403
+ tuple(permtrim(w3)),
1404
+ 0,
1432
1405
  )
1433
1406
  val = posify(
1434
1407
  val,
@@ -1455,78 +1428,24 @@ def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, si
1455
1428
  tomul *= var2[1] - var3[arr[i]]
1456
1429
 
1457
1430
  val2 = schubmult_one(tuple(permtrim(u3)), tuple(permtrim(v3)), var2, var3).get(
1458
- tuple(permtrim(w3)), 0
1431
+ tuple(permtrim(w3)),
1432
+ 0,
1459
1433
  )
1460
1434
  val2 = posify(val2, u3, tuple(permtrim(v3)), w3, var2, var3, msg, do_pos_neg)
1461
1435
  val += tomul * shiftsub(val2)
1462
- # elif inv(w)-inv(u)==2 and len(trimcode(u)) == len(trimcode(w)):
1463
- # indices = []
1464
- # for i in range(len(w)):
1465
- # if i>=len(u) or u[i]!=w[i]:
1466
- # indices += [i+1]
1467
- # arr = [[[],v]]
1468
- # d = -1
1469
- # for i in range(len(v)-1):
1470
- # if v[i]>v[i+1]:
1471
- # d = i + 1
1472
- # for i in range(d):
1473
- # arr2 = []
1474
- #
1475
- # if i+1 in indices:
1476
- # continue
1477
- # i2 = 1
1478
- # i2 += len([aa for aa in indices if i+1>aa])
1479
- # for vr, v2 in arr:
1480
- # dpret = pull_out_var(i2,[*v2])
1481
- # for v3r, v3 in dpret:
1482
- # arr2 += [[vr + [(v3r,i+1)],v3]]
1483
- # arr = arr2
1484
- # val = 0
1485
- #
1486
- # for L in arr:
1487
- # v3 = [*L[-1]]
1488
- # tomul = 1
1489
- # toadd = 1
1490
- # for i in range(len(L[0])):
1491
- # var_index = L[0][i][1]
1492
- # oaf = L[0][i][0]
1493
- # if var_index-1>=len(w):
1494
- # yv = var_index
1495
- # else:
1496
- # yv = w[var_index-1]
1497
- # for j in range(len(oaf)):
1498
- # toadd*= var2[yv] - var3[oaf[j]]
1499
- # pooly = skew_div_diff(u,w,schubpoly(v3,[0,*[var2[a] for a in indices]],var3))
1500
- # if toadd == 0:
1501
- # continue
1502
- # if pooly !=0:
1503
- # coeff = compute_positive_rep(pooly,var2,var3,msg,False)
1504
- # else:
1505
- # coeff = 0
1506
- # if coeff == -1:
1507
- # return -1
1508
- # tomul = sympify(coeff)
1509
- # toadd*=tomul#.subs(subs_dict3)
1510
- # val += toadd
1511
- else:
1512
- if not sign_only:
1513
- if inv(u) + inv(v) - inv(w) == 1:
1514
- val2 = compute_positive_rep(val, var2, var3, msg, False)
1515
- else:
1516
- val2 = compute_positive_rep(val, var2, var3, msg, do_pos_neg)
1517
- if val2 is not None:
1518
- val = val2
1436
+ elif not sign_only:
1437
+ if inv(u) + inv(v) - inv(w) == 1:
1438
+ val2 = compute_positive_rep(val, var2, var3, msg, False)
1519
1439
  else:
1520
- # st = str(expand(val))
1521
- # if st.find("-")!=-1:
1522
- # return -1
1523
- # else:
1524
- # return val
1525
- d = expand(val).as_coefficients_dict()
1526
- for v in d.values():
1527
- if v < 0:
1528
- return -1
1529
- return 1
1440
+ val2 = compute_positive_rep(val, var2, var3, msg, do_pos_neg)
1441
+ if val2 is not None:
1442
+ val = val2
1443
+ else:
1444
+ d = expand(val).as_coefficients_dict()
1445
+ for v in d.values():
1446
+ if v < 0:
1447
+ return -1
1448
+ return 1
1530
1449
  return val
1531
1450
 
1532
1451
 
@@ -1551,8 +1470,7 @@ def split_perms(perms):
1551
1470
  for j in range(index, len(cd)):
1552
1471
  if cd[j] != 0:
1553
1472
  break
1554
- else:
1555
- num_zeros += 1
1473
+ num_zeros += 1
1556
1474
  if num_zeros >= num_zeros_to_miss:
1557
1475
  cd1 = cd[:index]
1558
1476
  cd2 = [0 for i in range(index)] + cd[index:]
@@ -1567,7 +1485,7 @@ def split_perms(perms):
1567
1485
  return perms2
1568
1486
 
1569
1487
 
1570
- def schubpoly(v, var2=var2, var3=var3, start_var=1):
1488
+ def schubpoly(v, var2=None, var3=None, start_var=1):
1571
1489
  n = 0
1572
1490
  for j in range(len(v) - 2, -1, -1):
1573
1491
  if v[j] > v[j + 1]:
@@ -1585,6 +1503,52 @@ def schubpoly(v, var2=var2, var3=var3, start_var=1):
1585
1503
  return ret
1586
1504
 
1587
1505
 
1588
- def permy(val, i):
1506
+ def permy(val, i, var2=_vars.var2):
1589
1507
  subsdict = {var2[i]: var2[i + 1], var2[i + 1]: var2[i]}
1590
1508
  return sympify(val).subs(subsdict)
1509
+
1510
+
1511
+ def schub_coprod(mperm, indices, var2=_vars.var2, var3=_vars.var3):
1512
+ indices = sorted(indices)
1513
+ subs_dict_coprod = {}
1514
+ k = len(indices)
1515
+ n = len(mperm)
1516
+ kcd = [indices[i] - i - 1 for i in range(len(indices))] + [n + 1 - k for i in range(k, n)]
1517
+ max_required = max([kcd[i] + i for i in range(len(kcd))])
1518
+ kcd2 = kcd + [0 for i in range(len(kcd), max_required)] + [0]
1519
+ N = len(kcd)
1520
+ kperm = permtrim(inverse(uncode(kcd2)))
1521
+ inv_kperm = inv(kperm)
1522
+ vn = symarray("soible", 100)
1523
+
1524
+ for i in range(1, N * 2 + 1):
1525
+ if i <= N:
1526
+ subs_dict_coprod[vn[i]] = var2[i]
1527
+ else:
1528
+ subs_dict_coprod[vn[i]] = var3[i - N]
1529
+
1530
+ coeff_dict = {tuple(kperm): 1}
1531
+ coeff_dict = schubmult(coeff_dict, mperm, vn, var2)
1532
+
1533
+ inverse_kperm = inverse(kperm)
1534
+
1535
+ ret_dict = {}
1536
+ for perm in coeff_dict:
1537
+ downperm = mulperm(list(perm), inverse_kperm)
1538
+ if inv(downperm) == inv(perm) - inv_kperm:
1539
+ flag = True
1540
+ for i in range(N):
1541
+ if downperm[i] > N:
1542
+ flag = False
1543
+ break
1544
+ if not flag:
1545
+ continue
1546
+ firstperm = downperm[0:N]
1547
+ secondperm = [downperm[i] - N for i in range(N, len(downperm))]
1548
+
1549
+ val = sympify(coeff_dict[perm]).subs(subs_dict_coprod)
1550
+
1551
+ key = (tuple(permtrim(firstperm)), tuple(permtrim(secondperm)))
1552
+ ret_dict[key] = val
1553
+
1554
+ return ret_dict