ripple-down-rules 0.4.88__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/datastructures/callable_expression.py +20 -1
- ripple_down_rules/datastructures/dataclasses.py +9 -1
- ripple_down_rules/experts.py +189 -32
- ripple_down_rules/rdr.py +162 -76
- ripple_down_rules/rdr_decorators.py +73 -52
- ripple_down_rules/rules.py +5 -4
- ripple_down_rules/user_interface/template_file_creator.py +6 -6
- ripple_down_rules/utils.py +3 -6
- {ripple_down_rules-0.4.88.dist-info → ripple_down_rules-0.5.0.dist-info}/METADATA +1 -1
- ripple_down_rules-0.5.0.dist-info/RECORD +26 -0
- ripple_down_rules-0.4.88.dist-info/RECORD +0 -26
- {ripple_down_rules-0.4.88.dist-info → ripple_down_rules-0.5.0.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.4.88.dist-info → ripple_down_rules-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.4.88.dist-info → ripple_down_rules-0.5.0.dist-info}/top_level.txt +0 -0
ripple_down_rules/__init__.py
CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import ast
|
4
4
|
import logging
|
5
|
+
import os
|
5
6
|
from _ast import AST
|
6
7
|
from enum import Enum
|
7
8
|
|
@@ -10,7 +11,7 @@ from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
|
|
10
11
|
from .case import create_case, Case
|
11
12
|
from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string, conclusion_to_json, is_iterable, \
|
12
13
|
build_user_input_from_conclusion, encapsulate_user_input, extract_function_source, are_results_subclass_of_types, \
|
13
|
-
make_list
|
14
|
+
make_list, get_imports_from_scope
|
14
15
|
|
15
16
|
|
16
17
|
class VariableVisitor(ast.NodeVisitor):
|
@@ -175,6 +176,24 @@ class CallableExpression(SubclassJSONSerializer):
|
|
175
176
|
return
|
176
177
|
self.user_input = self.encapsulating_function + '\n' + new_function_body
|
177
178
|
|
179
|
+
def write_to_python_file(self, file_path: str, append: bool = False):
|
180
|
+
"""
|
181
|
+
Write the callable expression to a python file.
|
182
|
+
|
183
|
+
:param file_path: The path to the file where the callable expression will be written.
|
184
|
+
:param append: If True, the callable expression will be appended to the file. If False,
|
185
|
+
the file will be overwritten.
|
186
|
+
"""
|
187
|
+
imports = '\n'.join(get_imports_from_scope(self.scope))
|
188
|
+
if append and os.path.exists(file_path):
|
189
|
+
with open(file_path, 'a') as f:
|
190
|
+
f.write('\n\n\n' + imports + '\n\n\n')
|
191
|
+
f.write(self.user_input)
|
192
|
+
else:
|
193
|
+
with open(file_path, 'w') as f:
|
194
|
+
f.write(imports + '\n\n\n')
|
195
|
+
f.write(self.user_input)
|
196
|
+
|
178
197
|
@property
|
179
198
|
def user_input(self):
|
180
199
|
"""
|
@@ -78,7 +78,15 @@ class CaseQuery:
|
|
78
78
|
"""
|
79
79
|
:return: The type of the case that the attribute belongs to.
|
80
80
|
"""
|
81
|
-
|
81
|
+
if self.is_function:
|
82
|
+
if self.function_args_type_hints is not None:
|
83
|
+
func_args = [arg for name, arg in self.function_args_type_hints.items() if name != 'return']
|
84
|
+
case_type_args = Union[tuple(func_args)]
|
85
|
+
else:
|
86
|
+
case_type_args = Any
|
87
|
+
return Dict[str, case_type_args]
|
88
|
+
else:
|
89
|
+
return self.original_case._obj_type if isinstance(self.original_case, Case) else type(self.original_case)
|
82
90
|
|
83
91
|
@property
|
84
92
|
def case(self) -> Any:
|
ripple_down_rules/experts.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import ast
|
3
4
|
import json
|
4
5
|
import logging
|
6
|
+
import os
|
5
7
|
from abc import ABC, abstractmethod
|
6
8
|
|
7
9
|
from typing_extensions import Optional, TYPE_CHECKING, List
|
@@ -10,6 +12,8 @@ from .datastructures.callable_expression import CallableExpression
|
|
10
12
|
from .datastructures.enums import PromptFor
|
11
13
|
from .datastructures.dataclasses import CaseQuery
|
12
14
|
from .datastructures.case import show_current_and_corner_cases
|
15
|
+
from .utils import extract_imports, extract_function_source, get_imports_from_scope, encapsulate_user_input
|
16
|
+
|
13
17
|
try:
|
14
18
|
from .user_interface.gui import RDRCaseViewer
|
15
19
|
except ImportError as e:
|
@@ -36,10 +40,19 @@ class Expert(ABC):
|
|
36
40
|
A flag to indicate if the expert should use loaded answers or not.
|
37
41
|
"""
|
38
42
|
|
39
|
-
def __init__(self, use_loaded_answers: bool =
|
43
|
+
def __init__(self, use_loaded_answers: bool = True,
|
44
|
+
append: bool = False,
|
45
|
+
answers_save_path: Optional[str] = None):
|
40
46
|
self.all_expert_answers = []
|
41
47
|
self.use_loaded_answers = use_loaded_answers
|
42
48
|
self.append = append
|
49
|
+
self.answers_save_path = answers_save_path
|
50
|
+
if answers_save_path is not None:
|
51
|
+
if use_loaded_answers:
|
52
|
+
self.load_answers(answers_save_path)
|
53
|
+
else:
|
54
|
+
os.remove(answers_save_path + '.py')
|
55
|
+
self.append = True
|
43
56
|
|
44
57
|
@abstractmethod
|
45
58
|
def ask_for_conditions(self, case_query: CaseQuery, last_evaluated_rule: Optional[Rule] = None) \
|
@@ -63,46 +76,138 @@ class Expert(ABC):
|
|
63
76
|
:return: A callable expression that can be called with a new case as an argument.
|
64
77
|
"""
|
65
78
|
|
79
|
+
def clear_answers(self, path: Optional[str] = None):
|
80
|
+
"""
|
81
|
+
Clear the expert answers.
|
66
82
|
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
83
|
+
:param path: The path to clear the answers from. If None, the answers will be cleared from the
|
84
|
+
answers_save_path attribute.
|
85
|
+
"""
|
86
|
+
if path is None and self.answers_save_path is None:
|
87
|
+
raise ValueError("No path provided to clear expert answers, either provide a path or set the "
|
88
|
+
"answers_save_path attribute.")
|
89
|
+
if path is None:
|
90
|
+
path = self.answers_save_path
|
91
|
+
if os.path.exists(path + '.json'):
|
92
|
+
os.remove(path + '.json')
|
93
|
+
if os.path.exists(path + '.py'):
|
94
|
+
os.remove(path + '.py')
|
95
|
+
self.all_expert_answers = []
|
71
96
|
|
72
|
-
def
|
97
|
+
def save_answers(self, path: Optional[str] = None):
|
73
98
|
"""
|
74
|
-
|
99
|
+
Save the expert answers to a file.
|
75
100
|
|
76
|
-
:param
|
101
|
+
:param path: The path to save the answers to.
|
77
102
|
"""
|
78
|
-
|
79
|
-
|
103
|
+
if path is None and self.answers_save_path is None:
|
104
|
+
raise ValueError("No path provided to save expert answers, either provide a path or set the "
|
105
|
+
"answers_save_path attribute.")
|
106
|
+
if path is None:
|
107
|
+
path = self.answers_save_path
|
108
|
+
is_json = os.path.exists(path + '.json')
|
109
|
+
if is_json:
|
110
|
+
self._save_to_json(path)
|
111
|
+
else:
|
112
|
+
self._save_to_python(path)
|
80
113
|
|
81
|
-
def
|
114
|
+
def _save_to_json(self, path: str):
|
82
115
|
"""
|
83
|
-
Save the expert answers to a file.
|
116
|
+
Save the expert answers to a JSON file.
|
84
117
|
|
85
118
|
:param path: The path to save the answers to.
|
86
119
|
"""
|
87
|
-
|
120
|
+
all_answers = self.all_expert_answers
|
121
|
+
if self.append and os.path.exists(path + '.json'):
|
88
122
|
# read the file and append the new answers
|
89
123
|
with open(path + '.json', "r") as f:
|
90
|
-
|
91
|
-
all_answers
|
92
|
-
|
93
|
-
|
94
|
-
else:
|
95
|
-
with open(path + '.json', "w") as f:
|
96
|
-
json.dump(self.all_expert_answers, f)
|
124
|
+
old_answers = json.load(f)
|
125
|
+
all_answers = old_answers + all_answers
|
126
|
+
with open(path + '.json', "w") as f:
|
127
|
+
json.dump(all_answers, f)
|
97
128
|
|
98
|
-
def
|
129
|
+
def _save_to_python(self, path: str):
|
130
|
+
"""
|
131
|
+
Save the expert answers to a Python file.
|
132
|
+
|
133
|
+
:param path: The path to save the answers to.
|
134
|
+
"""
|
135
|
+
dir_name = os.path.dirname(path)
|
136
|
+
if not os.path.exists(dir_name + '/__init__.py'):
|
137
|
+
os.makedirs(dir_name, exist_ok=True)
|
138
|
+
with open(dir_name + '/__init__.py', 'w') as f:
|
139
|
+
f.write('# This is an empty init file to make the directory a package.\n')
|
140
|
+
action = 'w' if not self.append else 'a'
|
141
|
+
with open(path + '.py', action) as f:
|
142
|
+
for scope, func_source in self.all_expert_answers:
|
143
|
+
if len(scope) > 0:
|
144
|
+
imports = '\n'.join(get_imports_from_scope(scope)) + '\n\n\n'
|
145
|
+
else:
|
146
|
+
imports = ''
|
147
|
+
if func_source is not None:
|
148
|
+
func_source = encapsulate_user_input(func_source, CallableExpression.encapsulating_function)
|
149
|
+
else:
|
150
|
+
func_source = 'pass # No user input provided for this case.\n'
|
151
|
+
f.write(imports + func_source + '\n' + '\n\n\n\'===New Answer===\'\n\n\n')
|
152
|
+
|
153
|
+
def load_answers(self, path: Optional[str] = None):
|
99
154
|
"""
|
100
155
|
Load the expert answers from a file.
|
101
156
|
|
157
|
+
:param path: The path to load the answers from.
|
158
|
+
"""
|
159
|
+
if path is None and self.answers_save_path is None:
|
160
|
+
raise ValueError("No path provided to load expert answers from, either provide a path or set the "
|
161
|
+
"answers_save_path attribute.")
|
162
|
+
if path is None:
|
163
|
+
path = self.answers_save_path
|
164
|
+
is_json = os.path.exists(path + '.json')
|
165
|
+
if is_json:
|
166
|
+
self._load_answers_from_json(path)
|
167
|
+
elif os.path.exists(path + '.py'):
|
168
|
+
self._load_answers_from_python(path)
|
169
|
+
|
170
|
+
def _load_answers_from_json(self, path: str):
|
171
|
+
"""
|
172
|
+
Load the expert answers from a JSON file.
|
173
|
+
|
102
174
|
:param path: The path to load the answers from.
|
103
175
|
"""
|
104
176
|
with open(path + '.json', "r") as f:
|
105
|
-
|
177
|
+
all_answers = json.load(f)
|
178
|
+
self.all_expert_answers = [({}, answer) for answer in all_answers]
|
179
|
+
|
180
|
+
def _load_answers_from_python(self, path: str):
|
181
|
+
"""
|
182
|
+
Load the expert answers from a Python file.
|
183
|
+
|
184
|
+
:param path: The path to load the answers from.
|
185
|
+
"""
|
186
|
+
file_path = path + '.py'
|
187
|
+
with open(file_path, "r") as f:
|
188
|
+
all_answers = f.read().split('\n\n\n\'===New Answer===\'\n\n\n')
|
189
|
+
for answer in all_answers:
|
190
|
+
answer = answer.strip('\n').strip()
|
191
|
+
if 'def ' not in answer and 'pass' in answer:
|
192
|
+
self.all_expert_answers.append(({}, None))
|
193
|
+
scope = extract_imports(tree=ast.parse(answer))
|
194
|
+
func_source = list(extract_function_source(file_path, []).values())[0]
|
195
|
+
self.all_expert_answers.append((scope, func_source))
|
196
|
+
|
197
|
+
|
198
|
+
class Human(Expert):
|
199
|
+
"""
|
200
|
+
The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
|
201
|
+
"""
|
202
|
+
|
203
|
+
def __init__(self, viewer: Optional[RDRCaseViewer] = None, **kwargs):
|
204
|
+
"""
|
205
|
+
Initialize the Human expert.
|
206
|
+
|
207
|
+
:param viewer: The RDRCaseViewer instance to use for prompting the user.
|
208
|
+
"""
|
209
|
+
super().__init__(**kwargs)
|
210
|
+
self.user_prompt = UserPrompt(viewer)
|
106
211
|
|
107
212
|
def ask_for_conditions(self, case_query: CaseQuery,
|
108
213
|
last_evaluated_rule: Optional[Rule] = None) \
|
@@ -125,13 +230,18 @@ class Human(Expert):
|
|
125
230
|
if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
|
126
231
|
self.use_loaded_answers = False
|
127
232
|
if self.use_loaded_answers:
|
128
|
-
|
129
|
-
|
233
|
+
try:
|
234
|
+
loaded_scope, user_input = self.all_expert_answers.pop(0)
|
235
|
+
except IndexError:
|
236
|
+
self.use_loaded_answers = False
|
237
|
+
if user_input is not None:
|
130
238
|
condition = CallableExpression(user_input, bool, scope=case_query.scope)
|
131
239
|
else:
|
132
240
|
user_input, condition = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conditions)
|
133
241
|
if not self.use_loaded_answers:
|
134
|
-
self.all_expert_answers.append(user_input)
|
242
|
+
self.all_expert_answers.append((condition.scope, user_input))
|
243
|
+
if self.answers_save_path is not None:
|
244
|
+
self.save_answers()
|
135
245
|
case_query.conditions = condition
|
136
246
|
return condition
|
137
247
|
|
@@ -143,18 +253,65 @@ class Human(Expert):
|
|
143
253
|
:return: The conclusion for the case as a callable expression.
|
144
254
|
"""
|
145
255
|
expression: Optional[CallableExpression] = None
|
256
|
+
expert_input: Optional[str] = None
|
146
257
|
if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
|
147
258
|
self.use_loaded_answers = False
|
148
259
|
if self.use_loaded_answers:
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
260
|
+
try:
|
261
|
+
loaded_scope, expert_input = self.all_expert_answers.pop(0)
|
262
|
+
if expert_input is not None:
|
263
|
+
expression = CallableExpression(expert_input, case_query.attribute_type,
|
264
|
+
scope=case_query.scope,
|
265
|
+
mutually_exclusive=case_query.mutually_exclusive)
|
266
|
+
except IndexError:
|
267
|
+
self.use_loaded_answers = False
|
268
|
+
if not self.use_loaded_answers:
|
155
269
|
if self.user_prompt.viewer is None:
|
156
270
|
show_current_and_corner_cases(case_query.case)
|
157
271
|
expert_input, expression = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conclusion)
|
158
|
-
|
272
|
+
if expression is None:
|
273
|
+
self.all_expert_answers.append(({}, None))
|
274
|
+
else:
|
275
|
+
self.all_expert_answers.append((expression.scope, expert_input))
|
276
|
+
if self.answers_save_path is not None:
|
277
|
+
self.save_answers()
|
159
278
|
case_query.target = expression
|
160
279
|
return expression
|
280
|
+
|
281
|
+
|
282
|
+
class File(Expert):
|
283
|
+
"""
|
284
|
+
The File Expert class, an expert that reads the answers from a file.
|
285
|
+
This is used for testing purposes.
|
286
|
+
"""
|
287
|
+
|
288
|
+
def __init__(self, filename: str, **kwargs):
|
289
|
+
"""
|
290
|
+
Initialize the File expert.
|
291
|
+
|
292
|
+
:param filename: The path to the file containing the expert answers.
|
293
|
+
"""
|
294
|
+
super().__init__(**kwargs)
|
295
|
+
self.filename = filename
|
296
|
+
self.load_answers(filename)
|
297
|
+
|
298
|
+
def ask_for_conditions(self, case_query: CaseQuery,
|
299
|
+
last_evaluated_rule: Optional[Rule] = None) -> CallableExpression:
|
300
|
+
loaded_scope, user_input = self.all_expert_answers.pop(0)
|
301
|
+
if user_input:
|
302
|
+
condition = CallableExpression(user_input, bool, scope=case_query.scope)
|
303
|
+
else:
|
304
|
+
raise ValueError("No user input found in the expert answers file.")
|
305
|
+
case_query.conditions = condition
|
306
|
+
return condition
|
307
|
+
|
308
|
+
def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
|
309
|
+
loaded_scope, expert_input = self.all_expert_answers.pop(0)
|
310
|
+
if expert_input is not None:
|
311
|
+
expression = CallableExpression(expert_input, case_query.attribute_type,
|
312
|
+
scope=case_query.scope,
|
313
|
+
mutually_exclusive=case_query.mutually_exclusive)
|
314
|
+
else:
|
315
|
+
raise ValueError("No expert input found in the expert answers file.")
|
316
|
+
case_query.target = expression
|
317
|
+
return expression
|
ripple_down_rules/rdr.py
CHANGED
@@ -2,6 +2,8 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import copyreg
|
4
4
|
import importlib
|
5
|
+
import os
|
6
|
+
|
5
7
|
from . import logger
|
6
8
|
import sys
|
7
9
|
from abc import ABC, abstractmethod
|
@@ -34,7 +36,8 @@ except ImportError as e:
|
|
34
36
|
RDRCaseViewer = None
|
35
37
|
from .utils import draw_tree, make_set, copy_case, \
|
36
38
|
SubclassJSONSerializer, make_list, get_type_from_string, \
|
37
|
-
is_conflicting, update_case, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name
|
39
|
+
is_conflicting, update_case, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
|
40
|
+
is_iterable, str_to_snake_case
|
38
41
|
|
39
42
|
|
40
43
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -61,17 +64,90 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
61
64
|
"""
|
62
65
|
The type of the case (input) to the RDR classifier.
|
63
66
|
"""
|
67
|
+
case_name: Optional[str] = None
|
68
|
+
"""
|
69
|
+
The name of the case type.
|
70
|
+
"""
|
71
|
+
metadata_folder: str = "rdr_metadata"
|
72
|
+
"""
|
73
|
+
The folder to save the metadata of the RDR classifier.
|
74
|
+
"""
|
75
|
+
model_name: Optional[str] = None
|
76
|
+
"""
|
77
|
+
The name of the model. If None, the model name will be the generated python file name.
|
78
|
+
"""
|
64
79
|
|
65
|
-
def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None
|
80
|
+
def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None,
|
81
|
+
save_dir: Optional[str] = None, ask_always: bool = True, model_name: Optional[str] = None):
|
66
82
|
"""
|
67
83
|
:param start_rule: The starting rule for the classifier.
|
84
|
+
:param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
|
85
|
+
:param save_dir: The directory to save the classifier to.
|
86
|
+
:param ask_always: Whether to always ask the expert (True) or only ask when classification fails (False).
|
68
87
|
"""
|
88
|
+
self.ask_always: bool = ask_always
|
89
|
+
self.model_name: Optional[str] = model_name
|
90
|
+
self.save_dir = save_dir
|
69
91
|
self.start_rule = start_rule
|
70
92
|
self.fig: Optional[Figure] = None
|
71
93
|
self.viewer: Optional[RDRCaseViewer] = viewer
|
72
94
|
if self.viewer is not None:
|
73
95
|
self.viewer.set_save_function(self.save)
|
74
96
|
|
97
|
+
def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None) -> str:
|
98
|
+
"""
|
99
|
+
Save the classifier to a file.
|
100
|
+
|
101
|
+
:param save_dir: The directory to save the classifier to.
|
102
|
+
:param model_name: The name of the model to save. If None, a default name is generated.
|
103
|
+
:param postfix: The postfix to add to the file name.
|
104
|
+
:return: The name of the saved model.
|
105
|
+
"""
|
106
|
+
save_dir = save_dir or self.save_dir
|
107
|
+
if save_dir is None:
|
108
|
+
raise ValueError("The save directory cannot be None. Please provide a valid directory to save"
|
109
|
+
" the classifier.")
|
110
|
+
if not os.path.exists(save_dir + '/__init__.py'):
|
111
|
+
os.makedirs(save_dir, exist_ok=True)
|
112
|
+
with open(save_dir + '/__init__.py', 'w') as f:
|
113
|
+
f.write("# This is an empty __init__.py file to make the directory a package.\n")
|
114
|
+
if model_name is not None:
|
115
|
+
self.model_name = model_name
|
116
|
+
elif self.model_name is None:
|
117
|
+
self.model_name = self.generated_python_file_name
|
118
|
+
model_dir = os.path.join(save_dir, self.model_name)
|
119
|
+
os.makedirs(model_dir, exist_ok=True)
|
120
|
+
json_dir = os.path.join(model_dir, self.metadata_folder)
|
121
|
+
os.makedirs(json_dir, exist_ok=True)
|
122
|
+
self.to_json_file(os.path.join(json_dir, self.model_name))
|
123
|
+
self._write_to_python(model_dir)
|
124
|
+
return self.model_name
|
125
|
+
|
126
|
+
@classmethod
|
127
|
+
def load(cls, load_dir: str, model_name: str) -> Self:
|
128
|
+
"""
|
129
|
+
Load the classifier from a file.
|
130
|
+
|
131
|
+
:param load_dir: The path to the model directory to load the classifier from.
|
132
|
+
:param model_name: The name of the model to load.
|
133
|
+
"""
|
134
|
+
model_dir = os.path.join(load_dir, model_name)
|
135
|
+
json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
|
136
|
+
rdr = cls.from_json_file(json_file)
|
137
|
+
rdr.update_from_python(model_dir)
|
138
|
+
rdr.save_dir = load_dir
|
139
|
+
rdr.model_name = model_name
|
140
|
+
return rdr
|
141
|
+
|
142
|
+
@abstractmethod
|
143
|
+
def _write_to_python(self, model_dir: str):
|
144
|
+
"""
|
145
|
+
Write the tree of rules as source code to a file.
|
146
|
+
|
147
|
+
:param model_dir: The path to the directory to write the source code to.
|
148
|
+
"""
|
149
|
+
pass
|
150
|
+
|
75
151
|
def set_viewer(self, viewer: RDRCaseViewer):
|
76
152
|
"""
|
77
153
|
Set the viewer for the classifier.
|
@@ -160,19 +236,29 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
160
236
|
"""
|
161
237
|
if case_query is None:
|
162
238
|
raise ValueError("The case query cannot be None.")
|
239
|
+
|
163
240
|
self.name = case_query.attribute_name if self.name is None else self.name
|
164
241
|
self.case_type = case_query.case_type if self.case_type is None else self.case_type
|
242
|
+
self.case_name = case_query.case_name if self.case_name is None else self.case_name
|
243
|
+
|
165
244
|
if case_query.target is None:
|
166
245
|
case_query_cp = copy(case_query)
|
167
|
-
self.classify(case_query_cp.case, modify_case=True)
|
168
|
-
|
169
|
-
|
246
|
+
conclusions = self.classify(case_query_cp.case, modify_case=True)
|
247
|
+
if self.ask_always or conclusions is None or is_iterable(conclusions) and len(conclusions) == 0:
|
248
|
+
expert.ask_for_conclusion(case_query_cp)
|
249
|
+
case_query.target = case_query_cp.target
|
170
250
|
if case_query.target is None:
|
171
251
|
return self.classify(case_query.case)
|
172
252
|
|
173
253
|
self.update_start_rule(case_query, expert)
|
174
254
|
|
175
|
-
|
255
|
+
fit_case_result = self._fit_case(case_query, expert=expert, **kwargs)
|
256
|
+
|
257
|
+
if self.save_dir is not None:
|
258
|
+
self.save()
|
259
|
+
expert.clear_answers()
|
260
|
+
|
261
|
+
return fit_case_result
|
176
262
|
|
177
263
|
@abstractmethod
|
178
264
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
@@ -238,28 +324,54 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
238
324
|
pass
|
239
325
|
|
240
326
|
@abstractmethod
|
241
|
-
def
|
327
|
+
def update_from_python(self, model_dir: str):
|
242
328
|
"""
|
243
329
|
Update the rules from the generated python file, that might have been modified by the user.
|
244
330
|
|
245
|
-
:param
|
331
|
+
:param model_dir: The directory where the generated python file is located.
|
246
332
|
"""
|
247
333
|
pass
|
248
334
|
|
335
|
+
@classmethod
|
336
|
+
def get_acronym(cls) -> str:
|
337
|
+
"""
|
338
|
+
:return: The acronym of the classifier.
|
339
|
+
"""
|
340
|
+
if cls.__name__ == "GeneralRDR":
|
341
|
+
return "RDR"
|
342
|
+
elif cls.__name__ == "MultiClassRDR":
|
343
|
+
return "MCRDR"
|
344
|
+
else:
|
345
|
+
return "SCRDR"
|
346
|
+
|
347
|
+
def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
|
348
|
+
"""
|
349
|
+
:param package_name: The name of the package that contains the RDR classifier function.
|
350
|
+
:return: The module that contains the rdr classifier function.
|
351
|
+
"""
|
352
|
+
# remove from imports if exists first
|
353
|
+
name = f"{package_name.strip('./').replace('/', '.')}.{self.generated_python_file_name}"
|
354
|
+
try:
|
355
|
+
module = importlib.import_module(name)
|
356
|
+
del sys.modules[name]
|
357
|
+
except ModuleNotFoundError:
|
358
|
+
pass
|
359
|
+
return importlib.import_module(name).classify
|
360
|
+
|
249
361
|
|
250
362
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
251
363
|
|
252
|
-
def
|
364
|
+
def update_from_python(self, model_dir: str):
|
253
365
|
"""
|
254
366
|
Update the rules from the generated python file, that might have been modified by the user.
|
255
367
|
|
256
|
-
:param
|
368
|
+
:param model_dir: The directory where the generated python file is located.
|
257
369
|
"""
|
258
|
-
|
259
|
-
condition_func_names = [f'conditions_{rid}' for rid in
|
260
|
-
conclusion_func_names = [f'conclusion_{rid}' for rid in
|
370
|
+
rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None}
|
371
|
+
condition_func_names = [f'conditions_{rid}' for rid in rules_dict.keys()]
|
372
|
+
conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys() if not isinstance(rules_dict[rid], MultiClassStopRule)]
|
261
373
|
all_func_names = condition_func_names + conclusion_func_names
|
262
|
-
filepath = f"{
|
374
|
+
filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
|
263
375
|
functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
|
264
376
|
# get the scope from the imports in the file
|
265
377
|
scope = extract_imports(filepath)
|
@@ -267,7 +379,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
267
379
|
if rule.conditions is not None:
|
268
380
|
rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
|
269
381
|
rule.conditions.scope = scope
|
270
|
-
if rule.conclusion is not None:
|
382
|
+
if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
|
271
383
|
rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
|
272
384
|
rule.conclusion.scope = scope
|
273
385
|
|
@@ -284,17 +396,19 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
284
396
|
"""
|
285
397
|
pass
|
286
398
|
|
287
|
-
def
|
399
|
+
def _write_to_python(self, model_dir: str):
|
288
400
|
"""
|
289
401
|
Write the tree of rules as source code to a file.
|
290
402
|
|
291
|
-
:param
|
292
|
-
:param postfix: The postfix to add to the file name.
|
403
|
+
:param model_dir: The path to the directory to write the source code to.
|
293
404
|
"""
|
294
|
-
|
405
|
+
os.makedirs(model_dir, exist_ok=True)
|
406
|
+
if not os.path.exists(model_dir + '/__init__.py'):
|
407
|
+
with open(model_dir + '/__init__.py', 'w') as f:
|
408
|
+
f.write("# This is an empty __init__.py file to make the directory a package.\n")
|
295
409
|
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
296
|
-
file_name =
|
297
|
-
defs_file_name =
|
410
|
+
file_name = model_dir + f"/{self.generated_python_file_name}.py"
|
411
|
+
defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
|
298
412
|
imports, defs_imports = self._get_imports()
|
299
413
|
# clear the files first
|
300
414
|
with open(defs_file_name, "w") as f:
|
@@ -345,20 +459,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
345
459
|
imports = "\n".join(imports) + "\n"
|
346
460
|
return imports, defs_imports
|
347
461
|
|
348
|
-
def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
|
349
|
-
"""
|
350
|
-
:param package_name: The name of the package that contains the RDR classifier function.
|
351
|
-
:return: The module that contains the rdr classifier function.
|
352
|
-
"""
|
353
|
-
# remove from imports if exists first
|
354
|
-
name = f"{package_name.strip('./')}.{self.generated_python_file_name}"
|
355
|
-
try:
|
356
|
-
module = importlib.import_module(name)
|
357
|
-
del sys.modules[name]
|
358
|
-
except ModuleNotFoundError:
|
359
|
-
pass
|
360
|
-
return importlib.import_module(name).classify
|
361
|
-
|
362
462
|
@property
|
363
463
|
def _default_generated_python_file_name(self) -> Optional[str]:
|
364
464
|
"""
|
@@ -366,23 +466,12 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
366
466
|
"""
|
367
467
|
if self.start_rule is None or self.start_rule.conclusion is None:
|
368
468
|
return None
|
369
|
-
return f"{self.
|
469
|
+
return f"{str_to_snake_case(self.case_name)}_{self.attribute_name}_{self.get_acronym().lower()}"
|
370
470
|
|
371
471
|
@property
|
372
472
|
def generated_python_defs_file_name(self) -> str:
|
373
473
|
return f"{self.generated_python_file_name}_defs"
|
374
474
|
|
375
|
-
@property
|
376
|
-
def acronym(self) -> str:
|
377
|
-
"""
|
378
|
-
:return: The acronym of the classifier.
|
379
|
-
"""
|
380
|
-
if self.__class__.__name__ == "GeneralRDR":
|
381
|
-
return "GRDR"
|
382
|
-
elif self.__class__.__name__ == "MultiClassRDR":
|
383
|
-
return "MCRDR"
|
384
|
-
else:
|
385
|
-
return "SCRDR"
|
386
475
|
|
387
476
|
@property
|
388
477
|
def conclusion_type(self) -> Tuple[Type]:
|
@@ -403,7 +492,9 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
403
492
|
|
404
493
|
def _to_json(self) -> Dict[str, Any]:
|
405
494
|
return {"start_rule": self.start_rule.to_json(), "generated_python_file_name": self.generated_python_file_name,
|
406
|
-
"name": self.name,
|
495
|
+
"name": self.name,
|
496
|
+
"case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
|
497
|
+
"case_name": self.case_name}
|
407
498
|
|
408
499
|
@classmethod
|
409
500
|
def _from_json(cls, data: Dict[str, Any]) -> Self:
|
@@ -411,13 +502,15 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
411
502
|
Create an instance of the class from a json
|
412
503
|
"""
|
413
504
|
start_rule = cls.start_rule_type().from_json(data["start_rule"])
|
414
|
-
new_rdr = cls(start_rule)
|
505
|
+
new_rdr = cls(start_rule=start_rule)
|
415
506
|
if "generated_python_file_name" in data:
|
416
507
|
new_rdr.generated_python_file_name = data["generated_python_file_name"]
|
417
508
|
if "name" in data:
|
418
509
|
new_rdr.name = data["name"]
|
419
510
|
if "case_type" in data:
|
420
511
|
new_rdr.case_type = get_type_from_string(data["case_type"])
|
512
|
+
if "case_name" in data:
|
513
|
+
new_rdr.case_name = data["case_name"]
|
421
514
|
return new_rdr
|
422
515
|
|
423
516
|
@staticmethod
|
@@ -431,12 +524,12 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
431
524
|
|
432
525
|
class SingleClassRDR(RDRWithCodeWriter):
|
433
526
|
|
434
|
-
def __init__(self,
|
527
|
+
def __init__(self, default_conclusion: Optional[Any] = None, **kwargs):
|
435
528
|
"""
|
436
529
|
:param start_rule: The starting rule for the classifier.
|
437
530
|
:param default_conclusion: The default conclusion for the classifier if no rules fire.
|
438
531
|
"""
|
439
|
-
super(SingleClassRDR, self).__init__(
|
532
|
+
super(SingleClassRDR, self).__init__(**kwargs)
|
440
533
|
self.default_conclusion: Optional[Any] = default_conclusion
|
441
534
|
|
442
535
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
@@ -488,10 +581,10 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
488
581
|
matched_rule = self.start_rule(case) if self.start_rule is not None else None
|
489
582
|
return matched_rule if matched_rule is not None else self.start_rule
|
490
583
|
|
491
|
-
def
|
492
|
-
super().
|
584
|
+
def _write_to_python(self, model_dir: str):
|
585
|
+
super()._write_to_python(model_dir)
|
493
586
|
if self.default_conclusion is not None:
|
494
|
-
with open(
|
587
|
+
with open(model_dir + f"/{self.generated_python_file_name}.py", "a") as f:
|
495
588
|
f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
|
496
589
|
|
497
590
|
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
|
@@ -892,7 +985,8 @@ class GeneralRDR(RippleDownRules):
|
|
892
985
|
return {"start_rules": {name: rdr.to_json() for name, rdr in self.start_rules_dict.items()}
|
893
986
|
, "generated_python_file_name": self.generated_python_file_name,
|
894
987
|
"name": self.name,
|
895
|
-
"case_type": get_full_class_name(self.case_type) if self.case_type is not None else None
|
988
|
+
"case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
|
989
|
+
"case_name": self.case_name}
|
896
990
|
|
897
991
|
@classmethod
|
898
992
|
def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
|
@@ -902,37 +996,37 @@ class GeneralRDR(RippleDownRules):
|
|
902
996
|
start_rules_dict = {}
|
903
997
|
for k, v in data["start_rules"].items():
|
904
998
|
start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
|
905
|
-
new_rdr = cls(start_rules_dict)
|
999
|
+
new_rdr = cls(category_rdr_map=start_rules_dict)
|
906
1000
|
if "generated_python_file_name" in data:
|
907
1001
|
new_rdr.generated_python_file_name = data["generated_python_file_name"]
|
908
1002
|
if "name" in data:
|
909
1003
|
new_rdr.name = data["name"]
|
910
1004
|
if "case_type" in data:
|
911
1005
|
new_rdr.case_type = get_type_from_string(data["case_type"])
|
1006
|
+
if "case_name" in data:
|
1007
|
+
new_rdr.case_name = data["case_name"]
|
912
1008
|
return new_rdr
|
913
1009
|
|
914
|
-
def
|
1010
|
+
def update_from_python(self, model_dir: str) -> None:
|
915
1011
|
"""
|
916
1012
|
Update the rules from the generated python file, that might have been modified by the user.
|
917
1013
|
|
918
|
-
:param
|
1014
|
+
:param model_dir: The directory where the model is stored.
|
919
1015
|
"""
|
920
1016
|
for rdr in self.start_rules_dict.values():
|
921
|
-
rdr.
|
1017
|
+
rdr.update_from_python(model_dir)
|
922
1018
|
|
923
|
-
def
|
1019
|
+
def _write_to_python(self, model_dir: str) -> None:
|
924
1020
|
"""
|
925
1021
|
Write the tree of rules as source code to a file.
|
926
1022
|
|
927
|
-
:param
|
928
|
-
:param postfix: The postfix to add to the file name.
|
1023
|
+
:param model_dir: The directory where the model is stored.
|
929
1024
|
"""
|
930
|
-
self.generated_python_file_name = self._default_generated_python_file_name + postfix
|
931
1025
|
for rdr in self.start_rules_dict.values():
|
932
|
-
rdr.
|
1026
|
+
rdr._write_to_python(model_dir)
|
933
1027
|
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
934
|
-
with open(
|
935
|
-
f.write(self._get_imports(
|
1028
|
+
with open(model_dir + f"/{self.generated_python_file_name}.py", "w") as f:
|
1029
|
+
f.write(self._get_imports() + "\n\n")
|
936
1030
|
f.write("classifiers_dict = dict()\n")
|
937
1031
|
for rdr_key, rdr in self.start_rules_dict.items():
|
938
1032
|
f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
|
@@ -942,13 +1036,6 @@ class GeneralRDR(RippleDownRules):
|
|
942
1036
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
943
1037
|
f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
|
944
1038
|
|
945
|
-
def get_rdr_classifier_from_python_file(self, file_path: str) -> Callable[[Any], Any]:
|
946
|
-
"""
|
947
|
-
:param file_path: The path to the file that contains the RDR classifier function.
|
948
|
-
:return: The module that contains the rdr classifier function.
|
949
|
-
"""
|
950
|
-
return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
|
951
|
-
|
952
1039
|
@property
|
953
1040
|
def _default_generated_python_file_name(self) -> Optional[str]:
|
954
1041
|
"""
|
@@ -956,17 +1043,16 @@ class GeneralRDR(RippleDownRules):
|
|
956
1043
|
"""
|
957
1044
|
if self.start_rule is None or self.start_rule.conclusion is None:
|
958
1045
|
return None
|
959
|
-
return f"{self.
|
1046
|
+
return f"{str_to_snake_case(self.case_name)}_rdr".lower()
|
960
1047
|
|
961
1048
|
@property
|
962
1049
|
def conclusion_type_hint(self) -> str:
|
963
1050
|
return f"Dict[str, Any]"
|
964
1051
|
|
965
|
-
def _get_imports(self
|
1052
|
+
def _get_imports(self) -> str:
|
966
1053
|
"""
|
967
1054
|
Get the imports needed for the generated python file.
|
968
1055
|
|
969
|
-
:param file_path: The path to the file that contains the RDR classifier function.
|
970
1056
|
:return: The imports needed for the generated python file.
|
971
1057
|
"""
|
972
1058
|
imports = ""
|
@@ -5,15 +5,17 @@ of the RDRs.
|
|
5
5
|
"""
|
6
6
|
import os.path
|
7
7
|
from functools import wraps
|
8
|
+
|
9
|
+
from pyparsing.tools.cvt_pyparsing_pep8_names import camel_to_snake
|
8
10
|
from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union
|
9
11
|
|
10
|
-
from ripple_down_rules.datastructures.case import create_case
|
12
|
+
from ripple_down_rules.datastructures.case import create_case, Case
|
11
13
|
from ripple_down_rules.datastructures.dataclasses import CaseQuery
|
12
14
|
from ripple_down_rules.datastructures.enums import Category
|
13
15
|
from ripple_down_rules.experts import Expert, Human
|
14
16
|
from ripple_down_rules.rdr import GeneralRDR, RippleDownRules
|
15
17
|
from ripple_down_rules.utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
|
16
|
-
get_method_class_if_exists
|
18
|
+
get_method_class_if_exists, get_method_name, str_to_snake_case
|
17
19
|
|
18
20
|
|
19
21
|
class RDRDecorator:
|
@@ -41,99 +43,118 @@ class RDRDecorator:
|
|
41
43
|
:return: A decorator to use a GeneralRDR as a classifier that monitors and modifies the function's output.
|
42
44
|
"""
|
43
45
|
self.rdr_models_dir = models_dir
|
46
|
+
self.model_name: Optional[str] = None
|
44
47
|
self.output_type = output_type
|
45
48
|
self.parsed_output_type: List[Type] = []
|
46
49
|
self.mutual_exclusive = mutual_exclusive
|
47
50
|
self.rdr_python_path: Optional[str] = python_dir
|
48
51
|
self.output_name = output_name
|
49
52
|
self.fit: bool = fit
|
50
|
-
self.expert = expert
|
51
|
-
self.rdr_model_path: Optional[str] = None
|
53
|
+
self.expert: Optional[Expert] = expert
|
52
54
|
self.load()
|
53
55
|
|
54
56
|
def decorator(self, func: Callable) -> Callable:
|
55
57
|
|
56
58
|
@wraps(func)
|
57
59
|
def wrapper(*args, **kwargs) -> Optional[Any]:
|
60
|
+
|
58
61
|
if len(self.parsed_output_type) == 0:
|
59
|
-
self.parse_output_type(func, *args)
|
60
|
-
if self.
|
61
|
-
self.
|
62
|
-
|
63
|
-
func_output = func(*args, **kwargs)
|
64
|
-
case_dict.update({self.output_name: func_output})
|
65
|
-
case = create_case(case_dict, obj_name=get_func_rdr_model_name(func), max_recursion_idx=3)
|
62
|
+
self.parsed_output_type = self.parse_output_type(func, self.output_type, *args)
|
63
|
+
if self.model_name is None:
|
64
|
+
self.initialize_rdr_model_name_and_load(func)
|
65
|
+
|
66
66
|
if self.fit:
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
self.mutual_exclusive,
|
73
|
-
scope=scope, is_function=True, function_args_type_hints=func_args_type_hints)
|
67
|
+
expert_answers_path = os.path.join(self.rdr_models_dir, self.model_name, "expert_answers")
|
68
|
+
self.expert = self.expert or Human(answers_save_path=expert_answers_path)
|
69
|
+
case_query = self.create_case_query_from_method(func, self.parsed_output_type,
|
70
|
+
self.mutual_exclusive, self.output_name,
|
71
|
+
*args, **kwargs)
|
74
72
|
output = self.rdr.fit_case(case_query, expert=self.expert)
|
75
73
|
return output[self.output_name]
|
76
74
|
else:
|
75
|
+
case, case_dict = self.create_case_from_method(func, self.output_name, *args, **kwargs)
|
77
76
|
return self.rdr.classify(case)[self.output_name]
|
78
77
|
|
79
78
|
return wrapper
|
80
79
|
|
81
|
-
|
80
|
+
@staticmethod
|
81
|
+
def create_case_query_from_method(func: Callable, output_type, mutual_exclusive: bool,
|
82
|
+
output_name: str = 'output_', *args, **kwargs) -> CaseQuery:
|
83
|
+
"""
|
84
|
+
Create a CaseQuery from the function and its arguments.
|
85
|
+
|
86
|
+
:param func: The function to create a case from.
|
87
|
+
:param output_type: The type of the output.
|
88
|
+
:param mutual_exclusive: If True, the output types are mutually exclusive.
|
89
|
+
:param output_name: The name of the output in the case. Defaults to 'output_'.
|
90
|
+
:param args: The positional arguments of the function.
|
91
|
+
:param kwargs: The keyword arguments of the function.
|
92
|
+
:return: A CaseQuery object representing the case.
|
93
|
+
"""
|
94
|
+
output_type = make_set(output_type)
|
95
|
+
case, case_dict = RDRDecorator.create_case_from_method(func, output_name, *args, **kwargs)
|
96
|
+
scope = func.__globals__
|
97
|
+
scope.update(case_dict)
|
98
|
+
func_args_type_hints = get_type_hints(func)
|
99
|
+
func_args_type_hints.update({output_name: Union[tuple(output_type)]})
|
100
|
+
return CaseQuery(case, output_name, Union[tuple(output_type)],
|
101
|
+
mutual_exclusive, scope=scope,
|
102
|
+
is_function=True, function_args_type_hints=func_args_type_hints)
|
103
|
+
|
104
|
+
@staticmethod
|
105
|
+
def create_case_from_method(func: Callable, output_name: str = "output_", *args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
|
106
|
+
"""
|
107
|
+
Create a Case from the function and its arguments.
|
108
|
+
|
109
|
+
:param func: The function to create a case from.
|
110
|
+
:param output_name: The name of the output in the case. Defaults to 'output_'.
|
111
|
+
:param args: The positional arguments of the function.
|
112
|
+
:param kwargs: The keyword arguments of the function.
|
113
|
+
:return: A Case object representing the case.
|
114
|
+
"""
|
115
|
+
case_dict = get_method_args_as_dict(func, *args, **kwargs)
|
116
|
+
func_output = func(*args, **kwargs)
|
117
|
+
case_dict.update({output_name: func_output})
|
118
|
+
case_name = get_func_rdr_model_name(func)
|
119
|
+
return create_case(case_dict, obj_name=case_name, max_recursion_idx=3), case_dict
|
120
|
+
|
121
|
+
def initialize_rdr_model_name_and_load(self, func: Callable) -> None:
|
82
122
|
model_file_name = get_func_rdr_model_name(func, include_file_name=True)
|
83
|
-
|
84
|
-
.replace('__', '_') + ".json")
|
85
|
-
self.rdr_model_path = os.path.join(self.rdr_models_dir, model_file_name)
|
123
|
+
self.model_name = str_to_snake_case(model_file_name)
|
86
124
|
self.load()
|
87
125
|
|
88
|
-
|
89
|
-
|
126
|
+
@staticmethod
|
127
|
+
def parse_output_type(func: Callable, output_type: Any, *args) -> List[Type]:
|
128
|
+
parsed_output_type = []
|
129
|
+
for ot in make_set(output_type):
|
90
130
|
if ot is Self:
|
91
131
|
func_class = get_method_class_if_exists(func, *args)
|
92
132
|
if func_class is not None:
|
93
|
-
|
133
|
+
parsed_output_type.append(func_class)
|
94
134
|
else:
|
95
135
|
raise ValueError(f"The function {func} is not a method of a class,"
|
96
136
|
f" and the output type is {Self}.")
|
97
137
|
else:
|
98
|
-
|
138
|
+
parsed_output_type.append(ot)
|
139
|
+
return parsed_output_type
|
99
140
|
|
100
141
|
def save(self):
|
101
142
|
"""
|
102
143
|
Save the RDR model to the specified directory.
|
103
144
|
"""
|
104
|
-
self.rdr.save(self.
|
105
|
-
|
106
|
-
if self.rdr_python_path is not None:
|
107
|
-
if not os.path.exists(self.rdr_python_path):
|
108
|
-
os.makedirs(self.rdr_python_path)
|
109
|
-
if not os.path.exists(os.path.join(self.rdr_python_path, "__init__.py")):
|
110
|
-
# add __init__.py file to the directory
|
111
|
-
with open(os.path.join(self.rdr_python_path, "__init__.py"), "w") as f:
|
112
|
-
f.write("# This is an empty __init__.py file to make the directory a package.")
|
113
|
-
# write the RDR model to a python file
|
114
|
-
self.rdr.write_to_python_file(self.rdr_python_path)
|
145
|
+
self.rdr.save(self.rdr_models_dir)
|
115
146
|
|
116
147
|
def load(self):
|
117
148
|
"""
|
118
149
|
Load the RDR model from the specified directory.
|
119
150
|
"""
|
120
|
-
if self.
|
121
|
-
self.rdr = GeneralRDR.load(self.
|
151
|
+
if self.model_name is not None and os.path.exists(os.path.join(self.rdr_models_dir, self.model_name)):
|
152
|
+
self.rdr = GeneralRDR.load(self.rdr_models_dir, self.model_name)
|
122
153
|
else:
|
123
|
-
self.rdr = GeneralRDR()
|
124
|
-
|
125
|
-
def write_to_python_file(self, package_dir: str, file_name_postfix: str = ""):
|
126
|
-
"""
|
127
|
-
Write the RDR model to a python file.
|
128
|
-
|
129
|
-
:param package_dir: The path to the directory to write the python file.
|
130
|
-
"""
|
131
|
-
self.rdr.write_to_python_file(package_dir, postfix=file_name_postfix)
|
154
|
+
self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name)
|
132
155
|
|
133
|
-
def
|
156
|
+
def update_from_python(self):
|
134
157
|
"""
|
135
158
|
Update the RDR model from a python file.
|
136
|
-
|
137
|
-
:param package_dir: The directory of the package that contains the generated python file.
|
138
159
|
"""
|
139
|
-
self.rdr.
|
160
|
+
self.rdr.update_from_python(self.rdr_models_dir, self.model_name)
|
ripple_down_rules/rules.py
CHANGED
@@ -118,7 +118,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
118
118
|
func_call = f"{parent_indent} return {new_function_name.replace('def ', '')}(case)\n"
|
119
119
|
return "\n".join(conclusion_lines).strip(' '), func_call
|
120
120
|
else:
|
121
|
-
raise ValueError(f"Conclusion
|
121
|
+
raise ValueError(f"Conclusion format is not valid, it should contain a function definition."
|
122
122
|
f" Instead got:\n{conclusion}\n")
|
123
123
|
|
124
124
|
def write_condition_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
|
@@ -129,9 +129,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
129
129
|
:param defs_file: The file to write the conditions to if they are a definition.
|
130
130
|
"""
|
131
131
|
if_clause = self._if_statement_source_code_clause()
|
132
|
-
if
|
133
|
-
return f"{parent_indent}{if_clause} {self.conditions.user_input}:\n"
|
134
|
-
elif "def " in self.conditions.user_input:
|
132
|
+
if "def " in self.conditions.user_input:
|
135
133
|
if defs_file is None:
|
136
134
|
raise ValueError("Cannot write conditions to source code as definitions python file was not given.")
|
137
135
|
# This means the conditions are a definition that should be written and then called
|
@@ -143,6 +141,9 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
143
141
|
with open(defs_file, 'a') as f:
|
144
142
|
f.write(def_code.strip() + "\n\n\n")
|
145
143
|
return f"\n{parent_indent}{if_clause} {new_function_name.replace('def ', '')}(case):\n"
|
144
|
+
else:
|
145
|
+
raise ValueError(f"Conditions format is not valid, it should contain a function definition"
|
146
|
+
f" Instead got:\n{self.conditions.user_input}\n")
|
146
147
|
|
147
148
|
@abstractmethod
|
148
149
|
def _if_statement_source_code_clause(self) -> str:
|
@@ -258,12 +258,12 @@ class TemplateFileCreator:
|
|
258
258
|
func_name = f"{prompt_for.value.lower()}_for_"
|
259
259
|
case_name = case_query.name.replace(".", "_")
|
260
260
|
if case_query.is_function:
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
261
|
+
func_name += case_name.replace(f"_{case_query.attribute_name}", "")
|
262
|
+
else:
|
263
|
+
func_name += case_name
|
264
|
+
attribute_types = TemplateFileCreator.get_core_attribute_types(case_query)
|
265
|
+
attribute_type_names = [t.__name__ for t in attribute_types]
|
266
|
+
func_name += f"_of_type_{'_or_'.join(attribute_type_names)}"
|
267
267
|
return str_to_snake_case(func_name)
|
268
268
|
|
269
269
|
@cached_property
|
ripple_down_rules/utils.py
CHANGED
@@ -178,7 +178,7 @@ def extract_function_source(file_path: str,
|
|
178
178
|
functions_source: Dict[str, Union[str, List[str]]] = {}
|
179
179
|
line_numbers = []
|
180
180
|
for node in tree.body:
|
181
|
-
if isinstance(node, ast.FunctionDef) and node.name in function_names:
|
181
|
+
if isinstance(node, ast.FunctionDef) and (node.name in function_names or len(function_names) == 0):
|
182
182
|
# Get the line numbers of the function
|
183
183
|
lines = source.splitlines()
|
184
184
|
func_lines = lines[node.lineno - 1:node.end_lineno]
|
@@ -186,9 +186,9 @@ def extract_function_source(file_path: str,
|
|
186
186
|
func_lines = func_lines[1:]
|
187
187
|
line_numbers.append((node.lineno, node.end_lineno))
|
188
188
|
functions_source[node.name] = dedent("\n".join(func_lines)) if join_lines else func_lines
|
189
|
-
if len(functions_source)
|
189
|
+
if len(functions_source) >= len(function_names):
|
190
190
|
break
|
191
|
-
if len(functions_source)
|
191
|
+
if len(functions_source) < len(function_names):
|
192
192
|
raise ValueError(f"Could not find all functions in {file_path}: {function_names} not found,"
|
193
193
|
f"functions not found: {set(function_names) - set(functions_source.keys())}")
|
194
194
|
if return_line_numbers:
|
@@ -953,9 +953,6 @@ class SubclassJSONSerializer:
|
|
953
953
|
|
954
954
|
raise ValueError("Unknown type {}".format(data["_type"]))
|
955
955
|
|
956
|
-
save = to_json_file
|
957
|
-
load = from_json_file
|
958
|
-
|
959
956
|
|
960
957
|
def _pickle_thread(thread_obj) -> Any:
|
961
958
|
"""Return a plain object with user-defined attributes but no thread behavior."""
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.5.0
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -0,0 +1,26 @@
|
|
1
|
+
ripple_down_rules/__init__.py,sha256=FQAv_KtUXVoz9VCR37DEoK0MC84rKZnIRUy-4pQ95sE,100
|
2
|
+
ripple_down_rules/datasets.py,sha256=fJbZ7V-UUYTu5XVVpFinTbuzN3YePCnUB01L3AyZVM8,6837
|
3
|
+
ripple_down_rules/experts.py,sha256=9Vc3vx0uhDPy3YlNjwKuWJLl_A-kubRPUU6bMvQhaAg,13237
|
4
|
+
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
|
+
ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
|
6
|
+
ripple_down_rules/rdr.py,sha256=E1OiiZClQyAfGjL64ID-MWYFO4-h8iUAX-Vm9qrOoeQ,48727
|
7
|
+
ripple_down_rules/rdr_decorators.py,sha256=pYCKLgMKgQ6x_252WQtF2t4ZNjWPBxnaWtJ6TpGdcc0,7820
|
8
|
+
ripple_down_rules/rules.py,sha256=TPNVMqW9T-_46BS4WemrspLg5uG8kP6tsPvWWBAzJxg,17515
|
9
|
+
ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
|
10
|
+
ripple_down_rules/utils.py,sha256=uS38KcFceRMzT_470DCL1M0LzETdP5RLwE7cCmfo7eI,51086
|
11
|
+
ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
|
12
|
+
ripple_down_rules/datastructures/callable_expression.py,sha256=3EucsD3jWzekhjyzL2y0dyUsucd-aqC9glmgPL0Ubb4,12425
|
13
|
+
ripple_down_rules/datastructures/case.py,sha256=r8kjL9xP_wk84ThXusspgPMrAoed2bGQmKi54fzhmH8,15258
|
14
|
+
ripple_down_rules/datastructures/dataclasses.py,sha256=PuD-7zWqWT2p4FnGvnihHvZlZKg9A1ctnFgVYf2cs-8,8554
|
15
|
+
ripple_down_rules/datastructures/enums.py,sha256=ce7tqS0otfSTNAOwsnXlhsvIn4iW_Y_N3TNebF3YoZs,5700
|
16
|
+
ripple_down_rules/user_interface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
+
ripple_down_rules/user_interface/gui.py,sha256=SB0gUhgReJ3yx-NEHRPMGVuNRLPRUwW8-qup-Kd4Cfo,27182
|
18
|
+
ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=24MIFwqnAhC6ofObEO6x5xRWRnyQmPpPmTvxbCKBrzM,6514
|
19
|
+
ripple_down_rules/user_interface/object_diagram.py,sha256=tsB6iuLNEbHxp5lR2WjyejjWbnAX_nHF9xS8jNPOQVk,4548
|
20
|
+
ripple_down_rules/user_interface/prompt.py,sha256=AkkltdDIaioN43lkRKDPKSjJcmdSSGZDMYz7AL7X9lE,8082
|
21
|
+
ripple_down_rules/user_interface/template_file_creator.py,sha256=ycCbddy_BJP8d0Q2Sj21UzamhGtqGZuK_e73VTJqznY,13766
|
22
|
+
ripple_down_rules-0.5.0.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
23
|
+
ripple_down_rules-0.5.0.dist-info/METADATA,sha256=LYiepkd0xlfYVqVMdVrKZNbMJuxybqBheA2b0_CgGsY,43306
|
24
|
+
ripple_down_rules-0.5.0.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
|
25
|
+
ripple_down_rules-0.5.0.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
26
|
+
ripple_down_rules-0.5.0.dist-info/RECORD,,
|
@@ -1,26 +0,0 @@
|
|
1
|
-
ripple_down_rules/__init__.py,sha256=gvXUS_xmUCsWcUwVy5Sd8tyjdLhlPGbjfDrfDImrt7o,100
|
2
|
-
ripple_down_rules/datasets.py,sha256=fJbZ7V-UUYTu5XVVpFinTbuzN3YePCnUB01L3AyZVM8,6837
|
3
|
-
ripple_down_rules/experts.py,sha256=RWDR-xxbeFIrUQiMYLEDr_PLQFdpPZ-hOXo4dpeiUpI,6630
|
4
|
-
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
|
-
ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
|
6
|
-
ripple_down_rules/rdr.py,sha256=a7sSxvJewzG5FZvbUW_Ss7VVYQtBnH-H--hni8-pWC4,45528
|
7
|
-
ripple_down_rules/rdr_decorators.py,sha256=VdmE0JrE8j89b6Af1R1tLZiKfy3h1VCvhAUefN_FLLQ,6753
|
8
|
-
ripple_down_rules/rules.py,sha256=7NB8qWW7XEB45tmJRYsKJqBG8DN3v02fzAFYmOkX8ow,17458
|
9
|
-
ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
|
10
|
-
ripple_down_rules/utils.py,sha256=t_yutgZvrOOGb6Wa-uAuoTafLicwovSFRiUa746ALOw,51108
|
11
|
-
ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
|
12
|
-
ripple_down_rules/datastructures/callable_expression.py,sha256=jA7424_mWPbOoPICW3eLMX0-ypxnsW6gOqxrJ7JpDbE,11610
|
13
|
-
ripple_down_rules/datastructures/case.py,sha256=r8kjL9xP_wk84ThXusspgPMrAoed2bGQmKi54fzhmH8,15258
|
14
|
-
ripple_down_rules/datastructures/dataclasses.py,sha256=GWnUF4h4zfNHSsyBIz3L9y8sLkrXRv0FK_OxzzLc8L8,8183
|
15
|
-
ripple_down_rules/datastructures/enums.py,sha256=ce7tqS0otfSTNAOwsnXlhsvIn4iW_Y_N3TNebF3YoZs,5700
|
16
|
-
ripple_down_rules/user_interface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
-
ripple_down_rules/user_interface/gui.py,sha256=SB0gUhgReJ3yx-NEHRPMGVuNRLPRUwW8-qup-Kd4Cfo,27182
|
18
|
-
ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=24MIFwqnAhC6ofObEO6x5xRWRnyQmPpPmTvxbCKBrzM,6514
|
19
|
-
ripple_down_rules/user_interface/object_diagram.py,sha256=tsB6iuLNEbHxp5lR2WjyejjWbnAX_nHF9xS8jNPOQVk,4548
|
20
|
-
ripple_down_rules/user_interface/prompt.py,sha256=AkkltdDIaioN43lkRKDPKSjJcmdSSGZDMYz7AL7X9lE,8082
|
21
|
-
ripple_down_rules/user_interface/template_file_creator.py,sha256=J_bBOJltc1fsrIYeHdrSUA_jep2DhDbTK5NYRbL6QyY,13831
|
22
|
-
ripple_down_rules-0.4.88.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
23
|
-
ripple_down_rules-0.4.88.dist-info/METADATA,sha256=ytWRoIfcAHeBfJMqT1KtQJPsAEGDqXZZegjDaq6YcuM,43307
|
24
|
-
ripple_down_rules-0.4.88.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
|
25
|
-
ripple_down_rules-0.4.88.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
26
|
-
ripple_down_rules-0.4.88.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|