schubmult 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. schubmult/__init__.py +1 -0
  2. schubmult/_base_argparse.py +174 -0
  3. schubmult/perm_lib.py +999 -0
  4. schubmult/sage_integration/__init__.py +25 -0
  5. schubmult/sage_integration/_fast_double_schubert_polynomial_ring.py +528 -0
  6. schubmult/sage_integration/_fast_schubert_polynomial_ring.py +356 -0
  7. schubmult/sage_integration/_indexing.py +44 -0
  8. schubmult/schubmult_double/__init__.py +18 -0
  9. schubmult/schubmult_double/__main__.py +5 -0
  10. schubmult/schubmult_double/_funcs.py +1590 -0
  11. schubmult/schubmult_double/_script.py +407 -0
  12. schubmult/schubmult_double/_vars.py +16 -0
  13. schubmult/schubmult_py/__init__.py +10 -0
  14. schubmult/schubmult_py/__main__.py +5 -0
  15. schubmult/schubmult_py/_funcs.py +111 -0
  16. schubmult/schubmult_py/_script.py +115 -0
  17. schubmult/schubmult_py/_vars.py +3 -0
  18. schubmult/schubmult_q/__init__.py +12 -0
  19. schubmult/schubmult_q/__main__.py +5 -0
  20. schubmult/schubmult_q/_funcs.py +304 -0
  21. schubmult/schubmult_q/_script.py +157 -0
  22. schubmult/schubmult_q/_vars.py +18 -0
  23. schubmult/schubmult_q_double/__init__.py +14 -0
  24. schubmult/schubmult_q_double/__main__.py +5 -0
  25. schubmult/schubmult_q_double/_funcs.py +507 -0
  26. schubmult/schubmult_q_double/_script.py +337 -0
  27. schubmult/schubmult_q_double/_vars.py +21 -0
  28. schubmult-2.0.0.dist-info/METADATA +455 -0
  29. schubmult-2.0.0.dist-info/RECORD +36 -0
  30. schubmult-2.0.0.dist-info/WHEEL +5 -0
  31. schubmult-2.0.0.dist-info/entry_points.txt +5 -0
  32. schubmult-2.0.0.dist-info/licenses/LICENSE +674 -0
  33. schubmult-2.0.0.dist-info/top_level.txt +2 -0
  34. tests/__init__.py +0 -0
  35. tests/test_fast_double_schubert.py +145 -0
  36. tests/test_fast_schubert.py +38 -0
@@ -0,0 +1,304 @@
1
+ from ._vars import (
2
+ var_x,
3
+ )
4
+ from symengine import Add, Mul, Pow
5
+ from schubmult.perm_lib import (
6
+ elem_sym_perms_q,
7
+ add_perm_dict,
8
+ compute_vpathdicts,
9
+ inverse,
10
+ strict_theta,
11
+ medium_theta,
12
+ permtrim,
13
+ inv,
14
+ mulperm,
15
+ code,
16
+ uncode,
17
+ double_elem_sym_q,
18
+ q_var,
19
+ )
20
+
21
+ # from symengine import sympify, Add, Mul, Pow, symarray, Symbol, expand
22
+ # from schubmult._base_argparse import schub_argparse
23
+ # from schubmult.perm_lib import (
24
+ # trimcode,
25
+ # elem_sym_perms_q,
26
+ # add_perm_dict,
27
+ # compute_vpathdicts,
28
+ # inverse,
29
+ # strict_theta,
30
+ # medium_theta,
31
+ # permtrim,
32
+ # inv,
33
+ # mulperm,
34
+ # code,
35
+ # uncode,
36
+ # double_elem_sym_q,
37
+ # longest_element,
38
+ # check_blocks,
39
+ # is_parabolic,
40
+ # q_vector,
41
+ # omega,
42
+ # count_less_than,
43
+ # q_var,
44
+ # sg,
45
+ # n,
46
+ # )
47
+ # import numpy as np
48
+ # from schubmult.schubmult_q_double import factor_out_q_keep_factored
49
+
50
+
51
+ def single_variable(coeff_dict, varnum, var_q=q_var):
52
+ ret = {}
53
+ for u in coeff_dict:
54
+ new_perms_k = elem_sym_perms_q(u, 1, varnum, var_q)
55
+ new_perms_km1 = []
56
+ if varnum > 1:
57
+ new_perms_km1 = elem_sym_perms_q(u, 1, varnum - 1, var_q)
58
+ for perm, udiff, mul_val in new_perms_k:
59
+ if udiff == 1:
60
+ ret[perm] = ret.get(perm, 0) + coeff_dict[u] * mul_val
61
+ for perm, udiff, mul_val in new_perms_km1:
62
+ if udiff == 1:
63
+ ret[perm] = ret.get(perm, 0) - coeff_dict[u] * mul_val
64
+ return ret
65
+
66
+
67
+ def mult_poly(coeff_dict, poly, var_x=var_x, var_q=q_var):
68
+ if poly in var_x:
69
+ return single_variable(coeff_dict, var_x.index(poly), var_q=var_q)
70
+ elif isinstance(poly, Mul):
71
+ ret = coeff_dict
72
+ for a in poly.args:
73
+ ret = mult_poly(ret, a, var_x, var_q=var_q)
74
+ return ret
75
+ elif isinstance(poly, Pow):
76
+ base = poly.args[0]
77
+ exponent = int(poly.args[1])
78
+ ret = coeff_dict
79
+ for i in range(int(exponent)):
80
+ ret = mult_poly(ret, base, var_x, var_q=var_q)
81
+ return ret
82
+ elif isinstance(poly, Add):
83
+ ret = {}
84
+ for a in poly.args:
85
+ ret = add_perm_dict(ret, mult_poly(coeff_dict, a, var_x, var_q=var_q))
86
+ return ret
87
+ else:
88
+ ret = {}
89
+ for perm in coeff_dict:
90
+ ret[perm] = poly * coeff_dict[perm]
91
+ return ret
92
+
93
+ def schubmult_db(perm_dict, v, q_var=q_var):
94
+ if v == (1, 2):
95
+ return perm_dict
96
+ th = medium_theta(inverse(v))
97
+ if len(th) == 0:
98
+ return perm_dict
99
+ while th[-1] == 0:
100
+ th.pop()
101
+ mu = permtrim(uncode(th))
102
+ vmu = permtrim(mulperm([*v], mu))
103
+ inv_vmu = inv(vmu)
104
+ inv_mu = inv(mu)
105
+ ret_dict = {}
106
+
107
+ thL = len(th)
108
+ # if thL!=2 and len(set(thL))!=1:
109
+ # raise ValueError("Not what I can do")
110
+ vpathdicts = compute_vpathdicts(th, vmu, True)
111
+ # print(f"{vpathdicts=}")
112
+ for u, val in perm_dict.items():
113
+ inv_u = inv(u)
114
+ vpathsums = {u: {(1, 2): val}}
115
+ for index in range(thL):
116
+ if index > 0 and th[index - 1] == th[index]:
117
+ continue
118
+ mx_th = 0
119
+ for vp in vpathdicts[index]:
120
+ for v2, vdiff, s in vpathdicts[index][vp]:
121
+ if th[index] - vdiff > mx_th:
122
+ mx_th = th[index] - vdiff
123
+ if index < len(th) - 1 and th[index] == th[index + 1]:
124
+ mx_th1 = 0
125
+ for vp in vpathdicts[index + 1]:
126
+ for v2, vdiff, s in vpathdicts[index + 1][vp]:
127
+ if th[index + 1] - vdiff > mx_th1:
128
+ mx_th1 = th[index + 1] - vdiff
129
+ newpathsums = {}
130
+ for up in vpathsums:
131
+ newpathsums0 = {}
132
+ inv_up = inv(up)
133
+ newperms = double_elem_sym_q(up, mx_th, mx_th1, th[index], q_var)
134
+ for v in vpathdicts[index]:
135
+ sumval = vpathsums[up].get(v, 0)
136
+ if sumval == 0:
137
+ continue
138
+ for v2, vdiff2, s2 in vpathdicts[index][v]:
139
+ for up1, udiff1, mul_val1 in newperms:
140
+ if (up1, udiff1, mul_val1) not in newpathsums0:
141
+ newpathsums0[(up1, udiff1, mul_val1)] = {}
142
+ if udiff1 + vdiff2 == th[index]:
143
+ newpathsums0[(up1, udiff1, mul_val1)][v2] = (
144
+ newpathsums0[(up1, udiff1, mul_val1)].get(
145
+ v2, 0
146
+ )
147
+ + s2 * sumval * mul_val1
148
+ )
149
+
150
+ for up1, udiff1, mul_val1 in newpathsums0:
151
+ for v in vpathdicts[index + 1]:
152
+ sumval = newpathsums0[(up1, udiff1, mul_val1)].get(v, 0)
153
+ if sumval == 0:
154
+ continue
155
+ for v2, vdiff2, s2 in vpathdicts[index + 1][v]:
156
+ for up2, udiff2, mul_val2 in newperms[
157
+ (up1, udiff1, mul_val1)
158
+ ]:
159
+ if up2 not in newpathsums:
160
+ newpathsums[up2] = {}
161
+ if udiff2 + vdiff2 == th[index + 1]:
162
+ newpathsums[up2][v2] = (
163
+ newpathsums[up2].get(v2, 0)
164
+ + s2 * sumval * mul_val2
165
+ )
166
+ else:
167
+ newpathsums = {}
168
+ for up in vpathsums:
169
+ inv_up = inv(up)
170
+ newperms = elem_sym_perms_q(
171
+ up,
172
+ min(mx_th, (inv_mu - (inv_up - inv_u)) - inv_vmu),
173
+ th[index],
174
+ q_var,
175
+ )
176
+ for up2, udiff, mul_val in newperms:
177
+ if up2 not in newpathsums:
178
+ newpathsums[up2] = {}
179
+ for v in vpathdicts[index]:
180
+ sumval = vpathsums[up].get(v, 0)
181
+ if sumval == 0:
182
+ continue
183
+ for v2, vdiff, s in vpathdicts[index][v]:
184
+ if udiff + vdiff == th[index]:
185
+ newpathsums[up2][v2] = (
186
+ newpathsums[up2].get(v2, 0)
187
+ + s * sumval * mul_val
188
+ )
189
+ vpathsums = newpathsums
190
+ toget = tuple(vmu)
191
+ ret_dict = add_perm_dict(
192
+ {ep: vpathsums[ep].get(toget, 0) for ep in vpathsums}, ret_dict
193
+ )
194
+ return ret_dict
195
+
196
+
197
+ def schubmult(perm_dict, v):
198
+ th = strict_theta(inverse(v))
199
+ mu = permtrim(uncode(th))
200
+ vmu = permtrim(mulperm([*v], mu))
201
+ inv_vmu = inv(vmu)
202
+ inv_mu = inv(mu)
203
+ ret_dict = {}
204
+ if len(th) == 0:
205
+ return perm_dict
206
+ while th[-1] == 0:
207
+ th.pop()
208
+ thL = len(th)
209
+ vpathdicts = compute_vpathdicts(th, vmu, True)
210
+ for u, val in perm_dict.items():
211
+ inv_u = inv(u)
212
+ vpathsums = {u: {(1, 2): val}}
213
+ for index in range(thL):
214
+ mx_th = 0
215
+ for vp in vpathdicts[index]:
216
+ for v2, vdiff, s in vpathdicts[index][vp]:
217
+ if th[index] - vdiff > mx_th:
218
+ mx_th = th[index] - vdiff
219
+ newpathsums = {}
220
+ for up in vpathsums:
221
+ inv_up = inv(up)
222
+ newperms = elem_sym_perms_q(
223
+ up, min(mx_th, (inv_mu - (inv_up - inv_u)) - inv_vmu), th[index]
224
+ )
225
+ for up2, udiff, mul_val in newperms:
226
+ if up2 not in newpathsums:
227
+ newpathsums[up2] = {}
228
+ for v in vpathdicts[index]:
229
+ sumval = vpathsums[up].get(v, 0)
230
+ if sumval == 0:
231
+ continue
232
+ for v2, vdiff, s in vpathdicts[index][v]:
233
+ if udiff + vdiff == th[index]:
234
+ newpathsums[up2][v2] = (
235
+ newpathsums[up2].get(v2, 0)
236
+ + s * sumval * mul_val
237
+ )
238
+ vpathsums = newpathsums
239
+ toget = tuple(vmu)
240
+ ret_dict = add_perm_dict(
241
+ {ep: vpathsums[ep].get(toget, 0) for ep in vpathsums}, ret_dict
242
+ )
243
+ return ret_dict
244
+
245
+
246
+
247
+
248
+ def grass_q_replace(perm, k, d, n):
249
+ if k - d < 0:
250
+ return None
251
+ cd = code(perm)
252
+ for i in range(k - d, k):
253
+ if i >= len(cd) or cd[i] < d:
254
+ return None
255
+ grass_rep = [0 for i in range(n)]
256
+ perm2 = [*perm] + [i + 1 for i in range(len(perm), n)]
257
+ for i in range(k, n):
258
+ grass_rep[perm2[i] - 1] = 2
259
+ num_0 = 0
260
+ # print(f"{grass_rep=} {d=}")
261
+ for i in range(len(grass_rep) - 1, -1, -1):
262
+ if num_0 == d:
263
+ break
264
+ if grass_rep[i] == 0:
265
+ grass_rep[i] = 1
266
+ num_0 += 1
267
+ num_2 = 0
268
+ for i in range(len(grass_rep)):
269
+ if num_2 == d:
270
+ break
271
+ if grass_rep[i] == 2:
272
+ grass_rep[i] = 1
273
+ num_2 += 1
274
+ # print(f"New {grass_rep=}")
275
+ k1 = k - d
276
+ k2 = k + d
277
+ pos_1 = 0
278
+ pos_2 = 0
279
+ pos_3 = 0
280
+ new_perm = [0 for i in range(n)]
281
+ for i in range(len(grass_rep)):
282
+ if grass_rep[i] == 0:
283
+ new_perm[pos_1] = i + 1
284
+ pos_1 += 1
285
+ if grass_rep[i] == 1:
286
+ new_perm[k1 + pos_2] = i + 1
287
+ pos_2 += 1
288
+ if grass_rep[i] == 2:
289
+ new_perm[k2 + pos_3] = i + 1
290
+ pos_3 += 1
291
+ return tuple(permtrim(new_perm))
292
+
293
+
294
+ def to_two_step(perm, k1, k2, n):
295
+ rep = [0 for i in range(n)]
296
+ perm2 = [*perm] + [i + 1 for i in range(len(perm), n)]
297
+ for i in range(n):
298
+ if i < k1:
299
+ rep[perm2[i] - 1] = 0
300
+ elif i < k2:
301
+ rep[perm2[i] - 1] = 1
302
+ else:
303
+ rep[perm2[i] - 1] = 2
304
+ return rep
@@ -0,0 +1,157 @@
1
+ from ._funcs import (
2
+ schubmult,
3
+ schubmult_db,
4
+ mult_poly,
5
+ )
6
+ from symengine import sympify
7
+ from schubmult._base_argparse import schub_argparse
8
+ from schubmult.perm_lib import (
9
+ trimcode,
10
+ permtrim,
11
+ inv,
12
+ mulperm,
13
+ uncode,
14
+ longest_element,
15
+ check_blocks,
16
+ is_parabolic,
17
+ q_vector,
18
+ omega,
19
+ count_less_than,
20
+ q_var,
21
+ sg,
22
+ )
23
+ import numpy as np
24
+ from schubmult.schubmult_q_double import factor_out_q_keep_factored
25
+
26
+
27
+ def _display_full(coeff_dict, args, formatter):
28
+ ascode = args.ascode
29
+ parabolic_index = [int(s) for s in args.parabolic]
30
+ parabolic = len(parabolic_index) != 0
31
+
32
+ if parabolic:
33
+ w_P = longest_element(parabolic_index)
34
+ w_P_prime = [1, 2]
35
+ coeff_dict_update = {}
36
+ for w_1 in coeff_dict:
37
+ val = coeff_dict[w_1]
38
+ q_dict = factor_out_q_keep_factored(val)
39
+ for q_part in q_dict:
40
+ qv = q_vector(q_part)
41
+ w = [*w_1]
42
+ good = True
43
+ parabolic_index2 = []
44
+ for i in range(len(parabolic_index)):
45
+ if omega(parabolic_index[i], qv) == 0:
46
+ parabolic_index2 += [parabolic_index[i]]
47
+ elif omega(parabolic_index[i], qv) != -1:
48
+ good = False
49
+ break
50
+ if not good:
51
+ continue
52
+ w_P_prime = longest_element(parabolic_index2)
53
+ if not check_blocks(qv, parabolic_index):
54
+ continue
55
+ w = permtrim(mulperm(mulperm(w, w_P_prime), w_P))
56
+ if not is_parabolic(w, parabolic_index):
57
+ continue
58
+
59
+ w = tuple(permtrim(w))
60
+
61
+ new_q_part = np.prod(
62
+ [
63
+ q_var[index + 1 - count_less_than(parabolic_index, index + 1)] ** qv[index]
64
+ for index in range(len(qv))
65
+ if index + 1 not in parabolic_index
66
+ ]
67
+ )
68
+
69
+ try:
70
+ new_q_part = int(new_q_part)
71
+ except Exception:
72
+ pass
73
+ q_val_part = q_dict[q_part]
74
+ coeff_dict_update[w] = coeff_dict_update.get(w, 0) + new_q_part * q_val_part
75
+ coeff_dict = coeff_dict_update
76
+
77
+ coeff_perms = list(coeff_dict.keys())
78
+ coeff_perms.sort(key=lambda x: (inv(x), *x))
79
+
80
+ for perm in coeff_perms:
81
+ val = sympify(coeff_dict[perm]).expand()
82
+ if val != 0:
83
+ if ascode:
84
+ print(f"{str(trimcode(perm))} {formatter(val)}")
85
+ else:
86
+ print(f"{str(perm)} {formatter(val)}")
87
+
88
+
89
+ def main():
90
+ try:
91
+ args, formatter = schub_argparse(
92
+ "schubmult_q",
93
+ "Compute products of quantum Schubert polynomials",
94
+ quantum=True,
95
+ )
96
+
97
+ mulstring = ""
98
+
99
+ mult = False
100
+ if args.mult is not None:
101
+ mult = True
102
+ mulstring = " ".join(args.mult)
103
+
104
+ perms = args.perms
105
+
106
+ for perm in perms:
107
+ try:
108
+ for i in range(len(perm)):
109
+ perm[i] = int(perm[i])
110
+ except Exception as e:
111
+ print("Permutations must have integer values")
112
+ raise e
113
+
114
+ ascode = args.ascode
115
+ pr = args.pr
116
+ parabolic_index = [int(s) for s in args.parabolic]
117
+ parabolic = len(parabolic_index) != 0
118
+ slow = args.slow
119
+
120
+ if parabolic and len(perms) != 2:
121
+ print("Only two permutations supported for parabolic.")
122
+ exit(1)
123
+
124
+ if ascode:
125
+ for i in range(len(perms)):
126
+ perms[i] = uncode(perms[i])
127
+
128
+ if parabolic:
129
+ for i in range(len(parabolic_index)):
130
+ index = parabolic_index[i] - 1
131
+ if sg(index, perms[0]) == 1 or sg(index, perms[1]) == 1:
132
+ print(
133
+ "Parabolic given but elements are not minimal length coset representatives."
134
+ )
135
+ exit(1)
136
+
137
+ coeff_dict = {tuple(permtrim([*perms[0]])): 1}
138
+
139
+ if not slow:
140
+ for perm in perms[1:]:
141
+ coeff_dict = schubmult_db(coeff_dict, tuple(permtrim([*perm])))
142
+ else:
143
+ for perm in perms[1:]:
144
+ coeff_dict = schubmult(coeff_dict, tuple(permtrim([*perm])))
145
+
146
+ if mult:
147
+ mul_exp = sympify(mulstring)
148
+ coeff_dict = mult_poly(coeff_dict, mul_exp)
149
+
150
+ if pr:
151
+ _display_full(coeff_dict, args, formatter)
152
+ except BrokenPipeError:
153
+ pass
154
+
155
+
156
+ if __name__ == "__main__":
157
+ main()
@@ -0,0 +1,18 @@
1
+ from symengine import symarray
2
+
3
+ n = 100
4
+
5
+ var = symarray("x", n)
6
+ var2 = symarray("y", n)
7
+ var3 = var2
8
+ var_r = symarray("r", n)
9
+
10
+ subs_dict = {}
11
+
12
+ var_x = symarray("x", 100).tolist()
13
+
14
+ for i in range(1, n):
15
+ sm = var_r[0]
16
+ for j in range(1, i):
17
+ sm += var_r[j]
18
+ subs_dict[var2[i]] = sm
@@ -0,0 +1,14 @@
1
+ from ._funcs import (
2
+ schubmult,
3
+ schubmult_db,
4
+ mult_poly,
5
+ factor_out_q_keep_factored
6
+ )
7
+
8
+
9
+ __all__ = [
10
+ "schubmult",
11
+ "schubmult_db",
12
+ "mult_poly",
13
+ "factor_out_q_keep_factored"
14
+ ]
@@ -0,0 +1,5 @@
1
+ import sys
2
+ from ._script import main
3
+
4
+ if __name__ == "__main__":
5
+ sys.exit(main())