mathai 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mathai/__init__.py +53 -0
- mathai/apart.py +142 -0
- mathai/base.py +419 -0
- mathai/bivariate_inequality.py +317 -0
- mathai/console.py +84 -0
- mathai/diff.py +68 -0
- mathai/expand.py +124 -0
- mathai/factor.py +304 -0
- mathai/fraction.py +103 -0
- mathai/integrate.py +459 -0
- mathai/inverse.py +65 -0
- mathai/limit.py +156 -0
- mathai/linear.py +165 -0
- mathai/logic.py +230 -0
- mathai/matrix.py +22 -0
- mathai/ode.py +124 -0
- mathai/parser.py +158 -0
- mathai/printeq.py +34 -0
- mathai/simplify.py +521 -0
- mathai/structure.py +103 -0
- mathai/tool.py +163 -0
- mathai/trig.py +276 -0
- mathai/univariate_inequality.py +458 -0
- mathai-0.6.0.dist-info/METADATA +234 -0
- mathai-0.6.0.dist-info/RECORD +27 -0
- mathai-0.6.0.dist-info/WHEEL +5 -0
- mathai-0.6.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from .linear import linear_or
|
|
3
|
+
from functools import reduce
|
|
4
|
+
import operator
|
|
5
|
+
from .base import *
|
|
6
|
+
from .simplify import simplify
|
|
7
|
+
from .expand import expand
|
|
8
|
+
from .logic import logic0
|
|
9
|
+
|
|
10
|
+
def shoelace_area(vertices):
|
|
11
|
+
n = len(vertices)
|
|
12
|
+
area = 0.0
|
|
13
|
+
for i in range(n):
|
|
14
|
+
j = (i + 1) % n
|
|
15
|
+
area += vertices[i][0] * vertices[j][1]
|
|
16
|
+
area -= vertices[j][0] * vertices[i][1]
|
|
17
|
+
return abs(area) / 2.0
|
|
18
|
+
|
|
19
|
+
def triangle_area(p1, p2, p3):
|
|
20
|
+
area = 0.0
|
|
21
|
+
area += p1[0] * (p2[1] - p3[1])
|
|
22
|
+
area += p2[0] * (p3[1] - p1[1])
|
|
23
|
+
area += p3[0] * (p1[1] - p2[1])
|
|
24
|
+
return abs(area) / 2.0
|
|
25
|
+
|
|
26
|
+
def is_point_inside_polygon(point, vertices):
|
|
27
|
+
if len(vertices) < 3:
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
polygon_area = shoelace_area(vertices)
|
|
31
|
+
|
|
32
|
+
total_triangle_area = 0.0
|
|
33
|
+
n = len(vertices)
|
|
34
|
+
for i in range(n):
|
|
35
|
+
j = (i + 1) % n
|
|
36
|
+
total_triangle_area += triangle_area(point, vertices[i], vertices[j])
|
|
37
|
+
|
|
38
|
+
tolerance = 1e-5
|
|
39
|
+
return abs(total_triangle_area - polygon_area) < tolerance
|
|
40
|
+
|
|
41
|
+
def distance_point_to_segment(px, py, x1, y1, x2, y2):
|
|
42
|
+
dx, dy = x2 - x1, y2 - y1
|
|
43
|
+
if dx == dy == 0:
|
|
44
|
+
return ((px - x1)**2 + (py - y1)**2)**0.5
|
|
45
|
+
t = max(0, min(1, ((px - x1) * dx + (py - y1) * dy) / (dx*dx + dy*dy)))
|
|
46
|
+
proj_x = x1 + t * dx
|
|
47
|
+
proj_y = y1 + t * dy
|
|
48
|
+
return ((px - proj_x)**2 + (py - proj_y)**2)**0.5
|
|
49
|
+
|
|
50
|
+
def deterministic_middle_point(vertices, grid_resolution=100):
|
|
51
|
+
xs = [v[0] for v in vertices]
|
|
52
|
+
ys = [v[1] for v in vertices]
|
|
53
|
+
xmin, xmax = min(xs), max(xs)
|
|
54
|
+
ymin, ymax = min(ys), max(ys)
|
|
55
|
+
|
|
56
|
+
best_point = None
|
|
57
|
+
max_dist = -1
|
|
58
|
+
|
|
59
|
+
for i in range(grid_resolution + 1):
|
|
60
|
+
for j in range(grid_resolution + 1):
|
|
61
|
+
px = xmin + (xmax - xmin) * i / grid_resolution
|
|
62
|
+
py = ymin + (ymax - ymin) * j / grid_resolution
|
|
63
|
+
if not is_point_inside_polygon((px, py), vertices):
|
|
64
|
+
continue
|
|
65
|
+
min_edge_dist = float('inf')
|
|
66
|
+
n = len(vertices)
|
|
67
|
+
for k in range(n):
|
|
68
|
+
x1, y1 = vertices[k]
|
|
69
|
+
x2, y2 = vertices[(k + 1) % n]
|
|
70
|
+
d = distance_point_to_segment(px, py, x1, y1, x2, y2)
|
|
71
|
+
min_edge_dist = min(min_edge_dist, d)
|
|
72
|
+
if min_edge_dist > max_dist:
|
|
73
|
+
max_dist = min_edge_dist
|
|
74
|
+
best_point = (px, py)
|
|
75
|
+
|
|
76
|
+
return best_point
|
|
77
|
+
|
|
78
|
+
def build(eq):
|
|
79
|
+
eq = TreeNode("f_or", eq)
|
|
80
|
+
eq = flatten_tree(eq)
|
|
81
|
+
orig = eq.copy_tree()
|
|
82
|
+
def fxhelper3(eq):
|
|
83
|
+
if eq.name[2:] in "le ge lt gt".split(" "):
|
|
84
|
+
return TreeNode("f_eq", [child.copy_tree() for child in eq.children])
|
|
85
|
+
return TreeNode(eq.name, [fxhelper3(child) for child in eq.children])
|
|
86
|
+
eq = fxhelper3(eq)
|
|
87
|
+
|
|
88
|
+
result = linear_or(eq)
|
|
89
|
+
|
|
90
|
+
if result is None:
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
maxnum = tree_form("d_2")
|
|
94
|
+
if len(result[1]) != 0:
|
|
95
|
+
maxnum = max([max([simplify(item2.fx("abs")) for item2 in item], key=lambda x: compute(x)) for item in result[1]], key=lambda x: compute(x))
|
|
96
|
+
maxnum += 1
|
|
97
|
+
maxnum = simplify(maxnum)
|
|
98
|
+
eq = flatten_tree(eq | simplify(TreeNode("f_or", [TreeNode("f_eq", [tree_form(item)+maxnum, tree_form("d_0")])|\
|
|
99
|
+
TreeNode("f_eq", [tree_form(item)-maxnum, tree_form("d_0")]) for item in ["v_0","v_1"]])))
|
|
100
|
+
result2 = linear_or(eq)
|
|
101
|
+
if result2 is None:
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
point_lst = result2[2]
|
|
105
|
+
|
|
106
|
+
def gen(point):
|
|
107
|
+
nonlocal point_lst
|
|
108
|
+
out = []
|
|
109
|
+
for item in point_lst:
|
|
110
|
+
p = None
|
|
111
|
+
if point in item:
|
|
112
|
+
p = item.index(point)
|
|
113
|
+
else:
|
|
114
|
+
continue
|
|
115
|
+
if p < len(item)-1:
|
|
116
|
+
out.append(item[p+1])
|
|
117
|
+
if p > 0:
|
|
118
|
+
out.append(item[p-1])
|
|
119
|
+
return list(set(out))
|
|
120
|
+
start = list(range(len(result2[1])))
|
|
121
|
+
graph= {}
|
|
122
|
+
for item in start:
|
|
123
|
+
graph[item] = gen(item)
|
|
124
|
+
|
|
125
|
+
points = {}
|
|
126
|
+
for index, item in enumerate(result2[1]):
|
|
127
|
+
points[index] = [compute(item2) for item2 in item]
|
|
128
|
+
|
|
129
|
+
res = []
|
|
130
|
+
for index, item in enumerate(result2[1]):
|
|
131
|
+
if any(simplify(item2.fx("abs")-maxnum)!=0 and abs(compute(item2))>compute(maxnum) for item2 in item):
|
|
132
|
+
res.append(index)
|
|
133
|
+
|
|
134
|
+
graph = {k: sorted(v) for k, v in graph.items()}
|
|
135
|
+
|
|
136
|
+
def dfs(current, parent, path, visited, cycles):
|
|
137
|
+
path.append(current)
|
|
138
|
+
visited.add(current)
|
|
139
|
+
for neighbor in graph[current]:
|
|
140
|
+
if neighbor == parent:
|
|
141
|
+
continue
|
|
142
|
+
if neighbor in visited:
|
|
143
|
+
idx = path.index(neighbor)
|
|
144
|
+
cycle = path[idx:]
|
|
145
|
+
cycles.append(cycle)
|
|
146
|
+
else:
|
|
147
|
+
dfs(neighbor, current, path, visited, cycles)
|
|
148
|
+
path.pop()
|
|
149
|
+
visited.remove(current)
|
|
150
|
+
|
|
151
|
+
cycles = []
|
|
152
|
+
for start in sorted(graph.keys()):
|
|
153
|
+
path = []
|
|
154
|
+
visited = set()
|
|
155
|
+
dfs(start, -1, path, visited, cycles)
|
|
156
|
+
|
|
157
|
+
def normalize(cycle):
|
|
158
|
+
k = len(cycle)
|
|
159
|
+
if k < 3:
|
|
160
|
+
return None
|
|
161
|
+
candidates = []
|
|
162
|
+
for direction in [cycle, list(reversed(cycle))]:
|
|
163
|
+
doubled = direction + direction[:-1]
|
|
164
|
+
for i in range(k):
|
|
165
|
+
rot = tuple(doubled[i:i + k])
|
|
166
|
+
candidates.append(rot)
|
|
167
|
+
return min(candidates)
|
|
168
|
+
|
|
169
|
+
unique = set()
|
|
170
|
+
for c in cycles:
|
|
171
|
+
norm = normalize(c)
|
|
172
|
+
if norm:
|
|
173
|
+
unique.add(norm)
|
|
174
|
+
|
|
175
|
+
cycles = sorted(list(unique), key=lambda x: (len(x), x))
|
|
176
|
+
|
|
177
|
+
start = list(range(len(result2[1])))
|
|
178
|
+
for i in range(len(cycles)-1,-1,-1):
|
|
179
|
+
if any(item in cycles[i] for item in res) or\
|
|
180
|
+
any(is_point_inside_polygon([compute(item2) for item2 in list(result2[1][p])], [[compute(item2) for item2 in result2[1][item]] for item in cycles[i]]) for p in list(set(start) - set(cycles[i]))) or\
|
|
181
|
+
any(len(set(graph[item]) & set(cycles[i]))>2 for item in cycles[i]):
|
|
182
|
+
cycles.pop(i)
|
|
183
|
+
|
|
184
|
+
point_lst = [index for index, item in enumerate(result2[1]) if item in result[1]]
|
|
185
|
+
|
|
186
|
+
border = []
|
|
187
|
+
for item in start:
|
|
188
|
+
for item2 in graph[item]:
|
|
189
|
+
a = result2[1][item]
|
|
190
|
+
b = result2[1][item2]
|
|
191
|
+
|
|
192
|
+
if a[0] == b[0] and simplify(a[0].fx("abs") - maxnum) == 0:
|
|
193
|
+
continue
|
|
194
|
+
if a[1] == b[1] and simplify(a[1].fx("abs") - maxnum) == 0:
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
border.append(tuple(sorted([item, item2])))
|
|
198
|
+
|
|
199
|
+
line = []
|
|
200
|
+
for key in graph.keys():
|
|
201
|
+
for item in list(set(point_lst)&set(graph[key])):
|
|
202
|
+
line.append(tuple(sorted([item, key])))
|
|
203
|
+
line = list(set(line+border))
|
|
204
|
+
point_in = [deterministic_middle_point([[compute(item3) for item3 in result2[1][item2]] for item2 in item]) for item in cycles]
|
|
205
|
+
def work(eq, point):
|
|
206
|
+
nonlocal result2
|
|
207
|
+
if eq.name[:2] == "d_":
|
|
208
|
+
return float(eq.name[2:])
|
|
209
|
+
if eq.name in result2[0]:
|
|
210
|
+
return point[result2[0].index(eq.name)]
|
|
211
|
+
if eq.name == "f_add":
|
|
212
|
+
return sum(work(item, point) for item in eq.children)
|
|
213
|
+
if eq.name == "f_mul":
|
|
214
|
+
return math.prod(work(item, point) for item in eq.children)
|
|
215
|
+
if eq.name == "f_sub":
|
|
216
|
+
return work(eq.children[0], point) - work(eq.children[1], point)
|
|
217
|
+
return {"eq": lambda a,b: abs(a-b)<0.001, "gt":lambda a,b: False if abs(a-b)<0.001 else a>b, "lt":lambda a,b: False if abs(a-b)<0.001 else a<b}[eq.name[2:]](work(eq.children[0], point), work(eq.children[1], point))
|
|
218
|
+
|
|
219
|
+
data = []
|
|
220
|
+
for index, item in enumerate(result2[2][:-4]):
|
|
221
|
+
a = tuple([item for item in point_lst if work(orig.children[index], [compute(item2) for item2 in result2[1][item]])])
|
|
222
|
+
#a = tuple(set(item) & set(point_lst))
|
|
223
|
+
#b = tuple(set([tuple(sorted([item[i], item[i+1]])) for i in range(len(item)-1)]) & set(line))
|
|
224
|
+
b = None
|
|
225
|
+
if orig.children[index] == "f_eq":
|
|
226
|
+
b = tuple([tuple(item) for item in line if work(orig.children[index], [compute(item2) for item2 in result2[1][item[1]]]) and work(orig.children[index], [compute(item2) for item2 in result2[1][item[0]]])])
|
|
227
|
+
else:
|
|
228
|
+
b = tuple([tuple(item) for item in line if work(orig.children[index], [compute(item2) for item2 in result2[1][item[1]]]) or work(orig.children[index], [compute(item2) for item2 in result2[1][item[0]]])])
|
|
229
|
+
c = tuple([tuple(item) for index2, item in enumerate(cycles) if work(orig.children[index], point_in[index2])])
|
|
230
|
+
data.append((a,b,c))
|
|
231
|
+
|
|
232
|
+
total = tuple([tuple(point_lst), tuple(line), tuple(cycles)])
|
|
233
|
+
final = {}
|
|
234
|
+
for index, item in enumerate(orig.children):
|
|
235
|
+
final[item] = tuple(data[index])
|
|
236
|
+
return final, total, result2[1]
|
|
237
|
+
|
|
238
|
+
def inequality_solve(eq):
|
|
239
|
+
|
|
240
|
+
eq = logic0(eq)
|
|
241
|
+
element = []
|
|
242
|
+
def helper(eq):
|
|
243
|
+
nonlocal element
|
|
244
|
+
|
|
245
|
+
if eq.name[2:] in "le ge lt gt eq".split(" ") and "v_" in str_form(eq):
|
|
246
|
+
element.append(eq)
|
|
247
|
+
return TreeNode(eq.name, [helper(child) for child in eq.children])
|
|
248
|
+
helper(eq)
|
|
249
|
+
|
|
250
|
+
out = build(list(set(element)))
|
|
251
|
+
|
|
252
|
+
if out is None:
|
|
253
|
+
return eq
|
|
254
|
+
|
|
255
|
+
def helper2(eq):
|
|
256
|
+
nonlocal out
|
|
257
|
+
if eq == tree_form("s_true"):
|
|
258
|
+
return [set(item) for item in out[1]]
|
|
259
|
+
if eq == tree_form("s_false"):
|
|
260
|
+
return [set(), set(), set()]
|
|
261
|
+
if eq in out[0].keys():
|
|
262
|
+
return [set(item) for item in out[0][eq]]
|
|
263
|
+
if eq.name == "f_or":
|
|
264
|
+
result = [helper2(child) for child in eq.children]
|
|
265
|
+
a = []
|
|
266
|
+
b = []
|
|
267
|
+
c = []
|
|
268
|
+
for item in result:
|
|
269
|
+
a += [item[0]]
|
|
270
|
+
b += [item[1]]
|
|
271
|
+
c += [item[2]]
|
|
272
|
+
x = a[0]
|
|
273
|
+
for item in a[1:]:
|
|
274
|
+
x |= item
|
|
275
|
+
y = b[0]
|
|
276
|
+
for item in b[1:]:
|
|
277
|
+
y |= item
|
|
278
|
+
z = c[0]
|
|
279
|
+
for item in c[1:]:
|
|
280
|
+
z |= item
|
|
281
|
+
return [x, y, z]
|
|
282
|
+
if eq.name == "f_and":
|
|
283
|
+
result = [helper2(child) for child in eq.children]
|
|
284
|
+
a = []
|
|
285
|
+
b = []
|
|
286
|
+
c = []
|
|
287
|
+
for item in result:
|
|
288
|
+
a += [item[0]]
|
|
289
|
+
b += [item[1]]
|
|
290
|
+
c += [item[2]]
|
|
291
|
+
x = a[0]
|
|
292
|
+
for item in a[1:]:
|
|
293
|
+
x &= item
|
|
294
|
+
y = b[0]
|
|
295
|
+
for item in b[1:]:
|
|
296
|
+
y &= item
|
|
297
|
+
z = c[0]
|
|
298
|
+
for item in c[1:]:
|
|
299
|
+
z &= item
|
|
300
|
+
return [x, y, z]
|
|
301
|
+
if eq.name == "f_not":
|
|
302
|
+
eq2 = helper2(eq.children[0])
|
|
303
|
+
a,b,c= eq2
|
|
304
|
+
d,e,f= [set(item) for item in out[1]]
|
|
305
|
+
return [d-a,e-b,f-c]
|
|
306
|
+
return helper2(dowhile(eq, lambda x: logic0(expand(simplify(eq)))))
|
|
307
|
+
out2 = helper2(eq)
|
|
308
|
+
|
|
309
|
+
out = list(out)
|
|
310
|
+
out[1] = [set(item) for item in out[1]]
|
|
311
|
+
if tuple(out[1]) == (set(), set(), set()):
|
|
312
|
+
return eq
|
|
313
|
+
if tuple(out[1]) == tuple(out2):
|
|
314
|
+
return tree_form("s_true")
|
|
315
|
+
if tuple(out2) == (set(), set(), set()):
|
|
316
|
+
return tree_form("s_false")
|
|
317
|
+
return eq
|
mathai/console.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from .expand import expand
|
|
3
|
+
from .parser import parse
|
|
4
|
+
from .printeq import printeq, printeq_log
|
|
5
|
+
from .simplify import solve, simplify
|
|
6
|
+
|
|
7
|
+
from .diff import diff
|
|
8
|
+
from .base import *
|
|
9
|
+
from .factor import _factorconst, factor
|
|
10
|
+
from .fraction import fraction
|
|
11
|
+
from .inverse import inverse
|
|
12
|
+
from .trig import trig0, trig1, trig2, trig3, trig4
|
|
13
|
+
from .logic import logic0, logic1, logic2, logic3
|
|
14
|
+
from .apart import apart
|
|
15
|
+
|
|
16
|
+
def console():
|
|
17
|
+
eq = None
|
|
18
|
+
orig = None
|
|
19
|
+
while True:
|
|
20
|
+
command = input(">>> ")
|
|
21
|
+
try:
|
|
22
|
+
orig = copy.deepcopy(eq)
|
|
23
|
+
if command == "expand":
|
|
24
|
+
eq = expand(eq)
|
|
25
|
+
elif command.split(" ")[0] == "inverse":
|
|
26
|
+
eq=simplify(eq)
|
|
27
|
+
if eq.name == "f_eq":
|
|
28
|
+
eq3 = eq.children[0]-eq.children[1]
|
|
29
|
+
eq2 = parse(command.split(" ")[1])
|
|
30
|
+
out = inverse(eq3, str_form(eq2))
|
|
31
|
+
eq = TreeNode(eq.name, [eq2,out])
|
|
32
|
+
elif command == "apart":
|
|
33
|
+
eq = apart(eq, vlist(eq)[0])
|
|
34
|
+
elif command == "rawprint":
|
|
35
|
+
print(eq)
|
|
36
|
+
elif command == "logic0":
|
|
37
|
+
eq = logic0(eq)
|
|
38
|
+
elif command == "logic1":
|
|
39
|
+
eq = logic1(eq)
|
|
40
|
+
elif command == "logic2":
|
|
41
|
+
eq = logic2(eq)
|
|
42
|
+
elif command == "logic3":
|
|
43
|
+
eq = logic3(eq)
|
|
44
|
+
elif command == "trig0":
|
|
45
|
+
eq = trig0(eq)
|
|
46
|
+
elif command == "trig1":
|
|
47
|
+
eq = trig1(eq)
|
|
48
|
+
elif command == "factor":
|
|
49
|
+
eq = factor(eq)
|
|
50
|
+
elif command == "trig2":
|
|
51
|
+
eq = trig2(eq)
|
|
52
|
+
elif command == "trig3":
|
|
53
|
+
eq = trig3(eq)
|
|
54
|
+
elif command == "trig4":
|
|
55
|
+
eq = trig4(eq)
|
|
56
|
+
elif command == "simplify":
|
|
57
|
+
eq = _factorconst(eq)
|
|
58
|
+
eq = simplify(eq)
|
|
59
|
+
elif command == "fraction":
|
|
60
|
+
eq = fraction(eq)
|
|
61
|
+
elif command.split(" ")[0] in ["integrate", "sqint", "byparts"]:
|
|
62
|
+
if command.split(" ")[0] == "sqint":
|
|
63
|
+
typesqint()
|
|
64
|
+
elif command.split(" ")[0] == "byparts":
|
|
65
|
+
typebyparts()
|
|
66
|
+
elif command.split(" ")[0] == "integrate":
|
|
67
|
+
typeintegrate()
|
|
68
|
+
out = integrate(eq, parse(command.split(" ")[1]).name)
|
|
69
|
+
if out is None:
|
|
70
|
+
print("failed to integrate")
|
|
71
|
+
else:
|
|
72
|
+
eq, logs = out
|
|
73
|
+
eq = simplify(eq)
|
|
74
|
+
printeq_log(logs)
|
|
75
|
+
print()
|
|
76
|
+
elif command.split(" ")[0] == "diff":
|
|
77
|
+
eq = diff(eq, parse(command.split(" ")[1]).name)
|
|
78
|
+
else:
|
|
79
|
+
eq = parse(command)
|
|
80
|
+
eq = copy.deepcopy(eq)
|
|
81
|
+
printeq(eq)
|
|
82
|
+
except:
|
|
83
|
+
eq = copy.deepcopy(orig)
|
|
84
|
+
print("error")
|
mathai/diff.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from .simplify import simplify
|
|
2
|
+
from .base import *
|
|
3
|
+
from .trig import trig0
|
|
4
|
+
def diff(equation, var="v_0"):
|
|
5
|
+
def diffeq(eq):
|
|
6
|
+
eq = simplify(eq)
|
|
7
|
+
if "v_" not in str_form(eq):
|
|
8
|
+
return tree_form("d_0")
|
|
9
|
+
if eq.name == "f_add":
|
|
10
|
+
add = tree_form("d_0")
|
|
11
|
+
for child in eq.children:
|
|
12
|
+
add += diffeq(child)
|
|
13
|
+
return add
|
|
14
|
+
elif eq.name == "f_abs":
|
|
15
|
+
return diffeq(eq.children[0])*eq.children[0]/eq
|
|
16
|
+
elif eq.name == "f_pow" and eq.children[0].name == "s_e":
|
|
17
|
+
return diffeq(eq.children[1])*eq
|
|
18
|
+
elif eq.name == "f_tan":
|
|
19
|
+
return diffeq(eq.children[0])/(eq.children[0].fx("cos")*eq.children[0].fx("cos"))
|
|
20
|
+
elif eq.name == "f_log":
|
|
21
|
+
return diffeq(eq.children[0])*(tree_form("d_1")/eq.children[0])
|
|
22
|
+
elif eq.name == "f_arcsin":
|
|
23
|
+
return diffeq(eq.children[0])/(tree_form("d_1")-eq.children[0]*eq.children[0])**(tree_form("d_2")**-1)
|
|
24
|
+
elif eq.name == "f_arccos":
|
|
25
|
+
return tree_form("d_-1")*diffeq(eq.children[0])/(tree_form("d_1")-eq.children[0]*eq.children[0])**(tree_form("d_2")**-1)
|
|
26
|
+
elif eq.name == "f_arctan":
|
|
27
|
+
return diffeq(eq.children[0])/(tree_form("d_1")+eq.children[0]*eq.children[0])
|
|
28
|
+
elif eq.name == "f_pow" and "v_" in str_form(eq.children[1]):
|
|
29
|
+
a, b = eq.children
|
|
30
|
+
return a**b * ((b/a) * diffeq(a) + a.fx("log") * diffeq(b))
|
|
31
|
+
elif eq.name == "f_mul":
|
|
32
|
+
add = tree_form("d_0")
|
|
33
|
+
for i in range(len(eq.children)):
|
|
34
|
+
tmp = eq.children.pop(i)
|
|
35
|
+
if len(eq.children)==1:
|
|
36
|
+
eq2 = eq.children[0]
|
|
37
|
+
else:
|
|
38
|
+
eq2 = eq
|
|
39
|
+
add += diffeq(tmp)*eq2
|
|
40
|
+
eq.children.insert(i, tmp)
|
|
41
|
+
return add
|
|
42
|
+
elif eq.name == "f_sin":
|
|
43
|
+
eq.name = "f_cos"
|
|
44
|
+
return diffeq(eq.children[0])*eq
|
|
45
|
+
elif eq.name == "f_cos":
|
|
46
|
+
eq.name = "f_sin"
|
|
47
|
+
return tree_form("d_-1")*diffeq(eq.children[0])*eq
|
|
48
|
+
elif eq.name[:2] == "v_":
|
|
49
|
+
return TreeNode("f_dif", [eq])
|
|
50
|
+
elif eq.name == "f_pow" and "v_" not in str_form(eq.children[1]):
|
|
51
|
+
base, power = eq.children
|
|
52
|
+
dbase = diffeq(base)
|
|
53
|
+
b1 = power - tree_form("d_1")
|
|
54
|
+
bab1 = TreeNode("f_pow", [base, b1])
|
|
55
|
+
return power * bab1 * dbase
|
|
56
|
+
return TreeNode("f_dif", [eq, tree_form(var)])
|
|
57
|
+
def helper(equation, var="v_0"):
|
|
58
|
+
if equation.name == "f_dif":
|
|
59
|
+
if equation.children[0].name == var:
|
|
60
|
+
return tree_form("d_1")
|
|
61
|
+
if var not in str_form(equation.children[0]):
|
|
62
|
+
return tree_form("d_0")
|
|
63
|
+
else:
|
|
64
|
+
return equation
|
|
65
|
+
return TreeNode(equation.name, [helper(child, var) for child in equation.children])
|
|
66
|
+
equation = diffeq(trig0(equation))
|
|
67
|
+
equation = helper(equation, var)
|
|
68
|
+
return simplify(equation)
|
mathai/expand.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
from .base import *
|
|
3
|
+
from .simplify import simplify
|
|
4
|
+
'''
|
|
5
|
+
def _expand(equation):
|
|
6
|
+
eq = equation
|
|
7
|
+
eq.children = [_expand(flatten_tree(child)) for child in eq.children]
|
|
8
|
+
if eq.name == "f_pow":
|
|
9
|
+
n = frac(eq.children[1])
|
|
10
|
+
if n is not None and n.denominator == 1 and n.numerator > 1:
|
|
11
|
+
power_children = []
|
|
12
|
+
for i in range(n.numerator):
|
|
13
|
+
power_children.append(eq.children[0])
|
|
14
|
+
return _expand(flatten_tree(TreeNode("f_mul", power_children)))
|
|
15
|
+
if eq.name == "f_mul":
|
|
16
|
+
lone_children = tree_form("d_1")
|
|
17
|
+
bracket_children = []
|
|
18
|
+
for i in range(len(eq.children)-1,-1,-1):
|
|
19
|
+
if eq.children[i].name == "f_add":
|
|
20
|
+
bracket_children.append(eq.children[i])
|
|
21
|
+
elif eq.children[i].name == "f_pow" and eq.children[i].children[0].name == "f_add":
|
|
22
|
+
n = frac(eq.children[i].children[1])
|
|
23
|
+
if n is not None and n.denominator == 1 and n.numerator > 1:
|
|
24
|
+
for j in range(n.numerator):
|
|
25
|
+
bracket_children.append(eq.children[i].children[0])
|
|
26
|
+
else:
|
|
27
|
+
lone_children = lone_children * eq.children[i]
|
|
28
|
+
else:
|
|
29
|
+
lone_children = lone_children * eq.children[i]
|
|
30
|
+
lone_children = simplify(lone_children)
|
|
31
|
+
while bracket_children != []:
|
|
32
|
+
tmp = tree_form("d_0")
|
|
33
|
+
for i in range(len(bracket_children[0].children)):
|
|
34
|
+
if lone_children.name == "f_add":
|
|
35
|
+
for j in range(len(lone_children.children)):
|
|
36
|
+
tmp = tmp + bracket_children[0].children[i] * lone_children.children[j]
|
|
37
|
+
else:
|
|
38
|
+
tmp = tmp + lone_children * bracket_children[0].children[i]
|
|
39
|
+
lone_children = flatten_tree(simplify(tmp))
|
|
40
|
+
bracket_children.pop(0)
|
|
41
|
+
return lone_children
|
|
42
|
+
return eq
|
|
43
|
+
'''
|
|
44
|
+
def _expand(equation):
|
|
45
|
+
"""Iterative version of _expand without recursion."""
|
|
46
|
+
# Stack: (node, child_index, partially_processed_children)
|
|
47
|
+
stack = [(equation, 0, [])]
|
|
48
|
+
|
|
49
|
+
while stack:
|
|
50
|
+
node, child_index, processed_children = stack.pop()
|
|
51
|
+
|
|
52
|
+
# If all children are processed
|
|
53
|
+
if child_index >= len(node.children):
|
|
54
|
+
# Replace children with processed versions
|
|
55
|
+
node.children = processed_children
|
|
56
|
+
|
|
57
|
+
# === Handle f_pow ===
|
|
58
|
+
if node.name == "f_pow":
|
|
59
|
+
n = frac(node.children[1])
|
|
60
|
+
if n is not None and n.denominator == 1 and n.numerator > 1:
|
|
61
|
+
# Convert power to repeated multiplication
|
|
62
|
+
power_children = [node.children[0] for _ in range(n.numerator)]
|
|
63
|
+
new_node = TreeNode("f_mul", power_children)
|
|
64
|
+
# Flatten tree
|
|
65
|
+
node = flatten_tree(new_node)
|
|
66
|
+
# Push it back for further processing
|
|
67
|
+
stack.append((node, 0, []))
|
|
68
|
+
continue
|
|
69
|
+
|
|
70
|
+
# === Handle f_mul ===
|
|
71
|
+
elif node.name == "f_mul":
|
|
72
|
+
# Separate lone children and bracket children
|
|
73
|
+
lone_children = tree_form("d_1")
|
|
74
|
+
bracket_children = []
|
|
75
|
+
|
|
76
|
+
# Iterate in reverse (like original)
|
|
77
|
+
for child in reversed(node.children):
|
|
78
|
+
if child.name == "f_add":
|
|
79
|
+
bracket_children.append(child)
|
|
80
|
+
elif child.name == "f_pow" and child.children[0].name == "f_add":
|
|
81
|
+
n = frac(child.children[1])
|
|
82
|
+
if n is not None and n.denominator == 1 and n.numerator > 1:
|
|
83
|
+
for _ in range(n.numerator):
|
|
84
|
+
bracket_children.append(child.children[0])
|
|
85
|
+
else:
|
|
86
|
+
lone_children = lone_children * child
|
|
87
|
+
else:
|
|
88
|
+
lone_children = lone_children * child
|
|
89
|
+
|
|
90
|
+
lone_children = simplify(lone_children)
|
|
91
|
+
|
|
92
|
+
# Distribute bracket children over lone children iteratively
|
|
93
|
+
while bracket_children:
|
|
94
|
+
tmp = tree_form("d_0")
|
|
95
|
+
bracket = bracket_children.pop(0)
|
|
96
|
+
for bc in bracket.children:
|
|
97
|
+
if lone_children.name == "f_add":
|
|
98
|
+
for lc in lone_children.children:
|
|
99
|
+
tmp = tmp + bc * lc
|
|
100
|
+
else:
|
|
101
|
+
tmp = tmp + bc * lone_children
|
|
102
|
+
# Simplify after each distribution
|
|
103
|
+
lone_children = flatten_tree(simplify(tmp))
|
|
104
|
+
|
|
105
|
+
node = lone_children
|
|
106
|
+
|
|
107
|
+
# === Return node to parent ===
|
|
108
|
+
if stack:
|
|
109
|
+
parent, idx, parent_children = stack.pop()
|
|
110
|
+
parent_children.append(node)
|
|
111
|
+
stack.append((parent, idx + 1, parent_children))
|
|
112
|
+
else:
|
|
113
|
+
# Root node fully expanded
|
|
114
|
+
return node
|
|
115
|
+
|
|
116
|
+
else:
|
|
117
|
+
# Push current node back for next child
|
|
118
|
+
stack.append((node, child_index, processed_children))
|
|
119
|
+
# Push the child to process next
|
|
120
|
+
child = flatten_tree(node.children[child_index])
|
|
121
|
+
stack.append((child, 0, []))
|
|
122
|
+
|
|
123
|
+
def expand(eq):
|
|
124
|
+
return _expand(eq)
|