ripple-down-rules 0.0.15__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ripple_down_rules/rdr.py CHANGED
@@ -8,13 +8,13 @@ from types import ModuleType
8
8
  from matplotlib import pyplot as plt
9
9
  from ordered_set import OrderedSet
10
10
  from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
11
- from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable
11
+ from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
12
12
 
13
13
  from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute, CaseQuery
14
14
  from .experts import Expert, Human
15
15
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
16
- from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
17
- get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_full_class_name, get_type_from_string
16
+ from .utils import draw_tree, make_set, copy_case, \
17
+ get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string
18
18
 
19
19
 
20
20
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -80,7 +80,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
80
80
  :param kwargs_for_fit_case: The keyword arguments to pass to the fit_case method.
81
81
  """
82
82
  cases = [case_query.case for case_query in case_queries]
83
- targets = [case.target for case in case_queries]
83
+ targets = [{case_query.attribute_name: case_query.target} for case_query in case_queries]
84
84
  if animate_tree:
85
85
  plt.ion()
86
86
  i = 0
@@ -91,11 +91,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
91
91
  if not targets:
92
92
  targets = [None] * len(cases)
93
93
  for case_query in case_queries:
94
- case = case_query.case
95
- target = case_query.target
96
- if not target:
97
- conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
98
- target = expert.ask_for_conclusion(case_query, conclusions)
94
+ target = {case_query.attribute_name: case_query.target}
99
95
  pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
100
96
  match = self.is_matching(pred_cat, target)
101
97
  if not match:
@@ -105,8 +101,9 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
105
101
  num_rules = self.start_rule.size
106
102
  self.update_figures()
107
103
  i += 1
108
- all_predictions = [1 if self.is_matching(self.classify(case), target) else 0
109
- for case, target in zip(cases, targets)]
104
+ all_predictions = [1 if self.is_matching(self.classify(case_query.case), {case_query.attribute_name:
105
+ case_query.target}) else 0
106
+ for case_query in case_queries]
110
107
  all_pred = sum(all_predictions)
111
108
  print(f"Accuracy: {all_pred}/{len(targets)}")
112
109
  all_predicted = targets and all_pred == len(targets)
@@ -129,9 +126,33 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
129
126
  """
130
127
  pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
131
128
  target = target if is_iterable(target) else [target]
132
- recall = [not yi or (yi in pred_cat) for yi in target]
133
- target_types = [type(yi) for yi in target]
134
- precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
129
+ recall = []
130
+ precision = []
131
+ if isinstance(pred_cat, dict):
132
+ for pred_key, pred_value in pred_cat.items():
133
+ if pred_key not in target:
134
+ continue
135
+ # if is_iterable(pred_value):
136
+ # print(pred_value, target[pred_key])
137
+ # precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
138
+ precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
139
+ # else:
140
+ # precision.append(pred_value == target[pred_key])
141
+ for target_key, target_value in target.items():
142
+ if target_key not in pred_cat:
143
+ recall.append(False)
144
+ continue
145
+ if is_iterable(target_value):
146
+ recall.extend([v in pred_cat[target_key] for v in target_value])
147
+ else:
148
+ recall.append(target_value == pred_cat[target_key])
149
+ print(f"Precision: {precision}, Recall: {recall}")
150
+ else:
151
+ if isinstance(target, dict):
152
+ target = list(target.values())
153
+ recall = [not yi or (yi in pred_cat) for yi in target]
154
+ target_types = [type(yi) for yi in target]
155
+ precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
135
156
  return precision, recall
136
157
 
137
158
  def is_matching(self, pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> bool:
@@ -158,22 +179,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
158
179
  draw_tree(self.start_rule, self.fig)
159
180
 
160
181
  @staticmethod
161
- def case_has_conclusion(case: Union[Case, SQLTable], conclusion_type: Type) -> bool:
182
+ def case_has_conclusion(case: Union[Case, SQLTable], conclusion_name: str) -> bool:
162
183
  """
163
184
  Check if the case has a conclusion.
164
185
 
165
186
  :param case: The case to check.
166
- :param conclusion_type: The target category type to compare the case with.
187
+ :param conclusion_name: The target category name to compare the case with.
167
188
  :return: Whether the case has a conclusion or not.
168
189
  """
169
- if isinstance(case, SQLTable):
170
- prop_name, prop_value = get_attribute_by_type(case, conclusion_type)
171
- if hasattr(prop_value, "__iter__") and not isinstance(prop_value, str):
172
- return len(prop_value) > 0
173
- else:
174
- return prop_value is not None
175
- else:
176
- return conclusion_type in case
190
+ return hasattr(case, conclusion_name) and getattr(case, conclusion_name) is not None
177
191
 
178
192
 
179
193
  class RDRWithCodeWriter(RippleDownRules, ABC):
@@ -244,7 +258,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
244
258
  :return: The type of the case (input) to the RDR classifier.
245
259
  """
246
260
  if isinstance(self.start_rule.corner_case, Case):
247
- return self.start_rule.corner_case._type
261
+ return self.start_rule.corner_case._obj_type
248
262
  else:
249
263
  return type(self.start_rule.corner_case)
250
264
 
@@ -260,6 +274,13 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
260
274
  return type(list(self.start_rule.conclusion)[0])
261
275
  return type(self.start_rule.conclusion)
262
276
 
277
+ @property
278
+ def attribute_name(self) -> str:
279
+ """
280
+ :return: The name of the attribute that the classifier is classifying.
281
+ """
282
+ return self.start_rule.conclusion_name
283
+
263
284
 
264
285
  class SingleClassRDR(RDRWithCodeWriter):
265
286
 
@@ -274,23 +295,20 @@ class SingleClassRDR(RDRWithCodeWriter):
274
295
  :return: The category that the case belongs to.
275
296
  """
276
297
  expert = expert if expert else Human(session=self.session)
277
- case, attribute = case_query.case, case_query.attribute
278
298
  if case_query.target is None:
279
299
  target = expert.ask_for_conclusion(case_query)
280
- else:
281
- target = case_query.target
282
-
283
300
  if not self.start_rule:
284
- conditions = expert.ask_for_conditions(case, [target])
285
- self.start_rule = SingleClassRule(conditions, target, corner_case=case)
301
+ conditions = expert.ask_for_conditions(case_query)
302
+ self.start_rule = SingleClassRule(conditions, case_query.target, corner_case=case_query.case,
303
+ conclusion_name=case_query.attribute_name)
286
304
 
287
- pred = self.evaluate(case)
305
+ pred = self.evaluate(case_query.case)
288
306
 
289
- if pred.conclusion != target:
290
- conditions = expert.ask_for_conditions(case, [target], pred)
291
- pred.fit_rule(case, target, conditions=conditions)
307
+ if pred.conclusion != case_query.target:
308
+ conditions = expert.ask_for_conditions(case_query, pred)
309
+ pred.fit_rule(case_query.case, case_query.target, conditions=conditions)
292
310
 
293
- return self.classify(case)
311
+ return self.classify(case_query.case)
294
312
 
295
313
  def classify(self, case: Case) -> Optional[CaseAttribute]:
296
314
  """
@@ -388,44 +406,41 @@ class MultiClassRDR(RDRWithCodeWriter):
388
406
  :return: The conclusions that the case belongs to.
389
407
  """
390
408
  expert = expert if expert else Human(session=self.session)
391
- case = case_query.case
392
409
  if case_query.target is None:
393
- targets = [expert.ask_for_conclusion(case_query)]
394
- else:
395
- targets = [case_query.target]
410
+ targets = expert.ask_for_conclusion(case_query)
396
411
  self.expert_accepted_conclusions = []
397
412
  user_conclusions = []
398
- for target in targets:
399
- self.update_start_rule(case, target, expert)
400
- self.conclusions = []
401
- self.stop_rule_conditions = None
402
- evaluated_rule = self.start_rule
403
- while evaluated_rule:
404
- next_rule = evaluated_rule(case)
405
- good_conclusions = targets + user_conclusions + self.expert_accepted_conclusions
406
-
407
- if evaluated_rule.fired:
408
- if target and evaluated_rule.conclusion not in good_conclusions:
409
- # if self.case_has_conclusion(case, evaluated_rule.conclusion):
410
- # Rule fired and conclusion is different from target
411
- self.stop_wrong_conclusion_else_add_it(case, target, expert, evaluated_rule,
412
- add_extra_conclusions)
413
- else:
414
- # Rule fired and target is correct or there is no target to compare
415
- self.add_conclusion(evaluated_rule)
416
-
417
- if not next_rule:
418
- if not make_set(target).intersection(make_set(self.conclusions)):
419
- # Nothing fired and there is a target that should have been in the conclusions
420
- self.add_rule_for_case(case, target, expert)
421
- # Have to check all rules again to make sure only this new rule fires
422
- next_rule = self.start_rule
423
- elif add_extra_conclusions and not user_conclusions:
424
- # No more conclusions can be made, ask the expert for extra conclusions if needed.
425
- user_conclusions.extend(self.ask_expert_for_extra_conclusions(expert, case))
426
- if user_conclusions:
427
- next_rule = self.last_top_rule
428
- evaluated_rule = next_rule
413
+ self.update_start_rule(case_query, expert)
414
+ self.conclusions = []
415
+ self.stop_rule_conditions = None
416
+ evaluated_rule = self.start_rule
417
+ while evaluated_rule:
418
+ next_rule = evaluated_rule(case_query.case)
419
+ good_conclusions = make_list(case_query.target) + user_conclusions + self.expert_accepted_conclusions
420
+ good_conclusions = make_set(good_conclusions)
421
+
422
+ if evaluated_rule.fired:
423
+ if case_query.target and not make_set(evaluated_rule.conclusion).issubset(good_conclusions):
424
+ # if self.case_has_conclusion(case, evaluated_rule.conclusion):
425
+ # Rule fired and conclusion is different from target
426
+ self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
427
+ add_extra_conclusions)
428
+ else:
429
+ # Rule fired and target is correct or there is no target to compare
430
+ self.add_conclusion(evaluated_rule)
431
+
432
+ if not next_rule:
433
+ if not make_set(case_query.target).intersection(make_set(self.conclusions)):
434
+ # Nothing fired and there is a target that should have been in the conclusions
435
+ self.add_rule_for_case(case_query, expert)
436
+ # Have to check all rules again to make sure only this new rule fires
437
+ next_rule = self.start_rule
438
+ elif add_extra_conclusions and not user_conclusions:
439
+ # No more conclusions can be made, ask the expert for extra conclusions if needed.
440
+ user_conclusions.extend(self.ask_expert_for_extra_conclusions(expert, case_query.case))
441
+ if user_conclusions:
442
+ next_rule = self.last_top_rule
443
+ evaluated_rule = next_rule
429
444
  return self.conclusions
430
445
 
431
446
  def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
@@ -458,19 +473,18 @@ class MultiClassRDR(RDRWithCodeWriter):
458
473
  imports += "from typing_extensions import Set\n"
459
474
  return imports
460
475
 
461
- def update_start_rule(self, case: Union[Case, SQLTable], target: Any, expert: Expert):
476
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
462
477
  """
463
478
  Update the starting rule of the classifier.
464
479
 
465
- :param case: The case to classify.
466
- :param target: The target category to compare the case with.
480
+ :param case_query: The case query to update the starting rule with.
467
481
  :param expert: The expert to ask for differentiating features as new rule conditions.
468
482
  """
469
483
  if not self.start_rule.conditions:
470
- conditions = expert.ask_for_conditions(case, target)
484
+ conditions = expert.ask_for_conditions(case_query)
471
485
  self.start_rule.conditions = conditions
472
- self.start_rule.conclusion = target
473
- self.start_rule.corner_case = case
486
+ self.start_rule.conclusion = case_query.target
487
+ self.start_rule.corner_case = case_query.case
474
488
 
475
489
  @property
476
490
  def last_top_rule(self) -> Optional[MultiClassTopRule]:
@@ -482,35 +496,34 @@ class MultiClassRDR(RDRWithCodeWriter):
482
496
  else:
483
497
  return self.start_rule.furthest_alternative[-1]
484
498
 
485
- def stop_wrong_conclusion_else_add_it(self, case: Union[Case, SQLTable], target: Any, expert: Expert,
499
+ def stop_wrong_conclusion_else_add_it(self, case_query: CaseQuery, expert: Expert,
486
500
  evaluated_rule: MultiClassTopRule,
487
501
  add_extra_conclusions: bool):
488
502
  """
489
503
  Stop a wrong conclusion by adding a stopping rule.
490
504
  """
491
- if self.is_same_category_type(evaluated_rule.conclusion, target) \
492
- and self.is_conflicting_with_target(evaluated_rule.conclusion, target):
493
- self.stop_conclusion(case, target, expert, evaluated_rule)
494
- elif not self.conclusion_is_correct(case, target, expert, evaluated_rule, add_extra_conclusions):
495
- self.stop_conclusion(case, target, expert, evaluated_rule)
505
+ if self.is_same_category_type(evaluated_rule.conclusion, case_query.target) \
506
+ and self.is_conflicting_with_target(evaluated_rule.conclusion, case_query.target):
507
+ self.stop_conclusion(case_query, expert, evaluated_rule)
508
+ elif not self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
509
+ self.stop_conclusion(case_query, expert, evaluated_rule)
496
510
 
497
- def stop_conclusion(self, case: Union[Case, SQLTable], target: Any,
511
+ def stop_conclusion(self, case_query: CaseQuery,
498
512
  expert: Expert, evaluated_rule: MultiClassTopRule):
499
513
  """
500
514
  Stop a conclusion by adding a stopping rule.
501
515
 
502
- :param case: The case to classify.
503
- :param target: The target category to compare the case with.
516
+ :param case_query: The case query to stop the conclusion for.
504
517
  :param expert: The expert to ask for differentiating features as new rule conditions.
505
518
  :param evaluated_rule: The evaluated rule to ask the expert about.
506
519
  """
507
- conditions = expert.ask_for_conditions(case, target, evaluated_rule)
508
- evaluated_rule.fit_rule(case, target, conditions=conditions)
520
+ conditions = expert.ask_for_conditions(case_query, evaluated_rule)
521
+ evaluated_rule.fit_rule(case_query.case, case_query.target, conditions=conditions)
509
522
  if self.mode == MCRDRMode.StopPlusRule:
510
523
  self.stop_rule_conditions = conditions
511
524
  if self.mode == MCRDRMode.StopPlusRuleCombined:
512
525
  new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
513
- self.add_top_rule(new_top_rule_conditions, target, case)
526
+ self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
514
527
 
515
528
  @staticmethod
516
529
  def is_conflicting_with_target(conclusion: Any, target: Any) -> bool:
@@ -537,37 +550,40 @@ class MultiClassRDR(RDRWithCodeWriter):
537
550
  """
538
551
  return conclusion.__class__ == target.__class__ and target.__class__ != CaseAttribute
539
552
 
540
- def conclusion_is_correct(self, case: Union[Case, SQLTable], target: Any, expert: Expert, evaluated_rule: Rule,
553
+ def conclusion_is_correct(self, case_query: CaseQuery,
554
+ expert: Expert, evaluated_rule: Rule,
541
555
  add_extra_conclusions: bool) -> bool:
542
556
  """
543
557
  Ask the expert if the conclusion is correct, and add it to the conclusions if it is.
544
558
 
545
- :param case: The case to classify.
546
- :param target: The target category to compare the case with.
559
+ :param case_query: The case query to ask the expert about.
547
560
  :param expert: The expert to ask for differentiating features as new rule conditions.
548
561
  :param evaluated_rule: The evaluated rule to ask the expert about.
549
562
  :param add_extra_conclusions: Whether adding extra conclusions after classification is allowed.
550
563
  :return: Whether the conclusion is correct or not.
551
564
  """
552
- conclusions = list(OrderedSet(self.conclusions))
553
- if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case, evaluated_rule.conclusion,
554
- targets=target,
565
+ conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
566
+ if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case, evaluated_rule.conclusion,
567
+ targets=case_query.target,
555
568
  current_conclusions=conclusions)):
556
569
  self.add_conclusion(evaluated_rule)
557
570
  self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
558
571
  return True
559
572
  return False
560
573
 
561
- def add_rule_for_case(self, case: Union[Case, SQLTable], target: Any, expert: Expert):
574
+ def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
562
575
  """
563
576
  Add a rule for a case that has not been classified with any conclusion.
577
+
578
+ :param case_query: The case query to add the rule for.
579
+ :param expert: The expert to ask for differentiating features as new rule conditions.
564
580
  """
565
581
  if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
566
582
  conditions = self.stop_rule_conditions
567
583
  self.stop_rule_conditions = None
568
584
  else:
569
- conditions = expert.ask_for_conditions(case, target)
570
- self.add_top_rule(conditions, target, case)
585
+ conditions = expert.ask_for_conditions(case_query)
586
+ self.add_top_rule(conditions, case_query.target, case_query.case)
571
587
 
572
588
  def ask_expert_for_extra_conclusions(self, expert: Expert, case: Union[Case, SQLTable]) -> List[Any]:
573
589
  """
@@ -641,20 +657,31 @@ class GeneralRDR(RippleDownRules):
641
657
  gets called when the final rule fires.
642
658
  """
643
659
 
644
- def __init__(self, category_rdr_map: Optional[Dict[Type, Union[SingleClassRDR, MultiClassRDR]]] = None):
660
+ def __init__(self, category_rdr_map: Optional[Dict[str, Union[SingleClassRDR, MultiClassRDR]]] = None):
645
661
  """
646
- :param category_rdr_map: A map of categories to ripple down rules classifiers,
662
+ :param category_rdr_map: A map of case attribute names to ripple down rules classifiers,
647
663
  where each category is a parent category that has a set of mutually exclusive (in case of SCRDR) child
648
- categories, e.g. {Species: SCRDR, Habitat: MCRDR}, where Species and Habitat are parent categories and SCRDR
649
- and MCRDR are SingleClass and MultiClass ripple down rules classifiers. Species can have child categories like
650
- Mammal, Bird, Fish, etc. which are mutually exclusive, and Habitat can have child categories like
651
- Land, Water, Air, etc, which are not mutually exclusive due to some animals living more than one habitat.
664
+ categories, e.g. {'species': SCRDR, 'habitats': MCRDR}, where 'species' and 'habitats' are attribute names
665
+ for a case of type Animal, while SCRDR and MCRDR are SingleClass and MultiClass ripple down rules classifiers.
666
+ Species can have values like Mammal, Bird, Fish, etc. which are mutually exclusive, while Habitat can have
667
+ values like Land, Water, Air, etc., which are not mutually exclusive due to some animals living more than one
668
+ habitat.
652
669
  """
653
- self.start_rules_dict: Dict[Type, Union[SingleClassRDR, MultiClassRDR]] \
670
+ self.start_rules_dict: Dict[str, Union[SingleClassRDR, MultiClassRDR]] \
654
671
  = category_rdr_map if category_rdr_map else {}
655
672
  super(GeneralRDR, self).__init__()
656
673
  self.all_figs: List[plt.Figure] = [sr.fig for sr in self.start_rules_dict.values()]
657
674
 
675
+ def add_rdr(self, rdr: Union[SingleClassRDR, MultiClassRDR], attribute_name: Optional[str] = None):
676
+ """
677
+ Add a ripple down rules classifier to the map of classifiers.
678
+
679
+ :param rdr: The ripple down rules classifier to add.
680
+ :param attribute_name: The name of the attribute that the classifier is classifying.
681
+ """
682
+ attribute_name = attribute_name if attribute_name else rdr.attribute_name
683
+ self.start_rules_dict[attribute_name] = rdr
684
+
658
685
  @property
659
686
  def start_rule(self) -> Optional[Union[SingleClassRule, MultiClassTopRule]]:
660
687
  return self.start_rules[0] if self.start_rules_dict else None
@@ -662,7 +689,7 @@ class GeneralRDR(RippleDownRules):
662
689
  @start_rule.setter
663
690
  def start_rule(self, value: Union[SingleClassRDR, MultiClassRDR]):
664
691
  if value:
665
- self.start_rules_dict[value.conclusion_type] = value
692
+ self.start_rules_dict[value.attribute_name] = value
666
693
 
667
694
  @property
668
695
  def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
@@ -679,8 +706,8 @@ class GeneralRDR(RippleDownRules):
679
706
  return self._classify(self.start_rules_dict, case)
680
707
 
681
708
  @staticmethod
682
- def _classify(classifiers_dict: Dict[Type, Union[ModuleType, RippleDownRules]],
683
- case: Union[Case, SQLTable]) -> Optional[List[Any]]:
709
+ def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
710
+ case: Union[Case, SQLTable]) -> Optional[Dict[str, Any]]:
684
711
  """
685
712
  Classify a case by going through all classifiers and adding the categories that are classified,
686
713
  and then restarting the classification until no more categories can be added.
@@ -689,21 +716,31 @@ class GeneralRDR(RippleDownRules):
689
716
  :param case: The case to classify.
690
717
  :return: The categories that the case belongs to.
691
718
  """
692
- conclusions = []
719
+ conclusions = {}
693
720
  case_cp = copy_case(case)
694
721
  while True:
695
- added_attributes = False
696
- for cat_type, rdr in classifiers_dict.items():
697
- if GeneralRDR.case_has_conclusion(case_cp, cat_type):
698
- continue
722
+ new_conclusions = {}
723
+ for attribute_name, rdr in classifiers_dict.items():
699
724
  pred_atts = rdr.classify(case_cp)
700
- if pred_atts:
725
+ if pred_atts is None:
726
+ continue
727
+ if isinstance(rdr, SingleClassRDR):
728
+ if attribute_name not in conclusions or \
729
+ (attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
730
+ conclusions[attribute_name] = pred_atts
731
+ new_conclusions[attribute_name] = pred_atts
732
+ else:
701
733
  pred_atts = make_list(pred_atts)
702
- pred_atts = [p for p in pred_atts if p not in conclusions]
703
- added_attributes = True
704
- conclusions.extend(pred_atts)
705
- GeneralRDR.update_case(case_cp, pred_atts)
706
- if not added_attributes:
734
+ if attribute_name in conclusions:
735
+ pred_atts = [p for p in pred_atts if p not in conclusions[attribute_name]]
736
+ if len(pred_atts) > 0:
737
+ new_conclusions[attribute_name] = pred_atts
738
+ if attribute_name not in conclusions:
739
+ conclusions[attribute_name] = []
740
+ conclusions[attribute_name].extend(pred_atts)
741
+ if attribute_name in new_conclusions:
742
+ GeneralRDR.update_case(case_cp, new_conclusions)
743
+ if len(new_conclusions) == 0:
707
744
  break
708
745
  return conclusions
709
746
 
@@ -728,103 +765,79 @@ class GeneralRDR(RippleDownRules):
728
765
  case = case_queries[0].case
729
766
  assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
730
767
  " for multiple cases use fit instead")
731
- case_query_cp = copy(case_queries[0])
732
- case_cp = case_query_cp.case
768
+ case_cp = copy(case_queries[0]).case
733
769
  for case_query in case_queries:
734
- target = case_query.target
735
- if not target:
770
+ case_query_cp = copy(case_query)
771
+ case_query_cp.case = case_cp
772
+ if case_query.target is None:
773
+ conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
736
774
  target = expert.ask_for_conclusion(case_query)
737
- case_query_cp = CaseQuery(case_cp, attribute_name=case_query.attribute_name, target=target)
738
- if is_iterable(target) and not isinstance(target, CaseAttribute):
739
- target_type = type(make_list(target)[0])
740
- assert all([type(t) is target_type for t in target]), ("All targets of a case query must be of the same"
741
- " type")
742
- else:
743
- target_type = type(target)
744
- if target_type not in self.start_rules_dict:
775
+
776
+ if case_query.attribute_name not in self.start_rules_dict:
745
777
  conclusions = self.classify(case)
746
778
  self.update_case(case_cp, conclusions)
747
- new_rdr = self.initialize_new_rdr_for_attribute(target, case_cp)
779
+
780
+ new_rdr = self.initialize_new_rdr_for_attribute(case_query.attribute_name, case_cp, case_query.target)
781
+ self.add_rdr(new_rdr, case_query.attribute_name)
782
+
748
783
  new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
749
- self.start_rules_dict[target_type] = new_rdr
750
- self.update_case(case_cp, new_conclusions, target_type)
751
- elif not self.case_has_conclusion(case_cp, target_type):
752
- for rdr_type, rdr in self.start_rules_dict.items():
753
- if target_type is not rdr_type:
784
+ self.update_case(case_cp, {case_query.attribute_name: new_conclusions})
785
+ else:
786
+ for rdr_attribute_name, rdr in self.start_rules_dict.items():
787
+ if case_query.attribute_name != rdr_attribute_name:
754
788
  conclusions = rdr.classify(case_cp)
755
789
  else:
756
- conclusions = self.start_rules_dict[target_type].fit_case(case_query_cp,
757
- expert, **kwargs)
758
- self.update_case(case_cp, conclusions, rdr_type)
790
+ conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
791
+ **kwargs)
792
+ if conclusions is not None or (is_iterable(conclusions) and len(conclusions) > 0):
793
+ conclusions = {rdr_attribute_name: conclusions}
794
+ self.update_case(case_cp, conclusions)
759
795
 
760
796
  return self.classify(case)
761
797
 
762
798
  @staticmethod
763
- def initialize_new_rdr_for_attribute(attribute: Any, case: Union[Case, SQLTable]):
799
+ def initialize_new_rdr_for_attribute(attribute_name: str, case: Union[Case, SQLTable], target: Any):
764
800
  """
765
801
  Initialize the appropriate RDR type for the target.
766
802
  """
767
- if isinstance(case, SQLTable):
768
- prop = get_attribute_by_type(case, type(attribute))
769
- if hasattr(prop, "__iter__") and not isinstance(prop, str):
770
- return MultiClassRDR()
771
- else:
772
- return SingleClassRDR()
773
- elif isinstance(attribute, CaseAttribute):
803
+ attribute = getattr(case, attribute_name) if hasattr(case, attribute_name) else target
804
+ if isinstance(attribute, CaseAttribute):
774
805
  return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
775
806
  else:
776
- return MultiClassRDR() if is_iterable(attribute) else SingleClassRDR()
807
+ return MultiClassRDR() if is_iterable(attribute) or (attribute is None) else SingleClassRDR()
777
808
 
778
809
  @staticmethod
779
- def update_case(case: Union[Case, SQLTable],
780
- conclusions: List[Any], attribute_type: Optional[Any] = None):
810
+ def update_case(case: Union[Case, SQLTable], conclusions: Dict[str, Any]):
781
811
  """
782
812
  Update the case with the conclusions.
783
813
 
784
814
  :param case: The case to update.
785
815
  :param conclusions: The conclusions to update the case with.
786
- :param attribute_type: The type of the attribute to update.
787
816
  """
788
817
  if not conclusions:
789
818
  return
790
- conclusions = [conclusions] if not isinstance(conclusions, list) else list(conclusions)
791
819
  if len(conclusions) == 0:
792
820
  return
793
821
  if isinstance(case, SQLTable):
794
- conclusions_type = type(conclusions[0]) if not attribute_type else attribute_type
795
- attr_name, attribute = get_attribute_by_type(case, conclusions_type)
796
- hint, origin, args = get_hint_for_attribute(attr_name, case)
797
- if isinstance(attribute, set) or origin == set:
798
- attribute = set() if attribute is None else attribute
799
- for c in conclusions:
800
- attribute.update(make_set(c))
801
- elif isinstance(attribute, list) or origin == list:
802
- attribute = [] if attribute is None else attribute
803
- attribute.extend(conclusions)
804
- elif len(conclusions) == 1 and hint == conclusions_type:
805
- setattr(case, attr_name, conclusions.pop())
806
- else:
807
- raise ValueError(f"Cannot add multiple conclusions to attribute {attr_name}")
822
+ for conclusion_name, conclusion in conclusions.items():
823
+ hint, origin, args = get_hint_for_attribute(conclusion_name, case)
824
+ attribute = getattr(case, conclusion_name)
825
+ if isinstance(attribute, set) or origin in {Set, set}:
826
+ attribute = set() if attribute is None else attribute
827
+ for c in conclusion:
828
+ attribute.update(make_set(c))
829
+ elif isinstance(attribute, list) or origin in {list, List}:
830
+ attribute = [] if attribute is None else attribute
831
+ attribute.extend(conclusion)
832
+ elif (not is_iterable(conclusion) or (len(conclusion) == 1)) and hint == type(conclusion):
833
+ setattr(case, conclusion_name, conclusion)
834
+ else:
835
+ raise ValueError(f"Cannot add multiple conclusions to attribute {conclusion_name}")
808
836
  else:
809
- for c in make_set(conclusions):
810
- case.update(c.as_dict)
811
-
812
- @property
813
- def names_of_all_types(self) -> List[str]:
814
- """
815
- Get the names of all the types of categories that the GRDR can classify.
816
- """
817
- return [t.__name__ for t in self.start_rules_dict.keys()]
818
-
819
- @property
820
- def all_types(self) -> List[Type]:
821
- """
822
- Get all the types of categories that the GRDR can classify.
823
- """
824
- return list(self.start_rules_dict.keys())
837
+ case.update(conclusions)
825
838
 
826
839
  def _to_json(self) -> Dict[str, Any]:
827
- return {"start_rules": {get_full_class_name(t): rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
840
+ return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
828
841
 
829
842
  @classmethod
830
843
  def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
@@ -833,7 +846,6 @@ class GeneralRDR(RippleDownRules):
833
846
  """
834
847
  start_rules_dict = {}
835
848
  for k, v in data["start_rules"].items():
836
- k = get_type_from_string(k)
837
849
  start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
838
850
  return cls(start_rules_dict)
839
851
 
@@ -849,8 +861,8 @@ class GeneralRDR(RippleDownRules):
849
861
  with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
850
862
  f.write(self._get_imports(file_path) + "\n\n")
851
863
  f.write("classifiers_dict = dict()\n")
852
- for t, rdr in self.start_rules_dict.items():
853
- f.write(f"classifiers_dict[{t.__name__}] = {t.__name__.lower()}_classifier\n")
864
+ for rdr_key, rdr in self.start_rules_dict.items():
865
+ f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
854
866
  f.write("\n\n")
855
867
  f.write(func_def)
856
868
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
@@ -863,7 +875,7 @@ class GeneralRDR(RippleDownRules):
863
875
  :return: The type of the case (input) to the RDR classifier.
864
876
  """
865
877
  if isinstance(self.start_rule.corner_case, Case):
866
- return self.start_rule.corner_case._type
878
+ return self.start_rule.corner_case._obj_type
867
879
  else:
868
880
  return type(self.start_rule.corner_case)
869
881
 
@@ -892,10 +904,20 @@ class GeneralRDR(RippleDownRules):
892
904
  imports += f"from ripple_down_rules.datastructures import Case, create_case\n"
893
905
  imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
894
906
  # add conclusion type imports
895
- for conclusion_type in self.start_rules_dict.keys():
896
- imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
907
+ for rdr in self.start_rules_dict.values():
908
+ imports += f"from {rdr.conclusion_type.__module__} import {rdr.conclusion_type.__name__}\n"
897
909
  # add rdr python generated functions.
898
- for conclusion_type, rdr in self.start_rules_dict.items():
910
+ for rdr_key, rdr in self.start_rules_dict.items():
899
911
  imports += (f"from {file_path.strip('./')}"
900
- f" import {rdr.generated_python_file_name} as {conclusion_type.__name__.lower()}_classifier\n")
912
+ f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
901
913
  return imports
914
+
915
+ @staticmethod
916
+ def rdr_key_to_function_name(rdr_key: str) -> str:
917
+ """
918
+ Convert the RDR key to a function name.
919
+
920
+ :param rdr_key: The RDR key to convert.
921
+ :return: The function name.
922
+ """
923
+ return rdr_key.replace(".", "_").lower() + "_classifier"