ripple-down-rules 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datastructures/callable_expression.py +8 -3
- ripple_down_rules/prompt.py +32 -26
- ripple_down_rules/rdr.py +2 -2
- ripple_down_rules/utils.py +75 -2
- {ripple_down_rules-0.1.1.dist-info → ripple_down_rules-0.1.3.dist-info}/METADATA +1 -1
- {ripple_down_rules-0.1.1.dist-info → ripple_down_rules-0.1.3.dist-info}/RECORD +9 -9
- {ripple_down_rules-0.1.1.dist-info → ripple_down_rules-0.1.3.dist-info}/WHEEL +1 -1
- {ripple_down_rules-0.1.1.dist-info → ripple_down_rules-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.1.1.dist-info → ripple_down_rules-0.1.3.dist-info}/top_level.txt +0 -0
@@ -65,7 +65,8 @@ class VariableVisitor(ast.NodeVisitor):
|
|
65
65
|
|
66
66
|
def get_used_scope(code_str, scope):
|
67
67
|
# Parse the code into an AST
|
68
|
-
|
68
|
+
mode = 'exec' if code_str.startswith('def') else 'eval'
|
69
|
+
tree = ast.parse(code_str, mode=mode)
|
69
70
|
|
70
71
|
# Walk the AST to collect used variable names
|
71
72
|
class NameCollector(ast.NodeVisitor):
|
@@ -169,6 +170,8 @@ class CallableExpression(SubclassJSONSerializer):
|
|
169
170
|
case = create_case(case, max_recursion_idx=3)
|
170
171
|
scope = {'case': case, **self.scope}
|
171
172
|
output = eval(self.code, scope)
|
173
|
+
if output is None:
|
174
|
+
output = scope['_get_value'](case)
|
172
175
|
if self.conclusion_type is not None:
|
173
176
|
assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
|
174
177
|
f" got {type(output)}")
|
@@ -219,7 +222,8 @@ def compile_expression_to_code(expression_tree: AST) -> Any:
|
|
219
222
|
:param expression_tree: The parsed expression tree.
|
220
223
|
:return: The code that was compiled from the expression tree.
|
221
224
|
"""
|
222
|
-
|
225
|
+
mode = 'exec' if isinstance(expression_tree, ast.Module) else 'eval'
|
226
|
+
return compile(expression_tree, filename="<string>", mode=mode)
|
223
227
|
|
224
228
|
|
225
229
|
def assert_context_contains_needed_information(case: Any, context: Dict[str, Any],
|
@@ -273,6 +277,7 @@ def parse_string_to_expression(expression_str: str) -> AST:
|
|
273
277
|
:param expression_str: The string which will be parsed.
|
274
278
|
:return: The parsed expression.
|
275
279
|
"""
|
276
|
-
|
280
|
+
mode = 'exec' if expression_str.startswith('def') else 'eval'
|
281
|
+
tree = ast.parse(expression_str, mode=mode)
|
277
282
|
logging.debug(f"AST parsed successfully: {ast.dump(tree)}")
|
278
283
|
return tree
|
ripple_down_rules/prompt.py
CHANGED
@@ -11,58 +11,64 @@ from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
|
|
11
11
|
from typing_extensions import Any, List, Optional, Tuple, Dict, Union, Type
|
12
12
|
|
13
13
|
from .datastructures import Case, PromptFor, CallableExpression, create_case, parse_string_to_expression, CaseQuery
|
14
|
-
from .utils import capture_variable_assignment
|
14
|
+
from .utils import capture_variable_assignment, extract_dependencies, contains_return_statement
|
15
|
+
|
16
|
+
|
17
|
+
class CustomInteractiveShell(InteractiveShellEmbed):
|
18
|
+
def __init__(self, **kwargs):
|
19
|
+
super().__init__(**kwargs)
|
20
|
+
self.all_lines = []
|
21
|
+
|
22
|
+
def run_cell(self, raw_cell: str, **kwargs):
|
23
|
+
"""
|
24
|
+
Override the run_cell method to capture return statements.
|
25
|
+
"""
|
26
|
+
if contains_return_statement(raw_cell):
|
27
|
+
self.all_lines.append(raw_cell)
|
28
|
+
print("Exiting shell on `return` statement.")
|
29
|
+
self.history_manager.store_inputs(line_num=self.execution_count, source=raw_cell)
|
30
|
+
self.ask_exit()
|
31
|
+
return None
|
32
|
+
result = super().run_cell(raw_cell, store_history=True, **kwargs)
|
33
|
+
if not result.error_in_exec:
|
34
|
+
self.all_lines.append(raw_cell)
|
35
|
+
return result
|
15
36
|
|
16
37
|
|
17
38
|
class IpythonShell:
|
18
39
|
"""
|
19
40
|
Create an embedded Ipython shell that can be used to prompt the user for input.
|
20
41
|
"""
|
21
|
-
def __init__(self,
|
42
|
+
def __init__(self, prompt_for: PromptFor, scope: Optional[Dict] = None, header: Optional[str] = None):
|
22
43
|
"""
|
23
44
|
Initialize the Ipython shell with the given scope and header.
|
24
45
|
|
25
|
-
:param variable_to_capture: The variable to capture from the user input.
|
26
46
|
:param scope: The scope to use for the shell.
|
27
47
|
:param header: The header to display when the shell is started.
|
28
48
|
"""
|
29
|
-
self.
|
49
|
+
self.prompt_for: PromptFor = prompt_for
|
30
50
|
self.scope: Dict = scope or {}
|
31
51
|
self.header: str = header or ">>> Embedded Ipython Shell"
|
32
52
|
self.user_input: Optional[str] = None
|
33
|
-
self.shell:
|
34
|
-
self.
|
53
|
+
self.shell: CustomInteractiveShell = self._init_shell()
|
54
|
+
self.all_code_lines: List[str] = []
|
35
55
|
|
36
56
|
def _init_shell(self):
|
37
57
|
"""
|
38
58
|
Initialize the Ipython shell with a custom configuration.
|
39
59
|
"""
|
40
60
|
cfg = Config()
|
41
|
-
shell =
|
61
|
+
shell = CustomInteractiveShell(config=cfg, user_ns=self.scope, banner1=self.header)
|
42
62
|
return shell
|
43
63
|
|
44
|
-
def _register_hooks(self):
|
45
|
-
"""
|
46
|
-
Register hooks to capture specific events in the Ipython shell.
|
47
|
-
"""
|
48
|
-
def capture_variable(exec_info: ExecutionInfo):
|
49
|
-
code = exec_info.raw_cell
|
50
|
-
if self.variable_to_capture not in code:
|
51
|
-
return
|
52
|
-
# use ast to find if the user is assigning a value to the variable "condition"
|
53
|
-
assignment = capture_variable_assignment(code, self.variable_to_capture)
|
54
|
-
if assignment:
|
55
|
-
# if the user is assigning a value to the variable "condition", update the raw_condition
|
56
|
-
self.user_input = assignment
|
57
|
-
print(f"[Captured {self.variable_to_capture}]:\n{self.user_input}")
|
58
|
-
|
59
|
-
self.shell.events.register('pre_run_cell', capture_variable)
|
60
|
-
|
61
64
|
def run(self):
|
62
65
|
"""
|
63
66
|
Run the embedded shell.
|
64
67
|
"""
|
65
68
|
self.shell()
|
69
|
+
self.all_code_lines = extract_dependencies(self.shell.all_lines)
|
70
|
+
self.user_input = f"def _get_value(case):\n "
|
71
|
+
self.user_input += '\n '.join(self.all_code_lines)
|
66
72
|
|
67
73
|
|
68
74
|
def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor,
|
@@ -99,7 +105,7 @@ def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor) -> Tupl
|
|
99
105
|
"""
|
100
106
|
prompt_str = f"Give {prompt_for} for {case_query.name}"
|
101
107
|
scope = {'case': case_query.case, **case_query.scope}
|
102
|
-
shell = IpythonShell(prompt_for
|
108
|
+
shell = IpythonShell(prompt_for, scope=scope, header=prompt_str)
|
103
109
|
user_input, expression_tree = prompt_user_input_and_parse_to_expression(shell=shell)
|
104
110
|
return user_input, expression_tree
|
105
111
|
|
@@ -132,12 +138,12 @@ def prompt_user_input_and_parse_to_expression(shell: Optional[IpythonShell] = No
|
|
132
138
|
shell = IpythonShell() if shell is None else shell
|
133
139
|
shell.run()
|
134
140
|
user_input = shell.user_input
|
141
|
+
print(user_input)
|
135
142
|
try:
|
136
143
|
return user_input, parse_string_to_expression(user_input)
|
137
144
|
except Exception as e:
|
138
145
|
msg = f"Error parsing expression: {e}"
|
139
146
|
logging.error(msg)
|
140
|
-
print(msg)
|
141
147
|
user_input = None
|
142
148
|
|
143
149
|
|
ripple_down_rules/rdr.py
CHANGED
@@ -163,9 +163,9 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
163
163
|
Update the figures of the classifier.
|
164
164
|
"""
|
165
165
|
if isinstance(self, GeneralRDR):
|
166
|
-
for i, (
|
166
|
+
for i, (rdr_name, rdr) in enumerate(self.start_rules_dict.items()):
|
167
167
|
if not rdr.fig:
|
168
|
-
rdr.fig = plt.figure(f"Rule {i}: {
|
168
|
+
rdr.fig = plt.figure(f"Rule {i}: {rdr_name}")
|
169
169
|
draw_tree(rdr.start_rule, rdr.fig)
|
170
170
|
else:
|
171
171
|
if not self.fig:
|
ripple_down_rules/utils.py
CHANGED
@@ -5,10 +5,9 @@ import importlib
|
|
5
5
|
import json
|
6
6
|
import logging
|
7
7
|
import os
|
8
|
-
from abc import abstractmethod
|
9
8
|
from collections import UserDict
|
10
9
|
from copy import deepcopy
|
11
|
-
from dataclasses import
|
10
|
+
from dataclasses import is_dataclass, fields
|
12
11
|
|
13
12
|
import matplotlib
|
14
13
|
import networkx as nx
|
@@ -24,9 +23,82 @@ from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get
|
|
24
23
|
if TYPE_CHECKING:
|
25
24
|
from .datastructures import Case
|
26
25
|
|
26
|
+
import ast
|
27
|
+
|
27
28
|
matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
|
28
29
|
|
29
30
|
|
31
|
+
def contains_return_statement(source: str) -> bool:
|
32
|
+
"""
|
33
|
+
:param source: The source code to check.
|
34
|
+
:return: True if the source code contains a return statement, False otherwise.
|
35
|
+
"""
|
36
|
+
try:
|
37
|
+
tree = ast.parse(source)
|
38
|
+
for node in tree.body:
|
39
|
+
if isinstance(node, ast.Return):
|
40
|
+
return True
|
41
|
+
return False
|
42
|
+
except SyntaxError:
|
43
|
+
return False
|
44
|
+
|
45
|
+
|
46
|
+
def get_names_used(node):
|
47
|
+
return {n.id for n in ast.walk(node) if isinstance(n, ast.Name)}
|
48
|
+
|
49
|
+
|
50
|
+
def extract_dependencies(code_lines):
|
51
|
+
full_code = '\n'.join(code_lines)
|
52
|
+
tree = ast.parse(full_code)
|
53
|
+
final_stmt = tree.body[-1]
|
54
|
+
|
55
|
+
if not isinstance(final_stmt, ast.Return):
|
56
|
+
raise ValueError("Last line is not a return statement")
|
57
|
+
|
58
|
+
needed = get_names_used(final_stmt.value)
|
59
|
+
required_lines = []
|
60
|
+
line_map = {id(node): i for i, node in enumerate(tree.body)}
|
61
|
+
|
62
|
+
def handle_stmt(stmt, needed):
|
63
|
+
keep = False
|
64
|
+
if isinstance(stmt, ast.Assign):
|
65
|
+
targets = [t.id for t in stmt.targets if isinstance(t, ast.Name)]
|
66
|
+
if any(t in needed for t in targets):
|
67
|
+
needed.update(get_names_used(stmt.value))
|
68
|
+
keep = True
|
69
|
+
elif isinstance(stmt, ast.AugAssign):
|
70
|
+
if isinstance(stmt.target, ast.Name) and stmt.target.id in needed:
|
71
|
+
needed.update(get_names_used(stmt.value))
|
72
|
+
keep = True
|
73
|
+
elif isinstance(stmt, ast.FunctionDef):
|
74
|
+
if stmt.name in needed:
|
75
|
+
for n in ast.walk(stmt):
|
76
|
+
if isinstance(n, ast.Name):
|
77
|
+
needed.add(n.id)
|
78
|
+
keep = True
|
79
|
+
elif isinstance(stmt, (ast.For, ast.While, ast.If)):
|
80
|
+
# Check if any of the body statements interact with needed variables
|
81
|
+
for substmt in stmt.body + getattr(stmt, 'orelse', []):
|
82
|
+
if handle_stmt(substmt, needed):
|
83
|
+
keep = True
|
84
|
+
# Also check the condition (test or iter)
|
85
|
+
if isinstance(stmt, ast.For):
|
86
|
+
if isinstance(stmt.target, ast.Name) and stmt.target.id in needed:
|
87
|
+
keep = True
|
88
|
+
needed.update(get_names_used(stmt.iter))
|
89
|
+
elif isinstance(stmt, ast.If) or isinstance(stmt, ast.While):
|
90
|
+
needed.update(get_names_used(stmt.test))
|
91
|
+
|
92
|
+
return keep
|
93
|
+
|
94
|
+
for stmt in reversed(tree.body[:-1]):
|
95
|
+
if handle_stmt(stmt, needed):
|
96
|
+
required_lines.insert(0, code_lines[line_map[id(stmt)]])
|
97
|
+
|
98
|
+
required_lines.append(code_lines[-1]) # Always include return
|
99
|
+
return required_lines
|
100
|
+
|
101
|
+
|
30
102
|
def serialize_dataclass(obj: Any) -> Union[Dict, Any]:
|
31
103
|
"""
|
32
104
|
Recursively serialize a dataclass to a dictionary. If the dataclass contains any nested dataclasses, they will be
|
@@ -61,6 +133,7 @@ def deserialize_dataclass(data: dict) -> Any:
|
|
61
133
|
:param data: The dictionary to deserialize.
|
62
134
|
:return: The deserialized dataclass.
|
63
135
|
"""
|
136
|
+
|
64
137
|
def recursive_load(obj):
|
65
138
|
if isinstance(obj, dict) and "__dataclass__" in obj:
|
66
139
|
module_name, class_name = obj["__dataclass__"].rsplit(".", 1)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.3
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -3,18 +3,18 @@ ripple_down_rules/datasets.py,sha256=AzPtqUXuR1qLQNtRsWLsJ3gX2oIf8nIkFvmsmz7fHlw
|
|
3
3
|
ripple_down_rules/experts.py,sha256=Xz1U1Tdq7jrFlcVuSusaMB241AG9TEs7q101i59Xijs,10683
|
4
4
|
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
5
|
ripple_down_rules/helpers.py,sha256=AhqerAQoCdSovJ7SdQrNtAI_hYagKpLsy2nJQGA0bl0,1062
|
6
|
-
ripple_down_rules/prompt.py,sha256=
|
7
|
-
ripple_down_rules/rdr.py,sha256=
|
6
|
+
ripple_down_rules/prompt.py,sha256=kXQAiNDCayB6Ijecxx487eOqqWLcfvmp0q7FbyfuQM0,6433
|
7
|
+
ripple_down_rules/rdr.py,sha256=NXVYIflUxcDzC5DDrK-l_ZT-sBmUV1ZgkznshSsJZYc,43508
|
8
8
|
ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
|
9
9
|
ripple_down_rules/rules.py,sha256=aM3Im4ePuFDlkuD2EKRtiVmYgoQ_sxlwcbzrDKqXAfs,14578
|
10
|
-
ripple_down_rules/utils.py,sha256=
|
10
|
+
ripple_down_rules/utils.py,sha256=HWu5rvAaV2SXIORHR0c2RBdWNc9q4B7DjWIskuyDTA8,26877
|
11
11
|
ripple_down_rules/datastructures/__init__.py,sha256=zpmiYm4WkwNHaGdTIfacS7llN5d2xyU6U-saH_TpydI,103
|
12
|
-
ripple_down_rules/datastructures/callable_expression.py,sha256=
|
12
|
+
ripple_down_rules/datastructures/callable_expression.py,sha256=noukQQh1Loto9s8EJWnaS-7kaPaKuUccyEz5xHp7KtI,10790
|
13
13
|
ripple_down_rules/datastructures/case.py,sha256=3Pl07jmYn94wdCVTaRZDmBPgyAsN1TjebvrE6-68MVU,13606
|
14
14
|
ripple_down_rules/datastructures/dataclasses.py,sha256=AI-wqNy8y9QPg6lov0P-c5b8JXemuM4X62tIRhW-Gqs,4231
|
15
15
|
ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
|
16
|
-
ripple_down_rules-0.1.
|
17
|
-
ripple_down_rules-0.1.
|
18
|
-
ripple_down_rules-0.1.
|
19
|
-
ripple_down_rules-0.1.
|
20
|
-
ripple_down_rules-0.1.
|
16
|
+
ripple_down_rules-0.1.3.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
17
|
+
ripple_down_rules-0.1.3.dist-info/METADATA,sha256=k5i5LEQ0cEKeuTQvwRIBNYy3YAVG160mvh8DYsJKxJ0,42518
|
18
|
+
ripple_down_rules-0.1.3.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
19
|
+
ripple_down_rules-0.1.3.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
20
|
+
ripple_down_rules-0.1.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|