ripple-down-rules 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,687 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from copy import copy
5
+
6
+ from matplotlib import pyplot as plt
7
+ from ordered_set import OrderedSet
8
+ from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
9
+ from typing_extensions import List, Optional, Dict, Type, Union, Any, Self
10
+
11
+ from .datastructures import Case, MCRDRMode, CallableExpression, Column, CaseQuery
12
+ from .experts import Expert, Human
13
+ from .rules import Rule, SingleClassRule, MultiClassTopRule
14
+ from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
15
+ get_hint_for_attribute, SubclassJSONSerializer
16
+
17
+
18
+ class RippleDownRules(ABC):
19
+ """
20
+ The abstract base class for the ripple down rules classifiers.
21
+ """
22
+ fig: Optional[plt.Figure] = None
23
+ """
24
+ The figure to draw the tree on.
25
+ """
26
+ expert_accepted_conclusions: Optional[List[Column]] = None
27
+ """
28
+ The conclusions that the expert has accepted, such that they are not asked again.
29
+ """
30
+
31
+ def __init__(self, start_rule: Optional[Rule] = None, session: Optional[Session] = None):
32
+ """
33
+ :param start_rule: The starting rule for the classifier.
34
+ :param session: The sqlalchemy orm session.
35
+ """
36
+ self.start_rule = start_rule
37
+ self.session = session
38
+ self.fig: Optional[plt.Figure] = None
39
+
40
+ def __call__(self, case: Union[Case, SQLTable]) -> Column:
41
+ return self.classify(case)
42
+
43
+ @abstractmethod
44
+ def classify(self, case: Union[Case, SQLTable]) -> Optional[Column]:
45
+ """
46
+ Classify a case.
47
+
48
+ :param case: The case to classify.
49
+ :return: The category that the case belongs to.
50
+ """
51
+ pass
52
+
53
+ @abstractmethod
54
+ def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs)\
55
+ -> Union[Column, CallableExpression]:
56
+ """
57
+ Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
58
+ comparing the case with the target category.
59
+
60
+ :param case_query: The query containing the case to classify and the target category to compare the case with.
61
+ :param expert: The expert to ask for differentiating features as new rule conditions.
62
+ :return: The category that the case belongs to.
63
+ """
64
+ pass
65
+
66
+ def fit(self, case_queries: List[CaseQuery],
67
+ expert: Optional[Expert] = None,
68
+ n_iter: int = None,
69
+ animate_tree: bool = False,
70
+ **kwargs_for_fit_case):
71
+ """
72
+ Fit the classifier to a batch of cases and categories.
73
+
74
+ :param case_queries: The cases and categories to fit the classifier to.
75
+ :param expert: The expert to ask for differentiating features as new rule conditions.
76
+ :param n_iter: The number of iterations to fit the classifier for.
77
+ :param animate_tree: Whether to draw the tree while fitting the classifier.
78
+ :param kwargs_for_fit_case: The keyword arguments to pass to the fit_case method.
79
+ """
80
+ cases = [case_query.case for case_query in case_queries]
81
+ targets = [case.target for case in case_queries]
82
+ if animate_tree:
83
+ plt.ion()
84
+ i = 0
85
+ stop_iterating = False
86
+ num_rules: int = 0
87
+ while not stop_iterating:
88
+ all_pred = 0
89
+ all_recall = []
90
+ all_precision = []
91
+ if not targets:
92
+ targets = [None] * len(cases)
93
+ for case_query in case_queries:
94
+ case = case_query.case
95
+ target = case_query.target
96
+ if not target:
97
+ conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
98
+ target = expert.ask_for_conclusion(case_query, conclusions)
99
+ pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
100
+ pred_cat = pred_cat if isinstance(pred_cat, list) else [pred_cat]
101
+ target = target if isinstance(target, list) else [target]
102
+ recall = [not yi or (yi in pred_cat) for yi in target]
103
+ y_type = [type(yi) for yi in target]
104
+ precision = [(pred in target) or (type(pred) not in y_type) for pred in pred_cat]
105
+ match = all(recall) and all(precision)
106
+ all_recall.extend(recall)
107
+ all_precision.extend(precision)
108
+ if not match:
109
+ print(f"Predicted: {pred_cat} but expected: {target}")
110
+ all_pred += int(match)
111
+ if animate_tree and self.start_rule.size > num_rules:
112
+ num_rules = self.start_rule.size
113
+ self.update_figures()
114
+ i += 1
115
+ all_predicted = targets and all_pred == len(targets)
116
+ num_iter_reached = n_iter and i >= n_iter
117
+ stop_iterating = all_predicted or num_iter_reached
118
+ if stop_iterating:
119
+ break
120
+ print(f"Recall: {sum(all_recall) / len(all_recall)}")
121
+ print(f"Precision: {sum(all_precision) / len(all_precision)}")
122
+ print(f"Accuracy: {all_pred}/{n_iter}")
123
+ print(f"Finished training in {i} iterations")
124
+ if animate_tree:
125
+ plt.ioff()
126
+ plt.show()
127
+
128
+ def update_figures(self):
129
+ """
130
+ Update the figures of the classifier.
131
+ """
132
+ if isinstance(self, GeneralRDR):
133
+ for i, (_type, rdr) in enumerate(self.start_rules_dict.items()):
134
+ if not rdr.fig:
135
+ rdr.fig = plt.figure(f"Rule {i}: {_type.__name__}")
136
+ draw_tree(rdr.start_rule, rdr.fig)
137
+ else:
138
+ if not self.fig:
139
+ self.fig = plt.figure(0)
140
+ draw_tree(self.start_rule, self.fig)
141
+
142
+ @staticmethod
143
+ def case_has_conclusion(case: Union[Case, SQLTable], conclusion_type: Type) -> bool:
144
+ """
145
+ Check if the case has a conclusion.
146
+
147
+ :param case: The case to check.
148
+ :param conclusion_type: The target category type to compare the case with.
149
+ :return: Whether the case has a conclusion or not.
150
+ """
151
+ if isinstance(case, SQLTable):
152
+ prop_name, prop_value = get_attribute_by_type(case, conclusion_type)
153
+ if hasattr(prop_value, "__iter__") and not isinstance(prop_value, str):
154
+ return len(prop_value) > 0
155
+ else:
156
+ return prop_value is not None
157
+ else:
158
+ return conclusion_type in case
159
+
160
+
161
+ RDR = RippleDownRules
162
+
163
+
164
+ class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
165
+
166
+ def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
167
+ -> Union[Column, CallableExpression]:
168
+ """
169
+ Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
170
+ comparing the case with the target category if provided.
171
+
172
+ :param case_query: The case to classify and the target category to compare the case with.
173
+ :param expert: The expert to ask for differentiating features as new rule conditions.
174
+ :return: The category that the case belongs to.
175
+ """
176
+ expert = expert if expert else Human(session=self.session)
177
+ case, attribute = case_query.case, case_query.attribute
178
+ if case_query.target is None:
179
+ target = expert.ask_for_conclusion(case_query)
180
+ else:
181
+ target = case_query.target
182
+
183
+ if not self.start_rule:
184
+ conditions = expert.ask_for_conditions(case, [target])
185
+ self.start_rule = SingleClassRule(conditions, target, corner_case=case)
186
+
187
+ pred = self.evaluate(case)
188
+
189
+ if pred.conclusion != target:
190
+ conditions = expert.ask_for_conditions(case, [target], pred)
191
+ pred.fit_rule(case, target, conditions=conditions)
192
+
193
+ return self.classify(case)
194
+
195
+ def classify(self, case: Case) -> Optional[Column]:
196
+ """
197
+ Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
198
+ """
199
+ pred = self.evaluate(case)
200
+ return pred.conclusion if pred.fired else None
201
+
202
+ def evaluate(self, case: Case) -> SingleClassRule:
203
+ """
204
+ Evaluate the starting rule on a case.
205
+ """
206
+ matched_rule = self.start_rule(case)
207
+ return matched_rule if matched_rule else self.start_rule
208
+
209
+ def write_to_python_file(self, filename: str):
210
+ """
211
+ Write the tree of rules as source code to a file.
212
+ """
213
+ case_type = self.start_rule.corner_case.__class__.__name__
214
+ case_module = self.start_rule.corner_case.__class__.__module__
215
+ conclusion = self.start_rule.conclusion
216
+ if isinstance(conclusion, CallableExpression):
217
+ conclusion_types = [conclusion.conclusion_type]
218
+ elif isinstance(conclusion, Column):
219
+ conclusion_types = list(conclusion._value_range)
220
+ else:
221
+ conclusion_types = [type(conclusion)]
222
+ imports = ""
223
+ if case_module != "builtins":
224
+ imports += f"from {case_module} import {case_type}\n"
225
+ if len(conclusion_types) > 1:
226
+ conclusion_name = "Union[" + ", ".join([c.__name__ for c in conclusion_types]) + "]"
227
+ else:
228
+ conclusion_name = conclusion_types[0].__name__
229
+ for conclusion_type in conclusion_types:
230
+ if conclusion_type.__module__ != "builtins":
231
+ imports += f"from {conclusion_type.__module__} import {conclusion_name}\n"
232
+ imports += "\n\n"
233
+ func_def = f"def classify_{conclusion_name.lower()}(case: {case_type}) -> {conclusion_name}:\n"
234
+ with open(filename, "w") as f:
235
+ f.write(imports)
236
+ f.write(func_def)
237
+ self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
238
+
239
+ def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file, parent_indent: str = ""):
240
+ """
241
+ Write the rules as source code to a file.
242
+ """
243
+ if rule.conditions:
244
+ file.write(rule.write_condition_as_source_code(parent_indent))
245
+ if rule.refinement:
246
+ self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ")
247
+
248
+ file.write(rule.write_conclusion_as_source_code(parent_indent))
249
+
250
+ if rule.alternative:
251
+ self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
252
+
253
+ def to_json(self) -> Dict[str, Any]:
254
+ return {**SubclassJSONSerializer.to_json(self), "start_rule": self.start_rule.to_json()}
255
+
256
+ @classmethod
257
+ def _from_json(cls, data: Dict[str, Any]) -> Self:
258
+ """
259
+ Create an instance of the class from a json
260
+ """
261
+ start_rule = SingleClassRule.from_json(data["start_rule"])
262
+ return cls(start_rule)
263
+
264
+
265
+ class MultiClassRDR(RippleDownRules):
266
+ """
267
+ A multi class ripple down rules classifier, which can draw multiple conclusions for a case.
268
+ This is done by going through all rules and checking if they fire or not, and adding stopping rules if needed,
269
+ when wrong conclusions are made to stop these rules from firing again for similar cases.
270
+ """
271
+ evaluated_rules: Optional[List[Rule]] = None
272
+ """
273
+ The evaluated rules in the classifier for one case.
274
+ """
275
+ conclusions: Optional[List[Column]] = None
276
+ """
277
+ The conclusions that the case belongs to.
278
+ """
279
+ stop_rule_conditions: Optional[CallableExpression] = None
280
+ """
281
+ The conditions of the stopping rule if needed.
282
+ """
283
+
284
+ def __init__(self, start_rules: Optional[List[Rule]] = None,
285
+ mode: MCRDRMode = MCRDRMode.StopOnly, session: Optional[Session] = None):
286
+ """
287
+ :param start_rules: The starting rules for the classifier, these are the rules that are at the top of the tree
288
+ and are always checked, in contrast to the refinement and alternative rules which are only checked if the
289
+ starting rules fire or not.
290
+ :param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
291
+ :param session: The sqlalchemy orm session.
292
+ """
293
+ self.start_rules = [MultiClassTopRule()] if not start_rules else start_rules
294
+ super(MultiClassRDR, self).__init__(self.start_rules[0], session=session)
295
+ self.mode: MCRDRMode = mode
296
+
297
+ def classify(self, case: Union[Case, SQLTable]) -> List[Any]:
298
+ evaluated_rule = self.start_rule
299
+ self.conclusions = []
300
+ while evaluated_rule:
301
+ next_rule = evaluated_rule(case)
302
+ if evaluated_rule.fired:
303
+ self.add_conclusion(evaluated_rule)
304
+ evaluated_rule = next_rule
305
+ return self.conclusions
306
+
307
+ def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None,
308
+ add_extra_conclusions: bool = False) -> List[Union[Column, CallableExpression]]:
309
+ """
310
+ Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
311
+ or missing by comparing the case with the target category if provided.
312
+
313
+ :param case_query: The query containing the case to classify and the target category to compare the case with.
314
+ :param expert: The expert to ask for differentiating features as new rule conditions or for extra conclusions.
315
+ :param add_extra_conclusions: Whether to add extra conclusions after classification is done.
316
+ :return: The conclusions that the case belongs to.
317
+ """
318
+ expert = expert if expert else Human(session=self.session)
319
+ case = case_query.case
320
+ if case_query.target is None:
321
+ targets = [expert.ask_for_conclusion(case_query)]
322
+ else:
323
+ targets = [case_query.target]
324
+ self.expert_accepted_conclusions = []
325
+ user_conclusions = []
326
+ for target in targets:
327
+ self.update_start_rule(case, target, expert)
328
+ self.conclusions = []
329
+ self.stop_rule_conditions = None
330
+ evaluated_rule = self.start_rule
331
+ while evaluated_rule:
332
+ next_rule = evaluated_rule(case)
333
+ good_conclusions = targets + user_conclusions + self.expert_accepted_conclusions
334
+
335
+ if evaluated_rule.fired:
336
+ if target and evaluated_rule.conclusion not in good_conclusions:
337
+ # if self.case_has_conclusion(case, evaluated_rule.conclusion):
338
+ # Rule fired and conclusion is different from target
339
+ self.stop_wrong_conclusion_else_add_it(case, target, expert, evaluated_rule,
340
+ add_extra_conclusions)
341
+ else:
342
+ # Rule fired and target is correct or there is no target to compare
343
+ self.add_conclusion(evaluated_rule)
344
+
345
+ if not next_rule:
346
+ if target not in self.conclusions:
347
+ # Nothing fired and there is a target that should have been in the conclusions
348
+ self.add_rule_for_case(case, target, expert)
349
+ # Have to check all rules again to make sure only this new rule fires
350
+ next_rule = self.start_rule
351
+ elif add_extra_conclusions and not user_conclusions:
352
+ # No more conclusions can be made, ask the expert for extra conclusions if needed.
353
+ user_conclusions.extend(self.ask_expert_for_extra_conclusions(expert, case))
354
+ if user_conclusions:
355
+ next_rule = self.last_top_rule
356
+ evaluated_rule = next_rule
357
+ return self.conclusions
358
+
359
+ def update_start_rule(self, case: Union[Case, SQLTable], target: Any, expert: Expert):
360
+ """
361
+ Update the starting rule of the classifier.
362
+
363
+ :param case: The case to classify.
364
+ :param target: The target category to compare the case with.
365
+ :param expert: The expert to ask for differentiating features as new rule conditions.
366
+ """
367
+ if not self.start_rule.conditions:
368
+ conditions = expert.ask_for_conditions(case, target)
369
+ self.start_rule.conditions = conditions
370
+ self.start_rule.conclusion = target
371
+ self.start_rule.corner_case = case
372
+
373
+ @property
374
+ def last_top_rule(self) -> Optional[MultiClassTopRule]:
375
+ """
376
+ Get the last top rule in the tree.
377
+ """
378
+ if not self.start_rule.furthest_alternative:
379
+ return self.start_rule
380
+ else:
381
+ return self.start_rule.furthest_alternative[-1]
382
+
383
+ def stop_wrong_conclusion_else_add_it(self, case: Union[Case, SQLTable], target: Any, expert: Expert,
384
+ evaluated_rule: MultiClassTopRule,
385
+ add_extra_conclusions: bool):
386
+ """
387
+ Stop a wrong conclusion by adding a stopping rule.
388
+ """
389
+ if self.is_same_category_type(evaluated_rule.conclusion, target) \
390
+ and self.is_conflicting_with_target(evaluated_rule.conclusion, target):
391
+ self.stop_conclusion(case, target, expert, evaluated_rule)
392
+ elif not self.conclusion_is_correct(case, target, expert, evaluated_rule, add_extra_conclusions):
393
+ self.stop_conclusion(case, target, expert, evaluated_rule)
394
+
395
+ def stop_conclusion(self, case: Union[Case, SQLTable], target: Any,
396
+ expert: Expert, evaluated_rule: MultiClassTopRule):
397
+ """
398
+ Stop a conclusion by adding a stopping rule.
399
+
400
+ :param case: The case to classify.
401
+ :param target: The target category to compare the case with.
402
+ :param expert: The expert to ask for differentiating features as new rule conditions.
403
+ :param evaluated_rule: The evaluated rule to ask the expert about.
404
+ """
405
+ conditions = expert.ask_for_conditions(case, target, evaluated_rule)
406
+ evaluated_rule.fit_rule(case, target, conditions=conditions)
407
+ if self.mode == MCRDRMode.StopPlusRule:
408
+ self.stop_rule_conditions = conditions
409
+ if self.mode == MCRDRMode.StopPlusRuleCombined:
410
+ new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
411
+ self.add_top_rule(new_top_rule_conditions, target, case)
412
+
413
+ @staticmethod
414
+ def is_conflicting_with_target(conclusion: Any, target: Any) -> bool:
415
+ """
416
+ Check if the conclusion is conflicting with the target category.
417
+
418
+ :param conclusion: The conclusion to check.
419
+ :param target: The target category to compare the conclusion with.
420
+ :return: Whether the conclusion is conflicting with the target category.
421
+ """
422
+ if hasattr(conclusion, "mutually_exclusive") and conclusion.mutually_exclusive:
423
+ return True
424
+ else:
425
+ return not make_set(conclusion).issubset(make_set(target))
426
+
427
+ @staticmethod
428
+ def is_same_category_type(conclusion: Any, target: Any) -> bool:
429
+ """
430
+ Check if the conclusion is of the same class as the target category.
431
+
432
+ :param conclusion: The conclusion to check.
433
+ :param target: The target category to compare the conclusion with.
434
+ :return: Whether the conclusion is of the same class as the target category but has a different value.
435
+ """
436
+ return conclusion.__class__ == target.__class__ and target.__class__ != Column
437
+
438
+ def conclusion_is_correct(self, case: Union[Case, SQLTable], target: Any, expert: Expert, evaluated_rule: Rule,
439
+ add_extra_conclusions: bool) -> bool:
440
+ """
441
+ Ask the expert if the conclusion is correct, and add it to the conclusions if it is.
442
+
443
+ :param case: The case to classify.
444
+ :param target: The target category to compare the case with.
445
+ :param expert: The expert to ask for differentiating features as new rule conditions.
446
+ :param evaluated_rule: The evaluated rule to ask the expert about.
447
+ :param add_extra_conclusions: Whether adding extra conclusions after classification is allowed.
448
+ :return: Whether the conclusion is correct or not.
449
+ """
450
+ conclusions = list(OrderedSet(self.conclusions))
451
+ if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case, evaluated_rule.conclusion,
452
+ targets=target,
453
+ current_conclusions=conclusions)):
454
+ self.add_conclusion(evaluated_rule)
455
+ self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
456
+ return True
457
+ return False
458
+
459
+ def add_rule_for_case(self, case: Union[Case, SQLTable], target: Any, expert: Expert):
460
+ """
461
+ Add a rule for a case that has not been classified with any conclusion.
462
+ """
463
+ if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
464
+ conditions = self.stop_rule_conditions
465
+ self.stop_rule_conditions = None
466
+ else:
467
+ conditions = expert.ask_for_conditions(case, target)
468
+ self.add_top_rule(conditions, target, case)
469
+
470
+ def ask_expert_for_extra_conclusions(self, expert: Expert, case: Union[Case, SQLTable]) -> List[Any]:
471
+ """
472
+ Ask the expert for extra conclusions when no more conclusions can be made.
473
+
474
+ :param expert: The expert to ask for extra conclusions.
475
+ :param case: The case to ask extra conclusions for.
476
+ :return: The extra conclusions that the expert has provided.
477
+ """
478
+ extra_conclusions = []
479
+ conclusions = list(OrderedSet(self.conclusions))
480
+ if not expert.use_loaded_answers:
481
+ print("current conclusions:", conclusions)
482
+ extra_conclusions_dict = expert.ask_for_extra_conclusions(case, conclusions)
483
+ if extra_conclusions_dict:
484
+ for conclusion, conditions in extra_conclusions_dict.items():
485
+ self.add_top_rule(conditions, conclusion, case)
486
+ extra_conclusions.append(conclusion)
487
+ return extra_conclusions
488
+
489
+ def add_conclusion(self, evaluated_rule: Rule) -> None:
490
+ """
491
+ Add the conclusion of the evaluated rule to the list of conclusions.
492
+
493
+ :param evaluated_rule: The evaluated rule to add the conclusion of.
494
+ """
495
+ conclusion_types = [type(c) for c in self.conclusions]
496
+ if type(evaluated_rule.conclusion) not in conclusion_types:
497
+ self.conclusions.append(evaluated_rule.conclusion)
498
+ else:
499
+ same_type_conclusions = [c for c in self.conclusions if type(c) == type(evaluated_rule.conclusion)]
500
+ combined_conclusion = evaluated_rule.conclusion if isinstance(evaluated_rule.conclusion, set) \
501
+ else {evaluated_rule.conclusion}
502
+ for c in same_type_conclusions:
503
+ combined_conclusion.update(c if isinstance(c, set) else make_set(c))
504
+ self.conclusions.remove(c)
505
+ self.conclusions.extend(combined_conclusion)
506
+
507
+ def add_top_rule(self, conditions: CallableExpression, conclusion: Any, corner_case: Union[Case, SQLTable]):
508
+ """
509
+ Add a top rule to the classifier, which is a rule that is always checked and is part of the start_rules list.
510
+
511
+ :param conditions: The conditions of the rule.
512
+ :param conclusion: The conclusion of the rule.
513
+ :param corner_case: The corner case of the rule.
514
+ """
515
+ self.start_rule.alternative = MultiClassTopRule(conditions, conclusion, corner_case=corner_case)
516
+
517
+
518
+ class GeneralRDR(RippleDownRules):
519
+ """
520
+ A general ripple down rules classifier, which can draw multiple conclusions for a case, but each conclusion is part
521
+ of a set of mutually exclusive conclusions. Whenever a conclusion is made, the classification restarts from the
522
+ starting rule, and all the rules that belong to the class of the made conclusion are not checked again. This
523
+ continues until no more rules can be fired. In addition, previous conclusions can be used as conditions or input to
524
+ the next classification/cycle.
525
+ Another possible mode is to have rules that are considered final, when fired, inference will not be restarted,
526
+ and only a refinement can be made to the final rule, those can also be used in another SCRDR of their own that
527
+ gets called when the final rule fires.
528
+ """
529
+
530
+ def __init__(self, category_rdr_map: Optional[Dict[Type, Union[SingleClassRDR, MultiClassRDR]]] = None):
531
+ """
532
+ :param category_rdr_map: A map of categories to ripple down rules classifiers,
533
+ where each category is a parent category that has a set of mutually exclusive (in case of SCRDR) child
534
+ categories, e.g. {Species: SCRDR, Habitat: MCRDR}, where Species and Habitat are parent categories and SCRDR
535
+ and MCRDR are SingleClass and MultiClass ripple down rules classifiers. Species can have child categories like
536
+ Mammal, Bird, Fish, etc. which are mutually exclusive, and Habitat can have child categories like
537
+ Land, Water, Air, etc, which are not mutually exclusive due to some animals living more than one habitat.
538
+ """
539
+ self.start_rules_dict: Dict[Type, Union[SingleClassRDR, MultiClassRDR]] \
540
+ = category_rdr_map if category_rdr_map else {}
541
+ super(GeneralRDR, self).__init__()
542
+ self.all_figs: List[plt.Figure] = [sr.fig for sr in self.start_rules_dict.values()]
543
+
544
+ @property
545
+ def start_rule(self) -> Optional[Union[SingleClassRule, MultiClassTopRule]]:
546
+ return self.start_rules[0] if self.start_rules_dict else None
547
+
548
+ @start_rule.setter
549
+ def start_rule(self, value: Union[SingleClassRDR, MultiClassRDR]):
550
+ if value:
551
+ self.start_rules_dict[type(value.start_rule.conclusion)] = value
552
+
553
+ @property
554
+ def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
555
+ return [rdr.start_rule for rdr in self.start_rules_dict.values()]
556
+
557
+ def classify(self, case: Union[Case, SQLTable]) -> Optional[List[Any]]:
558
+ """
559
+ Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
560
+ the classification until no more categories can be added.
561
+
562
+ :param case: The case to classify.
563
+ :return: The categories that the case belongs to.
564
+ """
565
+ conclusions = []
566
+ case_cp = copy_case(case)
567
+ while True:
568
+ added_attributes = False
569
+ for cat_type, rdr in self.start_rules_dict.items():
570
+ if self.case_has_conclusion(case_cp, cat_type):
571
+ continue
572
+ pred_atts = rdr.classify(case_cp)
573
+ if pred_atts:
574
+ pred_atts = pred_atts if isinstance(pred_atts, list) else [pred_atts]
575
+ pred_atts = [p for p in pred_atts if p not in conclusions]
576
+ added_attributes = True
577
+ conclusions.extend(pred_atts)
578
+ self.update_case_with_same_type_conclusions(case_cp, pred_atts)
579
+ if not added_attributes:
580
+ break
581
+ return conclusions
582
+
583
+ def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs)\
584
+ -> List[Union[Column, CallableExpression]]:
585
+ """
586
+ Fit the GRDR on a case, if the target is a new type of category, a new RDR is created for it,
587
+ else the existing RDR of that type will be fitted on the case, and then classification is done and all
588
+ concluded categories are returned. If the category is mutually exclusive, an SCRDR is created, else an MCRDR.
589
+ In case of SCRDR, multiple conclusions of the same type replace each other, in case of MCRDR, they are added if
590
+ they are accepted by the expert, and the attribute of that category is represented in the case as a set of
591
+ values.
592
+
593
+ :param case_queries: The queries containing the case to classify and the target categories to compare the case
594
+ with.
595
+ :param expert: The expert to ask for differentiating features as new rule conditions.
596
+ :return: The categories that the case belongs to.
597
+ """
598
+ expert = expert if expert else Human()
599
+ case_queries = [case_queries] if not isinstance(case_queries, list) else case_queries
600
+ assert len(case_queries) > 0, "No case queries provided"
601
+ case = case_queries[0].case
602
+ assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
603
+ " for multiple cases use fit instead")
604
+ case_query_cp = copy(case_queries[0])
605
+ case_cp = case_query_cp.case
606
+ for case_query in case_queries:
607
+ target = case_query.target
608
+ if not target:
609
+ target = expert.ask_for_conclusion(case_query)
610
+ case_query_cp = CaseQuery(case_cp, attribute_name=case_query.attribute_name, target=target)
611
+ if type(target) not in self.start_rules_dict:
612
+ conclusions = self.classify(case)
613
+ self.update_case_with_same_type_conclusions(case_cp, conclusions)
614
+ new_rdr = self.initialize_new_rdr_for_attribute(target, case_cp)
615
+ new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
616
+ self.start_rules_dict[type(target)] = new_rdr
617
+ self.update_case_with_same_type_conclusions(case_cp, new_conclusions, type(target))
618
+ elif not self.case_has_conclusion(case_cp, type(target)):
619
+ for rdr_type, rdr in self.start_rules_dict.items():
620
+ if type(target) is not rdr_type:
621
+ conclusions = rdr.classify(case_cp)
622
+ else:
623
+ conclusions = self.start_rules_dict[type(target)].fit_case(case_query_cp,
624
+ expert, **kwargs)
625
+ self.update_case_with_same_type_conclusions(case_cp, conclusions, rdr_type)
626
+
627
+ return self.classify(case)
628
+
629
+ @staticmethod
630
+ def initialize_new_rdr_for_attribute(attribute: Any, case: Union[Case, SQLTable]):
631
+ """
632
+ Initialize the appropriate RDR type for the target.
633
+ """
634
+ if isinstance(case, SQLTable):
635
+ prop = get_attribute_by_type(case, type(attribute))
636
+ if hasattr(prop, "__iter__") and not isinstance(prop, str):
637
+ return MultiClassRDR()
638
+ else:
639
+ return SingleClassRDR()
640
+ else:
641
+ return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
642
+
643
+ @staticmethod
644
+ def update_case_with_same_type_conclusions(case: Union[Case, SQLTable],
645
+ conclusions: List[Any], attribute_type: Optional[Any] = None):
646
+ """
647
+ Update the case with the conclusions.
648
+
649
+ :param case: The case to update.
650
+ :param conclusions: The conclusions to update the case with.
651
+ :param attribute_type: The type of the attribute to update.
652
+ """
653
+ if not conclusions:
654
+ return
655
+ conclusions = [conclusions] if not isinstance(conclusions, list) else conclusions
656
+ if len(conclusions) == 0:
657
+ return
658
+ if isinstance(case, SQLTable):
659
+ conclusions_type = type(conclusions[0]) if not attribute_type else attribute_type
660
+ attr_name, attribute = get_attribute_by_type(case, conclusions_type)
661
+ hint, origin, args = get_hint_for_attribute(attr_name, case)
662
+ if isinstance(attribute, set) or origin == set:
663
+ attribute = set() if attribute is None else attribute
664
+ attribute.update(*[make_set(c) for c in conclusions])
665
+ elif isinstance(attribute, list) or origin == list:
666
+ attribute = [] if attribute is None else attribute
667
+ attribute.extend(conclusions)
668
+ elif len(conclusions) == 1 and hint == conclusions_type:
669
+ setattr(case, attr_name, conclusions.pop())
670
+ else:
671
+ raise ValueError(f"Cannot add multiple conclusions to attribute {attr_name}")
672
+ else:
673
+ case.update(*[c.as_dict for c in make_set(conclusions)])
674
+
675
+ @property
676
+ def names_of_all_types(self) -> List[str]:
677
+ """
678
+ Get the names of all the types of categories that the GRDR can classify.
679
+ """
680
+ return [t.__name__ for t in self.start_rules_dict.keys()]
681
+
682
+ @property
683
+ def all_types(self) -> List[Type]:
684
+ """
685
+ Get all the types of categories that the GRDR can classify.
686
+ """
687
+ return list(self.start_rules_dict.keys())