ripple-down-rules 0.1.3__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +2 -1
- ripple_down_rules/datastructures/__init__.py +4 -4
- ripple_down_rules/datastructures/callable_expression.py +68 -128
- ripple_down_rules/datastructures/case.py +1 -1
- ripple_down_rules/datastructures/dataclasses.py +101 -48
- ripple_down_rules/experts.py +24 -22
- ripple_down_rules/prompt.py +44 -50
- ripple_down_rules/rdr.py +270 -164
- ripple_down_rules/rules.py +64 -32
- ripple_down_rules/utils.py +130 -2
- {ripple_down_rules-0.1.3.dist-info → ripple_down_rules-0.1.6.dist-info}/METADATA +1 -1
- ripple_down_rules-0.1.6.dist-info/RECORD +20 -0
- {ripple_down_rules-0.1.3.dist-info → ripple_down_rules-0.1.6.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.1.3.dist-info/RECORD +0 -20
- {ripple_down_rules-0.1.3.dist-info → ripple_down_rules-0.1.6.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.1.3.dist-info → ripple_down_rules-0.1.6.dist-info}/top_level.txt +0 -0
ripple_down_rules/rdr.py
CHANGED
@@ -1,20 +1,28 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import importlib
|
4
|
+
import re
|
5
|
+
import sys
|
4
6
|
from abc import ABC, abstractmethod
|
5
7
|
from copy import copy
|
8
|
+
from dataclasses import is_dataclass
|
9
|
+
from io import TextIOWrapper
|
6
10
|
from types import ModuleType
|
7
11
|
|
8
12
|
from matplotlib import pyplot as plt
|
9
13
|
from ordered_set import OrderedSet
|
10
|
-
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
14
|
+
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
11
15
|
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
|
12
16
|
|
13
|
-
from .datastructures import Case,
|
17
|
+
from .datastructures.case import Case, CaseAttribute
|
18
|
+
from .datastructures.callable_expression import CallableExpression
|
19
|
+
from .datastructures.dataclasses import CaseQuery
|
20
|
+
from .datastructures.enums import MCRDRMode
|
14
21
|
from .experts import Expert, Human
|
15
22
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
16
23
|
from .utils import draw_tree, make_set, copy_case, \
|
17
|
-
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string
|
24
|
+
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
|
25
|
+
get_case_attribute_type, ask_llm, is_matching
|
18
26
|
|
19
27
|
|
20
28
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -29,14 +37,16 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
29
37
|
"""
|
30
38
|
The conclusions that the expert has accepted, such that they are not asked again.
|
31
39
|
"""
|
40
|
+
_generated_python_file_name: Optional[str] = None
|
41
|
+
"""
|
42
|
+
The name of the generated python file.
|
43
|
+
"""
|
32
44
|
|
33
|
-
def __init__(self, start_rule: Optional[Rule] = None
|
45
|
+
def __init__(self, start_rule: Optional[Rule] = None):
|
34
46
|
"""
|
35
47
|
:param start_rule: The starting rule for the classifier.
|
36
|
-
:param session: The sqlalchemy orm session.
|
37
48
|
"""
|
38
49
|
self.start_rule = start_rule
|
39
|
-
self.session = session
|
40
50
|
self.fig: Optional[plt.Figure] = None
|
41
51
|
|
42
52
|
def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
|
@@ -79,31 +89,29 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
79
89
|
:param animate_tree: Whether to draw the tree while fitting the classifier.
|
80
90
|
:param kwargs_for_fit_case: The keyword arguments to pass to the fit_case method.
|
81
91
|
"""
|
82
|
-
|
83
|
-
targets = [{case_query.attribute_name: case_query.target} for case_query in case_queries]
|
92
|
+
targets = []
|
84
93
|
if animate_tree:
|
85
94
|
plt.ion()
|
86
95
|
i = 0
|
87
96
|
stop_iterating = False
|
88
97
|
num_rules: int = 0
|
89
98
|
while not stop_iterating:
|
90
|
-
all_pred = 0
|
91
|
-
if not targets:
|
92
|
-
targets = [None] * len(cases)
|
93
99
|
for case_query in case_queries:
|
94
|
-
target = {case_query.attribute_name: case_query.target}
|
95
100
|
pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
|
96
|
-
|
101
|
+
if case_query.target is None:
|
102
|
+
continue
|
103
|
+
target = {case_query.attribute_name: case_query.target(case_query.case)}
|
104
|
+
if len(targets) < len(case_queries):
|
105
|
+
targets.append(target)
|
106
|
+
match = is_matching(self.classify, case_query, pred_cat)
|
97
107
|
if not match:
|
98
108
|
print(f"Predicted: {pred_cat} but expected: {target}")
|
99
|
-
all_pred += int(match)
|
100
109
|
if animate_tree and self.start_rule.size > num_rules:
|
101
110
|
num_rules = self.start_rule.size
|
102
111
|
self.update_figures()
|
103
112
|
i += 1
|
104
|
-
all_predictions = [1 if
|
105
|
-
|
106
|
-
for case_query in case_queries]
|
113
|
+
all_predictions = [1 if is_matching(self.classify, case_query) else 0 for case_query in case_queries
|
114
|
+
if case_query.target is not None]
|
107
115
|
all_pred = sum(all_predictions)
|
108
116
|
print(f"Accuracy: {all_pred}/{len(targets)}")
|
109
117
|
all_predicted = targets and all_pred == len(targets)
|
@@ -116,48 +124,6 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
116
124
|
plt.ioff()
|
117
125
|
plt.show()
|
118
126
|
|
119
|
-
@staticmethod
|
120
|
-
def calculate_precision_and_recall(pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> Tuple[
|
121
|
-
List[bool], List[bool]]:
|
122
|
-
"""
|
123
|
-
:param pred_cat: The predicted category.
|
124
|
-
:param target: The target category.
|
125
|
-
:return: The precision and recall of the classifier.
|
126
|
-
"""
|
127
|
-
pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
|
128
|
-
target = target if is_iterable(target) else [target]
|
129
|
-
recall = []
|
130
|
-
precision = []
|
131
|
-
if isinstance(pred_cat, dict):
|
132
|
-
for pred_key, pred_value in pred_cat.items():
|
133
|
-
if pred_key not in target:
|
134
|
-
continue
|
135
|
-
precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
|
136
|
-
for target_key, target_value in target.items():
|
137
|
-
if target_key not in pred_cat:
|
138
|
-
recall.append(False)
|
139
|
-
continue
|
140
|
-
if is_iterable(target_value):
|
141
|
-
recall.extend([v in pred_cat[target_key] for v in target_value])
|
142
|
-
else:
|
143
|
-
recall.append(target_value == pred_cat[target_key])
|
144
|
-
else:
|
145
|
-
if isinstance(target, dict):
|
146
|
-
target = list(target.values())
|
147
|
-
recall = [not yi or (yi in pred_cat) for yi in target]
|
148
|
-
target_types = [type(yi) for yi in target]
|
149
|
-
precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
|
150
|
-
return precision, recall
|
151
|
-
|
152
|
-
def is_matching(self, pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> bool:
|
153
|
-
"""
|
154
|
-
:param pred_cat: The predicted category.
|
155
|
-
:param target: The target category.
|
156
|
-
:return: Whether the classifier is matching or not.
|
157
|
-
"""
|
158
|
-
precision, recall = self.calculate_precision_and_recall(pred_cat, target)
|
159
|
-
return all(recall) and all(precision)
|
160
|
-
|
161
127
|
def update_figures(self):
|
162
128
|
"""
|
163
129
|
Update the figures of the classifier.
|
@@ -183,33 +149,51 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
183
149
|
"""
|
184
150
|
return hasattr(case, conclusion_name) and getattr(case, conclusion_name) is not None
|
185
151
|
|
152
|
+
@property
|
153
|
+
def type_(self):
|
154
|
+
return self.__class__
|
155
|
+
|
186
156
|
|
187
157
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
188
158
|
|
189
159
|
@abstractmethod
|
190
|
-
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = ""
|
160
|
+
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
|
161
|
+
defs_file: Optional[str] = None):
|
191
162
|
"""
|
192
163
|
Write the rules as source code to a file.
|
193
164
|
|
194
165
|
:param rule: The rule to write as source code.
|
195
166
|
:param file: The file to write the source code to.
|
196
167
|
:param parent_indent: The indentation of the parent rule.
|
168
|
+
:param defs_file: The file to write the definitions to.
|
197
169
|
"""
|
198
170
|
pass
|
199
171
|
|
200
|
-
def write_to_python_file(self, file_path: str):
|
172
|
+
def write_to_python_file(self, file_path: str, postfix: str = ""):
|
201
173
|
"""
|
202
174
|
Write the tree of rules as source code to a file.
|
203
175
|
|
204
176
|
:param file_path: The path to the file to write the source code to.
|
177
|
+
:param postfix: The postfix to add to the file name.
|
205
178
|
"""
|
179
|
+
self.generated_python_file_name = self._default_generated_python_file_name + postfix
|
206
180
|
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
207
|
-
|
208
|
-
|
181
|
+
file_name = file_path + f"/{self.generated_python_file_name}.py"
|
182
|
+
defs_file_name = file_path + f"/{self.generated_python_defs_file_name}.py"
|
183
|
+
imports = self._get_imports()
|
184
|
+
# clear the files first
|
185
|
+
with open(defs_file_name, "w") as f:
|
186
|
+
f.write(imports + "\n\n")
|
187
|
+
with open(file_name, "w") as f:
|
188
|
+
imports += f"from .{self.generated_python_defs_file_name} import *\n"
|
189
|
+
imports += f"from ripple_down_rules.rdr import {self.__class__.__name__}\n"
|
190
|
+
f.write(imports + "\n\n")
|
191
|
+
f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n\n")
|
192
|
+
f.write(f"type_ = {self.__class__.__name__}\n\n")
|
209
193
|
f.write(func_def)
|
210
194
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
211
195
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
212
|
-
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
|
196
|
+
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name)
|
213
197
|
|
214
198
|
@property
|
215
199
|
@abstractmethod
|
@@ -226,26 +210,73 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
226
210
|
imports = ""
|
227
211
|
if self.case_type.__module__ != "builtins":
|
228
212
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|
229
|
-
|
230
|
-
|
231
|
-
|
213
|
+
for conclusion_type in self.conclusion_type:
|
214
|
+
if conclusion_type.__module__ != "builtins":
|
215
|
+
imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
|
216
|
+
imports += "from ripple_down_rules.datastructures.case import Case, create_case\n"
|
232
217
|
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
233
|
-
if rule.conditions:
|
234
|
-
|
235
|
-
|
236
|
-
|
218
|
+
if not rule.conditions:
|
219
|
+
continue
|
220
|
+
if rule.conditions.scope is None or len(rule.conditions.scope) == 0:
|
221
|
+
continue
|
222
|
+
for k, v in rule.conditions.scope.items():
|
223
|
+
new_imports = f"from {v.__module__} import {v.__name__}\n"
|
224
|
+
if new_imports in imports:
|
225
|
+
continue
|
226
|
+
imports += new_imports
|
237
227
|
return imports
|
238
228
|
|
239
|
-
def get_rdr_classifier_from_python_file(self, package_name) -> Callable[[Any], Any]:
|
229
|
+
def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
|
240
230
|
"""
|
241
231
|
:param package_name: The name of the package that contains the RDR classifier function.
|
242
232
|
:return: The module that contains the rdr classifier function.
|
243
233
|
"""
|
244
|
-
|
234
|
+
# remove from imports if exists first
|
235
|
+
name = f"{package_name.strip('./')}.{self.generated_python_file_name}"
|
236
|
+
try:
|
237
|
+
module = importlib.import_module(name)
|
238
|
+
del sys.modules[name]
|
239
|
+
except ModuleNotFoundError:
|
240
|
+
pass
|
241
|
+
return importlib.import_module(name).classify
|
245
242
|
|
246
243
|
@property
|
247
244
|
def generated_python_file_name(self) -> str:
|
248
|
-
|
245
|
+
if self._generated_python_file_name is None:
|
246
|
+
self._generated_python_file_name = self._default_generated_python_file_name
|
247
|
+
return self._generated_python_file_name
|
248
|
+
|
249
|
+
@generated_python_file_name.setter
|
250
|
+
def generated_python_file_name(self, value: str):
|
251
|
+
"""
|
252
|
+
Set the generated python file name.
|
253
|
+
:param value: The new value for the generated python file name.
|
254
|
+
"""
|
255
|
+
self._generated_python_file_name = value
|
256
|
+
|
257
|
+
@property
|
258
|
+
def _default_generated_python_file_name(self) -> str:
|
259
|
+
"""
|
260
|
+
:return: The default generated python file name.
|
261
|
+
"""
|
262
|
+
return f"{self.start_rule.corner_case._name.lower()}_{self.attribute_name}_{self.acronym.lower()}"
|
263
|
+
|
264
|
+
@property
|
265
|
+
def generated_python_defs_file_name(self) -> str:
|
266
|
+
return f"{self.generated_python_file_name}_defs"
|
267
|
+
|
268
|
+
@property
|
269
|
+
def acronym(self) -> str:
|
270
|
+
"""
|
271
|
+
:return: The acronym of the classifier.
|
272
|
+
"""
|
273
|
+
if self.__class__.__name__ == "GeneralRDR":
|
274
|
+
return "GRDR"
|
275
|
+
elif self.__class__.__name__ == "MultiClassRDR":
|
276
|
+
return "MCRDR"
|
277
|
+
else:
|
278
|
+
return "SCRDR"
|
279
|
+
|
249
280
|
|
250
281
|
@property
|
251
282
|
def case_type(self) -> Type:
|
@@ -258,16 +289,17 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
258
289
|
return type(self.start_rule.corner_case)
|
259
290
|
|
260
291
|
@property
|
261
|
-
def conclusion_type(self) -> Type:
|
292
|
+
def conclusion_type(self) -> Tuple[Type]:
|
262
293
|
"""
|
263
294
|
:return: The type of the conclusion of the RDR classifier.
|
264
295
|
"""
|
265
296
|
if isinstance(self.start_rule.conclusion, CallableExpression):
|
266
297
|
return self.start_rule.conclusion.conclusion_type
|
267
298
|
else:
|
268
|
-
|
269
|
-
|
270
|
-
|
299
|
+
conclusion = self.start_rule.conclusion
|
300
|
+
if isinstance(conclusion, set):
|
301
|
+
return type(list(conclusion)[0]), set
|
302
|
+
return (type(conclusion),)
|
271
303
|
|
272
304
|
@property
|
273
305
|
def attribute_name(self) -> str:
|
@@ -279,8 +311,16 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
279
311
|
|
280
312
|
class SingleClassRDR(RDRWithCodeWriter):
|
281
313
|
|
314
|
+
def __init__(self, start_rule: Optional[SingleClassRule] = None, default_conclusion: Optional[Any] = None):
|
315
|
+
"""
|
316
|
+
:param start_rule: The starting rule for the classifier.
|
317
|
+
:param default_conclusion: The default conclusion for the classifier if no rules fire.
|
318
|
+
"""
|
319
|
+
super(SingleClassRDR, self).__init__(start_rule)
|
320
|
+
self.default_conclusion: Optional[Any] = default_conclusion
|
321
|
+
|
282
322
|
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
283
|
-
-> Union[CaseAttribute, CallableExpression]:
|
323
|
+
-> Union[CaseAttribute, CallableExpression, None]:
|
284
324
|
"""
|
285
325
|
Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
|
286
326
|
comparing the case with the target category if provided.
|
@@ -289,28 +329,31 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
289
329
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
290
330
|
:return: The category that the case belongs to.
|
291
331
|
"""
|
292
|
-
expert = expert if expert else Human(
|
293
|
-
if case_query.
|
294
|
-
|
332
|
+
expert = expert if expert else Human()
|
333
|
+
if case_query.default_value is not None and self.default_conclusion != case_query.default_value:
|
334
|
+
self.default_conclusion = case_query.default_value
|
335
|
+
case = case_query.case
|
336
|
+
target = expert.ask_for_conclusion(case_query) if case_query.target is None else case_query.target
|
337
|
+
if target is None:
|
338
|
+
return self.classify(case)
|
295
339
|
if not self.start_rule:
|
296
340
|
conditions = expert.ask_for_conditions(case_query)
|
297
|
-
self.start_rule = SingleClassRule(conditions,
|
341
|
+
self.start_rule = SingleClassRule(conditions, target, corner_case=case,
|
298
342
|
conclusion_name=case_query.attribute_name)
|
299
343
|
|
300
344
|
pred = self.evaluate(case_query.case)
|
301
|
-
|
302
|
-
if pred.conclusion != case_query.target:
|
345
|
+
if pred.conclusion(case) != target(case):
|
303
346
|
conditions = expert.ask_for_conditions(case_query, pred)
|
304
|
-
pred.fit_rule(case_query.case,
|
347
|
+
pred.fit_rule(case_query.case, target, conditions=conditions)
|
305
348
|
|
306
349
|
return self.classify(case_query.case)
|
307
350
|
|
308
|
-
def classify(self, case: Case) -> Optional[
|
351
|
+
def classify(self, case: Case) -> Optional[Any]:
|
309
352
|
"""
|
310
353
|
Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
|
311
354
|
"""
|
312
355
|
pred = self.evaluate(case)
|
313
|
-
return pred.conclusion if pred.fired else
|
356
|
+
return pred.conclusion(case) if pred.fired else self.default_conclusion
|
314
357
|
|
315
358
|
def evaluate(self, case: Case) -> SingleClassRule:
|
316
359
|
"""
|
@@ -319,23 +362,33 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
319
362
|
matched_rule = self.start_rule(case)
|
320
363
|
return matched_rule if matched_rule else self.start_rule
|
321
364
|
|
322
|
-
def
|
365
|
+
def write_to_python_file(self, file_path: str, postfix: str = ""):
|
366
|
+
super().write_to_python_file(file_path, postfix)
|
367
|
+
if self.default_conclusion is not None:
|
368
|
+
with open(file_path + f"/{self.generated_python_file_name}.py", "a") as f:
|
369
|
+
f.write(f"{' '*4}else:\n{' '*8}return {self.default_conclusion}\n")
|
370
|
+
|
371
|
+
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
|
372
|
+
defs_file: Optional[str] = None):
|
323
373
|
"""
|
324
374
|
Write the rules as source code to a file.
|
325
375
|
"""
|
326
376
|
if rule.conditions:
|
327
|
-
|
377
|
+
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
378
|
+
file.write(if_clause)
|
328
379
|
if rule.refinement:
|
329
|
-
|
380
|
+
|
381
|
+
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
|
382
|
+
defs_file=defs_file)
|
330
383
|
|
331
384
|
file.write(rule.write_conclusion_as_source_code(parent_indent))
|
332
385
|
|
333
386
|
if rule.alternative:
|
334
|
-
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
|
387
|
+
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
|
335
388
|
|
336
389
|
@property
|
337
390
|
def conclusion_type_hint(self) -> str:
|
338
|
-
return self.conclusion_type.__name__
|
391
|
+
return self.conclusion_type[0].__name__
|
339
392
|
|
340
393
|
def _to_json(self) -> Dict[str, Any]:
|
341
394
|
return {"start_rule": self.start_rule.to_json()}
|
@@ -369,28 +422,27 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
369
422
|
"""
|
370
423
|
|
371
424
|
def __init__(self, start_rule: Optional[Rule] = None,
|
372
|
-
mode: MCRDRMode = MCRDRMode.StopOnly
|
425
|
+
mode: MCRDRMode = MCRDRMode.StopOnly):
|
373
426
|
"""
|
374
427
|
:param start_rule: The starting rules for the classifier.
|
375
428
|
:param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
|
376
|
-
:param session: The sqlalchemy orm session.
|
377
429
|
"""
|
378
430
|
start_rule = MultiClassTopRule() if not start_rule else start_rule
|
379
|
-
super(MultiClassRDR, self).__init__(start_rule
|
431
|
+
super(MultiClassRDR, self).__init__(start_rule)
|
380
432
|
self.mode: MCRDRMode = mode
|
381
433
|
|
382
|
-
def classify(self, case: Union[Case, SQLTable]) ->
|
434
|
+
def classify(self, case: Union[Case, SQLTable]) -> Set[Any]:
|
383
435
|
evaluated_rule = self.start_rule
|
384
436
|
self.conclusions = []
|
385
437
|
while evaluated_rule:
|
386
438
|
next_rule = evaluated_rule(case)
|
387
439
|
if evaluated_rule.fired:
|
388
|
-
self.add_conclusion(evaluated_rule)
|
440
|
+
self.add_conclusion(evaluated_rule, case)
|
389
441
|
evaluated_rule = next_rule
|
390
|
-
return self.conclusions
|
442
|
+
return make_set(self.conclusions)
|
391
443
|
|
392
444
|
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None,
|
393
|
-
add_extra_conclusions: bool = False) ->
|
445
|
+
add_extra_conclusions: bool = False) -> Set[Union[CaseAttribute, CallableExpression, None]]:
|
394
446
|
"""
|
395
447
|
Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
|
396
448
|
or missing by comparing the case with the target category if provided.
|
@@ -400,32 +452,35 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
400
452
|
:param add_extra_conclusions: Whether to add extra conclusions after classification is done.
|
401
453
|
:return: The conclusions that the case belongs to.
|
402
454
|
"""
|
403
|
-
expert = expert if expert else Human(
|
455
|
+
expert = expert if expert else Human()
|
404
456
|
if case_query.target is None:
|
405
|
-
|
457
|
+
expert.ask_for_conclusion(case_query)
|
458
|
+
if case_query.target is None:
|
459
|
+
return self.classify(case_query.case)
|
460
|
+
self.update_start_rule(case_query, expert)
|
406
461
|
self.expert_accepted_conclusions = []
|
407
462
|
user_conclusions = []
|
408
|
-
self.update_start_rule(case_query, expert)
|
409
463
|
self.conclusions = []
|
410
464
|
self.stop_rule_conditions = None
|
411
465
|
evaluated_rule = self.start_rule
|
466
|
+
target = case_query.target(case_query.case)
|
412
467
|
while evaluated_rule:
|
413
468
|
next_rule = evaluated_rule(case_query.case)
|
414
|
-
|
469
|
+
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
470
|
+
good_conclusions = make_list(target) + user_conclusions + self.expert_accepted_conclusions
|
415
471
|
good_conclusions = make_set(good_conclusions)
|
416
472
|
|
417
473
|
if evaluated_rule.fired:
|
418
|
-
if
|
419
|
-
# if self.case_has_conclusion(case, evaluated_rule.conclusion):
|
474
|
+
if target and not make_set(rule_conclusion).issubset(good_conclusions):
|
420
475
|
# Rule fired and conclusion is different from target
|
421
476
|
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
|
422
477
|
add_extra_conclusions)
|
423
478
|
else:
|
424
479
|
# Rule fired and target is correct or there is no target to compare
|
425
|
-
self.add_conclusion(evaluated_rule)
|
480
|
+
self.add_conclusion(evaluated_rule, case_query.case)
|
426
481
|
|
427
482
|
if not next_rule:
|
428
|
-
if not make_set(
|
483
|
+
if not make_set(target).issubset(make_set(self.conclusions)):
|
429
484
|
# Nothing fired and there is a target that should have been in the conclusions
|
430
485
|
self.add_rule_for_case(case_query, expert)
|
431
486
|
# Have to check all rules again to make sure only this new rule fires
|
@@ -439,33 +494,31 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
439
494
|
return self.conclusions
|
440
495
|
|
441
496
|
def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
|
442
|
-
file, parent_indent: str = ""):
|
443
|
-
"""
|
444
|
-
Write the rules as source code to a file.
|
445
|
-
|
446
|
-
:
|
447
|
-
"""
|
497
|
+
file, parent_indent: str = "", defs_file: Optional[str] = None):
|
448
498
|
if rule == self.start_rule:
|
449
499
|
file.write(f"{parent_indent}conclusions = set()\n")
|
450
500
|
if rule.conditions:
|
451
|
-
|
501
|
+
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
502
|
+
file.write(if_clause)
|
452
503
|
conclusion_indent = parent_indent
|
453
504
|
if hasattr(rule, "refinement") and rule.refinement:
|
454
|
-
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " "
|
505
|
+
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
|
506
|
+
defs_file=defs_file)
|
455
507
|
conclusion_indent = parent_indent + " " * 4
|
456
508
|
file.write(f"{conclusion_indent}else:\n")
|
457
509
|
file.write(rule.write_conclusion_as_source_code(conclusion_indent))
|
458
510
|
|
459
511
|
if rule.alternative:
|
460
|
-
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
|
512
|
+
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
|
461
513
|
|
462
514
|
@property
|
463
515
|
def conclusion_type_hint(self) -> str:
|
464
|
-
return f"Set[{self.conclusion_type.__name__}]"
|
516
|
+
return f"Set[{self.conclusion_type[0].__name__}]"
|
465
517
|
|
466
518
|
def _get_imports(self) -> str:
|
467
519
|
imports = super()._get_imports()
|
468
520
|
imports += "from typing_extensions import Set\n"
|
521
|
+
imports += "from ripple_down_rules.utils import make_set\n"
|
469
522
|
return imports
|
470
523
|
|
471
524
|
def update_start_rule(self, case_query: CaseQuery, expert: Expert):
|
@@ -498,8 +551,10 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
498
551
|
"""
|
499
552
|
Stop a wrong conclusion by adding a stopping rule.
|
500
553
|
"""
|
501
|
-
|
502
|
-
|
554
|
+
target = case_query.target(case_query.case)
|
555
|
+
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
556
|
+
if self.is_same_category_type(rule_conclusion, target) \
|
557
|
+
and self.is_conflicting_with_target(rule_conclusion, target):
|
503
558
|
self.stop_conclusion(case_query, expert, evaluated_rule)
|
504
559
|
elif not self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
|
505
560
|
self.stop_conclusion(case_query, expert, evaluated_rule)
|
@@ -559,10 +614,11 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
559
614
|
:return: Whether the conclusion is correct or not.
|
560
615
|
"""
|
561
616
|
conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
|
562
|
-
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
|
563
|
-
|
617
|
+
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
|
618
|
+
evaluated_rule.conclusion(case_query.case),
|
619
|
+
targets=case_query.target(case_query.case),
|
564
620
|
current_conclusions=conclusions)):
|
565
|
-
self.add_conclusion(evaluated_rule)
|
621
|
+
self.add_conclusion(evaluated_rule, case_query.case)
|
566
622
|
self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
|
567
623
|
return True
|
568
624
|
return False
|
@@ -600,19 +656,21 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
600
656
|
extra_conclusions.append(conclusion)
|
601
657
|
return extra_conclusions
|
602
658
|
|
603
|
-
def add_conclusion(self, evaluated_rule: Rule) -> None:
|
659
|
+
def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
|
604
660
|
"""
|
605
661
|
Add the conclusion of the evaluated rule to the list of conclusions.
|
606
662
|
|
607
663
|
:param evaluated_rule: The evaluated rule to add the conclusion of.
|
664
|
+
:param case: The case to add the conclusion for.
|
608
665
|
"""
|
609
666
|
conclusion_types = [type(c) for c in self.conclusions]
|
610
|
-
|
611
|
-
|
667
|
+
rule_conclusion = evaluated_rule.conclusion(case)
|
668
|
+
if type(rule_conclusion) not in conclusion_types:
|
669
|
+
self.conclusions.extend(make_list(rule_conclusion))
|
612
670
|
else:
|
613
|
-
same_type_conclusions = [c for c in self.conclusions if type(c) == type(
|
614
|
-
combined_conclusion =
|
615
|
-
else {
|
671
|
+
same_type_conclusions = [c for c in self.conclusions if type(c) == type(rule_conclusion)]
|
672
|
+
combined_conclusion = rule_conclusion if isinstance(rule_conclusion, set) \
|
673
|
+
else {rule_conclusion}
|
616
674
|
combined_conclusion = copy(combined_conclusion)
|
617
675
|
for c in same_type_conclusions:
|
618
676
|
combined_conclusion.update(c if isinstance(c, set) else make_set(c))
|
@@ -720,22 +778,24 @@ class GeneralRDR(RippleDownRules):
|
|
720
778
|
pred_atts = rdr.classify(case_cp)
|
721
779
|
if pred_atts is None:
|
722
780
|
continue
|
723
|
-
if
|
781
|
+
if rdr.type_ is SingleClassRDR:
|
724
782
|
if attribute_name not in conclusions or \
|
725
783
|
(attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
|
726
784
|
conclusions[attribute_name] = pred_atts
|
727
785
|
new_conclusions[attribute_name] = pred_atts
|
728
786
|
else:
|
729
|
-
pred_atts =
|
787
|
+
pred_atts = make_set(pred_atts)
|
730
788
|
if attribute_name in conclusions:
|
731
|
-
pred_atts =
|
789
|
+
pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
|
732
790
|
if len(pred_atts) > 0:
|
733
791
|
new_conclusions[attribute_name] = pred_atts
|
734
792
|
if attribute_name not in conclusions:
|
735
|
-
conclusions[attribute_name] =
|
736
|
-
conclusions[attribute_name].
|
793
|
+
conclusions[attribute_name] = set()
|
794
|
+
conclusions[attribute_name].update(pred_atts)
|
737
795
|
if attribute_name in new_conclusions:
|
738
|
-
|
796
|
+
mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
|
797
|
+
GeneralRDR.update_case(CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive),
|
798
|
+
new_conclusions)
|
739
799
|
if len(new_conclusions) == 0:
|
740
800
|
break
|
741
801
|
return conclusions
|
@@ -761,76 +821,98 @@ class GeneralRDR(RippleDownRules):
|
|
761
821
|
case = case_queries[0].case
|
762
822
|
assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
|
763
823
|
" for multiple cases use fit instead")
|
764
|
-
|
824
|
+
original_case_query_cp = copy(case_queries[0])
|
765
825
|
for case_query in case_queries:
|
766
826
|
case_query_cp = copy(case_query)
|
767
|
-
case_query_cp.case =
|
768
|
-
if
|
827
|
+
case_query_cp.case = original_case_query_cp.case
|
828
|
+
if case_query_cp.target is None:
|
769
829
|
conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
|
770
|
-
|
830
|
+
self.update_case(case_query_cp, conclusions)
|
831
|
+
expert.ask_for_conclusion(case_query_cp)
|
832
|
+
if case_query_cp.target is None:
|
833
|
+
continue
|
834
|
+
case_query.target = case_query_cp.target
|
771
835
|
|
772
836
|
if case_query.attribute_name not in self.start_rules_dict:
|
773
837
|
conclusions = self.classify(case)
|
774
|
-
self.update_case(
|
838
|
+
self.update_case(case_query_cp, conclusions)
|
775
839
|
|
776
|
-
new_rdr = self.initialize_new_rdr_for_attribute(
|
840
|
+
new_rdr = self.initialize_new_rdr_for_attribute(case_query_cp)
|
777
841
|
self.add_rdr(new_rdr, case_query.attribute_name)
|
778
842
|
|
779
843
|
new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
|
780
|
-
self.update_case(
|
844
|
+
self.update_case(case_query_cp, {case_query.attribute_name: new_conclusions})
|
781
845
|
else:
|
782
846
|
for rdr_attribute_name, rdr in self.start_rules_dict.items():
|
783
847
|
if case_query.attribute_name != rdr_attribute_name:
|
784
|
-
conclusions = rdr.classify(
|
848
|
+
conclusions = rdr.classify(case_query_cp.case)
|
785
849
|
else:
|
786
850
|
conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
|
787
851
|
**kwargs)
|
788
852
|
if conclusions is not None or (is_iterable(conclusions) and len(conclusions) > 0):
|
789
853
|
conclusions = {rdr_attribute_name: conclusions}
|
790
|
-
|
854
|
+
case_query_cp.mutually_exclusive = True if isinstance(rdr, SingleClassRDR) else False
|
855
|
+
self.update_case(case_query_cp, conclusions)
|
856
|
+
case_query.conditions = case_query_cp.conditions
|
791
857
|
|
792
858
|
return self.classify(case)
|
793
859
|
|
794
860
|
@staticmethod
|
795
|
-
def initialize_new_rdr_for_attribute(
|
861
|
+
def initialize_new_rdr_for_attribute(case_query: CaseQuery):
|
796
862
|
"""
|
797
863
|
Initialize the appropriate RDR type for the target.
|
798
864
|
"""
|
799
|
-
|
865
|
+
if case_query.mutually_exclusive is not None:
|
866
|
+
return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive\
|
867
|
+
else MultiClassRDR()
|
868
|
+
if case_query.attribute_type in [list, set]:
|
869
|
+
return MultiClassRDR()
|
870
|
+
attribute = getattr(case_query.case, case_query.attribute_name)\
|
871
|
+
if hasattr(case_query.case, case_query.attribute_name) else case_query.target(case_query.case)
|
800
872
|
if isinstance(attribute, CaseAttribute):
|
801
|
-
return SingleClassRDR() if attribute.mutually_exclusive
|
873
|
+
return SingleClassRDR(default_conclusion=case_query.default_value) if attribute.mutually_exclusive \
|
874
|
+
else MultiClassRDR()
|
802
875
|
else:
|
803
|
-
return MultiClassRDR() if is_iterable(attribute) or (attribute is None)
|
876
|
+
return MultiClassRDR() if is_iterable(attribute) or (attribute is None)\
|
877
|
+
else SingleClassRDR(default_conclusion=case_query.default_value)
|
804
878
|
|
805
879
|
@staticmethod
|
806
|
-
def update_case(
|
880
|
+
def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
|
807
881
|
"""
|
808
882
|
Update the case with the conclusions.
|
809
883
|
|
810
|
-
:param
|
884
|
+
:param case_query: The case query that contains the case to update.
|
811
885
|
:param conclusions: The conclusions to update the case with.
|
812
886
|
"""
|
813
887
|
if not conclusions:
|
814
888
|
return
|
815
889
|
if len(conclusions) == 0:
|
816
890
|
return
|
817
|
-
if isinstance(
|
891
|
+
if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
|
818
892
|
for conclusion_name, conclusion in conclusions.items():
|
819
|
-
|
820
|
-
|
821
|
-
|
893
|
+
attribute = getattr(case_query.case, conclusion_name)
|
894
|
+
if conclusion_name == case_query.attribute_name:
|
895
|
+
attribute_type = case_query.attribute_type
|
896
|
+
else:
|
897
|
+
attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
|
898
|
+
if isinstance(attribute, set) or any(at in {Set, set} for at in attribute_type):
|
822
899
|
attribute = set() if attribute is None else attribute
|
823
900
|
for c in conclusion:
|
824
901
|
attribute.update(make_set(c))
|
825
|
-
elif isinstance(attribute, list) or
|
902
|
+
elif isinstance(attribute, list) or any(at in {List, list} for at in attribute_type):
|
826
903
|
attribute = [] if attribute is None else attribute
|
827
904
|
attribute.extend(conclusion)
|
828
|
-
elif
|
829
|
-
|
905
|
+
elif is_iterable(conclusion) and len(conclusion) == 1 \
|
906
|
+
and any(at is type(list(conclusion)[0]) for at in attribute_type):
|
907
|
+
setattr(case_query.case, conclusion_name, list(conclusion)[0])
|
908
|
+
elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
|
909
|
+
setattr(case_query.case, conclusion_name, conclusion)
|
830
910
|
else:
|
831
|
-
raise ValueError(f"
|
911
|
+
raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
|
912
|
+
f"{case_query.attribute_type} with conclusion "
|
913
|
+
f"{conclusion} of type {type(conclusion)}")
|
832
914
|
else:
|
833
|
-
case.update(conclusions)
|
915
|
+
case_query.case.update(conclusions)
|
834
916
|
|
835
917
|
def _to_json(self) -> Dict[str, Any]:
|
836
918
|
return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
|
@@ -845,14 +927,16 @@ class GeneralRDR(RippleDownRules):
|
|
845
927
|
start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
|
846
928
|
return cls(start_rules_dict)
|
847
929
|
|
848
|
-
def write_to_python_file(self, file_path: str):
|
930
|
+
def write_to_python_file(self, file_path: str, postfix: str = "") -> None:
|
849
931
|
"""
|
850
932
|
Write the tree of rules as source code to a file.
|
851
933
|
|
852
934
|
:param file_path: The path to the file to write the source code to.
|
935
|
+
:param postfix: The postfix to add to the file name.
|
853
936
|
"""
|
937
|
+
self.generated_python_file_name = self._default_generated_python_file_name + postfix
|
854
938
|
for rdr in self.start_rules_dict.values():
|
855
|
-
rdr.write_to_python_file(file_path)
|
939
|
+
rdr.write_to_python_file(file_path, postfix=f"_of_grdr{postfix}")
|
856
940
|
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
857
941
|
with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
|
858
942
|
f.write(self._get_imports(file_path) + "\n\n")
|
@@ -875,15 +959,29 @@ class GeneralRDR(RippleDownRules):
|
|
875
959
|
else:
|
876
960
|
return type(self.start_rule.corner_case)
|
877
961
|
|
878
|
-
def get_rdr_classifier_from_python_file(self, file_path: str):
|
962
|
+
def get_rdr_classifier_from_python_file(self, file_path: str) -> Callable[[Any], Any]:
|
879
963
|
"""
|
880
964
|
:param file_path: The path to the file that contains the RDR classifier function.
|
965
|
+
:param postfix: The postfix to add to the file name.
|
881
966
|
:return: The module that contains the rdr classifier function.
|
882
967
|
"""
|
883
968
|
return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
|
884
969
|
|
885
970
|
@property
|
886
971
|
def generated_python_file_name(self) -> str:
|
972
|
+
if self._generated_python_file_name is None:
|
973
|
+
self._generated_python_file_name = self._default_generated_python_file_name
|
974
|
+
return self._generated_python_file_name
|
975
|
+
|
976
|
+
@generated_python_file_name.setter
|
977
|
+
def generated_python_file_name(self, value: str):
|
978
|
+
self._generated_python_file_name = value
|
979
|
+
|
980
|
+
@property
|
981
|
+
def _default_generated_python_file_name(self) -> str:
|
982
|
+
"""
|
983
|
+
:return: The default generated python file name.
|
984
|
+
"""
|
887
985
|
return f"{self.start_rule.corner_case._name.lower()}_rdr"
|
888
986
|
|
889
987
|
@property
|
@@ -891,17 +989,25 @@ class GeneralRDR(RippleDownRules):
|
|
891
989
|
return f"List[Union[{', '.join([rdr.conclusion_type_hint for rdr in self.start_rules_dict.values()])}]]"
|
892
990
|
|
893
991
|
def _get_imports(self, file_path: str) -> str:
|
992
|
+
"""
|
993
|
+
Get the imports needed for the generated python file.
|
994
|
+
|
995
|
+
:param file_path: The path to the file that contains the RDR classifier function.
|
996
|
+
:return: The imports needed for the generated python file.
|
997
|
+
"""
|
894
998
|
imports = ""
|
895
999
|
# add type hints
|
896
1000
|
imports += f"from typing_extensions import List, Union, Set\n"
|
897
1001
|
# import rdr type
|
898
1002
|
imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
|
899
1003
|
# add case type
|
900
|
-
imports += f"from ripple_down_rules.datastructures import Case, create_case\n"
|
1004
|
+
imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
|
901
1005
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|
902
1006
|
# add conclusion type imports
|
903
1007
|
for rdr in self.start_rules_dict.values():
|
904
|
-
|
1008
|
+
for conclusion_type in rdr.conclusion_type:
|
1009
|
+
if conclusion_type.__module__ != "builtins":
|
1010
|
+
imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
|
905
1011
|
# add rdr python generated functions.
|
906
1012
|
for rdr_key, rdr in self.start_rules_dict.items():
|
907
1013
|
imports += (f"from {file_path.strip('./')}"
|