mathai 0.4.8__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mathai/__init__.py CHANGED
@@ -1,15 +1,15 @@
1
1
  from .ode import diffsolve as ode_solve
2
2
  from .ode import diffsolve_sep as ode_shift_term
3
3
 
4
- from .linear import linear_solve
4
+ from .linear import linear_solve, linear_or
5
5
 
6
6
  from .expand import expand
7
7
 
8
8
  from .parser import parse
9
9
 
10
- from .printeq import printeq, printeq_log, printeq_str
10
+ from .printeq import printeq, printeq_str
11
11
 
12
- from .simplify import solve, simplify, solve2
12
+ from .simplify import simplify
13
13
 
14
14
  from .integrate import ref as integrate_save
15
15
  from .integrate import integrate_subs_main as integrate_subs
@@ -24,7 +24,7 @@ from .integrate import integrate_formula
24
24
  from .diff import diff
25
25
 
26
26
  from .factor import factor as factor1
27
- from .factor import factor2
27
+ from .factor import factor2, factor3
28
28
  from .factor import rationalize_sqrt as rationalize
29
29
  from .factor import merge_sqrt
30
30
  from .factor import factorconst as factor0
@@ -35,15 +35,16 @@ from .inverse import inverse
35
35
 
36
36
  from .trig import trig0, trig1, trig2, trig3, trig4
37
37
 
38
- from .logic import logic0, logic1, logic2, logic3
38
+ from .logic import logic0, logic1, logic2, logic3, logic_n
39
39
 
40
40
  from .apart import apart, apart2
41
41
 
42
- from .console import console
42
+ from .limit import limit1, limit2, limit0, limit3
43
43
 
44
- from .limit import limit
44
+ from .univariate_inequality import wavycurvy, absolute, domain, handle_sqrt, prepare
45
+ from .bivariate_inequality import inequality_solve
45
46
 
46
- from .univariate_inequality import wavycurvy, absolute, domain, handle_sqrt
47
+ from .matrix import matrix_solve
47
48
 
48
49
  from .base import *
49
50
 
mathai/base.py CHANGED
@@ -1,24 +1,67 @@
1
1
  import copy
2
2
  from fractions import Fraction
3
+ def use(eq):
4
+ return TreeNode(eq.name, [use(child) for child in eq.children])
5
+ def use2(eq):
6
+ return use(tree_form(str_form(eq).replace("f_wmul", "f_mul")))
7
+ def contains_list_or_neg(node):
8
+ stack = [node]
9
+ while stack:
10
+ n = stack.pop()
11
+ if n.name == "f_list" or n.name.startswith("v_-"):
12
+ return True
13
+ stack.extend(n.children)
14
+ return False
3
15
 
4
16
  class TreeNode:
5
- def __init__(self, name, children=[]):
6
-
7
- children = [child.copy_tree() for child in children]
17
+ matmul = None
18
+
19
+ def __init__(self, name, children=None):
20
+ if children is None:
21
+ children = []
22
+
23
+ # copy once
24
+ children = copy.deepcopy(children)
8
25
  self.name = name
9
- if name in ["f_add", "f_mul"]:
10
- self.children = sorted(children, key=lambda x: str_form(x))
26
+
27
+ if name == "f_add" or (name == "f_mul" and TreeNode.matmul is None):
28
+ keyed = [(str_form(c), c) for c in children]
29
+ self.children = [c for _, c in sorted(keyed)]
30
+
31
+ elif name == "f_mul" and TreeNode.matmul == False:
32
+ sortable = []
33
+ fixed = []
34
+ for c in children:
35
+ if not contains_list_or_neg(c):
36
+ sortable.append(c)
37
+ else:
38
+ fixed.append(c)
39
+
40
+ if len(sortable) > 1:
41
+ sortable = TreeNode("f_dmul", list(sorted(sortable, key=lambda x: str_form(x))))
42
+ sortable.name = "f_mul"
43
+
44
+ elif len(sortable) == 1:
45
+ sortable = sortable[0]
46
+
47
+ if isinstance(sortable, TreeNode):
48
+ fixed.append(sortable)
49
+ if len(fixed) > 1:
50
+ self.name = "f_wmul"
51
+ elif len(fixed) == 1:
52
+ self.name = fixed[0].name
53
+ fixed = fixed[0].children
54
+
55
+
56
+ self.children = fixed
11
57
  else:
12
58
  self.children = children
13
59
 
60
+
14
61
  def fx(self, fxname):
15
62
  return TreeNode("f_" + fxname, [self])
16
- def copy_tree(node):
17
- if node is None:
18
- return None
19
-
20
- return tree_form(str_form(node))
21
-
63
+ def copy_tree(self):
64
+ return copy.deepcopy(self)
22
65
  def __repr__(self):
23
66
  return string_equation(str_form(self))
24
67
 
@@ -126,6 +169,44 @@ def remove_duplicates_custom(lst, rcustom):
126
169
  if not any(rcustom(item, x) for x in result):
127
170
  result.append(item)
128
171
  return result
172
+ def frac_to_tree(f):
173
+ if isinstance(f, int):
174
+ f = Fraction(f)
175
+ if f.numerator == 0:
176
+ return tree_form("d_0")
177
+ if f.numerator == 1:
178
+ if f.denominator == 1:
179
+ return tree_form("d_1")
180
+ return tree_form("d_"+str(f.denominator))**tree_form("d_-1")
181
+ if f.denominator == 1:
182
+ return tree_form("d_"+str(f.numerator))
183
+ else:
184
+ return tree_form("d_"+str(f.numerator))/tree_form("d_"+str(f.denominator))
185
+ def perfect_root(n, r):
186
+ if r <= 0 or (n < 0 and r % 2 == 0):
187
+ return False, None
188
+
189
+ lo = 0
190
+ hi = n if n > 1 else 1
191
+
192
+ while lo <= hi:
193
+ mid = lo + (hi - lo) // 2
194
+ pow_val = 1
195
+
196
+ for _ in range(r):
197
+ pow_val *= mid
198
+ if pow_val > n:
199
+ break
200
+
201
+ if pow_val == n:
202
+ return True, mid
203
+ elif pow_val < n:
204
+ lo = mid + 1
205
+ else:
206
+ hi = mid - 1
207
+
208
+ return False, None
209
+
129
210
  def frac(eq):
130
211
  if eq.name[:2] == "d_":
131
212
  return Fraction(int(eq.name[2:]))
@@ -154,23 +235,27 @@ def frac(eq):
154
235
  if eq.name == "f_pow":
155
236
  a = frac(eq.children[0])
156
237
  b = frac(eq.children[1])
157
- if isinstance(a, Fraction) and isinstance(b, Fraction) and b.denominator==1:
158
- if a == 0 and b <= 0:
159
- return None
160
- return a**b
161
- else:
238
+ if a is None or b is None:
162
239
  return None
240
+ if a == 0 and b <= 0:
241
+ return None
242
+ if b.denominator == 1:
243
+ return a ** b.numerator
244
+ found_c, c = perfect_root(a.numerator, b.denominator)
245
+ found_d, d = perfect_root(a.denominator, b.denominator)
246
+ if found_c and found_d:
247
+ return Fraction(c,d) ** b.numerator
248
+ return None
163
249
  return None
164
250
  def factor_generation(eq):
165
251
  output = []
166
252
  if eq.name != "f_mul":
167
- eq = TreeNode("f_mul", [eq])
253
+ tmp = TreeNode("f_mul", [])
254
+ tmp.children.append(eq)
255
+ eq = tmp
168
256
  if eq.name == "f_mul":
169
257
  for child in eq.children:
170
258
  if child.name == "f_pow":
171
- if child.children[0].name[:2] == "s_":
172
- output.append(child)
173
- continue
174
259
  if child.children[1].name[:2] != "d_":
175
260
  output.append(child)
176
261
  continue
@@ -178,7 +263,10 @@ def factor_generation(eq):
178
263
  n = int(child.children[1].name[2:])
179
264
  if n < 0:
180
265
  for i in range(-n):
181
- output.append(child.children[0]**-1)
266
+ out = factor_generation(child.children[0])
267
+ out = [x**-1 for x in out]
268
+ output += out
269
+ #output.append(child.children[0]**-1)
182
270
  else:
183
271
  for i in range(n):
184
272
  output.append(child.children[0])
@@ -209,11 +297,13 @@ def compute(eq):
209
297
  # Evaluate based on node type
210
298
  if eq.name == "f_add":
211
299
  return sum(values)
300
+ elif eq.name == "f_abs":
301
+ return math.fabs(values[0])
212
302
  elif eq.name == "f_sub":
213
303
  return values[0] - values[1]
214
304
  elif eq.name == "f_rad":
215
305
  return values[0] * math.pi / 180
216
- elif eq.name == "f_mul":
306
+ elif eq.name in ["f_wmul", "f_mul"]:
217
307
  result = 1.0
218
308
  for v in values:
219
309
  result *= v
@@ -245,14 +335,19 @@ def num_dem(equation):
245
335
  num = tree_form("d_1")
246
336
  den = tree_form("d_1")
247
337
  for item in factor_generation(equation):
248
-
249
- t = item
250
- if t.name == "f_pow" and "v_" not in str_form(t.children[1]) and compute(t.children[1]) < 0:
251
-
252
- den = den*item
338
+ if item.name == "f_pow":
339
+ c = frac(item.children[1])
340
+ if c is not None and c < 0:
341
+ t = frac_to_tree(-c)
342
+ if t == tree_form("d_1"):
343
+ den = den * item.children[0]
344
+ else:
345
+ den = den * item.children[0]**t
346
+ else:
347
+ den = den * item
253
348
  else:
254
349
  num = num*item
255
- return [num, tree_form("d_1")/den]
350
+ return num, den
256
351
 
257
352
  def summation(lst):
258
353
  if len(lst) == 0:
@@ -276,9 +371,12 @@ def product(lst):
276
371
  s *= item
277
372
  return s
278
373
  def flatten_tree(node):
374
+ if node is None:
375
+ return None
279
376
  if not node.children:
280
377
  return node
281
- if node.name in ("f_add", "f_mul", "f_and", "f_or"):
378
+ ad = []
379
+ if node.name in ["f_add", "f_mul", "f_and", "f_or", "f_wmul"]:
282
380
  merged_children = []
283
381
  for child in node.children:
284
382
  flattened_child = flatten_tree(child)
@@ -294,11 +392,11 @@ def dowhile(eq, fx):
294
392
  if eq is None:
295
393
  return None
296
394
  while True:
297
- orig = eq.copy_tree()
395
+ orig = copy.deepcopy(eq)
298
396
  eq2 = fx(eq)
299
397
  if eq2 is None:
300
398
  return None
301
- eq = eq2.copy_tree()
399
+ eq = copy.deepcopy(eq2)
302
400
  if eq == orig:
303
401
  return orig
304
402
  def tree_form(tabbed_strings):
@@ -332,9 +430,9 @@ def string_equation_helper(equation_tree):
332
430
  if equation_tree.name == "f_index":
333
431
  return string_equation_helper(equation_tree.children[0])+"["+",".join([string_equation_helper(child) for child in equation_tree.children[1:]])+"]"
334
432
  s = "("
335
- if len(equation_tree.children) == 1 or equation_tree.name[2:] in [chr(ord("A")+i) for i in range(26)]+["subs", "try", "ref", "integrate", "exist", "forall", "sum2", "int", "pdif", "dif", "A", "B", "C", "covariance", "sum"]:
433
+ if len(equation_tree.children) == 1 or equation_tree.name[2:] in [chr(ord("A")+i) for i in range(26)]+["limitpinf", "subs", "try", "ref","limit", "integrate", "exist", "forall", "sum2", "int", "pdif", "dif", "A", "B", "C", "covariance", "sum"]:
336
434
  s = equation_tree.name[2:] + s
337
- sign = {"f_not":"~", "f_addw":"+", "f_mulw":"*", "f_intersection":"&", "f_union":"|", "f_sum2":",", "f_exist":",", "f_forall":",", "f_sum":",","f_covariance": ",", "f_B":",", "f_imply":"->", "f_ge":">=", "f_le":"<=", "f_gt":">", "f_lt":"<", "f_cosec":"?" , "f_equiv": "<->", "f_sec":"?", "f_cot": "?", "f_dot": ".", "f_circumcenter":"?", "f_transpose":"?", "f_exp":"?", "f_abs":"?", "f_log":"?", "f_and":"&", "f_or":"|", "f_sub":"-", "f_neg":"?", "f_inv":"?", "f_add": "+", "f_mul": "*", "f_pow": "^", "f_poly": ",", "f_div": "/", "f_sub": "-", "f_dif": ",", "f_sin": "?", "f_cos": "?", "f_tan": "?", "f_eq": "=", "f_sqrt": "?"}
435
+ sign = {"f_not":"~", "f_wadd":"+", "f_wmul":"*", "f_intersection":"&", "f_union":"|", "f_sum2":",", "f_exist":",", "f_forall":",", "f_sum":",","f_covariance": ",", "f_B":",", "f_imply":"->", "f_ge":">=", "f_le":"<=", "f_gt":">", "f_lt":"<", "f_cosec":"?" , "f_equiv": "<->", "f_sec":"?", "f_cot": "?", "f_dot": ".", "f_circumcenter":"?", "f_transpose":"?", "f_exp":"?", "f_abs":"?", "f_log":"?", "f_and":"&", "f_or":"|", "f_sub":"-", "f_neg":"?", "f_inv":"?", "f_add": "+", "f_mul": "*", "f_pow": "^", "f_poly": ",", "f_div": "/", "f_sub": "-", "f_dif": ",", "f_sin": "?", "f_cos": "?", "f_tan": "?", "f_eq": "=", "f_sqrt": "?"}
338
436
  arr = []
339
437
  k = None
340
438
  if equation_tree.name not in sign.keys():
@@ -342,7 +440,7 @@ def string_equation_helper(equation_tree):
342
440
  else:
343
441
  k = sign[equation_tree.name]
344
442
  for child in equation_tree.children:
345
- arr.append(string_equation_helper(child.copy_tree()))
443
+ arr.append(string_equation_helper(copy.deepcopy(child)))
346
444
  outfinal = s + k.join(arr) + ")"+extra
347
445
 
348
446
  return outfinal.replace("+-", "-")
@@ -0,0 +1,317 @@
1
+ import math
2
+ from .linear import linear_or
3
+ from functools import reduce
4
+ import operator
5
+ from .base import *
6
+ from .simplify import simplify
7
+ from .expand import expand
8
+ from .logic import logic0
9
+
10
+ def shoelace_area(vertices):
11
+ n = len(vertices)
12
+ area = 0.0
13
+ for i in range(n):
14
+ j = (i + 1) % n
15
+ area += vertices[i][0] * vertices[j][1]
16
+ area -= vertices[j][0] * vertices[i][1]
17
+ return abs(area) / 2.0
18
+
19
+ def triangle_area(p1, p2, p3):
20
+ area = 0.0
21
+ area += p1[0] * (p2[1] - p3[1])
22
+ area += p2[0] * (p3[1] - p1[1])
23
+ area += p3[0] * (p1[1] - p2[1])
24
+ return abs(area) / 2.0
25
+
26
+ def is_point_inside_polygon(point, vertices):
27
+ if len(vertices) < 3:
28
+ return False
29
+
30
+ polygon_area = shoelace_area(vertices)
31
+
32
+ total_triangle_area = 0.0
33
+ n = len(vertices)
34
+ for i in range(n):
35
+ j = (i + 1) % n
36
+ total_triangle_area += triangle_area(point, vertices[i], vertices[j])
37
+
38
+ tolerance = 1e-5
39
+ return abs(total_triangle_area - polygon_area) < tolerance
40
+
41
+ def distance_point_to_segment(px, py, x1, y1, x2, y2):
42
+ dx, dy = x2 - x1, y2 - y1
43
+ if dx == dy == 0:
44
+ return ((px - x1)**2 + (py - y1)**2)**0.5
45
+ t = max(0, min(1, ((px - x1) * dx + (py - y1) * dy) / (dx*dx + dy*dy)))
46
+ proj_x = x1 + t * dx
47
+ proj_y = y1 + t * dy
48
+ return ((px - proj_x)**2 + (py - proj_y)**2)**0.5
49
+
50
+ def deterministic_middle_point(vertices, grid_resolution=100):
51
+ xs = [v[0] for v in vertices]
52
+ ys = [v[1] for v in vertices]
53
+ xmin, xmax = min(xs), max(xs)
54
+ ymin, ymax = min(ys), max(ys)
55
+
56
+ best_point = None
57
+ max_dist = -1
58
+
59
+ for i in range(grid_resolution + 1):
60
+ for j in range(grid_resolution + 1):
61
+ px = xmin + (xmax - xmin) * i / grid_resolution
62
+ py = ymin + (ymax - ymin) * j / grid_resolution
63
+ if not is_point_inside_polygon((px, py), vertices):
64
+ continue
65
+ min_edge_dist = float('inf')
66
+ n = len(vertices)
67
+ for k in range(n):
68
+ x1, y1 = vertices[k]
69
+ x2, y2 = vertices[(k + 1) % n]
70
+ d = distance_point_to_segment(px, py, x1, y1, x2, y2)
71
+ min_edge_dist = min(min_edge_dist, d)
72
+ if min_edge_dist > max_dist:
73
+ max_dist = min_edge_dist
74
+ best_point = (px, py)
75
+
76
+ return best_point
77
+
78
+ def build(eq):
79
+ eq = TreeNode("f_or", eq)
80
+ eq = flatten_tree(eq)
81
+ orig = eq.copy_tree()
82
+ def fxhelper3(eq):
83
+ if eq.name[2:] in "le ge lt gt".split(" "):
84
+ return TreeNode("f_eq", [child.copy_tree() for child in eq.children])
85
+ return TreeNode(eq.name, [fxhelper3(child) for child in eq.children])
86
+ eq = fxhelper3(eq)
87
+
88
+ result = linear_or(eq)
89
+
90
+ if result is None:
91
+ return None
92
+
93
+ maxnum = tree_form("d_2")
94
+ if len(result[1]) != 0:
95
+ maxnum = max([max([simplify(item2.fx("abs")) for item2 in item], key=lambda x: compute(x)) for item in result[1]], key=lambda x: compute(x))
96
+ maxnum += 1
97
+ maxnum = simplify(maxnum)
98
+ eq = flatten_tree(eq | simplify(TreeNode("f_or", [TreeNode("f_eq", [tree_form(item)+maxnum, tree_form("d_0")])|\
99
+ TreeNode("f_eq", [tree_form(item)-maxnum, tree_form("d_0")]) for item in ["v_0","v_1"]])))
100
+ result2 = linear_or(eq)
101
+ if result2 is None:
102
+ return None
103
+
104
+ point_lst = result2[2]
105
+
106
+ def gen(point):
107
+ nonlocal point_lst
108
+ out = []
109
+ for item in point_lst:
110
+ p = None
111
+ if point in item:
112
+ p = item.index(point)
113
+ else:
114
+ continue
115
+ if p < len(item)-1:
116
+ out.append(item[p+1])
117
+ if p > 0:
118
+ out.append(item[p-1])
119
+ return list(set(out))
120
+ start = list(range(len(result2[1])))
121
+ graph= {}
122
+ for item in start:
123
+ graph[item] = gen(item)
124
+
125
+ points = {}
126
+ for index, item in enumerate(result2[1]):
127
+ points[index] = [compute(item2) for item2 in item]
128
+
129
+ res = []
130
+ for index, item in enumerate(result2[1]):
131
+ if any(simplify(item2.fx("abs")-maxnum)!=0 and abs(compute(item2))>compute(maxnum) for item2 in item):
132
+ res.append(index)
133
+
134
+ graph = {k: sorted(v) for k, v in graph.items()}
135
+
136
+ def dfs(current, parent, path, visited, cycles):
137
+ path.append(current)
138
+ visited.add(current)
139
+ for neighbor in graph[current]:
140
+ if neighbor == parent:
141
+ continue
142
+ if neighbor in visited:
143
+ idx = path.index(neighbor)
144
+ cycle = path[idx:]
145
+ cycles.append(cycle)
146
+ else:
147
+ dfs(neighbor, current, path, visited, cycles)
148
+ path.pop()
149
+ visited.remove(current)
150
+
151
+ cycles = []
152
+ for start in sorted(graph.keys()):
153
+ path = []
154
+ visited = set()
155
+ dfs(start, -1, path, visited, cycles)
156
+
157
+ def normalize(cycle):
158
+ k = len(cycle)
159
+ if k < 3:
160
+ return None
161
+ candidates = []
162
+ for direction in [cycle, list(reversed(cycle))]:
163
+ doubled = direction + direction[:-1]
164
+ for i in range(k):
165
+ rot = tuple(doubled[i:i + k])
166
+ candidates.append(rot)
167
+ return min(candidates)
168
+
169
+ unique = set()
170
+ for c in cycles:
171
+ norm = normalize(c)
172
+ if norm:
173
+ unique.add(norm)
174
+
175
+ cycles = sorted(list(unique), key=lambda x: (len(x), x))
176
+
177
+ start = list(range(len(result2[1])))
178
+ for i in range(len(cycles)-1,-1,-1):
179
+ if any(item in cycles[i] for item in res) or\
180
+ any(is_point_inside_polygon([compute(item2) for item2 in list(result2[1][p])], [[compute(item2) for item2 in result2[1][item]] for item in cycles[i]]) for p in list(set(start) - set(cycles[i]))) or\
181
+ any(len(set(graph[item]) & set(cycles[i]))>2 for item in cycles[i]):
182
+ cycles.pop(i)
183
+
184
+ point_lst = [index for index, item in enumerate(result2[1]) if item in result[1]]
185
+
186
+ border = []
187
+ for item in start:
188
+ for item2 in graph[item]:
189
+ a = result2[1][item]
190
+ b = result2[1][item2]
191
+
192
+ if a[0] == b[0] and simplify(a[0].fx("abs") - maxnum) == 0:
193
+ continue
194
+ if a[1] == b[1] and simplify(a[1].fx("abs") - maxnum) == 0:
195
+ continue
196
+
197
+ border.append(tuple(sorted([item, item2])))
198
+
199
+ line = []
200
+ for key in graph.keys():
201
+ for item in list(set(point_lst)&set(graph[key])):
202
+ line.append(tuple(sorted([item, key])))
203
+ line = list(set(line+border))
204
+ point_in = [deterministic_middle_point([[compute(item3) for item3 in result2[1][item2]] for item2 in item]) for item in cycles]
205
+ def work(eq, point):
206
+ nonlocal result2
207
+ if eq.name[:2] == "d_":
208
+ return float(eq.name[2:])
209
+ if eq.name in result2[0]:
210
+ return point[result2[0].index(eq.name)]
211
+ if eq.name == "f_add":
212
+ return sum(work(item, point) for item in eq.children)
213
+ if eq.name == "f_mul":
214
+ return math.prod(work(item, point) for item in eq.children)
215
+ if eq.name == "f_sub":
216
+ return work(eq.children[0], point) - work(eq.children[1], point)
217
+ return {"eq": lambda a,b: abs(a-b)<0.001, "gt":lambda a,b: False if abs(a-b)<0.001 else a>b, "lt":lambda a,b: False if abs(a-b)<0.001 else a<b}[eq.name[2:]](work(eq.children[0], point), work(eq.children[1], point))
218
+
219
+ data = []
220
+ for index, item in enumerate(result2[2][:-4]):
221
+ a = tuple([item for item in point_lst if work(orig.children[index], [compute(item2) for item2 in result2[1][item]])])
222
+ #a = tuple(set(item) & set(point_lst))
223
+ #b = tuple(set([tuple(sorted([item[i], item[i+1]])) for i in range(len(item)-1)]) & set(line))
224
+ b = None
225
+ if orig.children[index] == "f_eq":
226
+ b = tuple([tuple(item) for item in line if work(orig.children[index], [compute(item2) for item2 in result2[1][item[1]]]) and work(orig.children[index], [compute(item2) for item2 in result2[1][item[0]]])])
227
+ else:
228
+ b = tuple([tuple(item) for item in line if work(orig.children[index], [compute(item2) for item2 in result2[1][item[1]]]) or work(orig.children[index], [compute(item2) for item2 in result2[1][item[0]]])])
229
+ c = tuple([tuple(item) for index2, item in enumerate(cycles) if work(orig.children[index], point_in[index2])])
230
+ data.append((a,b,c))
231
+
232
+ total = tuple([tuple(point_lst), tuple(line), tuple(cycles)])
233
+ final = {}
234
+ for index, item in enumerate(orig.children):
235
+ final[item] = tuple(data[index])
236
+ return final, total, result2[1]
237
+
238
+ def inequality_solve(eq):
239
+
240
+ eq = logic0(eq)
241
+ element = []
242
+ def helper(eq):
243
+ nonlocal element
244
+
245
+ if eq.name[2:] in "le ge lt gt eq".split(" ") and "v_" in str_form(eq):
246
+ element.append(eq)
247
+ return TreeNode(eq.name, [helper(child) for child in eq.children])
248
+ helper(eq)
249
+
250
+ out = build(list(set(element)))
251
+
252
+ if out is None:
253
+ return eq
254
+
255
+ def helper2(eq):
256
+ nonlocal out
257
+ if eq == tree_form("s_true"):
258
+ return [set(item) for item in out[1]]
259
+ if eq == tree_form("s_false"):
260
+ return [set(), set(), set()]
261
+ if eq in out[0].keys():
262
+ return [set(item) for item in out[0][eq]]
263
+ if eq.name == "f_or":
264
+ result = [helper2(child) for child in eq.children]
265
+ a = []
266
+ b = []
267
+ c = []
268
+ for item in result:
269
+ a += [item[0]]
270
+ b += [item[1]]
271
+ c += [item[2]]
272
+ x = a[0]
273
+ for item in a[1:]:
274
+ x |= item
275
+ y = b[0]
276
+ for item in b[1:]:
277
+ y |= item
278
+ z = c[0]
279
+ for item in c[1:]:
280
+ z |= item
281
+ return [x, y, z]
282
+ if eq.name == "f_and":
283
+ result = [helper2(child) for child in eq.children]
284
+ a = []
285
+ b = []
286
+ c = []
287
+ for item in result:
288
+ a += [item[0]]
289
+ b += [item[1]]
290
+ c += [item[2]]
291
+ x = a[0]
292
+ for item in a[1:]:
293
+ x &= item
294
+ y = b[0]
295
+ for item in b[1:]:
296
+ y &= item
297
+ z = c[0]
298
+ for item in c[1:]:
299
+ z &= item
300
+ return [x, y, z]
301
+ if eq.name == "f_not":
302
+ eq2 = helper2(eq.children[0])
303
+ a,b,c= eq2
304
+ d,e,f= [set(item) for item in out[1]]
305
+ return [d-a,e-b,f-c]
306
+ return helper2(dowhile(eq, lambda x: logic0(expand(simplify(eq)))))
307
+ out2 = helper2(eq)
308
+
309
+ out = list(out)
310
+ out[1] = [set(item) for item in out[1]]
311
+ if tuple(out[1]) == (set(), set(), set()):
312
+ return eq
313
+ if tuple(out[1]) == tuple(out2):
314
+ return tree_form("s_true")
315
+ if tuple(out2) == (set(), set(), set()):
316
+ return tree_form("s_false")
317
+ return eq
mathai/diff.py CHANGED
@@ -1,9 +1,9 @@
1
- from .simplify import solve
1
+ from .simplify import simplify
2
2
  from .base import *
3
3
  from .trig import trig0
4
4
  def diff(equation, var="v_0"):
5
5
  def diffeq(eq):
6
- eq = solve(eq)
6
+ eq = simplify(eq)
7
7
  if "v_" not in str_form(eq):
8
8
  return tree_form("d_0")
9
9
  if eq.name == "f_add":
@@ -65,4 +65,4 @@ def diff(equation, var="v_0"):
65
65
  return TreeNode(equation.name, [helper(child, var) for child in equation.children])
66
66
  equation = diffeq(trig0(equation))
67
67
  equation = helper(equation, var)
68
- return solve(equation)
68
+ return simplify(equation)