mathai 0.7.8__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mathai/ode.py CHANGED
@@ -1,6 +1,7 @@
1
+ import itertools
1
2
  from collections import Counter
2
- from .diff import diff
3
- from .factor import factor, factor2
3
+ from .diff import diff, diff2
4
+ from .factor import factor, factor2, term_common2
4
5
  from .expand import expand
5
6
  from .base import *
6
7
  from .fraction import fraction
@@ -9,71 +10,110 @@ import copy
9
10
  from .inverse import inverse
10
11
  from .parser import parse
11
12
 
12
-
13
- def jjj(lhs, rhs):
14
- lst = [lhs, rhs]
15
- for i in range(2):
16
- if lst[i].name in ["f_mul", "f_add"]:
17
-
18
- out = []
19
- for child in lst[i].children:
20
-
21
- if not contain(child, tree_form(f"v_{i}")):
22
- out.append(child)
23
- if out == []:
24
- continue
25
- out = TreeNode(lst[i].name, out)
26
-
27
- if len(out.children) == 1:
28
- out = out.children[0]
29
- out = out.copy_tree()
30
- if lst[i].name == "f_add":
31
- lst[i] = lst[i] - out
32
- lst[1-i] = lst[1-i] - out
33
- else:
34
- lst[i] = lst[i] / out
35
- lst[1-i] = lst[1-i] / out
36
-
37
- lst = [simplify(expand(simplify(item))) for item in lst]
38
- return lst
39
-
40
- def kkk(lhs, rhs, depth=3):
41
-
42
- lst = jjj(lhs, rhs)
43
-
44
- if depth < 0:
45
- return lst, False
13
+ def rev(eq):
14
+ tmp = factor_generation(eq)
15
+ if tree_form("v_0").fx("dif")**-1 in tmp or tree_form("v_1").fx("dif")**-1 in tmp:
16
+ return False
17
+ for child in eq.children:
18
+ if not rev(child):
19
+ return False
20
+ return True
21
+ node_count = 100
22
+ def kkk(lhs, rhs, depth=5):
23
+ global node_count
24
+ lst = [simplify(lhs), simplify(rhs)]
25
+ orig = copy.deepcopy(lst)
46
26
  if not contain(lst[0], tree_form("v_1")) and not contain(lst[1], tree_form("v_0")):
27
+ if not contain(lst[0], tree_form("v_0")) and not contain(lst[1], tree_form("v_1")):
28
+ return lst, False
47
29
  return lst, True
48
- orig = copy.deepcopy(lst)
49
- for i in range(2):
50
- if lst[i].name in ["f_mul", "f_add"]:
51
- for child in lst[i].children:
52
-
53
- out = child
54
- if lst[i].name == "f_add":
55
- lst[i] = lst[i] - out
56
- lst[1-i] = lst[1-i] - out
57
- else:
58
- lst[i] = lst[i] / out
59
- lst[1-i] = lst[1-i] / out
60
- lst = [simplify(item) for item in lst]
30
+ node_count -= 1
31
+
32
+ if depth < 0 or node_count < 0:
33
+ return lst, False
34
+ for j in range(2):
35
+ for i in range(2):
36
+ if lst[i].name in ["f_mul", "f_add"]:
37
+ for child in lst[i].children:
38
+ out = child
39
+ if j == 0:
40
+ if contain(out, tree_form(f"v_{i}")) or not contain(out, tree_form(f"v_{1-i}")):
41
+ continue
42
+ if contain(out, tree_form(f"v_{i}")) and not contain(out, tree_form(f"v_{1-i}")):
43
+ continue
44
+
45
+ if lst[i].name == "f_add":
46
+ lst[i] = lst[i] - out
47
+ lst[1-i] = lst[1-i] - out
48
+ elif lst[i].name == "f_mul":
49
+ lst[i] = lst[i] / out
50
+ lst[1-i] = lst[1-i] / out
51
+ else:
52
+ continue
53
+
54
+ output = kkk(lst[0], lst[1], depth-1)
55
+ lst = orig
56
+
57
+ if output[1]:
58
+ return output
61
59
 
62
- output = kkk(lst[0], lst[1], depth-1)
63
- lst = orig
64
- if output[1]:
65
- return output
66
60
  return lst, False
67
-
61
+ def clr(eq):
62
+ return simplify(product([item for item in factor_generation(eq) if "f_add" in str_form(item)]))
68
63
  def inversediff(lhs, rhs):
69
- out = [[tree_form("d_0"), lhs-rhs], False]
70
- while True:
71
- out = list(kkk(out[0][0], out[0][1]))
72
- if out[1]:
73
- break
74
- out[0] = [simplify(item) for item in out[0]]
64
+ global node_count
65
+ eq = simplify(fraction(TreeNode("f_eq", [lhs-rhs, tree_form("d_0")]))).children[0]
66
+ eq = simplify(term_common2(eq))
67
+ eq = clr(eq)
68
+
69
+ out= None
70
+ if eq.name == "f_add":
71
+ h = {}
72
+ n = [eq]
73
+ for i in range(len(eq.children)-2,1,-1):
74
+ for item in itertools.combinations(list(range(len(eq.children))), i):
75
+ item = tuple(sorted(list(item)))
76
+ tmp = simplify(term_common2(simplify(summation([eq.children[x] for x in item]))))
77
+ if tmp.name == "f_mul":
78
+ h[item] = tmp
79
+
80
+ for item in itertools.combinations(list(h.keys()),2):
81
+
82
+ g = []
83
+ for x in item:
84
+ g += x
85
+ if sum([len(x) for x in item]) == len(set(g)):
86
+ pass
87
+ else:
88
+ continue
89
+
90
+ item2 = summation([eq.children[x] for x in list(set(range(len(eq.children)))-set(g))])
91
+ n.append(simplify(term_common2(simplify(h[item[0]] + h[item[1]]))+item2))
92
+
93
+
94
+ for item in list(set(n)):
95
+
96
+ item = clr(item)
97
+ node_count = 100
98
+
99
+ tmp = kkk(item, tree_form("d_0"))
100
+
101
+ if tmp[1]:
102
+ out = tmp[0]
103
+ break
104
+ else:
105
+ node_count = 100
106
+ tmp = kkk(eq, tree_form("d_0"))
107
+ if tmp[1]:
108
+ out = tmp[0]
109
+ if out is None:
110
+ return None
111
+ out = [simplify(fraction(item)) for item in out]
112
+
113
+ if not rev(out[0]) and not rev(out[1]):
75
114
 
76
- out = out[0]
115
+ out[0] = fraction(1/out[0])
116
+ out[1] = fraction(1/out[1])
77
117
  return simplify(e0(out[0]-out[1]))
78
118
 
79
119
  def allocvar():
@@ -88,38 +128,28 @@ def esolve(s):
88
128
  return product([tree_form("s_e")**child for child in s.children]) - tree_form("d_1")
89
129
  return TreeNode(s.name, [esolve(child) for child in s.children])
90
130
  def diffsolve_sep2(eq):
91
- global tab
92
-
93
- s = []
94
- eq = simplify(expand(eq))
95
- eq = e1(eq)
96
-
97
- def vlor1(eq):
98
- if contain(eq, tree_form("v_0")) and not contain(eq, tree_form("v_1")):
99
- return True
100
- if contain(eq, tree_form("v_1")) and not contain(eq, tree_form("v_0")):
101
- return True
102
- return False
103
- if eq.name == "f_add" and all(vlor1(child) and [str_form(x) for x in factor_generation(copy.deepcopy(child))].count(str_form(tree_form(vlist(child)[0]).fx("dif")))==1 for child in eq.children):
104
- for child in eq.children:
105
- v = vlist(child)[0]
106
- v2 = tree_form(v).fx("dif")
107
- child = replace(child, v2, tree_form("d_1"))
108
- child = simplify(child)
109
-
110
-
111
- tmp6 = TreeNode("f_integrate", [child, tree_form(v)])
112
- s.append(tmp6)
113
-
114
- if s[-1] is None:
115
- return None
116
- s.append(allocvar())
117
- else:
131
+ lst = None
132
+ if eq is None:
118
133
  return None
119
- s = summation(s)
120
- s = simplify(e0(s))
134
+ eq = eq.children[0]
135
+ if eq.name == "f_add":
136
+ lst = list(eq.children)
137
+ else:
138
+ lst = [eq]
139
+ s = [allocvar()]
140
+
141
+ for item in lst:
142
+ item = simplify(item)
143
+ tmp = factor_generation(item)
144
+
145
+ tmp2 = product([k for k in tmp if k.name != "f_dif"])
146
+
147
+ if tree_form("v_0").fx("dif") in tmp:
148
+ s.append(TreeNode("f_integrate", [tmp2, tree_form("v_0")]))
149
+ elif tree_form("v_1").fx("dif") in tmp:
150
+ s.append(TreeNode("f_integrate", [tmp2, tree_form("v_1")]))
121
151
 
122
- return groupe(s)
152
+ return TreeNode("f_eq", [summation(s), tree_form("d_0")])
123
153
  def e0(eq):
124
154
  return TreeNode("f_eq", [eq, tree_form("d_0")])
125
155
  def e1(eq):
@@ -136,11 +166,17 @@ def groupe(eq):
136
166
 
137
167
  def diffsolve_sep(eq):
138
168
  eq = epowersplit(eq)
169
+
139
170
  eq = inversediff(tree_form("d_0"), eq.children[0].copy_tree())
171
+
140
172
  return eq
141
173
 
142
174
  def diffsolve(eq):
143
175
  orig = eq.copy_tree()
176
+ eq = diff2(eq)
177
+ eq = subs2(eq, order(eq))
178
+ eq = fraction(simplify(fraction(eq)))
179
+
144
180
  if order(eq) == 2:
145
181
  for i in range(2):
146
182
  out = second_order_dif(eq, tree_form(f"v_{i}"), tree_form(f"v_{1-i}"))
@@ -149,28 +185,38 @@ def diffsolve(eq):
149
185
  return orig
150
186
 
151
187
  eq = diffsolve_sep2(diffsolve_sep(eq))
188
+
152
189
  if eq is None:
153
190
  for i in range(2):
154
191
  a = tree_form(f"v_{i}")
155
192
  b = tree_form(f"v_{1-i}")
156
193
  c = tree_form("v_2")
157
- eq2 = replace(orig, b,b*a)
158
- eq2 = replace(eq2, (a*b).fx("dif"), a.fx("dif")*b + b.fx("dif")*a)
159
- eq2 = expand(simplify(fraction(simplify(eq2))))
194
+ eq2 = orig
195
+
196
+ eq2 = subs2(eq2, 1)
197
+ eq2 = replace(eq2, b, b*a)
198
+ eq2 = subs3(eq2)
199
+
200
+ eq2 = simplify(fraction(simplify(eq2)))
201
+
160
202
  eq2 = diffsolve_sep(eq2)
203
+
161
204
  eq2 = diffsolve_sep2(eq2)
162
205
  if eq2 is not None:
163
206
  return e0(TreeNode("f_subs", [replace(eq2.children[0],b,c), c,b/a]).fx("try"))
164
207
  eq = orig
165
-
166
- eq = fraction(eq)
167
- eq = simplify(eq)
168
- for i in range(2):
169
-
170
- out = linear_dif(eq, tree_form(f"v_{i}"), tree_form(f"v_{1-i}"))
171
- if out is not None:
172
- return out
173
- return eq
208
+ eq = simplify(eq)
209
+ eq = subs2(eq, 1)
210
+ eq = fraction(eq)
211
+ for i in range(2):
212
+
213
+ out = linear_dif(eq, tree_form(f"v_{i}"), tree_form(f"v_{1-i}"))
214
+ if out is not None:
215
+ return out
216
+ return orig
217
+ else:
218
+ return eq
219
+
174
220
  def clist(x):
175
221
  return list(x.elements())
176
222
  def collect_term(eq, term_lst):
@@ -219,6 +265,19 @@ def order(eq,m=0):
219
265
  out = order(child, m)
220
266
  best = max(out, best)
221
267
  return best
268
+ def subs2(eq, orde):
269
+ if eq.name in ["f_dif", "f_pdif"] and len(eq.children) == 2:
270
+ if orde == 1:
271
+ return eq.children[0].fx("dif")/eq.children[1].fx("dif")
272
+ else:
273
+ return subs2(TreeNode("f_dif", eq.children), orde)
274
+ return TreeNode(eq.name, [subs2(child, orde) for child in eq.children])
275
+ def subs3(eq):
276
+ if eq.name == "f_dif" and eq.children[0].name == "f_add":
277
+ return summation([subs3(child.fx("dif")) for child in eq.children[0].children])
278
+ if eq.name == "f_dif" and eq.children[0].name == "f_mul":
279
+ return summation([product([subs3(child.fx("dif")) if index==index2 else child for index2, child in enumerate(eq.children[0].children)]) for index in range(len(eq.children[0].children))])
280
+ return TreeNode(eq.name, [subs3(child) for child in eq.children])
222
281
  def second_order_dif(eq, a, b):
223
282
  eq = simplify(eq)
224
283
  nn = [TreeNode("f_dif", [TreeNode("f_dif", [b,a]),a]), TreeNode("f_dif", [b,a]), b]
@@ -260,7 +319,8 @@ def linear_dif(eq, a, b):
260
319
  for key in out[0].keys():
261
320
  out[0][key] = simplify(out[0][key]/tmp)
262
321
  p, q = out[0][b*a.fx("dif")], -out[0][a.fx("dif")]
263
-
322
+ if contain(p, b) or contain(q, b):
323
+ return None
264
324
  f = tree_form("s_e") ** TreeNode("f_integrate", [p, a])
265
325
  return simplify(TreeNode("f_eq", [b*f, TreeNode("f_integrate", [q*f, a])+allocvar()]))
266
326
  return None
mathai/parser.py CHANGED
@@ -87,7 +87,6 @@ def parse(equation, funclist=None):
87
87
  parser_main = Lark(grammar2, start='start', parser='lalr')
88
88
  parse_tree = parser_main.parse(equation)
89
89
 
90
- # Convert Lark tree to TreeNode
91
90
  def convert_to_treenode(parse_tree):
92
91
  if isinstance(parse_tree, Tree):
93
92
  node = TreeNode(parse_tree.data)
@@ -96,7 +95,6 @@ def parse(equation, funclist=None):
96
95
  else:
97
96
  return TreeNode(str(parse_tree))
98
97
 
99
- # Flatten unnecessary nodes like pass_through
100
98
  def remove_past(equation):
101
99
  if equation.name in {"number", "paren", "func", "variable", "pass_through", "cnumber", "string", "matrix"}:
102
100
  if len(equation.children) == 1:
@@ -107,7 +105,6 @@ def parse(equation, funclist=None):
107
105
  equation.children = [remove_past(child) for child in equation.children]
108
106
  return equation
109
107
 
110
- # Handle indices if any
111
108
  def prefixindex(equation):
112
109
  if equation.name == "base" and len(equation.children) > 1:
113
110
  return TreeNode("index", [equation.children[0]] + equation.children[1].children)
@@ -117,16 +114,15 @@ def parse(equation, funclist=None):
117
114
  tree_node = remove_past(tree_node)
118
115
  tree_node = prefixindex(tree_node)
119
116
 
120
- # Convert function names and constants
121
117
  def fxchange(tree_node):
122
118
  tmp3 = funclist if funclist is not None else []
123
119
  if tree_node.name == "neg":
124
120
  child = fxchange(tree_node.children[0])
125
- # if the child is a number, make it negative
121
+
126
122
  if child.name.startswith("d_") and re.match(r"d_\d+(\.\d+)?$", child.name):
127
123
  return TreeNode("d_" + str(-int(child.name[2:])))
128
124
  else:
129
- # otherwise subtract from zero
125
+
130
126
  return TreeNode("f_sub", [tree_form("d_0"), child])
131
127
  if tree_node.name == "pass_through":
132
128
  return fxchange(tree_node.children[0])
@@ -137,11 +133,9 @@ def parse(equation, funclist=None):
137
133
 
138
134
  tree_node = fxchange(tree_node)
139
135
 
140
- # Replace common constants
141
136
  for const in ["e","pi","kc","em","ec","anot","hbar","false","true","i","nabla"]:
142
137
  tree_node = replace(tree_node, tree_form("d_"+const), tree_form("s_"+const))
143
138
 
144
- # Map letters to variables
145
139
  for i, c in enumerate(["x","y","z"] + [chr(x+ord("a")) for x in range(0,23)]):
146
140
  tree_node = replace(tree_node, tree_form("d_"+c), tree_form("v_"+str(i)))
147
141
  for i, c in enumerate([chr(x+ord("A")) for x in range(0,26)]):
mathai/pde.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from .expand import expand
2
- from .ode import diffsolve, inversediff, order, groupe, epowersplit
2
+ from .ode import diffsolve, inversediff, order, groupe, epowersplit, subs2
3
3
  from .base import *
4
4
  from .simplify import simplify
5
5
  from .diff import diff2
@@ -16,13 +16,7 @@ def capital2(eq):
16
16
  if out is not None:
17
17
  return out
18
18
  return None
19
- def subs2(eq, orde):
20
- if eq.name == "f_pdif":
21
- if orde == 1:
22
- return eq.children[0].fx("dif")/eq.children[1].fx("dif")
23
- else:
24
- return subs2(TreeNode("f_dif", eq.children), orde)
25
- return TreeNode(eq.name, [subs2(child, orde) for child in eq.children])
19
+
26
20
  def capital(eq):
27
21
  if eq.name[:2] == "f_" and eq.name != eq.name.lower():
28
22
  return eq
@@ -36,6 +30,8 @@ def abs_const(eq):
36
30
  return tree_form("v_101")*eq.children[0]
37
31
  return TreeNode(eq.name, [abs_const(child) for child in eq.children])
38
32
  def want(eq):
33
+ if eq is None:
34
+ return None
39
35
  if eq.name == "f_want":
40
36
 
41
37
  eq2 = eq.children[0]
@@ -95,8 +91,7 @@ def pde_sep(eq):
95
91
  r = capital(out[i])
96
92
  lst.append(r)
97
93
  out[i] = replace(out[i], r, tree_form(f"v_{1-i}"))
98
- out[i] = subs2(out[i], order(out[i]))
99
-
94
+ out[i] = tree_form(str_form(out[i]).replace("f_pdif", "f_dif"))
100
95
  out[i] = diffsolve(out[i])
101
96
 
102
97
  out[i] = replace(out[i], tree_form(f"v_{1-i}"), r)