ripple-down-rules 0.1.21__py3-none-any.whl → 0.1.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ripple_down_rules/rdr.py CHANGED
@@ -1,20 +1,28 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import importlib
4
+ import sys
4
5
  from abc import ABC, abstractmethod
5
6
  from copy import copy
7
+ from dataclasses import is_dataclass
8
+ from io import TextIOWrapper
6
9
  from types import ModuleType
7
10
 
8
11
  from matplotlib import pyplot as plt
9
12
  from ordered_set import OrderedSet
10
- from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
13
+ from sqlalchemy.orm import DeclarativeBase as SQLTable
11
14
  from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
12
15
 
13
- from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute, CaseQuery
16
+ from .datastructures.callable_expression import CallableExpression
17
+ from .datastructures.case import Case, CaseAttribute, create_case
18
+ from .datastructures.dataclasses import CaseQuery
19
+ from .datastructures.enums import MCRDRMode, PromptFor
14
20
  from .experts import Expert, Human
21
+ from .helpers import is_matching
15
22
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
16
23
  from .utils import draw_tree, make_set, copy_case, \
17
- get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string
24
+ SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
25
+ get_case_attribute_type, is_conflicting
18
26
 
19
27
 
20
28
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -29,14 +37,16 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
29
37
  """
30
38
  The conclusions that the expert has accepted, such that they are not asked again.
31
39
  """
40
+ _generated_python_file_name: Optional[str] = None
41
+ """
42
+ The name of the generated python file.
43
+ """
32
44
 
33
- def __init__(self, start_rule: Optional[Rule] = None, session: Optional[Session] = None):
45
+ def __init__(self, start_rule: Optional[Rule] = None):
34
46
  """
35
47
  :param start_rule: The starting rule for the classifier.
36
- :param session: The sqlalchemy orm session.
37
48
  """
38
49
  self.start_rule = start_rule
39
- self.session = session
40
50
  self.fig: Optional[plt.Figure] = None
41
51
 
42
52
  def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
@@ -79,31 +89,29 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
79
89
  :param animate_tree: Whether to draw the tree while fitting the classifier.
80
90
  :param kwargs_for_fit_case: The keyword arguments to pass to the fit_case method.
81
91
  """
82
- cases = [case_query.case for case_query in case_queries]
83
- targets = [{case_query.attribute_name: case_query.target} for case_query in case_queries]
92
+ targets = []
84
93
  if animate_tree:
85
94
  plt.ion()
86
95
  i = 0
87
96
  stop_iterating = False
88
97
  num_rules: int = 0
89
98
  while not stop_iterating:
90
- all_pred = 0
91
- if not targets:
92
- targets = [None] * len(cases)
93
99
  for case_query in case_queries:
94
- target = {case_query.attribute_name: case_query.target}
95
100
  pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
96
- match = self.is_matching(pred_cat, target)
101
+ if case_query.target is None:
102
+ continue
103
+ target = {case_query.attribute_name: case_query.target(case_query.case)}
104
+ if len(targets) < len(case_queries):
105
+ targets.append(target)
106
+ match = is_matching(self.classify, case_query, pred_cat)
97
107
  if not match:
98
108
  print(f"Predicted: {pred_cat} but expected: {target}")
99
- all_pred += int(match)
100
109
  if animate_tree and self.start_rule.size > num_rules:
101
110
  num_rules = self.start_rule.size
102
111
  self.update_figures()
103
112
  i += 1
104
- all_predictions = [1 if self.is_matching(self.classify(case_query.case), {case_query.attribute_name:
105
- case_query.target}) else 0
106
- for case_query in case_queries]
113
+ all_predictions = [1 if is_matching(self.classify, case_query) else 0 for case_query in case_queries
114
+ if case_query.target is not None]
107
115
  all_pred = sum(all_predictions)
108
116
  print(f"Accuracy: {all_pred}/{len(targets)}")
109
117
  all_predicted = targets and all_pred == len(targets)
@@ -116,48 +124,6 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
116
124
  plt.ioff()
117
125
  plt.show()
118
126
 
119
- @staticmethod
120
- def calculate_precision_and_recall(pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> Tuple[
121
- List[bool], List[bool]]:
122
- """
123
- :param pred_cat: The predicted category.
124
- :param target: The target category.
125
- :return: The precision and recall of the classifier.
126
- """
127
- pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
128
- target = target if is_iterable(target) else [target]
129
- recall = []
130
- precision = []
131
- if isinstance(pred_cat, dict):
132
- for pred_key, pred_value in pred_cat.items():
133
- if pred_key not in target:
134
- continue
135
- precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
136
- for target_key, target_value in target.items():
137
- if target_key not in pred_cat:
138
- recall.append(False)
139
- continue
140
- if is_iterable(target_value):
141
- recall.extend([v in pred_cat[target_key] for v in target_value])
142
- else:
143
- recall.append(target_value == pred_cat[target_key])
144
- else:
145
- if isinstance(target, dict):
146
- target = list(target.values())
147
- recall = [not yi or (yi in pred_cat) for yi in target]
148
- target_types = [type(yi) for yi in target]
149
- precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
150
- return precision, recall
151
-
152
- def is_matching(self, pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> bool:
153
- """
154
- :param pred_cat: The predicted category.
155
- :param target: The target category.
156
- :return: Whether the classifier is matching or not.
157
- """
158
- precision, recall = self.calculate_precision_and_recall(pred_cat, target)
159
- return all(recall) and all(precision)
160
-
161
127
  def update_figures(self):
162
128
  """
163
129
  Update the figures of the classifier.
@@ -183,33 +149,51 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
183
149
  """
184
150
  return hasattr(case, conclusion_name) and getattr(case, conclusion_name) is not None
185
151
 
152
+ @property
153
+ def type_(self):
154
+ return self.__class__
155
+
186
156
 
187
157
  class RDRWithCodeWriter(RippleDownRules, ABC):
188
158
 
189
159
  @abstractmethod
190
- def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = ""):
160
+ def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
161
+ defs_file: Optional[str] = None):
191
162
  """
192
163
  Write the rules as source code to a file.
193
164
 
194
165
  :param rule: The rule to write as source code.
195
166
  :param file: The file to write the source code to.
196
167
  :param parent_indent: The indentation of the parent rule.
168
+ :param defs_file: The file to write the definitions to.
197
169
  """
198
170
  pass
199
171
 
200
- def write_to_python_file(self, file_path: str):
172
+ def write_to_python_file(self, file_path: str, postfix: str = ""):
201
173
  """
202
174
  Write the tree of rules as source code to a file.
203
175
 
204
176
  :param file_path: The path to the file to write the source code to.
177
+ :param postfix: The postfix to add to the file name.
205
178
  """
179
+ self.generated_python_file_name = self._default_generated_python_file_name + postfix
206
180
  func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
207
- with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
208
- f.write(self._get_imports() + "\n\n")
181
+ file_name = file_path + f"/{self.generated_python_file_name}.py"
182
+ defs_file_name = file_path + f"/{self.generated_python_defs_file_name}.py"
183
+ imports = self._get_imports()
184
+ # clear the files first
185
+ with open(defs_file_name, "w") as f:
186
+ f.write(imports + "\n\n")
187
+ with open(file_name, "w") as f:
188
+ imports += f"from .{self.generated_python_defs_file_name} import *\n"
189
+ imports += f"from ripple_down_rules.rdr import {self.__class__.__name__}\n"
190
+ f.write(imports + "\n\n")
191
+ f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n\n")
192
+ f.write(f"type_ = {self.__class__.__name__}\n\n")
209
193
  f.write(func_def)
210
194
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
211
195
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
212
- self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
196
+ self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name)
213
197
 
214
198
  @property
215
199
  @abstractmethod
@@ -226,26 +210,72 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
226
210
  imports = ""
227
211
  if self.case_type.__module__ != "builtins":
228
212
  imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
229
- if self.conclusion_type.__module__ != "builtins":
230
- imports += f"from {self.conclusion_type.__module__} import {self.conclusion_type.__name__}\n"
231
- imports += "from ripple_down_rules.datastructures import Case, create_case\n"
213
+ for conclusion_type in self.conclusion_type:
214
+ if conclusion_type.__module__ != "builtins":
215
+ imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
216
+ imports += "from ripple_down_rules.datastructures.case import Case, create_case\n"
232
217
  for rule in [self.start_rule] + list(self.start_rule.descendants):
233
- if rule.conditions:
234
- if rule.conditions.scope is not None and len(rule.conditions.scope) > 0:
235
- for k, v in rule.conditions.scope.items():
236
- imports += f"from {v.__module__} import {v.__name__}\n"
218
+ if not rule.conditions:
219
+ continue
220
+ if rule.conditions.scope is None or len(rule.conditions.scope) == 0:
221
+ continue
222
+ for k, v in rule.conditions.scope.items():
223
+ new_imports = f"from {v.__module__} import {v.__name__}\n"
224
+ if new_imports in imports:
225
+ continue
226
+ imports += new_imports
237
227
  return imports
238
228
 
239
- def get_rdr_classifier_from_python_file(self, package_name) -> Callable[[Any], Any]:
229
+ def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
240
230
  """
241
231
  :param package_name: The name of the package that contains the RDR classifier function.
242
232
  :return: The module that contains the rdr classifier function.
243
233
  """
244
- return importlib.import_module(f"{package_name.strip('./')}.{self.generated_python_file_name}").classify
234
+ # remove from imports if exists first
235
+ name = f"{package_name.strip('./')}.{self.generated_python_file_name}"
236
+ try:
237
+ module = importlib.import_module(name)
238
+ del sys.modules[name]
239
+ except ModuleNotFoundError:
240
+ pass
241
+ return importlib.import_module(name).classify
245
242
 
246
243
  @property
247
244
  def generated_python_file_name(self) -> str:
248
- return f"{self.start_rule.corner_case._name.lower()}_{self.attribute_name}_rdr"
245
+ if self._generated_python_file_name is None:
246
+ self._generated_python_file_name = self._default_generated_python_file_name
247
+ return self._generated_python_file_name
248
+
249
+ @generated_python_file_name.setter
250
+ def generated_python_file_name(self, value: str):
251
+ """
252
+ Set the generated python file name.
253
+ :param value: The new value for the generated python file name.
254
+ """
255
+ self._generated_python_file_name = value
256
+
257
+ @property
258
+ def _default_generated_python_file_name(self) -> str:
259
+ """
260
+ :return: The default generated python file name.
261
+ """
262
+ return f"{self.start_rule.corner_case._name.lower()}_{self.attribute_name}_{self.acronym.lower()}"
263
+
264
+ @property
265
+ def generated_python_defs_file_name(self) -> str:
266
+ return f"{self.generated_python_file_name}_defs"
267
+
268
+ @property
269
+ def acronym(self) -> str:
270
+ """
271
+ :return: The acronym of the classifier.
272
+ """
273
+ if self.__class__.__name__ == "GeneralRDR":
274
+ return "GRDR"
275
+ elif self.__class__.__name__ == "MultiClassRDR":
276
+ return "MCRDR"
277
+ else:
278
+ return "SCRDR"
249
279
 
250
280
  @property
251
281
  def case_type(self) -> Type:
@@ -258,16 +288,17 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
258
288
  return type(self.start_rule.corner_case)
259
289
 
260
290
  @property
261
- def conclusion_type(self) -> Type:
291
+ def conclusion_type(self) -> Tuple[Type]:
262
292
  """
263
293
  :return: The type of the conclusion of the RDR classifier.
264
294
  """
265
295
  if isinstance(self.start_rule.conclusion, CallableExpression):
266
296
  return self.start_rule.conclusion.conclusion_type
267
297
  else:
268
- if isinstance(self.start_rule.conclusion, set):
269
- return type(list(self.start_rule.conclusion)[0])
270
- return type(self.start_rule.conclusion)
298
+ conclusion = self.start_rule.conclusion
299
+ if isinstance(conclusion, set):
300
+ return type(list(conclusion)[0]), set
301
+ return (type(conclusion),)
271
302
 
272
303
  @property
273
304
  def attribute_name(self) -> str:
@@ -279,8 +310,16 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
279
310
 
280
311
  class SingleClassRDR(RDRWithCodeWriter):
281
312
 
313
+ def __init__(self, start_rule: Optional[SingleClassRule] = None, default_conclusion: Optional[Any] = None):
314
+ """
315
+ :param start_rule: The starting rule for the classifier.
316
+ :param default_conclusion: The default conclusion for the classifier if no rules fire.
317
+ """
318
+ super(SingleClassRDR, self).__init__(start_rule)
319
+ self.default_conclusion: Optional[Any] = default_conclusion
320
+
282
321
  def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
283
- -> Union[CaseAttribute, CallableExpression]:
322
+ -> Union[CaseAttribute, CallableExpression, None]:
284
323
  """
285
324
  Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
286
325
  comparing the case with the target category if provided.
@@ -289,28 +328,31 @@ class SingleClassRDR(RDRWithCodeWriter):
289
328
  :param expert: The expert to ask for differentiating features as new rule conditions.
290
329
  :return: The category that the case belongs to.
291
330
  """
292
- expert = expert if expert else Human(session=self.session)
293
- if case_query.target is None:
294
- target = expert.ask_for_conclusion(case_query)
331
+ expert = expert if expert else Human()
332
+ if case_query.default_value is not None and self.default_conclusion != case_query.default_value:
333
+ self.default_conclusion = case_query.default_value
334
+ case = case_query.case
335
+ target = expert.ask_for_conclusion(case_query) if case_query.target is None else case_query.target
336
+ if target is None:
337
+ return self.classify(case)
295
338
  if not self.start_rule:
296
339
  conditions = expert.ask_for_conditions(case_query)
297
- self.start_rule = SingleClassRule(conditions, case_query.target, corner_case=case_query.case,
340
+ self.start_rule = SingleClassRule(conditions, target, corner_case=case,
298
341
  conclusion_name=case_query.attribute_name)
299
342
 
300
343
  pred = self.evaluate(case_query.case)
301
-
302
- if pred.conclusion != case_query.target:
344
+ if pred.conclusion(case) != target(case):
303
345
  conditions = expert.ask_for_conditions(case_query, pred)
304
- pred.fit_rule(case_query.case, case_query.target, conditions=conditions)
346
+ pred.fit_rule(case_query.case, target, conditions=conditions)
305
347
 
306
348
  return self.classify(case_query.case)
307
349
 
308
- def classify(self, case: Case) -> Optional[CaseAttribute]:
350
+ def classify(self, case: Case) -> Optional[Any]:
309
351
  """
310
352
  Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
311
353
  """
312
354
  pred = self.evaluate(case)
313
- return pred.conclusion if pred.fired else None
355
+ return pred.conclusion(case) if pred.fired else self.default_conclusion
314
356
 
315
357
  def evaluate(self, case: Case) -> SingleClassRule:
316
358
  """
@@ -319,23 +361,32 @@ class SingleClassRDR(RDRWithCodeWriter):
319
361
  matched_rule = self.start_rule(case)
320
362
  return matched_rule if matched_rule else self.start_rule
321
363
 
322
- def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file, parent_indent: str = ""):
364
+ def write_to_python_file(self, file_path: str, postfix: str = ""):
365
+ super().write_to_python_file(file_path, postfix)
366
+ if self.default_conclusion is not None:
367
+ with open(file_path + f"/{self.generated_python_file_name}.py", "a") as f:
368
+ f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
369
+
370
+ def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
371
+ defs_file: Optional[str] = None):
323
372
  """
324
373
  Write the rules as source code to a file.
325
374
  """
326
375
  if rule.conditions:
327
- file.write(rule.write_condition_as_source_code(parent_indent))
376
+ if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
377
+ file.write(if_clause)
328
378
  if rule.refinement:
329
- self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ")
379
+ self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
380
+ defs_file=defs_file)
330
381
 
331
382
  file.write(rule.write_conclusion_as_source_code(parent_indent))
332
383
 
333
384
  if rule.alternative:
334
- self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
385
+ self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
335
386
 
336
387
  @property
337
388
  def conclusion_type_hint(self) -> str:
338
- return self.conclusion_type.__name__
389
+ return self.conclusion_type[0].__name__
339
390
 
340
391
  def _to_json(self) -> Dict[str, Any]:
341
392
  return {"start_rule": self.start_rule.to_json()}
@@ -369,28 +420,27 @@ class MultiClassRDR(RDRWithCodeWriter):
369
420
  """
370
421
 
371
422
  def __init__(self, start_rule: Optional[Rule] = None,
372
- mode: MCRDRMode = MCRDRMode.StopOnly, session: Optional[Session] = None):
423
+ mode: MCRDRMode = MCRDRMode.StopOnly):
373
424
  """
374
425
  :param start_rule: The starting rules for the classifier.
375
426
  :param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
376
- :param session: The sqlalchemy orm session.
377
427
  """
378
428
  start_rule = MultiClassTopRule() if not start_rule else start_rule
379
- super(MultiClassRDR, self).__init__(start_rule, session=session)
429
+ super(MultiClassRDR, self).__init__(start_rule)
380
430
  self.mode: MCRDRMode = mode
381
431
 
382
- def classify(self, case: Union[Case, SQLTable]) -> List[Any]:
432
+ def classify(self, case: Union[Case, SQLTable]) -> Set[Any]:
383
433
  evaluated_rule = self.start_rule
384
434
  self.conclusions = []
385
435
  while evaluated_rule:
386
436
  next_rule = evaluated_rule(case)
387
437
  if evaluated_rule.fired:
388
- self.add_conclusion(evaluated_rule)
438
+ self.add_conclusion(evaluated_rule, case)
389
439
  evaluated_rule = next_rule
390
- return self.conclusions
440
+ return make_set(self.conclusions)
391
441
 
392
442
  def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None,
393
- add_extra_conclusions: bool = False) -> List[Union[CaseAttribute, CallableExpression]]:
443
+ add_extra_conclusions: bool = False) -> Set[Union[CaseAttribute, CallableExpression, None]]:
394
444
  """
395
445
  Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
396
446
  or missing by comparing the case with the target category if provided.
@@ -400,72 +450,76 @@ class MultiClassRDR(RDRWithCodeWriter):
400
450
  :param add_extra_conclusions: Whether to add extra conclusions after classification is done.
401
451
  :return: The conclusions that the case belongs to.
402
452
  """
403
- expert = expert if expert else Human(session=self.session)
453
+ expert = expert if expert else Human()
454
+ if case_query.target is None:
455
+ expert.ask_for_conclusion(case_query)
404
456
  if case_query.target is None:
405
- targets = expert.ask_for_conclusion(case_query)
457
+ return self.classify(case_query.case)
458
+ self.update_start_rule(case_query, expert)
406
459
  self.expert_accepted_conclusions = []
407
460
  user_conclusions = []
408
- self.update_start_rule(case_query, expert)
409
461
  self.conclusions = []
410
462
  self.stop_rule_conditions = None
411
463
  evaluated_rule = self.start_rule
464
+ target = case_query.target(case_query.case)
412
465
  while evaluated_rule:
413
466
  next_rule = evaluated_rule(case_query.case)
414
- good_conclusions = make_list(case_query.target) + user_conclusions + self.expert_accepted_conclusions
467
+ rule_conclusion = evaluated_rule.conclusion(case_query.case)
468
+ good_conclusions = make_list(target) + user_conclusions + self.expert_accepted_conclusions
415
469
  good_conclusions = make_set(good_conclusions)
416
470
 
417
471
  if evaluated_rule.fired:
418
- if case_query.target and not make_set(evaluated_rule.conclusion).issubset(good_conclusions):
419
- # if self.case_has_conclusion(case, evaluated_rule.conclusion):
472
+ if target and not make_set(rule_conclusion).issubset(good_conclusions):
420
473
  # Rule fired and conclusion is different from target
421
474
  self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
422
- add_extra_conclusions)
475
+ len(user_conclusions) > 0)
423
476
  else:
424
477
  # Rule fired and target is correct or there is no target to compare
425
- self.add_conclusion(evaluated_rule)
478
+ self.add_conclusion(evaluated_rule, case_query.case)
426
479
 
427
480
  if not next_rule:
428
- if not make_set(case_query.target).intersection(make_set(self.conclusions)):
481
+ if not make_set(target).issubset(make_set(self.conclusions)):
429
482
  # Nothing fired and there is a target that should have been in the conclusions
430
483
  self.add_rule_for_case(case_query, expert)
431
484
  # Have to check all rules again to make sure only this new rule fires
432
485
  next_rule = self.start_rule
433
- elif add_extra_conclusions and not user_conclusions:
486
+ elif add_extra_conclusions:
434
487
  # No more conclusions can be made, ask the expert for extra conclusions if needed.
435
- user_conclusions.extend(self.ask_expert_for_extra_conclusions(expert, case_query.case))
436
- if user_conclusions:
488
+ new_user_conclusions = self.ask_expert_for_extra_rules(expert, case_query)
489
+ user_conclusions.extend(new_user_conclusions)
490
+ if len(new_user_conclusions) > 0:
437
491
  next_rule = self.last_top_rule
492
+ else:
493
+ add_extra_conclusions = False
438
494
  evaluated_rule = next_rule
439
495
  return self.conclusions
440
496
 
441
497
  def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
442
- file, parent_indent: str = ""):
443
- """
444
- Write the rules as source code to a file.
445
-
446
- :
447
- """
498
+ file, parent_indent: str = "", defs_file: Optional[str] = None):
448
499
  if rule == self.start_rule:
449
500
  file.write(f"{parent_indent}conclusions = set()\n")
450
501
  if rule.conditions:
451
- file.write(rule.write_condition_as_source_code(parent_indent))
502
+ if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
503
+ file.write(if_clause)
452
504
  conclusion_indent = parent_indent
453
505
  if hasattr(rule, "refinement") and rule.refinement:
454
- self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ")
506
+ self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
507
+ defs_file=defs_file)
455
508
  conclusion_indent = parent_indent + " " * 4
456
509
  file.write(f"{conclusion_indent}else:\n")
457
510
  file.write(rule.write_conclusion_as_source_code(conclusion_indent))
458
511
 
459
512
  if rule.alternative:
460
- self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
513
+ self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
461
514
 
462
515
  @property
463
516
  def conclusion_type_hint(self) -> str:
464
- return f"Set[{self.conclusion_type.__name__}]"
517
+ return f"Set[{self.conclusion_type[0].__name__}]"
465
518
 
466
519
  def _get_imports(self) -> str:
467
520
  imports = super()._get_imports()
468
521
  imports += "from typing_extensions import Set\n"
522
+ imports += "from ripple_down_rules.utils import make_set\n"
469
523
  return imports
470
524
 
471
525
  def update_start_rule(self, case_query: CaseQuery, expert: Expert):
@@ -498,11 +552,12 @@ class MultiClassRDR(RDRWithCodeWriter):
498
552
  """
499
553
  Stop a wrong conclusion by adding a stopping rule.
500
554
  """
501
- if self.is_same_category_type(evaluated_rule.conclusion, case_query.target) \
502
- and self.is_conflicting_with_target(evaluated_rule.conclusion, case_query.target):
503
- self.stop_conclusion(case_query, expert, evaluated_rule)
504
- elif not self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
505
- self.stop_conclusion(case_query, expert, evaluated_rule)
555
+ rule_conclusion = evaluated_rule.conclusion(case_query.case)
556
+ if is_conflicting(rule_conclusion, case_query.target_value):
557
+ if self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
558
+ return
559
+ else:
560
+ self.stop_conclusion(case_query, expert, evaluated_rule)
506
561
 
507
562
  def stop_conclusion(self, case_query: CaseQuery,
508
563
  expert: Expert, evaluated_rule: MultiClassTopRule):
@@ -521,31 +576,6 @@ class MultiClassRDR(RDRWithCodeWriter):
521
576
  new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
522
577
  self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
523
578
 
524
- @staticmethod
525
- def is_conflicting_with_target(conclusion: Any, target: Any) -> bool:
526
- """
527
- Check if the conclusion is conflicting with the target category.
528
-
529
- :param conclusion: The conclusion to check.
530
- :param target: The target category to compare the conclusion with.
531
- :return: Whether the conclusion is conflicting with the target category.
532
- """
533
- if hasattr(conclusion, "mutually_exclusive") and conclusion.mutually_exclusive:
534
- return True
535
- else:
536
- return not make_set(conclusion).issubset(make_set(target))
537
-
538
- @staticmethod
539
- def is_same_category_type(conclusion: Any, target: Any) -> bool:
540
- """
541
- Check if the conclusion is of the same class as the target category.
542
-
543
- :param conclusion: The conclusion to check.
544
- :param target: The target category to compare the conclusion with.
545
- :return: Whether the conclusion is of the same class as the target category but has a different value.
546
- """
547
- return conclusion.__class__ == target.__class__ and target.__class__ != CaseAttribute
548
-
549
579
  def conclusion_is_correct(self, case_query: CaseQuery,
550
580
  expert: Expert, evaluated_rule: Rule,
551
581
  add_extra_conclusions: bool) -> bool:
@@ -559,10 +589,10 @@ class MultiClassRDR(RDRWithCodeWriter):
559
589
  :return: Whether the conclusion is correct or not.
560
590
  """
561
591
  conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
562
- if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case, evaluated_rule.conclusion,
563
- targets=case_query.target,
592
+ if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
593
+ evaluated_rule.conclusion(case_query.case),
564
594
  current_conclusions=conclusions)):
565
- self.add_conclusion(evaluated_rule)
595
+ self.add_conclusion(evaluated_rule, case_query.case)
566
596
  self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
567
597
  return True
568
598
  return False
@@ -581,38 +611,39 @@ class MultiClassRDR(RDRWithCodeWriter):
581
611
  conditions = expert.ask_for_conditions(case_query)
582
612
  self.add_top_rule(conditions, case_query.target, case_query.case)
583
613
 
584
- def ask_expert_for_extra_conclusions(self, expert: Expert, case: Union[Case, SQLTable]) -> List[Any]:
614
+ def ask_expert_for_extra_rules(self, expert: Expert, case_query: CaseQuery) -> List[Any]:
585
615
  """
586
- Ask the expert for extra conclusions when no more conclusions can be made.
616
+ Ask the expert for extra rules when no more conclusions can be made for a case.
587
617
 
588
618
  :param expert: The expert to ask for extra conclusions.
589
- :param case: The case to ask extra conclusions for.
590
- :return: The extra conclusions that the expert has provided.
619
+ :param case_query: The case query to ask the expert about.
620
+ :return: The extra conclusions for the rules that the expert has provided.
591
621
  """
592
622
  extra_conclusions = []
593
623
  conclusions = list(OrderedSet(self.conclusions))
594
624
  if not expert.use_loaded_answers:
595
625
  print("current conclusions:", conclusions)
596
- extra_conclusions_dict = expert.ask_for_extra_conclusions(case, conclusions)
597
- if extra_conclusions_dict:
598
- for conclusion, conditions in extra_conclusions_dict.items():
599
- self.add_top_rule(conditions, conclusion, case)
600
- extra_conclusions.append(conclusion)
626
+ extra_rules = expert.ask_for_extra_rules(case_query)
627
+ for rule in extra_rules:
628
+ self.add_top_rule(rule[PromptFor.Conditions], rule[PromptFor.Conclusion], case_query.case)
629
+ extra_conclusions.extend(rule[PromptFor.Conclusion](case_query.case))
601
630
  return extra_conclusions
602
631
 
603
- def add_conclusion(self, evaluated_rule: Rule) -> None:
632
+ def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
604
633
  """
605
634
  Add the conclusion of the evaluated rule to the list of conclusions.
606
635
 
607
636
  :param evaluated_rule: The evaluated rule to add the conclusion of.
637
+ :param case: The case to add the conclusion for.
608
638
  """
609
639
  conclusion_types = [type(c) for c in self.conclusions]
610
- if type(evaluated_rule.conclusion) not in conclusion_types:
611
- self.conclusions.extend(make_list(evaluated_rule.conclusion))
640
+ rule_conclusion = evaluated_rule.conclusion(case)
641
+ if type(rule_conclusion) not in conclusion_types:
642
+ self.conclusions.extend(make_list(rule_conclusion))
612
643
  else:
613
- same_type_conclusions = [c for c in self.conclusions if type(c) == type(evaluated_rule.conclusion)]
614
- combined_conclusion = evaluated_rule.conclusion if isinstance(evaluated_rule.conclusion, set) \
615
- else {evaluated_rule.conclusion}
644
+ same_type_conclusions = [c for c in self.conclusions if type(c) == type(rule_conclusion)]
645
+ combined_conclusion = rule_conclusion if isinstance(rule_conclusion, set) \
646
+ else {rule_conclusion}
616
647
  combined_conclusion = copy(combined_conclusion)
617
648
  for c in same_type_conclusions:
618
649
  combined_conclusion.update(c if isinstance(c, set) else make_set(c))
@@ -713,6 +744,7 @@ class GeneralRDR(RippleDownRules):
713
744
  :return: The categories that the case belongs to.
714
745
  """
715
746
  conclusions = {}
747
+ case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
716
748
  case_cp = copy_case(case)
717
749
  while True:
718
750
  new_conclusions = {}
@@ -720,22 +752,24 @@ class GeneralRDR(RippleDownRules):
720
752
  pred_atts = rdr.classify(case_cp)
721
753
  if pred_atts is None:
722
754
  continue
723
- if isinstance(rdr, SingleClassRDR):
755
+ if rdr.type_ is SingleClassRDR:
724
756
  if attribute_name not in conclusions or \
725
757
  (attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
726
758
  conclusions[attribute_name] = pred_atts
727
759
  new_conclusions[attribute_name] = pred_atts
728
760
  else:
729
- pred_atts = make_list(pred_atts)
761
+ pred_atts = make_set(pred_atts)
730
762
  if attribute_name in conclusions:
731
- pred_atts = [p for p in pred_atts if p not in conclusions[attribute_name]]
763
+ pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
732
764
  if len(pred_atts) > 0:
733
765
  new_conclusions[attribute_name] = pred_atts
734
766
  if attribute_name not in conclusions:
735
- conclusions[attribute_name] = []
736
- conclusions[attribute_name].extend(pred_atts)
767
+ conclusions[attribute_name] = set()
768
+ conclusions[attribute_name].update(pred_atts)
737
769
  if attribute_name in new_conclusions:
738
- GeneralRDR.update_case(case_cp, new_conclusions)
770
+ mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
771
+ GeneralRDR.update_case(CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive),
772
+ new_conclusions)
739
773
  if len(new_conclusions) == 0:
740
774
  break
741
775
  return conclusions
@@ -761,76 +795,103 @@ class GeneralRDR(RippleDownRules):
761
795
  case = case_queries[0].case
762
796
  assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
763
797
  " for multiple cases use fit instead")
764
- case_cp = copy(case_queries[0]).case
798
+ original_case_query_cp = copy(case_queries[0])
765
799
  for case_query in case_queries:
766
800
  case_query_cp = copy(case_query)
767
- case_query_cp.case = case_cp
768
- if case_query.target is None:
801
+ case_query_cp.case = original_case_query_cp.case
802
+ if case_query_cp.target is None:
769
803
  conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
770
- target = expert.ask_for_conclusion(case_query)
804
+ self.update_case(case_query_cp, conclusions)
805
+ expert.ask_for_conclusion(case_query_cp)
806
+ if case_query_cp.target is None:
807
+ continue
808
+ case_query.target = case_query_cp.target
771
809
 
772
810
  if case_query.attribute_name not in self.start_rules_dict:
773
811
  conclusions = self.classify(case)
774
- self.update_case(case_cp, conclusions)
812
+ self.update_case(case_query_cp, conclusions)
775
813
 
776
- new_rdr = self.initialize_new_rdr_for_attribute(case_query.attribute_name, case_cp, case_query.target)
814
+ new_rdr = self.initialize_new_rdr_for_attribute(case_query_cp)
777
815
  self.add_rdr(new_rdr, case_query.attribute_name)
778
816
 
779
817
  new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
780
- self.update_case(case_cp, {case_query.attribute_name: new_conclusions})
818
+ self.update_case(case_query_cp, {case_query.attribute_name: new_conclusions})
781
819
  else:
782
820
  for rdr_attribute_name, rdr in self.start_rules_dict.items():
783
821
  if case_query.attribute_name != rdr_attribute_name:
784
- conclusions = rdr.classify(case_cp)
822
+ conclusions = rdr.classify(case_query_cp.case)
785
823
  else:
786
824
  conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
787
825
  **kwargs)
788
826
  if conclusions is not None or (is_iterable(conclusions) and len(conclusions) > 0):
789
827
  conclusions = {rdr_attribute_name: conclusions}
790
- self.update_case(case_cp, conclusions)
828
+ case_query_cp.mutually_exclusive = True if isinstance(rdr, SingleClassRDR) else False
829
+ self.update_case(case_query_cp, conclusions)
830
+ case_query.conditions = case_query_cp.conditions
791
831
 
792
832
  return self.classify(case)
793
833
 
794
834
  @staticmethod
795
- def initialize_new_rdr_for_attribute(attribute_name: str, case: Union[Case, SQLTable], target: Any):
835
+ def initialize_new_rdr_for_attribute(case_query: CaseQuery):
796
836
  """
797
837
  Initialize the appropriate RDR type for the target.
798
838
  """
799
- attribute = getattr(case, attribute_name) if hasattr(case, attribute_name) else target
839
+ if case_query.mutually_exclusive is not None:
840
+ return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
841
+ else MultiClassRDR()
842
+ if case_query.attribute_type in [list, set]:
843
+ return MultiClassRDR()
844
+ attribute = getattr(case_query.case, case_query.attribute_name) \
845
+ if hasattr(case_query.case, case_query.attribute_name) else case_query.target(case_query.case)
800
846
  if isinstance(attribute, CaseAttribute):
801
- return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
847
+ return SingleClassRDR(default_conclusion=case_query.default_value) if attribute.mutually_exclusive \
848
+ else MultiClassRDR()
802
849
  else:
803
- return MultiClassRDR() if is_iterable(attribute) or (attribute is None) else SingleClassRDR()
850
+ return MultiClassRDR() if is_iterable(attribute) or (attribute is None) \
851
+ else SingleClassRDR(default_conclusion=case_query.default_value)
804
852
 
805
853
  @staticmethod
806
- def update_case(case: Union[Case, SQLTable], conclusions: Dict[str, Any]):
854
+ def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
807
855
  """
808
856
  Update the case with the conclusions.
809
857
 
810
- :param case: The case to update.
858
+ :param case_query: The case query that contains the case to update.
811
859
  :param conclusions: The conclusions to update the case with.
812
860
  """
813
861
  if not conclusions:
814
862
  return
815
863
  if len(conclusions) == 0:
816
864
  return
817
- if isinstance(case, SQLTable):
865
+ if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
818
866
  for conclusion_name, conclusion in conclusions.items():
819
- hint, origin, args = get_hint_for_attribute(conclusion_name, case)
820
- attribute = getattr(case, conclusion_name)
821
- if isinstance(attribute, set) or origin in {Set, set}:
822
- attribute = set() if attribute is None else attribute
867
+ attribute = getattr(case_query.case, conclusion_name)
868
+ if conclusion_name == case_query.attribute_name:
869
+ attribute_type = case_query.attribute_type
870
+ else:
871
+ attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
872
+ if isinstance(attribute, set):
823
873
  for c in conclusion:
824
874
  attribute.update(make_set(c))
825
- elif isinstance(attribute, list) or origin in {list, List}:
875
+ elif isinstance(attribute, list):
876
+ attribute.extend(conclusion)
877
+ elif any(at in {List, list} for at in attribute_type):
826
878
  attribute = [] if attribute is None else attribute
827
879
  attribute.extend(conclusion)
828
- elif (not is_iterable(conclusion) or (len(conclusion) == 1)) and hint == type(conclusion):
829
- setattr(case, conclusion_name, conclusion)
880
+ elif any(at in {Set, set} for at in attribute_type):
881
+ attribute = set() if attribute is None else attribute
882
+ for c in conclusion:
883
+ attribute.update(make_set(c))
884
+ elif is_iterable(conclusion) and len(conclusion) == 1 \
885
+ and any(at is type(list(conclusion)[0]) for at in attribute_type):
886
+ setattr(case_query.case, conclusion_name, list(conclusion)[0])
887
+ elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
888
+ setattr(case_query.case, conclusion_name, conclusion)
830
889
  else:
831
- raise ValueError(f"Cannot add multiple conclusions to attribute {conclusion_name}")
890
+ raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
891
+ f"{case_query.attribute_type} with conclusion "
892
+ f"{conclusion} of type {type(conclusion)}")
832
893
  else:
833
- case.update(conclusions)
894
+ case_query.case.update(conclusions)
834
895
 
835
896
  def _to_json(self) -> Dict[str, Any]:
836
897
  return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
@@ -845,14 +906,16 @@ class GeneralRDR(RippleDownRules):
845
906
  start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
846
907
  return cls(start_rules_dict)
847
908
 
848
- def write_to_python_file(self, file_path: str):
909
+ def write_to_python_file(self, file_path: str, postfix: str = "") -> None:
849
910
  """
850
911
  Write the tree of rules as source code to a file.
851
912
 
852
913
  :param file_path: The path to the file to write the source code to.
914
+ :param postfix: The postfix to add to the file name.
853
915
  """
916
+ self.generated_python_file_name = self._default_generated_python_file_name + postfix
854
917
  for rdr in self.start_rules_dict.values():
855
- rdr.write_to_python_file(file_path)
918
+ rdr.write_to_python_file(file_path, postfix=f"_of_grdr{postfix}")
856
919
  func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
857
920
  with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
858
921
  f.write(self._get_imports(file_path) + "\n\n")
@@ -875,15 +938,29 @@ class GeneralRDR(RippleDownRules):
875
938
  else:
876
939
  return type(self.start_rule.corner_case)
877
940
 
878
- def get_rdr_classifier_from_python_file(self, file_path: str):
941
+ def get_rdr_classifier_from_python_file(self, file_path: str) -> Callable[[Any], Any]:
879
942
  """
880
943
  :param file_path: The path to the file that contains the RDR classifier function.
944
+ :param postfix: The postfix to add to the file name.
881
945
  :return: The module that contains the rdr classifier function.
882
946
  """
883
947
  return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
884
948
 
885
949
  @property
886
950
  def generated_python_file_name(self) -> str:
951
+ if self._generated_python_file_name is None:
952
+ self._generated_python_file_name = self._default_generated_python_file_name
953
+ return self._generated_python_file_name
954
+
955
+ @generated_python_file_name.setter
956
+ def generated_python_file_name(self, value: str):
957
+ self._generated_python_file_name = value
958
+
959
+ @property
960
+ def _default_generated_python_file_name(self) -> str:
961
+ """
962
+ :return: The default generated python file name.
963
+ """
887
964
  return f"{self.start_rule.corner_case._name.lower()}_rdr"
888
965
 
889
966
  @property
@@ -891,17 +968,25 @@ class GeneralRDR(RippleDownRules):
891
968
  return f"List[Union[{', '.join([rdr.conclusion_type_hint for rdr in self.start_rules_dict.values()])}]]"
892
969
 
893
970
  def _get_imports(self, file_path: str) -> str:
971
+ """
972
+ Get the imports needed for the generated python file.
973
+
974
+ :param file_path: The path to the file that contains the RDR classifier function.
975
+ :return: The imports needed for the generated python file.
976
+ """
894
977
  imports = ""
895
978
  # add type hints
896
979
  imports += f"from typing_extensions import List, Union, Set\n"
897
980
  # import rdr type
898
981
  imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
899
982
  # add case type
900
- imports += f"from ripple_down_rules.datastructures import Case, create_case\n"
983
+ imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
901
984
  imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
902
985
  # add conclusion type imports
903
986
  for rdr in self.start_rules_dict.values():
904
- imports += f"from {rdr.conclusion_type.__module__} import {rdr.conclusion_type.__name__}\n"
987
+ for conclusion_type in rdr.conclusion_type:
988
+ if conclusion_type.__module__ != "builtins":
989
+ imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
905
990
  # add rdr python generated functions.
906
991
  for rdr_key, rdr in self.start_rules_dict.items():
907
992
  imports += (f"from {file_path.strip('./')}"