ripple-down-rules 0.4.86__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1 +1,5 @@
1
- __version__ = "0.4.86"
1
+ __version__ = "0.5.00"
2
+
3
+ import logging
4
+ logger = logging.Logger("rdr")
5
+ logger.setLevel(logging.INFO)
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import ast
4
4
  import logging
5
+ import os
5
6
  from _ast import AST
6
7
  from enum import Enum
7
8
 
@@ -10,7 +11,7 @@ from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
10
11
  from .case import create_case, Case
11
12
  from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string, conclusion_to_json, is_iterable, \
12
13
  build_user_input_from_conclusion, encapsulate_user_input, extract_function_source, are_results_subclass_of_types, \
13
- make_list
14
+ make_list, get_imports_from_scope
14
15
 
15
16
 
16
17
  class VariableVisitor(ast.NodeVisitor):
@@ -175,6 +176,24 @@ class CallableExpression(SubclassJSONSerializer):
175
176
  return
176
177
  self.user_input = self.encapsulating_function + '\n' + new_function_body
177
178
 
179
+ def write_to_python_file(self, file_path: str, append: bool = False):
180
+ """
181
+ Write the callable expression to a python file.
182
+
183
+ :param file_path: The path to the file where the callable expression will be written.
184
+ :param append: If True, the callable expression will be appended to the file. If False,
185
+ the file will be overwritten.
186
+ """
187
+ imports = '\n'.join(get_imports_from_scope(self.scope))
188
+ if append and os.path.exists(file_path):
189
+ with open(file_path, 'a') as f:
190
+ f.write('\n\n\n' + imports + '\n\n\n')
191
+ f.write(self.user_input)
192
+ else:
193
+ with open(file_path, 'w') as f:
194
+ f.write(imports + '\n\n\n')
195
+ f.write(self.user_input)
196
+
178
197
  @property
179
198
  def user_input(self):
180
199
  """
@@ -354,11 +354,13 @@ def show_current_and_corner_cases(case: Any, targets: Optional[Dict[str, Any]] =
354
354
  if last_evaluated_rule and last_evaluated_rule.fired:
355
355
  corner_row_dict = copy_case(corner_case)
356
356
 
357
+ case_dict.update(targets)
358
+ case_dict.update(current_conclusions)
359
+ all_table_rows = [case_dict]
357
360
  if corner_row_dict:
358
361
  corner_conclusion = last_evaluated_rule.conclusion(case)
359
362
  corner_row_dict.update({corner_conclusion.__class__.__name__: corner_conclusion})
360
- print(table_rows_as_str(corner_row_dict))
361
- print("=" * 50)
362
- case_dict.update(targets)
363
- case_dict.update(current_conclusions)
364
- print(table_rows_as_str(case_dict))
363
+ all_table_rows.append(corner_row_dict)
364
+ # print(table_rows_as_str(corner_row_dict))
365
+ print("\n" + "=" * 50)
366
+ print(table_rows_as_str(all_table_rows))
@@ -78,7 +78,15 @@ class CaseQuery:
78
78
  """
79
79
  :return: The type of the case that the attribute belongs to.
80
80
  """
81
- return self.original_case._obj_type if isinstance(self.original_case, Case) else type(self.original_case)
81
+ if self.is_function:
82
+ if self.function_args_type_hints is not None:
83
+ func_args = [arg for name, arg in self.function_args_type_hints.items() if name != 'return']
84
+ case_type_args = Union[tuple(func_args)]
85
+ else:
86
+ case_type_args = Any
87
+ return Dict[str, case_type_args]
88
+ else:
89
+ return self.original_case._obj_type if isinstance(self.original_case, Case) else type(self.original_case)
82
90
 
83
91
  @property
84
92
  def case(self) -> Any:
@@ -1,7 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import ast
3
4
  import json
4
5
  import logging
6
+ import os
5
7
  from abc import ABC, abstractmethod
6
8
 
7
9
  from typing_extensions import Optional, TYPE_CHECKING, List
@@ -10,6 +12,8 @@ from .datastructures.callable_expression import CallableExpression
10
12
  from .datastructures.enums import PromptFor
11
13
  from .datastructures.dataclasses import CaseQuery
12
14
  from .datastructures.case import show_current_and_corner_cases
15
+ from .utils import extract_imports, extract_function_source, get_imports_from_scope, encapsulate_user_input
16
+
13
17
  try:
14
18
  from .user_interface.gui import RDRCaseViewer
15
19
  except ImportError as e:
@@ -36,10 +40,19 @@ class Expert(ABC):
36
40
  A flag to indicate if the expert should use loaded answers or not.
37
41
  """
38
42
 
39
- def __init__(self, use_loaded_answers: bool = False, append: bool = False):
43
+ def __init__(self, use_loaded_answers: bool = True,
44
+ append: bool = False,
45
+ answers_save_path: Optional[str] = None):
40
46
  self.all_expert_answers = []
41
47
  self.use_loaded_answers = use_loaded_answers
42
48
  self.append = append
49
+ self.answers_save_path = answers_save_path
50
+ if answers_save_path is not None:
51
+ if use_loaded_answers:
52
+ self.load_answers(answers_save_path)
53
+ else:
54
+ os.remove(answers_save_path + '.py')
55
+ self.append = True
43
56
 
44
57
  @abstractmethod
45
58
  def ask_for_conditions(self, case_query: CaseQuery, last_evaluated_rule: Optional[Rule] = None) \
@@ -63,46 +76,138 @@ class Expert(ABC):
63
76
  :return: A callable expression that can be called with a new case as an argument.
64
77
  """
65
78
 
79
+ def clear_answers(self, path: Optional[str] = None):
80
+ """
81
+ Clear the expert answers.
66
82
 
67
- class Human(Expert):
68
- """
69
- The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
70
- """
83
+ :param path: The path to clear the answers from. If None, the answers will be cleared from the
84
+ answers_save_path attribute.
85
+ """
86
+ if path is None and self.answers_save_path is None:
87
+ raise ValueError("No path provided to clear expert answers, either provide a path or set the "
88
+ "answers_save_path attribute.")
89
+ if path is None:
90
+ path = self.answers_save_path
91
+ if os.path.exists(path + '.json'):
92
+ os.remove(path + '.json')
93
+ if os.path.exists(path + '.py'):
94
+ os.remove(path + '.py')
95
+ self.all_expert_answers = []
71
96
 
72
- def __init__(self, use_loaded_answers: bool = False, append: bool = False, viewer: Optional[RDRCaseViewer] = None):
97
+ def save_answers(self, path: Optional[str] = None):
73
98
  """
74
- Initialize the Human expert.
99
+ Save the expert answers to a file.
75
100
 
76
- :param viewer: The RDRCaseViewer instance to use for prompting the user.
101
+ :param path: The path to save the answers to.
77
102
  """
78
- super().__init__(use_loaded_answers=use_loaded_answers, append=append)
79
- self.user_prompt = UserPrompt(viewer)
103
+ if path is None and self.answers_save_path is None:
104
+ raise ValueError("No path provided to save expert answers, either provide a path or set the "
105
+ "answers_save_path attribute.")
106
+ if path is None:
107
+ path = self.answers_save_path
108
+ is_json = os.path.exists(path + '.json')
109
+ if is_json:
110
+ self._save_to_json(path)
111
+ else:
112
+ self._save_to_python(path)
80
113
 
81
- def save_answers(self, path: str):
114
+ def _save_to_json(self, path: str):
82
115
  """
83
- Save the expert answers to a file.
116
+ Save the expert answers to a JSON file.
84
117
 
85
118
  :param path: The path to save the answers to.
86
119
  """
87
- if self.append:
120
+ all_answers = self.all_expert_answers
121
+ if self.append and os.path.exists(path + '.json'):
88
122
  # read the file and append the new answers
89
123
  with open(path + '.json', "r") as f:
90
- all_answers = json.load(f)
91
- all_answers.extend(self.all_expert_answers)
92
- with open(path + '.json', "w") as f:
93
- json.dump(all_answers, f)
94
- else:
95
- with open(path + '.json', "w") as f:
96
- json.dump(self.all_expert_answers, f)
124
+ old_answers = json.load(f)
125
+ all_answers = old_answers + all_answers
126
+ with open(path + '.json', "w") as f:
127
+ json.dump(all_answers, f)
97
128
 
98
- def load_answers(self, path: str):
129
+ def _save_to_python(self, path: str):
130
+ """
131
+ Save the expert answers to a Python file.
132
+
133
+ :param path: The path to save the answers to.
134
+ """
135
+ dir_name = os.path.dirname(path)
136
+ if not os.path.exists(dir_name + '/__init__.py'):
137
+ os.makedirs(dir_name, exist_ok=True)
138
+ with open(dir_name + '/__init__.py', 'w') as f:
139
+ f.write('# This is an empty init file to make the directory a package.\n')
140
+ action = 'w' if not self.append else 'a'
141
+ with open(path + '.py', action) as f:
142
+ for scope, func_source in self.all_expert_answers:
143
+ if len(scope) > 0:
144
+ imports = '\n'.join(get_imports_from_scope(scope)) + '\n\n\n'
145
+ else:
146
+ imports = ''
147
+ if func_source is not None:
148
+ func_source = encapsulate_user_input(func_source, CallableExpression.encapsulating_function)
149
+ else:
150
+ func_source = 'pass # No user input provided for this case.\n'
151
+ f.write(imports + func_source + '\n' + '\n\n\n\'===New Answer===\'\n\n\n')
152
+
153
+ def load_answers(self, path: Optional[str] = None):
99
154
  """
100
155
  Load the expert answers from a file.
101
156
 
157
+ :param path: The path to load the answers from.
158
+ """
159
+ if path is None and self.answers_save_path is None:
160
+ raise ValueError("No path provided to load expert answers from, either provide a path or set the "
161
+ "answers_save_path attribute.")
162
+ if path is None:
163
+ path = self.answers_save_path
164
+ is_json = os.path.exists(path + '.json')
165
+ if is_json:
166
+ self._load_answers_from_json(path)
167
+ elif os.path.exists(path + '.py'):
168
+ self._load_answers_from_python(path)
169
+
170
+ def _load_answers_from_json(self, path: str):
171
+ """
172
+ Load the expert answers from a JSON file.
173
+
102
174
  :param path: The path to load the answers from.
103
175
  """
104
176
  with open(path + '.json', "r") as f:
105
- self.all_expert_answers = json.load(f)
177
+ all_answers = json.load(f)
178
+ self.all_expert_answers = [({}, answer) for answer in all_answers]
179
+
180
+ def _load_answers_from_python(self, path: str):
181
+ """
182
+ Load the expert answers from a Python file.
183
+
184
+ :param path: The path to load the answers from.
185
+ """
186
+ file_path = path + '.py'
187
+ with open(file_path, "r") as f:
188
+ all_answers = f.read().split('\n\n\n\'===New Answer===\'\n\n\n')
189
+ for answer in all_answers:
190
+ answer = answer.strip('\n').strip()
191
+ if 'def ' not in answer and 'pass' in answer:
192
+ self.all_expert_answers.append(({}, None))
193
+ scope = extract_imports(tree=ast.parse(answer))
194
+ func_source = list(extract_function_source(file_path, []).values())[0]
195
+ self.all_expert_answers.append((scope, func_source))
196
+
197
+
198
+ class Human(Expert):
199
+ """
200
+ The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
201
+ """
202
+
203
+ def __init__(self, viewer: Optional[RDRCaseViewer] = None, **kwargs):
204
+ """
205
+ Initialize the Human expert.
206
+
207
+ :param viewer: The RDRCaseViewer instance to use for prompting the user.
208
+ """
209
+ super().__init__(**kwargs)
210
+ self.user_prompt = UserPrompt(viewer)
106
211
 
107
212
  def ask_for_conditions(self, case_query: CaseQuery,
108
213
  last_evaluated_rule: Optional[Rule] = None) \
@@ -125,13 +230,18 @@ class Human(Expert):
125
230
  if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
126
231
  self.use_loaded_answers = False
127
232
  if self.use_loaded_answers:
128
- user_input = self.all_expert_answers.pop(0)
129
- if user_input:
233
+ try:
234
+ loaded_scope, user_input = self.all_expert_answers.pop(0)
235
+ except IndexError:
236
+ self.use_loaded_answers = False
237
+ if user_input is not None:
130
238
  condition = CallableExpression(user_input, bool, scope=case_query.scope)
131
239
  else:
132
240
  user_input, condition = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conditions)
133
241
  if not self.use_loaded_answers:
134
- self.all_expert_answers.append(user_input)
242
+ self.all_expert_answers.append((condition.scope, user_input))
243
+ if self.answers_save_path is not None:
244
+ self.save_answers()
135
245
  case_query.conditions = condition
136
246
  return condition
137
247
 
@@ -143,18 +253,65 @@ class Human(Expert):
143
253
  :return: The conclusion for the case as a callable expression.
144
254
  """
145
255
  expression: Optional[CallableExpression] = None
256
+ expert_input: Optional[str] = None
146
257
  if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
147
258
  self.use_loaded_answers = False
148
259
  if self.use_loaded_answers:
149
- expert_input = self.all_expert_answers.pop(0)
150
- if expert_input is not None:
151
- expression = CallableExpression(expert_input, case_query.attribute_type,
152
- scope=case_query.scope,
153
- mutually_exclusive=case_query.mutually_exclusive)
154
- else:
260
+ try:
261
+ loaded_scope, expert_input = self.all_expert_answers.pop(0)
262
+ if expert_input is not None:
263
+ expression = CallableExpression(expert_input, case_query.attribute_type,
264
+ scope=case_query.scope,
265
+ mutually_exclusive=case_query.mutually_exclusive)
266
+ except IndexError:
267
+ self.use_loaded_answers = False
268
+ if not self.use_loaded_answers:
155
269
  if self.user_prompt.viewer is None:
156
270
  show_current_and_corner_cases(case_query.case)
157
271
  expert_input, expression = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conclusion)
158
- self.all_expert_answers.append(expert_input)
272
+ if expression is None:
273
+ self.all_expert_answers.append(({}, None))
274
+ else:
275
+ self.all_expert_answers.append((expression.scope, expert_input))
276
+ if self.answers_save_path is not None:
277
+ self.save_answers()
159
278
  case_query.target = expression
160
279
  return expression
280
+
281
+
282
+ class File(Expert):
283
+ """
284
+ The File Expert class, an expert that reads the answers from a file.
285
+ This is used for testing purposes.
286
+ """
287
+
288
+ def __init__(self, filename: str, **kwargs):
289
+ """
290
+ Initialize the File expert.
291
+
292
+ :param filename: The path to the file containing the expert answers.
293
+ """
294
+ super().__init__(**kwargs)
295
+ self.filename = filename
296
+ self.load_answers(filename)
297
+
298
+ def ask_for_conditions(self, case_query: CaseQuery,
299
+ last_evaluated_rule: Optional[Rule] = None) -> CallableExpression:
300
+ loaded_scope, user_input = self.all_expert_answers.pop(0)
301
+ if user_input:
302
+ condition = CallableExpression(user_input, bool, scope=case_query.scope)
303
+ else:
304
+ raise ValueError("No user input found in the expert answers file.")
305
+ case_query.conditions = condition
306
+ return condition
307
+
308
+ def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
309
+ loaded_scope, expert_input = self.all_expert_answers.pop(0)
310
+ if expert_input is not None:
311
+ expression = CallableExpression(expert_input, case_query.attribute_type,
312
+ scope=case_query.scope,
313
+ mutually_exclusive=case_query.mutually_exclusive)
314
+ else:
315
+ raise ValueError("No expert input found in the expert answers file.")
316
+ case_query.target = expression
317
+ return expression
ripple_down_rules/rdr.py CHANGED
@@ -2,7 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  import copyreg
4
4
  import importlib
5
- import logging
5
+ import os
6
+
7
+ from . import logger
6
8
  import sys
7
9
  from abc import ABC, abstractmethod
8
10
  from copy import copy
@@ -13,7 +15,7 @@ try:
13
15
  from matplotlib import pyplot as plt
14
16
  Figure = plt.Figure
15
17
  except ImportError as e:
16
- logging.debug(f"{e}: matplotlib is not installed")
18
+ logger.debug(f"{e}: matplotlib is not installed")
17
19
  matplotlib = None
18
20
  Figure = None
19
21
  plt = None
@@ -34,7 +36,8 @@ except ImportError as e:
34
36
  RDRCaseViewer = None
35
37
  from .utils import draw_tree, make_set, copy_case, \
36
38
  SubclassJSONSerializer, make_list, get_type_from_string, \
37
- is_conflicting, update_case, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name
39
+ is_conflicting, update_case, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
40
+ is_iterable, str_to_snake_case
38
41
 
39
42
 
40
43
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -61,17 +64,90 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
61
64
  """
62
65
  The type of the case (input) to the RDR classifier.
63
66
  """
67
+ case_name: Optional[str] = None
68
+ """
69
+ The name of the case type.
70
+ """
71
+ metadata_folder: str = "rdr_metadata"
72
+ """
73
+ The folder to save the metadata of the RDR classifier.
74
+ """
75
+ model_name: Optional[str] = None
76
+ """
77
+ The name of the model. If None, the model name will be the generated python file name.
78
+ """
64
79
 
65
- def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None):
80
+ def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None,
81
+ save_dir: Optional[str] = None, ask_always: bool = True, model_name: Optional[str] = None):
66
82
  """
67
83
  :param start_rule: The starting rule for the classifier.
84
+ :param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
85
+ :param save_dir: The directory to save the classifier to.
86
+ :param ask_always: Whether to always ask the expert (True) or only ask when classification fails (False).
68
87
  """
88
+ self.ask_always: bool = ask_always
89
+ self.model_name: Optional[str] = model_name
90
+ self.save_dir = save_dir
69
91
  self.start_rule = start_rule
70
92
  self.fig: Optional[Figure] = None
71
93
  self.viewer: Optional[RDRCaseViewer] = viewer
72
94
  if self.viewer is not None:
73
95
  self.viewer.set_save_function(self.save)
74
96
 
97
+ def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None) -> str:
98
+ """
99
+ Save the classifier to a file.
100
+
101
+ :param save_dir: The directory to save the classifier to.
102
+ :param model_name: The name of the model to save. If None, a default name is generated.
103
+ :param postfix: The postfix to add to the file name.
104
+ :return: The name of the saved model.
105
+ """
106
+ save_dir = save_dir or self.save_dir
107
+ if save_dir is None:
108
+ raise ValueError("The save directory cannot be None. Please provide a valid directory to save"
109
+ " the classifier.")
110
+ if not os.path.exists(save_dir + '/__init__.py'):
111
+ os.makedirs(save_dir, exist_ok=True)
112
+ with open(save_dir + '/__init__.py', 'w') as f:
113
+ f.write("# This is an empty __init__.py file to make the directory a package.\n")
114
+ if model_name is not None:
115
+ self.model_name = model_name
116
+ elif self.model_name is None:
117
+ self.model_name = self.generated_python_file_name
118
+ model_dir = os.path.join(save_dir, self.model_name)
119
+ os.makedirs(model_dir, exist_ok=True)
120
+ json_dir = os.path.join(model_dir, self.metadata_folder)
121
+ os.makedirs(json_dir, exist_ok=True)
122
+ self.to_json_file(os.path.join(json_dir, self.model_name))
123
+ self._write_to_python(model_dir)
124
+ return self.model_name
125
+
126
+ @classmethod
127
+ def load(cls, load_dir: str, model_name: str) -> Self:
128
+ """
129
+ Load the classifier from a file.
130
+
131
+ :param load_dir: The path to the model directory to load the classifier from.
132
+ :param model_name: The name of the model to load.
133
+ """
134
+ model_dir = os.path.join(load_dir, model_name)
135
+ json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
136
+ rdr = cls.from_json_file(json_file)
137
+ rdr.update_from_python(model_dir)
138
+ rdr.save_dir = load_dir
139
+ rdr.model_name = model_name
140
+ return rdr
141
+
142
+ @abstractmethod
143
+ def _write_to_python(self, model_dir: str):
144
+ """
145
+ Write the tree of rules as source code to a file.
146
+
147
+ :param model_dir: The path to the directory to write the source code to.
148
+ """
149
+ pass
150
+
75
151
  def set_viewer(self, viewer: RDRCaseViewer):
76
152
  """
77
153
  Set the viewer for the classifier.
@@ -122,13 +198,13 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
122
198
  all_predictions = [1 if is_matching(self.classify, case_query) else 0 for case_query in case_queries
123
199
  if case_query.target is not None]
124
200
  all_pred = sum(all_predictions)
125
- print(f"Accuracy: {all_pred}/{len(targets)}")
201
+ logger.info(f"Accuracy: {all_pred}/{len(targets)}")
126
202
  all_predicted = targets and all_pred == len(targets)
127
203
  num_iter_reached = n_iter and i >= n_iter
128
204
  stop_iterating = all_predicted or num_iter_reached
129
205
  if stop_iterating:
130
206
  break
131
- print(f"Finished training in {i} iterations")
207
+ logger.info(f"Finished training in {i} iterations")
132
208
  if animate_tree:
133
209
  plt.ioff()
134
210
  plt.show()
@@ -160,19 +236,29 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
160
236
  """
161
237
  if case_query is None:
162
238
  raise ValueError("The case query cannot be None.")
239
+
163
240
  self.name = case_query.attribute_name if self.name is None else self.name
164
241
  self.case_type = case_query.case_type if self.case_type is None else self.case_type
242
+ self.case_name = case_query.case_name if self.case_name is None else self.case_name
243
+
165
244
  if case_query.target is None:
166
245
  case_query_cp = copy(case_query)
167
- self.classify(case_query_cp.case, modify_case=True)
168
- expert.ask_for_conclusion(case_query_cp)
169
- case_query.target = case_query_cp.target
246
+ conclusions = self.classify(case_query_cp.case, modify_case=True)
247
+ if self.ask_always or conclusions is None or is_iterable(conclusions) and len(conclusions) == 0:
248
+ expert.ask_for_conclusion(case_query_cp)
249
+ case_query.target = case_query_cp.target
170
250
  if case_query.target is None:
171
251
  return self.classify(case_query.case)
172
252
 
173
253
  self.update_start_rule(case_query, expert)
174
254
 
175
- return self._fit_case(case_query, expert=expert, **kwargs)
255
+ fit_case_result = self._fit_case(case_query, expert=expert, **kwargs)
256
+
257
+ if self.save_dir is not None:
258
+ self.save()
259
+ expert.clear_answers()
260
+
261
+ return fit_case_result
176
262
 
177
263
  @abstractmethod
178
264
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
@@ -238,28 +324,54 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
238
324
  pass
239
325
 
240
326
  @abstractmethod
241
- def update_from_python_file(self, package_dir: str):
327
+ def update_from_python(self, model_dir: str):
242
328
  """
243
329
  Update the rules from the generated python file, that might have been modified by the user.
244
330
 
245
- :param package_dir: The directory of the package that contains the generated python file.
331
+ :param model_dir: The directory where the generated python file is located.
246
332
  """
247
333
  pass
248
334
 
335
+ @classmethod
336
+ def get_acronym(cls) -> str:
337
+ """
338
+ :return: The acronym of the classifier.
339
+ """
340
+ if cls.__name__ == "GeneralRDR":
341
+ return "RDR"
342
+ elif cls.__name__ == "MultiClassRDR":
343
+ return "MCRDR"
344
+ else:
345
+ return "SCRDR"
346
+
347
+ def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
348
+ """
349
+ :param package_name: The name of the package that contains the RDR classifier function.
350
+ :return: The module that contains the rdr classifier function.
351
+ """
352
+ # remove from imports if exists first
353
+ name = f"{package_name.strip('./').replace('/', '.')}.{self.generated_python_file_name}"
354
+ try:
355
+ module = importlib.import_module(name)
356
+ del sys.modules[name]
357
+ except ModuleNotFoundError:
358
+ pass
359
+ return importlib.import_module(name).classify
360
+
249
361
 
250
362
  class RDRWithCodeWriter(RippleDownRules, ABC):
251
363
 
252
- def update_from_python_file(self, package_dir: str):
364
+ def update_from_python(self, model_dir: str):
253
365
  """
254
366
  Update the rules from the generated python file, that might have been modified by the user.
255
367
 
256
- :param package_dir: The directory of the package that contains the generated python file.
368
+ :param model_dir: The directory where the generated python file is located.
257
369
  """
258
- rule_ids = [r.uid for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None]
259
- condition_func_names = [f'conditions_{rid}' for rid in rule_ids]
260
- conclusion_func_names = [f'conclusion_{rid}' for rid in rule_ids]
370
+ rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None}
371
+ condition_func_names = [f'conditions_{rid}' for rid in rules_dict.keys()]
372
+ conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys() if not isinstance(rules_dict[rid], MultiClassStopRule)]
261
373
  all_func_names = condition_func_names + conclusion_func_names
262
- filepath = f"{package_dir}/{self.generated_python_defs_file_name}.py"
374
+ filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
263
375
  functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
264
376
  # get the scope from the imports in the file
265
377
  scope = extract_imports(filepath)
@@ -267,7 +379,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
267
379
  if rule.conditions is not None:
268
380
  rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
269
381
  rule.conditions.scope = scope
270
- if rule.conclusion is not None:
382
+ if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
271
383
  rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
272
384
  rule.conclusion.scope = scope
273
385
 
@@ -284,17 +396,19 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
284
396
  """
285
397
  pass
286
398
 
287
- def write_to_python_file(self, file_path: str, postfix: str = ""):
399
+ def _write_to_python(self, model_dir: str):
288
400
  """
289
401
  Write the tree of rules as source code to a file.
290
402
 
291
- :param file_path: The path to the file to write the source code to.
292
- :param postfix: The postfix to add to the file name.
403
+ :param model_dir: The path to the directory to write the source code to.
293
404
  """
294
- self.generated_python_file_name = self._default_generated_python_file_name + postfix
405
+ os.makedirs(model_dir, exist_ok=True)
406
+ if not os.path.exists(model_dir + '/__init__.py'):
407
+ with open(model_dir + '/__init__.py', 'w') as f:
408
+ f.write("# This is an empty __init__.py file to make the directory a package.\n")
295
409
  func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
296
- file_name = file_path + f"/{self.generated_python_file_name}.py"
297
- defs_file_name = file_path + f"/{self.generated_python_defs_file_name}.py"
410
+ file_name = model_dir + f"/{self.generated_python_file_name}.py"
411
+ defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
298
412
  imports, defs_imports = self._get_imports()
299
413
  # clear the files first
300
414
  with open(defs_file_name, "w") as f:
@@ -345,20 +459,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
345
459
  imports = "\n".join(imports) + "\n"
346
460
  return imports, defs_imports
347
461
 
348
- def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
349
- """
350
- :param package_name: The name of the package that contains the RDR classifier function.
351
- :return: The module that contains the rdr classifier function.
352
- """
353
- # remove from imports if exists first
354
- name = f"{package_name.strip('./')}.{self.generated_python_file_name}"
355
- try:
356
- module = importlib.import_module(name)
357
- del sys.modules[name]
358
- except ModuleNotFoundError:
359
- pass
360
- return importlib.import_module(name).classify
361
-
362
462
  @property
363
463
  def _default_generated_python_file_name(self) -> Optional[str]:
364
464
  """
@@ -366,23 +466,12 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
366
466
  """
367
467
  if self.start_rule is None or self.start_rule.conclusion is None:
368
468
  return None
369
- return f"{self.case_type.__name__.lower()}_{self.attribute_name}_{self.acronym.lower()}"
469
+ return f"{str_to_snake_case(self.case_name)}_{self.attribute_name}_{self.get_acronym().lower()}"
370
470
 
371
471
  @property
372
472
  def generated_python_defs_file_name(self) -> str:
373
473
  return f"{self.generated_python_file_name}_defs"
374
474
 
375
- @property
376
- def acronym(self) -> str:
377
- """
378
- :return: The acronym of the classifier.
379
- """
380
- if self.__class__.__name__ == "GeneralRDR":
381
- return "GRDR"
382
- elif self.__class__.__name__ == "MultiClassRDR":
383
- return "MCRDR"
384
- else:
385
- return "SCRDR"
386
475
 
387
476
  @property
388
477
  def conclusion_type(self) -> Tuple[Type]:
@@ -403,7 +492,9 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
403
492
 
404
493
  def _to_json(self) -> Dict[str, Any]:
405
494
  return {"start_rule": self.start_rule.to_json(), "generated_python_file_name": self.generated_python_file_name,
406
- "name": self.name, "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None}
495
+ "name": self.name,
496
+ "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
497
+ "case_name": self.case_name}
407
498
 
408
499
  @classmethod
409
500
  def _from_json(cls, data: Dict[str, Any]) -> Self:
@@ -411,13 +502,15 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
411
502
  Create an instance of the class from a json
412
503
  """
413
504
  start_rule = cls.start_rule_type().from_json(data["start_rule"])
414
- new_rdr = cls(start_rule)
505
+ new_rdr = cls(start_rule=start_rule)
415
506
  if "generated_python_file_name" in data:
416
507
  new_rdr.generated_python_file_name = data["generated_python_file_name"]
417
508
  if "name" in data:
418
509
  new_rdr.name = data["name"]
419
510
  if "case_type" in data:
420
511
  new_rdr.case_type = get_type_from_string(data["case_type"])
512
+ if "case_name" in data:
513
+ new_rdr.case_name = data["case_name"]
421
514
  return new_rdr
422
515
 
423
516
  @staticmethod
@@ -431,12 +524,12 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
431
524
 
432
525
  class SingleClassRDR(RDRWithCodeWriter):
433
526
 
434
- def __init__(self, start_rule: Optional[SingleClassRule] = None, default_conclusion: Optional[Any] = None):
527
+ def __init__(self, default_conclusion: Optional[Any] = None, **kwargs):
435
528
  """
436
529
  :param start_rule: The starting rule for the classifier.
437
530
  :param default_conclusion: The default conclusion for the classifier if no rules fire.
438
531
  """
439
- super(SingleClassRDR, self).__init__(start_rule)
532
+ super(SingleClassRDR, self).__init__(**kwargs)
440
533
  self.default_conclusion: Optional[Any] = default_conclusion
441
534
 
442
535
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
@@ -488,10 +581,10 @@ class SingleClassRDR(RDRWithCodeWriter):
488
581
  matched_rule = self.start_rule(case) if self.start_rule is not None else None
489
582
  return matched_rule if matched_rule is not None else self.start_rule
490
583
 
491
- def write_to_python_file(self, file_path: str, postfix: str = ""):
492
- super().write_to_python_file(file_path, postfix)
584
+ def _write_to_python(self, model_dir: str):
585
+ super()._write_to_python(model_dir)
493
586
  if self.default_conclusion is not None:
494
- with open(file_path + f"/{self.generated_python_file_name}.py", "a") as f:
587
+ with open(model_dir + f"/{self.generated_python_file_name}.py", "a") as f:
495
588
  f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
496
589
 
497
590
  def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
@@ -892,7 +985,8 @@ class GeneralRDR(RippleDownRules):
892
985
  return {"start_rules": {name: rdr.to_json() for name, rdr in self.start_rules_dict.items()}
893
986
  , "generated_python_file_name": self.generated_python_file_name,
894
987
  "name": self.name,
895
- "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None}
988
+ "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
989
+ "case_name": self.case_name}
896
990
 
897
991
  @classmethod
898
992
  def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
@@ -902,37 +996,37 @@ class GeneralRDR(RippleDownRules):
902
996
  start_rules_dict = {}
903
997
  for k, v in data["start_rules"].items():
904
998
  start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
905
- new_rdr = cls(start_rules_dict)
999
+ new_rdr = cls(category_rdr_map=start_rules_dict)
906
1000
  if "generated_python_file_name" in data:
907
1001
  new_rdr.generated_python_file_name = data["generated_python_file_name"]
908
1002
  if "name" in data:
909
1003
  new_rdr.name = data["name"]
910
1004
  if "case_type" in data:
911
1005
  new_rdr.case_type = get_type_from_string(data["case_type"])
1006
+ if "case_name" in data:
1007
+ new_rdr.case_name = data["case_name"]
912
1008
  return new_rdr
913
1009
 
914
- def update_from_python_file(self, package_dir: str) -> None:
1010
+ def update_from_python(self, model_dir: str) -> None:
915
1011
  """
916
1012
  Update the rules from the generated python file, that might have been modified by the user.
917
1013
 
918
- :param package_dir: The directory of the package that contains the generated python file.
1014
+ :param model_dir: The directory where the model is stored.
919
1015
  """
920
1016
  for rdr in self.start_rules_dict.values():
921
- rdr.update_from_python_file(package_dir)
1017
+ rdr.update_from_python(model_dir)
922
1018
 
923
- def write_to_python_file(self, file_path: str, postfix: str = "") -> None:
1019
+ def _write_to_python(self, model_dir: str) -> None:
924
1020
  """
925
1021
  Write the tree of rules as source code to a file.
926
1022
 
927
- :param file_path: The path to the file to write the source code to.
928
- :param postfix: The postfix to add to the file name.
1023
+ :param model_dir: The directory where the model is stored.
929
1024
  """
930
- self.generated_python_file_name = self._default_generated_python_file_name + postfix
931
1025
  for rdr in self.start_rules_dict.values():
932
- rdr.write_to_python_file(file_path, postfix=f"_of_grdr{postfix}")
1026
+ rdr._write_to_python(model_dir)
933
1027
  func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
934
- with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
935
- f.write(self._get_imports(file_path) + "\n\n")
1028
+ with open(model_dir + f"/{self.generated_python_file_name}.py", "w") as f:
1029
+ f.write(self._get_imports() + "\n\n")
936
1030
  f.write("classifiers_dict = dict()\n")
937
1031
  for rdr_key, rdr in self.start_rules_dict.items():
938
1032
  f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
@@ -942,13 +1036,6 @@ class GeneralRDR(RippleDownRules):
942
1036
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
943
1037
  f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
944
1038
 
945
- def get_rdr_classifier_from_python_file(self, file_path: str) -> Callable[[Any], Any]:
946
- """
947
- :param file_path: The path to the file that contains the RDR classifier function.
948
- :return: The module that contains the rdr classifier function.
949
- """
950
- return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
951
-
952
1039
  @property
953
1040
  def _default_generated_python_file_name(self) -> Optional[str]:
954
1041
  """
@@ -956,17 +1043,16 @@ class GeneralRDR(RippleDownRules):
956
1043
  """
957
1044
  if self.start_rule is None or self.start_rule.conclusion is None:
958
1045
  return None
959
- return f"{self.case_type.__name__.lower()}_rdr".lower()
1046
+ return f"{str_to_snake_case(self.case_name)}_rdr".lower()
960
1047
 
961
1048
  @property
962
1049
  def conclusion_type_hint(self) -> str:
963
1050
  return f"Dict[str, Any]"
964
1051
 
965
- def _get_imports(self, file_path: str) -> str:
1052
+ def _get_imports(self) -> str:
966
1053
  """
967
1054
  Get the imports needed for the generated python file.
968
1055
 
969
- :param file_path: The path to the file that contains the RDR classifier function.
970
1056
  :return: The imports needed for the generated python file.
971
1057
  """
972
1058
  imports = ""
@@ -5,15 +5,17 @@ of the RDRs.
5
5
  """
6
6
  import os.path
7
7
  from functools import wraps
8
+
9
+ from pyparsing.tools.cvt_pyparsing_pep8_names import camel_to_snake
8
10
  from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union
9
11
 
10
- from ripple_down_rules.datastructures.case import create_case
12
+ from ripple_down_rules.datastructures.case import create_case, Case
11
13
  from ripple_down_rules.datastructures.dataclasses import CaseQuery
12
14
  from ripple_down_rules.datastructures.enums import Category
13
15
  from ripple_down_rules.experts import Expert, Human
14
16
  from ripple_down_rules.rdr import GeneralRDR, RippleDownRules
15
17
  from ripple_down_rules.utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
16
- get_method_class_if_exists
18
+ get_method_class_if_exists, get_method_name, str_to_snake_case
17
19
 
18
20
 
19
21
  class RDRDecorator:
@@ -41,99 +43,118 @@ class RDRDecorator:
41
43
  :return: A decorator to use a GeneralRDR as a classifier that monitors and modifies the function's output.
42
44
  """
43
45
  self.rdr_models_dir = models_dir
46
+ self.model_name: Optional[str] = None
44
47
  self.output_type = output_type
45
48
  self.parsed_output_type: List[Type] = []
46
49
  self.mutual_exclusive = mutual_exclusive
47
50
  self.rdr_python_path: Optional[str] = python_dir
48
51
  self.output_name = output_name
49
52
  self.fit: bool = fit
50
- self.expert = expert if expert else Human()
51
- self.rdr_model_path: Optional[str] = None
53
+ self.expert: Optional[Expert] = expert
52
54
  self.load()
53
55
 
54
56
  def decorator(self, func: Callable) -> Callable:
55
57
 
56
58
  @wraps(func)
57
59
  def wrapper(*args, **kwargs) -> Optional[Any]:
60
+
58
61
  if len(self.parsed_output_type) == 0:
59
- self.parse_output_type(func, *args)
60
- if self.rdr_model_path is None:
61
- self.initialize_rdr_model_path_and_load(func)
62
- case_dict = get_method_args_as_dict(func, *args, **kwargs)
63
- func_output = func(*args, **kwargs)
64
- case_dict.update({self.output_name: func_output})
65
- case = create_case(case_dict, obj_name=get_func_rdr_model_name(func), max_recursion_idx=3)
62
+ self.parsed_output_type = self.parse_output_type(func, self.output_type, *args)
63
+ if self.model_name is None:
64
+ self.initialize_rdr_model_name_and_load(func)
65
+
66
66
  if self.fit:
67
- scope = func.__globals__
68
- scope.update(case_dict)
69
- func_args_type_hints = get_type_hints(func)
70
- func_args_type_hints.update({self.output_name: Union[tuple(self.parsed_output_type)]})
71
- case_query = CaseQuery(case, self.output_name, Union[tuple(self.parsed_output_type)],
72
- self.mutual_exclusive,
73
- scope=scope, is_function=True, function_args_type_hints=func_args_type_hints)
67
+ expert_answers_path = os.path.join(self.rdr_models_dir, self.model_name, "expert_answers")
68
+ self.expert = self.expert or Human(answers_save_path=expert_answers_path)
69
+ case_query = self.create_case_query_from_method(func, self.parsed_output_type,
70
+ self.mutual_exclusive, self.output_name,
71
+ *args, **kwargs)
74
72
  output = self.rdr.fit_case(case_query, expert=self.expert)
75
73
  return output[self.output_name]
76
74
  else:
75
+ case, case_dict = self.create_case_from_method(func, self.output_name, *args, **kwargs)
77
76
  return self.rdr.classify(case)[self.output_name]
78
77
 
79
78
  return wrapper
80
79
 
81
- def initialize_rdr_model_path_and_load(self, func: Callable) -> None:
80
+ @staticmethod
81
+ def create_case_query_from_method(func: Callable, output_type, mutual_exclusive: bool,
82
+ output_name: str = 'output_', *args, **kwargs) -> CaseQuery:
83
+ """
84
+ Create a CaseQuery from the function and its arguments.
85
+
86
+ :param func: The function to create a case from.
87
+ :param output_type: The type of the output.
88
+ :param mutual_exclusive: If True, the output types are mutually exclusive.
89
+ :param output_name: The name of the output in the case. Defaults to 'output_'.
90
+ :param args: The positional arguments of the function.
91
+ :param kwargs: The keyword arguments of the function.
92
+ :return: A CaseQuery object representing the case.
93
+ """
94
+ output_type = make_set(output_type)
95
+ case, case_dict = RDRDecorator.create_case_from_method(func, output_name, *args, **kwargs)
96
+ scope = func.__globals__
97
+ scope.update(case_dict)
98
+ func_args_type_hints = get_type_hints(func)
99
+ func_args_type_hints.update({output_name: Union[tuple(output_type)]})
100
+ return CaseQuery(case, output_name, Union[tuple(output_type)],
101
+ mutual_exclusive, scope=scope,
102
+ is_function=True, function_args_type_hints=func_args_type_hints)
103
+
104
+ @staticmethod
105
+ def create_case_from_method(func: Callable, output_name: str = "output_", *args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
106
+ """
107
+ Create a Case from the function and its arguments.
108
+
109
+ :param func: The function to create a case from.
110
+ :param output_name: The name of the output in the case. Defaults to 'output_'.
111
+ :param args: The positional arguments of the function.
112
+ :param kwargs: The keyword arguments of the function.
113
+ :return: A Case object representing the case.
114
+ """
115
+ case_dict = get_method_args_as_dict(func, *args, **kwargs)
116
+ func_output = func(*args, **kwargs)
117
+ case_dict.update({output_name: func_output})
118
+ case_name = get_func_rdr_model_name(func)
119
+ return create_case(case_dict, obj_name=case_name, max_recursion_idx=3), case_dict
120
+
121
+ def initialize_rdr_model_name_and_load(self, func: Callable) -> None:
82
122
  model_file_name = get_func_rdr_model_name(func, include_file_name=True)
83
- model_file_name = (''.join(['_' + c.lower() if c.isupper() else c for c in model_file_name]).lstrip('_')
84
- .replace('__', '_') + ".json")
85
- self.rdr_model_path = os.path.join(self.rdr_models_dir, model_file_name)
123
+ self.model_name = str_to_snake_case(model_file_name)
86
124
  self.load()
87
125
 
88
- def parse_output_type(self, func: Callable, *args) -> None:
89
- for ot in make_set(self.output_type):
126
+ @staticmethod
127
+ def parse_output_type(func: Callable, output_type: Any, *args) -> List[Type]:
128
+ parsed_output_type = []
129
+ for ot in make_set(output_type):
90
130
  if ot is Self:
91
131
  func_class = get_method_class_if_exists(func, *args)
92
132
  if func_class is not None:
93
- self.parsed_output_type.append(func_class)
133
+ parsed_output_type.append(func_class)
94
134
  else:
95
135
  raise ValueError(f"The function {func} is not a method of a class,"
96
136
  f" and the output type is {Self}.")
97
137
  else:
98
- self.parsed_output_type.append(ot)
138
+ parsed_output_type.append(ot)
139
+ return parsed_output_type
99
140
 
100
141
  def save(self):
101
142
  """
102
143
  Save the RDR model to the specified directory.
103
144
  """
104
- self.rdr.save(self.rdr_model_path)
105
-
106
- if self.rdr_python_path is not None:
107
- if not os.path.exists(self.rdr_python_path):
108
- os.makedirs(self.rdr_python_path)
109
- if not os.path.exists(os.path.join(self.rdr_python_path, "__init__.py")):
110
- # add __init__.py file to the directory
111
- with open(os.path.join(self.rdr_python_path, "__init__.py"), "w") as f:
112
- f.write("# This is an empty __init__.py file to make the directory a package.")
113
- # write the RDR model to a python file
114
- self.rdr.write_to_python_file(self.rdr_python_path)
145
+ self.rdr.save(self.rdr_models_dir)
115
146
 
116
147
  def load(self):
117
148
  """
118
149
  Load the RDR model from the specified directory.
119
150
  """
120
- if self.rdr_model_path is not None and os.path.exists(self.rdr_model_path):
121
- self.rdr = GeneralRDR.load(self.rdr_model_path)
151
+ if self.model_name is not None and os.path.exists(os.path.join(self.rdr_models_dir, self.model_name)):
152
+ self.rdr = GeneralRDR.load(self.rdr_models_dir, self.model_name)
122
153
  else:
123
- self.rdr = GeneralRDR()
124
-
125
- def write_to_python_file(self, package_dir: str, file_name_postfix: str = ""):
126
- """
127
- Write the RDR model to a python file.
128
-
129
- :param package_dir: The path to the directory to write the python file.
130
- """
131
- self.rdr.write_to_python_file(package_dir, postfix=file_name_postfix)
154
+ self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name)
132
155
 
133
- def update_from_python_file(self, package_dir: str):
156
+ def update_from_python(self):
134
157
  """
135
158
  Update the RDR model from a python file.
136
-
137
- :param package_dir: The directory of the package that contains the generated python file.
138
159
  """
139
- self.rdr.update_from_python_file(package_dir)
160
+ self.rdr.update_from_python(self.rdr_models_dir, self.model_name)
@@ -118,7 +118,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
118
118
  func_call = f"{parent_indent} return {new_function_name.replace('def ', '')}(case)\n"
119
119
  return "\n".join(conclusion_lines).strip(' '), func_call
120
120
  else:
121
- raise ValueError(f"Conclusion is format is not valid, it should be contain a function definition."
121
+ raise ValueError(f"Conclusion format is not valid, it should contain a function definition."
122
122
  f" Instead got:\n{conclusion}\n")
123
123
 
124
124
  def write_condition_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
@@ -129,9 +129,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
129
129
  :param defs_file: The file to write the conditions to if they are a definition.
130
130
  """
131
131
  if_clause = self._if_statement_source_code_clause()
132
- if '\n' not in self.conditions.user_input:
133
- return f"{parent_indent}{if_clause} {self.conditions.user_input}:\n"
134
- elif "def " in self.conditions.user_input:
132
+ if "def " in self.conditions.user_input:
135
133
  if defs_file is None:
136
134
  raise ValueError("Cannot write conditions to source code as definitions python file was not given.")
137
135
  # This means the conditions are a definition that should be written and then called
@@ -143,6 +141,9 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
143
141
  with open(defs_file, 'a') as f:
144
142
  f.write(def_code.strip() + "\n\n\n")
145
143
  return f"\n{parent_indent}{if_clause} {new_function_name.replace('def ', '')}(case):\n"
144
+ else:
145
+ raise ValueError(f"Conditions format is not valid, it should contain a function definition"
146
+ f" Instead got:\n{self.conditions.user_input}\n")
146
147
 
147
148
  @abstractmethod
148
149
  def _if_statement_source_code_clause(self) -> str:
@@ -258,12 +258,12 @@ class TemplateFileCreator:
258
258
  func_name = f"{prompt_for.value.lower()}_for_"
259
259
  case_name = case_query.name.replace(".", "_")
260
260
  if case_query.is_function:
261
- # convert any CamelCase word into snake_case by adding _ before each capital letter
262
- case_name = case_name.replace(f"_{case_query.attribute_name}", "")
263
- func_name += case_name
264
- attribute_types = TemplateFileCreator.get_core_attribute_types(case_query)
265
- attribute_type_names = [t.__name__ for t in attribute_types]
266
- func_name += f"_of_type_{'_or_'.join(attribute_type_names)}"
261
+ func_name += case_name.replace(f"_{case_query.attribute_name}", "")
262
+ else:
263
+ func_name += case_name
264
+ attribute_types = TemplateFileCreator.get_core_attribute_types(case_query)
265
+ attribute_type_names = [t.__name__ for t in attribute_types]
266
+ func_name += f"_of_type_{'_or_'.join(attribute_type_names)}"
267
267
  return str_to_snake_case(func_name)
268
268
 
269
269
  @cached_property
@@ -178,7 +178,7 @@ def extract_function_source(file_path: str,
178
178
  functions_source: Dict[str, Union[str, List[str]]] = {}
179
179
  line_numbers = []
180
180
  for node in tree.body:
181
- if isinstance(node, ast.FunctionDef) and node.name in function_names:
181
+ if isinstance(node, ast.FunctionDef) and (node.name in function_names or len(function_names) == 0):
182
182
  # Get the line numbers of the function
183
183
  lines = source.splitlines()
184
184
  func_lines = lines[node.lineno - 1:node.end_lineno]
@@ -186,9 +186,9 @@ def extract_function_source(file_path: str,
186
186
  func_lines = func_lines[1:]
187
187
  line_numbers.append((node.lineno, node.end_lineno))
188
188
  functions_source[node.name] = dedent("\n".join(func_lines)) if join_lines else func_lines
189
- if len(functions_source) == len(function_names):
189
+ if len(functions_source) >= len(function_names):
190
190
  break
191
- if len(functions_source) != len(function_names):
191
+ if len(functions_source) < len(function_names):
192
192
  raise ValueError(f"Could not find all functions in {file_path}: {function_names} not found,"
193
193
  f"functions not found: {set(function_names) - set(functions_source.keys())}")
194
194
  if return_line_numbers:
@@ -953,9 +953,6 @@ class SubclassJSONSerializer:
953
953
 
954
954
  raise ValueError("Unknown type {}".format(data["_type"]))
955
955
 
956
- save = to_json_file
957
- load = from_json_file
958
-
959
956
 
960
957
  def _pickle_thread(thread_obj) -> Any:
961
958
  """Return a plain object with user-defined attributes but no thread behavior."""
@@ -1098,24 +1095,29 @@ def get_origin_and_args_from_type_hint(type_hint: Type) -> Tuple[Optional[Type],
1098
1095
  return origin, args
1099
1096
 
1100
1097
 
1101
- def table_rows_as_str(row_dict: Dict[str, Any], columns_per_row: int = 9):
1098
+ def table_rows_as_str(row_dicts: List[Dict[str, Any]], columns_per_row: int = 20):
1102
1099
  """
1103
1100
  Print a table row.
1104
1101
 
1105
- :param row_dict: The row to print.
1102
+ :param row_dicts: The rows to print.
1106
1103
  :param columns_per_row: The maximum number of columns per row.
1107
1104
  """
1108
- all_items = list(row_dict.items())
1105
+ all_row_dicts_items = [list(row_dict.items()) for row_dict in row_dicts]
1109
1106
  # make items a list of n rows such that each row has a max size of 4
1110
- all_items = [all_items[i:i + columns_per_row] for i in range(0, len(all_items), columns_per_row)]
1107
+ all_items = [all_items[i:i + columns_per_row] for all_items in all_row_dicts_items
1108
+ for i in range(0, len(all_items), columns_per_row)]
1111
1109
  keys = [list(map(lambda i: i[0], row)) for row in all_items]
1112
1110
  values = [list(map(lambda i: i[1], row)) for row in all_items]
1111
+ zipped_keys = list(zip(*keys))
1112
+ zipped_values = list(zip(*values))
1113
+ keys_values = [list(zip(zipped_keys[i], zipped_values[i])) for i in range(len(zipped_keys))]
1114
+ keys_values = [list(r[0]) + list(r[1]) if len(r) > 1 else r[0] for r in keys_values]
1113
1115
  all_table_rows = []
1114
- for row_keys, row_values in zip(keys, values):
1115
- row_values = [str(v) if v is not None else "" for v in row_values]
1116
- row_values = [v.lower() if v in ["True", "False"] else v for v in row_values]
1117
- table = tabulate([row_values], headers=row_keys, tablefmt='plain', maxcolwidths=[20] * len(row_keys))
1118
- all_table_rows.append(table)
1116
+ row_values = [list(map(lambda v: str(v) if v is not None else "", row)) for row in keys_values]
1117
+ row_values = [list(map(lambda v: v[:150] + '...' if len(v) > 150 else v, row)) for row in row_values]
1118
+ row_values = [list(map(lambda v: v.lower() if v in ["True", "False"] else v, row)) for row in row_values]
1119
+ table = tabulate(row_values, tablefmt='simple_grid', maxcolwidths=[150] * 2)
1120
+ all_table_rows.append(table)
1119
1121
  return "\n".join(all_table_rows)
1120
1122
 
1121
1123
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.4.86
3
+ Version: 0.5.0
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -693,20 +693,12 @@ Requires-Dist: colorama
693
693
  Requires-Dist: pygments
694
694
  Requires-Dist: sqlalchemy
695
695
  Requires-Dist: pandas
696
+ Provides-Extra: viz
697
+ Requires-Dist: networkx>=3.1; extra == "viz"
698
+ Requires-Dist: matplotlib>=3.7.5; extra == "viz"
696
699
  Provides-Extra: gui
697
700
  Requires-Dist: pyqt6; extra == "gui"
698
701
  Requires-Dist: qtconsole; extra == "gui"
699
- Provides-Extra: viz
700
- Requires-Dist: matplotlib; extra == "viz"
701
- Requires-Dist: networkx; extra == "viz"
702
- Provides-Extra: dev
703
- Requires-Dist: pyqt6; extra == "dev"
704
- Requires-Dist: qtconsole; extra == "dev"
705
- Requires-Dist: matplotlib; extra == "dev"
706
- Requires-Dist: networkx; extra == "dev"
707
- Requires-Dist: pytest; extra == "dev"
708
- Requires-Dist: ucimlrepo>=0.0.7; extra == "dev"
709
- Requires-Dist: pdbpp; extra == "dev"
710
702
  Dynamic: license-file
711
703
 
712
704
  # Ripple Down Rules (RDR)
@@ -0,0 +1,26 @@
1
+ ripple_down_rules/__init__.py,sha256=FQAv_KtUXVoz9VCR37DEoK0MC84rKZnIRUy-4pQ95sE,100
2
+ ripple_down_rules/datasets.py,sha256=fJbZ7V-UUYTu5XVVpFinTbuzN3YePCnUB01L3AyZVM8,6837
3
+ ripple_down_rules/experts.py,sha256=9Vc3vx0uhDPy3YlNjwKuWJLl_A-kubRPUU6bMvQhaAg,13237
4
+ ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
+ ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
6
+ ripple_down_rules/rdr.py,sha256=E1OiiZClQyAfGjL64ID-MWYFO4-h8iUAX-Vm9qrOoeQ,48727
7
+ ripple_down_rules/rdr_decorators.py,sha256=pYCKLgMKgQ6x_252WQtF2t4ZNjWPBxnaWtJ6TpGdcc0,7820
8
+ ripple_down_rules/rules.py,sha256=TPNVMqW9T-_46BS4WemrspLg5uG8kP6tsPvWWBAzJxg,17515
9
+ ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
10
+ ripple_down_rules/utils.py,sha256=uS38KcFceRMzT_470DCL1M0LzETdP5RLwE7cCmfo7eI,51086
11
+ ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
12
+ ripple_down_rules/datastructures/callable_expression.py,sha256=3EucsD3jWzekhjyzL2y0dyUsucd-aqC9glmgPL0Ubb4,12425
13
+ ripple_down_rules/datastructures/case.py,sha256=r8kjL9xP_wk84ThXusspgPMrAoed2bGQmKi54fzhmH8,15258
14
+ ripple_down_rules/datastructures/dataclasses.py,sha256=PuD-7zWqWT2p4FnGvnihHvZlZKg9A1ctnFgVYf2cs-8,8554
15
+ ripple_down_rules/datastructures/enums.py,sha256=ce7tqS0otfSTNAOwsnXlhsvIn4iW_Y_N3TNebF3YoZs,5700
16
+ ripple_down_rules/user_interface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
+ ripple_down_rules/user_interface/gui.py,sha256=SB0gUhgReJ3yx-NEHRPMGVuNRLPRUwW8-qup-Kd4Cfo,27182
18
+ ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=24MIFwqnAhC6ofObEO6x5xRWRnyQmPpPmTvxbCKBrzM,6514
19
+ ripple_down_rules/user_interface/object_diagram.py,sha256=tsB6iuLNEbHxp5lR2WjyejjWbnAX_nHF9xS8jNPOQVk,4548
20
+ ripple_down_rules/user_interface/prompt.py,sha256=AkkltdDIaioN43lkRKDPKSjJcmdSSGZDMYz7AL7X9lE,8082
21
+ ripple_down_rules/user_interface/template_file_creator.py,sha256=ycCbddy_BJP8d0Q2Sj21UzamhGtqGZuK_e73VTJqznY,13766
22
+ ripple_down_rules-0.5.0.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
23
+ ripple_down_rules-0.5.0.dist-info/METADATA,sha256=LYiepkd0xlfYVqVMdVrKZNbMJuxybqBheA2b0_CgGsY,43306
24
+ ripple_down_rules-0.5.0.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
25
+ ripple_down_rules-0.5.0.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
26
+ ripple_down_rules-0.5.0.dist-info/RECORD,,
@@ -1,26 +0,0 @@
1
- ripple_down_rules/__init__.py,sha256=v-1F6x1Q2QlDiOWgzsCkkSZVyHAeRboirfjTWeRqJxo,22
2
- ripple_down_rules/datasets.py,sha256=fJbZ7V-UUYTu5XVVpFinTbuzN3YePCnUB01L3AyZVM8,6837
3
- ripple_down_rules/experts.py,sha256=RWDR-xxbeFIrUQiMYLEDr_PLQFdpPZ-hOXo4dpeiUpI,6630
4
- ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
- ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
6
- ripple_down_rules/rdr.py,sha256=zQiWHUw8Qq-FPuKHdBErZUhh8zqiB6UCBi1kWZx5cJE,45511
7
- ripple_down_rules/rdr_decorators.py,sha256=VdmE0JrE8j89b6Af1R1tLZiKfy3h1VCvhAUefN_FLLQ,6753
8
- ripple_down_rules/rules.py,sha256=7NB8qWW7XEB45tmJRYsKJqBG8DN3v02fzAFYmOkX8ow,17458
9
- ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
10
- ripple_down_rules/utils.py,sha256=iM2bYRYanuWTnq7dflRar8tMwRxL88B__hWkayGLVz4,50675
11
- ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
12
- ripple_down_rules/datastructures/callable_expression.py,sha256=jA7424_mWPbOoPICW3eLMX0-ypxnsW6gOqxrJ7JpDbE,11610
13
- ripple_down_rules/datastructures/case.py,sha256=3nGeTv9JYpnRSmsHCt61qo1QT4d9Dfkl9KrZ7I6E2kk,15164
14
- ripple_down_rules/datastructures/dataclasses.py,sha256=GWnUF4h4zfNHSsyBIz3L9y8sLkrXRv0FK_OxzzLc8L8,8183
15
- ripple_down_rules/datastructures/enums.py,sha256=ce7tqS0otfSTNAOwsnXlhsvIn4iW_Y_N3TNebF3YoZs,5700
16
- ripple_down_rules/user_interface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- ripple_down_rules/user_interface/gui.py,sha256=SB0gUhgReJ3yx-NEHRPMGVuNRLPRUwW8-qup-Kd4Cfo,27182
18
- ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=24MIFwqnAhC6ofObEO6x5xRWRnyQmPpPmTvxbCKBrzM,6514
19
- ripple_down_rules/user_interface/object_diagram.py,sha256=tsB6iuLNEbHxp5lR2WjyejjWbnAX_nHF9xS8jNPOQVk,4548
20
- ripple_down_rules/user_interface/prompt.py,sha256=AkkltdDIaioN43lkRKDPKSjJcmdSSGZDMYz7AL7X9lE,8082
21
- ripple_down_rules/user_interface/template_file_creator.py,sha256=J_bBOJltc1fsrIYeHdrSUA_jep2DhDbTK5NYRbL6QyY,13831
22
- ripple_down_rules-0.4.86.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
23
- ripple_down_rules-0.4.86.dist-info/METADATA,sha256=OZ1QS5AnJFolHEoUSn_b88yB7jX2SN3Kanyr6u_Y9oY,43598
24
- ripple_down_rules-0.4.86.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
25
- ripple_down_rules-0.4.86.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
26
- ripple_down_rules-0.4.86.dist-info/RECORD,,