ripple-down-rules 0.6.1__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +21 -1
- ripple_down_rules/datastructures/callable_expression.py +24 -7
- ripple_down_rules/datastructures/case.py +12 -11
- ripple_down_rules/datastructures/dataclasses.py +135 -14
- ripple_down_rules/datastructures/enums.py +29 -86
- ripple_down_rules/datastructures/field_info.py +177 -0
- ripple_down_rules/datastructures/tracked_object.py +208 -0
- ripple_down_rules/experts.py +141 -50
- ripple_down_rules/failures.py +4 -0
- ripple_down_rules/helpers.py +75 -8
- ripple_down_rules/predicates.py +97 -0
- ripple_down_rules/rdr.py +712 -96
- ripple_down_rules/rdr_decorators.py +164 -112
- ripple_down_rules/rules.py +351 -114
- ripple_down_rules/user_interface/gui.py +66 -41
- ripple_down_rules/user_interface/ipython_custom_shell.py +46 -9
- ripple_down_rules/user_interface/prompt.py +80 -60
- ripple_down_rules/user_interface/template_file_creator.py +13 -8
- ripple_down_rules/utils.py +537 -53
- {ripple_down_rules-0.6.1.dist-info → ripple_down_rules-0.6.6.dist-info}/METADATA +4 -1
- ripple_down_rules-0.6.6.dist-info/RECORD +28 -0
- ripple_down_rules-0.6.1.dist-info/RECORD +0 -24
- {ripple_down_rules-0.6.1.dist-info → ripple_down_rules-0.6.6.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.6.1.dist-info → ripple_down_rules-0.6.6.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.6.1.dist-info → ripple_down_rules-0.6.6.dist-info}/top_level.txt +0 -0
ripple_down_rules/rdr.py
CHANGED
@@ -1,17 +1,24 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import ast
|
3
4
|
import importlib
|
5
|
+
import json
|
4
6
|
import os
|
7
|
+
import sys
|
5
8
|
from abc import ABC, abstractmethod
|
6
9
|
from copy import copy
|
7
|
-
from
|
10
|
+
from dataclasses import is_dataclass
|
11
|
+
from io import TextIOWrapper
|
12
|
+
from os.path import dirname
|
13
|
+
from pathlib import Path
|
14
|
+
from types import NoneType, ModuleType
|
8
15
|
|
9
16
|
from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
|
10
17
|
from . import logger
|
18
|
+
from .failures import RDRLoadError
|
11
19
|
|
12
20
|
try:
|
13
21
|
from matplotlib import pyplot as plt
|
14
|
-
|
15
22
|
Figure = plt.Figure
|
16
23
|
except ImportError as e:
|
17
24
|
logger.debug(f"{e}: matplotlib is not installed")
|
@@ -25,18 +32,22 @@ from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tupl
|
|
25
32
|
from .datastructures.callable_expression import CallableExpression
|
26
33
|
from .datastructures.case import Case, CaseAttribute, create_case
|
27
34
|
from .datastructures.dataclasses import CaseQuery
|
28
|
-
from .datastructures.enums import MCRDRMode
|
35
|
+
from .datastructures.enums import MCRDRMode, RDREdge
|
29
36
|
from .experts import Expert, Human
|
30
|
-
from .helpers import is_matching, general_rdr_classify
|
31
|
-
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
37
|
+
from .helpers import is_matching, general_rdr_classify, get_an_updated_case_copy
|
38
|
+
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule, MultiClassRefinementRule, \
|
39
|
+
MultiClassFilterRule
|
32
40
|
|
33
41
|
try:
|
34
42
|
from .user_interface.gui import RDRCaseViewer
|
35
43
|
except ImportError as e:
|
36
44
|
RDRCaseViewer = None
|
37
45
|
from .utils import draw_tree, make_set, SubclassJSONSerializer, make_list, get_type_from_string, \
|
38
|
-
|
39
|
-
is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types
|
46
|
+
is_value_conflicting, extract_function_source, extract_imports, get_full_class_name, \
|
47
|
+
is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types, render_tree, \
|
48
|
+
get_types_to_import_from_func_type_hints, get_function_return_type, get_file_that_ends_with, \
|
49
|
+
get_and_import_python_module, get_and_import_python_modules_in_a_package, get_type_from_type_hint, \
|
50
|
+
are_results_subclass_of_types
|
40
51
|
|
41
52
|
|
42
53
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -80,20 +91,98 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
80
91
|
Whether the output of the classification of this rdr allows only one possible conclusion or not.
|
81
92
|
"""
|
82
93
|
|
83
|
-
def __init__(self, start_rule: Optional[Rule] = None,
|
94
|
+
def __init__(self, start_rule: Optional[Rule] = None,
|
84
95
|
save_dir: Optional[str] = None, model_name: Optional[str] = None):
|
85
96
|
"""
|
86
97
|
:param start_rule: The starting rule for the classifier.
|
87
|
-
:param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
|
88
98
|
:param save_dir: The directory to save the classifier to.
|
89
99
|
"""
|
90
100
|
self.model_name: Optional[str] = model_name
|
91
|
-
self.save_dir = save_dir
|
92
|
-
self.start_rule = start_rule
|
101
|
+
self.save_dir: Optional[str] = save_dir
|
102
|
+
self.start_rule: Optional[Rule] = start_rule
|
93
103
|
self.fig: Optional[Figure] = None
|
94
|
-
self.viewer: Optional[RDRCaseViewer] =
|
95
|
-
|
96
|
-
|
104
|
+
self.viewer: Optional[RDRCaseViewer] = RDRCaseViewer.instances[0]\
|
105
|
+
if RDRCaseViewer and any(RDRCaseViewer.instances) else None
|
106
|
+
self.input_node: Optional[Rule] = None
|
107
|
+
self.update_model()
|
108
|
+
|
109
|
+
def update_model(self):
|
110
|
+
"""
|
111
|
+
Update the model from the model directory if it exists.
|
112
|
+
"""
|
113
|
+
if self.save_dir is not None and self.model_name is not None:
|
114
|
+
model_path = os.path.join(self.save_dir, self.model_name)
|
115
|
+
if os.path.exists(model_path):
|
116
|
+
try:
|
117
|
+
self.update_from_python(model_path, update_rule_tree=True)
|
118
|
+
except (FileNotFoundError, ModuleNotFoundError) as e:
|
119
|
+
pass
|
120
|
+
|
121
|
+
def write_rdr_metadata_to_pyton_file(self, file: TextIOWrapper):
|
122
|
+
"""
|
123
|
+
Write the metadata of the RDR classifier to a python file.
|
124
|
+
|
125
|
+
:param file: The file to write the metadata to.
|
126
|
+
"""
|
127
|
+
file.write(f"name = \'{self.name}\'\n")
|
128
|
+
file.write(f"case_type = {self.case_type.__name__ if self.case_type is not None else None}\n")
|
129
|
+
file.write(f"case_name = \'{self.case_name}\'\n")
|
130
|
+
|
131
|
+
def update_rdr_metadata_from_python(self, module: ModuleType):
|
132
|
+
"""
|
133
|
+
Update the RDR metadata from the module that contains the RDR classifier function.
|
134
|
+
|
135
|
+
:param module: The module that contains the RDR classifier function.
|
136
|
+
"""
|
137
|
+
try:
|
138
|
+
self.name = module.name if hasattr(module, "name") else self.start_rule.conclusion_name
|
139
|
+
self.case_type = module.case_type
|
140
|
+
self.case_name = module.case_name if hasattr(module, "case_name") else f"{self.case_type.__name__}.{self.name}"
|
141
|
+
except AttributeError as e:
|
142
|
+
logger.warning(f"Could not update the RDR metadata from the module {module.__name__}. "
|
143
|
+
f"Make sure the module has the required attributes: {e}")
|
144
|
+
|
145
|
+
def render_evaluated_rule_tree(self, filename: str, show_full_tree: bool = False) -> None:
|
146
|
+
if show_full_tree:
|
147
|
+
start_rule = self.start_rule if self.input_node is None else self.input_node
|
148
|
+
render_tree(start_rule, use_dot_exporter=True, filename=filename)
|
149
|
+
else:
|
150
|
+
evaluated_rules = self.get_evaluated_rule_tree()
|
151
|
+
if evaluated_rules is not None and len(evaluated_rules) > 0:
|
152
|
+
render_tree(evaluated_rules[0], use_dot_exporter=True, filename=filename,
|
153
|
+
only_nodes=evaluated_rules)
|
154
|
+
|
155
|
+
def get_contributing_rules(self) -> Optional[List[Rule]]:
|
156
|
+
"""
|
157
|
+
Get the contributing rules of the classifier.
|
158
|
+
|
159
|
+
:return: The contributing rules.
|
160
|
+
"""
|
161
|
+
if self.start_rule is None:
|
162
|
+
return None
|
163
|
+
return [r for r in self.get_fired_rule_tree() if r.contributed]
|
164
|
+
|
165
|
+
def get_fired_rule_tree(self) -> Optional[List[Rule]]:
|
166
|
+
"""
|
167
|
+
Get the fired rule tree of the classifier.
|
168
|
+
|
169
|
+
:return: The fired rule tree.
|
170
|
+
"""
|
171
|
+
if self.start_rule is None:
|
172
|
+
return None
|
173
|
+
return [r for r in self.get_evaluated_rule_tree() if r.fired]
|
174
|
+
|
175
|
+
def get_evaluated_rule_tree(self) -> Optional[List[Rule]]:
|
176
|
+
"""
|
177
|
+
Get the evaluated rule tree of the classifier.
|
178
|
+
|
179
|
+
:return: The evaluated rule tree.
|
180
|
+
"""
|
181
|
+
if self.start_rule is None:
|
182
|
+
return None
|
183
|
+
start_rule = self.start_rule
|
184
|
+
evaluated_rule_tree = [r for r in [start_rule] + list(start_rule.descendants) if r.evaluated]
|
185
|
+
return evaluated_rule_tree
|
97
186
|
|
98
187
|
def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None,
|
99
188
|
package_name: Optional[str] = None) -> str:
|
@@ -137,18 +226,43 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
137
226
|
:param package_name: The name of the package that contains the RDR classifier function, this
|
138
227
|
is required in case of relative imports in the generated python file.
|
139
228
|
"""
|
229
|
+
rdr: Optional[RippleDownRules] = None
|
140
230
|
model_dir = os.path.join(load_dir, model_name)
|
141
231
|
json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
|
142
|
-
|
143
|
-
|
144
|
-
rdr.model_name = model_name
|
232
|
+
if os.path.exists(json_file + ".json"):
|
233
|
+
rdr = cls.from_json_file(json_file)
|
145
234
|
try:
|
146
|
-
rdr
|
235
|
+
if rdr is None:
|
236
|
+
rdr = cls.from_python(model_dir, parent_package_name=package_name)
|
237
|
+
else:
|
238
|
+
rdr.update_from_python(model_dir, parent_package_name=package_name)
|
147
239
|
rdr.to_json_file(json_file)
|
148
|
-
except (FileNotFoundError, ValueError, SyntaxError) as e:
|
240
|
+
except (FileNotFoundError, ValueError, SyntaxError, ModuleNotFoundError) as e:
|
149
241
|
logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
|
150
242
|
f"Make sure the file exists and is valid.")
|
243
|
+
if rdr is None:
|
244
|
+
raise RDRLoadError(f"Could not load the rdr model {model_name} from {model_dir}, error is {e}")
|
151
245
|
rdr.save(save_dir=load_dir, model_name=model_name, package_name=package_name)
|
246
|
+
rdr.save_dir = load_dir
|
247
|
+
rdr.model_name = model_name
|
248
|
+
return rdr
|
249
|
+
|
250
|
+
@classmethod
|
251
|
+
def from_python(cls, model_path: str,
|
252
|
+
python_file_path: Optional[str] = None,
|
253
|
+
parent_package_name: Optional[str] = None) -> Self:
|
254
|
+
"""
|
255
|
+
Load the RDR classifier from a generated python file.
|
256
|
+
|
257
|
+
:param model_path: The directory where the generated python file is located.
|
258
|
+
:param python_file_path: The path to the generated python file that contains the RDR classifier function.
|
259
|
+
:param parent_package_name: The name of the package that contains the RDR classifier function, this
|
260
|
+
is required in case of relative imports in the generated python file.
|
261
|
+
:return: An instance of the RDR classifier.
|
262
|
+
"""
|
263
|
+
rdr = cls()
|
264
|
+
rdr.update_from_python(model_path, parent_package_name=parent_package_name, python_file_path=python_file_path,
|
265
|
+
update_rule_tree=True)
|
152
266
|
return rdr
|
153
267
|
|
154
268
|
@abstractmethod
|
@@ -162,16 +276,6 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
162
276
|
"""
|
163
277
|
pass
|
164
278
|
|
165
|
-
def set_viewer(self, viewer: RDRCaseViewer):
|
166
|
-
"""
|
167
|
-
Set the viewer for the classifier.
|
168
|
-
|
169
|
-
:param viewer: The viewer to set.
|
170
|
-
"""
|
171
|
-
self.viewer = viewer
|
172
|
-
if self.viewer is not None:
|
173
|
-
self.viewer.set_save_function(self.save)
|
174
|
-
|
175
279
|
def fit(self, case_queries: List[CaseQuery],
|
176
280
|
expert: Optional[Expert] = None,
|
177
281
|
n_iter: int = None,
|
@@ -196,7 +300,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
196
300
|
num_rules: int = 0
|
197
301
|
while not stop_iterating:
|
198
302
|
for case_query in case_queries:
|
199
|
-
pred_cat = self.fit_case(case_query, expert=expert,
|
303
|
+
pred_cat = self.fit_case(case_query, expert=expert, clear_expert_answers=False,
|
304
|
+
**kwargs_for_fit_case)
|
200
305
|
if case_query.target is None:
|
201
306
|
continue
|
202
307
|
target = {case_query.attribute_name: case_query.target(case_query.case)}
|
@@ -226,9 +331,37 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
226
331
|
def __call__(self, case: Union[Case, SQLTable]) -> Union[CallableExpression, Dict[str, CallableExpression]]:
|
227
332
|
return self.classify(case)
|
228
333
|
|
334
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False, case_query: Optional[CaseQuery] = None) \
|
335
|
+
-> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
|
336
|
+
"""
|
337
|
+
Classify a case using the RDR classifier.
|
338
|
+
|
339
|
+
:param case: The case to classify.
|
340
|
+
:param modify_case: Whether to modify the original case attributes with the conclusion or not.
|
341
|
+
:param case_query: The case query containing the case to classify and the target category to compare the case with.
|
342
|
+
:return: The category that the case belongs to.
|
343
|
+
"""
|
344
|
+
if self.start_rule is not None:
|
345
|
+
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
346
|
+
rule.reset()
|
347
|
+
if self.start_rule is not None and self.start_rule.parent is None:
|
348
|
+
if self.input_node is None:
|
349
|
+
self.input_node = type(self.start_rule)(_parent=None, uid='0')
|
350
|
+
self.input_node.evaluated = False
|
351
|
+
self.input_node.fired = False
|
352
|
+
self.start_rule.parent = self.input_node
|
353
|
+
self.start_rule.weight = RDREdge.Empty
|
354
|
+
if self.input_node is not None:
|
355
|
+
data = case.__dict__ if is_dataclass(case) else case
|
356
|
+
if hasattr(case, "items"):
|
357
|
+
self.input_node.name = json.dumps({k: str(v) for k, v in data.items()}, indent=4)
|
358
|
+
else:
|
359
|
+
self.input_node.name = str(data)
|
360
|
+
return self._classify(case, modify_case=modify_case, case_query=case_query)
|
361
|
+
|
229
362
|
@abstractmethod
|
230
|
-
def
|
231
|
-
|
363
|
+
def _classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
|
364
|
+
case_query: Optional[CaseQuery] = None) \
|
232
365
|
-> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
|
233
366
|
"""
|
234
367
|
Classify a case.
|
@@ -244,6 +377,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
244
377
|
expert: Optional[Expert] = None,
|
245
378
|
update_existing_rules: bool = True,
|
246
379
|
scenario: Optional[Callable] = None,
|
380
|
+
ask_now: Callable = lambda _: True,
|
381
|
+
clear_expert_answers: bool = True,
|
247
382
|
**kwargs) \
|
248
383
|
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
249
384
|
"""
|
@@ -255,6 +390,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
255
390
|
:param update_existing_rules: Whether to update the existing same conclusion type rules that already gave
|
256
391
|
some conclusions with the type required by the case query.
|
257
392
|
:param scenario: The scenario at which the case was created, this is used to recreate the case if needed.
|
393
|
+
:param ask_now: Whether to ask the expert for refinements or alternatives.
|
394
|
+
:param clear_expert_answers: Whether to clear expert answers after saving the new rule.
|
258
395
|
:return: The category that the case belongs to.
|
259
396
|
"""
|
260
397
|
if case_query is None:
|
@@ -264,14 +401,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
264
401
|
self.case_type = case_query.case_type if self.case_type is None else self.case_type
|
265
402
|
self.case_name = case_query.case_name if self.case_name is None else self.case_name
|
266
403
|
case_query.scenario = scenario if case_query.scenario is None else case_query.scenario
|
404
|
+
case_query.rdr = self
|
267
405
|
|
268
|
-
expert = expert or Human(
|
269
|
-
answers_save_path=self.save_dir + '/expert_answers'
|
406
|
+
expert = expert or Human(answers_save_path=self.save_dir + '/expert_answers'
|
270
407
|
if self.save_dir else None)
|
271
408
|
if case_query.target is None:
|
272
409
|
case_query_cp = copy(case_query)
|
273
410
|
conclusions = self.classify(case_query_cp.case, modify_case=True, case_query=case_query_cp)
|
274
|
-
if self.should_i_ask_the_expert_for_a_target(conclusions, case_query_cp, update_existing_rules)
|
411
|
+
if (self.should_i_ask_the_expert_for_a_target(conclusions, case_query_cp, update_existing_rules)
|
412
|
+
and ask_now(case_query_cp.case)):
|
275
413
|
expert.ask_for_conclusion(case_query_cp)
|
276
414
|
case_query.target = case_query_cp.target
|
277
415
|
if case_query.target is None:
|
@@ -283,7 +421,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
283
421
|
|
284
422
|
if self.save_dir is not None:
|
285
423
|
self.save()
|
286
|
-
|
424
|
+
if clear_expert_answers:
|
425
|
+
expert.clear_answers()
|
287
426
|
|
288
427
|
return fit_case_result
|
289
428
|
|
@@ -356,6 +495,61 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
356
495
|
def type_(self):
|
357
496
|
return self.__class__
|
358
497
|
|
498
|
+
@classmethod
|
499
|
+
def get_json_file_path(cls, model_path: str) -> str:
|
500
|
+
"""
|
501
|
+
Get the path to the saved json file.
|
502
|
+
|
503
|
+
:param model_path : The path to the model directory.
|
504
|
+
:return: The path to the saved model.
|
505
|
+
"""
|
506
|
+
model_name = cls.get_model_name_from_model_path(model_path)
|
507
|
+
return os.path.join(model_path, cls.metadata_folder, f"{model_name}.json")
|
508
|
+
|
509
|
+
@classmethod
|
510
|
+
def get_generated_cases_file_path(cls, model_path: str) -> str:
|
511
|
+
"""
|
512
|
+
Get the path to the python file that contains the RDR classifier cases.
|
513
|
+
|
514
|
+
:param model_path : The path to the model directory.
|
515
|
+
:return: The path to the generated python file.
|
516
|
+
"""
|
517
|
+
return cls.get_generated_python_file_path(model_path).replace(".py", "_cases.py")
|
518
|
+
|
519
|
+
@classmethod
|
520
|
+
def get_generated_defs_file_path(cls, model_path: str) -> str:
|
521
|
+
"""
|
522
|
+
Get the path to the python file that contains the RDR classifier function definitions.
|
523
|
+
|
524
|
+
:param model_path : The path to the model directory.
|
525
|
+
:return: The path to the generated python file.
|
526
|
+
"""
|
527
|
+
return cls.get_generated_python_file_path(model_path).replace(".py", "_defs.py")
|
528
|
+
|
529
|
+
@classmethod
|
530
|
+
def get_generated_python_file_path(cls, model_path: str) -> str:
|
531
|
+
"""
|
532
|
+
Get the path to the python file that contains the RDR classifier function.
|
533
|
+
|
534
|
+
:param model_path : The path to the model directory.
|
535
|
+
:return: The path to the generated python file.
|
536
|
+
"""
|
537
|
+
model_name = cls.get_model_name_from_model_path(model_path)
|
538
|
+
return os.path.join(model_path, f"{model_name}.py")
|
539
|
+
|
540
|
+
@classmethod
|
541
|
+
def get_model_name_from_model_path(cls, model_path: str) -> str:
|
542
|
+
"""
|
543
|
+
Get the model name from the model path.
|
544
|
+
|
545
|
+
:param model_path: The path to the model directory.
|
546
|
+
:return: The name of the model.
|
547
|
+
"""
|
548
|
+
file_name = get_file_that_ends_with(model_path, f"_{cls.get_acronym().lower()}.py")
|
549
|
+
if file_name is None:
|
550
|
+
raise FileNotFoundError(f"Could not find the python file for the model in the given path: {model_path}.")
|
551
|
+
return file_name.replace('.py', '')
|
552
|
+
|
359
553
|
@property
|
360
554
|
def generated_python_file_name(self) -> str:
|
361
555
|
if self._generated_python_file_name is None:
|
@@ -379,13 +573,17 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
379
573
|
pass
|
380
574
|
|
381
575
|
@abstractmethod
|
382
|
-
def update_from_python(self, model_dir: str,
|
576
|
+
def update_from_python(self, model_dir: str, parent_package_name: Optional[str] = None,
|
577
|
+
python_file_path: Optional[str] = None,
|
578
|
+
update_rule_tree: bool = False):
|
383
579
|
"""
|
384
580
|
Update the rules from the generated python file, that might have been modified by the user.
|
385
581
|
|
386
582
|
:param model_dir: The directory where the generated python file is located.
|
387
|
-
:param
|
583
|
+
:param parent_package_name: The name of the package that contains the RDR classifier function, this
|
388
584
|
is required in case of relative imports in the generated python file.
|
585
|
+
:param python_file_path: The path to the generated python file that contains the RDR classifier function.
|
586
|
+
:param update_rule_tree: Whether to update the rule tree from the python file or not.
|
389
587
|
"""
|
390
588
|
pass
|
391
589
|
|
@@ -414,41 +612,363 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
414
612
|
return module.classify
|
415
613
|
|
416
614
|
|
615
|
+
class TreeBuilder(ast.NodeVisitor, ABC):
|
616
|
+
"""Parses an AST of nested if-elif statements and reconstructs the tree."""
|
617
|
+
|
618
|
+
def __init__(self):
|
619
|
+
self.root: Optional[Rule] = None
|
620
|
+
self.current_parent: Optional[Rule] = None
|
621
|
+
self.current_edge: Optional[RDREdge] = None
|
622
|
+
self.default_conclusion: Optional[str] = None
|
623
|
+
|
624
|
+
def visit_FunctionDef(self, node):
|
625
|
+
"""Finds the main function and starts parsing its body."""
|
626
|
+
for stmt in node.body:
|
627
|
+
self.visit(stmt)
|
628
|
+
|
629
|
+
def visit_If(self, node):
|
630
|
+
"""Handles if-elif blocks and creates nodes."""
|
631
|
+
condition = self.get_condition_name(node.test)
|
632
|
+
if condition is None:
|
633
|
+
return
|
634
|
+
rule_uid = condition.split("conditions_")[1]
|
635
|
+
|
636
|
+
new_rule_type = self.get_new_rule_type(node)
|
637
|
+
new_node = new_rule_type(conditions=condition, _parent=self.current_parent, uid=rule_uid)
|
638
|
+
if self.current_parent is not None:
|
639
|
+
self.update_current_parent(new_node)
|
640
|
+
|
641
|
+
if self.current_parent is None and self.root is None:
|
642
|
+
self.root = new_node
|
643
|
+
|
644
|
+
self.current_parent = new_node
|
645
|
+
|
646
|
+
# Parse the body of the if statement
|
647
|
+
for stmt in node.body:
|
648
|
+
self.current_edge = self.get_refinement_edge(node)
|
649
|
+
self.current_parent = new_node
|
650
|
+
self.visit(stmt)
|
651
|
+
|
652
|
+
# Parse elif/else
|
653
|
+
for stmt in node.orelse:
|
654
|
+
self.current_edge = self.get_alternative_edge(node)
|
655
|
+
self.current_parent = new_node
|
656
|
+
if isinstance(stmt, ast.If): # elif case
|
657
|
+
self.visit_If(stmt)
|
658
|
+
else: # else case (return)
|
659
|
+
self.process_else_statement(stmt)
|
660
|
+
self.current_parent = new_node
|
661
|
+
self.current_edge = None
|
662
|
+
|
663
|
+
@abstractmethod
|
664
|
+
def process_else_statement(self, stmt: ast.AST):
|
665
|
+
"""
|
666
|
+
Process the else statement in the if-elif-else block.
|
667
|
+
|
668
|
+
:param stmt: The else statement to process.
|
669
|
+
"""
|
670
|
+
pass
|
671
|
+
|
672
|
+
@abstractmethod
|
673
|
+
def get_refinement_edge(self, node: ast.AST) -> RDREdge:
|
674
|
+
"""
|
675
|
+
:param node: The current AST node to determine the edge type from.
|
676
|
+
:return: The refinement edge type.
|
677
|
+
"""
|
678
|
+
pass
|
679
|
+
|
680
|
+
@abstractmethod
|
681
|
+
def get_alternative_edge(self, node: ast.AST) -> RDREdge:
|
682
|
+
"""
|
683
|
+
:param node: The current AST node to determine the alternative edge type from.
|
684
|
+
:return: The alternative edge type.
|
685
|
+
"""
|
686
|
+
pass
|
687
|
+
|
688
|
+
@abstractmethod
|
689
|
+
def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
|
690
|
+
"""
|
691
|
+
Get the new rule type to create.
|
692
|
+
:param node: The current AST node to determine the rule type from.
|
693
|
+
:return: The new rule type.
|
694
|
+
"""
|
695
|
+
pass
|
696
|
+
|
697
|
+
@abstractmethod
|
698
|
+
def update_current_parent(self, new_node: Rule):
|
699
|
+
"""
|
700
|
+
Update the current parent rule with the new node.
|
701
|
+
:param new_node: The new node to set as the current parent.
|
702
|
+
"""
|
703
|
+
pass
|
704
|
+
|
705
|
+
def visit_Return(self, node):
|
706
|
+
"""Handles return statements as leaf nodes."""
|
707
|
+
if isinstance(node.value, ast.Call):
|
708
|
+
return_value = node.value.func.id
|
709
|
+
else:
|
710
|
+
return_value = ast.literal_eval(node.value)
|
711
|
+
if self.current_parent is None:
|
712
|
+
self.default_conclusion = return_value
|
713
|
+
else:
|
714
|
+
self.current_parent.conclusion = return_value
|
715
|
+
|
716
|
+
def get_condition_name(self, node):
|
717
|
+
"""Extracts the condition function name from an AST expression."""
|
718
|
+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
719
|
+
return node.func.id
|
720
|
+
return None
|
721
|
+
|
722
|
+
class SingleClassTreeBuilder(TreeBuilder):
|
723
|
+
"""Parses an AST of generated SingleClassRDR classifier and reconstructs the rdr tree."""
|
724
|
+
|
725
|
+
def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
|
726
|
+
return SingleClassRule
|
727
|
+
|
728
|
+
def get_refinement_edge(self, node: ast.AST) -> RDREdge:
|
729
|
+
return RDREdge.Refinement
|
730
|
+
|
731
|
+
def get_alternative_edge(self, node: ast.AST) -> RDREdge:
|
732
|
+
return RDREdge.Alternative
|
733
|
+
|
734
|
+
def update_current_parent(self, new_node: Rule):
|
735
|
+
if self.current_edge == RDREdge.Alternative:
|
736
|
+
self.current_parent.alternative = new_node
|
737
|
+
elif self.current_edge == RDREdge.Refinement:
|
738
|
+
self.current_parent.refinement = new_node
|
739
|
+
|
740
|
+
def process_else_statement(self, stmt: ast.AST):
|
741
|
+
"""Handles the else statement in the if-elif-else block."""
|
742
|
+
if isinstance(stmt, ast.Return):
|
743
|
+
self.current_parent = None
|
744
|
+
self.visit_Return(stmt)
|
745
|
+
else:
|
746
|
+
raise ValueError(f"Unexpected statement in else block: {stmt}")
|
747
|
+
|
748
|
+
|
749
|
+
class MultiClassTreeBuilder(TreeBuilder):
|
750
|
+
"""Parses an AST of generated MultiClassRDR classifier and reconstructs the rdr tree."""
|
751
|
+
|
752
|
+
def visit_If(self, stmt: ast.If):
|
753
|
+
super().visit_If(stmt)
|
754
|
+
if isinstance(self.current_parent, (MultiClassTopRule, MultiClassFilterRule)):
|
755
|
+
self.current_parent.conclusion = self.current_parent.conditions.replace("conditions_", "conclusion_")
|
756
|
+
|
757
|
+
def visit_Return(self, node):
|
758
|
+
pass
|
759
|
+
|
760
|
+
def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
|
761
|
+
if self.current_edge == RDREdge.Refinement:
|
762
|
+
return MultiClassStopRule
|
763
|
+
elif self.current_edge == RDREdge.Filter:
|
764
|
+
return MultiClassFilterRule
|
765
|
+
elif self.current_edge in [RDREdge.Next, None]:
|
766
|
+
return MultiClassTopRule
|
767
|
+
elif self.current_edge == RDREdge.Alternative:
|
768
|
+
return self.get_refinement_rule_type(node)
|
769
|
+
else:
|
770
|
+
raise ValueError(f"Unknown edge type: {self.current_edge}")
|
771
|
+
|
772
|
+
def get_alternative_edge(self, node: ast.AST) -> RDREdge:
|
773
|
+
if isinstance(self.current_parent, MultiClassTopRule):
|
774
|
+
return RDREdge.Next
|
775
|
+
else:
|
776
|
+
return RDREdge.Alternative
|
777
|
+
|
778
|
+
def get_refinement_edge(self, node: ast.AST) -> RDREdge:
|
779
|
+
rule_type = self.get_refinement_rule_type(node)
|
780
|
+
return self.get_refinement_edge_from_refinement_rule(rule_type)
|
781
|
+
|
782
|
+
def get_refinement_edge_from_refinement_rule(self, rule_type: Type[Rule]) -> RDREdge:
|
783
|
+
"""
|
784
|
+
:param rule_type: The type of the rule to determine the refinement edge from.
|
785
|
+
:return: The refinement edge type based on the rule type.
|
786
|
+
"""
|
787
|
+
if isinstance(self.current_parent, MultiClassRefinementRule):
|
788
|
+
return RDREdge.Alternative
|
789
|
+
if rule_type == MultiClassStopRule:
|
790
|
+
return RDREdge.Refinement
|
791
|
+
else:
|
792
|
+
return RDREdge.Filter
|
793
|
+
|
794
|
+
def get_refinement_rule_type(self, node: ast.AST) -> Type[Rule]:
|
795
|
+
"""
|
796
|
+
:param node: The current AST node to determine the rule type from.
|
797
|
+
:return: The rule type based on the node body.
|
798
|
+
"""
|
799
|
+
for stmt in node.body:
|
800
|
+
if len(node.body) == 1 and isinstance(stmt, ast.Pass):
|
801
|
+
return MultiClassStopRule
|
802
|
+
elif isinstance(stmt, ast.If):
|
803
|
+
return self.get_refinement_rule_type(stmt)
|
804
|
+
else:
|
805
|
+
return MultiClassFilterRule
|
806
|
+
raise ValueError(f"Could not determine the refinement rule type from the node: {node} as it has an empty body.")
|
807
|
+
|
808
|
+
def update_current_parent(self, new_node: Rule):
|
809
|
+
if isinstance(new_node, MultiClassRefinementRule):
|
810
|
+
if isinstance(self.current_parent, MultiClassTopRule):
|
811
|
+
new_node.top_rule = self.current_parent
|
812
|
+
elif hasattr(self.current_parent, "top_rule"):
|
813
|
+
new_node.top_rule = self.current_parent.top_rule
|
814
|
+
else:
|
815
|
+
raise ValueError(f"Could not set the top rule for the refinement rule: {new_node}")
|
816
|
+
if self.current_edge in [RDREdge.Alternative, RDREdge.Next, None]:
|
817
|
+
self.current_parent.alternative = new_node
|
818
|
+
elif self.current_edge in [RDREdge.Refinement, RDREdge.Filter]:
|
819
|
+
self.current_parent.refinement = new_node
|
820
|
+
|
821
|
+
def process_else_statement(self, stmt: ast.AST):
|
822
|
+
"""Handles the else statement in the if-elif-else block."""
|
823
|
+
pass
|
824
|
+
|
417
825
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
418
826
|
|
419
|
-
|
827
|
+
@classmethod
|
828
|
+
def read_rule_tree_from_python(cls, model_path: str, python_file_path: Optional[str] = None) -> Rule:
|
829
|
+
"""
|
830
|
+
:param model_path: The path to the generated python file that contains the RDR classifier function.
|
831
|
+
:param python_file_path: The path to the generated python file that contains the RDR classifier function.
|
832
|
+
"""
|
833
|
+
if python_file_path is None:
|
834
|
+
python_file_path = cls.get_generated_python_file_path(model_path)
|
835
|
+
with open(python_file_path, "r") as f:
|
836
|
+
source_code = f.read()
|
837
|
+
|
838
|
+
tree = ast.parse(source_code)
|
839
|
+
builder = cls.get_tree_builder_class()()
|
840
|
+
|
841
|
+
# Find and process the function
|
842
|
+
for node in tree.body:
|
843
|
+
if isinstance(node, ast.FunctionDef) and node.name == "classify":
|
844
|
+
builder.visit_FunctionDef(node)
|
845
|
+
|
846
|
+
return builder.root
|
847
|
+
|
848
|
+
@classmethod
|
849
|
+
@abstractmethod
|
850
|
+
def get_tree_builder_class(cls) -> Type[TreeBuilder]:
|
851
|
+
"""
|
852
|
+
:return: The class that builds the rule tree from the generated python file.
|
853
|
+
This should be either SingleClassTreeBuilder or MultiClassTreeBuilder.
|
854
|
+
"""
|
855
|
+
pass
|
856
|
+
|
857
|
+
@property
|
858
|
+
def all_rules(self) -> List[Rule]:
|
859
|
+
"""
|
860
|
+
Get all rules in the classifier.
|
861
|
+
|
862
|
+
:return: A list of all rules in the classifier.
|
863
|
+
"""
|
864
|
+
if self.start_rule is None:
|
865
|
+
return []
|
866
|
+
return [r for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None]
|
867
|
+
|
868
|
+
def update_from_python(self, model_dir: str, parent_package_name: Optional[str] = None,
|
869
|
+
python_file_path: Optional[str] = None, update_rule_tree: bool = False):
|
420
870
|
"""
|
421
871
|
Update the rules from the generated python file, that might have been modified by the user.
|
422
872
|
|
423
873
|
:param model_dir: The directory where the generated python file is located.
|
424
|
-
:param
|
874
|
+
:param parent_package_name: The name of the package that contains the RDR classifier function, this
|
425
875
|
is required in case of relative imports in the generated python file.
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
876
|
+
:param python_file_path: The path to the generated python file that contains the RDR classifier function.
|
877
|
+
:param update_rule_tree: Whether to update the rule tree from the python file or not.
|
878
|
+
"""
|
879
|
+
if update_rule_tree:
|
880
|
+
self.start_rule = self.read_rule_tree_from_python(model_dir, python_file_path=python_file_path)
|
881
|
+
all_rules = self.all_rules
|
882
|
+
condition_func_names = [rule.generated_conditions_function_name for rule in all_rules]
|
883
|
+
conclusion_func_names = [rule.generated_conclusion_function_name for rule in all_rules
|
884
|
+
if not isinstance(rule, MultiClassStopRule)]
|
432
885
|
all_func_names = condition_func_names + conclusion_func_names
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
886
|
+
|
887
|
+
main_module, defs_module, cases_module = self.get_and_import_model_python_modules(
|
888
|
+
model_dir,
|
889
|
+
python_file_path=python_file_path,
|
890
|
+
parent_package_name=parent_package_name)
|
891
|
+
self.generated_python_file_name = Path(main_module.__file__).name.replace(".py", "")
|
892
|
+
|
893
|
+
self.update_rdr_metadata_from_python(main_module)
|
894
|
+
|
895
|
+
functions_source = extract_function_source(defs_module.__file__,
|
896
|
+
all_func_names, include_signature=True)
|
897
|
+
scope = extract_imports(defs_module.__file__, package_name=parent_package_name)
|
898
|
+
|
899
|
+
cases_source, cases_scope = None, None
|
900
|
+
if cases_module:
|
901
|
+
with open(cases_module.__file__, "r") as f:
|
902
|
+
cases_source = f.read()
|
903
|
+
cases_scope = extract_imports(cases_module.__file__, package_name=parent_package_name)
|
904
|
+
|
905
|
+
with open(main_module.__file__, "r") as f:
|
906
|
+
main_source = f.read()
|
907
|
+
main_scope = extract_imports(main_module.__file__, package_name=parent_package_name)
|
908
|
+
attribute_name_line = [l for l in main_source.split('\n') if "attribute_name = " in l]
|
909
|
+
conclusion_name = None
|
910
|
+
if len(attribute_name_line) > 0:
|
911
|
+
conclusion_name = eval(attribute_name_line[0].split('=')[-1].strip(), main_scope)
|
912
|
+
|
913
|
+
for rule in all_rules:
|
442
914
|
if rule.conditions is not None:
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
915
|
+
conditions_wrapper_func_name = rule.generated_conditions_function_name
|
916
|
+
user_input = functions_source[conditions_wrapper_func_name]
|
917
|
+
user_input = '\n'.join(user_input.split("\n")[1:]) # Remove the function signature line
|
918
|
+
rule.conditions = CallableExpression(user_input, (bool,), scope=scope)
|
919
|
+
if cases_module:
|
920
|
+
try:
|
921
|
+
rule.corner_case_metadata = cases_module.__dict__[rule.generated_corner_case_object_name]
|
922
|
+
except KeyError:
|
923
|
+
case_def_lines = [l for l in cases_source.split('\n') if rule.generated_corner_case_object_name in l]
|
924
|
+
if len(case_def_lines) > 0:
|
925
|
+
case_def_line = case_def_lines[0].split('=')[-1].strip()
|
926
|
+
rule.corner_case_metadata = eval(case_def_line, cases_scope)
|
927
|
+
|
928
|
+
if not isinstance(rule, MultiClassStopRule):
|
929
|
+
if conclusion_name:
|
930
|
+
rule.conclusion_name = conclusion_name
|
931
|
+
user_input = functions_source[rule.generated_conclusion_function_name]
|
932
|
+
split_user_input = user_input.split("\n")
|
933
|
+
user_input = '\n'.join(split_user_input[1:])
|
934
|
+
conclusion_func = defs_module.__dict__.get(rule.generated_conclusion_function_name)
|
935
|
+
if conclusion_func is None:
|
936
|
+
function_signature = split_user_input[0]
|
937
|
+
return_type_hint_str = function_signature.split('->')[-1].strip(' :')
|
938
|
+
return_type_hint = eval(return_type_hint_str, scope)
|
939
|
+
conclusion_type = get_type_from_type_hint(return_type_hint)
|
940
|
+
else:
|
941
|
+
conclusion_type = get_function_return_type(conclusion_func)
|
942
|
+
rule.conclusion = CallableExpression(user_input, conclusion_type, scope=scope,
|
943
|
+
mutually_exclusive=self.mutually_exclusive)
|
944
|
+
|
945
|
+
@classmethod
|
946
|
+
def get_and_import_model_python_modules(cls, model_dir: str,
|
947
|
+
python_file_path: Optional[str] = None,
|
948
|
+
parent_package_name: Optional[str] = None)\
|
949
|
+
-> Tuple[ModuleType, ModuleType, ModuleType]:
|
950
|
+
"""
|
951
|
+
Get and import the python modules that contain the RDR classifier function, definitions, and corner cases.
|
952
|
+
|
953
|
+
:param model_dir: The path to the directory where the generated python files are located.
|
954
|
+
:param python_file_path: The path to the generated python file that contains the RDR classifier function.
|
955
|
+
:param parent_package_name: The name of the package that contains the RDR classifier function, this
|
956
|
+
is required in case of relative imports in the generated python file.
|
957
|
+
:return: A tuple containing the main module, defs module, and cases module.
|
958
|
+
"""
|
959
|
+
if python_file_path is None:
|
960
|
+
main_file_path = cls.get_generated_python_file_path(model_dir)
|
961
|
+
else:
|
962
|
+
main_file_path = python_file_path
|
963
|
+
if not os.path.exists(main_file_path):
|
964
|
+
raise ModuleNotFoundError(main_file_path)
|
965
|
+
|
966
|
+
defs_file_path = main_file_path.replace(".py", "_defs.py")
|
967
|
+
cases_path = main_file_path.replace(".py", "_cases.py")
|
968
|
+
|
969
|
+
main_module, defs_module, cases_module = get_and_import_python_modules_in_a_package(
|
970
|
+
[main_file_path, defs_file_path, cases_path], parent_package_name=parent_package_name)
|
971
|
+
return main_module, defs_module, cases_module
|
452
972
|
|
453
973
|
@abstractmethod
|
454
974
|
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
|
@@ -492,6 +1012,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
492
1012
|
defs_imports = get_imports_from_types(defs_types, defs_file_name, package_name)
|
493
1013
|
corner_cases_imports = get_imports_from_types(corner_cases_types, cases_file_name, package_name)
|
494
1014
|
|
1015
|
+
defs_imports.append(f"from ripple_down_rules import *")
|
495
1016
|
# Add the imports to the defs file
|
496
1017
|
with open(defs_file_name, "w") as f:
|
497
1018
|
f.write('\n'.join(defs_imports) + "\n\n\n")
|
@@ -511,6 +1032,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
511
1032
|
f.write(f"attribute_name = '{self.attribute_name}'\n")
|
512
1033
|
f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
|
513
1034
|
f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
|
1035
|
+
self.write_rdr_metadata_to_pyton_file(f)
|
514
1036
|
f.write(f"\n\n{func_def}")
|
515
1037
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
516
1038
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
@@ -527,7 +1049,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
527
1049
|
"""
|
528
1050
|
pass
|
529
1051
|
|
530
|
-
def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
|
1052
|
+
def _get_types_to_import(self) -> Tuple[Set[Union[Type, Callable]], Set[Type], Set[Type]]:
|
531
1053
|
"""
|
532
1054
|
:return: The types of the main, defs, and corner cases files of the RDR classifier that will be imported.
|
533
1055
|
"""
|
@@ -550,7 +1072,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
550
1072
|
main_types.update({Union, Optional})
|
551
1073
|
defs_types.add(Union)
|
552
1074
|
main_types.update({Case, create_case})
|
553
|
-
main_types = main_types.difference(defs_types)
|
1075
|
+
# main_types = main_types.difference(defs_types)
|
554
1076
|
return main_types, defs_types, cases_types
|
555
1077
|
|
556
1078
|
@property
|
@@ -576,9 +1098,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
576
1098
|
:return: The type of the conclusion of the RDR classifier.
|
577
1099
|
"""
|
578
1100
|
all_types = []
|
579
|
-
|
580
|
-
|
581
|
-
all_types.extend(list(rule.conclusion.conclusion_type))
|
1101
|
+
for rule in self.all_rules:
|
1102
|
+
all_types.extend(list(rule.conclusion.conclusion_type))
|
582
1103
|
return tuple(set(all_types))
|
583
1104
|
|
584
1105
|
@property
|
@@ -635,6 +1156,10 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
635
1156
|
super(SingleClassRDR, self).__init__(**kwargs)
|
636
1157
|
self.default_conclusion: Optional[Any] = default_conclusion
|
637
1158
|
|
1159
|
+
@classmethod
|
1160
|
+
def get_tree_builder_class(cls) -> Type[TreeBuilder]:
|
1161
|
+
return SingleClassTreeBuilder
|
1162
|
+
|
638
1163
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
639
1164
|
-> Union[CaseAttribute, CallableExpression, None]:
|
640
1165
|
"""
|
@@ -649,7 +1174,7 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
649
1174
|
self.default_conclusion = case_query.default_value
|
650
1175
|
|
651
1176
|
pred = self.evaluate(case_query.case)
|
652
|
-
if pred.conclusion(case_query.case) != case_query.target_value:
|
1177
|
+
if (not pred.fired and self.default_conclusion is None) or pred.conclusion(case_query.case) != case_query.target_value:
|
653
1178
|
expert.ask_for_conditions(case_query, pred)
|
654
1179
|
pred.fit_rule(case_query)
|
655
1180
|
|
@@ -666,8 +1191,8 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
666
1191
|
expert.ask_for_conditions(case_query)
|
667
1192
|
self.start_rule = SingleClassRule.from_case_query(case_query)
|
668
1193
|
|
669
|
-
def
|
670
|
-
|
1194
|
+
def _classify(self, case: Case, modify_case: bool = False,
|
1195
|
+
case_query: Optional[CaseQuery] = None) -> Optional[Any]:
|
671
1196
|
"""
|
672
1197
|
Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
|
673
1198
|
|
@@ -677,6 +1202,11 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
677
1202
|
"""
|
678
1203
|
pred = self.evaluate(case)
|
679
1204
|
conclusion = pred.conclusion(case) if pred is not None and pred.fired else self.default_conclusion
|
1205
|
+
if pred is not None and pred.fired:
|
1206
|
+
pred.contributed = True
|
1207
|
+
pred.last_conclusion = conclusion
|
1208
|
+
if case_query is not None:
|
1209
|
+
pred.contributed_to_case_query = True
|
680
1210
|
if pred is not None and pred.fired and case_query is not None:
|
681
1211
|
if pred.corner_case_metadata is None and conclusion is not None \
|
682
1212
|
and type(conclusion) in case_query.core_attribute_type:
|
@@ -781,8 +1311,12 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
781
1311
|
super(MultiClassRDR, self).__init__(start_rule, **kwargs)
|
782
1312
|
self.mode: MCRDRMode = mode
|
783
1313
|
|
784
|
-
|
785
|
-
|
1314
|
+
@classmethod
|
1315
|
+
def get_tree_builder_class(cls) -> Type[TreeBuilder]:
|
1316
|
+
return MultiClassTreeBuilder
|
1317
|
+
|
1318
|
+
def _classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
|
1319
|
+
case_query: Optional[CaseQuery] = None) -> Set[Any]:
|
786
1320
|
evaluated_rule = self.start_rule
|
787
1321
|
self.conclusions = []
|
788
1322
|
while evaluated_rule:
|
@@ -794,6 +1328,13 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
794
1328
|
and any(
|
795
1329
|
ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
|
796
1330
|
evaluated_rule.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
1331
|
+
if rule_conclusion is not None and any(make_list(rule_conclusion)):
|
1332
|
+
evaluated_rule.contributed = True
|
1333
|
+
evaluated_rule.last_conclusion = rule_conclusion
|
1334
|
+
if case_query is not None:
|
1335
|
+
rule_conclusion_types = set(map(type, make_list(rule_conclusion)))
|
1336
|
+
if are_results_subclass_of_types(rule_conclusion_types, case_query.core_attribute_type):
|
1337
|
+
evaluated_rule.contributed_to_case_query = True
|
797
1338
|
self.add_conclusion(rule_conclusion)
|
798
1339
|
evaluated_rule = next_rule
|
799
1340
|
return make_set(self.conclusions)
|
@@ -860,6 +1401,9 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
860
1401
|
if rule.alternative:
|
861
1402
|
self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
|
862
1403
|
cases_file=cases_file, package_name=package_name)
|
1404
|
+
elif isinstance(rule, MultiClassTopRule):
|
1405
|
+
with open(filename, "a") as file:
|
1406
|
+
file.write(f"{parent_indent}return conclusions\n")
|
863
1407
|
|
864
1408
|
@property
|
865
1409
|
def conclusion_type_hint(self) -> str:
|
@@ -869,8 +1413,9 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
869
1413
|
else:
|
870
1414
|
return f"Set[Union[{', '.join(conclusion_types)}]]"
|
871
1415
|
|
872
|
-
def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
|
1416
|
+
def _get_types_to_import(self) -> Tuple[Set[Union[Type, Callable]], Set[Type], Set[Type]]:
|
873
1417
|
main_types, defs_types, cases_types = super()._get_types_to_import()
|
1418
|
+
main_types.add(get_an_updated_case_copy)
|
874
1419
|
main_types.update({Set, make_set})
|
875
1420
|
defs_types.update({List, Set})
|
876
1421
|
return main_types, defs_types, cases_types
|
@@ -884,7 +1429,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
884
1429
|
"""
|
885
1430
|
if not self.start_rule:
|
886
1431
|
conditions = expert.ask_for_conditions(case_query)
|
887
|
-
self.start_rule = MultiClassTopRule.from_case_query(case_query)
|
1432
|
+
self.start_rule: MultiClassTopRule = MultiClassTopRule.from_case_query(case_query)
|
888
1433
|
|
889
1434
|
@property
|
890
1435
|
def last_top_rule(self) -> Optional[MultiClassTopRule]:
|
@@ -902,28 +1447,43 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
902
1447
|
Stop a wrong conclusion by adding a stopping rule.
|
903
1448
|
"""
|
904
1449
|
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
905
|
-
|
906
|
-
|
907
|
-
|
1450
|
+
stop: bool = False
|
1451
|
+
add_filter_rule: bool = False
|
1452
|
+
if is_value_conflicting(rule_conclusion, case_query.target_value):
|
1453
|
+
if make_set(case_query.target_value).issubset(rule_conclusion):
|
1454
|
+
add_filter_rule = True
|
1455
|
+
else:
|
1456
|
+
stop = True
|
1457
|
+
elif make_set(case_query.core_attribute_type).issubset(make_set(evaluated_rule.conclusion.conclusion_type)):
|
1458
|
+
if make_set(case_query.target_value).issubset(rule_conclusion):
|
1459
|
+
add_filter_rule = True
|
1460
|
+
|
1461
|
+
if not stop:
|
908
1462
|
self.add_conclusion(rule_conclusion)
|
1463
|
+
if stop or add_filter_rule:
|
1464
|
+
refinement_type = MultiClassStopRule if stop else MultiClassFilterRule
|
1465
|
+
self.stop_or_filter_conclusion(case_query, expert, evaluated_rule, refinement_type=refinement_type)
|
909
1466
|
|
910
|
-
def
|
911
|
-
|
1467
|
+
def stop_or_filter_conclusion(self, case_query: CaseQuery,
|
1468
|
+
expert: Expert, evaluated_rule: MultiClassTopRule,
|
1469
|
+
refinement_type: Type[MultiClassRefinementRule] = MultiClassStopRule):
|
912
1470
|
"""
|
913
1471
|
Stop a conclusion by adding a stopping rule.
|
914
1472
|
|
915
1473
|
:param case_query: The case query to stop the conclusion for.
|
916
1474
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
917
1475
|
:param evaluated_rule: The evaluated rule to ask the expert about.
|
1476
|
+
:param refinement_type: The refinement type to use.
|
918
1477
|
"""
|
919
1478
|
conditions = expert.ask_for_conditions(case_query, evaluated_rule)
|
920
|
-
evaluated_rule.fit_rule(case_query)
|
921
|
-
if
|
922
|
-
self.
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
1479
|
+
evaluated_rule.fit_rule(case_query, refinement_type=refinement_type)
|
1480
|
+
if refinement_type is MultiClassStopRule:
|
1481
|
+
if self.mode == MCRDRMode.StopPlusRule:
|
1482
|
+
self.stop_rule_conditions = conditions
|
1483
|
+
if self.mode == MCRDRMode.StopPlusRuleCombined:
|
1484
|
+
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
1485
|
+
case_query.conditions = new_top_rule_conditions
|
1486
|
+
self.add_top_rule(case_query)
|
927
1487
|
|
928
1488
|
def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
|
929
1489
|
"""
|
@@ -1004,6 +1564,39 @@ class GeneralRDR(RippleDownRules):
|
|
1004
1564
|
super(GeneralRDR, self).__init__(**kwargs)
|
1005
1565
|
self.all_figs: List[Figure] = [sr.fig for sr in self.start_rules_dict.values()]
|
1006
1566
|
|
1567
|
+
@classmethod
|
1568
|
+
def from_python(cls, model_dir: str, python_file_path: Optional[str] = None,
|
1569
|
+
parent_package_name: Optional[str] = None) -> Self:
|
1570
|
+
"""
|
1571
|
+
Create an instance of the class from a python file.
|
1572
|
+
|
1573
|
+
:param model_dir: The path to the directory containing the python file.
|
1574
|
+
:param python_file_path: The path to the python file, if not provided, it will be generated from the model_dir.
|
1575
|
+
:param parent_package_name: The name of the package that contains the RDR classifier function, this
|
1576
|
+
is required in case of relative imports in the generated python file.
|
1577
|
+
:return: An instance of the class.
|
1578
|
+
"""
|
1579
|
+
grdr = cls()
|
1580
|
+
grdr.update_from_python(model_dir, parent_package_name=parent_package_name, python_file_path=python_file_path,
|
1581
|
+
update_rule_tree=True)
|
1582
|
+
return grdr
|
1583
|
+
|
1584
|
+
@classmethod
|
1585
|
+
def get_rdr_type_from_acronym(cls, acronym: str) -> Type[Union[SingleClassRDR, MultiClassRDR]]:
|
1586
|
+
"""
|
1587
|
+
Get the type of the ripple down rules classifier from the acronym.
|
1588
|
+
|
1589
|
+
:param acronym: The acronym of the ripple down rules classifier.
|
1590
|
+
:return: The type of the ripple down rules classifier.
|
1591
|
+
"""
|
1592
|
+
acronym = acronym.lower()
|
1593
|
+
if acronym == "scrdr":
|
1594
|
+
return SingleClassRDR
|
1595
|
+
elif acronym == "mcrdr":
|
1596
|
+
return MultiClassRDR
|
1597
|
+
else:
|
1598
|
+
raise ValueError(f"Unknown RDR type acronym: {acronym}")
|
1599
|
+
|
1007
1600
|
def add_rdr(self, rdr: Union[SingleClassRDR, MultiClassRDR], case_query: Optional[CaseQuery] = None):
|
1008
1601
|
"""
|
1009
1602
|
Add a ripple down rules classifier to the map of classifiers.
|
@@ -1027,8 +1620,8 @@ class GeneralRDR(RippleDownRules):
|
|
1027
1620
|
def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
|
1028
1621
|
return [rdr.start_rule for rdr in self.start_rules_dict.values()]
|
1029
1622
|
|
1030
|
-
def
|
1031
|
-
|
1623
|
+
def _classify(self, case: Any, modify_case: bool = False,
|
1624
|
+
case_query: Optional[CaseQuery] = None) -> Optional[Dict[str, Any]]:
|
1032
1625
|
"""
|
1033
1626
|
Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
|
1034
1627
|
the classification until no more categories can be added.
|
@@ -1109,23 +1702,45 @@ class GeneralRDR(RippleDownRules):
|
|
1109
1702
|
new_rdr.case_name = data["case_name"]
|
1110
1703
|
return new_rdr
|
1111
1704
|
|
1112
|
-
def update_from_python(self, model_dir: str,
|
1705
|
+
def update_from_python(self, model_dir: str, parent_package_name: Optional[str] = None,
|
1706
|
+
python_file_path: Optional[str] = None,
|
1707
|
+
update_rule_tree: bool = False) -> None:
|
1113
1708
|
"""
|
1114
1709
|
Update the rules from the generated python file, that might have been modified by the user.
|
1115
1710
|
|
1116
1711
|
:param model_dir: The directory where the model is stored.
|
1117
|
-
:param
|
1712
|
+
:param parent_package_name: The name of the package that contains the RDR classifier function, this
|
1118
1713
|
is required in case of relative imports in the generated python file.
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1714
|
+
:param python_file_path: The path to the python file, if not provided, it will be generated from the model_dir.
|
1715
|
+
:param update_rule_tree: Whether to update the rule tree from the python file or not.
|
1716
|
+
"""
|
1717
|
+
if update_rule_tree:
|
1718
|
+
if python_file_path is None:
|
1719
|
+
main_python_file_path = self.get_generated_python_file_path(model_dir)
|
1720
|
+
else:
|
1721
|
+
main_python_file_path = python_file_path
|
1722
|
+
main_module = get_and_import_python_module(main_python_file_path, parent_package_name=parent_package_name)
|
1723
|
+
classifiers_dict = main_module.classifiers_dict
|
1724
|
+
self.start_rules_dict = {}
|
1725
|
+
for rdr_name, rdr_module in classifiers_dict.items():
|
1726
|
+
rdr_acronym = rdr_module.__name__.split('_')[-1]
|
1727
|
+
rdr_type = self.get_rdr_type_from_acronym(rdr_acronym)
|
1728
|
+
rdr_model_path = main_python_file_path.replace('_rdr.py', f'_{rdr_name}_{rdr_acronym}.py')
|
1729
|
+
rdr = rdr_type.from_python(model_dir, python_file_path=rdr_model_path,
|
1730
|
+
parent_package_name=parent_package_name)
|
1731
|
+
self.start_rules_dict[rdr_name] = rdr
|
1732
|
+
|
1733
|
+
self.update_rdr_metadata_from_python(main_module)
|
1734
|
+
else:
|
1735
|
+
for rdr in self.start_rules_dict.values():
|
1736
|
+
rdr.update_from_python(model_dir, parent_package_name=parent_package_name)
|
1122
1737
|
|
1123
1738
|
def _write_to_python(self, model_dir: str, package_name: Optional[str] = None) -> None:
|
1124
1739
|
"""
|
1125
1740
|
Write the tree of rules as source code to a file.
|
1126
1741
|
|
1127
1742
|
:param model_dir: The directory where the model is stored.
|
1128
|
-
:param
|
1743
|
+
:param package_name: The name of the package that contains the RDR classifier function.
|
1129
1744
|
"""
|
1130
1745
|
for rdr in self.start_rules_dict.values():
|
1131
1746
|
rdr._write_to_python(model_dir, package_name=package_name)
|
@@ -1133,6 +1748,7 @@ class GeneralRDR(RippleDownRules):
|
|
1133
1748
|
file_path = model_dir + f"/{self.generated_python_file_name}.py"
|
1134
1749
|
with open(file_path, "w") as f:
|
1135
1750
|
f.write(self._get_imports(file_path=file_path, package_name=package_name) + "\n\n")
|
1751
|
+
self.write_rdr_metadata_to_pyton_file(f)
|
1136
1752
|
f.write("classifiers_dict = dict()\n")
|
1137
1753
|
for rdr_key, rdr in self.start_rules_dict.items():
|
1138
1754
|
f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
|