ripple-down-rules 0.0.15__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +2 -2
- ripple_down_rules/datastructures/callable_expression.py +52 -10
- ripple_down_rules/datastructures/case.py +53 -70
- ripple_down_rules/datastructures/dataclasses.py +69 -29
- ripple_down_rules/experts.py +29 -40
- ripple_down_rules/helpers.py +27 -0
- ripple_down_rules/prompt.py +77 -24
- ripple_down_rules/rdr.py +214 -192
- ripple_down_rules/rdr_decorators.py +55 -0
- ripple_down_rules/rules.py +7 -2
- ripple_down_rules/utils.py +154 -3
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.0.dist-info}/METADATA +1 -1
- ripple_down_rules-0.1.0.dist-info/RECORD +20 -0
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.0.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.0.15.dist-info/RECORD +0 -18
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.0.dist-info}/top_level.txt +0 -0
ripple_down_rules/rdr.py
CHANGED
@@ -8,13 +8,13 @@ from types import ModuleType
|
|
8
8
|
from matplotlib import pyplot as plt
|
9
9
|
from ordered_set import OrderedSet
|
10
10
|
from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
|
11
|
-
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable
|
11
|
+
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
|
12
12
|
|
13
13
|
from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute, CaseQuery
|
14
14
|
from .experts import Expert, Human
|
15
15
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
16
|
-
from .utils import draw_tree, make_set,
|
17
|
-
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list,
|
16
|
+
from .utils import draw_tree, make_set, copy_case, \
|
17
|
+
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string
|
18
18
|
|
19
19
|
|
20
20
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -80,7 +80,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
80
80
|
:param kwargs_for_fit_case: The keyword arguments to pass to the fit_case method.
|
81
81
|
"""
|
82
82
|
cases = [case_query.case for case_query in case_queries]
|
83
|
-
targets = [
|
83
|
+
targets = [{case_query.attribute_name: case_query.target} for case_query in case_queries]
|
84
84
|
if animate_tree:
|
85
85
|
plt.ion()
|
86
86
|
i = 0
|
@@ -91,11 +91,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
91
91
|
if not targets:
|
92
92
|
targets = [None] * len(cases)
|
93
93
|
for case_query in case_queries:
|
94
|
-
|
95
|
-
target = case_query.target
|
96
|
-
if not target:
|
97
|
-
conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
|
98
|
-
target = expert.ask_for_conclusion(case_query, conclusions)
|
94
|
+
target = {case_query.attribute_name: case_query.target}
|
99
95
|
pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
|
100
96
|
match = self.is_matching(pred_cat, target)
|
101
97
|
if not match:
|
@@ -105,8 +101,9 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
105
101
|
num_rules = self.start_rule.size
|
106
102
|
self.update_figures()
|
107
103
|
i += 1
|
108
|
-
all_predictions = [1 if self.is_matching(self.classify(case),
|
109
|
-
|
104
|
+
all_predictions = [1 if self.is_matching(self.classify(case_query.case), {case_query.attribute_name:
|
105
|
+
case_query.target}) else 0
|
106
|
+
for case_query in case_queries]
|
110
107
|
all_pred = sum(all_predictions)
|
111
108
|
print(f"Accuracy: {all_pred}/{len(targets)}")
|
112
109
|
all_predicted = targets and all_pred == len(targets)
|
@@ -129,9 +126,33 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
129
126
|
"""
|
130
127
|
pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
|
131
128
|
target = target if is_iterable(target) else [target]
|
132
|
-
recall = [
|
133
|
-
|
134
|
-
|
129
|
+
recall = []
|
130
|
+
precision = []
|
131
|
+
if isinstance(pred_cat, dict):
|
132
|
+
for pred_key, pred_value in pred_cat.items():
|
133
|
+
if pred_key not in target:
|
134
|
+
continue
|
135
|
+
# if is_iterable(pred_value):
|
136
|
+
# print(pred_value, target[pred_key])
|
137
|
+
# precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
|
138
|
+
precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
|
139
|
+
# else:
|
140
|
+
# precision.append(pred_value == target[pred_key])
|
141
|
+
for target_key, target_value in target.items():
|
142
|
+
if target_key not in pred_cat:
|
143
|
+
recall.append(False)
|
144
|
+
continue
|
145
|
+
if is_iterable(target_value):
|
146
|
+
recall.extend([v in pred_cat[target_key] for v in target_value])
|
147
|
+
else:
|
148
|
+
recall.append(target_value == pred_cat[target_key])
|
149
|
+
print(f"Precision: {precision}, Recall: {recall}")
|
150
|
+
else:
|
151
|
+
if isinstance(target, dict):
|
152
|
+
target = list(target.values())
|
153
|
+
recall = [not yi or (yi in pred_cat) for yi in target]
|
154
|
+
target_types = [type(yi) for yi in target]
|
155
|
+
precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
|
135
156
|
return precision, recall
|
136
157
|
|
137
158
|
def is_matching(self, pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> bool:
|
@@ -158,22 +179,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
158
179
|
draw_tree(self.start_rule, self.fig)
|
159
180
|
|
160
181
|
@staticmethod
|
161
|
-
def case_has_conclusion(case: Union[Case, SQLTable],
|
182
|
+
def case_has_conclusion(case: Union[Case, SQLTable], conclusion_name: str) -> bool:
|
162
183
|
"""
|
163
184
|
Check if the case has a conclusion.
|
164
185
|
|
165
186
|
:param case: The case to check.
|
166
|
-
:param
|
187
|
+
:param conclusion_name: The target category name to compare the case with.
|
167
188
|
:return: Whether the case has a conclusion or not.
|
168
189
|
"""
|
169
|
-
|
170
|
-
prop_name, prop_value = get_attribute_by_type(case, conclusion_type)
|
171
|
-
if hasattr(prop_value, "__iter__") and not isinstance(prop_value, str):
|
172
|
-
return len(prop_value) > 0
|
173
|
-
else:
|
174
|
-
return prop_value is not None
|
175
|
-
else:
|
176
|
-
return conclusion_type in case
|
190
|
+
return hasattr(case, conclusion_name) and getattr(case, conclusion_name) is not None
|
177
191
|
|
178
192
|
|
179
193
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
@@ -244,7 +258,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
244
258
|
:return: The type of the case (input) to the RDR classifier.
|
245
259
|
"""
|
246
260
|
if isinstance(self.start_rule.corner_case, Case):
|
247
|
-
return self.start_rule.corner_case.
|
261
|
+
return self.start_rule.corner_case._obj_type
|
248
262
|
else:
|
249
263
|
return type(self.start_rule.corner_case)
|
250
264
|
|
@@ -260,6 +274,13 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
260
274
|
return type(list(self.start_rule.conclusion)[0])
|
261
275
|
return type(self.start_rule.conclusion)
|
262
276
|
|
277
|
+
@property
|
278
|
+
def attribute_name(self) -> str:
|
279
|
+
"""
|
280
|
+
:return: The name of the attribute that the classifier is classifying.
|
281
|
+
"""
|
282
|
+
return self.start_rule.conclusion_name
|
283
|
+
|
263
284
|
|
264
285
|
class SingleClassRDR(RDRWithCodeWriter):
|
265
286
|
|
@@ -274,23 +295,20 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
274
295
|
:return: The category that the case belongs to.
|
275
296
|
"""
|
276
297
|
expert = expert if expert else Human(session=self.session)
|
277
|
-
case, attribute = case_query.case, case_query.attribute
|
278
298
|
if case_query.target is None:
|
279
299
|
target = expert.ask_for_conclusion(case_query)
|
280
|
-
else:
|
281
|
-
target = case_query.target
|
282
|
-
|
283
300
|
if not self.start_rule:
|
284
|
-
conditions = expert.ask_for_conditions(
|
285
|
-
self.start_rule = SingleClassRule(conditions, target, corner_case=case
|
301
|
+
conditions = expert.ask_for_conditions(case_query)
|
302
|
+
self.start_rule = SingleClassRule(conditions, case_query.target, corner_case=case_query.case,
|
303
|
+
conclusion_name=case_query.attribute_name)
|
286
304
|
|
287
|
-
pred = self.evaluate(case)
|
305
|
+
pred = self.evaluate(case_query.case)
|
288
306
|
|
289
|
-
if pred.conclusion != target:
|
290
|
-
conditions = expert.ask_for_conditions(
|
291
|
-
pred.fit_rule(case, target, conditions=conditions)
|
307
|
+
if pred.conclusion != case_query.target:
|
308
|
+
conditions = expert.ask_for_conditions(case_query, pred)
|
309
|
+
pred.fit_rule(case_query.case, case_query.target, conditions=conditions)
|
292
310
|
|
293
|
-
return self.classify(case)
|
311
|
+
return self.classify(case_query.case)
|
294
312
|
|
295
313
|
def classify(self, case: Case) -> Optional[CaseAttribute]:
|
296
314
|
"""
|
@@ -388,44 +406,41 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
388
406
|
:return: The conclusions that the case belongs to.
|
389
407
|
"""
|
390
408
|
expert = expert if expert else Human(session=self.session)
|
391
|
-
case = case_query.case
|
392
409
|
if case_query.target is None:
|
393
|
-
targets =
|
394
|
-
else:
|
395
|
-
targets = [case_query.target]
|
410
|
+
targets = expert.ask_for_conclusion(case_query)
|
396
411
|
self.expert_accepted_conclusions = []
|
397
412
|
user_conclusions = []
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
413
|
+
self.update_start_rule(case_query, expert)
|
414
|
+
self.conclusions = []
|
415
|
+
self.stop_rule_conditions = None
|
416
|
+
evaluated_rule = self.start_rule
|
417
|
+
while evaluated_rule:
|
418
|
+
next_rule = evaluated_rule(case_query.case)
|
419
|
+
good_conclusions = make_list(case_query.target) + user_conclusions + self.expert_accepted_conclusions
|
420
|
+
good_conclusions = make_set(good_conclusions)
|
421
|
+
|
422
|
+
if evaluated_rule.fired:
|
423
|
+
if case_query.target and not make_set(evaluated_rule.conclusion).issubset(good_conclusions):
|
424
|
+
# if self.case_has_conclusion(case, evaluated_rule.conclusion):
|
425
|
+
# Rule fired and conclusion is different from target
|
426
|
+
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
|
427
|
+
add_extra_conclusions)
|
428
|
+
else:
|
429
|
+
# Rule fired and target is correct or there is no target to compare
|
430
|
+
self.add_conclusion(evaluated_rule)
|
431
|
+
|
432
|
+
if not next_rule:
|
433
|
+
if not make_set(case_query.target).intersection(make_set(self.conclusions)):
|
434
|
+
# Nothing fired and there is a target that should have been in the conclusions
|
435
|
+
self.add_rule_for_case(case_query, expert)
|
436
|
+
# Have to check all rules again to make sure only this new rule fires
|
437
|
+
next_rule = self.start_rule
|
438
|
+
elif add_extra_conclusions and not user_conclusions:
|
439
|
+
# No more conclusions can be made, ask the expert for extra conclusions if needed.
|
440
|
+
user_conclusions.extend(self.ask_expert_for_extra_conclusions(expert, case_query.case))
|
441
|
+
if user_conclusions:
|
442
|
+
next_rule = self.last_top_rule
|
443
|
+
evaluated_rule = next_rule
|
429
444
|
return self.conclusions
|
430
445
|
|
431
446
|
def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
|
@@ -458,19 +473,18 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
458
473
|
imports += "from typing_extensions import Set\n"
|
459
474
|
return imports
|
460
475
|
|
461
|
-
def update_start_rule(self,
|
476
|
+
def update_start_rule(self, case_query: CaseQuery, expert: Expert):
|
462
477
|
"""
|
463
478
|
Update the starting rule of the classifier.
|
464
479
|
|
465
|
-
:param
|
466
|
-
:param target: The target category to compare the case with.
|
480
|
+
:param case_query: The case query to update the starting rule with.
|
467
481
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
468
482
|
"""
|
469
483
|
if not self.start_rule.conditions:
|
470
|
-
conditions = expert.ask_for_conditions(
|
484
|
+
conditions = expert.ask_for_conditions(case_query)
|
471
485
|
self.start_rule.conditions = conditions
|
472
|
-
self.start_rule.conclusion = target
|
473
|
-
self.start_rule.corner_case = case
|
486
|
+
self.start_rule.conclusion = case_query.target
|
487
|
+
self.start_rule.corner_case = case_query.case
|
474
488
|
|
475
489
|
@property
|
476
490
|
def last_top_rule(self) -> Optional[MultiClassTopRule]:
|
@@ -482,35 +496,34 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
482
496
|
else:
|
483
497
|
return self.start_rule.furthest_alternative[-1]
|
484
498
|
|
485
|
-
def stop_wrong_conclusion_else_add_it(self,
|
499
|
+
def stop_wrong_conclusion_else_add_it(self, case_query: CaseQuery, expert: Expert,
|
486
500
|
evaluated_rule: MultiClassTopRule,
|
487
501
|
add_extra_conclusions: bool):
|
488
502
|
"""
|
489
503
|
Stop a wrong conclusion by adding a stopping rule.
|
490
504
|
"""
|
491
|
-
if self.is_same_category_type(evaluated_rule.conclusion, target) \
|
492
|
-
and self.is_conflicting_with_target(evaluated_rule.conclusion, target):
|
493
|
-
self.stop_conclusion(
|
494
|
-
elif not self.conclusion_is_correct(
|
495
|
-
self.stop_conclusion(
|
505
|
+
if self.is_same_category_type(evaluated_rule.conclusion, case_query.target) \
|
506
|
+
and self.is_conflicting_with_target(evaluated_rule.conclusion, case_query.target):
|
507
|
+
self.stop_conclusion(case_query, expert, evaluated_rule)
|
508
|
+
elif not self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
|
509
|
+
self.stop_conclusion(case_query, expert, evaluated_rule)
|
496
510
|
|
497
|
-
def stop_conclusion(self,
|
511
|
+
def stop_conclusion(self, case_query: CaseQuery,
|
498
512
|
expert: Expert, evaluated_rule: MultiClassTopRule):
|
499
513
|
"""
|
500
514
|
Stop a conclusion by adding a stopping rule.
|
501
515
|
|
502
|
-
:param
|
503
|
-
:param target: The target category to compare the case with.
|
516
|
+
:param case_query: The case query to stop the conclusion for.
|
504
517
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
505
518
|
:param evaluated_rule: The evaluated rule to ask the expert about.
|
506
519
|
"""
|
507
|
-
conditions = expert.ask_for_conditions(
|
508
|
-
evaluated_rule.fit_rule(case, target, conditions=conditions)
|
520
|
+
conditions = expert.ask_for_conditions(case_query, evaluated_rule)
|
521
|
+
evaluated_rule.fit_rule(case_query.case, case_query.target, conditions=conditions)
|
509
522
|
if self.mode == MCRDRMode.StopPlusRule:
|
510
523
|
self.stop_rule_conditions = conditions
|
511
524
|
if self.mode == MCRDRMode.StopPlusRuleCombined:
|
512
525
|
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
513
|
-
self.add_top_rule(new_top_rule_conditions, target, case)
|
526
|
+
self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
|
514
527
|
|
515
528
|
@staticmethod
|
516
529
|
def is_conflicting_with_target(conclusion: Any, target: Any) -> bool:
|
@@ -537,37 +550,40 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
537
550
|
"""
|
538
551
|
return conclusion.__class__ == target.__class__ and target.__class__ != CaseAttribute
|
539
552
|
|
540
|
-
def conclusion_is_correct(self,
|
553
|
+
def conclusion_is_correct(self, case_query: CaseQuery,
|
554
|
+
expert: Expert, evaluated_rule: Rule,
|
541
555
|
add_extra_conclusions: bool) -> bool:
|
542
556
|
"""
|
543
557
|
Ask the expert if the conclusion is correct, and add it to the conclusions if it is.
|
544
558
|
|
545
|
-
:param
|
546
|
-
:param target: The target category to compare the case with.
|
559
|
+
:param case_query: The case query to ask the expert about.
|
547
560
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
548
561
|
:param evaluated_rule: The evaluated rule to ask the expert about.
|
549
562
|
:param add_extra_conclusions: Whether adding extra conclusions after classification is allowed.
|
550
563
|
:return: Whether the conclusion is correct or not.
|
551
564
|
"""
|
552
|
-
conclusions =
|
553
|
-
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case, evaluated_rule.conclusion,
|
554
|
-
targets=target,
|
565
|
+
conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
|
566
|
+
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case, evaluated_rule.conclusion,
|
567
|
+
targets=case_query.target,
|
555
568
|
current_conclusions=conclusions)):
|
556
569
|
self.add_conclusion(evaluated_rule)
|
557
570
|
self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
|
558
571
|
return True
|
559
572
|
return False
|
560
573
|
|
561
|
-
def add_rule_for_case(self,
|
574
|
+
def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
|
562
575
|
"""
|
563
576
|
Add a rule for a case that has not been classified with any conclusion.
|
577
|
+
|
578
|
+
:param case_query: The case query to add the rule for.
|
579
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
564
580
|
"""
|
565
581
|
if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
|
566
582
|
conditions = self.stop_rule_conditions
|
567
583
|
self.stop_rule_conditions = None
|
568
584
|
else:
|
569
|
-
conditions = expert.ask_for_conditions(
|
570
|
-
self.add_top_rule(conditions, target, case)
|
585
|
+
conditions = expert.ask_for_conditions(case_query)
|
586
|
+
self.add_top_rule(conditions, case_query.target, case_query.case)
|
571
587
|
|
572
588
|
def ask_expert_for_extra_conclusions(self, expert: Expert, case: Union[Case, SQLTable]) -> List[Any]:
|
573
589
|
"""
|
@@ -641,20 +657,31 @@ class GeneralRDR(RippleDownRules):
|
|
641
657
|
gets called when the final rule fires.
|
642
658
|
"""
|
643
659
|
|
644
|
-
def __init__(self, category_rdr_map: Optional[Dict[
|
660
|
+
def __init__(self, category_rdr_map: Optional[Dict[str, Union[SingleClassRDR, MultiClassRDR]]] = None):
|
645
661
|
"""
|
646
|
-
:param category_rdr_map: A map of
|
662
|
+
:param category_rdr_map: A map of case attribute names to ripple down rules classifiers,
|
647
663
|
where each category is a parent category that has a set of mutually exclusive (in case of SCRDR) child
|
648
|
-
categories, e.g. {
|
649
|
-
and MCRDR are SingleClass and MultiClass ripple down rules classifiers.
|
650
|
-
Mammal, Bird, Fish, etc. which are mutually exclusive,
|
651
|
-
Land, Water, Air, etc
|
664
|
+
categories, e.g. {'species': SCRDR, 'habitats': MCRDR}, where 'species' and 'habitats' are attribute names
|
665
|
+
for a case of type Animal, while SCRDR and MCRDR are SingleClass and MultiClass ripple down rules classifiers.
|
666
|
+
Species can have values like Mammal, Bird, Fish, etc. which are mutually exclusive, while Habitat can have
|
667
|
+
values like Land, Water, Air, etc., which are not mutually exclusive due to some animals living more than one
|
668
|
+
habitat.
|
652
669
|
"""
|
653
|
-
self.start_rules_dict: Dict[
|
670
|
+
self.start_rules_dict: Dict[str, Union[SingleClassRDR, MultiClassRDR]] \
|
654
671
|
= category_rdr_map if category_rdr_map else {}
|
655
672
|
super(GeneralRDR, self).__init__()
|
656
673
|
self.all_figs: List[plt.Figure] = [sr.fig for sr in self.start_rules_dict.values()]
|
657
674
|
|
675
|
+
def add_rdr(self, rdr: Union[SingleClassRDR, MultiClassRDR], attribute_name: Optional[str] = None):
|
676
|
+
"""
|
677
|
+
Add a ripple down rules classifier to the map of classifiers.
|
678
|
+
|
679
|
+
:param rdr: The ripple down rules classifier to add.
|
680
|
+
:param attribute_name: The name of the attribute that the classifier is classifying.
|
681
|
+
"""
|
682
|
+
attribute_name = attribute_name if attribute_name else rdr.attribute_name
|
683
|
+
self.start_rules_dict[attribute_name] = rdr
|
684
|
+
|
658
685
|
@property
|
659
686
|
def start_rule(self) -> Optional[Union[SingleClassRule, MultiClassTopRule]]:
|
660
687
|
return self.start_rules[0] if self.start_rules_dict else None
|
@@ -662,7 +689,7 @@ class GeneralRDR(RippleDownRules):
|
|
662
689
|
@start_rule.setter
|
663
690
|
def start_rule(self, value: Union[SingleClassRDR, MultiClassRDR]):
|
664
691
|
if value:
|
665
|
-
self.start_rules_dict[value.
|
692
|
+
self.start_rules_dict[value.attribute_name] = value
|
666
693
|
|
667
694
|
@property
|
668
695
|
def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
|
@@ -679,8 +706,8 @@ class GeneralRDR(RippleDownRules):
|
|
679
706
|
return self._classify(self.start_rules_dict, case)
|
680
707
|
|
681
708
|
@staticmethod
|
682
|
-
def _classify(classifiers_dict: Dict[
|
683
|
-
case: Union[Case, SQLTable]) -> Optional[
|
709
|
+
def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
|
710
|
+
case: Union[Case, SQLTable]) -> Optional[Dict[str, Any]]:
|
684
711
|
"""
|
685
712
|
Classify a case by going through all classifiers and adding the categories that are classified,
|
686
713
|
and then restarting the classification until no more categories can be added.
|
@@ -689,21 +716,31 @@ class GeneralRDR(RippleDownRules):
|
|
689
716
|
:param case: The case to classify.
|
690
717
|
:return: The categories that the case belongs to.
|
691
718
|
"""
|
692
|
-
conclusions =
|
719
|
+
conclusions = {}
|
693
720
|
case_cp = copy_case(case)
|
694
721
|
while True:
|
695
|
-
|
696
|
-
for
|
697
|
-
if GeneralRDR.case_has_conclusion(case_cp, cat_type):
|
698
|
-
continue
|
722
|
+
new_conclusions = {}
|
723
|
+
for attribute_name, rdr in classifiers_dict.items():
|
699
724
|
pred_atts = rdr.classify(case_cp)
|
700
|
-
if pred_atts:
|
725
|
+
if pred_atts is None:
|
726
|
+
continue
|
727
|
+
if isinstance(rdr, SingleClassRDR):
|
728
|
+
if attribute_name not in conclusions or \
|
729
|
+
(attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
|
730
|
+
conclusions[attribute_name] = pred_atts
|
731
|
+
new_conclusions[attribute_name] = pred_atts
|
732
|
+
else:
|
701
733
|
pred_atts = make_list(pred_atts)
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
734
|
+
if attribute_name in conclusions:
|
735
|
+
pred_atts = [p for p in pred_atts if p not in conclusions[attribute_name]]
|
736
|
+
if len(pred_atts) > 0:
|
737
|
+
new_conclusions[attribute_name] = pred_atts
|
738
|
+
if attribute_name not in conclusions:
|
739
|
+
conclusions[attribute_name] = []
|
740
|
+
conclusions[attribute_name].extend(pred_atts)
|
741
|
+
if attribute_name in new_conclusions:
|
742
|
+
GeneralRDR.update_case(case_cp, new_conclusions)
|
743
|
+
if len(new_conclusions) == 0:
|
707
744
|
break
|
708
745
|
return conclusions
|
709
746
|
|
@@ -728,103 +765,79 @@ class GeneralRDR(RippleDownRules):
|
|
728
765
|
case = case_queries[0].case
|
729
766
|
assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
|
730
767
|
" for multiple cases use fit instead")
|
731
|
-
|
732
|
-
case_cp = case_query_cp.case
|
768
|
+
case_cp = copy(case_queries[0]).case
|
733
769
|
for case_query in case_queries:
|
734
|
-
|
735
|
-
|
770
|
+
case_query_cp = copy(case_query)
|
771
|
+
case_query_cp.case = case_cp
|
772
|
+
if case_query.target is None:
|
773
|
+
conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
|
736
774
|
target = expert.ask_for_conclusion(case_query)
|
737
|
-
|
738
|
-
if
|
739
|
-
target_type = type(make_list(target)[0])
|
740
|
-
assert all([type(t) is target_type for t in target]), ("All targets of a case query must be of the same"
|
741
|
-
" type")
|
742
|
-
else:
|
743
|
-
target_type = type(target)
|
744
|
-
if target_type not in self.start_rules_dict:
|
775
|
+
|
776
|
+
if case_query.attribute_name not in self.start_rules_dict:
|
745
777
|
conclusions = self.classify(case)
|
746
778
|
self.update_case(case_cp, conclusions)
|
747
|
-
|
779
|
+
|
780
|
+
new_rdr = self.initialize_new_rdr_for_attribute(case_query.attribute_name, case_cp, case_query.target)
|
781
|
+
self.add_rdr(new_rdr, case_query.attribute_name)
|
782
|
+
|
748
783
|
new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
|
749
|
-
self.
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
if target_type is not rdr_type:
|
784
|
+
self.update_case(case_cp, {case_query.attribute_name: new_conclusions})
|
785
|
+
else:
|
786
|
+
for rdr_attribute_name, rdr in self.start_rules_dict.items():
|
787
|
+
if case_query.attribute_name != rdr_attribute_name:
|
754
788
|
conclusions = rdr.classify(case_cp)
|
755
789
|
else:
|
756
|
-
conclusions = self.start_rules_dict[
|
757
|
-
|
758
|
-
|
790
|
+
conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
|
791
|
+
**kwargs)
|
792
|
+
if conclusions is not None or (is_iterable(conclusions) and len(conclusions) > 0):
|
793
|
+
conclusions = {rdr_attribute_name: conclusions}
|
794
|
+
self.update_case(case_cp, conclusions)
|
759
795
|
|
760
796
|
return self.classify(case)
|
761
797
|
|
762
798
|
@staticmethod
|
763
|
-
def initialize_new_rdr_for_attribute(
|
799
|
+
def initialize_new_rdr_for_attribute(attribute_name: str, case: Union[Case, SQLTable], target: Any):
|
764
800
|
"""
|
765
801
|
Initialize the appropriate RDR type for the target.
|
766
802
|
"""
|
767
|
-
if
|
768
|
-
|
769
|
-
if hasattr(prop, "__iter__") and not isinstance(prop, str):
|
770
|
-
return MultiClassRDR()
|
771
|
-
else:
|
772
|
-
return SingleClassRDR()
|
773
|
-
elif isinstance(attribute, CaseAttribute):
|
803
|
+
attribute = getattr(case, attribute_name) if hasattr(case, attribute_name) else target
|
804
|
+
if isinstance(attribute, CaseAttribute):
|
774
805
|
return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
|
775
806
|
else:
|
776
|
-
return MultiClassRDR() if is_iterable(attribute) else SingleClassRDR()
|
807
|
+
return MultiClassRDR() if is_iterable(attribute) or (attribute is None) else SingleClassRDR()
|
777
808
|
|
778
809
|
@staticmethod
|
779
|
-
def update_case(case: Union[Case, SQLTable],
|
780
|
-
conclusions: List[Any], attribute_type: Optional[Any] = None):
|
810
|
+
def update_case(case: Union[Case, SQLTable], conclusions: Dict[str, Any]):
|
781
811
|
"""
|
782
812
|
Update the case with the conclusions.
|
783
813
|
|
784
814
|
:param case: The case to update.
|
785
815
|
:param conclusions: The conclusions to update the case with.
|
786
|
-
:param attribute_type: The type of the attribute to update.
|
787
816
|
"""
|
788
817
|
if not conclusions:
|
789
818
|
return
|
790
|
-
conclusions = [conclusions] if not isinstance(conclusions, list) else list(conclusions)
|
791
819
|
if len(conclusions) == 0:
|
792
820
|
return
|
793
821
|
if isinstance(case, SQLTable):
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
822
|
+
for conclusion_name, conclusion in conclusions.items():
|
823
|
+
hint, origin, args = get_hint_for_attribute(conclusion_name, case)
|
824
|
+
attribute = getattr(case, conclusion_name)
|
825
|
+
if isinstance(attribute, set) or origin in {Set, set}:
|
826
|
+
attribute = set() if attribute is None else attribute
|
827
|
+
for c in conclusion:
|
828
|
+
attribute.update(make_set(c))
|
829
|
+
elif isinstance(attribute, list) or origin in {list, List}:
|
830
|
+
attribute = [] if attribute is None else attribute
|
831
|
+
attribute.extend(conclusion)
|
832
|
+
elif (not is_iterable(conclusion) or (len(conclusion) == 1)) and hint == type(conclusion):
|
833
|
+
setattr(case, conclusion_name, conclusion)
|
834
|
+
else:
|
835
|
+
raise ValueError(f"Cannot add multiple conclusions to attribute {conclusion_name}")
|
808
836
|
else:
|
809
|
-
|
810
|
-
case.update(c.as_dict)
|
811
|
-
|
812
|
-
@property
|
813
|
-
def names_of_all_types(self) -> List[str]:
|
814
|
-
"""
|
815
|
-
Get the names of all the types of categories that the GRDR can classify.
|
816
|
-
"""
|
817
|
-
return [t.__name__ for t in self.start_rules_dict.keys()]
|
818
|
-
|
819
|
-
@property
|
820
|
-
def all_types(self) -> List[Type]:
|
821
|
-
"""
|
822
|
-
Get all the types of categories that the GRDR can classify.
|
823
|
-
"""
|
824
|
-
return list(self.start_rules_dict.keys())
|
837
|
+
case.update(conclusions)
|
825
838
|
|
826
839
|
def _to_json(self) -> Dict[str, Any]:
|
827
|
-
return {"start_rules": {
|
840
|
+
return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
|
828
841
|
|
829
842
|
@classmethod
|
830
843
|
def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
|
@@ -833,7 +846,6 @@ class GeneralRDR(RippleDownRules):
|
|
833
846
|
"""
|
834
847
|
start_rules_dict = {}
|
835
848
|
for k, v in data["start_rules"].items():
|
836
|
-
k = get_type_from_string(k)
|
837
849
|
start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
|
838
850
|
return cls(start_rules_dict)
|
839
851
|
|
@@ -849,8 +861,8 @@ class GeneralRDR(RippleDownRules):
|
|
849
861
|
with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
|
850
862
|
f.write(self._get_imports(file_path) + "\n\n")
|
851
863
|
f.write("classifiers_dict = dict()\n")
|
852
|
-
for
|
853
|
-
f.write(f"classifiers_dict[{
|
864
|
+
for rdr_key, rdr in self.start_rules_dict.items():
|
865
|
+
f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
|
854
866
|
f.write("\n\n")
|
855
867
|
f.write(func_def)
|
856
868
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
@@ -863,7 +875,7 @@ class GeneralRDR(RippleDownRules):
|
|
863
875
|
:return: The type of the case (input) to the RDR classifier.
|
864
876
|
"""
|
865
877
|
if isinstance(self.start_rule.corner_case, Case):
|
866
|
-
return self.start_rule.corner_case.
|
878
|
+
return self.start_rule.corner_case._obj_type
|
867
879
|
else:
|
868
880
|
return type(self.start_rule.corner_case)
|
869
881
|
|
@@ -892,10 +904,20 @@ class GeneralRDR(RippleDownRules):
|
|
892
904
|
imports += f"from ripple_down_rules.datastructures import Case, create_case\n"
|
893
905
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|
894
906
|
# add conclusion type imports
|
895
|
-
for
|
896
|
-
imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
|
907
|
+
for rdr in self.start_rules_dict.values():
|
908
|
+
imports += f"from {rdr.conclusion_type.__module__} import {rdr.conclusion_type.__name__}\n"
|
897
909
|
# add rdr python generated functions.
|
898
|
-
for
|
910
|
+
for rdr_key, rdr in self.start_rules_dict.items():
|
899
911
|
imports += (f"from {file_path.strip('./')}"
|
900
|
-
f" import {rdr.generated_python_file_name} as {
|
912
|
+
f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
|
901
913
|
return imports
|
914
|
+
|
915
|
+
@staticmethod
|
916
|
+
def rdr_key_to_function_name(rdr_key: str) -> str:
|
917
|
+
"""
|
918
|
+
Convert the RDR key to a function name.
|
919
|
+
|
920
|
+
:param rdr_key: The RDR key to convert.
|
921
|
+
:return: The function name.
|
922
|
+
"""
|
923
|
+
return rdr_key.replace(".", "_").lower() + "_classifier"
|