mathai 0.7.9__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mathai/apart.py +5 -2
- mathai/base.py +1 -1
- mathai/expand.py +2 -20
- mathai/fraction.py +5 -24
- mathai/linear.py +16 -21
- mathai/logic.py +5 -8
- mathai/matrix.py +12 -52
- mathai/parser.py +2 -8
- mathai/simplify.py +26 -175
- mathai/tool.py +19 -28
- mathai/trig.py +44 -33
- mathai/univariate_inequality.py +1 -2
- {mathai-0.7.9.dist-info → mathai-0.8.1.dist-info}/METADATA +1 -1
- mathai-0.8.1.dist-info/RECORD +28 -0
- mathai/console.py +0 -84
- mathai/parsetab.py +0 -61
- mathai-0.7.9.dist-info/RECORD +0 -30
- {mathai-0.7.9.dist-info → mathai-0.8.1.dist-info}/WHEEL +0 -0
- {mathai-0.7.9.dist-info → mathai-0.8.1.dist-info}/top_level.txt +0 -0
mathai/apart.py
CHANGED
|
@@ -107,13 +107,14 @@ def _apart(eq, v=None):
|
|
|
107
107
|
|
|
108
108
|
lst = poly(s.children[0], v)
|
|
109
109
|
|
|
110
|
-
lst = [TreeNode("f_eq", [item, tree_form("d_0")]) for item in lst if "v_" in str_form(item)]
|
|
110
|
+
lst = [simplify(TreeNode("f_eq", [item, tree_form("d_0")])) for item in lst if "v_" in str_form(item)]
|
|
111
111
|
lst2 = []
|
|
112
112
|
for item in lst:
|
|
113
113
|
lst2+=vlist(item)
|
|
114
114
|
origv = list(set(lst2)-set(origv))
|
|
115
115
|
|
|
116
116
|
out = linear_solve(TreeNode("f_and", lst), [tree_form(item) for item in origv])
|
|
117
|
+
|
|
117
118
|
for item in out.children:
|
|
118
119
|
|
|
119
120
|
final3 = replace(final3, tree_form(list(set(vlist(item))&set(origv))[0]), inverse(item.children[0], list(set(vlist(item))&set(origv))[0]))
|
|
@@ -139,4 +140,6 @@ def apart(eq):
|
|
|
139
140
|
return eq2
|
|
140
141
|
|
|
141
142
|
return TreeNode(eq.name, [helper(child) for child in eq.children])
|
|
142
|
-
|
|
143
|
+
eq = helper(eq)
|
|
144
|
+
eq = fx(eq)
|
|
145
|
+
return eq
|
mathai/base.py
CHANGED
|
@@ -334,7 +334,7 @@ def vlist(eq):
|
|
|
334
334
|
out.append(eq.name)
|
|
335
335
|
for child in eq.children:
|
|
336
336
|
out += vlist(child)
|
|
337
|
-
return sorted(list(set(out)), key=lambda x: int(x[2:]))
|
|
337
|
+
return list(sorted(list(set(out)), key=lambda x: int(x[2:])))
|
|
338
338
|
def product(lst):
|
|
339
339
|
if lst == []:
|
|
340
340
|
return tree_form("d_1")
|
mathai/expand.py
CHANGED
|
@@ -3,21 +3,12 @@ from .simplify import simplify
|
|
|
3
3
|
import itertools
|
|
4
4
|
|
|
5
5
|
def expand_nc(expr, label="f_mul"):
|
|
6
|
-
|
|
7
|
-
Expand expression where:
|
|
8
|
-
- f_add is commutative
|
|
9
|
-
- label (@) is NON-commutative
|
|
10
|
-
"""
|
|
11
|
-
# --- base cases ---
|
|
6
|
+
|
|
12
7
|
if expr.name not in {"f_add", label, "f_pow"}:
|
|
13
8
|
return expr
|
|
14
9
|
|
|
15
|
-
# --- expand children first ---
|
|
16
10
|
expr.children = [expand_nc(c, label) for c in expr.children]
|
|
17
11
|
|
|
18
|
-
# ==========================================================
|
|
19
|
-
# POWER: (A + B)^n only if n is positive integer
|
|
20
|
-
# ==========================================================
|
|
21
12
|
if expr.name == "f_pow":
|
|
22
13
|
base, exp = expr.children
|
|
23
14
|
n = frac(exp)
|
|
@@ -26,9 +17,6 @@ def expand_nc(expr, label="f_mul"):
|
|
|
26
17
|
return expand_nc(TreeNode(label, factors), label)
|
|
27
18
|
return expr
|
|
28
19
|
|
|
29
|
-
# ==========================================================
|
|
30
|
-
# ADDITION (commutative)
|
|
31
|
-
# ==========================================================
|
|
32
20
|
if expr.name == "f_add":
|
|
33
21
|
out = []
|
|
34
22
|
for c in expr.children:
|
|
@@ -38,20 +26,15 @@ def expand_nc(expr, label="f_mul"):
|
|
|
38
26
|
out.append(c)
|
|
39
27
|
return TreeNode("f_add", out)
|
|
40
28
|
|
|
41
|
-
# ==========================================================
|
|
42
|
-
# NON-COMMUTATIVE MULTIPLICATION (@)
|
|
43
|
-
# ==========================================================
|
|
44
29
|
if expr.name == label:
|
|
45
30
|
factors = []
|
|
46
31
|
|
|
47
|
-
# flatten only (NO reordering)
|
|
48
32
|
for c in expr.children:
|
|
49
33
|
if c.name == label:
|
|
50
34
|
factors.extend(c.children)
|
|
51
35
|
else:
|
|
52
36
|
factors.append(c)
|
|
53
37
|
|
|
54
|
-
# find first additive factor
|
|
55
38
|
for i, f in enumerate(factors):
|
|
56
39
|
if f.name == "f_add":
|
|
57
40
|
left = factors[:i]
|
|
@@ -66,13 +49,12 @@ def expand_nc(expr, label="f_mul"):
|
|
|
66
49
|
|
|
67
50
|
return TreeNode("f_add", terms)
|
|
68
51
|
|
|
69
|
-
# no addition inside → return as-is
|
|
70
52
|
return TreeNode(label, factors)
|
|
71
53
|
|
|
72
|
-
|
|
73
54
|
def expand2(eq, over="*"):
|
|
74
55
|
over = {"@": "f_wmul", ".":"f_dot", "*":"f_mul"}[over]
|
|
75
56
|
return expand_nc(eq, over)
|
|
76
57
|
def expand(eq, over="*"):
|
|
77
58
|
eq = expand2(eq, over)
|
|
78
59
|
return TreeNode(eq.name, [expand(child, over) for child in eq.children])
|
|
60
|
+
|
mathai/fraction.py
CHANGED
|
@@ -2,30 +2,22 @@ from .base import *
|
|
|
2
2
|
from .simplify import simplify
|
|
3
3
|
from .expand import expand
|
|
4
4
|
|
|
5
|
-
|
|
6
5
|
def fraction(expr):
|
|
7
6
|
if expr is None:
|
|
8
7
|
return None
|
|
9
|
-
|
|
8
|
+
|
|
10
9
|
expr = simplify(expr)
|
|
11
|
-
|
|
12
|
-
# -----------------------------
|
|
13
|
-
# leaf
|
|
14
|
-
# -----------------------------
|
|
10
|
+
|
|
15
11
|
if expr.children == []:
|
|
16
12
|
return expr
|
|
17
13
|
|
|
18
|
-
# recurse first (inner-most first)
|
|
19
14
|
children = [fraction(c) for c in expr.children]
|
|
20
15
|
|
|
21
|
-
# -----------------------------
|
|
22
|
-
# ADDITION: collect denominators
|
|
23
|
-
# -----------------------------
|
|
24
16
|
if expr.name == "f_add":
|
|
25
17
|
terms = []
|
|
26
18
|
|
|
27
19
|
for c in children:
|
|
28
|
-
|
|
20
|
+
|
|
29
21
|
if c.name == "f_mul":
|
|
30
22
|
num = []
|
|
31
23
|
den = []
|
|
@@ -45,7 +37,6 @@ def fraction(expr):
|
|
|
45
37
|
num.append(f)
|
|
46
38
|
terms.append((num, den))
|
|
47
39
|
|
|
48
|
-
# pure reciprocal
|
|
49
40
|
elif (
|
|
50
41
|
c.name == "f_pow"
|
|
51
42
|
and c.children[1].name.startswith("d_")
|
|
@@ -58,17 +49,12 @@ def fraction(expr):
|
|
|
58
49
|
else TreeNode("f_pow", [c.children[0], tree_form(f"d_{-n}")])
|
|
59
50
|
]))
|
|
60
51
|
|
|
61
|
-
# normal term
|
|
62
52
|
else:
|
|
63
53
|
terms.append(([c], []))
|
|
64
54
|
|
|
65
|
-
# if no denominators → rebuild normally
|
|
66
55
|
if not any(den for _, den in terms):
|
|
67
56
|
return TreeNode("f_add", children)
|
|
68
57
|
|
|
69
|
-
# -----------------------------
|
|
70
|
-
# build numerator
|
|
71
|
-
# -----------------------------
|
|
72
58
|
num_terms = []
|
|
73
59
|
for i, (num_i, _) in enumerate(terms):
|
|
74
60
|
acc = list(num_i)
|
|
@@ -83,16 +69,13 @@ def fraction(expr):
|
|
|
83
69
|
|
|
84
70
|
numerator = TreeNode("f_add", num_terms)
|
|
85
71
|
|
|
86
|
-
# -----------------------------
|
|
87
|
-
# build denominator
|
|
88
|
-
# -----------------------------
|
|
89
72
|
den_all = []
|
|
90
73
|
for _, den in terms:
|
|
91
74
|
den_all += den
|
|
92
75
|
|
|
93
76
|
denom = den_all[0] if len(den_all) == 1 else TreeNode("f_mul", den_all)
|
|
94
77
|
denom = TreeNode("f_pow", [denom, tree_form("d_-1")])
|
|
95
|
-
|
|
78
|
+
|
|
96
79
|
return simplify(
|
|
97
80
|
TreeNode(
|
|
98
81
|
"f_mul",
|
|
@@ -100,7 +83,5 @@ def fraction(expr):
|
|
|
100
83
|
)
|
|
101
84
|
)
|
|
102
85
|
|
|
103
|
-
# -----------------------------
|
|
104
|
-
# default reconstruction
|
|
105
|
-
# -----------------------------
|
|
106
86
|
return TreeNode(expr.name, children)
|
|
87
|
+
|
mathai/linear.py
CHANGED
|
@@ -8,7 +8,7 @@ from .base import *
|
|
|
8
8
|
from .factor import factorconst
|
|
9
9
|
from .tool import poly
|
|
10
10
|
def ss(eq):
|
|
11
|
-
return dowhile(eq, lambda x: fraction(
|
|
11
|
+
return dowhile(eq, lambda x: fraction(simplify(x)))
|
|
12
12
|
def rref(matrix):
|
|
13
13
|
rows, cols = len(matrix), len(matrix[0])
|
|
14
14
|
lead = 0
|
|
@@ -34,16 +34,17 @@ def rref(matrix):
|
|
|
34
34
|
return matrix
|
|
35
35
|
def islinear(eq, fxconst):
|
|
36
36
|
eq =simplify(eq)
|
|
37
|
-
if all(fxconst(tree_form(item)) and poly(eq, item) is not None and len(poly(eq, item)) <= 2
|
|
37
|
+
if all(not fxconst(tree_form(item)) or (fxconst(tree_form(item)) and poly(eq, item) is not None and len(poly(eq, item)) <= 2)for item in vlist(eq)):
|
|
38
38
|
return True
|
|
39
|
+
else:
|
|
40
|
+
pass
|
|
39
41
|
return False
|
|
40
42
|
def linear(eqlist, fxconst):
|
|
41
43
|
orig = [item.copy_tree() for item in eqlist]
|
|
42
|
-
#eqlist = [eq for eq in eqlist if fxconst(eq)]
|
|
43
44
|
|
|
44
45
|
if eqlist == [] or not all(islinear(eq, fxconst) for eq in eqlist):
|
|
45
46
|
return None
|
|
46
|
-
|
|
47
|
+
|
|
47
48
|
vl = []
|
|
48
49
|
def varlist(eq, fxconst):
|
|
49
50
|
nonlocal vl
|
|
@@ -58,6 +59,7 @@ def linear(eqlist, fxconst):
|
|
|
58
59
|
if len(vl) > len(eqlist):
|
|
59
60
|
return TreeNode("f_and", [TreeNode("f_eq", [x, tree_form("d_0")]) for x in eqlist])
|
|
60
61
|
m = []
|
|
62
|
+
|
|
61
63
|
for eq in eqlist:
|
|
62
64
|
s = copy.deepcopy(eq)
|
|
63
65
|
row = []
|
|
@@ -71,11 +73,11 @@ def linear(eqlist, fxconst):
|
|
|
71
73
|
m[i][j] = simplify(expand(m[i][j]))
|
|
72
74
|
|
|
73
75
|
m = rref(m)
|
|
74
|
-
|
|
76
|
+
|
|
75
77
|
for i in range(len(m)):
|
|
76
78
|
for j in range(len(m[i])):
|
|
77
79
|
m[i][j] = fraction(m[i][j])
|
|
78
|
-
|
|
80
|
+
|
|
79
81
|
output = []
|
|
80
82
|
for index, row in enumerate(m):
|
|
81
83
|
if not all(item == 0 for item in row[:-1]):
|
|
@@ -88,26 +90,17 @@ def linear(eqlist, fxconst):
|
|
|
88
90
|
return tree_form("s_false")
|
|
89
91
|
return TreeNode("f_and", [TreeNode("f_eq", [x, tree_form("d_0")]) for x in output])
|
|
90
92
|
def order_collinear_indices(points, idx):
|
|
91
|
-
"""
|
|
92
|
-
Arrange a subset of collinear points (given by indices) along their line.
|
|
93
93
|
|
|
94
|
-
points: list of (x, y) tuples
|
|
95
|
-
idx: list of indices referring to points
|
|
96
|
-
Returns: list of indices sorted along the line
|
|
97
|
-
"""
|
|
98
94
|
if len(idx) <= 1:
|
|
99
95
|
return idx[:]
|
|
100
|
-
|
|
101
|
-
# Take first two points from the subset to define the line
|
|
96
|
+
|
|
102
97
|
p0, p1 = points[idx[0]], points[idx[1]]
|
|
103
98
|
dx, dy = p1[0] - p0[0], p1[1] - p0[1]
|
|
104
|
-
|
|
105
|
-
# Projection factor for sorting
|
|
99
|
+
|
|
106
100
|
def projection_factor(i):
|
|
107
101
|
vx, vy = points[i][0] - p0[0], points[i][1] - p0[1]
|
|
108
102
|
return compute((vx * dx + vy * dy) / (dx**2 + dy**2))
|
|
109
|
-
|
|
110
|
-
# Sort indices by projection
|
|
103
|
+
|
|
111
104
|
sorted_idx = sorted(idx, key=projection_factor)
|
|
112
105
|
return list(sorted_idx)
|
|
113
106
|
def linear_or(eq):
|
|
@@ -124,12 +117,12 @@ def linear_or(eq):
|
|
|
124
117
|
for item in itertools.combinations(enumerate(eqlst), 2):
|
|
125
118
|
x, y = item[0][0], item[1][0]
|
|
126
119
|
item = [item[0][1], item[1][1]]
|
|
127
|
-
|
|
120
|
+
|
|
128
121
|
out = linear_solve(TreeNode("f_and", list(item)))
|
|
129
122
|
|
|
130
123
|
if out is None:
|
|
131
124
|
return None
|
|
132
|
-
|
|
125
|
+
|
|
133
126
|
if out.name == "f_and" and all(len(vlist(child)) == 1 for child in out.children) and set(vlist(out)) == set(v) and all(len(vlist(simplify(child))) >0 for child in out.children):
|
|
134
127
|
t = {}
|
|
135
128
|
for child in out.children:
|
|
@@ -151,7 +144,7 @@ def linear_solve(eq, lst=None):
|
|
|
151
144
|
eq = simplify(eq)
|
|
152
145
|
eqlist = []
|
|
153
146
|
if eq.name =="f_and" and all(child.name == "f_eq" and child.children[1] == 0 for child in eq.children):
|
|
154
|
-
|
|
147
|
+
|
|
155
148
|
eqlist = [child.children[0] for child in eq.children]
|
|
156
149
|
else:
|
|
157
150
|
return eq
|
|
@@ -159,7 +152,9 @@ def linear_solve(eq, lst=None):
|
|
|
159
152
|
if lst is None:
|
|
160
153
|
out = linear(copy.deepcopy(eqlist), lambda x: "v_" in str_form(x))
|
|
161
154
|
else:
|
|
155
|
+
|
|
162
156
|
out = linear(copy.deepcopy(eqlist), lambda x: any(contain(x, item) for item in lst))
|
|
163
157
|
if out is None:
|
|
164
158
|
return None
|
|
165
159
|
return simplify(out)
|
|
160
|
+
|
mathai/logic.py
CHANGED
|
@@ -108,7 +108,7 @@ def logic2(eq):
|
|
|
108
108
|
if len(lst) == 1:
|
|
109
109
|
return lst[0]
|
|
110
110
|
return TreeNode(eq.name, lst)
|
|
111
|
-
|
|
111
|
+
|
|
112
112
|
if eq.name in ["f_and", "f_or"] and any(child.children is not None and len(child.children)!=0 for child in eq.children):
|
|
113
113
|
for i in range(len(eq.children),1,-1):
|
|
114
114
|
for item in itertools.combinations(enumerate(eq.children), i):
|
|
@@ -159,7 +159,7 @@ def logic1(eq):
|
|
|
159
159
|
A, B = dowhile(A, logic2), dowhile(B, logic2)
|
|
160
160
|
return flatten_tree((A & B) | (A.fx("not") & B.fx("not")))
|
|
161
161
|
if eq.name == "f_imply":
|
|
162
|
-
|
|
162
|
+
|
|
163
163
|
A, B = eq.children
|
|
164
164
|
A, B = logic1(A), logic1(B)
|
|
165
165
|
A, B = dowhile(A, logic2), dowhile(B, logic2)
|
|
@@ -171,32 +171,28 @@ def logic1(eq):
|
|
|
171
171
|
return eq
|
|
172
172
|
eq = helper(eq)
|
|
173
173
|
eq = flatten_tree(eq)
|
|
174
|
-
|
|
174
|
+
|
|
175
175
|
if len(eq.children) > 2:
|
|
176
176
|
lst = []
|
|
177
177
|
l = len(eq.children)
|
|
178
178
|
|
|
179
|
-
# Handle last odd child directly
|
|
180
179
|
if l % 2 == 1:
|
|
181
180
|
last_child = eq.children[-1]
|
|
182
|
-
|
|
181
|
+
|
|
183
182
|
if isinstance(last_child, TreeNode):
|
|
184
183
|
last_child = dowhile(last_child, logic2)
|
|
185
184
|
lst.append(last_child)
|
|
186
185
|
l -= 1
|
|
187
186
|
|
|
188
|
-
# Pairwise combine children
|
|
189
187
|
for i in range(0, l, 2):
|
|
190
188
|
left, right = eq.children[i], eq.children[i+1]
|
|
191
189
|
pair = TreeNode(eq.name, [left, right])
|
|
192
190
|
simplified = dowhile(logic1(pair), logic2)
|
|
193
191
|
lst.append(simplified)
|
|
194
192
|
|
|
195
|
-
# If only one element left, just return it instead of nesting
|
|
196
193
|
if len(lst) == 1:
|
|
197
194
|
return flatten_tree(lst[0])
|
|
198
195
|
|
|
199
|
-
# Otherwise rewrap
|
|
200
196
|
return flatten_tree(TreeNode(eq.name, lst))
|
|
201
197
|
|
|
202
198
|
if eq.name == "f_and":
|
|
@@ -228,3 +224,4 @@ def logic1(eq):
|
|
|
228
224
|
out = out.children[0]
|
|
229
225
|
return flatten_tree(out)
|
|
230
226
|
return TreeNode(eq.name, [logic1(child) for child in eq.children])
|
|
227
|
+
|
mathai/matrix.py
CHANGED
|
@@ -3,7 +3,6 @@ import copy
|
|
|
3
3
|
from .simplify import simplify
|
|
4
4
|
import itertools
|
|
5
5
|
|
|
6
|
-
# ---------- tree <-> python list ----------
|
|
7
6
|
def tree_to_py(node):
|
|
8
7
|
if node.name=="f_list":
|
|
9
8
|
return [tree_to_py(c) for c in node.children]
|
|
@@ -14,16 +13,13 @@ def py_to_tree(obj):
|
|
|
14
13
|
return TreeNode("f_list",[py_to_tree(x) for x in obj])
|
|
15
14
|
return obj
|
|
16
15
|
|
|
17
|
-
# ---------- shape detection ----------
|
|
18
16
|
def is_vector(x):
|
|
19
17
|
return isinstance(x,list) and all(isinstance(item,TreeNode) for item in x)
|
|
20
18
|
def is_mat(x):
|
|
21
19
|
return isinstance(x,list) and all(isinstance(item,list) for item in x)
|
|
22
20
|
def is_matrix(x):
|
|
23
21
|
return isinstance(x, list) and all(isinstance(item, list) and (is_mat(item) or is_vector(item)) for item in x)
|
|
24
|
-
|
|
25
22
|
|
|
26
|
-
# ---------- algebra primitives ----------
|
|
27
23
|
def dot(u,v):
|
|
28
24
|
if len(u)!=len(v):
|
|
29
25
|
raise ValueError("Vector size mismatch")
|
|
@@ -33,9 +29,7 @@ def dot(u,v):
|
|
|
33
29
|
return s
|
|
34
30
|
|
|
35
31
|
def matmul(A, B):
|
|
36
|
-
|
|
37
|
-
# B: m × p
|
|
38
|
-
|
|
32
|
+
|
|
39
33
|
n = len(A)
|
|
40
34
|
m = len(A[0])
|
|
41
35
|
p = len(B[0])
|
|
@@ -54,7 +48,6 @@ def matmul(A, B):
|
|
|
54
48
|
)
|
|
55
49
|
return C
|
|
56
50
|
|
|
57
|
-
# ---------- promotion ----------
|
|
58
51
|
def promote(node):
|
|
59
52
|
if node.name=="f_list":
|
|
60
53
|
return tree_to_py(node)
|
|
@@ -68,7 +61,7 @@ def contains_neg(node):
|
|
|
68
61
|
if not contains_neg(child):
|
|
69
62
|
return False
|
|
70
63
|
return True
|
|
71
|
-
|
|
64
|
+
|
|
72
65
|
def multiply(left,right):
|
|
73
66
|
if left == tree_form("d_1"):
|
|
74
67
|
return right
|
|
@@ -83,17 +76,16 @@ def multiply(left,right):
|
|
|
83
76
|
return simplify(left2.children[0]**(left2.children[1]+right2.children[1]))
|
|
84
77
|
A,B = promote(left), promote(right)
|
|
85
78
|
|
|
86
|
-
# vector · vector
|
|
87
79
|
if is_vector(A) and is_vector(B):
|
|
88
80
|
return dot(A,B)
|
|
89
|
-
|
|
81
|
+
|
|
90
82
|
if is_matrix(A) and is_matrix(B):
|
|
91
83
|
return py_to_tree(matmul(A,B))
|
|
92
|
-
|
|
84
|
+
|
|
93
85
|
for _ in range(2):
|
|
94
86
|
if contains_neg(A) and is_vector(B):
|
|
95
87
|
return py_to_tree([TreeNode("f_mul",[A,x]) for x in B])
|
|
96
|
-
|
|
88
|
+
|
|
97
89
|
if contains_neg(A) and is_matrix(B):
|
|
98
90
|
return py_to_tree([[TreeNode("f_mul",[A,x]) for x in row] for row in B])
|
|
99
91
|
A, B = B, A
|
|
@@ -122,45 +114,16 @@ def matadd(A, B):
|
|
|
122
114
|
]
|
|
123
115
|
def addition(left,right):
|
|
124
116
|
A,B = promote(left), promote(right)
|
|
125
|
-
|
|
117
|
+
|
|
126
118
|
if is_vector(A) and is_vector(B):
|
|
127
119
|
return add_vec(A,B)
|
|
128
|
-
|
|
120
|
+
|
|
129
121
|
if is_matrix(A) and is_matrix(B):
|
|
130
122
|
return py_to_tree(matadd(A,B))
|
|
131
123
|
return None
|
|
132
|
-
|
|
133
|
-
def fold_wmul(eq):
|
|
134
|
-
if eq.name == "f_pow" and eq.children[1].name.startswith("d_"):
|
|
135
|
-
n = int(eq.children[1].name[2:])
|
|
136
|
-
if n == 1:
|
|
137
|
-
eq = eq.children[0]
|
|
138
|
-
elif n > 1:
|
|
139
|
-
tmp = promote(eq.children[0])
|
|
140
|
-
if is_matrix(tmp):
|
|
141
|
-
orig =tmp
|
|
142
|
-
for i in range(n-1):
|
|
143
|
-
tmp = matmul(orig, tmp)
|
|
144
|
-
eq = py_to_tree(tmp)
|
|
145
|
-
elif eq.name in ["f_wmul", "f_add"]:
|
|
146
|
-
if len(eq.children) == 1:
|
|
147
|
-
eq = eq.children[0]
|
|
148
|
-
else:
|
|
149
|
-
i = len(eq.children)-1
|
|
150
|
-
while i>0:
|
|
151
|
-
if eq.name == "f_wmul":
|
|
152
|
-
out = multiply(eq.children[i-1], eq.children[i])
|
|
153
|
-
else:
|
|
154
|
-
out = addition(eq.children[i-1], eq.children[i])
|
|
155
|
-
if out is not None:
|
|
156
|
-
eq.children.pop(i)
|
|
157
|
-
eq.children.pop(i-1)
|
|
158
|
-
eq.children.insert(i-1,out)
|
|
159
|
-
i = i-1
|
|
160
|
-
return TreeNode(eq.name, [fold_wmul(child) for child in eq.children])
|
|
161
|
-
'''
|
|
124
|
+
|
|
162
125
|
def fold_wmul(root):
|
|
163
|
-
|
|
126
|
+
|
|
164
127
|
stack = [(root, False)]
|
|
165
128
|
newnode = {}
|
|
166
129
|
|
|
@@ -168,17 +131,15 @@ def fold_wmul(root):
|
|
|
168
131
|
node, visited = stack.pop()
|
|
169
132
|
|
|
170
133
|
if not visited:
|
|
171
|
-
|
|
134
|
+
|
|
172
135
|
stack.append((node, True))
|
|
173
136
|
for child in node.children:
|
|
174
137
|
stack.append((child, False))
|
|
175
138
|
else:
|
|
176
|
-
|
|
139
|
+
|
|
177
140
|
children = [newnode[c] for c in node.children]
|
|
178
141
|
eq = TreeNode(node.name, children)
|
|
179
142
|
|
|
180
|
-
# ---- original rewrite logic ----
|
|
181
|
-
|
|
182
143
|
if eq.name == "f_pow" and eq.children[1].name.startswith("d_"):
|
|
183
144
|
n = int(eq.children[1].name[2:])
|
|
184
145
|
if n == 1:
|
|
@@ -208,8 +169,6 @@ def fold_wmul(root):
|
|
|
208
169
|
eq.children.insert(i - 1, out)
|
|
209
170
|
i -= 1
|
|
210
171
|
|
|
211
|
-
# --------------------------------
|
|
212
|
-
|
|
213
172
|
newnode[node] = eq
|
|
214
173
|
|
|
215
174
|
return newnode[root]
|
|
@@ -223,3 +182,4 @@ def _matrix_solve(eq):
|
|
|
223
182
|
return eq
|
|
224
183
|
def matrix_solve(eq):
|
|
225
184
|
return _matrix_solve(eq)
|
|
185
|
+
|
mathai/parser.py
CHANGED
|
@@ -87,7 +87,6 @@ def parse(equation, funclist=None):
|
|
|
87
87
|
parser_main = Lark(grammar2, start='start', parser='lalr')
|
|
88
88
|
parse_tree = parser_main.parse(equation)
|
|
89
89
|
|
|
90
|
-
# Convert Lark tree to TreeNode
|
|
91
90
|
def convert_to_treenode(parse_tree):
|
|
92
91
|
if isinstance(parse_tree, Tree):
|
|
93
92
|
node = TreeNode(parse_tree.data)
|
|
@@ -96,7 +95,6 @@ def parse(equation, funclist=None):
|
|
|
96
95
|
else:
|
|
97
96
|
return TreeNode(str(parse_tree))
|
|
98
97
|
|
|
99
|
-
# Flatten unnecessary nodes like pass_through
|
|
100
98
|
def remove_past(equation):
|
|
101
99
|
if equation.name in {"number", "paren", "func", "variable", "pass_through", "cnumber", "string", "matrix"}:
|
|
102
100
|
if len(equation.children) == 1:
|
|
@@ -107,7 +105,6 @@ def parse(equation, funclist=None):
|
|
|
107
105
|
equation.children = [remove_past(child) for child in equation.children]
|
|
108
106
|
return equation
|
|
109
107
|
|
|
110
|
-
# Handle indices if any
|
|
111
108
|
def prefixindex(equation):
|
|
112
109
|
if equation.name == "base" and len(equation.children) > 1:
|
|
113
110
|
return TreeNode("index", [equation.children[0]] + equation.children[1].children)
|
|
@@ -117,16 +114,15 @@ def parse(equation, funclist=None):
|
|
|
117
114
|
tree_node = remove_past(tree_node)
|
|
118
115
|
tree_node = prefixindex(tree_node)
|
|
119
116
|
|
|
120
|
-
# Convert function names and constants
|
|
121
117
|
def fxchange(tree_node):
|
|
122
118
|
tmp3 = funclist if funclist is not None else []
|
|
123
119
|
if tree_node.name == "neg":
|
|
124
120
|
child = fxchange(tree_node.children[0])
|
|
125
|
-
|
|
121
|
+
|
|
126
122
|
if child.name.startswith("d_") and re.match(r"d_\d+(\.\d+)?$", child.name):
|
|
127
123
|
return TreeNode("d_" + str(-int(child.name[2:])))
|
|
128
124
|
else:
|
|
129
|
-
|
|
125
|
+
|
|
130
126
|
return TreeNode("f_sub", [tree_form("d_0"), child])
|
|
131
127
|
if tree_node.name == "pass_through":
|
|
132
128
|
return fxchange(tree_node.children[0])
|
|
@@ -137,11 +133,9 @@ def parse(equation, funclist=None):
|
|
|
137
133
|
|
|
138
134
|
tree_node = fxchange(tree_node)
|
|
139
135
|
|
|
140
|
-
# Replace common constants
|
|
141
136
|
for const in ["e","pi","kc","em","ec","anot","hbar","false","true","i","nabla"]:
|
|
142
137
|
tree_node = replace(tree_node, tree_form("d_"+const), tree_form("s_"+const))
|
|
143
138
|
|
|
144
|
-
# Map letters to variables
|
|
145
139
|
for i, c in enumerate(["x","y","z"] + [chr(x+ord("a")) for x in range(0,23)]):
|
|
146
140
|
tree_node = replace(tree_node, tree_form("d_"+c), tree_form("v_"+str(i)))
|
|
147
141
|
for i, c in enumerate([chr(x+ord("A")) for x in range(0,26)]):
|