ripple-down-rules 0.0.14__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +2 -2
- ripple_down_rules/datastructures/callable_expression.py +52 -10
- ripple_down_rules/datastructures/case.py +53 -70
- ripple_down_rules/datastructures/dataclasses.py +69 -29
- ripple_down_rules/experts.py +29 -40
- ripple_down_rules/helpers.py +27 -0
- ripple_down_rules/prompt.py +77 -24
- ripple_down_rules/rdr.py +298 -192
- ripple_down_rules/rdr_decorators.py +55 -0
- ripple_down_rules/rules.py +12 -3
- ripple_down_rules/utils.py +154 -3
- {ripple_down_rules-0.0.14.dist-info → ripple_down_rules-0.1.0.dist-info}/METADATA +1 -1
- ripple_down_rules-0.1.0.dist-info/RECORD +20 -0
- {ripple_down_rules-0.0.14.dist-info → ripple_down_rules-0.1.0.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.0.14.dist-info/RECORD +0 -18
- {ripple_down_rules-0.0.14.dist-info → ripple_down_rules-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.0.14.dist-info → ripple_down_rules-0.1.0.dist-info}/top_level.txt +0 -0
ripple_down_rules/rdr.py
CHANGED
@@ -2,19 +2,19 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import importlib
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
-
from copy import copy
|
5
|
+
from copy import copy
|
6
6
|
from types import ModuleType
|
7
7
|
|
8
8
|
from matplotlib import pyplot as plt
|
9
9
|
from ordered_set import OrderedSet
|
10
10
|
from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
|
11
|
-
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable
|
11
|
+
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
|
12
12
|
|
13
13
|
from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute, CaseQuery
|
14
14
|
from .experts import Expert, Human
|
15
15
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
16
|
-
from .utils import draw_tree, make_set,
|
17
|
-
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list,
|
16
|
+
from .utils import draw_tree, make_set, copy_case, \
|
17
|
+
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string
|
18
18
|
|
19
19
|
|
20
20
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -80,7 +80,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
80
80
|
:param kwargs_for_fit_case: The keyword arguments to pass to the fit_case method.
|
81
81
|
"""
|
82
82
|
cases = [case_query.case for case_query in case_queries]
|
83
|
-
targets = [
|
83
|
+
targets = [{case_query.attribute_name: case_query.target} for case_query in case_queries]
|
84
84
|
if animate_tree:
|
85
85
|
plt.ion()
|
86
86
|
i = 0
|
@@ -91,11 +91,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
91
91
|
if not targets:
|
92
92
|
targets = [None] * len(cases)
|
93
93
|
for case_query in case_queries:
|
94
|
-
|
95
|
-
target = case_query.target
|
96
|
-
if not target:
|
97
|
-
conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
|
98
|
-
target = expert.ask_for_conclusion(case_query, conclusions)
|
94
|
+
target = {case_query.attribute_name: case_query.target}
|
99
95
|
pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
|
100
96
|
match = self.is_matching(pred_cat, target)
|
101
97
|
if not match:
|
@@ -105,8 +101,9 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
105
101
|
num_rules = self.start_rule.size
|
106
102
|
self.update_figures()
|
107
103
|
i += 1
|
108
|
-
all_predictions = [1 if self.is_matching(self.classify(case),
|
109
|
-
|
104
|
+
all_predictions = [1 if self.is_matching(self.classify(case_query.case), {case_query.attribute_name:
|
105
|
+
case_query.target}) else 0
|
106
|
+
for case_query in case_queries]
|
110
107
|
all_pred = sum(all_predictions)
|
111
108
|
print(f"Accuracy: {all_pred}/{len(targets)}")
|
112
109
|
all_predicted = targets and all_pred == len(targets)
|
@@ -120,7 +117,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
120
117
|
plt.show()
|
121
118
|
|
122
119
|
@staticmethod
|
123
|
-
def calculate_precision_and_recall(pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> Tuple[
|
120
|
+
def calculate_precision_and_recall(pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> Tuple[
|
121
|
+
List[bool], List[bool]]:
|
124
122
|
"""
|
125
123
|
:param pred_cat: The predicted category.
|
126
124
|
:param target: The target category.
|
@@ -128,9 +126,33 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
128
126
|
"""
|
129
127
|
pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
|
130
128
|
target = target if is_iterable(target) else [target]
|
131
|
-
recall = [
|
132
|
-
|
133
|
-
|
129
|
+
recall = []
|
130
|
+
precision = []
|
131
|
+
if isinstance(pred_cat, dict):
|
132
|
+
for pred_key, pred_value in pred_cat.items():
|
133
|
+
if pred_key not in target:
|
134
|
+
continue
|
135
|
+
# if is_iterable(pred_value):
|
136
|
+
# print(pred_value, target[pred_key])
|
137
|
+
# precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
|
138
|
+
precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
|
139
|
+
# else:
|
140
|
+
# precision.append(pred_value == target[pred_key])
|
141
|
+
for target_key, target_value in target.items():
|
142
|
+
if target_key not in pred_cat:
|
143
|
+
recall.append(False)
|
144
|
+
continue
|
145
|
+
if is_iterable(target_value):
|
146
|
+
recall.extend([v in pred_cat[target_key] for v in target_value])
|
147
|
+
else:
|
148
|
+
recall.append(target_value == pred_cat[target_key])
|
149
|
+
print(f"Precision: {precision}, Recall: {recall}")
|
150
|
+
else:
|
151
|
+
if isinstance(target, dict):
|
152
|
+
target = list(target.values())
|
153
|
+
recall = [not yi or (yi in pred_cat) for yi in target]
|
154
|
+
target_types = [type(yi) for yi in target]
|
155
|
+
precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
|
134
156
|
return precision, recall
|
135
157
|
|
136
158
|
def is_matching(self, pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> bool:
|
@@ -157,22 +179,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
157
179
|
draw_tree(self.start_rule, self.fig)
|
158
180
|
|
159
181
|
@staticmethod
|
160
|
-
def case_has_conclusion(case: Union[Case, SQLTable],
|
182
|
+
def case_has_conclusion(case: Union[Case, SQLTable], conclusion_name: str) -> bool:
|
161
183
|
"""
|
162
184
|
Check if the case has a conclusion.
|
163
185
|
|
164
186
|
:param case: The case to check.
|
165
|
-
:param
|
187
|
+
:param conclusion_name: The target category name to compare the case with.
|
166
188
|
:return: Whether the case has a conclusion or not.
|
167
189
|
"""
|
168
|
-
|
169
|
-
prop_name, prop_value = get_attribute_by_type(case, conclusion_type)
|
170
|
-
if hasattr(prop_value, "__iter__") and not isinstance(prop_value, str):
|
171
|
-
return len(prop_value) > 0
|
172
|
-
else:
|
173
|
-
return prop_value is not None
|
174
|
-
else:
|
175
|
-
return conclusion_type in case
|
190
|
+
return hasattr(case, conclusion_name) and getattr(case, conclusion_name) is not None
|
176
191
|
|
177
192
|
|
178
193
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
@@ -194,16 +209,17 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
194
209
|
|
195
210
|
:param file_path: The path to the file to write the source code to.
|
196
211
|
"""
|
197
|
-
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.
|
212
|
+
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
198
213
|
with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
|
199
214
|
f.write(self._get_imports() + "\n\n")
|
200
215
|
f.write(func_def)
|
201
|
-
f.write(f"{' '*4}if not isinstance(case, Case):\n"
|
202
|
-
f"{' '*4} case = create_case(case, recursion_idx=3)\n""")
|
216
|
+
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
217
|
+
f"{' ' * 4} case = create_case(case, recursion_idx=3)\n""")
|
203
218
|
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
|
204
219
|
|
220
|
+
@property
|
205
221
|
@abstractmethod
|
206
|
-
def
|
222
|
+
def conclusion_type_hint(self) -> str:
|
207
223
|
"""
|
208
224
|
:return: The type hint of the conclusion of the rdr as a string.
|
209
225
|
"""
|
@@ -242,7 +258,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
242
258
|
:return: The type of the case (input) to the RDR classifier.
|
243
259
|
"""
|
244
260
|
if isinstance(self.start_rule.corner_case, Case):
|
245
|
-
return self.start_rule.corner_case.
|
261
|
+
return self.start_rule.corner_case._obj_type
|
246
262
|
else:
|
247
263
|
return type(self.start_rule.corner_case)
|
248
264
|
|
@@ -254,8 +270,17 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
254
270
|
if isinstance(self.start_rule.conclusion, CallableExpression):
|
255
271
|
return self.start_rule.conclusion.conclusion_type
|
256
272
|
else:
|
273
|
+
if isinstance(self.start_rule.conclusion, set):
|
274
|
+
return type(list(self.start_rule.conclusion)[0])
|
257
275
|
return type(self.start_rule.conclusion)
|
258
276
|
|
277
|
+
@property
|
278
|
+
def attribute_name(self) -> str:
|
279
|
+
"""
|
280
|
+
:return: The name of the attribute that the classifier is classifying.
|
281
|
+
"""
|
282
|
+
return self.start_rule.conclusion_name
|
283
|
+
|
259
284
|
|
260
285
|
class SingleClassRDR(RDRWithCodeWriter):
|
261
286
|
|
@@ -270,23 +295,20 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
270
295
|
:return: The category that the case belongs to.
|
271
296
|
"""
|
272
297
|
expert = expert if expert else Human(session=self.session)
|
273
|
-
case, attribute = case_query.case, case_query.attribute
|
274
298
|
if case_query.target is None:
|
275
299
|
target = expert.ask_for_conclusion(case_query)
|
276
|
-
else:
|
277
|
-
target = case_query.target
|
278
|
-
|
279
300
|
if not self.start_rule:
|
280
|
-
conditions = expert.ask_for_conditions(
|
281
|
-
self.start_rule = SingleClassRule(conditions, target, corner_case=case
|
301
|
+
conditions = expert.ask_for_conditions(case_query)
|
302
|
+
self.start_rule = SingleClassRule(conditions, case_query.target, corner_case=case_query.case,
|
303
|
+
conclusion_name=case_query.attribute_name)
|
282
304
|
|
283
|
-
pred = self.evaluate(case)
|
305
|
+
pred = self.evaluate(case_query.case)
|
284
306
|
|
285
|
-
if pred.conclusion != target:
|
286
|
-
conditions = expert.ask_for_conditions(
|
287
|
-
pred.fit_rule(case, target, conditions=conditions)
|
307
|
+
if pred.conclusion != case_query.target:
|
308
|
+
conditions = expert.ask_for_conditions(case_query, pred)
|
309
|
+
pred.fit_rule(case_query.case, case_query.target, conditions=conditions)
|
288
310
|
|
289
|
-
return self.classify(case)
|
311
|
+
return self.classify(case_query.case)
|
290
312
|
|
291
313
|
def classify(self, case: Case) -> Optional[CaseAttribute]:
|
292
314
|
"""
|
@@ -316,7 +338,8 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
316
338
|
if rule.alternative:
|
317
339
|
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
|
318
340
|
|
319
|
-
|
341
|
+
@property
|
342
|
+
def conclusion_type_hint(self) -> str:
|
320
343
|
return self.conclusion_type.__name__
|
321
344
|
|
322
345
|
def _to_json(self) -> Dict[str, Any]:
|
@@ -383,50 +406,49 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
383
406
|
:return: The conclusions that the case belongs to.
|
384
407
|
"""
|
385
408
|
expert = expert if expert else Human(session=self.session)
|
386
|
-
case = case_query.case
|
387
409
|
if case_query.target is None:
|
388
|
-
targets =
|
389
|
-
else:
|
390
|
-
targets = [case_query.target]
|
410
|
+
targets = expert.ask_for_conclusion(case_query)
|
391
411
|
self.expert_accepted_conclusions = []
|
392
412
|
user_conclusions = []
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
413
|
+
self.update_start_rule(case_query, expert)
|
414
|
+
self.conclusions = []
|
415
|
+
self.stop_rule_conditions = None
|
416
|
+
evaluated_rule = self.start_rule
|
417
|
+
while evaluated_rule:
|
418
|
+
next_rule = evaluated_rule(case_query.case)
|
419
|
+
good_conclusions = make_list(case_query.target) + user_conclusions + self.expert_accepted_conclusions
|
420
|
+
good_conclusions = make_set(good_conclusions)
|
421
|
+
|
422
|
+
if evaluated_rule.fired:
|
423
|
+
if case_query.target and not make_set(evaluated_rule.conclusion).issubset(good_conclusions):
|
424
|
+
# if self.case_has_conclusion(case, evaluated_rule.conclusion):
|
425
|
+
# Rule fired and conclusion is different from target
|
426
|
+
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
|
427
|
+
add_extra_conclusions)
|
428
|
+
else:
|
429
|
+
# Rule fired and target is correct or there is no target to compare
|
430
|
+
self.add_conclusion(evaluated_rule)
|
431
|
+
|
432
|
+
if not next_rule:
|
433
|
+
if not make_set(case_query.target).intersection(make_set(self.conclusions)):
|
434
|
+
# Nothing fired and there is a target that should have been in the conclusions
|
435
|
+
self.add_rule_for_case(case_query, expert)
|
436
|
+
# Have to check all rules again to make sure only this new rule fires
|
437
|
+
next_rule = self.start_rule
|
438
|
+
elif add_extra_conclusions and not user_conclusions:
|
439
|
+
# No more conclusions can be made, ask the expert for extra conclusions if needed.
|
440
|
+
user_conclusions.extend(self.ask_expert_for_extra_conclusions(expert, case_query.case))
|
441
|
+
if user_conclusions:
|
442
|
+
next_rule = self.last_top_rule
|
443
|
+
evaluated_rule = next_rule
|
424
444
|
return self.conclusions
|
425
445
|
|
426
446
|
def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
|
427
447
|
file, parent_indent: str = ""):
|
428
448
|
"""
|
429
449
|
Write the rules as source code to a file.
|
450
|
+
|
451
|
+
:
|
430
452
|
"""
|
431
453
|
if rule == self.start_rule:
|
432
454
|
file.write(f"{parent_indent}conclusions = set()\n")
|
@@ -435,14 +457,15 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
435
457
|
conclusion_indent = parent_indent
|
436
458
|
if hasattr(rule, "refinement") and rule.refinement:
|
437
459
|
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ")
|
438
|
-
conclusion_indent = parent_indent + " "*4
|
460
|
+
conclusion_indent = parent_indent + " " * 4
|
439
461
|
file.write(f"{conclusion_indent}else:\n")
|
440
462
|
file.write(rule.write_conclusion_as_source_code(conclusion_indent))
|
441
463
|
|
442
464
|
if rule.alternative:
|
443
465
|
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
|
444
466
|
|
445
|
-
|
467
|
+
@property
|
468
|
+
def conclusion_type_hint(self) -> str:
|
446
469
|
return f"Set[{self.conclusion_type.__name__}]"
|
447
470
|
|
448
471
|
def _get_imports(self) -> str:
|
@@ -450,19 +473,18 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
450
473
|
imports += "from typing_extensions import Set\n"
|
451
474
|
return imports
|
452
475
|
|
453
|
-
def update_start_rule(self,
|
476
|
+
def update_start_rule(self, case_query: CaseQuery, expert: Expert):
|
454
477
|
"""
|
455
478
|
Update the starting rule of the classifier.
|
456
479
|
|
457
|
-
:param
|
458
|
-
:param target: The target category to compare the case with.
|
480
|
+
:param case_query: The case query to update the starting rule with.
|
459
481
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
460
482
|
"""
|
461
483
|
if not self.start_rule.conditions:
|
462
|
-
conditions = expert.ask_for_conditions(
|
484
|
+
conditions = expert.ask_for_conditions(case_query)
|
463
485
|
self.start_rule.conditions = conditions
|
464
|
-
self.start_rule.conclusion = target
|
465
|
-
self.start_rule.corner_case = case
|
486
|
+
self.start_rule.conclusion = case_query.target
|
487
|
+
self.start_rule.corner_case = case_query.case
|
466
488
|
|
467
489
|
@property
|
468
490
|
def last_top_rule(self) -> Optional[MultiClassTopRule]:
|
@@ -474,35 +496,34 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
474
496
|
else:
|
475
497
|
return self.start_rule.furthest_alternative[-1]
|
476
498
|
|
477
|
-
def stop_wrong_conclusion_else_add_it(self,
|
499
|
+
def stop_wrong_conclusion_else_add_it(self, case_query: CaseQuery, expert: Expert,
|
478
500
|
evaluated_rule: MultiClassTopRule,
|
479
501
|
add_extra_conclusions: bool):
|
480
502
|
"""
|
481
503
|
Stop a wrong conclusion by adding a stopping rule.
|
482
504
|
"""
|
483
|
-
if self.is_same_category_type(evaluated_rule.conclusion, target) \
|
484
|
-
and self.is_conflicting_with_target(evaluated_rule.conclusion, target):
|
485
|
-
self.stop_conclusion(
|
486
|
-
elif not self.conclusion_is_correct(
|
487
|
-
self.stop_conclusion(
|
505
|
+
if self.is_same_category_type(evaluated_rule.conclusion, case_query.target) \
|
506
|
+
and self.is_conflicting_with_target(evaluated_rule.conclusion, case_query.target):
|
507
|
+
self.stop_conclusion(case_query, expert, evaluated_rule)
|
508
|
+
elif not self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
|
509
|
+
self.stop_conclusion(case_query, expert, evaluated_rule)
|
488
510
|
|
489
|
-
def stop_conclusion(self,
|
511
|
+
def stop_conclusion(self, case_query: CaseQuery,
|
490
512
|
expert: Expert, evaluated_rule: MultiClassTopRule):
|
491
513
|
"""
|
492
514
|
Stop a conclusion by adding a stopping rule.
|
493
515
|
|
494
|
-
:param
|
495
|
-
:param target: The target category to compare the case with.
|
516
|
+
:param case_query: The case query to stop the conclusion for.
|
496
517
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
497
518
|
:param evaluated_rule: The evaluated rule to ask the expert about.
|
498
519
|
"""
|
499
|
-
conditions = expert.ask_for_conditions(
|
500
|
-
evaluated_rule.fit_rule(case, target, conditions=conditions)
|
520
|
+
conditions = expert.ask_for_conditions(case_query, evaluated_rule)
|
521
|
+
evaluated_rule.fit_rule(case_query.case, case_query.target, conditions=conditions)
|
501
522
|
if self.mode == MCRDRMode.StopPlusRule:
|
502
523
|
self.stop_rule_conditions = conditions
|
503
524
|
if self.mode == MCRDRMode.StopPlusRuleCombined:
|
504
525
|
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
505
|
-
self.add_top_rule(new_top_rule_conditions, target, case)
|
526
|
+
self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
|
506
527
|
|
507
528
|
@staticmethod
|
508
529
|
def is_conflicting_with_target(conclusion: Any, target: Any) -> bool:
|
@@ -529,37 +550,40 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
529
550
|
"""
|
530
551
|
return conclusion.__class__ == target.__class__ and target.__class__ != CaseAttribute
|
531
552
|
|
532
|
-
def conclusion_is_correct(self,
|
553
|
+
def conclusion_is_correct(self, case_query: CaseQuery,
|
554
|
+
expert: Expert, evaluated_rule: Rule,
|
533
555
|
add_extra_conclusions: bool) -> bool:
|
534
556
|
"""
|
535
557
|
Ask the expert if the conclusion is correct, and add it to the conclusions if it is.
|
536
558
|
|
537
|
-
:param
|
538
|
-
:param target: The target category to compare the case with.
|
559
|
+
:param case_query: The case query to ask the expert about.
|
539
560
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
540
561
|
:param evaluated_rule: The evaluated rule to ask the expert about.
|
541
562
|
:param add_extra_conclusions: Whether adding extra conclusions after classification is allowed.
|
542
563
|
:return: Whether the conclusion is correct or not.
|
543
564
|
"""
|
544
|
-
conclusions =
|
545
|
-
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case, evaluated_rule.conclusion,
|
546
|
-
targets=target,
|
565
|
+
conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
|
566
|
+
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case, evaluated_rule.conclusion,
|
567
|
+
targets=case_query.target,
|
547
568
|
current_conclusions=conclusions)):
|
548
569
|
self.add_conclusion(evaluated_rule)
|
549
570
|
self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
|
550
571
|
return True
|
551
572
|
return False
|
552
573
|
|
553
|
-
def add_rule_for_case(self,
|
574
|
+
def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
|
554
575
|
"""
|
555
576
|
Add a rule for a case that has not been classified with any conclusion.
|
577
|
+
|
578
|
+
:param case_query: The case query to add the rule for.
|
579
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
556
580
|
"""
|
557
581
|
if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
|
558
582
|
conditions = self.stop_rule_conditions
|
559
583
|
self.stop_rule_conditions = None
|
560
584
|
else:
|
561
|
-
conditions = expert.ask_for_conditions(
|
562
|
-
self.add_top_rule(conditions, target, case)
|
585
|
+
conditions = expert.ask_for_conditions(case_query)
|
586
|
+
self.add_top_rule(conditions, case_query.target, case_query.case)
|
563
587
|
|
564
588
|
def ask_expert_for_extra_conclusions(self, expert: Expert, case: Union[Case, SQLTable]) -> List[Any]:
|
565
589
|
"""
|
@@ -633,20 +657,31 @@ class GeneralRDR(RippleDownRules):
|
|
633
657
|
gets called when the final rule fires.
|
634
658
|
"""
|
635
659
|
|
636
|
-
def __init__(self, category_rdr_map: Optional[Dict[
|
660
|
+
def __init__(self, category_rdr_map: Optional[Dict[str, Union[SingleClassRDR, MultiClassRDR]]] = None):
|
637
661
|
"""
|
638
|
-
:param category_rdr_map: A map of
|
662
|
+
:param category_rdr_map: A map of case attribute names to ripple down rules classifiers,
|
639
663
|
where each category is a parent category that has a set of mutually exclusive (in case of SCRDR) child
|
640
|
-
categories, e.g. {
|
641
|
-
and MCRDR are SingleClass and MultiClass ripple down rules classifiers.
|
642
|
-
Mammal, Bird, Fish, etc. which are mutually exclusive,
|
643
|
-
Land, Water, Air, etc
|
664
|
+
categories, e.g. {'species': SCRDR, 'habitats': MCRDR}, where 'species' and 'habitats' are attribute names
|
665
|
+
for a case of type Animal, while SCRDR and MCRDR are SingleClass and MultiClass ripple down rules classifiers.
|
666
|
+
Species can have values like Mammal, Bird, Fish, etc. which are mutually exclusive, while Habitat can have
|
667
|
+
values like Land, Water, Air, etc., which are not mutually exclusive due to some animals living more than one
|
668
|
+
habitat.
|
644
669
|
"""
|
645
|
-
self.start_rules_dict: Dict[
|
670
|
+
self.start_rules_dict: Dict[str, Union[SingleClassRDR, MultiClassRDR]] \
|
646
671
|
= category_rdr_map if category_rdr_map else {}
|
647
672
|
super(GeneralRDR, self).__init__()
|
648
673
|
self.all_figs: List[plt.Figure] = [sr.fig for sr in self.start_rules_dict.values()]
|
649
674
|
|
675
|
+
def add_rdr(self, rdr: Union[SingleClassRDR, MultiClassRDR], attribute_name: Optional[str] = None):
|
676
|
+
"""
|
677
|
+
Add a ripple down rules classifier to the map of classifiers.
|
678
|
+
|
679
|
+
:param rdr: The ripple down rules classifier to add.
|
680
|
+
:param attribute_name: The name of the attribute that the classifier is classifying.
|
681
|
+
"""
|
682
|
+
attribute_name = attribute_name if attribute_name else rdr.attribute_name
|
683
|
+
self.start_rules_dict[attribute_name] = rdr
|
684
|
+
|
650
685
|
@property
|
651
686
|
def start_rule(self) -> Optional[Union[SingleClassRule, MultiClassTopRule]]:
|
652
687
|
return self.start_rules[0] if self.start_rules_dict else None
|
@@ -654,7 +689,7 @@ class GeneralRDR(RippleDownRules):
|
|
654
689
|
@start_rule.setter
|
655
690
|
def start_rule(self, value: Union[SingleClassRDR, MultiClassRDR]):
|
656
691
|
if value:
|
657
|
-
self.start_rules_dict[
|
692
|
+
self.start_rules_dict[value.attribute_name] = value
|
658
693
|
|
659
694
|
@property
|
660
695
|
def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
|
@@ -668,21 +703,44 @@ class GeneralRDR(RippleDownRules):
|
|
668
703
|
:param case: The case to classify.
|
669
704
|
:return: The categories that the case belongs to.
|
670
705
|
"""
|
671
|
-
|
706
|
+
return self._classify(self.start_rules_dict, case)
|
707
|
+
|
708
|
+
@staticmethod
|
709
|
+
def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
|
710
|
+
case: Union[Case, SQLTable]) -> Optional[Dict[str, Any]]:
|
711
|
+
"""
|
712
|
+
Classify a case by going through all classifiers and adding the categories that are classified,
|
713
|
+
and then restarting the classification until no more categories can be added.
|
714
|
+
|
715
|
+
:param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
|
716
|
+
:param case: The case to classify.
|
717
|
+
:return: The categories that the case belongs to.
|
718
|
+
"""
|
719
|
+
conclusions = {}
|
672
720
|
case_cp = copy_case(case)
|
673
721
|
while True:
|
674
|
-
|
675
|
-
for
|
676
|
-
if self.case_has_conclusion(case_cp, cat_type):
|
677
|
-
continue
|
722
|
+
new_conclusions = {}
|
723
|
+
for attribute_name, rdr in classifiers_dict.items():
|
678
724
|
pred_atts = rdr.classify(case_cp)
|
679
|
-
if pred_atts:
|
725
|
+
if pred_atts is None:
|
726
|
+
continue
|
727
|
+
if isinstance(rdr, SingleClassRDR):
|
728
|
+
if attribute_name not in conclusions or \
|
729
|
+
(attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
|
730
|
+
conclusions[attribute_name] = pred_atts
|
731
|
+
new_conclusions[attribute_name] = pred_atts
|
732
|
+
else:
|
680
733
|
pred_atts = make_list(pred_atts)
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
734
|
+
if attribute_name in conclusions:
|
735
|
+
pred_atts = [p for p in pred_atts if p not in conclusions[attribute_name]]
|
736
|
+
if len(pred_atts) > 0:
|
737
|
+
new_conclusions[attribute_name] = pred_atts
|
738
|
+
if attribute_name not in conclusions:
|
739
|
+
conclusions[attribute_name] = []
|
740
|
+
conclusions[attribute_name].extend(pred_atts)
|
741
|
+
if attribute_name in new_conclusions:
|
742
|
+
GeneralRDR.update_case(case_cp, new_conclusions)
|
743
|
+
if len(new_conclusions) == 0:
|
686
744
|
break
|
687
745
|
return conclusions
|
688
746
|
|
@@ -707,103 +765,79 @@ class GeneralRDR(RippleDownRules):
|
|
707
765
|
case = case_queries[0].case
|
708
766
|
assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
|
709
767
|
" for multiple cases use fit instead")
|
710
|
-
|
711
|
-
case_cp = case_query_cp.case
|
768
|
+
case_cp = copy(case_queries[0]).case
|
712
769
|
for case_query in case_queries:
|
713
|
-
|
714
|
-
|
770
|
+
case_query_cp = copy(case_query)
|
771
|
+
case_query_cp.case = case_cp
|
772
|
+
if case_query.target is None:
|
773
|
+
conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
|
715
774
|
target = expert.ask_for_conclusion(case_query)
|
716
|
-
|
717
|
-
if
|
718
|
-
target_type = type(make_list(target)[0])
|
719
|
-
assert all([type(t) is target_type for t in target]), ("All targets of a case query must be of the same"
|
720
|
-
" type")
|
721
|
-
else:
|
722
|
-
target_type = type(target)
|
723
|
-
if target_type not in self.start_rules_dict:
|
775
|
+
|
776
|
+
if case_query.attribute_name not in self.start_rules_dict:
|
724
777
|
conclusions = self.classify(case)
|
725
778
|
self.update_case(case_cp, conclusions)
|
726
|
-
|
779
|
+
|
780
|
+
new_rdr = self.initialize_new_rdr_for_attribute(case_query.attribute_name, case_cp, case_query.target)
|
781
|
+
self.add_rdr(new_rdr, case_query.attribute_name)
|
782
|
+
|
727
783
|
new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
|
728
|
-
self.
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
if target_type is not rdr_type:
|
784
|
+
self.update_case(case_cp, {case_query.attribute_name: new_conclusions})
|
785
|
+
else:
|
786
|
+
for rdr_attribute_name, rdr in self.start_rules_dict.items():
|
787
|
+
if case_query.attribute_name != rdr_attribute_name:
|
733
788
|
conclusions = rdr.classify(case_cp)
|
734
789
|
else:
|
735
|
-
conclusions = self.start_rules_dict[
|
736
|
-
|
737
|
-
|
790
|
+
conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
|
791
|
+
**kwargs)
|
792
|
+
if conclusions is not None or (is_iterable(conclusions) and len(conclusions) > 0):
|
793
|
+
conclusions = {rdr_attribute_name: conclusions}
|
794
|
+
self.update_case(case_cp, conclusions)
|
738
795
|
|
739
796
|
return self.classify(case)
|
740
797
|
|
741
798
|
@staticmethod
|
742
|
-
def initialize_new_rdr_for_attribute(
|
799
|
+
def initialize_new_rdr_for_attribute(attribute_name: str, case: Union[Case, SQLTable], target: Any):
|
743
800
|
"""
|
744
801
|
Initialize the appropriate RDR type for the target.
|
745
802
|
"""
|
746
|
-
if
|
747
|
-
|
748
|
-
if hasattr(prop, "__iter__") and not isinstance(prop, str):
|
749
|
-
return MultiClassRDR()
|
750
|
-
else:
|
751
|
-
return SingleClassRDR()
|
752
|
-
elif isinstance(attribute, CaseAttribute):
|
803
|
+
attribute = getattr(case, attribute_name) if hasattr(case, attribute_name) else target
|
804
|
+
if isinstance(attribute, CaseAttribute):
|
753
805
|
return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
|
754
806
|
else:
|
755
|
-
return MultiClassRDR() if is_iterable(attribute) else SingleClassRDR()
|
807
|
+
return MultiClassRDR() if is_iterable(attribute) or (attribute is None) else SingleClassRDR()
|
756
808
|
|
757
809
|
@staticmethod
|
758
|
-
def update_case(case: Union[Case, SQLTable],
|
759
|
-
conclusions: List[Any], attribute_type: Optional[Any] = None):
|
810
|
+
def update_case(case: Union[Case, SQLTable], conclusions: Dict[str, Any]):
|
760
811
|
"""
|
761
812
|
Update the case with the conclusions.
|
762
813
|
|
763
814
|
:param case: The case to update.
|
764
815
|
:param conclusions: The conclusions to update the case with.
|
765
|
-
:param attribute_type: The type of the attribute to update.
|
766
816
|
"""
|
767
817
|
if not conclusions:
|
768
818
|
return
|
769
|
-
conclusions = [conclusions] if not isinstance(conclusions, list) else list(conclusions)
|
770
819
|
if len(conclusions) == 0:
|
771
820
|
return
|
772
821
|
if isinstance(case, SQLTable):
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
822
|
+
for conclusion_name, conclusion in conclusions.items():
|
823
|
+
hint, origin, args = get_hint_for_attribute(conclusion_name, case)
|
824
|
+
attribute = getattr(case, conclusion_name)
|
825
|
+
if isinstance(attribute, set) or origin in {Set, set}:
|
826
|
+
attribute = set() if attribute is None else attribute
|
827
|
+
for c in conclusion:
|
828
|
+
attribute.update(make_set(c))
|
829
|
+
elif isinstance(attribute, list) or origin in {list, List}:
|
830
|
+
attribute = [] if attribute is None else attribute
|
831
|
+
attribute.extend(conclusion)
|
832
|
+
elif (not is_iterable(conclusion) or (len(conclusion) == 1)) and hint == type(conclusion):
|
833
|
+
setattr(case, conclusion_name, conclusion)
|
834
|
+
else:
|
835
|
+
raise ValueError(f"Cannot add multiple conclusions to attribute {conclusion_name}")
|
787
836
|
else:
|
788
|
-
|
789
|
-
case.update(c.as_dict)
|
790
|
-
|
791
|
-
@property
|
792
|
-
def names_of_all_types(self) -> List[str]:
|
793
|
-
"""
|
794
|
-
Get the names of all the types of categories that the GRDR can classify.
|
795
|
-
"""
|
796
|
-
return [t.__name__ for t in self.start_rules_dict.keys()]
|
797
|
-
|
798
|
-
@property
|
799
|
-
def all_types(self) -> List[Type]:
|
800
|
-
"""
|
801
|
-
Get all the types of categories that the GRDR can classify.
|
802
|
-
"""
|
803
|
-
return list(self.start_rules_dict.keys())
|
837
|
+
case.update(conclusions)
|
804
838
|
|
805
839
|
def _to_json(self) -> Dict[str, Any]:
|
806
|
-
return {"start_rules": {
|
840
|
+
return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
|
807
841
|
|
808
842
|
@classmethod
|
809
843
|
def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
|
@@ -812,6 +846,78 @@ class GeneralRDR(RippleDownRules):
|
|
812
846
|
"""
|
813
847
|
start_rules_dict = {}
|
814
848
|
for k, v in data["start_rules"].items():
|
815
|
-
k = get_type_from_string(k)
|
816
849
|
start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
|
817
850
|
return cls(start_rules_dict)
|
851
|
+
|
852
|
+
def write_to_python_file(self, file_path: str):
|
853
|
+
"""
|
854
|
+
Write the tree of rules as source code to a file.
|
855
|
+
|
856
|
+
:param file_path: The path to the file to write the source code to.
|
857
|
+
"""
|
858
|
+
for rdr in self.start_rules_dict.values():
|
859
|
+
rdr.write_to_python_file(file_path)
|
860
|
+
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
861
|
+
with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
|
862
|
+
f.write(self._get_imports(file_path) + "\n\n")
|
863
|
+
f.write("classifiers_dict = dict()\n")
|
864
|
+
for rdr_key, rdr in self.start_rules_dict.items():
|
865
|
+
f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
|
866
|
+
f.write("\n\n")
|
867
|
+
f.write(func_def)
|
868
|
+
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
869
|
+
f"{' ' * 4} case = create_case(case, recursion_idx=3)\n""")
|
870
|
+
f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
|
871
|
+
|
872
|
+
@property
|
873
|
+
def case_type(self) -> Type:
|
874
|
+
"""
|
875
|
+
:return: The type of the case (input) to the RDR classifier.
|
876
|
+
"""
|
877
|
+
if isinstance(self.start_rule.corner_case, Case):
|
878
|
+
return self.start_rule.corner_case._obj_type
|
879
|
+
else:
|
880
|
+
return type(self.start_rule.corner_case)
|
881
|
+
|
882
|
+
def get_rdr_classifier_from_python_file(self, file_path: str):
|
883
|
+
"""
|
884
|
+
:param file_path: The path to the file that contains the RDR classifier function.
|
885
|
+
:return: The module that contains the rdr classifier function.
|
886
|
+
"""
|
887
|
+
return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
|
888
|
+
|
889
|
+
@property
|
890
|
+
def generated_python_file_name(self) -> str:
|
891
|
+
return f"{self.case_type.__name__.lower()}_grdr"
|
892
|
+
|
893
|
+
@property
|
894
|
+
def conclusion_type_hint(self) -> str:
|
895
|
+
return f"List[Union[{', '.join([rdr.conclusion_type_hint for rdr in self.start_rules_dict.values()])}]]"
|
896
|
+
|
897
|
+
def _get_imports(self, file_path: str) -> str:
|
898
|
+
imports = ""
|
899
|
+
# add type hints
|
900
|
+
imports += f"from typing_extensions import List, Union, Set\n"
|
901
|
+
# import rdr type
|
902
|
+
imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
|
903
|
+
# add case type
|
904
|
+
imports += f"from ripple_down_rules.datastructures import Case, create_case\n"
|
905
|
+
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|
906
|
+
# add conclusion type imports
|
907
|
+
for rdr in self.start_rules_dict.values():
|
908
|
+
imports += f"from {rdr.conclusion_type.__module__} import {rdr.conclusion_type.__name__}\n"
|
909
|
+
# add rdr python generated functions.
|
910
|
+
for rdr_key, rdr in self.start_rules_dict.items():
|
911
|
+
imports += (f"from {file_path.strip('./')}"
|
912
|
+
f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
|
913
|
+
return imports
|
914
|
+
|
915
|
+
@staticmethod
|
916
|
+
def rdr_key_to_function_name(rdr_key: str) -> str:
|
917
|
+
"""
|
918
|
+
Convert the RDR key to a function name.
|
919
|
+
|
920
|
+
:param rdr_key: The RDR key to convert.
|
921
|
+
:return: The function name.
|
922
|
+
"""
|
923
|
+
return rdr_key.replace(".", "_").lower() + "_classifier"
|