schubmult 2.0.3__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. schubmult/__init__.py +94 -1
  2. schubmult/perm_lib.py +233 -880
  3. schubmult/poly_lib/__init__.py +31 -0
  4. schubmult/poly_lib/poly_lib.py +244 -0
  5. schubmult/poly_lib/schub_poly.py +148 -0
  6. schubmult/poly_lib/variables.py +204 -0
  7. schubmult/rings/__init__.py +17 -0
  8. schubmult/rings/_quantum_schubert_polynomial_ring.py +788 -0
  9. schubmult/rings/_schubert_polynomial_ring.py +1006 -0
  10. schubmult/rings/_tensor_schub_ring.py +128 -0
  11. schubmult/rings/_utils.py +55 -0
  12. schubmult/{sage_integration → sage}/__init__.py +17 -15
  13. schubmult/{sage_integration → sage}/_fast_double_schubert_polynomial_ring.py +142 -220
  14. schubmult/{sage_integration → sage}/_fast_schubert_polynomial_ring.py +78 -72
  15. schubmult/sage/_indexing.py +51 -0
  16. schubmult/schub_lib/__init__.py +51 -0
  17. schubmult/{schubmult_double/_funcs.py → schub_lib/double.py} +618 -798
  18. schubmult/{schubmult_q/_funcs.py → schub_lib/quantum.py} +70 -72
  19. schubmult/schub_lib/quantum_double.py +954 -0
  20. schubmult/schub_lib/schub_lib.py +659 -0
  21. schubmult/{schubmult_py/_funcs.py → schub_lib/single.py} +58 -48
  22. schubmult/schub_lib/tests/__init__.py +0 -0
  23. schubmult/schub_lib/tests/legacy_perm_lib.py +946 -0
  24. schubmult/schub_lib/tests/test_vs_old.py +109 -0
  25. schubmult/scripts/__init__.py +0 -0
  26. schubmult/scripts/schubmult_double.py +378 -0
  27. schubmult/scripts/schubmult_py.py +84 -0
  28. schubmult/scripts/schubmult_q.py +109 -0
  29. schubmult/scripts/schubmult_q_double.py +207 -0
  30. schubmult/utils/__init__.py +0 -0
  31. schubmult/{_base_argparse.py → utils/argparse.py} +40 -11
  32. schubmult/utils/logging.py +16 -0
  33. schubmult/utils/parsing.py +20 -0
  34. schubmult/utils/perm_utils.py +135 -0
  35. schubmult/utils/test_utils.py +65 -0
  36. schubmult-3.0.0.dist-info/METADATA +1234 -0
  37. schubmult-3.0.0.dist-info/RECORD +41 -0
  38. {schubmult-2.0.3.dist-info → schubmult-3.0.0.dist-info}/WHEEL +1 -1
  39. schubmult-3.0.0.dist-info/entry_points.txt +5 -0
  40. schubmult/_tests.py +0 -9
  41. schubmult/sage_integration/_indexing.py +0 -51
  42. schubmult/schubmult_double/__init__.py +0 -22
  43. schubmult/schubmult_double/__main__.py +0 -5
  44. schubmult/schubmult_double/_script.py +0 -474
  45. schubmult/schubmult_py/__init__.py +0 -13
  46. schubmult/schubmult_py/__main__.py +0 -5
  47. schubmult/schubmult_py/_script.py +0 -96
  48. schubmult/schubmult_q/__init__.py +0 -13
  49. schubmult/schubmult_q/__main__.py +0 -5
  50. schubmult/schubmult_q/_script.py +0 -160
  51. schubmult/schubmult_q_double/__init__.py +0 -17
  52. schubmult/schubmult_q_double/__main__.py +0 -5
  53. schubmult/schubmult_q_double/_funcs.py +0 -540
  54. schubmult/schubmult_q_double/_script.py +0 -398
  55. schubmult-2.0.3.dist-info/METADATA +0 -455
  56. schubmult-2.0.3.dist-info/RECORD +0 -30
  57. schubmult-2.0.3.dist-info/entry_points.txt +0 -5
  58. {schubmult-2.0.3.dist-info → schubmult-3.0.0.dist-info}/licenses/LICENSE +0 -0
  59. {schubmult-2.0.3.dist-info → schubmult-3.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,31 @@
1
+ # from .poly_lib import (
2
+ # call_zvars,
3
+ # efficient_subs,
4
+ # elem_sym_func,
5
+ # elem_sym_func_q,
6
+ # elem_sym_poly,
7
+ # elem_sym_poly_q,
8
+ # expand,
9
+ # q_vector,
10
+ # xreplace_genvars,
11
+ # )
12
+ # from .schub_poly import div_diff, perm_act, schubpoly, skew_div_diff
13
+ # from .variables import GeneratingSet, base_index
14
+
15
+ # __all__ = [
16
+ # "GeneratingSet",
17
+ # "base_index",
18
+ # "call_zvars",
19
+ # "div_diff",
20
+ # "efficient_subs",
21
+ # "elem_sym_func",
22
+ # "elem_sym_func_q",
23
+ # "elem_sym_poly",
24
+ # "elem_sym_poly_q",
25
+ # "expand",
26
+ # "perm_act",
27
+ # "q_vector",
28
+ # "schubpoly",
29
+ # "skew_div_diff",
30
+ # "xreplace_genvars",
31
+ # ]
@@ -0,0 +1,244 @@
1
+ from functools import cache, cached_property
2
+
3
+ import symengine
4
+ from symengine import Mul, Pow, sympify
5
+
6
+ import schubmult.perm_lib as pl
7
+ import schubmult.poly_lib.variables as vv
8
+
9
+ # import vv.GeneratingSet, vv.base_index
10
+
11
+
12
+ # Indexed._sympystr = lambda x, p: f"{p.doprint(x.args[0])}_{x.args[1]}"
13
+ def expand(val):
14
+ return symengine.expand(val)
15
+
16
+
17
+ class _gvars:
18
+ @cached_property
19
+ def n(self):
20
+ return 100
21
+
22
+ # @cached_property
23
+ # def fvar(self):
24
+ # return 100
25
+
26
+ @cached_property
27
+ def var1(self):
28
+ return vv.GeneratingSet("x")
29
+
30
+ @cached_property
31
+ def var2(self):
32
+ return vv.GeneratingSet("y")
33
+
34
+ @cached_property
35
+ def var3(self):
36
+ return vv.GeneratingSet("z")
37
+
38
+ @cached_property
39
+ def var_r(self):
40
+ return vv.GeneratingSet("r")
41
+
42
+ @cached_property
43
+ def var_g1(self):
44
+ return vv.GeneratingSet("y")
45
+
46
+ @cached_property
47
+ def var_g2(self):
48
+ return vv.GeneratingSet("z")
49
+
50
+ @cached_property
51
+ def q_var(self):
52
+ return vv.GeneratingSet("q")
53
+
54
+
55
+ zero = sympify(0)
56
+
57
+ _vars = _gvars()
58
+
59
+ def act(w, poly, genset):
60
+ if not isinstance(w, pl.Permutation):
61
+ w = pl.Permutation(w)
62
+ subs_dict = {}
63
+ if not isinstance(genset, vv.GeneratingSet_base):
64
+ genset = vv.CustomGeneratingSet(genset)
65
+ for s in poly.free_symbols:
66
+ if genset.index(s) != -1:
67
+ subs_dict[s] = genset[w(genset.index(s))]
68
+ return efficient_subs(poly, subs_dict)
69
+
70
+ def elem_sym_func(k, i, u1, u2, v1, v2, udiff, vdiff, varl1, varl2):
71
+ newk = k - udiff
72
+ if newk < vdiff:
73
+ return zero
74
+ if newk == vdiff:
75
+ return one
76
+ yvars = []
77
+ for j in range(min(len(u1), k)):
78
+ if u1[j] == u2[j]:
79
+ yvars += [varl1[u2[j]]]
80
+ for j in range(len(u1), min(k, len(u2))):
81
+ if u2[j] == j + 1:
82
+ yvars += [varl1[u2[j]]]
83
+ for j in range(len(u2), k):
84
+ yvars += [varl1[j + 1]]
85
+ zvars = [varl2[i] for i in call_zvars(v1, v2, k, i)]
86
+ return elem_sym_poly(newk - vdiff, newk, yvars, zvars)
87
+
88
+
89
+ # def elem_sym_func_q(k, i, u1, u2, v1, v2, udiff, vdiff, varl1, varl2):
90
+ # newk = k - udiff
91
+ # if newk < vdiff:
92
+ # return zero
93
+ # if newk == vdiff:
94
+ # return one
95
+ # yvars = []
96
+ # mlen = max(len(u1), len(u2))
97
+ # u1 = [*u1] + [a + 1 for a in range(len(u1), mlen)]
98
+ # u2 = [*u2] + [a + 1 for a in range(len(u2), mlen)]
99
+ # for j in range(min(len(u1), k)):
100
+ # if u1[j] == u2[j]:
101
+ # yvars += [varl1[u2[j]]]
102
+ # for j in range(len(u1), min(k, len(u2))):
103
+ # if u2[j] == j + 1:
104
+ # yvars += [varl1[u2[j]]]
105
+ # for j in range(len(u2), k):
106
+ # yvars += [varl1[j + 1]]
107
+ # zvars = [varl2[a] for a in call_zvars(v1, v2, k, i)]
108
+ # return elem_sym_poly(newk - vdiff, newk, yvars, zvars)
109
+
110
+
111
+ def elem_sym_func_q(k, i, u1, u2, v1, v2, udiff, vdiff, varl1, varl2):
112
+ newk = k - udiff
113
+ if newk < vdiff:
114
+ return zero
115
+ if newk == vdiff:
116
+ return one
117
+ yvars = []
118
+ # print(f"{u1=} {u2=} {max(len(u1),len(u2))=}")
119
+ # print(f"{u1=} {u2=} {max(len(u1),len(u2))=}")
120
+ # print(f"{k=}")
121
+ # u1 = [*u1] + [a + 1 for a in range(len(u1), mlen)]
122
+ # u2 = [*u2] + [a + 1 for a in range(len(u2), mlen)]
123
+ for j in range(k):
124
+ if u1[j] == u2[j]:
125
+ yvars += [varl1[u2[j]]]
126
+ # print(f"{yvars=}")
127
+ # for j in range(len(u1), min(k, len(u2))):
128
+ # if u2[j] == j + 1:
129
+ # yvars += [varl1[u2[j]]]
130
+ # for j in range(len(u2), k):
131
+ # yvars += [varl1[j + 1]]
132
+ zvars = [varl2[a] for a in call_zvars(v1, v2, k, i)]
133
+ return elem_sym_poly(newk - vdiff, newk, yvars, zvars)
134
+
135
+
136
+ one = sympify(1)
137
+
138
+
139
+ def elem_sym_poly_q(p, k, varl1, varl2, q_var=_vars.q_var):
140
+ if p == 0 and k >= 0:
141
+ return one
142
+ if p < 0 or p > k:
143
+ return zero
144
+ return (
145
+ (varl1[k - 1] - varl2[k - p]) * elem_sym_poly_q(p - 1, k - 1, varl1, varl2, q_var)
146
+ + elem_sym_poly_q(p, k - 1, varl1, varl2, q_var)
147
+ + q_var[k - 1] * elem_sym_poly_q(p - 2, k - 2, varl1, varl2, q_var)
148
+ )
149
+
150
+
151
+ def elem_sym_poly(p, k, varl1, varl2, xstart=0, ystart=0):
152
+ if p > k:
153
+ return zero
154
+ if p == 0:
155
+ return one
156
+ if p == 1:
157
+ res = varl1[xstart] - varl2[ystart]
158
+ for i in range(1, k):
159
+ res += varl1[xstart + i] - varl2[ystart + i]
160
+ return res
161
+ if p == k:
162
+ res = (varl1[xstart] - varl2[ystart]) * (varl1[xstart + 1] - varl2[ystart])
163
+ for i in range(2, k):
164
+ res *= varl1[i + xstart] - varl2[ystart]
165
+ return res
166
+ mid = k // 2
167
+ xsm = xstart + mid
168
+ ysm = ystart + mid
169
+ kmm = k - mid
170
+ res = elem_sym_poly(p, mid, varl1, varl2, xstart, ystart) + elem_sym_poly(
171
+ p,
172
+ kmm,
173
+ varl1,
174
+ varl2,
175
+ xsm,
176
+ ysm,
177
+ )
178
+ for p2 in range(max(1, p - kmm), min(p, mid + 1)):
179
+ res += elem_sym_poly(p2, mid, varl1, varl2, xstart, ystart) * elem_sym_poly(
180
+ p - p2,
181
+ kmm,
182
+ varl1,
183
+ varl2,
184
+ xsm,
185
+ ysm - p2,
186
+ )
187
+ return res
188
+
189
+
190
+ # def call_zvars(v1, v2, k, i):
191
+ # v3 = [*v2, *list(range(len(v2) + 1, i + 1))]
192
+ # return [v3[i - 1]] + [v3[j] for j in range(len(v1), len(v3)) if v3[j] != j + 1 and j != i - 1] + [v3[j] for j in range(len(v1)) if v1[j] != v3[j] and j != i - 1]
193
+
194
+
195
+ @cache
196
+ def call_zvars(v1, v2, k, i): # noqa: ARG001
197
+ return [v2[i - 1]] + [v2[j] for j in range(len(v1), len(v2) + max(0, i - len(v2))) if v2[j] != j + 1 and j != i - 1] + [v2[j] for j in range(len(v1)) if v1[j] != v2[j] and j != i - 1]
198
+
199
+
200
+ def efficient_subs(expr, subs_dict):
201
+ subs_dict_new = {}
202
+ expr = sympify(expr)
203
+ for s in expr.free_symbols:
204
+ if s in subs_dict:
205
+ subs_dict_new[s] = subs_dict[s]
206
+ return expr.subs(subs_dict_new)
207
+
208
+
209
+ def q_vector(q_exp, q_var=_vars.q_var):
210
+ # qvar_list = q_var.tolist()
211
+ ret = []
212
+
213
+ if q_exp == 1:
214
+ return ret
215
+ if q_var.index(q_exp) != -1:
216
+ i = q_var.index(q_exp)
217
+ return [0 for j in range(i - 1)] + [1]
218
+ if isinstance(q_exp, Pow):
219
+ qv = q_exp.args[0]
220
+ expon = int(q_exp.args[1])
221
+ i = q_var.index(qv)
222
+ if i == -1:
223
+ raise IndexError
224
+ return [0 for j in range(i - 1)] + [expon]
225
+ if isinstance(q_exp, Mul):
226
+ for a in q_exp.args:
227
+ v1 = q_vector(a)
228
+ v1 += [0 for i in range(len(v1), len(ret))]
229
+ ret += [0 for i in range(len(ret), len(v1))]
230
+ ret = [ret[i] + v1[i] for i in range(len(ret))]
231
+ return ret
232
+
233
+ return None
234
+
235
+
236
+ def xreplace_genvars(poly, vars1, vars2):
237
+ subs_dict = {}
238
+ for s in sympify(poly).free_symbols:
239
+ if _vars.var_g1.index(s) != -1:
240
+ subs_dict[s] = vars1[_vars.var_g1.index(s)]
241
+ elif _vars.var_g2.index(s) != -1:
242
+ subs_dict[s] = vars2[_vars.var_g2.index(s)]
243
+ return sympify(poly).xreplace(subs_dict)
244
+ # print(f"{poly2=} {poly2.free_symbols=}")
@@ -0,0 +1,148 @@
1
+ import sympy
2
+
3
+ import schubmult.perm_lib as pl
4
+ import schubmult.schub_lib.schub_lib as schub_lib
5
+ from schubmult.poly_lib.poly_lib import call_zvars
6
+ from schubmult.poly_lib.variables import GeneratingSet
7
+
8
+
9
+ def perm_act(val, i, var2=None):
10
+ subsdict = {var2[i]: var2[i + 1], var2[i + 1]: var2[i]}
11
+ return sympy.sympify(val).subs(subsdict)
12
+
13
+
14
+ def div_diff(i, poly, var2=None):
15
+ return sympy.sympify(
16
+ sympy.div(sympy.sympify(poly - perm_act(poly, i)), sympy.sympify(var2[i] - var2[i + 1]))[0],
17
+ )
18
+
19
+
20
+ def elem_func_func(k, i, v1, v2, vdiff, varl1, varl2, elem_func):
21
+ newk = k
22
+ if newk < vdiff:
23
+ return 0
24
+ if newk == vdiff:
25
+ return 1
26
+ zvars = [varl2[i] for i in call_zvars(v1, v2, k, i)]
27
+ return elem_func(newk - vdiff, newk, varl1[1:], zvars)
28
+
29
+
30
+ def schubpoly_from_elems(v, var_x=None, var_y=None, elem_func=None, mumu=None):
31
+ if mumu:
32
+ # print(pl.code(mumu))
33
+ th = pl.code(mumu)
34
+ # print(f"{th=}")
35
+ mu = mumu
36
+ else:
37
+ th = pl.strict_theta(~pl.Permutation(v))
38
+ mu = pl.uncode(th)
39
+ vmu = pl.Permutation(v) * mu # permtrim(mulperm([*v], mu))
40
+ if len(th) == 0:
41
+ return elem_func(0, 0, var_x, var_y)
42
+ while len(th) > 0 and th[-1] == 0:
43
+ th.pop()
44
+ vpathdicts = schub_lib.compute_vpathdicts(th, vmu)
45
+ vpathsums = {pl.Permutation([1, 2]): elem_func(0, 0, var_x, var_y)}
46
+ for index in range(len(th)):
47
+ mx_th = 0
48
+ newpathsums = {}
49
+ for vp in vpathdicts[index]:
50
+ for v2, vdiff, s in vpathdicts[index][vp]:
51
+ mx_th = max(mx_th, th[index] - vdiff)
52
+ for v in vpathdicts[index]:
53
+ sumval = vpathsums.get(v, 0)
54
+ if sumval == 0:
55
+ continue
56
+ for v2, vdiff, s in vpathdicts[index][v]:
57
+ newpathsums[v2] = newpathsums.get(
58
+ v2,
59
+ 0,
60
+ ) + s * sumval * elem_func_func(
61
+ th[index],
62
+ index + 1,
63
+ v,
64
+ v2,
65
+ vdiff,
66
+ var_x,
67
+ var_y,
68
+ elem_func=elem_func,
69
+ )
70
+ vpathsums = newpathsums
71
+ return vpathsums.get(vmu, 0)
72
+
73
+ def schubpoly_classical_from_elems(v, var_x=None, var_y=None, elem_func=None):
74
+ th = pl.theta(~pl.Permutation(v))
75
+ mu = pl.uncode(th)
76
+ vmu = pl.Permutation(v) * mu # permtrim(mulperm([*v], mu))
77
+ if len(th) == 0:
78
+ return elem_func(0, 0, var_x, var_y)
79
+ while len(th) > 0 and th[-1] == 0:
80
+ th.pop()
81
+ vpathdicts = schub_lib.compute_vpathdicts(th, vmu)
82
+ vpathsums = {pl.Permutation([1, 2]): elem_func(0, 0, var_x, var_y)}
83
+ for index in range(len(th)):
84
+ mx_th = 0
85
+ newpathsums = {}
86
+ for vp in vpathdicts[index]:
87
+ for v2, vdiff, s in vpathdicts[index][vp]:
88
+ mx_th = max(mx_th, th[index] - vdiff)
89
+ for v in vpathdicts[index]:
90
+ sumval = vpathsums.get(v, 0)
91
+ if sumval == 0:
92
+ continue
93
+ for v2, vdiff, s in vpathdicts[index][v]:
94
+ newpathsums[v2] = newpathsums.get(
95
+ v2,
96
+ 0,
97
+ ) + s * sumval * elem_func_func(
98
+ th[index],
99
+ index + 1,
100
+ v,
101
+ v2,
102
+ vdiff,
103
+ var_x,
104
+ var_y,
105
+ elem_func=elem_func,
106
+ )
107
+ vpathsums = newpathsums
108
+ return vpathsums.get(vmu, 0)
109
+
110
+
111
+ def schubpoly(v, var2=GeneratingSet("y"), var3=GeneratingSet("z"), start_var=1):
112
+ n = 0
113
+ for j in range(len(v) - 2, -1, -1):
114
+ if v[j] > v[j + 1]:
115
+ n = j + 1
116
+ break
117
+ if n == 0:
118
+ return 1
119
+ lst = schub_lib.pull_out_var(n, v)
120
+ ret = 0
121
+ for pw, vp in lst:
122
+ tomul = 1
123
+ for p in pw:
124
+ tomul *= var2[start_var + n - 1] - var3[p]
125
+ ret += tomul * schubpoly(vp, var2, var3, start_var)
126
+ return ret
127
+
128
+
129
+ def skew_div_diff(u, w, poly):
130
+ d = -1
131
+ for i in range(len(w) - 1):
132
+ if w[i] > w[i + 1]:
133
+ d = i
134
+ break
135
+ d2 = -1
136
+ for i in range(len(u) - 1):
137
+ if u[i] > u[i + 1]:
138
+ d2 = i
139
+ break
140
+ if d == -1:
141
+ if d2 == -1:
142
+ return poly
143
+ return 0
144
+ w2 = w.swap(d, d + 1)
145
+ if d < len(u) - 1 and u[d] > u[d + 1]:
146
+ u2 = u.swap(d, d + 1)
147
+ return skew_div_diff(u2, w2, perm_act(poly, d + 1))
148
+ return skew_div_diff(u, w2, div_diff(d + 1, poly))
@@ -0,0 +1,204 @@
1
+ # class generators with base
2
+ # symbols cls argument!
3
+
4
+ import re
5
+ from bisect import bisect_left
6
+ from functools import cache
7
+ from typing import ClassVar
8
+
9
+ from symengine import SympifyError, symbols, sympify
10
+ from sympy import Basic, Tuple
11
+ from sympy.core.symbol import Str
12
+
13
+ from schubmult.utils.logging import get_logger
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class GeneratingSet_base(Basic):
19
+ def __new__(cls, *args):
20
+ return Basic.__new__(cls, *args)
21
+
22
+ def __getitem__(self, i):
23
+ return NotImplemented
24
+
25
+ def __len__(self):
26
+ return NotImplemented
27
+
28
+
29
+ # variable registry
30
+ # TODO: ensure sympifies
31
+ # TODO: masked generating set
32
+ class GeneratingSet(GeneratingSet_base):
33
+ def __new__(cls, name):
34
+ return GeneratingSet.__xnew_cached__(cls, name)
35
+
36
+ _registry: ClassVar = {}
37
+
38
+ _index_pattern = re.compile("^([^_]+)_([0-9]+)$")
39
+ _sage_index_pattern = re.compile("^([^0-9]+)([0-9]+)$")
40
+
41
+ # is_Atom = True
42
+ # TODO: masked generating set
43
+ @staticmethod
44
+ @cache
45
+ def __xnew_cached__(_class, name):
46
+ return GeneratingSet.__xnew__(_class, Str(str(name)))
47
+
48
+ @staticmethod
49
+ def __xnew__(_class, name):
50
+ obj = GeneratingSet_base.__new__(_class, name)
51
+ obj._symbols_arr = tuple([symbols(f"{name}_{i}") for i in range(100)])
52
+ obj._index_lookup = {obj._symbols_arr[i]: i for i in range(len(obj._symbols_arr))}
53
+ return obj
54
+
55
+ def __call__(self, index):
56
+ """1-indexed"""
57
+ return self[index - 1]
58
+
59
+ @property
60
+ def label(self):
61
+ return str(self.args[0])
62
+
63
+ # index of v in the genset
64
+ def index(self, v):
65
+ try:
66
+ return self._index_lookup.get(v, self._index_lookup.get(sympify(v), -1))
67
+ except SympifyError:
68
+ return -1
69
+
70
+ def __repr__(self):
71
+ return f"GeneratingSet('{self.label}')"
72
+
73
+ def __str__(self):
74
+ return self.name
75
+
76
+ def _latex(self, printer):
77
+ return printer.doprint(self.label)
78
+
79
+ def _sympystr(self, printer):
80
+ return printer.doprint(self.label)
81
+
82
+ def __getitem__(self, i):
83
+ return self._symbols_arr[i]
84
+
85
+ def __len__(self):
86
+ return len(self._symbols_arr)
87
+
88
+ def __hash__(self):
89
+ return hash(self.label)
90
+
91
+ def __iter__(self):
92
+ yield from [self[i] for i in range(len(self))]
93
+
94
+ def __eq__(self, other):
95
+ return isinstance(other, GeneratingSet) and self.label == other.label
96
+
97
+
98
+ class MaskedGeneratingSet(GeneratingSet_base):
99
+ def __new__(cls, gset, index_mask):
100
+ return MaskedGeneratingSet.__xnew_cached__(cls, gset, tuple(sorted(index_mask)))
101
+
102
+ @staticmethod
103
+ @cache
104
+ def __xnew_cached__(_class, gset, index_mask):
105
+ return MaskedGeneratingSet.__xnew__(_class, gset, index_mask)
106
+
107
+ @staticmethod
108
+ def __xnew__(_class, gset, index_mask):
109
+ obj = GeneratingSet_base.__new__(_class, gset, Tuple(*index_mask))
110
+ # obj._symbols_arr = tuple([symbols(f"{name}_{i}") for i in range(100)])
111
+ # obj._index_lookup = {obj._symbols_arr[i]: i for i in range(len(obj._symbols_arr))}
112
+ mask_dict = {}
113
+ mask_dict[0] = 0
114
+ for i in range(1, len(gset._symbols_arr)):
115
+ index = bisect_left(index_mask, i)
116
+ # logger.debug(f"{index=}")
117
+ if index >= len(index_mask) or index_mask[index] != i:
118
+ # logger.debug(f"{i - index} mapsto {i} and {index_mask=}")
119
+ mask_dict[i - index] = i
120
+ # if index>=len(index_mask) or index_mask[index] != i:
121
+ # mask_dict[cur_index] = i
122
+ # cur_index += 1
123
+ # print(f"{index_mask=} {mask_dict=}")
124
+ obj._mask = mask_dict
125
+ obj._index_lookup = {gset[mask_dict[i]]: i for i in range(len(gset) - len(index_mask))}
126
+ obj._label = gset.label
127
+ return obj
128
+
129
+ @property
130
+ def label(self):
131
+ return str(self._label)
132
+
133
+ def set_label(self, label):
134
+ self._label = label
135
+
136
+ @property
137
+ def index_mask(self):
138
+ return tuple(self.args[1])
139
+
140
+ def complement(self):
141
+ return MaskedGeneratingSet(self.base_genset, [i for i in range(1, len(self.base_genset)) if i not in set(self.index_mask)])
142
+
143
+ @property
144
+ def base_genset(self):
145
+ return self.args[0]
146
+
147
+ def __call__(self, index):
148
+ """1-indexed"""
149
+ return self[index - 1]
150
+
151
+ def __getitem__(self, index):
152
+ if isinstance(index, slice):
153
+ start = index.start if index.start is not None else 0
154
+ stop = index.stop if index.stop is not None else len(self)
155
+ return [self[ii] for ii in range(start, stop)]
156
+ return self.base_genset[self._mask[index]]
157
+
158
+ def __iter__(self):
159
+ yield from [self[i] for i in range(len(self))]
160
+
161
+ def index(self, v):
162
+ try:
163
+ return self._index_lookup.get(v, self._index_lookup.get(sympify(v), -1))
164
+ except SympifyError:
165
+ return -1
166
+
167
+ def __len__(self):
168
+ return len(self.base_genset) - len(self.index_mask)
169
+
170
+
171
+ class CustomGeneratingSet(GeneratingSet_base):
172
+ def __new__(cls, gens):
173
+ return CustomGeneratingSet.__xnew_cached__(cls, tuple(gens))
174
+
175
+ @staticmethod
176
+ @cache
177
+ def __xnew_cached__(_class, gens):
178
+ return CustomGeneratingSet.__xnew__(_class, gens)
179
+
180
+ @staticmethod
181
+ def __xnew__(_class, gens):
182
+ obj = GeneratingSet_base.__new__(_class, Tuple(*gens))
183
+ obj._symbols_arr = [sympify(gens[i]) for i in range(len(gens))]
184
+ obj._index_lookup = {obj._symbols_arr[i]: i for i in range(len(obj._symbols_arr))}
185
+ return obj
186
+
187
+ def __getitem__(self, index):
188
+ return self._symbols_arr[index]
189
+
190
+ def __iter__(self):
191
+ yield from [self[i] for i in range(len(self))]
192
+
193
+ def index(self, v):
194
+ try:
195
+ return self._index_lookup.get(v, self._index_lookup.get(sympify(v), -1))
196
+ except SympifyError:
197
+ return -1
198
+
199
+ def __call__(self, index):
200
+ """1-indexed"""
201
+ return self[index - 1]
202
+
203
+ def __len__(self):
204
+ return len(self.args[0])
@@ -0,0 +1,17 @@
1
+ from ._quantum_schubert_polynomial_ring import QDSx, QPDSx, QSx, QuantumDoubleSchubertAlgebraElement, QuantumDoubleSchubertAlgebraElement_basis, make_parabolic_quantum_basis
2
+ from ._schubert_polynomial_ring import DoubleSchubertAlgebraElement, DoubleSchubertAlgebraElement_basis, DSx, Sx
3
+ from ._utils import poly_ring
4
+
5
+ __all__ = [
6
+ "DSx",
7
+ "DoubleSchubertAlgebraElement",
8
+ "DoubleSchubertAlgebraElement_basis",
9
+ "QDSx",
10
+ "QPDSx",
11
+ "QSx",
12
+ "QuantumDoubleSchubertAlgebraElement",
13
+ "QuantumDoubleSchubertAlgebraElement_basis",
14
+ "Sx",
15
+ "make_parabolic_quantum_basis",
16
+ "poly_ring",
17
+ ]