ripple-down-rules 0.1.69__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,8 @@ from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
9
9
 
10
10
  from .case import create_case, Case
11
11
  from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string, conclusion_to_json, is_iterable, \
12
- build_user_input_from_conclusion, encapsulate_user_input
12
+ build_user_input_from_conclusion, encapsulate_user_input, extract_function_source, are_results_subclass_of_types, \
13
+ make_list
13
14
 
14
15
 
15
16
  class VariableVisitor(ast.NodeVisitor):
@@ -109,7 +110,7 @@ class CallableExpression(SubclassJSONSerializer):
109
110
  if user_input is None:
110
111
  user_input = build_user_input_from_conclusion(conclusion)
111
112
  self.conclusion: Optional[Any] = conclusion
112
- self.user_input: str = encapsulate_user_input(user_input, self.encapsulating_function)
113
+ self._user_input: str = encapsulate_user_input(user_input, self.encapsulating_function)
113
114
  if conclusion_type is not None:
114
115
  if is_iterable(conclusion_type):
115
116
  conclusion_type = tuple(conclusion_type)
@@ -133,15 +134,16 @@ class CallableExpression(SubclassJSONSerializer):
133
134
  if output is None:
134
135
  output = scope['_get_value'](case)
135
136
  if self.conclusion_type is not None:
136
- if is_iterable(output) and not isinstance(output, self.conclusion_type):
137
- assert isinstance(list(output)[0], self.conclusion_type), (f"Expected output type {self.conclusion_type},"
138
- f" got {type(output)}")
139
- else:
140
- assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
141
- f" got {type(output)}")
137
+ output_types = {type(o) for o in make_list(output)}
138
+ output_types.add(type(output))
139
+ if not are_results_subclass_of_types(output_types, self.conclusion_type):
140
+ raise ValueError(f"Not all result types {output_types} are subclasses of expected types"
141
+ f" {self.conclusion_type}")
142
142
  return output
143
- else:
143
+ elif self.conclusion is not None:
144
144
  return self.conclusion
145
+ else:
146
+ raise ValueError("Either user_input or conclusion must be provided.")
145
147
  except Exception as e:
146
148
  raise ValueError(f"Error during evaluation: {e}")
147
149
 
@@ -156,6 +158,35 @@ class CallableExpression(SubclassJSONSerializer):
156
158
  f"return _cond1(case) and _cond2(case)")
157
159
  return CallableExpression(new_user_input, conclusion_type=self.conclusion_type)
158
160
 
161
+ def update_user_input_from_file(self, file_path: str, function_name: str):
162
+ """
163
+ Update the user input from a file.
164
+ """
165
+ new_function_body = extract_function_source(file_path, [function_name])[function_name]
166
+ if new_function_body is None:
167
+ return
168
+ self.user_input = self.encapsulating_function + '\n' + new_function_body
169
+
170
+ @property
171
+ def user_input(self):
172
+ """
173
+ Get the user input.
174
+ """
175
+ return self._user_input
176
+
177
+ @user_input.setter
178
+ def user_input(self, value: str):
179
+ """
180
+ Set the user input.
181
+ """
182
+ if value is not None:
183
+ self._user_input = encapsulate_user_input(value, self.encapsulating_function)
184
+ self.scope = get_used_scope(self.user_input, self.scope)
185
+ self.expression_tree = parse_string_to_expression(self.user_input)
186
+ self.code = compile_expression_to_code(self.expression_tree)
187
+ self.visitor = VariableVisitor()
188
+ self.visitor.visit(self.expression_tree)
189
+
159
190
  def __eq__(self, other):
160
191
  """
161
192
  Check if two callable expressions are equal.
@@ -225,7 +256,7 @@ def parse_string_to_expression(expression_str: str) -> AST:
225
256
  :param expression_str: The string which will be parsed.
226
257
  :return: The parsed expression.
227
258
  """
228
- if not expression_str.startswith('def'):
259
+ if not expression_str.startswith(CallableExpression.encapsulating_function):
229
260
  expression_str = encapsulate_user_input(expression_str, CallableExpression.encapsulating_function)
230
261
  mode = 'exec' if expression_str.startswith('def') else 'eval'
231
262
  tree = ast.parse(expression_str, mode=mode)
@@ -83,6 +83,13 @@ class CaseQuery:
83
83
  raise ValueError("The case must be a Case or SQLTable object.")
84
84
  self._case = value
85
85
 
86
+ @property
87
+ def core_attribute_type(self) -> Tuple[Type]:
88
+ """
89
+ :return: The core type of the attribute.
90
+ """
91
+ return (t for t in self.attribute_type if t not in (set, list))
92
+
86
93
  @property
87
94
  def attribute_type(self) -> Tuple[Type]:
88
95
  """
@@ -59,7 +59,7 @@ class PromptFor(Enum):
59
59
  """
60
60
  Prompt for rule conditions about a case.
61
61
  """
62
- Conclusion: str = "conclusion"
62
+ Conclusion: str = "value"
63
63
  """
64
64
  Prompt for rule conclusion about a case.
65
65
  """
@@ -1,27 +1,149 @@
1
1
  import ast
2
2
  import logging
3
+ import os
4
+ import subprocess
5
+ import tempfile
3
6
  from _ast import AST
7
+ from functools import cached_property
8
+ from textwrap import indent, dedent
4
9
 
10
+ from IPython.core.magic import register_line_magic, line_magic, Magics, magics_class
5
11
  from IPython.terminal.embed import InteractiveShellEmbed
6
12
  from traitlets.config import Config
7
- from typing_extensions import List, Optional, Tuple, Dict
13
+ from typing_extensions import List, Optional, Tuple, Dict, Type, Union, Any
8
14
 
9
15
  from .datastructures.enums import PromptFor
16
+ from .datastructures.case import Case
10
17
  from .datastructures.callable_expression import CallableExpression, parse_string_to_expression
11
18
  from .datastructures.dataclasses import CaseQuery
12
- from .utils import extract_dependencies, contains_return_statement, make_set
19
+ from .utils import extract_dependencies, contains_return_statement, make_set, get_imports_from_scope, make_list, \
20
+ get_import_from_type, get_imports_from_types, is_iterable, extract_function_source, encapsulate_user_input, \
21
+ are_results_subclass_of_types
22
+
23
+
24
+ @magics_class
25
+ class MyMagics(Magics):
26
+ def __init__(self, shell, scope, output_type: Optional[Type] = None, func_name: str = "user_case",
27
+ func_doc: str = "User defined function to be executed on the case.",
28
+ code_to_modify: Optional[str] = None):
29
+ super().__init__(shell)
30
+ self.scope = scope
31
+ self.temp_file_path = None
32
+ self.func_name = func_name
33
+ self.func_doc = func_doc
34
+ self.code_to_modify = code_to_modify
35
+ self.output_type = make_list(output_type) if output_type is not None else None
36
+ self.user_edit_line = 0
37
+ self.function_signature: Optional[str] = None
38
+ self.build_function_signature()
39
+
40
+ @line_magic
41
+ def edit_case(self, line):
42
+
43
+ boilerplate_code = self.build_boilerplate_code()
44
+
45
+ self.write_to_file(boilerplate_code)
46
+
47
+ print(f"Opening {self.temp_file_path} in PyCharm...")
48
+ subprocess.Popen(["pycharm", "--line", str(self.user_edit_line), self.temp_file_path])
49
+
50
+ def build_boilerplate_code(self):
51
+ imports = self.get_imports()
52
+ self.build_function_signature()
53
+ if self.code_to_modify is not None:
54
+ body = indent(dedent(self.code_to_modify), ' ')
55
+ else:
56
+ body = " # Write your code here\n pass"
57
+ boilerplate = f"""{imports}\n\n{self.function_signature}\n \"\"\"{self.func_doc}\"\"\"\n{body}"""
58
+ self.user_edit_line = imports.count('\n')+6
59
+ return boilerplate
60
+
61
+ def build_function_signature(self):
62
+ if self.output_type is None:
63
+ output_type_hint = ""
64
+ elif len(self.output_type) == 1:
65
+ output_type_hint = f" -> {self.output_type[0].__name__}"
66
+ else:
67
+ output_type_hint = f" -> Union[{', '.join([t.__name__ for t in self.output_type])}]"
68
+ self.function_signature = f"def {self.func_name}(case: {self.case_type.__name__}){output_type_hint}:"
69
+
70
+ def write_to_file(self, code: str):
71
+ tmp = tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".py",
72
+ dir=os.path.dirname(self.scope['__file__']))
73
+ tmp.write(code)
74
+ tmp.flush()
75
+ self.temp_file_path = tmp.name
76
+ tmp.close()
77
+
78
+ def get_imports(self):
79
+ case_type_import = f"from {self.case_type.__module__} import {self.case_type.__name__}"
80
+ if self.output_type is None:
81
+ output_type_imports = [f"from typing_extensions import Any"]
82
+ else:
83
+ output_type_imports = get_imports_from_types(self.output_type)
84
+ if len(self.output_type) > 1:
85
+ output_type_imports.append("from typing_extensions import Union")
86
+ print(output_type_imports)
87
+ imports = get_imports_from_scope(self.scope)
88
+ imports = [i for i in imports if ("get_ipython" not in i)]
89
+ if case_type_import not in imports:
90
+ imports.append(case_type_import)
91
+ imports.extend([oti for oti in output_type_imports if oti not in imports])
92
+ imports = set(imports)
93
+ return '\n'.join(imports)
94
+
95
+ @cached_property
96
+ def case_type(self) -> Type:
97
+ """
98
+ Get the type of the case object in the current scope.
99
+
100
+ :return: The type of the case object.
101
+ """
102
+ case = self.scope['case']
103
+ return case._obj_type if isinstance(case, Case) else type(case)
104
+
105
+ @line_magic
106
+ def load_case(self, line):
107
+ if not self.temp_file_path:
108
+ print("No file to load. Run %edit_case first.")
109
+ return
110
+
111
+ with open(self.temp_file_path, 'r') as f:
112
+ source = f.read()
113
+
114
+ tree = ast.parse(source)
115
+ for node in tree.body:
116
+ if isinstance(node, ast.FunctionDef) and node.name == self.func_name:
117
+ exec_globals = {}
118
+ exec(source, self.scope, exec_globals)
119
+ user_function = exec_globals[self.func_name]
120
+ self.shell.user_ns[self.func_name] = user_function
121
+ print(f"Loaded `{self.func_name}` function into user namespace.")
122
+ return
123
+
124
+ print(f"Function `{self.func_name}` not found.")
13
125
 
14
126
 
15
127
  class CustomInteractiveShell(InteractiveShellEmbed):
16
- def __init__(self, **kwargs):
128
+ def __init__(self, output_type: Union[Type, Tuple[Type], None] = None, func_name: Optional[str] = None,
129
+ func_doc: Optional[str] = None, code_to_modify: Optional[str] = None, **kwargs):
17
130
  super().__init__(**kwargs)
131
+ keys = ['output_type', 'func_name', 'func_doc', 'code_to_modify']
132
+ values = [output_type, func_name, func_doc, code_to_modify]
133
+ magics_kwargs = {key: value for key, value in zip(keys, values) if value is not None}
134
+ self.my_magics = MyMagics(self, self.user_ns, **magics_kwargs)
135
+ self.register_magics(self.my_magics)
18
136
  self.all_lines = []
19
137
 
20
138
  def run_cell(self, raw_cell: str, **kwargs):
21
139
  """
22
140
  Override the run_cell method to capture return statements.
23
141
  """
24
- if contains_return_statement(raw_cell):
142
+ if contains_return_statement(raw_cell) and 'def ' not in raw_cell:
143
+ if self.my_magics.func_name in raw_cell:
144
+ self.all_lines = extract_function_source(self.my_magics.temp_file_path,
145
+ self.my_magics.func_name,
146
+ join_lines=False)[self.my_magics.func_name]
25
147
  self.all_lines.append(raw_cell)
26
148
  print("Exiting shell on `return` statement.")
27
149
  self.history_manager.store_inputs(line_num=self.execution_count, source=raw_cell)
@@ -38,16 +160,31 @@ class IPythonShell:
38
160
  Create an embedded Ipython shell that can be used to prompt the user for input.
39
161
  """
40
162
 
41
- def __init__(self, scope: Optional[Dict] = None, header: Optional[str] = None):
163
+ def __init__(self, scope: Optional[Dict] = None, header: Optional[str] = None,
164
+ output_type: Optional[Type] = None, prompt_for: Optional[PromptFor] = None,
165
+ attribute_name: Optional[str] = None, attribute_type: Optional[Type] = None,
166
+ code_to_modify: Optional[str] = None):
42
167
  """
43
168
  Initialize the Ipython shell with the given scope and header.
44
169
 
45
170
  :param scope: The scope to use for the shell.
46
171
  :param header: The header to display when the shell is started.
172
+ :param output_type: The type of the output from user input.
173
+ :param prompt_for: The type of information to ask the user about.
174
+ :param attribute_name: The name of the attribute of the case.
175
+ :param attribute_type: The type of the attribute of the case.
176
+ :param code_to_modify: The code to modify. If given, will be used as a start for user to modify.
47
177
  """
48
178
  self.scope: Dict = scope or {}
49
179
  self.header: str = header or ">>> Embedded Ipython Shell"
180
+ self.output_type: Optional[Type] = output_type
181
+ self.prompt_for: Optional[PromptFor] = prompt_for
182
+ self.attribute_name: Optional[str] = attribute_name
183
+ self.attribute_type: Optional[Type] = attribute_type
184
+ self.code_to_modify: Optional[str] = code_to_modify
50
185
  self.user_input: Optional[str] = None
186
+ self.func_name: str = ""
187
+ self.func_doc: str = ""
51
188
  self.shell: CustomInteractiveShell = self._init_shell()
52
189
  self.all_code_lines: List[str] = []
53
190
 
@@ -56,9 +193,49 @@ class IPythonShell:
56
193
  Initialize the Ipython shell with a custom configuration.
57
194
  """
58
195
  cfg = Config()
59
- shell = CustomInteractiveShell(config=cfg, user_ns=self.scope, banner1=self.header)
196
+ self.build_func_name_and_doc()
197
+ shell = CustomInteractiveShell(config=cfg, user_ns=self.scope, banner1=self.header,
198
+ output_type=self.output_type, func_name=self.func_name, func_doc=self.func_doc,
199
+ code_to_modify=self.code_to_modify)
60
200
  return shell
61
201
 
202
+ def build_func_name_and_doc(self) -> Tuple[str, str]:
203
+ """
204
+ Build the function name and docstring for the user-defined function.
205
+
206
+ :return: A tuple containing the function name and docstring.
207
+ """
208
+ case = self.scope['case']
209
+ case_type = case._obj_type if isinstance(case, Case) else type(case)
210
+ self.func_name = self.build_func_name(case_type)
211
+ self.func_doc = self.build_func_doc(case_type)
212
+
213
+ def build_func_doc(self, case_type: Type):
214
+ if self.prompt_for == PromptFor.Conditions:
215
+ func_doc = (f"Get conditions on whether it's possible to conclude a value"
216
+ f" for {case_type.__name__}.{self.attribute_name}")
217
+ else:
218
+ func_doc = f"Get possible value(s) for {case_type.__name__}.{self.attribute_name}"
219
+ if is_iterable(self.attribute_type):
220
+ possible_types = [t.__name__ for t in self.attribute_type if t not in [list, set]]
221
+ func_doc += f" of types list/set of {' and/or '.join(possible_types)}"
222
+ else:
223
+ func_doc += f" of type {self.attribute_type.__name__}"
224
+ return func_doc
225
+
226
+ def build_func_name(self, case_type: Type):
227
+ func_name = f"get_{self.prompt_for.value.lower()}_for"
228
+ func_name += f"_{case_type.__name__}"
229
+ if self.attribute_name is not None:
230
+ func_name += f"_{self.attribute_name}"
231
+ if is_iterable(self.attribute_type):
232
+ output_names = [f"{t.__name__}" for t in self.attribute_type if t not in [list, set]]
233
+ else:
234
+ output_names = [self.attribute_type.__name__] if self.attribute_type is not None else None
235
+ if output_names is not None:
236
+ func_name += '_of_type_' + '_'.join(output_names)
237
+ return func_name.lower()
238
+
62
239
  def run(self):
63
240
  """
64
241
  Run the embedded shell.
@@ -84,6 +261,10 @@ class IPythonShell:
84
261
  self.user_input = None
85
262
  else:
86
263
  self.user_input = '\n'.join(self.all_code_lines)
264
+ self.user_input = encapsulate_user_input(self.user_input, self.shell.my_magics.function_signature,
265
+ self.func_doc)
266
+ if f"return {self.func_name}(case)" not in self.user_input:
267
+ self.user_input = self.user_input.strip() + f"\nreturn {self.func_name}(case)"
87
268
 
88
269
 
89
270
  def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, prompt_str: Optional[str] = None)\
@@ -96,8 +277,12 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, pro
96
277
  :param prompt_str: The prompt string to display to the user.
97
278
  :return: A callable expression that takes a case and executes user expression on it.
98
279
  """
280
+ prev_user_input: Optional[str] = None
281
+ callable_expression: Optional[CallableExpression] = None
99
282
  while True:
100
- user_input, expression_tree = prompt_user_about_case(case_query, prompt_for, prompt_str)
283
+ user_input, expression_tree = prompt_user_about_case(case_query, prompt_for, prompt_str,
284
+ code_to_modify=prev_user_input)
285
+ prev_user_input = '\n'.join(user_input.split('\n')[2:-1])
101
286
  if user_input is None:
102
287
  if prompt_for == PromptFor.Conclusion:
103
288
  print("No conclusion provided. Exiting.")
@@ -109,7 +294,10 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, pro
109
294
  callable_expression = CallableExpression(user_input, conclusion_type, expression_tree=expression_tree,
110
295
  scope=case_query.scope)
111
296
  try:
112
- callable_expression(case_query.case)
297
+ result = callable_expression(case_query.case)
298
+ if len(make_list(result)) == 0:
299
+ print(f"The given expression gave an empty result for case {case_query.name}. Please modify!")
300
+ continue
113
301
  break
114
302
  except Exception as e:
115
303
  logging.error(e)
@@ -118,19 +306,24 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, pro
118
306
 
119
307
 
120
308
  def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor,
121
- prompt_str: Optional[str] = None) -> Tuple[Optional[str], Optional[AST]]:
309
+ prompt_str: Optional[str] = None,
310
+ code_to_modify: Optional[str] = None) -> Tuple[Optional[str], Optional[AST]]:
122
311
  """
123
312
  Prompt the user for input.
124
313
 
125
314
  :param case_query: The case query to prompt the user for.
126
315
  :param prompt_for: The type of information the user should provide for the given case.
127
316
  :param prompt_str: The prompt string to display to the user.
317
+ :param code_to_modify: The code to modify. If given will be used as a start for user to modify.
128
318
  :return: The user input, and the executable expression that was parsed from the user input.
129
319
  """
130
320
  if prompt_str is None:
131
321
  prompt_str = f"Give {prompt_for} for {case_query.name}"
132
322
  scope = {'case': case_query.case, **case_query.scope}
133
- shell = IPythonShell(scope=scope, header=prompt_str)
323
+ output_type = case_query.attribute_type if prompt_for == PromptFor.Conclusion else bool
324
+ shell = IPythonShell(scope=scope, header=prompt_str, output_type=output_type, prompt_for=prompt_for,
325
+ attribute_name=case_query.attribute_name, attribute_type=case_query.attribute_type,
326
+ code_to_modify=code_to_modify)
134
327
  return prompt_user_input_and_parse_to_expression(shell=shell)
135
328
 
136
329
 
ripple_down_rules/rdr.py CHANGED
@@ -4,25 +4,23 @@ import importlib
4
4
  import sys
5
5
  from abc import ABC, abstractmethod
6
6
  from copy import copy
7
- from dataclasses import is_dataclass
8
7
  from io import TextIOWrapper
9
8
  from types import ModuleType
10
9
 
11
10
  from matplotlib import pyplot as plt
12
- from ordered_set import OrderedSet
13
11
  from sqlalchemy.orm import DeclarativeBase as SQLTable
14
12
  from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
15
13
 
16
14
  from .datastructures.callable_expression import CallableExpression
17
15
  from .datastructures.case import Case, CaseAttribute, create_case
18
16
  from .datastructures.dataclasses import CaseQuery
19
- from .datastructures.enums import MCRDRMode, PromptFor
20
- from .experts import Expert, Human
17
+ from .datastructures.enums import MCRDRMode
18
+ from .experts import Expert
21
19
  from .helpers import is_matching
22
20
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
23
21
  from .utils import draw_tree, make_set, copy_case, \
24
- SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
25
- get_case_attribute_type, is_conflicting, update_case
22
+ SubclassJSONSerializer, make_list, get_type_from_string, \
23
+ is_conflicting, update_case, get_imports_from_scope, extract_function_source
26
24
 
27
25
 
28
26
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -102,11 +100,12 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
102
100
  return self.classify(case)
103
101
 
104
102
  @abstractmethod
105
- def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
103
+ def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) -> Optional[CaseAttribute]:
106
104
  """
107
105
  Classify a case.
108
106
 
109
107
  :param case: The case to classify.
108
+ :param modify_case: Whether to modify the original case attributes with the conclusion or not.
110
109
  :return: The category that the case belongs to.
111
110
  """
112
111
  pass
@@ -124,7 +123,10 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
124
123
  if case_query is None:
125
124
  raise ValueError("The case query cannot be None.")
126
125
  if case_query.target is None:
127
- expert.ask_for_conclusion(case_query)
126
+ case_query_cp = copy(case_query)
127
+ self.classify(case_query_cp.case, modify_case=True)
128
+ expert.ask_for_conclusion(case_query_cp)
129
+ case_query.target = case_query_cp.target
128
130
  if case_query.target is None:
129
131
  return self.classify(case_query.case)
130
132
 
@@ -169,24 +171,62 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
169
171
  self.fig = plt.figure(0)
170
172
  draw_tree(self.start_rule, self.fig)
171
173
 
172
- @staticmethod
173
- def case_has_conclusion(case: Union[Case, SQLTable], conclusion_name: str) -> bool:
174
- """
175
- Check if the case has a conclusion.
174
+ @property
175
+ def type_(self):
176
+ return self.__class__
177
+
178
+ @property
179
+ def generated_python_file_name(self) -> str:
180
+ if self._generated_python_file_name is None:
181
+ self._generated_python_file_name = self._default_generated_python_file_name
182
+ return self._generated_python_file_name
176
183
 
177
- :param case: The case to check.
178
- :param conclusion_name: The target category name to compare the case with.
179
- :return: Whether the case has a conclusion or not.
184
+ @generated_python_file_name.setter
185
+ def generated_python_file_name(self, value: str):
180
186
  """
181
- return hasattr(case, conclusion_name) and getattr(case, conclusion_name) is not None
187
+ Set the generated python file name.
188
+ :param value: The new value for the generated python file name.
189
+ """
190
+ self._generated_python_file_name = value
182
191
 
183
192
  @property
184
- def type_(self):
185
- return self.__class__
193
+ @abstractmethod
194
+ def _default_generated_python_file_name(self) -> str:
195
+ """
196
+ :return: The default generated python file name.
197
+ """
198
+ pass
199
+
200
+ @abstractmethod
201
+ def update_from_python_file(self, package_dir: str):
202
+ """
203
+ Update the rules from the generated python file, that might have been modified by the user.
204
+
205
+ :param package_dir: The directory of the package that contains the generated python file.
206
+ """
207
+ pass
186
208
 
187
209
 
188
210
  class RDRWithCodeWriter(RippleDownRules, ABC):
189
211
 
212
+ def update_from_python_file(self, package_dir: str):
213
+ """
214
+ Update the rules from the generated python file, that might have been modified by the user.
215
+
216
+ :param package_dir: The directory of the package that contains the generated python file.
217
+ """
218
+ rule_ids = [r.uid for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None]
219
+ condition_func_names = [f'conditions_{rid}' for rid in rule_ids]
220
+ conclusion_func_names = [f'conclusion_{rid}' for rid in rule_ids]
221
+ all_func_names = condition_func_names + conclusion_func_names
222
+ filepath = f"{package_dir}/{self.generated_python_defs_file_name}.py"
223
+ functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
224
+ for rule in [self.start_rule] + list(self.start_rule.descendants):
225
+ if rule.conditions is not None:
226
+ rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
227
+ if rule.conclusion is not None:
228
+ rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
229
+
190
230
  @abstractmethod
191
231
  def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
192
232
  defs_file: Optional[str] = None):
@@ -238,31 +278,26 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
238
278
  """
239
279
  :return: The imports for the generated python file of the RDR as a string.
240
280
  """
241
- defs_imports = ""
281
+ defs_imports_list = []
242
282
  for rule in [self.start_rule] + list(self.start_rule.descendants):
243
283
  if not rule.conditions:
244
284
  continue
245
285
  for scope in [rule.conditions.scope, rule.conclusion.scope]:
246
286
  if scope is None:
247
287
  continue
248
- for k, v in scope.items():
249
- if not hasattr(v, "__module__") or not hasattr(v, "__name__"):
250
- continue
251
- new_imports = f"from {v.__module__} import {v.__name__}\n"
252
- if new_imports in defs_imports:
253
- continue
254
- defs_imports += new_imports
255
- imports = ""
288
+ defs_imports_list.extend(get_imports_from_scope(scope))
289
+ if self.case_type.__module__ != "builtins":
290
+ defs_imports_list.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
291
+ defs_imports = "\n".join(set(defs_imports_list)) + "\n"
292
+ imports = []
256
293
  if self.case_type.__module__ != "builtins":
257
- new_import = f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
258
- if new_import not in defs_imports:
259
- imports += new_import
294
+ imports.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
260
295
  for conclusion_type in self.conclusion_type:
261
296
  if conclusion_type.__module__ != "builtins":
262
- new_import = f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
263
- if new_import not in defs_imports:
264
- imports += new_import
265
- imports += "from ripple_down_rules.datastructures.case import Case, create_case\n"
297
+ imports.append(f"from {conclusion_type.__module__} import {conclusion_type.__name__}")
298
+ imports.append("from ripple_down_rules.datastructures.case import Case, create_case")
299
+ imports = set(imports).difference(defs_imports_list)
300
+ imports = "\n".join(imports) + "\n"
266
301
  return imports, defs_imports
267
302
 
268
303
  def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
@@ -279,20 +314,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
279
314
  pass
280
315
  return importlib.import_module(name).classify
281
316
 
282
- @property
283
- def generated_python_file_name(self) -> str:
284
- if self._generated_python_file_name is None:
285
- self._generated_python_file_name = self._default_generated_python_file_name
286
- return self._generated_python_file_name
287
-
288
- @generated_python_file_name.setter
289
- def generated_python_file_name(self, value: str):
290
- """
291
- Set the generated python file name.
292
- :param value: The new value for the generated python file name.
293
- """
294
- self._generated_python_file_name = value
295
-
296
317
  @property
297
318
  def _default_generated_python_file_name(self) -> str:
298
319
  """
@@ -390,12 +411,12 @@ class SingleClassRDR(RDRWithCodeWriter):
390
411
  self.start_rule = SingleClassRule(case_query.conditions, case_query.target, corner_case=case_query.case,
391
412
  conclusion_name=case_query.attribute_name)
392
413
 
393
- def classify(self, case: Case, modify_original_case: bool = False) -> Optional[Any]:
414
+ def classify(self, case: Case, modify_case: bool = False) -> Optional[Any]:
394
415
  """
395
416
  Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
396
417
 
397
418
  :param case: The case to classify.
398
- :param modify_original_case: Whether to modify the original case attributes with the conclusion or not.
419
+ :param modify_case: Whether to modify the original case attributes with the conclusion or not.
399
420
  """
400
421
  pred = self.evaluate(case)
401
422
  return pred.conclusion(case) if pred is not None and pred.fired else self.default_conclusion
@@ -481,7 +502,7 @@ class MultiClassRDR(RDRWithCodeWriter):
481
502
  super(MultiClassRDR, self).__init__(start_rule)
482
503
  self.mode: MCRDRMode = mode
483
504
 
484
- def classify(self, case: Union[Case, SQLTable]) -> Set[Any]:
505
+ def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) -> Set[Any]:
485
506
  evaluated_rule = self.start_rule
486
507
  self.conclusions = []
487
508
  while evaluated_rule:
@@ -492,7 +513,7 @@ class MultiClassRDR(RDRWithCodeWriter):
492
513
  return make_set(self.conclusions)
493
514
 
494
515
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None
495
- , **kwargs) -> Set[Union[CaseAttribute, CallableExpression, None]]:
516
+ , **kwargs) -> Set[Union[CaseAttribute, CallableExpression, None]]:
496
517
  """
497
518
  Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
498
519
  or missing by comparing the case with the target category if provided.
@@ -504,13 +525,13 @@ class MultiClassRDR(RDRWithCodeWriter):
504
525
  self.conclusions = []
505
526
  self.stop_rule_conditions = None
506
527
  evaluated_rule = self.start_rule
507
- target = make_set(case_query.target_value)
528
+ target_value = make_set(case_query.target_value)
508
529
  while evaluated_rule:
509
530
  next_rule = evaluated_rule(case_query.case)
510
531
  rule_conclusion = evaluated_rule.conclusion(case_query.case)
511
532
 
512
533
  if evaluated_rule.fired:
513
- if not make_set(rule_conclusion).issubset(target):
534
+ if not make_set(rule_conclusion).issubset(target_value):
514
535
  # Rule fired and conclusion is different from target
515
536
  self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
516
537
  else:
@@ -518,7 +539,7 @@ class MultiClassRDR(RDRWithCodeWriter):
518
539
  self.add_conclusion(evaluated_rule, case_query.case)
519
540
 
520
541
  if not next_rule:
521
- if not make_set(target).issubset(make_set(self.conclusions)):
542
+ if not make_set(target_value).issubset(make_set(self.conclusions)):
522
543
  # Nothing fired and there is a target that should have been in the conclusions
523
544
  self.add_rule_for_case(case_query, expert)
524
545
  # Have to check all rules again to make sure only this new rule fires
@@ -556,12 +577,9 @@ class MultiClassRDR(RDRWithCodeWriter):
556
577
 
557
578
  def _get_imports(self) -> Tuple[str, str]:
558
579
  imports, defs_imports = super()._get_imports()
559
- conclusion_types = [ct for ct in self.conclusion_type if ct not in [list, set]]
560
- if len(conclusion_types) == 1:
561
- imports += f"from typing_extensions import Set\n"
562
- else:
563
- imports += "from typing_extensions import Set, Union\n"
580
+ imports += f"from typing_extensions import Set, Union\n"
564
581
  imports += "from ripple_down_rules.utils import make_set\n"
582
+ defs_imports += "from typing_extensions import Union\n"
565
583
  return imports, defs_imports
566
584
 
567
585
  def update_start_rule(self, case_query: CaseQuery, expert: Expert):
@@ -594,6 +612,8 @@ class MultiClassRDR(RDRWithCodeWriter):
594
612
  rule_conclusion = evaluated_rule.conclusion(case_query.case)
595
613
  if is_conflicting(rule_conclusion, case_query.target_value):
596
614
  self.stop_conclusion(case_query, expert, evaluated_rule)
615
+ else:
616
+ self.add_conclusion(evaluated_rule, case_query.case)
597
617
 
598
618
  def stop_conclusion(self, case_query: CaseQuery,
599
619
  expert: Expert, evaluated_rule: MultiClassTopRule):
@@ -829,6 +849,15 @@ class GeneralRDR(RippleDownRules):
829
849
  start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
830
850
  return cls(start_rules_dict)
831
851
 
852
+ def update_from_python_file(self, package_dir: str) -> None:
853
+ """
854
+ Update the rules from the generated python file, that might have been modified by the user.
855
+
856
+ :param package_dir: The directory of the package that contains the generated python file.
857
+ """
858
+ for rdr in self.start_rules_dict.values():
859
+ rdr.update_from_python_file(package_dir)
860
+
832
861
  def write_to_python_file(self, file_path: str, postfix: str = "") -> None:
833
862
  """
834
863
  Write the tree of rules as source code to a file.
@@ -864,21 +893,10 @@ class GeneralRDR(RippleDownRules):
864
893
  def get_rdr_classifier_from_python_file(self, file_path: str) -> Callable[[Any], Any]:
865
894
  """
866
895
  :param file_path: The path to the file that contains the RDR classifier function.
867
- :param postfix: The postfix to add to the file name.
868
896
  :return: The module that contains the rdr classifier function.
869
897
  """
870
898
  return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
871
899
 
872
- @property
873
- def generated_python_file_name(self) -> str:
874
- if self._generated_python_file_name is None:
875
- self._generated_python_file_name = self._default_generated_python_file_name
876
- return self._generated_python_file_name
877
-
878
- @generated_python_file_name.setter
879
- def generated_python_file_name(self, value: str):
880
- self._generated_python_file_name = value
881
-
882
900
  @property
883
901
  def _default_generated_python_file_name(self) -> str:
884
902
  """
@@ -2,17 +2,16 @@ from __future__ import annotations
2
2
 
3
3
  import re
4
4
  from abc import ABC, abstractmethod
5
- from enum import Enum
5
+ from uuid import uuid4
6
6
 
7
7
  from anytree import NodeMixin
8
+ from sqlalchemy.orm import DeclarativeBase as SQLTable
8
9
  from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
9
10
 
10
11
  from .datastructures.callable_expression import CallableExpression
11
12
  from .datastructures.case import Case
12
- from sqlalchemy.orm import DeclarativeBase as SQLTable
13
13
  from .datastructures.enums import RDREdge, Stop
14
- from .utils import SubclassJSONSerializer, is_iterable, get_full_class_name, conclusion_to_json, \
15
- get_rule_conclusion_as_source_code
14
+ from .utils import SubclassJSONSerializer, conclusion_to_json
16
15
 
17
16
 
18
17
  class Rule(NodeMixin, SubclassJSONSerializer, ABC):
@@ -26,7 +25,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
26
25
  parent: Optional[Rule] = None,
27
26
  corner_case: Optional[Union[Case, SQLTable]] = None,
28
27
  weight: Optional[str] = None,
29
- conclusion_name: Optional[str] = None):
28
+ conclusion_name: Optional[str] = None,
29
+ uid: Optional[str] = None):
30
30
  """
31
31
  A rule in the ripple down rules classifier.
32
32
 
@@ -36,6 +36,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
36
36
  :param corner_case: The corner case that this rule is based on/created from.
37
37
  :param weight: The weight of the rule, which is the type of edge connecting the rule to its parent.
38
38
  :param conclusion_name: The name of the conclusion of the rule.
39
+ :param uid: The unique id of the rule.
39
40
  """
40
41
  super(Rule, self).__init__()
41
42
  self.conclusion = conclusion
@@ -46,6 +47,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
46
47
  self.conclusion_name: Optional[str] = conclusion_name
47
48
  self.json_serialization: Optional[Dict[str, Any]] = None
48
49
  self._name: Optional[str] = None
50
+ # generate a unique id for the rule using uuid4
51
+ self.uid: str = uid if uid else str(uuid4().int)
49
52
 
50
53
  def _post_detach(self, parent):
51
54
  """
@@ -90,15 +93,32 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
90
93
  conclusion = self.conclusion.user_input
91
94
  else:
92
95
  conclusion = self.conclusion.conclusion
93
- conclusion_func, conclusion_func_call = self._conclusion_source_code(conclusion, parent_indent=parent_indent)
96
+ conclusion_func, conclusion_func_call = self.get_conclusion_as_source_code(conclusion,
97
+ parent_indent=parent_indent)
94
98
  if conclusion_func is not None:
95
99
  with open(defs_file, 'a') as f:
96
100
  f.write(conclusion_func.strip() + "\n\n\n")
97
101
  return conclusion_func_call
98
102
 
99
- @abstractmethod
100
- def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
101
- pass
103
+ def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
104
+ """
105
+ Convert the conclusion of a rule to source code.
106
+
107
+ :param conclusion: The conclusion to convert to source code.
108
+ :param parent_indent: The indentation of the parent rule.
109
+ :return: The source code of the conclusion as a tuple of strings, one for the function and one for the call.
110
+ """
111
+ if "def " in conclusion:
112
+ # This means the conclusion is a definition that should be written and then called
113
+ conclusion_lines = conclusion.split('\n')
114
+ # use regex to replace the function name
115
+ new_function_name = f"def conclusion_{self.uid}"
116
+ conclusion_lines[0] = re.sub(r"def (\w+)", new_function_name, conclusion_lines[0])
117
+ func_call = f"{parent_indent} return {new_function_name.replace('def ', '')}(case)\n"
118
+ return "\n".join(conclusion_lines).strip(' '), func_call
119
+ else:
120
+ raise ValueError(f"Conclusion is format is not valid, it should be contain a function definition."
121
+ f" Instead got:\n{conclusion}\n")
102
122
 
103
123
  def write_condition_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
104
124
  """
@@ -116,7 +136,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
116
136
  # This means the conditions are a definition that should be written and then called
117
137
  conditions_lines = self.conditions.user_input.split('\n')
118
138
  # use regex to replace the function name
119
- new_function_name = f"def conditions_{id(self)}"
139
+ new_function_name = f"def conditions_{self.uid}"
120
140
  conditions_lines[0] = re.sub(r"def (\w+)", new_function_name, conditions_lines[0])
121
141
  def_code = "\n".join(conditions_lines)
122
142
  with open(defs_file, 'a') as f:
@@ -133,7 +153,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
133
153
  "parent": self.parent.json_serialization if self.parent else None,
134
154
  "corner_case": SubclassJSONSerializer.to_json_static(self.corner_case),
135
155
  "conclusion_name": self.conclusion_name,
136
- "weight": self.weight}
156
+ "weight": self.weight,
157
+ "uid": self.uid}
137
158
  return json_serialization
138
159
 
139
160
  @classmethod
@@ -143,7 +164,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
143
164
  parent=cls.from_json(data["parent"]),
144
165
  corner_case=Case.from_json(data["corner_case"]),
145
166
  conclusion_name=data["conclusion_name"],
146
- weight=data["weight"])
167
+ weight=data["weight"],
168
+ uid=data["uid"])
147
169
  return loaded_rule
148
170
 
149
171
  @property
@@ -268,14 +290,6 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
268
290
  loaded_rule.alternative = SingleClassRule.from_json(data["alternative"])
269
291
  return loaded_rule
270
292
 
271
- def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
272
- conclusion = str(conclusion)
273
- # indent = parent_indent + " " * 4
274
- # if '\n' not in conclusion:
275
- # return None, f"{indent}return {conclusion}\n"
276
- # else:
277
- return get_rule_conclusion_as_source_code(self, conclusion, parent_indent=parent_indent)
278
-
279
293
  def _if_statement_source_code_clause(self) -> str:
280
294
  return "elif" if self.weight == RDREdge.Alternative.value else "if"
281
295
 
@@ -319,7 +333,7 @@ class MultiClassStopRule(Rule, HasAlternativeRule):
319
333
  loaded_rule.alternative = MultiClassStopRule.from_json(data["alternative"])
320
334
  return loaded_rule
321
335
 
322
- def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[None, str]:
336
+ def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[None, str]:
323
337
  return None, f"{parent_indent}{' ' * 4}pass\n"
324
338
 
325
339
  def _if_statement_source_code_clause(self) -> str:
@@ -364,20 +378,11 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
364
378
  loaded_rule.alternative = MultiClassTopRule.from_json(data["alternative"])
365
379
  return loaded_rule
366
380
 
367
- def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[str, str]:
368
- conclusion_str = str(conclusion)
369
- indent = parent_indent + " " * 4
370
- # if '\n' not in conclusion_str:
371
- # func = None
372
- # if is_iterable(conclusion):
373
- # conclusion_str = "{" + ", ".join([str(c) for c in conclusion]) + "}"
374
- # else:
375
- # conclusion_str = "{" + str(conclusion) + "}"
376
- # else:
377
- func, func_call = get_rule_conclusion_as_source_code(self, conclusion_str, parent_indent=parent_indent)
381
+ def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[str, str]:
382
+ func, func_call = super().get_conclusion_as_source_code(str(conclusion), parent_indent=parent_indent)
378
383
  conclusion_str = func_call.replace("return ", "").strip()
379
384
 
380
- statement = f"{indent}conclusions.update(make_set({conclusion_str}))\n"
385
+ statement = f"{parent_indent} conclusions.update(make_set({conclusion_str}))\n"
381
386
  if self.alternative is None:
382
387
  statement += f"{parent_indent}return conclusions\n"
383
388
  return func, statement
@@ -36,16 +36,145 @@ import ast
36
36
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
37
37
 
38
38
 
39
- def encapsulate_user_input(user_input: str, func_signature: str) -> str:
39
+ def are_results_subclass_of_types(result_types: List[Any], types_: List[Type]) -> bool:
40
+ """
41
+ Check if all results are subclasses of the given types.
42
+
43
+ :param result_types: The list of result types to check.
44
+ :param types_: The list of types to check against.
45
+ :return: True if all results are subclasses of the given types, False otherwise.
46
+ """
47
+ for rt in result_types:
48
+ if not any(issubclass(rt, t) for t in types_):
49
+ return False
50
+ return True
51
+
52
+
53
+ def get_imports_from_types(types: List[Type]) -> List[str]:
54
+ """
55
+ Get the import statements for a list of types.
56
+
57
+ :param types: The types to get the import statements for.
58
+ :return: The import statements as a string.
59
+ """
60
+ imports = map(get_import_from_type, types)
61
+ return list({i for i in imports if i is not None})
62
+
63
+
64
+ def get_import_from_type(type_: Type) -> Optional[str]:
65
+ """
66
+ Get the import statement for a given type.
67
+
68
+ :param type_: The type to get the import statement for.
69
+ :return: The import statement as a string.
70
+ """
71
+ if hasattr(type_, "__module__") and hasattr(type_, "__name__"):
72
+ if type_.__module__ == "builtins":
73
+ return
74
+ return f"from {type_.__module__} import {type_.__name__}"
75
+
76
+
77
+ def get_imports_from_scope(scope: Dict[str, Any]) -> List[str]:
78
+ """
79
+ Get the imports from the given scope.
80
+
81
+ :param scope: The scope to get the imports from.
82
+ :return: The imports as a string.
83
+ """
84
+ imports = []
85
+ for k, v in scope.items():
86
+ if not hasattr(v, "__module__") or not hasattr(v, "__name__"):
87
+ continue
88
+ imports.append(f"from {v.__module__} import {v.__name__}")
89
+ return imports
90
+
91
+
92
+ def extract_imports(file_path):
93
+ with open(file_path, "r") as f:
94
+ tree = ast.parse(f.read(), filename=file_path)
95
+
96
+ scope = {}
97
+
98
+ for node in ast.walk(tree):
99
+ if isinstance(node, ast.Import):
100
+ for alias in node.names:
101
+ module_name = alias.name
102
+ asname = alias.asname or alias.name
103
+ try:
104
+ scope[asname] = importlib.import_module(module_name)
105
+ except ImportError as e:
106
+ print(f"Could not import {module_name}: {e}")
107
+ elif isinstance(node, ast.ImportFrom):
108
+ module_name = node.module
109
+ for alias in node.names:
110
+ name = alias.name
111
+ asname = alias.asname or name
112
+ try:
113
+ module = importlib.import_module(module_name)
114
+ scope[asname] = getattr(module, name)
115
+ except (ImportError, AttributeError) as e:
116
+ print(f"Could not import {name} from {module_name}: {e}")
117
+
118
+ return scope
119
+
120
+
121
+ def extract_function_source(file_path: str,
122
+ function_names: List[str], join_lines: bool = True,
123
+ return_line_numbers: bool = False,
124
+ include_signature: bool = True) \
125
+ -> Union[Dict[str, Union[str, List[str]]],
126
+ Tuple[Dict[str, Union[str, List[str]]], List[Tuple[int, int]]]]:
127
+ """
128
+ Extract the source code of a function from a file.
129
+
130
+ :param file_path: The path to the file.
131
+ :param function_names: The names of the functions to extract.
132
+ :param join_lines: Whether to join the lines of the function.
133
+ :param return_line_numbers: Whether to return the line numbers of the function.
134
+ :param include_signature: Whether to include the function signature in the source code.
135
+ :return: A dictionary mapping function names to their source code as a string if join_lines is True,
136
+ otherwise as a list of strings.
137
+ """
138
+ with open(file_path, "r") as f:
139
+ source = f.read()
140
+
141
+ # Parse the source code into an AST
142
+ tree = ast.parse(source)
143
+ function_names = make_list(function_names)
144
+ functions_source: Dict[str, Union[str, List[str]]] = {}
145
+ line_numbers = []
146
+ for node in tree.body:
147
+ if isinstance(node, ast.FunctionDef) and node.name in function_names:
148
+ # Get the line numbers of the function
149
+ lines = source.splitlines()
150
+ func_lines = lines[node.lineno - 1:node.end_lineno]
151
+ if not include_signature:
152
+ func_lines = func_lines[1:]
153
+ line_numbers.append((node.lineno, node.end_lineno))
154
+ functions_source[node.name] = "\n".join(func_lines) if join_lines else func_lines
155
+ if len(functions_source) == len(function_names):
156
+ break
157
+ if len(functions_source) != len(function_names):
158
+ raise ValueError(f"Could not find all functions in {file_path}: {function_names} not found,"
159
+ f"functions not found: {set(function_names) - set(functions_source.keys())}")
160
+ if return_line_numbers:
161
+ return functions_source, line_numbers
162
+ return functions_source
163
+
164
+
165
+ def encapsulate_user_input(user_input: str, func_signature: str, func_doc: Optional[str] = None) -> str:
40
166
  """
41
167
  Encapsulate the user input string with a function definition.
42
168
 
43
169
  :param user_input: The user input string.
44
170
  :param func_signature: The function signature to use for encapsulation.
171
+ :param func_doc: The function docstring to use for encapsulation.
45
172
  :return: The encapsulated user input string.
46
173
  """
47
174
  if func_signature not in user_input:
48
175
  new_user_input = func_signature + "\n "
176
+ if func_doc is not None:
177
+ new_user_input += f"\"\"\"{func_doc}\"\"\"" + "\n "
49
178
  if "return " not in user_input:
50
179
  if '\n' not in user_input:
51
180
  new_user_input += f"return {user_input}"
@@ -173,29 +302,6 @@ def calculate_precision_and_recall(pred_cat: Dict[str, Any], target: Dict[str, A
173
302
  return precision, recall
174
303
 
175
304
 
176
- def get_rule_conclusion_as_source_code(rule: Rule, conclusion: str, parent_indent: str = "") -> Tuple[str, str]:
177
- """
178
- Convert the conclusion of a rule to source code.
179
-
180
- :param rule: The rule to get the conclusion from.
181
- :param conclusion: The conclusion to convert to source code.
182
- :param parent_indent: The indentation to use for the source code.
183
- :return: The source code of the conclusion as a tuple of strings, one for the function and one for the call.
184
- """
185
- indent = f"{parent_indent}{' ' * 4}"
186
- if "def " in conclusion:
187
- # This means the conclusion is a definition that should be written and then called
188
- conclusion_lines = conclusion.split('\n')
189
- # use regex to replace the function name
190
- new_function_name = f"def conclusion_{id(rule)}"
191
- conclusion_lines[0] = re.sub(r"def (\w+)", new_function_name, conclusion_lines[0])
192
- func_call = f"{indent}return {new_function_name.replace('def ', '')}(case)\n"
193
- return "\n".join(conclusion_lines).strip(' '), func_call
194
- else:
195
- raise ValueError(f"Conclusion is format is not valid, it should be a one line string or "
196
- f"contain a function definition. Instead got:\n{conclusion}\n")
197
-
198
-
199
305
  def ask_llm(prompt):
200
306
  try:
201
307
  response = requests.post("http://localhost:11434/api/generate", json={
@@ -317,7 +423,13 @@ def extract_dependencies(code_lines):
317
423
 
318
424
  for stmt in reversed(tree.body[:-1]):
319
425
  if handle_stmt(stmt, needed):
320
- required_lines.insert(0, code_lines[line_map[id(stmt)]])
426
+ # check if the statement is a function, if so then all its lines not just the line in line_map are needed.
427
+ if isinstance(stmt, ast.FunctionDef):
428
+ start_code_line = line_map[id(stmt)]
429
+ end_code_line = start_code_line + stmt.end_lineno
430
+ required_lines.extend(code_lines[start_code_line:end_code_line])
431
+ else:
432
+ required_lines.insert(0, code_lines[line_map[id(stmt)]])
321
433
 
322
434
  required_lines.append(code_lines[-1]) # Always include return
323
435
  return required_lines
@@ -749,9 +861,12 @@ def copy_orm_instance(instance: SQLTable) -> SQLTable:
749
861
  :return: The copied instance.
750
862
  """
751
863
  session: Session = inspect(instance).session
752
- session.expunge(instance)
753
- new_instance = deepcopy(instance)
754
- session.add(instance)
864
+ if session is not None:
865
+ session.expunge(instance)
866
+ new_instance = deepcopy(instance)
867
+ session.add(instance)
868
+ else:
869
+ new_instance = instance
755
870
  return new_instance
756
871
 
757
872
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.69
3
+ Version: 0.2.1
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -0,0 +1,20 @@
1
+ ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ ripple_down_rules/datasets.py,sha256=mjJh1GLD_5qMgHaukdDWSGphXS9k_BPEF001ZXPchr8,4687
3
+ ripple_down_rules/experts.py,sha256=JGVvSNiWhm4FpRpg76f98tl8Ii_C7x_aWD9FxD-JDLQ,6130
4
+ ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
+ ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
6
+ ripple_down_rules/prompt.py,sha256=ReXnZ6OraFPqK5kDfAqH8d4SRPYiQ5d4ESk0js_MM9c,16150
7
+ ripple_down_rules/rdr.py,sha256=vxNZckp6sLAUD92JQgfCzhBhg9CXfMZ_7W4VgrIUFjU,43366
8
+ ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
9
+ ripple_down_rules/rules.py,sha256=QQy7NBG6mKiowxVG_LjQJBxLTDW2hMyx5zAgwUxdCMM,17183
10
+ ripple_down_rules/utils.py,sha256=EdVdIf93TAqbxRTzbf_1422FjenRSI4MI_Ecp3e10z8,44007
11
+ ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
12
+ ripple_down_rules/datastructures/callable_expression.py,sha256=qVo8baEI_seeg6V23wmsctYdWj_tJJEOTkHeUc04Wvw,10912
13
+ ripple_down_rules/datastructures/case.py,sha256=nJDKOjyhGLx4gt0sHxJNxBLdy9X2SLcDn89_SmKzwoc,14035
14
+ ripple_down_rules/datastructures/dataclasses.py,sha256=BUr0T0CCh98sdsW2CVAXGk2oWqfemM7w1t91QKvg_KU,6171
15
+ ripple_down_rules/datastructures/enums.py,sha256=hlE6LAa1jUafnG_6UazvaPDfhC1ClI7hKvD89zOyAO8,4661
16
+ ripple_down_rules-0.2.1.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
+ ripple_down_rules-0.2.1.dist-info/METADATA,sha256=H-qZ4P2hARM8KV5Dq9duSaWHfdOvUYLRAYYTdnupfwI,42575
18
+ ripple_down_rules-0.2.1.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
19
+ ripple_down_rules-0.2.1.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
+ ripple_down_rules-0.2.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.1.0)
2
+ Generator: setuptools (80.3.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,20 +0,0 @@
1
- ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- ripple_down_rules/datasets.py,sha256=mjJh1GLD_5qMgHaukdDWSGphXS9k_BPEF001ZXPchr8,4687
3
- ripple_down_rules/experts.py,sha256=JGVvSNiWhm4FpRpg76f98tl8Ii_C7x_aWD9FxD-JDLQ,6130
4
- ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
- ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
6
- ripple_down_rules/prompt.py,sha256=gdLV2qhq-1kbmlJhVmkCwGGFhIeDxV0p5u2hwrbGUpk,6370
7
- ripple_down_rules/rdr.py,sha256=E8YXSU_VMd3e_LJItjzbBjr0g1KXhhU_HL4FMAgdczc,42203
8
- ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
9
- ripple_down_rules/rules.py,sha256=WfMWzgfI_5Tqv8a2k7jkpGdvwu-zYCG36EmpCl7GiEQ,16576
10
- ripple_down_rules/utils.py,sha256=INGRVIUH-SgssUW2T9gUMRf-XyJ3rxY_kuJUxfbsyOg,39683
11
- ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
12
- ripple_down_rules/datastructures/callable_expression.py,sha256=JLd8ZdIvAGX3mm-tID0buZIlbMF6hW2Z_jn5KA7X_ws,9788
13
- ripple_down_rules/datastructures/case.py,sha256=nJDKOjyhGLx4gt0sHxJNxBLdy9X2SLcDn89_SmKzwoc,14035
14
- ripple_down_rules/datastructures/dataclasses.py,sha256=TAOAeEvh0BeTis3rEHu8rpCeqNNhU0vK3to0JaBwTio,5961
15
- ripple_down_rules/datastructures/enums.py,sha256=RdyPUp9Ls1QuLmkcMMkBbCWrmXIZI4xWuM-cLPYZhR0,4666
16
- ripple_down_rules-0.1.69.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
- ripple_down_rules-0.1.69.dist-info/METADATA,sha256=hjuztFZVWpCOPfWULTsRdVCrfeMo8nc2XI4DdW4GJ9g,42576
18
- ripple_down_rules-0.1.69.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
19
- ripple_down_rules-0.1.69.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
- ripple_down_rules-0.1.69.dist-info/RECORD,,