ripple-down-rules 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +66 -6
- ripple_down_rules/datastructures/callable_expression.py +13 -5
- ripple_down_rules/datastructures/case.py +33 -5
- ripple_down_rules/datastructures/dataclasses.py +30 -8
- ripple_down_rules/datastructures/enums.py +30 -1
- ripple_down_rules/experts.py +2 -1
- ripple_down_rules/prompt.py +215 -109
- ripple_down_rules/rdr.py +7 -5
- ripple_down_rules/rdr_decorators.py +122 -38
- ripple_down_rules/utils.py +162 -18
- {ripple_down_rules-0.2.4.dist-info → ripple_down_rules-0.3.0.dist-info}/METADATA +1 -1
- ripple_down_rules-0.3.0.dist-info/RECORD +20 -0
- {ripple_down_rules-0.2.4.dist-info → ripple_down_rules-0.3.0.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.2.4.dist-info/RECORD +0 -20
- {ripple_down_rules-0.2.4.dist-info → ripple_down_rules-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.2.4.dist-info → ripple_down_rules-0.3.0.dist-info}/top_level.txt +0 -0
ripple_down_rules/prompt.py
CHANGED
@@ -1,91 +1,222 @@
|
|
1
1
|
import ast
|
2
2
|
import logging
|
3
3
|
import os
|
4
|
+
import shutil
|
5
|
+
import socket
|
4
6
|
import subprocess
|
5
7
|
import tempfile
|
6
8
|
from _ast import AST
|
7
9
|
from functools import cached_property
|
8
10
|
from textwrap import indent, dedent
|
9
11
|
|
10
|
-
from IPython.core.magic import
|
12
|
+
from IPython.core.magic import line_magic, Magics, magics_class
|
11
13
|
from IPython.terminal.embed import InteractiveShellEmbed
|
14
|
+
from colorama import Fore, Style
|
12
15
|
from pygments import highlight
|
13
16
|
from pygments.formatters.terminal import TerminalFormatter
|
14
17
|
from pygments.lexers.python import PythonLexer
|
15
18
|
from traitlets.config import Config
|
16
|
-
from typing_extensions import List, Optional, Tuple, Dict, Type
|
19
|
+
from typing_extensions import List, Optional, Tuple, Dict, Type
|
17
20
|
|
18
|
-
from .datastructures.enums import PromptFor
|
19
|
-
from .datastructures.case import Case
|
20
21
|
from .datastructures.callable_expression import CallableExpression, parse_string_to_expression
|
22
|
+
from .datastructures.case import Case
|
21
23
|
from .datastructures.dataclasses import CaseQuery
|
22
|
-
from .
|
23
|
-
|
24
|
-
|
25
|
-
|
24
|
+
from .datastructures.enums import PromptFor, Editor
|
25
|
+
from .utils import extract_dependencies, contains_return_statement, get_imports_from_scope, make_list, \
|
26
|
+
get_imports_from_types, extract_function_source, encapsulate_user_input, str_to_snake_case, typing_hint_to_str
|
27
|
+
|
28
|
+
|
29
|
+
def detect_available_editor() -> Optional[Editor]:
|
30
|
+
"""
|
31
|
+
Detect the available editor on the system.
|
32
|
+
|
33
|
+
:return: The first found editor that is available on the system.
|
34
|
+
"""
|
35
|
+
editor_env = os.environ.get("RDR_EDITOR")
|
36
|
+
if editor_env:
|
37
|
+
return Editor.from_str(editor_env)
|
38
|
+
for editor in [Editor.Pycharm, Editor.Code, Editor.CodeServer]:
|
39
|
+
if shutil.which(editor.value):
|
40
|
+
return editor
|
41
|
+
return None
|
42
|
+
|
43
|
+
|
44
|
+
def is_port_in_use(port: int = 8080) -> bool:
|
45
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
46
|
+
return s.connect_ex(("localhost", port)) == 0
|
47
|
+
|
48
|
+
|
49
|
+
def start_code_server(workspace):
|
50
|
+
"""
|
51
|
+
Start the code-server in the given workspace.
|
52
|
+
"""
|
53
|
+
filename = os.path.join(os.path.dirname(__file__), "start-code-server.sh")
|
54
|
+
os.system(f"chmod +x {filename}")
|
55
|
+
print(f"Starting code-server at {filename}")
|
56
|
+
return subprocess.Popen(["/bin/bash", filename, workspace], stdout=subprocess.PIPE,
|
57
|
+
stderr=subprocess.PIPE, text=True)
|
26
58
|
|
27
59
|
|
28
60
|
@magics_class
|
29
61
|
class MyMagics(Magics):
|
30
|
-
|
31
|
-
|
62
|
+
temp_file_path: Optional[str] = None
|
63
|
+
"""
|
64
|
+
The path to the temporary file that is created for the user to edit.
|
65
|
+
"""
|
66
|
+
port: int = int(os.environ.get("RDR_EDITOR_PORT", 8080))
|
67
|
+
"""
|
68
|
+
The port to use for the code-server.
|
69
|
+
"""
|
70
|
+
process: Optional[subprocess.Popen] = None
|
71
|
+
"""
|
72
|
+
The process of the code-server.
|
73
|
+
"""
|
74
|
+
|
75
|
+
def __init__(self, shell, scope,
|
32
76
|
code_to_modify: Optional[str] = None,
|
33
|
-
|
34
|
-
|
77
|
+
prompt_for: Optional[PromptFor] = None,
|
78
|
+
case_query: Optional[CaseQuery] = None):
|
35
79
|
super().__init__(shell)
|
36
80
|
self.scope = scope
|
37
|
-
self.temp_file_path = None
|
38
|
-
self.func_name = func_name
|
39
|
-
self.func_doc = func_doc
|
40
81
|
self.code_to_modify = code_to_modify
|
41
|
-
self.attribute_type_hint = attribute_type_hint
|
42
82
|
self.prompt_for = prompt_for
|
43
|
-
self.
|
83
|
+
self.case_query = case_query
|
84
|
+
self.output_type = self.get_output_type()
|
44
85
|
self.user_edit_line = 0
|
45
|
-
self.
|
46
|
-
self.
|
86
|
+
self.func_name: str = self.get_func_name()
|
87
|
+
self.func_doc: str = self.get_func_doc()
|
88
|
+
self.function_signature: str = self.get_function_signature()
|
89
|
+
self.editor: Optional[Editor] = detect_available_editor()
|
90
|
+
self.workspace: str = os.environ.get("RDR_EDITOR_WORKSPACE", os.path.dirname(self.scope['__file__']))
|
91
|
+
self.temp_file_path: str = os.path.join(self.workspace, "edit_code_here.py")
|
92
|
+
|
93
|
+
def get_output_type(self) -> List[Type]:
|
94
|
+
"""
|
95
|
+
:return: The output type of the function as a list of types.
|
96
|
+
"""
|
97
|
+
if self.prompt_for == PromptFor.Conditions:
|
98
|
+
output_type = bool
|
99
|
+
else:
|
100
|
+
output_type = self.case_query.attribute_type
|
101
|
+
return make_list(output_type) if output_type is not None else None
|
47
102
|
|
48
103
|
@line_magic
|
49
104
|
def edit(self, line):
|
105
|
+
if self.editor is None:
|
106
|
+
print(f"{Fore.RED}ERROR:: No editor found. Please install PyCharm, VSCode or code-server.{Style.RESET_ALL}")
|
107
|
+
return
|
50
108
|
|
51
109
|
boilerplate_code = self.build_boilerplate_code()
|
52
|
-
|
53
110
|
self.write_to_file(boilerplate_code)
|
54
111
|
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
112
|
+
self.open_file_in_editor()
|
113
|
+
|
114
|
+
def open_file_in_editor(self):
|
115
|
+
"""
|
116
|
+
Open the file in the available editor.
|
117
|
+
"""
|
118
|
+
if self.editor == Editor.Pycharm:
|
119
|
+
subprocess.Popen(["pycharm", "--line", str(self.user_edit_line), self.temp_file_path],
|
120
|
+
stdout=subprocess.DEVNULL,
|
121
|
+
stderr=subprocess.DEVNULL)
|
122
|
+
elif self.editor == Editor.Code:
|
123
|
+
subprocess.Popen(["code", self.temp_file_path])
|
124
|
+
elif self.editor == Editor.CodeServer:
|
125
|
+
try:
|
126
|
+
subprocess.check_output(["pgrep", "-f", "code-server"])
|
127
|
+
# check if same port is in use
|
128
|
+
if is_port_in_use(self.port):
|
129
|
+
print(f"Code-server is already running on port {self.port}.")
|
130
|
+
else:
|
131
|
+
raise ValueError("Port is not in use.")
|
132
|
+
except (subprocess.CalledProcessError, ValueError) as e:
|
133
|
+
self.process = start_code_server(self.workspace)
|
134
|
+
print(f"Open code-server in your browser at http://localhost:{self.port}?folder={self.workspace}")
|
135
|
+
print(f"Edit the file: {Fore.BLUE}{self.temp_file_path}")
|
59
136
|
|
60
137
|
def build_boilerplate_code(self):
|
61
138
|
imports = self.get_imports()
|
62
|
-
self.
|
139
|
+
if self.function_signature is None:
|
140
|
+
self.function_signature = self.get_function_signature()
|
141
|
+
if self.func_doc is None:
|
142
|
+
self.func_doc = self.get_func_doc()
|
63
143
|
if self.code_to_modify is not None:
|
64
144
|
body = indent(dedent(self.code_to_modify), ' ')
|
65
145
|
else:
|
66
146
|
body = " # Write your code here\n pass"
|
67
147
|
boilerplate = f"""{imports}\n\n{self.function_signature}\n \"\"\"{self.func_doc}\"\"\"\n{body}"""
|
68
|
-
self.user_edit_line = imports.count('\n')+6
|
148
|
+
self.user_edit_line = imports.count('\n') + 6
|
69
149
|
return boilerplate
|
70
150
|
|
71
|
-
def
|
151
|
+
def get_function_signature(self) -> str:
|
152
|
+
if self.func_name is None:
|
153
|
+
self.func_name = self.get_func_name()
|
154
|
+
output_type_hint = self.get_output_type_hint()
|
155
|
+
func_args = self.get_func_args()
|
156
|
+
return f"def {self.func_name}({func_args}){output_type_hint}:"
|
157
|
+
|
158
|
+
def get_output_type_hint(self) -> str:
|
159
|
+
"""
|
160
|
+
:return: A string containing the output type hint for the function.
|
161
|
+
"""
|
72
162
|
output_type_hint = ""
|
73
163
|
if self.prompt_for == PromptFor.Conditions:
|
74
164
|
output_type_hint = " -> bool"
|
75
165
|
elif self.prompt_for == PromptFor.Conclusion:
|
76
|
-
output_type_hint = f" -> {self.attribute_type_hint}"
|
77
|
-
|
166
|
+
output_type_hint = f" -> {self.case_query.attribute_type_hint}"
|
167
|
+
return output_type_hint
|
168
|
+
|
169
|
+
def get_func_args(self) -> str:
|
170
|
+
"""
|
171
|
+
:return: A string containing the function arguments.
|
172
|
+
"""
|
173
|
+
if self.case_query.is_function:
|
174
|
+
func_args = {}
|
175
|
+
for k, v in self.case_query.case.items():
|
176
|
+
if (self.case_query.function_args_type_hints is not None
|
177
|
+
and k in self.case_query.function_args_type_hints):
|
178
|
+
func_args[k] = typing_hint_to_str(self.case_query.function_args_type_hints[k])[0]
|
179
|
+
else:
|
180
|
+
func_args[k] = type(v).__name__ if not isinstance(v, type) else f"Type[{v.__name__}]"
|
181
|
+
func_args = ', '.join([f"{k}: {v}" if str(v) not in ["NoneType", "None"] else str(k)
|
182
|
+
for k, v in func_args.items()])
|
183
|
+
else:
|
184
|
+
func_args = f"case: {self.case_type.__name__}"
|
185
|
+
return func_args
|
78
186
|
|
79
187
|
def write_to_file(self, code: str):
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
188
|
+
if self.temp_file_path is None:
|
189
|
+
tmp = tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".py",
|
190
|
+
dir=self.workspace)
|
191
|
+
tmp.write(code)
|
192
|
+
tmp.flush()
|
193
|
+
self.temp_file_path = tmp.name
|
194
|
+
tmp.close()
|
195
|
+
else:
|
196
|
+
with open(self.temp_file_path, 'w+') as f:
|
197
|
+
f.write(code)
|
86
198
|
|
87
199
|
def get_imports(self):
|
88
|
-
|
200
|
+
"""
|
201
|
+
:return: A string containing the imports for the function.
|
202
|
+
"""
|
203
|
+
case_type_imports = []
|
204
|
+
if self.case_query.is_function:
|
205
|
+
for k, v in self.case_query.case.items():
|
206
|
+
if (self.case_query.function_args_type_hints is not None
|
207
|
+
and k in self.case_query.function_args_type_hints):
|
208
|
+
hint_list = typing_hint_to_str(self.case_query.function_args_type_hints[k])[1]
|
209
|
+
for hint in hint_list:
|
210
|
+
hint_split = hint.split('.')
|
211
|
+
if len(hint_split) > 1:
|
212
|
+
case_type_imports.append(f"from {'.'.join(hint_split[:-1])} import {hint_split[-1]}")
|
213
|
+
else:
|
214
|
+
if isinstance(v, type):
|
215
|
+
case_type_imports.append(f"from {v.__module__} import {v.__name__}")
|
216
|
+
elif hasattr(v, "__module__") and not v.__module__.startswith("__"):
|
217
|
+
case_type_imports.append(f"\nfrom {type(v).__module__} import {type(v).__name__}")
|
218
|
+
else:
|
219
|
+
case_type_imports.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
|
89
220
|
if self.output_type is None:
|
90
221
|
output_type_imports = [f"from typing_extensions import Any"]
|
91
222
|
else:
|
@@ -96,12 +227,32 @@ class MyMagics(Magics):
|
|
96
227
|
output_type_imports.append("from typing_extensions import List")
|
97
228
|
imports = get_imports_from_scope(self.scope)
|
98
229
|
imports = [i for i in imports if ("get_ipython" not in i)]
|
99
|
-
|
100
|
-
imports.append(case_type_import)
|
230
|
+
imports.extend(case_type_imports)
|
101
231
|
imports.extend([oti for oti in output_type_imports if oti not in imports])
|
102
232
|
imports = set(imports)
|
103
233
|
return '\n'.join(imports)
|
104
234
|
|
235
|
+
def get_func_doc(self) -> Optional[str]:
|
236
|
+
"""
|
237
|
+
:return: A string containing the function docstring.
|
238
|
+
"""
|
239
|
+
if self.prompt_for == PromptFor.Conditions:
|
240
|
+
return (f"Get conditions on whether it's possible to conclude a value"
|
241
|
+
f" for {self.case_query.name}")
|
242
|
+
else:
|
243
|
+
return f"Get possible value(s) for {self.case_query.name}"
|
244
|
+
|
245
|
+
def get_func_name(self) -> Optional[str]:
|
246
|
+
func_name = ""
|
247
|
+
if self.prompt_for == PromptFor.Conditions:
|
248
|
+
func_name = f"{self.prompt_for.value.lower()}_for_"
|
249
|
+
case_name = self.case_query.name.replace(".", "_")
|
250
|
+
if self.case_query.is_function:
|
251
|
+
# convert any CamelCase word into snake_case by adding _ before each capital letter
|
252
|
+
case_name = case_name.replace(f"_{self.case_query.attribute_name}", "")
|
253
|
+
func_name += case_name
|
254
|
+
return str_to_snake_case(func_name)
|
255
|
+
|
105
256
|
@cached_property
|
106
257
|
def case_type(self) -> Type:
|
107
258
|
"""
|
@@ -150,16 +301,20 @@ Loads the function defined in the temporary file into the user namespace, that c
|
|
150
301
|
"""
|
151
302
|
print(help_text)
|
152
303
|
|
304
|
+
def __del__(self):
|
305
|
+
if hasattr(self, 'process') and self.process is not None and self.process.poll() is None:
|
306
|
+
self.process.terminate() # Graceful shutdown
|
307
|
+
self.process.wait() # Ensure cleanup
|
308
|
+
|
153
309
|
|
154
310
|
class CustomInteractiveShell(InteractiveShellEmbed):
|
155
|
-
def __init__(self,
|
156
|
-
|
157
|
-
|
311
|
+
def __init__(self, code_to_modify: Optional[str] = None,
|
312
|
+
prompt_for: Optional[PromptFor] = None,
|
313
|
+
case_query: Optional[CaseQuery] = None,
|
314
|
+
**kwargs):
|
158
315
|
super().__init__(**kwargs)
|
159
|
-
|
160
|
-
|
161
|
-
magics_kwargs = {key: value for key, value in zip(keys, values) if value is not None}
|
162
|
-
self.my_magics = MyMagics(self, self.user_ns, **magics_kwargs)
|
316
|
+
self.my_magics = MyMagics(self, self.user_ns, code_to_modify=code_to_modify,
|
317
|
+
prompt_for=prompt_for, case_query=case_query)
|
163
318
|
self.register_magics(self.my_magics)
|
164
319
|
self.all_lines = []
|
165
320
|
|
@@ -201,19 +356,10 @@ class IPythonShell:
|
|
201
356
|
"""
|
202
357
|
self.scope: Dict = scope or {}
|
203
358
|
self.header: str = header or ">>> Embedded Ipython Shell"
|
204
|
-
output_type = None
|
205
|
-
if prompt_for is not None:
|
206
|
-
if prompt_for == PromptFor.Conclusion and case_query is not None:
|
207
|
-
output_type = case_query.attribute_type
|
208
|
-
elif prompt_for == PromptFor.Conditions:
|
209
|
-
output_type = bool
|
210
359
|
self.case_query: Optional[CaseQuery] = case_query
|
211
|
-
self.output_type: Optional[Type] = output_type
|
212
360
|
self.prompt_for: Optional[PromptFor] = prompt_for
|
213
361
|
self.code_to_modify: Optional[str] = code_to_modify
|
214
362
|
self.user_input: Optional[str] = None
|
215
|
-
self.func_name: str = ""
|
216
|
-
self.func_doc: str = ""
|
217
363
|
self.shell: CustomInteractiveShell = self._init_shell()
|
218
364
|
self.all_code_lines: List[str] = []
|
219
365
|
|
@@ -222,58 +368,13 @@ class IPythonShell:
|
|
222
368
|
Initialize the Ipython shell with a custom configuration.
|
223
369
|
"""
|
224
370
|
cfg = Config()
|
225
|
-
self.build_func_name_and_doc()
|
226
371
|
shell = CustomInteractiveShell(config=cfg, user_ns=self.scope, banner1=self.header,
|
227
|
-
output_type=self.output_type, func_name=self.func_name, func_doc=self.func_doc,
|
228
372
|
code_to_modify=self.code_to_modify,
|
229
|
-
|
230
|
-
|
373
|
+
prompt_for=self.prompt_for,
|
374
|
+
case_query=self.case_query,
|
375
|
+
)
|
231
376
|
return shell
|
232
377
|
|
233
|
-
def build_func_name_and_doc(self) -> Tuple[str, str]:
|
234
|
-
"""
|
235
|
-
Build the function name and docstring for the user-defined function.
|
236
|
-
|
237
|
-
:return: A tuple containing the function name and docstring.
|
238
|
-
"""
|
239
|
-
case = self.scope['case']
|
240
|
-
case_type = case._obj_type if isinstance(case, Case) else type(case)
|
241
|
-
self.func_name = self.build_func_name(case_type)
|
242
|
-
self.func_doc = self.build_func_doc(case_type)
|
243
|
-
|
244
|
-
def build_func_doc(self, case_type: Type) -> Optional[str]:
|
245
|
-
if self.case_query is None or self.prompt_for is None:
|
246
|
-
return
|
247
|
-
|
248
|
-
if self.prompt_for == PromptFor.Conditions:
|
249
|
-
func_doc = (f"Get conditions on whether it's possible to conclude a value"
|
250
|
-
f" for {case_type.__name__}.{self.case_query.attribute_name}")
|
251
|
-
elif self.prompt_for == PromptFor.Conclusion:
|
252
|
-
func_doc = f"Get possible value(s) for {case_type.__name__}.{self.case_query.attribute_name}"
|
253
|
-
else:
|
254
|
-
return
|
255
|
-
|
256
|
-
possible_types = [t.__name__ for t in self.case_query.attribute_type if t not in [list, set]]
|
257
|
-
if list in self.case_query.attribute_type:
|
258
|
-
func_doc += f" of type list of {' and/or '.join(possible_types)}"
|
259
|
-
else:
|
260
|
-
func_doc += f" of type(s) {', '.join(possible_types)}"
|
261
|
-
|
262
|
-
return func_doc
|
263
|
-
|
264
|
-
def build_func_name(self, case_type: Type) -> Optional[str]:
|
265
|
-
func_name = None
|
266
|
-
if self.prompt_for is not None:
|
267
|
-
func_name = f"get_{self.prompt_for.value.lower()}_for"
|
268
|
-
func_name += f"_{case_type.__name__}"
|
269
|
-
|
270
|
-
if self.case_query is not None:
|
271
|
-
func_name += f"_{self.case_query.attribute_name}"
|
272
|
-
output_names = [f"{t.__name__}" for t in self.case_query.attribute_type if t not in [list, set]]
|
273
|
-
func_name += '_of_type_' + '_'.join(output_names)
|
274
|
-
|
275
|
-
return func_name.lower() if func_name is not None else None
|
276
|
-
|
277
378
|
def run(self):
|
278
379
|
"""
|
279
380
|
Run the embedded shell.
|
@@ -300,12 +401,16 @@ class IPythonShell:
|
|
300
401
|
else:
|
301
402
|
self.user_input = '\n'.join(self.all_code_lines)
|
302
403
|
self.user_input = encapsulate_user_input(self.user_input, self.shell.my_magics.function_signature,
|
303
|
-
self.func_doc)
|
304
|
-
if
|
305
|
-
|
404
|
+
self.shell.my_magics.func_doc)
|
405
|
+
if self.case_query.is_function:
|
406
|
+
args = "**case"
|
407
|
+
else:
|
408
|
+
args = "case"
|
409
|
+
if f"return {self.shell.my_magics.func_name}({args})" not in self.user_input:
|
410
|
+
self.user_input = self.user_input.strip() + f"\nreturn {self.shell.my_magics.func_name}({args})"
|
306
411
|
|
307
412
|
|
308
|
-
def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, prompt_str: Optional[str] = None)\
|
413
|
+
def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, prompt_str: Optional[str] = None) \
|
309
414
|
-> Tuple[Optional[str], Optional[CallableExpression]]:
|
310
415
|
"""
|
311
416
|
Prompt the user for an executable python expression to the given case query.
|
@@ -320,7 +425,6 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, pro
|
|
320
425
|
while True:
|
321
426
|
user_input, expression_tree = prompt_user_about_case(case_query, prompt_for, prompt_str,
|
322
427
|
code_to_modify=prev_user_input)
|
323
|
-
prev_user_input = '\n'.join(user_input.split('\n')[2:-1])
|
324
428
|
if user_input is None:
|
325
429
|
if prompt_for == PromptFor.Conclusion:
|
326
430
|
print(f"{Fore.YELLOW}No conclusion provided. Exiting.{Style.RESET_ALL}")
|
@@ -328,9 +432,11 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, pro
|
|
328
432
|
else:
|
329
433
|
print(f"{Fore.RED}Conditions must be provided. Please try again.{Style.RESET_ALL}")
|
330
434
|
continue
|
435
|
+
prev_user_input = '\n'.join(user_input.split('\n')[2:-1])
|
331
436
|
conclusion_type = bool if prompt_for == PromptFor.Conditions else case_query.attribute_type
|
332
437
|
callable_expression = CallableExpression(user_input, conclusion_type, expression_tree=expression_tree,
|
333
|
-
scope=case_query.scope
|
438
|
+
scope=case_query.scope,
|
439
|
+
mutually_exclusive=case_query.mutually_exclusive)
|
334
440
|
try:
|
335
441
|
result = callable_expression(case_query.case)
|
336
442
|
if len(make_list(result)) == 0:
|
@@ -376,7 +482,7 @@ def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor,
|
|
376
482
|
|
377
483
|
|
378
484
|
def prompt_user_input_and_parse_to_expression(shell: Optional[IPythonShell] = None,
|
379
|
-
user_input: Optional[str] = None)\
|
485
|
+
user_input: Optional[str] = None) \
|
380
486
|
-> Tuple[Optional[str], Optional[ast.AST]]:
|
381
487
|
"""
|
382
488
|
Prompt the user for input.
|
ripple_down_rules/rdr.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import copyreg
|
3
4
|
import importlib
|
4
5
|
import sys
|
5
6
|
from abc import ABC, abstractmethod
|
@@ -96,11 +97,12 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
96
97
|
plt.ioff()
|
97
98
|
plt.show()
|
98
99
|
|
99
|
-
def __call__(self, case: Union[Case, SQLTable]) ->
|
100
|
+
def __call__(self, case: Union[Case, SQLTable]) -> Union[CallableExpression, Dict[str, CallableExpression]]:
|
100
101
|
return self.classify(case)
|
101
102
|
|
102
103
|
@abstractmethod
|
103
|
-
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False)
|
104
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) \
|
105
|
+
-> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
|
104
106
|
"""
|
105
107
|
Classify a case.
|
106
108
|
|
@@ -111,7 +113,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
111
113
|
pass
|
112
114
|
|
113
115
|
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
114
|
-
-> Union[
|
116
|
+
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
115
117
|
"""
|
116
118
|
Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
|
117
119
|
incorrect by comparing the case with the target category.
|
@@ -136,7 +138,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
136
138
|
|
137
139
|
@abstractmethod
|
138
140
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
139
|
-
-> Union[
|
141
|
+
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
140
142
|
"""
|
141
143
|
Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
|
142
144
|
comparing the case with the target category.
|
@@ -929,7 +931,7 @@ class GeneralRDR(RippleDownRules):
|
|
929
931
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|
930
932
|
# add rdr python generated functions.
|
931
933
|
for rdr_key, rdr in self.start_rules_dict.items():
|
932
|
-
imports += (f"from
|
934
|
+
imports += (f"from ."
|
933
935
|
f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
|
934
936
|
return imports
|
935
937
|
|