ripple-down-rules 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +2 -1
- ripple_down_rules/datastructures/__init__.py +4 -4
- ripple_down_rules/datastructures/callable_expression.py +68 -128
- ripple_down_rules/datastructures/case.py +1 -1
- ripple_down_rules/datastructures/dataclasses.py +102 -48
- ripple_down_rules/experts.py +24 -22
- ripple_down_rules/prompt.py +44 -50
- ripple_down_rules/rdr.py +290 -153
- ripple_down_rules/rules.py +64 -32
- ripple_down_rules/utils.py +91 -2
- {ripple_down_rules-0.1.3.dist-info → ripple_down_rules-0.1.5.dist-info}/METADATA +1 -1
- ripple_down_rules-0.1.5.dist-info/RECORD +20 -0
- {ripple_down_rules-0.1.3.dist-info → ripple_down_rules-0.1.5.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.1.3.dist-info/RECORD +0 -20
- {ripple_down_rules-0.1.3.dist-info → ripple_down_rules-0.1.5.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.1.3.dist-info → ripple_down_rules-0.1.5.dist-info}/top_level.txt +0 -0
ripple_down_rules/rdr.py
CHANGED
@@ -1,20 +1,28 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import importlib
|
4
|
+
import re
|
5
|
+
import sys
|
4
6
|
from abc import ABC, abstractmethod
|
5
7
|
from copy import copy
|
8
|
+
from dataclasses import is_dataclass
|
9
|
+
from io import TextIOWrapper
|
6
10
|
from types import ModuleType
|
7
11
|
|
8
12
|
from matplotlib import pyplot as plt
|
9
13
|
from ordered_set import OrderedSet
|
10
|
-
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
14
|
+
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
11
15
|
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
|
12
16
|
|
13
|
-
from .datastructures import Case,
|
17
|
+
from .datastructures.case import Case, CaseAttribute
|
18
|
+
from .datastructures.callable_expression import CallableExpression
|
19
|
+
from .datastructures.dataclasses import CaseQuery
|
20
|
+
from .datastructures.enums import MCRDRMode
|
14
21
|
from .experts import Expert, Human
|
15
22
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
16
23
|
from .utils import draw_tree, make_set, copy_case, \
|
17
|
-
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string
|
24
|
+
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
|
25
|
+
get_case_attribute_type, ask_llm
|
18
26
|
|
19
27
|
|
20
28
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -29,14 +37,16 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
29
37
|
"""
|
30
38
|
The conclusions that the expert has accepted, such that they are not asked again.
|
31
39
|
"""
|
40
|
+
_generated_python_file_name: Optional[str] = None
|
41
|
+
"""
|
42
|
+
The name of the generated python file.
|
43
|
+
"""
|
32
44
|
|
33
|
-
def __init__(self, start_rule: Optional[Rule] = None
|
45
|
+
def __init__(self, start_rule: Optional[Rule] = None):
|
34
46
|
"""
|
35
47
|
:param start_rule: The starting rule for the classifier.
|
36
|
-
:param session: The sqlalchemy orm session.
|
37
48
|
"""
|
38
49
|
self.start_rule = start_rule
|
39
|
-
self.session = session
|
40
50
|
self.fig: Optional[plt.Figure] = None
|
41
51
|
|
42
52
|
def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
|
@@ -79,31 +89,29 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
79
89
|
:param animate_tree: Whether to draw the tree while fitting the classifier.
|
80
90
|
:param kwargs_for_fit_case: The keyword arguments to pass to the fit_case method.
|
81
91
|
"""
|
82
|
-
|
83
|
-
targets = [{case_query.attribute_name: case_query.target} for case_query in case_queries]
|
92
|
+
targets = []
|
84
93
|
if animate_tree:
|
85
94
|
plt.ion()
|
86
95
|
i = 0
|
87
96
|
stop_iterating = False
|
88
97
|
num_rules: int = 0
|
89
98
|
while not stop_iterating:
|
90
|
-
all_pred = 0
|
91
|
-
if not targets:
|
92
|
-
targets = [None] * len(cases)
|
93
99
|
for case_query in case_queries:
|
94
|
-
target = {case_query.attribute_name: case_query.target}
|
95
100
|
pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
|
96
|
-
|
101
|
+
if case_query.target is None:
|
102
|
+
continue
|
103
|
+
target = {case_query.attribute_name: case_query.target(case_query.case)}
|
104
|
+
if len(targets) < len(case_queries):
|
105
|
+
targets.append(target)
|
106
|
+
match = self.is_matching(case_query, pred_cat)
|
97
107
|
if not match:
|
98
108
|
print(f"Predicted: {pred_cat} but expected: {target}")
|
99
|
-
all_pred += int(match)
|
100
109
|
if animate_tree and self.start_rule.size > num_rules:
|
101
110
|
num_rules = self.start_rule.size
|
102
111
|
self.update_figures()
|
103
112
|
i += 1
|
104
|
-
all_predictions = [1 if self.is_matching(
|
105
|
-
|
106
|
-
for case_query in case_queries]
|
113
|
+
all_predictions = [1 if self.is_matching(case_query) else 0 for case_query in case_queries
|
114
|
+
if case_query.target is not None]
|
107
115
|
all_pred = sum(all_predictions)
|
108
116
|
print(f"Accuracy: {all_pred}/{len(targets)}")
|
109
117
|
all_predicted = targets and all_pred == len(targets)
|
@@ -116,48 +124,43 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
116
124
|
plt.ioff()
|
117
125
|
plt.show()
|
118
126
|
|
127
|
+
def is_matching(self, case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
|
128
|
+
"""
|
129
|
+
:param case_query: The case query to check.
|
130
|
+
:param pred_cat: The predicted category.
|
131
|
+
:return: Whether the classifier prediction is matching case_query target or not.
|
132
|
+
"""
|
133
|
+
if case_query.target is None:
|
134
|
+
return False
|
135
|
+
if pred_cat is None:
|
136
|
+
pred_cat = self.classify(case_query.case)
|
137
|
+
if not isinstance(pred_cat, dict):
|
138
|
+
pred_cat = {case_query.attribute_name: pred_cat}
|
139
|
+
target = {case_query.attribute_name: case_query.target_value}
|
140
|
+
precision, recall = self.calculate_precision_and_recall(pred_cat, target)
|
141
|
+
return all(recall) and all(precision)
|
142
|
+
|
119
143
|
@staticmethod
|
120
|
-
def calculate_precision_and_recall(pred_cat:
|
144
|
+
def calculate_precision_and_recall(pred_cat: Dict[str, Any], target: Dict[str, Any]) -> Tuple[
|
121
145
|
List[bool], List[bool]]:
|
122
146
|
"""
|
123
147
|
:param pred_cat: The predicted category.
|
124
148
|
:param target: The target category.
|
125
149
|
:return: The precision and recall of the classifier.
|
126
150
|
"""
|
127
|
-
pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
|
128
|
-
target = target if is_iterable(target) else [target]
|
129
151
|
recall = []
|
130
152
|
precision = []
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
if is_iterable(target_value):
|
141
|
-
recall.extend([v in pred_cat[target_key] for v in target_value])
|
142
|
-
else:
|
143
|
-
recall.append(target_value == pred_cat[target_key])
|
144
|
-
else:
|
145
|
-
if isinstance(target, dict):
|
146
|
-
target = list(target.values())
|
147
|
-
recall = [not yi or (yi in pred_cat) for yi in target]
|
148
|
-
target_types = [type(yi) for yi in target]
|
149
|
-
precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
|
153
|
+
for pred_key, pred_value in pred_cat.items():
|
154
|
+
if pred_key not in target:
|
155
|
+
continue
|
156
|
+
precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
|
157
|
+
for target_key, target_value in target.items():
|
158
|
+
if target_key not in pred_cat:
|
159
|
+
recall.append(False)
|
160
|
+
continue
|
161
|
+
recall.extend([v in make_set(pred_cat[target_key]) for v in make_set(target_value)])
|
150
162
|
return precision, recall
|
151
163
|
|
152
|
-
def is_matching(self, pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> bool:
|
153
|
-
"""
|
154
|
-
:param pred_cat: The predicted category.
|
155
|
-
:param target: The target category.
|
156
|
-
:return: Whether the classifier is matching or not.
|
157
|
-
"""
|
158
|
-
precision, recall = self.calculate_precision_and_recall(pred_cat, target)
|
159
|
-
return all(recall) and all(precision)
|
160
|
-
|
161
164
|
def update_figures(self):
|
162
165
|
"""
|
163
166
|
Update the figures of the classifier.
|
@@ -183,33 +186,51 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
183
186
|
"""
|
184
187
|
return hasattr(case, conclusion_name) and getattr(case, conclusion_name) is not None
|
185
188
|
|
189
|
+
@property
|
190
|
+
def type_(self):
|
191
|
+
return self.__class__
|
192
|
+
|
186
193
|
|
187
194
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
188
195
|
|
189
196
|
@abstractmethod
|
190
|
-
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = ""
|
197
|
+
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
|
198
|
+
defs_file: Optional[str] = None):
|
191
199
|
"""
|
192
200
|
Write the rules as source code to a file.
|
193
201
|
|
194
202
|
:param rule: The rule to write as source code.
|
195
203
|
:param file: The file to write the source code to.
|
196
204
|
:param parent_indent: The indentation of the parent rule.
|
205
|
+
:param defs_file: The file to write the definitions to.
|
197
206
|
"""
|
198
207
|
pass
|
199
208
|
|
200
|
-
def write_to_python_file(self, file_path: str):
|
209
|
+
def write_to_python_file(self, file_path: str, postfix: str = ""):
|
201
210
|
"""
|
202
211
|
Write the tree of rules as source code to a file.
|
203
212
|
|
204
213
|
:param file_path: The path to the file to write the source code to.
|
214
|
+
:param postfix: The postfix to add to the file name.
|
205
215
|
"""
|
216
|
+
self.generated_python_file_name = self._default_generated_python_file_name + postfix
|
206
217
|
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
207
|
-
|
208
|
-
|
218
|
+
file_name = file_path + f"/{self.generated_python_file_name}.py"
|
219
|
+
defs_file_name = file_path + f"/{self.generated_python_defs_file_name}.py"
|
220
|
+
imports = self._get_imports()
|
221
|
+
# clear the files first
|
222
|
+
with open(defs_file_name, "w") as f:
|
223
|
+
f.write(imports + "\n\n")
|
224
|
+
with open(file_name, "w") as f:
|
225
|
+
imports += f"from .{self.generated_python_defs_file_name} import *\n"
|
226
|
+
imports += f"from ripple_down_rules.rdr import {self.__class__.__name__}\n"
|
227
|
+
f.write(imports + "\n\n")
|
228
|
+
f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n\n")
|
229
|
+
f.write(f"type_ = {self.__class__.__name__}\n\n")
|
209
230
|
f.write(func_def)
|
210
231
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
211
232
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
212
|
-
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
|
233
|
+
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name)
|
213
234
|
|
214
235
|
@property
|
215
236
|
@abstractmethod
|
@@ -226,26 +247,73 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
226
247
|
imports = ""
|
227
248
|
if self.case_type.__module__ != "builtins":
|
228
249
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|
229
|
-
|
230
|
-
|
231
|
-
|
250
|
+
for conclusion_type in self.conclusion_type:
|
251
|
+
if conclusion_type.__module__ != "builtins":
|
252
|
+
imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
|
253
|
+
imports += "from ripple_down_rules.datastructures.case import Case, create_case\n"
|
232
254
|
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
233
|
-
if rule.conditions:
|
234
|
-
|
235
|
-
|
236
|
-
|
255
|
+
if not rule.conditions:
|
256
|
+
continue
|
257
|
+
if rule.conditions.scope is None or len(rule.conditions.scope) == 0:
|
258
|
+
continue
|
259
|
+
for k, v in rule.conditions.scope.items():
|
260
|
+
new_imports = f"from {v.__module__} import {v.__name__}\n"
|
261
|
+
if new_imports in imports:
|
262
|
+
continue
|
263
|
+
imports += new_imports
|
237
264
|
return imports
|
238
265
|
|
239
|
-
def get_rdr_classifier_from_python_file(self, package_name) -> Callable[[Any], Any]:
|
266
|
+
def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
|
240
267
|
"""
|
241
268
|
:param package_name: The name of the package that contains the RDR classifier function.
|
242
269
|
:return: The module that contains the rdr classifier function.
|
243
270
|
"""
|
244
|
-
|
271
|
+
# remove from imports if exists first
|
272
|
+
name = f"{package_name.strip('./')}.{self.generated_python_file_name}"
|
273
|
+
try:
|
274
|
+
module = importlib.import_module(name)
|
275
|
+
del sys.modules[name]
|
276
|
+
except ModuleNotFoundError:
|
277
|
+
pass
|
278
|
+
return importlib.import_module(name).classify
|
245
279
|
|
246
280
|
@property
|
247
281
|
def generated_python_file_name(self) -> str:
|
248
|
-
|
282
|
+
if self._generated_python_file_name is None:
|
283
|
+
self._generated_python_file_name = self._default_generated_python_file_name
|
284
|
+
return self._generated_python_file_name
|
285
|
+
|
286
|
+
@generated_python_file_name.setter
|
287
|
+
def generated_python_file_name(self, value: str):
|
288
|
+
"""
|
289
|
+
Set the generated python file name.
|
290
|
+
:param value: The new value for the generated python file name.
|
291
|
+
"""
|
292
|
+
self._generated_python_file_name = value
|
293
|
+
|
294
|
+
@property
|
295
|
+
def _default_generated_python_file_name(self) -> str:
|
296
|
+
"""
|
297
|
+
:return: The default generated python file name.
|
298
|
+
"""
|
299
|
+
return f"{self.start_rule.corner_case._name.lower()}_{self.attribute_name}_{self.acronym.lower()}"
|
300
|
+
|
301
|
+
@property
|
302
|
+
def generated_python_defs_file_name(self) -> str:
|
303
|
+
return f"{self.generated_python_file_name}_defs"
|
304
|
+
|
305
|
+
@property
|
306
|
+
def acronym(self) -> str:
|
307
|
+
"""
|
308
|
+
:return: The acronym of the classifier.
|
309
|
+
"""
|
310
|
+
if self.__class__.__name__ == "GeneralRDR":
|
311
|
+
return "GRDR"
|
312
|
+
elif self.__class__.__name__ == "MultiClassRDR":
|
313
|
+
return "MCRDR"
|
314
|
+
else:
|
315
|
+
return "SCRDR"
|
316
|
+
|
249
317
|
|
250
318
|
@property
|
251
319
|
def case_type(self) -> Type:
|
@@ -258,16 +326,17 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
258
326
|
return type(self.start_rule.corner_case)
|
259
327
|
|
260
328
|
@property
|
261
|
-
def conclusion_type(self) -> Type:
|
329
|
+
def conclusion_type(self) -> Tuple[Type]:
|
262
330
|
"""
|
263
331
|
:return: The type of the conclusion of the RDR classifier.
|
264
332
|
"""
|
265
333
|
if isinstance(self.start_rule.conclusion, CallableExpression):
|
266
334
|
return self.start_rule.conclusion.conclusion_type
|
267
335
|
else:
|
268
|
-
|
269
|
-
|
270
|
-
|
336
|
+
conclusion = self.start_rule.conclusion
|
337
|
+
if isinstance(conclusion, set):
|
338
|
+
return type(list(conclusion)[0]), set
|
339
|
+
return (type(conclusion),)
|
271
340
|
|
272
341
|
@property
|
273
342
|
def attribute_name(self) -> str:
|
@@ -279,8 +348,16 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
279
348
|
|
280
349
|
class SingleClassRDR(RDRWithCodeWriter):
|
281
350
|
|
351
|
+
def __init__(self, start_rule: Optional[SingleClassRule] = None, default_conclusion: Optional[Any] = None):
|
352
|
+
"""
|
353
|
+
:param start_rule: The starting rule for the classifier.
|
354
|
+
:param default_conclusion: The default conclusion for the classifier if no rules fire.
|
355
|
+
"""
|
356
|
+
super(SingleClassRDR, self).__init__(start_rule)
|
357
|
+
self.default_conclusion: Optional[Any] = default_conclusion
|
358
|
+
|
282
359
|
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
283
|
-
-> Union[CaseAttribute, CallableExpression]:
|
360
|
+
-> Union[CaseAttribute, CallableExpression, None]:
|
284
361
|
"""
|
285
362
|
Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
|
286
363
|
comparing the case with the target category if provided.
|
@@ -289,28 +366,31 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
289
366
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
290
367
|
:return: The category that the case belongs to.
|
291
368
|
"""
|
292
|
-
expert = expert if expert else Human(
|
293
|
-
if case_query.
|
294
|
-
|
369
|
+
expert = expert if expert else Human()
|
370
|
+
if case_query.default_value is not None and self.default_conclusion != case_query.default_value:
|
371
|
+
self.default_conclusion = case_query.default_value
|
372
|
+
case = case_query.case
|
373
|
+
target = expert.ask_for_conclusion(case_query) if case_query.target is None else case_query.target
|
374
|
+
if target is None:
|
375
|
+
return self.classify(case)
|
295
376
|
if not self.start_rule:
|
296
377
|
conditions = expert.ask_for_conditions(case_query)
|
297
|
-
self.start_rule = SingleClassRule(conditions,
|
378
|
+
self.start_rule = SingleClassRule(conditions, target, corner_case=case,
|
298
379
|
conclusion_name=case_query.attribute_name)
|
299
380
|
|
300
381
|
pred = self.evaluate(case_query.case)
|
301
|
-
|
302
|
-
if pred.conclusion != case_query.target:
|
382
|
+
if pred.conclusion(case) != target(case):
|
303
383
|
conditions = expert.ask_for_conditions(case_query, pred)
|
304
|
-
pred.fit_rule(case_query.case,
|
384
|
+
pred.fit_rule(case_query.case, target, conditions=conditions)
|
305
385
|
|
306
386
|
return self.classify(case_query.case)
|
307
387
|
|
308
|
-
def classify(self, case: Case) -> Optional[
|
388
|
+
def classify(self, case: Case) -> Optional[Any]:
|
309
389
|
"""
|
310
390
|
Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
|
311
391
|
"""
|
312
392
|
pred = self.evaluate(case)
|
313
|
-
return pred.conclusion if pred.fired else
|
393
|
+
return pred.conclusion(case) if pred.fired else self.default_conclusion
|
314
394
|
|
315
395
|
def evaluate(self, case: Case) -> SingleClassRule:
|
316
396
|
"""
|
@@ -319,23 +399,27 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
319
399
|
matched_rule = self.start_rule(case)
|
320
400
|
return matched_rule if matched_rule else self.start_rule
|
321
401
|
|
322
|
-
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file, parent_indent: str = ""
|
402
|
+
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
|
403
|
+
defs_file: Optional[str] = None):
|
323
404
|
"""
|
324
405
|
Write the rules as source code to a file.
|
325
406
|
"""
|
326
407
|
if rule.conditions:
|
327
|
-
|
408
|
+
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
409
|
+
file.write(if_clause)
|
328
410
|
if rule.refinement:
|
329
|
-
|
411
|
+
|
412
|
+
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
|
413
|
+
defs_file=defs_file)
|
330
414
|
|
331
415
|
file.write(rule.write_conclusion_as_source_code(parent_indent))
|
332
416
|
|
333
417
|
if rule.alternative:
|
334
|
-
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
|
418
|
+
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
|
335
419
|
|
336
420
|
@property
|
337
421
|
def conclusion_type_hint(self) -> str:
|
338
|
-
return self.conclusion_type.__name__
|
422
|
+
return self.conclusion_type[0].__name__
|
339
423
|
|
340
424
|
def _to_json(self) -> Dict[str, Any]:
|
341
425
|
return {"start_rule": self.start_rule.to_json()}
|
@@ -369,28 +453,27 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
369
453
|
"""
|
370
454
|
|
371
455
|
def __init__(self, start_rule: Optional[Rule] = None,
|
372
|
-
mode: MCRDRMode = MCRDRMode.StopOnly
|
456
|
+
mode: MCRDRMode = MCRDRMode.StopOnly):
|
373
457
|
"""
|
374
458
|
:param start_rule: The starting rules for the classifier.
|
375
459
|
:param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
|
376
|
-
:param session: The sqlalchemy orm session.
|
377
460
|
"""
|
378
461
|
start_rule = MultiClassTopRule() if not start_rule else start_rule
|
379
|
-
super(MultiClassRDR, self).__init__(start_rule
|
462
|
+
super(MultiClassRDR, self).__init__(start_rule)
|
380
463
|
self.mode: MCRDRMode = mode
|
381
464
|
|
382
|
-
def classify(self, case: Union[Case, SQLTable]) ->
|
465
|
+
def classify(self, case: Union[Case, SQLTable]) -> Set[Any]:
|
383
466
|
evaluated_rule = self.start_rule
|
384
467
|
self.conclusions = []
|
385
468
|
while evaluated_rule:
|
386
469
|
next_rule = evaluated_rule(case)
|
387
470
|
if evaluated_rule.fired:
|
388
|
-
self.add_conclusion(evaluated_rule)
|
471
|
+
self.add_conclusion(evaluated_rule, case)
|
389
472
|
evaluated_rule = next_rule
|
390
|
-
return self.conclusions
|
473
|
+
return make_set(self.conclusions)
|
391
474
|
|
392
475
|
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None,
|
393
|
-
add_extra_conclusions: bool = False) ->
|
476
|
+
add_extra_conclusions: bool = False) -> Set[Union[CaseAttribute, CallableExpression, None]]:
|
394
477
|
"""
|
395
478
|
Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
|
396
479
|
or missing by comparing the case with the target category if provided.
|
@@ -400,32 +483,35 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
400
483
|
:param add_extra_conclusions: Whether to add extra conclusions after classification is done.
|
401
484
|
:return: The conclusions that the case belongs to.
|
402
485
|
"""
|
403
|
-
expert = expert if expert else Human(
|
486
|
+
expert = expert if expert else Human()
|
487
|
+
if case_query.target is None:
|
488
|
+
expert.ask_for_conclusion(case_query)
|
404
489
|
if case_query.target is None:
|
405
|
-
|
490
|
+
return self.classify(case_query.case)
|
491
|
+
self.update_start_rule(case_query, expert)
|
406
492
|
self.expert_accepted_conclusions = []
|
407
493
|
user_conclusions = []
|
408
|
-
self.update_start_rule(case_query, expert)
|
409
494
|
self.conclusions = []
|
410
495
|
self.stop_rule_conditions = None
|
411
496
|
evaluated_rule = self.start_rule
|
497
|
+
target = case_query.target(case_query.case)
|
412
498
|
while evaluated_rule:
|
413
499
|
next_rule = evaluated_rule(case_query.case)
|
414
|
-
|
500
|
+
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
501
|
+
good_conclusions = make_list(target) + user_conclusions + self.expert_accepted_conclusions
|
415
502
|
good_conclusions = make_set(good_conclusions)
|
416
503
|
|
417
504
|
if evaluated_rule.fired:
|
418
|
-
if
|
419
|
-
# if self.case_has_conclusion(case, evaluated_rule.conclusion):
|
505
|
+
if target and not make_set(rule_conclusion).issubset(good_conclusions):
|
420
506
|
# Rule fired and conclusion is different from target
|
421
507
|
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
|
422
508
|
add_extra_conclusions)
|
423
509
|
else:
|
424
510
|
# Rule fired and target is correct or there is no target to compare
|
425
|
-
self.add_conclusion(evaluated_rule)
|
511
|
+
self.add_conclusion(evaluated_rule, case_query.case)
|
426
512
|
|
427
513
|
if not next_rule:
|
428
|
-
if not make_set(
|
514
|
+
if not make_set(target).issubset(make_set(self.conclusions)):
|
429
515
|
# Nothing fired and there is a target that should have been in the conclusions
|
430
516
|
self.add_rule_for_case(case_query, expert)
|
431
517
|
# Have to check all rules again to make sure only this new rule fires
|
@@ -439,33 +525,31 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
439
525
|
return self.conclusions
|
440
526
|
|
441
527
|
def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
|
442
|
-
file, parent_indent: str = ""):
|
443
|
-
"""
|
444
|
-
Write the rules as source code to a file.
|
445
|
-
|
446
|
-
:
|
447
|
-
"""
|
528
|
+
file, parent_indent: str = "", defs_file: Optional[str] = None):
|
448
529
|
if rule == self.start_rule:
|
449
530
|
file.write(f"{parent_indent}conclusions = set()\n")
|
450
531
|
if rule.conditions:
|
451
|
-
|
532
|
+
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
533
|
+
file.write(if_clause)
|
452
534
|
conclusion_indent = parent_indent
|
453
535
|
if hasattr(rule, "refinement") and rule.refinement:
|
454
|
-
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " "
|
536
|
+
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
|
537
|
+
defs_file=defs_file)
|
455
538
|
conclusion_indent = parent_indent + " " * 4
|
456
539
|
file.write(f"{conclusion_indent}else:\n")
|
457
540
|
file.write(rule.write_conclusion_as_source_code(conclusion_indent))
|
458
541
|
|
459
542
|
if rule.alternative:
|
460
|
-
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
|
543
|
+
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
|
461
544
|
|
462
545
|
@property
|
463
546
|
def conclusion_type_hint(self) -> str:
|
464
|
-
return f"Set[{self.conclusion_type.__name__}]"
|
547
|
+
return f"Set[{self.conclusion_type[0].__name__}]"
|
465
548
|
|
466
549
|
def _get_imports(self) -> str:
|
467
550
|
imports = super()._get_imports()
|
468
551
|
imports += "from typing_extensions import Set\n"
|
552
|
+
imports += "from ripple_down_rules.utils import make_set\n"
|
469
553
|
return imports
|
470
554
|
|
471
555
|
def update_start_rule(self, case_query: CaseQuery, expert: Expert):
|
@@ -498,8 +582,10 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
498
582
|
"""
|
499
583
|
Stop a wrong conclusion by adding a stopping rule.
|
500
584
|
"""
|
501
|
-
|
502
|
-
|
585
|
+
target = case_query.target(case_query.case)
|
586
|
+
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
587
|
+
if self.is_same_category_type(rule_conclusion, target) \
|
588
|
+
and self.is_conflicting_with_target(rule_conclusion, target):
|
503
589
|
self.stop_conclusion(case_query, expert, evaluated_rule)
|
504
590
|
elif not self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
|
505
591
|
self.stop_conclusion(case_query, expert, evaluated_rule)
|
@@ -559,10 +645,11 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
559
645
|
:return: Whether the conclusion is correct or not.
|
560
646
|
"""
|
561
647
|
conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
|
562
|
-
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
|
563
|
-
|
648
|
+
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
|
649
|
+
evaluated_rule.conclusion(case_query.case),
|
650
|
+
targets=case_query.target(case_query.case),
|
564
651
|
current_conclusions=conclusions)):
|
565
|
-
self.add_conclusion(evaluated_rule)
|
652
|
+
self.add_conclusion(evaluated_rule, case_query.case)
|
566
653
|
self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
|
567
654
|
return True
|
568
655
|
return False
|
@@ -600,19 +687,21 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
600
687
|
extra_conclusions.append(conclusion)
|
601
688
|
return extra_conclusions
|
602
689
|
|
603
|
-
def add_conclusion(self, evaluated_rule: Rule) -> None:
|
690
|
+
def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
|
604
691
|
"""
|
605
692
|
Add the conclusion of the evaluated rule to the list of conclusions.
|
606
693
|
|
607
694
|
:param evaluated_rule: The evaluated rule to add the conclusion of.
|
695
|
+
:param case: The case to add the conclusion for.
|
608
696
|
"""
|
609
697
|
conclusion_types = [type(c) for c in self.conclusions]
|
610
|
-
|
611
|
-
|
698
|
+
rule_conclusion = evaluated_rule.conclusion(case)
|
699
|
+
if type(rule_conclusion) not in conclusion_types:
|
700
|
+
self.conclusions.extend(make_list(rule_conclusion))
|
612
701
|
else:
|
613
|
-
same_type_conclusions = [c for c in self.conclusions if type(c) == type(
|
614
|
-
combined_conclusion =
|
615
|
-
else {
|
702
|
+
same_type_conclusions = [c for c in self.conclusions if type(c) == type(rule_conclusion)]
|
703
|
+
combined_conclusion = rule_conclusion if isinstance(rule_conclusion, set) \
|
704
|
+
else {rule_conclusion}
|
616
705
|
combined_conclusion = copy(combined_conclusion)
|
617
706
|
for c in same_type_conclusions:
|
618
707
|
combined_conclusion.update(c if isinstance(c, set) else make_set(c))
|
@@ -720,22 +809,24 @@ class GeneralRDR(RippleDownRules):
|
|
720
809
|
pred_atts = rdr.classify(case_cp)
|
721
810
|
if pred_atts is None:
|
722
811
|
continue
|
723
|
-
if
|
812
|
+
if rdr.type_ is SingleClassRDR:
|
724
813
|
if attribute_name not in conclusions or \
|
725
814
|
(attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
|
726
815
|
conclusions[attribute_name] = pred_atts
|
727
816
|
new_conclusions[attribute_name] = pred_atts
|
728
817
|
else:
|
729
|
-
pred_atts =
|
818
|
+
pred_atts = make_set(pred_atts)
|
730
819
|
if attribute_name in conclusions:
|
731
|
-
pred_atts =
|
820
|
+
pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
|
732
821
|
if len(pred_atts) > 0:
|
733
822
|
new_conclusions[attribute_name] = pred_atts
|
734
823
|
if attribute_name not in conclusions:
|
735
|
-
conclusions[attribute_name] =
|
736
|
-
conclusions[attribute_name].
|
824
|
+
conclusions[attribute_name] = set()
|
825
|
+
conclusions[attribute_name].update(pred_atts)
|
737
826
|
if attribute_name in new_conclusions:
|
738
|
-
|
827
|
+
mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
|
828
|
+
GeneralRDR.update_case(CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive),
|
829
|
+
new_conclusions)
|
739
830
|
if len(new_conclusions) == 0:
|
740
831
|
break
|
741
832
|
return conclusions
|
@@ -761,76 +852,98 @@ class GeneralRDR(RippleDownRules):
|
|
761
852
|
case = case_queries[0].case
|
762
853
|
assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
|
763
854
|
" for multiple cases use fit instead")
|
764
|
-
|
855
|
+
original_case_query_cp = copy(case_queries[0])
|
765
856
|
for case_query in case_queries:
|
766
857
|
case_query_cp = copy(case_query)
|
767
|
-
case_query_cp.case =
|
768
|
-
if
|
858
|
+
case_query_cp.case = original_case_query_cp.case
|
859
|
+
if case_query_cp.target is None:
|
769
860
|
conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
|
770
|
-
|
861
|
+
self.update_case(case_query_cp, conclusions)
|
862
|
+
expert.ask_for_conclusion(case_query_cp)
|
863
|
+
if case_query_cp.target is None:
|
864
|
+
continue
|
865
|
+
case_query.target = case_query_cp.target
|
771
866
|
|
772
867
|
if case_query.attribute_name not in self.start_rules_dict:
|
773
868
|
conclusions = self.classify(case)
|
774
|
-
self.update_case(
|
869
|
+
self.update_case(case_query_cp, conclusions)
|
775
870
|
|
776
|
-
new_rdr = self.initialize_new_rdr_for_attribute(
|
871
|
+
new_rdr = self.initialize_new_rdr_for_attribute(case_query_cp)
|
777
872
|
self.add_rdr(new_rdr, case_query.attribute_name)
|
778
873
|
|
779
874
|
new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
|
780
|
-
self.update_case(
|
875
|
+
self.update_case(case_query_cp, {case_query.attribute_name: new_conclusions})
|
781
876
|
else:
|
782
877
|
for rdr_attribute_name, rdr in self.start_rules_dict.items():
|
783
878
|
if case_query.attribute_name != rdr_attribute_name:
|
784
|
-
conclusions = rdr.classify(
|
879
|
+
conclusions = rdr.classify(case_query_cp.case)
|
785
880
|
else:
|
786
881
|
conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
|
787
882
|
**kwargs)
|
788
883
|
if conclusions is not None or (is_iterable(conclusions) and len(conclusions) > 0):
|
789
884
|
conclusions = {rdr_attribute_name: conclusions}
|
790
|
-
|
885
|
+
case_query_cp.mutually_exclusive = True if isinstance(rdr, SingleClassRDR) else False
|
886
|
+
self.update_case(case_query_cp, conclusions)
|
887
|
+
case_query.conditions = case_query_cp.conditions
|
791
888
|
|
792
889
|
return self.classify(case)
|
793
890
|
|
794
891
|
@staticmethod
|
795
|
-
def initialize_new_rdr_for_attribute(
|
892
|
+
def initialize_new_rdr_for_attribute(case_query: CaseQuery):
|
796
893
|
"""
|
797
894
|
Initialize the appropriate RDR type for the target.
|
798
895
|
"""
|
799
|
-
|
896
|
+
if case_query.mutually_exclusive is not None:
|
897
|
+
return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive\
|
898
|
+
else MultiClassRDR()
|
899
|
+
if case_query.attribute_type in [list, set]:
|
900
|
+
return MultiClassRDR()
|
901
|
+
attribute = getattr(case_query.case, case_query.attribute_name)\
|
902
|
+
if hasattr(case_query.case, case_query.attribute_name) else case_query.target(case_query.case)
|
800
903
|
if isinstance(attribute, CaseAttribute):
|
801
|
-
return SingleClassRDR() if attribute.mutually_exclusive
|
904
|
+
return SingleClassRDR(default_conclusion=case_query.default_value) if attribute.mutually_exclusive \
|
905
|
+
else MultiClassRDR()
|
802
906
|
else:
|
803
|
-
return MultiClassRDR() if is_iterable(attribute) or (attribute is None)
|
907
|
+
return MultiClassRDR() if is_iterable(attribute) or (attribute is None)\
|
908
|
+
else SingleClassRDR(default_conclusion=case_query.default_value)
|
804
909
|
|
805
910
|
@staticmethod
|
806
|
-
def update_case(
|
911
|
+
def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
|
807
912
|
"""
|
808
913
|
Update the case with the conclusions.
|
809
914
|
|
810
|
-
:param
|
915
|
+
:param case_query: The case query that contains the case to update.
|
811
916
|
:param conclusions: The conclusions to update the case with.
|
812
917
|
"""
|
813
918
|
if not conclusions:
|
814
919
|
return
|
815
920
|
if len(conclusions) == 0:
|
816
921
|
return
|
817
|
-
if isinstance(
|
922
|
+
if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
|
818
923
|
for conclusion_name, conclusion in conclusions.items():
|
819
|
-
|
820
|
-
|
821
|
-
|
924
|
+
attribute = getattr(case_query.case, conclusion_name)
|
925
|
+
if conclusion_name == case_query.attribute_name:
|
926
|
+
attribute_type = case_query.attribute_type
|
927
|
+
else:
|
928
|
+
attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
|
929
|
+
if isinstance(attribute, set) or any(at in {Set, set} for at in attribute_type):
|
822
930
|
attribute = set() if attribute is None else attribute
|
823
931
|
for c in conclusion:
|
824
932
|
attribute.update(make_set(c))
|
825
|
-
elif isinstance(attribute, list) or
|
933
|
+
elif isinstance(attribute, list) or any(at in {List, list} for at in attribute_type):
|
826
934
|
attribute = [] if attribute is None else attribute
|
827
935
|
attribute.extend(conclusion)
|
828
|
-
elif
|
829
|
-
|
936
|
+
elif is_iterable(conclusion) and len(conclusion) == 1 \
|
937
|
+
and any(at is type(list(conclusion)[0]) for at in attribute_type):
|
938
|
+
setattr(case_query.case, conclusion_name, list(conclusion)[0])
|
939
|
+
elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
|
940
|
+
setattr(case_query.case, conclusion_name, conclusion)
|
830
941
|
else:
|
831
|
-
raise ValueError(f"
|
942
|
+
raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
|
943
|
+
f"{case_query.attribute_type} with conclusion "
|
944
|
+
f"{conclusion} of type {type(conclusion)}")
|
832
945
|
else:
|
833
|
-
case.update(conclusions)
|
946
|
+
case_query.case.update(conclusions)
|
834
947
|
|
835
948
|
def _to_json(self) -> Dict[str, Any]:
|
836
949
|
return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
|
@@ -845,14 +958,16 @@ class GeneralRDR(RippleDownRules):
|
|
845
958
|
start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
|
846
959
|
return cls(start_rules_dict)
|
847
960
|
|
848
|
-
def write_to_python_file(self, file_path: str):
|
961
|
+
def write_to_python_file(self, file_path: str, postfix: str = "") -> None:
|
849
962
|
"""
|
850
963
|
Write the tree of rules as source code to a file.
|
851
964
|
|
852
965
|
:param file_path: The path to the file to write the source code to.
|
966
|
+
:param postfix: The postfix to add to the file name.
|
853
967
|
"""
|
968
|
+
self.generated_python_file_name = self._default_generated_python_file_name + postfix
|
854
969
|
for rdr in self.start_rules_dict.values():
|
855
|
-
rdr.write_to_python_file(file_path)
|
970
|
+
rdr.write_to_python_file(file_path, postfix=f"_of_grdr{postfix}")
|
856
971
|
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
857
972
|
with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
|
858
973
|
f.write(self._get_imports(file_path) + "\n\n")
|
@@ -875,15 +990,29 @@ class GeneralRDR(RippleDownRules):
|
|
875
990
|
else:
|
876
991
|
return type(self.start_rule.corner_case)
|
877
992
|
|
878
|
-
def get_rdr_classifier_from_python_file(self, file_path: str):
|
993
|
+
def get_rdr_classifier_from_python_file(self, file_path: str) -> Callable[[Any], Any]:
|
879
994
|
"""
|
880
995
|
:param file_path: The path to the file that contains the RDR classifier function.
|
996
|
+
:param postfix: The postfix to add to the file name.
|
881
997
|
:return: The module that contains the rdr classifier function.
|
882
998
|
"""
|
883
999
|
return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
|
884
1000
|
|
885
1001
|
@property
|
886
1002
|
def generated_python_file_name(self) -> str:
|
1003
|
+
if self._generated_python_file_name is None:
|
1004
|
+
self._generated_python_file_name = self._default_generated_python_file_name
|
1005
|
+
return self._generated_python_file_name
|
1006
|
+
|
1007
|
+
@generated_python_file_name.setter
|
1008
|
+
def generated_python_file_name(self, value: str):
|
1009
|
+
self._generated_python_file_name = value
|
1010
|
+
|
1011
|
+
@property
|
1012
|
+
def _default_generated_python_file_name(self) -> str:
|
1013
|
+
"""
|
1014
|
+
:return: The default generated python file name.
|
1015
|
+
"""
|
887
1016
|
return f"{self.start_rule.corner_case._name.lower()}_rdr"
|
888
1017
|
|
889
1018
|
@property
|
@@ -891,17 +1020,25 @@ class GeneralRDR(RippleDownRules):
|
|
891
1020
|
return f"List[Union[{', '.join([rdr.conclusion_type_hint for rdr in self.start_rules_dict.values()])}]]"
|
892
1021
|
|
893
1022
|
def _get_imports(self, file_path: str) -> str:
|
1023
|
+
"""
|
1024
|
+
Get the imports needed for the generated python file.
|
1025
|
+
|
1026
|
+
:param file_path: The path to the file that contains the RDR classifier function.
|
1027
|
+
:return: The imports needed for the generated python file.
|
1028
|
+
"""
|
894
1029
|
imports = ""
|
895
1030
|
# add type hints
|
896
1031
|
imports += f"from typing_extensions import List, Union, Set\n"
|
897
1032
|
# import rdr type
|
898
1033
|
imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
|
899
1034
|
# add case type
|
900
|
-
imports += f"from ripple_down_rules.datastructures import Case, create_case\n"
|
1035
|
+
imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
|
901
1036
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|
902
1037
|
# add conclusion type imports
|
903
1038
|
for rdr in self.start_rules_dict.values():
|
904
|
-
|
1039
|
+
for conclusion_type in rdr.conclusion_type:
|
1040
|
+
if conclusion_type.__module__ != "builtins":
|
1041
|
+
imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
|
905
1042
|
# add rdr python generated functions.
|
906
1043
|
for rdr_key, rdr in self.start_rules_dict.items():
|
907
1044
|
imports += (f"from {file_path.strip('./')}"
|