ripple-down-rules 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +0 -0
- ripple_down_rules/datasets.py +148 -0
- ripple_down_rules/datastructures/__init__.py +4 -0
- ripple_down_rules/datastructures/callable_expression.py +237 -0
- ripple_down_rules/datastructures/dataclasses.py +76 -0
- ripple_down_rules/datastructures/enums.py +173 -0
- ripple_down_rules/datastructures/generated/__init__.py +0 -0
- ripple_down_rules/datastructures/generated/column/__init__.py +0 -0
- ripple_down_rules/datastructures/generated/row/__init__.py +0 -0
- ripple_down_rules/datastructures/table.py +544 -0
- ripple_down_rules/experts.py +281 -0
- ripple_down_rules/failures.py +10 -0
- ripple_down_rules/prompt.py +101 -0
- ripple_down_rules/rdr.py +687 -0
- ripple_down_rules/rules.py +260 -0
- ripple_down_rules/utils.py +463 -0
- ripple_down_rules-0.0.0.dist-info/METADATA +54 -0
- ripple_down_rules-0.0.0.dist-info/RECORD +20 -0
- ripple_down_rules-0.0.0.dist-info/WHEEL +5 -0
- ripple_down_rules-0.0.0.dist-info/top_level.txt +1 -0
ripple_down_rules/rdr.py
ADDED
@@ -0,0 +1,687 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from copy import copy
|
5
|
+
|
6
|
+
from matplotlib import pyplot as plt
|
7
|
+
from ordered_set import OrderedSet
|
8
|
+
from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
|
9
|
+
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self
|
10
|
+
|
11
|
+
from .datastructures import Case, MCRDRMode, CallableExpression, Column, CaseQuery
|
12
|
+
from .experts import Expert, Human
|
13
|
+
from .rules import Rule, SingleClassRule, MultiClassTopRule
|
14
|
+
from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
|
15
|
+
get_hint_for_attribute, SubclassJSONSerializer
|
16
|
+
|
17
|
+
|
18
|
+
class RippleDownRules(ABC):
|
19
|
+
"""
|
20
|
+
The abstract base class for the ripple down rules classifiers.
|
21
|
+
"""
|
22
|
+
fig: Optional[plt.Figure] = None
|
23
|
+
"""
|
24
|
+
The figure to draw the tree on.
|
25
|
+
"""
|
26
|
+
expert_accepted_conclusions: Optional[List[Column]] = None
|
27
|
+
"""
|
28
|
+
The conclusions that the expert has accepted, such that they are not asked again.
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self, start_rule: Optional[Rule] = None, session: Optional[Session] = None):
|
32
|
+
"""
|
33
|
+
:param start_rule: The starting rule for the classifier.
|
34
|
+
:param session: The sqlalchemy orm session.
|
35
|
+
"""
|
36
|
+
self.start_rule = start_rule
|
37
|
+
self.session = session
|
38
|
+
self.fig: Optional[plt.Figure] = None
|
39
|
+
|
40
|
+
def __call__(self, case: Union[Case, SQLTable]) -> Column:
|
41
|
+
return self.classify(case)
|
42
|
+
|
43
|
+
@abstractmethod
|
44
|
+
def classify(self, case: Union[Case, SQLTable]) -> Optional[Column]:
|
45
|
+
"""
|
46
|
+
Classify a case.
|
47
|
+
|
48
|
+
:param case: The case to classify.
|
49
|
+
:return: The category that the case belongs to.
|
50
|
+
"""
|
51
|
+
pass
|
52
|
+
|
53
|
+
@abstractmethod
|
54
|
+
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs)\
|
55
|
+
-> Union[Column, CallableExpression]:
|
56
|
+
"""
|
57
|
+
Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
|
58
|
+
comparing the case with the target category.
|
59
|
+
|
60
|
+
:param case_query: The query containing the case to classify and the target category to compare the case with.
|
61
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
62
|
+
:return: The category that the case belongs to.
|
63
|
+
"""
|
64
|
+
pass
|
65
|
+
|
66
|
+
def fit(self, case_queries: List[CaseQuery],
|
67
|
+
expert: Optional[Expert] = None,
|
68
|
+
n_iter: int = None,
|
69
|
+
animate_tree: bool = False,
|
70
|
+
**kwargs_for_fit_case):
|
71
|
+
"""
|
72
|
+
Fit the classifier to a batch of cases and categories.
|
73
|
+
|
74
|
+
:param case_queries: The cases and categories to fit the classifier to.
|
75
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
76
|
+
:param n_iter: The number of iterations to fit the classifier for.
|
77
|
+
:param animate_tree: Whether to draw the tree while fitting the classifier.
|
78
|
+
:param kwargs_for_fit_case: The keyword arguments to pass to the fit_case method.
|
79
|
+
"""
|
80
|
+
cases = [case_query.case for case_query in case_queries]
|
81
|
+
targets = [case.target for case in case_queries]
|
82
|
+
if animate_tree:
|
83
|
+
plt.ion()
|
84
|
+
i = 0
|
85
|
+
stop_iterating = False
|
86
|
+
num_rules: int = 0
|
87
|
+
while not stop_iterating:
|
88
|
+
all_pred = 0
|
89
|
+
all_recall = []
|
90
|
+
all_precision = []
|
91
|
+
if not targets:
|
92
|
+
targets = [None] * len(cases)
|
93
|
+
for case_query in case_queries:
|
94
|
+
case = case_query.case
|
95
|
+
target = case_query.target
|
96
|
+
if not target:
|
97
|
+
conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
|
98
|
+
target = expert.ask_for_conclusion(case_query, conclusions)
|
99
|
+
pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
|
100
|
+
pred_cat = pred_cat if isinstance(pred_cat, list) else [pred_cat]
|
101
|
+
target = target if isinstance(target, list) else [target]
|
102
|
+
recall = [not yi or (yi in pred_cat) for yi in target]
|
103
|
+
y_type = [type(yi) for yi in target]
|
104
|
+
precision = [(pred in target) or (type(pred) not in y_type) for pred in pred_cat]
|
105
|
+
match = all(recall) and all(precision)
|
106
|
+
all_recall.extend(recall)
|
107
|
+
all_precision.extend(precision)
|
108
|
+
if not match:
|
109
|
+
print(f"Predicted: {pred_cat} but expected: {target}")
|
110
|
+
all_pred += int(match)
|
111
|
+
if animate_tree and self.start_rule.size > num_rules:
|
112
|
+
num_rules = self.start_rule.size
|
113
|
+
self.update_figures()
|
114
|
+
i += 1
|
115
|
+
all_predicted = targets and all_pred == len(targets)
|
116
|
+
num_iter_reached = n_iter and i >= n_iter
|
117
|
+
stop_iterating = all_predicted or num_iter_reached
|
118
|
+
if stop_iterating:
|
119
|
+
break
|
120
|
+
print(f"Recall: {sum(all_recall) / len(all_recall)}")
|
121
|
+
print(f"Precision: {sum(all_precision) / len(all_precision)}")
|
122
|
+
print(f"Accuracy: {all_pred}/{n_iter}")
|
123
|
+
print(f"Finished training in {i} iterations")
|
124
|
+
if animate_tree:
|
125
|
+
plt.ioff()
|
126
|
+
plt.show()
|
127
|
+
|
128
|
+
def update_figures(self):
|
129
|
+
"""
|
130
|
+
Update the figures of the classifier.
|
131
|
+
"""
|
132
|
+
if isinstance(self, GeneralRDR):
|
133
|
+
for i, (_type, rdr) in enumerate(self.start_rules_dict.items()):
|
134
|
+
if not rdr.fig:
|
135
|
+
rdr.fig = plt.figure(f"Rule {i}: {_type.__name__}")
|
136
|
+
draw_tree(rdr.start_rule, rdr.fig)
|
137
|
+
else:
|
138
|
+
if not self.fig:
|
139
|
+
self.fig = plt.figure(0)
|
140
|
+
draw_tree(self.start_rule, self.fig)
|
141
|
+
|
142
|
+
@staticmethod
|
143
|
+
def case_has_conclusion(case: Union[Case, SQLTable], conclusion_type: Type) -> bool:
|
144
|
+
"""
|
145
|
+
Check if the case has a conclusion.
|
146
|
+
|
147
|
+
:param case: The case to check.
|
148
|
+
:param conclusion_type: The target category type to compare the case with.
|
149
|
+
:return: Whether the case has a conclusion or not.
|
150
|
+
"""
|
151
|
+
if isinstance(case, SQLTable):
|
152
|
+
prop_name, prop_value = get_attribute_by_type(case, conclusion_type)
|
153
|
+
if hasattr(prop_value, "__iter__") and not isinstance(prop_value, str):
|
154
|
+
return len(prop_value) > 0
|
155
|
+
else:
|
156
|
+
return prop_value is not None
|
157
|
+
else:
|
158
|
+
return conclusion_type in case
|
159
|
+
|
160
|
+
|
161
|
+
RDR = RippleDownRules
|
162
|
+
|
163
|
+
|
164
|
+
class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
|
165
|
+
|
166
|
+
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
167
|
+
-> Union[Column, CallableExpression]:
|
168
|
+
"""
|
169
|
+
Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
|
170
|
+
comparing the case with the target category if provided.
|
171
|
+
|
172
|
+
:param case_query: The case to classify and the target category to compare the case with.
|
173
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
174
|
+
:return: The category that the case belongs to.
|
175
|
+
"""
|
176
|
+
expert = expert if expert else Human(session=self.session)
|
177
|
+
case, attribute = case_query.case, case_query.attribute
|
178
|
+
if case_query.target is None:
|
179
|
+
target = expert.ask_for_conclusion(case_query)
|
180
|
+
else:
|
181
|
+
target = case_query.target
|
182
|
+
|
183
|
+
if not self.start_rule:
|
184
|
+
conditions = expert.ask_for_conditions(case, [target])
|
185
|
+
self.start_rule = SingleClassRule(conditions, target, corner_case=case)
|
186
|
+
|
187
|
+
pred = self.evaluate(case)
|
188
|
+
|
189
|
+
if pred.conclusion != target:
|
190
|
+
conditions = expert.ask_for_conditions(case, [target], pred)
|
191
|
+
pred.fit_rule(case, target, conditions=conditions)
|
192
|
+
|
193
|
+
return self.classify(case)
|
194
|
+
|
195
|
+
def classify(self, case: Case) -> Optional[Column]:
|
196
|
+
"""
|
197
|
+
Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
|
198
|
+
"""
|
199
|
+
pred = self.evaluate(case)
|
200
|
+
return pred.conclusion if pred.fired else None
|
201
|
+
|
202
|
+
def evaluate(self, case: Case) -> SingleClassRule:
|
203
|
+
"""
|
204
|
+
Evaluate the starting rule on a case.
|
205
|
+
"""
|
206
|
+
matched_rule = self.start_rule(case)
|
207
|
+
return matched_rule if matched_rule else self.start_rule
|
208
|
+
|
209
|
+
def write_to_python_file(self, filename: str):
|
210
|
+
"""
|
211
|
+
Write the tree of rules as source code to a file.
|
212
|
+
"""
|
213
|
+
case_type = self.start_rule.corner_case.__class__.__name__
|
214
|
+
case_module = self.start_rule.corner_case.__class__.__module__
|
215
|
+
conclusion = self.start_rule.conclusion
|
216
|
+
if isinstance(conclusion, CallableExpression):
|
217
|
+
conclusion_types = [conclusion.conclusion_type]
|
218
|
+
elif isinstance(conclusion, Column):
|
219
|
+
conclusion_types = list(conclusion._value_range)
|
220
|
+
else:
|
221
|
+
conclusion_types = [type(conclusion)]
|
222
|
+
imports = ""
|
223
|
+
if case_module != "builtins":
|
224
|
+
imports += f"from {case_module} import {case_type}\n"
|
225
|
+
if len(conclusion_types) > 1:
|
226
|
+
conclusion_name = "Union[" + ", ".join([c.__name__ for c in conclusion_types]) + "]"
|
227
|
+
else:
|
228
|
+
conclusion_name = conclusion_types[0].__name__
|
229
|
+
for conclusion_type in conclusion_types:
|
230
|
+
if conclusion_type.__module__ != "builtins":
|
231
|
+
imports += f"from {conclusion_type.__module__} import {conclusion_name}\n"
|
232
|
+
imports += "\n\n"
|
233
|
+
func_def = f"def classify_{conclusion_name.lower()}(case: {case_type}) -> {conclusion_name}:\n"
|
234
|
+
with open(filename, "w") as f:
|
235
|
+
f.write(imports)
|
236
|
+
f.write(func_def)
|
237
|
+
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
|
238
|
+
|
239
|
+
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file, parent_indent: str = ""):
|
240
|
+
"""
|
241
|
+
Write the rules as source code to a file.
|
242
|
+
"""
|
243
|
+
if rule.conditions:
|
244
|
+
file.write(rule.write_condition_as_source_code(parent_indent))
|
245
|
+
if rule.refinement:
|
246
|
+
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ")
|
247
|
+
|
248
|
+
file.write(rule.write_conclusion_as_source_code(parent_indent))
|
249
|
+
|
250
|
+
if rule.alternative:
|
251
|
+
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
|
252
|
+
|
253
|
+
def to_json(self) -> Dict[str, Any]:
|
254
|
+
return {**SubclassJSONSerializer.to_json(self), "start_rule": self.start_rule.to_json()}
|
255
|
+
|
256
|
+
@classmethod
|
257
|
+
def _from_json(cls, data: Dict[str, Any]) -> Self:
|
258
|
+
"""
|
259
|
+
Create an instance of the class from a json
|
260
|
+
"""
|
261
|
+
start_rule = SingleClassRule.from_json(data["start_rule"])
|
262
|
+
return cls(start_rule)
|
263
|
+
|
264
|
+
|
265
|
+
class MultiClassRDR(RippleDownRules):
|
266
|
+
"""
|
267
|
+
A multi class ripple down rules classifier, which can draw multiple conclusions for a case.
|
268
|
+
This is done by going through all rules and checking if they fire or not, and adding stopping rules if needed,
|
269
|
+
when wrong conclusions are made to stop these rules from firing again for similar cases.
|
270
|
+
"""
|
271
|
+
evaluated_rules: Optional[List[Rule]] = None
|
272
|
+
"""
|
273
|
+
The evaluated rules in the classifier for one case.
|
274
|
+
"""
|
275
|
+
conclusions: Optional[List[Column]] = None
|
276
|
+
"""
|
277
|
+
The conclusions that the case belongs to.
|
278
|
+
"""
|
279
|
+
stop_rule_conditions: Optional[CallableExpression] = None
|
280
|
+
"""
|
281
|
+
The conditions of the stopping rule if needed.
|
282
|
+
"""
|
283
|
+
|
284
|
+
def __init__(self, start_rules: Optional[List[Rule]] = None,
|
285
|
+
mode: MCRDRMode = MCRDRMode.StopOnly, session: Optional[Session] = None):
|
286
|
+
"""
|
287
|
+
:param start_rules: The starting rules for the classifier, these are the rules that are at the top of the tree
|
288
|
+
and are always checked, in contrast to the refinement and alternative rules which are only checked if the
|
289
|
+
starting rules fire or not.
|
290
|
+
:param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
|
291
|
+
:param session: The sqlalchemy orm session.
|
292
|
+
"""
|
293
|
+
self.start_rules = [MultiClassTopRule()] if not start_rules else start_rules
|
294
|
+
super(MultiClassRDR, self).__init__(self.start_rules[0], session=session)
|
295
|
+
self.mode: MCRDRMode = mode
|
296
|
+
|
297
|
+
def classify(self, case: Union[Case, SQLTable]) -> List[Any]:
|
298
|
+
evaluated_rule = self.start_rule
|
299
|
+
self.conclusions = []
|
300
|
+
while evaluated_rule:
|
301
|
+
next_rule = evaluated_rule(case)
|
302
|
+
if evaluated_rule.fired:
|
303
|
+
self.add_conclusion(evaluated_rule)
|
304
|
+
evaluated_rule = next_rule
|
305
|
+
return self.conclusions
|
306
|
+
|
307
|
+
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None,
|
308
|
+
add_extra_conclusions: bool = False) -> List[Union[Column, CallableExpression]]:
|
309
|
+
"""
|
310
|
+
Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
|
311
|
+
or missing by comparing the case with the target category if provided.
|
312
|
+
|
313
|
+
:param case_query: The query containing the case to classify and the target category to compare the case with.
|
314
|
+
:param expert: The expert to ask for differentiating features as new rule conditions or for extra conclusions.
|
315
|
+
:param add_extra_conclusions: Whether to add extra conclusions after classification is done.
|
316
|
+
:return: The conclusions that the case belongs to.
|
317
|
+
"""
|
318
|
+
expert = expert if expert else Human(session=self.session)
|
319
|
+
case = case_query.case
|
320
|
+
if case_query.target is None:
|
321
|
+
targets = [expert.ask_for_conclusion(case_query)]
|
322
|
+
else:
|
323
|
+
targets = [case_query.target]
|
324
|
+
self.expert_accepted_conclusions = []
|
325
|
+
user_conclusions = []
|
326
|
+
for target in targets:
|
327
|
+
self.update_start_rule(case, target, expert)
|
328
|
+
self.conclusions = []
|
329
|
+
self.stop_rule_conditions = None
|
330
|
+
evaluated_rule = self.start_rule
|
331
|
+
while evaluated_rule:
|
332
|
+
next_rule = evaluated_rule(case)
|
333
|
+
good_conclusions = targets + user_conclusions + self.expert_accepted_conclusions
|
334
|
+
|
335
|
+
if evaluated_rule.fired:
|
336
|
+
if target and evaluated_rule.conclusion not in good_conclusions:
|
337
|
+
# if self.case_has_conclusion(case, evaluated_rule.conclusion):
|
338
|
+
# Rule fired and conclusion is different from target
|
339
|
+
self.stop_wrong_conclusion_else_add_it(case, target, expert, evaluated_rule,
|
340
|
+
add_extra_conclusions)
|
341
|
+
else:
|
342
|
+
# Rule fired and target is correct or there is no target to compare
|
343
|
+
self.add_conclusion(evaluated_rule)
|
344
|
+
|
345
|
+
if not next_rule:
|
346
|
+
if target not in self.conclusions:
|
347
|
+
# Nothing fired and there is a target that should have been in the conclusions
|
348
|
+
self.add_rule_for_case(case, target, expert)
|
349
|
+
# Have to check all rules again to make sure only this new rule fires
|
350
|
+
next_rule = self.start_rule
|
351
|
+
elif add_extra_conclusions and not user_conclusions:
|
352
|
+
# No more conclusions can be made, ask the expert for extra conclusions if needed.
|
353
|
+
user_conclusions.extend(self.ask_expert_for_extra_conclusions(expert, case))
|
354
|
+
if user_conclusions:
|
355
|
+
next_rule = self.last_top_rule
|
356
|
+
evaluated_rule = next_rule
|
357
|
+
return self.conclusions
|
358
|
+
|
359
|
+
def update_start_rule(self, case: Union[Case, SQLTable], target: Any, expert: Expert):
|
360
|
+
"""
|
361
|
+
Update the starting rule of the classifier.
|
362
|
+
|
363
|
+
:param case: The case to classify.
|
364
|
+
:param target: The target category to compare the case with.
|
365
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
366
|
+
"""
|
367
|
+
if not self.start_rule.conditions:
|
368
|
+
conditions = expert.ask_for_conditions(case, target)
|
369
|
+
self.start_rule.conditions = conditions
|
370
|
+
self.start_rule.conclusion = target
|
371
|
+
self.start_rule.corner_case = case
|
372
|
+
|
373
|
+
@property
|
374
|
+
def last_top_rule(self) -> Optional[MultiClassTopRule]:
|
375
|
+
"""
|
376
|
+
Get the last top rule in the tree.
|
377
|
+
"""
|
378
|
+
if not self.start_rule.furthest_alternative:
|
379
|
+
return self.start_rule
|
380
|
+
else:
|
381
|
+
return self.start_rule.furthest_alternative[-1]
|
382
|
+
|
383
|
+
def stop_wrong_conclusion_else_add_it(self, case: Union[Case, SQLTable], target: Any, expert: Expert,
|
384
|
+
evaluated_rule: MultiClassTopRule,
|
385
|
+
add_extra_conclusions: bool):
|
386
|
+
"""
|
387
|
+
Stop a wrong conclusion by adding a stopping rule.
|
388
|
+
"""
|
389
|
+
if self.is_same_category_type(evaluated_rule.conclusion, target) \
|
390
|
+
and self.is_conflicting_with_target(evaluated_rule.conclusion, target):
|
391
|
+
self.stop_conclusion(case, target, expert, evaluated_rule)
|
392
|
+
elif not self.conclusion_is_correct(case, target, expert, evaluated_rule, add_extra_conclusions):
|
393
|
+
self.stop_conclusion(case, target, expert, evaluated_rule)
|
394
|
+
|
395
|
+
def stop_conclusion(self, case: Union[Case, SQLTable], target: Any,
|
396
|
+
expert: Expert, evaluated_rule: MultiClassTopRule):
|
397
|
+
"""
|
398
|
+
Stop a conclusion by adding a stopping rule.
|
399
|
+
|
400
|
+
:param case: The case to classify.
|
401
|
+
:param target: The target category to compare the case with.
|
402
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
403
|
+
:param evaluated_rule: The evaluated rule to ask the expert about.
|
404
|
+
"""
|
405
|
+
conditions = expert.ask_for_conditions(case, target, evaluated_rule)
|
406
|
+
evaluated_rule.fit_rule(case, target, conditions=conditions)
|
407
|
+
if self.mode == MCRDRMode.StopPlusRule:
|
408
|
+
self.stop_rule_conditions = conditions
|
409
|
+
if self.mode == MCRDRMode.StopPlusRuleCombined:
|
410
|
+
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
411
|
+
self.add_top_rule(new_top_rule_conditions, target, case)
|
412
|
+
|
413
|
+
@staticmethod
|
414
|
+
def is_conflicting_with_target(conclusion: Any, target: Any) -> bool:
|
415
|
+
"""
|
416
|
+
Check if the conclusion is conflicting with the target category.
|
417
|
+
|
418
|
+
:param conclusion: The conclusion to check.
|
419
|
+
:param target: The target category to compare the conclusion with.
|
420
|
+
:return: Whether the conclusion is conflicting with the target category.
|
421
|
+
"""
|
422
|
+
if hasattr(conclusion, "mutually_exclusive") and conclusion.mutually_exclusive:
|
423
|
+
return True
|
424
|
+
else:
|
425
|
+
return not make_set(conclusion).issubset(make_set(target))
|
426
|
+
|
427
|
+
@staticmethod
|
428
|
+
def is_same_category_type(conclusion: Any, target: Any) -> bool:
|
429
|
+
"""
|
430
|
+
Check if the conclusion is of the same class as the target category.
|
431
|
+
|
432
|
+
:param conclusion: The conclusion to check.
|
433
|
+
:param target: The target category to compare the conclusion with.
|
434
|
+
:return: Whether the conclusion is of the same class as the target category but has a different value.
|
435
|
+
"""
|
436
|
+
return conclusion.__class__ == target.__class__ and target.__class__ != Column
|
437
|
+
|
438
|
+
def conclusion_is_correct(self, case: Union[Case, SQLTable], target: Any, expert: Expert, evaluated_rule: Rule,
|
439
|
+
add_extra_conclusions: bool) -> bool:
|
440
|
+
"""
|
441
|
+
Ask the expert if the conclusion is correct, and add it to the conclusions if it is.
|
442
|
+
|
443
|
+
:param case: The case to classify.
|
444
|
+
:param target: The target category to compare the case with.
|
445
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
446
|
+
:param evaluated_rule: The evaluated rule to ask the expert about.
|
447
|
+
:param add_extra_conclusions: Whether adding extra conclusions after classification is allowed.
|
448
|
+
:return: Whether the conclusion is correct or not.
|
449
|
+
"""
|
450
|
+
conclusions = list(OrderedSet(self.conclusions))
|
451
|
+
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case, evaluated_rule.conclusion,
|
452
|
+
targets=target,
|
453
|
+
current_conclusions=conclusions)):
|
454
|
+
self.add_conclusion(evaluated_rule)
|
455
|
+
self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
|
456
|
+
return True
|
457
|
+
return False
|
458
|
+
|
459
|
+
def add_rule_for_case(self, case: Union[Case, SQLTable], target: Any, expert: Expert):
|
460
|
+
"""
|
461
|
+
Add a rule for a case that has not been classified with any conclusion.
|
462
|
+
"""
|
463
|
+
if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
|
464
|
+
conditions = self.stop_rule_conditions
|
465
|
+
self.stop_rule_conditions = None
|
466
|
+
else:
|
467
|
+
conditions = expert.ask_for_conditions(case, target)
|
468
|
+
self.add_top_rule(conditions, target, case)
|
469
|
+
|
470
|
+
def ask_expert_for_extra_conclusions(self, expert: Expert, case: Union[Case, SQLTable]) -> List[Any]:
|
471
|
+
"""
|
472
|
+
Ask the expert for extra conclusions when no more conclusions can be made.
|
473
|
+
|
474
|
+
:param expert: The expert to ask for extra conclusions.
|
475
|
+
:param case: The case to ask extra conclusions for.
|
476
|
+
:return: The extra conclusions that the expert has provided.
|
477
|
+
"""
|
478
|
+
extra_conclusions = []
|
479
|
+
conclusions = list(OrderedSet(self.conclusions))
|
480
|
+
if not expert.use_loaded_answers:
|
481
|
+
print("current conclusions:", conclusions)
|
482
|
+
extra_conclusions_dict = expert.ask_for_extra_conclusions(case, conclusions)
|
483
|
+
if extra_conclusions_dict:
|
484
|
+
for conclusion, conditions in extra_conclusions_dict.items():
|
485
|
+
self.add_top_rule(conditions, conclusion, case)
|
486
|
+
extra_conclusions.append(conclusion)
|
487
|
+
return extra_conclusions
|
488
|
+
|
489
|
+
def add_conclusion(self, evaluated_rule: Rule) -> None:
|
490
|
+
"""
|
491
|
+
Add the conclusion of the evaluated rule to the list of conclusions.
|
492
|
+
|
493
|
+
:param evaluated_rule: The evaluated rule to add the conclusion of.
|
494
|
+
"""
|
495
|
+
conclusion_types = [type(c) for c in self.conclusions]
|
496
|
+
if type(evaluated_rule.conclusion) not in conclusion_types:
|
497
|
+
self.conclusions.append(evaluated_rule.conclusion)
|
498
|
+
else:
|
499
|
+
same_type_conclusions = [c for c in self.conclusions if type(c) == type(evaluated_rule.conclusion)]
|
500
|
+
combined_conclusion = evaluated_rule.conclusion if isinstance(evaluated_rule.conclusion, set) \
|
501
|
+
else {evaluated_rule.conclusion}
|
502
|
+
for c in same_type_conclusions:
|
503
|
+
combined_conclusion.update(c if isinstance(c, set) else make_set(c))
|
504
|
+
self.conclusions.remove(c)
|
505
|
+
self.conclusions.extend(combined_conclusion)
|
506
|
+
|
507
|
+
def add_top_rule(self, conditions: CallableExpression, conclusion: Any, corner_case: Union[Case, SQLTable]):
|
508
|
+
"""
|
509
|
+
Add a top rule to the classifier, which is a rule that is always checked and is part of the start_rules list.
|
510
|
+
|
511
|
+
:param conditions: The conditions of the rule.
|
512
|
+
:param conclusion: The conclusion of the rule.
|
513
|
+
:param corner_case: The corner case of the rule.
|
514
|
+
"""
|
515
|
+
self.start_rule.alternative = MultiClassTopRule(conditions, conclusion, corner_case=corner_case)
|
516
|
+
|
517
|
+
|
518
|
+
class GeneralRDR(RippleDownRules):
|
519
|
+
"""
|
520
|
+
A general ripple down rules classifier, which can draw multiple conclusions for a case, but each conclusion is part
|
521
|
+
of a set of mutually exclusive conclusions. Whenever a conclusion is made, the classification restarts from the
|
522
|
+
starting rule, and all the rules that belong to the class of the made conclusion are not checked again. This
|
523
|
+
continues until no more rules can be fired. In addition, previous conclusions can be used as conditions or input to
|
524
|
+
the next classification/cycle.
|
525
|
+
Another possible mode is to have rules that are considered final, when fired, inference will not be restarted,
|
526
|
+
and only a refinement can be made to the final rule, those can also be used in another SCRDR of their own that
|
527
|
+
gets called when the final rule fires.
|
528
|
+
"""
|
529
|
+
|
530
|
+
def __init__(self, category_rdr_map: Optional[Dict[Type, Union[SingleClassRDR, MultiClassRDR]]] = None):
|
531
|
+
"""
|
532
|
+
:param category_rdr_map: A map of categories to ripple down rules classifiers,
|
533
|
+
where each category is a parent category that has a set of mutually exclusive (in case of SCRDR) child
|
534
|
+
categories, e.g. {Species: SCRDR, Habitat: MCRDR}, where Species and Habitat are parent categories and SCRDR
|
535
|
+
and MCRDR are SingleClass and MultiClass ripple down rules classifiers. Species can have child categories like
|
536
|
+
Mammal, Bird, Fish, etc. which are mutually exclusive, and Habitat can have child categories like
|
537
|
+
Land, Water, Air, etc, which are not mutually exclusive due to some animals living more than one habitat.
|
538
|
+
"""
|
539
|
+
self.start_rules_dict: Dict[Type, Union[SingleClassRDR, MultiClassRDR]] \
|
540
|
+
= category_rdr_map if category_rdr_map else {}
|
541
|
+
super(GeneralRDR, self).__init__()
|
542
|
+
self.all_figs: List[plt.Figure] = [sr.fig for sr in self.start_rules_dict.values()]
|
543
|
+
|
544
|
+
@property
|
545
|
+
def start_rule(self) -> Optional[Union[SingleClassRule, MultiClassTopRule]]:
|
546
|
+
return self.start_rules[0] if self.start_rules_dict else None
|
547
|
+
|
548
|
+
@start_rule.setter
|
549
|
+
def start_rule(self, value: Union[SingleClassRDR, MultiClassRDR]):
|
550
|
+
if value:
|
551
|
+
self.start_rules_dict[type(value.start_rule.conclusion)] = value
|
552
|
+
|
553
|
+
@property
|
554
|
+
def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
|
555
|
+
return [rdr.start_rule for rdr in self.start_rules_dict.values()]
|
556
|
+
|
557
|
+
def classify(self, case: Union[Case, SQLTable]) -> Optional[List[Any]]:
|
558
|
+
"""
|
559
|
+
Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
|
560
|
+
the classification until no more categories can be added.
|
561
|
+
|
562
|
+
:param case: The case to classify.
|
563
|
+
:return: The categories that the case belongs to.
|
564
|
+
"""
|
565
|
+
conclusions = []
|
566
|
+
case_cp = copy_case(case)
|
567
|
+
while True:
|
568
|
+
added_attributes = False
|
569
|
+
for cat_type, rdr in self.start_rules_dict.items():
|
570
|
+
if self.case_has_conclusion(case_cp, cat_type):
|
571
|
+
continue
|
572
|
+
pred_atts = rdr.classify(case_cp)
|
573
|
+
if pred_atts:
|
574
|
+
pred_atts = pred_atts if isinstance(pred_atts, list) else [pred_atts]
|
575
|
+
pred_atts = [p for p in pred_atts if p not in conclusions]
|
576
|
+
added_attributes = True
|
577
|
+
conclusions.extend(pred_atts)
|
578
|
+
self.update_case_with_same_type_conclusions(case_cp, pred_atts)
|
579
|
+
if not added_attributes:
|
580
|
+
break
|
581
|
+
return conclusions
|
582
|
+
|
583
|
+
def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs)\
|
584
|
+
-> List[Union[Column, CallableExpression]]:
|
585
|
+
"""
|
586
|
+
Fit the GRDR on a case, if the target is a new type of category, a new RDR is created for it,
|
587
|
+
else the existing RDR of that type will be fitted on the case, and then classification is done and all
|
588
|
+
concluded categories are returned. If the category is mutually exclusive, an SCRDR is created, else an MCRDR.
|
589
|
+
In case of SCRDR, multiple conclusions of the same type replace each other, in case of MCRDR, they are added if
|
590
|
+
they are accepted by the expert, and the attribute of that category is represented in the case as a set of
|
591
|
+
values.
|
592
|
+
|
593
|
+
:param case_queries: The queries containing the case to classify and the target categories to compare the case
|
594
|
+
with.
|
595
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
596
|
+
:return: The categories that the case belongs to.
|
597
|
+
"""
|
598
|
+
expert = expert if expert else Human()
|
599
|
+
case_queries = [case_queries] if not isinstance(case_queries, list) else case_queries
|
600
|
+
assert len(case_queries) > 0, "No case queries provided"
|
601
|
+
case = case_queries[0].case
|
602
|
+
assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
|
603
|
+
" for multiple cases use fit instead")
|
604
|
+
case_query_cp = copy(case_queries[0])
|
605
|
+
case_cp = case_query_cp.case
|
606
|
+
for case_query in case_queries:
|
607
|
+
target = case_query.target
|
608
|
+
if not target:
|
609
|
+
target = expert.ask_for_conclusion(case_query)
|
610
|
+
case_query_cp = CaseQuery(case_cp, attribute_name=case_query.attribute_name, target=target)
|
611
|
+
if type(target) not in self.start_rules_dict:
|
612
|
+
conclusions = self.classify(case)
|
613
|
+
self.update_case_with_same_type_conclusions(case_cp, conclusions)
|
614
|
+
new_rdr = self.initialize_new_rdr_for_attribute(target, case_cp)
|
615
|
+
new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
|
616
|
+
self.start_rules_dict[type(target)] = new_rdr
|
617
|
+
self.update_case_with_same_type_conclusions(case_cp, new_conclusions, type(target))
|
618
|
+
elif not self.case_has_conclusion(case_cp, type(target)):
|
619
|
+
for rdr_type, rdr in self.start_rules_dict.items():
|
620
|
+
if type(target) is not rdr_type:
|
621
|
+
conclusions = rdr.classify(case_cp)
|
622
|
+
else:
|
623
|
+
conclusions = self.start_rules_dict[type(target)].fit_case(case_query_cp,
|
624
|
+
expert, **kwargs)
|
625
|
+
self.update_case_with_same_type_conclusions(case_cp, conclusions, rdr_type)
|
626
|
+
|
627
|
+
return self.classify(case)
|
628
|
+
|
629
|
+
@staticmethod
|
630
|
+
def initialize_new_rdr_for_attribute(attribute: Any, case: Union[Case, SQLTable]):
|
631
|
+
"""
|
632
|
+
Initialize the appropriate RDR type for the target.
|
633
|
+
"""
|
634
|
+
if isinstance(case, SQLTable):
|
635
|
+
prop = get_attribute_by_type(case, type(attribute))
|
636
|
+
if hasattr(prop, "__iter__") and not isinstance(prop, str):
|
637
|
+
return MultiClassRDR()
|
638
|
+
else:
|
639
|
+
return SingleClassRDR()
|
640
|
+
else:
|
641
|
+
return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
|
642
|
+
|
643
|
+
@staticmethod
|
644
|
+
def update_case_with_same_type_conclusions(case: Union[Case, SQLTable],
|
645
|
+
conclusions: List[Any], attribute_type: Optional[Any] = None):
|
646
|
+
"""
|
647
|
+
Update the case with the conclusions.
|
648
|
+
|
649
|
+
:param case: The case to update.
|
650
|
+
:param conclusions: The conclusions to update the case with.
|
651
|
+
:param attribute_type: The type of the attribute to update.
|
652
|
+
"""
|
653
|
+
if not conclusions:
|
654
|
+
return
|
655
|
+
conclusions = [conclusions] if not isinstance(conclusions, list) else conclusions
|
656
|
+
if len(conclusions) == 0:
|
657
|
+
return
|
658
|
+
if isinstance(case, SQLTable):
|
659
|
+
conclusions_type = type(conclusions[0]) if not attribute_type else attribute_type
|
660
|
+
attr_name, attribute = get_attribute_by_type(case, conclusions_type)
|
661
|
+
hint, origin, args = get_hint_for_attribute(attr_name, case)
|
662
|
+
if isinstance(attribute, set) or origin == set:
|
663
|
+
attribute = set() if attribute is None else attribute
|
664
|
+
attribute.update(*[make_set(c) for c in conclusions])
|
665
|
+
elif isinstance(attribute, list) or origin == list:
|
666
|
+
attribute = [] if attribute is None else attribute
|
667
|
+
attribute.extend(conclusions)
|
668
|
+
elif len(conclusions) == 1 and hint == conclusions_type:
|
669
|
+
setattr(case, attr_name, conclusions.pop())
|
670
|
+
else:
|
671
|
+
raise ValueError(f"Cannot add multiple conclusions to attribute {attr_name}")
|
672
|
+
else:
|
673
|
+
case.update(*[c.as_dict for c in make_set(conclusions)])
|
674
|
+
|
675
|
+
@property
|
676
|
+
def names_of_all_types(self) -> List[str]:
|
677
|
+
"""
|
678
|
+
Get the names of all the types of categories that the GRDR can classify.
|
679
|
+
"""
|
680
|
+
return [t.__name__ for t in self.start_rules_dict.keys()]
|
681
|
+
|
682
|
+
@property
|
683
|
+
def all_types(self) -> List[Type]:
|
684
|
+
"""
|
685
|
+
Get all the types of categories that the GRDR can classify.
|
686
|
+
"""
|
687
|
+
return list(self.start_rules_dict.keys())
|