ripple-down-rules 0.4.7__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +5 -0
- ripple_down_rules/datastructures/callable_expression.py +20 -1
- ripple_down_rules/datastructures/case.py +8 -6
- ripple_down_rules/datastructures/dataclasses.py +9 -1
- ripple_down_rules/experts.py +194 -33
- ripple_down_rules/rdr.py +196 -114
- ripple_down_rules/rdr_decorators.py +73 -52
- ripple_down_rules/rules.py +7 -6
- ripple_down_rules/start-code-server.sh +27 -0
- ripple_down_rules/user_interface/gui.py +21 -35
- ripple_down_rules/user_interface/ipython_custom_shell.py +6 -4
- ripple_down_rules/user_interface/object_diagram.py +7 -1
- ripple_down_rules/user_interface/prompt.py +9 -4
- ripple_down_rules/user_interface/template_file_creator.py +55 -32
- ripple_down_rules/utils.py +66 -26
- {ripple_down_rules-0.4.7.dist-info → ripple_down_rules-0.4.9.dist-info}/METADATA +10 -8
- ripple_down_rules-0.4.9.dist-info/RECORD +26 -0
- ripple_down_rules-0.4.7.dist-info/RECORD +0 -25
- {ripple_down_rules-0.4.7.dist-info → ripple_down_rules-0.4.9.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.4.7.dist-info → ripple_down_rules-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.4.7.dist-info → ripple_down_rules-0.4.9.dist-info}/top_level.txt +0 -0
ripple_down_rules/rdr.py
CHANGED
@@ -2,13 +2,24 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import copyreg
|
4
4
|
import importlib
|
5
|
+
import os
|
6
|
+
|
7
|
+
from . import logger
|
5
8
|
import sys
|
6
9
|
from abc import ABC, abstractmethod
|
7
10
|
from copy import copy
|
8
11
|
from io import TextIOWrapper
|
9
12
|
from types import ModuleType
|
10
13
|
|
11
|
-
|
14
|
+
try:
|
15
|
+
from matplotlib import pyplot as plt
|
16
|
+
Figure = plt.Figure
|
17
|
+
except ImportError as e:
|
18
|
+
logger.debug(f"{e}: matplotlib is not installed")
|
19
|
+
matplotlib = None
|
20
|
+
Figure = None
|
21
|
+
plt = None
|
22
|
+
|
12
23
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
13
24
|
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
|
14
25
|
|
@@ -19,17 +30,21 @@ from .datastructures.enums import MCRDRMode
|
|
19
30
|
from .experts import Expert, Human
|
20
31
|
from .helpers import is_matching
|
21
32
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
22
|
-
|
33
|
+
try:
|
34
|
+
from .user_interface.gui import RDRCaseViewer
|
35
|
+
except ImportError as e:
|
36
|
+
RDRCaseViewer = None
|
23
37
|
from .utils import draw_tree, make_set, copy_case, \
|
24
38
|
SubclassJSONSerializer, make_list, get_type_from_string, \
|
25
|
-
is_conflicting, update_case, get_imports_from_scope, extract_function_source, extract_imports
|
39
|
+
is_conflicting, update_case, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
|
40
|
+
is_iterable, str_to_snake_case
|
26
41
|
|
27
42
|
|
28
43
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
29
44
|
"""
|
30
45
|
The abstract base class for the ripple down rules classifiers.
|
31
46
|
"""
|
32
|
-
fig: Optional[
|
47
|
+
fig: Optional[Figure] = None
|
33
48
|
"""
|
34
49
|
The figure to draw the tree on.
|
35
50
|
"""
|
@@ -45,17 +60,94 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
45
60
|
"""
|
46
61
|
The name of the classifier.
|
47
62
|
"""
|
63
|
+
case_type: Optional[Type] = None
|
64
|
+
"""
|
65
|
+
The type of the case (input) to the RDR classifier.
|
66
|
+
"""
|
67
|
+
case_name: Optional[str] = None
|
68
|
+
"""
|
69
|
+
The name of the case type.
|
70
|
+
"""
|
71
|
+
metadata_folder: str = "rdr_metadata"
|
72
|
+
"""
|
73
|
+
The folder to save the metadata of the RDR classifier.
|
74
|
+
"""
|
75
|
+
model_name: Optional[str] = None
|
76
|
+
"""
|
77
|
+
The name of the model. If None, the model name will be the generated python file name.
|
78
|
+
"""
|
48
79
|
|
49
|
-
def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None
|
80
|
+
def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None,
|
81
|
+
save_dir: Optional[str] = None, ask_always: bool = True, model_name: Optional[str] = None):
|
50
82
|
"""
|
51
83
|
:param start_rule: The starting rule for the classifier.
|
84
|
+
:param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
|
85
|
+
:param save_dir: The directory to save the classifier to.
|
86
|
+
:param ask_always: Whether to always ask the expert (True) or only ask when classification fails (False).
|
52
87
|
"""
|
88
|
+
self.ask_always: bool = ask_always
|
89
|
+
self.model_name: Optional[str] = model_name
|
90
|
+
self.save_dir = save_dir
|
53
91
|
self.start_rule = start_rule
|
54
|
-
self.fig: Optional[
|
92
|
+
self.fig: Optional[Figure] = None
|
55
93
|
self.viewer: Optional[RDRCaseViewer] = viewer
|
56
94
|
if self.viewer is not None:
|
57
95
|
self.viewer.set_save_function(self.save)
|
58
96
|
|
97
|
+
def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None) -> str:
|
98
|
+
"""
|
99
|
+
Save the classifier to a file.
|
100
|
+
|
101
|
+
:param save_dir: The directory to save the classifier to.
|
102
|
+
:param model_name: The name of the model to save. If None, a default name is generated.
|
103
|
+
:param postfix: The postfix to add to the file name.
|
104
|
+
:return: The name of the saved model.
|
105
|
+
"""
|
106
|
+
save_dir = save_dir or self.save_dir
|
107
|
+
if save_dir is None:
|
108
|
+
raise ValueError("The save directory cannot be None. Please provide a valid directory to save"
|
109
|
+
" the classifier.")
|
110
|
+
if not os.path.exists(save_dir + '/__init__.py'):
|
111
|
+
os.makedirs(save_dir, exist_ok=True)
|
112
|
+
with open(save_dir + '/__init__.py', 'w') as f:
|
113
|
+
f.write("# This is an empty __init__.py file to make the directory a package.\n")
|
114
|
+
if model_name is not None:
|
115
|
+
self.model_name = model_name
|
116
|
+
elif self.model_name is None:
|
117
|
+
self.model_name = self.generated_python_file_name
|
118
|
+
model_dir = os.path.join(save_dir, self.model_name)
|
119
|
+
os.makedirs(model_dir, exist_ok=True)
|
120
|
+
json_dir = os.path.join(model_dir, self.metadata_folder)
|
121
|
+
os.makedirs(json_dir, exist_ok=True)
|
122
|
+
self.to_json_file(os.path.join(json_dir, self.model_name))
|
123
|
+
self._write_to_python(model_dir)
|
124
|
+
return self.model_name
|
125
|
+
|
126
|
+
@classmethod
|
127
|
+
def load(cls, load_dir: str, model_name: str) -> Self:
|
128
|
+
"""
|
129
|
+
Load the classifier from a file.
|
130
|
+
|
131
|
+
:param load_dir: The path to the model directory to load the classifier from.
|
132
|
+
:param model_name: The name of the model to load.
|
133
|
+
"""
|
134
|
+
model_dir = os.path.join(load_dir, model_name)
|
135
|
+
json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
|
136
|
+
rdr = cls.from_json_file(json_file)
|
137
|
+
rdr.update_from_python(model_dir)
|
138
|
+
rdr.save_dir = load_dir
|
139
|
+
rdr.model_name = model_name
|
140
|
+
return rdr
|
141
|
+
|
142
|
+
@abstractmethod
|
143
|
+
def _write_to_python(self, model_dir: str):
|
144
|
+
"""
|
145
|
+
Write the tree of rules as source code to a file.
|
146
|
+
|
147
|
+
:param model_dir: The path to the directory to write the source code to.
|
148
|
+
"""
|
149
|
+
pass
|
150
|
+
|
59
151
|
def set_viewer(self, viewer: RDRCaseViewer):
|
60
152
|
"""
|
61
153
|
Set the viewer for the classifier.
|
@@ -82,6 +174,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
82
174
|
"""
|
83
175
|
targets = []
|
84
176
|
if animate_tree:
|
177
|
+
if plt is None:
|
178
|
+
raise ImportError("matplotlib is not installed, cannot animate the tree.")
|
85
179
|
plt.ion()
|
86
180
|
i = 0
|
87
181
|
stop_iterating = False
|
@@ -104,13 +198,13 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
104
198
|
all_predictions = [1 if is_matching(self.classify, case_query) else 0 for case_query in case_queries
|
105
199
|
if case_query.target is not None]
|
106
200
|
all_pred = sum(all_predictions)
|
107
|
-
|
201
|
+
logger.info(f"Accuracy: {all_pred}/{len(targets)}")
|
108
202
|
all_predicted = targets and all_pred == len(targets)
|
109
203
|
num_iter_reached = n_iter and i >= n_iter
|
110
204
|
stop_iterating = all_predicted or num_iter_reached
|
111
205
|
if stop_iterating:
|
112
206
|
break
|
113
|
-
|
207
|
+
logger.info(f"Finished training in {i} iterations")
|
114
208
|
if animate_tree:
|
115
209
|
plt.ioff()
|
116
210
|
plt.show()
|
@@ -142,18 +236,29 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
142
236
|
"""
|
143
237
|
if case_query is None:
|
144
238
|
raise ValueError("The case query cannot be None.")
|
239
|
+
|
145
240
|
self.name = case_query.attribute_name if self.name is None else self.name
|
241
|
+
self.case_type = case_query.case_type if self.case_type is None else self.case_type
|
242
|
+
self.case_name = case_query.case_name if self.case_name is None else self.case_name
|
243
|
+
|
146
244
|
if case_query.target is None:
|
147
245
|
case_query_cp = copy(case_query)
|
148
|
-
self.classify(case_query_cp.case, modify_case=True)
|
149
|
-
|
150
|
-
|
246
|
+
conclusions = self.classify(case_query_cp.case, modify_case=True)
|
247
|
+
if self.ask_always or conclusions is None or is_iterable(conclusions) and len(conclusions) == 0:
|
248
|
+
expert.ask_for_conclusion(case_query_cp)
|
249
|
+
case_query.target = case_query_cp.target
|
151
250
|
if case_query.target is None:
|
152
251
|
return self.classify(case_query.case)
|
153
252
|
|
154
253
|
self.update_start_rule(case_query, expert)
|
155
254
|
|
156
|
-
|
255
|
+
fit_case_result = self._fit_case(case_query, expert=expert, **kwargs)
|
256
|
+
|
257
|
+
if self.save_dir is not None:
|
258
|
+
self.save()
|
259
|
+
expert.clear_answers()
|
260
|
+
|
261
|
+
return fit_case_result
|
157
262
|
|
158
263
|
@abstractmethod
|
159
264
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
@@ -219,28 +324,54 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
219
324
|
pass
|
220
325
|
|
221
326
|
@abstractmethod
|
222
|
-
def
|
327
|
+
def update_from_python(self, model_dir: str):
|
223
328
|
"""
|
224
329
|
Update the rules from the generated python file, that might have been modified by the user.
|
225
330
|
|
226
|
-
:param
|
331
|
+
:param model_dir: The directory where the generated python file is located.
|
227
332
|
"""
|
228
333
|
pass
|
229
334
|
|
335
|
+
@classmethod
|
336
|
+
def get_acronym(cls) -> str:
|
337
|
+
"""
|
338
|
+
:return: The acronym of the classifier.
|
339
|
+
"""
|
340
|
+
if cls.__name__ == "GeneralRDR":
|
341
|
+
return "RDR"
|
342
|
+
elif cls.__name__ == "MultiClassRDR":
|
343
|
+
return "MCRDR"
|
344
|
+
else:
|
345
|
+
return "SCRDR"
|
346
|
+
|
347
|
+
def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
|
348
|
+
"""
|
349
|
+
:param package_name: The name of the package that contains the RDR classifier function.
|
350
|
+
:return: The module that contains the rdr classifier function.
|
351
|
+
"""
|
352
|
+
# remove from imports if exists first
|
353
|
+
name = f"{package_name.strip('./').replace('/', '.')}.{self.generated_python_file_name}"
|
354
|
+
try:
|
355
|
+
module = importlib.import_module(name)
|
356
|
+
del sys.modules[name]
|
357
|
+
except ModuleNotFoundError:
|
358
|
+
pass
|
359
|
+
return importlib.import_module(name).classify
|
360
|
+
|
230
361
|
|
231
362
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
232
363
|
|
233
|
-
def
|
364
|
+
def update_from_python(self, model_dir: str):
|
234
365
|
"""
|
235
366
|
Update the rules from the generated python file, that might have been modified by the user.
|
236
367
|
|
237
|
-
:param
|
368
|
+
:param model_dir: The directory where the generated python file is located.
|
238
369
|
"""
|
239
|
-
|
240
|
-
condition_func_names = [f'conditions_{rid}' for rid in
|
241
|
-
conclusion_func_names = [f'conclusion_{rid}' for rid in
|
370
|
+
rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None}
|
371
|
+
condition_func_names = [f'conditions_{rid}' for rid in rules_dict.keys()]
|
372
|
+
conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys() if not isinstance(rules_dict[rid], MultiClassStopRule)]
|
242
373
|
all_func_names = condition_func_names + conclusion_func_names
|
243
|
-
filepath = f"{
|
374
|
+
filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
|
244
375
|
functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
|
245
376
|
# get the scope from the imports in the file
|
246
377
|
scope = extract_imports(filepath)
|
@@ -248,7 +379,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
248
379
|
if rule.conditions is not None:
|
249
380
|
rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
|
250
381
|
rule.conditions.scope = scope
|
251
|
-
if rule.conclusion is not None:
|
382
|
+
if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
|
252
383
|
rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
|
253
384
|
rule.conclusion.scope = scope
|
254
385
|
|
@@ -265,17 +396,19 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
265
396
|
"""
|
266
397
|
pass
|
267
398
|
|
268
|
-
def
|
399
|
+
def _write_to_python(self, model_dir: str):
|
269
400
|
"""
|
270
401
|
Write the tree of rules as source code to a file.
|
271
402
|
|
272
|
-
:param
|
273
|
-
:param postfix: The postfix to add to the file name.
|
403
|
+
:param model_dir: The path to the directory to write the source code to.
|
274
404
|
"""
|
275
|
-
|
405
|
+
os.makedirs(model_dir, exist_ok=True)
|
406
|
+
if not os.path.exists(model_dir + '/__init__.py'):
|
407
|
+
with open(model_dir + '/__init__.py', 'w') as f:
|
408
|
+
f.write("# This is an empty __init__.py file to make the directory a package.\n")
|
276
409
|
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
277
|
-
file_name =
|
278
|
-
defs_file_name =
|
410
|
+
file_name = model_dir + f"/{self.generated_python_file_name}.py"
|
411
|
+
defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
|
279
412
|
imports, defs_imports = self._get_imports()
|
280
413
|
# clear the files first
|
281
414
|
with open(defs_file_name, "w") as f:
|
@@ -326,56 +459,19 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
326
459
|
imports = "\n".join(imports) + "\n"
|
327
460
|
return imports, defs_imports
|
328
461
|
|
329
|
-
def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
|
330
|
-
"""
|
331
|
-
:param package_name: The name of the package that contains the RDR classifier function.
|
332
|
-
:return: The module that contains the rdr classifier function.
|
333
|
-
"""
|
334
|
-
# remove from imports if exists first
|
335
|
-
name = f"{package_name.strip('./')}.{self.generated_python_file_name}"
|
336
|
-
try:
|
337
|
-
module = importlib.import_module(name)
|
338
|
-
del sys.modules[name]
|
339
|
-
except ModuleNotFoundError:
|
340
|
-
pass
|
341
|
-
return importlib.import_module(name).classify
|
342
|
-
|
343
462
|
@property
|
344
|
-
def _default_generated_python_file_name(self) -> str:
|
463
|
+
def _default_generated_python_file_name(self) -> Optional[str]:
|
345
464
|
"""
|
346
465
|
:return: The default generated python file name.
|
347
466
|
"""
|
348
|
-
if
|
349
|
-
|
350
|
-
|
351
|
-
name = self.start_rule.corner_case.__class__.__name__
|
352
|
-
return f"{name.lower()}_{self.attribute_name}_{self.acronym.lower()}"
|
467
|
+
if self.start_rule is None or self.start_rule.conclusion is None:
|
468
|
+
return None
|
469
|
+
return f"{str_to_snake_case(self.case_name)}_{self.attribute_name}_{self.get_acronym().lower()}"
|
353
470
|
|
354
471
|
@property
|
355
472
|
def generated_python_defs_file_name(self) -> str:
|
356
473
|
return f"{self.generated_python_file_name}_defs"
|
357
474
|
|
358
|
-
@property
|
359
|
-
def acronym(self) -> str:
|
360
|
-
"""
|
361
|
-
:return: The acronym of the classifier.
|
362
|
-
"""
|
363
|
-
if self.__class__.__name__ == "GeneralRDR":
|
364
|
-
return "GRDR"
|
365
|
-
elif self.__class__.__name__ == "MultiClassRDR":
|
366
|
-
return "MCRDR"
|
367
|
-
else:
|
368
|
-
return "SCRDR"
|
369
|
-
|
370
|
-
@property
|
371
|
-
def case_type(self) -> Type:
|
372
|
-
"""
|
373
|
-
:return: The type of the case (input) to the RDR classifier.
|
374
|
-
"""
|
375
|
-
if isinstance(self.start_rule.corner_case, Case):
|
376
|
-
return self.start_rule.corner_case._obj_type
|
377
|
-
else:
|
378
|
-
return type(self.start_rule.corner_case)
|
379
475
|
|
380
476
|
@property
|
381
477
|
def conclusion_type(self) -> Tuple[Type]:
|
@@ -396,7 +492,9 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
396
492
|
|
397
493
|
def _to_json(self) -> Dict[str, Any]:
|
398
494
|
return {"start_rule": self.start_rule.to_json(), "generated_python_file_name": self.generated_python_file_name,
|
399
|
-
"name": self.name
|
495
|
+
"name": self.name,
|
496
|
+
"case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
|
497
|
+
"case_name": self.case_name}
|
400
498
|
|
401
499
|
@classmethod
|
402
500
|
def _from_json(cls, data: Dict[str, Any]) -> Self:
|
@@ -404,11 +502,15 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
404
502
|
Create an instance of the class from a json
|
405
503
|
"""
|
406
504
|
start_rule = cls.start_rule_type().from_json(data["start_rule"])
|
407
|
-
new_rdr = cls(start_rule)
|
505
|
+
new_rdr = cls(start_rule=start_rule)
|
408
506
|
if "generated_python_file_name" in data:
|
409
507
|
new_rdr.generated_python_file_name = data["generated_python_file_name"]
|
410
508
|
if "name" in data:
|
411
509
|
new_rdr.name = data["name"]
|
510
|
+
if "case_type" in data:
|
511
|
+
new_rdr.case_type = get_type_from_string(data["case_type"])
|
512
|
+
if "case_name" in data:
|
513
|
+
new_rdr.case_name = data["case_name"]
|
412
514
|
return new_rdr
|
413
515
|
|
414
516
|
@staticmethod
|
@@ -422,12 +524,12 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
422
524
|
|
423
525
|
class SingleClassRDR(RDRWithCodeWriter):
|
424
526
|
|
425
|
-
def __init__(self,
|
527
|
+
def __init__(self, default_conclusion: Optional[Any] = None, **kwargs):
|
426
528
|
"""
|
427
529
|
:param start_rule: The starting rule for the classifier.
|
428
530
|
:param default_conclusion: The default conclusion for the classifier if no rules fire.
|
429
531
|
"""
|
430
|
-
super(SingleClassRDR, self).__init__(
|
532
|
+
super(SingleClassRDR, self).__init__(**kwargs)
|
431
533
|
self.default_conclusion: Optional[Any] = default_conclusion
|
432
534
|
|
433
535
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
@@ -479,10 +581,10 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
479
581
|
matched_rule = self.start_rule(case) if self.start_rule is not None else None
|
480
582
|
return matched_rule if matched_rule is not None else self.start_rule
|
481
583
|
|
482
|
-
def
|
483
|
-
super().
|
584
|
+
def _write_to_python(self, model_dir: str):
|
585
|
+
super()._write_to_python(model_dir)
|
484
586
|
if self.default_conclusion is not None:
|
485
|
-
with open(
|
587
|
+
with open(model_dir + f"/{self.generated_python_file_name}.py", "a") as f:
|
486
588
|
f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
|
487
589
|
|
488
590
|
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
|
@@ -758,7 +860,7 @@ class GeneralRDR(RippleDownRules):
|
|
758
860
|
self.start_rules_dict: Dict[str, Union[SingleClassRDR, MultiClassRDR]] \
|
759
861
|
= category_rdr_map if category_rdr_map else {}
|
760
862
|
super(GeneralRDR, self).__init__(**kwargs)
|
761
|
-
self.all_figs: List[
|
863
|
+
self.all_figs: List[Figure] = [sr.fig for sr in self.start_rules_dict.values()]
|
762
864
|
|
763
865
|
def add_rdr(self, rdr: Union[SingleClassRDR, MultiClassRDR], case_query: Optional[CaseQuery] = None):
|
764
866
|
"""
|
@@ -882,7 +984,9 @@ class GeneralRDR(RippleDownRules):
|
|
882
984
|
def _to_json(self) -> Dict[str, Any]:
|
883
985
|
return {"start_rules": {name: rdr.to_json() for name, rdr in self.start_rules_dict.items()}
|
884
986
|
, "generated_python_file_name": self.generated_python_file_name,
|
885
|
-
"name": self.name
|
987
|
+
"name": self.name,
|
988
|
+
"case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
|
989
|
+
"case_name": self.case_name}
|
886
990
|
|
887
991
|
@classmethod
|
888
992
|
def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
|
@@ -892,35 +996,37 @@ class GeneralRDR(RippleDownRules):
|
|
892
996
|
start_rules_dict = {}
|
893
997
|
for k, v in data["start_rules"].items():
|
894
998
|
start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
|
895
|
-
new_rdr = cls(start_rules_dict)
|
999
|
+
new_rdr = cls(category_rdr_map=start_rules_dict)
|
896
1000
|
if "generated_python_file_name" in data:
|
897
1001
|
new_rdr.generated_python_file_name = data["generated_python_file_name"]
|
898
1002
|
if "name" in data:
|
899
1003
|
new_rdr.name = data["name"]
|
1004
|
+
if "case_type" in data:
|
1005
|
+
new_rdr.case_type = get_type_from_string(data["case_type"])
|
1006
|
+
if "case_name" in data:
|
1007
|
+
new_rdr.case_name = data["case_name"]
|
900
1008
|
return new_rdr
|
901
1009
|
|
902
|
-
def
|
1010
|
+
def update_from_python(self, model_dir: str) -> None:
|
903
1011
|
"""
|
904
1012
|
Update the rules from the generated python file, that might have been modified by the user.
|
905
1013
|
|
906
|
-
:param
|
1014
|
+
:param model_dir: The directory where the model is stored.
|
907
1015
|
"""
|
908
1016
|
for rdr in self.start_rules_dict.values():
|
909
|
-
rdr.
|
1017
|
+
rdr.update_from_python(model_dir)
|
910
1018
|
|
911
|
-
def
|
1019
|
+
def _write_to_python(self, model_dir: str) -> None:
|
912
1020
|
"""
|
913
1021
|
Write the tree of rules as source code to a file.
|
914
1022
|
|
915
|
-
:param
|
916
|
-
:param postfix: The postfix to add to the file name.
|
1023
|
+
:param model_dir: The directory where the model is stored.
|
917
1024
|
"""
|
918
|
-
self.generated_python_file_name = self._default_generated_python_file_name + postfix
|
919
1025
|
for rdr in self.start_rules_dict.values():
|
920
|
-
rdr.
|
1026
|
+
rdr._write_to_python(model_dir)
|
921
1027
|
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
922
|
-
with open(
|
923
|
-
f.write(self._get_imports(
|
1028
|
+
with open(model_dir + f"/{self.generated_python_file_name}.py", "w") as f:
|
1029
|
+
f.write(self._get_imports() + "\n\n")
|
924
1030
|
f.write("classifiers_dict = dict()\n")
|
925
1031
|
for rdr_key, rdr in self.start_rules_dict.items():
|
926
1032
|
f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
|
@@ -930,25 +1036,6 @@ class GeneralRDR(RippleDownRules):
|
|
930
1036
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
931
1037
|
f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
|
932
1038
|
|
933
|
-
@property
|
934
|
-
def case_type(self) -> Optional[Type]:
|
935
|
-
"""
|
936
|
-
:return: The type of the case (input) to the RDR classifier.
|
937
|
-
"""
|
938
|
-
if self.start_rule is None or self.start_rule.corner_case is None:
|
939
|
-
return None
|
940
|
-
if isinstance(self.start_rule.corner_case, Case):
|
941
|
-
return self.start_rule.corner_case._obj_type
|
942
|
-
else:
|
943
|
-
return type(self.start_rule.corner_case)
|
944
|
-
|
945
|
-
def get_rdr_classifier_from_python_file(self, file_path: str) -> Callable[[Any], Any]:
|
946
|
-
"""
|
947
|
-
:param file_path: The path to the file that contains the RDR classifier function.
|
948
|
-
:return: The module that contains the rdr classifier function.
|
949
|
-
"""
|
950
|
-
return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
|
951
|
-
|
952
1039
|
@property
|
953
1040
|
def _default_generated_python_file_name(self) -> Optional[str]:
|
954
1041
|
"""
|
@@ -956,21 +1043,16 @@ class GeneralRDR(RippleDownRules):
|
|
956
1043
|
"""
|
957
1044
|
if self.start_rule is None or self.start_rule.conclusion is None:
|
958
1045
|
return None
|
959
|
-
|
960
|
-
name = self.start_rule.corner_case._name
|
961
|
-
else:
|
962
|
-
name = self.start_rule.corner_case.__class__.__name__
|
963
|
-
return f"{name}_rdr".lower()
|
1046
|
+
return f"{str_to_snake_case(self.case_name)}_rdr".lower()
|
964
1047
|
|
965
1048
|
@property
|
966
1049
|
def conclusion_type_hint(self) -> str:
|
967
1050
|
return f"Dict[str, Any]"
|
968
1051
|
|
969
|
-
def _get_imports(self
|
1052
|
+
def _get_imports(self) -> str:
|
970
1053
|
"""
|
971
1054
|
Get the imports needed for the generated python file.
|
972
1055
|
|
973
|
-
:param file_path: The path to the file that contains the RDR classifier function.
|
974
1056
|
:return: The imports needed for the generated python file.
|
975
1057
|
"""
|
976
1058
|
imports = ""
|