mathai 0.4.0__py3-none-any.whl → 0.6.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,317 @@
1
+ import math
2
+ from .linear import linear_or
3
+ from functools import reduce
4
+ import operator
5
+ from .base import *
6
+ from .simplify import simplify
7
+ from .expand import expand
8
+ from .logic import logic0
9
+
10
+ def shoelace_area(vertices):
11
+ n = len(vertices)
12
+ area = 0.0
13
+ for i in range(n):
14
+ j = (i + 1) % n
15
+ area += vertices[i][0] * vertices[j][1]
16
+ area -= vertices[j][0] * vertices[i][1]
17
+ return abs(area) / 2.0
18
+
19
+ def triangle_area(p1, p2, p3):
20
+ area = 0.0
21
+ area += p1[0] * (p2[1] - p3[1])
22
+ area += p2[0] * (p3[1] - p1[1])
23
+ area += p3[0] * (p1[1] - p2[1])
24
+ return abs(area) / 2.0
25
+
26
+ def is_point_inside_polygon(point, vertices):
27
+ if len(vertices) < 3:
28
+ return False
29
+
30
+ polygon_area = shoelace_area(vertices)
31
+
32
+ total_triangle_area = 0.0
33
+ n = len(vertices)
34
+ for i in range(n):
35
+ j = (i + 1) % n
36
+ total_triangle_area += triangle_area(point, vertices[i], vertices[j])
37
+
38
+ tolerance = 1e-5
39
+ return abs(total_triangle_area - polygon_area) < tolerance
40
+
41
+ def distance_point_to_segment(px, py, x1, y1, x2, y2):
42
+ dx, dy = x2 - x1, y2 - y1
43
+ if dx == dy == 0:
44
+ return ((px - x1)**2 + (py - y1)**2)**0.5
45
+ t = max(0, min(1, ((px - x1) * dx + (py - y1) * dy) / (dx*dx + dy*dy)))
46
+ proj_x = x1 + t * dx
47
+ proj_y = y1 + t * dy
48
+ return ((px - proj_x)**2 + (py - proj_y)**2)**0.5
49
+
50
+ def deterministic_middle_point(vertices, grid_resolution=100):
51
+ xs = [v[0] for v in vertices]
52
+ ys = [v[1] for v in vertices]
53
+ xmin, xmax = min(xs), max(xs)
54
+ ymin, ymax = min(ys), max(ys)
55
+
56
+ best_point = None
57
+ max_dist = -1
58
+
59
+ for i in range(grid_resolution + 1):
60
+ for j in range(grid_resolution + 1):
61
+ px = xmin + (xmax - xmin) * i / grid_resolution
62
+ py = ymin + (ymax - ymin) * j / grid_resolution
63
+ if not is_point_inside_polygon((px, py), vertices):
64
+ continue
65
+ min_edge_dist = float('inf')
66
+ n = len(vertices)
67
+ for k in range(n):
68
+ x1, y1 = vertices[k]
69
+ x2, y2 = vertices[(k + 1) % n]
70
+ d = distance_point_to_segment(px, py, x1, y1, x2, y2)
71
+ min_edge_dist = min(min_edge_dist, d)
72
+ if min_edge_dist > max_dist:
73
+ max_dist = min_edge_dist
74
+ best_point = (px, py)
75
+
76
+ return best_point
77
+
78
+ def build(eq):
79
+ eq = TreeNode("f_or", eq)
80
+ eq = flatten_tree(eq)
81
+ orig = eq.copy_tree()
82
+ def fxhelper3(eq):
83
+ if eq.name[2:] in "le ge lt gt".split(" "):
84
+ return TreeNode("f_eq", [child.copy_tree() for child in eq.children])
85
+ return TreeNode(eq.name, [fxhelper3(child) for child in eq.children])
86
+ eq = fxhelper3(eq)
87
+
88
+ result = linear_or(eq)
89
+
90
+ if result is None:
91
+ return None
92
+
93
+ maxnum = tree_form("d_2")
94
+ if len(result[1]) != 0:
95
+ maxnum = max([max([simplify(item2.fx("abs")) for item2 in item], key=lambda x: compute(x)) for item in result[1]], key=lambda x: compute(x))
96
+ maxnum += 1
97
+ maxnum = simplify(maxnum)
98
+ eq = flatten_tree(eq | simplify(TreeNode("f_or", [TreeNode("f_eq", [tree_form(item)+maxnum, tree_form("d_0")])|\
99
+ TreeNode("f_eq", [tree_form(item)-maxnum, tree_form("d_0")]) for item in ["v_0","v_1"]])))
100
+ result2 = linear_or(eq)
101
+ if result2 is None:
102
+ return None
103
+
104
+ point_lst = result2[2]
105
+
106
+ def gen(point):
107
+ nonlocal point_lst
108
+ out = []
109
+ for item in point_lst:
110
+ p = None
111
+ if point in item:
112
+ p = item.index(point)
113
+ else:
114
+ continue
115
+ if p < len(item)-1:
116
+ out.append(item[p+1])
117
+ if p > 0:
118
+ out.append(item[p-1])
119
+ return list(set(out))
120
+ start = list(range(len(result2[1])))
121
+ graph= {}
122
+ for item in start:
123
+ graph[item] = gen(item)
124
+
125
+ points = {}
126
+ for index, item in enumerate(result2[1]):
127
+ points[index] = [compute(item2) for item2 in item]
128
+
129
+ res = []
130
+ for index, item in enumerate(result2[1]):
131
+ if any(simplify(item2.fx("abs")-maxnum)!=0 and abs(compute(item2))>compute(maxnum) for item2 in item):
132
+ res.append(index)
133
+
134
+ graph = {k: sorted(v) for k, v in graph.items()}
135
+
136
+ def dfs(current, parent, path, visited, cycles):
137
+ path.append(current)
138
+ visited.add(current)
139
+ for neighbor in graph[current]:
140
+ if neighbor == parent:
141
+ continue
142
+ if neighbor in visited:
143
+ idx = path.index(neighbor)
144
+ cycle = path[idx:]
145
+ cycles.append(cycle)
146
+ else:
147
+ dfs(neighbor, current, path, visited, cycles)
148
+ path.pop()
149
+ visited.remove(current)
150
+
151
+ cycles = []
152
+ for start in sorted(graph.keys()):
153
+ path = []
154
+ visited = set()
155
+ dfs(start, -1, path, visited, cycles)
156
+
157
+ def normalize(cycle):
158
+ k = len(cycle)
159
+ if k < 3:
160
+ return None
161
+ candidates = []
162
+ for direction in [cycle, list(reversed(cycle))]:
163
+ doubled = direction + direction[:-1]
164
+ for i in range(k):
165
+ rot = tuple(doubled[i:i + k])
166
+ candidates.append(rot)
167
+ return min(candidates)
168
+
169
+ unique = set()
170
+ for c in cycles:
171
+ norm = normalize(c)
172
+ if norm:
173
+ unique.add(norm)
174
+
175
+ cycles = sorted(list(unique), key=lambda x: (len(x), x))
176
+
177
+ start = list(range(len(result2[1])))
178
+ for i in range(len(cycles)-1,-1,-1):
179
+ if any(item in cycles[i] for item in res) or\
180
+ any(is_point_inside_polygon([compute(item2) for item2 in list(result2[1][p])], [[compute(item2) for item2 in result2[1][item]] for item in cycles[i]]) for p in list(set(start) - set(cycles[i]))) or\
181
+ any(len(set(graph[item]) & set(cycles[i]))>2 for item in cycles[i]):
182
+ cycles.pop(i)
183
+
184
+ point_lst = [index for index, item in enumerate(result2[1]) if item in result[1]]
185
+
186
+ border = []
187
+ for item in start:
188
+ for item2 in graph[item]:
189
+ a = result2[1][item]
190
+ b = result2[1][item2]
191
+
192
+ if a[0] == b[0] and simplify(a[0].fx("abs") - maxnum) == 0:
193
+ continue
194
+ if a[1] == b[1] and simplify(a[1].fx("abs") - maxnum) == 0:
195
+ continue
196
+
197
+ border.append(tuple(sorted([item, item2])))
198
+
199
+ line = []
200
+ for key in graph.keys():
201
+ for item in list(set(point_lst)&set(graph[key])):
202
+ line.append(tuple(sorted([item, key])))
203
+ line = list(set(line+border))
204
+ point_in = [deterministic_middle_point([[compute(item3) for item3 in result2[1][item2]] for item2 in item]) for item in cycles]
205
+ def work(eq, point):
206
+ nonlocal result2
207
+ if eq.name[:2] == "d_":
208
+ return float(eq.name[2:])
209
+ if eq.name in result2[0]:
210
+ return point[result2[0].index(eq.name)]
211
+ if eq.name == "f_add":
212
+ return sum(work(item, point) for item in eq.children)
213
+ if eq.name == "f_mul":
214
+ return math.prod(work(item, point) for item in eq.children)
215
+ if eq.name == "f_sub":
216
+ return work(eq.children[0], point) - work(eq.children[1], point)
217
+ return {"eq": lambda a,b: abs(a-b)<0.001, "gt":lambda a,b: False if abs(a-b)<0.001 else a>b, "lt":lambda a,b: False if abs(a-b)<0.001 else a<b}[eq.name[2:]](work(eq.children[0], point), work(eq.children[1], point))
218
+
219
+ data = []
220
+ for index, item in enumerate(result2[2][:-4]):
221
+ a = tuple([item for item in point_lst if work(orig.children[index], [compute(item2) for item2 in result2[1][item]])])
222
+ #a = tuple(set(item) & set(point_lst))
223
+ #b = tuple(set([tuple(sorted([item[i], item[i+1]])) for i in range(len(item)-1)]) & set(line))
224
+ b = None
225
+ if orig.children[index] == "f_eq":
226
+ b = tuple([tuple(item) for item in line if work(orig.children[index], [compute(item2) for item2 in result2[1][item[1]]]) and work(orig.children[index], [compute(item2) for item2 in result2[1][item[0]]])])
227
+ else:
228
+ b = tuple([tuple(item) for item in line if work(orig.children[index], [compute(item2) for item2 in result2[1][item[1]]]) or work(orig.children[index], [compute(item2) for item2 in result2[1][item[0]]])])
229
+ c = tuple([tuple(item) for index2, item in enumerate(cycles) if work(orig.children[index], point_in[index2])])
230
+ data.append((a,b,c))
231
+
232
+ total = tuple([tuple(point_lst), tuple(line), tuple(cycles)])
233
+ final = {}
234
+ for index, item in enumerate(orig.children):
235
+ final[item] = tuple(data[index])
236
+ return final, total, result2[1]
237
+
238
+ def inequality_solve(eq):
239
+
240
+ eq = logic0(eq)
241
+ element = []
242
+ def helper(eq):
243
+ nonlocal element
244
+
245
+ if eq.name[2:] in "le ge lt gt eq".split(" ") and "v_" in str_form(eq):
246
+ element.append(eq)
247
+ return TreeNode(eq.name, [helper(child) for child in eq.children])
248
+ helper(eq)
249
+
250
+ out = build(list(set(element)))
251
+
252
+ if out is None:
253
+ return eq
254
+
255
+ def helper2(eq):
256
+ nonlocal out
257
+ if eq == tree_form("s_true"):
258
+ return [set(item) for item in out[1]]
259
+ if eq == tree_form("s_false"):
260
+ return [set(), set(), set()]
261
+ if eq in out[0].keys():
262
+ return [set(item) for item in out[0][eq]]
263
+ if eq.name == "f_or":
264
+ result = [helper2(child) for child in eq.children]
265
+ a = []
266
+ b = []
267
+ c = []
268
+ for item in result:
269
+ a += [item[0]]
270
+ b += [item[1]]
271
+ c += [item[2]]
272
+ x = a[0]
273
+ for item in a[1:]:
274
+ x |= item
275
+ y = b[0]
276
+ for item in b[1:]:
277
+ y |= item
278
+ z = c[0]
279
+ for item in c[1:]:
280
+ z |= item
281
+ return [x, y, z]
282
+ if eq.name == "f_and":
283
+ result = [helper2(child) for child in eq.children]
284
+ a = []
285
+ b = []
286
+ c = []
287
+ for item in result:
288
+ a += [item[0]]
289
+ b += [item[1]]
290
+ c += [item[2]]
291
+ x = a[0]
292
+ for item in a[1:]:
293
+ x &= item
294
+ y = b[0]
295
+ for item in b[1:]:
296
+ y &= item
297
+ z = c[0]
298
+ for item in c[1:]:
299
+ z &= item
300
+ return [x, y, z]
301
+ if eq.name == "f_not":
302
+ eq2 = helper2(eq.children[0])
303
+ a,b,c= eq2
304
+ d,e,f= [set(item) for item in out[1]]
305
+ return [d-a,e-b,f-c]
306
+ return helper2(dowhile(eq, lambda x: logic0(expand(simplify(eq)))))
307
+ out2 = helper2(eq)
308
+
309
+ out = list(out)
310
+ out[1] = [set(item) for item in out[1]]
311
+ if tuple(out[1]) == (set(), set(), set()):
312
+ return eq
313
+ if tuple(out[1]) == tuple(out2):
314
+ return tree_form("s_true")
315
+ if tuple(out2) == (set(), set(), set()):
316
+ return tree_form("s_false")
317
+ return eq
mathai/diff.py CHANGED
@@ -1,9 +1,9 @@
1
- from .simplify import solve
1
+ from .simplify import simplify
2
2
  from .base import *
3
3
  from .trig import trig0
4
4
  def diff(equation, var="v_0"):
5
5
  def diffeq(eq):
6
- eq = solve(eq)
6
+ eq = simplify(eq)
7
7
  if "v_" not in str_form(eq):
8
8
  return tree_form("d_0")
9
9
  if eq.name == "f_add":
@@ -65,4 +65,4 @@ def diff(equation, var="v_0"):
65
65
  return TreeNode(equation.name, [helper(child, var) for child in equation.children])
66
66
  equation = diffeq(trig0(equation))
67
67
  equation = helper(equation, var)
68
- return solve(equation)
68
+ return simplify(equation)
mathai/expand.py CHANGED
@@ -1,96 +1,95 @@
1
- import itertools
2
1
  from .base import *
3
- from .simplify import solve, simplify
2
+ from .simplify import simplify
3
+ import itertools
4
+
5
+ def eliminate_powers(node):
6
+ if not node.children:
7
+ return node
8
+
9
+ node.children = [eliminate_powers(c) for c in node.children]
10
+
11
+ if node.name == "f_pow":
12
+ base, exp = node.children
13
+ n = frac(exp)
14
+
15
+ # Only expand positive integer powers
16
+ if not (n and n.denominator == 1 and n.numerator > 1):
17
+ return node
18
+
19
+ n = n.numerator
20
+
21
+ # ---- Multinomial expansion ----
22
+ if base.name == "f_add":
23
+ terms = []
24
+ for combo in itertools.product(base.children, repeat=n):
25
+ prod = combo[0]
26
+ for c in combo[1:]:
27
+ prod = prod * c
28
+ terms.append(prod)
29
+ return simplify(TreeNode("f_add", terms))
30
+
31
+ # ---- Fallback: simple power ----
32
+ return TreeNode("f_mul", [base] * n)
33
+
34
+ return node
4
35
 
5
- def expand(eq):
6
- if eq is None:
7
- return None
8
-
9
- stack = [(eq, 0)] # (node, stage)
10
- result_map = {} # id(node) -> expanded TreeNode
11
-
12
- while stack:
13
- node, stage = stack.pop()
14
- node_id = id(node)
15
-
16
- # Leaf node
17
- if not node.children and stage == 0:
18
- result_map[node_id] = TreeNode(node.name, [])
19
- continue
20
-
21
- if stage == 0:
22
- # Stage 0: push node back for stage 1 after children
23
- stack.append((node, 1))
24
- # Push children to stack
25
- for child in reversed(node.children):
26
- if id(child) not in result_map:
27
- stack.append((child, 0))
28
- else:
29
- # Stage 1: all children processed
30
- children_expanded = [result_map[id(child)] for child in node.children]
31
-
32
- # Only f_mul or f_pow need special expansion
33
- if node.name in ["f_mul", "f_pow"]:
34
- current_eq = TreeNode(node.name, children_expanded)
35
-
36
- if node.name == "f_pow":
37
- current_eq = TreeNode("f_pow", [current_eq])
38
-
39
- ac = []
40
- addchild = []
41
-
42
- for child in current_eq.children:
43
- tmp5 = [solve(x) for x in factor_generation(child)]
44
- ac += tmp5
45
-
46
- tmp3 = []
47
- for child in ac:
48
- tmp2 = []
49
- if child.name == "f_add":
50
- if child.children != []:
51
- tmp2.extend(child.children)
52
- else:
53
- tmp2 = [child]
54
- else:
55
- tmp3.append(child)
56
- if tmp2 != []:
57
- addchild.append(tmp2)
58
-
59
- tmp4 = 1
60
- for item in tmp3:
61
- tmp4 = tmp4 * item
62
- addchild.append([tmp4])
63
-
64
- def flatten(lst):
65
- flat_list = []
66
- for item in lst:
67
- if isinstance(item, list) and item == []:
68
- continue
69
- if isinstance(item, list):
70
- flat_list.extend(flatten(item))
71
- else:
72
- flat_list.append(item)
73
- return flat_list
74
-
75
- if len(flatten(addchild)) > 0:
76
- add = 0
77
- for prod_items in itertools.product(*addchild):
78
- mul = 1
79
- for item2 in prod_items:
80
- mul = mul * item2
81
- mul = simplify(mul)
82
- add = add + mul
83
- add = simplify(add)
84
- current_eq = simplify(add)
85
- else:
86
- current_eq = simplify(current_eq)
87
-
88
- # Store expanded result
89
- result_map[node_id] = current_eq
90
- else:
91
- # Default: reconstruct node with children
92
- result_map[node_id] = TreeNode(node.name, children_expanded)
93
-
94
- # Return final expanded eq
95
- return result_map[id(eq)]
96
36
 
37
+
38
+ # =====================================================
39
+ # Phase 2: Single distributive rewrite (DEEPEST FIRST)
40
+ # =====================================================
41
+
42
+ def expand_once(node):
43
+ """
44
+ Performs exactly ONE distributive expansion.
45
+ Deepest-first (post-order).
46
+ """
47
+
48
+ # ---- recurse FIRST (this is the fix) ----
49
+ for i, c in enumerate(node.children):
50
+ new, changed = expand_once(c)
51
+ if changed:
52
+ node.children[i] = new
53
+ return node, True
54
+
55
+ # ---- now try expanding at this node ----
56
+ if node.name == "f_mul":
57
+ for i, child in enumerate(node.children):
58
+ if child.name == "f_add":
59
+ left = node.children[:i]
60
+ right = node.children[i+1:]
61
+
62
+ terms = []
63
+ for t in child.children:
64
+ prod = t
65
+ for r in right:
66
+ prod = prod * r
67
+ for l in reversed(left):
68
+ prod = l * prod
69
+ terms.append(prod)
70
+
71
+ return TreeNode("f_add", terms), True
72
+
73
+ return node, False
74
+
75
+
76
+ # =====================================================
77
+ # Phase 3: Global fixed-point driver
78
+ # =====================================================
79
+
80
+ def expand(eq):
81
+ orig = TreeNode.matmul
82
+ eq = simplify(eq)
83
+ if TreeNode.matmul is not None:
84
+ TreeNode.matmul = True
85
+ eq = tree_form(str_form(eq).replace("f_wmul", "f_mul"))
86
+ eq = flatten_tree(eq)
87
+ eq = eliminate_powers(eq)
88
+ while True:
89
+ eq = flatten_tree(eq)
90
+ eq, changed = expand_once(eq)
91
+ if not changed:
92
+ break
93
+ eq =simplify(eq)
94
+ TreeNode.matmul = orig
95
+ return eq