schubmult 2.0.4__py3-none-any.whl → 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. schubmult/__init__.py +96 -1
  2. schubmult/perm_lib.py +254 -819
  3. schubmult/poly_lib/__init__.py +31 -0
  4. schubmult/poly_lib/poly_lib.py +276 -0
  5. schubmult/poly_lib/schub_poly.py +148 -0
  6. schubmult/poly_lib/variables.py +204 -0
  7. schubmult/rings/__init__.py +18 -0
  8. schubmult/rings/_quantum_schubert_polynomial_ring.py +752 -0
  9. schubmult/rings/_schubert_polynomial_ring.py +1031 -0
  10. schubmult/rings/_tensor_schub_ring.py +128 -0
  11. schubmult/rings/_utils.py +55 -0
  12. schubmult/{sage_integration → sage}/__init__.py +4 -1
  13. schubmult/{sage_integration → sage}/_fast_double_schubert_polynomial_ring.py +67 -109
  14. schubmult/{sage_integration → sage}/_fast_schubert_polynomial_ring.py +33 -28
  15. schubmult/{sage_integration → sage}/_indexing.py +9 -5
  16. schubmult/schub_lib/__init__.py +51 -0
  17. schubmult/{schubmult_double/_funcs.py → schub_lib/double.py} +532 -596
  18. schubmult/{schubmult_q/_funcs.py → schub_lib/quantum.py} +54 -53
  19. schubmult/schub_lib/quantum_double.py +954 -0
  20. schubmult/schub_lib/schub_lib.py +659 -0
  21. schubmult/{schubmult_py/_funcs.py → schub_lib/single.py} +45 -35
  22. schubmult/schub_lib/tests/__init__.py +0 -0
  23. schubmult/schub_lib/tests/legacy_perm_lib.py +946 -0
  24. schubmult/schub_lib/tests/test_vs_old.py +109 -0
  25. schubmult/scripts/__init__.py +0 -0
  26. schubmult/scripts/schubmult_double.py +378 -0
  27. schubmult/scripts/schubmult_py.py +84 -0
  28. schubmult/scripts/schubmult_q.py +109 -0
  29. schubmult/scripts/schubmult_q_double.py +207 -0
  30. schubmult/utils/__init__.py +0 -0
  31. schubmult/{_base_argparse.py → utils/argparse.py} +29 -5
  32. schubmult/utils/logging.py +16 -0
  33. schubmult/utils/parsing.py +20 -0
  34. schubmult/utils/perm_utils.py +135 -0
  35. schubmult/utils/test_utils.py +65 -0
  36. schubmult-3.0.1.dist-info/METADATA +1234 -0
  37. schubmult-3.0.1.dist-info/RECORD +41 -0
  38. {schubmult-2.0.4.dist-info → schubmult-3.0.1.dist-info}/WHEEL +1 -1
  39. schubmult-3.0.1.dist-info/entry_points.txt +5 -0
  40. schubmult/_tests.py +0 -24
  41. schubmult/schubmult_double/__init__.py +0 -12
  42. schubmult/schubmult_double/__main__.py +0 -6
  43. schubmult/schubmult_double/_script.py +0 -474
  44. schubmult/schubmult_py/__init__.py +0 -12
  45. schubmult/schubmult_py/__main__.py +0 -6
  46. schubmult/schubmult_py/_script.py +0 -97
  47. schubmult/schubmult_q/__init__.py +0 -8
  48. schubmult/schubmult_q/__main__.py +0 -6
  49. schubmult/schubmult_q/_script.py +0 -166
  50. schubmult/schubmult_q_double/__init__.py +0 -10
  51. schubmult/schubmult_q_double/__main__.py +0 -6
  52. schubmult/schubmult_q_double/_funcs.py +0 -540
  53. schubmult/schubmult_q_double/_script.py +0 -396
  54. schubmult-2.0.4.dist-info/METADATA +0 -542
  55. schubmult-2.0.4.dist-info/RECORD +0 -30
  56. schubmult-2.0.4.dist-info/entry_points.txt +0 -5
  57. {schubmult-2.0.4.dist-info → schubmult-3.0.1.dist-info}/licenses/LICENSE +0 -0
  58. {schubmult-2.0.4.dist-info → schubmult-3.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,31 @@
1
+ # from .poly_lib import (
2
+ # call_zvars,
3
+ # efficient_subs,
4
+ # elem_sym_func,
5
+ # elem_sym_func_q,
6
+ # elem_sym_poly,
7
+ # elem_sym_poly_q,
8
+ # expand,
9
+ # q_vector,
10
+ # xreplace_genvars,
11
+ # )
12
+ # from .schub_poly import div_diff, perm_act, schubpoly, skew_div_diff
13
+ # from .variables import GeneratingSet, base_index
14
+
15
+ # __all__ = [
16
+ # "GeneratingSet",
17
+ # "base_index",
18
+ # "call_zvars",
19
+ # "div_diff",
20
+ # "efficient_subs",
21
+ # "elem_sym_func",
22
+ # "elem_sym_func_q",
23
+ # "elem_sym_poly",
24
+ # "elem_sym_poly_q",
25
+ # "expand",
26
+ # "perm_act",
27
+ # "q_vector",
28
+ # "schubpoly",
29
+ # "skew_div_diff",
30
+ # "xreplace_genvars",
31
+ # ]
@@ -0,0 +1,276 @@
1
+ from functools import cache, cached_property
2
+
3
+ import symengine
4
+ import sympy
5
+ from symengine import Mul, Pow, sympify
6
+
7
+ import schubmult.perm_lib as pl
8
+ import schubmult.poly_lib.variables as vv
9
+
10
+ # import vv.GeneratingSet, vv.base_index
11
+
12
+
13
+ # Indexed._sympystr = lambda x, p: f"{p.doprint(x.args[0])}_{x.args[1]}"
14
+ def expand(val):
15
+ return symengine.expand(val)
16
+
17
+
18
+ class _gvars:
19
+ @cached_property
20
+ def n(self):
21
+ return 100
22
+
23
+ # @cached_property
24
+ # def fvar(self):
25
+ # return 100
26
+
27
+ @cached_property
28
+ def var1(self):
29
+ return vv.GeneratingSet("x")
30
+
31
+ @cached_property
32
+ def var2(self):
33
+ return vv.GeneratingSet("y")
34
+
35
+ @cached_property
36
+ def var3(self):
37
+ return vv.GeneratingSet("z")
38
+
39
+ @cached_property
40
+ def var_r(self):
41
+ return vv.GeneratingSet("r")
42
+
43
+ @cached_property
44
+ def var_g1(self):
45
+ return vv.GeneratingSet("y")
46
+
47
+ @cached_property
48
+ def var_g2(self):
49
+ return vv.GeneratingSet("z")
50
+
51
+ @cached_property
52
+ def q_var(self):
53
+ return vv.GeneratingSet("q")
54
+
55
+
56
+ zero = sympify(0)
57
+
58
+ _vars = _gvars()
59
+
60
+ def sv_posify(val, var2):
61
+ var_r = vv.GeneratingSet("r")
62
+ subs_dict = {}
63
+ for i in range(1, 100):
64
+ sm = sympify(var2[1])
65
+ for j in range(1, i):
66
+ sm += var_r[j]
67
+ subs_dict[sympify(var2[i])] = sm
68
+ val = sympify(sympy.simplify(efficient_subs(sympify(val), subs_dict)))
69
+ bingle_dict = {}
70
+ for i in range(1, len(var_r) - 1):
71
+ bingle_dict[var_r[i]] = var2[i + 1] - var2[i] # sympy.Add(*[_vars.var2[i+1], - _vars.var2[i]],evaluate=False)
72
+ return val.xreplace(bingle_dict)
73
+
74
+ def act(w, poly, genset):
75
+ if not isinstance(w, pl.Permutation):
76
+ w = pl.Permutation(w)
77
+ subs_dict = {}
78
+ if not isinstance(genset, vv.GeneratingSet_base):
79
+ genset = vv.CustomGeneratingSet(genset)
80
+ for s in poly.free_symbols:
81
+ if genset.index(s) != -1:
82
+ subs_dict[s] = genset[w(genset.index(s))]
83
+ return efficient_subs(poly, subs_dict)
84
+
85
+ def elem_sym_func(k, i, u1, u2, v1, v2, udiff, vdiff, varl1, varl2):
86
+ newk = k - udiff
87
+ if newk < vdiff:
88
+ return zero
89
+ if newk == vdiff:
90
+ return one
91
+ yvars = []
92
+ for j in range(min(len(u1), k)):
93
+ if u1[j] == u2[j]:
94
+ yvars += [varl1[u2[j]]]
95
+ for j in range(len(u1), min(k, len(u2))):
96
+ if u2[j] == j + 1:
97
+ yvars += [varl1[u2[j]]]
98
+ for j in range(len(u2), k):
99
+ yvars += [varl1[j + 1]]
100
+ zvars = [varl2[i] for i in call_zvars(v1, v2, k, i)]
101
+ return elem_sym_poly(newk - vdiff, newk, yvars, zvars)
102
+
103
+
104
+ # def elem_sym_func_q(k, i, u1, u2, v1, v2, udiff, vdiff, varl1, varl2):
105
+ # newk = k - udiff
106
+ # if newk < vdiff:
107
+ # return zero
108
+ # if newk == vdiff:
109
+ # return one
110
+ # yvars = []
111
+ # mlen = max(len(u1), len(u2))
112
+ # u1 = [*u1] + [a + 1 for a in range(len(u1), mlen)]
113
+ # u2 = [*u2] + [a + 1 for a in range(len(u2), mlen)]
114
+ # for j in range(min(len(u1), k)):
115
+ # if u1[j] == u2[j]:
116
+ # yvars += [varl1[u2[j]]]
117
+ # for j in range(len(u1), min(k, len(u2))):
118
+ # if u2[j] == j + 1:
119
+ # yvars += [varl1[u2[j]]]
120
+ # for j in range(len(u2), k):
121
+ # yvars += [varl1[j + 1]]
122
+ # zvars = [varl2[a] for a in call_zvars(v1, v2, k, i)]
123
+ # return elem_sym_poly(newk - vdiff, newk, yvars, zvars)
124
+
125
+
126
+ def elem_sym_func_q(k, i, u1, u2, v1, v2, udiff, vdiff, varl1, varl2):
127
+ newk = k - udiff
128
+ if newk < vdiff:
129
+ return zero
130
+ if newk == vdiff:
131
+ return one
132
+ yvars = []
133
+ # print(f"{u1=} {u2=} {max(len(u1),len(u2))=}")
134
+ # print(f"{u1=} {u2=} {max(len(u1),len(u2))=}")
135
+ # print(f"{k=}")
136
+ # u1 = [*u1] + [a + 1 for a in range(len(u1), mlen)]
137
+ # u2 = [*u2] + [a + 1 for a in range(len(u2), mlen)]
138
+ for j in range(k):
139
+ if u1[j] == u2[j]:
140
+ yvars += [varl1[u2[j]]]
141
+ # print(f"{yvars=}")
142
+ # for j in range(len(u1), min(k, len(u2))):
143
+ # if u2[j] == j + 1:
144
+ # yvars += [varl1[u2[j]]]
145
+ # for j in range(len(u2), k):
146
+ # yvars += [varl1[j + 1]]
147
+ zvars = [varl2[a] for a in call_zvars(v1, v2, k, i)]
148
+ return elem_sym_poly(newk - vdiff, newk, yvars, zvars)
149
+
150
+
151
+ one = sympify(1)
152
+
153
+
154
+ def elem_sym_poly_q(p, k, varl1, varl2, q_var=_vars.q_var):
155
+ if p == 0 and k >= 0:
156
+ return one
157
+ if p < 0 or p > k:
158
+ return zero
159
+ return (
160
+ (varl1[k - 1] - varl2[k - p]) * elem_sym_poly_q(p - 1, k - 1, varl1, varl2, q_var)
161
+ + elem_sym_poly_q(p, k - 1, varl1, varl2, q_var)
162
+ + q_var[k - 1] * elem_sym_poly_q(p - 2, k - 2, varl1, varl2, q_var)
163
+ )
164
+
165
+
166
+ def complete_sym_poly(p, k, vrs):
167
+ if p == 0 and k>=0:
168
+ return 1
169
+ if p != 0 and k == 0:
170
+ return 0
171
+ if k < 0:
172
+ return 0
173
+ if k == 1:
174
+ return vrs[0]**p
175
+ sm = 0
176
+ mid = k // 2
177
+ for i in range(p + 1):
178
+ sm += complete_sym_poly(i, mid, vrs[:mid])*complete_sym_poly(p-i, k-mid, vrs[mid:])
179
+ return sm
180
+
181
+
182
+
183
+ def elem_sym_poly(p, k, varl1, varl2, xstart=0, ystart=0):
184
+ if p > k:
185
+ return zero
186
+ if p == 0:
187
+ return one
188
+ if p == 1:
189
+ res = varl1[xstart] - varl2[ystart]
190
+ for i in range(1, k):
191
+ res += varl1[xstart + i] - varl2[ystart + i]
192
+ return res
193
+ if p == k:
194
+ res = (varl1[xstart] - varl2[ystart]) * (varl1[xstart + 1] - varl2[ystart])
195
+ for i in range(2, k):
196
+ res *= varl1[i + xstart] - varl2[ystart]
197
+ return res
198
+ mid = k // 2
199
+ xsm = xstart + mid
200
+ ysm = ystart + mid
201
+ kmm = k - mid
202
+ res = elem_sym_poly(p, mid, varl1, varl2, xstart, ystart) + elem_sym_poly(
203
+ p,
204
+ kmm,
205
+ varl1,
206
+ varl2,
207
+ xsm,
208
+ ysm,
209
+ )
210
+ for p2 in range(max(1, p - kmm), min(p, mid + 1)):
211
+ res += elem_sym_poly(p2, mid, varl1, varl2, xstart, ystart) * elem_sym_poly(
212
+ p - p2,
213
+ kmm,
214
+ varl1,
215
+ varl2,
216
+ xsm,
217
+ ysm - p2,
218
+ )
219
+ return res
220
+
221
+
222
+ # def call_zvars(v1, v2, k, i):
223
+ # v3 = [*v2, *list(range(len(v2) + 1, i + 1))]
224
+ # return [v3[i - 1]] + [v3[j] for j in range(len(v1), len(v3)) if v3[j] != j + 1 and j != i - 1] + [v3[j] for j in range(len(v1)) if v1[j] != v3[j] and j != i - 1]
225
+
226
+
227
+ @cache
228
+ def call_zvars(v1, v2, k, i): # noqa: ARG001
229
+ return [v2[i - 1]] + [v2[j] for j in range(len(v1), len(v2) + max(0, i - len(v2))) if v2[j] != j + 1 and j != i - 1] + [v2[j] for j in range(len(v1)) if v1[j] != v2[j] and j != i - 1]
230
+
231
+
232
+ def efficient_subs(expr, subs_dict):
233
+ subs_dict_new = {}
234
+ expr = sympify(expr)
235
+ for s in expr.free_symbols:
236
+ if s in subs_dict:
237
+ subs_dict_new[s] = subs_dict[s]
238
+ return expr.subs(subs_dict_new)
239
+
240
+
241
+ def q_vector(q_exp, q_var=_vars.q_var):
242
+ # qvar_list = q_var.tolist()
243
+ ret = []
244
+
245
+ if q_exp == 1:
246
+ return ret
247
+ if q_var.index(q_exp) != -1:
248
+ i = q_var.index(q_exp)
249
+ return [0 for j in range(i - 1)] + [1]
250
+ if isinstance(q_exp, Pow):
251
+ qv = q_exp.args[0]
252
+ expon = int(q_exp.args[1])
253
+ i = q_var.index(qv)
254
+ if i == -1:
255
+ raise IndexError
256
+ return [0 for j in range(i - 1)] + [expon]
257
+ if isinstance(q_exp, Mul):
258
+ for a in q_exp.args:
259
+ v1 = q_vector(a)
260
+ v1 += [0 for i in range(len(v1), len(ret))]
261
+ ret += [0 for i in range(len(ret), len(v1))]
262
+ ret = [ret[i] + v1[i] for i in range(len(ret))]
263
+ return ret
264
+
265
+ return None
266
+
267
+
268
+ def xreplace_genvars(poly, vars1, vars2):
269
+ subs_dict = {}
270
+ for s in sympify(poly).free_symbols:
271
+ if _vars.var_g1.index(s) != -1:
272
+ subs_dict[s] = vars1[_vars.var_g1.index(s)]
273
+ elif _vars.var_g2.index(s) != -1:
274
+ subs_dict[s] = vars2[_vars.var_g2.index(s)]
275
+ return sympify(poly).xreplace(subs_dict)
276
+ # print(f"{poly2=} {poly2.free_symbols=}")
@@ -0,0 +1,148 @@
1
+ import sympy
2
+
3
+ import schubmult.perm_lib as pl
4
+ import schubmult.schub_lib.schub_lib as schub_lib
5
+ from schubmult.poly_lib.poly_lib import call_zvars
6
+ from schubmult.poly_lib.variables import GeneratingSet
7
+
8
+
9
+ def perm_act(val, i, var2=None):
10
+ subsdict = {var2[i]: var2[i + 1], var2[i + 1]: var2[i]}
11
+ return sympy.sympify(val).subs(subsdict)
12
+
13
+
14
+ def div_diff(i, poly, var2=None):
15
+ return sympy.sympify(
16
+ sympy.div(sympy.sympify(poly - perm_act(poly, i)), sympy.sympify(var2[i] - var2[i + 1]))[0],
17
+ )
18
+
19
+
20
+ def elem_func_func(k, i, v1, v2, vdiff, varl1, varl2, elem_func):
21
+ newk = k
22
+ if newk < vdiff:
23
+ return 0
24
+ if newk == vdiff:
25
+ return 1
26
+ zvars = [varl2[i] for i in call_zvars(v1, v2, k, i)]
27
+ return elem_func(newk - vdiff, newk, varl1[1:], zvars)
28
+
29
+
30
+ def schubpoly_from_elems(v, var_x=None, var_y=None, elem_func=None, mumu=None):
31
+ if mumu:
32
+ # print(pl.code(mumu))
33
+ th = pl.code(mumu)
34
+ # print(f"{th=}")
35
+ mu = mumu
36
+ else:
37
+ th = pl.strict_theta(~pl.Permutation(v))
38
+ mu = pl.uncode(th)
39
+ vmu = pl.Permutation(v) * mu # permtrim(mulperm([*v], mu))
40
+ if len(th) == 0:
41
+ return elem_func(0, 0, var_x, var_y)
42
+ while len(th) > 0 and th[-1] == 0:
43
+ th.pop()
44
+ vpathdicts = schub_lib.compute_vpathdicts(th, vmu)
45
+ vpathsums = {pl.Permutation([1, 2]): elem_func(0, 0, var_x, var_y)}
46
+ for index in range(len(th)):
47
+ mx_th = 0
48
+ newpathsums = {}
49
+ for vp in vpathdicts[index]:
50
+ for v2, vdiff, s in vpathdicts[index][vp]:
51
+ mx_th = max(mx_th, th[index] - vdiff)
52
+ for v in vpathdicts[index]:
53
+ sumval = vpathsums.get(v, 0)
54
+ if sumval == 0:
55
+ continue
56
+ for v2, vdiff, s in vpathdicts[index][v]:
57
+ newpathsums[v2] = newpathsums.get(
58
+ v2,
59
+ 0,
60
+ ) + s * sumval * elem_func_func(
61
+ th[index],
62
+ index + 1,
63
+ v,
64
+ v2,
65
+ vdiff,
66
+ var_x,
67
+ var_y,
68
+ elem_func=elem_func,
69
+ )
70
+ vpathsums = newpathsums
71
+ return vpathsums.get(vmu, 0)
72
+
73
+ def schubpoly_classical_from_elems(v, var_x=None, var_y=None, elem_func=None):
74
+ th = pl.theta(~pl.Permutation(v))
75
+ mu = pl.uncode(th)
76
+ vmu = pl.Permutation(v) * mu # permtrim(mulperm([*v], mu))
77
+ if len(th) == 0:
78
+ return elem_func(0, 0, var_x, var_y)
79
+ while len(th) > 0 and th[-1] == 0:
80
+ th.pop()
81
+ vpathdicts = schub_lib.compute_vpathdicts(th, vmu)
82
+ vpathsums = {pl.Permutation([1, 2]): elem_func(0, 0, var_x, var_y)}
83
+ for index in range(len(th)):
84
+ mx_th = 0
85
+ newpathsums = {}
86
+ for vp in vpathdicts[index]:
87
+ for v2, vdiff, s in vpathdicts[index][vp]:
88
+ mx_th = max(mx_th, th[index] - vdiff)
89
+ for v in vpathdicts[index]:
90
+ sumval = vpathsums.get(v, 0)
91
+ if sumval == 0:
92
+ continue
93
+ for v2, vdiff, s in vpathdicts[index][v]:
94
+ newpathsums[v2] = newpathsums.get(
95
+ v2,
96
+ 0,
97
+ ) + s * sumval * elem_func_func(
98
+ th[index],
99
+ index + 1,
100
+ v,
101
+ v2,
102
+ vdiff,
103
+ var_x,
104
+ var_y,
105
+ elem_func=elem_func,
106
+ )
107
+ vpathsums = newpathsums
108
+ return vpathsums.get(vmu, 0)
109
+
110
+
111
+ def schubpoly(v, var2=GeneratingSet("y"), var3=GeneratingSet("z"), start_var=1):
112
+ n = 0
113
+ for j in range(len(v) - 2, -1, -1):
114
+ if v[j] > v[j + 1]:
115
+ n = j + 1
116
+ break
117
+ if n == 0:
118
+ return 1
119
+ lst = schub_lib.pull_out_var(n, v)
120
+ ret = 0
121
+ for pw, vp in lst:
122
+ tomul = 1
123
+ for p in pw:
124
+ tomul *= var2[start_var + n - 1] - var3[p]
125
+ ret += tomul * schubpoly(vp, var2, var3, start_var)
126
+ return ret
127
+
128
+
129
+ def skew_div_diff(u, w, poly):
130
+ d = -1
131
+ for i in range(len(w) - 1):
132
+ if w[i] > w[i + 1]:
133
+ d = i
134
+ break
135
+ d2 = -1
136
+ for i in range(len(u) - 1):
137
+ if u[i] > u[i + 1]:
138
+ d2 = i
139
+ break
140
+ if d == -1:
141
+ if d2 == -1:
142
+ return poly
143
+ return 0
144
+ w2 = w.swap(d, d + 1)
145
+ if d < len(u) - 1 and u[d] > u[d + 1]:
146
+ u2 = u.swap(d, d + 1)
147
+ return skew_div_diff(u2, w2, perm_act(poly, d + 1))
148
+ return skew_div_diff(u, w2, div_diff(d + 1, poly))
@@ -0,0 +1,204 @@
1
+ # class generators with base
2
+ # symbols cls argument!
3
+
4
+ import re
5
+ from bisect import bisect_left
6
+ from functools import cache
7
+ from typing import ClassVar
8
+
9
+ from symengine import SympifyError, symbols, sympify
10
+ from sympy import Basic, Tuple
11
+ from sympy.core.symbol import Str
12
+
13
+ from schubmult.utils.logging import get_logger
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class GeneratingSet_base(Basic):
19
+ def __new__(cls, *args):
20
+ return Basic.__new__(cls, *args)
21
+
22
+ def __getitem__(self, i):
23
+ return NotImplemented
24
+
25
+ def __len__(self):
26
+ return NotImplemented
27
+
28
+
29
+ # variable registry
30
+ # TODO: ensure sympifies
31
+ # TODO: masked generating set
32
+ class GeneratingSet(GeneratingSet_base):
33
+ def __new__(cls, name):
34
+ return GeneratingSet.__xnew_cached__(cls, name)
35
+
36
+ _registry: ClassVar = {}
37
+
38
+ _index_pattern = re.compile("^([^_]+)_([0-9]+)$")
39
+ _sage_index_pattern = re.compile("^([^0-9]+)([0-9]+)$")
40
+
41
+ # is_Atom = True
42
+ # TODO: masked generating set
43
+ @staticmethod
44
+ @cache
45
+ def __xnew_cached__(_class, name):
46
+ return GeneratingSet.__xnew__(_class, Str(str(name)))
47
+
48
+ @staticmethod
49
+ def __xnew__(_class, name):
50
+ obj = GeneratingSet_base.__new__(_class, name)
51
+ obj._symbols_arr = tuple([symbols(f"{name}_{i}") for i in range(100)])
52
+ obj._index_lookup = {obj._symbols_arr[i]: i for i in range(len(obj._symbols_arr))}
53
+ return obj
54
+
55
+ def __call__(self, index):
56
+ """1-indexed"""
57
+ return self[index - 1]
58
+
59
+ @property
60
+ def label(self):
61
+ return str(self.args[0])
62
+
63
+ # index of v in the genset
64
+ def index(self, v):
65
+ try:
66
+ return self._index_lookup.get(v, self._index_lookup.get(sympify(v), -1))
67
+ except SympifyError:
68
+ return -1
69
+
70
+ def __repr__(self):
71
+ return f"GeneratingSet('{self.label}')"
72
+
73
+ def __str__(self):
74
+ return self.name
75
+
76
+ def _latex(self, printer):
77
+ return printer.doprint(self.label)
78
+
79
+ def _sympystr(self, printer):
80
+ return printer.doprint(self.label)
81
+
82
+ def __getitem__(self, i):
83
+ return self._symbols_arr[i]
84
+
85
+ def __len__(self):
86
+ return len(self._symbols_arr)
87
+
88
+ def __hash__(self):
89
+ return hash(self.label)
90
+
91
+ def __iter__(self):
92
+ yield from [self[i] for i in range(len(self))]
93
+
94
+ def __eq__(self, other):
95
+ return isinstance(other, GeneratingSet) and self.label == other.label
96
+
97
+
98
+ class MaskedGeneratingSet(GeneratingSet_base):
99
+ def __new__(cls, gset, index_mask):
100
+ return MaskedGeneratingSet.__xnew_cached__(cls, gset, tuple(sorted(index_mask)))
101
+
102
+ @staticmethod
103
+ @cache
104
+ def __xnew_cached__(_class, gset, index_mask):
105
+ return MaskedGeneratingSet.__xnew__(_class, gset, index_mask)
106
+
107
+ @staticmethod
108
+ def __xnew__(_class, gset, index_mask):
109
+ obj = GeneratingSet_base.__new__(_class, gset, Tuple(*index_mask))
110
+ # obj._symbols_arr = tuple([symbols(f"{name}_{i}") for i in range(100)])
111
+ # obj._index_lookup = {obj._symbols_arr[i]: i for i in range(len(obj._symbols_arr))}
112
+ mask_dict = {}
113
+ mask_dict[0] = 0
114
+ for i in range(1, len(gset._symbols_arr)):
115
+ index = bisect_left(index_mask, i)
116
+ # logger.debug(f"{index=}")
117
+ if index >= len(index_mask) or index_mask[index] != i:
118
+ # logger.debug(f"{i - index} mapsto {i} and {index_mask=}")
119
+ mask_dict[i - index] = i
120
+ # if index>=len(index_mask) or index_mask[index] != i:
121
+ # mask_dict[cur_index] = i
122
+ # cur_index += 1
123
+ # print(f"{index_mask=} {mask_dict=}")
124
+ obj._mask = mask_dict
125
+ obj._index_lookup = {gset[mask_dict[i]]: i for i in range(len(gset) - len(index_mask))}
126
+ obj._label = gset.label
127
+ return obj
128
+
129
+ @property
130
+ def label(self):
131
+ return str(self._label)
132
+
133
+ def set_label(self, label):
134
+ self._label = label
135
+
136
+ @property
137
+ def index_mask(self):
138
+ return tuple(self.args[1])
139
+
140
+ def complement(self):
141
+ return MaskedGeneratingSet(self.base_genset, [i for i in range(1, len(self.base_genset)) if i not in set(self.index_mask)])
142
+
143
+ @property
144
+ def base_genset(self):
145
+ return self.args[0]
146
+
147
+ def __call__(self, index):
148
+ """1-indexed"""
149
+ return self[index - 1]
150
+
151
+ def __getitem__(self, index):
152
+ if isinstance(index, slice):
153
+ start = index.start if index.start is not None else 0
154
+ stop = index.stop if index.stop is not None else len(self)
155
+ return [self[ii] for ii in range(start, stop)]
156
+ return self.base_genset[self._mask[index]]
157
+
158
+ def __iter__(self):
159
+ yield from [self[i] for i in range(len(self))]
160
+
161
+ def index(self, v):
162
+ try:
163
+ return self._index_lookup.get(v, self._index_lookup.get(sympify(v), -1))
164
+ except SympifyError:
165
+ return -1
166
+
167
+ def __len__(self):
168
+ return len(self.base_genset) - len(self.index_mask)
169
+
170
+
171
+ class CustomGeneratingSet(GeneratingSet_base):
172
+ def __new__(cls, gens):
173
+ return CustomGeneratingSet.__xnew_cached__(cls, tuple(gens))
174
+
175
+ @staticmethod
176
+ @cache
177
+ def __xnew_cached__(_class, gens):
178
+ return CustomGeneratingSet.__xnew__(_class, gens)
179
+
180
+ @staticmethod
181
+ def __xnew__(_class, gens):
182
+ obj = GeneratingSet_base.__new__(_class, Tuple(*gens))
183
+ obj._symbols_arr = [sympify(gens[i]) for i in range(len(gens))]
184
+ obj._index_lookup = {obj._symbols_arr[i]: i for i in range(len(obj._symbols_arr))}
185
+ return obj
186
+
187
+ def __getitem__(self, index):
188
+ return self._symbols_arr[index]
189
+
190
+ def __iter__(self):
191
+ yield from [self[i] for i in range(len(self))]
192
+
193
+ def index(self, v):
194
+ try:
195
+ return self._index_lookup.get(v, self._index_lookup.get(sympify(v), -1))
196
+ except SympifyError:
197
+ return -1
198
+
199
+ def __call__(self, index):
200
+ """1-indexed"""
201
+ return self[index - 1]
202
+
203
+ def __len__(self):
204
+ return len(self.args[0])
@@ -0,0 +1,18 @@
1
+ from ._quantum_schubert_polynomial_ring import QDSx, QPDSx, QPSx, QSx, QuantumDoubleSchubertAlgebraElement, QuantumDoubleSchubertAlgebraElement_basis, make_parabolic_quantum_basis
2
+ from ._schubert_polynomial_ring import DoubleSchubertAlgebraElement, DoubleSchubertAlgebraElement_basis, DSx, Sx
3
+ from ._utils import poly_ring
4
+
5
+ __all__ = [
6
+ "DSx",
7
+ "DoubleSchubertAlgebraElement",
8
+ "DoubleSchubertAlgebraElement_basis",
9
+ "QDSx",
10
+ "QPDSx",
11
+ "QPSx",
12
+ "QSx",
13
+ "QuantumDoubleSchubertAlgebraElement",
14
+ "QuantumDoubleSchubertAlgebraElement_basis",
15
+ "Sx",
16
+ "make_parabolic_quantum_basis",
17
+ "poly_ring",
18
+ ]