ripple-down-rules 0.0.14__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ripple_down_rules/rdr.py CHANGED
@@ -2,19 +2,19 @@ from __future__ import annotations
2
2
 
3
3
  import importlib
4
4
  from abc import ABC, abstractmethod
5
- from copy import copy, deepcopy
5
+ from copy import copy
6
6
  from types import ModuleType
7
7
 
8
8
  from matplotlib import pyplot as plt
9
9
  from ordered_set import OrderedSet
10
10
  from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
11
- from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable
11
+ from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
12
12
 
13
13
  from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute, CaseQuery
14
14
  from .experts import Expert, Human
15
15
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
16
- from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
17
- get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_full_class_name, get_type_from_string
16
+ from .utils import draw_tree, make_set, copy_case, \
17
+ get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string
18
18
 
19
19
 
20
20
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -80,7 +80,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
80
80
  :param kwargs_for_fit_case: The keyword arguments to pass to the fit_case method.
81
81
  """
82
82
  cases = [case_query.case for case_query in case_queries]
83
- targets = [case.target for case in case_queries]
83
+ targets = [{case_query.attribute_name: case_query.target} for case_query in case_queries]
84
84
  if animate_tree:
85
85
  plt.ion()
86
86
  i = 0
@@ -91,11 +91,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
91
91
  if not targets:
92
92
  targets = [None] * len(cases)
93
93
  for case_query in case_queries:
94
- case = case_query.case
95
- target = case_query.target
96
- if not target:
97
- conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
98
- target = expert.ask_for_conclusion(case_query, conclusions)
94
+ target = {case_query.attribute_name: case_query.target}
99
95
  pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
100
96
  match = self.is_matching(pred_cat, target)
101
97
  if not match:
@@ -105,8 +101,9 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
105
101
  num_rules = self.start_rule.size
106
102
  self.update_figures()
107
103
  i += 1
108
- all_predictions = [1 if self.is_matching(self.classify(case), target) else 0
109
- for case, target in zip(cases, targets)]
104
+ all_predictions = [1 if self.is_matching(self.classify(case_query.case), {case_query.attribute_name:
105
+ case_query.target}) else 0
106
+ for case_query in case_queries]
110
107
  all_pred = sum(all_predictions)
111
108
  print(f"Accuracy: {all_pred}/{len(targets)}")
112
109
  all_predicted = targets and all_pred == len(targets)
@@ -120,7 +117,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
120
117
  plt.show()
121
118
 
122
119
  @staticmethod
123
- def calculate_precision_and_recall(pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> Tuple[List[bool], List[bool]]:
120
+ def calculate_precision_and_recall(pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> Tuple[
121
+ List[bool], List[bool]]:
124
122
  """
125
123
  :param pred_cat: The predicted category.
126
124
  :param target: The target category.
@@ -128,9 +126,33 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
128
126
  """
129
127
  pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
130
128
  target = target if is_iterable(target) else [target]
131
- recall = [not yi or (yi in pred_cat) for yi in target]
132
- target_types = [type(yi) for yi in target]
133
- precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
129
+ recall = []
130
+ precision = []
131
+ if isinstance(pred_cat, dict):
132
+ for pred_key, pred_value in pred_cat.items():
133
+ if pred_key not in target:
134
+ continue
135
+ # if is_iterable(pred_value):
136
+ # print(pred_value, target[pred_key])
137
+ # precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
138
+ precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
139
+ # else:
140
+ # precision.append(pred_value == target[pred_key])
141
+ for target_key, target_value in target.items():
142
+ if target_key not in pred_cat:
143
+ recall.append(False)
144
+ continue
145
+ if is_iterable(target_value):
146
+ recall.extend([v in pred_cat[target_key] for v in target_value])
147
+ else:
148
+ recall.append(target_value == pred_cat[target_key])
149
+ print(f"Precision: {precision}, Recall: {recall}")
150
+ else:
151
+ if isinstance(target, dict):
152
+ target = list(target.values())
153
+ recall = [not yi or (yi in pred_cat) for yi in target]
154
+ target_types = [type(yi) for yi in target]
155
+ precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
134
156
  return precision, recall
135
157
 
136
158
  def is_matching(self, pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> bool:
@@ -157,22 +179,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
157
179
  draw_tree(self.start_rule, self.fig)
158
180
 
159
181
  @staticmethod
160
- def case_has_conclusion(case: Union[Case, SQLTable], conclusion_type: Type) -> bool:
182
+ def case_has_conclusion(case: Union[Case, SQLTable], conclusion_name: str) -> bool:
161
183
  """
162
184
  Check if the case has a conclusion.
163
185
 
164
186
  :param case: The case to check.
165
- :param conclusion_type: The target category type to compare the case with.
187
+ :param conclusion_name: The target category name to compare the case with.
166
188
  :return: Whether the case has a conclusion or not.
167
189
  """
168
- if isinstance(case, SQLTable):
169
- prop_name, prop_value = get_attribute_by_type(case, conclusion_type)
170
- if hasattr(prop_value, "__iter__") and not isinstance(prop_value, str):
171
- return len(prop_value) > 0
172
- else:
173
- return prop_value is not None
174
- else:
175
- return conclusion_type in case
190
+ return hasattr(case, conclusion_name) and getattr(case, conclusion_name) is not None
176
191
 
177
192
 
178
193
  class RDRWithCodeWriter(RippleDownRules, ABC):
@@ -194,16 +209,17 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
194
209
 
195
210
  :param file_path: The path to the file to write the source code to.
196
211
  """
197
- func_def = f"def classify(case: {self.case_type.__name__}) -> {self._get_conclusion_type_hint()}:\n"
212
+ func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
198
213
  with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
199
214
  f.write(self._get_imports() + "\n\n")
200
215
  f.write(func_def)
201
- f.write(f"{' '*4}if not isinstance(case, Case):\n"
202
- f"{' '*4} case = create_case(case, recursion_idx=3)\n""")
216
+ f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
217
+ f"{' ' * 4} case = create_case(case, recursion_idx=3)\n""")
203
218
  self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
204
219
 
220
+ @property
205
221
  @abstractmethod
206
- def _get_conclusion_type_hint(self) -> str:
222
+ def conclusion_type_hint(self) -> str:
207
223
  """
208
224
  :return: The type hint of the conclusion of the rdr as a string.
209
225
  """
@@ -242,7 +258,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
242
258
  :return: The type of the case (input) to the RDR classifier.
243
259
  """
244
260
  if isinstance(self.start_rule.corner_case, Case):
245
- return self.start_rule.corner_case._type
261
+ return self.start_rule.corner_case._obj_type
246
262
  else:
247
263
  return type(self.start_rule.corner_case)
248
264
 
@@ -254,8 +270,17 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
254
270
  if isinstance(self.start_rule.conclusion, CallableExpression):
255
271
  return self.start_rule.conclusion.conclusion_type
256
272
  else:
273
+ if isinstance(self.start_rule.conclusion, set):
274
+ return type(list(self.start_rule.conclusion)[0])
257
275
  return type(self.start_rule.conclusion)
258
276
 
277
+ @property
278
+ def attribute_name(self) -> str:
279
+ """
280
+ :return: The name of the attribute that the classifier is classifying.
281
+ """
282
+ return self.start_rule.conclusion_name
283
+
259
284
 
260
285
  class SingleClassRDR(RDRWithCodeWriter):
261
286
 
@@ -270,23 +295,20 @@ class SingleClassRDR(RDRWithCodeWriter):
270
295
  :return: The category that the case belongs to.
271
296
  """
272
297
  expert = expert if expert else Human(session=self.session)
273
- case, attribute = case_query.case, case_query.attribute
274
298
  if case_query.target is None:
275
299
  target = expert.ask_for_conclusion(case_query)
276
- else:
277
- target = case_query.target
278
-
279
300
  if not self.start_rule:
280
- conditions = expert.ask_for_conditions(case, [target])
281
- self.start_rule = SingleClassRule(conditions, target, corner_case=case)
301
+ conditions = expert.ask_for_conditions(case_query)
302
+ self.start_rule = SingleClassRule(conditions, case_query.target, corner_case=case_query.case,
303
+ conclusion_name=case_query.attribute_name)
282
304
 
283
- pred = self.evaluate(case)
305
+ pred = self.evaluate(case_query.case)
284
306
 
285
- if pred.conclusion != target:
286
- conditions = expert.ask_for_conditions(case, [target], pred)
287
- pred.fit_rule(case, target, conditions=conditions)
307
+ if pred.conclusion != case_query.target:
308
+ conditions = expert.ask_for_conditions(case_query, pred)
309
+ pred.fit_rule(case_query.case, case_query.target, conditions=conditions)
288
310
 
289
- return self.classify(case)
311
+ return self.classify(case_query.case)
290
312
 
291
313
  def classify(self, case: Case) -> Optional[CaseAttribute]:
292
314
  """
@@ -316,7 +338,8 @@ class SingleClassRDR(RDRWithCodeWriter):
316
338
  if rule.alternative:
317
339
  self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
318
340
 
319
- def _get_conclusion_type_hint(self) -> str:
341
+ @property
342
+ def conclusion_type_hint(self) -> str:
320
343
  return self.conclusion_type.__name__
321
344
 
322
345
  def _to_json(self) -> Dict[str, Any]:
@@ -383,50 +406,49 @@ class MultiClassRDR(RDRWithCodeWriter):
383
406
  :return: The conclusions that the case belongs to.
384
407
  """
385
408
  expert = expert if expert else Human(session=self.session)
386
- case = case_query.case
387
409
  if case_query.target is None:
388
- targets = [expert.ask_for_conclusion(case_query)]
389
- else:
390
- targets = [case_query.target]
410
+ targets = expert.ask_for_conclusion(case_query)
391
411
  self.expert_accepted_conclusions = []
392
412
  user_conclusions = []
393
- for target in targets:
394
- self.update_start_rule(case, target, expert)
395
- self.conclusions = []
396
- self.stop_rule_conditions = None
397
- evaluated_rule = self.start_rule
398
- while evaluated_rule:
399
- next_rule = evaluated_rule(case)
400
- good_conclusions = targets + user_conclusions + self.expert_accepted_conclusions
401
-
402
- if evaluated_rule.fired:
403
- if target and evaluated_rule.conclusion not in good_conclusions:
404
- # if self.case_has_conclusion(case, evaluated_rule.conclusion):
405
- # Rule fired and conclusion is different from target
406
- self.stop_wrong_conclusion_else_add_it(case, target, expert, evaluated_rule,
407
- add_extra_conclusions)
408
- else:
409
- # Rule fired and target is correct or there is no target to compare
410
- self.add_conclusion(evaluated_rule)
411
-
412
- if not next_rule:
413
- if not make_set(target).intersection(make_set(self.conclusions)):
414
- # Nothing fired and there is a target that should have been in the conclusions
415
- self.add_rule_for_case(case, target, expert)
416
- # Have to check all rules again to make sure only this new rule fires
417
- next_rule = self.start_rule
418
- elif add_extra_conclusions and not user_conclusions:
419
- # No more conclusions can be made, ask the expert for extra conclusions if needed.
420
- user_conclusions.extend(self.ask_expert_for_extra_conclusions(expert, case))
421
- if user_conclusions:
422
- next_rule = self.last_top_rule
423
- evaluated_rule = next_rule
413
+ self.update_start_rule(case_query, expert)
414
+ self.conclusions = []
415
+ self.stop_rule_conditions = None
416
+ evaluated_rule = self.start_rule
417
+ while evaluated_rule:
418
+ next_rule = evaluated_rule(case_query.case)
419
+ good_conclusions = make_list(case_query.target) + user_conclusions + self.expert_accepted_conclusions
420
+ good_conclusions = make_set(good_conclusions)
421
+
422
+ if evaluated_rule.fired:
423
+ if case_query.target and not make_set(evaluated_rule.conclusion).issubset(good_conclusions):
424
+ # if self.case_has_conclusion(case, evaluated_rule.conclusion):
425
+ # Rule fired and conclusion is different from target
426
+ self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
427
+ add_extra_conclusions)
428
+ else:
429
+ # Rule fired and target is correct or there is no target to compare
430
+ self.add_conclusion(evaluated_rule)
431
+
432
+ if not next_rule:
433
+ if not make_set(case_query.target).intersection(make_set(self.conclusions)):
434
+ # Nothing fired and there is a target that should have been in the conclusions
435
+ self.add_rule_for_case(case_query, expert)
436
+ # Have to check all rules again to make sure only this new rule fires
437
+ next_rule = self.start_rule
438
+ elif add_extra_conclusions and not user_conclusions:
439
+ # No more conclusions can be made, ask the expert for extra conclusions if needed.
440
+ user_conclusions.extend(self.ask_expert_for_extra_conclusions(expert, case_query.case))
441
+ if user_conclusions:
442
+ next_rule = self.last_top_rule
443
+ evaluated_rule = next_rule
424
444
  return self.conclusions
425
445
 
426
446
  def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
427
447
  file, parent_indent: str = ""):
428
448
  """
429
449
  Write the rules as source code to a file.
450
+
451
+ :
430
452
  """
431
453
  if rule == self.start_rule:
432
454
  file.write(f"{parent_indent}conclusions = set()\n")
@@ -435,14 +457,15 @@ class MultiClassRDR(RDRWithCodeWriter):
435
457
  conclusion_indent = parent_indent
436
458
  if hasattr(rule, "refinement") and rule.refinement:
437
459
  self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ")
438
- conclusion_indent = parent_indent + " "*4
460
+ conclusion_indent = parent_indent + " " * 4
439
461
  file.write(f"{conclusion_indent}else:\n")
440
462
  file.write(rule.write_conclusion_as_source_code(conclusion_indent))
441
463
 
442
464
  if rule.alternative:
443
465
  self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
444
466
 
445
- def _get_conclusion_type_hint(self) -> str:
467
+ @property
468
+ def conclusion_type_hint(self) -> str:
446
469
  return f"Set[{self.conclusion_type.__name__}]"
447
470
 
448
471
  def _get_imports(self) -> str:
@@ -450,19 +473,18 @@ class MultiClassRDR(RDRWithCodeWriter):
450
473
  imports += "from typing_extensions import Set\n"
451
474
  return imports
452
475
 
453
- def update_start_rule(self, case: Union[Case, SQLTable], target: Any, expert: Expert):
476
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
454
477
  """
455
478
  Update the starting rule of the classifier.
456
479
 
457
- :param case: The case to classify.
458
- :param target: The target category to compare the case with.
480
+ :param case_query: The case query to update the starting rule with.
459
481
  :param expert: The expert to ask for differentiating features as new rule conditions.
460
482
  """
461
483
  if not self.start_rule.conditions:
462
- conditions = expert.ask_for_conditions(case, target)
484
+ conditions = expert.ask_for_conditions(case_query)
463
485
  self.start_rule.conditions = conditions
464
- self.start_rule.conclusion = target
465
- self.start_rule.corner_case = case
486
+ self.start_rule.conclusion = case_query.target
487
+ self.start_rule.corner_case = case_query.case
466
488
 
467
489
  @property
468
490
  def last_top_rule(self) -> Optional[MultiClassTopRule]:
@@ -474,35 +496,34 @@ class MultiClassRDR(RDRWithCodeWriter):
474
496
  else:
475
497
  return self.start_rule.furthest_alternative[-1]
476
498
 
477
- def stop_wrong_conclusion_else_add_it(self, case: Union[Case, SQLTable], target: Any, expert: Expert,
499
+ def stop_wrong_conclusion_else_add_it(self, case_query: CaseQuery, expert: Expert,
478
500
  evaluated_rule: MultiClassTopRule,
479
501
  add_extra_conclusions: bool):
480
502
  """
481
503
  Stop a wrong conclusion by adding a stopping rule.
482
504
  """
483
- if self.is_same_category_type(evaluated_rule.conclusion, target) \
484
- and self.is_conflicting_with_target(evaluated_rule.conclusion, target):
485
- self.stop_conclusion(case, target, expert, evaluated_rule)
486
- elif not self.conclusion_is_correct(case, target, expert, evaluated_rule, add_extra_conclusions):
487
- self.stop_conclusion(case, target, expert, evaluated_rule)
505
+ if self.is_same_category_type(evaluated_rule.conclusion, case_query.target) \
506
+ and self.is_conflicting_with_target(evaluated_rule.conclusion, case_query.target):
507
+ self.stop_conclusion(case_query, expert, evaluated_rule)
508
+ elif not self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
509
+ self.stop_conclusion(case_query, expert, evaluated_rule)
488
510
 
489
- def stop_conclusion(self, case: Union[Case, SQLTable], target: Any,
511
+ def stop_conclusion(self, case_query: CaseQuery,
490
512
  expert: Expert, evaluated_rule: MultiClassTopRule):
491
513
  """
492
514
  Stop a conclusion by adding a stopping rule.
493
515
 
494
- :param case: The case to classify.
495
- :param target: The target category to compare the case with.
516
+ :param case_query: The case query to stop the conclusion for.
496
517
  :param expert: The expert to ask for differentiating features as new rule conditions.
497
518
  :param evaluated_rule: The evaluated rule to ask the expert about.
498
519
  """
499
- conditions = expert.ask_for_conditions(case, target, evaluated_rule)
500
- evaluated_rule.fit_rule(case, target, conditions=conditions)
520
+ conditions = expert.ask_for_conditions(case_query, evaluated_rule)
521
+ evaluated_rule.fit_rule(case_query.case, case_query.target, conditions=conditions)
501
522
  if self.mode == MCRDRMode.StopPlusRule:
502
523
  self.stop_rule_conditions = conditions
503
524
  if self.mode == MCRDRMode.StopPlusRuleCombined:
504
525
  new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
505
- self.add_top_rule(new_top_rule_conditions, target, case)
526
+ self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
506
527
 
507
528
  @staticmethod
508
529
  def is_conflicting_with_target(conclusion: Any, target: Any) -> bool:
@@ -529,37 +550,40 @@ class MultiClassRDR(RDRWithCodeWriter):
529
550
  """
530
551
  return conclusion.__class__ == target.__class__ and target.__class__ != CaseAttribute
531
552
 
532
- def conclusion_is_correct(self, case: Union[Case, SQLTable], target: Any, expert: Expert, evaluated_rule: Rule,
553
+ def conclusion_is_correct(self, case_query: CaseQuery,
554
+ expert: Expert, evaluated_rule: Rule,
533
555
  add_extra_conclusions: bool) -> bool:
534
556
  """
535
557
  Ask the expert if the conclusion is correct, and add it to the conclusions if it is.
536
558
 
537
- :param case: The case to classify.
538
- :param target: The target category to compare the case with.
559
+ :param case_query: The case query to ask the expert about.
539
560
  :param expert: The expert to ask for differentiating features as new rule conditions.
540
561
  :param evaluated_rule: The evaluated rule to ask the expert about.
541
562
  :param add_extra_conclusions: Whether adding extra conclusions after classification is allowed.
542
563
  :return: Whether the conclusion is correct or not.
543
564
  """
544
- conclusions = list(OrderedSet(self.conclusions))
545
- if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case, evaluated_rule.conclusion,
546
- targets=target,
565
+ conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
566
+ if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case, evaluated_rule.conclusion,
567
+ targets=case_query.target,
547
568
  current_conclusions=conclusions)):
548
569
  self.add_conclusion(evaluated_rule)
549
570
  self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
550
571
  return True
551
572
  return False
552
573
 
553
- def add_rule_for_case(self, case: Union[Case, SQLTable], target: Any, expert: Expert):
574
+ def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
554
575
  """
555
576
  Add a rule for a case that has not been classified with any conclusion.
577
+
578
+ :param case_query: The case query to add the rule for.
579
+ :param expert: The expert to ask for differentiating features as new rule conditions.
556
580
  """
557
581
  if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
558
582
  conditions = self.stop_rule_conditions
559
583
  self.stop_rule_conditions = None
560
584
  else:
561
- conditions = expert.ask_for_conditions(case, target)
562
- self.add_top_rule(conditions, target, case)
585
+ conditions = expert.ask_for_conditions(case_query)
586
+ self.add_top_rule(conditions, case_query.target, case_query.case)
563
587
 
564
588
  def ask_expert_for_extra_conclusions(self, expert: Expert, case: Union[Case, SQLTable]) -> List[Any]:
565
589
  """
@@ -633,20 +657,31 @@ class GeneralRDR(RippleDownRules):
633
657
  gets called when the final rule fires.
634
658
  """
635
659
 
636
- def __init__(self, category_rdr_map: Optional[Dict[Type, Union[SingleClassRDR, MultiClassRDR]]] = None):
660
+ def __init__(self, category_rdr_map: Optional[Dict[str, Union[SingleClassRDR, MultiClassRDR]]] = None):
637
661
  """
638
- :param category_rdr_map: A map of categories to ripple down rules classifiers,
662
+ :param category_rdr_map: A map of case attribute names to ripple down rules classifiers,
639
663
  where each category is a parent category that has a set of mutually exclusive (in case of SCRDR) child
640
- categories, e.g. {Species: SCRDR, Habitat: MCRDR}, where Species and Habitat are parent categories and SCRDR
641
- and MCRDR are SingleClass and MultiClass ripple down rules classifiers. Species can have child categories like
642
- Mammal, Bird, Fish, etc. which are mutually exclusive, and Habitat can have child categories like
643
- Land, Water, Air, etc, which are not mutually exclusive due to some animals living more than one habitat.
664
+ categories, e.g. {'species': SCRDR, 'habitats': MCRDR}, where 'species' and 'habitats' are attribute names
665
+ for a case of type Animal, while SCRDR and MCRDR are SingleClass and MultiClass ripple down rules classifiers.
666
+ Species can have values like Mammal, Bird, Fish, etc. which are mutually exclusive, while Habitat can have
667
+ values like Land, Water, Air, etc., which are not mutually exclusive due to some animals living more than one
668
+ habitat.
644
669
  """
645
- self.start_rules_dict: Dict[Type, Union[SingleClassRDR, MultiClassRDR]] \
670
+ self.start_rules_dict: Dict[str, Union[SingleClassRDR, MultiClassRDR]] \
646
671
  = category_rdr_map if category_rdr_map else {}
647
672
  super(GeneralRDR, self).__init__()
648
673
  self.all_figs: List[plt.Figure] = [sr.fig for sr in self.start_rules_dict.values()]
649
674
 
675
+ def add_rdr(self, rdr: Union[SingleClassRDR, MultiClassRDR], attribute_name: Optional[str] = None):
676
+ """
677
+ Add a ripple down rules classifier to the map of classifiers.
678
+
679
+ :param rdr: The ripple down rules classifier to add.
680
+ :param attribute_name: The name of the attribute that the classifier is classifying.
681
+ """
682
+ attribute_name = attribute_name if attribute_name else rdr.attribute_name
683
+ self.start_rules_dict[attribute_name] = rdr
684
+
650
685
  @property
651
686
  def start_rule(self) -> Optional[Union[SingleClassRule, MultiClassTopRule]]:
652
687
  return self.start_rules[0] if self.start_rules_dict else None
@@ -654,7 +689,7 @@ class GeneralRDR(RippleDownRules):
654
689
  @start_rule.setter
655
690
  def start_rule(self, value: Union[SingleClassRDR, MultiClassRDR]):
656
691
  if value:
657
- self.start_rules_dict[type(value.start_rule.conclusion)] = value
692
+ self.start_rules_dict[value.attribute_name] = value
658
693
 
659
694
  @property
660
695
  def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
@@ -668,21 +703,44 @@ class GeneralRDR(RippleDownRules):
668
703
  :param case: The case to classify.
669
704
  :return: The categories that the case belongs to.
670
705
  """
671
- conclusions = []
706
+ return self._classify(self.start_rules_dict, case)
707
+
708
+ @staticmethod
709
+ def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
710
+ case: Union[Case, SQLTable]) -> Optional[Dict[str, Any]]:
711
+ """
712
+ Classify a case by going through all classifiers and adding the categories that are classified,
713
+ and then restarting the classification until no more categories can be added.
714
+
715
+ :param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
716
+ :param case: The case to classify.
717
+ :return: The categories that the case belongs to.
718
+ """
719
+ conclusions = {}
672
720
  case_cp = copy_case(case)
673
721
  while True:
674
- added_attributes = False
675
- for cat_type, rdr in self.start_rules_dict.items():
676
- if self.case_has_conclusion(case_cp, cat_type):
677
- continue
722
+ new_conclusions = {}
723
+ for attribute_name, rdr in classifiers_dict.items():
678
724
  pred_atts = rdr.classify(case_cp)
679
- if pred_atts:
725
+ if pred_atts is None:
726
+ continue
727
+ if isinstance(rdr, SingleClassRDR):
728
+ if attribute_name not in conclusions or \
729
+ (attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
730
+ conclusions[attribute_name] = pred_atts
731
+ new_conclusions[attribute_name] = pred_atts
732
+ else:
680
733
  pred_atts = make_list(pred_atts)
681
- pred_atts = [p for p in pred_atts if p not in conclusions]
682
- added_attributes = True
683
- conclusions.extend(pred_atts)
684
- GeneralRDR.update_case(case_cp, pred_atts)
685
- if not added_attributes:
734
+ if attribute_name in conclusions:
735
+ pred_atts = [p for p in pred_atts if p not in conclusions[attribute_name]]
736
+ if len(pred_atts) > 0:
737
+ new_conclusions[attribute_name] = pred_atts
738
+ if attribute_name not in conclusions:
739
+ conclusions[attribute_name] = []
740
+ conclusions[attribute_name].extend(pred_atts)
741
+ if attribute_name in new_conclusions:
742
+ GeneralRDR.update_case(case_cp, new_conclusions)
743
+ if len(new_conclusions) == 0:
686
744
  break
687
745
  return conclusions
688
746
 
@@ -707,103 +765,79 @@ class GeneralRDR(RippleDownRules):
707
765
  case = case_queries[0].case
708
766
  assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
709
767
  " for multiple cases use fit instead")
710
- case_query_cp = copy(case_queries[0])
711
- case_cp = case_query_cp.case
768
+ case_cp = copy(case_queries[0]).case
712
769
  for case_query in case_queries:
713
- target = case_query.target
714
- if not target:
770
+ case_query_cp = copy(case_query)
771
+ case_query_cp.case = case_cp
772
+ if case_query.target is None:
773
+ conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
715
774
  target = expert.ask_for_conclusion(case_query)
716
- case_query_cp = CaseQuery(case_cp, attribute_name=case_query.attribute_name, target=target)
717
- if is_iterable(target) and not isinstance(target, CaseAttribute):
718
- target_type = type(make_list(target)[0])
719
- assert all([type(t) is target_type for t in target]), ("All targets of a case query must be of the same"
720
- " type")
721
- else:
722
- target_type = type(target)
723
- if target_type not in self.start_rules_dict:
775
+
776
+ if case_query.attribute_name not in self.start_rules_dict:
724
777
  conclusions = self.classify(case)
725
778
  self.update_case(case_cp, conclusions)
726
- new_rdr = self.initialize_new_rdr_for_attribute(target, case_cp)
779
+
780
+ new_rdr = self.initialize_new_rdr_for_attribute(case_query.attribute_name, case_cp, case_query.target)
781
+ self.add_rdr(new_rdr, case_query.attribute_name)
782
+
727
783
  new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
728
- self.start_rules_dict[target_type] = new_rdr
729
- self.update_case(case_cp, new_conclusions, target_type)
730
- elif not self.case_has_conclusion(case_cp, target_type):
731
- for rdr_type, rdr in self.start_rules_dict.items():
732
- if target_type is not rdr_type:
784
+ self.update_case(case_cp, {case_query.attribute_name: new_conclusions})
785
+ else:
786
+ for rdr_attribute_name, rdr in self.start_rules_dict.items():
787
+ if case_query.attribute_name != rdr_attribute_name:
733
788
  conclusions = rdr.classify(case_cp)
734
789
  else:
735
- conclusions = self.start_rules_dict[target_type].fit_case(case_query_cp,
736
- expert, **kwargs)
737
- self.update_case(case_cp, conclusions, rdr_type)
790
+ conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
791
+ **kwargs)
792
+ if conclusions is not None or (is_iterable(conclusions) and len(conclusions) > 0):
793
+ conclusions = {rdr_attribute_name: conclusions}
794
+ self.update_case(case_cp, conclusions)
738
795
 
739
796
  return self.classify(case)
740
797
 
741
798
  @staticmethod
742
- def initialize_new_rdr_for_attribute(attribute: Any, case: Union[Case, SQLTable]):
799
+ def initialize_new_rdr_for_attribute(attribute_name: str, case: Union[Case, SQLTable], target: Any):
743
800
  """
744
801
  Initialize the appropriate RDR type for the target.
745
802
  """
746
- if isinstance(case, SQLTable):
747
- prop = get_attribute_by_type(case, type(attribute))
748
- if hasattr(prop, "__iter__") and not isinstance(prop, str):
749
- return MultiClassRDR()
750
- else:
751
- return SingleClassRDR()
752
- elif isinstance(attribute, CaseAttribute):
803
+ attribute = getattr(case, attribute_name) if hasattr(case, attribute_name) else target
804
+ if isinstance(attribute, CaseAttribute):
753
805
  return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
754
806
  else:
755
- return MultiClassRDR() if is_iterable(attribute) else SingleClassRDR()
807
+ return MultiClassRDR() if is_iterable(attribute) or (attribute is None) else SingleClassRDR()
756
808
 
757
809
  @staticmethod
758
- def update_case(case: Union[Case, SQLTable],
759
- conclusions: List[Any], attribute_type: Optional[Any] = None):
810
+ def update_case(case: Union[Case, SQLTable], conclusions: Dict[str, Any]):
760
811
  """
761
812
  Update the case with the conclusions.
762
813
 
763
814
  :param case: The case to update.
764
815
  :param conclusions: The conclusions to update the case with.
765
- :param attribute_type: The type of the attribute to update.
766
816
  """
767
817
  if not conclusions:
768
818
  return
769
- conclusions = [conclusions] if not isinstance(conclusions, list) else list(conclusions)
770
819
  if len(conclusions) == 0:
771
820
  return
772
821
  if isinstance(case, SQLTable):
773
- conclusions_type = type(conclusions[0]) if not attribute_type else attribute_type
774
- attr_name, attribute = get_attribute_by_type(case, conclusions_type)
775
- hint, origin, args = get_hint_for_attribute(attr_name, case)
776
- if isinstance(attribute, set) or origin == set:
777
- attribute = set() if attribute is None else attribute
778
- for c in conclusions:
779
- attribute.update(make_set(c))
780
- elif isinstance(attribute, list) or origin == list:
781
- attribute = [] if attribute is None else attribute
782
- attribute.extend(conclusions)
783
- elif len(conclusions) == 1 and hint == conclusions_type:
784
- setattr(case, attr_name, conclusions.pop())
785
- else:
786
- raise ValueError(f"Cannot add multiple conclusions to attribute {attr_name}")
822
+ for conclusion_name, conclusion in conclusions.items():
823
+ hint, origin, args = get_hint_for_attribute(conclusion_name, case)
824
+ attribute = getattr(case, conclusion_name)
825
+ if isinstance(attribute, set) or origin in {Set, set}:
826
+ attribute = set() if attribute is None else attribute
827
+ for c in conclusion:
828
+ attribute.update(make_set(c))
829
+ elif isinstance(attribute, list) or origin in {list, List}:
830
+ attribute = [] if attribute is None else attribute
831
+ attribute.extend(conclusion)
832
+ elif (not is_iterable(conclusion) or (len(conclusion) == 1)) and hint == type(conclusion):
833
+ setattr(case, conclusion_name, conclusion)
834
+ else:
835
+ raise ValueError(f"Cannot add multiple conclusions to attribute {conclusion_name}")
787
836
  else:
788
- for c in make_set(conclusions):
789
- case.update(c.as_dict)
790
-
791
- @property
792
- def names_of_all_types(self) -> List[str]:
793
- """
794
- Get the names of all the types of categories that the GRDR can classify.
795
- """
796
- return [t.__name__ for t in self.start_rules_dict.keys()]
797
-
798
- @property
799
- def all_types(self) -> List[Type]:
800
- """
801
- Get all the types of categories that the GRDR can classify.
802
- """
803
- return list(self.start_rules_dict.keys())
837
+ case.update(conclusions)
804
838
 
805
839
  def _to_json(self) -> Dict[str, Any]:
806
- return {"start_rules": {get_full_class_name(t): rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
840
+ return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
807
841
 
808
842
  @classmethod
809
843
  def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
@@ -812,6 +846,78 @@ class GeneralRDR(RippleDownRules):
812
846
  """
813
847
  start_rules_dict = {}
814
848
  for k, v in data["start_rules"].items():
815
- k = get_type_from_string(k)
816
849
  start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
817
850
  return cls(start_rules_dict)
851
+
852
+ def write_to_python_file(self, file_path: str):
853
+ """
854
+ Write the tree of rules as source code to a file.
855
+
856
+ :param file_path: The path to the file to write the source code to.
857
+ """
858
+ for rdr in self.start_rules_dict.values():
859
+ rdr.write_to_python_file(file_path)
860
+ func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
861
+ with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
862
+ f.write(self._get_imports(file_path) + "\n\n")
863
+ f.write("classifiers_dict = dict()\n")
864
+ for rdr_key, rdr in self.start_rules_dict.items():
865
+ f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
866
+ f.write("\n\n")
867
+ f.write(func_def)
868
+ f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
869
+ f"{' ' * 4} case = create_case(case, recursion_idx=3)\n""")
870
+ f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
871
+
872
+ @property
873
+ def case_type(self) -> Type:
874
+ """
875
+ :return: The type of the case (input) to the RDR classifier.
876
+ """
877
+ if isinstance(self.start_rule.corner_case, Case):
878
+ return self.start_rule.corner_case._obj_type
879
+ else:
880
+ return type(self.start_rule.corner_case)
881
+
882
+ def get_rdr_classifier_from_python_file(self, file_path: str):
883
+ """
884
+ :param file_path: The path to the file that contains the RDR classifier function.
885
+ :return: The module that contains the rdr classifier function.
886
+ """
887
+ return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
888
+
889
+ @property
890
+ def generated_python_file_name(self) -> str:
891
+ return f"{self.case_type.__name__.lower()}_grdr"
892
+
893
+ @property
894
+ def conclusion_type_hint(self) -> str:
895
+ return f"List[Union[{', '.join([rdr.conclusion_type_hint for rdr in self.start_rules_dict.values()])}]]"
896
+
897
+ def _get_imports(self, file_path: str) -> str:
898
+ imports = ""
899
+ # add type hints
900
+ imports += f"from typing_extensions import List, Union, Set\n"
901
+ # import rdr type
902
+ imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
903
+ # add case type
904
+ imports += f"from ripple_down_rules.datastructures import Case, create_case\n"
905
+ imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
906
+ # add conclusion type imports
907
+ for rdr in self.start_rules_dict.values():
908
+ imports += f"from {rdr.conclusion_type.__module__} import {rdr.conclusion_type.__name__}\n"
909
+ # add rdr python generated functions.
910
+ for rdr_key, rdr in self.start_rules_dict.items():
911
+ imports += (f"from {file_path.strip('./')}"
912
+ f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
913
+ return imports
914
+
915
+ @staticmethod
916
+ def rdr_key_to_function_name(rdr_key: str) -> str:
917
+ """
918
+ Convert the RDR key to a function name.
919
+
920
+ :param rdr_key: The RDR key to convert.
921
+ :return: The function name.
922
+ """
923
+ return rdr_key.replace(".", "_").lower() + "_classifier"