ripple-down-rules 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +66 -6
- ripple_down_rules/datastructures/callable_expression.py +13 -3
- ripple_down_rules/datastructures/case.py +33 -5
- ripple_down_rules/datastructures/dataclasses.py +53 -9
- ripple_down_rules/datastructures/enums.py +30 -1
- ripple_down_rules/experts.py +2 -1
- ripple_down_rules/prompt.py +274 -118
- ripple_down_rules/rdr.py +7 -5
- ripple_down_rules/rdr_decorators.py +122 -38
- ripple_down_rules/utils.py +162 -18
- {ripple_down_rules-0.2.3.dist-info → ripple_down_rules-0.3.0.dist-info}/METADATA +1 -1
- ripple_down_rules-0.3.0.dist-info/RECORD +20 -0
- {ripple_down_rules-0.2.3.dist-info → ripple_down_rules-0.3.0.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.2.3.dist-info/RECORD +0 -20
- {ripple_down_rules-0.2.3.dist-info → ripple_down_rules-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.2.3.dist-info → ripple_down_rules-0.3.0.dist-info}/top_level.txt +0 -0
ripple_down_rules/prompt.py
CHANGED
@@ -1,97 +1,258 @@
|
|
1
1
|
import ast
|
2
2
|
import logging
|
3
3
|
import os
|
4
|
+
import shutil
|
5
|
+
import socket
|
4
6
|
import subprocess
|
5
7
|
import tempfile
|
6
8
|
from _ast import AST
|
7
9
|
from functools import cached_property
|
8
10
|
from textwrap import indent, dedent
|
9
11
|
|
10
|
-
from IPython.core.magic import
|
12
|
+
from IPython.core.magic import line_magic, Magics, magics_class
|
11
13
|
from IPython.terminal.embed import InteractiveShellEmbed
|
14
|
+
from colorama import Fore, Style
|
15
|
+
from pygments import highlight
|
16
|
+
from pygments.formatters.terminal import TerminalFormatter
|
17
|
+
from pygments.lexers.python import PythonLexer
|
12
18
|
from traitlets.config import Config
|
13
|
-
from typing_extensions import List, Optional, Tuple, Dict, Type
|
19
|
+
from typing_extensions import List, Optional, Tuple, Dict, Type
|
14
20
|
|
15
|
-
from .datastructures.enums import PromptFor
|
16
|
-
from .datastructures.case import Case
|
17
21
|
from .datastructures.callable_expression import CallableExpression, parse_string_to_expression
|
22
|
+
from .datastructures.case import Case
|
18
23
|
from .datastructures.dataclasses import CaseQuery
|
19
|
-
from .
|
20
|
-
|
21
|
-
|
24
|
+
from .datastructures.enums import PromptFor, Editor
|
25
|
+
from .utils import extract_dependencies, contains_return_statement, get_imports_from_scope, make_list, \
|
26
|
+
get_imports_from_types, extract_function_source, encapsulate_user_input, str_to_snake_case, typing_hint_to_str
|
27
|
+
|
28
|
+
|
29
|
+
def detect_available_editor() -> Optional[Editor]:
|
30
|
+
"""
|
31
|
+
Detect the available editor on the system.
|
32
|
+
|
33
|
+
:return: The first found editor that is available on the system.
|
34
|
+
"""
|
35
|
+
editor_env = os.environ.get("RDR_EDITOR")
|
36
|
+
if editor_env:
|
37
|
+
return Editor.from_str(editor_env)
|
38
|
+
for editor in [Editor.Pycharm, Editor.Code, Editor.CodeServer]:
|
39
|
+
if shutil.which(editor.value):
|
40
|
+
return editor
|
41
|
+
return None
|
42
|
+
|
43
|
+
|
44
|
+
def is_port_in_use(port: int = 8080) -> bool:
|
45
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
46
|
+
return s.connect_ex(("localhost", port)) == 0
|
47
|
+
|
48
|
+
|
49
|
+
def start_code_server(workspace):
|
50
|
+
"""
|
51
|
+
Start the code-server in the given workspace.
|
52
|
+
"""
|
53
|
+
filename = os.path.join(os.path.dirname(__file__), "start-code-server.sh")
|
54
|
+
os.system(f"chmod +x {filename}")
|
55
|
+
print(f"Starting code-server at {filename}")
|
56
|
+
return subprocess.Popen(["/bin/bash", filename, workspace], stdout=subprocess.PIPE,
|
57
|
+
stderr=subprocess.PIPE, text=True)
|
22
58
|
|
23
59
|
|
24
60
|
@magics_class
|
25
61
|
class MyMagics(Magics):
|
26
|
-
|
27
|
-
|
28
|
-
|
62
|
+
temp_file_path: Optional[str] = None
|
63
|
+
"""
|
64
|
+
The path to the temporary file that is created for the user to edit.
|
65
|
+
"""
|
66
|
+
port: int = int(os.environ.get("RDR_EDITOR_PORT", 8080))
|
67
|
+
"""
|
68
|
+
The port to use for the code-server.
|
69
|
+
"""
|
70
|
+
process: Optional[subprocess.Popen] = None
|
71
|
+
"""
|
72
|
+
The process of the code-server.
|
73
|
+
"""
|
74
|
+
|
75
|
+
def __init__(self, shell, scope,
|
76
|
+
code_to_modify: Optional[str] = None,
|
77
|
+
prompt_for: Optional[PromptFor] = None,
|
78
|
+
case_query: Optional[CaseQuery] = None):
|
29
79
|
super().__init__(shell)
|
30
80
|
self.scope = scope
|
31
|
-
self.temp_file_path = None
|
32
|
-
self.func_name = func_name
|
33
|
-
self.func_doc = func_doc
|
34
81
|
self.code_to_modify = code_to_modify
|
35
|
-
self.
|
82
|
+
self.prompt_for = prompt_for
|
83
|
+
self.case_query = case_query
|
84
|
+
self.output_type = self.get_output_type()
|
36
85
|
self.user_edit_line = 0
|
37
|
-
self.
|
38
|
-
self.
|
86
|
+
self.func_name: str = self.get_func_name()
|
87
|
+
self.func_doc: str = self.get_func_doc()
|
88
|
+
self.function_signature: str = self.get_function_signature()
|
89
|
+
self.editor: Optional[Editor] = detect_available_editor()
|
90
|
+
self.workspace: str = os.environ.get("RDR_EDITOR_WORKSPACE", os.path.dirname(self.scope['__file__']))
|
91
|
+
self.temp_file_path: str = os.path.join(self.workspace, "edit_code_here.py")
|
92
|
+
|
93
|
+
def get_output_type(self) -> List[Type]:
|
94
|
+
"""
|
95
|
+
:return: The output type of the function as a list of types.
|
96
|
+
"""
|
97
|
+
if self.prompt_for == PromptFor.Conditions:
|
98
|
+
output_type = bool
|
99
|
+
else:
|
100
|
+
output_type = self.case_query.attribute_type
|
101
|
+
return make_list(output_type) if output_type is not None else None
|
39
102
|
|
40
103
|
@line_magic
|
41
|
-
def
|
104
|
+
def edit(self, line):
|
105
|
+
if self.editor is None:
|
106
|
+
print(f"{Fore.RED}ERROR:: No editor found. Please install PyCharm, VSCode or code-server.{Style.RESET_ALL}")
|
107
|
+
return
|
42
108
|
|
43
109
|
boilerplate_code = self.build_boilerplate_code()
|
44
|
-
|
45
110
|
self.write_to_file(boilerplate_code)
|
46
111
|
|
47
|
-
|
48
|
-
|
112
|
+
self.open_file_in_editor()
|
113
|
+
|
114
|
+
def open_file_in_editor(self):
|
115
|
+
"""
|
116
|
+
Open the file in the available editor.
|
117
|
+
"""
|
118
|
+
if self.editor == Editor.Pycharm:
|
119
|
+
subprocess.Popen(["pycharm", "--line", str(self.user_edit_line), self.temp_file_path],
|
120
|
+
stdout=subprocess.DEVNULL,
|
121
|
+
stderr=subprocess.DEVNULL)
|
122
|
+
elif self.editor == Editor.Code:
|
123
|
+
subprocess.Popen(["code", self.temp_file_path])
|
124
|
+
elif self.editor == Editor.CodeServer:
|
125
|
+
try:
|
126
|
+
subprocess.check_output(["pgrep", "-f", "code-server"])
|
127
|
+
# check if same port is in use
|
128
|
+
if is_port_in_use(self.port):
|
129
|
+
print(f"Code-server is already running on port {self.port}.")
|
130
|
+
else:
|
131
|
+
raise ValueError("Port is not in use.")
|
132
|
+
except (subprocess.CalledProcessError, ValueError) as e:
|
133
|
+
self.process = start_code_server(self.workspace)
|
134
|
+
print(f"Open code-server in your browser at http://localhost:{self.port}?folder={self.workspace}")
|
135
|
+
print(f"Edit the file: {Fore.BLUE}{self.temp_file_path}")
|
49
136
|
|
50
137
|
def build_boilerplate_code(self):
|
51
138
|
imports = self.get_imports()
|
52
|
-
self.
|
139
|
+
if self.function_signature is None:
|
140
|
+
self.function_signature = self.get_function_signature()
|
141
|
+
if self.func_doc is None:
|
142
|
+
self.func_doc = self.get_func_doc()
|
53
143
|
if self.code_to_modify is not None:
|
54
144
|
body = indent(dedent(self.code_to_modify), ' ')
|
55
145
|
else:
|
56
146
|
body = " # Write your code here\n pass"
|
57
147
|
boilerplate = f"""{imports}\n\n{self.function_signature}\n \"\"\"{self.func_doc}\"\"\"\n{body}"""
|
58
|
-
self.user_edit_line = imports.count('\n')+6
|
148
|
+
self.user_edit_line = imports.count('\n') + 6
|
59
149
|
return boilerplate
|
60
150
|
|
61
|
-
def
|
62
|
-
if self.
|
63
|
-
|
64
|
-
|
65
|
-
|
151
|
+
def get_function_signature(self) -> str:
|
152
|
+
if self.func_name is None:
|
153
|
+
self.func_name = self.get_func_name()
|
154
|
+
output_type_hint = self.get_output_type_hint()
|
155
|
+
func_args = self.get_func_args()
|
156
|
+
return f"def {self.func_name}({func_args}){output_type_hint}:"
|
157
|
+
|
158
|
+
def get_output_type_hint(self) -> str:
|
159
|
+
"""
|
160
|
+
:return: A string containing the output type hint for the function.
|
161
|
+
"""
|
162
|
+
output_type_hint = ""
|
163
|
+
if self.prompt_for == PromptFor.Conditions:
|
164
|
+
output_type_hint = " -> bool"
|
165
|
+
elif self.prompt_for == PromptFor.Conclusion:
|
166
|
+
output_type_hint = f" -> {self.case_query.attribute_type_hint}"
|
167
|
+
return output_type_hint
|
168
|
+
|
169
|
+
def get_func_args(self) -> str:
|
170
|
+
"""
|
171
|
+
:return: A string containing the function arguments.
|
172
|
+
"""
|
173
|
+
if self.case_query.is_function:
|
174
|
+
func_args = {}
|
175
|
+
for k, v in self.case_query.case.items():
|
176
|
+
if (self.case_query.function_args_type_hints is not None
|
177
|
+
and k in self.case_query.function_args_type_hints):
|
178
|
+
func_args[k] = typing_hint_to_str(self.case_query.function_args_type_hints[k])[0]
|
179
|
+
else:
|
180
|
+
func_args[k] = type(v).__name__ if not isinstance(v, type) else f"Type[{v.__name__}]"
|
181
|
+
func_args = ', '.join([f"{k}: {v}" if str(v) not in ["NoneType", "None"] else str(k)
|
182
|
+
for k, v in func_args.items()])
|
66
183
|
else:
|
67
|
-
|
68
|
-
|
184
|
+
func_args = f"case: {self.case_type.__name__}"
|
185
|
+
return func_args
|
69
186
|
|
70
187
|
def write_to_file(self, code: str):
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
188
|
+
if self.temp_file_path is None:
|
189
|
+
tmp = tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".py",
|
190
|
+
dir=self.workspace)
|
191
|
+
tmp.write(code)
|
192
|
+
tmp.flush()
|
193
|
+
self.temp_file_path = tmp.name
|
194
|
+
tmp.close()
|
195
|
+
else:
|
196
|
+
with open(self.temp_file_path, 'w+') as f:
|
197
|
+
f.write(code)
|
77
198
|
|
78
199
|
def get_imports(self):
|
79
|
-
|
200
|
+
"""
|
201
|
+
:return: A string containing the imports for the function.
|
202
|
+
"""
|
203
|
+
case_type_imports = []
|
204
|
+
if self.case_query.is_function:
|
205
|
+
for k, v in self.case_query.case.items():
|
206
|
+
if (self.case_query.function_args_type_hints is not None
|
207
|
+
and k in self.case_query.function_args_type_hints):
|
208
|
+
hint_list = typing_hint_to_str(self.case_query.function_args_type_hints[k])[1]
|
209
|
+
for hint in hint_list:
|
210
|
+
hint_split = hint.split('.')
|
211
|
+
if len(hint_split) > 1:
|
212
|
+
case_type_imports.append(f"from {'.'.join(hint_split[:-1])} import {hint_split[-1]}")
|
213
|
+
else:
|
214
|
+
if isinstance(v, type):
|
215
|
+
case_type_imports.append(f"from {v.__module__} import {v.__name__}")
|
216
|
+
elif hasattr(v, "__module__") and not v.__module__.startswith("__"):
|
217
|
+
case_type_imports.append(f"\nfrom {type(v).__module__} import {type(v).__name__}")
|
218
|
+
else:
|
219
|
+
case_type_imports.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
|
80
220
|
if self.output_type is None:
|
81
221
|
output_type_imports = [f"from typing_extensions import Any"]
|
82
222
|
else:
|
83
223
|
output_type_imports = get_imports_from_types(self.output_type)
|
84
224
|
if len(self.output_type) > 1:
|
85
225
|
output_type_imports.append("from typing_extensions import Union")
|
86
|
-
|
226
|
+
if list in self.output_type:
|
227
|
+
output_type_imports.append("from typing_extensions import List")
|
87
228
|
imports = get_imports_from_scope(self.scope)
|
88
229
|
imports = [i for i in imports if ("get_ipython" not in i)]
|
89
|
-
|
90
|
-
imports.append(case_type_import)
|
230
|
+
imports.extend(case_type_imports)
|
91
231
|
imports.extend([oti for oti in output_type_imports if oti not in imports])
|
92
232
|
imports = set(imports)
|
93
233
|
return '\n'.join(imports)
|
94
234
|
|
235
|
+
def get_func_doc(self) -> Optional[str]:
|
236
|
+
"""
|
237
|
+
:return: A string containing the function docstring.
|
238
|
+
"""
|
239
|
+
if self.prompt_for == PromptFor.Conditions:
|
240
|
+
return (f"Get conditions on whether it's possible to conclude a value"
|
241
|
+
f" for {self.case_query.name}")
|
242
|
+
else:
|
243
|
+
return f"Get possible value(s) for {self.case_query.name}"
|
244
|
+
|
245
|
+
def get_func_name(self) -> Optional[str]:
|
246
|
+
func_name = ""
|
247
|
+
if self.prompt_for == PromptFor.Conditions:
|
248
|
+
func_name = f"{self.prompt_for.value.lower()}_for_"
|
249
|
+
case_name = self.case_query.name.replace(".", "_")
|
250
|
+
if self.case_query.is_function:
|
251
|
+
# convert any CamelCase word into snake_case by adding _ before each capital letter
|
252
|
+
case_name = case_name.replace(f"_{self.case_query.attribute_name}", "")
|
253
|
+
func_name += case_name
|
254
|
+
return str_to_snake_case(func_name)
|
255
|
+
|
95
256
|
@cached_property
|
96
257
|
def case_type(self) -> Type:
|
97
258
|
"""
|
@@ -103,9 +264,9 @@ class MyMagics(Magics):
|
|
103
264
|
return case._obj_type if isinstance(case, Case) else type(case)
|
104
265
|
|
105
266
|
@line_magic
|
106
|
-
def
|
267
|
+
def load(self, line):
|
107
268
|
if not self.temp_file_path:
|
108
|
-
print("No file to load. Run %
|
269
|
+
print(f"{Fore.RED}ERROR:: No file to load. Run %edit first.{Style.RESET_ALL}")
|
109
270
|
return
|
110
271
|
|
111
272
|
with open(self.temp_file_path, 'r') as f:
|
@@ -118,20 +279,42 @@ class MyMagics(Magics):
|
|
118
279
|
exec(source, self.scope, exec_globals)
|
119
280
|
user_function = exec_globals[self.func_name]
|
120
281
|
self.shell.user_ns[self.func_name] = user_function
|
121
|
-
print(f"Loaded `{self.func_name}` function into user namespace.")
|
282
|
+
print(f"{Fore.BLUE}Loaded `{self.func_name}` function into user namespace.{Style.RESET_ALL}")
|
122
283
|
return
|
123
284
|
|
124
|
-
print(f"Function `{self.func_name}` not found.")
|
285
|
+
print(f"{Fore.RED}ERROR:: Function `{self.func_name}` not found.{Style.RESET_ALL}")
|
286
|
+
|
287
|
+
@line_magic
|
288
|
+
def help(self, line):
|
289
|
+
"""
|
290
|
+
Display help information for the Ipython shell.
|
291
|
+
"""
|
292
|
+
help_text = f"""
|
293
|
+
Directly write python code in the shell, and then `{Fore.GREEN}return {Fore.RESET}output`. Or use
|
294
|
+
the magic commands to write the code in a temporary file and edit it in PyCharm:
|
295
|
+
{Fore.MAGENTA}Usage: %edit{Style.RESET_ALL}
|
296
|
+
Opens a temporary file in PyCharm for editing a function (conclusion or conditions for case)
|
297
|
+
that will be executed on the case object.
|
298
|
+
{Fore.MAGENTA}Usage: %load{Style.RESET_ALL}
|
299
|
+
Loads the function defined in the temporary file into the user namespace, that can then be used inside the
|
300
|
+
Ipython shell. You can then do `{Fore.GREEN}return {Fore.RESET}function_name(case)`.
|
301
|
+
"""
|
302
|
+
print(help_text)
|
303
|
+
|
304
|
+
def __del__(self):
|
305
|
+
if hasattr(self, 'process') and self.process is not None and self.process.poll() is None:
|
306
|
+
self.process.terminate() # Graceful shutdown
|
307
|
+
self.process.wait() # Ensure cleanup
|
125
308
|
|
126
309
|
|
127
310
|
class CustomInteractiveShell(InteractiveShellEmbed):
|
128
|
-
def __init__(self,
|
129
|
-
|
311
|
+
def __init__(self, code_to_modify: Optional[str] = None,
|
312
|
+
prompt_for: Optional[PromptFor] = None,
|
313
|
+
case_query: Optional[CaseQuery] = None,
|
314
|
+
**kwargs):
|
130
315
|
super().__init__(**kwargs)
|
131
|
-
|
132
|
-
|
133
|
-
magics_kwargs = {key: value for key, value in zip(keys, values) if value is not None}
|
134
|
-
self.my_magics = MyMagics(self, self.user_ns, **magics_kwargs)
|
316
|
+
self.my_magics = MyMagics(self, self.user_ns, code_to_modify=code_to_modify,
|
317
|
+
prompt_for=prompt_for, case_query=case_query)
|
135
318
|
self.register_magics(self.my_magics)
|
136
319
|
self.all_lines = []
|
137
320
|
|
@@ -145,7 +328,6 @@ class CustomInteractiveShell(InteractiveShellEmbed):
|
|
145
328
|
self.my_magics.func_name,
|
146
329
|
join_lines=False)[self.my_magics.func_name]
|
147
330
|
self.all_lines.append(raw_cell)
|
148
|
-
print("Exiting shell on `return` statement.")
|
149
331
|
self.history_manager.store_inputs(line_num=self.execution_count, source=raw_cell)
|
150
332
|
self.ask_exit()
|
151
333
|
return None
|
@@ -161,30 +343,23 @@ class IPythonShell:
|
|
161
343
|
"""
|
162
344
|
|
163
345
|
def __init__(self, scope: Optional[Dict] = None, header: Optional[str] = None,
|
164
|
-
|
165
|
-
attribute_name: Optional[str] = None, attribute_type: Optional[Type] = None,
|
346
|
+
prompt_for: Optional[PromptFor] = None, case_query: Optional[CaseQuery] = None,
|
166
347
|
code_to_modify: Optional[str] = None):
|
167
348
|
"""
|
168
349
|
Initialize the Ipython shell with the given scope and header.
|
169
350
|
|
170
351
|
:param scope: The scope to use for the shell.
|
171
352
|
:param header: The header to display when the shell is started.
|
172
|
-
:param output_type: The type of the output from user input.
|
173
353
|
:param prompt_for: The type of information to ask the user about.
|
174
|
-
:param
|
175
|
-
:param attribute_type: The type of the attribute of the case.
|
354
|
+
:param case_query: The case query which contains the case and the attribute to ask about.
|
176
355
|
:param code_to_modify: The code to modify. If given, will be used as a start for user to modify.
|
177
356
|
"""
|
178
357
|
self.scope: Dict = scope or {}
|
179
358
|
self.header: str = header or ">>> Embedded Ipython Shell"
|
180
|
-
self.
|
359
|
+
self.case_query: Optional[CaseQuery] = case_query
|
181
360
|
self.prompt_for: Optional[PromptFor] = prompt_for
|
182
|
-
self.attribute_name: Optional[str] = attribute_name
|
183
|
-
self.attribute_type: Optional[Type] = attribute_type
|
184
361
|
self.code_to_modify: Optional[str] = code_to_modify
|
185
362
|
self.user_input: Optional[str] = None
|
186
|
-
self.func_name: str = ""
|
187
|
-
self.func_doc: str = ""
|
188
363
|
self.shell: CustomInteractiveShell = self._init_shell()
|
189
364
|
self.all_code_lines: List[str] = []
|
190
365
|
|
@@ -193,49 +368,13 @@ class IPythonShell:
|
|
193
368
|
Initialize the Ipython shell with a custom configuration.
|
194
369
|
"""
|
195
370
|
cfg = Config()
|
196
|
-
self.build_func_name_and_doc()
|
197
371
|
shell = CustomInteractiveShell(config=cfg, user_ns=self.scope, banner1=self.header,
|
198
|
-
|
199
|
-
|
372
|
+
code_to_modify=self.code_to_modify,
|
373
|
+
prompt_for=self.prompt_for,
|
374
|
+
case_query=self.case_query,
|
375
|
+
)
|
200
376
|
return shell
|
201
377
|
|
202
|
-
def build_func_name_and_doc(self) -> Tuple[str, str]:
|
203
|
-
"""
|
204
|
-
Build the function name and docstring for the user-defined function.
|
205
|
-
|
206
|
-
:return: A tuple containing the function name and docstring.
|
207
|
-
"""
|
208
|
-
case = self.scope['case']
|
209
|
-
case_type = case._obj_type if isinstance(case, Case) else type(case)
|
210
|
-
self.func_name = self.build_func_name(case_type)
|
211
|
-
self.func_doc = self.build_func_doc(case_type)
|
212
|
-
|
213
|
-
def build_func_doc(self, case_type: Type):
|
214
|
-
if self.prompt_for == PromptFor.Conditions:
|
215
|
-
func_doc = (f"Get conditions on whether it's possible to conclude a value"
|
216
|
-
f" for {case_type.__name__}.{self.attribute_name}")
|
217
|
-
else:
|
218
|
-
func_doc = f"Get possible value(s) for {case_type.__name__}.{self.attribute_name}"
|
219
|
-
if is_iterable(self.attribute_type):
|
220
|
-
possible_types = [t.__name__ for t in self.attribute_type if t not in [list, set]]
|
221
|
-
func_doc += f" of types list/set of {' and/or '.join(possible_types)}"
|
222
|
-
else:
|
223
|
-
func_doc += f" of type {self.attribute_type.__name__}"
|
224
|
-
return func_doc
|
225
|
-
|
226
|
-
def build_func_name(self, case_type: Type):
|
227
|
-
func_name = f"get_{self.prompt_for.value.lower()}_for"
|
228
|
-
func_name += f"_{case_type.__name__}"
|
229
|
-
if self.attribute_name is not None:
|
230
|
-
func_name += f"_{self.attribute_name}"
|
231
|
-
if is_iterable(self.attribute_type):
|
232
|
-
output_names = [f"{t.__name__}" for t in self.attribute_type if t not in [list, set]]
|
233
|
-
else:
|
234
|
-
output_names = [self.attribute_type.__name__] if self.attribute_type is not None else None
|
235
|
-
if output_names is not None:
|
236
|
-
func_name += '_of_type_' + '_'.join(output_names)
|
237
|
-
return func_name.lower()
|
238
|
-
|
239
378
|
def run(self):
|
240
379
|
"""
|
241
380
|
Run the embedded shell.
|
@@ -247,7 +386,7 @@ class IPythonShell:
|
|
247
386
|
break
|
248
387
|
except Exception as e:
|
249
388
|
logging.error(e)
|
250
|
-
print(e)
|
389
|
+
print(f"{Fore.RED}ERROR::{e}{Style.RESET_ALL}")
|
251
390
|
|
252
391
|
def update_user_input_from_code_lines(self):
|
253
392
|
"""
|
@@ -262,12 +401,16 @@ class IPythonShell:
|
|
262
401
|
else:
|
263
402
|
self.user_input = '\n'.join(self.all_code_lines)
|
264
403
|
self.user_input = encapsulate_user_input(self.user_input, self.shell.my_magics.function_signature,
|
265
|
-
self.func_doc)
|
266
|
-
if
|
267
|
-
|
404
|
+
self.shell.my_magics.func_doc)
|
405
|
+
if self.case_query.is_function:
|
406
|
+
args = "**case"
|
407
|
+
else:
|
408
|
+
args = "case"
|
409
|
+
if f"return {self.shell.my_magics.func_name}({args})" not in self.user_input:
|
410
|
+
self.user_input = self.user_input.strip() + f"\nreturn {self.shell.my_magics.func_name}({args})"
|
268
411
|
|
269
412
|
|
270
|
-
def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, prompt_str: Optional[str] = None)\
|
413
|
+
def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, prompt_str: Optional[str] = None) \
|
271
414
|
-> Tuple[Optional[str], Optional[CallableExpression]]:
|
272
415
|
"""
|
273
416
|
Prompt the user for an executable python expression to the given case query.
|
@@ -282,26 +425,28 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, pro
|
|
282
425
|
while True:
|
283
426
|
user_input, expression_tree = prompt_user_about_case(case_query, prompt_for, prompt_str,
|
284
427
|
code_to_modify=prev_user_input)
|
285
|
-
prev_user_input = '\n'.join(user_input.split('\n')[2:-1])
|
286
428
|
if user_input is None:
|
287
429
|
if prompt_for == PromptFor.Conclusion:
|
288
|
-
print("No conclusion provided. Exiting.")
|
430
|
+
print(f"{Fore.YELLOW}No conclusion provided. Exiting.{Style.RESET_ALL}")
|
289
431
|
return None, None
|
290
432
|
else:
|
291
|
-
print("Conditions must be provided. Please try again.")
|
433
|
+
print(f"{Fore.RED}Conditions must be provided. Please try again.{Style.RESET_ALL}")
|
292
434
|
continue
|
435
|
+
prev_user_input = '\n'.join(user_input.split('\n')[2:-1])
|
293
436
|
conclusion_type = bool if prompt_for == PromptFor.Conditions else case_query.attribute_type
|
294
437
|
callable_expression = CallableExpression(user_input, conclusion_type, expression_tree=expression_tree,
|
295
|
-
scope=case_query.scope
|
438
|
+
scope=case_query.scope,
|
439
|
+
mutually_exclusive=case_query.mutually_exclusive)
|
296
440
|
try:
|
297
441
|
result = callable_expression(case_query.case)
|
298
442
|
if len(make_list(result)) == 0:
|
299
|
-
print(f"The given expression gave an empty result for case {case_query.name}.
|
443
|
+
print(f"{Fore.YELLOW}The given expression gave an empty result for case {case_query.name}."
|
444
|
+
f" Please modify!{Style.RESET_ALL}")
|
300
445
|
continue
|
301
446
|
break
|
302
447
|
except Exception as e:
|
303
448
|
logging.error(e)
|
304
|
-
print(e)
|
449
|
+
print(f"{Fore.RED}{e}{Style.RESET_ALL}")
|
305
450
|
return user_input, callable_expression
|
306
451
|
|
307
452
|
|
@@ -318,17 +463,26 @@ def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor,
|
|
318
463
|
:return: The user input, and the executable expression that was parsed from the user input.
|
319
464
|
"""
|
320
465
|
if prompt_str is None:
|
321
|
-
|
466
|
+
if prompt_for == PromptFor.Conclusion:
|
467
|
+
prompt_str = f"Give possible value(s) for:\n"
|
468
|
+
else:
|
469
|
+
prompt_str = f"Give conditions on when can the rule be evaluated for:\n"
|
470
|
+
prompt_str += (f"{Fore.CYAN}{case_query.name}{Fore.MAGENTA} of type(s) "
|
471
|
+
f"{Fore.CYAN}({', '.join(map(lambda x: x.__name__, case_query.core_attribute_type))}){Fore.MAGENTA}")
|
472
|
+
if prompt_for == PromptFor.Conditions:
|
473
|
+
prompt_str += (f"\ne.g. `{Fore.GREEN}return {Fore.BLUE}len{Fore.RESET}(case.attribute) > {Fore.BLUE}0` "
|
474
|
+
f"{Fore.MAGENTA}\nOR `{Fore.GREEN}return {Fore.YELLOW}True`{Fore.MAGENTA} (If you want the"
|
475
|
+
f" rule to be always evaluated) \n"
|
476
|
+
f"You can also do {Fore.YELLOW}%edit{Fore.MAGENTA} for more complex conditions.")
|
477
|
+
prompt_str = f"{Fore.MAGENTA}{prompt_str}{Fore.YELLOW}\n(Write %help for guide){Fore.RESET}"
|
322
478
|
scope = {'case': case_query.case, **case_query.scope}
|
323
|
-
|
324
|
-
shell = IPythonShell(scope=scope, header=prompt_str, output_type=output_type, prompt_for=prompt_for,
|
325
|
-
attribute_name=case_query.attribute_name, attribute_type=case_query.attribute_type,
|
479
|
+
shell = IPythonShell(scope=scope, header=prompt_str, prompt_for=prompt_for, case_query=case_query,
|
326
480
|
code_to_modify=code_to_modify)
|
327
481
|
return prompt_user_input_and_parse_to_expression(shell=shell)
|
328
482
|
|
329
483
|
|
330
484
|
def prompt_user_input_and_parse_to_expression(shell: Optional[IPythonShell] = None,
|
331
|
-
user_input: Optional[str] = None)\
|
485
|
+
user_input: Optional[str] = None) \
|
332
486
|
-> Tuple[Optional[str], Optional[ast.AST]]:
|
333
487
|
"""
|
334
488
|
Prompt the user for input.
|
@@ -344,11 +498,13 @@ def prompt_user_input_and_parse_to_expression(shell: Optional[IPythonShell] = No
|
|
344
498
|
user_input = shell.user_input
|
345
499
|
if user_input is None:
|
346
500
|
return None, None
|
347
|
-
print(
|
501
|
+
print(f"{Fore.BLUE}Captured User input: {Style.RESET_ALL}")
|
502
|
+
highlighted_code = highlight(user_input, PythonLexer(), TerminalFormatter())
|
503
|
+
print(highlighted_code)
|
348
504
|
try:
|
349
505
|
return user_input, parse_string_to_expression(user_input)
|
350
506
|
except Exception as e:
|
351
507
|
msg = f"Error parsing expression: {e}"
|
352
508
|
logging.error(msg)
|
353
|
-
print(msg)
|
509
|
+
print(f"{Fore.RED}{msg}{Style.RESET_ALL}")
|
354
510
|
user_input = None
|
ripple_down_rules/rdr.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import copyreg
|
3
4
|
import importlib
|
4
5
|
import sys
|
5
6
|
from abc import ABC, abstractmethod
|
@@ -96,11 +97,12 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
96
97
|
plt.ioff()
|
97
98
|
plt.show()
|
98
99
|
|
99
|
-
def __call__(self, case: Union[Case, SQLTable]) ->
|
100
|
+
def __call__(self, case: Union[Case, SQLTable]) -> Union[CallableExpression, Dict[str, CallableExpression]]:
|
100
101
|
return self.classify(case)
|
101
102
|
|
102
103
|
@abstractmethod
|
103
|
-
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False)
|
104
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) \
|
105
|
+
-> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
|
104
106
|
"""
|
105
107
|
Classify a case.
|
106
108
|
|
@@ -111,7 +113,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
111
113
|
pass
|
112
114
|
|
113
115
|
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
114
|
-
-> Union[
|
116
|
+
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
115
117
|
"""
|
116
118
|
Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
|
117
119
|
incorrect by comparing the case with the target category.
|
@@ -136,7 +138,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
136
138
|
|
137
139
|
@abstractmethod
|
138
140
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
139
|
-
-> Union[
|
141
|
+
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
140
142
|
"""
|
141
143
|
Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
|
142
144
|
comparing the case with the target category.
|
@@ -929,7 +931,7 @@ class GeneralRDR(RippleDownRules):
|
|
929
931
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|
930
932
|
# add rdr python generated functions.
|
931
933
|
for rdr_key, rdr in self.start_rules_dict.items():
|
932
|
-
imports += (f"from
|
934
|
+
imports += (f"from ."
|
933
935
|
f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
|
934
936
|
return imports
|
935
937
|
|