schubmult 2.0.2__py3-none-any.whl → 2.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. schubmult/__init__.py +1 -1
  2. schubmult/_base_argparse.py +35 -6
  3. schubmult/_tests.py +9 -0
  4. schubmult/sage_integration/__init__.py +1 -0
  5. schubmult/sage_integration/_fast_double_schubert_polynomial_ring.py +68 -11
  6. schubmult/sage_integration/_fast_schubert_polynomial_ring.py +43 -5
  7. schubmult/sage_integration/_indexing.py +11 -4
  8. schubmult/schubmult_double/__init__.py +6 -2
  9. schubmult/schubmult_double/__main__.py +1 -1
  10. schubmult/schubmult_double/_funcs.py +112 -32
  11. schubmult/schubmult_double/_script.py +109 -51
  12. schubmult/schubmult_py/__init__.py +5 -2
  13. schubmult/schubmult_py/__main__.py +1 -1
  14. schubmult/schubmult_py/_funcs.py +54 -9
  15. schubmult/schubmult_py/_script.py +33 -52
  16. schubmult/schubmult_q/__init__.py +1 -0
  17. schubmult/schubmult_q/__main__.py +1 -1
  18. schubmult/schubmult_q/_funcs.py +21 -37
  19. schubmult/schubmult_q/_script.py +19 -16
  20. schubmult/schubmult_q_double/__init__.py +1 -0
  21. schubmult/schubmult_q_double/__main__.py +1 -1
  22. schubmult/schubmult_q_double/_funcs.py +57 -24
  23. schubmult/schubmult_q_double/_script.py +200 -139
  24. {schubmult-2.0.2.dist-info → schubmult-2.0.3.dist-info}/METADATA +4 -4
  25. schubmult-2.0.3.dist-info/RECORD +30 -0
  26. {schubmult-2.0.2.dist-info → schubmult-2.0.3.dist-info}/WHEEL +1 -1
  27. schubmult-2.0.3.dist-info/entry_points.txt +5 -0
  28. {schubmult-2.0.2.dist-info → schubmult-2.0.3.dist-info}/top_level.txt +0 -1
  29. schubmult/schubmult_double/_vars.py +0 -18
  30. schubmult/schubmult_py/_vars.py +0 -3
  31. schubmult/schubmult_q/_vars.py +0 -18
  32. schubmult/schubmult_q_double/_vars.py +0 -21
  33. schubmult-2.0.2.dist-info/RECORD +0 -36
  34. schubmult-2.0.2.dist-info/entry_points.txt +0 -5
  35. tests/__init__.py +0 -0
  36. tests/test_fast_double_schubert.py +0 -145
  37. tests/test_fast_schubert.py +0 -38
  38. {schubmult-2.0.2.dist-info → schubmult-2.0.3.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@ from bisect import bisect_left
2
2
  from functools import cache
3
3
  from cachetools import cached
4
4
  from cachetools.keys import hashkey
5
- from symengine import sympify, Add, Mul, Pow, expand, Integer
5
+ from symengine import sympify, Add, Mul, Pow, expand, Integer, symarray
6
6
  from schubmult.perm_lib import (
7
7
  elem_sym_perms,
8
8
  elem_sym_poly,
@@ -36,13 +36,45 @@ import pulp as pu
36
36
  import sympy
37
37
  import psutil
38
38
  from sortedcontainers import SortedList
39
- from ._vars import (
40
- n,
41
- var2,
42
- var3,
43
- var_x,
44
- var_y,
45
- )
39
+ from functools import cached_property
40
+
41
+ # NO GLOBAL VARS
42
+ # from ._vars import (
43
+ # n,
44
+ # var2,
45
+ # var3,
46
+ # _vars.var1,
47
+ # var_y,
48
+ # )
49
+
50
+
51
+ class _gvars:
52
+ @cached_property
53
+ def n(self):
54
+ return 100
55
+
56
+ # @cached_property
57
+ # def fvar(self):
58
+ # return 100
59
+
60
+ @cached_property
61
+ def var1(self):
62
+ return tuple(symarray("x", self.n).tolist())
63
+
64
+ @cached_property
65
+ def var2(self):
66
+ return tuple(symarray("y", self.n).tolist())
67
+
68
+ @cached_property
69
+ def var3(self):
70
+ return tuple(symarray("z", self.n).tolist())
71
+
72
+ @cached_property
73
+ def var_r(self):
74
+ return symarray("r", 100)
75
+
76
+
77
+ _vars = _gvars()
46
78
 
47
79
 
48
80
  def count_sorted(mn, tp):
@@ -54,11 +86,11 @@ def count_sorted(mn, tp):
54
86
  return ct
55
87
 
56
88
 
57
- def E(p, k, varl=var_y[1:]):
58
- return elem_sym_poly(p, k, var_x[1:], varl)
89
+ def E(p, k, varl=_vars.var2[1:]):
90
+ return elem_sym_poly(p, k, _vars.var1[1:], varl)
59
91
 
60
92
 
61
- def single_variable(coeff_dict, varnum, var2=var2):
93
+ def single_variable(coeff_dict, varnum, var2=None):
62
94
  ret = {}
63
95
  for u in coeff_dict:
64
96
  if varnum - 1 < len(u):
@@ -78,7 +110,7 @@ def single_variable(coeff_dict, varnum, var2=var2):
78
110
  return ret
79
111
 
80
112
 
81
- def single_variable_down(coeff_dict, varnum):
113
+ def single_variable_down(coeff_dict, varnum, var2=_vars.var2):
82
114
  ret = {}
83
115
  for u in coeff_dict:
84
116
  if varnum - 1 < len(u):
@@ -98,7 +130,7 @@ def single_variable_down(coeff_dict, varnum):
98
130
  return ret
99
131
 
100
132
 
101
- def mult_poly(coeff_dict, poly, var_x=var_x, var_y=var2):
133
+ def mult_poly(coeff_dict, poly, var_x=_vars.var1, var_y=_vars.var2):
102
134
  if poly in var_x:
103
135
  return single_variable(coeff_dict, var_x.index(poly), var_y)
104
136
  elif isinstance(poly, Mul):
@@ -126,8 +158,8 @@ def mult_poly(coeff_dict, poly, var_x=var_x, var_y=var2):
126
158
 
127
159
 
128
160
  def mult_poly_down(coeff_dict, poly):
129
- if poly in var_x:
130
- return single_variable_down(coeff_dict, var_x.index(poly))
161
+ if poly in _vars.var1:
162
+ return single_variable_down(coeff_dict, _vars.var1.index(poly))
131
163
  elif isinstance(poly, Mul):
132
164
  ret = coeff_dict
133
165
  for a in poly.args:
@@ -168,7 +200,7 @@ def nilhecke_mult(coeff_dict1, coeff_dict2):
168
200
  return ret
169
201
 
170
202
 
171
- def forwardcoeff(u, v, perm, var2=var2, var3=var3):
203
+ def forwardcoeff(u, v, perm, var2=None, var3=None):
172
204
  th = theta(v)
173
205
  muv = uncode(th)
174
206
  vmun1 = mulperm(inverse([*v]), muv)
@@ -180,7 +212,7 @@ def forwardcoeff(u, v, perm, var2=var2, var3=var3):
180
212
  return 0
181
213
 
182
214
 
183
- def dualcoeff(u, v, perm, var2=var2, var3=var3):
215
+ def dualcoeff(u, v, perm, var2=None, var3=None):
184
216
  if u == (1, 2):
185
217
  vp = mulperm([*v], inverse(perm))
186
218
  if inv(vp) == inv(v) - inv(perm):
@@ -256,11 +288,11 @@ monom_to_vec = {}
256
288
 
257
289
 
258
290
  @cache
259
- def schubmult_one(perm1, perm2, var2=var2, var3=var3):
291
+ def schubmult_one(perm1, perm2, var2=None, var3=None):
260
292
  return schubmult({perm1: 1}, perm2, var2, var3)
261
293
 
262
294
 
263
- def schubmult(perm_dict, v, var2=var2, var3=var3):
295
+ def schubmult(perm_dict, v, var2=None, var3=None):
264
296
  vn1 = inverse(v)
265
297
  th = theta(vn1)
266
298
  if len(th) == 0:
@@ -319,7 +351,7 @@ def schubmult(perm_dict, v, var2=var2, var3=var3):
319
351
  return ret_dict
320
352
 
321
353
 
322
- def schubmult_down(perm_dict, v, var2=var2, var3=var3):
354
+ def schubmult_down(perm_dict, v, var2=None, var3=None):
323
355
  vn1 = inverse(v)
324
356
  th = theta(vn1)
325
357
  if th[0] == 0:
@@ -371,7 +403,7 @@ def schubmult_down(perm_dict, v, var2=var2, var3=var3):
371
403
  return ret_dict
372
404
 
373
405
 
374
- def poly_to_vec(poly, vec0=None):
406
+ def poly_to_vec(poly, vec0=None, var3=_vars.var3):
375
407
  global dimen, monom_to_vec, base_vec
376
408
  poly = expand(poly.xreplace({var3[1]: 0}))
377
409
 
@@ -396,12 +428,12 @@ def poly_to_vec(poly, vec0=None):
396
428
  return vec
397
429
 
398
430
 
399
- def shiftsub(pol):
431
+ def shiftsub(pol, var2=_vars.var2, var3=_vars.var3):
400
432
  subs_dict = dict([(var2[i], var2[i + 1]) for i in range(99)])
401
433
  return sympify(pol).subs(subs_dict)
402
434
 
403
435
 
404
- def shiftsubz(pol):
436
+ def shiftsubz(pol, var2=_vars.var2, var3=_vars.var3):
405
437
  subs_dict = dict([(var3[i], var3[i + 1]) for i in range(99)])
406
438
  return sympify(pol).subs(subs_dict)
407
439
 
@@ -449,7 +481,7 @@ def is_flat_term(term):
449
481
  return True
450
482
 
451
483
 
452
- def flatten_factors(term, var2=var3, var3=var3):
484
+ def flatten_factors(term, var2=_vars.var2, var3=_vars.var3):
453
485
  found_one = False
454
486
  if is_flat_term(term):
455
487
  return term, False
@@ -511,7 +543,7 @@ def fres(v):
511
543
  return s
512
544
 
513
545
 
514
- def split_mul(arg0, var2=var2, var3=var3):
546
+ def split_mul(arg0, var2=None, var3=None):
515
547
  monoms = SortedList()
516
548
 
517
549
  var2s = {fres(var2[i]): i for i in range(len(var2))}
@@ -705,7 +737,7 @@ def find_base_vectors(monom_list, monom_list_neg, var2, var3, depth):
705
737
  return ret, monom_list
706
738
 
707
739
 
708
- def compute_positive_rep(val, var2=var2, var3=var3, msg=False, do_pos_neg=True):
740
+ def compute_positive_rep(val, var2=None, var3=None, msg=False, do_pos_neg=True):
709
741
  notint = False
710
742
  try:
711
743
  int(expand(val))
@@ -970,7 +1002,7 @@ def is_hook(cd):
970
1002
  return False
971
1003
 
972
1004
 
973
- def div_diff(i, poly):
1005
+ def div_diff(i, poly, var2=_vars.var2):
974
1006
  return sympify(
975
1007
  sympy.div(sympy.sympify(poly - permy(poly, i)), sympy.sympify(var2[i] - var2[i + 1]))[0]
976
1008
  )
@@ -1007,13 +1039,15 @@ def skew_div_diff(u, w, poly):
1007
1039
  u2,
1008
1040
  v2,
1009
1041
  w2,
1010
- var2=var2,
1011
- var3=var3,
1042
+ var2=None,
1043
+ var3=None,
1012
1044
  msg=False,
1013
1045
  do_pos_neg=True,
1014
1046
  sign_only=False: hashkey(u2, v2, w2, var2, var3, msg, do_pos_neg, sign_only),
1015
1047
  )
1016
- def posify(val, u2, v2, w2, var2=var2, var3=var3, msg=False, do_pos_neg=True, sign_only=False):
1048
+ def posify(
1049
+ val, u2, v2, w2, var2=None, var3=None, msg=False, do_pos_neg=True, sign_only=False, n=_vars.n
1050
+ ):
1017
1051
  if inv(u2) + inv(v2) - inv(w2) == 0:
1018
1052
  return val
1019
1053
  cdv = code(v2)
@@ -1567,7 +1601,7 @@ def split_perms(perms):
1567
1601
  return perms2
1568
1602
 
1569
1603
 
1570
- def schubpoly(v, var2=var2, var3=var3, start_var=1):
1604
+ def schubpoly(v, var2=None, var3=None, start_var=1):
1571
1605
  n = 0
1572
1606
  for j in range(len(v) - 2, -1, -1):
1573
1607
  if v[j] > v[j + 1]:
@@ -1585,6 +1619,52 @@ def schubpoly(v, var2=var2, var3=var3, start_var=1):
1585
1619
  return ret
1586
1620
 
1587
1621
 
1588
- def permy(val, i):
1622
+ def permy(val, i, var2=_vars.var2):
1589
1623
  subsdict = {var2[i]: var2[i + 1], var2[i + 1]: var2[i]}
1590
1624
  return sympify(val).subs(subsdict)
1625
+
1626
+
1627
+ def schub_coprod(mperm, indices, var2=_vars.var2, var3=_vars.var3):
1628
+ indices = sorted(indices)
1629
+ subs_dict_coprod = {}
1630
+ k = len(indices)
1631
+ n = len(mperm)
1632
+ kcd = [indices[i] - i - 1 for i in range(len(indices))] + [n + 1 - k for i in range(k, n)]
1633
+ max_required = max([kcd[i] + i for i in range(len(kcd))])
1634
+ kcd2 = kcd + [0 for i in range(len(kcd), max_required)] + [0]
1635
+ N = len(kcd)
1636
+ kperm = permtrim(inverse(uncode(kcd2)))
1637
+ inv_kperm = inv(kperm)
1638
+ vn = symarray("soible", 100)
1639
+
1640
+ for i in range(1, N * 2 + 1):
1641
+ if i <= N:
1642
+ subs_dict_coprod[vn[i]] = var2[i]
1643
+ else:
1644
+ subs_dict_coprod[vn[i]] = var3[i - N]
1645
+
1646
+ coeff_dict = {tuple(kperm): 1}
1647
+ coeff_dict = schubmult(coeff_dict, mperm, vn, var2)
1648
+
1649
+ inverse_kperm = inverse(kperm)
1650
+
1651
+ ret_dict = {}
1652
+ for perm in coeff_dict:
1653
+ downperm = mulperm(list(perm), inverse_kperm)
1654
+ if inv(downperm) == inv(perm) - inv_kperm:
1655
+ flag = True
1656
+ for i in range(N):
1657
+ if downperm[i] > N:
1658
+ flag = False
1659
+ break
1660
+ if not flag:
1661
+ continue
1662
+ firstperm = downperm[0:N]
1663
+ secondperm = [downperm[i] - N for i in range(N, len(downperm))]
1664
+
1665
+ val = sympify(coeff_dict[perm]).subs(subs_dict_coprod)
1666
+
1667
+ key = (tuple(permtrim(firstperm)), tuple(permtrim(secondperm)))
1668
+ ret_dict[key] = val
1669
+
1670
+ return ret_dict
@@ -1,14 +1,9 @@
1
1
  import numpy as np
2
2
  import sympy
3
3
  import sys
4
- from ._vars import (
5
- var2,
6
- var3,
7
- var_x,
8
- var,
9
- var_r
10
- )
11
- from ._funcs import (
4
+
5
+ # from schubmult.schubmult_double._vars import var_x, var, var_r
6
+ from schubmult.schubmult_double._funcs import (
12
7
  mult_poly,
13
8
  mult_poly_down,
14
9
  schubmult,
@@ -17,7 +12,7 @@ from ._funcs import (
17
12
  posify,
18
13
  split_perms,
19
14
  )
20
- from symengine import expand, sympify
15
+ from symengine import expand, sympify, symarray
21
16
  from schubmult._base_argparse import schub_argparse
22
17
  from schubmult.perm_lib import (
23
18
  add_perm_dict,
@@ -33,6 +28,37 @@ from schubmult.perm_lib import (
33
28
  trimcode,
34
29
  )
35
30
 
31
+ from functools import cached_property
32
+
33
+
34
+ class _gvars:
35
+ @cached_property
36
+ def n(self):
37
+ return 100
38
+
39
+ # @cached_property
40
+ # def fvar(self):
41
+ # return 100
42
+
43
+ @cached_property
44
+ def var1(self):
45
+ return tuple(symarray("x", self.n).tolist())
46
+
47
+ @cached_property
48
+ def var2(self):
49
+ return tuple(symarray("y", self.n).tolist())
50
+
51
+ @cached_property
52
+ def var3(self):
53
+ return tuple(symarray("z", self.n).tolist())
54
+
55
+ @cached_property
56
+ def var_r(self):
57
+ return symarray("r", 100)
58
+
59
+
60
+ _vars = _gvars()
61
+
36
62
 
37
63
  def _display(val):
38
64
  print(val)
@@ -42,13 +68,20 @@ def _display_full(
42
68
  coeff_dict,
43
69
  args,
44
70
  formatter,
71
+ var2,
72
+ var3,
45
73
  posified=None,
46
74
  check_coeff_dict=None,
47
75
  kperm=None,
48
- var2=var2,
49
- var3=var3,
50
76
  N=None,
51
77
  ):
78
+ subs_dict2 = {}
79
+ for i in range(1, 100):
80
+ sm = var2[1]
81
+ for j in range(1, i):
82
+ sm += _vars.var_r[j]
83
+ subs_dict2[var2[i]] = sm
84
+ raw_result_dict = {}
52
85
  perms = args.perms
53
86
  mult = args.mult
54
87
  ascode = args.ascode
@@ -88,9 +121,9 @@ def _display_full(
88
121
 
89
122
  for i in range(1, 100):
90
123
  if i <= N:
91
- subs_dict[var[i]] = var2[i]
124
+ subs_dict[_vars.var1[i]] = var2[i]
92
125
  else:
93
- subs_dict[var[i]] = var3[i - N]
126
+ subs_dict[_vars.var1[i]] = var3[i - N]
94
127
 
95
128
  coeff_perms.sort(key=lambda x: (inv(x), *x))
96
129
 
@@ -117,6 +150,7 @@ def _display_full(
117
150
  else:
118
151
  width = max([len(str(perm[0]) + " " + str(perm[1])) for perm in perm_pairs])
119
152
 
153
+ subs_dict2 = {}
120
154
  for perm in coeff_perms:
121
155
  val = coeff_dict[perm]
122
156
  downperm = mulperm(list(perm), inverse_kperm)
@@ -132,13 +166,7 @@ def _display_full(
132
166
  secondperm = [downperm[i] - N for i in range(N, len(downperm))]
133
167
  val = sympify(val).subs(subs_dict)
134
168
 
135
- if same and display_positive:
136
- subs_dict2 = {}
137
- for i in range(1, 100):
138
- sm = var2[1]
139
- for j in range(1, i):
140
- sm += var_r[j]
141
- subs_dict2[var2[i]] = sm
169
+ if same and display_positive:
142
170
  val = expand(sympify(val).xreplace(subs_dict2))
143
171
 
144
172
  if val != 0:
@@ -164,7 +192,7 @@ def _display_full(
164
192
  exit(1)
165
193
  val = val2
166
194
  else:
167
- val = 0
195
+ val = 0
168
196
  if val != 0:
169
197
  if not ascode:
170
198
  width2 = (
@@ -172,18 +200,26 @@ def _display_full(
172
200
  - len(str(permtrim(firstperm)))
173
201
  - len(str(permtrim(secondperm)))
174
202
  )
175
- _display(
176
- f"{tuple(permtrim(firstperm))}{' ':>{width2}}{tuple(permtrim(secondperm))} {formatter(val)}"
177
- )
203
+ raw_result_dict[
204
+ (tuple(permtrim(firstperm)), tuple(permtrim(secondperm)))
205
+ ] = val
206
+ if formatter:
207
+ _display(
208
+ f"{tuple(permtrim(firstperm))}{' ':>{width2}}{tuple(permtrim(secondperm))} {formatter(val)}"
209
+ )
178
210
  else:
179
211
  width2 = (
180
212
  width
181
213
  - len(str(trimcode(firstperm)))
182
214
  - len(str(trimcode(secondperm)))
183
215
  )
184
- _display(
185
- f"{trimcode(firstperm)}{' ':>{width2}}{trimcode(secondperm)} {formatter(val)}"
186
- )
216
+ raw_result_dict[
217
+ (tuple(trimcode(firstperm)), tuple(trimcode(secondperm)))
218
+ ] = val
219
+ if formatter:
220
+ _display(
221
+ f"{trimcode(firstperm)}{' ':>{width2}}{trimcode(secondperm)} {formatter(val)}"
222
+ )
187
223
  else:
188
224
  if ascode:
189
225
  width = max([len(str(trimcode(perm))) for perm in coeff_dict.keys()])
@@ -192,7 +228,7 @@ def _display_full(
192
228
 
193
229
  coeff_perms = list(coeff_dict.keys())
194
230
  coeff_perms.sort(key=lambda x: (inv(x), *x))
195
-
231
+
196
232
  for perm in coeff_perms:
197
233
  val = coeff_dict[perm]
198
234
  if val != 0:
@@ -207,7 +243,7 @@ def _display_full(
207
243
  for i in range(1, 100):
208
244
  sm = var2[1]
209
245
  for j in range(1, i):
210
- sm += var_r[j]
246
+ sm += _vars.var_r[j]
211
247
  subs_dict[var2[i]] = sm
212
248
  val = expand(sympify(coeff_dict[perm]).xreplace(subs_dict))
213
249
  else:
@@ -255,14 +291,29 @@ def _display_full(
255
291
  exit(1)
256
292
  if val != 0:
257
293
  if ascode:
258
- _display(f"{str(trimcode(perm)):>{width}} {formatter(val)}")
294
+ raw_result_dict[tuple(trimcode(perm))] = val
295
+ if formatter:
296
+ _display(f"{str(trimcode(perm)):>{width}} {formatter(val)}")
259
297
  else:
260
- _display(f"{str(perm):>{width}} {formatter(val)}")
261
-
262
-
263
- def main():
264
- global var2, var3
298
+ raw_result_dict[tuple(perm)] = val
299
+ if formatter:
300
+ _display(f"{str(perm):>{width}} {formatter(val)}")
301
+ return raw_result_dict
302
+
303
+
304
+ def main(argv=None):
305
+ import logging
306
+
307
+ logging.basicConfig(
308
+ level=logging.ERROR, format="%(asctime)s %(levelname)s %(message) %(module) s"
309
+ )
310
+ logger = logging.getLogger(__name__)
311
+ logger.log(logging.DEBUG, f"main {argv=}")
312
+ if argv is None:
313
+ argv = sys.argv
265
314
  try:
315
+ var2 = tuple(symarray("y", 100).tolist())
316
+ var3 = tuple(symarray("z", 100).tolist())
266
317
  sys.setrecursionlimit(1000000)
267
318
 
268
319
  # TEMP
@@ -271,6 +322,7 @@ def main():
271
322
  args, formatter = schub_argparse(
272
323
  "schubmult_double",
273
324
  "Compute coefficients of product of double Schubert polynomials in the same or different sets of coefficient variables",
325
+ argv=argv[1:],
274
326
  yz=True,
275
327
  )
276
328
 
@@ -287,9 +339,11 @@ def main():
287
339
  display_positive = args.display_positive
288
340
  pr = args.pr
289
341
 
342
+ # logger.log(logging.DEBUG, f"main boing 1 {var2=}{var3=}{same=}")
290
343
  if same:
344
+ # logger.log(logging.DEBUG, f"main OOO {same=}")
291
345
  var3 = var2
292
-
346
+ # logger.log(logging.DEBUG, f"main boing 2 {var2=}{var3=}{same=}")
293
347
  posified = False
294
348
  if coprod:
295
349
  if ascode:
@@ -307,10 +361,11 @@ def main():
307
361
 
308
362
  kperm = inverse(uncode(kcd))
309
363
  coeff_dict = {tuple(kperm): 1}
310
- coeff_dict = schubmult(coeff_dict, perms[0], var, var2)
364
+ coeff_dict = schubmult(coeff_dict, perms[0], _vars.var1, var2)
311
365
 
312
- if pr:
313
- _display_full(
366
+ if pr or formatter is None:
367
+ # logger.log(logging.DEBUG, f"main {var2=}{var3=}{same=}")
368
+ return _display_full(
314
369
  coeff_dict,
315
370
  args,
316
371
  formatter,
@@ -339,13 +394,13 @@ def main():
339
394
  coeff_dict = {perms[0]: 1}
340
395
  check_coeff_dict = {perms[0]: 1}
341
396
 
342
- if mult:
343
- for v in var2:
344
- globals()[str(v)] = v
345
- for v in var3:
346
- globals()[str(v)] = v
347
- for v in var_x:
348
- globals()[str(v)] = v
397
+ # if mult:
398
+ # for v in var2:
399
+ # ()[str(v)] = v
400
+ # for v in var3:
401
+ # globals()[str(v)] = v
402
+ # for v in _vars.var1:
403
+ # globals()[str(v)] = v
349
404
 
350
405
  if down:
351
406
  for perm in orig_perms[1:]:
@@ -356,6 +411,7 @@ def main():
356
411
  else:
357
412
  for perm in orig_perms[1:]:
358
413
  check_coeff_dict = schubmult(check_coeff_dict, perm, var2, var3)
414
+ # coeff_dict = check_coeff_dict
359
415
  if mult:
360
416
  mul_exp = eval(mulstring)
361
417
  check_coeff_dict = mult_poly(check_coeff_dict, mul_exp)
@@ -398,19 +454,21 @@ def main():
398
454
  elif not posified:
399
455
  coeff_dict = check_coeff_dict
400
456
 
401
- if pr:
402
- _display_full(
457
+ if pr or formatter is None:
458
+ raw_result_dict = _display_full(
403
459
  coeff_dict,
404
460
  args,
405
461
  formatter,
462
+ var2,
463
+ var3,
406
464
  posified=posified,
407
465
  check_coeff_dict=check_coeff_dict,
408
- var2=var2,
409
- var3=var3,
410
466
  )
467
+ if formatter is None:
468
+ return raw_result_dict
411
469
  except BrokenPipeError:
412
470
  pass
413
471
 
414
472
 
415
473
  if __name__ == "__main__":
416
- main()
474
+ main(sys.argv)
@@ -1,10 +1,13 @@
1
1
  from ._funcs import (
2
2
  schubmult,
3
- mult_poly
3
+ mult_poly,
4
+ schub_coprod,
4
5
  )
5
6
 
6
7
 
7
8
  __all__ = [
8
9
  "schubmult",
9
- "mult_poly"
10
+ "mult_poly",
11
+ "schub_coprod"
10
12
  ]
13
+
@@ -2,4 +2,4 @@ import sys
2
2
  from ._script import main
3
3
 
4
4
  if __name__ == "__main__":
5
- sys.exit(main())
5
+ sys.exit(main(sys.argv))
@@ -1,4 +1,3 @@
1
- from ._vars import var_x
2
1
  from schubmult.perm_lib import (
3
2
  elem_sym_perms,
4
3
  add_perm_dict,
@@ -10,7 +9,21 @@ from schubmult.perm_lib import (
10
9
  mulperm,
11
10
  uncode,
12
11
  )
13
- from symengine import Add, Mul, Pow
12
+ from symengine import Add, Mul, Pow, symarray
13
+ from functools import cached_property
14
+
15
+
16
+ class _gvars:
17
+ @cached_property
18
+ def n(self):
19
+ return 100
20
+
21
+ @cached_property
22
+ def var_x(self):
23
+ return tuple(symarray("x", self.n).tolist())
24
+
25
+
26
+ _vars = _gvars()
14
27
 
15
28
 
16
29
  def single_variable(coeff_dict, varnum):
@@ -29,7 +42,7 @@ def single_variable(coeff_dict, varnum):
29
42
  return ret
30
43
 
31
44
 
32
- def mult_poly(coeff_dict, poly, var_x=var_x):
45
+ def mult_poly(coeff_dict, poly, var_x=_vars.var_x):
33
46
  if poly in var_x:
34
47
  return single_variable(coeff_dict, var_x.index(poly))
35
48
  elif isinstance(poly, Mul):
@@ -100,12 +113,44 @@ def schubmult(perm_dict, v):
100
113
  if vdiff + udiff == th[index]:
101
114
  if up2 not in newpathsums:
102
115
  newpathsums[up2] = {}
103
- newpathsums[up2][v2] = (
104
- newpathsums[up2].get(v2, 0) + addsumval
105
- )
116
+ newpathsums[up2][v2] = newpathsums[up2].get(v2, 0) + addsumval
106
117
  vpathsums = newpathsums
107
118
  toget = tuple(vmu)
108
- ret_dict = add_perm_dict(
109
- {ep: vpathsums[ep].get(toget, 0) for ep in vpathsums}, ret_dict
110
- )
119
+ ret_dict = add_perm_dict({ep: vpathsums[ep].get(toget, 0) for ep in vpathsums}, ret_dict)
120
+ return ret_dict
121
+
122
+
123
+ def schub_coprod(perm, indices):
124
+ mperm = tuple(perm)
125
+ indices = sorted(indices)
126
+ ret_dict = {}
127
+ k = len(indices)
128
+ n = len(mperm)
129
+ kcd = [indices[i] - i - 1 for i in range(len(indices))] + [n + 1 - k for i in range(k, n)]
130
+ max_required = max([kcd[i] + i for i in range(len(kcd))])
131
+ kcd2 = kcd + [0 for i in range(len(kcd), max_required)] + [0]
132
+ N = len(kcd)
133
+ kperm = permtrim(inverse(uncode(kcd2)))
134
+ coeff_dict = {tuple(kperm): 1}
135
+ coeff_dict = schubmult(coeff_dict, tuple(mperm))
136
+
137
+ inv_kperm = inv(kperm)
138
+ inverse_kperm = inverse(kperm)
139
+ # total_sum = 0
140
+ for perm, val in coeff_dict.items():
141
+ if val == 0:
142
+ continue
143
+ pperm = [*perm]
144
+ downperm = mulperm(pperm, inverse_kperm)
145
+ if inv(downperm) == inv(pperm) - inv_kperm:
146
+ flag = True
147
+ for i in range(N):
148
+ if downperm[i] > N:
149
+ flag = False
150
+ break
151
+ if not flag:
152
+ continue
153
+ firstperm = tuple(permtrim((list(downperm[0:N]))))
154
+ secondperm = tuple(permtrim(([downperm[i] - N for i in range(N, len(downperm))])))
155
+ ret_dict[(firstperm, secondperm)] = val
111
156
  return ret_dict