schubmult 2.0.2__py3-none-any.whl → 2.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. schubmult/__init__.py +1 -1
  2. schubmult/_base_argparse.py +42 -8
  3. schubmult/_tests.py +24 -0
  4. schubmult/perm_lib.py +52 -112
  5. schubmult/sage_integration/__init__.py +13 -13
  6. schubmult/sage_integration/_fast_double_schubert_polynomial_ring.py +139 -118
  7. schubmult/sage_integration/_fast_schubert_polynomial_ring.py +88 -49
  8. schubmult/sage_integration/_indexing.py +35 -32
  9. schubmult/schubmult_double/__init__.py +6 -12
  10. schubmult/schubmult_double/__main__.py +2 -1
  11. schubmult/schubmult_double/_funcs.py +245 -281
  12. schubmult/schubmult_double/_script.py +128 -70
  13. schubmult/schubmult_py/__init__.py +5 -3
  14. schubmult/schubmult_py/__main__.py +2 -1
  15. schubmult/schubmult_py/_funcs.py +68 -23
  16. schubmult/schubmult_py/_script.py +40 -58
  17. schubmult/schubmult_q/__init__.py +3 -7
  18. schubmult/schubmult_q/__main__.py +2 -1
  19. schubmult/schubmult_q/_funcs.py +41 -60
  20. schubmult/schubmult_q/_script.py +39 -30
  21. schubmult/schubmult_q_double/__init__.py +5 -11
  22. schubmult/schubmult_q_double/__main__.py +2 -1
  23. schubmult/schubmult_q_double/_funcs.py +99 -66
  24. schubmult/schubmult_q_double/_script.py +209 -150
  25. schubmult-2.0.4.dist-info/METADATA +542 -0
  26. schubmult-2.0.4.dist-info/RECORD +30 -0
  27. {schubmult-2.0.2.dist-info → schubmult-2.0.4.dist-info}/WHEEL +1 -1
  28. schubmult-2.0.4.dist-info/entry_points.txt +5 -0
  29. {schubmult-2.0.2.dist-info → schubmult-2.0.4.dist-info}/top_level.txt +0 -1
  30. schubmult/schubmult_double/_vars.py +0 -18
  31. schubmult/schubmult_py/_vars.py +0 -3
  32. schubmult/schubmult_q/_vars.py +0 -18
  33. schubmult/schubmult_q_double/_vars.py +0 -21
  34. schubmult-2.0.2.dist-info/METADATA +0 -455
  35. schubmult-2.0.2.dist-info/RECORD +0 -36
  36. schubmult-2.0.2.dist-info/entry_points.txt +0 -5
  37. tests/__init__.py +0 -0
  38. tests/test_fast_double_schubert.py +0 -145
  39. tests/test_fast_schubert.py +0 -38
  40. {schubmult-2.0.2.dist-info → schubmult-2.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,31 +1,69 @@
1
1
  import sys
2
+ from functools import cached_property
3
+
2
4
  import numpy as np
3
- from ._vars import var_x, var2, var3, var2_t, var3_t
4
- from ._funcs import schubmult, schubmult_db, mult_poly, nil_hecke, factor_out_q_keep_factored
5
- from schubmult.schubmult_double import compute_positive_rep, posify, div_diff
6
- from symengine import expand, sympify, symarray
5
+ from symengine import expand, symarray, sympify
6
+
7
+ from schubmult._base_argparse import schub_argparse
7
8
  from schubmult.perm_lib import (
9
+ check_blocks,
10
+ code,
11
+ count_less_than,
12
+ inv,
8
13
  inverse,
14
+ is_parabolic,
15
+ longest_element,
9
16
  medium_theta,
10
- permtrim,
11
- inv,
12
17
  mulperm,
13
- uncode,
18
+ omega,
19
+ permtrim,
14
20
  q_var,
15
21
  q_vector,
16
22
  reduce_q_coeff,
17
- code,
18
23
  trimcode,
19
- longest_element,
20
- check_blocks,
21
- is_parabolic,
22
- count_less_than,
23
- omega,
24
+ uncode,
24
25
  )
25
- from schubmult._base_argparse import schub_argparse
26
+ from schubmult.schubmult_double import compute_positive_rep, div_diff, posify
27
+ from schubmult.schubmult_q_double._funcs import (
28
+ factor_out_q_keep_factored,
29
+ # mult_poly,
30
+ nil_hecke,
31
+ schubmult,
32
+ schubmult_db,
33
+ )
34
+
35
+
36
+ class _gvars:
37
+ @cached_property
38
+ def n(self):
39
+ return 100
40
+
41
+ # @cached_property
42
+ # def fvar(self):
43
+ # return 100
44
+
45
+ @cached_property
46
+ def var1(self):
47
+ return tuple(symarray("x", self.n).tolist())
48
+
49
+ @cached_property
50
+ def var2(self):
51
+ return tuple(symarray("y", self.n).tolist())
52
+
53
+ @cached_property
54
+ def var3(self):
55
+ return tuple(symarray("z", self.n).tolist())
56
+
57
+ @cached_property
58
+ def var_r(self):
59
+ return symarray("r", 100)
60
+
26
61
 
62
+ _vars = _gvars()
27
63
 
28
- def _display_full(coeff_dict, args, formatter, posified=None, var2=var2, var3=var3):
64
+
65
+ def _display_full(coeff_dict, args, formatter, posified=None, var2=_vars.var2, var3=_vars.var3):
66
+ raw_result_dict = {}
29
67
  mult = args.mult
30
68
 
31
69
  perms = args.perms
@@ -37,12 +75,17 @@ def _display_full(coeff_dict, args, formatter, posified=None, var2=var2, var3=va
37
75
  display_positive = args.display_positive
38
76
  expa = args.expa
39
77
  slow = args.slow
40
- nilhecke_apply = False
78
+ nilhecke_apply = False
79
+ subs_dict2 = {}
80
+ for i in range(1, 100):
81
+ sm = var2[1]
82
+ for j in range(1, i):
83
+ sm += _vars.var_r[j]
84
+ subs_dict2[var2[i]] = sm
41
85
 
42
86
  coeff_perms = list(coeff_dict.keys())
43
87
  coeff_perms.sort(key=lambda x: (inv(x), *x))
44
88
 
45
- var_r = symarray("r", 100)
46
89
  for perm in coeff_perms:
47
90
  val = coeff_dict[perm]
48
91
  if expand(val) != 0:
@@ -50,109 +93,111 @@ def _display_full(coeff_dict, args, formatter, posified=None, var2=var2, var3=va
50
93
  int(val)
51
94
  except Exception:
52
95
  val2 = 0
53
- if display_positive and not posified and not same:
96
+ if display_positive and not posified:
54
97
  q_dict = factor_out_q_keep_factored(val)
55
98
  for q_part in q_dict:
56
99
  try:
57
100
  val2 += q_part * int(q_dict[q_part])
58
101
  except Exception:
59
- try:
60
- if len(perms) == 2:
61
- u = tuple(permtrim([*perms[0]]))
62
- v = tuple(permtrim([*perms[1]]))
63
- if (
64
- len(perms) == 2
65
- and code(inverse(perms[1])) == medium_theta(inverse(perms[1]))
66
- and not mult
67
- and not slow
68
- and not nilhecke_apply
69
- ):
70
- val2 += q_part * q_dict[q_part]
71
- else:
72
- q_part2 = q_part
73
- if not mult and not nilhecke_apply and len(perms) == 2:
74
- qv = q_vector(q_part)
75
- u2, v2, w2 = u, v, perm
76
- u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
77
- while did_one:
102
+ if same:
103
+ to_add = q_part * expand(sympify(q_dict[q_part]).xreplace(subs_dict2))
104
+ val2 += to_add
105
+ else:
106
+ try:
107
+ if len(perms) == 2:
108
+ u = tuple(permtrim([*perms[0]]))
109
+ v = tuple(permtrim([*perms[1]]))
110
+ if (
111
+ len(perms) == 2
112
+ and code(inverse(perms[1])) == medium_theta(inverse(perms[1]))
113
+ and not mult
114
+ and not slow
115
+ and not nilhecke_apply
116
+ ):
117
+ val2 += q_part * q_dict[q_part]
118
+ else:
119
+ q_part2 = q_part
120
+ if not mult and not nilhecke_apply and len(perms) == 2:
121
+ qv = q_vector(q_part)
122
+ u2, v2, w2 = u, v, perm
78
123
  u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
79
- q_part2 = np.prod(
80
- [q_var[i + 1] ** qv[i] for i in range(len(qv))]
81
- )
82
- if q_part2 == 1:
83
- # reduced to classical coefficient
84
- val2 += q_part * posify(
85
- q_dict[q_part],
86
- u2,
87
- v2,
88
- w2,
89
- var2_t,
90
- var3_t,
91
- msg,
92
- False,
124
+ while did_one:
125
+ u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
126
+ q_part2 = np.prod(
127
+ [q_var[i + 1] ** qv[i] for i in range(len(qv))],
93
128
  )
129
+ if q_part2 == 1:
130
+ # reduced to classical coefficient
131
+ val2 += q_part * posify(
132
+ q_dict[q_part],
133
+ u2,
134
+ v2,
135
+ w2,
136
+ var2,
137
+ var3,
138
+ msg,
139
+ False,
140
+ )
141
+ else:
142
+ val2 += q_part * compute_positive_rep(
143
+ q_dict[q_part],
144
+ var2,
145
+ var3,
146
+ msg,
147
+ False,
148
+ )
94
149
  else:
95
150
  val2 += q_part * compute_positive_rep(
96
151
  q_dict[q_part],
97
- var2_t,
98
- var3_t,
152
+ var2,
153
+ var3,
99
154
  msg,
100
155
  False,
101
156
  )
102
- else:
103
- val2 += q_part * compute_positive_rep(
104
- q_dict[q_part],
105
- var2_t,
106
- var3_t,
107
- msg,
108
- False,
157
+ except Exception as e:
158
+ if mult:
159
+ print(
160
+ "warning; --display-positive is on but result is not positive",
161
+ file=sys.stderr,
109
162
  )
110
- except Exception as e:
111
- if mult:
112
- print(
113
- "warning; --display-positive is on but result is not positive",
114
- file=sys.stderr,
115
- )
116
- val2 = val
117
- break
118
- else:
163
+ val2 = val
164
+ break
119
165
  print(
120
- f"error; write to schubmult@gmail.com with the case {perms=} {perm=} {val=} {coeff_dict.get(perm,0)=}"
166
+ f"error; write to schubmult@gmail.com with the case {perms=} {perm=} {val=} {coeff_dict.get(perm,0)=}",
121
167
  )
122
168
  print(f"Exception: {e}")
123
169
  import traceback
124
170
 
125
171
  traceback.print_exc()
126
172
  exit(1)
127
- if check and expand(val - val2) != 0:
173
+ if not same and check and expand(val - val2) != 0:
128
174
  if mult:
129
175
  val2 = val
130
176
  else:
131
177
  print(
132
- f"error: value not equal; write to schubmult@gmail.com with the case {perms=} {perm=} {val2=} {coeff_dict.get(perm,0)=}"
178
+ f"error: value not equal; write to schubmult@gmail.com with the case {perms=} {perm=} {val2=} {coeff_dict.get(perm,0)=}",
133
179
  )
134
180
  exit(1)
135
181
  val = val2
136
- if same and display_positive:
137
- if same:
138
- subs_dict = {}
139
- for i in range(1, 100):
140
- sm = var2[1]
141
- for j in range(1, i):
142
- sm += var_r[j]
143
- subs_dict[var2[i]] = sm
144
- val = sympify(coeff_dict[perm]).subs(subs_dict)
145
- elif expa:
182
+ if expa:
146
183
  val = expand(val)
147
184
  if val != 0:
148
185
  if ascode:
149
- print(f"{str(trimcode(perm))} {formatter(val)}")
186
+ raw_result_dict[tuple(trimcode(perm))] = val
187
+ if formatter:
188
+ print(f"{trimcode(perm)!s} {formatter(val)}")
150
189
  else:
151
- print(f"{str(perm)} {formatter(val)}")
190
+ raw_result_dict[tuple(perm)] = val
191
+ if formatter:
192
+ print(f"{perm!s} {formatter(val)}")
193
+ return raw_result_dict
152
194
 
153
195
 
154
- def main():
155
- global var2, var3
196
+ def main(argv=None):
197
+ if argv is None:
198
+ argv = sys.argv
199
+ var2 = tuple(symarray("y", 100))
200
+ var3 = tuple(symarray("z", 100))
156
201
  try:
157
202
  sys.setrecursionlimit(1000000)
158
203
 
@@ -161,17 +206,24 @@ def main():
161
206
  "Compute coefficients of products of quantum double Schubert polynomials in the same or different sets of coefficient variables",
162
207
  yz=True,
163
208
  quantum=True,
209
+ argv=argv[1:],
164
210
  )
211
+ subs_dict2 = {}
212
+ for i in range(1, 100):
213
+ sm = var2[1]
214
+ for j in range(1, i):
215
+ sm += _vars.var_r[j]
216
+ subs_dict2[var2[i]] = sm
165
217
 
166
- mult = args.mult
167
- mulstring = args.mulstring
218
+ mult = args.mult # noqa: F841
219
+ mulstring = args.mulstring # noqa: F841
168
220
 
169
221
  perms = args.perms
170
222
 
171
223
  ascode = args.ascode
172
224
  msg = args.msg
173
225
  display_positive = args.display_positive
174
- pr = args.pr
226
+ pr = args.pr
175
227
  parabolic_index = [int(s) for s in args.parabolic]
176
228
  parabolic = len(parabolic_index) != 0
177
229
  slow = args.slow
@@ -212,18 +264,18 @@ def main():
212
264
  coeff_dict = schubmult_db(coeff_dict, perm, var2, var3)
213
265
  else:
214
266
  coeff_dict = schubmult(coeff_dict, perm, var2, var3)
215
- if mult:
216
- for v in var2:
217
- globals()[str(v)] = v
218
- for v in var3:
219
- globals()[str(v)] = v
220
- for v in var_x:
221
- globals()[str(v)] = v
222
- for v in q_var:
223
- globals()[str(v)] = v
224
-
225
- mul_exp = eval(mulstring)
226
- coeff_dict = mult_poly(coeff_dict, mul_exp)
267
+ # if mult:
268
+ # for v in var2:
269
+ # globals()[str(v)] = v
270
+ # for v in var3:
271
+ # globals()[str(v)] = v
272
+ # for v in var_x:
273
+ # globals()[str(v)] = v
274
+ # for v in q_var:
275
+ # globals()[str(v)] = v
276
+
277
+ # mul_exp = eval(mulstring)
278
+ # coeff_dict = mult_poly(coeff_dict, mul_exp)
227
279
 
228
280
  posified = False
229
281
  if parabolic:
@@ -259,11 +311,10 @@ def main():
259
311
 
260
312
  new_q_part = np.prod(
261
313
  [
262
- q_var[index + 1 - count_less_than(parabolic_index, index + 1)]
263
- ** qv[index]
314
+ q_var[index + 1 - count_less_than(parabolic_index, index + 1)] ** qv[index]
264
315
  for index in range(len(qv))
265
316
  if index + 1 not in parabolic_index
266
- ]
317
+ ],
267
318
  )
268
319
 
269
320
  try:
@@ -271,67 +322,75 @@ def main():
271
322
  except Exception:
272
323
  pass
273
324
  q_val_part = q_dict[q_part]
274
- if display_positive and not same:
325
+ if display_positive:
275
326
  try:
276
327
  q_val_part = int(q_val_part)
277
328
  except Exception:
278
- try:
279
- if len(perms) == 2 and q_part == 1:
280
- u = permtrim([*perms[0]])
281
- v = permtrim([*perms[1]])
282
- q_val_part = posify(
283
- q_dict[q_part],
284
- tuple(u),
285
- tuple(v),
286
- w_1,
287
- var2_t,
288
- var3_t,
289
- msg,
290
- False,
291
- )
292
- else:
293
- qv = q_vector(q_part)
294
- u2, v2, w2 = perms[0], perms[1], w_1
295
- u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
296
- while did_one:
297
- u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
298
- q_part2 = np.prod(
299
- [q_var[i + 1] ** qv[i] for i in range(len(qv))]
300
- )
301
- if q_part2 == 1:
329
+ if same:
330
+ q_val_part = expand(sympify(q_val_part).xreplace(subs_dict2))
331
+ else:
332
+ try:
333
+ if len(perms) == 2 and q_part == 1:
334
+ u = permtrim([*perms[0]])
335
+ v = permtrim([*perms[1]])
302
336
  q_val_part = posify(
303
337
  q_dict[q_part],
304
- u2,
305
- v2,
306
- w2,
307
- var2_t,
308
- var3_t,
338
+ tuple(u),
339
+ tuple(v),
340
+ w_1,
341
+ var2,
342
+ var3,
309
343
  msg,
310
344
  False,
311
345
  )
312
346
  else:
313
- q_val_part = compute_positive_rep(
314
- q_dict[q_part],
315
- var2_t,
316
- var3_t,
317
- msg,
318
- False,
347
+ qv = q_vector(q_part)
348
+ u2, v2, w2 = perms[0], perms[1], w_1
349
+ u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
350
+ while did_one:
351
+ u2, v2, w2, qv, did_one = reduce_q_coeff(u2, v2, w2, qv)
352
+ q_part2 = np.prod(
353
+ [q_var[i + 1] ** qv[i] for i in range(len(qv))],
319
354
  )
320
- except Exception as e:
321
- print(
322
- f"error; write to schubmult@gmail.com with the case {perms=} {perm=} {q_part*q_val_part=} {coeff_dict.get(w_1,0)=}"
323
- )
324
- print(f"Exception: {e}")
325
- exit(1)
326
- coeff_dict_update[w] = coeff_dict_update.get(w, 0) + new_q_part * q_val_part
355
+ if q_part2 == 1:
356
+ q_val_part = posify(
357
+ q_dict[q_part],
358
+ u2,
359
+ v2,
360
+ w2,
361
+ var2,
362
+ var3,
363
+ msg,
364
+ False,
365
+ )
366
+ else:
367
+ q_val_part = compute_positive_rep(
368
+ q_dict[q_part],
369
+ var2,
370
+ var3,
371
+ msg,
372
+ False,
373
+ )
374
+ except Exception as e:
375
+ print(
376
+ f"error; write to schubmult@gmail.com with the case {perms=} {perm=} {q_part*q_val_part=} {coeff_dict.get(w_1,0)=}",
377
+ )
378
+ print(f"Exception: {e}")
379
+ exit(1)
380
+ coeff_dict_update[w] = coeff_dict_update.get(w, 0) + new_q_part * q_val_part
327
381
 
328
382
  coeff_dict = coeff_dict_update
329
383
 
330
- if pr:
331
- _display_full(coeff_dict, args, formatter, posified)
384
+ raw_result_dict = {}
385
+ if pr or formatter is None:
386
+ raw_result_dict = _display_full(coeff_dict, args, formatter, posified)
387
+ if formatter is None:
388
+ return raw_result_dict
332
389
  except BrokenPipeError:
333
390
  pass
334
391
 
335
392
 
336
393
  if __name__ == "__main__":
337
- main()
394
+ import sys
395
+
396
+ sys.exit(main(sys.argv))