ripple-down-rules 0.0.15__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +2 -2
- ripple_down_rules/datastructures/callable_expression.py +52 -10
- ripple_down_rules/datastructures/case.py +54 -70
- ripple_down_rules/datastructures/dataclasses.py +69 -29
- ripple_down_rules/experts.py +29 -40
- ripple_down_rules/helpers.py +27 -0
- ripple_down_rules/prompt.py +77 -24
- ripple_down_rules/rdr.py +218 -200
- ripple_down_rules/rdr_decorators.py +55 -0
- ripple_down_rules/rules.py +7 -2
- ripple_down_rules/utils.py +167 -3
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/METADATA +1 -1
- ripple_down_rules-0.1.1.dist-info/RECORD +20 -0
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.0.15.dist-info/RECORD +0 -18
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/top_level.txt +0 -0
ripple_down_rules/rdr.py
CHANGED
@@ -8,13 +8,13 @@ from types import ModuleType
|
|
8
8
|
from matplotlib import pyplot as plt
|
9
9
|
from ordered_set import OrderedSet
|
10
10
|
from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
|
11
|
-
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable
|
11
|
+
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
|
12
12
|
|
13
13
|
from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute, CaseQuery
|
14
14
|
from .experts import Expert, Human
|
15
15
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
16
|
-
from .utils import draw_tree, make_set,
|
17
|
-
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list,
|
16
|
+
from .utils import draw_tree, make_set, copy_case, \
|
17
|
+
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string
|
18
18
|
|
19
19
|
|
20
20
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -80,7 +80,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
80
80
|
:param kwargs_for_fit_case: The keyword arguments to pass to the fit_case method.
|
81
81
|
"""
|
82
82
|
cases = [case_query.case for case_query in case_queries]
|
83
|
-
targets = [
|
83
|
+
targets = [{case_query.attribute_name: case_query.target} for case_query in case_queries]
|
84
84
|
if animate_tree:
|
85
85
|
plt.ion()
|
86
86
|
i = 0
|
@@ -91,11 +91,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
91
91
|
if not targets:
|
92
92
|
targets = [None] * len(cases)
|
93
93
|
for case_query in case_queries:
|
94
|
-
|
95
|
-
target = case_query.target
|
96
|
-
if not target:
|
97
|
-
conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
|
98
|
-
target = expert.ask_for_conclusion(case_query, conclusions)
|
94
|
+
target = {case_query.attribute_name: case_query.target}
|
99
95
|
pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
|
100
96
|
match = self.is_matching(pred_cat, target)
|
101
97
|
if not match:
|
@@ -105,8 +101,9 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
105
101
|
num_rules = self.start_rule.size
|
106
102
|
self.update_figures()
|
107
103
|
i += 1
|
108
|
-
all_predictions = [1 if self.is_matching(self.classify(case),
|
109
|
-
|
104
|
+
all_predictions = [1 if self.is_matching(self.classify(case_query.case), {case_query.attribute_name:
|
105
|
+
case_query.target}) else 0
|
106
|
+
for case_query in case_queries]
|
110
107
|
all_pred = sum(all_predictions)
|
111
108
|
print(f"Accuracy: {all_pred}/{len(targets)}")
|
112
109
|
all_predicted = targets and all_pred == len(targets)
|
@@ -129,9 +126,27 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
129
126
|
"""
|
130
127
|
pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
|
131
128
|
target = target if is_iterable(target) else [target]
|
132
|
-
recall = [
|
133
|
-
|
134
|
-
|
129
|
+
recall = []
|
130
|
+
precision = []
|
131
|
+
if isinstance(pred_cat, dict):
|
132
|
+
for pred_key, pred_value in pred_cat.items():
|
133
|
+
if pred_key not in target:
|
134
|
+
continue
|
135
|
+
precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
|
136
|
+
for target_key, target_value in target.items():
|
137
|
+
if target_key not in pred_cat:
|
138
|
+
recall.append(False)
|
139
|
+
continue
|
140
|
+
if is_iterable(target_value):
|
141
|
+
recall.extend([v in pred_cat[target_key] for v in target_value])
|
142
|
+
else:
|
143
|
+
recall.append(target_value == pred_cat[target_key])
|
144
|
+
else:
|
145
|
+
if isinstance(target, dict):
|
146
|
+
target = list(target.values())
|
147
|
+
recall = [not yi or (yi in pred_cat) for yi in target]
|
148
|
+
target_types = [type(yi) for yi in target]
|
149
|
+
precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
|
135
150
|
return precision, recall
|
136
151
|
|
137
152
|
def is_matching(self, pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> bool:
|
@@ -158,22 +173,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
158
173
|
draw_tree(self.start_rule, self.fig)
|
159
174
|
|
160
175
|
@staticmethod
|
161
|
-
def case_has_conclusion(case: Union[Case, SQLTable],
|
176
|
+
def case_has_conclusion(case: Union[Case, SQLTable], conclusion_name: str) -> bool:
|
162
177
|
"""
|
163
178
|
Check if the case has a conclusion.
|
164
179
|
|
165
180
|
:param case: The case to check.
|
166
|
-
:param
|
181
|
+
:param conclusion_name: The target category name to compare the case with.
|
167
182
|
:return: Whether the case has a conclusion or not.
|
168
183
|
"""
|
169
|
-
|
170
|
-
prop_name, prop_value = get_attribute_by_type(case, conclusion_type)
|
171
|
-
if hasattr(prop_value, "__iter__") and not isinstance(prop_value, str):
|
172
|
-
return len(prop_value) > 0
|
173
|
-
else:
|
174
|
-
return prop_value is not None
|
175
|
-
else:
|
176
|
-
return conclusion_type in case
|
184
|
+
return hasattr(case, conclusion_name) and getattr(case, conclusion_name) is not None
|
177
185
|
|
178
186
|
|
179
187
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
@@ -200,7 +208,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
200
208
|
f.write(self._get_imports() + "\n\n")
|
201
209
|
f.write(func_def)
|
202
210
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
203
|
-
f"{' ' * 4} case = create_case(case,
|
211
|
+
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
204
212
|
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
|
205
213
|
|
206
214
|
@property
|
@@ -221,6 +229,11 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
221
229
|
if self.conclusion_type.__module__ != "builtins":
|
222
230
|
imports += f"from {self.conclusion_type.__module__} import {self.conclusion_type.__name__}\n"
|
223
231
|
imports += "from ripple_down_rules.datastructures import Case, create_case\n"
|
232
|
+
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
233
|
+
if rule.conditions:
|
234
|
+
if rule.conditions.scope is not None and len(rule.conditions.scope) > 0:
|
235
|
+
for k, v in rule.conditions.scope.items():
|
236
|
+
imports += f"from {v.__module__} import {v.__name__}\n"
|
224
237
|
return imports
|
225
238
|
|
226
239
|
def get_rdr_classifier_from_python_file(self, package_name) -> Callable[[Any], Any]:
|
@@ -232,11 +245,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
232
245
|
|
233
246
|
@property
|
234
247
|
def generated_python_file_name(self) -> str:
|
235
|
-
return f"{self.
|
236
|
-
|
237
|
-
@property
|
238
|
-
def python_file_name(self):
|
239
|
-
return f"{self.start_rule.conclusion.__name__.lower()}_rdr"
|
248
|
+
return f"{self.start_rule.corner_case._name.lower()}_{self.attribute_name}_rdr"
|
240
249
|
|
241
250
|
@property
|
242
251
|
def case_type(self) -> Type:
|
@@ -244,7 +253,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
244
253
|
:return: The type of the case (input) to the RDR classifier.
|
245
254
|
"""
|
246
255
|
if isinstance(self.start_rule.corner_case, Case):
|
247
|
-
return self.start_rule.corner_case.
|
256
|
+
return self.start_rule.corner_case._obj_type
|
248
257
|
else:
|
249
258
|
return type(self.start_rule.corner_case)
|
250
259
|
|
@@ -260,6 +269,13 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
260
269
|
return type(list(self.start_rule.conclusion)[0])
|
261
270
|
return type(self.start_rule.conclusion)
|
262
271
|
|
272
|
+
@property
|
273
|
+
def attribute_name(self) -> str:
|
274
|
+
"""
|
275
|
+
:return: The name of the attribute that the classifier is classifying.
|
276
|
+
"""
|
277
|
+
return self.start_rule.conclusion_name
|
278
|
+
|
263
279
|
|
264
280
|
class SingleClassRDR(RDRWithCodeWriter):
|
265
281
|
|
@@ -274,23 +290,20 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
274
290
|
:return: The category that the case belongs to.
|
275
291
|
"""
|
276
292
|
expert = expert if expert else Human(session=self.session)
|
277
|
-
case, attribute = case_query.case, case_query.attribute
|
278
293
|
if case_query.target is None:
|
279
294
|
target = expert.ask_for_conclusion(case_query)
|
280
|
-
else:
|
281
|
-
target = case_query.target
|
282
|
-
|
283
295
|
if not self.start_rule:
|
284
|
-
conditions = expert.ask_for_conditions(
|
285
|
-
self.start_rule = SingleClassRule(conditions, target, corner_case=case
|
296
|
+
conditions = expert.ask_for_conditions(case_query)
|
297
|
+
self.start_rule = SingleClassRule(conditions, case_query.target, corner_case=case_query.case,
|
298
|
+
conclusion_name=case_query.attribute_name)
|
286
299
|
|
287
|
-
pred = self.evaluate(case)
|
300
|
+
pred = self.evaluate(case_query.case)
|
288
301
|
|
289
|
-
if pred.conclusion != target:
|
290
|
-
conditions = expert.ask_for_conditions(
|
291
|
-
pred.fit_rule(case, target, conditions=conditions)
|
302
|
+
if pred.conclusion != case_query.target:
|
303
|
+
conditions = expert.ask_for_conditions(case_query, pred)
|
304
|
+
pred.fit_rule(case_query.case, case_query.target, conditions=conditions)
|
292
305
|
|
293
|
-
return self.classify(case)
|
306
|
+
return self.classify(case_query.case)
|
294
307
|
|
295
308
|
def classify(self, case: Case) -> Optional[CaseAttribute]:
|
296
309
|
"""
|
@@ -388,44 +401,41 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
388
401
|
:return: The conclusions that the case belongs to.
|
389
402
|
"""
|
390
403
|
expert = expert if expert else Human(session=self.session)
|
391
|
-
case = case_query.case
|
392
404
|
if case_query.target is None:
|
393
|
-
targets =
|
394
|
-
else:
|
395
|
-
targets = [case_query.target]
|
405
|
+
targets = expert.ask_for_conclusion(case_query)
|
396
406
|
self.expert_accepted_conclusions = []
|
397
407
|
user_conclusions = []
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
408
|
+
self.update_start_rule(case_query, expert)
|
409
|
+
self.conclusions = []
|
410
|
+
self.stop_rule_conditions = None
|
411
|
+
evaluated_rule = self.start_rule
|
412
|
+
while evaluated_rule:
|
413
|
+
next_rule = evaluated_rule(case_query.case)
|
414
|
+
good_conclusions = make_list(case_query.target) + user_conclusions + self.expert_accepted_conclusions
|
415
|
+
good_conclusions = make_set(good_conclusions)
|
416
|
+
|
417
|
+
if evaluated_rule.fired:
|
418
|
+
if case_query.target and not make_set(evaluated_rule.conclusion).issubset(good_conclusions):
|
419
|
+
# if self.case_has_conclusion(case, evaluated_rule.conclusion):
|
420
|
+
# Rule fired and conclusion is different from target
|
421
|
+
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
|
422
|
+
add_extra_conclusions)
|
423
|
+
else:
|
424
|
+
# Rule fired and target is correct or there is no target to compare
|
425
|
+
self.add_conclusion(evaluated_rule)
|
426
|
+
|
427
|
+
if not next_rule:
|
428
|
+
if not make_set(case_query.target).intersection(make_set(self.conclusions)):
|
429
|
+
# Nothing fired and there is a target that should have been in the conclusions
|
430
|
+
self.add_rule_for_case(case_query, expert)
|
431
|
+
# Have to check all rules again to make sure only this new rule fires
|
432
|
+
next_rule = self.start_rule
|
433
|
+
elif add_extra_conclusions and not user_conclusions:
|
434
|
+
# No more conclusions can be made, ask the expert for extra conclusions if needed.
|
435
|
+
user_conclusions.extend(self.ask_expert_for_extra_conclusions(expert, case_query.case))
|
436
|
+
if user_conclusions:
|
437
|
+
next_rule = self.last_top_rule
|
438
|
+
evaluated_rule = next_rule
|
429
439
|
return self.conclusions
|
430
440
|
|
431
441
|
def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
|
@@ -458,19 +468,19 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
458
468
|
imports += "from typing_extensions import Set\n"
|
459
469
|
return imports
|
460
470
|
|
461
|
-
def update_start_rule(self,
|
471
|
+
def update_start_rule(self, case_query: CaseQuery, expert: Expert):
|
462
472
|
"""
|
463
473
|
Update the starting rule of the classifier.
|
464
474
|
|
465
|
-
:param
|
466
|
-
:param target: The target category to compare the case with.
|
475
|
+
:param case_query: The case query to update the starting rule with.
|
467
476
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
468
477
|
"""
|
469
478
|
if not self.start_rule.conditions:
|
470
|
-
conditions = expert.ask_for_conditions(
|
479
|
+
conditions = expert.ask_for_conditions(case_query)
|
471
480
|
self.start_rule.conditions = conditions
|
472
|
-
self.start_rule.conclusion = target
|
473
|
-
self.start_rule.corner_case = case
|
481
|
+
self.start_rule.conclusion = case_query.target
|
482
|
+
self.start_rule.corner_case = case_query.case
|
483
|
+
self.start_rule.conclusion_name = case_query.attribute_name
|
474
484
|
|
475
485
|
@property
|
476
486
|
def last_top_rule(self) -> Optional[MultiClassTopRule]:
|
@@ -482,35 +492,34 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
482
492
|
else:
|
483
493
|
return self.start_rule.furthest_alternative[-1]
|
484
494
|
|
485
|
-
def stop_wrong_conclusion_else_add_it(self,
|
495
|
+
def stop_wrong_conclusion_else_add_it(self, case_query: CaseQuery, expert: Expert,
|
486
496
|
evaluated_rule: MultiClassTopRule,
|
487
497
|
add_extra_conclusions: bool):
|
488
498
|
"""
|
489
499
|
Stop a wrong conclusion by adding a stopping rule.
|
490
500
|
"""
|
491
|
-
if self.is_same_category_type(evaluated_rule.conclusion, target) \
|
492
|
-
and self.is_conflicting_with_target(evaluated_rule.conclusion, target):
|
493
|
-
self.stop_conclusion(
|
494
|
-
elif not self.conclusion_is_correct(
|
495
|
-
self.stop_conclusion(
|
501
|
+
if self.is_same_category_type(evaluated_rule.conclusion, case_query.target) \
|
502
|
+
and self.is_conflicting_with_target(evaluated_rule.conclusion, case_query.target):
|
503
|
+
self.stop_conclusion(case_query, expert, evaluated_rule)
|
504
|
+
elif not self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
|
505
|
+
self.stop_conclusion(case_query, expert, evaluated_rule)
|
496
506
|
|
497
|
-
def stop_conclusion(self,
|
507
|
+
def stop_conclusion(self, case_query: CaseQuery,
|
498
508
|
expert: Expert, evaluated_rule: MultiClassTopRule):
|
499
509
|
"""
|
500
510
|
Stop a conclusion by adding a stopping rule.
|
501
511
|
|
502
|
-
:param
|
503
|
-
:param target: The target category to compare the case with.
|
512
|
+
:param case_query: The case query to stop the conclusion for.
|
504
513
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
505
514
|
:param evaluated_rule: The evaluated rule to ask the expert about.
|
506
515
|
"""
|
507
|
-
conditions = expert.ask_for_conditions(
|
508
|
-
evaluated_rule.fit_rule(case, target, conditions=conditions)
|
516
|
+
conditions = expert.ask_for_conditions(case_query, evaluated_rule)
|
517
|
+
evaluated_rule.fit_rule(case_query.case, case_query.target, conditions=conditions)
|
509
518
|
if self.mode == MCRDRMode.StopPlusRule:
|
510
519
|
self.stop_rule_conditions = conditions
|
511
520
|
if self.mode == MCRDRMode.StopPlusRuleCombined:
|
512
521
|
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
513
|
-
self.add_top_rule(new_top_rule_conditions, target, case)
|
522
|
+
self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
|
514
523
|
|
515
524
|
@staticmethod
|
516
525
|
def is_conflicting_with_target(conclusion: Any, target: Any) -> bool:
|
@@ -537,37 +546,40 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
537
546
|
"""
|
538
547
|
return conclusion.__class__ == target.__class__ and target.__class__ != CaseAttribute
|
539
548
|
|
540
|
-
def conclusion_is_correct(self,
|
549
|
+
def conclusion_is_correct(self, case_query: CaseQuery,
|
550
|
+
expert: Expert, evaluated_rule: Rule,
|
541
551
|
add_extra_conclusions: bool) -> bool:
|
542
552
|
"""
|
543
553
|
Ask the expert if the conclusion is correct, and add it to the conclusions if it is.
|
544
554
|
|
545
|
-
:param
|
546
|
-
:param target: The target category to compare the case with.
|
555
|
+
:param case_query: The case query to ask the expert about.
|
547
556
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
548
557
|
:param evaluated_rule: The evaluated rule to ask the expert about.
|
549
558
|
:param add_extra_conclusions: Whether adding extra conclusions after classification is allowed.
|
550
559
|
:return: Whether the conclusion is correct or not.
|
551
560
|
"""
|
552
|
-
conclusions =
|
553
|
-
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case, evaluated_rule.conclusion,
|
554
|
-
targets=target,
|
561
|
+
conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
|
562
|
+
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case, evaluated_rule.conclusion,
|
563
|
+
targets=case_query.target,
|
555
564
|
current_conclusions=conclusions)):
|
556
565
|
self.add_conclusion(evaluated_rule)
|
557
566
|
self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
|
558
567
|
return True
|
559
568
|
return False
|
560
569
|
|
561
|
-
def add_rule_for_case(self,
|
570
|
+
def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
|
562
571
|
"""
|
563
572
|
Add a rule for a case that has not been classified with any conclusion.
|
573
|
+
|
574
|
+
:param case_query: The case query to add the rule for.
|
575
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
564
576
|
"""
|
565
577
|
if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
|
566
578
|
conditions = self.stop_rule_conditions
|
567
579
|
self.stop_rule_conditions = None
|
568
580
|
else:
|
569
|
-
conditions = expert.ask_for_conditions(
|
570
|
-
self.add_top_rule(conditions, target, case)
|
581
|
+
conditions = expert.ask_for_conditions(case_query)
|
582
|
+
self.add_top_rule(conditions, case_query.target, case_query.case)
|
571
583
|
|
572
584
|
def ask_expert_for_extra_conclusions(self, expert: Expert, case: Union[Case, SQLTable]) -> List[Any]:
|
573
585
|
"""
|
@@ -641,20 +653,31 @@ class GeneralRDR(RippleDownRules):
|
|
641
653
|
gets called when the final rule fires.
|
642
654
|
"""
|
643
655
|
|
644
|
-
def __init__(self, category_rdr_map: Optional[Dict[
|
656
|
+
def __init__(self, category_rdr_map: Optional[Dict[str, Union[SingleClassRDR, MultiClassRDR]]] = None):
|
645
657
|
"""
|
646
|
-
:param category_rdr_map: A map of
|
658
|
+
:param category_rdr_map: A map of case attribute names to ripple down rules classifiers,
|
647
659
|
where each category is a parent category that has a set of mutually exclusive (in case of SCRDR) child
|
648
|
-
categories, e.g. {
|
649
|
-
and MCRDR are SingleClass and MultiClass ripple down rules classifiers.
|
650
|
-
Mammal, Bird, Fish, etc. which are mutually exclusive,
|
651
|
-
Land, Water, Air, etc
|
660
|
+
categories, e.g. {'species': SCRDR, 'habitats': MCRDR}, where 'species' and 'habitats' are attribute names
|
661
|
+
for a case of type Animal, while SCRDR and MCRDR are SingleClass and MultiClass ripple down rules classifiers.
|
662
|
+
Species can have values like Mammal, Bird, Fish, etc. which are mutually exclusive, while Habitat can have
|
663
|
+
values like Land, Water, Air, etc., which are not mutually exclusive due to some animals living more than one
|
664
|
+
habitat.
|
652
665
|
"""
|
653
|
-
self.start_rules_dict: Dict[
|
666
|
+
self.start_rules_dict: Dict[str, Union[SingleClassRDR, MultiClassRDR]] \
|
654
667
|
= category_rdr_map if category_rdr_map else {}
|
655
668
|
super(GeneralRDR, self).__init__()
|
656
669
|
self.all_figs: List[plt.Figure] = [sr.fig for sr in self.start_rules_dict.values()]
|
657
670
|
|
671
|
+
def add_rdr(self, rdr: Union[SingleClassRDR, MultiClassRDR], attribute_name: Optional[str] = None):
|
672
|
+
"""
|
673
|
+
Add a ripple down rules classifier to the map of classifiers.
|
674
|
+
|
675
|
+
:param rdr: The ripple down rules classifier to add.
|
676
|
+
:param attribute_name: The name of the attribute that the classifier is classifying.
|
677
|
+
"""
|
678
|
+
attribute_name = attribute_name if attribute_name else rdr.attribute_name
|
679
|
+
self.start_rules_dict[attribute_name] = rdr
|
680
|
+
|
658
681
|
@property
|
659
682
|
def start_rule(self) -> Optional[Union[SingleClassRule, MultiClassTopRule]]:
|
660
683
|
return self.start_rules[0] if self.start_rules_dict else None
|
@@ -662,7 +685,7 @@ class GeneralRDR(RippleDownRules):
|
|
662
685
|
@start_rule.setter
|
663
686
|
def start_rule(self, value: Union[SingleClassRDR, MultiClassRDR]):
|
664
687
|
if value:
|
665
|
-
self.start_rules_dict[value.
|
688
|
+
self.start_rules_dict[value.attribute_name] = value
|
666
689
|
|
667
690
|
@property
|
668
691
|
def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
|
@@ -679,8 +702,8 @@ class GeneralRDR(RippleDownRules):
|
|
679
702
|
return self._classify(self.start_rules_dict, case)
|
680
703
|
|
681
704
|
@staticmethod
|
682
|
-
def _classify(classifiers_dict: Dict[
|
683
|
-
case: Union[Case, SQLTable]) -> Optional[
|
705
|
+
def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
|
706
|
+
case: Union[Case, SQLTable]) -> Optional[Dict[str, Any]]:
|
684
707
|
"""
|
685
708
|
Classify a case by going through all classifiers and adding the categories that are classified,
|
686
709
|
and then restarting the classification until no more categories can be added.
|
@@ -689,21 +712,31 @@ class GeneralRDR(RippleDownRules):
|
|
689
712
|
:param case: The case to classify.
|
690
713
|
:return: The categories that the case belongs to.
|
691
714
|
"""
|
692
|
-
conclusions =
|
715
|
+
conclusions = {}
|
693
716
|
case_cp = copy_case(case)
|
694
717
|
while True:
|
695
|
-
|
696
|
-
for
|
697
|
-
if GeneralRDR.case_has_conclusion(case_cp, cat_type):
|
698
|
-
continue
|
718
|
+
new_conclusions = {}
|
719
|
+
for attribute_name, rdr in classifiers_dict.items():
|
699
720
|
pred_atts = rdr.classify(case_cp)
|
700
|
-
if pred_atts:
|
721
|
+
if pred_atts is None:
|
722
|
+
continue
|
723
|
+
if isinstance(rdr, SingleClassRDR):
|
724
|
+
if attribute_name not in conclusions or \
|
725
|
+
(attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
|
726
|
+
conclusions[attribute_name] = pred_atts
|
727
|
+
new_conclusions[attribute_name] = pred_atts
|
728
|
+
else:
|
701
729
|
pred_atts = make_list(pred_atts)
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
730
|
+
if attribute_name in conclusions:
|
731
|
+
pred_atts = [p for p in pred_atts if p not in conclusions[attribute_name]]
|
732
|
+
if len(pred_atts) > 0:
|
733
|
+
new_conclusions[attribute_name] = pred_atts
|
734
|
+
if attribute_name not in conclusions:
|
735
|
+
conclusions[attribute_name] = []
|
736
|
+
conclusions[attribute_name].extend(pred_atts)
|
737
|
+
if attribute_name in new_conclusions:
|
738
|
+
GeneralRDR.update_case(case_cp, new_conclusions)
|
739
|
+
if len(new_conclusions) == 0:
|
707
740
|
break
|
708
741
|
return conclusions
|
709
742
|
|
@@ -728,103 +761,79 @@ class GeneralRDR(RippleDownRules):
|
|
728
761
|
case = case_queries[0].case
|
729
762
|
assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
|
730
763
|
" for multiple cases use fit instead")
|
731
|
-
|
732
|
-
case_cp = case_query_cp.case
|
764
|
+
case_cp = copy(case_queries[0]).case
|
733
765
|
for case_query in case_queries:
|
734
|
-
|
735
|
-
|
766
|
+
case_query_cp = copy(case_query)
|
767
|
+
case_query_cp.case = case_cp
|
768
|
+
if case_query.target is None:
|
769
|
+
conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
|
736
770
|
target = expert.ask_for_conclusion(case_query)
|
737
|
-
|
738
|
-
if
|
739
|
-
target_type = type(make_list(target)[0])
|
740
|
-
assert all([type(t) is target_type for t in target]), ("All targets of a case query must be of the same"
|
741
|
-
" type")
|
742
|
-
else:
|
743
|
-
target_type = type(target)
|
744
|
-
if target_type not in self.start_rules_dict:
|
771
|
+
|
772
|
+
if case_query.attribute_name not in self.start_rules_dict:
|
745
773
|
conclusions = self.classify(case)
|
746
774
|
self.update_case(case_cp, conclusions)
|
747
|
-
|
775
|
+
|
776
|
+
new_rdr = self.initialize_new_rdr_for_attribute(case_query.attribute_name, case_cp, case_query.target)
|
777
|
+
self.add_rdr(new_rdr, case_query.attribute_name)
|
778
|
+
|
748
779
|
new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
|
749
|
-
self.
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
if target_type is not rdr_type:
|
780
|
+
self.update_case(case_cp, {case_query.attribute_name: new_conclusions})
|
781
|
+
else:
|
782
|
+
for rdr_attribute_name, rdr in self.start_rules_dict.items():
|
783
|
+
if case_query.attribute_name != rdr_attribute_name:
|
754
784
|
conclusions = rdr.classify(case_cp)
|
755
785
|
else:
|
756
|
-
conclusions = self.start_rules_dict[
|
757
|
-
|
758
|
-
|
786
|
+
conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
|
787
|
+
**kwargs)
|
788
|
+
if conclusions is not None or (is_iterable(conclusions) and len(conclusions) > 0):
|
789
|
+
conclusions = {rdr_attribute_name: conclusions}
|
790
|
+
self.update_case(case_cp, conclusions)
|
759
791
|
|
760
792
|
return self.classify(case)
|
761
793
|
|
762
794
|
@staticmethod
|
763
|
-
def initialize_new_rdr_for_attribute(
|
795
|
+
def initialize_new_rdr_for_attribute(attribute_name: str, case: Union[Case, SQLTable], target: Any):
|
764
796
|
"""
|
765
797
|
Initialize the appropriate RDR type for the target.
|
766
798
|
"""
|
767
|
-
if
|
768
|
-
|
769
|
-
if hasattr(prop, "__iter__") and not isinstance(prop, str):
|
770
|
-
return MultiClassRDR()
|
771
|
-
else:
|
772
|
-
return SingleClassRDR()
|
773
|
-
elif isinstance(attribute, CaseAttribute):
|
799
|
+
attribute = getattr(case, attribute_name) if hasattr(case, attribute_name) else target
|
800
|
+
if isinstance(attribute, CaseAttribute):
|
774
801
|
return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
|
775
802
|
else:
|
776
|
-
return MultiClassRDR() if is_iterable(attribute) else SingleClassRDR()
|
803
|
+
return MultiClassRDR() if is_iterable(attribute) or (attribute is None) else SingleClassRDR()
|
777
804
|
|
778
805
|
@staticmethod
|
779
|
-
def update_case(case: Union[Case, SQLTable],
|
780
|
-
conclusions: List[Any], attribute_type: Optional[Any] = None):
|
806
|
+
def update_case(case: Union[Case, SQLTable], conclusions: Dict[str, Any]):
|
781
807
|
"""
|
782
808
|
Update the case with the conclusions.
|
783
809
|
|
784
810
|
:param case: The case to update.
|
785
811
|
:param conclusions: The conclusions to update the case with.
|
786
|
-
:param attribute_type: The type of the attribute to update.
|
787
812
|
"""
|
788
813
|
if not conclusions:
|
789
814
|
return
|
790
|
-
conclusions = [conclusions] if not isinstance(conclusions, list) else list(conclusions)
|
791
815
|
if len(conclusions) == 0:
|
792
816
|
return
|
793
817
|
if isinstance(case, SQLTable):
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
818
|
+
for conclusion_name, conclusion in conclusions.items():
|
819
|
+
hint, origin, args = get_hint_for_attribute(conclusion_name, case)
|
820
|
+
attribute = getattr(case, conclusion_name)
|
821
|
+
if isinstance(attribute, set) or origin in {Set, set}:
|
822
|
+
attribute = set() if attribute is None else attribute
|
823
|
+
for c in conclusion:
|
824
|
+
attribute.update(make_set(c))
|
825
|
+
elif isinstance(attribute, list) or origin in {list, List}:
|
826
|
+
attribute = [] if attribute is None else attribute
|
827
|
+
attribute.extend(conclusion)
|
828
|
+
elif (not is_iterable(conclusion) or (len(conclusion) == 1)) and hint == type(conclusion):
|
829
|
+
setattr(case, conclusion_name, conclusion)
|
830
|
+
else:
|
831
|
+
raise ValueError(f"Cannot add multiple conclusions to attribute {conclusion_name}")
|
808
832
|
else:
|
809
|
-
|
810
|
-
case.update(c.as_dict)
|
811
|
-
|
812
|
-
@property
|
813
|
-
def names_of_all_types(self) -> List[str]:
|
814
|
-
"""
|
815
|
-
Get the names of all the types of categories that the GRDR can classify.
|
816
|
-
"""
|
817
|
-
return [t.__name__ for t in self.start_rules_dict.keys()]
|
818
|
-
|
819
|
-
@property
|
820
|
-
def all_types(self) -> List[Type]:
|
821
|
-
"""
|
822
|
-
Get all the types of categories that the GRDR can classify.
|
823
|
-
"""
|
824
|
-
return list(self.start_rules_dict.keys())
|
833
|
+
case.update(conclusions)
|
825
834
|
|
826
835
|
def _to_json(self) -> Dict[str, Any]:
|
827
|
-
return {"start_rules": {
|
836
|
+
return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
|
828
837
|
|
829
838
|
@classmethod
|
830
839
|
def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
|
@@ -833,7 +842,6 @@ class GeneralRDR(RippleDownRules):
|
|
833
842
|
"""
|
834
843
|
start_rules_dict = {}
|
835
844
|
for k, v in data["start_rules"].items():
|
836
|
-
k = get_type_from_string(k)
|
837
845
|
start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
|
838
846
|
return cls(start_rules_dict)
|
839
847
|
|
@@ -849,12 +857,12 @@ class GeneralRDR(RippleDownRules):
|
|
849
857
|
with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
|
850
858
|
f.write(self._get_imports(file_path) + "\n\n")
|
851
859
|
f.write("classifiers_dict = dict()\n")
|
852
|
-
for
|
853
|
-
f.write(f"classifiers_dict[{
|
860
|
+
for rdr_key, rdr in self.start_rules_dict.items():
|
861
|
+
f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
|
854
862
|
f.write("\n\n")
|
855
863
|
f.write(func_def)
|
856
864
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
857
|
-
f"{' ' * 4} case = create_case(case,
|
865
|
+
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
858
866
|
f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
|
859
867
|
|
860
868
|
@property
|
@@ -863,7 +871,7 @@ class GeneralRDR(RippleDownRules):
|
|
863
871
|
:return: The type of the case (input) to the RDR classifier.
|
864
872
|
"""
|
865
873
|
if isinstance(self.start_rule.corner_case, Case):
|
866
|
-
return self.start_rule.corner_case.
|
874
|
+
return self.start_rule.corner_case._obj_type
|
867
875
|
else:
|
868
876
|
return type(self.start_rule.corner_case)
|
869
877
|
|
@@ -876,7 +884,7 @@ class GeneralRDR(RippleDownRules):
|
|
876
884
|
|
877
885
|
@property
|
878
886
|
def generated_python_file_name(self) -> str:
|
879
|
-
return f"{self.
|
887
|
+
return f"{self.start_rule.corner_case._name.lower()}_rdr"
|
880
888
|
|
881
889
|
@property
|
882
890
|
def conclusion_type_hint(self) -> str:
|
@@ -892,10 +900,20 @@ class GeneralRDR(RippleDownRules):
|
|
892
900
|
imports += f"from ripple_down_rules.datastructures import Case, create_case\n"
|
893
901
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|
894
902
|
# add conclusion type imports
|
895
|
-
for
|
896
|
-
imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
|
903
|
+
for rdr in self.start_rules_dict.values():
|
904
|
+
imports += f"from {rdr.conclusion_type.__module__} import {rdr.conclusion_type.__name__}\n"
|
897
905
|
# add rdr python generated functions.
|
898
|
-
for
|
906
|
+
for rdr_key, rdr in self.start_rules_dict.items():
|
899
907
|
imports += (f"from {file_path.strip('./')}"
|
900
|
-
f" import {rdr.generated_python_file_name} as {
|
908
|
+
f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
|
901
909
|
return imports
|
910
|
+
|
911
|
+
@staticmethod
|
912
|
+
def rdr_key_to_function_name(rdr_key: str) -> str:
|
913
|
+
"""
|
914
|
+
Convert the RDR key to a function name.
|
915
|
+
|
916
|
+
:param rdr_key: The RDR key to convert.
|
917
|
+
:return: The function name.
|
918
|
+
"""
|
919
|
+
return rdr_key.replace(".", "_").lower() + "_classifier"
|