mathai 0.4.8__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mathai/__init__.py +9 -8
- mathai/base.py +132 -34
- mathai/bivariate_inequality.py +317 -0
- mathai/diff.py +3 -3
- mathai/expand.py +159 -80
- mathai/factor.py +45 -21
- mathai/fraction.py +2 -2
- mathai/integrate.py +38 -19
- mathai/inverse.py +4 -4
- mathai/limit.py +94 -70
- mathai/linear.py +90 -81
- mathai/logic.py +7 -1
- mathai/matrix.py +228 -0
- mathai/parser.py +13 -7
- mathai/parsetab.py +61 -0
- mathai/printeq.py +12 -9
- mathai/simplify.py +511 -369
- mathai/structure.py +2 -2
- mathai/tool.py +2 -2
- mathai/trig.py +42 -25
- mathai/univariate_inequality.py +78 -30
- mathai-0.7.2.dist-info/METADATA +293 -0
- mathai-0.7.2.dist-info/RECORD +28 -0
- {mathai-0.4.8.dist-info → mathai-0.7.2.dist-info}/WHEEL +1 -1
- mathai-0.4.8.dist-info/METADATA +0 -234
- mathai-0.4.8.dist-info/RECORD +0 -25
- {mathai-0.4.8.dist-info → mathai-0.7.2.dist-info}/top_level.txt +0 -0
mathai/limit.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from .structure import structure
|
|
2
2
|
from .base import *
|
|
3
3
|
from .parser import parse
|
|
4
|
-
from .simplify import simplify
|
|
4
|
+
from .simplify import simplify
|
|
5
5
|
from .expand import expand
|
|
6
6
|
from .diff import diff
|
|
7
7
|
from .trig import trig0
|
|
8
8
|
from .fraction import fraction
|
|
9
|
-
from .printeq import
|
|
9
|
+
from .printeq import printeq
|
|
10
10
|
tab=0
|
|
11
11
|
def substitute_val(eq, val, var="v_0"):
|
|
12
12
|
eq = replace(eq, tree_form(var), tree_form("d_"+str(val)))
|
|
@@ -33,100 +33,124 @@ def check(num, den, var):
|
|
|
33
33
|
return simplify(n/d)
|
|
34
34
|
return False
|
|
35
35
|
def lhospital(num, den, steps,var):
|
|
36
|
-
logs = []
|
|
37
36
|
|
|
38
37
|
out = check(num, den, var)
|
|
39
38
|
|
|
40
39
|
if isinstance(out, TreeNode):
|
|
41
|
-
return out
|
|
40
|
+
return out
|
|
42
41
|
for _ in range(steps):
|
|
43
42
|
num2, den2 = map(lambda e: simplify(diff(e, var)), (num, den))
|
|
44
43
|
out = check(num2, den2, var)
|
|
45
44
|
if out is True:
|
|
46
45
|
num, den = num2, den2
|
|
47
|
-
logs += [(0,"lim x->0 "+printeq_str(simplify(num/den)))]
|
|
48
46
|
continue
|
|
49
47
|
if out is False:
|
|
50
48
|
eq2 = simplify(fraction(simplify(num/den)))
|
|
51
|
-
return eq2
|
|
52
|
-
return out
|
|
49
|
+
return eq2
|
|
50
|
+
return out
|
|
53
51
|
def lhospital2(eq, var):
|
|
54
52
|
eq= simplify(eq)
|
|
55
53
|
if eq is None:
|
|
56
54
|
return None
|
|
57
55
|
if not contain(eq, tree_form(var)):
|
|
58
|
-
return eq
|
|
56
|
+
return eq
|
|
59
57
|
num, dem = [simplify(item) for item in num_dem(eq)]
|
|
60
58
|
if num is None or dem is None:
|
|
61
|
-
return eq
|
|
59
|
+
return eq
|
|
62
60
|
|
|
63
61
|
return lhospital(num, dem, 10,var)
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
equation = equation/item
|
|
83
|
-
equation = equation*(tree_form("s_pi") - tmp["v_-1"])
|
|
84
|
-
break
|
|
85
|
-
tmp = structure(item, ls[1])
|
|
86
|
-
if tmp is not None and contain(tmp["v_-1"], var) and not contain(tmp["v_-2"], var):
|
|
87
|
-
item2 = substitute_val(tmp["v_-1"], 0, var.name)
|
|
88
|
-
item2 = expand(solve(item2))
|
|
89
|
-
if tree_form("d_0") == item2:
|
|
90
|
-
equation = equation/item
|
|
91
|
-
equation = solve(equation*tmp["v_-1"]*tmp["v_-2"].fx("log"))
|
|
92
|
-
break
|
|
93
|
-
tmp = structure(item, ls[2])
|
|
94
|
-
if tmp is not None and contain(tmp["v_-1"], var):
|
|
95
|
-
|
|
96
|
-
item2 = substitute_val(tmp["v_-1"], 0, var.name)
|
|
97
|
-
item2 = expand(solve(item2))
|
|
98
|
-
if tree_form("d_0") == item2:
|
|
99
|
-
equation = equation/item
|
|
100
|
-
equation = solve(equation*tmp["v_-1"])
|
|
101
|
-
break
|
|
102
|
-
tmp = structure(item, ls[3])
|
|
103
|
-
if tmp is not None and contain(tmp["v_-1"], var):
|
|
104
|
-
item2 = substitute_val(item, 0, var.name)
|
|
105
|
-
|
|
106
|
-
if tree_form("d_0") == expand(solve(item2)):
|
|
107
|
-
|
|
108
|
-
equation = equation/item
|
|
109
|
-
equation = equation*(tree_form("d_1") - tmp["v_-1"]**tree_form("d_2"))
|
|
110
|
-
break
|
|
62
|
+
def limit0(equation):
|
|
63
|
+
if equation.name == "f_ref":
|
|
64
|
+
return equation
|
|
65
|
+
eq2 = equation
|
|
66
|
+
g = ["f_limit", "f_limitpinf", "f_limitninf"]
|
|
67
|
+
if eq2.name in g and contain(eq2.children[0], eq2.children[1]):
|
|
68
|
+
equation = eq2.children[0]
|
|
69
|
+
wrt = eq2.children[1]
|
|
70
|
+
lst = factor_generation(equation)
|
|
71
|
+
|
|
72
|
+
lst_const = [item for item in lst if not contain(item, wrt)]
|
|
73
|
+
if lst_const != []:
|
|
74
|
+
|
|
75
|
+
equation = product([item for item in lst if contain(item, wrt)]).copy_tree()
|
|
76
|
+
const = product(lst_const)
|
|
77
|
+
const = simplify(const)
|
|
78
|
+
|
|
79
|
+
if not contain(const, tree_form("s_i")):
|
|
111
80
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
return
|
|
115
|
-
def
|
|
116
|
-
|
|
81
|
+
return limit0(TreeNode(equation.name,[equation, wrt])) *const
|
|
82
|
+
equation = eq2
|
|
83
|
+
return TreeNode(equation.name, [limit0(child) for child in equation.children])
|
|
84
|
+
def limit2(eq):
|
|
85
|
+
g = ["f_limit", "f_limitpinf", "f_limitninf"]
|
|
86
|
+
if eq.name in g and eq.children[0].name == "f_add":
|
|
87
|
+
eq = summation([TreeNode(eq.name, [child, eq.children[1]]) for child in eq.children[0].children])
|
|
88
|
+
return TreeNode(eq.name, [limit2(child) for child in eq.children])
|
|
89
|
+
def limit1(eq):
|
|
90
|
+
if eq.name == "f_limit":
|
|
91
|
+
a, b = limit(eq.children[0], eq.children[1].name)
|
|
92
|
+
if b:
|
|
93
|
+
return a
|
|
94
|
+
else:
|
|
95
|
+
return TreeNode(eq.name, [a, eq.children[1]])
|
|
96
|
+
return TreeNode(eq.name, [limit1(child) for child in eq.children])
|
|
97
|
+
def fxinf(eq):
|
|
98
|
+
if eq is None:
|
|
99
|
+
return None
|
|
100
|
+
if eq.name == "f_add":
|
|
101
|
+
if tree_form("s_inf") in eq.children and -tree_form("s_inf") in eq.children:
|
|
102
|
+
return None
|
|
103
|
+
if tree_form("s_inf") in eq.children:
|
|
104
|
+
return tree_form("s_inf")
|
|
105
|
+
if -tree_form("s_inf") in eq.children:
|
|
106
|
+
return -tree_form("s_inf")
|
|
107
|
+
if eq.name == "f_mul":
|
|
108
|
+
lst = factor_generation(eq)
|
|
109
|
+
if tree_form("s_inf") in lst:
|
|
110
|
+
eq = TreeNode(eq.name, [dowhile(child, fxinf) for child in eq.children])
|
|
111
|
+
if None in eq.children:
|
|
112
|
+
return None
|
|
113
|
+
lst = factor_generation(eq)
|
|
114
|
+
if tree_form("d_0") in lst:
|
|
115
|
+
return tree_form("d_0")
|
|
116
|
+
lst2 = [item for item in lst if "v_" in str_form(item)]
|
|
117
|
+
sign = True
|
|
118
|
+
if len([item for item in lst if "v_" not in str_form(item) and not contain(item, tree_form("s_inf")) and compute(item)<0]) % 2==1:
|
|
119
|
+
sign = False
|
|
120
|
+
if lst2 == []:
|
|
121
|
+
if sign:
|
|
122
|
+
return tree_form("s_inf")
|
|
123
|
+
else:
|
|
124
|
+
return -tree_form("s_inf")
|
|
125
|
+
if eq.name == "f_pow":
|
|
126
|
+
if "v_" not in str_form(eq.children[0]) and not contain(eq.children[0], tree_form("s_inf")) and compute(eq.children[0])>0:
|
|
127
|
+
if eq.children[1] == -tree_form("s_inf"):
|
|
128
|
+
return tree_form("d_0")
|
|
129
|
+
|
|
130
|
+
eq = TreeNode(eq.name, [fxinf(child) for child in eq.children])
|
|
131
|
+
if None in eq.children:
|
|
132
|
+
return None
|
|
133
|
+
return eq
|
|
134
|
+
def limit3(eq):
|
|
135
|
+
|
|
136
|
+
if eq.name == "f_limitpinf":
|
|
137
|
+
if not contain(eq, eq.children[1]):
|
|
138
|
+
return eq.children[0]
|
|
139
|
+
eq2 = replace(eq.children[0], eq.children[1], tree_form("s_inf"))
|
|
140
|
+
eq2 = dowhile(eq2, fxinf)
|
|
141
|
+
if not contain(eq2, tree_form("s_inf")) and not contain(eq2, eq.children[1]):
|
|
142
|
+
return simplify(eq2)
|
|
143
|
+
return TreeNode(eq.name, [limit3(child) for child in eq.children])
|
|
117
144
|
|
|
118
145
|
def limit(equation, var="v_0"):
|
|
119
|
-
|
|
146
|
+
|
|
120
147
|
eq2 = dowhile(replace(equation, tree_form(var), tree_form("d_0")), lambda x: trig0(simplify(x)))
|
|
121
148
|
if eq2 is not None and not contain(equation, tree_form(var)):
|
|
122
|
-
return eq2,
|
|
149
|
+
return eq2, True
|
|
123
150
|
|
|
124
|
-
equation
|
|
151
|
+
equation = lhospital2(equation, var)
|
|
125
152
|
equation = simplify(expand(simplify(equation)))
|
|
126
153
|
if not contain(equation, tree_form(var)):
|
|
127
|
-
return equation,
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
return simplify(summation([limit(child, var) for child in equation.children]))
|
|
131
|
-
'''
|
|
132
|
-
return equation,logs+tmp
|
|
154
|
+
return equation, True
|
|
155
|
+
|
|
156
|
+
return equation, False
|
mathai/linear.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
|
+
from .inverse import inverse
|
|
2
|
+
import itertools
|
|
1
3
|
from .diff import diff
|
|
2
|
-
from .simplify import simplify
|
|
4
|
+
from .simplify import simplify
|
|
3
5
|
from .fraction import fraction
|
|
4
6
|
from .expand import expand
|
|
5
7
|
from .base import *
|
|
6
8
|
from .factor import factorconst
|
|
9
|
+
from .tool import poly
|
|
7
10
|
def ss(eq):
|
|
8
11
|
return dowhile(eq, lambda x: fraction(expand(simplify(x))))
|
|
9
12
|
def rref(matrix):
|
|
@@ -31,38 +34,16 @@ def rref(matrix):
|
|
|
31
34
|
return matrix
|
|
32
35
|
def islinear(eq, fxconst):
|
|
33
36
|
eq =simplify(eq)
|
|
34
|
-
if eq
|
|
35
|
-
return
|
|
36
|
-
|
|
37
|
-
out = islinear(child, fxconst)
|
|
38
|
-
if not out:
|
|
39
|
-
return out
|
|
40
|
-
return True
|
|
37
|
+
if all(fxconst(tree_form(item)) and poly(eq, item) is not None and len(poly(eq, item)) <= 2 for item in vlist(eq)):
|
|
38
|
+
return True
|
|
39
|
+
return False
|
|
41
40
|
def linear(eqlist, fxconst):
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
for i in range(len(eqlist)-1,-1,-1):
|
|
45
|
-
if eqlist[i].name == "f_mul" and not islinear(expand2(eqlist[i]), fxconst):
|
|
46
|
-
if "v_" in str_form(eqlist[i]):
|
|
47
|
-
eqlist[i] = TreeNode("f_mul", [child for child in eqlist[i].children if fxconst(child)])
|
|
48
|
-
if all(islinear(child, fxconst) for child in eqlist[i].children):
|
|
49
|
-
for child in eqlist[i].children:
|
|
50
|
-
extra.append(TreeNode("f_eq", [child, tree_form("d_0")]))
|
|
51
|
-
eqlist.pop(i)
|
|
52
|
-
else:
|
|
53
|
-
final.append(TreeNode("f_eq", [eqlist[i], tree_form("d_0")]))
|
|
54
|
-
eqlist.pop(i)
|
|
41
|
+
orig = [item.copy_tree() for item in eqlist]
|
|
42
|
+
#eqlist = [eq for eq in eqlist if fxconst(eq)]
|
|
55
43
|
|
|
56
|
-
if
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
if len(final)==1:
|
|
60
|
-
|
|
61
|
-
return final[0]
|
|
62
|
-
return TreeNode("f_and", final)
|
|
63
|
-
eqlist = [eq for eq in eqlist if fxconst(eq)]
|
|
64
|
-
if not all(islinear(eq, fxconst) for eq in eqlist):
|
|
65
|
-
return TreeNode("f_and", copy.deepcopy(final+eqlist))
|
|
44
|
+
if eqlist == [] or not all(islinear(eq, fxconst) for eq in eqlist):
|
|
45
|
+
return None
|
|
46
|
+
#return TreeNode("f_and", [TreeNode("f_eq", [x, tree_form("d_0")]) for x in orig])
|
|
66
47
|
vl = []
|
|
67
48
|
def varlist(eq, fxconst):
|
|
68
49
|
nonlocal vl
|
|
@@ -75,7 +56,7 @@ def linear(eqlist, fxconst):
|
|
|
75
56
|
vl = list(set(vl))
|
|
76
57
|
|
|
77
58
|
if len(vl) > len(eqlist):
|
|
78
|
-
return TreeNode("f_and",
|
|
59
|
+
return TreeNode("f_and", [TreeNode("f_eq", [x, tree_form("d_0")]) for x in eqlist])
|
|
79
60
|
m = []
|
|
80
61
|
for eq in eqlist:
|
|
81
62
|
s = copy.deepcopy(eq)
|
|
@@ -94,63 +75,91 @@ def linear(eqlist, fxconst):
|
|
|
94
75
|
for i in range(len(m)):
|
|
95
76
|
for j in range(len(m[i])):
|
|
96
77
|
m[i][j] = fraction(m[i][j])
|
|
97
|
-
|
|
98
|
-
for item in m:
|
|
99
|
-
if all(item2==tree_form("d_0") for item2 in item[:-1]) and item[-1] != tree_form("d_0"):
|
|
100
|
-
return tree_form("s_false")
|
|
101
78
|
|
|
102
79
|
output = []
|
|
103
80
|
for index, row in enumerate(m):
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
break
|
|
110
|
-
elif item == tree_form("d_0") and count == 1:
|
|
111
|
-
break
|
|
112
|
-
if count == 0:
|
|
113
|
-
continue
|
|
114
|
-
output.append(tree_form(vl[index])+row[-1])
|
|
115
|
-
if len(output) == 1 and len(final)==0:
|
|
81
|
+
if not all(item == 0 for item in row[:-1]):
|
|
82
|
+
output.append(summation([tree_form(vl[index2])*coeff for index2, coeff in enumerate(row[:-1])])+row[-1])
|
|
83
|
+
elif row[-1] != 0:
|
|
84
|
+
return tree_form("s_false")
|
|
85
|
+
if len(output) == 1:
|
|
116
86
|
return TreeNode("f_eq", [output[0], tree_form("d_0")])
|
|
117
|
-
|
|
87
|
+
if len(output) == 0:
|
|
88
|
+
return tree_form("s_false")
|
|
89
|
+
return TreeNode("f_and", [TreeNode("f_eq", [x, tree_form("d_0")]) for x in output])
|
|
90
|
+
def order_collinear_indices(points, idx):
|
|
91
|
+
"""
|
|
92
|
+
Arrange a subset of collinear points (given by indices) along their line.
|
|
93
|
+
|
|
94
|
+
points: list of (x, y) tuples
|
|
95
|
+
idx: list of indices referring to points
|
|
96
|
+
Returns: list of indices sorted along the line
|
|
97
|
+
"""
|
|
98
|
+
if len(idx) <= 1:
|
|
99
|
+
return idx[:]
|
|
100
|
+
|
|
101
|
+
# Take first two points from the subset to define the line
|
|
102
|
+
p0, p1 = points[idx[0]], points[idx[1]]
|
|
103
|
+
dx, dy = p1[0] - p0[0], p1[1] - p0[1]
|
|
104
|
+
|
|
105
|
+
# Projection factor for sorting
|
|
106
|
+
def projection_factor(i):
|
|
107
|
+
vx, vy = points[i][0] - p0[0], points[i][1] - p0[1]
|
|
108
|
+
return compute((vx * dx + vy * dy) / (dx**2 + dy**2))
|
|
109
|
+
|
|
110
|
+
# Sort indices by projection
|
|
111
|
+
sorted_idx = sorted(idx, key=projection_factor)
|
|
112
|
+
return list(sorted_idx)
|
|
113
|
+
def linear_or(eq):
|
|
114
|
+
eqlst =[]
|
|
115
|
+
if eq.name != "f_or":
|
|
116
|
+
eqlst = [eq]
|
|
117
|
+
else:
|
|
118
|
+
eqlst = eq.children
|
|
119
|
+
v = vlist(eq)
|
|
120
|
+
p = []
|
|
121
|
+
line = {}
|
|
122
|
+
for i in range(len(eqlst)):
|
|
123
|
+
line[i] = []
|
|
124
|
+
for item in itertools.combinations(enumerate(eqlst), 2):
|
|
125
|
+
x, y = item[0][0], item[1][0]
|
|
126
|
+
item = [item[0][1], item[1][1]]
|
|
127
|
+
|
|
128
|
+
out = linear_solve(TreeNode("f_and", list(item)))
|
|
118
129
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
130
|
+
if out is None:
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
if out.name == "f_and" and all(len(vlist(child)) == 1 for child in out.children) and set(vlist(out)) == set(v) and all(len(vlist(simplify(child))) >0 for child in out.children):
|
|
134
|
+
t = {}
|
|
135
|
+
for child in out.children:
|
|
136
|
+
t[v.index(vlist(child)[0])] = simplify(inverse(child.children[0], vlist(child)[0]))
|
|
137
|
+
t2 = []
|
|
138
|
+
for key in sorted(t.keys()):
|
|
139
|
+
t2.append(t[key])
|
|
140
|
+
t2 = tuple(t2)
|
|
141
|
+
if t2 not in p:
|
|
142
|
+
p.append(t2)
|
|
143
|
+
line[x] += [p.index(t2)]
|
|
144
|
+
line[y] += [p.index(t2)]
|
|
145
|
+
line2 = []
|
|
146
|
+
for key in sorted(line.keys()):
|
|
147
|
+
line2.append(order_collinear_indices(p, list(set(line[key]))))
|
|
123
148
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
return
|
|
133
|
-
eqlist = findeq(eq)
|
|
134
|
-
eqlist = [tree_form(x) for x in eqlist]
|
|
135
|
-
eqlist = [rmeq(x) for x in eqlist]
|
|
136
|
-
eqlist = [TreeNode("f_mul", factor_generation(x)) for x in eqlist if x != tree_form("d_0")]
|
|
137
|
-
eqlist = [x.children[0] if len(x.children) == 1 else x for x in eqlist]
|
|
149
|
+
return v, p, line2, eqlst
|
|
150
|
+
def linear_solve(eq, lst=None):
|
|
151
|
+
eq = simplify(eq)
|
|
152
|
+
eqlist = []
|
|
153
|
+
if eq.name =="f_and" and all(child.name == "f_eq" and child.children[1] == 0 for child in eq.children):
|
|
154
|
+
|
|
155
|
+
eqlist = [child.children[0] for child in eq.children]
|
|
156
|
+
else:
|
|
157
|
+
return eq
|
|
138
158
|
out = None
|
|
139
|
-
|
|
140
159
|
if lst is None:
|
|
141
160
|
out = linear(copy.deepcopy(eqlist), lambda x: "v_" in str_form(x))
|
|
142
161
|
else:
|
|
143
162
|
out = linear(copy.deepcopy(eqlist), lambda x: any(contain(x, item) for item in lst))
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
return TreeNode(eq.name, [rms(child) for child in eq.children])
|
|
148
|
-
return rms(out)
|
|
149
|
-
def linear_solve(eq, lst=None):
|
|
150
|
-
if eq.name == "f_and":
|
|
151
|
-
eq2 = copy.deepcopy(eq)
|
|
152
|
-
eq2.name = "f_list"
|
|
153
|
-
return mat0(eq2, lst)
|
|
154
|
-
elif eq.name == "f_eq":
|
|
155
|
-
return mat0(eq, lst)
|
|
156
|
-
return TreeNode(eq.name, [linear_solve(child, lst) for child in eq.children])
|
|
163
|
+
if out is None:
|
|
164
|
+
return None
|
|
165
|
+
return simplify(out)
|
mathai/logic.py
CHANGED
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
import itertools
|
|
2
2
|
from .base import *
|
|
3
|
-
|
|
3
|
+
def c(eq):
|
|
4
|
+
eq = logic1(eq)
|
|
5
|
+
eq = dowhile(eq, logic0)
|
|
6
|
+
eq = dowhile(eq, logic2)
|
|
7
|
+
return eq
|
|
8
|
+
def logic_n(eq):
|
|
9
|
+
return dowhile(eq, c)
|
|
4
10
|
def logic0(eq):
|
|
5
11
|
if eq.children is None or len(eq.children)==0:
|
|
6
12
|
return eq
|
mathai/matrix.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
from .base import *
|
|
2
|
+
import copy
|
|
3
|
+
from .simplify import simplify
|
|
4
|
+
import itertools
|
|
5
|
+
|
|
6
|
+
# ---------- tree <-> python list ----------
|
|
7
|
+
def tree_to_py(node):
|
|
8
|
+
if node.name=="f_list":
|
|
9
|
+
return [tree_to_py(c) for c in node.children]
|
|
10
|
+
return node
|
|
11
|
+
|
|
12
|
+
def py_to_tree(obj):
|
|
13
|
+
if isinstance(obj,list):
|
|
14
|
+
return TreeNode("f_list",[py_to_tree(x) for x in obj])
|
|
15
|
+
return obj
|
|
16
|
+
|
|
17
|
+
# ---------- shape detection ----------
|
|
18
|
+
def is_vector(x):
|
|
19
|
+
return isinstance(x,list) and all(isinstance(item,TreeNode) for item in x)
|
|
20
|
+
def is_mat(x):
|
|
21
|
+
return isinstance(x,list) and all(isinstance(item,list) for item in x)
|
|
22
|
+
def is_matrix(x):
|
|
23
|
+
return isinstance(x, list) and all(isinstance(item, list) and (is_mat(item) or is_vector(item)) for item in x)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# ---------- algebra primitives ----------
|
|
27
|
+
def dot(u,v):
|
|
28
|
+
if len(u)!=len(v):
|
|
29
|
+
raise ValueError("Vector size mismatch")
|
|
30
|
+
s = tree_form("d_0")
|
|
31
|
+
for a,b in zip(u,v):
|
|
32
|
+
s = TreeNode("f_add",[s,TreeNode("f_mul",[a,b])])
|
|
33
|
+
return s
|
|
34
|
+
|
|
35
|
+
def matmul(A, B):
|
|
36
|
+
# A: n × m
|
|
37
|
+
# B: m × p
|
|
38
|
+
|
|
39
|
+
n = len(A)
|
|
40
|
+
m = len(A[0])
|
|
41
|
+
p = len(B[0])
|
|
42
|
+
|
|
43
|
+
if m != len(B):
|
|
44
|
+
raise ValueError("Matrix dimension mismatch")
|
|
45
|
+
|
|
46
|
+
C = [[tree_form("d_0") for _ in range(p)] for _ in range(n)]
|
|
47
|
+
|
|
48
|
+
for i in range(n):
|
|
49
|
+
for j in range(p):
|
|
50
|
+
for k in range(m):
|
|
51
|
+
C[i][j] = TreeNode(
|
|
52
|
+
"f_add",
|
|
53
|
+
[C[i][j], TreeNode("f_mul", [A[i][k], B[k][j]])]
|
|
54
|
+
)
|
|
55
|
+
return C
|
|
56
|
+
|
|
57
|
+
# ---------- promotion ----------
|
|
58
|
+
def promote(node):
|
|
59
|
+
if node.name=="f_list":
|
|
60
|
+
return tree_to_py(node)
|
|
61
|
+
return node
|
|
62
|
+
def contains_neg(node):
|
|
63
|
+
if isinstance(node, list):
|
|
64
|
+
return False
|
|
65
|
+
if node.name.startswith("v_-"):
|
|
66
|
+
return False
|
|
67
|
+
for child in node.children:
|
|
68
|
+
if not contains_neg(child):
|
|
69
|
+
return False
|
|
70
|
+
return True
|
|
71
|
+
# ---------- multiplication (fully simplified) ----------
|
|
72
|
+
def multiply(left,right):
|
|
73
|
+
if left == tree_form("d_1"):
|
|
74
|
+
return right
|
|
75
|
+
if right == tree_form("d_1"):
|
|
76
|
+
return left
|
|
77
|
+
left2, right2 = left, right
|
|
78
|
+
if left2.name != "f_pow":
|
|
79
|
+
left2 = left2 ** 1
|
|
80
|
+
if right2.name != "f_pow":
|
|
81
|
+
right2 = right2 ** 1
|
|
82
|
+
if left2.name == "f_pow" and right2.name == "f_pow" and left2.children[0]==right2.children[0]:
|
|
83
|
+
return simplify(left2.children[0]**(left2.children[1]+right2.children[1]))
|
|
84
|
+
A,B = promote(left), promote(right)
|
|
85
|
+
|
|
86
|
+
# vector · vector
|
|
87
|
+
if is_vector(A) and is_vector(B):
|
|
88
|
+
return dot(A,B)
|
|
89
|
+
# matrix × matrix
|
|
90
|
+
if is_matrix(A) and is_matrix(B):
|
|
91
|
+
return py_to_tree(matmul(A,B))
|
|
92
|
+
# scalar × vector
|
|
93
|
+
for _ in range(2):
|
|
94
|
+
if contains_neg(A) and is_vector(B):
|
|
95
|
+
return py_to_tree([TreeNode("f_mul",[A,x]) for x in B])
|
|
96
|
+
# scalar × matrix
|
|
97
|
+
if contains_neg(A) and is_matrix(B):
|
|
98
|
+
return py_to_tree([[TreeNode("f_mul",[A,x]) for x in row] for row in B])
|
|
99
|
+
A, B = B, A
|
|
100
|
+
return None
|
|
101
|
+
def add_vec(A, B):
|
|
102
|
+
if len(A) != len(B):
|
|
103
|
+
raise ValueError("Vector dimension mismatch")
|
|
104
|
+
|
|
105
|
+
return [
|
|
106
|
+
TreeNode("f_add", [A[i], B[i]])
|
|
107
|
+
for i in range(len(A))
|
|
108
|
+
]
|
|
109
|
+
def matadd(A, B):
|
|
110
|
+
if len(A) != len(B) or len(A[0]) != len(B[0]):
|
|
111
|
+
raise ValueError("Matrix dimension mismatch")
|
|
112
|
+
|
|
113
|
+
n = len(A)
|
|
114
|
+
m = len(A[0])
|
|
115
|
+
|
|
116
|
+
return [
|
|
117
|
+
[
|
|
118
|
+
TreeNode("f_add", [A[i][j], B[i][j]])
|
|
119
|
+
for j in range(m)
|
|
120
|
+
]
|
|
121
|
+
for i in range(n)
|
|
122
|
+
]
|
|
123
|
+
def addition(left,right):
|
|
124
|
+
A,B = promote(left), promote(right)
|
|
125
|
+
# vector + vector
|
|
126
|
+
if is_vector(A) and is_vector(B):
|
|
127
|
+
return add_vec(A,B)
|
|
128
|
+
# matrix + matrix
|
|
129
|
+
if is_matrix(A) and is_matrix(B):
|
|
130
|
+
return py_to_tree(matadd(A,B))
|
|
131
|
+
return None
|
|
132
|
+
'''
|
|
133
|
+
def fold_wmul(eq):
|
|
134
|
+
if eq.name == "f_pow" and eq.children[1].name.startswith("d_"):
|
|
135
|
+
n = int(eq.children[1].name[2:])
|
|
136
|
+
if n == 1:
|
|
137
|
+
eq = eq.children[0]
|
|
138
|
+
elif n > 1:
|
|
139
|
+
tmp = promote(eq.children[0])
|
|
140
|
+
if is_matrix(tmp):
|
|
141
|
+
orig =tmp
|
|
142
|
+
for i in range(n-1):
|
|
143
|
+
tmp = matmul(orig, tmp)
|
|
144
|
+
eq = py_to_tree(tmp)
|
|
145
|
+
elif eq.name in ["f_wmul", "f_add"]:
|
|
146
|
+
if len(eq.children) == 1:
|
|
147
|
+
eq = eq.children[0]
|
|
148
|
+
else:
|
|
149
|
+
i = len(eq.children)-1
|
|
150
|
+
while i>0:
|
|
151
|
+
if eq.name == "f_wmul":
|
|
152
|
+
out = multiply(eq.children[i-1], eq.children[i])
|
|
153
|
+
else:
|
|
154
|
+
out = addition(eq.children[i-1], eq.children[i])
|
|
155
|
+
if out is not None:
|
|
156
|
+
eq.children.pop(i)
|
|
157
|
+
eq.children.pop(i-1)
|
|
158
|
+
eq.children.insert(i-1,out)
|
|
159
|
+
i = i-1
|
|
160
|
+
return TreeNode(eq.name, [fold_wmul(child) for child in eq.children])
|
|
161
|
+
'''
|
|
162
|
+
def fold_wmul(root):
|
|
163
|
+
# Post-order traversal using explicit stack
|
|
164
|
+
stack = [(root, False)]
|
|
165
|
+
newnode = {}
|
|
166
|
+
|
|
167
|
+
while stack:
|
|
168
|
+
node, visited = stack.pop()
|
|
169
|
+
|
|
170
|
+
if not visited:
|
|
171
|
+
# First time: push back as visited, then children
|
|
172
|
+
stack.append((node, True))
|
|
173
|
+
for child in node.children:
|
|
174
|
+
stack.append((child, False))
|
|
175
|
+
else:
|
|
176
|
+
# All children already processed
|
|
177
|
+
children = [newnode[c] for c in node.children]
|
|
178
|
+
eq = TreeNode(node.name, children)
|
|
179
|
+
|
|
180
|
+
# ---- original rewrite logic ----
|
|
181
|
+
|
|
182
|
+
if eq.name == "f_pow" and eq.children[1].name.startswith("d_"):
|
|
183
|
+
n = int(eq.children[1].name[2:])
|
|
184
|
+
if n == 1:
|
|
185
|
+
eq = eq.children[0]
|
|
186
|
+
elif n > 1:
|
|
187
|
+
tmp = promote(eq.children[0])
|
|
188
|
+
if is_matrix(tmp):
|
|
189
|
+
orig = tmp
|
|
190
|
+
for _ in range(n - 1):
|
|
191
|
+
tmp = matmul(orig, tmp)
|
|
192
|
+
eq = py_to_tree(tmp)
|
|
193
|
+
|
|
194
|
+
elif eq.name in ["f_wmul", "f_add"]:
|
|
195
|
+
if len(eq.children) == 1:
|
|
196
|
+
eq = eq.children[0]
|
|
197
|
+
else:
|
|
198
|
+
i = len(eq.children) - 1
|
|
199
|
+
while i > 0:
|
|
200
|
+
if eq.name == "f_wmul":
|
|
201
|
+
out = multiply(eq.children[i - 1], eq.children[i])
|
|
202
|
+
else:
|
|
203
|
+
out = addition(eq.children[i - 1], eq.children[i])
|
|
204
|
+
|
|
205
|
+
if out is not None:
|
|
206
|
+
eq.children.pop(i)
|
|
207
|
+
eq.children.pop(i - 1)
|
|
208
|
+
eq.children.insert(i - 1, out)
|
|
209
|
+
i -= 1
|
|
210
|
+
|
|
211
|
+
# --------------------------------
|
|
212
|
+
|
|
213
|
+
newnode[node] = eq
|
|
214
|
+
|
|
215
|
+
return newnode[root]
|
|
216
|
+
|
|
217
|
+
def flat(eq):
|
|
218
|
+
return flatten_tree(eq)
|
|
219
|
+
def use(eq):
|
|
220
|
+
return TreeNode(eq.name, [use(child) for child in eq.children])
|
|
221
|
+
def _matrix_solve(eq):
|
|
222
|
+
if TreeNode.matmul == True:
|
|
223
|
+
TreeNode.matmul = False
|
|
224
|
+
eq = dowhile(eq, lambda x: fold_wmul(use(flat(x))))
|
|
225
|
+
TreeNode.matmul = True
|
|
226
|
+
return eq
|
|
227
|
+
def matrix_solve(eq):
|
|
228
|
+
return _matrix_solve(eq)
|