ripple-down-rules 0.1.69__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datastructures/callable_expression.py +32 -3
- ripple_down_rules/datastructures/enums.py +1 -1
- ripple_down_rules/prompt.py +203 -10
- ripple_down_rules/rdr.py +89 -71
- ripple_down_rules/rules.py +38 -33
- ripple_down_rules/utils.py +129 -28
- {ripple_down_rules-0.1.69.dist-info → ripple_down_rules-0.2.0.dist-info}/METADATA +1 -1
- ripple_down_rules-0.2.0.dist-info/RECORD +20 -0
- {ripple_down_rules-0.1.69.dist-info → ripple_down_rules-0.2.0.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.1.69.dist-info/RECORD +0 -20
- {ripple_down_rules-0.1.69.dist-info → ripple_down_rules-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.1.69.dist-info → ripple_down_rules-0.2.0.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
|
|
9
9
|
|
10
10
|
from .case import create_case, Case
|
11
11
|
from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string, conclusion_to_json, is_iterable, \
|
12
|
-
build_user_input_from_conclusion, encapsulate_user_input
|
12
|
+
build_user_input_from_conclusion, encapsulate_user_input, extract_function_source
|
13
13
|
|
14
14
|
|
15
15
|
class VariableVisitor(ast.NodeVisitor):
|
@@ -109,7 +109,7 @@ class CallableExpression(SubclassJSONSerializer):
|
|
109
109
|
if user_input is None:
|
110
110
|
user_input = build_user_input_from_conclusion(conclusion)
|
111
111
|
self.conclusion: Optional[Any] = conclusion
|
112
|
-
self.
|
112
|
+
self._user_input: str = encapsulate_user_input(user_input, self.encapsulating_function)
|
113
113
|
if conclusion_type is not None:
|
114
114
|
if is_iterable(conclusion_type):
|
115
115
|
conclusion_type = tuple(conclusion_type)
|
@@ -156,6 +156,35 @@ class CallableExpression(SubclassJSONSerializer):
|
|
156
156
|
f"return _cond1(case) and _cond2(case)")
|
157
157
|
return CallableExpression(new_user_input, conclusion_type=self.conclusion_type)
|
158
158
|
|
159
|
+
def update_user_input_from_file(self, file_path: str, function_name: str):
|
160
|
+
"""
|
161
|
+
Update the user input from a file.
|
162
|
+
"""
|
163
|
+
new_function_body = extract_function_source(file_path, [function_name])[function_name]
|
164
|
+
if new_function_body is None:
|
165
|
+
return
|
166
|
+
self.user_input = self.encapsulating_function + '\n' + new_function_body
|
167
|
+
|
168
|
+
@property
|
169
|
+
def user_input(self):
|
170
|
+
"""
|
171
|
+
Get the user input.
|
172
|
+
"""
|
173
|
+
return self._user_input
|
174
|
+
|
175
|
+
@user_input.setter
|
176
|
+
def user_input(self, value: str):
|
177
|
+
"""
|
178
|
+
Set the user input.
|
179
|
+
"""
|
180
|
+
if value is not None:
|
181
|
+
self._user_input = encapsulate_user_input(value, self.encapsulating_function)
|
182
|
+
self.scope = get_used_scope(self.user_input, self.scope)
|
183
|
+
self.expression_tree = parse_string_to_expression(self.user_input)
|
184
|
+
self.code = compile_expression_to_code(self.expression_tree)
|
185
|
+
self.visitor = VariableVisitor()
|
186
|
+
self.visitor.visit(self.expression_tree)
|
187
|
+
|
159
188
|
def __eq__(self, other):
|
160
189
|
"""
|
161
190
|
Check if two callable expressions are equal.
|
@@ -225,7 +254,7 @@ def parse_string_to_expression(expression_str: str) -> AST:
|
|
225
254
|
:param expression_str: The string which will be parsed.
|
226
255
|
:return: The parsed expression.
|
227
256
|
"""
|
228
|
-
if not expression_str.startswith(
|
257
|
+
if not expression_str.startswith(CallableExpression.encapsulating_function):
|
229
258
|
expression_str = encapsulate_user_input(expression_str, CallableExpression.encapsulating_function)
|
230
259
|
mode = 'exec' if expression_str.startswith('def') else 'eval'
|
231
260
|
tree = ast.parse(expression_str, mode=mode)
|
ripple_down_rules/prompt.py
CHANGED
@@ -1,27 +1,148 @@
|
|
1
1
|
import ast
|
2
2
|
import logging
|
3
|
+
import os
|
4
|
+
import subprocess
|
5
|
+
import tempfile
|
3
6
|
from _ast import AST
|
7
|
+
from functools import cached_property
|
8
|
+
from textwrap import indent, dedent
|
4
9
|
|
10
|
+
from IPython.core.magic import register_line_magic, line_magic, Magics, magics_class
|
5
11
|
from IPython.terminal.embed import InteractiveShellEmbed
|
6
12
|
from traitlets.config import Config
|
7
|
-
from typing_extensions import List, Optional, Tuple, Dict
|
13
|
+
from typing_extensions import List, Optional, Tuple, Dict, Type, Union
|
8
14
|
|
9
15
|
from .datastructures.enums import PromptFor
|
16
|
+
from .datastructures.case import Case
|
10
17
|
from .datastructures.callable_expression import CallableExpression, parse_string_to_expression
|
11
18
|
from .datastructures.dataclasses import CaseQuery
|
12
|
-
from .utils import extract_dependencies, contains_return_statement, make_set
|
19
|
+
from .utils import extract_dependencies, contains_return_statement, make_set, get_imports_from_scope, make_list, \
|
20
|
+
get_import_from_type, get_imports_from_types, is_iterable, extract_function_source, encapsulate_user_input
|
21
|
+
|
22
|
+
|
23
|
+
@magics_class
|
24
|
+
class MyMagics(Magics):
|
25
|
+
def __init__(self, shell, scope, output_type: Optional[Type] = None, func_name: str = "user_case",
|
26
|
+
func_doc: str = "User defined function to be executed on the case.",
|
27
|
+
code_to_modify: Optional[str] = None):
|
28
|
+
super().__init__(shell)
|
29
|
+
self.scope = scope
|
30
|
+
self.temp_file_path = None
|
31
|
+
self.func_name = func_name
|
32
|
+
self.func_doc = func_doc
|
33
|
+
self.code_to_modify = code_to_modify
|
34
|
+
self.output_type = make_list(output_type) if output_type is not None else None
|
35
|
+
self.user_edit_line = 0
|
36
|
+
self.function_signature: Optional[str] = None
|
37
|
+
self.build_function_signature()
|
38
|
+
|
39
|
+
@line_magic
|
40
|
+
def edit_case(self, line):
|
41
|
+
|
42
|
+
boilerplate_code = self.build_boilerplate_code()
|
43
|
+
|
44
|
+
self.write_to_file(boilerplate_code)
|
45
|
+
|
46
|
+
print(f"Opening {self.temp_file_path} in PyCharm...")
|
47
|
+
subprocess.Popen(["pycharm", "--line", str(self.user_edit_line), self.temp_file_path])
|
48
|
+
|
49
|
+
def build_boilerplate_code(self):
|
50
|
+
imports = self.get_imports()
|
51
|
+
self.build_function_signature()
|
52
|
+
if self.code_to_modify is not None:
|
53
|
+
body = indent(dedent(self.code_to_modify), ' ')
|
54
|
+
else:
|
55
|
+
body = " # Write your code here\n pass"
|
56
|
+
boilerplate = f"""{imports}\n\n{self.function_signature}\n \"\"\"{self.func_doc}\"\"\"\n{body}"""
|
57
|
+
self.user_edit_line = imports.count('\n')+6
|
58
|
+
return boilerplate
|
59
|
+
|
60
|
+
def build_function_signature(self):
|
61
|
+
if self.output_type is None:
|
62
|
+
output_type_hint = ""
|
63
|
+
elif len(self.output_type) == 1:
|
64
|
+
output_type_hint = f" -> {self.output_type[0].__name__}"
|
65
|
+
else:
|
66
|
+
output_type_hint = f" -> Union[{', '.join([t.__name__ for t in self.output_type])}]"
|
67
|
+
self.function_signature = f"def {self.func_name}(case: {self.case_type.__name__}){output_type_hint}:"
|
68
|
+
|
69
|
+
def write_to_file(self, code: str):
|
70
|
+
tmp = tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".py",
|
71
|
+
dir=os.path.dirname(self.scope['__file__']))
|
72
|
+
tmp.write(code)
|
73
|
+
tmp.flush()
|
74
|
+
self.temp_file_path = tmp.name
|
75
|
+
tmp.close()
|
76
|
+
|
77
|
+
def get_imports(self):
|
78
|
+
case_type_import = f"from {self.case_type.__module__} import {self.case_type.__name__}"
|
79
|
+
if self.output_type is None:
|
80
|
+
output_type_imports = [f"from typing_extensions import Any"]
|
81
|
+
else:
|
82
|
+
output_type_imports = get_imports_from_types(self.output_type)
|
83
|
+
if len(self.output_type) > 1:
|
84
|
+
output_type_imports.append("from typing_extensions import Union")
|
85
|
+
print(output_type_imports)
|
86
|
+
imports = get_imports_from_scope(self.scope)
|
87
|
+
imports = [i for i in imports if ("get_ipython" not in i)]
|
88
|
+
if case_type_import not in imports:
|
89
|
+
imports.append(case_type_import)
|
90
|
+
imports.extend([oti for oti in output_type_imports if oti not in imports])
|
91
|
+
imports = set(imports)
|
92
|
+
return '\n'.join(imports)
|
93
|
+
|
94
|
+
@cached_property
|
95
|
+
def case_type(self) -> Type:
|
96
|
+
"""
|
97
|
+
Get the type of the case object in the current scope.
|
98
|
+
|
99
|
+
:return: The type of the case object.
|
100
|
+
"""
|
101
|
+
case = self.scope['case']
|
102
|
+
return case._obj_type if isinstance(case, Case) else type(case)
|
103
|
+
|
104
|
+
@line_magic
|
105
|
+
def load_case(self, line):
|
106
|
+
if not self.temp_file_path:
|
107
|
+
print("No file to load. Run %edit_case first.")
|
108
|
+
return
|
109
|
+
|
110
|
+
with open(self.temp_file_path, 'r') as f:
|
111
|
+
source = f.read()
|
112
|
+
|
113
|
+
tree = ast.parse(source)
|
114
|
+
for node in tree.body:
|
115
|
+
if isinstance(node, ast.FunctionDef) and node.name == self.func_name:
|
116
|
+
exec_globals = {}
|
117
|
+
exec(source, self.scope, exec_globals)
|
118
|
+
user_function = exec_globals[self.func_name]
|
119
|
+
self.shell.user_ns[self.func_name] = user_function
|
120
|
+
print(f"Loaded `{self.func_name}` function into user namespace.")
|
121
|
+
return
|
122
|
+
|
123
|
+
print(f"Function `{self.func_name}` not found.")
|
13
124
|
|
14
125
|
|
15
126
|
class CustomInteractiveShell(InteractiveShellEmbed):
|
16
|
-
def __init__(self,
|
127
|
+
def __init__(self, output_type: Union[Type, Tuple[Type], None] = None, func_name: Optional[str] = None,
|
128
|
+
func_doc: Optional[str] = None, code_to_modify: Optional[str] = None, **kwargs):
|
17
129
|
super().__init__(**kwargs)
|
130
|
+
keys = ['output_type', 'func_name', 'func_doc', 'code_to_modify']
|
131
|
+
values = [output_type, func_name, func_doc, code_to_modify]
|
132
|
+
magics_kwargs = {key: value for key, value in zip(keys, values) if value is not None}
|
133
|
+
self.my_magics = MyMagics(self, self.user_ns, **magics_kwargs)
|
134
|
+
self.register_magics(self.my_magics)
|
18
135
|
self.all_lines = []
|
19
136
|
|
20
137
|
def run_cell(self, raw_cell: str, **kwargs):
|
21
138
|
"""
|
22
139
|
Override the run_cell method to capture return statements.
|
23
140
|
"""
|
24
|
-
if contains_return_statement(raw_cell):
|
141
|
+
if contains_return_statement(raw_cell) and 'def ' not in raw_cell:
|
142
|
+
if self.my_magics.func_name in raw_cell:
|
143
|
+
self.all_lines = extract_function_source(self.my_magics.temp_file_path,
|
144
|
+
self.my_magics.func_name,
|
145
|
+
join_lines=False)[self.my_magics.func_name]
|
25
146
|
self.all_lines.append(raw_cell)
|
26
147
|
print("Exiting shell on `return` statement.")
|
27
148
|
self.history_manager.store_inputs(line_num=self.execution_count, source=raw_cell)
|
@@ -38,16 +159,31 @@ class IPythonShell:
|
|
38
159
|
Create an embedded Ipython shell that can be used to prompt the user for input.
|
39
160
|
"""
|
40
161
|
|
41
|
-
def __init__(self, scope: Optional[Dict] = None, header: Optional[str] = None
|
162
|
+
def __init__(self, scope: Optional[Dict] = None, header: Optional[str] = None,
|
163
|
+
output_type: Optional[Type] = None, prompt_for: Optional[PromptFor] = None,
|
164
|
+
attribute_name: Optional[str] = None, attribute_type: Optional[Type] = None,
|
165
|
+
code_to_modify: Optional[str] = None):
|
42
166
|
"""
|
43
167
|
Initialize the Ipython shell with the given scope and header.
|
44
168
|
|
45
169
|
:param scope: The scope to use for the shell.
|
46
170
|
:param header: The header to display when the shell is started.
|
171
|
+
:param output_type: The type of the output from user input.
|
172
|
+
:param prompt_for: The type of information to ask the user about.
|
173
|
+
:param attribute_name: The name of the attribute of the case.
|
174
|
+
:param attribute_type: The type of the attribute of the case.
|
175
|
+
:param code_to_modify: The code to modify. If given, will be used as a start for user to modify.
|
47
176
|
"""
|
48
177
|
self.scope: Dict = scope or {}
|
49
178
|
self.header: str = header or ">>> Embedded Ipython Shell"
|
179
|
+
self.output_type: Optional[Type] = output_type
|
180
|
+
self.prompt_for: Optional[PromptFor] = prompt_for
|
181
|
+
self.attribute_name: Optional[str] = attribute_name
|
182
|
+
self.attribute_type: Optional[Type] = attribute_type
|
183
|
+
self.code_to_modify: Optional[str] = code_to_modify
|
50
184
|
self.user_input: Optional[str] = None
|
185
|
+
self.func_name: str = ""
|
186
|
+
self.func_doc: str = ""
|
51
187
|
self.shell: CustomInteractiveShell = self._init_shell()
|
52
188
|
self.all_code_lines: List[str] = []
|
53
189
|
|
@@ -56,9 +192,49 @@ class IPythonShell:
|
|
56
192
|
Initialize the Ipython shell with a custom configuration.
|
57
193
|
"""
|
58
194
|
cfg = Config()
|
59
|
-
|
195
|
+
self.build_func_name_and_doc()
|
196
|
+
shell = CustomInteractiveShell(config=cfg, user_ns=self.scope, banner1=self.header,
|
197
|
+
output_type=self.output_type, func_name=self.func_name, func_doc=self.func_doc,
|
198
|
+
code_to_modify=self.code_to_modify)
|
60
199
|
return shell
|
61
200
|
|
201
|
+
def build_func_name_and_doc(self) -> Tuple[str, str]:
|
202
|
+
"""
|
203
|
+
Build the function name and docstring for the user-defined function.
|
204
|
+
|
205
|
+
:return: A tuple containing the function name and docstring.
|
206
|
+
"""
|
207
|
+
case = self.scope['case']
|
208
|
+
case_type = case._obj_type if isinstance(case, Case) else type(case)
|
209
|
+
self.func_name = self.build_func_name(case_type)
|
210
|
+
self.func_doc = self.build_func_doc(case_type)
|
211
|
+
|
212
|
+
def build_func_doc(self, case_type: Type):
|
213
|
+
if self.prompt_for == PromptFor.Conditions:
|
214
|
+
func_doc = (f"Get conditions on whether it's possible to conclude a value"
|
215
|
+
f" for {case_type.__name__}.{self.attribute_name}")
|
216
|
+
else:
|
217
|
+
func_doc = f"Get possible value(s) for {case_type.__name__}.{self.attribute_name}"
|
218
|
+
if is_iterable(self.attribute_type):
|
219
|
+
possible_types = [t.__name__ for t in self.attribute_type if t not in [list, set]]
|
220
|
+
func_doc += f" of types list/set of {' and/or '.join(possible_types)}"
|
221
|
+
else:
|
222
|
+
func_doc += f" of type {self.attribute_type.__name__}"
|
223
|
+
return func_doc
|
224
|
+
|
225
|
+
def build_func_name(self, case_type: Type):
|
226
|
+
func_name = f"get_{self.prompt_for.value.lower()}_for"
|
227
|
+
func_name += f"_{case_type.__name__}"
|
228
|
+
if self.attribute_name is not None:
|
229
|
+
func_name += f"_{self.attribute_name}"
|
230
|
+
if is_iterable(self.attribute_type):
|
231
|
+
output_names = [f"{t.__name__}" for t in self.attribute_type if t not in [list, set]]
|
232
|
+
else:
|
233
|
+
output_names = [self.attribute_type.__name__] if self.attribute_type is not None else None
|
234
|
+
if output_names is not None:
|
235
|
+
func_name += '_of_type_' + '_'.join(output_names)
|
236
|
+
return func_name.lower()
|
237
|
+
|
62
238
|
def run(self):
|
63
239
|
"""
|
64
240
|
Run the embedded shell.
|
@@ -83,7 +259,12 @@ class IPythonShell:
|
|
83
259
|
if len(self.all_code_lines) == 1 and self.all_code_lines[0].strip() == '':
|
84
260
|
self.user_input = None
|
85
261
|
else:
|
262
|
+
import pdb; pdb.set_trace()
|
86
263
|
self.user_input = '\n'.join(self.all_code_lines)
|
264
|
+
self.user_input = encapsulate_user_input(self.user_input, self.shell.my_magics.function_signature,
|
265
|
+
self.func_doc)
|
266
|
+
if f"return {self.func_name}(case)" not in self.user_input:
|
267
|
+
self.user_input = self.user_input.strip() + f"\nreturn {self.func_name}(case)"
|
87
268
|
|
88
269
|
|
89
270
|
def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, prompt_str: Optional[str] = None)\
|
@@ -96,8 +277,11 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, pro
|
|
96
277
|
:param prompt_str: The prompt string to display to the user.
|
97
278
|
:return: A callable expression that takes a case and executes user expression on it.
|
98
279
|
"""
|
280
|
+
prev_user_input: Optional[str] = None
|
99
281
|
while True:
|
100
|
-
user_input, expression_tree = prompt_user_about_case(case_query, prompt_for, prompt_str
|
282
|
+
user_input, expression_tree = prompt_user_about_case(case_query, prompt_for, prompt_str,
|
283
|
+
code_to_modify=prev_user_input)
|
284
|
+
prev_user_input = '\n'.join(user_input.split('\n')[2:-1])
|
101
285
|
if user_input is None:
|
102
286
|
if prompt_for == PromptFor.Conclusion:
|
103
287
|
print("No conclusion provided. Exiting.")
|
@@ -109,7 +293,11 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, pro
|
|
109
293
|
callable_expression = CallableExpression(user_input, conclusion_type, expression_tree=expression_tree,
|
110
294
|
scope=case_query.scope)
|
111
295
|
try:
|
112
|
-
callable_expression(case_query.case)
|
296
|
+
result = callable_expression(case_query.case)
|
297
|
+
result = make_list(result)
|
298
|
+
if len(result) == 0:
|
299
|
+
print(f"The given expression gave an empty result for case {case_query.name}, please modify")
|
300
|
+
continue
|
113
301
|
break
|
114
302
|
except Exception as e:
|
115
303
|
logging.error(e)
|
@@ -118,19 +306,24 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, pro
|
|
118
306
|
|
119
307
|
|
120
308
|
def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor,
|
121
|
-
prompt_str: Optional[str] = None
|
309
|
+
prompt_str: Optional[str] = None,
|
310
|
+
code_to_modify: Optional[str] = None) -> Tuple[Optional[str], Optional[AST]]:
|
122
311
|
"""
|
123
312
|
Prompt the user for input.
|
124
313
|
|
125
314
|
:param case_query: The case query to prompt the user for.
|
126
315
|
:param prompt_for: The type of information the user should provide for the given case.
|
127
316
|
:param prompt_str: The prompt string to display to the user.
|
317
|
+
:param code_to_modify: The code to modify. If given will be used as a start for user to modify.
|
128
318
|
:return: The user input, and the executable expression that was parsed from the user input.
|
129
319
|
"""
|
130
320
|
if prompt_str is None:
|
131
321
|
prompt_str = f"Give {prompt_for} for {case_query.name}"
|
132
322
|
scope = {'case': case_query.case, **case_query.scope}
|
133
|
-
|
323
|
+
output_type = case_query.attribute_type if prompt_for == PromptFor.Conclusion else bool
|
324
|
+
shell = IPythonShell(scope=scope, header=prompt_str, output_type=output_type, prompt_for=prompt_for,
|
325
|
+
attribute_name=case_query.attribute_name, attribute_type=case_query.attribute_type,
|
326
|
+
code_to_modify=code_to_modify)
|
134
327
|
return prompt_user_input_and_parse_to_expression(shell=shell)
|
135
328
|
|
136
329
|
|
ripple_down_rules/rdr.py
CHANGED
@@ -4,25 +4,23 @@ import importlib
|
|
4
4
|
import sys
|
5
5
|
from abc import ABC, abstractmethod
|
6
6
|
from copy import copy
|
7
|
-
from dataclasses import is_dataclass
|
8
7
|
from io import TextIOWrapper
|
9
8
|
from types import ModuleType
|
10
9
|
|
11
10
|
from matplotlib import pyplot as plt
|
12
|
-
from ordered_set import OrderedSet
|
13
11
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
14
12
|
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
|
15
13
|
|
16
14
|
from .datastructures.callable_expression import CallableExpression
|
17
15
|
from .datastructures.case import Case, CaseAttribute, create_case
|
18
16
|
from .datastructures.dataclasses import CaseQuery
|
19
|
-
from .datastructures.enums import MCRDRMode
|
20
|
-
from .experts import Expert
|
17
|
+
from .datastructures.enums import MCRDRMode
|
18
|
+
from .experts import Expert
|
21
19
|
from .helpers import is_matching
|
22
20
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
23
21
|
from .utils import draw_tree, make_set, copy_case, \
|
24
|
-
SubclassJSONSerializer,
|
25
|
-
|
22
|
+
SubclassJSONSerializer, make_list, get_type_from_string, \
|
23
|
+
is_conflicting, update_case, get_imports_from_scope, extract_function_source
|
26
24
|
|
27
25
|
|
28
26
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -102,11 +100,12 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
102
100
|
return self.classify(case)
|
103
101
|
|
104
102
|
@abstractmethod
|
105
|
-
def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
|
103
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) -> Optional[CaseAttribute]:
|
106
104
|
"""
|
107
105
|
Classify a case.
|
108
106
|
|
109
107
|
:param case: The case to classify.
|
108
|
+
:param modify_case: Whether to modify the original case attributes with the conclusion or not.
|
110
109
|
:return: The category that the case belongs to.
|
111
110
|
"""
|
112
111
|
pass
|
@@ -124,7 +123,10 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
124
123
|
if case_query is None:
|
125
124
|
raise ValueError("The case query cannot be None.")
|
126
125
|
if case_query.target is None:
|
127
|
-
|
126
|
+
case_query_cp = copy(case_query)
|
127
|
+
self.classify(case_query_cp.case, modify_case=True)
|
128
|
+
expert.ask_for_conclusion(case_query_cp)
|
129
|
+
case_query.target = case_query_cp.target
|
128
130
|
if case_query.target is None:
|
129
131
|
return self.classify(case_query.case)
|
130
132
|
|
@@ -169,24 +171,62 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
169
171
|
self.fig = plt.figure(0)
|
170
172
|
draw_tree(self.start_rule, self.fig)
|
171
173
|
|
172
|
-
@
|
173
|
-
def
|
174
|
-
|
175
|
-
|
174
|
+
@property
|
175
|
+
def type_(self):
|
176
|
+
return self.__class__
|
177
|
+
|
178
|
+
@property
|
179
|
+
def generated_python_file_name(self) -> str:
|
180
|
+
if self._generated_python_file_name is None:
|
181
|
+
self._generated_python_file_name = self._default_generated_python_file_name
|
182
|
+
return self._generated_python_file_name
|
176
183
|
|
177
|
-
|
178
|
-
|
179
|
-
:return: Whether the case has a conclusion or not.
|
184
|
+
@generated_python_file_name.setter
|
185
|
+
def generated_python_file_name(self, value: str):
|
180
186
|
"""
|
181
|
-
|
187
|
+
Set the generated python file name.
|
188
|
+
:param value: The new value for the generated python file name.
|
189
|
+
"""
|
190
|
+
self._generated_python_file_name = value
|
182
191
|
|
183
192
|
@property
|
184
|
-
|
185
|
-
|
193
|
+
@abstractmethod
|
194
|
+
def _default_generated_python_file_name(self) -> str:
|
195
|
+
"""
|
196
|
+
:return: The default generated python file name.
|
197
|
+
"""
|
198
|
+
pass
|
199
|
+
|
200
|
+
@abstractmethod
|
201
|
+
def update_from_python_file(self, package_dir: str):
|
202
|
+
"""
|
203
|
+
Update the rules from the generated python file, that might have been modified by the user.
|
204
|
+
|
205
|
+
:param package_dir: The directory of the package that contains the generated python file.
|
206
|
+
"""
|
207
|
+
pass
|
186
208
|
|
187
209
|
|
188
210
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
189
211
|
|
212
|
+
def update_from_python_file(self, package_dir: str):
|
213
|
+
"""
|
214
|
+
Update the rules from the generated python file, that might have been modified by the user.
|
215
|
+
|
216
|
+
:param package_dir: The directory of the package that contains the generated python file.
|
217
|
+
"""
|
218
|
+
rule_ids = [r.uid for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None]
|
219
|
+
condition_func_names = [f'conditions_{rid}' for rid in rule_ids]
|
220
|
+
conclusion_func_names = [f'conclusion_{rid}' for rid in rule_ids]
|
221
|
+
all_func_names = condition_func_names + conclusion_func_names
|
222
|
+
filepath = f"{package_dir}/{self.generated_python_defs_file_name}.py"
|
223
|
+
functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
|
224
|
+
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
225
|
+
if rule.conditions is not None:
|
226
|
+
rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
|
227
|
+
if rule.conclusion is not None:
|
228
|
+
rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
|
229
|
+
|
190
230
|
@abstractmethod
|
191
231
|
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
|
192
232
|
defs_file: Optional[str] = None):
|
@@ -238,31 +278,26 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
238
278
|
"""
|
239
279
|
:return: The imports for the generated python file of the RDR as a string.
|
240
280
|
"""
|
241
|
-
|
281
|
+
defs_imports_list = []
|
242
282
|
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
243
283
|
if not rule.conditions:
|
244
284
|
continue
|
245
285
|
for scope in [rule.conditions.scope, rule.conclusion.scope]:
|
246
286
|
if scope is None:
|
247
287
|
continue
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
continue
|
254
|
-
defs_imports += new_imports
|
255
|
-
imports = ""
|
288
|
+
defs_imports_list.extend(get_imports_from_scope(scope))
|
289
|
+
if self.case_type.__module__ != "builtins":
|
290
|
+
defs_imports_list.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
|
291
|
+
defs_imports = "\n".join(set(defs_imports_list)) + "\n"
|
292
|
+
imports = []
|
256
293
|
if self.case_type.__module__ != "builtins":
|
257
|
-
|
258
|
-
if new_import not in defs_imports:
|
259
|
-
imports += new_import
|
294
|
+
imports.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
|
260
295
|
for conclusion_type in self.conclusion_type:
|
261
296
|
if conclusion_type.__module__ != "builtins":
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
imports
|
297
|
+
imports.append(f"from {conclusion_type.__module__} import {conclusion_type.__name__}")
|
298
|
+
imports.append("from ripple_down_rules.datastructures.case import Case, create_case")
|
299
|
+
imports = set(imports).difference(defs_imports_list)
|
300
|
+
imports = "\n".join(imports) + "\n"
|
266
301
|
return imports, defs_imports
|
267
302
|
|
268
303
|
def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
|
@@ -279,20 +314,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
279
314
|
pass
|
280
315
|
return importlib.import_module(name).classify
|
281
316
|
|
282
|
-
@property
|
283
|
-
def generated_python_file_name(self) -> str:
|
284
|
-
if self._generated_python_file_name is None:
|
285
|
-
self._generated_python_file_name = self._default_generated_python_file_name
|
286
|
-
return self._generated_python_file_name
|
287
|
-
|
288
|
-
@generated_python_file_name.setter
|
289
|
-
def generated_python_file_name(self, value: str):
|
290
|
-
"""
|
291
|
-
Set the generated python file name.
|
292
|
-
:param value: The new value for the generated python file name.
|
293
|
-
"""
|
294
|
-
self._generated_python_file_name = value
|
295
|
-
|
296
317
|
@property
|
297
318
|
def _default_generated_python_file_name(self) -> str:
|
298
319
|
"""
|
@@ -390,12 +411,12 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
390
411
|
self.start_rule = SingleClassRule(case_query.conditions, case_query.target, corner_case=case_query.case,
|
391
412
|
conclusion_name=case_query.attribute_name)
|
392
413
|
|
393
|
-
def classify(self, case: Case,
|
414
|
+
def classify(self, case: Case, modify_case: bool = False) -> Optional[Any]:
|
394
415
|
"""
|
395
416
|
Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
|
396
417
|
|
397
418
|
:param case: The case to classify.
|
398
|
-
:param
|
419
|
+
:param modify_case: Whether to modify the original case attributes with the conclusion or not.
|
399
420
|
"""
|
400
421
|
pred = self.evaluate(case)
|
401
422
|
return pred.conclusion(case) if pred is not None and pred.fired else self.default_conclusion
|
@@ -481,7 +502,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
481
502
|
super(MultiClassRDR, self).__init__(start_rule)
|
482
503
|
self.mode: MCRDRMode = mode
|
483
504
|
|
484
|
-
def classify(self, case: Union[Case, SQLTable]) -> Set[Any]:
|
505
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) -> Set[Any]:
|
485
506
|
evaluated_rule = self.start_rule
|
486
507
|
self.conclusions = []
|
487
508
|
while evaluated_rule:
|
@@ -492,7 +513,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
492
513
|
return make_set(self.conclusions)
|
493
514
|
|
494
515
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None
|
495
|
-
|
516
|
+
, **kwargs) -> Set[Union[CaseAttribute, CallableExpression, None]]:
|
496
517
|
"""
|
497
518
|
Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
|
498
519
|
or missing by comparing the case with the target category if provided.
|
@@ -504,13 +525,13 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
504
525
|
self.conclusions = []
|
505
526
|
self.stop_rule_conditions = None
|
506
527
|
evaluated_rule = self.start_rule
|
507
|
-
|
528
|
+
target_value = make_set(case_query.target_value)
|
508
529
|
while evaluated_rule:
|
509
530
|
next_rule = evaluated_rule(case_query.case)
|
510
531
|
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
511
532
|
|
512
533
|
if evaluated_rule.fired:
|
513
|
-
if not make_set(rule_conclusion).issubset(
|
534
|
+
if not make_set(rule_conclusion).issubset(target_value):
|
514
535
|
# Rule fired and conclusion is different from target
|
515
536
|
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
|
516
537
|
else:
|
@@ -518,7 +539,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
518
539
|
self.add_conclusion(evaluated_rule, case_query.case)
|
519
540
|
|
520
541
|
if not next_rule:
|
521
|
-
if not make_set(
|
542
|
+
if not make_set(target_value).issubset(make_set(self.conclusions)):
|
522
543
|
# Nothing fired and there is a target that should have been in the conclusions
|
523
544
|
self.add_rule_for_case(case_query, expert)
|
524
545
|
# Have to check all rules again to make sure only this new rule fires
|
@@ -556,12 +577,9 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
556
577
|
|
557
578
|
def _get_imports(self) -> Tuple[str, str]:
|
558
579
|
imports, defs_imports = super()._get_imports()
|
559
|
-
|
560
|
-
if len(conclusion_types) == 1:
|
561
|
-
imports += f"from typing_extensions import Set\n"
|
562
|
-
else:
|
563
|
-
imports += "from typing_extensions import Set, Union\n"
|
580
|
+
imports += f"from typing_extensions import Set, Union\n"
|
564
581
|
imports += "from ripple_down_rules.utils import make_set\n"
|
582
|
+
defs_imports += "from typing_extensions import Union\n"
|
565
583
|
return imports, defs_imports
|
566
584
|
|
567
585
|
def update_start_rule(self, case_query: CaseQuery, expert: Expert):
|
@@ -594,6 +612,8 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
594
612
|
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
595
613
|
if is_conflicting(rule_conclusion, case_query.target_value):
|
596
614
|
self.stop_conclusion(case_query, expert, evaluated_rule)
|
615
|
+
else:
|
616
|
+
self.add_conclusion(evaluated_rule, case_query.case)
|
597
617
|
|
598
618
|
def stop_conclusion(self, case_query: CaseQuery,
|
599
619
|
expert: Expert, evaluated_rule: MultiClassTopRule):
|
@@ -829,6 +849,15 @@ class GeneralRDR(RippleDownRules):
|
|
829
849
|
start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
|
830
850
|
return cls(start_rules_dict)
|
831
851
|
|
852
|
+
def update_from_python_file(self, package_dir: str) -> None:
|
853
|
+
"""
|
854
|
+
Update the rules from the generated python file, that might have been modified by the user.
|
855
|
+
|
856
|
+
:param package_dir: The directory of the package that contains the generated python file.
|
857
|
+
"""
|
858
|
+
for rdr in self.start_rules_dict.values():
|
859
|
+
rdr.update_from_python_file(package_dir)
|
860
|
+
|
832
861
|
def write_to_python_file(self, file_path: str, postfix: str = "") -> None:
|
833
862
|
"""
|
834
863
|
Write the tree of rules as source code to a file.
|
@@ -864,21 +893,10 @@ class GeneralRDR(RippleDownRules):
|
|
864
893
|
def get_rdr_classifier_from_python_file(self, file_path: str) -> Callable[[Any], Any]:
|
865
894
|
"""
|
866
895
|
:param file_path: The path to the file that contains the RDR classifier function.
|
867
|
-
:param postfix: The postfix to add to the file name.
|
868
896
|
:return: The module that contains the rdr classifier function.
|
869
897
|
"""
|
870
898
|
return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
|
871
899
|
|
872
|
-
@property
|
873
|
-
def generated_python_file_name(self) -> str:
|
874
|
-
if self._generated_python_file_name is None:
|
875
|
-
self._generated_python_file_name = self._default_generated_python_file_name
|
876
|
-
return self._generated_python_file_name
|
877
|
-
|
878
|
-
@generated_python_file_name.setter
|
879
|
-
def generated_python_file_name(self, value: str):
|
880
|
-
self._generated_python_file_name = value
|
881
|
-
|
882
900
|
@property
|
883
901
|
def _default_generated_python_file_name(self) -> str:
|
884
902
|
"""
|
ripple_down_rules/rules.py
CHANGED
@@ -2,17 +2,16 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import re
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
-
from
|
5
|
+
from uuid import uuid4
|
6
6
|
|
7
7
|
from anytree import NodeMixin
|
8
|
+
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
8
9
|
from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
|
9
10
|
|
10
11
|
from .datastructures.callable_expression import CallableExpression
|
11
12
|
from .datastructures.case import Case
|
12
|
-
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
13
13
|
from .datastructures.enums import RDREdge, Stop
|
14
|
-
from .utils import SubclassJSONSerializer,
|
15
|
-
get_rule_conclusion_as_source_code
|
14
|
+
from .utils import SubclassJSONSerializer, conclusion_to_json
|
16
15
|
|
17
16
|
|
18
17
|
class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
@@ -26,7 +25,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
26
25
|
parent: Optional[Rule] = None,
|
27
26
|
corner_case: Optional[Union[Case, SQLTable]] = None,
|
28
27
|
weight: Optional[str] = None,
|
29
|
-
conclusion_name: Optional[str] = None
|
28
|
+
conclusion_name: Optional[str] = None,
|
29
|
+
uid: Optional[str] = None):
|
30
30
|
"""
|
31
31
|
A rule in the ripple down rules classifier.
|
32
32
|
|
@@ -36,6 +36,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
36
36
|
:param corner_case: The corner case that this rule is based on/created from.
|
37
37
|
:param weight: The weight of the rule, which is the type of edge connecting the rule to its parent.
|
38
38
|
:param conclusion_name: The name of the conclusion of the rule.
|
39
|
+
:param uid: The unique id of the rule.
|
39
40
|
"""
|
40
41
|
super(Rule, self).__init__()
|
41
42
|
self.conclusion = conclusion
|
@@ -46,6 +47,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
46
47
|
self.conclusion_name: Optional[str] = conclusion_name
|
47
48
|
self.json_serialization: Optional[Dict[str, Any]] = None
|
48
49
|
self._name: Optional[str] = None
|
50
|
+
# generate a unique id for the rule using uuid4
|
51
|
+
self.uid: str = uid if uid else str(uuid4().int)
|
49
52
|
|
50
53
|
def _post_detach(self, parent):
|
51
54
|
"""
|
@@ -90,15 +93,32 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
90
93
|
conclusion = self.conclusion.user_input
|
91
94
|
else:
|
92
95
|
conclusion = self.conclusion.conclusion
|
93
|
-
conclusion_func, conclusion_func_call = self.
|
96
|
+
conclusion_func, conclusion_func_call = self.get_conclusion_as_source_code(conclusion,
|
97
|
+
parent_indent=parent_indent)
|
94
98
|
if conclusion_func is not None:
|
95
99
|
with open(defs_file, 'a') as f:
|
96
100
|
f.write(conclusion_func.strip() + "\n\n\n")
|
97
101
|
return conclusion_func_call
|
98
102
|
|
99
|
-
|
100
|
-
|
101
|
-
|
103
|
+
def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
|
104
|
+
"""
|
105
|
+
Convert the conclusion of a rule to source code.
|
106
|
+
|
107
|
+
:param conclusion: The conclusion to convert to source code.
|
108
|
+
:param parent_indent: The indentation of the parent rule.
|
109
|
+
:return: The source code of the conclusion as a tuple of strings, one for the function and one for the call.
|
110
|
+
"""
|
111
|
+
if "def " in conclusion:
|
112
|
+
# This means the conclusion is a definition that should be written and then called
|
113
|
+
conclusion_lines = conclusion.split('\n')
|
114
|
+
# use regex to replace the function name
|
115
|
+
new_function_name = f"def conclusion_{self.uid}"
|
116
|
+
conclusion_lines[0] = re.sub(r"def (\w+)", new_function_name, conclusion_lines[0])
|
117
|
+
func_call = f"{parent_indent} return {new_function_name.replace('def ', '')}(case)\n"
|
118
|
+
return "\n".join(conclusion_lines).strip(' '), func_call
|
119
|
+
else:
|
120
|
+
raise ValueError(f"Conclusion is format is not valid, it should be contain a function definition."
|
121
|
+
f" Instead got:\n{conclusion}\n")
|
102
122
|
|
103
123
|
def write_condition_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
|
104
124
|
"""
|
@@ -116,7 +136,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
116
136
|
# This means the conditions are a definition that should be written and then called
|
117
137
|
conditions_lines = self.conditions.user_input.split('\n')
|
118
138
|
# use regex to replace the function name
|
119
|
-
new_function_name = f"def conditions_{
|
139
|
+
new_function_name = f"def conditions_{self.uid}"
|
120
140
|
conditions_lines[0] = re.sub(r"def (\w+)", new_function_name, conditions_lines[0])
|
121
141
|
def_code = "\n".join(conditions_lines)
|
122
142
|
with open(defs_file, 'a') as f:
|
@@ -133,7 +153,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
133
153
|
"parent": self.parent.json_serialization if self.parent else None,
|
134
154
|
"corner_case": SubclassJSONSerializer.to_json_static(self.corner_case),
|
135
155
|
"conclusion_name": self.conclusion_name,
|
136
|
-
"weight": self.weight
|
156
|
+
"weight": self.weight,
|
157
|
+
"uid": self.uid}
|
137
158
|
return json_serialization
|
138
159
|
|
139
160
|
@classmethod
|
@@ -143,7 +164,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
143
164
|
parent=cls.from_json(data["parent"]),
|
144
165
|
corner_case=Case.from_json(data["corner_case"]),
|
145
166
|
conclusion_name=data["conclusion_name"],
|
146
|
-
weight=data["weight"]
|
167
|
+
weight=data["weight"],
|
168
|
+
uid=data["uid"])
|
147
169
|
return loaded_rule
|
148
170
|
|
149
171
|
@property
|
@@ -268,14 +290,6 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
|
268
290
|
loaded_rule.alternative = SingleClassRule.from_json(data["alternative"])
|
269
291
|
return loaded_rule
|
270
292
|
|
271
|
-
def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
|
272
|
-
conclusion = str(conclusion)
|
273
|
-
# indent = parent_indent + " " * 4
|
274
|
-
# if '\n' not in conclusion:
|
275
|
-
# return None, f"{indent}return {conclusion}\n"
|
276
|
-
# else:
|
277
|
-
return get_rule_conclusion_as_source_code(self, conclusion, parent_indent=parent_indent)
|
278
|
-
|
279
293
|
def _if_statement_source_code_clause(self) -> str:
|
280
294
|
return "elif" if self.weight == RDREdge.Alternative.value else "if"
|
281
295
|
|
@@ -319,7 +333,7 @@ class MultiClassStopRule(Rule, HasAlternativeRule):
|
|
319
333
|
loaded_rule.alternative = MultiClassStopRule.from_json(data["alternative"])
|
320
334
|
return loaded_rule
|
321
335
|
|
322
|
-
def
|
336
|
+
def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[None, str]:
|
323
337
|
return None, f"{parent_indent}{' ' * 4}pass\n"
|
324
338
|
|
325
339
|
def _if_statement_source_code_clause(self) -> str:
|
@@ -364,20 +378,11 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
|
364
378
|
loaded_rule.alternative = MultiClassTopRule.from_json(data["alternative"])
|
365
379
|
return loaded_rule
|
366
380
|
|
367
|
-
def
|
368
|
-
|
369
|
-
indent = parent_indent + " " * 4
|
370
|
-
# if '\n' not in conclusion_str:
|
371
|
-
# func = None
|
372
|
-
# if is_iterable(conclusion):
|
373
|
-
# conclusion_str = "{" + ", ".join([str(c) for c in conclusion]) + "}"
|
374
|
-
# else:
|
375
|
-
# conclusion_str = "{" + str(conclusion) + "}"
|
376
|
-
# else:
|
377
|
-
func, func_call = get_rule_conclusion_as_source_code(self, conclusion_str, parent_indent=parent_indent)
|
381
|
+
def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[str, str]:
|
382
|
+
func, func_call = super().get_conclusion_as_source_code(str(conclusion), parent_indent=parent_indent)
|
378
383
|
conclusion_str = func_call.replace("return ", "").strip()
|
379
384
|
|
380
|
-
statement = f"{
|
385
|
+
statement = f"{parent_indent} conclusions.update(make_set({conclusion_str}))\n"
|
381
386
|
if self.alternative is None:
|
382
387
|
statement += f"{parent_indent}return conclusions\n"
|
383
388
|
return func, statement
|
ripple_down_rules/utils.py
CHANGED
@@ -36,16 +36,131 @@ import ast
|
|
36
36
|
matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
|
37
37
|
|
38
38
|
|
39
|
-
def
|
39
|
+
def get_imports_from_types(types: List[Type]) -> List[str]:
|
40
|
+
"""
|
41
|
+
Get the import statements for a list of types.
|
42
|
+
|
43
|
+
:param types: The types to get the import statements for.
|
44
|
+
:return: The import statements as a string.
|
45
|
+
"""
|
46
|
+
imports = map(get_import_from_type, types)
|
47
|
+
return list({i for i in imports if i is not None})
|
48
|
+
|
49
|
+
|
50
|
+
def get_import_from_type(type_: Type) -> Optional[str]:
|
51
|
+
"""
|
52
|
+
Get the import statement for a given type.
|
53
|
+
|
54
|
+
:param type_: The type to get the import statement for.
|
55
|
+
:return: The import statement as a string.
|
56
|
+
"""
|
57
|
+
if hasattr(type_, "__module__") and hasattr(type_, "__name__"):
|
58
|
+
if type_.__module__ == "builtins":
|
59
|
+
return
|
60
|
+
return f"from {type_.__module__} import {type_.__name__}"
|
61
|
+
|
62
|
+
|
63
|
+
def get_imports_from_scope(scope: Dict[str, Any]) -> List[str]:
|
64
|
+
"""
|
65
|
+
Get the imports from the given scope.
|
66
|
+
|
67
|
+
:param scope: The scope to get the imports from.
|
68
|
+
:return: The imports as a string.
|
69
|
+
"""
|
70
|
+
imports = []
|
71
|
+
for k, v in scope.items():
|
72
|
+
if not hasattr(v, "__module__") or not hasattr(v, "__name__"):
|
73
|
+
continue
|
74
|
+
imports.append(f"from {v.__module__} import {v.__name__}")
|
75
|
+
return imports
|
76
|
+
|
77
|
+
|
78
|
+
def extract_imports(file_path):
|
79
|
+
with open(file_path, "r") as f:
|
80
|
+
tree = ast.parse(f.read(), filename=file_path)
|
81
|
+
|
82
|
+
scope = {}
|
83
|
+
|
84
|
+
for node in ast.walk(tree):
|
85
|
+
if isinstance(node, ast.Import):
|
86
|
+
for alias in node.names:
|
87
|
+
module_name = alias.name
|
88
|
+
asname = alias.asname or alias.name
|
89
|
+
try:
|
90
|
+
scope[asname] = importlib.import_module(module_name)
|
91
|
+
except ImportError as e:
|
92
|
+
print(f"Could not import {module_name}: {e}")
|
93
|
+
elif isinstance(node, ast.ImportFrom):
|
94
|
+
module_name = node.module
|
95
|
+
for alias in node.names:
|
96
|
+
name = alias.name
|
97
|
+
asname = alias.asname or name
|
98
|
+
try:
|
99
|
+
module = importlib.import_module(module_name)
|
100
|
+
scope[asname] = getattr(module, name)
|
101
|
+
except (ImportError, AttributeError) as e:
|
102
|
+
print(f"Could not import {name} from {module_name}: {e}")
|
103
|
+
|
104
|
+
return scope
|
105
|
+
|
106
|
+
|
107
|
+
def extract_function_source(file_path: str,
|
108
|
+
function_names: List[str], join_lines: bool = True,
|
109
|
+
return_line_numbers: bool = False,
|
110
|
+
include_signature: bool = True) \
|
111
|
+
-> Union[Dict[str, Union[str, List[str]]],
|
112
|
+
Tuple[Dict[str, Union[str, List[str]]], List[Tuple[int, int]]]]:
|
113
|
+
"""
|
114
|
+
Extract the source code of a function from a file.
|
115
|
+
|
116
|
+
:param file_path: The path to the file.
|
117
|
+
:param function_names: The names of the functions to extract.
|
118
|
+
:param join_lines: Whether to join the lines of the function.
|
119
|
+
:param return_line_numbers: Whether to return the line numbers of the function.
|
120
|
+
:param include_signature: Whether to include the function signature in the source code.
|
121
|
+
:return: A dictionary mapping function names to their source code as a string if join_lines is True,
|
122
|
+
otherwise as a list of strings.
|
123
|
+
"""
|
124
|
+
with open(file_path, "r") as f:
|
125
|
+
source = f.read()
|
126
|
+
|
127
|
+
# Parse the source code into an AST
|
128
|
+
tree = ast.parse(source)
|
129
|
+
function_names = make_list(function_names)
|
130
|
+
functions_source: Dict[str, Union[str, List[str]]] = {}
|
131
|
+
line_numbers = []
|
132
|
+
for node in tree.body:
|
133
|
+
if isinstance(node, ast.FunctionDef) and node.name in function_names:
|
134
|
+
# Get the line numbers of the function
|
135
|
+
lines = source.splitlines()
|
136
|
+
func_lines = lines[node.lineno - 1:node.end_lineno]
|
137
|
+
if not include_signature:
|
138
|
+
func_lines = func_lines[1:]
|
139
|
+
line_numbers.append((node.lineno, node.end_lineno))
|
140
|
+
functions_source[node.name] = "\n".join(func_lines) if join_lines else func_lines
|
141
|
+
if len(functions_source) == len(function_names):
|
142
|
+
break
|
143
|
+
if len(functions_source) != len(function_names):
|
144
|
+
raise ValueError(f"Could not find all functions in {file_path}: {function_names} not found,"
|
145
|
+
f"functions not found: {set(function_names) - set(functions_source.keys())}")
|
146
|
+
if return_line_numbers:
|
147
|
+
return functions_source, line_numbers
|
148
|
+
return functions_source
|
149
|
+
|
150
|
+
|
151
|
+
def encapsulate_user_input(user_input: str, func_signature: str, func_doc: Optional[str] = None) -> str:
|
40
152
|
"""
|
41
153
|
Encapsulate the user input string with a function definition.
|
42
154
|
|
43
155
|
:param user_input: The user input string.
|
44
156
|
:param func_signature: The function signature to use for encapsulation.
|
157
|
+
:param func_doc: The function docstring to use for encapsulation.
|
45
158
|
:return: The encapsulated user input string.
|
46
159
|
"""
|
47
160
|
if func_signature not in user_input:
|
48
161
|
new_user_input = func_signature + "\n "
|
162
|
+
if func_doc is not None:
|
163
|
+
new_user_input += f"\"\"\"{func_doc}\"\"\"" + "\n "
|
49
164
|
if "return " not in user_input:
|
50
165
|
if '\n' not in user_input:
|
51
166
|
new_user_input += f"return {user_input}"
|
@@ -173,29 +288,6 @@ def calculate_precision_and_recall(pred_cat: Dict[str, Any], target: Dict[str, A
|
|
173
288
|
return precision, recall
|
174
289
|
|
175
290
|
|
176
|
-
def get_rule_conclusion_as_source_code(rule: Rule, conclusion: str, parent_indent: str = "") -> Tuple[str, str]:
|
177
|
-
"""
|
178
|
-
Convert the conclusion of a rule to source code.
|
179
|
-
|
180
|
-
:param rule: The rule to get the conclusion from.
|
181
|
-
:param conclusion: The conclusion to convert to source code.
|
182
|
-
:param parent_indent: The indentation to use for the source code.
|
183
|
-
:return: The source code of the conclusion as a tuple of strings, one for the function and one for the call.
|
184
|
-
"""
|
185
|
-
indent = f"{parent_indent}{' ' * 4}"
|
186
|
-
if "def " in conclusion:
|
187
|
-
# This means the conclusion is a definition that should be written and then called
|
188
|
-
conclusion_lines = conclusion.split('\n')
|
189
|
-
# use regex to replace the function name
|
190
|
-
new_function_name = f"def conclusion_{id(rule)}"
|
191
|
-
conclusion_lines[0] = re.sub(r"def (\w+)", new_function_name, conclusion_lines[0])
|
192
|
-
func_call = f"{indent}return {new_function_name.replace('def ', '')}(case)\n"
|
193
|
-
return "\n".join(conclusion_lines).strip(' '), func_call
|
194
|
-
else:
|
195
|
-
raise ValueError(f"Conclusion is format is not valid, it should be a one line string or "
|
196
|
-
f"contain a function definition. Instead got:\n{conclusion}\n")
|
197
|
-
|
198
|
-
|
199
291
|
def ask_llm(prompt):
|
200
292
|
try:
|
201
293
|
response = requests.post("http://localhost:11434/api/generate", json={
|
@@ -317,7 +409,13 @@ def extract_dependencies(code_lines):
|
|
317
409
|
|
318
410
|
for stmt in reversed(tree.body[:-1]):
|
319
411
|
if handle_stmt(stmt, needed):
|
320
|
-
|
412
|
+
# check if the statement is a function, if so then all its lines not just the line in line_map are needed.
|
413
|
+
if isinstance(stmt, ast.FunctionDef):
|
414
|
+
start_code_line = line_map[id(stmt)]
|
415
|
+
end_code_line = start_code_line + stmt.end_lineno
|
416
|
+
required_lines.extend(code_lines[start_code_line:end_code_line])
|
417
|
+
else:
|
418
|
+
required_lines.insert(0, code_lines[line_map[id(stmt)]])
|
321
419
|
|
322
420
|
required_lines.append(code_lines[-1]) # Always include return
|
323
421
|
return required_lines
|
@@ -749,9 +847,12 @@ def copy_orm_instance(instance: SQLTable) -> SQLTable:
|
|
749
847
|
:return: The copied instance.
|
750
848
|
"""
|
751
849
|
session: Session = inspect(instance).session
|
752
|
-
session
|
753
|
-
|
754
|
-
|
850
|
+
if session is not None:
|
851
|
+
session.expunge(instance)
|
852
|
+
new_instance = deepcopy(instance)
|
853
|
+
session.add(instance)
|
854
|
+
else:
|
855
|
+
new_instance = instance
|
755
856
|
return new_instance
|
756
857
|
|
757
858
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -0,0 +1,20 @@
|
|
1
|
+
ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
ripple_down_rules/datasets.py,sha256=mjJh1GLD_5qMgHaukdDWSGphXS9k_BPEF001ZXPchr8,4687
|
3
|
+
ripple_down_rules/experts.py,sha256=JGVvSNiWhm4FpRpg76f98tl8Ii_C7x_aWD9FxD-JDLQ,6130
|
4
|
+
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
|
+
ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
|
6
|
+
ripple_down_rules/prompt.py,sha256=sab5yqSJaMGgqfOsXybzZhbKx2vyU0h5dfcSqsQuY7E,16118
|
7
|
+
ripple_down_rules/rdr.py,sha256=vxNZckp6sLAUD92JQgfCzhBhg9CXfMZ_7W4VgrIUFjU,43366
|
8
|
+
ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
|
9
|
+
ripple_down_rules/rules.py,sha256=QQy7NBG6mKiowxVG_LjQJBxLTDW2hMyx5zAgwUxdCMM,17183
|
10
|
+
ripple_down_rules/utils.py,sha256=Q5VWgisHCFxEA06Y0ImSfu7cozVJCHTCDYGWbDJNlgA,43516
|
11
|
+
ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
|
12
|
+
ripple_down_rules/datastructures/callable_expression.py,sha256=1hwRjS8-csfZ0HSlhLLOBCcvKJ8W3_N1qHayFLNhl3k,10908
|
13
|
+
ripple_down_rules/datastructures/case.py,sha256=nJDKOjyhGLx4gt0sHxJNxBLdy9X2SLcDn89_SmKzwoc,14035
|
14
|
+
ripple_down_rules/datastructures/dataclasses.py,sha256=TAOAeEvh0BeTis3rEHu8rpCeqNNhU0vK3to0JaBwTio,5961
|
15
|
+
ripple_down_rules/datastructures/enums.py,sha256=hlE6LAa1jUafnG_6UazvaPDfhC1ClI7hKvD89zOyAO8,4661
|
16
|
+
ripple_down_rules-0.2.0.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
17
|
+
ripple_down_rules-0.2.0.dist-info/METADATA,sha256=FZV_PqD3Yc6Qw1YCPAZowkycnAa6OzdcyMlpT7w8AS4,42575
|
18
|
+
ripple_down_rules-0.2.0.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
19
|
+
ripple_down_rules-0.2.0.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
20
|
+
ripple_down_rules-0.2.0.dist-info/RECORD,,
|
@@ -1,20 +0,0 @@
|
|
1
|
-
ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
ripple_down_rules/datasets.py,sha256=mjJh1GLD_5qMgHaukdDWSGphXS9k_BPEF001ZXPchr8,4687
|
3
|
-
ripple_down_rules/experts.py,sha256=JGVvSNiWhm4FpRpg76f98tl8Ii_C7x_aWD9FxD-JDLQ,6130
|
4
|
-
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
|
-
ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
|
6
|
-
ripple_down_rules/prompt.py,sha256=gdLV2qhq-1kbmlJhVmkCwGGFhIeDxV0p5u2hwrbGUpk,6370
|
7
|
-
ripple_down_rules/rdr.py,sha256=E8YXSU_VMd3e_LJItjzbBjr0g1KXhhU_HL4FMAgdczc,42203
|
8
|
-
ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
|
9
|
-
ripple_down_rules/rules.py,sha256=WfMWzgfI_5Tqv8a2k7jkpGdvwu-zYCG36EmpCl7GiEQ,16576
|
10
|
-
ripple_down_rules/utils.py,sha256=INGRVIUH-SgssUW2T9gUMRf-XyJ3rxY_kuJUxfbsyOg,39683
|
11
|
-
ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
|
12
|
-
ripple_down_rules/datastructures/callable_expression.py,sha256=JLd8ZdIvAGX3mm-tID0buZIlbMF6hW2Z_jn5KA7X_ws,9788
|
13
|
-
ripple_down_rules/datastructures/case.py,sha256=nJDKOjyhGLx4gt0sHxJNxBLdy9X2SLcDn89_SmKzwoc,14035
|
14
|
-
ripple_down_rules/datastructures/dataclasses.py,sha256=TAOAeEvh0BeTis3rEHu8rpCeqNNhU0vK3to0JaBwTio,5961
|
15
|
-
ripple_down_rules/datastructures/enums.py,sha256=RdyPUp9Ls1QuLmkcMMkBbCWrmXIZI4xWuM-cLPYZhR0,4666
|
16
|
-
ripple_down_rules-0.1.69.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
17
|
-
ripple_down_rules-0.1.69.dist-info/METADATA,sha256=hjuztFZVWpCOPfWULTsRdVCrfeMo8nc2XI4DdW4GJ9g,42576
|
18
|
-
ripple_down_rules-0.1.69.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
|
19
|
-
ripple_down_rules-0.1.69.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
20
|
-
ripple_down_rules-0.1.69.dist-info/RECORD,,
|
File without changes
|
File without changes
|