ripple-down-rules 0.6.0__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +21 -1
- ripple_down_rules/datastructures/callable_expression.py +24 -7
- ripple_down_rules/datastructures/case.py +12 -11
- ripple_down_rules/datastructures/dataclasses.py +135 -14
- ripple_down_rules/datastructures/enums.py +29 -86
- ripple_down_rules/datastructures/field_info.py +177 -0
- ripple_down_rules/datastructures/tracked_object.py +208 -0
- ripple_down_rules/experts.py +141 -50
- ripple_down_rules/failures.py +4 -0
- ripple_down_rules/helpers.py +75 -8
- ripple_down_rules/predicates.py +97 -0
- ripple_down_rules/rdr.py +712 -96
- ripple_down_rules/rdr_decorators.py +164 -112
- ripple_down_rules/rules.py +351 -114
- ripple_down_rules/user_interface/gui.py +66 -41
- ripple_down_rules/user_interface/ipython_custom_shell.py +46 -9
- ripple_down_rules/user_interface/prompt.py +80 -60
- ripple_down_rules/user_interface/template_file_creator.py +13 -8
- ripple_down_rules/utils.py +537 -53
- {ripple_down_rules-0.6.0.dist-info → ripple_down_rules-0.6.6.dist-info}/METADATA +4 -1
- ripple_down_rules-0.6.6.dist-info/RECORD +28 -0
- ripple_down_rules-0.6.0.dist-info/RECORD +0 -24
- {ripple_down_rules-0.6.0.dist-info → ripple_down_rules-0.6.6.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.6.0.dist-info → ripple_down_rules-0.6.6.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.6.0.dist-info → ripple_down_rules-0.6.6.dist-info}/top_level.txt +0 -0
ripple_down_rules/experts.py
CHANGED
@@ -1,19 +1,23 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import ast
|
4
|
+
import importlib
|
4
5
|
import json
|
5
|
-
import logging
|
6
6
|
import os
|
7
|
-
import
|
7
|
+
import sys
|
8
8
|
from abc import ABC, abstractmethod
|
9
|
+
from dataclasses import is_dataclass
|
10
|
+
from textwrap import dedent, indent
|
11
|
+
from typing import Tuple, Dict
|
9
12
|
|
10
13
|
from typing_extensions import Optional, TYPE_CHECKING, List
|
11
14
|
|
12
15
|
from .datastructures.callable_expression import CallableExpression
|
13
|
-
from .datastructures.enums import PromptFor
|
14
|
-
from .datastructures.dataclasses import CaseQuery
|
15
16
|
from .datastructures.case import show_current_and_corner_cases
|
16
|
-
from .
|
17
|
+
from .datastructures.dataclasses import CaseQuery
|
18
|
+
from .datastructures.enums import PromptFor
|
19
|
+
from .user_interface.template_file_creator import TemplateFileCreator
|
20
|
+
from .utils import extract_imports, extract_function_source, get_imports_from_scope, get_class_file_path
|
17
21
|
|
18
22
|
try:
|
19
23
|
from .user_interface.gui import RDRCaseViewer
|
@@ -41,19 +45,18 @@ class Expert(ABC):
|
|
41
45
|
A flag to indicate if the expert should use loaded answers or not.
|
42
46
|
"""
|
43
47
|
|
44
|
-
def __init__(self, use_loaded_answers: bool =
|
48
|
+
def __init__(self, use_loaded_answers: bool = False,
|
45
49
|
append: bool = False,
|
46
50
|
answers_save_path: Optional[str] = None):
|
47
51
|
self.all_expert_answers = []
|
48
52
|
self.use_loaded_answers = use_loaded_answers
|
49
|
-
self.append = append
|
50
53
|
self.answers_save_path = answers_save_path
|
51
|
-
if answers_save_path is not None:
|
54
|
+
if answers_save_path is not None and os.path.exists(answers_save_path + '.py'):
|
52
55
|
if use_loaded_answers:
|
53
56
|
self.load_answers(answers_save_path)
|
54
|
-
|
57
|
+
if not append:
|
55
58
|
os.remove(answers_save_path + '.py')
|
56
|
-
|
59
|
+
self.append = True
|
57
60
|
|
58
61
|
@abstractmethod
|
59
62
|
def ask_for_conditions(self, case_query: CaseQuery, last_evaluated_rule: Optional[Rule] = None) \
|
@@ -86,31 +89,29 @@ class Expert(ABC):
|
|
86
89
|
"""
|
87
90
|
if path is None and self.answers_save_path is None:
|
88
91
|
raise ValueError("No path provided to clear expert answers, either provide a path or set the "
|
89
|
-
|
92
|
+
"answers_save_path attribute.")
|
90
93
|
if path is None:
|
91
94
|
path = self.answers_save_path
|
92
|
-
if os.path.exists(path + '.json'):
|
93
|
-
os.remove(path + '.json')
|
94
95
|
if os.path.exists(path + '.py'):
|
95
96
|
os.remove(path + '.py')
|
96
97
|
self.all_expert_answers = []
|
97
98
|
|
98
|
-
def save_answers(self, path: Optional[str] = None):
|
99
|
+
def save_answers(self, path: Optional[str] = None, expert_answers: Optional[List[Tuple[Dict, str]]] = None):
|
99
100
|
"""
|
100
101
|
Save the expert answers to a file.
|
101
102
|
|
102
103
|
:param path: The path to save the answers to.
|
104
|
+
:param expert_answers: The expert answers to save.
|
103
105
|
"""
|
106
|
+
expert_answers = expert_answers if expert_answers else self.all_expert_answers
|
107
|
+
if not any(expert_answers):
|
108
|
+
return
|
104
109
|
if path is None and self.answers_save_path is None:
|
105
110
|
raise ValueError("No path provided to save expert answers, either provide a path or set the "
|
106
|
-
|
111
|
+
"answers_save_path attribute.")
|
107
112
|
if path is None:
|
108
113
|
path = self.answers_save_path
|
109
|
-
|
110
|
-
if is_json:
|
111
|
-
self._save_to_json(path)
|
112
|
-
else:
|
113
|
-
self._save_to_python(path)
|
114
|
+
self._save_to_python(path, expert_answers=expert_answers)
|
114
115
|
|
115
116
|
def _save_to_json(self, path: str):
|
116
117
|
"""
|
@@ -127,12 +128,14 @@ class Expert(ABC):
|
|
127
128
|
with open(path + '.json', "w") as f:
|
128
129
|
json.dump(all_answers, f)
|
129
130
|
|
130
|
-
def _save_to_python(self, path: str):
|
131
|
+
def _save_to_python(self, path: str, expert_answers: Optional[List[Tuple[Dict, str]]] = None):
|
131
132
|
"""
|
132
133
|
Save the expert answers to a Python file.
|
133
134
|
|
134
135
|
:param path: The path to save the answers to.
|
136
|
+
:param expert_answers: The expert answers to save.
|
135
137
|
"""
|
138
|
+
expert_answers = expert_answers if expert_answers else self.all_expert_answers
|
136
139
|
dir_name = os.path.dirname(path)
|
137
140
|
if not os.path.exists(dir_name + '/__init__.py'):
|
138
141
|
os.makedirs(dir_name, exist_ok=True)
|
@@ -145,18 +148,13 @@ class Expert(ABC):
|
|
145
148
|
current_file_data = f.read()
|
146
149
|
action = 'a' if self.append and current_file_data is not None else 'w'
|
147
150
|
with open(path + '.py', action) as f:
|
148
|
-
for scope, func_source in
|
151
|
+
for scope, func_source in expert_answers:
|
149
152
|
if len(scope) > 0:
|
150
153
|
imports = '\n'.join(get_imports_from_scope(scope)) + '\n\n\n'
|
151
154
|
else:
|
152
155
|
imports = ''
|
153
|
-
if func_source is
|
154
|
-
uid = uuid.uuid4().hex
|
155
|
-
func_source = encapsulate_user_input(func_source, CallableExpression.get_encapsulating_function(f'_{uid}'))
|
156
|
-
else:
|
156
|
+
if func_source is None:
|
157
157
|
func_source = 'pass # No user input provided for this case.\n'
|
158
|
-
if current_file_data is not None and func_source[1:] in current_file_data:
|
159
|
-
continue
|
160
158
|
f.write(imports + func_source + '\n' + '\n\n\n\'===New Answer===\'\n\n\n')
|
161
159
|
|
162
160
|
def load_answers(self, path: Optional[str] = None):
|
@@ -167,14 +165,13 @@ class Expert(ABC):
|
|
167
165
|
"""
|
168
166
|
if path is None and self.answers_save_path is None:
|
169
167
|
raise ValueError("No path provided to load expert answers from, either provide a path or set the "
|
170
|
-
|
168
|
+
"answers_save_path attribute.")
|
171
169
|
if path is None:
|
172
170
|
path = self.answers_save_path
|
173
|
-
|
174
|
-
if is_json:
|
175
|
-
self._load_answers_from_json(path)
|
176
|
-
elif os.path.exists(path + '.py'):
|
171
|
+
if os.path.exists(path + '.py'):
|
177
172
|
self._load_answers_from_python(path)
|
173
|
+
elif os.path.exists(path + '.json'):
|
174
|
+
self._load_answers_from_json(path)
|
178
175
|
|
179
176
|
def _load_answers_from_json(self, path: str):
|
180
177
|
"""
|
@@ -195,40 +192,114 @@ class Expert(ABC):
|
|
195
192
|
file_path = path + '.py'
|
196
193
|
with open(file_path, "r") as f:
|
197
194
|
all_answers = f.read().split('\n\n\n\'===New Answer===\'\n\n\n')[:-1]
|
198
|
-
all_function_sources =
|
199
|
-
all_function_sources_names = list(extract_function_source(file_path, []).keys())
|
195
|
+
all_function_sources = extract_function_source(file_path, [], as_list=True)
|
200
196
|
for i, answer in enumerate(all_answers):
|
201
197
|
answer = answer.strip('\n').strip()
|
202
198
|
if 'def ' not in answer and 'pass' in answer:
|
203
199
|
self.all_expert_answers.append(({}, None))
|
204
200
|
continue
|
205
201
|
scope = extract_imports(tree=ast.parse(answer))
|
206
|
-
|
202
|
+
func_name = all_function_sources[i].split('def ')[1].split('(')[0]
|
203
|
+
function_source = all_function_sources[i].replace(func_name,
|
207
204
|
CallableExpression.encapsulating_function_name)
|
208
205
|
self.all_expert_answers.append((scope, function_source))
|
209
206
|
|
210
207
|
|
208
|
+
class AI(Expert):
|
209
|
+
"""
|
210
|
+
The AI Expert class, an expert that uses AI to provide differentiating features and conclusions.
|
211
|
+
"""
|
212
|
+
|
213
|
+
def __init__(self, **kwargs):
|
214
|
+
"""
|
215
|
+
Initialize the AI expert.
|
216
|
+
"""
|
217
|
+
super().__init__(**kwargs)
|
218
|
+
self.user_prompt = UserPrompt()
|
219
|
+
|
220
|
+
def ask_for_conditions(self, case_query: CaseQuery,
|
221
|
+
last_evaluated_rule: Optional[Rule] = None) \
|
222
|
+
-> CallableExpression:
|
223
|
+
prompt_str = self.get_prompt_for_ai(case_query, PromptFor.Conditions)
|
224
|
+
print(prompt_str)
|
225
|
+
sys.exit()
|
226
|
+
|
227
|
+
def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
|
228
|
+
prompt_str = self.get_prompt_for_ai(case_query, PromptFor.Conclusion)
|
229
|
+
output_type_source = self.get_output_type_class_source(case_query)
|
230
|
+
prompt_str = f"\n\n\nOutput type(s) class source:\n{output_type_source}\n\n" + prompt_str
|
231
|
+
print(prompt_str)
|
232
|
+
sys.exit()
|
233
|
+
|
234
|
+
def get_output_type_class_source(self, case_query: CaseQuery) -> str:
|
235
|
+
"""
|
236
|
+
Get the output type class source for the AI expert.
|
237
|
+
|
238
|
+
:param case_query: The case query containing the case to classify.
|
239
|
+
:return: The output type class source.
|
240
|
+
"""
|
241
|
+
output_types = case_query.core_attribute_type
|
242
|
+
|
243
|
+
def get_class_source(cls):
|
244
|
+
cls_source_file = get_class_file_path(cls)
|
245
|
+
found_class_source = extract_function_source(cls_source_file, function_names=[cls.__name__],
|
246
|
+
is_class=True,
|
247
|
+
as_list=True)[0]
|
248
|
+
class_signature = found_class_source.split('\n')[0]
|
249
|
+
if '(' in class_signature:
|
250
|
+
parent_class_names = list(map(lambda x: x.strip(),
|
251
|
+
class_signature.split('(')[1].split(')')[0].split(',')))
|
252
|
+
parent_classes = [importlib.import_module(cls.__module__).__dict__.get(cls_name.strip())
|
253
|
+
for cls_name in parent_class_names]
|
254
|
+
else:
|
255
|
+
parent_classes = []
|
256
|
+
if is_dataclass(cls):
|
257
|
+
found_class_source = f"@dataclass\n{found_class_source}"
|
258
|
+
return '\n'.join([get_class_source(pcls) for pcls in parent_classes] + [found_class_source])
|
259
|
+
|
260
|
+
found_class_sources = []
|
261
|
+
for output_type in output_types:
|
262
|
+
found_class_sources.append(get_class_source(output_type))
|
263
|
+
found_class_sources = '\n\n\n'.join(found_class_sources)
|
264
|
+
return found_class_sources
|
265
|
+
|
266
|
+
def get_prompt_for_ai(self, case_query: CaseQuery, prompt_for: PromptFor) -> str:
|
267
|
+
"""
|
268
|
+
Get the prompt for the AI expert.
|
269
|
+
|
270
|
+
:param case_query: The case query containing the case to classify.
|
271
|
+
:param prompt_for: The type of prompt to get.
|
272
|
+
:return: The prompt for the AI expert.
|
273
|
+
"""
|
274
|
+
# data_to_show = show_current_and_corner_cases(case_query.case)
|
275
|
+
data_to_show = f"\nCase ({case_query.case_name}):\n {case_query.case.__dict__}"
|
276
|
+
template_file_creator = TemplateFileCreator(case_query, prompt_for=prompt_for)
|
277
|
+
boilerplate_code = template_file_creator.build_boilerplate_code()
|
278
|
+
initial_prompt_str = data_to_show + "\n\n" + boilerplate_code + "\n\n"
|
279
|
+
return self.user_prompt.build_prompt_str_for_ai(case_query, prompt_for=prompt_for,
|
280
|
+
initial_prompt_str=initial_prompt_str)
|
281
|
+
|
282
|
+
|
211
283
|
class Human(Expert):
|
212
284
|
"""
|
213
285
|
The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
|
214
286
|
"""
|
215
287
|
|
216
|
-
def __init__(self,
|
288
|
+
def __init__(self, **kwargs):
|
217
289
|
"""
|
218
290
|
Initialize the Human expert.
|
219
|
-
|
220
|
-
:param viewer: The RDRCaseViewer instance to use for prompting the user.
|
221
291
|
"""
|
222
292
|
super().__init__(**kwargs)
|
223
|
-
self.user_prompt = UserPrompt(
|
293
|
+
self.user_prompt = UserPrompt()
|
224
294
|
|
225
295
|
def ask_for_conditions(self, case_query: CaseQuery,
|
226
296
|
last_evaluated_rule: Optional[Rule] = None) \
|
227
297
|
-> CallableExpression:
|
228
298
|
data_to_show = None
|
229
299
|
if (not self.use_loaded_answers or len(self.all_expert_answers) == 0) and self.user_prompt.viewer is None:
|
230
|
-
data_to_show = show_current_and_corner_cases(case_query.case,
|
231
|
-
|
300
|
+
data_to_show = show_current_and_corner_cases(case_query.case,
|
301
|
+
{case_query.attribute_name: case_query.target_value},
|
302
|
+
last_evaluated_rule=last_evaluated_rule)
|
232
303
|
return self._get_conditions(case_query, data_to_show)
|
233
304
|
|
234
305
|
def _get_conditions(self, case_query: CaseQuery, data_to_show: Optional[str] = None) \
|
@@ -251,10 +322,13 @@ class Human(Expert):
|
|
251
322
|
if user_input is not None:
|
252
323
|
case_query.scope.update(loaded_scope)
|
253
324
|
condition = CallableExpression(user_input, bool, scope=case_query.scope)
|
325
|
+
if self.answers_save_path is not None and not any(loaded_scope):
|
326
|
+
self.convert_json_answer_to_python_answer(case_query, user_input, condition, PromptFor.Conditions)
|
254
327
|
else:
|
255
|
-
user_input, condition = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conditions,
|
256
|
-
|
257
|
-
|
328
|
+
user_input, condition = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conditions,
|
329
|
+
prompt_str=data_to_show)
|
330
|
+
if user_input in ['exit', 'quit']:
|
331
|
+
sys.exit()
|
258
332
|
if not self.use_loaded_answers:
|
259
333
|
self.all_expert_answers.append((condition.scope, user_input))
|
260
334
|
if self.answers_save_path is not None:
|
@@ -262,6 +336,19 @@ class Human(Expert):
|
|
262
336
|
case_query.conditions = condition
|
263
337
|
return condition
|
264
338
|
|
339
|
+
def convert_json_answer_to_python_answer(self, case_query: CaseQuery, user_input: str,
|
340
|
+
callable_expression: CallableExpression,
|
341
|
+
prompt_for: PromptFor):
|
342
|
+
tfc = TemplateFileCreator(case_query, prompt_for=prompt_for)
|
343
|
+
code = tfc.build_boilerplate_code()
|
344
|
+
if user_input.startswith('def'):
|
345
|
+
user_input = '\n'.join(user_input.split('\n')[1:])
|
346
|
+
user_input = indent(dedent(user_input), " " * 4).strip()
|
347
|
+
code = code.replace('pass', user_input)
|
348
|
+
else:
|
349
|
+
code = code.replace('pass', f"return {user_input}")
|
350
|
+
self.save_answers(expert_answers=[({}, code)])
|
351
|
+
|
265
352
|
def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
|
266
353
|
"""
|
267
354
|
Ask the expert to provide a conclusion for the case.
|
@@ -281,20 +368,24 @@ class Human(Expert):
|
|
281
368
|
expression = CallableExpression(expert_input, case_query.attribute_type,
|
282
369
|
scope=case_query.scope,
|
283
370
|
mutually_exclusive=case_query.mutually_exclusive)
|
371
|
+
if self.answers_save_path is not None and not any(loaded_scope):
|
372
|
+
self.convert_json_answer_to_python_answer(case_query, expert_input, expression,
|
373
|
+
PromptFor.Conclusion)
|
284
374
|
except IndexError:
|
285
375
|
self.use_loaded_answers = False
|
286
376
|
if not self.use_loaded_answers:
|
287
377
|
data_to_show = None
|
288
378
|
if self.user_prompt.viewer is None:
|
289
379
|
data_to_show = show_current_and_corner_cases(case_query.case)
|
290
|
-
expert_input, expression = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conclusion,
|
380
|
+
expert_input, expression = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conclusion,
|
381
|
+
prompt_str=data_to_show)
|
291
382
|
if expert_input is None:
|
292
383
|
self.all_expert_answers.append(({}, None))
|
293
|
-
elif expert_input
|
384
|
+
elif expert_input not in ['exit', 'quit']:
|
294
385
|
self.all_expert_answers.append((expression.scope, expert_input))
|
295
|
-
if self.answers_save_path is not None and expert_input
|
386
|
+
if self.answers_save_path is not None and expert_input not in ['exit', 'quit']:
|
296
387
|
self.save_answers()
|
297
|
-
if expert_input
|
298
|
-
exit()
|
388
|
+
if expert_input in ['exit', 'quit']:
|
389
|
+
sys.exit()
|
299
390
|
case_query.target = expression
|
300
391
|
return expression
|
ripple_down_rules/helpers.py
CHANGED
@@ -1,16 +1,18 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import importlib
|
3
4
|
import os
|
5
|
+
import sys
|
6
|
+
from functools import wraps
|
4
7
|
from types import ModuleType
|
8
|
+
from typing import Tuple, Callable, Dict, Any, Optional
|
5
9
|
|
6
|
-
from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
|
7
|
-
|
8
|
-
from .datastructures.case import create_case
|
9
|
-
from .datastructures.dataclasses import CaseQuery
|
10
10
|
from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING, Union
|
11
11
|
|
12
|
+
from .datastructures.case import create_case, Case
|
13
|
+
from .datastructures.dataclasses import CaseQuery
|
14
|
+
from .utils import calculate_precision_and_recall, get_method_args_as_dict, get_func_rdr_model_name
|
12
15
|
from .utils import get_func_rdr_model_name, copy_case, make_set, update_case
|
13
|
-
from .utils import calculate_precision_and_recall
|
14
16
|
|
15
17
|
if TYPE_CHECKING:
|
16
18
|
from .rdr import RippleDownRules
|
@@ -55,12 +57,14 @@ def general_rdr_classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDow
|
|
55
57
|
if attribute_name in new_conclusions:
|
56
58
|
temp_case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, rdr.mutually_exclusive)
|
57
59
|
update_case(temp_case_query, new_conclusions)
|
58
|
-
if len(new_conclusions) == 0
|
60
|
+
if len(new_conclusions) == 0 or len(classifiers_dict) == 1 and list(classifiers_dict.values())[
|
61
|
+
0].mutually_exclusive:
|
59
62
|
break
|
60
63
|
return conclusions
|
61
64
|
|
62
65
|
|
63
|
-
def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery,
|
66
|
+
def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery,
|
67
|
+
pred_cat: Optional[Dict[str, Any]] = None) -> bool:
|
64
68
|
"""
|
65
69
|
:param classifier: The RDR classifier to check the prediction of.
|
66
70
|
:param case_query: The case query to check.
|
@@ -89,9 +93,72 @@ def load_or_create_func_rdr_model(func, model_dir: str, rdr_type: Type[RippleDow
|
|
89
93
|
:param rdr_kwargs: Additional arguments to pass to the RDR constructor in the case of a new model.
|
90
94
|
"""
|
91
95
|
model_name = get_func_rdr_model_name(func)
|
92
|
-
model_path = os.path.join(model_dir, model_name,
|
96
|
+
model_path = os.path.join(model_dir, model_name, f"{model_name}.py")
|
93
97
|
if os.path.exists(model_path):
|
94
98
|
rdr = rdr_type.load(load_dir=model_dir, model_name=model_name)
|
95
99
|
else:
|
96
100
|
rdr = rdr_type(**rdr_kwargs)
|
97
101
|
return rdr
|
102
|
+
|
103
|
+
|
104
|
+
def get_an_updated_case_copy(case: Case, conclusion: Callable, attribute_name: str, conclusion_type: Tuple[Type, ...],
|
105
|
+
mutually_exclusive: bool) -> Case:
|
106
|
+
"""
|
107
|
+
:param case: The case to copy and update.
|
108
|
+
:param conclusion: The conclusion to add to the case.
|
109
|
+
:param attribute_name: The name of the attribute to update.
|
110
|
+
:param conclusion_type: The type of the conclusion to update.
|
111
|
+
:param mutually_exclusive: Whether the rule belongs to a mutually exclusive RDR.
|
112
|
+
:return: A copy of the case updated with the given conclusion.
|
113
|
+
"""
|
114
|
+
case_cp = copy_case(case)
|
115
|
+
temp_case_query = CaseQuery(case_cp, attribute_name, conclusion_type,
|
116
|
+
mutually_exclusive=mutually_exclusive)
|
117
|
+
output = conclusion(case_cp)
|
118
|
+
if not isinstance(output, Dict):
|
119
|
+
output = {attribute_name: output}
|
120
|
+
update_case(temp_case_query, output)
|
121
|
+
return case_cp
|
122
|
+
|
123
|
+
def enable_gui():
|
124
|
+
"""
|
125
|
+
Enable the GUI for Ripple Down Rules if available.
|
126
|
+
"""
|
127
|
+
try:
|
128
|
+
from .user_interface.gui import RDRCaseViewer
|
129
|
+
viewer = RDRCaseViewer()
|
130
|
+
except ImportError:
|
131
|
+
pass
|
132
|
+
|
133
|
+
|
134
|
+
def create_case_from_method(func: Callable,
|
135
|
+
func_output: Dict[str, Any],
|
136
|
+
*args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
|
137
|
+
"""
|
138
|
+
Create a Case from the function and its arguments.
|
139
|
+
|
140
|
+
:param func: The function to create a case from.
|
141
|
+
:param func_output: A dictionary containing the output of the function, where the key is the output name.
|
142
|
+
:param args: The positional arguments of the function.
|
143
|
+
:param kwargs: The keyword arguments of the function.
|
144
|
+
:return: A Case object representing the case.
|
145
|
+
"""
|
146
|
+
case_dict = get_method_args_as_dict(func, *args, **kwargs)
|
147
|
+
case_dict.update(func_output)
|
148
|
+
case_name = get_func_rdr_model_name(func)
|
149
|
+
return Case(dict, id(case_dict), case_name, case_dict, **case_dict), case_dict
|
150
|
+
|
151
|
+
|
152
|
+
class MockRDRDecorator:
|
153
|
+
def __init__(self, models_dir: str):
|
154
|
+
self.models_dir = models_dir
|
155
|
+
def decorator(self, func: Callable) -> Callable:
|
156
|
+
@wraps(func)
|
157
|
+
def wrapper(*args, **kwargs) -> Optional[Any]:
|
158
|
+
model_dir = get_func_rdr_model_name(func, include_file_name=True)
|
159
|
+
model_name = get_func_rdr_model_name(func, include_file_name=False)
|
160
|
+
rdr = importlib.import_module(os.path.join(self.models_dir, model_dir, f"{model_name}_rdr.py"))
|
161
|
+
func_output = {"output_": func(*args, **kwargs)}
|
162
|
+
case, case_dict = create_case_from_method(func, func_output, *args, **kwargs)
|
163
|
+
return rdr.classify(case)
|
164
|
+
return wrapper
|
@@ -0,0 +1,97 @@
|
|
1
|
+
import os.path
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from enum import Enum
|
5
|
+
from os.path import dirname
|
6
|
+
|
7
|
+
from typing_extensions import Type, ClassVar, TYPE_CHECKING, Tuple, Optional, Callable
|
8
|
+
|
9
|
+
from .datastructures.tracked_object import TrackedObjectMixin, Direction, Relation
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from .rdr_decorators import RDRDecorator
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass(eq=False)
|
16
|
+
class Predicate(TrackedObjectMixin, ABC):
|
17
|
+
|
18
|
+
def __call__(self, *args, **kwargs):
|
19
|
+
return self.evaluate(*args, **kwargs)
|
20
|
+
|
21
|
+
@classmethod
|
22
|
+
@abstractmethod
|
23
|
+
def evaluate(cls, *args, **kwargs):
|
24
|
+
"""
|
25
|
+
Evaluate the predicate with the given arguments.
|
26
|
+
This method should be implemented by subclasses.
|
27
|
+
"""
|
28
|
+
pass
|
29
|
+
|
30
|
+
@classmethod
|
31
|
+
def rdr_decorator(cls, output_types: Tuple[Type, ...], mutually_exclusive: bool,
|
32
|
+
package_name: Optional[str] = None) -> Callable[[Callable], Callable]:
|
33
|
+
"""
|
34
|
+
Returns the RDRDecorator to decorate the predicate evaluate method with.
|
35
|
+
"""
|
36
|
+
rdr_decorator: RDRDecorator = RDRDecorator(cls.models_dir, output_types, mutually_exclusive,
|
37
|
+
package_name=package_name)
|
38
|
+
return rdr_decorator.decorator
|
39
|
+
|
40
|
+
def __hash__(self):
|
41
|
+
return hash(self.__class__.__name__)
|
42
|
+
|
43
|
+
def __eq__(self, other):
|
44
|
+
if not isinstance(other, Predicate):
|
45
|
+
return False
|
46
|
+
return self.__class__ == other.__class__
|
47
|
+
|
48
|
+
|
49
|
+
@dataclass
|
50
|
+
class IsA(Predicate):
|
51
|
+
"""
|
52
|
+
A predicate that checks if an object type is a subclass of another object type.
|
53
|
+
"""
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def evaluate(cls, child_type: Type[TrackedObjectMixin], parent_type: Type[TrackedObjectMixin]) -> bool:
|
57
|
+
return issubclass(child_type, parent_type)
|
58
|
+
|
59
|
+
isA = IsA()
|
60
|
+
|
61
|
+
|
62
|
+
@dataclass
|
63
|
+
class Has(Predicate):
|
64
|
+
"""
|
65
|
+
A predicate that checks if an object type has a certain member object type.
|
66
|
+
"""
|
67
|
+
|
68
|
+
@classmethod
|
69
|
+
def evaluate(cls, owner_type: Type[TrackedObjectMixin],
|
70
|
+
member_type: Type[TrackedObjectMixin], recursive: bool = False) -> bool:
|
71
|
+
neighbors = cls._dependency_graph.adj_direction(owner_type._my_graph_idx(), Direction.OUTBOUND.value)
|
72
|
+
curr_val = any(e == Relation.has and isA(cls._dependency_graph.get_node_data(n), member_type)
|
73
|
+
or e == Relation.isA and cls.evaluate(cls._dependency_graph.get_node_data(n), member_type)
|
74
|
+
for n, e in neighbors.items())
|
75
|
+
if recursive:
|
76
|
+
return curr_val or any((e == Relation.has
|
77
|
+
and cls.evaluate(cls._dependency_graph.get_node_data(n), member_type, recursive=True))
|
78
|
+
for n, e in neighbors.items())
|
79
|
+
else:
|
80
|
+
return curr_val
|
81
|
+
|
82
|
+
has = Has()
|
83
|
+
|
84
|
+
|
85
|
+
@dataclass
|
86
|
+
class DependsOn(Predicate):
|
87
|
+
"""
|
88
|
+
A predicate that checks if an object type depends on another object type.
|
89
|
+
"""
|
90
|
+
|
91
|
+
@classmethod
|
92
|
+
def evaluate(cls, dependent: Type[TrackedObjectMixin],
|
93
|
+
dependency: Type[TrackedObjectMixin], recursive: bool = False) -> bool:
|
94
|
+
raise NotImplementedError("Should be overridden in rdr meta")
|
95
|
+
|
96
|
+
|
97
|
+
dependsOn = DependsOn()
|