ripple-down-rules 0.1.21__py3-none-any.whl → 0.1.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +2 -1
- ripple_down_rules/datastructures/__init__.py +4 -4
- ripple_down_rules/datastructures/callable_expression.py +74 -129
- ripple_down_rules/datastructures/case.py +8 -6
- ripple_down_rules/datastructures/dataclasses.py +102 -48
- ripple_down_rules/datastructures/enums.py +5 -1
- ripple_down_rules/experts.py +61 -68
- ripple_down_rules/helpers.py +27 -3
- ripple_down_rules/prompt.py +87 -74
- ripple_down_rules/rdr.py +291 -206
- ripple_down_rules/rules.py +64 -32
- ripple_down_rules/utils.py +209 -4
- {ripple_down_rules-0.1.21.dist-info → ripple_down_rules-0.1.62.dist-info}/METADATA +5 -4
- ripple_down_rules-0.1.62.dist-info/RECORD +20 -0
- {ripple_down_rules-0.1.21.dist-info → ripple_down_rules-0.1.62.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.1.21.dist-info/RECORD +0 -20
- {ripple_down_rules-0.1.21.dist-info → ripple_down_rules-0.1.62.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.1.21.dist-info → ripple_down_rules-0.1.62.dist-info}/top_level.txt +0 -0
ripple_down_rules/rdr.py
CHANGED
@@ -1,20 +1,28 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import importlib
|
4
|
+
import sys
|
4
5
|
from abc import ABC, abstractmethod
|
5
6
|
from copy import copy
|
7
|
+
from dataclasses import is_dataclass
|
8
|
+
from io import TextIOWrapper
|
6
9
|
from types import ModuleType
|
7
10
|
|
8
11
|
from matplotlib import pyplot as plt
|
9
12
|
from ordered_set import OrderedSet
|
10
|
-
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
13
|
+
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
11
14
|
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
|
12
15
|
|
13
|
-
from .datastructures import
|
16
|
+
from .datastructures.callable_expression import CallableExpression
|
17
|
+
from .datastructures.case import Case, CaseAttribute, create_case
|
18
|
+
from .datastructures.dataclasses import CaseQuery
|
19
|
+
from .datastructures.enums import MCRDRMode, PromptFor
|
14
20
|
from .experts import Expert, Human
|
21
|
+
from .helpers import is_matching
|
15
22
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
16
23
|
from .utils import draw_tree, make_set, copy_case, \
|
17
|
-
|
24
|
+
SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
|
25
|
+
get_case_attribute_type, is_conflicting
|
18
26
|
|
19
27
|
|
20
28
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -29,14 +37,16 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
29
37
|
"""
|
30
38
|
The conclusions that the expert has accepted, such that they are not asked again.
|
31
39
|
"""
|
40
|
+
_generated_python_file_name: Optional[str] = None
|
41
|
+
"""
|
42
|
+
The name of the generated python file.
|
43
|
+
"""
|
32
44
|
|
33
|
-
def __init__(self, start_rule: Optional[Rule] = None
|
45
|
+
def __init__(self, start_rule: Optional[Rule] = None):
|
34
46
|
"""
|
35
47
|
:param start_rule: The starting rule for the classifier.
|
36
|
-
:param session: The sqlalchemy orm session.
|
37
48
|
"""
|
38
49
|
self.start_rule = start_rule
|
39
|
-
self.session = session
|
40
50
|
self.fig: Optional[plt.Figure] = None
|
41
51
|
|
42
52
|
def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
|
@@ -79,31 +89,29 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
79
89
|
:param animate_tree: Whether to draw the tree while fitting the classifier.
|
80
90
|
:param kwargs_for_fit_case: The keyword arguments to pass to the fit_case method.
|
81
91
|
"""
|
82
|
-
|
83
|
-
targets = [{case_query.attribute_name: case_query.target} for case_query in case_queries]
|
92
|
+
targets = []
|
84
93
|
if animate_tree:
|
85
94
|
plt.ion()
|
86
95
|
i = 0
|
87
96
|
stop_iterating = False
|
88
97
|
num_rules: int = 0
|
89
98
|
while not stop_iterating:
|
90
|
-
all_pred = 0
|
91
|
-
if not targets:
|
92
|
-
targets = [None] * len(cases)
|
93
99
|
for case_query in case_queries:
|
94
|
-
target = {case_query.attribute_name: case_query.target}
|
95
100
|
pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
|
96
|
-
|
101
|
+
if case_query.target is None:
|
102
|
+
continue
|
103
|
+
target = {case_query.attribute_name: case_query.target(case_query.case)}
|
104
|
+
if len(targets) < len(case_queries):
|
105
|
+
targets.append(target)
|
106
|
+
match = is_matching(self.classify, case_query, pred_cat)
|
97
107
|
if not match:
|
98
108
|
print(f"Predicted: {pred_cat} but expected: {target}")
|
99
|
-
all_pred += int(match)
|
100
109
|
if animate_tree and self.start_rule.size > num_rules:
|
101
110
|
num_rules = self.start_rule.size
|
102
111
|
self.update_figures()
|
103
112
|
i += 1
|
104
|
-
all_predictions = [1 if
|
105
|
-
|
106
|
-
for case_query in case_queries]
|
113
|
+
all_predictions = [1 if is_matching(self.classify, case_query) else 0 for case_query in case_queries
|
114
|
+
if case_query.target is not None]
|
107
115
|
all_pred = sum(all_predictions)
|
108
116
|
print(f"Accuracy: {all_pred}/{len(targets)}")
|
109
117
|
all_predicted = targets and all_pred == len(targets)
|
@@ -116,48 +124,6 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
116
124
|
plt.ioff()
|
117
125
|
plt.show()
|
118
126
|
|
119
|
-
@staticmethod
|
120
|
-
def calculate_precision_and_recall(pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> Tuple[
|
121
|
-
List[bool], List[bool]]:
|
122
|
-
"""
|
123
|
-
:param pred_cat: The predicted category.
|
124
|
-
:param target: The target category.
|
125
|
-
:return: The precision and recall of the classifier.
|
126
|
-
"""
|
127
|
-
pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
|
128
|
-
target = target if is_iterable(target) else [target]
|
129
|
-
recall = []
|
130
|
-
precision = []
|
131
|
-
if isinstance(pred_cat, dict):
|
132
|
-
for pred_key, pred_value in pred_cat.items():
|
133
|
-
if pred_key not in target:
|
134
|
-
continue
|
135
|
-
precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
|
136
|
-
for target_key, target_value in target.items():
|
137
|
-
if target_key not in pred_cat:
|
138
|
-
recall.append(False)
|
139
|
-
continue
|
140
|
-
if is_iterable(target_value):
|
141
|
-
recall.extend([v in pred_cat[target_key] for v in target_value])
|
142
|
-
else:
|
143
|
-
recall.append(target_value == pred_cat[target_key])
|
144
|
-
else:
|
145
|
-
if isinstance(target, dict):
|
146
|
-
target = list(target.values())
|
147
|
-
recall = [not yi or (yi in pred_cat) for yi in target]
|
148
|
-
target_types = [type(yi) for yi in target]
|
149
|
-
precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
|
150
|
-
return precision, recall
|
151
|
-
|
152
|
-
def is_matching(self, pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> bool:
|
153
|
-
"""
|
154
|
-
:param pred_cat: The predicted category.
|
155
|
-
:param target: The target category.
|
156
|
-
:return: Whether the classifier is matching or not.
|
157
|
-
"""
|
158
|
-
precision, recall = self.calculate_precision_and_recall(pred_cat, target)
|
159
|
-
return all(recall) and all(precision)
|
160
|
-
|
161
127
|
def update_figures(self):
|
162
128
|
"""
|
163
129
|
Update the figures of the classifier.
|
@@ -183,33 +149,51 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
183
149
|
"""
|
184
150
|
return hasattr(case, conclusion_name) and getattr(case, conclusion_name) is not None
|
185
151
|
|
152
|
+
@property
|
153
|
+
def type_(self):
|
154
|
+
return self.__class__
|
155
|
+
|
186
156
|
|
187
157
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
188
158
|
|
189
159
|
@abstractmethod
|
190
|
-
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = ""
|
160
|
+
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
|
161
|
+
defs_file: Optional[str] = None):
|
191
162
|
"""
|
192
163
|
Write the rules as source code to a file.
|
193
164
|
|
194
165
|
:param rule: The rule to write as source code.
|
195
166
|
:param file: The file to write the source code to.
|
196
167
|
:param parent_indent: The indentation of the parent rule.
|
168
|
+
:param defs_file: The file to write the definitions to.
|
197
169
|
"""
|
198
170
|
pass
|
199
171
|
|
200
|
-
def write_to_python_file(self, file_path: str):
|
172
|
+
def write_to_python_file(self, file_path: str, postfix: str = ""):
|
201
173
|
"""
|
202
174
|
Write the tree of rules as source code to a file.
|
203
175
|
|
204
176
|
:param file_path: The path to the file to write the source code to.
|
177
|
+
:param postfix: The postfix to add to the file name.
|
205
178
|
"""
|
179
|
+
self.generated_python_file_name = self._default_generated_python_file_name + postfix
|
206
180
|
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
207
|
-
|
208
|
-
|
181
|
+
file_name = file_path + f"/{self.generated_python_file_name}.py"
|
182
|
+
defs_file_name = file_path + f"/{self.generated_python_defs_file_name}.py"
|
183
|
+
imports = self._get_imports()
|
184
|
+
# clear the files first
|
185
|
+
with open(defs_file_name, "w") as f:
|
186
|
+
f.write(imports + "\n\n")
|
187
|
+
with open(file_name, "w") as f:
|
188
|
+
imports += f"from .{self.generated_python_defs_file_name} import *\n"
|
189
|
+
imports += f"from ripple_down_rules.rdr import {self.__class__.__name__}\n"
|
190
|
+
f.write(imports + "\n\n")
|
191
|
+
f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n\n")
|
192
|
+
f.write(f"type_ = {self.__class__.__name__}\n\n")
|
209
193
|
f.write(func_def)
|
210
194
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
211
195
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
212
|
-
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
|
196
|
+
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name)
|
213
197
|
|
214
198
|
@property
|
215
199
|
@abstractmethod
|
@@ -226,26 +210,72 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
226
210
|
imports = ""
|
227
211
|
if self.case_type.__module__ != "builtins":
|
228
212
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|
229
|
-
|
230
|
-
|
231
|
-
|
213
|
+
for conclusion_type in self.conclusion_type:
|
214
|
+
if conclusion_type.__module__ != "builtins":
|
215
|
+
imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
|
216
|
+
imports += "from ripple_down_rules.datastructures.case import Case, create_case\n"
|
232
217
|
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
233
|
-
if rule.conditions:
|
234
|
-
|
235
|
-
|
236
|
-
|
218
|
+
if not rule.conditions:
|
219
|
+
continue
|
220
|
+
if rule.conditions.scope is None or len(rule.conditions.scope) == 0:
|
221
|
+
continue
|
222
|
+
for k, v in rule.conditions.scope.items():
|
223
|
+
new_imports = f"from {v.__module__} import {v.__name__}\n"
|
224
|
+
if new_imports in imports:
|
225
|
+
continue
|
226
|
+
imports += new_imports
|
237
227
|
return imports
|
238
228
|
|
239
|
-
def get_rdr_classifier_from_python_file(self, package_name) -> Callable[[Any], Any]:
|
229
|
+
def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
|
240
230
|
"""
|
241
231
|
:param package_name: The name of the package that contains the RDR classifier function.
|
242
232
|
:return: The module that contains the rdr classifier function.
|
243
233
|
"""
|
244
|
-
|
234
|
+
# remove from imports if exists first
|
235
|
+
name = f"{package_name.strip('./')}.{self.generated_python_file_name}"
|
236
|
+
try:
|
237
|
+
module = importlib.import_module(name)
|
238
|
+
del sys.modules[name]
|
239
|
+
except ModuleNotFoundError:
|
240
|
+
pass
|
241
|
+
return importlib.import_module(name).classify
|
245
242
|
|
246
243
|
@property
|
247
244
|
def generated_python_file_name(self) -> str:
|
248
|
-
|
245
|
+
if self._generated_python_file_name is None:
|
246
|
+
self._generated_python_file_name = self._default_generated_python_file_name
|
247
|
+
return self._generated_python_file_name
|
248
|
+
|
249
|
+
@generated_python_file_name.setter
|
250
|
+
def generated_python_file_name(self, value: str):
|
251
|
+
"""
|
252
|
+
Set the generated python file name.
|
253
|
+
:param value: The new value for the generated python file name.
|
254
|
+
"""
|
255
|
+
self._generated_python_file_name = value
|
256
|
+
|
257
|
+
@property
|
258
|
+
def _default_generated_python_file_name(self) -> str:
|
259
|
+
"""
|
260
|
+
:return: The default generated python file name.
|
261
|
+
"""
|
262
|
+
return f"{self.start_rule.corner_case._name.lower()}_{self.attribute_name}_{self.acronym.lower()}"
|
263
|
+
|
264
|
+
@property
|
265
|
+
def generated_python_defs_file_name(self) -> str:
|
266
|
+
return f"{self.generated_python_file_name}_defs"
|
267
|
+
|
268
|
+
@property
|
269
|
+
def acronym(self) -> str:
|
270
|
+
"""
|
271
|
+
:return: The acronym of the classifier.
|
272
|
+
"""
|
273
|
+
if self.__class__.__name__ == "GeneralRDR":
|
274
|
+
return "GRDR"
|
275
|
+
elif self.__class__.__name__ == "MultiClassRDR":
|
276
|
+
return "MCRDR"
|
277
|
+
else:
|
278
|
+
return "SCRDR"
|
249
279
|
|
250
280
|
@property
|
251
281
|
def case_type(self) -> Type:
|
@@ -258,16 +288,17 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
258
288
|
return type(self.start_rule.corner_case)
|
259
289
|
|
260
290
|
@property
|
261
|
-
def conclusion_type(self) -> Type:
|
291
|
+
def conclusion_type(self) -> Tuple[Type]:
|
262
292
|
"""
|
263
293
|
:return: The type of the conclusion of the RDR classifier.
|
264
294
|
"""
|
265
295
|
if isinstance(self.start_rule.conclusion, CallableExpression):
|
266
296
|
return self.start_rule.conclusion.conclusion_type
|
267
297
|
else:
|
268
|
-
|
269
|
-
|
270
|
-
|
298
|
+
conclusion = self.start_rule.conclusion
|
299
|
+
if isinstance(conclusion, set):
|
300
|
+
return type(list(conclusion)[0]), set
|
301
|
+
return (type(conclusion),)
|
271
302
|
|
272
303
|
@property
|
273
304
|
def attribute_name(self) -> str:
|
@@ -279,8 +310,16 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
279
310
|
|
280
311
|
class SingleClassRDR(RDRWithCodeWriter):
|
281
312
|
|
313
|
+
def __init__(self, start_rule: Optional[SingleClassRule] = None, default_conclusion: Optional[Any] = None):
|
314
|
+
"""
|
315
|
+
:param start_rule: The starting rule for the classifier.
|
316
|
+
:param default_conclusion: The default conclusion for the classifier if no rules fire.
|
317
|
+
"""
|
318
|
+
super(SingleClassRDR, self).__init__(start_rule)
|
319
|
+
self.default_conclusion: Optional[Any] = default_conclusion
|
320
|
+
|
282
321
|
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
283
|
-
-> Union[CaseAttribute, CallableExpression]:
|
322
|
+
-> Union[CaseAttribute, CallableExpression, None]:
|
284
323
|
"""
|
285
324
|
Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
|
286
325
|
comparing the case with the target category if provided.
|
@@ -289,28 +328,31 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
289
328
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
290
329
|
:return: The category that the case belongs to.
|
291
330
|
"""
|
292
|
-
expert = expert if expert else Human(
|
293
|
-
if case_query.
|
294
|
-
|
331
|
+
expert = expert if expert else Human()
|
332
|
+
if case_query.default_value is not None and self.default_conclusion != case_query.default_value:
|
333
|
+
self.default_conclusion = case_query.default_value
|
334
|
+
case = case_query.case
|
335
|
+
target = expert.ask_for_conclusion(case_query) if case_query.target is None else case_query.target
|
336
|
+
if target is None:
|
337
|
+
return self.classify(case)
|
295
338
|
if not self.start_rule:
|
296
339
|
conditions = expert.ask_for_conditions(case_query)
|
297
|
-
self.start_rule = SingleClassRule(conditions,
|
340
|
+
self.start_rule = SingleClassRule(conditions, target, corner_case=case,
|
298
341
|
conclusion_name=case_query.attribute_name)
|
299
342
|
|
300
343
|
pred = self.evaluate(case_query.case)
|
301
|
-
|
302
|
-
if pred.conclusion != case_query.target:
|
344
|
+
if pred.conclusion(case) != target(case):
|
303
345
|
conditions = expert.ask_for_conditions(case_query, pred)
|
304
|
-
pred.fit_rule(case_query.case,
|
346
|
+
pred.fit_rule(case_query.case, target, conditions=conditions)
|
305
347
|
|
306
348
|
return self.classify(case_query.case)
|
307
349
|
|
308
|
-
def classify(self, case: Case) -> Optional[
|
350
|
+
def classify(self, case: Case) -> Optional[Any]:
|
309
351
|
"""
|
310
352
|
Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
|
311
353
|
"""
|
312
354
|
pred = self.evaluate(case)
|
313
|
-
return pred.conclusion if pred.fired else
|
355
|
+
return pred.conclusion(case) if pred.fired else self.default_conclusion
|
314
356
|
|
315
357
|
def evaluate(self, case: Case) -> SingleClassRule:
|
316
358
|
"""
|
@@ -319,23 +361,32 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
319
361
|
matched_rule = self.start_rule(case)
|
320
362
|
return matched_rule if matched_rule else self.start_rule
|
321
363
|
|
322
|
-
def
|
364
|
+
def write_to_python_file(self, file_path: str, postfix: str = ""):
|
365
|
+
super().write_to_python_file(file_path, postfix)
|
366
|
+
if self.default_conclusion is not None:
|
367
|
+
with open(file_path + f"/{self.generated_python_file_name}.py", "a") as f:
|
368
|
+
f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
|
369
|
+
|
370
|
+
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
|
371
|
+
defs_file: Optional[str] = None):
|
323
372
|
"""
|
324
373
|
Write the rules as source code to a file.
|
325
374
|
"""
|
326
375
|
if rule.conditions:
|
327
|
-
|
376
|
+
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
377
|
+
file.write(if_clause)
|
328
378
|
if rule.refinement:
|
329
|
-
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " "
|
379
|
+
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
|
380
|
+
defs_file=defs_file)
|
330
381
|
|
331
382
|
file.write(rule.write_conclusion_as_source_code(parent_indent))
|
332
383
|
|
333
384
|
if rule.alternative:
|
334
|
-
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
|
385
|
+
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
|
335
386
|
|
336
387
|
@property
|
337
388
|
def conclusion_type_hint(self) -> str:
|
338
|
-
return self.conclusion_type.__name__
|
389
|
+
return self.conclusion_type[0].__name__
|
339
390
|
|
340
391
|
def _to_json(self) -> Dict[str, Any]:
|
341
392
|
return {"start_rule": self.start_rule.to_json()}
|
@@ -369,28 +420,27 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
369
420
|
"""
|
370
421
|
|
371
422
|
def __init__(self, start_rule: Optional[Rule] = None,
|
372
|
-
mode: MCRDRMode = MCRDRMode.StopOnly
|
423
|
+
mode: MCRDRMode = MCRDRMode.StopOnly):
|
373
424
|
"""
|
374
425
|
:param start_rule: The starting rules for the classifier.
|
375
426
|
:param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
|
376
|
-
:param session: The sqlalchemy orm session.
|
377
427
|
"""
|
378
428
|
start_rule = MultiClassTopRule() if not start_rule else start_rule
|
379
|
-
super(MultiClassRDR, self).__init__(start_rule
|
429
|
+
super(MultiClassRDR, self).__init__(start_rule)
|
380
430
|
self.mode: MCRDRMode = mode
|
381
431
|
|
382
|
-
def classify(self, case: Union[Case, SQLTable]) ->
|
432
|
+
def classify(self, case: Union[Case, SQLTable]) -> Set[Any]:
|
383
433
|
evaluated_rule = self.start_rule
|
384
434
|
self.conclusions = []
|
385
435
|
while evaluated_rule:
|
386
436
|
next_rule = evaluated_rule(case)
|
387
437
|
if evaluated_rule.fired:
|
388
|
-
self.add_conclusion(evaluated_rule)
|
438
|
+
self.add_conclusion(evaluated_rule, case)
|
389
439
|
evaluated_rule = next_rule
|
390
|
-
return self.conclusions
|
440
|
+
return make_set(self.conclusions)
|
391
441
|
|
392
442
|
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None,
|
393
|
-
add_extra_conclusions: bool = False) ->
|
443
|
+
add_extra_conclusions: bool = False) -> Set[Union[CaseAttribute, CallableExpression, None]]:
|
394
444
|
"""
|
395
445
|
Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
|
396
446
|
or missing by comparing the case with the target category if provided.
|
@@ -400,72 +450,76 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
400
450
|
:param add_extra_conclusions: Whether to add extra conclusions after classification is done.
|
401
451
|
:return: The conclusions that the case belongs to.
|
402
452
|
"""
|
403
|
-
expert = expert if expert else Human(
|
453
|
+
expert = expert if expert else Human()
|
454
|
+
if case_query.target is None:
|
455
|
+
expert.ask_for_conclusion(case_query)
|
404
456
|
if case_query.target is None:
|
405
|
-
|
457
|
+
return self.classify(case_query.case)
|
458
|
+
self.update_start_rule(case_query, expert)
|
406
459
|
self.expert_accepted_conclusions = []
|
407
460
|
user_conclusions = []
|
408
|
-
self.update_start_rule(case_query, expert)
|
409
461
|
self.conclusions = []
|
410
462
|
self.stop_rule_conditions = None
|
411
463
|
evaluated_rule = self.start_rule
|
464
|
+
target = case_query.target(case_query.case)
|
412
465
|
while evaluated_rule:
|
413
466
|
next_rule = evaluated_rule(case_query.case)
|
414
|
-
|
467
|
+
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
468
|
+
good_conclusions = make_list(target) + user_conclusions + self.expert_accepted_conclusions
|
415
469
|
good_conclusions = make_set(good_conclusions)
|
416
470
|
|
417
471
|
if evaluated_rule.fired:
|
418
|
-
if
|
419
|
-
# if self.case_has_conclusion(case, evaluated_rule.conclusion):
|
472
|
+
if target and not make_set(rule_conclusion).issubset(good_conclusions):
|
420
473
|
# Rule fired and conclusion is different from target
|
421
474
|
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
|
422
|
-
|
475
|
+
len(user_conclusions) > 0)
|
423
476
|
else:
|
424
477
|
# Rule fired and target is correct or there is no target to compare
|
425
|
-
self.add_conclusion(evaluated_rule)
|
478
|
+
self.add_conclusion(evaluated_rule, case_query.case)
|
426
479
|
|
427
480
|
if not next_rule:
|
428
|
-
if not make_set(
|
481
|
+
if not make_set(target).issubset(make_set(self.conclusions)):
|
429
482
|
# Nothing fired and there is a target that should have been in the conclusions
|
430
483
|
self.add_rule_for_case(case_query, expert)
|
431
484
|
# Have to check all rules again to make sure only this new rule fires
|
432
485
|
next_rule = self.start_rule
|
433
|
-
elif add_extra_conclusions
|
486
|
+
elif add_extra_conclusions:
|
434
487
|
# No more conclusions can be made, ask the expert for extra conclusions if needed.
|
435
|
-
|
436
|
-
|
488
|
+
new_user_conclusions = self.ask_expert_for_extra_rules(expert, case_query)
|
489
|
+
user_conclusions.extend(new_user_conclusions)
|
490
|
+
if len(new_user_conclusions) > 0:
|
437
491
|
next_rule = self.last_top_rule
|
492
|
+
else:
|
493
|
+
add_extra_conclusions = False
|
438
494
|
evaluated_rule = next_rule
|
439
495
|
return self.conclusions
|
440
496
|
|
441
497
|
def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
|
442
|
-
file, parent_indent: str = ""):
|
443
|
-
"""
|
444
|
-
Write the rules as source code to a file.
|
445
|
-
|
446
|
-
:
|
447
|
-
"""
|
498
|
+
file, parent_indent: str = "", defs_file: Optional[str] = None):
|
448
499
|
if rule == self.start_rule:
|
449
500
|
file.write(f"{parent_indent}conclusions = set()\n")
|
450
501
|
if rule.conditions:
|
451
|
-
|
502
|
+
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
503
|
+
file.write(if_clause)
|
452
504
|
conclusion_indent = parent_indent
|
453
505
|
if hasattr(rule, "refinement") and rule.refinement:
|
454
|
-
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " "
|
506
|
+
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
|
507
|
+
defs_file=defs_file)
|
455
508
|
conclusion_indent = parent_indent + " " * 4
|
456
509
|
file.write(f"{conclusion_indent}else:\n")
|
457
510
|
file.write(rule.write_conclusion_as_source_code(conclusion_indent))
|
458
511
|
|
459
512
|
if rule.alternative:
|
460
|
-
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
|
513
|
+
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
|
461
514
|
|
462
515
|
@property
|
463
516
|
def conclusion_type_hint(self) -> str:
|
464
|
-
return f"Set[{self.conclusion_type.__name__}]"
|
517
|
+
return f"Set[{self.conclusion_type[0].__name__}]"
|
465
518
|
|
466
519
|
def _get_imports(self) -> str:
|
467
520
|
imports = super()._get_imports()
|
468
521
|
imports += "from typing_extensions import Set\n"
|
522
|
+
imports += "from ripple_down_rules.utils import make_set\n"
|
469
523
|
return imports
|
470
524
|
|
471
525
|
def update_start_rule(self, case_query: CaseQuery, expert: Expert):
|
@@ -498,11 +552,12 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
498
552
|
"""
|
499
553
|
Stop a wrong conclusion by adding a stopping rule.
|
500
554
|
"""
|
501
|
-
|
502
|
-
|
503
|
-
self.
|
504
|
-
|
505
|
-
|
555
|
+
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
556
|
+
if is_conflicting(rule_conclusion, case_query.target_value):
|
557
|
+
if self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
|
558
|
+
return
|
559
|
+
else:
|
560
|
+
self.stop_conclusion(case_query, expert, evaluated_rule)
|
506
561
|
|
507
562
|
def stop_conclusion(self, case_query: CaseQuery,
|
508
563
|
expert: Expert, evaluated_rule: MultiClassTopRule):
|
@@ -521,31 +576,6 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
521
576
|
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
522
577
|
self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
|
523
578
|
|
524
|
-
@staticmethod
|
525
|
-
def is_conflicting_with_target(conclusion: Any, target: Any) -> bool:
|
526
|
-
"""
|
527
|
-
Check if the conclusion is conflicting with the target category.
|
528
|
-
|
529
|
-
:param conclusion: The conclusion to check.
|
530
|
-
:param target: The target category to compare the conclusion with.
|
531
|
-
:return: Whether the conclusion is conflicting with the target category.
|
532
|
-
"""
|
533
|
-
if hasattr(conclusion, "mutually_exclusive") and conclusion.mutually_exclusive:
|
534
|
-
return True
|
535
|
-
else:
|
536
|
-
return not make_set(conclusion).issubset(make_set(target))
|
537
|
-
|
538
|
-
@staticmethod
|
539
|
-
def is_same_category_type(conclusion: Any, target: Any) -> bool:
|
540
|
-
"""
|
541
|
-
Check if the conclusion is of the same class as the target category.
|
542
|
-
|
543
|
-
:param conclusion: The conclusion to check.
|
544
|
-
:param target: The target category to compare the conclusion with.
|
545
|
-
:return: Whether the conclusion is of the same class as the target category but has a different value.
|
546
|
-
"""
|
547
|
-
return conclusion.__class__ == target.__class__ and target.__class__ != CaseAttribute
|
548
|
-
|
549
579
|
def conclusion_is_correct(self, case_query: CaseQuery,
|
550
580
|
expert: Expert, evaluated_rule: Rule,
|
551
581
|
add_extra_conclusions: bool) -> bool:
|
@@ -559,10 +589,10 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
559
589
|
:return: Whether the conclusion is correct or not.
|
560
590
|
"""
|
561
591
|
conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
|
562
|
-
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
|
563
|
-
|
592
|
+
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
|
593
|
+
evaluated_rule.conclusion(case_query.case),
|
564
594
|
current_conclusions=conclusions)):
|
565
|
-
self.add_conclusion(evaluated_rule)
|
595
|
+
self.add_conclusion(evaluated_rule, case_query.case)
|
566
596
|
self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
|
567
597
|
return True
|
568
598
|
return False
|
@@ -581,38 +611,39 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
581
611
|
conditions = expert.ask_for_conditions(case_query)
|
582
612
|
self.add_top_rule(conditions, case_query.target, case_query.case)
|
583
613
|
|
584
|
-
def
|
614
|
+
def ask_expert_for_extra_rules(self, expert: Expert, case_query: CaseQuery) -> List[Any]:
|
585
615
|
"""
|
586
|
-
Ask the expert for extra
|
616
|
+
Ask the expert for extra rules when no more conclusions can be made for a case.
|
587
617
|
|
588
618
|
:param expert: The expert to ask for extra conclusions.
|
589
|
-
:param
|
590
|
-
:return: The extra conclusions that the expert has provided.
|
619
|
+
:param case_query: The case query to ask the expert about.
|
620
|
+
:return: The extra conclusions for the rules that the expert has provided.
|
591
621
|
"""
|
592
622
|
extra_conclusions = []
|
593
623
|
conclusions = list(OrderedSet(self.conclusions))
|
594
624
|
if not expert.use_loaded_answers:
|
595
625
|
print("current conclusions:", conclusions)
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
extra_conclusions.append(conclusion)
|
626
|
+
extra_rules = expert.ask_for_extra_rules(case_query)
|
627
|
+
for rule in extra_rules:
|
628
|
+
self.add_top_rule(rule[PromptFor.Conditions], rule[PromptFor.Conclusion], case_query.case)
|
629
|
+
extra_conclusions.extend(rule[PromptFor.Conclusion](case_query.case))
|
601
630
|
return extra_conclusions
|
602
631
|
|
603
|
-
def add_conclusion(self, evaluated_rule: Rule) -> None:
|
632
|
+
def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
|
604
633
|
"""
|
605
634
|
Add the conclusion of the evaluated rule to the list of conclusions.
|
606
635
|
|
607
636
|
:param evaluated_rule: The evaluated rule to add the conclusion of.
|
637
|
+
:param case: The case to add the conclusion for.
|
608
638
|
"""
|
609
639
|
conclusion_types = [type(c) for c in self.conclusions]
|
610
|
-
|
611
|
-
|
640
|
+
rule_conclusion = evaluated_rule.conclusion(case)
|
641
|
+
if type(rule_conclusion) not in conclusion_types:
|
642
|
+
self.conclusions.extend(make_list(rule_conclusion))
|
612
643
|
else:
|
613
|
-
same_type_conclusions = [c for c in self.conclusions if type(c) == type(
|
614
|
-
combined_conclusion =
|
615
|
-
else {
|
644
|
+
same_type_conclusions = [c for c in self.conclusions if type(c) == type(rule_conclusion)]
|
645
|
+
combined_conclusion = rule_conclusion if isinstance(rule_conclusion, set) \
|
646
|
+
else {rule_conclusion}
|
616
647
|
combined_conclusion = copy(combined_conclusion)
|
617
648
|
for c in same_type_conclusions:
|
618
649
|
combined_conclusion.update(c if isinstance(c, set) else make_set(c))
|
@@ -713,6 +744,7 @@ class GeneralRDR(RippleDownRules):
|
|
713
744
|
:return: The categories that the case belongs to.
|
714
745
|
"""
|
715
746
|
conclusions = {}
|
747
|
+
case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
|
716
748
|
case_cp = copy_case(case)
|
717
749
|
while True:
|
718
750
|
new_conclusions = {}
|
@@ -720,22 +752,24 @@ class GeneralRDR(RippleDownRules):
|
|
720
752
|
pred_atts = rdr.classify(case_cp)
|
721
753
|
if pred_atts is None:
|
722
754
|
continue
|
723
|
-
if
|
755
|
+
if rdr.type_ is SingleClassRDR:
|
724
756
|
if attribute_name not in conclusions or \
|
725
757
|
(attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
|
726
758
|
conclusions[attribute_name] = pred_atts
|
727
759
|
new_conclusions[attribute_name] = pred_atts
|
728
760
|
else:
|
729
|
-
pred_atts =
|
761
|
+
pred_atts = make_set(pred_atts)
|
730
762
|
if attribute_name in conclusions:
|
731
|
-
pred_atts =
|
763
|
+
pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
|
732
764
|
if len(pred_atts) > 0:
|
733
765
|
new_conclusions[attribute_name] = pred_atts
|
734
766
|
if attribute_name not in conclusions:
|
735
|
-
conclusions[attribute_name] =
|
736
|
-
conclusions[attribute_name].
|
767
|
+
conclusions[attribute_name] = set()
|
768
|
+
conclusions[attribute_name].update(pred_atts)
|
737
769
|
if attribute_name in new_conclusions:
|
738
|
-
|
770
|
+
mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
|
771
|
+
GeneralRDR.update_case(CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive),
|
772
|
+
new_conclusions)
|
739
773
|
if len(new_conclusions) == 0:
|
740
774
|
break
|
741
775
|
return conclusions
|
@@ -761,76 +795,103 @@ class GeneralRDR(RippleDownRules):
|
|
761
795
|
case = case_queries[0].case
|
762
796
|
assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
|
763
797
|
" for multiple cases use fit instead")
|
764
|
-
|
798
|
+
original_case_query_cp = copy(case_queries[0])
|
765
799
|
for case_query in case_queries:
|
766
800
|
case_query_cp = copy(case_query)
|
767
|
-
case_query_cp.case =
|
768
|
-
if
|
801
|
+
case_query_cp.case = original_case_query_cp.case
|
802
|
+
if case_query_cp.target is None:
|
769
803
|
conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
|
770
|
-
|
804
|
+
self.update_case(case_query_cp, conclusions)
|
805
|
+
expert.ask_for_conclusion(case_query_cp)
|
806
|
+
if case_query_cp.target is None:
|
807
|
+
continue
|
808
|
+
case_query.target = case_query_cp.target
|
771
809
|
|
772
810
|
if case_query.attribute_name not in self.start_rules_dict:
|
773
811
|
conclusions = self.classify(case)
|
774
|
-
self.update_case(
|
812
|
+
self.update_case(case_query_cp, conclusions)
|
775
813
|
|
776
|
-
new_rdr = self.initialize_new_rdr_for_attribute(
|
814
|
+
new_rdr = self.initialize_new_rdr_for_attribute(case_query_cp)
|
777
815
|
self.add_rdr(new_rdr, case_query.attribute_name)
|
778
816
|
|
779
817
|
new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
|
780
|
-
self.update_case(
|
818
|
+
self.update_case(case_query_cp, {case_query.attribute_name: new_conclusions})
|
781
819
|
else:
|
782
820
|
for rdr_attribute_name, rdr in self.start_rules_dict.items():
|
783
821
|
if case_query.attribute_name != rdr_attribute_name:
|
784
|
-
conclusions = rdr.classify(
|
822
|
+
conclusions = rdr.classify(case_query_cp.case)
|
785
823
|
else:
|
786
824
|
conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
|
787
825
|
**kwargs)
|
788
826
|
if conclusions is not None or (is_iterable(conclusions) and len(conclusions) > 0):
|
789
827
|
conclusions = {rdr_attribute_name: conclusions}
|
790
|
-
|
828
|
+
case_query_cp.mutually_exclusive = True if isinstance(rdr, SingleClassRDR) else False
|
829
|
+
self.update_case(case_query_cp, conclusions)
|
830
|
+
case_query.conditions = case_query_cp.conditions
|
791
831
|
|
792
832
|
return self.classify(case)
|
793
833
|
|
794
834
|
@staticmethod
|
795
|
-
def initialize_new_rdr_for_attribute(
|
835
|
+
def initialize_new_rdr_for_attribute(case_query: CaseQuery):
|
796
836
|
"""
|
797
837
|
Initialize the appropriate RDR type for the target.
|
798
838
|
"""
|
799
|
-
|
839
|
+
if case_query.mutually_exclusive is not None:
|
840
|
+
return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
|
841
|
+
else MultiClassRDR()
|
842
|
+
if case_query.attribute_type in [list, set]:
|
843
|
+
return MultiClassRDR()
|
844
|
+
attribute = getattr(case_query.case, case_query.attribute_name) \
|
845
|
+
if hasattr(case_query.case, case_query.attribute_name) else case_query.target(case_query.case)
|
800
846
|
if isinstance(attribute, CaseAttribute):
|
801
|
-
return SingleClassRDR() if attribute.mutually_exclusive
|
847
|
+
return SingleClassRDR(default_conclusion=case_query.default_value) if attribute.mutually_exclusive \
|
848
|
+
else MultiClassRDR()
|
802
849
|
else:
|
803
|
-
return MultiClassRDR() if is_iterable(attribute) or (attribute is None)
|
850
|
+
return MultiClassRDR() if is_iterable(attribute) or (attribute is None) \
|
851
|
+
else SingleClassRDR(default_conclusion=case_query.default_value)
|
804
852
|
|
805
853
|
@staticmethod
|
806
|
-
def update_case(
|
854
|
+
def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
|
807
855
|
"""
|
808
856
|
Update the case with the conclusions.
|
809
857
|
|
810
|
-
:param
|
858
|
+
:param case_query: The case query that contains the case to update.
|
811
859
|
:param conclusions: The conclusions to update the case with.
|
812
860
|
"""
|
813
861
|
if not conclusions:
|
814
862
|
return
|
815
863
|
if len(conclusions) == 0:
|
816
864
|
return
|
817
|
-
if isinstance(
|
865
|
+
if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
|
818
866
|
for conclusion_name, conclusion in conclusions.items():
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
867
|
+
attribute = getattr(case_query.case, conclusion_name)
|
868
|
+
if conclusion_name == case_query.attribute_name:
|
869
|
+
attribute_type = case_query.attribute_type
|
870
|
+
else:
|
871
|
+
attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
|
872
|
+
if isinstance(attribute, set):
|
823
873
|
for c in conclusion:
|
824
874
|
attribute.update(make_set(c))
|
825
|
-
elif isinstance(attribute, list)
|
875
|
+
elif isinstance(attribute, list):
|
876
|
+
attribute.extend(conclusion)
|
877
|
+
elif any(at in {List, list} for at in attribute_type):
|
826
878
|
attribute = [] if attribute is None else attribute
|
827
879
|
attribute.extend(conclusion)
|
828
|
-
elif (
|
829
|
-
|
880
|
+
elif any(at in {Set, set} for at in attribute_type):
|
881
|
+
attribute = set() if attribute is None else attribute
|
882
|
+
for c in conclusion:
|
883
|
+
attribute.update(make_set(c))
|
884
|
+
elif is_iterable(conclusion) and len(conclusion) == 1 \
|
885
|
+
and any(at is type(list(conclusion)[0]) for at in attribute_type):
|
886
|
+
setattr(case_query.case, conclusion_name, list(conclusion)[0])
|
887
|
+
elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
|
888
|
+
setattr(case_query.case, conclusion_name, conclusion)
|
830
889
|
else:
|
831
|
-
raise ValueError(f"
|
890
|
+
raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
|
891
|
+
f"{case_query.attribute_type} with conclusion "
|
892
|
+
f"{conclusion} of type {type(conclusion)}")
|
832
893
|
else:
|
833
|
-
case.update(conclusions)
|
894
|
+
case_query.case.update(conclusions)
|
834
895
|
|
835
896
|
def _to_json(self) -> Dict[str, Any]:
|
836
897
|
return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
|
@@ -845,14 +906,16 @@ class GeneralRDR(RippleDownRules):
|
|
845
906
|
start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
|
846
907
|
return cls(start_rules_dict)
|
847
908
|
|
848
|
-
def write_to_python_file(self, file_path: str):
|
909
|
+
def write_to_python_file(self, file_path: str, postfix: str = "") -> None:
|
849
910
|
"""
|
850
911
|
Write the tree of rules as source code to a file.
|
851
912
|
|
852
913
|
:param file_path: The path to the file to write the source code to.
|
914
|
+
:param postfix: The postfix to add to the file name.
|
853
915
|
"""
|
916
|
+
self.generated_python_file_name = self._default_generated_python_file_name + postfix
|
854
917
|
for rdr in self.start_rules_dict.values():
|
855
|
-
rdr.write_to_python_file(file_path)
|
918
|
+
rdr.write_to_python_file(file_path, postfix=f"_of_grdr{postfix}")
|
856
919
|
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
857
920
|
with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
|
858
921
|
f.write(self._get_imports(file_path) + "\n\n")
|
@@ -875,15 +938,29 @@ class GeneralRDR(RippleDownRules):
|
|
875
938
|
else:
|
876
939
|
return type(self.start_rule.corner_case)
|
877
940
|
|
878
|
-
def get_rdr_classifier_from_python_file(self, file_path: str):
|
941
|
+
def get_rdr_classifier_from_python_file(self, file_path: str) -> Callable[[Any], Any]:
|
879
942
|
"""
|
880
943
|
:param file_path: The path to the file that contains the RDR classifier function.
|
944
|
+
:param postfix: The postfix to add to the file name.
|
881
945
|
:return: The module that contains the rdr classifier function.
|
882
946
|
"""
|
883
947
|
return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
|
884
948
|
|
885
949
|
@property
|
886
950
|
def generated_python_file_name(self) -> str:
|
951
|
+
if self._generated_python_file_name is None:
|
952
|
+
self._generated_python_file_name = self._default_generated_python_file_name
|
953
|
+
return self._generated_python_file_name
|
954
|
+
|
955
|
+
@generated_python_file_name.setter
|
956
|
+
def generated_python_file_name(self, value: str):
|
957
|
+
self._generated_python_file_name = value
|
958
|
+
|
959
|
+
@property
|
960
|
+
def _default_generated_python_file_name(self) -> str:
|
961
|
+
"""
|
962
|
+
:return: The default generated python file name.
|
963
|
+
"""
|
887
964
|
return f"{self.start_rule.corner_case._name.lower()}_rdr"
|
888
965
|
|
889
966
|
@property
|
@@ -891,17 +968,25 @@ class GeneralRDR(RippleDownRules):
|
|
891
968
|
return f"List[Union[{', '.join([rdr.conclusion_type_hint for rdr in self.start_rules_dict.values()])}]]"
|
892
969
|
|
893
970
|
def _get_imports(self, file_path: str) -> str:
|
971
|
+
"""
|
972
|
+
Get the imports needed for the generated python file.
|
973
|
+
|
974
|
+
:param file_path: The path to the file that contains the RDR classifier function.
|
975
|
+
:return: The imports needed for the generated python file.
|
976
|
+
"""
|
894
977
|
imports = ""
|
895
978
|
# add type hints
|
896
979
|
imports += f"from typing_extensions import List, Union, Set\n"
|
897
980
|
# import rdr type
|
898
981
|
imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
|
899
982
|
# add case type
|
900
|
-
imports += f"from ripple_down_rules.datastructures import Case, create_case\n"
|
983
|
+
imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
|
901
984
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|
902
985
|
# add conclusion type imports
|
903
986
|
for rdr in self.start_rules_dict.values():
|
904
|
-
|
987
|
+
for conclusion_type in rdr.conclusion_type:
|
988
|
+
if conclusion_type.__module__ != "builtins":
|
989
|
+
imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
|
905
990
|
# add rdr python generated functions.
|
906
991
|
for rdr_key, rdr in self.start_rules_dict.items():
|
907
992
|
imports += (f"from {file_path.strip('./')}"
|