ripple-down-rules 0.6.23__py3-none-any.whl → 0.6.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/datastructures/callable_expression.py +3 -2
- ripple_down_rules/datastructures/dataclasses.py +2 -2
- ripple_down_rules/datastructures/enums.py +4 -1
- ripple_down_rules/experts.py +2 -2
- ripple_down_rules/helpers.py +28 -7
- ripple_down_rules/rdr.py +113 -34
- ripple_down_rules/rdr_decorators.py +7 -6
- ripple_down_rules/rules.py +160 -56
- ripple_down_rules/user_interface/prompt.py +6 -3
- ripple_down_rules/utils.py +30 -15
- {ripple_down_rules-0.6.23.dist-info → ripple_down_rules-0.6.25.dist-info}/METADATA +1 -1
- ripple_down_rules-0.6.25.dist-info/RECORD +24 -0
- ripple_down_rules-0.6.23.dist-info/RECORD +0 -24
- {ripple_down_rules-0.6.23.dist-info → ripple_down_rules-0.6.25.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.6.23.dist-info → ripple_down_rules-0.6.25.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.6.23.dist-info → ripple_down_rules-0.6.25.dist-info}/top_level.txt +0 -0
ripple_down_rules/__init__.py
CHANGED
@@ -95,7 +95,7 @@ class CallableExpression(SubclassJSONSerializer):
|
|
95
95
|
encapsulating_function_name: str = "_get_value"
|
96
96
|
|
97
97
|
def __init__(self, user_input: Optional[str] = None,
|
98
|
-
conclusion_type: Optional[Tuple[Type]] = None,
|
98
|
+
conclusion_type: Optional[Tuple[Type, ...]] = None,
|
99
99
|
expression_tree: Optional[AST] = None,
|
100
100
|
scope: Optional[Dict[str, Any]] = None,
|
101
101
|
conclusion: Optional[Any] = None,
|
@@ -176,7 +176,8 @@ class CallableExpression(SubclassJSONSerializer):
|
|
176
176
|
new_user_input = (f"{cond1_user_input}\n"
|
177
177
|
f"{cond2_user_input}\n"
|
178
178
|
f"return _cond1(case) and _cond2(case)")
|
179
|
-
return CallableExpression(new_user_input, conclusion_type=self.conclusion_type
|
179
|
+
return CallableExpression(new_user_input, conclusion_type=self.conclusion_type,
|
180
|
+
mutually_exclusive=self.mutually_exclusive)
|
180
181
|
|
181
182
|
def update_user_input_from_file(self, file_path: str, function_name: str):
|
182
183
|
"""
|
@@ -31,7 +31,7 @@ class CaseQuery:
|
|
31
31
|
"""
|
32
32
|
The name of the attribute.
|
33
33
|
"""
|
34
|
-
_attribute_types: Tuple[Type]
|
34
|
+
_attribute_types: Tuple[Type, ...]
|
35
35
|
"""
|
36
36
|
The type(s) of the attribute.
|
37
37
|
"""
|
@@ -139,7 +139,7 @@ class CaseQuery:
|
|
139
139
|
attribute_types_str = f"Union[{', '.join([t.__name__ for t in self.core_attribute_type])}]"
|
140
140
|
else:
|
141
141
|
attribute_types_str = self.core_attribute_type[0].__name__
|
142
|
-
if
|
142
|
+
if not self.mutually_exclusive:
|
143
143
|
return f"List[{attribute_types_str}]"
|
144
144
|
else:
|
145
145
|
return attribute_types_str
|
@@ -196,7 +196,10 @@ class RDREdge(Enum):
|
|
196
196
|
"""
|
197
197
|
Next edge, the edge that represents the next rule to be evaluated.
|
198
198
|
"""
|
199
|
-
|
199
|
+
Filter = "filter if"
|
200
|
+
"""
|
201
|
+
Filter edge, the edge that represents the filter condition.
|
202
|
+
"""
|
200
203
|
|
201
204
|
class ValueType(Enum):
|
202
205
|
Unary = auto()
|
ripple_down_rules/experts.py
CHANGED
@@ -41,14 +41,14 @@ class Expert(ABC):
|
|
41
41
|
A flag to indicate if the expert should use loaded answers or not.
|
42
42
|
"""
|
43
43
|
|
44
|
-
def __init__(self, use_loaded_answers: bool =
|
44
|
+
def __init__(self, use_loaded_answers: bool = False,
|
45
45
|
append: bool = False,
|
46
46
|
answers_save_path: Optional[str] = None):
|
47
47
|
self.all_expert_answers = []
|
48
48
|
self.use_loaded_answers = use_loaded_answers
|
49
49
|
self.append = append
|
50
50
|
self.answers_save_path = answers_save_path
|
51
|
-
if answers_save_path is not None:
|
51
|
+
if answers_save_path is not None and os.path.exists(answers_save_path + '.py'):
|
52
52
|
if use_loaded_answers:
|
53
53
|
self.load_answers(answers_save_path)
|
54
54
|
else:
|
ripple_down_rules/helpers.py
CHANGED
@@ -2,15 +2,14 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import os
|
4
4
|
from types import ModuleType
|
5
|
+
from typing import Tuple
|
5
6
|
|
6
|
-
from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
|
7
|
-
|
8
|
-
from .datastructures.case import create_case
|
9
|
-
from .datastructures.dataclasses import CaseQuery
|
10
7
|
from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING, Union
|
11
8
|
|
12
|
-
from .
|
9
|
+
from .datastructures.case import create_case, Case
|
10
|
+
from .datastructures.dataclasses import CaseQuery
|
13
11
|
from .utils import calculate_precision_and_recall
|
12
|
+
from .utils import get_func_rdr_model_name, copy_case, make_set, update_case
|
14
13
|
|
15
14
|
if TYPE_CHECKING:
|
16
15
|
from .rdr import RippleDownRules
|
@@ -55,12 +54,14 @@ def general_rdr_classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDow
|
|
55
54
|
if attribute_name in new_conclusions:
|
56
55
|
temp_case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, rdr.mutually_exclusive)
|
57
56
|
update_case(temp_case_query, new_conclusions)
|
58
|
-
if len(new_conclusions) == 0 or len(classifiers_dict) == 1 and list(classifiers_dict.values())[
|
57
|
+
if len(new_conclusions) == 0 or len(classifiers_dict) == 1 and list(classifiers_dict.values())[
|
58
|
+
0].mutually_exclusive:
|
59
59
|
break
|
60
60
|
return conclusions
|
61
61
|
|
62
62
|
|
63
|
-
def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery,
|
63
|
+
def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery,
|
64
|
+
pred_cat: Optional[Dict[str, Any]] = None) -> bool:
|
64
65
|
"""
|
65
66
|
:param classifier: The RDR classifier to check the prediction of.
|
66
67
|
:param case_query: The case query to check.
|
@@ -95,3 +96,23 @@ def load_or_create_func_rdr_model(func, model_dir: str, rdr_type: Type[RippleDow
|
|
95
96
|
else:
|
96
97
|
rdr = rdr_type(**rdr_kwargs)
|
97
98
|
return rdr
|
99
|
+
|
100
|
+
|
101
|
+
def get_an_updated_case_copy(case: Case, conclusion: Callable, attribute_name: str, conclusion_type: Tuple[Type, ...],
|
102
|
+
mutually_exclusive: bool) -> Case:
|
103
|
+
"""
|
104
|
+
:param case: The case to copy and update.
|
105
|
+
:param conclusion: The conclusion to add to the case.
|
106
|
+
:param attribute_name: The name of the attribute to update.
|
107
|
+
:param conclusion_type: The type of the conclusion to update.
|
108
|
+
:param mutually_exclusive: Whether the rule belongs to a mutually exclusive RDR.
|
109
|
+
:return: A copy of the case updated with the given conclusion.
|
110
|
+
"""
|
111
|
+
case_cp = copy_case(case)
|
112
|
+
temp_case_query = CaseQuery(case_cp, attribute_name, conclusion_type,
|
113
|
+
mutually_exclusive=mutually_exclusive)
|
114
|
+
output = conclusion(case_cp)
|
115
|
+
if not isinstance(output, Dict):
|
116
|
+
output = {attribute_name: output}
|
117
|
+
update_case(temp_case_query, output)
|
118
|
+
return case_cp
|
ripple_down_rules/rdr.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import importlib
|
4
|
+
import json
|
4
5
|
import os
|
5
6
|
from abc import ABC, abstractmethod
|
6
7
|
from copy import copy
|
8
|
+
from dataclasses import is_dataclass
|
7
9
|
from types import NoneType
|
8
10
|
|
9
11
|
from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
|
@@ -27,15 +29,16 @@ from .datastructures.case import Case, CaseAttribute, create_case
|
|
27
29
|
from .datastructures.dataclasses import CaseQuery
|
28
30
|
from .datastructures.enums import MCRDRMode
|
29
31
|
from .experts import Expert, Human
|
30
|
-
from .helpers import is_matching, general_rdr_classify
|
31
|
-
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
32
|
+
from .helpers import is_matching, general_rdr_classify, get_an_updated_case_copy
|
33
|
+
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule, MultiClassRefinementRule, \
|
34
|
+
MultiClassFilterRule
|
32
35
|
|
33
36
|
try:
|
34
37
|
from .user_interface.gui import RDRCaseViewer
|
35
38
|
except ImportError as e:
|
36
39
|
RDRCaseViewer = None
|
37
40
|
from .utils import draw_tree, make_set, SubclassJSONSerializer, make_list, get_type_from_string, \
|
38
|
-
|
41
|
+
is_value_conflicting, extract_function_source, extract_imports, get_full_class_name, \
|
39
42
|
is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types, render_tree
|
40
43
|
|
41
44
|
|
@@ -97,11 +100,12 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
97
100
|
self.viewer = RDRCaseViewer.instances[0]
|
98
101
|
logger.error("No viewer was provided, but there is already an existing viewer. "
|
99
102
|
"Using the existing viewer.")
|
103
|
+
self.input_node: Optional[Rule] = None
|
100
104
|
|
101
105
|
@property
|
102
106
|
def viewer(self):
|
103
107
|
return self._viewer
|
104
|
-
|
108
|
+
|
105
109
|
@viewer.setter
|
106
110
|
def viewer(self, value):
|
107
111
|
self._viewer = value
|
@@ -110,23 +114,24 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
110
114
|
|
111
115
|
def render_evaluated_rule_tree(self, filename: str, show_full_tree: bool = False) -> None:
|
112
116
|
if show_full_tree:
|
113
|
-
|
117
|
+
start_rule = self.start_rule if self.input_node is None else self.input_node
|
118
|
+
render_tree(start_rule, use_dot_exporter=True, filename=filename)
|
114
119
|
else:
|
115
120
|
evaluated_rules = self.get_evaluated_rule_tree()
|
116
121
|
if evaluated_rules is not None and len(evaluated_rules) > 0:
|
117
122
|
render_tree(evaluated_rules[0], use_dot_exporter=True, filename=filename,
|
118
123
|
only_nodes=evaluated_rules)
|
119
124
|
|
120
|
-
def get_evaluated_rule_tree(self) -> List[Rule]:
|
125
|
+
def get_evaluated_rule_tree(self) -> Optional[List[Rule]]:
|
121
126
|
"""
|
122
127
|
Get the evaluated rule tree of the classifier.
|
123
128
|
|
124
129
|
:return: The evaluated rule tree.
|
125
130
|
"""
|
126
131
|
if self.start_rule is None:
|
127
|
-
return
|
128
|
-
|
129
|
-
evaluated_rule_tree = [r for r in [
|
132
|
+
return None
|
133
|
+
start_rule = self.start_rule
|
134
|
+
evaluated_rule_tree = [r for r in [start_rule] + list(start_rule.descendants) if r.evaluated]
|
130
135
|
return evaluated_rule_tree
|
131
136
|
|
132
137
|
def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None,
|
@@ -260,9 +265,38 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
260
265
|
def __call__(self, case: Union[Case, SQLTable]) -> Union[CallableExpression, Dict[str, CallableExpression]]:
|
261
266
|
return self.classify(case)
|
262
267
|
|
268
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False, case_query: Optional[CaseQuery] = None) \
|
269
|
+
-> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
|
270
|
+
"""
|
271
|
+
Classify a case using the RDR classifier.
|
272
|
+
|
273
|
+
:param case: The case to classify.
|
274
|
+
:param modify_case: Whether to modify the original case attributes with the conclusion or not.
|
275
|
+
:param case_query: The case query containing the case to classify and the target category to compare the case with.
|
276
|
+
:return: The category that the case belongs to.
|
277
|
+
"""
|
278
|
+
if self.start_rule is not None:
|
279
|
+
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
280
|
+
rule.evaluated = False
|
281
|
+
rule.fired = False
|
282
|
+
if self.start_rule is not None and self.start_rule.parent is None:
|
283
|
+
if self.input_node is None:
|
284
|
+
self.input_node = type(self.start_rule)(parent=None, uid='0')
|
285
|
+
self.input_node.evaluated = False
|
286
|
+
self.input_node.fired = False
|
287
|
+
self.start_rule.parent = self.input_node
|
288
|
+
self.start_rule.weight = ""
|
289
|
+
if self.input_node is not None:
|
290
|
+
data = case.__dict__ if is_dataclass(case) else case
|
291
|
+
if hasattr(case, "items"):
|
292
|
+
self.input_node.name = json.dumps({k: str(v) for k, v in data.items()}, indent=4)
|
293
|
+
else:
|
294
|
+
self.input_node.name = str(data)
|
295
|
+
return self._classify(case, modify_case=modify_case, case_query=case_query)
|
296
|
+
|
263
297
|
@abstractmethod
|
264
|
-
def
|
265
|
-
|
298
|
+
def _classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
|
299
|
+
case_query: Optional[CaseQuery] = None) \
|
266
300
|
-> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
|
267
301
|
"""
|
268
302
|
Classify a case.
|
@@ -467,25 +501,51 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
467
501
|
conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys()
|
468
502
|
if not isinstance(rules_dict[rid], MultiClassStopRule)]
|
469
503
|
all_func_names = condition_func_names + conclusion_func_names
|
504
|
+
rule_tree_file_path = f"{model_dir}/{self.generated_python_file_name}.py"
|
470
505
|
filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
|
471
506
|
cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
|
472
507
|
cases_import_path = get_import_path_from_path(model_dir)
|
473
508
|
cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path \
|
474
509
|
else self.generated_python_cases_file_name
|
475
510
|
functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
|
511
|
+
python_rule_tree_source = ""
|
512
|
+
with open(rule_tree_file_path, "r") as rule_tree_source:
|
513
|
+
python_rule_tree_source = rule_tree_source.read()
|
476
514
|
# get the scope from the imports in the file
|
477
515
|
scope = extract_imports(filepath, package_name=package_name)
|
516
|
+
rules_not_found = set()
|
478
517
|
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
479
518
|
if rule.conditions is not None:
|
480
|
-
|
519
|
+
conditions_name = rule.generated_conditions_function_name
|
520
|
+
if conditions_name not in functions_source or conditions_name not in python_rule_tree_source:
|
521
|
+
rules_not_found.add(rule)
|
522
|
+
continue
|
523
|
+
rule.conditions.user_input = functions_source[conditions_name]
|
481
524
|
rule.conditions.scope = scope
|
482
525
|
if os.path.exists(cases_path):
|
483
526
|
module = importlib.import_module(cases_import_path, package=package_name)
|
484
527
|
importlib.reload(module)
|
485
528
|
rule.corner_case_metadata = module.__dict__.get(f"corner_case_{rule.uid}", None)
|
486
|
-
if
|
487
|
-
|
529
|
+
if not isinstance(rule, MultiClassStopRule):
|
530
|
+
conclusion_name = rule.generated_conclusion_function_name
|
531
|
+
if conclusion_name not in functions_source or conclusion_name not in python_rule_tree_source:
|
532
|
+
rules_not_found.add(rule)
|
533
|
+
rule.conclusion.user_input = functions_source[conclusion_name]
|
488
534
|
rule.conclusion.scope = scope
|
535
|
+
for rule in rules_not_found:
|
536
|
+
if isinstance(rule, MultiClassTopRule):
|
537
|
+
import pdb; pdb.set_trace()
|
538
|
+
rule.parent.set_immediate_alternative(rule.alternative)
|
539
|
+
if rule.refinement is not None:
|
540
|
+
ref_rules = [ref_rule for ref_rule in [rule.refinement] + list(rule.refinement.descendants)]
|
541
|
+
for ref_rule in ref_rules:
|
542
|
+
del ref_rule
|
543
|
+
else:
|
544
|
+
rule.parent.refinement = rule.alternative
|
545
|
+
if rule.alternative is not None:
|
546
|
+
rule.alternative = None
|
547
|
+
rule.parent = None
|
548
|
+
del rule
|
489
549
|
|
490
550
|
@abstractmethod
|
491
551
|
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
|
@@ -564,7 +624,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
564
624
|
"""
|
565
625
|
pass
|
566
626
|
|
567
|
-
def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
|
627
|
+
def _get_types_to_import(self) -> Tuple[Set[Union[Type, Callable]], Set[Type], Set[Type]]:
|
568
628
|
"""
|
569
629
|
:return: The types of the main, defs, and corner cases files of the RDR classifier that will be imported.
|
570
630
|
"""
|
@@ -703,8 +763,8 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
703
763
|
expert.ask_for_conditions(case_query)
|
704
764
|
self.start_rule = SingleClassRule.from_case_query(case_query)
|
705
765
|
|
706
|
-
def
|
707
|
-
|
766
|
+
def _classify(self, case: Case, modify_case: bool = False,
|
767
|
+
case_query: Optional[CaseQuery] = None) -> Optional[Any]:
|
708
768
|
"""
|
709
769
|
Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
|
710
770
|
|
@@ -818,8 +878,8 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
818
878
|
super(MultiClassRDR, self).__init__(start_rule, **kwargs)
|
819
879
|
self.mode: MCRDRMode = mode
|
820
880
|
|
821
|
-
def
|
822
|
-
|
881
|
+
def _classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
|
882
|
+
case_query: Optional[CaseQuery] = None) -> Set[Any]:
|
823
883
|
evaluated_rule = self.start_rule
|
824
884
|
self.conclusions = []
|
825
885
|
while evaluated_rule:
|
@@ -897,6 +957,9 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
897
957
|
if rule.alternative:
|
898
958
|
self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
|
899
959
|
cases_file=cases_file, package_name=package_name)
|
960
|
+
elif isinstance(rule, MultiClassTopRule):
|
961
|
+
with open(filename, "a") as file:
|
962
|
+
file.write(f"{parent_indent}return conclusions\n")
|
900
963
|
|
901
964
|
@property
|
902
965
|
def conclusion_type_hint(self) -> str:
|
@@ -906,8 +969,9 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
906
969
|
else:
|
907
970
|
return f"Set[Union[{', '.join(conclusion_types)}]]"
|
908
971
|
|
909
|
-
def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
|
972
|
+
def _get_types_to_import(self) -> Tuple[Set[Union[Type, Callable]], Set[Type], Set[Type]]:
|
910
973
|
main_types, defs_types, cases_types = super()._get_types_to_import()
|
974
|
+
main_types.add(get_an_updated_case_copy)
|
911
975
|
main_types.update({Set, make_set})
|
912
976
|
defs_types.update({List, Set})
|
913
977
|
return main_types, defs_types, cases_types
|
@@ -939,28 +1003,43 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
939
1003
|
Stop a wrong conclusion by adding a stopping rule.
|
940
1004
|
"""
|
941
1005
|
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
942
|
-
|
943
|
-
|
944
|
-
|
1006
|
+
stop: bool = False
|
1007
|
+
add_filter_rule: bool = False
|
1008
|
+
if is_value_conflicting(rule_conclusion, case_query.target_value):
|
1009
|
+
if make_set(case_query.target_value).issubset(rule_conclusion):
|
1010
|
+
add_filter_rule = True
|
1011
|
+
else:
|
1012
|
+
stop = True
|
1013
|
+
elif make_set(case_query.core_attribute_type).issubset(make_set(evaluated_rule.conclusion.conclusion_type)):
|
1014
|
+
if make_set(case_query.target_value).issubset(rule_conclusion):
|
1015
|
+
add_filter_rule = True
|
1016
|
+
|
1017
|
+
if not stop:
|
945
1018
|
self.add_conclusion(rule_conclusion)
|
1019
|
+
if stop or add_filter_rule:
|
1020
|
+
refinement_type = MultiClassStopRule if stop else MultiClassFilterRule
|
1021
|
+
self.stop_or_filter_conclusion(case_query, expert, evaluated_rule, refinement_type=refinement_type)
|
946
1022
|
|
947
|
-
def
|
948
|
-
|
1023
|
+
def stop_or_filter_conclusion(self, case_query: CaseQuery,
|
1024
|
+
expert: Expert, evaluated_rule: MultiClassTopRule,
|
1025
|
+
refinement_type: Type[MultiClassRefinementRule] = MultiClassStopRule):
|
949
1026
|
"""
|
950
1027
|
Stop a conclusion by adding a stopping rule.
|
951
1028
|
|
952
1029
|
:param case_query: The case query to stop the conclusion for.
|
953
1030
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
954
1031
|
:param evaluated_rule: The evaluated rule to ask the expert about.
|
1032
|
+
:param refinement_type: The refinement type to use.
|
955
1033
|
"""
|
956
1034
|
conditions = expert.ask_for_conditions(case_query, evaluated_rule)
|
957
|
-
evaluated_rule.fit_rule(case_query)
|
958
|
-
if
|
959
|
-
self.
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
1035
|
+
evaluated_rule.fit_rule(case_query, refinement_type=refinement_type)
|
1036
|
+
if refinement_type is MultiClassStopRule:
|
1037
|
+
if self.mode == MCRDRMode.StopPlusRule:
|
1038
|
+
self.stop_rule_conditions = conditions
|
1039
|
+
if self.mode == MCRDRMode.StopPlusRuleCombined:
|
1040
|
+
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
1041
|
+
case_query.conditions = new_top_rule_conditions
|
1042
|
+
self.add_top_rule(case_query)
|
964
1043
|
|
965
1044
|
def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
|
966
1045
|
"""
|
@@ -1064,8 +1143,8 @@ class GeneralRDR(RippleDownRules):
|
|
1064
1143
|
def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
|
1065
1144
|
return [rdr.start_rule for rdr in self.start_rules_dict.values()]
|
1066
1145
|
|
1067
|
-
def
|
1068
|
-
|
1146
|
+
def _classify(self, case: Any, modify_case: bool = False,
|
1147
|
+
case_query: Optional[CaseQuery] = None) -> Optional[Dict[str, Any]]:
|
1069
1148
|
"""
|
1070
1149
|
Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
|
1071
1150
|
the classification until no more categories can be added.
|
@@ -8,15 +8,16 @@ from functools import wraps
|
|
8
8
|
|
9
9
|
from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union, Sequence
|
10
10
|
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from
|
11
|
+
from .datastructures.case import Case
|
12
|
+
from .datastructures.dataclasses import CaseQuery
|
13
|
+
from .experts import Expert, Human
|
14
|
+
from .rdr import GeneralRDR
|
15
|
+
from . import logger
|
15
16
|
try:
|
16
|
-
from
|
17
|
+
from .user_interface.gui import RDRCaseViewer
|
17
18
|
except ImportError:
|
18
19
|
RDRCaseViewer = None
|
19
|
-
from
|
20
|
+
from .utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
|
20
21
|
get_method_class_if_exists, str_to_snake_case
|
21
22
|
|
22
23
|
|
ripple_down_rules/rules.py
CHANGED
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
3
3
|
import logging
|
4
4
|
import re
|
5
5
|
from abc import ABC, abstractmethod
|
6
|
-
from pathlib import Path
|
7
6
|
from types import NoneType
|
8
7
|
from uuid import uuid4
|
9
8
|
|
@@ -15,7 +14,8 @@ from .datastructures.callable_expression import CallableExpression
|
|
15
14
|
from .datastructures.case import Case
|
16
15
|
from .datastructures.dataclasses import CaseFactoryMetaData, CaseQuery
|
17
16
|
from .datastructures.enums import RDREdge, Stop
|
18
|
-
from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name,
|
17
|
+
from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_type_from_string
|
18
|
+
from .helpers import get_an_updated_case_copy
|
19
19
|
|
20
20
|
|
21
21
|
class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
@@ -23,6 +23,10 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
23
23
|
"""
|
24
24
|
Whether the rule has fired or not.
|
25
25
|
"""
|
26
|
+
mutually_exclusive: bool
|
27
|
+
"""
|
28
|
+
Whether the rule is mutually exclusive with other rules.
|
29
|
+
"""
|
26
30
|
|
27
31
|
def __init__(self, conditions: Optional[CallableExpression] = None,
|
28
32
|
conclusion: Optional[CallableExpression] = None,
|
@@ -60,6 +64,14 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
60
64
|
self.evaluated: bool = False
|
61
65
|
self._user_defined_name: Optional[str] = None
|
62
66
|
|
67
|
+
def get_an_updated_case_copy(self, case: Case) -> Case:
|
68
|
+
"""
|
69
|
+
:param case: The case to copy and update.
|
70
|
+
:return: A copy of the case updated with this rule conclusion.
|
71
|
+
"""
|
72
|
+
return get_an_updated_case_copy(case, self.conclusion, self.conclusion_name, self.conclusion.conclusion_type,
|
73
|
+
self.mutually_exclusive)
|
74
|
+
|
63
75
|
@property
|
64
76
|
def color(self) -> str:
|
65
77
|
if self.evaluated:
|
@@ -78,22 +90,27 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
78
90
|
if self._user_defined_name is None:
|
79
91
|
if self.conditions and self.conditions.user_input and "def " in self.conditions.user_input:
|
80
92
|
# If the conditions have a user input, use it as the name
|
81
|
-
|
93
|
+
func_name = self.conditions.user_input.split('(')[0].replace('def ', '').strip()
|
94
|
+
if func_name == self.conditions.encapsulating_function_name:
|
95
|
+
self._user_defined_name = str(self.conditions)
|
96
|
+
else:
|
97
|
+
self._user_defined_name = func_name
|
82
98
|
else:
|
83
99
|
self._user_defined_name = f"Rule_{self.uid}"
|
84
100
|
return self._user_defined_name
|
85
101
|
|
86
102
|
@classmethod
|
87
|
-
def from_case_query(cls, case_query: CaseQuery) -> Rule:
|
103
|
+
def from_case_query(cls, case_query: CaseQuery, parent: Optional[Rule] = None) -> Rule:
|
88
104
|
"""
|
89
105
|
Create a SingleClassRule from a CaseQuery.
|
90
106
|
|
91
107
|
:param case_query: The CaseQuery to create the rule from.
|
108
|
+
:param parent: The parent rule of this rule.
|
92
109
|
:return: A SingleClassRule instance.
|
93
110
|
"""
|
94
111
|
corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
95
112
|
return cls(conditions=case_query.conditions, conclusion=case_query.target,
|
96
|
-
corner_case=case_query.case, parent=
|
113
|
+
corner_case=case_query.case, parent=parent,
|
97
114
|
corner_case_metadata=corner_case_metadata,
|
98
115
|
conclusion_name=case_query.attribute_name)
|
99
116
|
|
@@ -116,9 +133,6 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
116
133
|
:param x: The case to evaluate the rule on.
|
117
134
|
:return: The rule that fired or the last evaluated rule if no rule fired.
|
118
135
|
"""
|
119
|
-
if self.root is self:
|
120
|
-
for descendant in self.descendants:
|
121
|
-
descendant.evaluated = False
|
122
136
|
self.evaluated = True
|
123
137
|
if not self.conditions:
|
124
138
|
raise ValueError("Rule has no conditions")
|
@@ -178,6 +192,14 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
178
192
|
f.write(conclusion_func.strip() + "\n\n\n")
|
179
193
|
return conclusion_func_call
|
180
194
|
|
195
|
+
@property
|
196
|
+
def generated_conclusion_function_name(self) -> str:
|
197
|
+
return f"conclusion_{self.uid}"
|
198
|
+
|
199
|
+
@property
|
200
|
+
def generated_conditions_function_name(self) -> str:
|
201
|
+
return f"conditions_{self.uid}"
|
202
|
+
|
181
203
|
def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
|
182
204
|
"""
|
183
205
|
Convert the conclusion of a rule to source code.
|
@@ -190,23 +212,24 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
190
212
|
# This means the conclusion is a definition that should be written and then called
|
191
213
|
conclusion_lines = conclusion.split('\n')
|
192
214
|
# use regex to replace the function name
|
193
|
-
new_function_name = f"def
|
215
|
+
new_function_name = f"def {self.generated_conclusion_function_name}"
|
194
216
|
conclusion_lines[0] = re.sub(r"def (\w+)", new_function_name, conclusion_lines[0])
|
195
217
|
# add type hint
|
196
|
-
if
|
197
|
-
|
218
|
+
if not self.conclusion.mutually_exclusive:
|
219
|
+
type_names = [t.__name__ for t in self.conclusion.conclusion_type if t not in [list, set]]
|
220
|
+
if len(type_names) == 1:
|
221
|
+
hint = f"List[{type_names[0]}]"
|
222
|
+
else:
|
223
|
+
hint = f"List[Union[{', '.join(type_names)}]]"
|
198
224
|
else:
|
199
|
-
if
|
200
|
-
|
201
|
-
|
202
|
-
|
225
|
+
if NoneType in self.conclusion.conclusion_type:
|
226
|
+
type_names = [t.__name__ for t in self.conclusion.conclusion_type if t is not NoneType]
|
227
|
+
hint = f"Optional[{', '.join(type_names)}]"
|
228
|
+
elif len(self.conclusion.conclusion_type) == 1:
|
229
|
+
hint = self.conclusion.conclusion_type[0].__name__
|
203
230
|
else:
|
204
|
-
|
205
|
-
|
206
|
-
hint = f"Optional[{', '.join(type_names)}]"
|
207
|
-
else:
|
208
|
-
type_names = [t.__name__ for t in self.conclusion.conclusion_type]
|
209
|
-
hint = f"Union[{', '.join(type_names)}]"
|
231
|
+
type_names = [t.__name__ for t in self.conclusion.conclusion_type]
|
232
|
+
hint = f"Union[{', '.join(type_names)}]"
|
210
233
|
conclusion_lines[0] = conclusion_lines[0].replace("):", f") -> {hint}:")
|
211
234
|
func_call = f"{parent_indent} return {new_function_name.replace('def ', '')}(case)\n"
|
212
235
|
return "\n".join(conclusion_lines).strip(' '), func_call
|
@@ -228,7 +251,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
228
251
|
# This means the conditions are a definition that should be written and then called
|
229
252
|
conditions_lines = self.conditions.user_input.split('\n')
|
230
253
|
# use regex to replace the function name
|
231
|
-
new_function_name = f"def
|
254
|
+
new_function_name = f"def {self.generated_conditions_function_name}"
|
232
255
|
conditions_lines[0] = re.sub(r"def (\w+)", new_function_name, conditions_lines[0])
|
233
256
|
# add type hint
|
234
257
|
conditions_lines[0] = conditions_lines[0].replace('):', ') -> bool:')
|
@@ -310,10 +333,11 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
310
333
|
Get the name of the expression, which is the user input of the expression if it exists,
|
311
334
|
otherwise it is the conclusion or conditions of the rule.
|
312
335
|
"""
|
313
|
-
if expression.user_defined_name is not None:
|
336
|
+
if expression.user_defined_name is not None and expression.user_defined_name != expression.encapsulating_function_name:
|
314
337
|
return expression.user_defined_name.strip()
|
315
|
-
|
316
|
-
|
338
|
+
func_name = expression.user_input.split('(')[0].replace('def ', '').strip() if "def " in expression.user_input else None
|
339
|
+
if func_name is not None and func_name != expression.encapsulating_function_name:
|
340
|
+
return func_name
|
317
341
|
elif expression.user_input:
|
318
342
|
return expression.user_input.strip()
|
319
343
|
else:
|
@@ -350,6 +374,9 @@ class HasAlternativeRule:
|
|
350
374
|
def alternative(self) -> Optional[Rule]:
|
351
375
|
return self._alternative
|
352
376
|
|
377
|
+
def set_immediate_alternative(self, alternative: Optional[Rule]):
|
378
|
+
self._alternative = alternative
|
379
|
+
|
353
380
|
@alternative.setter
|
354
381
|
def alternative(self, new_rule: Rule):
|
355
382
|
"""
|
@@ -385,12 +412,11 @@ class HasRefinementRule:
|
|
385
412
|
"""
|
386
413
|
if new_rule is None:
|
387
414
|
return
|
388
|
-
new_rule.top_rule = self
|
389
415
|
if self.refinement:
|
390
416
|
self.refinement.alternative = new_rule
|
391
417
|
else:
|
392
418
|
new_rule.parent = self
|
393
|
-
new_rule.weight = RDREdge.Refinement.value
|
419
|
+
new_rule.weight = RDREdge.Refinement.value if not isinstance(new_rule, MultiClassFilterRule) else new_rule.weight
|
394
420
|
self._refinement = new_rule
|
395
421
|
|
396
422
|
|
@@ -399,6 +425,8 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
|
399
425
|
A rule in the SingleClassRDR classifier, it can have a refinement or an alternative rule or both.
|
400
426
|
"""
|
401
427
|
|
428
|
+
mutually_exclusive: bool = True
|
429
|
+
|
402
430
|
def evaluate_next_rule(self, x: Case) -> SingleClassRule:
|
403
431
|
if self.fired:
|
404
432
|
returned_rule = self.refinement(x) if self.refinement else self
|
@@ -434,28 +462,15 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
|
434
462
|
return "elif" if self.weight == RDREdge.Alternative.value else "if"
|
435
463
|
|
436
464
|
|
437
|
-
class
|
465
|
+
class MultiClassRefinementRule(Rule, HasAlternativeRule, ABC):
|
438
466
|
"""
|
439
|
-
A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule
|
440
|
-
the conclusion of the rule is a Stop category meant to stop the parent conclusion from being made.
|
467
|
+
A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule.
|
441
468
|
"""
|
442
469
|
top_rule: Optional[MultiClassTopRule] = None
|
443
470
|
"""
|
444
471
|
The top rule of the rule, which is the nearest ancestor that fired and this rule is a refinement of.
|
445
472
|
"""
|
446
|
-
|
447
|
-
def __init__(self, *args, **kwargs):
|
448
|
-
super(MultiClassStopRule, self).__init__(*args, **kwargs)
|
449
|
-
self.conclusion = CallableExpression(conclusion_type=(Stop,), conclusion=Stop.stop)
|
450
|
-
|
451
|
-
def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassStopRule, MultiClassTopRule]]:
|
452
|
-
if self.fired:
|
453
|
-
self.top_rule.fired = False
|
454
|
-
return self.top_rule.alternative
|
455
|
-
elif self.alternative:
|
456
|
-
return self.alternative(x)
|
457
|
-
else:
|
458
|
-
return self.top_rule.alternative
|
473
|
+
mutually_exclusive: bool = False
|
459
474
|
|
460
475
|
def _to_json(self) -> Dict[str, Any]:
|
461
476
|
self.json_serialization = {**Rule._to_json(self),
|
@@ -463,27 +478,106 @@ class MultiClassStopRule(Rule, HasAlternativeRule):
|
|
463
478
|
return self.json_serialization
|
464
479
|
|
465
480
|
@classmethod
|
466
|
-
def _from_json(cls, data: Dict[str, Any]) ->
|
467
|
-
loaded_rule = super(
|
481
|
+
def _from_json(cls, data: Dict[str, Any]) -> MultiClassRefinementRule:
|
482
|
+
loaded_rule = super(MultiClassRefinementRule, cls)._from_json(data)
|
468
483
|
# The following is done to prevent re-initialization of the top rule,
|
469
484
|
# so the top rule that is already initialized is passed in the data instead of its json serialization.
|
470
485
|
loaded_rule.top_rule = data['top_rule']
|
471
486
|
if data['alternative'] is not None:
|
472
487
|
data['alternative']['top_rule'] = data['top_rule']
|
473
|
-
|
488
|
+
loaded_rule.alternative = SubclassJSONSerializer.from_json(data["alternative"])
|
474
489
|
return loaded_rule
|
475
490
|
|
491
|
+
def _if_statement_source_code_clause(self) -> str:
|
492
|
+
return "elif" if self.weight == RDREdge.Alternative.value else "if"
|
493
|
+
|
494
|
+
|
495
|
+
class MultiClassStopRule(MultiClassRefinementRule):
|
496
|
+
"""
|
497
|
+
A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
|
498
|
+
the conclusion of the rule is a Stop category meant to stop the parent conclusion from being made.
|
499
|
+
"""
|
500
|
+
|
501
|
+
def __init__(self, *args, **kwargs):
|
502
|
+
super(MultiClassRefinementRule, self).__init__(*args, **kwargs)
|
503
|
+
self.conclusion = CallableExpression(conclusion_type=(Stop,), conclusion=Stop.stop)
|
504
|
+
|
505
|
+
def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassRefinementRule, MultiClassTopRule]]:
|
506
|
+
if self.fired:
|
507
|
+
self.top_rule.fired = False
|
508
|
+
return self.top_rule.alternative
|
509
|
+
elif self.alternative:
|
510
|
+
return self.alternative(x)
|
511
|
+
else:
|
512
|
+
return self.top_rule.alternative
|
513
|
+
|
476
514
|
def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[None, str]:
|
477
515
|
return None, f"{parent_indent}{' ' * 4}pass\n"
|
478
516
|
|
479
|
-
|
480
|
-
|
517
|
+
|
518
|
+
class MultiClassFilterRule(MultiClassRefinementRule, HasRefinementRule):
|
519
|
+
"""
|
520
|
+
A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
|
521
|
+
the conclusion of the rule is a Filter category meant to filter the parent conclusion.
|
522
|
+
"""
|
523
|
+
|
524
|
+
def __init__(self, *args, **kwargs):
|
525
|
+
super(MultiClassRefinementRule, self).__init__(*args, **kwargs)
|
526
|
+
self.weight = RDREdge.Filter.value
|
527
|
+
|
528
|
+
def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassRefinementRule, MultiClassTopRule]]:
|
529
|
+
if self.fired:
|
530
|
+
if self.refinement:
|
531
|
+
case_cp = x
|
532
|
+
if isinstance(self.refinement, MultiClassFilterRule):
|
533
|
+
case_cp = self.get_an_updated_case_copy(case_cp)
|
534
|
+
return self.refinement(case_cp)
|
535
|
+
else:
|
536
|
+
return self.top_rule.alternative
|
537
|
+
elif self.alternative:
|
538
|
+
return self.alternative(x)
|
539
|
+
else:
|
540
|
+
return self.top_rule.alternative
|
541
|
+
|
542
|
+
def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[None, str]:
|
543
|
+
func, func_call = super().get_conclusion_as_source_code(str(conclusion), parent_indent=parent_indent)
|
544
|
+
conclusion_str = func_call.replace("return ", "").strip()
|
545
|
+
conclusion_str = conclusion_str.replace("(case)", "(case_cp)")
|
546
|
+
|
547
|
+
parent_to_filter = self.get_parent_to_filter()
|
548
|
+
statement = (
|
549
|
+
f"\n{parent_indent} case_cp = get_an_updated_case_copy(case, {parent_to_filter.generated_conclusion_function_name},"
|
550
|
+
f" attribute_name, conclusion_type, mutually_exclusive)")
|
551
|
+
statement += f"\n{parent_indent} conclusions.update(make_set({conclusion_str}))\n"
|
552
|
+
return func, statement
|
553
|
+
|
554
|
+
def get_parent_to_filter(self, parent: Union[None, MultiClassRefinementRule, MultiClassTopRule] = None) \
|
555
|
+
-> Union[MultiClassFilterRule, MultiClassTopRule]:
|
556
|
+
parent = self.parent if parent is None else parent
|
557
|
+
if isinstance(parent, (MultiClassFilterRule, MultiClassTopRule)) and parent.fired:
|
558
|
+
return parent
|
559
|
+
else:
|
560
|
+
return parent.parent
|
561
|
+
|
562
|
+
def _to_json(self) -> Dict[str, Any]:
|
563
|
+
self.json_serialization = super(MultiClassFilterRule, self)._to_json()
|
564
|
+
self.json_serialization['refinement'] = self.refinement.to_json() if self.refinement is not None else None
|
565
|
+
return self.json_serialization
|
566
|
+
|
567
|
+
@classmethod
|
568
|
+
def _from_json(cls, data: Dict[str, Any]) -> MultiClassFilterRule:
|
569
|
+
loaded_rule = super(MultiClassFilterRule, cls)._from_json(data)
|
570
|
+
if data['refinement'] is not None:
|
571
|
+
data['refinement']['top_rule'] = data['top_rule']
|
572
|
+
loaded_rule.refinement = cls.from_json(data["refinement"]) if data["refinement"] is not None else None
|
573
|
+
return loaded_rule
|
481
574
|
|
482
575
|
|
483
576
|
class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
484
577
|
"""
|
485
578
|
A rule in the MultiClassRDR classifier, it can have a refinement and a next rule.
|
486
579
|
"""
|
580
|
+
mutually_exclusive: bool = False
|
487
581
|
|
488
582
|
def __init__(self, *args, **kwargs):
|
489
583
|
super(MultiClassTopRule, self).__init__(*args, **kwargs)
|
@@ -491,16 +585,27 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
|
491
585
|
|
492
586
|
def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassStopRule, MultiClassTopRule]]:
|
493
587
|
if self.fired and self.refinement:
|
494
|
-
|
588
|
+
case_cp = x
|
589
|
+
if isinstance(self.refinement, MultiClassFilterRule):
|
590
|
+
case_cp = self.get_an_updated_case_copy(case_cp)
|
591
|
+
return self.refinement(case_cp)
|
495
592
|
elif self.alternative: # Here alternative refers to next rule in MultiClassRDR
|
496
593
|
return self.alternative
|
594
|
+
return None
|
497
595
|
|
498
|
-
def fit_rule(self, case_query: CaseQuery):
|
596
|
+
def fit_rule(self, case_query: CaseQuery, refinement_type: Optional[Type[MultiClassRefinementRule]] = None):
|
499
597
|
if self.fired and case_query.target != self.conclusion:
|
500
|
-
|
598
|
+
if refinement_type in [None, MultiClassStopRule]:
|
599
|
+
new_rule = MultiClassStopRule(case_query.conditions, corner_case=case_query.case,
|
600
|
+
parent=self)
|
601
|
+
elif refinement_type is MultiClassFilterRule:
|
602
|
+
new_rule = MultiClassFilterRule.from_case_query(case_query, parent=self)
|
603
|
+
else:
|
604
|
+
raise ValueError(f"Unknown refinement type {refinement_type}")
|
605
|
+
new_rule.top_rule = self
|
606
|
+
self.refinement = new_rule
|
501
607
|
elif not self.fired:
|
502
|
-
self.alternative = MultiClassTopRule(case_query
|
503
|
-
corner_case=case_query.case, parent=self)
|
608
|
+
self.alternative = MultiClassTopRule.from_case_query(case_query, parent=self)
|
504
609
|
|
505
610
|
def _to_json(self) -> Dict[str, Any]:
|
506
611
|
self.json_serialization = {**Rule._to_json(self),
|
@@ -515,7 +620,8 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
|
515
620
|
# so the top rule that is already initialized is passed in the data instead of its json serialization.
|
516
621
|
if data['refinement'] is not None:
|
517
622
|
data['refinement']['top_rule'] = loaded_rule
|
518
|
-
|
623
|
+
data_type = get_type_from_string(data["refinement"]["_type"])
|
624
|
+
loaded_rule.refinement = data_type.from_json(data["refinement"])
|
519
625
|
loaded_rule.alternative = MultiClassTopRule.from_json(data["alternative"])
|
520
626
|
return loaded_rule
|
521
627
|
|
@@ -524,8 +630,6 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
|
524
630
|
conclusion_str = func_call.replace("return ", "").strip()
|
525
631
|
|
526
632
|
statement = f"{parent_indent} conclusions.update(make_set({conclusion_str}))\n"
|
527
|
-
if self.alternative is None:
|
528
|
-
statement += f"{parent_indent}return conclusions\n"
|
529
633
|
return func, statement
|
530
634
|
|
531
635
|
def _if_statement_source_code_clause(self) -> str:
|
@@ -50,6 +50,7 @@ class UserPrompt:
|
|
50
50
|
:return: A callable expression that takes a case and executes user expression on it.
|
51
51
|
"""
|
52
52
|
prev_user_input: Optional[str] = None
|
53
|
+
user_input_to_modify: Optional[str] = None
|
53
54
|
callable_expression: Optional[CallableExpression] = None
|
54
55
|
while True:
|
55
56
|
with self.shell_lock:
|
@@ -69,12 +70,14 @@ class UserPrompt:
|
|
69
70
|
conclusion_type = bool if prompt_for == PromptFor.Conditions else case_query.attribute_type
|
70
71
|
callable_expression = CallableExpression(user_input, conclusion_type, expression_tree=expression_tree,
|
71
72
|
scope=case_query.scope,
|
72
|
-
|
73
|
+
mutually_exclusive=case_query.mutually_exclusive)
|
73
74
|
try:
|
74
75
|
result = callable_expression(case_query.case)
|
75
|
-
if len(make_list(result)) == 0
|
76
|
+
if len(make_list(result)) == 0 and (user_input_to_modify is not None
|
77
|
+
and (prev_user_input != user_input_to_modify)):
|
78
|
+
user_input_to_modify = prev_user_input
|
76
79
|
self.print_func(f"{Fore.YELLOW}The given expression gave an empty result for case {case_query.name}."
|
77
|
-
f" Please modify!{Style.RESET_ALL}")
|
80
|
+
f" Please accept or modify!{Style.RESET_ALL}")
|
78
81
|
continue
|
79
82
|
break
|
80
83
|
except Exception as e:
|
ripple_down_rules/utils.py
CHANGED
@@ -8,6 +8,7 @@ import json
|
|
8
8
|
import logging
|
9
9
|
import os
|
10
10
|
import re
|
11
|
+
import shutil
|
11
12
|
import sys
|
12
13
|
import threading
|
13
14
|
import uuid
|
@@ -24,6 +25,7 @@ from types import NoneType
|
|
24
25
|
|
25
26
|
import six
|
26
27
|
from sqlalchemy.exc import NoInspectionAvailable
|
28
|
+
from src.pycram.ros import logwarn
|
27
29
|
|
28
30
|
try:
|
29
31
|
import matplotlib
|
@@ -44,12 +46,12 @@ except ImportError as e:
|
|
44
46
|
|
45
47
|
import requests
|
46
48
|
from anytree import Node, RenderTree, PreOrderIter
|
47
|
-
from anytree.exporter import DotExporter
|
48
49
|
from sqlalchemy import MetaData, inspect
|
49
50
|
from sqlalchemy.orm import Mapped, registry, class_mapper, DeclarativeBase as SQLTable, Session
|
50
51
|
from tabulate import tabulate
|
51
52
|
from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get_type_hints, \
|
52
53
|
get_origin, get_args, Tuple, Optional, List, Union, Self, ForwardRef, Iterable
|
54
|
+
from . import logger
|
53
55
|
|
54
56
|
if TYPE_CHECKING:
|
55
57
|
from .datastructures.case import Case
|
@@ -180,8 +182,8 @@ def extract_function_source(file_path: str,
|
|
180
182
|
if (len(functions_source) >= len(function_names)) and (not len(function_names) == 0):
|
181
183
|
break
|
182
184
|
if len(functions_source) < len(function_names):
|
183
|
-
|
184
|
-
|
185
|
+
logwarn(f"Could not find all functions in {file_path}: {function_names} not found, "
|
186
|
+
f"functions not found: {set(function_names) - set(functions_source.keys())}")
|
185
187
|
if return_line_numbers:
|
186
188
|
return functions_source, line_numbers
|
187
189
|
return functions_source
|
@@ -285,7 +287,7 @@ def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
|
|
285
287
|
case_query.case.update(conclusions)
|
286
288
|
|
287
289
|
|
288
|
-
def
|
290
|
+
def is_value_conflicting(conclusion: Any, target: Any) -> bool:
|
289
291
|
"""
|
290
292
|
:param conclusion: The conclusion to check.
|
291
293
|
:param target: The target to compare the conclusion with.
|
@@ -845,10 +847,12 @@ def get_relative_import(target_file_path, imported_module_path: Optional[str] =
|
|
845
847
|
imported_file_name = Path(imported_module_path).name
|
846
848
|
target_file_name = Path(target_file_path).name
|
847
849
|
if package_name is not None:
|
848
|
-
target_path = Path(
|
850
|
+
target_path = Path(
|
851
|
+
get_path_starting_from_latest_encounter_of(str(target_path), package_name, [target_file_name]))
|
849
852
|
imported_path = Path(imported_module_path).resolve()
|
850
853
|
if package_name is not None:
|
851
|
-
imported_path = Path(
|
854
|
+
imported_path = Path(
|
855
|
+
get_path_starting_from_latest_encounter_of(str(imported_path), package_name, [imported_file_name]))
|
852
856
|
|
853
857
|
# Compute relative path from target to imported module
|
854
858
|
rel_path = os.path.relpath(imported_path.parent, target_path.parent)
|
@@ -925,8 +929,8 @@ def get_imports_from_types(type_objs: Iterable[Type],
|
|
925
929
|
continue
|
926
930
|
if name == "NoneType":
|
927
931
|
module = "types"
|
928
|
-
if module is None or module == 'builtins' or module.startswith('_')\
|
929
|
-
|
932
|
+
if module is None or module == 'builtins' or module.startswith('_') \
|
933
|
+
or module in sys.builtin_module_names or module in excluded_modules or "<" in module \
|
930
934
|
or name in exclueded_names:
|
931
935
|
continue
|
932
936
|
if module == "typing":
|
@@ -1216,7 +1220,8 @@ class SubclassJSONSerializer:
|
|
1216
1220
|
return cls._from_json(data)
|
1217
1221
|
for subclass in recursive_subclasses(SubclassJSONSerializer):
|
1218
1222
|
if get_full_class_name(subclass) == data["_type"]:
|
1219
|
-
subclass_data = deepcopy(data)
|
1223
|
+
# subclass_data = deepcopy(data)
|
1224
|
+
subclass_data = data
|
1220
1225
|
subclass_data.pop("_type")
|
1221
1226
|
return subclass._from_json(subclass_data)
|
1222
1227
|
|
@@ -1408,7 +1413,11 @@ def table_rows_as_str(row_dicts: List[Dict[str, Any]], columns_per_row: int = 20
|
|
1408
1413
|
row_values = [list(map(lambda v: v[:max_line_sze] + '...' if len(v) > max_line_sze else v, row)) for row in
|
1409
1414
|
row_values]
|
1410
1415
|
row_values = [list(map(lambda v: v.lower() if v in ["True", "False"] else v, row)) for row in row_values]
|
1411
|
-
|
1416
|
+
# Step 1: Get terminal size
|
1417
|
+
terminal_width = shutil.get_terminal_size((80, 20)).columns
|
1418
|
+
# Step 2: Dynamically calculate max width per column (simple approximation)
|
1419
|
+
max_col_width = terminal_width // len(row_values[0])
|
1420
|
+
table = tabulate(row_values, tablefmt='simple_grid', maxcolwidths=max_col_width) # [max_line_sze] * 2)
|
1412
1421
|
all_table_rows.append(table)
|
1413
1422
|
return "\n".join(all_table_rows)
|
1414
1423
|
|
@@ -1628,12 +1637,14 @@ def edge_attr_setter(parent, child):
|
|
1628
1637
|
"""
|
1629
1638
|
Set the edge attributes for the dot exporter.
|
1630
1639
|
"""
|
1631
|
-
if child and hasattr(child, "weight") and child.weight:
|
1640
|
+
if child and hasattr(child, "weight") and child.weight is not None:
|
1632
1641
|
return f'style="bold", label=" {child.weight}"'
|
1633
1642
|
return ""
|
1634
1643
|
|
1635
1644
|
|
1636
1645
|
_RE_ESC = re.compile(r'["\\]')
|
1646
|
+
|
1647
|
+
|
1637
1648
|
class FilteredDotExporter(object):
|
1638
1649
|
|
1639
1650
|
def __init__(self, node, include_nodes=None, graph="digraph", name="tree", options=None,
|
@@ -1913,7 +1924,7 @@ class FilteredDotExporter(object):
|
|
1913
1924
|
|
1914
1925
|
|
1915
1926
|
def render_tree(root: Node, use_dot_exporter: bool = False,
|
1916
|
-
filename: str = "scrdr", only_nodes: List[Node] = None):
|
1927
|
+
filename: str = "scrdr", only_nodes: List[Node] = None, show_in_console: bool = False):
|
1917
1928
|
"""
|
1918
1929
|
Render the tree using the console and optionally export it to a dot file.
|
1919
1930
|
|
@@ -1921,12 +1932,16 @@ def render_tree(root: Node, use_dot_exporter: bool = False,
|
|
1921
1932
|
:param use_dot_exporter: Whether to export the tree to a dot file.
|
1922
1933
|
:param filename: The name of the file to export the tree to.
|
1923
1934
|
:param only_nodes: A list of nodes to include in the dot export.
|
1935
|
+
:param show_in_console: Whether to print the tree to the console.
|
1924
1936
|
"""
|
1925
1937
|
if not root:
|
1926
|
-
|
1938
|
+
logger.warning("No rules to render")
|
1927
1939
|
return
|
1928
|
-
|
1929
|
-
|
1940
|
+
if show_in_console:
|
1941
|
+
for pre, _, node in RenderTree(root):
|
1942
|
+
if only_nodes is not None and node not in only_nodes:
|
1943
|
+
continue
|
1944
|
+
print(f"{pre}{node.weight if hasattr(node, 'weight') and node.weight else ''} {node.__str__()}")
|
1930
1945
|
if use_dot_exporter:
|
1931
1946
|
unique_node_names = get_unique_node_names_func(root)
|
1932
1947
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.6.
|
3
|
+
Version: 0.6.25
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -0,0 +1,24 @@
|
|
1
|
+
ripple_down_rules/__init__.py,sha256=hhuChAaI4pksG6FgW6-mLaJhifqkw_2Dy21ccjHfIFs,100
|
2
|
+
ripple_down_rules/experts.py,sha256=MYK1-vuvU1Stp82YZa8TcwOzvriIiYb0WrPFpWUNnXc,13005
|
3
|
+
ripple_down_rules/helpers.py,sha256=X1psHOqrb4_xYN4ssQNB8S9aRKKsqgihAyWJurN0dqk,5499
|
4
|
+
ripple_down_rules/rdr.py,sha256=avufEkijvJQCji1S5O1zaDSS7aIoP7Yjefy_4uK1TII,62125
|
5
|
+
ripple_down_rules/rdr_decorators.py,sha256=TRhbaB_ZIXN0n8Up19NI43_mMjmTm24qo8axAAOzbxM,11728
|
6
|
+
ripple_down_rules/rules.py,sha256=N4dEx-xyqxGZpoEYzRd9P5u97_DcDEVLY_UiNhZ4E7g,28726
|
7
|
+
ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
|
8
|
+
ripple_down_rules/utils.py,sha256=CD-J33gKHUSGU-hZdVvlGtCvsuID7b1pirv5ejpkEP0,73891
|
9
|
+
ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
|
10
|
+
ripple_down_rules/datastructures/callable_expression.py,sha256=P3o-z54Jt4rtIczeFWiuHFTNqMzYEOm94OyOP535D6Q,13378
|
11
|
+
ripple_down_rules/datastructures/case.py,sha256=PJ7_-AdxYic6BO5z816piFODj6nU5J6Jt1YzTFH-dds,15510
|
12
|
+
ripple_down_rules/datastructures/dataclasses.py,sha256=kI3Kv8GiVR8igMgA_BlKN6djUYxC2mLecvyh19pqQQA,10998
|
13
|
+
ripple_down_rules/datastructures/enums.py,sha256=R9AkhMKTDErOSZ8J6gEdh2lQ0Bjsxs22eMBtCPrXosI,5804
|
14
|
+
ripple_down_rules/user_interface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
|
+
ripple_down_rules/user_interface/gui.py,sha256=druufu9cVeVUajPW-RqGW3ZiEbgdgNBQD2CLhadQo18,27486
|
16
|
+
ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=yp-F8YRWGhj1PLB33HE6vJkdYWFN5Zn2244S2DUWRTM,6576
|
17
|
+
ripple_down_rules/user_interface/object_diagram.py,sha256=FEa2HaYR9QmTE6NsOwBvZ0jqmu3DKyg6mig2VE5ZP4Y,4956
|
18
|
+
ripple_down_rules/user_interface/prompt.py,sha256=e5FzVfiIagwKTK3WCKsHvWaWZ4kb8FP8X-SgieTln6E,9156
|
19
|
+
ripple_down_rules/user_interface/template_file_creator.py,sha256=kwBbFLyN6Yx2NTIHPSwOoytWgbJDYhgrUOVFw_jkDQ4,13522
|
20
|
+
ripple_down_rules-0.6.25.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
21
|
+
ripple_down_rules-0.6.25.dist-info/METADATA,sha256=OCHGDEQ_jkpOq8iUIZONmvsMGXHPdsfzfuTSZOJUb_k,48294
|
22
|
+
ripple_down_rules-0.6.25.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
23
|
+
ripple_down_rules-0.6.25.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
24
|
+
ripple_down_rules-0.6.25.dist-info/RECORD,,
|
@@ -1,24 +0,0 @@
|
|
1
|
-
ripple_down_rules/__init__.py,sha256=Y8dJbmLSU5jLm9E0jiReo5eIcp03jmLvx6sNfRfi49M,100
|
2
|
-
ripple_down_rules/experts.py,sha256=4-dMIVeMzFXCLYl_XBG_P7_Xs4sZih9-vZxCIPri6dA,12958
|
3
|
-
ripple_down_rules/helpers.py,sha256=RUdfiSWMZjGwCxuCy44TcEJf2UNAFlPJusgHzuAs6qI,4583
|
4
|
-
ripple_down_rules/rdr.py,sha256=sibe0amvb2MuVFRGjADgel6a-K0bp8fpJQRLnRuYW-k,57803
|
5
|
-
ripple_down_rules/rdr_decorators.py,sha256=0xVM-yRZJ6BrpyMfQNBuFUKT_-me6tbd4UTgG5Exx2g,11809
|
6
|
-
ripple_down_rules/rules.py,sha256=2d2Fgo5urVvUoIJWdRyaojcFBj_lfhLZaBVrXwb5KLA,23764
|
7
|
-
ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
|
8
|
-
ripple_down_rules/utils.py,sha256=VM8vrshOJIqXAY48dghu86KSGpYQoGGDMEiq3GrLXsQ,73319
|
9
|
-
ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
|
10
|
-
ripple_down_rules/datastructures/callable_expression.py,sha256=ysK-4JmZ4oSUTJC7zpo_o77g4ONxPDEcIpSWggsnx3c,13320
|
11
|
-
ripple_down_rules/datastructures/case.py,sha256=PJ7_-AdxYic6BO5z816piFODj6nU5J6Jt1YzTFH-dds,15510
|
12
|
-
ripple_down_rules/datastructures/dataclasses.py,sha256=D-nrVEW_27njmDGkyiHRnq5lmqEdO8RHKnLb1mdnwrA,11054
|
13
|
-
ripple_down_rules/datastructures/enums.py,sha256=ce7tqS0otfSTNAOwsnXlhsvIn4iW_Y_N3TNebF3YoZs,5700
|
14
|
-
ripple_down_rules/user_interface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
|
-
ripple_down_rules/user_interface/gui.py,sha256=druufu9cVeVUajPW-RqGW3ZiEbgdgNBQD2CLhadQo18,27486
|
16
|
-
ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=yp-F8YRWGhj1PLB33HE6vJkdYWFN5Zn2244S2DUWRTM,6576
|
17
|
-
ripple_down_rules/user_interface/object_diagram.py,sha256=FEa2HaYR9QmTE6NsOwBvZ0jqmu3DKyg6mig2VE5ZP4Y,4956
|
18
|
-
ripple_down_rules/user_interface/prompt.py,sha256=JceEUGYsd0lIvd-v2y3D3swoo96_C0lxfp3CxM7Vfts,8900
|
19
|
-
ripple_down_rules/user_interface/template_file_creator.py,sha256=kwBbFLyN6Yx2NTIHPSwOoytWgbJDYhgrUOVFw_jkDQ4,13522
|
20
|
-
ripple_down_rules-0.6.23.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
21
|
-
ripple_down_rules-0.6.23.dist-info/METADATA,sha256=5B9P3h5EyJEZvyuj7CtFVNeJxQOg8kjG7pQJfoVlAqs,48294
|
22
|
-
ripple_down_rules-0.6.23.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
23
|
-
ripple_down_rules-0.6.23.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
24
|
-
ripple_down_rules-0.6.23.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|