mathai 0.4.8__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mathai/limit.py CHANGED
@@ -1,12 +1,12 @@
1
1
  from .structure import structure
2
2
  from .base import *
3
3
  from .parser import parse
4
- from .simplify import simplify, solve
4
+ from .simplify import simplify
5
5
  from .expand import expand
6
6
  from .diff import diff
7
7
  from .trig import trig0
8
8
  from .fraction import fraction
9
- from .printeq import printeq_str
9
+ from .printeq import printeq
10
10
  tab=0
11
11
  def substitute_val(eq, val, var="v_0"):
12
12
  eq = replace(eq, tree_form(var), tree_form("d_"+str(val)))
@@ -33,100 +33,124 @@ def check(num, den, var):
33
33
  return simplify(n/d)
34
34
  return False
35
35
  def lhospital(num, den, steps,var):
36
- logs = []
37
36
 
38
37
  out = check(num, den, var)
39
38
 
40
39
  if isinstance(out, TreeNode):
41
- return out,[]
40
+ return out
42
41
  for _ in range(steps):
43
42
  num2, den2 = map(lambda e: simplify(diff(e, var)), (num, den))
44
43
  out = check(num2, den2, var)
45
44
  if out is True:
46
45
  num, den = num2, den2
47
- logs += [(0,"lim x->0 "+printeq_str(simplify(num/den)))]
48
46
  continue
49
47
  if out is False:
50
48
  eq2 = simplify(fraction(simplify(num/den)))
51
- return eq2,logs
52
- return out,logs
49
+ return eq2
50
+ return out
53
51
  def lhospital2(eq, var):
54
52
  eq= simplify(eq)
55
53
  if eq is None:
56
54
  return None
57
55
  if not contain(eq, tree_form(var)):
58
- return eq,[]
56
+ return eq
59
57
  num, dem = [simplify(item) for item in num_dem(eq)]
60
58
  if num is None or dem is None:
61
- return eq,[]
59
+ return eq
62
60
 
63
61
  return lhospital(num, dem, 10,var)
64
- ls = [parse("sin(A)"), parse("A^B-1"),parse("log(1+A)"), parse("cos(A)")]
65
- ls= [simplify(item) for item in ls]
66
-
67
- def approx(eq, var):
68
- n, d= num_dem(eq)
69
- n, d = solve(n), solve(d)
70
- n, d = expand(n), expand(d)
71
- out = []
72
- for equation in [n, d]:
73
- for item in factor_generation(equation):
74
- tmp = structure(item, ls[0])
75
- if tmp is not None and contain(tmp["v_-1"], var):
76
- item2 = substitute_val(tmp["v_-1"], 0, var.name)
77
- if tree_form("d_0") == expand(simplify(item2)):
78
- equation = equation/item
79
- equation = equation*tmp["v_-1"]
80
- break
81
- elif tree_form("d_0") == expand(simplify(tree_form("s_pi") - item2)):
82
- equation = equation/item
83
- equation = equation*(tree_form("s_pi") - tmp["v_-1"])
84
- break
85
- tmp = structure(item, ls[1])
86
- if tmp is not None and contain(tmp["v_-1"], var) and not contain(tmp["v_-2"], var):
87
- item2 = substitute_val(tmp["v_-1"], 0, var.name)
88
- item2 = expand(solve(item2))
89
- if tree_form("d_0") == item2:
90
- equation = equation/item
91
- equation = solve(equation*tmp["v_-1"]*tmp["v_-2"].fx("log"))
92
- break
93
- tmp = structure(item, ls[2])
94
- if tmp is not None and contain(tmp["v_-1"], var):
95
-
96
- item2 = substitute_val(tmp["v_-1"], 0, var.name)
97
- item2 = expand(solve(item2))
98
- if tree_form("d_0") == item2:
99
- equation = equation/item
100
- equation = solve(equation*tmp["v_-1"])
101
- break
102
- tmp = structure(item, ls[3])
103
- if tmp is not None and contain(tmp["v_-1"], var):
104
- item2 = substitute_val(item, 0, var.name)
105
-
106
- if tree_form("d_0") == expand(solve(item2)):
107
-
108
- equation = equation/item
109
- equation = equation*(tree_form("d_1") - tmp["v_-1"]**tree_form("d_2"))
110
- break
62
+ def limit0(equation):
63
+ if equation.name == "f_ref":
64
+ return equation
65
+ eq2 = equation
66
+ g = ["f_limit", "f_limitpinf", "f_limitninf"]
67
+ if eq2.name in g and contain(eq2.children[0], eq2.children[1]):
68
+ equation = eq2.children[0]
69
+ wrt = eq2.children[1]
70
+ lst = factor_generation(equation)
71
+
72
+ lst_const = [item for item in lst if not contain(item, wrt)]
73
+ if lst_const != []:
74
+
75
+ equation = product([item for item in lst if contain(item, wrt)]).copy_tree()
76
+ const = product(lst_const)
77
+ const = simplify(const)
78
+
79
+ if not contain(const, tree_form("s_i")):
111
80
 
112
- equation = solve(equation)
113
- out.append(equation)
114
- return simplify(out[0]/out[1])
115
- def approx_limit(equation, var):
116
- return dowhile(equation, lambda x: approx(x, var))
81
+ return limit0(TreeNode(equation.name,[equation, wrt])) *const
82
+ equation = eq2
83
+ return TreeNode(equation.name, [limit0(child) for child in equation.children])
84
+ def limit2(eq):
85
+ g = ["f_limit", "f_limitpinf", "f_limitninf"]
86
+ if eq.name in g and eq.children[0].name == "f_add":
87
+ eq = summation([TreeNode(eq.name, [child, eq.children[1]]) for child in eq.children[0].children])
88
+ return TreeNode(eq.name, [limit2(child) for child in eq.children])
89
+ def limit1(eq):
90
+ if eq.name == "f_limit":
91
+ a, b = limit(eq.children[0], eq.children[1].name)
92
+ if b:
93
+ return a
94
+ else:
95
+ return TreeNode(eq.name, [a, eq.children[1]])
96
+ return TreeNode(eq.name, [limit1(child) for child in eq.children])
97
+ def fxinf(eq):
98
+ if eq is None:
99
+ return None
100
+ if eq.name == "f_add":
101
+ if tree_form("s_inf") in eq.children and -tree_form("s_inf") in eq.children:
102
+ return None
103
+ if tree_form("s_inf") in eq.children:
104
+ return tree_form("s_inf")
105
+ if -tree_form("s_inf") in eq.children:
106
+ return -tree_form("s_inf")
107
+ if eq.name == "f_mul":
108
+ lst = factor_generation(eq)
109
+ if tree_form("s_inf") in lst:
110
+ eq = TreeNode(eq.name, [dowhile(child, fxinf) for child in eq.children])
111
+ if None in eq.children:
112
+ return None
113
+ lst = factor_generation(eq)
114
+ if tree_form("d_0") in lst:
115
+ return tree_form("d_0")
116
+ lst2 = [item for item in lst if "v_" in str_form(item)]
117
+ sign = True
118
+ if len([item for item in lst if "v_" not in str_form(item) and not contain(item, tree_form("s_inf")) and compute(item)<0]) % 2==1:
119
+ sign = False
120
+ if lst2 == []:
121
+ if sign:
122
+ return tree_form("s_inf")
123
+ else:
124
+ return -tree_form("s_inf")
125
+ if eq.name == "f_pow":
126
+ if "v_" not in str_form(eq.children[0]) and not contain(eq.children[0], tree_form("s_inf")) and compute(eq.children[0])>0:
127
+ if eq.children[1] == -tree_form("s_inf"):
128
+ return tree_form("d_0")
129
+
130
+ eq = TreeNode(eq.name, [fxinf(child) for child in eq.children])
131
+ if None in eq.children:
132
+ return None
133
+ return eq
134
+ def limit3(eq):
135
+
136
+ if eq.name == "f_limitpinf":
137
+ if not contain(eq, eq.children[1]):
138
+ return eq.children[0]
139
+ eq2 = replace(eq.children[0], eq.children[1], tree_form("s_inf"))
140
+ eq2 = dowhile(eq2, fxinf)
141
+ if not contain(eq2, tree_form("s_inf")) and not contain(eq2, eq.children[1]):
142
+ return simplify(eq2)
143
+ return TreeNode(eq.name, [limit3(child) for child in eq.children])
117
144
 
118
145
  def limit(equation, var="v_0"):
119
- logs = [(0,"lim x->0 "+printeq_str(simplify(equation)))]
146
+
120
147
  eq2 = dowhile(replace(equation, tree_form(var), tree_form("d_0")), lambda x: trig0(simplify(x)))
121
148
  if eq2 is not None and not contain(equation, tree_form(var)):
122
- return eq2,logs
149
+ return eq2, True
123
150
 
124
- equation, tmp = lhospital2(equation, var)
151
+ equation = lhospital2(equation, var)
125
152
  equation = simplify(expand(simplify(equation)))
126
153
  if not contain(equation, tree_form(var)):
127
- return equation,logs+tmp
128
- '''
129
- if equation.name == "f_add":
130
- return simplify(summation([limit(child, var) for child in equation.children]))
131
- '''
132
- return equation,logs+tmp
154
+ return equation, True
155
+
156
+ return equation, False
mathai/linear.py CHANGED
@@ -1,9 +1,12 @@
1
+ from .inverse import inverse
2
+ import itertools
1
3
  from .diff import diff
2
- from .simplify import simplify, solve
4
+ from .simplify import simplify
3
5
  from .fraction import fraction
4
6
  from .expand import expand
5
7
  from .base import *
6
8
  from .factor import factorconst
9
+ from .tool import poly
7
10
  def ss(eq):
8
11
  return dowhile(eq, lambda x: fraction(expand(simplify(x))))
9
12
  def rref(matrix):
@@ -31,38 +34,16 @@ def rref(matrix):
31
34
  return matrix
32
35
  def islinear(eq, fxconst):
33
36
  eq =simplify(eq)
34
- if eq.name == "f_pow" and fxconst(eq):#"v_" in str_form(eq):
35
- return False
36
- for child in eq.children:
37
- out = islinear(child, fxconst)
38
- if not out:
39
- return out
40
- return True
37
+ if all(fxconst(tree_form(item)) and poly(eq, item) is not None and len(poly(eq, item)) <= 2 for item in vlist(eq)):
38
+ return True
39
+ return False
41
40
  def linear(eqlist, fxconst):
42
- final = []
43
- extra = []
44
- for i in range(len(eqlist)-1,-1,-1):
45
- if eqlist[i].name == "f_mul" and not islinear(expand2(eqlist[i]), fxconst):
46
- if "v_" in str_form(eqlist[i]):
47
- eqlist[i] = TreeNode("f_mul", [child for child in eqlist[i].children if fxconst(child)])
48
- if all(islinear(child, fxconst) for child in eqlist[i].children):
49
- for child in eqlist[i].children:
50
- extra.append(TreeNode("f_eq", [child, tree_form("d_0")]))
51
- eqlist.pop(i)
52
- else:
53
- final.append(TreeNode("f_eq", [eqlist[i], tree_form("d_0")]))
54
- eqlist.pop(i)
41
+ orig = [item.copy_tree() for item in eqlist]
42
+ #eqlist = [eq for eq in eqlist if fxconst(eq)]
55
43
 
56
- if extra != []:
57
- final.append(TreeNode("f_or", extra))
58
- if eqlist == []:
59
- if len(final)==1:
60
-
61
- return final[0]
62
- return TreeNode("f_and", final)
63
- eqlist = [eq for eq in eqlist if fxconst(eq)]
64
- if not all(islinear(eq, fxconst) for eq in eqlist):
65
- return TreeNode("f_and", copy.deepcopy(final+eqlist))
44
+ if eqlist == [] or not all(islinear(eq, fxconst) for eq in eqlist):
45
+ return None
46
+ #return TreeNode("f_and", [TreeNode("f_eq", [x, tree_form("d_0")]) for x in orig])
66
47
  vl = []
67
48
  def varlist(eq, fxconst):
68
49
  nonlocal vl
@@ -75,7 +56,7 @@ def linear(eqlist, fxconst):
75
56
  vl = list(set(vl))
76
57
 
77
58
  if len(vl) > len(eqlist):
78
- return TreeNode("f_and", final+[TreeNode("f_eq", [x, tree_form("d_0")]) for x in eqlist])
59
+ return TreeNode("f_and", [TreeNode("f_eq", [x, tree_form("d_0")]) for x in eqlist])
79
60
  m = []
80
61
  for eq in eqlist:
81
62
  s = copy.deepcopy(eq)
@@ -94,63 +75,91 @@ def linear(eqlist, fxconst):
94
75
  for i in range(len(m)):
95
76
  for j in range(len(m[i])):
96
77
  m[i][j] = fraction(m[i][j])
97
-
98
- for item in m:
99
- if all(item2==tree_form("d_0") for item2 in item[:-1]) and item[-1] != tree_form("d_0"):
100
- return tree_form("s_false")
101
78
 
102
79
  output = []
103
80
  for index, row in enumerate(m):
104
- count = 0
105
- for item in row[:-1]:
106
- if item == tree_form("d_1"):
107
- count += 1
108
- if count == 2:
109
- break
110
- elif item == tree_form("d_0") and count == 1:
111
- break
112
- if count == 0:
113
- continue
114
- output.append(tree_form(vl[index])+row[-1])
115
- if len(output) == 1 and len(final)==0:
81
+ if not all(item == 0 for item in row[:-1]):
82
+ output.append(summation([tree_form(vl[index2])*coeff for index2, coeff in enumerate(row[:-1])])+row[-1])
83
+ elif row[-1] != 0:
84
+ return tree_form("s_false")
85
+ if len(output) == 1:
116
86
  return TreeNode("f_eq", [output[0], tree_form("d_0")])
117
- return TreeNode("f_and", final+[TreeNode("f_eq", [x, tree_form("d_0")]) for x in output])
87
+ if len(output) == 0:
88
+ return tree_form("s_false")
89
+ return TreeNode("f_and", [TreeNode("f_eq", [x, tree_form("d_0")]) for x in output])
90
+ def order_collinear_indices(points, idx):
91
+ """
92
+ Arrange a subset of collinear points (given by indices) along their line.
93
+
94
+ points: list of (x, y) tuples
95
+ idx: list of indices referring to points
96
+ Returns: list of indices sorted along the line
97
+ """
98
+ if len(idx) <= 1:
99
+ return idx[:]
100
+
101
+ # Take first two points from the subset to define the line
102
+ p0, p1 = points[idx[0]], points[idx[1]]
103
+ dx, dy = p1[0] - p0[0], p1[1] - p0[1]
104
+
105
+ # Projection factor for sorting
106
+ def projection_factor(i):
107
+ vx, vy = points[i][0] - p0[0], points[i][1] - p0[1]
108
+ return compute((vx * dx + vy * dy) / (dx**2 + dy**2))
109
+
110
+ # Sort indices by projection
111
+ sorted_idx = sorted(idx, key=projection_factor)
112
+ return list(sorted_idx)
113
+ def linear_or(eq):
114
+ eqlst =[]
115
+ if eq.name != "f_or":
116
+ eqlst = [eq]
117
+ else:
118
+ eqlst = eq.children
119
+ v = vlist(eq)
120
+ p = []
121
+ line = {}
122
+ for i in range(len(eqlst)):
123
+ line[i] = []
124
+ for item in itertools.combinations(enumerate(eqlst), 2):
125
+ x, y = item[0][0], item[1][0]
126
+ item = [item[0][1], item[1][1]]
127
+
128
+ out = linear_solve(TreeNode("f_and", list(item)))
118
129
 
119
- def rmeq(eq):
120
- if eq.name == "f_eq":
121
- return rmeq(eq.children[0])
122
- return TreeNode(eq.name, [rmeq(child) for child in eq.children])
130
+ if out is None:
131
+ return None
132
+
133
+ if out.name == "f_and" and all(len(vlist(child)) == 1 for child in out.children) and set(vlist(out)) == set(v) and all(len(vlist(simplify(child))) >0 for child in out.children):
134
+ t = {}
135
+ for child in out.children:
136
+ t[v.index(vlist(child)[0])] = simplify(inverse(child.children[0], vlist(child)[0]))
137
+ t2 = []
138
+ for key in sorted(t.keys()):
139
+ t2.append(t[key])
140
+ t2 = tuple(t2)
141
+ if t2 not in p:
142
+ p.append(t2)
143
+ line[x] += [p.index(t2)]
144
+ line[y] += [p.index(t2)]
145
+ line2 = []
146
+ for key in sorted(line.keys()):
147
+ line2.append(order_collinear_indices(p, list(set(line[key]))))
123
148
 
124
- def mat0(eq, lst=None):
125
- def findeq(eq):
126
- out = []
127
- if "f_list" not in str_form(eq) and "f_eq" not in str_form(eq):
128
- return [str_form(eq)]
129
- else:
130
- for child in eq.children:
131
- out += findeq(child)
132
- return out
133
- eqlist = findeq(eq)
134
- eqlist = [tree_form(x) for x in eqlist]
135
- eqlist = [rmeq(x) for x in eqlist]
136
- eqlist = [TreeNode("f_mul", factor_generation(x)) for x in eqlist if x != tree_form("d_0")]
137
- eqlist = [x.children[0] if len(x.children) == 1 else x for x in eqlist]
149
+ return v, p, line2, eqlst
150
+ def linear_solve(eq, lst=None):
151
+ eq = simplify(eq)
152
+ eqlist = []
153
+ if eq.name =="f_and" and all(child.name == "f_eq" and child.children[1] == 0 for child in eq.children):
154
+
155
+ eqlist = [child.children[0] for child in eq.children]
156
+ else:
157
+ return eq
138
158
  out = None
139
-
140
159
  if lst is None:
141
160
  out = linear(copy.deepcopy(eqlist), lambda x: "v_" in str_form(x))
142
161
  else:
143
162
  out = linear(copy.deepcopy(eqlist), lambda x: any(contain(x, item) for item in lst))
144
- def rms(eq):
145
- if eq.name in ["f_and", "f_or"] and len(eq.children) == 1:
146
- return eq.children[0]
147
- return TreeNode(eq.name, [rms(child) for child in eq.children])
148
- return rms(out)
149
- def linear_solve(eq, lst=None):
150
- if eq.name == "f_and":
151
- eq2 = copy.deepcopy(eq)
152
- eq2.name = "f_list"
153
- return mat0(eq2, lst)
154
- elif eq.name == "f_eq":
155
- return mat0(eq, lst)
156
- return TreeNode(eq.name, [linear_solve(child, lst) for child in eq.children])
163
+ if out is None:
164
+ return None
165
+ return simplify(out)
mathai/logic.py CHANGED
@@ -1,6 +1,12 @@
1
1
  import itertools
2
2
  from .base import *
3
-
3
+ def c(eq):
4
+ eq = logic1(eq)
5
+ eq = dowhile(eq, logic0)
6
+ eq = dowhile(eq, logic2)
7
+ return eq
8
+ def logic_n(eq):
9
+ return dowhile(eq, c)
4
10
  def logic0(eq):
5
11
  if eq.children is None or len(eq.children)==0:
6
12
  return eq
mathai/matrix.py ADDED
@@ -0,0 +1,228 @@
1
+ from .base import *
2
+ import copy
3
+ from .simplify import simplify
4
+ import itertools
5
+
6
+ # ---------- tree <-> python list ----------
7
+ def tree_to_py(node):
8
+ if node.name=="f_list":
9
+ return [tree_to_py(c) for c in node.children]
10
+ return node
11
+
12
+ def py_to_tree(obj):
13
+ if isinstance(obj,list):
14
+ return TreeNode("f_list",[py_to_tree(x) for x in obj])
15
+ return obj
16
+
17
+ # ---------- shape detection ----------
18
+ def is_vector(x):
19
+ return isinstance(x,list) and all(isinstance(item,TreeNode) for item in x)
20
+ def is_mat(x):
21
+ return isinstance(x,list) and all(isinstance(item,list) for item in x)
22
+ def is_matrix(x):
23
+ return isinstance(x, list) and all(isinstance(item, list) and (is_mat(item) or is_vector(item)) for item in x)
24
+
25
+
26
+ # ---------- algebra primitives ----------
27
+ def dot(u,v):
28
+ if len(u)!=len(v):
29
+ raise ValueError("Vector size mismatch")
30
+ s = tree_form("d_0")
31
+ for a,b in zip(u,v):
32
+ s = TreeNode("f_add",[s,TreeNode("f_mul",[a,b])])
33
+ return s
34
+
35
+ def matmul(A, B):
36
+ # A: n × m
37
+ # B: m × p
38
+
39
+ n = len(A)
40
+ m = len(A[0])
41
+ p = len(B[0])
42
+
43
+ if m != len(B):
44
+ raise ValueError("Matrix dimension mismatch")
45
+
46
+ C = [[tree_form("d_0") for _ in range(p)] for _ in range(n)]
47
+
48
+ for i in range(n):
49
+ for j in range(p):
50
+ for k in range(m):
51
+ C[i][j] = TreeNode(
52
+ "f_add",
53
+ [C[i][j], TreeNode("f_mul", [A[i][k], B[k][j]])]
54
+ )
55
+ return C
56
+
57
+ # ---------- promotion ----------
58
+ def promote(node):
59
+ if node.name=="f_list":
60
+ return tree_to_py(node)
61
+ return node
62
+ def contains_neg(node):
63
+ if isinstance(node, list):
64
+ return False
65
+ if node.name.startswith("v_-"):
66
+ return False
67
+ for child in node.children:
68
+ if not contains_neg(child):
69
+ return False
70
+ return True
71
+ # ---------- multiplication (fully simplified) ----------
72
+ def multiply(left,right):
73
+ if left == tree_form("d_1"):
74
+ return right
75
+ if right == tree_form("d_1"):
76
+ return left
77
+ left2, right2 = left, right
78
+ if left2.name != "f_pow":
79
+ left2 = left2 ** 1
80
+ if right2.name != "f_pow":
81
+ right2 = right2 ** 1
82
+ if left2.name == "f_pow" and right2.name == "f_pow" and left2.children[0]==right2.children[0]:
83
+ return simplify(left2.children[0]**(left2.children[1]+right2.children[1]))
84
+ A,B = promote(left), promote(right)
85
+
86
+ # vector · vector
87
+ if is_vector(A) and is_vector(B):
88
+ return dot(A,B)
89
+ # matrix × matrix
90
+ if is_matrix(A) and is_matrix(B):
91
+ return py_to_tree(matmul(A,B))
92
+ # scalar × vector
93
+ for _ in range(2):
94
+ if contains_neg(A) and is_vector(B):
95
+ return py_to_tree([TreeNode("f_mul",[A,x]) for x in B])
96
+ # scalar × matrix
97
+ if contains_neg(A) and is_matrix(B):
98
+ return py_to_tree([[TreeNode("f_mul",[A,x]) for x in row] for row in B])
99
+ A, B = B, A
100
+ return None
101
+ def add_vec(A, B):
102
+ if len(A) != len(B):
103
+ raise ValueError("Vector dimension mismatch")
104
+
105
+ return [
106
+ TreeNode("f_add", [A[i], B[i]])
107
+ for i in range(len(A))
108
+ ]
109
+ def matadd(A, B):
110
+ if len(A) != len(B) or len(A[0]) != len(B[0]):
111
+ raise ValueError("Matrix dimension mismatch")
112
+
113
+ n = len(A)
114
+ m = len(A[0])
115
+
116
+ return [
117
+ [
118
+ TreeNode("f_add", [A[i][j], B[i][j]])
119
+ for j in range(m)
120
+ ]
121
+ for i in range(n)
122
+ ]
123
+ def addition(left,right):
124
+ A,B = promote(left), promote(right)
125
+ # vector + vector
126
+ if is_vector(A) and is_vector(B):
127
+ return add_vec(A,B)
128
+ # matrix + matrix
129
+ if is_matrix(A) and is_matrix(B):
130
+ return py_to_tree(matadd(A,B))
131
+ return None
132
+ '''
133
+ def fold_wmul(eq):
134
+ if eq.name == "f_pow" and eq.children[1].name.startswith("d_"):
135
+ n = int(eq.children[1].name[2:])
136
+ if n == 1:
137
+ eq = eq.children[0]
138
+ elif n > 1:
139
+ tmp = promote(eq.children[0])
140
+ if is_matrix(tmp):
141
+ orig =tmp
142
+ for i in range(n-1):
143
+ tmp = matmul(orig, tmp)
144
+ eq = py_to_tree(tmp)
145
+ elif eq.name in ["f_wmul", "f_add"]:
146
+ if len(eq.children) == 1:
147
+ eq = eq.children[0]
148
+ else:
149
+ i = len(eq.children)-1
150
+ while i>0:
151
+ if eq.name == "f_wmul":
152
+ out = multiply(eq.children[i-1], eq.children[i])
153
+ else:
154
+ out = addition(eq.children[i-1], eq.children[i])
155
+ if out is not None:
156
+ eq.children.pop(i)
157
+ eq.children.pop(i-1)
158
+ eq.children.insert(i-1,out)
159
+ i = i-1
160
+ return TreeNode(eq.name, [fold_wmul(child) for child in eq.children])
161
+ '''
162
+ def fold_wmul(root):
163
+ # Post-order traversal using explicit stack
164
+ stack = [(root, False)]
165
+ newnode = {}
166
+
167
+ while stack:
168
+ node, visited = stack.pop()
169
+
170
+ if not visited:
171
+ # First time: push back as visited, then children
172
+ stack.append((node, True))
173
+ for child in node.children:
174
+ stack.append((child, False))
175
+ else:
176
+ # All children already processed
177
+ children = [newnode[c] for c in node.children]
178
+ eq = TreeNode(node.name, children)
179
+
180
+ # ---- original rewrite logic ----
181
+
182
+ if eq.name == "f_pow" and eq.children[1].name.startswith("d_"):
183
+ n = int(eq.children[1].name[2:])
184
+ if n == 1:
185
+ eq = eq.children[0]
186
+ elif n > 1:
187
+ tmp = promote(eq.children[0])
188
+ if is_matrix(tmp):
189
+ orig = tmp
190
+ for _ in range(n - 1):
191
+ tmp = matmul(orig, tmp)
192
+ eq = py_to_tree(tmp)
193
+
194
+ elif eq.name in ["f_wmul", "f_add"]:
195
+ if len(eq.children) == 1:
196
+ eq = eq.children[0]
197
+ else:
198
+ i = len(eq.children) - 1
199
+ while i > 0:
200
+ if eq.name == "f_wmul":
201
+ out = multiply(eq.children[i - 1], eq.children[i])
202
+ else:
203
+ out = addition(eq.children[i - 1], eq.children[i])
204
+
205
+ if out is not None:
206
+ eq.children.pop(i)
207
+ eq.children.pop(i - 1)
208
+ eq.children.insert(i - 1, out)
209
+ i -= 1
210
+
211
+ # --------------------------------
212
+
213
+ newnode[node] = eq
214
+
215
+ return newnode[root]
216
+
217
+ def flat(eq):
218
+ return flatten_tree(eq)
219
+ def use(eq):
220
+ return TreeNode(eq.name, [use(child) for child in eq.children])
221
+ def _matrix_solve(eq):
222
+ if TreeNode.matmul == True:
223
+ TreeNode.matmul = False
224
+ eq = dowhile(eq, lambda x: fold_wmul(use(flat(x))))
225
+ TreeNode.matmul = True
226
+ return eq
227
+ def matrix_solve(eq):
228
+ return _matrix_solve(eq)