ripple-down-rules 0.5.5__py3-none-any.whl → 0.5.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/datastructures/callable_expression.py +16 -9
- ripple_down_rules/datastructures/case.py +10 -4
- ripple_down_rules/datastructures/dataclasses.py +62 -3
- ripple_down_rules/experts.py +12 -2
- ripple_down_rules/helpers.py +55 -9
- ripple_down_rules/rdr.py +269 -180
- ripple_down_rules/rdr_decorators.py +60 -31
- ripple_down_rules/rules.py +69 -13
- ripple_down_rules/user_interface/gui.py +10 -7
- ripple_down_rules/user_interface/ipython_custom_shell.py +1 -1
- ripple_down_rules/user_interface/object_diagram.py +9 -1
- ripple_down_rules/user_interface/template_file_creator.py +25 -25
- ripple_down_rules/utils.py +330 -79
- {ripple_down_rules-0.5.5.dist-info → ripple_down_rules-0.5.8.dist-info}/METADATA +2 -1
- ripple_down_rules-0.5.8.dist-info/RECORD +24 -0
- ripple_down_rules-0.5.5.dist-info/RECORD +0 -24
- {ripple_down_rules-0.5.5.dist-info → ripple_down_rules-0.5.8.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.5.5.dist-info → ripple_down_rules-0.5.8.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.5.5.dist-info → ripple_down_rules-0.5.8.dist-info}/top_level.txt +0 -0
ripple_down_rules/rdr.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import copyreg
|
4
3
|
import importlib
|
5
4
|
import os
|
6
|
-
|
7
|
-
from . import logger
|
8
|
-
import sys
|
9
5
|
from abc import ABC, abstractmethod
|
10
6
|
from copy import copy
|
11
|
-
|
12
|
-
from
|
7
|
+
|
8
|
+
from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
|
9
|
+
from . import logger
|
13
10
|
|
14
11
|
try:
|
15
12
|
from matplotlib import pyplot as plt
|
13
|
+
|
16
14
|
Figure = plt.Figure
|
17
15
|
except ImportError as e:
|
18
16
|
logger.debug(f"{e}: matplotlib is not installed")
|
@@ -28,16 +26,16 @@ from .datastructures.case import Case, CaseAttribute, create_case
|
|
28
26
|
from .datastructures.dataclasses import CaseQuery
|
29
27
|
from .datastructures.enums import MCRDRMode
|
30
28
|
from .experts import Expert, Human
|
31
|
-
from .helpers import is_matching
|
29
|
+
from .helpers import is_matching, general_rdr_classify
|
32
30
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
31
|
+
|
33
32
|
try:
|
34
33
|
from .user_interface.gui import RDRCaseViewer
|
35
34
|
except ImportError as e:
|
36
35
|
RDRCaseViewer = None
|
37
|
-
from .utils import draw_tree, make_set,
|
38
|
-
|
39
|
-
|
40
|
-
is_iterable, str_to_snake_case
|
36
|
+
from .utils import draw_tree, make_set, SubclassJSONSerializer, make_list, get_type_from_string, \
|
37
|
+
is_conflicting, extract_function_source, extract_imports, get_full_class_name, \
|
38
|
+
is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types
|
41
39
|
|
42
40
|
|
43
41
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -76,16 +74,18 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
76
74
|
"""
|
77
75
|
The name of the model. If None, the model name will be the generated python file name.
|
78
76
|
"""
|
77
|
+
mutually_exclusive: Optional[bool] = None
|
78
|
+
"""
|
79
|
+
Whether the output of the classification of this rdr allows only one possible conclusion or not.
|
80
|
+
"""
|
79
81
|
|
80
82
|
def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None,
|
81
|
-
save_dir: Optional[str] = None,
|
83
|
+
save_dir: Optional[str] = None, model_name: Optional[str] = None):
|
82
84
|
"""
|
83
85
|
:param start_rule: The starting rule for the classifier.
|
84
86
|
:param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
|
85
87
|
:param save_dir: The directory to save the classifier to.
|
86
|
-
:param ask_always: Whether to always ask the expert (True) or only ask when classification fails (False).
|
87
88
|
"""
|
88
|
-
self.ask_always: bool = ask_always
|
89
89
|
self.model_name: Optional[str] = model_name
|
90
90
|
self.save_dir = save_dir
|
91
91
|
self.start_rule = start_rule
|
@@ -94,13 +94,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
94
94
|
if self.viewer is not None:
|
95
95
|
self.viewer.set_save_function(self.save)
|
96
96
|
|
97
|
-
def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None
|
97
|
+
def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None,
|
98
|
+
package_name: Optional[str] = None) -> str:
|
98
99
|
"""
|
99
100
|
Save the classifier to a file.
|
100
101
|
|
101
102
|
:param save_dir: The directory to save the classifier to.
|
102
103
|
:param model_name: The name of the model to save. If None, a default name is generated.
|
103
|
-
:param
|
104
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
105
|
+
is required in case of relative imports in the generated python file.
|
104
106
|
:return: The name of the saved model.
|
105
107
|
"""
|
106
108
|
save_dir = save_dir or self.save_dir
|
@@ -110,7 +112,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
110
112
|
if not os.path.exists(save_dir + '/__init__.py'):
|
111
113
|
os.makedirs(save_dir, exist_ok=True)
|
112
114
|
with open(save_dir + '/__init__.py', 'w') as f:
|
113
|
-
f.write("
|
115
|
+
f.write("from . import *\n")
|
114
116
|
if model_name is not None:
|
115
117
|
self.model_name = model_name
|
116
118
|
elif self.model_name is None:
|
@@ -120,31 +122,40 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
120
122
|
json_dir = os.path.join(model_dir, self.metadata_folder)
|
121
123
|
os.makedirs(json_dir, exist_ok=True)
|
122
124
|
self.to_json_file(os.path.join(json_dir, self.model_name))
|
123
|
-
self._write_to_python(model_dir)
|
125
|
+
self._write_to_python(model_dir, package_name=package_name)
|
124
126
|
return self.model_name
|
125
127
|
|
126
128
|
@classmethod
|
127
|
-
def load(cls, load_dir: str, model_name: str
|
129
|
+
def load(cls, load_dir: str, model_name: str,
|
130
|
+
package_name: Optional[str] = None) -> Self:
|
128
131
|
"""
|
129
132
|
Load the classifier from a file.
|
130
133
|
|
131
134
|
:param load_dir: The path to the model directory to load the classifier from.
|
132
135
|
:param model_name: The name of the model to load.
|
136
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
137
|
+
is required in case of relative imports in the generated python file.
|
133
138
|
"""
|
134
139
|
model_dir = os.path.join(load_dir, model_name)
|
135
140
|
json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
|
136
141
|
rdr = cls.from_json_file(json_file)
|
137
|
-
|
142
|
+
try:
|
143
|
+
rdr.update_from_python(model_dir, package_name=package_name)
|
144
|
+
except (FileNotFoundError, ValueError) as e:
|
145
|
+
logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
|
146
|
+
f"Make sure the file exists and is valid.")
|
138
147
|
rdr.save_dir = load_dir
|
139
148
|
rdr.model_name = model_name
|
140
149
|
return rdr
|
141
150
|
|
142
151
|
@abstractmethod
|
143
|
-
def _write_to_python(self, model_dir: str):
|
152
|
+
def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
|
144
153
|
"""
|
145
154
|
Write the tree of rules as source code to a file.
|
146
155
|
|
147
156
|
:param model_dir: The path to the directory to write the source code to.
|
157
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
158
|
+
is required in case of relative imports in the generated python file.
|
148
159
|
"""
|
149
160
|
pass
|
150
161
|
|
@@ -213,18 +224,24 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
213
224
|
return self.classify(case)
|
214
225
|
|
215
226
|
@abstractmethod
|
216
|
-
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False
|
227
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
|
228
|
+
case_query: Optional[CaseQuery] = None) \
|
217
229
|
-> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
|
218
230
|
"""
|
219
231
|
Classify a case.
|
220
232
|
|
221
233
|
:param case: The case to classify.
|
222
234
|
:param modify_case: Whether to modify the original case attributes with the conclusion or not.
|
235
|
+
:param case_query: The case query containing the case to classify and the target category to compare the case with.
|
223
236
|
:return: The category that the case belongs to.
|
224
237
|
"""
|
225
238
|
pass
|
226
239
|
|
227
|
-
def fit_case(self, case_query: CaseQuery,
|
240
|
+
def fit_case(self, case_query: CaseQuery,
|
241
|
+
expert: Optional[Expert] = None,
|
242
|
+
update_existing_rules: bool = True,
|
243
|
+
scenario: Optional[Callable] = None,
|
244
|
+
**kwargs) \
|
228
245
|
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
229
246
|
"""
|
230
247
|
Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
|
@@ -232,6 +249,9 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
232
249
|
|
233
250
|
:param case_query: The query containing the case to classify and the target category to compare the case with.
|
234
251
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
252
|
+
:param update_existing_rules: Whether to update the existing same conclusion type rules that already gave
|
253
|
+
some conclusions with the type required by the case query.
|
254
|
+
:param scenario: The scenario at which the case was created, this is used to recreate the case if needed.
|
235
255
|
:return: The category that the case belongs to.
|
236
256
|
"""
|
237
257
|
if case_query is None:
|
@@ -240,13 +260,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
240
260
|
self.name = case_query.attribute_name if self.name is None else self.name
|
241
261
|
self.case_type = case_query.case_type if self.case_type is None else self.case_type
|
242
262
|
self.case_name = case_query.case_name if self.case_name is None else self.case_name
|
263
|
+
case_query.scenario = scenario if case_query.scenario is None else case_query.scenario
|
243
264
|
|
244
|
-
expert = expert or Human(
|
245
|
-
|
265
|
+
expert = expert or Human(viewer=self.viewer,
|
266
|
+
answers_save_path=self.save_dir + '/expert_answers'
|
267
|
+
if self.save_dir else None)
|
246
268
|
if case_query.target is None:
|
247
269
|
case_query_cp = copy(case_query)
|
248
|
-
conclusions = self.classify(case_query_cp.case, modify_case=True)
|
249
|
-
if self.
|
270
|
+
conclusions = self.classify(case_query_cp.case, modify_case=True, case_query=case_query_cp)
|
271
|
+
if self.should_i_ask_the_expert_for_a_target(conclusions, case_query_cp, update_existing_rules):
|
250
272
|
expert.ask_for_conclusion(case_query_cp)
|
251
273
|
case_query.target = case_query_cp.target
|
252
274
|
if case_query.target is None:
|
@@ -262,6 +284,34 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
262
284
|
|
263
285
|
return fit_case_result
|
264
286
|
|
287
|
+
@staticmethod
|
288
|
+
def should_i_ask_the_expert_for_a_target(conclusions: Union[Any, Dict[str, Any]],
|
289
|
+
case_query: CaseQuery,
|
290
|
+
update_existing: bool) -> bool:
|
291
|
+
"""
|
292
|
+
Determine if the rdr should ask the expert for the target of a given case query.
|
293
|
+
|
294
|
+
:param conclusions: The conclusions of the case.
|
295
|
+
:param case_query: The query containing the case to classify.
|
296
|
+
:param update_existing: Whether to update rules that gave the required type of conclusions.
|
297
|
+
:return: True if the rdr should ask the expert, False otherwise.
|
298
|
+
"""
|
299
|
+
if conclusions is None:
|
300
|
+
return True
|
301
|
+
elif is_iterable(conclusions) and len(conclusions) == 0:
|
302
|
+
return True
|
303
|
+
elif isinstance(conclusions, dict):
|
304
|
+
if case_query.attribute_name not in conclusions:
|
305
|
+
return True
|
306
|
+
conclusions = conclusions[case_query.attribute_name]
|
307
|
+
conclusion_types = map(type, make_list(conclusions))
|
308
|
+
if not any(ct in case_query.core_attribute_type for ct in conclusion_types):
|
309
|
+
return True
|
310
|
+
elif update_existing:
|
311
|
+
return True
|
312
|
+
else:
|
313
|
+
return False
|
314
|
+
|
265
315
|
@abstractmethod
|
266
316
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
267
317
|
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
@@ -326,11 +376,13 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
326
376
|
pass
|
327
377
|
|
328
378
|
@abstractmethod
|
329
|
-
def update_from_python(self, model_dir: str):
|
379
|
+
def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
|
330
380
|
"""
|
331
381
|
Update the rules from the generated python file, that might have been modified by the user.
|
332
382
|
|
333
383
|
:param model_dir: The directory where the generated python file is located.
|
384
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
385
|
+
is required in case of relative imports in the generated python file.
|
334
386
|
"""
|
335
387
|
pass
|
336
388
|
|
@@ -352,42 +404,53 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
352
404
|
:return: The module that contains the rdr classifier function.
|
353
405
|
"""
|
354
406
|
# remove from imports if exists first
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
pass
|
361
|
-
return importlib.import_module(name).classify
|
407
|
+
package_name = get_import_path_from_path(package_name)
|
408
|
+
name = f"{package_name}.{self.generated_python_file_name}" if package_name else self.generated_python_file_name
|
409
|
+
module = importlib.import_module(name)
|
410
|
+
importlib.reload(module)
|
411
|
+
return module.classify
|
362
412
|
|
363
413
|
|
364
414
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
365
415
|
|
366
|
-
def update_from_python(self, model_dir: str):
|
416
|
+
def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
|
367
417
|
"""
|
368
418
|
Update the rules from the generated python file, that might have been modified by the user.
|
369
419
|
|
370
420
|
:param model_dir: The directory where the generated python file is located.
|
421
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
422
|
+
is required in case of relative imports in the generated python file.
|
371
423
|
"""
|
372
|
-
rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants)
|
424
|
+
rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants)
|
425
|
+
if r.conditions is not None}
|
373
426
|
condition_func_names = [f'conditions_{rid}' for rid in rules_dict.keys()]
|
374
|
-
conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys()
|
427
|
+
conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys()
|
428
|
+
if not isinstance(rules_dict[rid], MultiClassStopRule)]
|
375
429
|
all_func_names = condition_func_names + conclusion_func_names
|
376
430
|
filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
|
431
|
+
cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
|
432
|
+
cases_import_path = get_import_path_from_path(model_dir)
|
433
|
+
cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path \
|
434
|
+
else self.generated_python_cases_file_name
|
377
435
|
functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
|
378
436
|
# get the scope from the imports in the file
|
379
|
-
scope = extract_imports(filepath)
|
437
|
+
scope = extract_imports(filepath, package_name=package_name)
|
380
438
|
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
381
439
|
if rule.conditions is not None:
|
382
440
|
rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
|
383
441
|
rule.conditions.scope = scope
|
442
|
+
if os.path.exists(cases_path):
|
443
|
+
module = importlib.import_module(cases_import_path, package=package_name)
|
444
|
+
importlib.reload(module)
|
445
|
+
rule.corner_case_metadata = module.__dict__.get(f"corner_case_{rule.uid}", None)
|
384
446
|
if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
|
385
447
|
rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
|
386
448
|
rule.conclusion.scope = scope
|
387
449
|
|
388
450
|
@abstractmethod
|
389
451
|
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
|
390
|
-
defs_file: Optional[str] = None
|
452
|
+
defs_file: Optional[str] = None, cases_file: Optional[str] = None,
|
453
|
+
package_name: Optional[str] = None):
|
391
454
|
"""
|
392
455
|
Write the rules as source code to a file.
|
393
456
|
|
@@ -395,37 +458,63 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
395
458
|
:param file: The file to write the source code to.
|
396
459
|
:param parent_indent: The indentation of the parent rule.
|
397
460
|
:param defs_file: The file to write the definitions to.
|
461
|
+
:param cases_file: The file to write the cases to.
|
462
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
463
|
+
is required in case of relative imports in the generated python file.
|
398
464
|
"""
|
399
465
|
pass
|
400
466
|
|
401
|
-
def _write_to_python(self, model_dir: str):
|
467
|
+
def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
|
402
468
|
"""
|
403
469
|
Write the tree of rules as source code to a file.
|
404
470
|
|
405
471
|
:param model_dir: The path to the directory to write the source code to.
|
472
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
473
|
+
is required in case of relative imports in the generated python file.
|
406
474
|
"""
|
475
|
+
# Make sure the model directory exists and create an __init__.py file if it doesn't exist
|
407
476
|
os.makedirs(model_dir, exist_ok=True)
|
408
477
|
if not os.path.exists(model_dir + '/__init__.py'):
|
409
478
|
with open(model_dir + '/__init__.py', 'w') as f:
|
410
|
-
f.write("
|
411
|
-
|
479
|
+
f.write("from . import *\n")
|
480
|
+
|
481
|
+
# Set the file names for the generated python files
|
412
482
|
file_name = model_dir + f"/{self.generated_python_file_name}.py"
|
413
483
|
defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
|
414
|
-
|
415
|
-
|
484
|
+
cases_file_name = model_dir + f"/{self.generated_python_cases_file_name}.py"
|
485
|
+
|
486
|
+
# Get the required imports for the main file and the defs file
|
487
|
+
main_types, defs_types, corner_cases_types = self._get_types_to_import()
|
488
|
+
imports = get_imports_from_types(main_types, file_name, package_name)
|
489
|
+
defs_imports = get_imports_from_types(defs_types, defs_file_name, package_name)
|
490
|
+
corner_cases_imports = get_imports_from_types(corner_cases_types, cases_file_name, package_name)
|
491
|
+
|
492
|
+
# Add the imports to the defs file
|
416
493
|
with open(defs_file_name, "w") as f:
|
417
|
-
f.write(defs_imports + "\n\n")
|
494
|
+
f.write('\n'.join(defs_imports) + "\n\n\n")
|
495
|
+
|
496
|
+
# Add the imports to the cases file
|
497
|
+
case_factory_import = get_imports_from_types([CaseFactoryMetaData], cases_file_name, package_name)
|
498
|
+
corner_cases_imports.extend(case_factory_import)
|
499
|
+
with open(cases_file_name, "w") as cases_f:
|
500
|
+
cases_f.write("# This file contains the corner cases for the rules.\n")
|
501
|
+
cases_f.write('\n'.join(corner_cases_imports) + "\n\n\n")
|
502
|
+
|
503
|
+
# Add the imports, the attributes, and the function definition to the main file
|
504
|
+
func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
|
418
505
|
with open(file_name, "w") as f:
|
419
|
-
imports
|
420
|
-
imports
|
421
|
-
f.write(imports + "\n\n")
|
506
|
+
imports.append(f"from .{self.generated_python_defs_file_name} import *")
|
507
|
+
f.write('\n'.join(imports) + "\n\n\n")
|
422
508
|
f.write(f"attribute_name = '{self.attribute_name}'\n")
|
423
509
|
f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
|
424
|
-
f.write(f"
|
510
|
+
f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
|
425
511
|
f.write(f"\n\n{func_def}")
|
426
512
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
427
513
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
428
|
-
|
514
|
+
|
515
|
+
# Write the rules as source code to the main file
|
516
|
+
self.write_rules_as_source_code_to_file(self.start_rule, file_name, " " * 4, defs_file=defs_file_name,
|
517
|
+
cases_file=cases_file_name, package_name=package_name)
|
429
518
|
|
430
519
|
@property
|
431
520
|
@abstractmethod
|
@@ -435,31 +524,27 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
435
524
|
"""
|
436
525
|
pass
|
437
526
|
|
438
|
-
def
|
527
|
+
def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
|
439
528
|
"""
|
440
|
-
:return: The
|
529
|
+
:return: The types of the main, defs, and corner cases files of the RDR classifier that will be imported.
|
441
530
|
"""
|
442
|
-
|
531
|
+
defs_types = set()
|
532
|
+
cases_types = set()
|
443
533
|
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
444
534
|
if not rule.conditions:
|
445
535
|
continue
|
446
536
|
for scope in [rule.conditions.scope, rule.conclusion.scope]:
|
447
537
|
if scope is None:
|
448
538
|
continue
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
imports.append(f"from {conclusion_type.__module__} import {conclusion_type.__name__}")
|
459
|
-
imports.append("from ripple_down_rules.datastructures.case import Case, create_case")
|
460
|
-
imports = set(imports).difference(defs_imports_list)
|
461
|
-
imports = "\n".join(imports) + "\n"
|
462
|
-
return imports, defs_imports
|
539
|
+
defs_types.update(make_set(scope.values()))
|
540
|
+
cases_types.update(rule.get_corner_case_types_to_import())
|
541
|
+
defs_types.add(self.case_type)
|
542
|
+
main_types = set()
|
543
|
+
main_types.add(self.case_type)
|
544
|
+
main_types.update(make_set(self.conclusion_type))
|
545
|
+
main_types.update({Case, create_case})
|
546
|
+
main_types = main_types.difference(defs_types)
|
547
|
+
return main_types, defs_types, cases_types
|
463
548
|
|
464
549
|
@property
|
465
550
|
def _default_generated_python_file_name(self) -> Optional[str]:
|
@@ -474,6 +559,9 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
474
559
|
def generated_python_defs_file_name(self) -> str:
|
475
560
|
return f"{self.generated_python_file_name}_defs"
|
476
561
|
|
562
|
+
@property
|
563
|
+
def generated_python_cases_file_name(self) -> str:
|
564
|
+
return f"{self.generated_python_file_name}_cases"
|
477
565
|
|
478
566
|
@property
|
479
567
|
def conclusion_type(self) -> Tuple[Type]:
|
@@ -493,7 +581,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
493
581
|
return self.start_rule.conclusion_name
|
494
582
|
|
495
583
|
def _to_json(self) -> Dict[str, Any]:
|
496
|
-
return {"start_rule": self.start_rule.to_json(),
|
584
|
+
return {"start_rule": self.start_rule.to_json(),
|
585
|
+
"generated_python_file_name": self.generated_python_file_name,
|
497
586
|
"name": self.name,
|
498
587
|
"case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
|
499
588
|
"case_name": self.case_name}
|
@@ -525,6 +614,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
525
614
|
|
526
615
|
|
527
616
|
class SingleClassRDR(RDRWithCodeWriter):
|
617
|
+
mutually_exclusive: bool = True
|
618
|
+
"""
|
619
|
+
The output of the classification of this rdr negates all other possible outputs, there can only be one true value.
|
620
|
+
"""
|
528
621
|
|
529
622
|
def __init__(self, default_conclusion: Optional[Any] = None, **kwargs):
|
530
623
|
"""
|
@@ -550,7 +643,7 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
550
643
|
pred = self.evaluate(case_query.case)
|
551
644
|
if pred.conclusion(case_query.case) != case_query.target_value:
|
552
645
|
expert.ask_for_conditions(case_query, pred)
|
553
|
-
pred.fit_rule(case_query
|
646
|
+
pred.fit_rule(case_query)
|
554
647
|
|
555
648
|
return self.classify(case_query.case)
|
556
649
|
|
@@ -563,18 +656,24 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
563
656
|
"""
|
564
657
|
if not self.start_rule:
|
565
658
|
expert.ask_for_conditions(case_query)
|
566
|
-
self.start_rule = SingleClassRule(case_query
|
567
|
-
conclusion_name=case_query.attribute_name)
|
659
|
+
self.start_rule = SingleClassRule.from_case_query(case_query)
|
568
660
|
|
569
|
-
def classify(self, case: Case, modify_case: bool = False
|
661
|
+
def classify(self, case: Case, modify_case: bool = False,
|
662
|
+
case_query: Optional[CaseQuery] = None) -> Optional[Any]:
|
570
663
|
"""
|
571
664
|
Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
|
572
665
|
|
573
666
|
:param case: The case to classify.
|
574
667
|
:param modify_case: Whether to modify the original case attributes with the conclusion or not.
|
668
|
+
:param case_query: The case query containing the case and the target category to compare the case with.
|
575
669
|
"""
|
576
670
|
pred = self.evaluate(case)
|
577
|
-
|
671
|
+
conclusion = pred.conclusion(case) if pred is not None else None
|
672
|
+
if pred is not None and pred.fired and case_query is not None:
|
673
|
+
if pred.corner_case_metadata is None and conclusion is not None \
|
674
|
+
and type(conclusion) in case_query.core_attribute_type:
|
675
|
+
pred.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
676
|
+
return conclusion if pred is not None and pred.fired else self.default_conclusion
|
578
677
|
|
579
678
|
def evaluate(self, case: Case) -> SingleClassRule:
|
580
679
|
"""
|
@@ -583,29 +682,35 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
583
682
|
matched_rule = self.start_rule(case) if self.start_rule is not None else None
|
584
683
|
return matched_rule if matched_rule is not None else self.start_rule
|
585
684
|
|
586
|
-
def _write_to_python(self, model_dir: str):
|
587
|
-
super()._write_to_python(model_dir)
|
685
|
+
def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
|
686
|
+
super()._write_to_python(model_dir, package_name=package_name)
|
588
687
|
if self.default_conclusion is not None:
|
589
688
|
with open(model_dir + f"/{self.generated_python_file_name}.py", "a") as f:
|
590
689
|
f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
|
591
690
|
|
592
|
-
def write_rules_as_source_code_to_file(self, rule: SingleClassRule,
|
593
|
-
defs_file: Optional[str] = None
|
691
|
+
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, filename: str, parent_indent: str = "",
|
692
|
+
defs_file: Optional[str] = None, cases_file: Optional[str] = None,
|
693
|
+
package_name: Optional[str] = None):
|
594
694
|
"""
|
595
695
|
Write the rules as source code to a file.
|
596
696
|
"""
|
597
697
|
if rule.conditions:
|
698
|
+
rule.write_corner_case_as_source_code(cases_file, package_name=package_name)
|
598
699
|
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
599
|
-
|
700
|
+
with open(filename, "a") as file:
|
701
|
+
file.write(if_clause)
|
600
702
|
if rule.refinement:
|
601
|
-
self.write_rules_as_source_code_to_file(rule.refinement,
|
602
|
-
defs_file=defs_file
|
703
|
+
self.write_rules_as_source_code_to_file(rule.refinement, filename, parent_indent + " ",
|
704
|
+
defs_file=defs_file, cases_file=cases_file,
|
705
|
+
package_name=package_name)
|
603
706
|
|
604
707
|
conclusion_call = rule.write_conclusion_as_source_code(parent_indent, defs_file)
|
605
|
-
|
708
|
+
with open(filename, "a") as file:
|
709
|
+
file.write(conclusion_call)
|
606
710
|
|
607
711
|
if rule.alternative:
|
608
|
-
self.write_rules_as_source_code_to_file(rule.alternative,
|
712
|
+
self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
|
713
|
+
cases_file=cases_file, package_name=package_name)
|
609
714
|
|
610
715
|
@property
|
611
716
|
def conclusion_type_hint(self) -> str:
|
@@ -643,23 +748,34 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
643
748
|
"""
|
644
749
|
The conditions of the stopping rule if needed.
|
645
750
|
"""
|
751
|
+
mutually_exclusive: bool = False
|
752
|
+
"""
|
753
|
+
The output of the classification of this rdr allows for more than one true value as conclusion.
|
754
|
+
"""
|
646
755
|
|
647
756
|
def __init__(self, start_rule: Optional[MultiClassTopRule] = None,
|
648
|
-
mode: MCRDRMode = MCRDRMode.StopOnly):
|
757
|
+
mode: MCRDRMode = MCRDRMode.StopOnly, **kwargs):
|
649
758
|
"""
|
650
759
|
:param start_rule: The starting rules for the classifier.
|
651
760
|
:param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
|
652
761
|
"""
|
653
|
-
super(MultiClassRDR, self).__init__(start_rule)
|
762
|
+
super(MultiClassRDR, self).__init__(start_rule, **kwargs)
|
654
763
|
self.mode: MCRDRMode = mode
|
655
764
|
|
656
|
-
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False
|
765
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
|
766
|
+
case_query: Optional[CaseQuery] = None) -> Set[Any]:
|
657
767
|
evaluated_rule = self.start_rule
|
658
768
|
self.conclusions = []
|
659
769
|
while evaluated_rule:
|
660
770
|
next_rule = evaluated_rule(case)
|
661
771
|
if evaluated_rule.fired:
|
662
|
-
|
772
|
+
rule_conclusion = evaluated_rule.conclusion(case)
|
773
|
+
if evaluated_rule.corner_case_metadata is None and case_query is not None:
|
774
|
+
if rule_conclusion is not None and len(make_list(rule_conclusion)) > 0 \
|
775
|
+
and any(
|
776
|
+
ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
|
777
|
+
evaluated_rule.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
778
|
+
self.add_conclusion(rule_conclusion)
|
663
779
|
evaluated_rule = next_rule
|
664
780
|
return make_set(self.conclusions)
|
665
781
|
|
@@ -687,7 +803,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
687
803
|
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
|
688
804
|
else:
|
689
805
|
# Rule fired and target is correct or there is no target to compare
|
690
|
-
self.add_conclusion(
|
806
|
+
self.add_conclusion(rule_conclusion)
|
691
807
|
|
692
808
|
if not next_rule:
|
693
809
|
if not make_set(target_value).issubset(make_set(self.conclusions)):
|
@@ -699,24 +815,32 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
699
815
|
return self.conclusions
|
700
816
|
|
701
817
|
def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
|
702
|
-
|
818
|
+
filename: str, parent_indent: str = "", defs_file: Optional[str] = None,
|
819
|
+
cases_file: Optional[str] = None, package_name: Optional[str] = None):
|
703
820
|
if rule == self.start_rule:
|
704
|
-
|
821
|
+
with open(filename, "a") as file:
|
822
|
+
file.write(f"{parent_indent}conclusions = set()\n")
|
705
823
|
if rule.conditions:
|
824
|
+
rule.write_corner_case_as_source_code(cases_file, package_name=package_name)
|
706
825
|
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
707
|
-
|
826
|
+
with open(filename, "a") as file:
|
827
|
+
file.write(if_clause)
|
708
828
|
conclusion_indent = parent_indent
|
709
829
|
if hasattr(rule, "refinement") and rule.refinement:
|
710
|
-
self.write_rules_as_source_code_to_file(rule.refinement,
|
711
|
-
defs_file=defs_file
|
830
|
+
self.write_rules_as_source_code_to_file(rule.refinement, filename, parent_indent + " ",
|
831
|
+
defs_file=defs_file, cases_file=cases_file,
|
832
|
+
package_name=package_name)
|
712
833
|
conclusion_indent = parent_indent + " " * 4
|
713
|
-
|
834
|
+
with open(filename, "a") as file:
|
835
|
+
file.write(f"{conclusion_indent}else:\n")
|
714
836
|
|
715
837
|
conclusion_call = rule.write_conclusion_as_source_code(conclusion_indent, defs_file)
|
716
|
-
|
838
|
+
with open(filename, "a") as file:
|
839
|
+
file.write(conclusion_call)
|
717
840
|
|
718
841
|
if rule.alternative:
|
719
|
-
self.write_rules_as_source_code_to_file(rule.alternative,
|
842
|
+
self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
|
843
|
+
cases_file=cases_file, package_name=package_name)
|
720
844
|
|
721
845
|
@property
|
722
846
|
def conclusion_type_hint(self) -> str:
|
@@ -726,12 +850,11 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
726
850
|
else:
|
727
851
|
return f"Set[Union[{', '.join(conclusion_types)}]]"
|
728
852
|
|
729
|
-
def
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
return imports, defs_imports
|
853
|
+
def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
|
854
|
+
main_types, defs_types, cases_types = super()._get_types_to_import()
|
855
|
+
main_types.update({Set, Union, make_set})
|
856
|
+
defs_types.add(Union)
|
857
|
+
return main_types, defs_types, cases_types
|
735
858
|
|
736
859
|
def update_start_rule(self, case_query: CaseQuery, expert: Expert):
|
737
860
|
"""
|
@@ -742,8 +865,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
742
865
|
"""
|
743
866
|
if not self.start_rule:
|
744
867
|
conditions = expert.ask_for_conditions(case_query)
|
745
|
-
self.start_rule = MultiClassTopRule(
|
746
|
-
conclusion_name=case_query.attribute_name)
|
868
|
+
self.start_rule = MultiClassTopRule.from_case_query(case_query)
|
747
869
|
|
748
870
|
@property
|
749
871
|
def last_top_rule(self) -> Optional[MultiClassTopRule]:
|
@@ -764,7 +886,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
764
886
|
if is_conflicting(rule_conclusion, case_query.target_value):
|
765
887
|
self.stop_conclusion(case_query, expert, evaluated_rule)
|
766
888
|
else:
|
767
|
-
self.add_conclusion(
|
889
|
+
self.add_conclusion(rule_conclusion)
|
768
890
|
|
769
891
|
def stop_conclusion(self, case_query: CaseQuery,
|
770
892
|
expert: Expert, evaluated_rule: MultiClassTopRule):
|
@@ -776,12 +898,13 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
776
898
|
:param evaluated_rule: The evaluated rule to ask the expert about.
|
777
899
|
"""
|
778
900
|
conditions = expert.ask_for_conditions(case_query, evaluated_rule)
|
779
|
-
evaluated_rule.fit_rule(case_query
|
901
|
+
evaluated_rule.fit_rule(case_query)
|
780
902
|
if self.mode == MCRDRMode.StopPlusRule:
|
781
903
|
self.stop_rule_conditions = conditions
|
782
904
|
if self.mode == MCRDRMode.StopPlusRuleCombined:
|
783
905
|
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
784
|
-
|
906
|
+
case_query.conditions = new_top_rule_conditions
|
907
|
+
self.add_top_rule(case_query)
|
785
908
|
|
786
909
|
def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
|
787
910
|
"""
|
@@ -793,19 +916,19 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
793
916
|
if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
|
794
917
|
conditions = self.stop_rule_conditions
|
795
918
|
self.stop_rule_conditions = None
|
919
|
+
case_query.conditions = conditions
|
796
920
|
else:
|
797
921
|
conditions = expert.ask_for_conditions(case_query)
|
798
|
-
self.add_top_rule(
|
922
|
+
self.add_top_rule(case_query)
|
799
923
|
|
800
|
-
def add_conclusion(self,
|
924
|
+
def add_conclusion(self, rule_conclusion: List[Any]) -> None:
|
801
925
|
"""
|
802
926
|
Add the conclusion of the evaluated rule to the list of conclusions.
|
803
927
|
|
804
|
-
:param
|
805
|
-
|
928
|
+
:param rule_conclusion: The conclusion of the evaluated rule, which can be a single conclusion
|
929
|
+
or a set of conclusions.
|
806
930
|
"""
|
807
931
|
conclusion_types = [type(c) for c in self.conclusions]
|
808
|
-
rule_conclusion = evaluated_rule.conclusion(case)
|
809
932
|
if type(rule_conclusion) not in conclusion_types:
|
810
933
|
self.conclusions.extend(make_list(rule_conclusion))
|
811
934
|
else:
|
@@ -818,15 +941,13 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
818
941
|
self.conclusions.remove(c)
|
819
942
|
self.conclusions.extend(make_list(combined_conclusion))
|
820
943
|
|
821
|
-
def add_top_rule(self,
|
944
|
+
def add_top_rule(self, case_query: CaseQuery):
|
822
945
|
"""
|
823
946
|
Add a top rule to the classifier, which is a rule that is always checked and is part of the start_rules list.
|
824
947
|
|
825
|
-
:param
|
826
|
-
:param conclusion: The conclusion of the rule.
|
827
|
-
:param corner_case: The corner case of the rule.
|
948
|
+
:param case_query: The case query to add the top rule for.
|
828
949
|
"""
|
829
|
-
self.start_rule.alternative = MultiClassTopRule(
|
950
|
+
self.start_rule.alternative = MultiClassTopRule.from_case_query(case_query)
|
830
951
|
|
831
952
|
@staticmethod
|
832
953
|
def start_rule_type() -> Type[Rule]:
|
@@ -887,59 +1008,19 @@ class GeneralRDR(RippleDownRules):
|
|
887
1008
|
def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
|
888
1009
|
return [rdr.start_rule for rdr in self.start_rules_dict.values()]
|
889
1010
|
|
890
|
-
def classify(self, case: Any, modify_case: bool = False
|
1011
|
+
def classify(self, case: Any, modify_case: bool = False,
|
1012
|
+
case_query: Optional[CaseQuery] = None) -> Optional[Dict[str, Any]]:
|
891
1013
|
"""
|
892
1014
|
Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
|
893
1015
|
the classification until no more categories can be added.
|
894
1016
|
|
895
1017
|
:param case: The case to classify.
|
896
1018
|
:param modify_case: Whether to modify the original case or create a copy and modify it.
|
1019
|
+
:param case_query: The case query containing the case and the target category to compare the case with.
|
897
1020
|
:return: The categories that the case belongs to.
|
898
1021
|
"""
|
899
|
-
return
|
900
|
-
|
901
|
-
@staticmethod
|
902
|
-
def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
|
903
|
-
case: Any, modify_original_case: bool = False) -> Dict[str, Any]:
|
904
|
-
"""
|
905
|
-
Classify a case by going through all classifiers and adding the categories that are classified,
|
906
|
-
and then restarting the classification until no more categories can be added.
|
907
|
-
|
908
|
-
:param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
|
909
|
-
:param case: The case to classify.
|
910
|
-
:param modify_original_case: Whether to modify the original case or create a copy and modify it.
|
911
|
-
:return: The categories that the case belongs to.
|
912
|
-
"""
|
913
|
-
conclusions = {}
|
914
|
-
case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
|
915
|
-
case_cp = copy_case(case) if not modify_original_case else case
|
916
|
-
while True:
|
917
|
-
new_conclusions = {}
|
918
|
-
for attribute_name, rdr in classifiers_dict.items():
|
919
|
-
pred_atts = rdr.classify(case_cp)
|
920
|
-
if pred_atts is None:
|
921
|
-
continue
|
922
|
-
if rdr.type_ is SingleClassRDR:
|
923
|
-
if attribute_name not in conclusions or \
|
924
|
-
(attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
|
925
|
-
conclusions[attribute_name] = pred_atts
|
926
|
-
new_conclusions[attribute_name] = pred_atts
|
927
|
-
else:
|
928
|
-
pred_atts = make_set(pred_atts)
|
929
|
-
if attribute_name in conclusions:
|
930
|
-
pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
|
931
|
-
if len(pred_atts) > 0:
|
932
|
-
new_conclusions[attribute_name] = pred_atts
|
933
|
-
if attribute_name not in conclusions:
|
934
|
-
conclusions[attribute_name] = set()
|
935
|
-
conclusions[attribute_name].update(pred_atts)
|
936
|
-
if attribute_name in new_conclusions:
|
937
|
-
mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
|
938
|
-
case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive)
|
939
|
-
update_case(case_query, new_conclusions)
|
940
|
-
if len(new_conclusions) == 0:
|
941
|
-
break
|
942
|
-
return conclusions
|
1022
|
+
return general_rdr_classify(self.start_rules_dict, case, modify_original_case=modify_case,
|
1023
|
+
case_query=case_query)
|
943
1024
|
|
944
1025
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
945
1026
|
-> Dict[str, Any]:
|
@@ -985,7 +1066,7 @@ class GeneralRDR(RippleDownRules):
|
|
985
1066
|
|
986
1067
|
def _to_json(self) -> Dict[str, Any]:
|
987
1068
|
return {"start_rules": {name: rdr.to_json() for name, rdr in self.start_rules_dict.items()}
|
988
|
-
|
1069
|
+
, "generated_python_file_name": self.generated_python_file_name,
|
989
1070
|
"name": self.name,
|
990
1071
|
"case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
|
991
1072
|
"case_name": self.case_name}
|
@@ -1009,26 +1090,30 @@ class GeneralRDR(RippleDownRules):
|
|
1009
1090
|
new_rdr.case_name = data["case_name"]
|
1010
1091
|
return new_rdr
|
1011
1092
|
|
1012
|
-
def update_from_python(self, model_dir: str) -> None:
|
1093
|
+
def update_from_python(self, model_dir: str, package_name: Optional[str] = None) -> None:
|
1013
1094
|
"""
|
1014
1095
|
Update the rules from the generated python file, that might have been modified by the user.
|
1015
1096
|
|
1016
1097
|
:param model_dir: The directory where the model is stored.
|
1098
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
1099
|
+
is required in case of relative imports in the generated python file.
|
1017
1100
|
"""
|
1018
1101
|
for rdr in self.start_rules_dict.values():
|
1019
|
-
rdr.update_from_python(model_dir)
|
1102
|
+
rdr.update_from_python(model_dir, package_name=package_name)
|
1020
1103
|
|
1021
|
-
def _write_to_python(self, model_dir: str) -> None:
|
1104
|
+
def _write_to_python(self, model_dir: str, package_name: Optional[str] = None) -> None:
|
1022
1105
|
"""
|
1023
1106
|
Write the tree of rules as source code to a file.
|
1024
1107
|
|
1025
1108
|
:param model_dir: The directory where the model is stored.
|
1109
|
+
:param relative_imports: Whether to use relative imports in the generated python file.
|
1026
1110
|
"""
|
1027
1111
|
for rdr in self.start_rules_dict.values():
|
1028
|
-
rdr._write_to_python(model_dir)
|
1029
|
-
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
1030
|
-
|
1031
|
-
|
1112
|
+
rdr._write_to_python(model_dir, package_name=package_name)
|
1113
|
+
func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
|
1114
|
+
file_path = model_dir + f"/{self.generated_python_file_name}.py"
|
1115
|
+
with open(file_path, "w") as f:
|
1116
|
+
f.write(self._get_imports(file_path=file_path, package_name=package_name) + "\n\n")
|
1032
1117
|
f.write("classifiers_dict = dict()\n")
|
1033
1118
|
for rdr_key, rdr in self.start_rules_dict.items():
|
1034
1119
|
f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
|
@@ -1036,7 +1121,7 @@ class GeneralRDR(RippleDownRules):
|
|
1036
1121
|
f.write(func_def)
|
1037
1122
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
1038
1123
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
1039
|
-
f.write(f"{' ' * 4}return
|
1124
|
+
f.write(f"{' ' * 4}return general_rdr_classify(classifiers_dict, case, **kwargs)\n")
|
1040
1125
|
|
1041
1126
|
@property
|
1042
1127
|
def _default_generated_python_file_name(self) -> Optional[str]:
|
@@ -1051,25 +1136,29 @@ class GeneralRDR(RippleDownRules):
|
|
1051
1136
|
def conclusion_type_hint(self) -> str:
|
1052
1137
|
return f"Dict[str, Any]"
|
1053
1138
|
|
1054
|
-
def _get_imports(self) -> str:
|
1139
|
+
def _get_imports(self, file_path: Optional[str] = None, package_name: Optional[str] = None) -> str:
|
1055
1140
|
"""
|
1056
1141
|
Get the imports needed for the generated python file.
|
1057
1142
|
|
1143
|
+
:param file_path: The path to the file where the imports will be written, if None, the imports will be absolute.
|
1144
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
1145
|
+
is required in case of relative imports in the generated python file.
|
1058
1146
|
:return: The imports needed for the generated python file.
|
1059
1147
|
"""
|
1060
|
-
|
1148
|
+
all_types = set()
|
1061
1149
|
# add type hints
|
1062
|
-
|
1150
|
+
all_types.update({Dict, Any})
|
1063
1151
|
# import rdr type
|
1064
|
-
|
1152
|
+
all_types.add(general_rdr_classify)
|
1065
1153
|
# add case type
|
1066
|
-
|
1067
|
-
imports
|
1154
|
+
all_types.update({Case, create_case, self.case_type})
|
1155
|
+
# get the imports from the types
|
1156
|
+
imports = get_imports_from_types(all_types, target_file_path=file_path, package_name=package_name)
|
1068
1157
|
# add rdr python generated functions.
|
1069
1158
|
for rdr_key, rdr in self.start_rules_dict.items():
|
1070
|
-
imports
|
1071
|
-
|
1072
|
-
return imports
|
1159
|
+
imports.append(
|
1160
|
+
f"from . import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}")
|
1161
|
+
return '\n'.join(imports)
|
1073
1162
|
|
1074
1163
|
@staticmethod
|
1075
1164
|
def rdr_key_to_function_name(rdr_key: str) -> str:
|