ripple-down-rules 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,97 +1,258 @@
1
1
  import ast
2
2
  import logging
3
3
  import os
4
+ import shutil
5
+ import socket
4
6
  import subprocess
5
7
  import tempfile
6
8
  from _ast import AST
7
9
  from functools import cached_property
8
10
  from textwrap import indent, dedent
9
11
 
10
- from IPython.core.magic import register_line_magic, line_magic, Magics, magics_class
12
+ from IPython.core.magic import line_magic, Magics, magics_class
11
13
  from IPython.terminal.embed import InteractiveShellEmbed
14
+ from colorama import Fore, Style
15
+ from pygments import highlight
16
+ from pygments.formatters.terminal import TerminalFormatter
17
+ from pygments.lexers.python import PythonLexer
12
18
  from traitlets.config import Config
13
- from typing_extensions import List, Optional, Tuple, Dict, Type, Union, Any
19
+ from typing_extensions import List, Optional, Tuple, Dict, Type
14
20
 
15
- from .datastructures.enums import PromptFor
16
- from .datastructures.case import Case
17
21
  from .datastructures.callable_expression import CallableExpression, parse_string_to_expression
22
+ from .datastructures.case import Case
18
23
  from .datastructures.dataclasses import CaseQuery
19
- from .utils import extract_dependencies, contains_return_statement, make_set, get_imports_from_scope, make_list, \
20
- get_import_from_type, get_imports_from_types, is_iterable, extract_function_source, encapsulate_user_input, \
21
- are_results_subclass_of_types
24
+ from .datastructures.enums import PromptFor, Editor
25
+ from .utils import extract_dependencies, contains_return_statement, get_imports_from_scope, make_list, \
26
+ get_imports_from_types, extract_function_source, encapsulate_user_input, str_to_snake_case, typing_hint_to_str
27
+
28
+
29
+ def detect_available_editor() -> Optional[Editor]:
30
+ """
31
+ Detect the available editor on the system.
32
+
33
+ :return: The first found editor that is available on the system.
34
+ """
35
+ editor_env = os.environ.get("RDR_EDITOR")
36
+ if editor_env:
37
+ return Editor.from_str(editor_env)
38
+ for editor in [Editor.Pycharm, Editor.Code, Editor.CodeServer]:
39
+ if shutil.which(editor.value):
40
+ return editor
41
+ return None
42
+
43
+
44
+ def is_port_in_use(port: int = 8080) -> bool:
45
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
46
+ return s.connect_ex(("localhost", port)) == 0
47
+
48
+
49
+ def start_code_server(workspace):
50
+ """
51
+ Start the code-server in the given workspace.
52
+ """
53
+ filename = os.path.join(os.path.dirname(__file__), "start-code-server.sh")
54
+ os.system(f"chmod +x {filename}")
55
+ print(f"Starting code-server at {filename}")
56
+ return subprocess.Popen(["/bin/bash", filename, workspace], stdout=subprocess.PIPE,
57
+ stderr=subprocess.PIPE, text=True)
22
58
 
23
59
 
24
60
  @magics_class
25
61
  class MyMagics(Magics):
26
- def __init__(self, shell, scope, output_type: Optional[Type] = None, func_name: str = "user_case",
27
- func_doc: str = "User defined function to be executed on the case.",
28
- code_to_modify: Optional[str] = None):
62
+ temp_file_path: Optional[str] = None
63
+ """
64
+ The path to the temporary file that is created for the user to edit.
65
+ """
66
+ port: int = int(os.environ.get("RDR_EDITOR_PORT", 8080))
67
+ """
68
+ The port to use for the code-server.
69
+ """
70
+ process: Optional[subprocess.Popen] = None
71
+ """
72
+ The process of the code-server.
73
+ """
74
+
75
+ def __init__(self, shell, scope,
76
+ code_to_modify: Optional[str] = None,
77
+ prompt_for: Optional[PromptFor] = None,
78
+ case_query: Optional[CaseQuery] = None):
29
79
  super().__init__(shell)
30
80
  self.scope = scope
31
- self.temp_file_path = None
32
- self.func_name = func_name
33
- self.func_doc = func_doc
34
81
  self.code_to_modify = code_to_modify
35
- self.output_type = make_list(output_type) if output_type is not None else None
82
+ self.prompt_for = prompt_for
83
+ self.case_query = case_query
84
+ self.output_type = self.get_output_type()
36
85
  self.user_edit_line = 0
37
- self.function_signature: Optional[str] = None
38
- self.build_function_signature()
86
+ self.func_name: str = self.get_func_name()
87
+ self.func_doc: str = self.get_func_doc()
88
+ self.function_signature: str = self.get_function_signature()
89
+ self.editor: Optional[Editor] = detect_available_editor()
90
+ self.workspace: str = os.environ.get("RDR_EDITOR_WORKSPACE", os.path.dirname(self.scope['__file__']))
91
+ self.temp_file_path: str = os.path.join(self.workspace, "edit_code_here.py")
92
+
93
+ def get_output_type(self) -> List[Type]:
94
+ """
95
+ :return: The output type of the function as a list of types.
96
+ """
97
+ if self.prompt_for == PromptFor.Conditions:
98
+ output_type = bool
99
+ else:
100
+ output_type = self.case_query.attribute_type
101
+ return make_list(output_type) if output_type is not None else None
39
102
 
40
103
  @line_magic
41
- def edit_case(self, line):
104
+ def edit(self, line):
105
+ if self.editor is None:
106
+ print(f"{Fore.RED}ERROR:: No editor found. Please install PyCharm, VSCode or code-server.{Style.RESET_ALL}")
107
+ return
42
108
 
43
109
  boilerplate_code = self.build_boilerplate_code()
44
-
45
110
  self.write_to_file(boilerplate_code)
46
111
 
47
- print(f"Opening {self.temp_file_path} in PyCharm...")
48
- subprocess.Popen(["pycharm", "--line", str(self.user_edit_line), self.temp_file_path])
112
+ self.open_file_in_editor()
113
+
114
+ def open_file_in_editor(self):
115
+ """
116
+ Open the file in the available editor.
117
+ """
118
+ if self.editor == Editor.Pycharm:
119
+ subprocess.Popen(["pycharm", "--line", str(self.user_edit_line), self.temp_file_path],
120
+ stdout=subprocess.DEVNULL,
121
+ stderr=subprocess.DEVNULL)
122
+ elif self.editor == Editor.Code:
123
+ subprocess.Popen(["code", self.temp_file_path])
124
+ elif self.editor == Editor.CodeServer:
125
+ try:
126
+ subprocess.check_output(["pgrep", "-f", "code-server"])
127
+ # check if same port is in use
128
+ if is_port_in_use(self.port):
129
+ print(f"Code-server is already running on port {self.port}.")
130
+ else:
131
+ raise ValueError("Port is not in use.")
132
+ except (subprocess.CalledProcessError, ValueError) as e:
133
+ self.process = start_code_server(self.workspace)
134
+ print(f"Open code-server in your browser at http://localhost:{self.port}?folder={self.workspace}")
135
+ print(f"Edit the file: {Fore.BLUE}{self.temp_file_path}")
49
136
 
50
137
  def build_boilerplate_code(self):
51
138
  imports = self.get_imports()
52
- self.build_function_signature()
139
+ if self.function_signature is None:
140
+ self.function_signature = self.get_function_signature()
141
+ if self.func_doc is None:
142
+ self.func_doc = self.get_func_doc()
53
143
  if self.code_to_modify is not None:
54
144
  body = indent(dedent(self.code_to_modify), ' ')
55
145
  else:
56
146
  body = " # Write your code here\n pass"
57
147
  boilerplate = f"""{imports}\n\n{self.function_signature}\n \"\"\"{self.func_doc}\"\"\"\n{body}"""
58
- self.user_edit_line = imports.count('\n')+6
148
+ self.user_edit_line = imports.count('\n') + 6
59
149
  return boilerplate
60
150
 
61
- def build_function_signature(self):
62
- if self.output_type is None:
63
- output_type_hint = ""
64
- elif len(self.output_type) == 1:
65
- output_type_hint = f" -> {self.output_type[0].__name__}"
151
+ def get_function_signature(self) -> str:
152
+ if self.func_name is None:
153
+ self.func_name = self.get_func_name()
154
+ output_type_hint = self.get_output_type_hint()
155
+ func_args = self.get_func_args()
156
+ return f"def {self.func_name}({func_args}){output_type_hint}:"
157
+
158
+ def get_output_type_hint(self) -> str:
159
+ """
160
+ :return: A string containing the output type hint for the function.
161
+ """
162
+ output_type_hint = ""
163
+ if self.prompt_for == PromptFor.Conditions:
164
+ output_type_hint = " -> bool"
165
+ elif self.prompt_for == PromptFor.Conclusion:
166
+ output_type_hint = f" -> {self.case_query.attribute_type_hint}"
167
+ return output_type_hint
168
+
169
+ def get_func_args(self) -> str:
170
+ """
171
+ :return: A string containing the function arguments.
172
+ """
173
+ if self.case_query.is_function:
174
+ func_args = {}
175
+ for k, v in self.case_query.case.items():
176
+ if (self.case_query.function_args_type_hints is not None
177
+ and k in self.case_query.function_args_type_hints):
178
+ func_args[k] = typing_hint_to_str(self.case_query.function_args_type_hints[k])[0]
179
+ else:
180
+ func_args[k] = type(v).__name__ if not isinstance(v, type) else f"Type[{v.__name__}]"
181
+ func_args = ', '.join([f"{k}: {v}" if str(v) not in ["NoneType", "None"] else str(k)
182
+ for k, v in func_args.items()])
66
183
  else:
67
- output_type_hint = f" -> Union[{', '.join([t.__name__ for t in self.output_type])}]"
68
- self.function_signature = f"def {self.func_name}(case: {self.case_type.__name__}){output_type_hint}:"
184
+ func_args = f"case: {self.case_type.__name__}"
185
+ return func_args
69
186
 
70
187
  def write_to_file(self, code: str):
71
- tmp = tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".py",
72
- dir=os.path.dirname(self.scope['__file__']))
73
- tmp.write(code)
74
- tmp.flush()
75
- self.temp_file_path = tmp.name
76
- tmp.close()
188
+ if self.temp_file_path is None:
189
+ tmp = tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".py",
190
+ dir=self.workspace)
191
+ tmp.write(code)
192
+ tmp.flush()
193
+ self.temp_file_path = tmp.name
194
+ tmp.close()
195
+ else:
196
+ with open(self.temp_file_path, 'w+') as f:
197
+ f.write(code)
77
198
 
78
199
  def get_imports(self):
79
- case_type_import = f"from {self.case_type.__module__} import {self.case_type.__name__}"
200
+ """
201
+ :return: A string containing the imports for the function.
202
+ """
203
+ case_type_imports = []
204
+ if self.case_query.is_function:
205
+ for k, v in self.case_query.case.items():
206
+ if (self.case_query.function_args_type_hints is not None
207
+ and k in self.case_query.function_args_type_hints):
208
+ hint_list = typing_hint_to_str(self.case_query.function_args_type_hints[k])[1]
209
+ for hint in hint_list:
210
+ hint_split = hint.split('.')
211
+ if len(hint_split) > 1:
212
+ case_type_imports.append(f"from {'.'.join(hint_split[:-1])} import {hint_split[-1]}")
213
+ else:
214
+ if isinstance(v, type):
215
+ case_type_imports.append(f"from {v.__module__} import {v.__name__}")
216
+ elif hasattr(v, "__module__") and not v.__module__.startswith("__"):
217
+ case_type_imports.append(f"\nfrom {type(v).__module__} import {type(v).__name__}")
218
+ else:
219
+ case_type_imports.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
80
220
  if self.output_type is None:
81
221
  output_type_imports = [f"from typing_extensions import Any"]
82
222
  else:
83
223
  output_type_imports = get_imports_from_types(self.output_type)
84
224
  if len(self.output_type) > 1:
85
225
  output_type_imports.append("from typing_extensions import Union")
86
- print(output_type_imports)
226
+ if list in self.output_type:
227
+ output_type_imports.append("from typing_extensions import List")
87
228
  imports = get_imports_from_scope(self.scope)
88
229
  imports = [i for i in imports if ("get_ipython" not in i)]
89
- if case_type_import not in imports:
90
- imports.append(case_type_import)
230
+ imports.extend(case_type_imports)
91
231
  imports.extend([oti for oti in output_type_imports if oti not in imports])
92
232
  imports = set(imports)
93
233
  return '\n'.join(imports)
94
234
 
235
+ def get_func_doc(self) -> Optional[str]:
236
+ """
237
+ :return: A string containing the function docstring.
238
+ """
239
+ if self.prompt_for == PromptFor.Conditions:
240
+ return (f"Get conditions on whether it's possible to conclude a value"
241
+ f" for {self.case_query.name}")
242
+ else:
243
+ return f"Get possible value(s) for {self.case_query.name}"
244
+
245
+ def get_func_name(self) -> Optional[str]:
246
+ func_name = ""
247
+ if self.prompt_for == PromptFor.Conditions:
248
+ func_name = f"{self.prompt_for.value.lower()}_for_"
249
+ case_name = self.case_query.name.replace(".", "_")
250
+ if self.case_query.is_function:
251
+ # convert any CamelCase word into snake_case by adding _ before each capital letter
252
+ case_name = case_name.replace(f"_{self.case_query.attribute_name}", "")
253
+ func_name += case_name
254
+ return str_to_snake_case(func_name)
255
+
95
256
  @cached_property
96
257
  def case_type(self) -> Type:
97
258
  """
@@ -103,9 +264,9 @@ class MyMagics(Magics):
103
264
  return case._obj_type if isinstance(case, Case) else type(case)
104
265
 
105
266
  @line_magic
106
- def load_case(self, line):
267
+ def load(self, line):
107
268
  if not self.temp_file_path:
108
- print("No file to load. Run %edit_case first.")
269
+ print(f"{Fore.RED}ERROR:: No file to load. Run %edit first.{Style.RESET_ALL}")
109
270
  return
110
271
 
111
272
  with open(self.temp_file_path, 'r') as f:
@@ -118,20 +279,42 @@ class MyMagics(Magics):
118
279
  exec(source, self.scope, exec_globals)
119
280
  user_function = exec_globals[self.func_name]
120
281
  self.shell.user_ns[self.func_name] = user_function
121
- print(f"Loaded `{self.func_name}` function into user namespace.")
282
+ print(f"{Fore.BLUE}Loaded `{self.func_name}` function into user namespace.{Style.RESET_ALL}")
122
283
  return
123
284
 
124
- print(f"Function `{self.func_name}` not found.")
285
+ print(f"{Fore.RED}ERROR:: Function `{self.func_name}` not found.{Style.RESET_ALL}")
286
+
287
+ @line_magic
288
+ def help(self, line):
289
+ """
290
+ Display help information for the Ipython shell.
291
+ """
292
+ help_text = f"""
293
+ Directly write python code in the shell, and then `{Fore.GREEN}return {Fore.RESET}output`. Or use
294
+ the magic commands to write the code in a temporary file and edit it in PyCharm:
295
+ {Fore.MAGENTA}Usage: %edit{Style.RESET_ALL}
296
+ Opens a temporary file in PyCharm for editing a function (conclusion or conditions for case)
297
+ that will be executed on the case object.
298
+ {Fore.MAGENTA}Usage: %load{Style.RESET_ALL}
299
+ Loads the function defined in the temporary file into the user namespace, that can then be used inside the
300
+ Ipython shell. You can then do `{Fore.GREEN}return {Fore.RESET}function_name(case)`.
301
+ """
302
+ print(help_text)
303
+
304
+ def __del__(self):
305
+ if hasattr(self, 'process') and self.process is not None and self.process.poll() is None:
306
+ self.process.terminate() # Graceful shutdown
307
+ self.process.wait() # Ensure cleanup
125
308
 
126
309
 
127
310
  class CustomInteractiveShell(InteractiveShellEmbed):
128
- def __init__(self, output_type: Union[Type, Tuple[Type], None] = None, func_name: Optional[str] = None,
129
- func_doc: Optional[str] = None, code_to_modify: Optional[str] = None, **kwargs):
311
+ def __init__(self, code_to_modify: Optional[str] = None,
312
+ prompt_for: Optional[PromptFor] = None,
313
+ case_query: Optional[CaseQuery] = None,
314
+ **kwargs):
130
315
  super().__init__(**kwargs)
131
- keys = ['output_type', 'func_name', 'func_doc', 'code_to_modify']
132
- values = [output_type, func_name, func_doc, code_to_modify]
133
- magics_kwargs = {key: value for key, value in zip(keys, values) if value is not None}
134
- self.my_magics = MyMagics(self, self.user_ns, **magics_kwargs)
316
+ self.my_magics = MyMagics(self, self.user_ns, code_to_modify=code_to_modify,
317
+ prompt_for=prompt_for, case_query=case_query)
135
318
  self.register_magics(self.my_magics)
136
319
  self.all_lines = []
137
320
 
@@ -145,7 +328,6 @@ class CustomInteractiveShell(InteractiveShellEmbed):
145
328
  self.my_magics.func_name,
146
329
  join_lines=False)[self.my_magics.func_name]
147
330
  self.all_lines.append(raw_cell)
148
- print("Exiting shell on `return` statement.")
149
331
  self.history_manager.store_inputs(line_num=self.execution_count, source=raw_cell)
150
332
  self.ask_exit()
151
333
  return None
@@ -161,30 +343,23 @@ class IPythonShell:
161
343
  """
162
344
 
163
345
  def __init__(self, scope: Optional[Dict] = None, header: Optional[str] = None,
164
- output_type: Optional[Type] = None, prompt_for: Optional[PromptFor] = None,
165
- attribute_name: Optional[str] = None, attribute_type: Optional[Type] = None,
346
+ prompt_for: Optional[PromptFor] = None, case_query: Optional[CaseQuery] = None,
166
347
  code_to_modify: Optional[str] = None):
167
348
  """
168
349
  Initialize the Ipython shell with the given scope and header.
169
350
 
170
351
  :param scope: The scope to use for the shell.
171
352
  :param header: The header to display when the shell is started.
172
- :param output_type: The type of the output from user input.
173
353
  :param prompt_for: The type of information to ask the user about.
174
- :param attribute_name: The name of the attribute of the case.
175
- :param attribute_type: The type of the attribute of the case.
354
+ :param case_query: The case query which contains the case and the attribute to ask about.
176
355
  :param code_to_modify: The code to modify. If given, will be used as a start for user to modify.
177
356
  """
178
357
  self.scope: Dict = scope or {}
179
358
  self.header: str = header or ">>> Embedded Ipython Shell"
180
- self.output_type: Optional[Type] = output_type
359
+ self.case_query: Optional[CaseQuery] = case_query
181
360
  self.prompt_for: Optional[PromptFor] = prompt_for
182
- self.attribute_name: Optional[str] = attribute_name
183
- self.attribute_type: Optional[Type] = attribute_type
184
361
  self.code_to_modify: Optional[str] = code_to_modify
185
362
  self.user_input: Optional[str] = None
186
- self.func_name: str = ""
187
- self.func_doc: str = ""
188
363
  self.shell: CustomInteractiveShell = self._init_shell()
189
364
  self.all_code_lines: List[str] = []
190
365
 
@@ -193,49 +368,13 @@ class IPythonShell:
193
368
  Initialize the Ipython shell with a custom configuration.
194
369
  """
195
370
  cfg = Config()
196
- self.build_func_name_and_doc()
197
371
  shell = CustomInteractiveShell(config=cfg, user_ns=self.scope, banner1=self.header,
198
- output_type=self.output_type, func_name=self.func_name, func_doc=self.func_doc,
199
- code_to_modify=self.code_to_modify)
372
+ code_to_modify=self.code_to_modify,
373
+ prompt_for=self.prompt_for,
374
+ case_query=self.case_query,
375
+ )
200
376
  return shell
201
377
 
202
- def build_func_name_and_doc(self) -> Tuple[str, str]:
203
- """
204
- Build the function name and docstring for the user-defined function.
205
-
206
- :return: A tuple containing the function name and docstring.
207
- """
208
- case = self.scope['case']
209
- case_type = case._obj_type if isinstance(case, Case) else type(case)
210
- self.func_name = self.build_func_name(case_type)
211
- self.func_doc = self.build_func_doc(case_type)
212
-
213
- def build_func_doc(self, case_type: Type):
214
- if self.prompt_for == PromptFor.Conditions:
215
- func_doc = (f"Get conditions on whether it's possible to conclude a value"
216
- f" for {case_type.__name__}.{self.attribute_name}")
217
- else:
218
- func_doc = f"Get possible value(s) for {case_type.__name__}.{self.attribute_name}"
219
- if is_iterable(self.attribute_type):
220
- possible_types = [t.__name__ for t in self.attribute_type if t not in [list, set]]
221
- func_doc += f" of types list/set of {' and/or '.join(possible_types)}"
222
- else:
223
- func_doc += f" of type {self.attribute_type.__name__}"
224
- return func_doc
225
-
226
- def build_func_name(self, case_type: Type):
227
- func_name = f"get_{self.prompt_for.value.lower()}_for"
228
- func_name += f"_{case_type.__name__}"
229
- if self.attribute_name is not None:
230
- func_name += f"_{self.attribute_name}"
231
- if is_iterable(self.attribute_type):
232
- output_names = [f"{t.__name__}" for t in self.attribute_type if t not in [list, set]]
233
- else:
234
- output_names = [self.attribute_type.__name__] if self.attribute_type is not None else None
235
- if output_names is not None:
236
- func_name += '_of_type_' + '_'.join(output_names)
237
- return func_name.lower()
238
-
239
378
  def run(self):
240
379
  """
241
380
  Run the embedded shell.
@@ -247,7 +386,7 @@ class IPythonShell:
247
386
  break
248
387
  except Exception as e:
249
388
  logging.error(e)
250
- print(e)
389
+ print(f"{Fore.RED}ERROR::{e}{Style.RESET_ALL}")
251
390
 
252
391
  def update_user_input_from_code_lines(self):
253
392
  """
@@ -262,12 +401,16 @@ class IPythonShell:
262
401
  else:
263
402
  self.user_input = '\n'.join(self.all_code_lines)
264
403
  self.user_input = encapsulate_user_input(self.user_input, self.shell.my_magics.function_signature,
265
- self.func_doc)
266
- if f"return {self.func_name}(case)" not in self.user_input:
267
- self.user_input = self.user_input.strip() + f"\nreturn {self.func_name}(case)"
404
+ self.shell.my_magics.func_doc)
405
+ if self.case_query.is_function:
406
+ args = "**case"
407
+ else:
408
+ args = "case"
409
+ if f"return {self.shell.my_magics.func_name}({args})" not in self.user_input:
410
+ self.user_input = self.user_input.strip() + f"\nreturn {self.shell.my_magics.func_name}({args})"
268
411
 
269
412
 
270
- def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, prompt_str: Optional[str] = None)\
413
+ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, prompt_str: Optional[str] = None) \
271
414
  -> Tuple[Optional[str], Optional[CallableExpression]]:
272
415
  """
273
416
  Prompt the user for an executable python expression to the given case query.
@@ -282,26 +425,28 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, pro
282
425
  while True:
283
426
  user_input, expression_tree = prompt_user_about_case(case_query, prompt_for, prompt_str,
284
427
  code_to_modify=prev_user_input)
285
- prev_user_input = '\n'.join(user_input.split('\n')[2:-1])
286
428
  if user_input is None:
287
429
  if prompt_for == PromptFor.Conclusion:
288
- print("No conclusion provided. Exiting.")
430
+ print(f"{Fore.YELLOW}No conclusion provided. Exiting.{Style.RESET_ALL}")
289
431
  return None, None
290
432
  else:
291
- print("Conditions must be provided. Please try again.")
433
+ print(f"{Fore.RED}Conditions must be provided. Please try again.{Style.RESET_ALL}")
292
434
  continue
435
+ prev_user_input = '\n'.join(user_input.split('\n')[2:-1])
293
436
  conclusion_type = bool if prompt_for == PromptFor.Conditions else case_query.attribute_type
294
437
  callable_expression = CallableExpression(user_input, conclusion_type, expression_tree=expression_tree,
295
- scope=case_query.scope)
438
+ scope=case_query.scope,
439
+ mutually_exclusive=case_query.mutually_exclusive)
296
440
  try:
297
441
  result = callable_expression(case_query.case)
298
442
  if len(make_list(result)) == 0:
299
- print(f"The given expression gave an empty result for case {case_query.name}. Please modify!")
443
+ print(f"{Fore.YELLOW}The given expression gave an empty result for case {case_query.name}."
444
+ f" Please modify!{Style.RESET_ALL}")
300
445
  continue
301
446
  break
302
447
  except Exception as e:
303
448
  logging.error(e)
304
- print(e)
449
+ print(f"{Fore.RED}{e}{Style.RESET_ALL}")
305
450
  return user_input, callable_expression
306
451
 
307
452
 
@@ -318,17 +463,26 @@ def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor,
318
463
  :return: The user input, and the executable expression that was parsed from the user input.
319
464
  """
320
465
  if prompt_str is None:
321
- prompt_str = f"Give {prompt_for} for {case_query.name}"
466
+ if prompt_for == PromptFor.Conclusion:
467
+ prompt_str = f"Give possible value(s) for:\n"
468
+ else:
469
+ prompt_str = f"Give conditions on when can the rule be evaluated for:\n"
470
+ prompt_str += (f"{Fore.CYAN}{case_query.name}{Fore.MAGENTA} of type(s) "
471
+ f"{Fore.CYAN}({', '.join(map(lambda x: x.__name__, case_query.core_attribute_type))}){Fore.MAGENTA}")
472
+ if prompt_for == PromptFor.Conditions:
473
+ prompt_str += (f"\ne.g. `{Fore.GREEN}return {Fore.BLUE}len{Fore.RESET}(case.attribute) > {Fore.BLUE}0` "
474
+ f"{Fore.MAGENTA}\nOR `{Fore.GREEN}return {Fore.YELLOW}True`{Fore.MAGENTA} (If you want the"
475
+ f" rule to be always evaluated) \n"
476
+ f"You can also do {Fore.YELLOW}%edit{Fore.MAGENTA} for more complex conditions.")
477
+ prompt_str = f"{Fore.MAGENTA}{prompt_str}{Fore.YELLOW}\n(Write %help for guide){Fore.RESET}"
322
478
  scope = {'case': case_query.case, **case_query.scope}
323
- output_type = case_query.attribute_type if prompt_for == PromptFor.Conclusion else bool
324
- shell = IPythonShell(scope=scope, header=prompt_str, output_type=output_type, prompt_for=prompt_for,
325
- attribute_name=case_query.attribute_name, attribute_type=case_query.attribute_type,
479
+ shell = IPythonShell(scope=scope, header=prompt_str, prompt_for=prompt_for, case_query=case_query,
326
480
  code_to_modify=code_to_modify)
327
481
  return prompt_user_input_and_parse_to_expression(shell=shell)
328
482
 
329
483
 
330
484
  def prompt_user_input_and_parse_to_expression(shell: Optional[IPythonShell] = None,
331
- user_input: Optional[str] = None)\
485
+ user_input: Optional[str] = None) \
332
486
  -> Tuple[Optional[str], Optional[ast.AST]]:
333
487
  """
334
488
  Prompt the user for input.
@@ -344,11 +498,13 @@ def prompt_user_input_and_parse_to_expression(shell: Optional[IPythonShell] = No
344
498
  user_input = shell.user_input
345
499
  if user_input is None:
346
500
  return None, None
347
- print(user_input)
501
+ print(f"{Fore.BLUE}Captured User input: {Style.RESET_ALL}")
502
+ highlighted_code = highlight(user_input, PythonLexer(), TerminalFormatter())
503
+ print(highlighted_code)
348
504
  try:
349
505
  return user_input, parse_string_to_expression(user_input)
350
506
  except Exception as e:
351
507
  msg = f"Error parsing expression: {e}"
352
508
  logging.error(msg)
353
- print(msg)
509
+ print(f"{Fore.RED}{msg}{Style.RESET_ALL}")
354
510
  user_input = None
ripple_down_rules/rdr.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import copyreg
3
4
  import importlib
4
5
  import sys
5
6
  from abc import ABC, abstractmethod
@@ -96,11 +97,12 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
96
97
  plt.ioff()
97
98
  plt.show()
98
99
 
99
- def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
100
+ def __call__(self, case: Union[Case, SQLTable]) -> Union[CallableExpression, Dict[str, CallableExpression]]:
100
101
  return self.classify(case)
101
102
 
102
103
  @abstractmethod
103
- def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) -> Optional[CaseAttribute]:
104
+ def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) \
105
+ -> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
104
106
  """
105
107
  Classify a case.
106
108
 
@@ -111,7 +113,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
111
113
  pass
112
114
 
113
115
  def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
114
- -> Union[CaseAttribute, CallableExpression]:
116
+ -> Union[CallableExpression, Dict[str, CallableExpression]]:
115
117
  """
116
118
  Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
117
119
  incorrect by comparing the case with the target category.
@@ -136,7 +138,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
136
138
 
137
139
  @abstractmethod
138
140
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
139
- -> Union[CaseAttribute, CallableExpression]:
141
+ -> Union[CallableExpression, Dict[str, CallableExpression]]:
140
142
  """
141
143
  Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
142
144
  comparing the case with the target category.
@@ -929,7 +931,7 @@ class GeneralRDR(RippleDownRules):
929
931
  imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
930
932
  # add rdr python generated functions.
931
933
  for rdr_key, rdr in self.start_rules_dict.items():
932
- imports += (f"from {file_path.strip('./')}"
934
+ imports += (f"from ."
933
935
  f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
934
936
  return imports
935
937