mathai 0.6.4__tar.gz → 0.6.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mathai-0.6.4 → mathai-0.6.6}/PKG-INFO +1 -1
- {mathai-0.6.4 → mathai-0.6.6}/mathai/base.py +25 -22
- mathai-0.6.6/mathai/matrix.py +224 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/parser.py +2 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/simplify.py +6 -1
- {mathai-0.6.4 → mathai-0.6.6}/mathai.egg-info/PKG-INFO +1 -1
- {mathai-0.6.4 → mathai-0.6.6}/setup.py +1 -1
- mathai-0.6.4/mathai/matrix.py +0 -118
- {mathai-0.6.4 → mathai-0.6.6}/README.md +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/__init__.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/apart.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/bivariate_inequality.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/console.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/diff.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/expand.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/factor.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/fraction.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/integrate.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/inverse.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/limit.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/linear.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/logic.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/ode.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/parsetab.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/printeq.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/structure.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/tool.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/trig.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai/univariate_inequality.py +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai.egg-info/SOURCES.txt +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai.egg-info/dependency_links.txt +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai.egg-info/requires.txt +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/mathai.egg-info/top_level.txt +0 -0
- {mathai-0.6.4 → mathai-0.6.6}/setup.cfg +0 -0
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
from fractions import Fraction
|
|
3
3
|
def contains_list_or_neg(node):
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
if
|
|
4
|
+
stack = [node]
|
|
5
|
+
while stack:
|
|
6
|
+
n = stack.pop()
|
|
7
|
+
if n.name == "f_list" or n.name.startswith("v_-"):
|
|
8
8
|
return True
|
|
9
|
+
stack.extend(n.children)
|
|
9
10
|
return False
|
|
10
11
|
class TreeNode:
|
|
11
12
|
matmul = None
|
|
@@ -18,7 +19,7 @@ class TreeNode:
|
|
|
18
19
|
children = copy.deepcopy(children)
|
|
19
20
|
self.name = name
|
|
20
21
|
|
|
21
|
-
if name == "f_add" or (name == "f_mul" and
|
|
22
|
+
if name == "f_add" or (name == "f_mul" and TreeNode.matmul is None):
|
|
22
23
|
keyed = [(str_form(c), c) for c in children]
|
|
23
24
|
self.children = [c for _, c in sorted(keyed)]
|
|
24
25
|
|
|
@@ -30,31 +31,32 @@ class TreeNode:
|
|
|
30
31
|
sortable.append(c)
|
|
31
32
|
else:
|
|
32
33
|
fixed.append(c)
|
|
34
|
+
|
|
35
|
+
if len(sortable) > 1:
|
|
36
|
+
sortable = TreeNode("f_dmul", list(sorted(sortable, key=lambda x: str_form(x))))
|
|
37
|
+
sortable.name = "f_mul"
|
|
38
|
+
|
|
39
|
+
elif len(sortable) == 1:
|
|
40
|
+
sortable = sortable[0]
|
|
41
|
+
|
|
42
|
+
if isinstance(sortable, TreeNode):
|
|
43
|
+
fixed.append(sortable)
|
|
33
44
|
if len(fixed) > 1:
|
|
34
|
-
|
|
45
|
+
self.name = "f_wmul"
|
|
35
46
|
elif len(fixed) == 1:
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
sortable.append(fixed)
|
|
39
|
-
if len(sortable)==1 and name == "f_mul":
|
|
40
|
-
self.name = sortable[0].name
|
|
41
|
-
if self.name in ["f_add", "f_mul"]:
|
|
42
|
-
self.children = list(sorted(sortable[0].children, key=lambda x: str_form(x)))
|
|
43
|
-
else:
|
|
44
|
-
self.children = sortable[0].children
|
|
47
|
+
self.name = fixed[0].name
|
|
48
|
+
fixed = fixed[0].children
|
|
45
49
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
elif name == "f_mul" and TreeNode.matmul == 0:
|
|
49
|
-
|
|
50
|
-
self.children = children
|
|
50
|
+
|
|
51
|
+
self.children = fixed
|
|
51
52
|
else:
|
|
52
53
|
self.children = children
|
|
53
54
|
|
|
54
55
|
|
|
55
56
|
def fx(self, fxname):
|
|
56
57
|
return TreeNode("f_" + fxname, [self])
|
|
57
|
-
|
|
58
|
+
def copy_tree(self):
|
|
59
|
+
return copy.deepcopy(self)
|
|
58
60
|
def __repr__(self):
|
|
59
61
|
return string_equation(str_form(self))
|
|
60
62
|
|
|
@@ -364,7 +366,8 @@ def product(lst):
|
|
|
364
366
|
def flatten_tree(node, add=[]):
|
|
365
367
|
if not node.children:
|
|
366
368
|
return node
|
|
367
|
-
|
|
369
|
+
ad = []
|
|
370
|
+
if node.name in ["f_add", "f_mul", "f_and", "f_or", "f_wmul"]:
|
|
368
371
|
merged_children = []
|
|
369
372
|
for child in node.children:
|
|
370
373
|
flattened_child = flatten_tree(child, add)
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
from .base import *
|
|
2
|
+
import copy
|
|
3
|
+
from .simplify import simplify
|
|
4
|
+
import itertools
|
|
5
|
+
|
|
6
|
+
# ---------- tree <-> python list ----------
|
|
7
|
+
def tree_to_py(node):
|
|
8
|
+
if node.name=="f_list":
|
|
9
|
+
return [tree_to_py(c) for c in node.children]
|
|
10
|
+
return node
|
|
11
|
+
|
|
12
|
+
def py_to_tree(obj):
|
|
13
|
+
if isinstance(obj,list):
|
|
14
|
+
return TreeNode("f_list",[py_to_tree(x) for x in obj])
|
|
15
|
+
return obj
|
|
16
|
+
|
|
17
|
+
# ---------- shape detection ----------
|
|
18
|
+
def is_vector(x):
|
|
19
|
+
return isinstance(x,list) and all(isinstance(item,TreeNode) for item in x)
|
|
20
|
+
def is_mat(x):
|
|
21
|
+
return isinstance(x,list) and all(isinstance(item,list) for item in x)
|
|
22
|
+
def is_matrix(x):
|
|
23
|
+
return isinstance(x, list) and all(isinstance(item, list) and (is_mat(item) or is_vector(item)) for item in x)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# ---------- algebra primitives ----------
|
|
27
|
+
def dot(u,v):
|
|
28
|
+
if len(u)!=len(v):
|
|
29
|
+
raise ValueError("Vector size mismatch")
|
|
30
|
+
s = tree_form("d_0")
|
|
31
|
+
for a,b in zip(u,v):
|
|
32
|
+
s = TreeNode("f_add",[s,TreeNode("f_mul",[a,b])])
|
|
33
|
+
return s
|
|
34
|
+
|
|
35
|
+
def matmul(A, B):
|
|
36
|
+
# A: n × m
|
|
37
|
+
# B: m × p
|
|
38
|
+
|
|
39
|
+
n = len(A)
|
|
40
|
+
m = len(A[0])
|
|
41
|
+
p = len(B[0])
|
|
42
|
+
|
|
43
|
+
if m != len(B):
|
|
44
|
+
raise ValueError("Matrix dimension mismatch")
|
|
45
|
+
|
|
46
|
+
C = [[tree_form("d_0") for _ in range(p)] for _ in range(n)]
|
|
47
|
+
|
|
48
|
+
for i in range(n):
|
|
49
|
+
for j in range(p):
|
|
50
|
+
for k in range(m):
|
|
51
|
+
C[i][j] = TreeNode(
|
|
52
|
+
"f_add",
|
|
53
|
+
[C[i][j], TreeNode("f_mul", [A[i][k], B[k][j]])]
|
|
54
|
+
)
|
|
55
|
+
return C
|
|
56
|
+
|
|
57
|
+
# ---------- promotion ----------
|
|
58
|
+
def promote(node):
|
|
59
|
+
if node.name=="f_list":
|
|
60
|
+
return tree_to_py(node)
|
|
61
|
+
return node
|
|
62
|
+
def contains_neg(node):
|
|
63
|
+
if isinstance(node, list):
|
|
64
|
+
return False
|
|
65
|
+
if node.name.startswith("v_-"):
|
|
66
|
+
return False
|
|
67
|
+
for child in node.children:
|
|
68
|
+
if not contains_neg(child):
|
|
69
|
+
return False
|
|
70
|
+
return True
|
|
71
|
+
# ---------- multiplication (fully simplified) ----------
|
|
72
|
+
def multiply(left,right):
|
|
73
|
+
left2, right2 = left, right
|
|
74
|
+
if left2.name != "f_pow":
|
|
75
|
+
left2 = left2 ** 1
|
|
76
|
+
if right2.name != "f_pow":
|
|
77
|
+
right2 = right2 ** 1
|
|
78
|
+
if left2.name == "f_pow" and right2.name == "f_pow" and left2.children[0]==right2.children[0]:
|
|
79
|
+
return simplify(left2.children[0]**(left2.children[1]+right2.children[1]))
|
|
80
|
+
A,B = promote(left), promote(right)
|
|
81
|
+
|
|
82
|
+
# vector · vector
|
|
83
|
+
if is_vector(A) and is_vector(B):
|
|
84
|
+
return dot(A,B)
|
|
85
|
+
# matrix × matrix
|
|
86
|
+
if is_matrix(A) and is_matrix(B):
|
|
87
|
+
return py_to_tree(matmul(A,B))
|
|
88
|
+
# scalar × vector
|
|
89
|
+
for _ in range(2):
|
|
90
|
+
if contains_neg(A) and is_vector(B):
|
|
91
|
+
return py_to_tree([TreeNode("f_mul",[A,x]) for x in B])
|
|
92
|
+
# scalar × matrix
|
|
93
|
+
if contains_neg(A) and is_matrix(B):
|
|
94
|
+
return py_to_tree([[TreeNode("f_mul",[A,x]) for x in row] for row in B])
|
|
95
|
+
A, B = B, A
|
|
96
|
+
return None
|
|
97
|
+
def add_vec(A, B):
|
|
98
|
+
if len(A) != len(B):
|
|
99
|
+
raise ValueError("Vector dimension mismatch")
|
|
100
|
+
|
|
101
|
+
return [
|
|
102
|
+
TreeNode("f_add", [A[i], B[i]])
|
|
103
|
+
for i in range(len(A))
|
|
104
|
+
]
|
|
105
|
+
def matadd(A, B):
|
|
106
|
+
if len(A) != len(B) or len(A[0]) != len(B[0]):
|
|
107
|
+
raise ValueError("Matrix dimension mismatch")
|
|
108
|
+
|
|
109
|
+
n = len(A)
|
|
110
|
+
m = len(A[0])
|
|
111
|
+
|
|
112
|
+
return [
|
|
113
|
+
[
|
|
114
|
+
TreeNode("f_add", [A[i][j], B[i][j]])
|
|
115
|
+
for j in range(m)
|
|
116
|
+
]
|
|
117
|
+
for i in range(n)
|
|
118
|
+
]
|
|
119
|
+
def addition(left,right):
|
|
120
|
+
A,B = promote(left), promote(right)
|
|
121
|
+
# vector + vector
|
|
122
|
+
if is_vector(A) and is_vector(B):
|
|
123
|
+
return add_vec(A,B)
|
|
124
|
+
# matrix + matrix
|
|
125
|
+
if is_matrix(A) and is_matrix(B):
|
|
126
|
+
return py_to_tree(matadd(A,B))
|
|
127
|
+
return None
|
|
128
|
+
'''
|
|
129
|
+
def fold_wmul(eq):
|
|
130
|
+
if eq.name == "f_pow" and eq.children[1].name.startswith("d_"):
|
|
131
|
+
n = int(eq.children[1].name[2:])
|
|
132
|
+
if n == 1:
|
|
133
|
+
eq = eq.children[0]
|
|
134
|
+
elif n > 1:
|
|
135
|
+
tmp = promote(eq.children[0])
|
|
136
|
+
if is_matrix(tmp):
|
|
137
|
+
orig =tmp
|
|
138
|
+
for i in range(n-1):
|
|
139
|
+
tmp = matmul(orig, tmp)
|
|
140
|
+
eq = py_to_tree(tmp)
|
|
141
|
+
elif eq.name in ["f_wmul", "f_add"]:
|
|
142
|
+
if len(eq.children) == 1:
|
|
143
|
+
eq = eq.children[0]
|
|
144
|
+
else:
|
|
145
|
+
i = len(eq.children)-1
|
|
146
|
+
while i>0:
|
|
147
|
+
if eq.name == "f_wmul":
|
|
148
|
+
out = multiply(eq.children[i-1], eq.children[i])
|
|
149
|
+
else:
|
|
150
|
+
out = addition(eq.children[i-1], eq.children[i])
|
|
151
|
+
if out is not None:
|
|
152
|
+
eq.children.pop(i)
|
|
153
|
+
eq.children.pop(i-1)
|
|
154
|
+
eq.children.insert(i-1,out)
|
|
155
|
+
i = i-1
|
|
156
|
+
return TreeNode(eq.name, [fold_wmul(child) for child in eq.children])
|
|
157
|
+
'''
|
|
158
|
+
def fold_wmul(root):
|
|
159
|
+
# Post-order traversal using explicit stack
|
|
160
|
+
stack = [(root, False)]
|
|
161
|
+
newnode = {}
|
|
162
|
+
|
|
163
|
+
while stack:
|
|
164
|
+
node, visited = stack.pop()
|
|
165
|
+
|
|
166
|
+
if not visited:
|
|
167
|
+
# First time: push back as visited, then children
|
|
168
|
+
stack.append((node, True))
|
|
169
|
+
for child in node.children:
|
|
170
|
+
stack.append((child, False))
|
|
171
|
+
else:
|
|
172
|
+
# All children already processed
|
|
173
|
+
children = [newnode[c] for c in node.children]
|
|
174
|
+
eq = TreeNode(node.name, children)
|
|
175
|
+
|
|
176
|
+
# ---- original rewrite logic ----
|
|
177
|
+
|
|
178
|
+
if eq.name == "f_pow" and eq.children[1].name.startswith("d_"):
|
|
179
|
+
n = int(eq.children[1].name[2:])
|
|
180
|
+
if n == 1:
|
|
181
|
+
eq = eq.children[0]
|
|
182
|
+
elif n > 1:
|
|
183
|
+
tmp = promote(eq.children[0])
|
|
184
|
+
if is_matrix(tmp):
|
|
185
|
+
orig = tmp
|
|
186
|
+
for _ in range(n - 1):
|
|
187
|
+
tmp = matmul(orig, tmp)
|
|
188
|
+
eq = py_to_tree(tmp)
|
|
189
|
+
|
|
190
|
+
elif eq.name in ["f_wmul", "f_add"]:
|
|
191
|
+
if len(eq.children) == 1:
|
|
192
|
+
eq = eq.children[0]
|
|
193
|
+
else:
|
|
194
|
+
i = len(eq.children) - 1
|
|
195
|
+
while i > 0:
|
|
196
|
+
if eq.name == "f_wmul":
|
|
197
|
+
out = multiply(eq.children[i - 1], eq.children[i])
|
|
198
|
+
else:
|
|
199
|
+
out = addition(eq.children[i - 1], eq.children[i])
|
|
200
|
+
|
|
201
|
+
if out is not None:
|
|
202
|
+
eq.children.pop(i)
|
|
203
|
+
eq.children.pop(i - 1)
|
|
204
|
+
eq.children.insert(i - 1, out)
|
|
205
|
+
i -= 1
|
|
206
|
+
|
|
207
|
+
# --------------------------------
|
|
208
|
+
|
|
209
|
+
newnode[node] = eq
|
|
210
|
+
|
|
211
|
+
return newnode[root]
|
|
212
|
+
|
|
213
|
+
def flat(eq):
|
|
214
|
+
return flatten_tree(eq, ["f_wmul"])
|
|
215
|
+
def use(eq):
|
|
216
|
+
return TreeNode(eq.name, [use(child) for child in eq.children])
|
|
217
|
+
def _matrix_solve(eq):
|
|
218
|
+
if TreeNode.matmul == True:
|
|
219
|
+
TreeNode.matmul = False
|
|
220
|
+
eq = flat(fold_wmul(flat(eq)))
|
|
221
|
+
TreeNode.matmul = True
|
|
222
|
+
return eq
|
|
223
|
+
def matrix_solve(eq):
|
|
224
|
+
return _matrix_solve(eq)
|
|
@@ -500,6 +500,9 @@ def solve3(eq):
|
|
|
500
500
|
def simplify(eq, basic=True):
|
|
501
501
|
if eq is None:
|
|
502
502
|
return None
|
|
503
|
+
orig = TreeNode.matmul
|
|
504
|
+
if TreeNode.matmul == True:
|
|
505
|
+
TreeNode.matmul = False
|
|
503
506
|
if eq.name == "f_and" or eq.name == "f_not" or eq.name == "f_or":
|
|
504
507
|
new_children = []
|
|
505
508
|
for child in eq.children:
|
|
@@ -518,4 +521,6 @@ def simplify(eq, basic=True):
|
|
|
518
521
|
eq = flatten_tree(eq)
|
|
519
522
|
if basic:
|
|
520
523
|
eq = convert_to_basic(eq)
|
|
521
|
-
|
|
524
|
+
eq = solve3(eq)
|
|
525
|
+
TreeNode.matmul = orig
|
|
526
|
+
return eq
|
mathai-0.6.4/mathai/matrix.py
DELETED
|
@@ -1,118 +0,0 @@
|
|
|
1
|
-
from .base import *
|
|
2
|
-
import copy
|
|
3
|
-
from .simplify import simplify
|
|
4
|
-
import itertools
|
|
5
|
-
|
|
6
|
-
# ---------- tree <-> python list ----------
|
|
7
|
-
def tree_to_py(node):
|
|
8
|
-
if node.name=="f_list":
|
|
9
|
-
return [tree_to_py(c) for c in node.children]
|
|
10
|
-
return node
|
|
11
|
-
|
|
12
|
-
def py_to_tree(obj):
|
|
13
|
-
if isinstance(obj,list):
|
|
14
|
-
return TreeNode("f_list",[py_to_tree(x) for x in obj])
|
|
15
|
-
return obj
|
|
16
|
-
|
|
17
|
-
# ---------- shape detection ----------
|
|
18
|
-
def is_vector(x):
|
|
19
|
-
return isinstance(x,list) and all(isinstance(item,TreeNode) for item in x)
|
|
20
|
-
def is_mat(x):
|
|
21
|
-
return isinstance(x,list) and all(isinstance(item,list) for item in x)
|
|
22
|
-
def is_matrix(x):
|
|
23
|
-
return isinstance(x, list) and all(isinstance(item, list) and (is_mat(item) or is_vector(item)) for item in x)
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
# ---------- algebra primitives ----------
|
|
27
|
-
def dot(u,v):
|
|
28
|
-
if len(u)!=len(v):
|
|
29
|
-
raise ValueError("Vector size mismatch")
|
|
30
|
-
s = tree_form("d_0")
|
|
31
|
-
for a,b in zip(u,v):
|
|
32
|
-
s = TreeNode("f_add",[s,TreeNode("f_mul",[a,b])])
|
|
33
|
-
return s
|
|
34
|
-
|
|
35
|
-
def matmul(A,B):
|
|
36
|
-
n,m,p = len(A), len(A[0]), len(B[0])
|
|
37
|
-
if m!=len(B):
|
|
38
|
-
raise ValueError("Matrix dimension mismatch")
|
|
39
|
-
z = tree_form("d_0")
|
|
40
|
-
C = [[z for _ in range(p)] for _ in range(n)]
|
|
41
|
-
for i in range(n):
|
|
42
|
-
for j in range(p):
|
|
43
|
-
for k in range(m):
|
|
44
|
-
C[i][j] = TreeNode("f_add",[C[i][j], TreeNode("f_mul",[A[i][k], B[k][j]])])
|
|
45
|
-
return C
|
|
46
|
-
|
|
47
|
-
# ---------- promotion ----------
|
|
48
|
-
def promote(node):
|
|
49
|
-
if node.name=="f_list":
|
|
50
|
-
return tree_to_py(node)
|
|
51
|
-
return node
|
|
52
|
-
def contains_neg(node):
|
|
53
|
-
if isinstance(node, list):
|
|
54
|
-
return False
|
|
55
|
-
if node.name.startswith("v_-"):
|
|
56
|
-
return False
|
|
57
|
-
for child in node.children:
|
|
58
|
-
if not contains_neg(child):
|
|
59
|
-
return False
|
|
60
|
-
return True
|
|
61
|
-
# ---------- multiplication (fully simplified) ----------
|
|
62
|
-
def multiply(left,right):
|
|
63
|
-
left2, right2 = left, right
|
|
64
|
-
if left2.name != "f_pow":
|
|
65
|
-
left2 = left2 ** 1
|
|
66
|
-
if right2.name != "f_pow":
|
|
67
|
-
right2 = right2 ** 1
|
|
68
|
-
if left2.name == "f_pow" and right2.name == "f_pow" and left2.children[0]==right2.children[0]:
|
|
69
|
-
return simplify(left2.children[0]**(left2.children[1]+right2.children[1]))
|
|
70
|
-
A,B = promote(left), promote(right)
|
|
71
|
-
|
|
72
|
-
# vector · vector
|
|
73
|
-
if is_vector(A) and is_vector(B):
|
|
74
|
-
return dot(A,B)
|
|
75
|
-
# matrix × matrix
|
|
76
|
-
if is_matrix(A) and is_matrix(B):
|
|
77
|
-
return py_to_tree(matmul(A,B))
|
|
78
|
-
# scalar × vector
|
|
79
|
-
for _ in range(2):
|
|
80
|
-
if contains_neg(A) and is_vector(B):
|
|
81
|
-
return py_to_tree([TreeNode("f_mul",[A,x]) for x in B])
|
|
82
|
-
# scalar × matrix
|
|
83
|
-
if contains_neg(A) and is_matrix(B):
|
|
84
|
-
return py_to_tree([[TreeNode("f_mul",[A,x]) for x in row] for row in B])
|
|
85
|
-
A, B = B, A
|
|
86
|
-
return None
|
|
87
|
-
|
|
88
|
-
def fold_wmul(eq):
|
|
89
|
-
if eq.name == "f_pow" and eq.children[1].name.startswith("d_"):
|
|
90
|
-
n = int(eq.children[1].name[2:])
|
|
91
|
-
if n == 1:
|
|
92
|
-
eq = eq.children[0]
|
|
93
|
-
elif n > 1:
|
|
94
|
-
tmp = promote(eq.children[0])
|
|
95
|
-
if is_matrix(tmp):
|
|
96
|
-
orig =tmp
|
|
97
|
-
for i in range(n-1):
|
|
98
|
-
tmp = matmul(orig, tmp)
|
|
99
|
-
eq = py_to_tree(tmp)
|
|
100
|
-
elif eq.name=="f_wmul":
|
|
101
|
-
|
|
102
|
-
i = len(eq.children)-1
|
|
103
|
-
while i>0:
|
|
104
|
-
out = multiply(eq.children[i], eq.children[i-1])
|
|
105
|
-
if out is not None:
|
|
106
|
-
eq.children.pop(i)
|
|
107
|
-
eq.children.pop(i-1)
|
|
108
|
-
eq.children.insert(i-1,out)
|
|
109
|
-
i = i-1
|
|
110
|
-
return TreeNode(eq.name, [fold_wmul(child) for child in eq.children])
|
|
111
|
-
def flat(eq):
|
|
112
|
-
return flatten_tree(eq, ["f_wmul"])
|
|
113
|
-
def matrix_solve(eq):
|
|
114
|
-
if TreeNode.matmul == True:
|
|
115
|
-
TreeNode.matmul = False
|
|
116
|
-
eq = flat(dowhile(eq, lambda x: fold_wmul(flat(x))))
|
|
117
|
-
TreeNode.matmul = True
|
|
118
|
-
return eq
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|