mathai 0.4.0__py3-none-any.whl → 0.6.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mathai/__init__.py +18 -11
- mathai/apart.py +41 -12
- mathai/base.py +141 -38
- mathai/bivariate_inequality.py +317 -0
- mathai/diff.py +3 -3
- mathai/expand.py +92 -93
- mathai/factor.py +213 -41
- mathai/fraction.py +2 -2
- mathai/integrate.py +96 -34
- mathai/inverse.py +4 -4
- mathai/limit.py +96 -70
- mathai/linear.py +96 -84
- mathai/logic.py +7 -1
- mathai/matrix.py +228 -0
- mathai/ode.py +124 -0
- mathai/parser.py +13 -7
- mathai/parsetab.py +61 -0
- mathai/printeq.py +12 -9
- mathai/simplify.py +511 -333
- mathai/structure.py +2 -2
- mathai/tool.py +105 -4
- mathai/trig.py +134 -72
- mathai/univariate_inequality.py +78 -30
- {mathai-0.4.0.dist-info → mathai-0.6.9.dist-info}/METADATA +4 -1
- mathai-0.6.9.dist-info/RECORD +28 -0
- {mathai-0.4.0.dist-info → mathai-0.6.9.dist-info}/WHEEL +1 -1
- mathai/search.py +0 -117
- mathai-0.4.0.dist-info/RECORD +0 -25
- {mathai-0.4.0.dist-info → mathai-0.6.9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from .linear import linear_or
|
|
3
|
+
from functools import reduce
|
|
4
|
+
import operator
|
|
5
|
+
from .base import *
|
|
6
|
+
from .simplify import simplify
|
|
7
|
+
from .expand import expand
|
|
8
|
+
from .logic import logic0
|
|
9
|
+
|
|
10
|
+
def shoelace_area(vertices):
|
|
11
|
+
n = len(vertices)
|
|
12
|
+
area = 0.0
|
|
13
|
+
for i in range(n):
|
|
14
|
+
j = (i + 1) % n
|
|
15
|
+
area += vertices[i][0] * vertices[j][1]
|
|
16
|
+
area -= vertices[j][0] * vertices[i][1]
|
|
17
|
+
return abs(area) / 2.0
|
|
18
|
+
|
|
19
|
+
def triangle_area(p1, p2, p3):
|
|
20
|
+
area = 0.0
|
|
21
|
+
area += p1[0] * (p2[1] - p3[1])
|
|
22
|
+
area += p2[0] * (p3[1] - p1[1])
|
|
23
|
+
area += p3[0] * (p1[1] - p2[1])
|
|
24
|
+
return abs(area) / 2.0
|
|
25
|
+
|
|
26
|
+
def is_point_inside_polygon(point, vertices):
|
|
27
|
+
if len(vertices) < 3:
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
polygon_area = shoelace_area(vertices)
|
|
31
|
+
|
|
32
|
+
total_triangle_area = 0.0
|
|
33
|
+
n = len(vertices)
|
|
34
|
+
for i in range(n):
|
|
35
|
+
j = (i + 1) % n
|
|
36
|
+
total_triangle_area += triangle_area(point, vertices[i], vertices[j])
|
|
37
|
+
|
|
38
|
+
tolerance = 1e-5
|
|
39
|
+
return abs(total_triangle_area - polygon_area) < tolerance
|
|
40
|
+
|
|
41
|
+
def distance_point_to_segment(px, py, x1, y1, x2, y2):
|
|
42
|
+
dx, dy = x2 - x1, y2 - y1
|
|
43
|
+
if dx == dy == 0:
|
|
44
|
+
return ((px - x1)**2 + (py - y1)**2)**0.5
|
|
45
|
+
t = max(0, min(1, ((px - x1) * dx + (py - y1) * dy) / (dx*dx + dy*dy)))
|
|
46
|
+
proj_x = x1 + t * dx
|
|
47
|
+
proj_y = y1 + t * dy
|
|
48
|
+
return ((px - proj_x)**2 + (py - proj_y)**2)**0.5
|
|
49
|
+
|
|
50
|
+
def deterministic_middle_point(vertices, grid_resolution=100):
|
|
51
|
+
xs = [v[0] for v in vertices]
|
|
52
|
+
ys = [v[1] for v in vertices]
|
|
53
|
+
xmin, xmax = min(xs), max(xs)
|
|
54
|
+
ymin, ymax = min(ys), max(ys)
|
|
55
|
+
|
|
56
|
+
best_point = None
|
|
57
|
+
max_dist = -1
|
|
58
|
+
|
|
59
|
+
for i in range(grid_resolution + 1):
|
|
60
|
+
for j in range(grid_resolution + 1):
|
|
61
|
+
px = xmin + (xmax - xmin) * i / grid_resolution
|
|
62
|
+
py = ymin + (ymax - ymin) * j / grid_resolution
|
|
63
|
+
if not is_point_inside_polygon((px, py), vertices):
|
|
64
|
+
continue
|
|
65
|
+
min_edge_dist = float('inf')
|
|
66
|
+
n = len(vertices)
|
|
67
|
+
for k in range(n):
|
|
68
|
+
x1, y1 = vertices[k]
|
|
69
|
+
x2, y2 = vertices[(k + 1) % n]
|
|
70
|
+
d = distance_point_to_segment(px, py, x1, y1, x2, y2)
|
|
71
|
+
min_edge_dist = min(min_edge_dist, d)
|
|
72
|
+
if min_edge_dist > max_dist:
|
|
73
|
+
max_dist = min_edge_dist
|
|
74
|
+
best_point = (px, py)
|
|
75
|
+
|
|
76
|
+
return best_point
|
|
77
|
+
|
|
78
|
+
def build(eq):
|
|
79
|
+
eq = TreeNode("f_or", eq)
|
|
80
|
+
eq = flatten_tree(eq)
|
|
81
|
+
orig = eq.copy_tree()
|
|
82
|
+
def fxhelper3(eq):
|
|
83
|
+
if eq.name[2:] in "le ge lt gt".split(" "):
|
|
84
|
+
return TreeNode("f_eq", [child.copy_tree() for child in eq.children])
|
|
85
|
+
return TreeNode(eq.name, [fxhelper3(child) for child in eq.children])
|
|
86
|
+
eq = fxhelper3(eq)
|
|
87
|
+
|
|
88
|
+
result = linear_or(eq)
|
|
89
|
+
|
|
90
|
+
if result is None:
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
maxnum = tree_form("d_2")
|
|
94
|
+
if len(result[1]) != 0:
|
|
95
|
+
maxnum = max([max([simplify(item2.fx("abs")) for item2 in item], key=lambda x: compute(x)) for item in result[1]], key=lambda x: compute(x))
|
|
96
|
+
maxnum += 1
|
|
97
|
+
maxnum = simplify(maxnum)
|
|
98
|
+
eq = flatten_tree(eq | simplify(TreeNode("f_or", [TreeNode("f_eq", [tree_form(item)+maxnum, tree_form("d_0")])|\
|
|
99
|
+
TreeNode("f_eq", [tree_form(item)-maxnum, tree_form("d_0")]) for item in ["v_0","v_1"]])))
|
|
100
|
+
result2 = linear_or(eq)
|
|
101
|
+
if result2 is None:
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
point_lst = result2[2]
|
|
105
|
+
|
|
106
|
+
def gen(point):
|
|
107
|
+
nonlocal point_lst
|
|
108
|
+
out = []
|
|
109
|
+
for item in point_lst:
|
|
110
|
+
p = None
|
|
111
|
+
if point in item:
|
|
112
|
+
p = item.index(point)
|
|
113
|
+
else:
|
|
114
|
+
continue
|
|
115
|
+
if p < len(item)-1:
|
|
116
|
+
out.append(item[p+1])
|
|
117
|
+
if p > 0:
|
|
118
|
+
out.append(item[p-1])
|
|
119
|
+
return list(set(out))
|
|
120
|
+
start = list(range(len(result2[1])))
|
|
121
|
+
graph= {}
|
|
122
|
+
for item in start:
|
|
123
|
+
graph[item] = gen(item)
|
|
124
|
+
|
|
125
|
+
points = {}
|
|
126
|
+
for index, item in enumerate(result2[1]):
|
|
127
|
+
points[index] = [compute(item2) for item2 in item]
|
|
128
|
+
|
|
129
|
+
res = []
|
|
130
|
+
for index, item in enumerate(result2[1]):
|
|
131
|
+
if any(simplify(item2.fx("abs")-maxnum)!=0 and abs(compute(item2))>compute(maxnum) for item2 in item):
|
|
132
|
+
res.append(index)
|
|
133
|
+
|
|
134
|
+
graph = {k: sorted(v) for k, v in graph.items()}
|
|
135
|
+
|
|
136
|
+
def dfs(current, parent, path, visited, cycles):
|
|
137
|
+
path.append(current)
|
|
138
|
+
visited.add(current)
|
|
139
|
+
for neighbor in graph[current]:
|
|
140
|
+
if neighbor == parent:
|
|
141
|
+
continue
|
|
142
|
+
if neighbor in visited:
|
|
143
|
+
idx = path.index(neighbor)
|
|
144
|
+
cycle = path[idx:]
|
|
145
|
+
cycles.append(cycle)
|
|
146
|
+
else:
|
|
147
|
+
dfs(neighbor, current, path, visited, cycles)
|
|
148
|
+
path.pop()
|
|
149
|
+
visited.remove(current)
|
|
150
|
+
|
|
151
|
+
cycles = []
|
|
152
|
+
for start in sorted(graph.keys()):
|
|
153
|
+
path = []
|
|
154
|
+
visited = set()
|
|
155
|
+
dfs(start, -1, path, visited, cycles)
|
|
156
|
+
|
|
157
|
+
def normalize(cycle):
|
|
158
|
+
k = len(cycle)
|
|
159
|
+
if k < 3:
|
|
160
|
+
return None
|
|
161
|
+
candidates = []
|
|
162
|
+
for direction in [cycle, list(reversed(cycle))]:
|
|
163
|
+
doubled = direction + direction[:-1]
|
|
164
|
+
for i in range(k):
|
|
165
|
+
rot = tuple(doubled[i:i + k])
|
|
166
|
+
candidates.append(rot)
|
|
167
|
+
return min(candidates)
|
|
168
|
+
|
|
169
|
+
unique = set()
|
|
170
|
+
for c in cycles:
|
|
171
|
+
norm = normalize(c)
|
|
172
|
+
if norm:
|
|
173
|
+
unique.add(norm)
|
|
174
|
+
|
|
175
|
+
cycles = sorted(list(unique), key=lambda x: (len(x), x))
|
|
176
|
+
|
|
177
|
+
start = list(range(len(result2[1])))
|
|
178
|
+
for i in range(len(cycles)-1,-1,-1):
|
|
179
|
+
if any(item in cycles[i] for item in res) or\
|
|
180
|
+
any(is_point_inside_polygon([compute(item2) for item2 in list(result2[1][p])], [[compute(item2) for item2 in result2[1][item]] for item in cycles[i]]) for p in list(set(start) - set(cycles[i]))) or\
|
|
181
|
+
any(len(set(graph[item]) & set(cycles[i]))>2 for item in cycles[i]):
|
|
182
|
+
cycles.pop(i)
|
|
183
|
+
|
|
184
|
+
point_lst = [index for index, item in enumerate(result2[1]) if item in result[1]]
|
|
185
|
+
|
|
186
|
+
border = []
|
|
187
|
+
for item in start:
|
|
188
|
+
for item2 in graph[item]:
|
|
189
|
+
a = result2[1][item]
|
|
190
|
+
b = result2[1][item2]
|
|
191
|
+
|
|
192
|
+
if a[0] == b[0] and simplify(a[0].fx("abs") - maxnum) == 0:
|
|
193
|
+
continue
|
|
194
|
+
if a[1] == b[1] and simplify(a[1].fx("abs") - maxnum) == 0:
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
border.append(tuple(sorted([item, item2])))
|
|
198
|
+
|
|
199
|
+
line = []
|
|
200
|
+
for key in graph.keys():
|
|
201
|
+
for item in list(set(point_lst)&set(graph[key])):
|
|
202
|
+
line.append(tuple(sorted([item, key])))
|
|
203
|
+
line = list(set(line+border))
|
|
204
|
+
point_in = [deterministic_middle_point([[compute(item3) for item3 in result2[1][item2]] for item2 in item]) for item in cycles]
|
|
205
|
+
def work(eq, point):
|
|
206
|
+
nonlocal result2
|
|
207
|
+
if eq.name[:2] == "d_":
|
|
208
|
+
return float(eq.name[2:])
|
|
209
|
+
if eq.name in result2[0]:
|
|
210
|
+
return point[result2[0].index(eq.name)]
|
|
211
|
+
if eq.name == "f_add":
|
|
212
|
+
return sum(work(item, point) for item in eq.children)
|
|
213
|
+
if eq.name == "f_mul":
|
|
214
|
+
return math.prod(work(item, point) for item in eq.children)
|
|
215
|
+
if eq.name == "f_sub":
|
|
216
|
+
return work(eq.children[0], point) - work(eq.children[1], point)
|
|
217
|
+
return {"eq": lambda a,b: abs(a-b)<0.001, "gt":lambda a,b: False if abs(a-b)<0.001 else a>b, "lt":lambda a,b: False if abs(a-b)<0.001 else a<b}[eq.name[2:]](work(eq.children[0], point), work(eq.children[1], point))
|
|
218
|
+
|
|
219
|
+
data = []
|
|
220
|
+
for index, item in enumerate(result2[2][:-4]):
|
|
221
|
+
a = tuple([item for item in point_lst if work(orig.children[index], [compute(item2) for item2 in result2[1][item]])])
|
|
222
|
+
#a = tuple(set(item) & set(point_lst))
|
|
223
|
+
#b = tuple(set([tuple(sorted([item[i], item[i+1]])) for i in range(len(item)-1)]) & set(line))
|
|
224
|
+
b = None
|
|
225
|
+
if orig.children[index] == "f_eq":
|
|
226
|
+
b = tuple([tuple(item) for item in line if work(orig.children[index], [compute(item2) for item2 in result2[1][item[1]]]) and work(orig.children[index], [compute(item2) for item2 in result2[1][item[0]]])])
|
|
227
|
+
else:
|
|
228
|
+
b = tuple([tuple(item) for item in line if work(orig.children[index], [compute(item2) for item2 in result2[1][item[1]]]) or work(orig.children[index], [compute(item2) for item2 in result2[1][item[0]]])])
|
|
229
|
+
c = tuple([tuple(item) for index2, item in enumerate(cycles) if work(orig.children[index], point_in[index2])])
|
|
230
|
+
data.append((a,b,c))
|
|
231
|
+
|
|
232
|
+
total = tuple([tuple(point_lst), tuple(line), tuple(cycles)])
|
|
233
|
+
final = {}
|
|
234
|
+
for index, item in enumerate(orig.children):
|
|
235
|
+
final[item] = tuple(data[index])
|
|
236
|
+
return final, total, result2[1]
|
|
237
|
+
|
|
238
|
+
def inequality_solve(eq):
|
|
239
|
+
|
|
240
|
+
eq = logic0(eq)
|
|
241
|
+
element = []
|
|
242
|
+
def helper(eq):
|
|
243
|
+
nonlocal element
|
|
244
|
+
|
|
245
|
+
if eq.name[2:] in "le ge lt gt eq".split(" ") and "v_" in str_form(eq):
|
|
246
|
+
element.append(eq)
|
|
247
|
+
return TreeNode(eq.name, [helper(child) for child in eq.children])
|
|
248
|
+
helper(eq)
|
|
249
|
+
|
|
250
|
+
out = build(list(set(element)))
|
|
251
|
+
|
|
252
|
+
if out is None:
|
|
253
|
+
return eq
|
|
254
|
+
|
|
255
|
+
def helper2(eq):
|
|
256
|
+
nonlocal out
|
|
257
|
+
if eq == tree_form("s_true"):
|
|
258
|
+
return [set(item) for item in out[1]]
|
|
259
|
+
if eq == tree_form("s_false"):
|
|
260
|
+
return [set(), set(), set()]
|
|
261
|
+
if eq in out[0].keys():
|
|
262
|
+
return [set(item) for item in out[0][eq]]
|
|
263
|
+
if eq.name == "f_or":
|
|
264
|
+
result = [helper2(child) for child in eq.children]
|
|
265
|
+
a = []
|
|
266
|
+
b = []
|
|
267
|
+
c = []
|
|
268
|
+
for item in result:
|
|
269
|
+
a += [item[0]]
|
|
270
|
+
b += [item[1]]
|
|
271
|
+
c += [item[2]]
|
|
272
|
+
x = a[0]
|
|
273
|
+
for item in a[1:]:
|
|
274
|
+
x |= item
|
|
275
|
+
y = b[0]
|
|
276
|
+
for item in b[1:]:
|
|
277
|
+
y |= item
|
|
278
|
+
z = c[0]
|
|
279
|
+
for item in c[1:]:
|
|
280
|
+
z |= item
|
|
281
|
+
return [x, y, z]
|
|
282
|
+
if eq.name == "f_and":
|
|
283
|
+
result = [helper2(child) for child in eq.children]
|
|
284
|
+
a = []
|
|
285
|
+
b = []
|
|
286
|
+
c = []
|
|
287
|
+
for item in result:
|
|
288
|
+
a += [item[0]]
|
|
289
|
+
b += [item[1]]
|
|
290
|
+
c += [item[2]]
|
|
291
|
+
x = a[0]
|
|
292
|
+
for item in a[1:]:
|
|
293
|
+
x &= item
|
|
294
|
+
y = b[0]
|
|
295
|
+
for item in b[1:]:
|
|
296
|
+
y &= item
|
|
297
|
+
z = c[0]
|
|
298
|
+
for item in c[1:]:
|
|
299
|
+
z &= item
|
|
300
|
+
return [x, y, z]
|
|
301
|
+
if eq.name == "f_not":
|
|
302
|
+
eq2 = helper2(eq.children[0])
|
|
303
|
+
a,b,c= eq2
|
|
304
|
+
d,e,f= [set(item) for item in out[1]]
|
|
305
|
+
return [d-a,e-b,f-c]
|
|
306
|
+
return helper2(dowhile(eq, lambda x: logic0(expand(simplify(eq)))))
|
|
307
|
+
out2 = helper2(eq)
|
|
308
|
+
|
|
309
|
+
out = list(out)
|
|
310
|
+
out[1] = [set(item) for item in out[1]]
|
|
311
|
+
if tuple(out[1]) == (set(), set(), set()):
|
|
312
|
+
return eq
|
|
313
|
+
if tuple(out[1]) == tuple(out2):
|
|
314
|
+
return tree_form("s_true")
|
|
315
|
+
if tuple(out2) == (set(), set(), set()):
|
|
316
|
+
return tree_form("s_false")
|
|
317
|
+
return eq
|
mathai/diff.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
from .simplify import
|
|
1
|
+
from .simplify import simplify
|
|
2
2
|
from .base import *
|
|
3
3
|
from .trig import trig0
|
|
4
4
|
def diff(equation, var="v_0"):
|
|
5
5
|
def diffeq(eq):
|
|
6
|
-
eq =
|
|
6
|
+
eq = simplify(eq)
|
|
7
7
|
if "v_" not in str_form(eq):
|
|
8
8
|
return tree_form("d_0")
|
|
9
9
|
if eq.name == "f_add":
|
|
@@ -65,4 +65,4 @@ def diff(equation, var="v_0"):
|
|
|
65
65
|
return TreeNode(equation.name, [helper(child, var) for child in equation.children])
|
|
66
66
|
equation = diffeq(trig0(equation))
|
|
67
67
|
equation = helper(equation, var)
|
|
68
|
-
return
|
|
68
|
+
return simplify(equation)
|
mathai/expand.py
CHANGED
|
@@ -1,96 +1,95 @@
|
|
|
1
|
-
import itertools
|
|
2
1
|
from .base import *
|
|
3
|
-
from .simplify import
|
|
2
|
+
from .simplify import simplify
|
|
3
|
+
import itertools
|
|
4
|
+
|
|
5
|
+
def eliminate_powers(node):
|
|
6
|
+
if not node.children:
|
|
7
|
+
return node
|
|
8
|
+
|
|
9
|
+
node.children = [eliminate_powers(c) for c in node.children]
|
|
10
|
+
|
|
11
|
+
if node.name == "f_pow":
|
|
12
|
+
base, exp = node.children
|
|
13
|
+
n = frac(exp)
|
|
14
|
+
|
|
15
|
+
# Only expand positive integer powers
|
|
16
|
+
if not (n and n.denominator == 1 and n.numerator > 1):
|
|
17
|
+
return node
|
|
18
|
+
|
|
19
|
+
n = n.numerator
|
|
20
|
+
|
|
21
|
+
# ---- Multinomial expansion ----
|
|
22
|
+
if base.name == "f_add":
|
|
23
|
+
terms = []
|
|
24
|
+
for combo in itertools.product(base.children, repeat=n):
|
|
25
|
+
prod = combo[0]
|
|
26
|
+
for c in combo[1:]:
|
|
27
|
+
prod = prod * c
|
|
28
|
+
terms.append(prod)
|
|
29
|
+
return simplify(TreeNode("f_add", terms))
|
|
30
|
+
|
|
31
|
+
# ---- Fallback: simple power ----
|
|
32
|
+
return TreeNode("f_mul", [base] * n)
|
|
33
|
+
|
|
34
|
+
return node
|
|
4
35
|
|
|
5
|
-
def expand(eq):
|
|
6
|
-
if eq is None:
|
|
7
|
-
return None
|
|
8
|
-
|
|
9
|
-
stack = [(eq, 0)] # (node, stage)
|
|
10
|
-
result_map = {} # id(node) -> expanded TreeNode
|
|
11
|
-
|
|
12
|
-
while stack:
|
|
13
|
-
node, stage = stack.pop()
|
|
14
|
-
node_id = id(node)
|
|
15
|
-
|
|
16
|
-
# Leaf node
|
|
17
|
-
if not node.children and stage == 0:
|
|
18
|
-
result_map[node_id] = TreeNode(node.name, [])
|
|
19
|
-
continue
|
|
20
|
-
|
|
21
|
-
if stage == 0:
|
|
22
|
-
# Stage 0: push node back for stage 1 after children
|
|
23
|
-
stack.append((node, 1))
|
|
24
|
-
# Push children to stack
|
|
25
|
-
for child in reversed(node.children):
|
|
26
|
-
if id(child) not in result_map:
|
|
27
|
-
stack.append((child, 0))
|
|
28
|
-
else:
|
|
29
|
-
# Stage 1: all children processed
|
|
30
|
-
children_expanded = [result_map[id(child)] for child in node.children]
|
|
31
|
-
|
|
32
|
-
# Only f_mul or f_pow need special expansion
|
|
33
|
-
if node.name in ["f_mul", "f_pow"]:
|
|
34
|
-
current_eq = TreeNode(node.name, children_expanded)
|
|
35
|
-
|
|
36
|
-
if node.name == "f_pow":
|
|
37
|
-
current_eq = TreeNode("f_pow", [current_eq])
|
|
38
|
-
|
|
39
|
-
ac = []
|
|
40
|
-
addchild = []
|
|
41
|
-
|
|
42
|
-
for child in current_eq.children:
|
|
43
|
-
tmp5 = [solve(x) for x in factor_generation(child)]
|
|
44
|
-
ac += tmp5
|
|
45
|
-
|
|
46
|
-
tmp3 = []
|
|
47
|
-
for child in ac:
|
|
48
|
-
tmp2 = []
|
|
49
|
-
if child.name == "f_add":
|
|
50
|
-
if child.children != []:
|
|
51
|
-
tmp2.extend(child.children)
|
|
52
|
-
else:
|
|
53
|
-
tmp2 = [child]
|
|
54
|
-
else:
|
|
55
|
-
tmp3.append(child)
|
|
56
|
-
if tmp2 != []:
|
|
57
|
-
addchild.append(tmp2)
|
|
58
|
-
|
|
59
|
-
tmp4 = 1
|
|
60
|
-
for item in tmp3:
|
|
61
|
-
tmp4 = tmp4 * item
|
|
62
|
-
addchild.append([tmp4])
|
|
63
|
-
|
|
64
|
-
def flatten(lst):
|
|
65
|
-
flat_list = []
|
|
66
|
-
for item in lst:
|
|
67
|
-
if isinstance(item, list) and item == []:
|
|
68
|
-
continue
|
|
69
|
-
if isinstance(item, list):
|
|
70
|
-
flat_list.extend(flatten(item))
|
|
71
|
-
else:
|
|
72
|
-
flat_list.append(item)
|
|
73
|
-
return flat_list
|
|
74
|
-
|
|
75
|
-
if len(flatten(addchild)) > 0:
|
|
76
|
-
add = 0
|
|
77
|
-
for prod_items in itertools.product(*addchild):
|
|
78
|
-
mul = 1
|
|
79
|
-
for item2 in prod_items:
|
|
80
|
-
mul = mul * item2
|
|
81
|
-
mul = simplify(mul)
|
|
82
|
-
add = add + mul
|
|
83
|
-
add = simplify(add)
|
|
84
|
-
current_eq = simplify(add)
|
|
85
|
-
else:
|
|
86
|
-
current_eq = simplify(current_eq)
|
|
87
|
-
|
|
88
|
-
# Store expanded result
|
|
89
|
-
result_map[node_id] = current_eq
|
|
90
|
-
else:
|
|
91
|
-
# Default: reconstruct node with children
|
|
92
|
-
result_map[node_id] = TreeNode(node.name, children_expanded)
|
|
93
|
-
|
|
94
|
-
# Return final expanded eq
|
|
95
|
-
return result_map[id(eq)]
|
|
96
36
|
|
|
37
|
+
|
|
38
|
+
# =====================================================
|
|
39
|
+
# Phase 2: Single distributive rewrite (DEEPEST FIRST)
|
|
40
|
+
# =====================================================
|
|
41
|
+
|
|
42
|
+
def expand_once(node):
|
|
43
|
+
"""
|
|
44
|
+
Performs exactly ONE distributive expansion.
|
|
45
|
+
Deepest-first (post-order).
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
# ---- recurse FIRST (this is the fix) ----
|
|
49
|
+
for i, c in enumerate(node.children):
|
|
50
|
+
new, changed = expand_once(c)
|
|
51
|
+
if changed:
|
|
52
|
+
node.children[i] = new
|
|
53
|
+
return node, True
|
|
54
|
+
|
|
55
|
+
# ---- now try expanding at this node ----
|
|
56
|
+
if node.name == "f_mul":
|
|
57
|
+
for i, child in enumerate(node.children):
|
|
58
|
+
if child.name == "f_add":
|
|
59
|
+
left = node.children[:i]
|
|
60
|
+
right = node.children[i+1:]
|
|
61
|
+
|
|
62
|
+
terms = []
|
|
63
|
+
for t in child.children:
|
|
64
|
+
prod = t
|
|
65
|
+
for r in right:
|
|
66
|
+
prod = prod * r
|
|
67
|
+
for l in reversed(left):
|
|
68
|
+
prod = l * prod
|
|
69
|
+
terms.append(prod)
|
|
70
|
+
|
|
71
|
+
return TreeNode("f_add", terms), True
|
|
72
|
+
|
|
73
|
+
return node, False
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# =====================================================
|
|
77
|
+
# Phase 3: Global fixed-point driver
|
|
78
|
+
# =====================================================
|
|
79
|
+
|
|
80
|
+
def expand(eq):
|
|
81
|
+
orig = TreeNode.matmul
|
|
82
|
+
eq = simplify(eq)
|
|
83
|
+
if TreeNode.matmul is not None:
|
|
84
|
+
TreeNode.matmul = True
|
|
85
|
+
eq = tree_form(str_form(eq).replace("f_wmul", "f_mul"))
|
|
86
|
+
eq = flatten_tree(eq)
|
|
87
|
+
eq = eliminate_powers(eq)
|
|
88
|
+
while True:
|
|
89
|
+
eq = flatten_tree(eq)
|
|
90
|
+
eq, changed = expand_once(eq)
|
|
91
|
+
if not changed:
|
|
92
|
+
break
|
|
93
|
+
eq =simplify(eq)
|
|
94
|
+
TreeNode.matmul = orig
|
|
95
|
+
return eq
|