ripple-down-rules 0.0.15__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ripple_down_rules/rdr.py CHANGED
@@ -8,13 +8,13 @@ from types import ModuleType
8
8
  from matplotlib import pyplot as plt
9
9
  from ordered_set import OrderedSet
10
10
  from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
11
- from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable
11
+ from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
12
12
 
13
13
  from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute, CaseQuery
14
14
  from .experts import Expert, Human
15
15
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
16
- from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
17
- get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_full_class_name, get_type_from_string
16
+ from .utils import draw_tree, make_set, copy_case, \
17
+ get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string
18
18
 
19
19
 
20
20
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -80,7 +80,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
80
80
  :param kwargs_for_fit_case: The keyword arguments to pass to the fit_case method.
81
81
  """
82
82
  cases = [case_query.case for case_query in case_queries]
83
- targets = [case.target for case in case_queries]
83
+ targets = [{case_query.attribute_name: case_query.target} for case_query in case_queries]
84
84
  if animate_tree:
85
85
  plt.ion()
86
86
  i = 0
@@ -91,11 +91,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
91
91
  if not targets:
92
92
  targets = [None] * len(cases)
93
93
  for case_query in case_queries:
94
- case = case_query.case
95
- target = case_query.target
96
- if not target:
97
- conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
98
- target = expert.ask_for_conclusion(case_query, conclusions)
94
+ target = {case_query.attribute_name: case_query.target}
99
95
  pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
100
96
  match = self.is_matching(pred_cat, target)
101
97
  if not match:
@@ -105,8 +101,9 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
105
101
  num_rules = self.start_rule.size
106
102
  self.update_figures()
107
103
  i += 1
108
- all_predictions = [1 if self.is_matching(self.classify(case), target) else 0
109
- for case, target in zip(cases, targets)]
104
+ all_predictions = [1 if self.is_matching(self.classify(case_query.case), {case_query.attribute_name:
105
+ case_query.target}) else 0
106
+ for case_query in case_queries]
110
107
  all_pred = sum(all_predictions)
111
108
  print(f"Accuracy: {all_pred}/{len(targets)}")
112
109
  all_predicted = targets and all_pred == len(targets)
@@ -129,9 +126,27 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
129
126
  """
130
127
  pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
131
128
  target = target if is_iterable(target) else [target]
132
- recall = [not yi or (yi in pred_cat) for yi in target]
133
- target_types = [type(yi) for yi in target]
134
- precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
129
+ recall = []
130
+ precision = []
131
+ if isinstance(pred_cat, dict):
132
+ for pred_key, pred_value in pred_cat.items():
133
+ if pred_key not in target:
134
+ continue
135
+ precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
136
+ for target_key, target_value in target.items():
137
+ if target_key not in pred_cat:
138
+ recall.append(False)
139
+ continue
140
+ if is_iterable(target_value):
141
+ recall.extend([v in pred_cat[target_key] for v in target_value])
142
+ else:
143
+ recall.append(target_value == pred_cat[target_key])
144
+ else:
145
+ if isinstance(target, dict):
146
+ target = list(target.values())
147
+ recall = [not yi or (yi in pred_cat) for yi in target]
148
+ target_types = [type(yi) for yi in target]
149
+ precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
135
150
  return precision, recall
136
151
 
137
152
  def is_matching(self, pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> bool:
@@ -158,22 +173,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
158
173
  draw_tree(self.start_rule, self.fig)
159
174
 
160
175
  @staticmethod
161
- def case_has_conclusion(case: Union[Case, SQLTable], conclusion_type: Type) -> bool:
176
+ def case_has_conclusion(case: Union[Case, SQLTable], conclusion_name: str) -> bool:
162
177
  """
163
178
  Check if the case has a conclusion.
164
179
 
165
180
  :param case: The case to check.
166
- :param conclusion_type: The target category type to compare the case with.
181
+ :param conclusion_name: The target category name to compare the case with.
167
182
  :return: Whether the case has a conclusion or not.
168
183
  """
169
- if isinstance(case, SQLTable):
170
- prop_name, prop_value = get_attribute_by_type(case, conclusion_type)
171
- if hasattr(prop_value, "__iter__") and not isinstance(prop_value, str):
172
- return len(prop_value) > 0
173
- else:
174
- return prop_value is not None
175
- else:
176
- return conclusion_type in case
184
+ return hasattr(case, conclusion_name) and getattr(case, conclusion_name) is not None
177
185
 
178
186
 
179
187
  class RDRWithCodeWriter(RippleDownRules, ABC):
@@ -200,7 +208,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
200
208
  f.write(self._get_imports() + "\n\n")
201
209
  f.write(func_def)
202
210
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
203
- f"{' ' * 4} case = create_case(case, recursion_idx=3)\n""")
211
+ f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
204
212
  self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
205
213
 
206
214
  @property
@@ -221,6 +229,11 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
221
229
  if self.conclusion_type.__module__ != "builtins":
222
230
  imports += f"from {self.conclusion_type.__module__} import {self.conclusion_type.__name__}\n"
223
231
  imports += "from ripple_down_rules.datastructures import Case, create_case\n"
232
+ for rule in [self.start_rule] + list(self.start_rule.descendants):
233
+ if rule.conditions:
234
+ if rule.conditions.scope is not None and len(rule.conditions.scope) > 0:
235
+ for k, v in rule.conditions.scope.items():
236
+ imports += f"from {v.__module__} import {v.__name__}\n"
224
237
  return imports
225
238
 
226
239
  def get_rdr_classifier_from_python_file(self, package_name) -> Callable[[Any], Any]:
@@ -232,11 +245,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
232
245
 
233
246
  @property
234
247
  def generated_python_file_name(self) -> str:
235
- return f"{self.conclusion_type.__name__.lower()}_{self.__class__.__name__}"
236
-
237
- @property
238
- def python_file_name(self):
239
- return f"{self.start_rule.conclusion.__name__.lower()}_rdr"
248
+ return f"{self.start_rule.corner_case._name.lower()}_{self.attribute_name}_rdr"
240
249
 
241
250
  @property
242
251
  def case_type(self) -> Type:
@@ -244,7 +253,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
244
253
  :return: The type of the case (input) to the RDR classifier.
245
254
  """
246
255
  if isinstance(self.start_rule.corner_case, Case):
247
- return self.start_rule.corner_case._type
256
+ return self.start_rule.corner_case._obj_type
248
257
  else:
249
258
  return type(self.start_rule.corner_case)
250
259
 
@@ -260,6 +269,13 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
260
269
  return type(list(self.start_rule.conclusion)[0])
261
270
  return type(self.start_rule.conclusion)
262
271
 
272
+ @property
273
+ def attribute_name(self) -> str:
274
+ """
275
+ :return: The name of the attribute that the classifier is classifying.
276
+ """
277
+ return self.start_rule.conclusion_name
278
+
263
279
 
264
280
  class SingleClassRDR(RDRWithCodeWriter):
265
281
 
@@ -274,23 +290,20 @@ class SingleClassRDR(RDRWithCodeWriter):
274
290
  :return: The category that the case belongs to.
275
291
  """
276
292
  expert = expert if expert else Human(session=self.session)
277
- case, attribute = case_query.case, case_query.attribute
278
293
  if case_query.target is None:
279
294
  target = expert.ask_for_conclusion(case_query)
280
- else:
281
- target = case_query.target
282
-
283
295
  if not self.start_rule:
284
- conditions = expert.ask_for_conditions(case, [target])
285
- self.start_rule = SingleClassRule(conditions, target, corner_case=case)
296
+ conditions = expert.ask_for_conditions(case_query)
297
+ self.start_rule = SingleClassRule(conditions, case_query.target, corner_case=case_query.case,
298
+ conclusion_name=case_query.attribute_name)
286
299
 
287
- pred = self.evaluate(case)
300
+ pred = self.evaluate(case_query.case)
288
301
 
289
- if pred.conclusion != target:
290
- conditions = expert.ask_for_conditions(case, [target], pred)
291
- pred.fit_rule(case, target, conditions=conditions)
302
+ if pred.conclusion != case_query.target:
303
+ conditions = expert.ask_for_conditions(case_query, pred)
304
+ pred.fit_rule(case_query.case, case_query.target, conditions=conditions)
292
305
 
293
- return self.classify(case)
306
+ return self.classify(case_query.case)
294
307
 
295
308
  def classify(self, case: Case) -> Optional[CaseAttribute]:
296
309
  """
@@ -388,44 +401,41 @@ class MultiClassRDR(RDRWithCodeWriter):
388
401
  :return: The conclusions that the case belongs to.
389
402
  """
390
403
  expert = expert if expert else Human(session=self.session)
391
- case = case_query.case
392
404
  if case_query.target is None:
393
- targets = [expert.ask_for_conclusion(case_query)]
394
- else:
395
- targets = [case_query.target]
405
+ targets = expert.ask_for_conclusion(case_query)
396
406
  self.expert_accepted_conclusions = []
397
407
  user_conclusions = []
398
- for target in targets:
399
- self.update_start_rule(case, target, expert)
400
- self.conclusions = []
401
- self.stop_rule_conditions = None
402
- evaluated_rule = self.start_rule
403
- while evaluated_rule:
404
- next_rule = evaluated_rule(case)
405
- good_conclusions = targets + user_conclusions + self.expert_accepted_conclusions
406
-
407
- if evaluated_rule.fired:
408
- if target and evaluated_rule.conclusion not in good_conclusions:
409
- # if self.case_has_conclusion(case, evaluated_rule.conclusion):
410
- # Rule fired and conclusion is different from target
411
- self.stop_wrong_conclusion_else_add_it(case, target, expert, evaluated_rule,
412
- add_extra_conclusions)
413
- else:
414
- # Rule fired and target is correct or there is no target to compare
415
- self.add_conclusion(evaluated_rule)
416
-
417
- if not next_rule:
418
- if not make_set(target).intersection(make_set(self.conclusions)):
419
- # Nothing fired and there is a target that should have been in the conclusions
420
- self.add_rule_for_case(case, target, expert)
421
- # Have to check all rules again to make sure only this new rule fires
422
- next_rule = self.start_rule
423
- elif add_extra_conclusions and not user_conclusions:
424
- # No more conclusions can be made, ask the expert for extra conclusions if needed.
425
- user_conclusions.extend(self.ask_expert_for_extra_conclusions(expert, case))
426
- if user_conclusions:
427
- next_rule = self.last_top_rule
428
- evaluated_rule = next_rule
408
+ self.update_start_rule(case_query, expert)
409
+ self.conclusions = []
410
+ self.stop_rule_conditions = None
411
+ evaluated_rule = self.start_rule
412
+ while evaluated_rule:
413
+ next_rule = evaluated_rule(case_query.case)
414
+ good_conclusions = make_list(case_query.target) + user_conclusions + self.expert_accepted_conclusions
415
+ good_conclusions = make_set(good_conclusions)
416
+
417
+ if evaluated_rule.fired:
418
+ if case_query.target and not make_set(evaluated_rule.conclusion).issubset(good_conclusions):
419
+ # if self.case_has_conclusion(case, evaluated_rule.conclusion):
420
+ # Rule fired and conclusion is different from target
421
+ self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
422
+ add_extra_conclusions)
423
+ else:
424
+ # Rule fired and target is correct or there is no target to compare
425
+ self.add_conclusion(evaluated_rule)
426
+
427
+ if not next_rule:
428
+ if not make_set(case_query.target).intersection(make_set(self.conclusions)):
429
+ # Nothing fired and there is a target that should have been in the conclusions
430
+ self.add_rule_for_case(case_query, expert)
431
+ # Have to check all rules again to make sure only this new rule fires
432
+ next_rule = self.start_rule
433
+ elif add_extra_conclusions and not user_conclusions:
434
+ # No more conclusions can be made, ask the expert for extra conclusions if needed.
435
+ user_conclusions.extend(self.ask_expert_for_extra_conclusions(expert, case_query.case))
436
+ if user_conclusions:
437
+ next_rule = self.last_top_rule
438
+ evaluated_rule = next_rule
429
439
  return self.conclusions
430
440
 
431
441
  def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
@@ -458,19 +468,19 @@ class MultiClassRDR(RDRWithCodeWriter):
458
468
  imports += "from typing_extensions import Set\n"
459
469
  return imports
460
470
 
461
- def update_start_rule(self, case: Union[Case, SQLTable], target: Any, expert: Expert):
471
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
462
472
  """
463
473
  Update the starting rule of the classifier.
464
474
 
465
- :param case: The case to classify.
466
- :param target: The target category to compare the case with.
475
+ :param case_query: The case query to update the starting rule with.
467
476
  :param expert: The expert to ask for differentiating features as new rule conditions.
468
477
  """
469
478
  if not self.start_rule.conditions:
470
- conditions = expert.ask_for_conditions(case, target)
479
+ conditions = expert.ask_for_conditions(case_query)
471
480
  self.start_rule.conditions = conditions
472
- self.start_rule.conclusion = target
473
- self.start_rule.corner_case = case
481
+ self.start_rule.conclusion = case_query.target
482
+ self.start_rule.corner_case = case_query.case
483
+ self.start_rule.conclusion_name = case_query.attribute_name
474
484
 
475
485
  @property
476
486
  def last_top_rule(self) -> Optional[MultiClassTopRule]:
@@ -482,35 +492,34 @@ class MultiClassRDR(RDRWithCodeWriter):
482
492
  else:
483
493
  return self.start_rule.furthest_alternative[-1]
484
494
 
485
- def stop_wrong_conclusion_else_add_it(self, case: Union[Case, SQLTable], target: Any, expert: Expert,
495
+ def stop_wrong_conclusion_else_add_it(self, case_query: CaseQuery, expert: Expert,
486
496
  evaluated_rule: MultiClassTopRule,
487
497
  add_extra_conclusions: bool):
488
498
  """
489
499
  Stop a wrong conclusion by adding a stopping rule.
490
500
  """
491
- if self.is_same_category_type(evaluated_rule.conclusion, target) \
492
- and self.is_conflicting_with_target(evaluated_rule.conclusion, target):
493
- self.stop_conclusion(case, target, expert, evaluated_rule)
494
- elif not self.conclusion_is_correct(case, target, expert, evaluated_rule, add_extra_conclusions):
495
- self.stop_conclusion(case, target, expert, evaluated_rule)
501
+ if self.is_same_category_type(evaluated_rule.conclusion, case_query.target) \
502
+ and self.is_conflicting_with_target(evaluated_rule.conclusion, case_query.target):
503
+ self.stop_conclusion(case_query, expert, evaluated_rule)
504
+ elif not self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
505
+ self.stop_conclusion(case_query, expert, evaluated_rule)
496
506
 
497
- def stop_conclusion(self, case: Union[Case, SQLTable], target: Any,
507
+ def stop_conclusion(self, case_query: CaseQuery,
498
508
  expert: Expert, evaluated_rule: MultiClassTopRule):
499
509
  """
500
510
  Stop a conclusion by adding a stopping rule.
501
511
 
502
- :param case: The case to classify.
503
- :param target: The target category to compare the case with.
512
+ :param case_query: The case query to stop the conclusion for.
504
513
  :param expert: The expert to ask for differentiating features as new rule conditions.
505
514
  :param evaluated_rule: The evaluated rule to ask the expert about.
506
515
  """
507
- conditions = expert.ask_for_conditions(case, target, evaluated_rule)
508
- evaluated_rule.fit_rule(case, target, conditions=conditions)
516
+ conditions = expert.ask_for_conditions(case_query, evaluated_rule)
517
+ evaluated_rule.fit_rule(case_query.case, case_query.target, conditions=conditions)
509
518
  if self.mode == MCRDRMode.StopPlusRule:
510
519
  self.stop_rule_conditions = conditions
511
520
  if self.mode == MCRDRMode.StopPlusRuleCombined:
512
521
  new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
513
- self.add_top_rule(new_top_rule_conditions, target, case)
522
+ self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
514
523
 
515
524
  @staticmethod
516
525
  def is_conflicting_with_target(conclusion: Any, target: Any) -> bool:
@@ -537,37 +546,40 @@ class MultiClassRDR(RDRWithCodeWriter):
537
546
  """
538
547
  return conclusion.__class__ == target.__class__ and target.__class__ != CaseAttribute
539
548
 
540
- def conclusion_is_correct(self, case: Union[Case, SQLTable], target: Any, expert: Expert, evaluated_rule: Rule,
549
+ def conclusion_is_correct(self, case_query: CaseQuery,
550
+ expert: Expert, evaluated_rule: Rule,
541
551
  add_extra_conclusions: bool) -> bool:
542
552
  """
543
553
  Ask the expert if the conclusion is correct, and add it to the conclusions if it is.
544
554
 
545
- :param case: The case to classify.
546
- :param target: The target category to compare the case with.
555
+ :param case_query: The case query to ask the expert about.
547
556
  :param expert: The expert to ask for differentiating features as new rule conditions.
548
557
  :param evaluated_rule: The evaluated rule to ask the expert about.
549
558
  :param add_extra_conclusions: Whether adding extra conclusions after classification is allowed.
550
559
  :return: Whether the conclusion is correct or not.
551
560
  """
552
- conclusions = list(OrderedSet(self.conclusions))
553
- if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case, evaluated_rule.conclusion,
554
- targets=target,
561
+ conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
562
+ if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case, evaluated_rule.conclusion,
563
+ targets=case_query.target,
555
564
  current_conclusions=conclusions)):
556
565
  self.add_conclusion(evaluated_rule)
557
566
  self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
558
567
  return True
559
568
  return False
560
569
 
561
- def add_rule_for_case(self, case: Union[Case, SQLTable], target: Any, expert: Expert):
570
+ def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
562
571
  """
563
572
  Add a rule for a case that has not been classified with any conclusion.
573
+
574
+ :param case_query: The case query to add the rule for.
575
+ :param expert: The expert to ask for differentiating features as new rule conditions.
564
576
  """
565
577
  if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
566
578
  conditions = self.stop_rule_conditions
567
579
  self.stop_rule_conditions = None
568
580
  else:
569
- conditions = expert.ask_for_conditions(case, target)
570
- self.add_top_rule(conditions, target, case)
581
+ conditions = expert.ask_for_conditions(case_query)
582
+ self.add_top_rule(conditions, case_query.target, case_query.case)
571
583
 
572
584
  def ask_expert_for_extra_conclusions(self, expert: Expert, case: Union[Case, SQLTable]) -> List[Any]:
573
585
  """
@@ -641,20 +653,31 @@ class GeneralRDR(RippleDownRules):
641
653
  gets called when the final rule fires.
642
654
  """
643
655
 
644
- def __init__(self, category_rdr_map: Optional[Dict[Type, Union[SingleClassRDR, MultiClassRDR]]] = None):
656
+ def __init__(self, category_rdr_map: Optional[Dict[str, Union[SingleClassRDR, MultiClassRDR]]] = None):
645
657
  """
646
- :param category_rdr_map: A map of categories to ripple down rules classifiers,
658
+ :param category_rdr_map: A map of case attribute names to ripple down rules classifiers,
647
659
  where each category is a parent category that has a set of mutually exclusive (in case of SCRDR) child
648
- categories, e.g. {Species: SCRDR, Habitat: MCRDR}, where Species and Habitat are parent categories and SCRDR
649
- and MCRDR are SingleClass and MultiClass ripple down rules classifiers. Species can have child categories like
650
- Mammal, Bird, Fish, etc. which are mutually exclusive, and Habitat can have child categories like
651
- Land, Water, Air, etc, which are not mutually exclusive due to some animals living more than one habitat.
660
+ categories, e.g. {'species': SCRDR, 'habitats': MCRDR}, where 'species' and 'habitats' are attribute names
661
+ for a case of type Animal, while SCRDR and MCRDR are SingleClass and MultiClass ripple down rules classifiers.
662
+ Species can have values like Mammal, Bird, Fish, etc. which are mutually exclusive, while Habitat can have
663
+ values like Land, Water, Air, etc., which are not mutually exclusive due to some animals living more than one
664
+ habitat.
652
665
  """
653
- self.start_rules_dict: Dict[Type, Union[SingleClassRDR, MultiClassRDR]] \
666
+ self.start_rules_dict: Dict[str, Union[SingleClassRDR, MultiClassRDR]] \
654
667
  = category_rdr_map if category_rdr_map else {}
655
668
  super(GeneralRDR, self).__init__()
656
669
  self.all_figs: List[plt.Figure] = [sr.fig for sr in self.start_rules_dict.values()]
657
670
 
671
+ def add_rdr(self, rdr: Union[SingleClassRDR, MultiClassRDR], attribute_name: Optional[str] = None):
672
+ """
673
+ Add a ripple down rules classifier to the map of classifiers.
674
+
675
+ :param rdr: The ripple down rules classifier to add.
676
+ :param attribute_name: The name of the attribute that the classifier is classifying.
677
+ """
678
+ attribute_name = attribute_name if attribute_name else rdr.attribute_name
679
+ self.start_rules_dict[attribute_name] = rdr
680
+
658
681
  @property
659
682
  def start_rule(self) -> Optional[Union[SingleClassRule, MultiClassTopRule]]:
660
683
  return self.start_rules[0] if self.start_rules_dict else None
@@ -662,7 +685,7 @@ class GeneralRDR(RippleDownRules):
662
685
  @start_rule.setter
663
686
  def start_rule(self, value: Union[SingleClassRDR, MultiClassRDR]):
664
687
  if value:
665
- self.start_rules_dict[value.conclusion_type] = value
688
+ self.start_rules_dict[value.attribute_name] = value
666
689
 
667
690
  @property
668
691
  def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
@@ -679,8 +702,8 @@ class GeneralRDR(RippleDownRules):
679
702
  return self._classify(self.start_rules_dict, case)
680
703
 
681
704
  @staticmethod
682
- def _classify(classifiers_dict: Dict[Type, Union[ModuleType, RippleDownRules]],
683
- case: Union[Case, SQLTable]) -> Optional[List[Any]]:
705
+ def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
706
+ case: Union[Case, SQLTable]) -> Optional[Dict[str, Any]]:
684
707
  """
685
708
  Classify a case by going through all classifiers and adding the categories that are classified,
686
709
  and then restarting the classification until no more categories can be added.
@@ -689,21 +712,31 @@ class GeneralRDR(RippleDownRules):
689
712
  :param case: The case to classify.
690
713
  :return: The categories that the case belongs to.
691
714
  """
692
- conclusions = []
715
+ conclusions = {}
693
716
  case_cp = copy_case(case)
694
717
  while True:
695
- added_attributes = False
696
- for cat_type, rdr in classifiers_dict.items():
697
- if GeneralRDR.case_has_conclusion(case_cp, cat_type):
698
- continue
718
+ new_conclusions = {}
719
+ for attribute_name, rdr in classifiers_dict.items():
699
720
  pred_atts = rdr.classify(case_cp)
700
- if pred_atts:
721
+ if pred_atts is None:
722
+ continue
723
+ if isinstance(rdr, SingleClassRDR):
724
+ if attribute_name not in conclusions or \
725
+ (attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
726
+ conclusions[attribute_name] = pred_atts
727
+ new_conclusions[attribute_name] = pred_atts
728
+ else:
701
729
  pred_atts = make_list(pred_atts)
702
- pred_atts = [p for p in pred_atts if p not in conclusions]
703
- added_attributes = True
704
- conclusions.extend(pred_atts)
705
- GeneralRDR.update_case(case_cp, pred_atts)
706
- if not added_attributes:
730
+ if attribute_name in conclusions:
731
+ pred_atts = [p for p in pred_atts if p not in conclusions[attribute_name]]
732
+ if len(pred_atts) > 0:
733
+ new_conclusions[attribute_name] = pred_atts
734
+ if attribute_name not in conclusions:
735
+ conclusions[attribute_name] = []
736
+ conclusions[attribute_name].extend(pred_atts)
737
+ if attribute_name in new_conclusions:
738
+ GeneralRDR.update_case(case_cp, new_conclusions)
739
+ if len(new_conclusions) == 0:
707
740
  break
708
741
  return conclusions
709
742
 
@@ -728,103 +761,79 @@ class GeneralRDR(RippleDownRules):
728
761
  case = case_queries[0].case
729
762
  assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
730
763
  " for multiple cases use fit instead")
731
- case_query_cp = copy(case_queries[0])
732
- case_cp = case_query_cp.case
764
+ case_cp = copy(case_queries[0]).case
733
765
  for case_query in case_queries:
734
- target = case_query.target
735
- if not target:
766
+ case_query_cp = copy(case_query)
767
+ case_query_cp.case = case_cp
768
+ if case_query.target is None:
769
+ conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
736
770
  target = expert.ask_for_conclusion(case_query)
737
- case_query_cp = CaseQuery(case_cp, attribute_name=case_query.attribute_name, target=target)
738
- if is_iterable(target) and not isinstance(target, CaseAttribute):
739
- target_type = type(make_list(target)[0])
740
- assert all([type(t) is target_type for t in target]), ("All targets of a case query must be of the same"
741
- " type")
742
- else:
743
- target_type = type(target)
744
- if target_type not in self.start_rules_dict:
771
+
772
+ if case_query.attribute_name not in self.start_rules_dict:
745
773
  conclusions = self.classify(case)
746
774
  self.update_case(case_cp, conclusions)
747
- new_rdr = self.initialize_new_rdr_for_attribute(target, case_cp)
775
+
776
+ new_rdr = self.initialize_new_rdr_for_attribute(case_query.attribute_name, case_cp, case_query.target)
777
+ self.add_rdr(new_rdr, case_query.attribute_name)
778
+
748
779
  new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
749
- self.start_rules_dict[target_type] = new_rdr
750
- self.update_case(case_cp, new_conclusions, target_type)
751
- elif not self.case_has_conclusion(case_cp, target_type):
752
- for rdr_type, rdr in self.start_rules_dict.items():
753
- if target_type is not rdr_type:
780
+ self.update_case(case_cp, {case_query.attribute_name: new_conclusions})
781
+ else:
782
+ for rdr_attribute_name, rdr in self.start_rules_dict.items():
783
+ if case_query.attribute_name != rdr_attribute_name:
754
784
  conclusions = rdr.classify(case_cp)
755
785
  else:
756
- conclusions = self.start_rules_dict[target_type].fit_case(case_query_cp,
757
- expert, **kwargs)
758
- self.update_case(case_cp, conclusions, rdr_type)
786
+ conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
787
+ **kwargs)
788
+ if conclusions is not None or (is_iterable(conclusions) and len(conclusions) > 0):
789
+ conclusions = {rdr_attribute_name: conclusions}
790
+ self.update_case(case_cp, conclusions)
759
791
 
760
792
  return self.classify(case)
761
793
 
762
794
  @staticmethod
763
- def initialize_new_rdr_for_attribute(attribute: Any, case: Union[Case, SQLTable]):
795
+ def initialize_new_rdr_for_attribute(attribute_name: str, case: Union[Case, SQLTable], target: Any):
764
796
  """
765
797
  Initialize the appropriate RDR type for the target.
766
798
  """
767
- if isinstance(case, SQLTable):
768
- prop = get_attribute_by_type(case, type(attribute))
769
- if hasattr(prop, "__iter__") and not isinstance(prop, str):
770
- return MultiClassRDR()
771
- else:
772
- return SingleClassRDR()
773
- elif isinstance(attribute, CaseAttribute):
799
+ attribute = getattr(case, attribute_name) if hasattr(case, attribute_name) else target
800
+ if isinstance(attribute, CaseAttribute):
774
801
  return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
775
802
  else:
776
- return MultiClassRDR() if is_iterable(attribute) else SingleClassRDR()
803
+ return MultiClassRDR() if is_iterable(attribute) or (attribute is None) else SingleClassRDR()
777
804
 
778
805
  @staticmethod
779
- def update_case(case: Union[Case, SQLTable],
780
- conclusions: List[Any], attribute_type: Optional[Any] = None):
806
+ def update_case(case: Union[Case, SQLTable], conclusions: Dict[str, Any]):
781
807
  """
782
808
  Update the case with the conclusions.
783
809
 
784
810
  :param case: The case to update.
785
811
  :param conclusions: The conclusions to update the case with.
786
- :param attribute_type: The type of the attribute to update.
787
812
  """
788
813
  if not conclusions:
789
814
  return
790
- conclusions = [conclusions] if not isinstance(conclusions, list) else list(conclusions)
791
815
  if len(conclusions) == 0:
792
816
  return
793
817
  if isinstance(case, SQLTable):
794
- conclusions_type = type(conclusions[0]) if not attribute_type else attribute_type
795
- attr_name, attribute = get_attribute_by_type(case, conclusions_type)
796
- hint, origin, args = get_hint_for_attribute(attr_name, case)
797
- if isinstance(attribute, set) or origin == set:
798
- attribute = set() if attribute is None else attribute
799
- for c in conclusions:
800
- attribute.update(make_set(c))
801
- elif isinstance(attribute, list) or origin == list:
802
- attribute = [] if attribute is None else attribute
803
- attribute.extend(conclusions)
804
- elif len(conclusions) == 1 and hint == conclusions_type:
805
- setattr(case, attr_name, conclusions.pop())
806
- else:
807
- raise ValueError(f"Cannot add multiple conclusions to attribute {attr_name}")
818
+ for conclusion_name, conclusion in conclusions.items():
819
+ hint, origin, args = get_hint_for_attribute(conclusion_name, case)
820
+ attribute = getattr(case, conclusion_name)
821
+ if isinstance(attribute, set) or origin in {Set, set}:
822
+ attribute = set() if attribute is None else attribute
823
+ for c in conclusion:
824
+ attribute.update(make_set(c))
825
+ elif isinstance(attribute, list) or origin in {list, List}:
826
+ attribute = [] if attribute is None else attribute
827
+ attribute.extend(conclusion)
828
+ elif (not is_iterable(conclusion) or (len(conclusion) == 1)) and hint == type(conclusion):
829
+ setattr(case, conclusion_name, conclusion)
830
+ else:
831
+ raise ValueError(f"Cannot add multiple conclusions to attribute {conclusion_name}")
808
832
  else:
809
- for c in make_set(conclusions):
810
- case.update(c.as_dict)
811
-
812
- @property
813
- def names_of_all_types(self) -> List[str]:
814
- """
815
- Get the names of all the types of categories that the GRDR can classify.
816
- """
817
- return [t.__name__ for t in self.start_rules_dict.keys()]
818
-
819
- @property
820
- def all_types(self) -> List[Type]:
821
- """
822
- Get all the types of categories that the GRDR can classify.
823
- """
824
- return list(self.start_rules_dict.keys())
833
+ case.update(conclusions)
825
834
 
826
835
  def _to_json(self) -> Dict[str, Any]:
827
- return {"start_rules": {get_full_class_name(t): rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
836
+ return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
828
837
 
829
838
  @classmethod
830
839
  def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
@@ -833,7 +842,6 @@ class GeneralRDR(RippleDownRules):
833
842
  """
834
843
  start_rules_dict = {}
835
844
  for k, v in data["start_rules"].items():
836
- k = get_type_from_string(k)
837
845
  start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
838
846
  return cls(start_rules_dict)
839
847
 
@@ -849,12 +857,12 @@ class GeneralRDR(RippleDownRules):
849
857
  with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
850
858
  f.write(self._get_imports(file_path) + "\n\n")
851
859
  f.write("classifiers_dict = dict()\n")
852
- for t, rdr in self.start_rules_dict.items():
853
- f.write(f"classifiers_dict[{t.__name__}] = {t.__name__.lower()}_classifier\n")
860
+ for rdr_key, rdr in self.start_rules_dict.items():
861
+ f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
854
862
  f.write("\n\n")
855
863
  f.write(func_def)
856
864
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
857
- f"{' ' * 4} case = create_case(case, recursion_idx=3)\n""")
865
+ f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
858
866
  f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
859
867
 
860
868
  @property
@@ -863,7 +871,7 @@ class GeneralRDR(RippleDownRules):
863
871
  :return: The type of the case (input) to the RDR classifier.
864
872
  """
865
873
  if isinstance(self.start_rule.corner_case, Case):
866
- return self.start_rule.corner_case._type
874
+ return self.start_rule.corner_case._obj_type
867
875
  else:
868
876
  return type(self.start_rule.corner_case)
869
877
 
@@ -876,7 +884,7 @@ class GeneralRDR(RippleDownRules):
876
884
 
877
885
  @property
878
886
  def generated_python_file_name(self) -> str:
879
- return f"{self.case_type.__name__.lower()}_grdr"
887
+ return f"{self.start_rule.corner_case._name.lower()}_rdr"
880
888
 
881
889
  @property
882
890
  def conclusion_type_hint(self) -> str:
@@ -892,10 +900,20 @@ class GeneralRDR(RippleDownRules):
892
900
  imports += f"from ripple_down_rules.datastructures import Case, create_case\n"
893
901
  imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
894
902
  # add conclusion type imports
895
- for conclusion_type in self.start_rules_dict.keys():
896
- imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
903
+ for rdr in self.start_rules_dict.values():
904
+ imports += f"from {rdr.conclusion_type.__module__} import {rdr.conclusion_type.__name__}\n"
897
905
  # add rdr python generated functions.
898
- for conclusion_type, rdr in self.start_rules_dict.items():
906
+ for rdr_key, rdr in self.start_rules_dict.items():
899
907
  imports += (f"from {file_path.strip('./')}"
900
- f" import {rdr.generated_python_file_name} as {conclusion_type.__name__.lower()}_classifier\n")
908
+ f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
901
909
  return imports
910
+
911
+ @staticmethod
912
+ def rdr_key_to_function_name(rdr_key: str) -> str:
913
+ """
914
+ Convert the RDR key to a function name.
915
+
916
+ :param rdr_key: The RDR key to convert.
917
+ :return: The function name.
918
+ """
919
+ return rdr_key.replace(".", "_").lower() + "_classifier"