mathai 0.2.9__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mathai/parser.py CHANGED
@@ -1,154 +1,158 @@
1
- import copy
2
- from lark import Lark, Tree
3
- from .base import *
4
- import re
5
-
6
- grammar = """
7
- ?start: expr
8
-
9
- ?expr: logic_equiv
10
-
11
- ?logic_equiv: logic_imply
12
- | logic_equiv "<->" logic_imply -> equiv
13
-
14
- ?logic_imply: logic_or
15
- | logic_or "->" logic_imply -> imply
16
-
17
- ?logic_or: logic_and
18
- | logic_or "|" logic_and -> or
19
- | logic_or "||" logic_and -> or
20
-
21
- ?logic_and: comparison
22
- | logic_and "&" comparison -> and
23
- | logic_and "&&" comparison -> and
24
-
25
- ?comparison: arithmetic
26
- | comparison "=" arithmetic -> eq
27
- | comparison "<" arithmetic -> lt
28
- | comparison ">" arithmetic -> gt
29
- | comparison "<=" arithmetic -> le
30
- | comparison ">=" arithmetic -> ge
31
-
32
- ?arithmetic: arithmetic "+" term -> add
33
- | arithmetic "-" term -> sub
34
- | term
35
-
36
- ?term: term "*" power -> mul
37
- | term "/" power -> div
38
- | term "." power -> dot
39
- | power
40
-
41
- ?power: power "^" factor -> pow
42
- | power "**" factor -> pow
43
- | factor
44
-
45
- ?factor: "-" factor -> neg
46
- | "+" factor -> pass_through
47
- | atom
48
-
49
- ?atom: NUMBER -> number
50
- | VARIABLE -> variable
51
- | FUNC_NAME "(" [expr ("," expr)*] ")" -> func
52
- | "[" [expr ("," expr)*] "]" -> list
53
- | "(" expr ")" -> paren
54
- | CNUMBER -> cnumber
55
- | ESCAPED_STRING -> string
56
- | CAPITAL_ID -> matrix
57
-
58
- FUNC_NAME: "midpoint" | "forall" | "imply" | "exist" | "len" | "sum" | "angle" | "line" | "sum2" | "charge" | "electricfield" | "perm" | "point" | "equationrhs" | "transpose" | "equationlhs" | "equation" | "error" | "covariance" | "variance" | "expect" | "mag" | "rad" | "laplace" | "diverge" | "pdif" | "gradient" | "curl" | "point1" | "point2" | "dot" | "point3" | "line1" | "line2" | "line3" | "sin" | "circumcenter" | "eqtri" | "linesegment" | "cos" | "tan" | "log" | "sqrt" | "integrate" | "dif" | "abs" | "cosec" | "sec" | "cot" | "arctan" | "arcsin" | "arccos" | "log10"
59
-
60
- VARIABLE: /[a-z]/ | "nabla" | "pi" | "kc" | "hbar" | "em" | "ec" | "anot" | "false" | "true"
61
-
62
- CAPITAL_ID: /[A-Z]/
63
-
64
- CNUMBER: /c[0-9]+/
65
-
66
- %import common.NUMBER
67
- %import common.ESCAPED_STRING
68
- %import common.WS_INLINE
69
- %ignore WS_INLINE
70
- """
71
-
72
- def parse(equation, funclist=None):
73
- equation = copy.copy(equation.replace(" ", ""))
74
- grammar2 = copy.deepcopy(grammar)
75
- if funclist is not None:
76
- output = grammar2.split("\n")
77
- for i in range(len(output)):
78
- if "FUNC_NAME:" in output[i]:
79
- output[i] = output[i].replace("FUNC_NAME: ", "FUNC_NAME: " + " | ".join(['"' + x + '"' for x in funclist]) + " | ")
80
- grammar2 = "\n".join(output)
81
-
82
- parser_main = Lark(grammar2, start='start', parser='lalr')
83
- parse_tree = parser_main.parse(equation)
84
-
85
- # Convert Lark tree to TreeNode
86
- def convert_to_treenode(parse_tree):
87
- if isinstance(parse_tree, Tree):
88
- node = TreeNode(parse_tree.data)
89
- node.children = [convert_to_treenode(child) for child in parse_tree.children]
90
- return node
91
- else:
92
- return TreeNode(str(parse_tree))
93
-
94
- # Flatten unnecessary nodes like pass_through
95
- def remove_past(equation):
96
- if equation.name in {"number", "paren", "func", "variable", "pass_through", "cnumber", "string", "matrix"}:
97
- if len(equation.children) == 1:
98
- return remove_past(equation.children[0])
99
- else:
100
- equation.children = [remove_past(child) for child in equation.children]
101
- return TreeNode(equation.children[0].name, equation.children[1:])
102
- equation.children = [remove_past(child) for child in equation.children]
103
- return equation
104
-
105
- # Handle indices if any
106
- def prefixindex(equation):
107
- if equation.name == "base" and len(equation.children) > 1:
108
- return TreeNode("index", [equation.children[0]] + equation.children[1].children)
109
- return TreeNode(equation.name, [prefixindex(child) for child in equation.children])
110
-
111
- tree_node = convert_to_treenode(parse_tree)
112
- tree_node = remove_past(tree_node)
113
- tree_node = prefixindex(tree_node)
114
-
115
- # Convert function names and constants
116
- def fxchange(tree_node):
117
- tmp3 = funclist if funclist is not None else []
118
- if tree_node.name == "neg":
119
- child = fxchange(tree_node.children[0])
120
- # if the child is a number, make it negative
121
- if child.name.startswith("d_") and re.match(r"d_\d+(\.\d+)?$", child.name):
122
- return TreeNode("d_" + str(-int(child.name[2:])))
123
- else:
124
- # otherwise subtract from zero
125
- return TreeNode("f_sub", [tree_form("d_0"), child])
126
- if tree_node.name == "pass_through":
127
- return fxchange(tree_node.children[0])
128
- return TreeNode(
129
- "f_" + tree_node.name if tree_node.name in tmp3 + ["sqrt","imply","forall","exist","exclude","union","intersection","len","index","angle","charge","sum2","electricfield","line","point","sum","transpose","equationrhs","equationlhs","equation","covariance","variance","expect","error","laplace","dot","curl","pdif","diverge","gradient","rad","ge","le","gt","lt","eqtri","linesegment","midpoint","mag","point1","point2","point3","line1","line2","line3","log10","arcsin","arccos","arctan","list","cosec","sec","cot","equiv","or","not","and","circumcenter","eq","sub","add","sin","cos","tan","mul","integrate","dif","pow","div","log","abs"] else "d_" + tree_node.name,
130
- [fxchange(child) for child in tree_node.children]
131
- )
132
-
133
- tree_node = fxchange(tree_node)
134
-
135
- # Replace common constants
136
- for const in ["e","pi","kc","em","ec","anot","hbar","false","true","i","nabla"]:
137
- tree_node = replace(tree_node, tree_form("d_"+const), tree_form("s_"+const))
138
-
139
- # Map letters to variables
140
- for i, c in enumerate(["x","y","z"] + [chr(x+ord("a")) for x in range(0,23)]):
141
- tree_node = replace(tree_node, tree_form("d_"+c), tree_form("v_"+str(i)))
142
- for i, c in enumerate([chr(x+ord("A")) for x in range(0,26)]):
143
- tree_node = replace(tree_node, tree_form("d_"+c), tree_form("v_-"+str(i+1)))
144
- tree_node = replace(tree_node, tree_form("f_"+c), tree_form("v_-"+str(i+1)))
145
-
146
- # Final recursive replacements
147
- def rfx(tree_node):
148
- if tree_node.name[:3] == "d_c":
149
- return tree_form("v_" + str(int(tree_node.name[3:])+100))
150
- tree_node.children = [rfx(child) for child in tree_node.children]
151
- return tree_node
152
-
153
- tree_node = rfx(tree_node)
154
- return tree_node
1
+ import copy
2
+ from lark import Lark, Tree
3
+ from .base import *
4
+ import re
5
+
6
+ grammar = """
7
+ ?start: expr
8
+
9
+ ?expr: logic_equiv
10
+
11
+ ?logic_equiv: logic_imply
12
+ | logic_equiv "<->" logic_imply -> equiv
13
+
14
+ ?logic_imply: logic_or
15
+ | logic_or "->" logic_imply -> imply
16
+
17
+ ?logic_or: logic_and
18
+ | logic_or "|" logic_and -> or
19
+ | logic_or "||" logic_and -> or
20
+
21
+ ?logic_and: logic_not
22
+ | logic_and "&" logic_not -> and
23
+ | logic_and "&&" logic_not -> and
24
+
25
+ ?logic_not: comparison
26
+ | "!" logic_not -> not
27
+ | "~" logic_not -> not
28
+
29
+ ?comparison: arithmetic
30
+ | comparison "=" arithmetic -> eq
31
+ | comparison "<" arithmetic -> lt
32
+ | comparison ">" arithmetic -> gt
33
+ | comparison "<=" arithmetic -> le
34
+ | comparison ">=" arithmetic -> ge
35
+
36
+ ?arithmetic: arithmetic "+" term -> add
37
+ | arithmetic "-" term -> sub
38
+ | term
39
+
40
+ ?term: term "*" power -> mul
41
+ | term "/" power -> div
42
+ | term "." power -> dot
43
+ | power
44
+
45
+ ?power: power "^" factor -> pow
46
+ | power "**" factor -> pow
47
+ | factor
48
+
49
+ ?factor: "-" factor -> neg
50
+ | "+" factor -> pass_through
51
+ | atom
52
+
53
+ ?atom: NUMBER -> number
54
+ | VARIABLE -> variable
55
+ | FUNC_NAME "(" [expr ("," expr)*] ")" -> func
56
+ | "[" [expr ("," expr)*] "]" -> list
57
+ | "(" expr ")" -> paren
58
+ | CNUMBER -> cnumber
59
+ | ESCAPED_STRING -> string
60
+ | CAPITAL_ID -> matrix
61
+
62
+ FUNC_NAME: "midpoint" | "forall" | "imply" | "exist" | "len" | "sum" | "angle" | "line" | "sum2" | "charge" | "electricfield" | "perm" | "point" | "equationrhs" | "transpose" | "equationlhs" | "equation" | "error" | "covariance" | "variance" | "expect" | "mag" | "rad" | "laplace" | "diverge" | "pdif" | "gradient" | "curl" | "point1" | "point2" | "dot" | "point3" | "line1" | "line2" | "line3" | "sin" | "circumcenter" | "eqtri" | "linesegment" | "cos" | "tan" | "log" | "sqrt" | "integrate" | "dif" | "abs" | "cosec" | "sec" | "cot" | "arctan" | "arcsin" | "arccos" | "log10"
63
+
64
+ VARIABLE: /[a-z]/ | "nabla" | "pi" | "kc" | "hbar" | "em" | "ec" | "anot" | "false" | "true"
65
+
66
+ CAPITAL_ID: /[A-Z]/
67
+
68
+ CNUMBER: /c[0-9]+/
69
+
70
+ %import common.NUMBER
71
+ %import common.ESCAPED_STRING
72
+ %import common.WS_INLINE
73
+ %ignore WS_INLINE
74
+ """
75
+
76
+ def parse(equation, funclist=None):
77
+ equation = copy.copy(equation.replace(" ", ""))
78
+ grammar2 = copy.deepcopy(grammar)
79
+ if funclist is not None:
80
+ output = grammar2.split("\n")
81
+ for i in range(len(output)):
82
+ if "FUNC_NAME:" in output[i]:
83
+ output[i] = output[i].replace("FUNC_NAME: ", "FUNC_NAME: " + " | ".join(['"' + x + '"' for x in funclist]) + " | ")
84
+ grammar2 = "\n".join(output)
85
+
86
+ parser_main = Lark(grammar2, start='start', parser='lalr')
87
+ parse_tree = parser_main.parse(equation)
88
+
89
+ # Convert Lark tree to TreeNode
90
+ def convert_to_treenode(parse_tree):
91
+ if isinstance(parse_tree, Tree):
92
+ node = TreeNode(parse_tree.data)
93
+ node.children = [convert_to_treenode(child) for child in parse_tree.children]
94
+ return node
95
+ else:
96
+ return TreeNode(str(parse_tree))
97
+
98
+ # Flatten unnecessary nodes like pass_through
99
+ def remove_past(equation):
100
+ if equation.name in {"number", "paren", "func", "variable", "pass_through", "cnumber", "string", "matrix"}:
101
+ if len(equation.children) == 1:
102
+ return remove_past(equation.children[0])
103
+ else:
104
+ equation.children = [remove_past(child) for child in equation.children]
105
+ return TreeNode(equation.children[0].name, equation.children[1:])
106
+ equation.children = [remove_past(child) for child in equation.children]
107
+ return equation
108
+
109
+ # Handle indices if any
110
+ def prefixindex(equation):
111
+ if equation.name == "base" and len(equation.children) > 1:
112
+ return TreeNode("index", [equation.children[0]] + equation.children[1].children)
113
+ return TreeNode(equation.name, [prefixindex(child) for child in equation.children])
114
+
115
+ tree_node = convert_to_treenode(parse_tree)
116
+ tree_node = remove_past(tree_node)
117
+ tree_node = prefixindex(tree_node)
118
+
119
+ # Convert function names and constants
120
+ def fxchange(tree_node):
121
+ tmp3 = funclist if funclist is not None else []
122
+ if tree_node.name == "neg":
123
+ child = fxchange(tree_node.children[0])
124
+ # if the child is a number, make it negative
125
+ if child.name.startswith("d_") and re.match(r"d_\d+(\.\d+)?$", child.name):
126
+ return TreeNode("d_" + str(-int(child.name[2:])))
127
+ else:
128
+ # otherwise subtract from zero
129
+ return TreeNode("f_sub", [tree_form("d_0"), child])
130
+ if tree_node.name == "pass_through":
131
+ return fxchange(tree_node.children[0])
132
+ return TreeNode(
133
+ "f_" + tree_node.name if tree_node.name in tmp3 + ["sqrt","imply","forall","exist","exclude","union","intersection","len","index","angle","charge","sum2","electricfield","line","point","sum","transpose","equationrhs","equationlhs","equation","covariance","variance","expect","error","laplace","dot","curl","pdif","diverge","gradient","rad","ge","le","gt","lt","eqtri","linesegment","midpoint","mag","point1","point2","point3","line1","line2","line3","log10","arcsin","arccos","arctan","list","cosec","sec","cot","equiv","or","not","and","circumcenter","eq","sub","add","sin","cos","tan","mul","integrate","dif","pow","div","log","abs"] else "d_" + tree_node.name,
134
+ [fxchange(child) for child in tree_node.children]
135
+ )
136
+
137
+ tree_node = fxchange(tree_node)
138
+
139
+ # Replace common constants
140
+ for const in ["e","pi","kc","em","ec","anot","hbar","false","true","i","nabla"]:
141
+ tree_node = replace(tree_node, tree_form("d_"+const), tree_form("s_"+const))
142
+
143
+ # Map letters to variables
144
+ for i, c in enumerate(["x","y","z"] + [chr(x+ord("a")) for x in range(0,23)]):
145
+ tree_node = replace(tree_node, tree_form("d_"+c), tree_form("v_"+str(i)))
146
+ for i, c in enumerate([chr(x+ord("A")) for x in range(0,26)]):
147
+ tree_node = replace(tree_node, tree_form("d_"+c), tree_form("v_-"+str(i+1)))
148
+ tree_node = replace(tree_node, tree_form("f_"+c), tree_form("v_-"+str(i+1)))
149
+
150
+ # Final recursive replacements
151
+ def rfx(tree_node):
152
+ if tree_node.name[:3] == "d_c":
153
+ return tree_form("v_" + str(int(tree_node.name[3:])+100))
154
+ tree_node.children = [rfx(child) for child in tree_node.children]
155
+ return tree_node
156
+
157
+ tree_node = rfx(tree_node)
158
+ return tree_node
mathai/printeq.py CHANGED
@@ -1,34 +1,34 @@
1
- from .base import *
2
- from .simplify import solve
3
- import copy
4
- from fractions import Fraction
5
- def abstractexpr(eq):
6
- if eq.name == "f_pow" and frac(eq.children[1])==Fraction(1,2):
7
- eq = eq.children[0].fx("sqrt")
8
- if eq.name == "f_pow" and frac(eq.children[1])==Fraction(-1,2):
9
- eq = eq.children[0].fx("sqrt")**-1
10
- if eq.name in ["f_mul", "f_pow"]:
11
-
12
- lst = factor_generation(eq)
13
- deno = [item.children[0]**int(item.children[1].name[3:]) for item in lst if item.name == "f_pow" and item.children[1].name[:3] == "d_-"]
14
- if eq.name == "f_mul" and any(item.name[:2] == "d_" and int(item.name[2:]) < 0 for item in lst):
15
- return solve(-eq).fx("neg")
16
- if deno != []:
17
-
18
- num = [item for item in lst if item.name != "f_pow" or item.children[1].name[:3] != "d_-"]
19
- if num == []:
20
- num = [tree_form("d_1")]
21
- return TreeNode("f_div", [solve(product(num)), solve(product(deno))])
22
-
23
-
24
- return TreeNode(eq.name, [abstractexpr(child) for child in eq.children])
25
-
26
- def printeq_str(eq):
27
- return str(dowhile(copy.deepcopy(eq), abstractexpr))
28
-
29
- def printeq(eq):
30
- print(printeq_str(eq))
31
-
32
- def printeq_log(lst):
33
- for item in lst:
34
- print(" "*item[0] + item[1])
1
+ from .base import *
2
+ from .simplify import solve
3
+ import copy
4
+ from fractions import Fraction
5
+ def abstractexpr(eq):
6
+ if eq.name == "f_pow" and frac(eq.children[1])==Fraction(1,2):
7
+ eq = eq.children[0].fx("sqrt")
8
+ if eq.name == "f_pow" and frac(eq.children[1])==Fraction(-1,2):
9
+ eq = eq.children[0].fx("sqrt")**-1
10
+ if eq.name in ["f_mul", "f_pow"]:
11
+
12
+ lst = factor_generation(eq)
13
+ deno = [item.children[0]**int(item.children[1].name[3:]) for item in lst if item.name == "f_pow" and item.children[1].name[:3] == "d_-"]
14
+ if eq.name == "f_mul" and any(item.name[:2] == "d_" and int(item.name[2:]) < 0 for item in lst):
15
+ return solve(-eq).fx("neg")
16
+ if deno != []:
17
+
18
+ num = [item for item in lst if item.name != "f_pow" or item.children[1].name[:3] != "d_-"]
19
+ if num == []:
20
+ num = [tree_form("d_1")]
21
+ return TreeNode("f_div", [solve(product(num)), solve(product(deno))])
22
+
23
+
24
+ return TreeNode(eq.name, [abstractexpr(child) for child in eq.children])
25
+
26
+ def printeq_str(eq):
27
+ return str(dowhile(copy.deepcopy(eq), abstractexpr))
28
+
29
+ def printeq(eq):
30
+ print(printeq_str(eq))
31
+
32
+ def printeq_log(lst):
33
+ for item in lst:
34
+ print(" "*item[0] + item[1])
mathai/search.py ADDED
@@ -0,0 +1,111 @@
1
+ from mathai import *
2
+ import copy
3
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError
4
+
5
+ def dfs_simplify(equation, functions, true_expr, false_expr,
6
+ max_timeout=25, max_small=4,
7
+ base_timeout=1, time_per_char=0.1, timeout_increase=0.5):
8
+ """
9
+ Perform DFS simplification on a given equation using provided functions.
10
+
11
+ Args:
12
+ equation: The starting expression (TreeNode or parsed equation)
13
+ functions: List of simplification functions
14
+ true_expr: Expression representing True (immediate termination)
15
+ false_expr: Expression representing False (immediate termination)
16
+ max_timeout: Maximum timeout allowed for any function
17
+ max_small: Number of smallest expressions to track
18
+ base_timeout: Base timeout in seconds
19
+ time_per_char: Additional timeout per character of expression
20
+ timeout_increase: Factor to increase timeout for consecutive timeouts
21
+
22
+ Returns:
23
+ tuple(found_boolean, boolean_path, smallest_expressions)
24
+ """
25
+ original_eq = simplify(equation)
26
+ smallest_four = []
27
+
28
+ stack = [(copy.deepcopy(original_eq), [copy.deepcopy(original_eq)])]
29
+ visited = set()
30
+
31
+ found_boolean = False
32
+ boolean_path = None
33
+ boolean_expr = None
34
+
35
+ executor = ThreadPoolExecutor(max_workers=3)
36
+ consecutive_timeouts = 0
37
+
38
+ while stack and not found_boolean:
39
+ current_eq, path = stack.pop()
40
+ expr_str = str(current_eq)
41
+
42
+ if expr_str in visited:
43
+ continue
44
+ visited.add(expr_str)
45
+
46
+ # Thinking message
47
+ printeq(current_eq)
48
+
49
+ # Immediate termination for provided boolean expressions
50
+ if current_eq == true_expr or current_eq == false_expr:
51
+ found_boolean = True
52
+ boolean_path = path
53
+ boolean_expr = current_eq
54
+ break
55
+
56
+ # Insert into smallest_four if qualifies
57
+ inserted = False
58
+ for j in range(len(smallest_four)):
59
+ if len(expr_str) < len(str(smallest_four[j][0])):
60
+ smallest_four.insert(j, (copy.deepcopy(current_eq), copy.deepcopy(path)))
61
+ inserted = True
62
+ break
63
+ if not inserted and len(smallest_four) < max_small:
64
+ smallest_four.append((copy.deepcopy(current_eq), copy.deepcopy(path)))
65
+ if len(smallest_four) > max_small:
66
+ smallest_four = smallest_four[:max_small]
67
+
68
+ # Calculate adaptive timeout with cap
69
+ timeout = (base_timeout + time_per_char * len(expr_str)) * (1 + timeout_increase * consecutive_timeouts)
70
+ if timeout > max_timeout:
71
+ timeout = max_timeout
72
+
73
+ # Try functions that reduce length first
74
+ reduced_any = False
75
+ for fx in functions:
76
+ print(f"[Thinking] Executing {fx.__name__} on current expression (timeout={timeout:.2f}s):")
77
+ printeq(current_eq)
78
+ future = executor.submit(fx, current_eq)
79
+ try:
80
+ new_expr = future.result(timeout=timeout)
81
+ new_expr_str = str(new_expr)
82
+ if len(new_expr_str) <= len(expr_str) and new_expr_str != expr_str:
83
+ reduced_any = True
84
+ stack.append((new_expr, path + [copy.deepcopy(new_expr)]))
85
+ consecutive_timeouts = 0 # reset after success
86
+ except TimeoutError:
87
+ print(f"[Thinking] {fx.__name__} timed out, skipping.")
88
+ consecutive_timeouts += 1
89
+ continue
90
+
91
+ # If no reducing function worked, try growing functions
92
+ if not reduced_any:
93
+ for fx in functions:
94
+ print(f"[Thinking] Trying growing {fx.__name__} on current expression (timeout={timeout:.2f}s):")
95
+ printeq(current_eq)
96
+ future = executor.submit(fx, current_eq)
97
+ try:
98
+ new_expr = future.result(timeout=timeout)
99
+ new_expr_str = str(new_expr)
100
+ if new_expr_str != expr_str:
101
+ stack.append((new_expr, path + [copy.deepcopy(new_expr)]))
102
+ consecutive_timeouts = 0
103
+ break # only take one growing function
104
+ except TimeoutError:
105
+ print(f"[Thinking] {fx.__name__} (growing) timed out, skipping.")
106
+ consecutive_timeouts += 1
107
+ continue
108
+
109
+ executor.shutdown(wait=True)
110
+
111
+ return found_boolean, boolean_path, smallest_four