ripple-down-rules 0.5.63__py3-none-any.whl → 0.5.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/datastructures/case.py +10 -4
- ripple_down_rules/datastructures/dataclasses.py +62 -3
- ripple_down_rules/helpers.py +55 -9
- ripple_down_rules/rdr.py +141 -101
- ripple_down_rules/rdr_decorators.py +54 -23
- ripple_down_rules/rules.py +63 -13
- ripple_down_rules/user_interface/gui.py +9 -7
- ripple_down_rules/user_interface/ipython_custom_shell.py +1 -1
- ripple_down_rules/user_interface/object_diagram.py +9 -1
- ripple_down_rules/user_interface/template_file_creator.py +17 -22
- ripple_down_rules/utils.py +235 -62
- {ripple_down_rules-0.5.63.dist-info → ripple_down_rules-0.5.71.dist-info}/METADATA +2 -1
- ripple_down_rules-0.5.71.dist-info/RECORD +24 -0
- ripple_down_rules-0.5.63.dist-info/RECORD +0 -24
- {ripple_down_rules-0.5.63.dist-info → ripple_down_rules-0.5.71.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.5.63.dist-info → ripple_down_rules-0.5.71.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.5.63.dist-info → ripple_down_rules-0.5.71.dist-info}/top_level.txt +0 -0
ripple_down_rules/rdr.py
CHANGED
@@ -4,6 +4,8 @@ import copyreg
|
|
4
4
|
import importlib
|
5
5
|
import os
|
6
6
|
|
7
|
+
from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
|
8
|
+
|
7
9
|
from . import logger
|
8
10
|
import sys
|
9
11
|
from abc import ABC, abstractmethod
|
@@ -28,7 +30,7 @@ from .datastructures.case import Case, CaseAttribute, create_case
|
|
28
30
|
from .datastructures.dataclasses import CaseQuery
|
29
31
|
from .datastructures.enums import MCRDRMode
|
30
32
|
from .experts import Expert, Human
|
31
|
-
from .helpers import is_matching
|
33
|
+
from .helpers import is_matching, general_rdr_classify
|
32
34
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
33
35
|
try:
|
34
36
|
from .user_interface.gui import RDRCaseViewer
|
@@ -36,8 +38,8 @@ except ImportError as e:
|
|
36
38
|
RDRCaseViewer = None
|
37
39
|
from .utils import draw_tree, make_set, copy_case, \
|
38
40
|
SubclassJSONSerializer, make_list, get_type_from_string, \
|
39
|
-
is_conflicting,
|
40
|
-
is_iterable, str_to_snake_case
|
41
|
+
is_conflicting, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
|
42
|
+
is_iterable, str_to_snake_case, get_import_path_from_path
|
41
43
|
|
42
44
|
|
43
45
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -76,16 +78,18 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
76
78
|
"""
|
77
79
|
The name of the model. If None, the model name will be the generated python file name.
|
78
80
|
"""
|
81
|
+
mutually_exclusive: Optional[bool] = None
|
82
|
+
"""
|
83
|
+
Whether the output of the classification of this rdr allows only one possible conclusion or not.
|
84
|
+
"""
|
79
85
|
|
80
86
|
def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None,
|
81
|
-
save_dir: Optional[str] = None,
|
87
|
+
save_dir: Optional[str] = None, model_name: Optional[str] = None):
|
82
88
|
"""
|
83
89
|
:param start_rule: The starting rule for the classifier.
|
84
90
|
:param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
|
85
91
|
:param save_dir: The directory to save the classifier to.
|
86
|
-
:param ask_always: Whether to always ask the expert (True) or only ask when classification fails (False).
|
87
92
|
"""
|
88
|
-
self.ask_always: bool = ask_always
|
89
93
|
self.model_name: Optional[str] = model_name
|
90
94
|
self.save_dir = save_dir
|
91
95
|
self.start_rule = start_rule
|
@@ -110,7 +114,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
110
114
|
if not os.path.exists(save_dir + '/__init__.py'):
|
111
115
|
os.makedirs(save_dir, exist_ok=True)
|
112
116
|
with open(save_dir + '/__init__.py', 'w') as f:
|
113
|
-
f.write("
|
117
|
+
f.write("from . import *\n")
|
114
118
|
if model_name is not None:
|
115
119
|
self.model_name = model_name
|
116
120
|
elif self.model_name is None:
|
@@ -134,7 +138,11 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
134
138
|
model_dir = os.path.join(load_dir, model_name)
|
135
139
|
json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
|
136
140
|
rdr = cls.from_json_file(json_file)
|
137
|
-
|
141
|
+
try:
|
142
|
+
rdr.update_from_python(model_dir)
|
143
|
+
except (FileNotFoundError, ValueError) as e:
|
144
|
+
logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
|
145
|
+
f"Make sure the file exists and is valid.")
|
138
146
|
rdr.save_dir = load_dir
|
139
147
|
rdr.model_name = model_name
|
140
148
|
return rdr
|
@@ -213,18 +221,24 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
213
221
|
return self.classify(case)
|
214
222
|
|
215
223
|
@abstractmethod
|
216
|
-
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False
|
224
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
|
225
|
+
case_query: Optional[CaseQuery] = None) \
|
217
226
|
-> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
|
218
227
|
"""
|
219
228
|
Classify a case.
|
220
229
|
|
221
230
|
:param case: The case to classify.
|
222
231
|
:param modify_case: Whether to modify the original case attributes with the conclusion or not.
|
232
|
+
:param case_query: The case query containing the case to classify and the target category to compare the case with.
|
223
233
|
:return: The category that the case belongs to.
|
224
234
|
"""
|
225
235
|
pass
|
226
236
|
|
227
|
-
def fit_case(self, case_query: CaseQuery,
|
237
|
+
def fit_case(self, case_query: CaseQuery,
|
238
|
+
expert: Optional[Expert] = None,
|
239
|
+
update_existing_rules: bool = True,
|
240
|
+
scenario: Optional[Callable] = None,
|
241
|
+
**kwargs) \
|
228
242
|
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
229
243
|
"""
|
230
244
|
Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
|
@@ -232,6 +246,9 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
232
246
|
|
233
247
|
:param case_query: The query containing the case to classify and the target category to compare the case with.
|
234
248
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
249
|
+
:param update_existing_rules: Whether to update the existing same conclusion type rules that already gave
|
250
|
+
some conclusions with the type required by the case query.
|
251
|
+
:param scenario: The scenario at which the case was created, this is used to recreate the case if needed.
|
235
252
|
:return: The category that the case belongs to.
|
236
253
|
"""
|
237
254
|
if case_query is None:
|
@@ -240,19 +257,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
240
257
|
self.name = case_query.attribute_name if self.name is None else self.name
|
241
258
|
self.case_type = case_query.case_type if self.case_type is None else self.case_type
|
242
259
|
self.case_name = case_query.case_name if self.case_name is None else self.case_name
|
260
|
+
case_query.scenario = scenario if case_query.scenario is None else case_query.scenario
|
243
261
|
|
244
262
|
expert = expert or Human(viewer=self.viewer,
|
245
263
|
answers_save_path=self.save_dir + '/expert_answers'
|
246
264
|
if self.save_dir else None)
|
247
|
-
|
248
265
|
if case_query.target is None:
|
249
266
|
case_query_cp = copy(case_query)
|
250
|
-
conclusions = self.classify(case_query_cp.case, modify_case=True)
|
251
|
-
if
|
252
|
-
or is_iterable(conclusions) and len(conclusions) == 0
|
253
|
-
or (isinstance(conclusions, dict) and (case_query_cp.attribute_name not in conclusions
|
254
|
-
or not any(type(c) in case_query_cp.core_attribute_type
|
255
|
-
for c in make_list(conclusions[case_query_cp.attribute_name]))))):
|
267
|
+
conclusions = self.classify(case_query_cp.case, modify_case=True, case_query=case_query_cp)
|
268
|
+
if self.should_i_ask_the_expert_for_a_target(conclusions, case_query_cp, update_existing_rules):
|
256
269
|
expert.ask_for_conclusion(case_query_cp)
|
257
270
|
case_query.target = case_query_cp.target
|
258
271
|
if case_query.target is None:
|
@@ -268,6 +281,34 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
268
281
|
|
269
282
|
return fit_case_result
|
270
283
|
|
284
|
+
@staticmethod
|
285
|
+
def should_i_ask_the_expert_for_a_target(conclusions: Union[Any, Dict[str, Any]],
|
286
|
+
case_query: CaseQuery,
|
287
|
+
update_existing: bool) -> bool:
|
288
|
+
"""
|
289
|
+
Determine if the rdr should ask the expert for the target of a given case query.
|
290
|
+
|
291
|
+
:param conclusions: The conclusions of the case.
|
292
|
+
:param case_query: The query containing the case to classify.
|
293
|
+
:param update_existing: Whether to update rules that gave the required type of conclusions.
|
294
|
+
:return: True if the rdr should ask the expert, False otherwise.
|
295
|
+
"""
|
296
|
+
if conclusions is None:
|
297
|
+
return True
|
298
|
+
elif is_iterable(conclusions) and len(conclusions) == 0:
|
299
|
+
return True
|
300
|
+
elif isinstance(conclusions, dict):
|
301
|
+
if case_query.attribute_name not in conclusions:
|
302
|
+
return True
|
303
|
+
conclusions = conclusions[case_query.attribute_name]
|
304
|
+
conclusion_types = map(type, make_list(conclusions))
|
305
|
+
if not any(ct in case_query.core_attribute_type for ct in conclusion_types):
|
306
|
+
return True
|
307
|
+
elif update_existing:
|
308
|
+
return True
|
309
|
+
else:
|
310
|
+
return False
|
311
|
+
|
271
312
|
@abstractmethod
|
272
313
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
273
314
|
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
@@ -358,7 +399,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
358
399
|
:return: The module that contains the rdr classifier function.
|
359
400
|
"""
|
360
401
|
# remove from imports if exists first
|
361
|
-
|
402
|
+
package_name = get_import_path_from_path(package_name)
|
403
|
+
name = f"{package_name}.{self.generated_python_file_name}" if package_name else self.generated_python_file_name
|
362
404
|
try:
|
363
405
|
module = importlib.import_module(name)
|
364
406
|
del sys.modules[name]
|
@@ -380,6 +422,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
380
422
|
conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys() if not isinstance(rules_dict[rid], MultiClassStopRule)]
|
381
423
|
all_func_names = condition_func_names + conclusion_func_names
|
382
424
|
filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
|
425
|
+
cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
|
426
|
+
cases_import_path = get_import_path_from_path(model_dir)
|
427
|
+
cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path\
|
428
|
+
else self.generated_python_cases_file_name
|
383
429
|
functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
|
384
430
|
# get the scope from the imports in the file
|
385
431
|
scope = extract_imports(filepath)
|
@@ -387,13 +433,15 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
387
433
|
if rule.conditions is not None:
|
388
434
|
rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
|
389
435
|
rule.conditions.scope = scope
|
436
|
+
if os.path.exists(cases_path):
|
437
|
+
rule.corner_case_metadata = importlib.import_module(cases_import_path).__dict__.get(f"corner_case_{rule.uid}", None)
|
390
438
|
if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
|
391
439
|
rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
|
392
440
|
rule.conclusion.scope = scope
|
393
441
|
|
394
442
|
@abstractmethod
|
395
443
|
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
|
396
|
-
defs_file: Optional[str] = None):
|
444
|
+
defs_file: Optional[str] = None, cases_file: Optional[str] = None):
|
397
445
|
"""
|
398
446
|
Write the rules as source code to a file.
|
399
447
|
|
@@ -401,6 +449,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
401
449
|
:param file: The file to write the source code to.
|
402
450
|
:param parent_indent: The indentation of the parent rule.
|
403
451
|
:param defs_file: The file to write the definitions to.
|
452
|
+
:param cases_file: The file to write the cases to.
|
404
453
|
"""
|
405
454
|
pass
|
406
455
|
|
@@ -413,25 +462,28 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
413
462
|
os.makedirs(model_dir, exist_ok=True)
|
414
463
|
if not os.path.exists(model_dir + '/__init__.py'):
|
415
464
|
with open(model_dir + '/__init__.py', 'w') as f:
|
416
|
-
f.write("
|
417
|
-
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
465
|
+
f.write("from . import *\n")
|
466
|
+
func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
|
418
467
|
file_name = model_dir + f"/{self.generated_python_file_name}.py"
|
419
468
|
defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
|
469
|
+
cases_file_name = model_dir + f"/{self.generated_python_cases_file_name}.py"
|
420
470
|
imports, defs_imports = self._get_imports()
|
421
471
|
# clear the files first
|
422
472
|
with open(defs_file_name, "w") as f:
|
423
473
|
f.write(defs_imports + "\n\n")
|
474
|
+
with open(cases_file_name, "w") as cases_f:
|
475
|
+
cases_f.write("# This file contains the corner cases for the rules.\n")
|
424
476
|
with open(file_name, "w") as f:
|
425
477
|
imports += f"from .{self.generated_python_defs_file_name} import *\n"
|
426
|
-
imports += f"from ripple_down_rules.rdr import {self.__class__.__name__}\n"
|
427
478
|
f.write(imports + "\n\n")
|
428
479
|
f.write(f"attribute_name = '{self.attribute_name}'\n")
|
429
480
|
f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
|
430
|
-
f.write(f"
|
481
|
+
f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
|
431
482
|
f.write(f"\n\n{func_def}")
|
432
483
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
433
484
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
434
|
-
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name
|
485
|
+
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name,
|
486
|
+
cases_file=cases_file_name)
|
435
487
|
|
436
488
|
@property
|
437
489
|
@abstractmethod
|
@@ -480,6 +532,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
480
532
|
def generated_python_defs_file_name(self) -> str:
|
481
533
|
return f"{self.generated_python_file_name}_defs"
|
482
534
|
|
535
|
+
@property
|
536
|
+
def generated_python_cases_file_name(self) -> str:
|
537
|
+
return f"{self.generated_python_file_name}_cases"
|
538
|
+
|
483
539
|
|
484
540
|
@property
|
485
541
|
def conclusion_type(self) -> Tuple[Type]:
|
@@ -533,6 +589,11 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
533
589
|
|
534
590
|
class SingleClassRDR(RDRWithCodeWriter):
|
535
591
|
|
592
|
+
mutually_exclusive: bool = True
|
593
|
+
"""
|
594
|
+
The output of the classification of this rdr negates all other possible outputs, there can only be one true value.
|
595
|
+
"""
|
596
|
+
|
536
597
|
def __init__(self, default_conclusion: Optional[Any] = None, **kwargs):
|
537
598
|
"""
|
538
599
|
:param start_rule: The starting rule for the classifier.
|
@@ -557,7 +618,7 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
557
618
|
pred = self.evaluate(case_query.case)
|
558
619
|
if pred.conclusion(case_query.case) != case_query.target_value:
|
559
620
|
expert.ask_for_conditions(case_query, pred)
|
560
|
-
pred.fit_rule(case_query
|
621
|
+
pred.fit_rule(case_query)
|
561
622
|
|
562
623
|
return self.classify(case_query.case)
|
563
624
|
|
@@ -570,18 +631,24 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
570
631
|
"""
|
571
632
|
if not self.start_rule:
|
572
633
|
expert.ask_for_conditions(case_query)
|
573
|
-
self.start_rule = SingleClassRule(case_query
|
574
|
-
conclusion_name=case_query.attribute_name)
|
634
|
+
self.start_rule = SingleClassRule.from_case_query(case_query)
|
575
635
|
|
576
|
-
def classify(self, case: Case, modify_case: bool = False
|
636
|
+
def classify(self, case: Case, modify_case: bool = False,
|
637
|
+
case_query: Optional[CaseQuery] = None) -> Optional[Any]:
|
577
638
|
"""
|
578
639
|
Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
|
579
640
|
|
580
641
|
:param case: The case to classify.
|
581
642
|
:param modify_case: Whether to modify the original case attributes with the conclusion or not.
|
643
|
+
:param case_query: The case query containing the case and the target category to compare the case with.
|
582
644
|
"""
|
583
645
|
pred = self.evaluate(case)
|
584
|
-
|
646
|
+
conclusion = pred.conclusion(case) if pred is not None else None
|
647
|
+
if pred is not None and pred.fired and case_query is not None:
|
648
|
+
if pred.corner_case_metadata is None and conclusion is not None\
|
649
|
+
and type(conclusion) in case_query.core_attribute_type:
|
650
|
+
pred.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
651
|
+
return conclusion if pred is not None and pred.fired else self.default_conclusion
|
585
652
|
|
586
653
|
def evaluate(self, case: Case) -> SingleClassRule:
|
587
654
|
"""
|
@@ -597,22 +664,24 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
597
664
|
f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
|
598
665
|
|
599
666
|
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
|
600
|
-
defs_file: Optional[str] = None):
|
667
|
+
defs_file: Optional[str] = None, cases_file: Optional[str] = None):
|
601
668
|
"""
|
602
669
|
Write the rules as source code to a file.
|
603
670
|
"""
|
604
671
|
if rule.conditions:
|
672
|
+
rule.write_corner_case_as_source_code(cases_file)
|
605
673
|
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
606
674
|
file.write(if_clause)
|
607
675
|
if rule.refinement:
|
608
676
|
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
|
609
|
-
defs_file=defs_file)
|
677
|
+
defs_file=defs_file, cases_file=cases_file)
|
610
678
|
|
611
679
|
conclusion_call = rule.write_conclusion_as_source_code(parent_indent, defs_file)
|
612
680
|
file.write(conclusion_call)
|
613
681
|
|
614
682
|
if rule.alternative:
|
615
|
-
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file
|
683
|
+
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file,
|
684
|
+
cases_file=cases_file)
|
616
685
|
|
617
686
|
@property
|
618
687
|
def conclusion_type_hint(self) -> str:
|
@@ -650,6 +719,10 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
650
719
|
"""
|
651
720
|
The conditions of the stopping rule if needed.
|
652
721
|
"""
|
722
|
+
mutually_exclusive: bool = False
|
723
|
+
"""
|
724
|
+
The output of the classification of this rdr allows for more than one true value as conclusion.
|
725
|
+
"""
|
653
726
|
|
654
727
|
def __init__(self, start_rule: Optional[MultiClassTopRule] = None,
|
655
728
|
mode: MCRDRMode = MCRDRMode.StopOnly, **kwargs):
|
@@ -660,13 +733,19 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
660
733
|
super(MultiClassRDR, self).__init__(start_rule, **kwargs)
|
661
734
|
self.mode: MCRDRMode = mode
|
662
735
|
|
663
|
-
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False
|
736
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
|
737
|
+
case_query: Optional[CaseQuery] = None) -> Set[Any]:
|
664
738
|
evaluated_rule = self.start_rule
|
665
739
|
self.conclusions = []
|
666
740
|
while evaluated_rule:
|
667
741
|
next_rule = evaluated_rule(case)
|
668
742
|
if evaluated_rule.fired:
|
669
|
-
|
743
|
+
rule_conclusion = evaluated_rule.conclusion(case)
|
744
|
+
if evaluated_rule.corner_case_metadata is None and case_query is not None:
|
745
|
+
if rule_conclusion is not None and len(make_list(rule_conclusion)) > 0\
|
746
|
+
and any(ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
|
747
|
+
evaluated_rule.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
748
|
+
self.add_conclusion(rule_conclusion)
|
670
749
|
evaluated_rule = next_rule
|
671
750
|
return make_set(self.conclusions)
|
672
751
|
|
@@ -694,7 +773,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
694
773
|
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
|
695
774
|
else:
|
696
775
|
# Rule fired and target is correct or there is no target to compare
|
697
|
-
self.add_conclusion(
|
776
|
+
self.add_conclusion(rule_conclusion)
|
698
777
|
|
699
778
|
if not next_rule:
|
700
779
|
if not make_set(target_value).issubset(make_set(self.conclusions)):
|
@@ -706,16 +785,18 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
706
785
|
return self.conclusions
|
707
786
|
|
708
787
|
def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
|
709
|
-
file, parent_indent: str = "", defs_file: Optional[str] = None
|
788
|
+
file, parent_indent: str = "", defs_file: Optional[str] = None,
|
789
|
+
cases_file: Optional[str] = None):
|
710
790
|
if rule == self.start_rule:
|
711
791
|
file.write(f"{parent_indent}conclusions = set()\n")
|
712
792
|
if rule.conditions:
|
793
|
+
rule.write_corner_case_as_source_code(cases_file)
|
713
794
|
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
714
795
|
file.write(if_clause)
|
715
796
|
conclusion_indent = parent_indent
|
716
797
|
if hasattr(rule, "refinement") and rule.refinement:
|
717
798
|
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
|
718
|
-
defs_file=defs_file)
|
799
|
+
defs_file=defs_file, cases_file=cases_file)
|
719
800
|
conclusion_indent = parent_indent + " " * 4
|
720
801
|
file.write(f"{conclusion_indent}else:\n")
|
721
802
|
|
@@ -723,7 +804,8 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
723
804
|
file.write(conclusion_call)
|
724
805
|
|
725
806
|
if rule.alternative:
|
726
|
-
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file
|
807
|
+
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file,
|
808
|
+
cases_file=cases_file)
|
727
809
|
|
728
810
|
@property
|
729
811
|
def conclusion_type_hint(self) -> str:
|
@@ -749,8 +831,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
749
831
|
"""
|
750
832
|
if not self.start_rule:
|
751
833
|
conditions = expert.ask_for_conditions(case_query)
|
752
|
-
self.start_rule = MultiClassTopRule(
|
753
|
-
conclusion_name=case_query.attribute_name)
|
834
|
+
self.start_rule = MultiClassTopRule.from_case_query(case_query)
|
754
835
|
|
755
836
|
@property
|
756
837
|
def last_top_rule(self) -> Optional[MultiClassTopRule]:
|
@@ -771,7 +852,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
771
852
|
if is_conflicting(rule_conclusion, case_query.target_value):
|
772
853
|
self.stop_conclusion(case_query, expert, evaluated_rule)
|
773
854
|
else:
|
774
|
-
self.add_conclusion(
|
855
|
+
self.add_conclusion(rule_conclusion)
|
775
856
|
|
776
857
|
def stop_conclusion(self, case_query: CaseQuery,
|
777
858
|
expert: Expert, evaluated_rule: MultiClassTopRule):
|
@@ -783,12 +864,13 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
783
864
|
:param evaluated_rule: The evaluated rule to ask the expert about.
|
784
865
|
"""
|
785
866
|
conditions = expert.ask_for_conditions(case_query, evaluated_rule)
|
786
|
-
evaluated_rule.fit_rule(case_query
|
867
|
+
evaluated_rule.fit_rule(case_query)
|
787
868
|
if self.mode == MCRDRMode.StopPlusRule:
|
788
869
|
self.stop_rule_conditions = conditions
|
789
870
|
if self.mode == MCRDRMode.StopPlusRuleCombined:
|
790
871
|
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
791
|
-
|
872
|
+
case_query.conditions = new_top_rule_conditions
|
873
|
+
self.add_top_rule(case_query)
|
792
874
|
|
793
875
|
def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
|
794
876
|
"""
|
@@ -800,19 +882,19 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
800
882
|
if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
|
801
883
|
conditions = self.stop_rule_conditions
|
802
884
|
self.stop_rule_conditions = None
|
885
|
+
case_query.conditions = conditions
|
803
886
|
else:
|
804
887
|
conditions = expert.ask_for_conditions(case_query)
|
805
|
-
self.add_top_rule(
|
888
|
+
self.add_top_rule(case_query)
|
806
889
|
|
807
|
-
def add_conclusion(self,
|
890
|
+
def add_conclusion(self, rule_conclusion: List[Any]) -> None:
|
808
891
|
"""
|
809
892
|
Add the conclusion of the evaluated rule to the list of conclusions.
|
810
893
|
|
811
|
-
:param
|
812
|
-
|
894
|
+
:param rule_conclusion: The conclusion of the evaluated rule, which can be a single conclusion
|
895
|
+
or a set of conclusions.
|
813
896
|
"""
|
814
897
|
conclusion_types = [type(c) for c in self.conclusions]
|
815
|
-
rule_conclusion = evaluated_rule.conclusion(case)
|
816
898
|
if type(rule_conclusion) not in conclusion_types:
|
817
899
|
self.conclusions.extend(make_list(rule_conclusion))
|
818
900
|
else:
|
@@ -825,15 +907,13 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
825
907
|
self.conclusions.remove(c)
|
826
908
|
self.conclusions.extend(make_list(combined_conclusion))
|
827
909
|
|
828
|
-
def add_top_rule(self,
|
910
|
+
def add_top_rule(self, case_query: CaseQuery):
|
829
911
|
"""
|
830
912
|
Add a top rule to the classifier, which is a rule that is always checked and is part of the start_rules list.
|
831
913
|
|
832
|
-
:param
|
833
|
-
:param conclusion: The conclusion of the rule.
|
834
|
-
:param corner_case: The corner case of the rule.
|
914
|
+
:param case_query: The case query to add the top rule for.
|
835
915
|
"""
|
836
|
-
self.start_rule.alternative = MultiClassTopRule(
|
916
|
+
self.start_rule.alternative = MultiClassTopRule.from_case_query(case_query)
|
837
917
|
|
838
918
|
@staticmethod
|
839
919
|
def start_rule_type() -> Type[Rule]:
|
@@ -894,59 +974,19 @@ class GeneralRDR(RippleDownRules):
|
|
894
974
|
def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
|
895
975
|
return [rdr.start_rule for rdr in self.start_rules_dict.values()]
|
896
976
|
|
897
|
-
def classify(self, case: Any, modify_case: bool = False
|
977
|
+
def classify(self, case: Any, modify_case: bool = False,
|
978
|
+
case_query: Optional[CaseQuery] = None) -> Optional[Dict[str, Any]]:
|
898
979
|
"""
|
899
980
|
Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
|
900
981
|
the classification until no more categories can be added.
|
901
982
|
|
902
983
|
:param case: The case to classify.
|
903
984
|
:param modify_case: Whether to modify the original case or create a copy and modify it.
|
985
|
+
:param case_query: The case query containing the case and the target category to compare the case with.
|
904
986
|
:return: The categories that the case belongs to.
|
905
987
|
"""
|
906
|
-
return
|
907
|
-
|
908
|
-
@staticmethod
|
909
|
-
def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
|
910
|
-
case: Any, modify_original_case: bool = False) -> Dict[str, Any]:
|
911
|
-
"""
|
912
|
-
Classify a case by going through all classifiers and adding the categories that are classified,
|
913
|
-
and then restarting the classification until no more categories can be added.
|
914
|
-
|
915
|
-
:param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
|
916
|
-
:param case: The case to classify.
|
917
|
-
:param modify_original_case: Whether to modify the original case or create a copy and modify it.
|
918
|
-
:return: The categories that the case belongs to.
|
919
|
-
"""
|
920
|
-
conclusions = {}
|
921
|
-
case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
|
922
|
-
case_cp = copy_case(case) if not modify_original_case else case
|
923
|
-
while True:
|
924
|
-
new_conclusions = {}
|
925
|
-
for attribute_name, rdr in classifiers_dict.items():
|
926
|
-
pred_atts = rdr.classify(case_cp)
|
927
|
-
if pred_atts is None:
|
928
|
-
continue
|
929
|
-
if rdr.type_ is SingleClassRDR:
|
930
|
-
if attribute_name not in conclusions or \
|
931
|
-
(attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
|
932
|
-
conclusions[attribute_name] = pred_atts
|
933
|
-
new_conclusions[attribute_name] = pred_atts
|
934
|
-
else:
|
935
|
-
pred_atts = make_set(pred_atts)
|
936
|
-
if attribute_name in conclusions:
|
937
|
-
pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
|
938
|
-
if len(pred_atts) > 0:
|
939
|
-
new_conclusions[attribute_name] = pred_atts
|
940
|
-
if attribute_name not in conclusions:
|
941
|
-
conclusions[attribute_name] = set()
|
942
|
-
conclusions[attribute_name].update(pred_atts)
|
943
|
-
if attribute_name in new_conclusions:
|
944
|
-
mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
|
945
|
-
case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive)
|
946
|
-
update_case(case_query, new_conclusions)
|
947
|
-
if len(new_conclusions) == 0:
|
948
|
-
break
|
949
|
-
return conclusions
|
988
|
+
return general_rdr_classify(self.start_rules_dict, case, modify_original_case=modify_case,
|
989
|
+
case_query=case_query)
|
950
990
|
|
951
991
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
952
992
|
-> Dict[str, Any]:
|
@@ -1033,7 +1073,7 @@ class GeneralRDR(RippleDownRules):
|
|
1033
1073
|
"""
|
1034
1074
|
for rdr in self.start_rules_dict.values():
|
1035
1075
|
rdr._write_to_python(model_dir)
|
1036
|
-
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
1076
|
+
func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
|
1037
1077
|
with open(model_dir + f"/{self.generated_python_file_name}.py", "w") as f:
|
1038
1078
|
f.write(self._get_imports() + "\n\n")
|
1039
1079
|
f.write("classifiers_dict = dict()\n")
|
@@ -1043,7 +1083,7 @@ class GeneralRDR(RippleDownRules):
|
|
1043
1083
|
f.write(func_def)
|
1044
1084
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
1045
1085
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
1046
|
-
f.write(f"{' ' * 4}return
|
1086
|
+
f.write(f"{' ' * 4}return general_rdr_classify(classifiers_dict, case, **kwargs)\n")
|
1047
1087
|
|
1048
1088
|
@property
|
1049
1089
|
def _default_generated_python_file_name(self) -> Optional[str]:
|
@@ -1068,7 +1108,7 @@ class GeneralRDR(RippleDownRules):
|
|
1068
1108
|
# add type hints
|
1069
1109
|
imports += f"from typing_extensions import Dict, Any\n"
|
1070
1110
|
# import rdr type
|
1071
|
-
imports += f"from ripple_down_rules.
|
1111
|
+
imports += f"from ripple_down_rules.helpers import general_rdr_classify\n"
|
1072
1112
|
# add case type
|
1073
1113
|
imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
|
1074
1114
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|