ripple-down-rules 0.1.64__py3-none-any.whl → 0.1.66__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datastructures/dataclasses.py +3 -3
- ripple_down_rules/experts.py +11 -110
- ripple_down_rules/rdr.py +141 -239
- ripple_down_rules/rules.py +27 -25
- ripple_down_rules/utils.py +48 -5
- {ripple_down_rules-0.1.64.dist-info → ripple_down_rules-0.1.66.dist-info}/METADATA +1 -1
- {ripple_down_rules-0.1.64.dist-info → ripple_down_rules-0.1.66.dist-info}/RECORD +10 -10
- {ripple_down_rules-0.1.64.dist-info → ripple_down_rules-0.1.66.dist-info}/WHEEL +1 -1
- {ripple_down_rules-0.1.64.dist-info → ripple_down_rules-0.1.66.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.1.64.dist-info → ripple_down_rules-0.1.66.dist-info}/top_level.txt +0 -0
@@ -133,7 +133,7 @@ class CaseQuery:
|
|
133
133
|
if value is not None and not isinstance(value, (CallableExpression, str)):
|
134
134
|
raise ValueError("The target must be a CallableExpression or a string.")
|
135
135
|
self._target = value
|
136
|
-
self.
|
136
|
+
self.update_target_value()
|
137
137
|
|
138
138
|
@property
|
139
139
|
def target_value(self) -> Any:
|
@@ -141,10 +141,10 @@ class CaseQuery:
|
|
141
141
|
:return: The target value of the case query.
|
142
142
|
"""
|
143
143
|
if self._target_value is None:
|
144
|
-
self.
|
144
|
+
self.update_target_value()
|
145
145
|
return self._target_value
|
146
146
|
|
147
|
-
def
|
147
|
+
def update_target_value(self):
|
148
148
|
"""
|
149
149
|
Update the target value of the case query.
|
150
150
|
"""
|
ripple_down_rules/experts.py
CHANGED
@@ -32,10 +32,11 @@ class Expert(ABC):
|
|
32
32
|
"""
|
33
33
|
A flag to indicate if the expert should use loaded answers or not.
|
34
34
|
"""
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
35
|
+
|
36
|
+
def __init__(self, use_loaded_answers: bool = False, append: bool = False):
|
37
|
+
self.all_expert_answers = []
|
38
|
+
self.use_loaded_answers = use_loaded_answers
|
39
|
+
self.append = append
|
39
40
|
|
40
41
|
@abstractmethod
|
41
42
|
def ask_for_conditions(self, case_query: CaseQuery, last_evaluated_rule: Optional[Rule] = None) \
|
@@ -51,28 +52,6 @@ class Expert(ABC):
|
|
51
52
|
pass
|
52
53
|
|
53
54
|
@abstractmethod
|
54
|
-
def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
|
55
|
-
"""
|
56
|
-
Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
|
57
|
-
|
58
|
-
:param case_query: The case query containing the case to classify.
|
59
|
-
:return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
|
60
|
-
conclusion and conditions for the rule.
|
61
|
-
"""
|
62
|
-
pass
|
63
|
-
|
64
|
-
@abstractmethod
|
65
|
-
def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
|
66
|
-
current_conclusions: Any) -> bool:
|
67
|
-
"""
|
68
|
-
Ask the expert if the conclusion is correct.
|
69
|
-
|
70
|
-
:param case_query: The case query about which the expert should answer.
|
71
|
-
:param conclusion: The conclusion to check.
|
72
|
-
:param current_conclusions: The current conclusions for the case.
|
73
|
-
"""
|
74
|
-
pass
|
75
|
-
|
76
55
|
def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
|
77
56
|
"""
|
78
57
|
Ask the expert to provide a relational conclusion for the case.
|
@@ -87,18 +66,14 @@ class Human(Expert):
|
|
87
66
|
The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
|
88
67
|
"""
|
89
68
|
|
90
|
-
def
|
91
|
-
self.all_expert_answers = []
|
92
|
-
self.use_loaded_answers = use_loaded_answers
|
93
|
-
|
94
|
-
def save_answers(self, path: str, append: bool = False):
|
69
|
+
def save_answers(self, path: str):
|
95
70
|
"""
|
96
71
|
Save the expert answers to a file.
|
97
72
|
|
98
73
|
:param path: The path to save the answers to.
|
99
74
|
:param append: A flag to indicate if the answers should be appended to the file or not.
|
100
75
|
"""
|
101
|
-
if append:
|
76
|
+
if self.append:
|
102
77
|
# read the file and append the new answers
|
103
78
|
with open(path + '.json', "r") as f:
|
104
79
|
all_answers = json.load(f)
|
@@ -126,24 +101,6 @@ class Human(Expert):
|
|
126
101
|
last_evaluated_rule=last_evaluated_rule)
|
127
102
|
return self._get_conditions(case_query)
|
128
103
|
|
129
|
-
def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
|
130
|
-
"""
|
131
|
-
Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
|
132
|
-
|
133
|
-
:param case_query: The case query containing the case to classify.
|
134
|
-
:return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
|
135
|
-
conclusion and conditions for the rule.
|
136
|
-
"""
|
137
|
-
rules = []
|
138
|
-
while True:
|
139
|
-
conclusion = self.ask_for_conclusion(case_query)
|
140
|
-
if conclusion is None:
|
141
|
-
break
|
142
|
-
conditions = self._get_conditions(case_query)
|
143
|
-
rules.append({PromptFor.Conclusion: conclusion,
|
144
|
-
PromptFor.Conditions: conditions})
|
145
|
-
return rules
|
146
|
-
|
147
104
|
def _get_conditions(self, case_query: CaseQuery) \
|
148
105
|
-> CallableExpression:
|
149
106
|
"""
|
@@ -154,6 +111,8 @@ class Human(Expert):
|
|
154
111
|
:return: The differentiating features as new rule conditions.
|
155
112
|
"""
|
156
113
|
user_input = None
|
114
|
+
if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
|
115
|
+
self.use_loaded_answers = False
|
157
116
|
if self.use_loaded_answers:
|
158
117
|
user_input = self.all_expert_answers.pop(0)
|
159
118
|
if user_input:
|
@@ -173,6 +132,8 @@ class Human(Expert):
|
|
173
132
|
:return: The conclusion for the case as a callable expression.
|
174
133
|
"""
|
175
134
|
expression: Optional[CallableExpression] = None
|
135
|
+
if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
|
136
|
+
self.use_loaded_answers = False
|
176
137
|
if self.use_loaded_answers:
|
177
138
|
expert_input = self.all_expert_answers.pop(0)
|
178
139
|
if expert_input is not None:
|
@@ -184,63 +145,3 @@ class Human(Expert):
|
|
184
145
|
self.all_expert_answers.append(expert_input)
|
185
146
|
case_query.target = expression
|
186
147
|
return expression
|
187
|
-
|
188
|
-
def get_category_type(self, cat_name: str) -> Optional[Type[CaseAttribute]]:
|
189
|
-
"""
|
190
|
-
Get the category type from the known categories.
|
191
|
-
|
192
|
-
:param cat_name: The name of the category.
|
193
|
-
:return: The category type.
|
194
|
-
"""
|
195
|
-
cat_name = cat_name.lower()
|
196
|
-
self.known_categories = get_all_subclasses(
|
197
|
-
CaseAttribute) if not self.known_categories else self.known_categories
|
198
|
-
self.known_categories.update(CaseAttribute.registry)
|
199
|
-
category_type = None
|
200
|
-
if cat_name in self.known_categories:
|
201
|
-
category_type = self.known_categories[cat_name]
|
202
|
-
return category_type
|
203
|
-
|
204
|
-
def ask_if_category_is_mutually_exclusive(self, category_name: str) -> bool:
|
205
|
-
"""
|
206
|
-
Ask the expert if the new category can have multiple values.
|
207
|
-
|
208
|
-
:param category_name: The name of the category to ask about.
|
209
|
-
"""
|
210
|
-
question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
|
211
|
-
return not self.ask_for_affirmation(question)
|
212
|
-
|
213
|
-
def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
|
214
|
-
current_conclusions: Any) -> bool:
|
215
|
-
"""
|
216
|
-
Ask the expert if the conclusion is correct.
|
217
|
-
|
218
|
-
:param case_query: The case query about which the expert should answer.
|
219
|
-
:param conclusion: The conclusion to check.
|
220
|
-
:param current_conclusions: The current conclusions for the case.
|
221
|
-
"""
|
222
|
-
if not self.use_loaded_answers:
|
223
|
-
print(f"Current conclusions: {current_conclusions}")
|
224
|
-
return self.ask_for_affirmation(case_query,
|
225
|
-
f"Is the conclusion {conclusion} correct for the case (True/False):")
|
226
|
-
|
227
|
-
def ask_for_affirmation(self, case_query: CaseQuery, question: str) -> bool:
|
228
|
-
"""
|
229
|
-
Ask the expert a yes or no question.
|
230
|
-
|
231
|
-
:param case_query: The case query about which the expert should answer.
|
232
|
-
:param question: The question to ask the expert.
|
233
|
-
:return: The answer to the question.
|
234
|
-
"""
|
235
|
-
while True:
|
236
|
-
if self.use_loaded_answers:
|
237
|
-
answer = self.all_expert_answers.pop(0)
|
238
|
-
else:
|
239
|
-
_, expression = prompt_user_for_expression(case_query, PromptFor.Affirmation, question)
|
240
|
-
answer = expression(case_query.case)
|
241
|
-
if answer:
|
242
|
-
self.all_expert_answers.append(True)
|
243
|
-
return True
|
244
|
-
else:
|
245
|
-
self.all_expert_answers.append(False)
|
246
|
-
return False
|
ripple_down_rules/rdr.py
CHANGED
@@ -22,7 +22,7 @@ from .helpers import is_matching
|
|
22
22
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
23
23
|
from .utils import draw_tree, make_set, copy_case, \
|
24
24
|
SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
|
25
|
-
get_case_attribute_type, is_conflicting
|
25
|
+
get_case_attribute_type, is_conflicting, update_case
|
26
26
|
|
27
27
|
|
28
28
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -49,32 +49,6 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
49
49
|
self.start_rule = start_rule
|
50
50
|
self.fig: Optional[plt.Figure] = None
|
51
51
|
|
52
|
-
def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
|
53
|
-
return self.classify(case)
|
54
|
-
|
55
|
-
@abstractmethod
|
56
|
-
def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
|
57
|
-
"""
|
58
|
-
Classify a case.
|
59
|
-
|
60
|
-
:param case: The case to classify.
|
61
|
-
:return: The category that the case belongs to.
|
62
|
-
"""
|
63
|
-
pass
|
64
|
-
|
65
|
-
@abstractmethod
|
66
|
-
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
67
|
-
-> Union[CaseAttribute, CallableExpression]:
|
68
|
-
"""
|
69
|
-
Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
|
70
|
-
comparing the case with the target category.
|
71
|
-
|
72
|
-
:param case_query: The query containing the case to classify and the target category to compare the case with.
|
73
|
-
:param expert: The expert to ask for differentiating features as new rule conditions.
|
74
|
-
:return: The category that the case belongs to.
|
75
|
-
"""
|
76
|
-
pass
|
77
|
-
|
78
52
|
def fit(self, case_queries: List[CaseQuery],
|
79
53
|
expert: Optional[Expert] = None,
|
80
54
|
n_iter: int = None,
|
@@ -124,6 +98,63 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
124
98
|
plt.ioff()
|
125
99
|
plt.show()
|
126
100
|
|
101
|
+
def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
|
102
|
+
return self.classify(case)
|
103
|
+
|
104
|
+
@abstractmethod
|
105
|
+
def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
|
106
|
+
"""
|
107
|
+
Classify a case.
|
108
|
+
|
109
|
+
:param case: The case to classify.
|
110
|
+
:return: The category that the case belongs to.
|
111
|
+
"""
|
112
|
+
pass
|
113
|
+
|
114
|
+
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
115
|
+
-> Union[CaseAttribute, CallableExpression]:
|
116
|
+
"""
|
117
|
+
Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
|
118
|
+
incorrect by comparing the case with the target category.
|
119
|
+
|
120
|
+
:param case_query: The query containing the case to classify and the target category to compare the case with.
|
121
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
122
|
+
:return: The category that the case belongs to.
|
123
|
+
"""
|
124
|
+
if case_query is None:
|
125
|
+
raise ValueError("The case query cannot be None.")
|
126
|
+
if case_query.target is None:
|
127
|
+
expert.ask_for_conclusion(case_query)
|
128
|
+
if case_query.target is None:
|
129
|
+
return self.classify(case_query.case)
|
130
|
+
|
131
|
+
self.update_start_rule(case_query, expert)
|
132
|
+
|
133
|
+
return self._fit_case(case_query, expert=expert, **kwargs)
|
134
|
+
|
135
|
+
@abstractmethod
|
136
|
+
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
137
|
+
-> Union[CaseAttribute, CallableExpression]:
|
138
|
+
"""
|
139
|
+
Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
|
140
|
+
comparing the case with the target category.
|
141
|
+
|
142
|
+
:param case_query: The query containing the case to classify and the target category to compare the case with.
|
143
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
144
|
+
:return: The category that the case belongs to.
|
145
|
+
"""
|
146
|
+
pass
|
147
|
+
|
148
|
+
@abstractmethod
|
149
|
+
def update_start_rule(self, case_query: CaseQuery, expert: Expert):
|
150
|
+
"""
|
151
|
+
Update the starting rule of the classifier.
|
152
|
+
|
153
|
+
:param case_query: The case query to update the starting rule with.
|
154
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
155
|
+
"""
|
156
|
+
pass
|
157
|
+
|
127
158
|
def update_figures(self):
|
128
159
|
"""
|
129
160
|
Update the figures of the classifier.
|
@@ -295,13 +326,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
295
326
|
"""
|
296
327
|
:return: The type of the conclusion of the RDR classifier.
|
297
328
|
"""
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
if isinstance(conclusion, set):
|
303
|
-
return type(list(conclusion)[0]), set
|
304
|
-
return (type(conclusion),)
|
329
|
+
all_types = []
|
330
|
+
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
331
|
+
all_types.extend(list(rule.conclusion.conclusion_type))
|
332
|
+
return tuple(set(all_types))
|
305
333
|
|
306
334
|
@property
|
307
335
|
def attribute_name(self) -> str:
|
@@ -321,7 +349,7 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
321
349
|
super(SingleClassRDR, self).__init__(start_rule)
|
322
350
|
self.default_conclusion: Optional[Any] = default_conclusion
|
323
351
|
|
324
|
-
def
|
352
|
+
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
325
353
|
-> Union[CaseAttribute, CallableExpression, None]:
|
326
354
|
"""
|
327
355
|
Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
|
@@ -331,38 +359,44 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
331
359
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
332
360
|
:return: The category that the case belongs to.
|
333
361
|
"""
|
334
|
-
expert = expert if expert else Human()
|
335
362
|
if case_query.default_value is not None and self.default_conclusion != case_query.default_value:
|
336
363
|
self.default_conclusion = case_query.default_value
|
337
|
-
case = case_query.case
|
338
|
-
target = expert.ask_for_conclusion(case_query) if case_query.target is None else case_query.target
|
339
|
-
if target is None:
|
340
|
-
return self.classify(case)
|
341
|
-
if not self.start_rule:
|
342
|
-
conditions = expert.ask_for_conditions(case_query)
|
343
|
-
self.start_rule = SingleClassRule(conditions, target, corner_case=case,
|
344
|
-
conclusion_name=case_query.attribute_name)
|
345
364
|
|
346
365
|
pred = self.evaluate(case_query.case)
|
347
|
-
if pred.conclusion(case) !=
|
348
|
-
|
349
|
-
pred.fit_rule(case_query.case, target, conditions=conditions)
|
366
|
+
if pred.conclusion(case_query.case) != case_query.target_value:
|
367
|
+
expert.ask_for_conditions(case_query, pred)
|
368
|
+
pred.fit_rule(case_query.case, case_query.target, conditions=case_query.conditions)
|
350
369
|
|
351
370
|
return self.classify(case_query.case)
|
352
371
|
|
353
|
-
def
|
372
|
+
def update_start_rule(self, case_query: CaseQuery, expert: Expert):
|
373
|
+
"""
|
374
|
+
Update the starting rule of the classifier.
|
375
|
+
|
376
|
+
:param case_query: The case query to update the starting rule with.
|
377
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
378
|
+
"""
|
379
|
+
if not self.start_rule:
|
380
|
+
expert.ask_for_conditions(case_query)
|
381
|
+
self.start_rule = SingleClassRule(case_query.conditions, case_query.target, corner_case=case_query.case,
|
382
|
+
conclusion_name=case_query.attribute_name)
|
383
|
+
|
384
|
+
def classify(self, case: Case, modify_original_case: bool = False) -> Optional[Any]:
|
354
385
|
"""
|
355
386
|
Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
|
387
|
+
|
388
|
+
:param case: The case to classify.
|
389
|
+
:param modify_original_case: Whether to modify the original case attributes with the conclusion or not.
|
356
390
|
"""
|
357
391
|
pred = self.evaluate(case)
|
358
|
-
return pred.conclusion(case) if pred.fired else self.default_conclusion
|
392
|
+
return pred.conclusion(case) if pred is not None and pred.fired else self.default_conclusion
|
359
393
|
|
360
394
|
def evaluate(self, case: Case) -> SingleClassRule:
|
361
395
|
"""
|
362
396
|
Evaluate the starting rule on a case.
|
363
397
|
"""
|
364
|
-
matched_rule = self.start_rule(case)
|
365
|
-
return matched_rule if matched_rule else self.start_rule
|
398
|
+
matched_rule = self.start_rule(case) if self.start_rule is not None else None
|
399
|
+
return matched_rule if matched_rule is not None else self.start_rule
|
366
400
|
|
367
401
|
def write_to_python_file(self, file_path: str, postfix: str = ""):
|
368
402
|
super().write_to_python_file(file_path, postfix)
|
@@ -382,7 +416,8 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
382
416
|
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
|
383
417
|
defs_file=defs_file)
|
384
418
|
|
385
|
-
|
419
|
+
conclusion_call = rule.write_conclusion_as_source_code(parent_indent, defs_file)
|
420
|
+
file.write(conclusion_call)
|
386
421
|
|
387
422
|
if rule.alternative:
|
388
423
|
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
|
@@ -391,6 +426,12 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
391
426
|
def conclusion_type_hint(self) -> str:
|
392
427
|
return self.conclusion_type[0].__name__
|
393
428
|
|
429
|
+
@property
|
430
|
+
def conclusion_type(self) -> Tuple[Type]:
|
431
|
+
if self.default_conclusion is not None:
|
432
|
+
return (type(self.default_conclusion),)
|
433
|
+
return super().conclusion_type
|
434
|
+
|
394
435
|
def _to_json(self) -> Dict[str, Any]:
|
395
436
|
return {"start_rule": self.start_rule.to_json()}
|
396
437
|
|
@@ -422,13 +463,12 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
422
463
|
The conditions of the stopping rule if needed.
|
423
464
|
"""
|
424
465
|
|
425
|
-
def __init__(self, start_rule: Optional[
|
466
|
+
def __init__(self, start_rule: Optional[MultiClassTopRule] = None,
|
426
467
|
mode: MCRDRMode = MCRDRMode.StopOnly):
|
427
468
|
"""
|
428
469
|
:param start_rule: The starting rules for the classifier.
|
429
470
|
:param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
|
430
471
|
"""
|
431
|
-
start_rule = MultiClassTopRule() if not start_rule else start_rule
|
432
472
|
super(MultiClassRDR, self).__init__(start_rule)
|
433
473
|
self.mode: MCRDRMode = mode
|
434
474
|
|
@@ -442,40 +482,28 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
442
482
|
evaluated_rule = next_rule
|
443
483
|
return make_set(self.conclusions)
|
444
484
|
|
445
|
-
def
|
446
|
-
|
485
|
+
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None
|
486
|
+
, **kwargs) -> Set[Union[CaseAttribute, CallableExpression, None]]:
|
447
487
|
"""
|
448
488
|
Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
|
449
489
|
or missing by comparing the case with the target category if provided.
|
450
490
|
|
451
491
|
:param case_query: The query containing the case to classify and the target category to compare the case with.
|
452
492
|
:param expert: The expert to ask for differentiating features as new rule conditions or for extra conclusions.
|
453
|
-
:param add_extra_conclusions: Whether to add extra conclusions after classification is done.
|
454
493
|
:return: The conclusions that the case belongs to.
|
455
494
|
"""
|
456
|
-
expert = expert if expert else Human()
|
457
|
-
if case_query.target is None:
|
458
|
-
expert.ask_for_conclusion(case_query)
|
459
|
-
if case_query.target is None:
|
460
|
-
return self.classify(case_query.case)
|
461
|
-
self.update_start_rule(case_query, expert)
|
462
|
-
self.expert_accepted_conclusions = []
|
463
|
-
user_conclusions = []
|
464
495
|
self.conclusions = []
|
465
496
|
self.stop_rule_conditions = None
|
466
497
|
evaluated_rule = self.start_rule
|
467
|
-
target =
|
498
|
+
target = make_set(case_query.target_value)
|
468
499
|
while evaluated_rule:
|
469
500
|
next_rule = evaluated_rule(case_query.case)
|
470
501
|
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
471
|
-
good_conclusions = make_list(target) + user_conclusions + self.expert_accepted_conclusions
|
472
|
-
good_conclusions = make_set(good_conclusions)
|
473
502
|
|
474
503
|
if evaluated_rule.fired:
|
475
|
-
if
|
504
|
+
if not make_set(rule_conclusion).issubset(target):
|
476
505
|
# Rule fired and conclusion is different from target
|
477
|
-
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule
|
478
|
-
len(user_conclusions) > 0)
|
506
|
+
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
|
479
507
|
else:
|
480
508
|
# Rule fired and target is correct or there is no target to compare
|
481
509
|
self.add_conclusion(evaluated_rule, case_query.case)
|
@@ -486,14 +514,6 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
486
514
|
self.add_rule_for_case(case_query, expert)
|
487
515
|
# Have to check all rules again to make sure only this new rule fires
|
488
516
|
next_rule = self.start_rule
|
489
|
-
elif add_extra_conclusions:
|
490
|
-
# No more conclusions can be made, ask the expert for extra conclusions if needed.
|
491
|
-
new_user_conclusions = self.ask_expert_for_extra_rules(expert, case_query)
|
492
|
-
user_conclusions.extend(new_user_conclusions)
|
493
|
-
if len(new_user_conclusions) > 0:
|
494
|
-
next_rule = self.last_top_rule
|
495
|
-
else:
|
496
|
-
add_extra_conclusions = False
|
497
517
|
evaluated_rule = next_rule
|
498
518
|
return self.conclusions
|
499
519
|
|
@@ -510,18 +530,20 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
510
530
|
defs_file=defs_file)
|
511
531
|
conclusion_indent = parent_indent + " " * 4
|
512
532
|
file.write(f"{conclusion_indent}else:\n")
|
513
|
-
|
533
|
+
|
534
|
+
conclusion_call = rule.write_conclusion_as_source_code(conclusion_indent, defs_file)
|
535
|
+
file.write(conclusion_call)
|
514
536
|
|
515
537
|
if rule.alternative:
|
516
538
|
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
|
517
539
|
|
518
540
|
@property
|
519
541
|
def conclusion_type_hint(self) -> str:
|
520
|
-
return f"Set[{self.conclusion_type[
|
542
|
+
return f"Set[Union[{', '.join([ct.__name__ for ct in self.conclusion_type if ct not in [list, set]])}]]"
|
521
543
|
|
522
544
|
def _get_imports(self) -> str:
|
523
545
|
imports = super()._get_imports()
|
524
|
-
imports += "from typing_extensions import Set\n"
|
546
|
+
imports += "from typing_extensions import Set, Union\n"
|
525
547
|
imports += "from ripple_down_rules.utils import make_set\n"
|
526
548
|
return imports
|
527
549
|
|
@@ -532,12 +554,10 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
532
554
|
:param case_query: The case query to update the starting rule with.
|
533
555
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
534
556
|
"""
|
535
|
-
if not self.start_rule
|
557
|
+
if not self.start_rule:
|
536
558
|
conditions = expert.ask_for_conditions(case_query)
|
537
|
-
self.start_rule
|
538
|
-
|
539
|
-
self.start_rule.corner_case = case_query.case
|
540
|
-
self.start_rule.conclusion_name = case_query.attribute_name
|
559
|
+
self.start_rule = MultiClassTopRule(conditions, case_query.target, corner_case=case_query.case,
|
560
|
+
conclusion_name=case_query.attribute_name)
|
541
561
|
|
542
562
|
@property
|
543
563
|
def last_top_rule(self) -> Optional[MultiClassTopRule]:
|
@@ -550,17 +570,13 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
550
570
|
return self.start_rule.furthest_alternative[-1]
|
551
571
|
|
552
572
|
def stop_wrong_conclusion_else_add_it(self, case_query: CaseQuery, expert: Expert,
|
553
|
-
evaluated_rule: MultiClassTopRule
|
554
|
-
add_extra_conclusions: bool):
|
573
|
+
evaluated_rule: MultiClassTopRule):
|
555
574
|
"""
|
556
575
|
Stop a wrong conclusion by adding a stopping rule.
|
557
576
|
"""
|
558
577
|
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
559
578
|
if is_conflicting(rule_conclusion, case_query.target_value):
|
560
|
-
|
561
|
-
return
|
562
|
-
else:
|
563
|
-
self.stop_conclusion(case_query, expert, evaluated_rule)
|
579
|
+
self.stop_conclusion(case_query, expert, evaluated_rule)
|
564
580
|
|
565
581
|
def stop_conclusion(self, case_query: CaseQuery,
|
566
582
|
expert: Expert, evaluated_rule: MultiClassTopRule):
|
@@ -579,27 +595,6 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
579
595
|
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
580
596
|
self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
|
581
597
|
|
582
|
-
def conclusion_is_correct(self, case_query: CaseQuery,
|
583
|
-
expert: Expert, evaluated_rule: Rule,
|
584
|
-
add_extra_conclusions: bool) -> bool:
|
585
|
-
"""
|
586
|
-
Ask the expert if the conclusion is correct, and add it to the conclusions if it is.
|
587
|
-
|
588
|
-
:param case_query: The case query to ask the expert about.
|
589
|
-
:param expert: The expert to ask for differentiating features as new rule conditions.
|
590
|
-
:param evaluated_rule: The evaluated rule to ask the expert about.
|
591
|
-
:param add_extra_conclusions: Whether adding extra conclusions after classification is allowed.
|
592
|
-
:return: Whether the conclusion is correct or not.
|
593
|
-
"""
|
594
|
-
conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
|
595
|
-
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
|
596
|
-
evaluated_rule.conclusion(case_query.case),
|
597
|
-
current_conclusions=conclusions)):
|
598
|
-
self.add_conclusion(evaluated_rule, case_query.case)
|
599
|
-
self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
|
600
|
-
return True
|
601
|
-
return False
|
602
|
-
|
603
598
|
def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
|
604
599
|
"""
|
605
600
|
Add a rule for a case that has not been classified with any conclusion.
|
@@ -614,24 +609,6 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
614
609
|
conditions = expert.ask_for_conditions(case_query)
|
615
610
|
self.add_top_rule(conditions, case_query.target, case_query.case)
|
616
611
|
|
617
|
-
def ask_expert_for_extra_rules(self, expert: Expert, case_query: CaseQuery) -> List[Any]:
|
618
|
-
"""
|
619
|
-
Ask the expert for extra rules when no more conclusions can be made for a case.
|
620
|
-
|
621
|
-
:param expert: The expert to ask for extra conclusions.
|
622
|
-
:param case_query: The case query to ask the expert about.
|
623
|
-
:return: The extra conclusions for the rules that the expert has provided.
|
624
|
-
"""
|
625
|
-
extra_conclusions = []
|
626
|
-
conclusions = list(OrderedSet(self.conclusions))
|
627
|
-
if not expert.use_loaded_answers:
|
628
|
-
print("current conclusions:", conclusions)
|
629
|
-
extra_rules = expert.ask_for_extra_rules(case_query)
|
630
|
-
for rule in extra_rules:
|
631
|
-
self.add_top_rule(rule[PromptFor.Conditions], rule[PromptFor.Conclusion], case_query.case)
|
632
|
-
extra_conclusions.extend(rule[PromptFor.Conclusion](case_query.case))
|
633
|
-
return extra_conclusions
|
634
|
-
|
635
612
|
def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
|
636
613
|
"""
|
637
614
|
Add the conclusion of the evaluated rule to the list of conclusions.
|
@@ -725,30 +702,32 @@ class GeneralRDR(RippleDownRules):
|
|
725
702
|
def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
|
726
703
|
return [rdr.start_rule for rdr in self.start_rules_dict.values()]
|
727
704
|
|
728
|
-
def classify(self, case: Any) -> Optional[Dict[str, Any]]:
|
705
|
+
def classify(self, case: Any, modify_case: bool = False) -> Optional[Dict[str, Any]]:
|
729
706
|
"""
|
730
707
|
Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
|
731
708
|
the classification until no more categories can be added.
|
732
709
|
|
733
710
|
:param case: The case to classify.
|
711
|
+
:param modify_case: Whether to modify the original case or create a copy and modify it.
|
734
712
|
:return: The categories that the case belongs to.
|
735
713
|
"""
|
736
|
-
return self._classify(self.start_rules_dict, case)
|
714
|
+
return self._classify(self.start_rules_dict, case, modify_original_case=modify_case)
|
737
715
|
|
738
716
|
@staticmethod
|
739
717
|
def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
|
740
|
-
case: Any) -> Dict[str, Any]:
|
718
|
+
case: Any, modify_original_case: bool = True) -> Dict[str, Any]:
|
741
719
|
"""
|
742
720
|
Classify a case by going through all classifiers and adding the categories that are classified,
|
743
721
|
and then restarting the classification until no more categories can be added.
|
744
722
|
|
745
723
|
:param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
|
746
724
|
:param case: The case to classify.
|
725
|
+
:param modify_original_case: Whether to modify the original case or create a copy and modify it.
|
747
726
|
:return: The categories that the case belongs to.
|
748
727
|
"""
|
749
728
|
conclusions = {}
|
750
729
|
case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
|
751
|
-
case_cp = copy_case(case)
|
730
|
+
case_cp = copy_case(case) if not modify_original_case else case
|
752
731
|
while True:
|
753
732
|
new_conclusions = {}
|
754
733
|
for attribute_name, rdr in classifiers_dict.items():
|
@@ -771,13 +750,13 @@ class GeneralRDR(RippleDownRules):
|
|
771
750
|
conclusions[attribute_name].update(pred_atts)
|
772
751
|
if attribute_name in new_conclusions:
|
773
752
|
mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
|
774
|
-
|
775
|
-
|
753
|
+
case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive)
|
754
|
+
update_case(case_query, new_conclusions)
|
776
755
|
if len(new_conclusions) == 0:
|
777
756
|
break
|
778
757
|
return conclusions
|
779
758
|
|
780
|
-
def
|
759
|
+
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
781
760
|
-> Dict[str, Any]:
|
782
761
|
"""
|
783
762
|
Fit the GRDR on a case, if the target is a new type of category, a new RDR is created for it,
|
@@ -787,115 +766,38 @@ class GeneralRDR(RippleDownRules):
|
|
787
766
|
they are accepted by the expert, and the attribute of that category is represented in the case as a set of
|
788
767
|
values.
|
789
768
|
|
790
|
-
:param
|
769
|
+
:param case_query: The query containing the case to classify and the target category to compare the case
|
791
770
|
with.
|
792
771
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
793
772
|
:return: The categories that the case belongs to.
|
794
773
|
"""
|
795
|
-
expert = expert if expert else Human()
|
796
|
-
case_queries = make_list(case_queries)
|
797
|
-
assert len(case_queries) > 0, "No case queries provided"
|
798
|
-
case = case_queries[0].case
|
799
|
-
assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
|
800
|
-
" for multiple cases use fit instead")
|
801
|
-
original_case_query_cp = copy(case_queries[0])
|
802
|
-
for case_query in case_queries:
|
803
|
-
case_query_cp = copy(case_query)
|
804
|
-
case_query_cp.case = original_case_query_cp.case
|
805
|
-
if case_query_cp.target is None:
|
806
|
-
conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
|
807
|
-
self.update_case(case_query_cp, conclusions)
|
808
|
-
expert.ask_for_conclusion(case_query_cp)
|
809
|
-
if case_query_cp.target is None:
|
810
|
-
continue
|
811
|
-
case_query.target = case_query_cp.target
|
812
|
-
|
813
|
-
if case_query.attribute_name not in self.start_rules_dict:
|
814
|
-
conclusions = self.classify(case)
|
815
|
-
self.update_case(case_query_cp, conclusions)
|
816
|
-
|
817
|
-
new_rdr = self.initialize_new_rdr_for_attribute(case_query_cp)
|
818
|
-
self.add_rdr(new_rdr, case_query.attribute_name)
|
819
|
-
|
820
|
-
new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
|
821
|
-
self.update_case(case_query_cp, {case_query.attribute_name: new_conclusions})
|
822
|
-
else:
|
823
|
-
for rdr_attribute_name, rdr in self.start_rules_dict.items():
|
824
|
-
if case_query.attribute_name != rdr_attribute_name:
|
825
|
-
conclusions = rdr.classify(case_query_cp.case)
|
826
|
-
else:
|
827
|
-
conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
|
828
|
-
**kwargs)
|
829
|
-
if conclusions is not None:
|
830
|
-
if (not is_iterable(conclusions)) or len(conclusions) > 0:
|
831
|
-
conclusions = {rdr_attribute_name: conclusions}
|
832
|
-
case_query_cp.mutually_exclusive = True if isinstance(rdr, SingleClassRDR) else False
|
833
|
-
self.update_case(case_query_cp, conclusions)
|
834
|
-
case_query.conditions = case_query_cp.conditions
|
835
774
|
|
836
|
-
|
775
|
+
case_query_cp = copy(case_query)
|
776
|
+
self.classify(case_query_cp.case, modify_case=True)
|
777
|
+
case_query_cp.update_target_value()
|
778
|
+
|
779
|
+
self.start_rules_dict[case_query_cp.attribute_name].fit_case(case_query_cp, expert, **kwargs)
|
780
|
+
|
781
|
+
return self.classify(case_query.case)
|
782
|
+
|
783
|
+
def update_start_rule(self, case_query: CaseQuery, expert: Expert):
|
784
|
+
"""
|
785
|
+
Update the starting rule of the classifier.
|
786
|
+
|
787
|
+
:param case_query: The case query to update the starting rule with.
|
788
|
+
:param expert: The expert to ask for differentiating features as new rule conditions.
|
789
|
+
"""
|
790
|
+
if case_query.attribute_name not in self.start_rules_dict:
|
791
|
+
new_rdr = self.initialize_new_rdr_for_attribute(case_query)
|
792
|
+
self.add_rdr(new_rdr, case_query.attribute_name)
|
837
793
|
|
838
794
|
@staticmethod
|
839
795
|
def initialize_new_rdr_for_attribute(case_query: CaseQuery):
|
840
796
|
"""
|
841
797
|
Initialize the appropriate RDR type for the target.
|
842
798
|
"""
|
843
|
-
if case_query.mutually_exclusive
|
844
|
-
|
845
|
-
else MultiClassRDR()
|
846
|
-
if case_query.attribute_type in [list, set]:
|
847
|
-
return MultiClassRDR()
|
848
|
-
attribute = getattr(case_query.case, case_query.attribute_name) \
|
849
|
-
if hasattr(case_query.case, case_query.attribute_name) else case_query.target(case_query.case)
|
850
|
-
if isinstance(attribute, CaseAttribute):
|
851
|
-
return SingleClassRDR(default_conclusion=case_query.default_value) if attribute.mutually_exclusive \
|
852
|
-
else MultiClassRDR()
|
853
|
-
else:
|
854
|
-
return MultiClassRDR() if is_iterable(attribute) or (attribute is None) \
|
855
|
-
else SingleClassRDR(default_conclusion=case_query.default_value)
|
856
|
-
|
857
|
-
@staticmethod
|
858
|
-
def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
|
859
|
-
"""
|
860
|
-
Update the case with the conclusions.
|
861
|
-
|
862
|
-
:param case_query: The case query that contains the case to update.
|
863
|
-
:param conclusions: The conclusions to update the case with.
|
864
|
-
"""
|
865
|
-
if not conclusions:
|
866
|
-
return
|
867
|
-
if len(conclusions) == 0:
|
868
|
-
return
|
869
|
-
if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
|
870
|
-
for conclusion_name, conclusion in conclusions.items():
|
871
|
-
attribute = getattr(case_query.case, conclusion_name)
|
872
|
-
if conclusion_name == case_query.attribute_name:
|
873
|
-
attribute_type = case_query.attribute_type
|
874
|
-
else:
|
875
|
-
attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
|
876
|
-
if isinstance(attribute, set):
|
877
|
-
for c in conclusion:
|
878
|
-
attribute.update(make_set(c))
|
879
|
-
elif isinstance(attribute, list):
|
880
|
-
attribute.extend(conclusion)
|
881
|
-
elif any(at in {List, list} for at in attribute_type):
|
882
|
-
attribute = [] if attribute is None else attribute
|
883
|
-
attribute.extend(conclusion)
|
884
|
-
elif any(at in {Set, set} for at in attribute_type):
|
885
|
-
attribute = set() if attribute is None else attribute
|
886
|
-
for c in conclusion:
|
887
|
-
attribute.update(make_set(c))
|
888
|
-
elif is_iterable(conclusion) and len(conclusion) == 1 \
|
889
|
-
and any(at is type(list(conclusion)[0]) for at in attribute_type):
|
890
|
-
setattr(case_query.case, conclusion_name, list(conclusion)[0])
|
891
|
-
elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
|
892
|
-
setattr(case_query.case, conclusion_name, conclusion)
|
893
|
-
else:
|
894
|
-
raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
|
895
|
-
f"{case_query.attribute_type} with conclusion "
|
896
|
-
f"{conclusion} of type {type(conclusion)}")
|
897
|
-
else:
|
898
|
-
case_query.case.update(conclusions)
|
799
|
+
return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
|
800
|
+
else MultiClassRDR()
|
899
801
|
|
900
802
|
def _to_json(self) -> Dict[str, Any]:
|
901
803
|
return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
|
@@ -969,7 +871,7 @@ class GeneralRDR(RippleDownRules):
|
|
969
871
|
|
970
872
|
@property
|
971
873
|
def conclusion_type_hint(self) -> str:
|
972
|
-
return f"
|
874
|
+
return f"Dict[str, Any]"
|
973
875
|
|
974
876
|
def _get_imports(self, file_path: str) -> str:
|
975
877
|
"""
|
@@ -980,7 +882,7 @@ class GeneralRDR(RippleDownRules):
|
|
980
882
|
"""
|
981
883
|
imports = ""
|
982
884
|
# add type hints
|
983
|
-
imports += f"from typing_extensions import
|
885
|
+
imports += f"from typing_extensions import Dict, Any, Union, Set\n"
|
984
886
|
# import rdr type
|
985
887
|
imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
|
986
888
|
# add case type
|
ripple_down_rules/rules.py
CHANGED
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
|
5
5
|
from enum import Enum
|
6
6
|
|
7
7
|
from anytree import NodeMixin
|
8
|
-
from typing_extensions import List, Optional, Self, Union, Dict, Any
|
8
|
+
from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
|
9
9
|
|
10
10
|
from .datastructures.callable_expression import CallableExpression
|
11
11
|
from .datastructures.case import Case
|
@@ -78,24 +78,26 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
78
78
|
"""
|
79
79
|
pass
|
80
80
|
|
81
|
-
def write_conclusion_as_source_code(self, parent_indent: str = "") -> str:
|
81
|
+
def write_conclusion_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
|
82
82
|
"""
|
83
83
|
Get the source code representation of the conclusion of the rule.
|
84
84
|
|
85
85
|
:param parent_indent: The indentation of the parent rule.
|
86
|
+
:param defs_file: The file to write the conclusion to if it is a definition.
|
87
|
+
:return: The source code representation of the conclusion of the rule.
|
86
88
|
"""
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
return
|
89
|
+
if self.conclusion.user_input is not None:
|
90
|
+
conclusion = self.conclusion.user_input
|
91
|
+
else:
|
92
|
+
conclusion = self.conclusion.conclusion
|
93
|
+
conclusion_func, conclusion_func_call = self._conclusion_source_code(conclusion, parent_indent=parent_indent)
|
94
|
+
if conclusion_func is not None:
|
95
|
+
with open(defs_file, 'a') as f:
|
96
|
+
f.write(conclusion_func + "\n\n")
|
97
|
+
return conclusion_func_call
|
96
98
|
|
97
99
|
@abstractmethod
|
98
|
-
def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> str:
|
100
|
+
def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
|
99
101
|
pass
|
100
102
|
|
101
103
|
def write_condition_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
|
@@ -118,7 +120,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
118
120
|
conditions_lines[0] = re.sub(r"def (\w+)", new_function_name, conditions_lines[0])
|
119
121
|
def_code = "\n".join(conditions_lines)
|
120
122
|
with open(defs_file, 'a') as f:
|
121
|
-
f.write(def_code + "\n")
|
123
|
+
f.write(def_code + "\n\n")
|
122
124
|
return f"\n{parent_indent}{if_clause} {new_function_name.replace('def ', '')}(case):\n"
|
123
125
|
|
124
126
|
@abstractmethod
|
@@ -130,6 +132,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
130
132
|
"conclusion": conclusion_to_json(self.conclusion),
|
131
133
|
"parent": self.parent.json_serialization if self.parent else None,
|
132
134
|
"corner_case": self.corner_case.to_json() if self.corner_case else None,
|
135
|
+
"conclusion_name": self.conclusion_name,
|
133
136
|
"weight": self.weight}
|
134
137
|
return json_serialization
|
135
138
|
|
@@ -139,6 +142,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
139
142
|
conclusion=CallableExpression.from_json(data["conclusion"]),
|
140
143
|
parent=cls.from_json(data["parent"]),
|
141
144
|
corner_case=Case.from_json(data["corner_case"]),
|
145
|
+
conclusion_name=data["conclusion_name"],
|
142
146
|
weight=data["weight"])
|
143
147
|
return loaded_rule
|
144
148
|
|
@@ -264,11 +268,11 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
|
264
268
|
loaded_rule.alternative = SingleClassRule.from_json(data["alternative"])
|
265
269
|
return loaded_rule
|
266
270
|
|
267
|
-
def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> str:
|
271
|
+
def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
|
268
272
|
conclusion = str(conclusion)
|
269
273
|
indent = parent_indent + " " * 4
|
270
274
|
if '\n' not in conclusion:
|
271
|
-
return f"{indent}return {conclusion}\n"
|
275
|
+
return None, f"{indent}return {conclusion}\n"
|
272
276
|
else:
|
273
277
|
return get_rule_conclusion_as_source_code(self, conclusion, parent_indent=parent_indent)
|
274
278
|
|
@@ -315,8 +319,8 @@ class MultiClassStopRule(Rule, HasAlternativeRule):
|
|
315
319
|
loaded_rule.alternative = MultiClassStopRule.from_json(data["alternative"])
|
316
320
|
return loaded_rule
|
317
321
|
|
318
|
-
def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> str:
|
319
|
-
return f"{parent_indent}{' ' * 4}pass\n"
|
322
|
+
def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[None, str]:
|
323
|
+
return None, f"{parent_indent}{' ' * 4}pass\n"
|
320
324
|
|
321
325
|
def _if_statement_source_code_clause(self) -> str:
|
322
326
|
return "elif" if self.weight == RDREdge.Alternative.value else "if"
|
@@ -360,25 +364,23 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
|
360
364
|
loaded_rule.alternative = MultiClassTopRule.from_json(data["alternative"])
|
361
365
|
return loaded_rule
|
362
366
|
|
363
|
-
def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> str:
|
367
|
+
def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[str, str]:
|
364
368
|
conclusion_str = str(conclusion)
|
365
369
|
indent = parent_indent + " " * 4
|
366
|
-
statement = ""
|
367
370
|
if '\n' not in conclusion_str:
|
371
|
+
func = None
|
368
372
|
if is_iterable(conclusion):
|
369
373
|
conclusion_str = "{" + ", ".join([str(c) for c in conclusion]) + "}"
|
370
374
|
else:
|
371
375
|
conclusion_str = "{" + str(conclusion) + "}"
|
372
376
|
else:
|
373
|
-
|
374
|
-
|
375
|
-
conclusion_str = lines[-2].replace("return ", "").strip()
|
376
|
-
statement += "\n".join(lines[:-2]) + "\n"
|
377
|
+
func, func_call = get_rule_conclusion_as_source_code(self, conclusion_str, parent_indent=parent_indent)
|
378
|
+
conclusion_str = func_call.replace("return ", "").strip()
|
377
379
|
|
378
|
-
statement
|
380
|
+
statement = f"{indent}conclusions.update(make_set({conclusion_str}))\n"
|
379
381
|
if self.alternative is None:
|
380
382
|
statement += f"{parent_indent}return conclusions\n"
|
381
|
-
return statement
|
383
|
+
return func, statement
|
382
384
|
|
383
385
|
def _if_statement_source_code_clause(self) -> str:
|
384
386
|
return "if"
|
ripple_down_rules/utils.py
CHANGED
@@ -26,6 +26,7 @@ from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get
|
|
26
26
|
|
27
27
|
if TYPE_CHECKING:
|
28
28
|
from .datastructures.case import Case
|
29
|
+
from .datastructures.dataclasses import CaseQuery
|
29
30
|
from .rules import Rule
|
30
31
|
|
31
32
|
import ast
|
@@ -33,6 +34,49 @@ import ast
|
|
33
34
|
matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
|
34
35
|
|
35
36
|
|
37
|
+
def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
|
38
|
+
"""
|
39
|
+
Update the case with the conclusions.
|
40
|
+
|
41
|
+
:param case_query: The case query that contains the case to update.
|
42
|
+
:param conclusions: The conclusions to update the case with.
|
43
|
+
"""
|
44
|
+
if not conclusions:
|
45
|
+
return
|
46
|
+
if len(conclusions) == 0:
|
47
|
+
return
|
48
|
+
if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
|
49
|
+
for conclusion_name, conclusion in conclusions.items():
|
50
|
+
attribute = getattr(case_query.case, conclusion_name)
|
51
|
+
if conclusion_name == case_query.attribute_name:
|
52
|
+
attribute_type = case_query.attribute_type
|
53
|
+
else:
|
54
|
+
attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
|
55
|
+
if isinstance(attribute, set):
|
56
|
+
for c in conclusion:
|
57
|
+
attribute.update(make_set(c))
|
58
|
+
elif isinstance(attribute, list):
|
59
|
+
attribute.extend(conclusion)
|
60
|
+
elif any(at in {List, list} for at in attribute_type):
|
61
|
+
attribute = [] if attribute is None else attribute
|
62
|
+
attribute.extend(conclusion)
|
63
|
+
elif any(at in {Set, set} for at in attribute_type):
|
64
|
+
attribute = set() if attribute is None else attribute
|
65
|
+
for c in conclusion:
|
66
|
+
attribute.update(make_set(c))
|
67
|
+
elif is_iterable(conclusion) and len(conclusion) == 1 \
|
68
|
+
and any(at is type(list(conclusion)[0]) for at in attribute_type):
|
69
|
+
setattr(case_query.case, conclusion_name, list(conclusion)[0])
|
70
|
+
elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
|
71
|
+
setattr(case_query.case, conclusion_name, conclusion)
|
72
|
+
else:
|
73
|
+
raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
|
74
|
+
f"{case_query.attribute_type} with conclusion "
|
75
|
+
f"{conclusion} of type {type(conclusion)}")
|
76
|
+
else:
|
77
|
+
case_query.case.update(conclusions)
|
78
|
+
|
79
|
+
|
36
80
|
def is_conflicting(conclusion: Any, target: Any) -> bool:
|
37
81
|
"""
|
38
82
|
:param conclusion: The conclusion to check.
|
@@ -75,14 +119,14 @@ def calculate_precision_and_recall(pred_cat: Dict[str, Any], target: Dict[str, A
|
|
75
119
|
return precision, recall
|
76
120
|
|
77
121
|
|
78
|
-
def get_rule_conclusion_as_source_code(rule: Rule, conclusion: str, parent_indent: str = "") -> str:
|
122
|
+
def get_rule_conclusion_as_source_code(rule: Rule, conclusion: str, parent_indent: str = "") -> Tuple[str, str]:
|
79
123
|
"""
|
80
124
|
Convert the conclusion of a rule to source code.
|
81
125
|
|
82
126
|
:param rule: The rule to get the conclusion from.
|
83
127
|
:param conclusion: The conclusion to convert to source code.
|
84
128
|
:param parent_indent: The indentation to use for the source code.
|
85
|
-
:return: The source code of the conclusion.
|
129
|
+
:return: The source code of the conclusion as a tuple of strings, one for the function and one for the call.
|
86
130
|
"""
|
87
131
|
indent = f"{parent_indent}{' ' * 4}"
|
88
132
|
if "def " in conclusion:
|
@@ -91,9 +135,8 @@ def get_rule_conclusion_as_source_code(rule: Rule, conclusion: str, parent_inden
|
|
91
135
|
# use regex to replace the function name
|
92
136
|
new_function_name = f"def conclusion_{id(rule)}"
|
93
137
|
conclusion_lines[0] = re.sub(r"def (\w+)", new_function_name, conclusion_lines[0])
|
94
|
-
|
95
|
-
|
96
|
-
return "\n".join(conclusion_lines)
|
138
|
+
func_call = f"{indent}return {new_function_name.replace('def ', '')}(case)\n"
|
139
|
+
return "\n".join(conclusion_lines).strip(' '), func_call
|
97
140
|
else:
|
98
141
|
raise ValueError(f"Conclusion is format is not valid, it should be a one line string or "
|
99
142
|
f"contain a function definition. Instead got:\n{conclusion}\n")
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.66
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -1,20 +1,20 @@
|
|
1
1
|
ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
ripple_down_rules/datasets.py,sha256=rCSpeFeu1gTuKESwjHUdQkPPvomI5OMRNGpbdKmHwMg,4639
|
3
|
-
ripple_down_rules/experts.py,sha256=
|
3
|
+
ripple_down_rules/experts.py,sha256=JGVvSNiWhm4FpRpg76f98tl8Ii_C7x_aWD9FxD-JDLQ,6130
|
4
4
|
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
5
|
ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
|
6
6
|
ripple_down_rules/prompt.py,sha256=cHqhMJqubGhfGpOOY_uXv5L7PBNb64O0IBWSfiY0ui0,6682
|
7
|
-
ripple_down_rules/rdr.py,sha256=
|
7
|
+
ripple_down_rules/rdr.py,sha256=iRoQcfXlCtAP9e_TxWY26ZZOv-Ki4MVRkmfPjlJ2vYY,41535
|
8
8
|
ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
|
9
|
-
ripple_down_rules/rules.py,sha256=
|
10
|
-
ripple_down_rules/utils.py,sha256=
|
9
|
+
ripple_down_rules/rules.py,sha256=Ms9uDDM7QaFrzzPB-nsPOWbIrT-QVqvhmHA4z8HQOgI,16547
|
10
|
+
ripple_down_rules/utils.py,sha256=A4ArFvF4lTsjicfN2fV2oIIo2HSapA9LXb92sY6n-Rs,34538
|
11
11
|
ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
|
12
12
|
ripple_down_rules/datastructures/callable_expression.py,sha256=TW_u6CJfelW2CiJj9pWFpdOBNIxeEuhhsQEz_pLpFVE,9092
|
13
13
|
ripple_down_rules/datastructures/case.py,sha256=uM5YkJOfYfARrZZ3oAxKMWE5QBSnvHLLZa9Atoxb7eY,13800
|
14
|
-
ripple_down_rules/datastructures/dataclasses.py,sha256=
|
14
|
+
ripple_down_rules/datastructures/dataclasses.py,sha256=V757VwxROlevXh5ZVFLVuzwY4JIJKG8ARlCfjhubfy8,5957
|
15
15
|
ripple_down_rules/datastructures/enums.py,sha256=RdyPUp9Ls1QuLmkcMMkBbCWrmXIZI4xWuM-cLPYZhR0,4666
|
16
|
-
ripple_down_rules-0.1.
|
17
|
-
ripple_down_rules-0.1.
|
18
|
-
ripple_down_rules-0.1.
|
19
|
-
ripple_down_rules-0.1.
|
20
|
-
ripple_down_rules-0.1.
|
16
|
+
ripple_down_rules-0.1.66.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
17
|
+
ripple_down_rules-0.1.66.dist-info/METADATA,sha256=PFdnyrcmNvbrVchvwNW0zl19IkCPhmsCAcpJ3VYO49Y,42576
|
18
|
+
ripple_down_rules-0.1.66.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
|
19
|
+
ripple_down_rules-0.1.66.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
20
|
+
ripple_down_rules-0.1.66.dist-info/RECORD,,
|
File without changes
|
File without changes
|