ripple-down-rules 0.5.4__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/datastructures/callable_expression.py +16 -9
- ripple_down_rules/datastructures/case.py +10 -4
- ripple_down_rules/datastructures/dataclasses.py +62 -3
- ripple_down_rules/experts.py +12 -2
- ripple_down_rules/helpers.py +55 -9
- ripple_down_rules/rdr.py +148 -101
- ripple_down_rules/rdr_decorators.py +54 -26
- ripple_down_rules/rules.py +63 -13
- ripple_down_rules/user_interface/gui.py +10 -7
- ripple_down_rules/user_interface/ipython_custom_shell.py +1 -1
- ripple_down_rules/user_interface/object_diagram.py +9 -1
- ripple_down_rules/user_interface/template_file_creator.py +25 -24
- ripple_down_rules/utils.py +260 -78
- {ripple_down_rules-0.5.4.dist-info → ripple_down_rules-0.5.7.dist-info}/METADATA +16 -2
- ripple_down_rules-0.5.7.dist-info/RECORD +24 -0
- {ripple_down_rules-0.5.4.dist-info → ripple_down_rules-0.5.7.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.5.4.dist-info/RECORD +0 -24
- {ripple_down_rules-0.5.4.dist-info → ripple_down_rules-0.5.7.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.5.4.dist-info → ripple_down_rules-0.5.7.dist-info}/top_level.txt +0 -0
ripple_down_rules/rdr.py
CHANGED
@@ -4,6 +4,8 @@ import copyreg
|
|
4
4
|
import importlib
|
5
5
|
import os
|
6
6
|
|
7
|
+
from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
|
8
|
+
|
7
9
|
from . import logger
|
8
10
|
import sys
|
9
11
|
from abc import ABC, abstractmethod
|
@@ -28,7 +30,7 @@ from .datastructures.case import Case, CaseAttribute, create_case
|
|
28
30
|
from .datastructures.dataclasses import CaseQuery
|
29
31
|
from .datastructures.enums import MCRDRMode
|
30
32
|
from .experts import Expert, Human
|
31
|
-
from .helpers import is_matching
|
33
|
+
from .helpers import is_matching, general_rdr_classify
|
32
34
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
33
35
|
try:
|
34
36
|
from .user_interface.gui import RDRCaseViewer
|
@@ -36,8 +38,8 @@ except ImportError as e:
|
|
36
38
|
RDRCaseViewer = None
|
37
39
|
from .utils import draw_tree, make_set, copy_case, \
|
38
40
|
SubclassJSONSerializer, make_list, get_type_from_string, \
|
39
|
-
is_conflicting,
|
40
|
-
is_iterable, str_to_snake_case
|
41
|
+
is_conflicting, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
|
42
|
+
is_iterable, str_to_snake_case, get_import_path_from_path
|
41
43
|
|
42
44
|
|
43
45
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -76,16 +78,18 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
76
78
|
"""
|
77
79
|
The name of the model. If None, the model name will be the generated python file name.
|
78
80
|
"""
|
81
|
+
mutually_exclusive: Optional[bool] = None
|
82
|
+
"""
|
83
|
+
Whether the output of the classification of this rdr allows only one possible conclusion or not.
|
84
|
+
"""
|
79
85
|
|
80
86
|
def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None,
|
81
|
-
save_dir: Optional[str] = None,
|
87
|
+
save_dir: Optional[str] = None, model_name: Optional[str] = None):
|
82
88
|
"""
|
83
89
|
:param start_rule: The starting rule for the classifier.
|
84
90
|
:param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
|
85
91
|
:param save_dir: The directory to save the classifier to.
|
86
|
-
:param ask_always: Whether to always ask the expert (True) or only ask when classification fails (False).
|
87
92
|
"""
|
88
|
-
self.ask_always: bool = ask_always
|
89
93
|
self.model_name: Optional[str] = model_name
|
90
94
|
self.save_dir = save_dir
|
91
95
|
self.start_rule = start_rule
|
@@ -110,7 +114,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
110
114
|
if not os.path.exists(save_dir + '/__init__.py'):
|
111
115
|
os.makedirs(save_dir, exist_ok=True)
|
112
116
|
with open(save_dir + '/__init__.py', 'w') as f:
|
113
|
-
f.write("
|
117
|
+
f.write("from . import *\n")
|
114
118
|
if model_name is not None:
|
115
119
|
self.model_name = model_name
|
116
120
|
elif self.model_name is None:
|
@@ -134,7 +138,11 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
134
138
|
model_dir = os.path.join(load_dir, model_name)
|
135
139
|
json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
|
136
140
|
rdr = cls.from_json_file(json_file)
|
137
|
-
|
141
|
+
try:
|
142
|
+
rdr.update_from_python(model_dir)
|
143
|
+
except (FileNotFoundError, ValueError) as e:
|
144
|
+
logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
|
145
|
+
f"Make sure the file exists and is valid.")
|
138
146
|
rdr.save_dir = load_dir
|
139
147
|
rdr.model_name = model_name
|
140
148
|
return rdr
|
@@ -213,18 +221,24 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
213
221
|
return self.classify(case)
|
214
222
|
|
215
223
|
@abstractmethod
|
216
|
-
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False
|
224
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
|
225
|
+
case_query: Optional[CaseQuery] = None) \
|
217
226
|
-> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
|
218
227
|
"""
|
219
228
|
Classify a case.
|
220
229
|
|
221
230
|
:param case: The case to classify.
|
222
231
|
:param modify_case: Whether to modify the original case attributes with the conclusion or not.
|
232
|
+
:param case_query: The case query containing the case to classify and the target category to compare the case with.
|
223
233
|
:return: The category that the case belongs to.
|
224
234
|
"""
|
225
235
|
pass
|
226
236
|
|
227
|
-
def fit_case(self, case_query: CaseQuery,
|
237
|
+
def fit_case(self, case_query: CaseQuery,
|
238
|
+
expert: Optional[Expert] = None,
|
239
|
+
update_existing_rules: bool = True,
|
240
|
+
scenario: Optional[Callable] = None,
|
241
|
+
**kwargs) \
|
228
242
|
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
229
243
|
"""
|
230
244
|
Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
|
@@ -232,6 +246,9 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
232
246
|
|
233
247
|
:param case_query: The query containing the case to classify and the target category to compare the case with.
|
234
248
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
249
|
+
:param update_existing_rules: Whether to update the existing same conclusion type rules that already gave
|
250
|
+
some conclusions with the type required by the case query.
|
251
|
+
:param scenario: The scenario at which the case was created, this is used to recreate the case if needed.
|
235
252
|
:return: The category that the case belongs to.
|
236
253
|
"""
|
237
254
|
if case_query is None:
|
@@ -240,13 +257,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
240
257
|
self.name = case_query.attribute_name if self.name is None else self.name
|
241
258
|
self.case_type = case_query.case_type if self.case_type is None else self.case_type
|
242
259
|
self.case_name = case_query.case_name if self.case_name is None else self.case_name
|
260
|
+
case_query.scenario = scenario if case_query.scenario is None else case_query.scenario
|
243
261
|
|
244
|
-
expert = expert or Human(
|
245
|
-
|
262
|
+
expert = expert or Human(viewer=self.viewer,
|
263
|
+
answers_save_path=self.save_dir + '/expert_answers'
|
264
|
+
if self.save_dir else None)
|
246
265
|
if case_query.target is None:
|
247
266
|
case_query_cp = copy(case_query)
|
248
|
-
conclusions = self.classify(case_query_cp.case, modify_case=True)
|
249
|
-
if self.
|
267
|
+
conclusions = self.classify(case_query_cp.case, modify_case=True, case_query=case_query_cp)
|
268
|
+
if self.should_i_ask_the_expert_for_a_target(conclusions, case_query_cp, update_existing_rules):
|
250
269
|
expert.ask_for_conclusion(case_query_cp)
|
251
270
|
case_query.target = case_query_cp.target
|
252
271
|
if case_query.target is None:
|
@@ -262,6 +281,34 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
262
281
|
|
263
282
|
return fit_case_result
|
264
283
|
|
284
|
+
@staticmethod
|
285
|
+
def should_i_ask_the_expert_for_a_target(conclusions: Union[Any, Dict[str, Any]],
|
286
|
+
case_query: CaseQuery,
|
287
|
+
update_existing: bool) -> bool:
|
288
|
+
"""
|
289
|
+
Determine if the rdr should ask the expert for the target of a given case query.
|
290
|
+
|
291
|
+
:param conclusions: The conclusions of the case.
|
292
|
+
:param case_query: The query containing the case to classify.
|
293
|
+
:param update_existing: Whether to update rules that gave the required type of conclusions.
|
294
|
+
:return: True if the rdr should ask the expert, False otherwise.
|
295
|
+
"""
|
296
|
+
if conclusions is None:
|
297
|
+
return True
|
298
|
+
elif is_iterable(conclusions) and len(conclusions) == 0:
|
299
|
+
return True
|
300
|
+
elif isinstance(conclusions, dict):
|
301
|
+
if case_query.attribute_name not in conclusions:
|
302
|
+
return True
|
303
|
+
conclusions = conclusions[case_query.attribute_name]
|
304
|
+
conclusion_types = map(type, make_list(conclusions))
|
305
|
+
if not any(ct in case_query.core_attribute_type for ct in conclusion_types):
|
306
|
+
return True
|
307
|
+
elif update_existing:
|
308
|
+
return True
|
309
|
+
else:
|
310
|
+
return False
|
311
|
+
|
265
312
|
@abstractmethod
|
266
313
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
267
314
|
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
@@ -352,7 +399,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
352
399
|
:return: The module that contains the rdr classifier function.
|
353
400
|
"""
|
354
401
|
# remove from imports if exists first
|
355
|
-
|
402
|
+
package_name = get_import_path_from_path(package_name)
|
403
|
+
name = f"{package_name}.{self.generated_python_file_name}" if package_name else self.generated_python_file_name
|
356
404
|
try:
|
357
405
|
module = importlib.import_module(name)
|
358
406
|
del sys.modules[name]
|
@@ -374,6 +422,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
374
422
|
conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys() if not isinstance(rules_dict[rid], MultiClassStopRule)]
|
375
423
|
all_func_names = condition_func_names + conclusion_func_names
|
376
424
|
filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
|
425
|
+
cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
|
426
|
+
cases_import_path = get_import_path_from_path(model_dir)
|
427
|
+
cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path\
|
428
|
+
else self.generated_python_cases_file_name
|
377
429
|
functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
|
378
430
|
# get the scope from the imports in the file
|
379
431
|
scope = extract_imports(filepath)
|
@@ -381,13 +433,15 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
381
433
|
if rule.conditions is not None:
|
382
434
|
rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
|
383
435
|
rule.conditions.scope = scope
|
436
|
+
if os.path.exists(cases_path):
|
437
|
+
rule.corner_case_metadata = importlib.import_module(cases_import_path).__dict__.get(f"corner_case_{rule.uid}", None)
|
384
438
|
if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
|
385
439
|
rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
|
386
440
|
rule.conclusion.scope = scope
|
387
441
|
|
388
442
|
@abstractmethod
|
389
443
|
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
|
390
|
-
defs_file: Optional[str] = None):
|
444
|
+
defs_file: Optional[str] = None, cases_file: Optional[str] = None):
|
391
445
|
"""
|
392
446
|
Write the rules as source code to a file.
|
393
447
|
|
@@ -395,6 +449,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
395
449
|
:param file: The file to write the source code to.
|
396
450
|
:param parent_indent: The indentation of the parent rule.
|
397
451
|
:param defs_file: The file to write the definitions to.
|
452
|
+
:param cases_file: The file to write the cases to.
|
398
453
|
"""
|
399
454
|
pass
|
400
455
|
|
@@ -407,25 +462,28 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
407
462
|
os.makedirs(model_dir, exist_ok=True)
|
408
463
|
if not os.path.exists(model_dir + '/__init__.py'):
|
409
464
|
with open(model_dir + '/__init__.py', 'w') as f:
|
410
|
-
f.write("
|
411
|
-
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
465
|
+
f.write("from . import *\n")
|
466
|
+
func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
|
412
467
|
file_name = model_dir + f"/{self.generated_python_file_name}.py"
|
413
468
|
defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
|
469
|
+
cases_file_name = model_dir + f"/{self.generated_python_cases_file_name}.py"
|
414
470
|
imports, defs_imports = self._get_imports()
|
415
471
|
# clear the files first
|
416
472
|
with open(defs_file_name, "w") as f:
|
417
473
|
f.write(defs_imports + "\n\n")
|
474
|
+
with open(cases_file_name, "w") as cases_f:
|
475
|
+
cases_f.write("# This file contains the corner cases for the rules.\n")
|
418
476
|
with open(file_name, "w") as f:
|
419
477
|
imports += f"from .{self.generated_python_defs_file_name} import *\n"
|
420
|
-
imports += f"from ripple_down_rules.rdr import {self.__class__.__name__}\n"
|
421
478
|
f.write(imports + "\n\n")
|
422
479
|
f.write(f"attribute_name = '{self.attribute_name}'\n")
|
423
480
|
f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
|
424
|
-
f.write(f"
|
481
|
+
f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
|
425
482
|
f.write(f"\n\n{func_def}")
|
426
483
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
427
484
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
428
|
-
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name
|
485
|
+
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name,
|
486
|
+
cases_file=cases_file_name)
|
429
487
|
|
430
488
|
@property
|
431
489
|
@abstractmethod
|
@@ -474,6 +532,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
474
532
|
def generated_python_defs_file_name(self) -> str:
|
475
533
|
return f"{self.generated_python_file_name}_defs"
|
476
534
|
|
535
|
+
@property
|
536
|
+
def generated_python_cases_file_name(self) -> str:
|
537
|
+
return f"{self.generated_python_file_name}_cases"
|
538
|
+
|
477
539
|
|
478
540
|
@property
|
479
541
|
def conclusion_type(self) -> Tuple[Type]:
|
@@ -493,7 +555,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
493
555
|
return self.start_rule.conclusion_name
|
494
556
|
|
495
557
|
def _to_json(self) -> Dict[str, Any]:
|
496
|
-
return {"start_rule": self.start_rule.to_json(),
|
558
|
+
return {"start_rule": self.start_rule.to_json(),
|
559
|
+
"generated_python_file_name": self.generated_python_file_name,
|
497
560
|
"name": self.name,
|
498
561
|
"case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
|
499
562
|
"case_name": self.case_name}
|
@@ -526,6 +589,11 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
526
589
|
|
527
590
|
class SingleClassRDR(RDRWithCodeWriter):
|
528
591
|
|
592
|
+
mutually_exclusive: bool = True
|
593
|
+
"""
|
594
|
+
The output of the classification of this rdr negates all other possible outputs, there can only be one true value.
|
595
|
+
"""
|
596
|
+
|
529
597
|
def __init__(self, default_conclusion: Optional[Any] = None, **kwargs):
|
530
598
|
"""
|
531
599
|
:param start_rule: The starting rule for the classifier.
|
@@ -550,7 +618,7 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
550
618
|
pred = self.evaluate(case_query.case)
|
551
619
|
if pred.conclusion(case_query.case) != case_query.target_value:
|
552
620
|
expert.ask_for_conditions(case_query, pred)
|
553
|
-
pred.fit_rule(case_query
|
621
|
+
pred.fit_rule(case_query)
|
554
622
|
|
555
623
|
return self.classify(case_query.case)
|
556
624
|
|
@@ -563,18 +631,24 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
563
631
|
"""
|
564
632
|
if not self.start_rule:
|
565
633
|
expert.ask_for_conditions(case_query)
|
566
|
-
self.start_rule = SingleClassRule(case_query
|
567
|
-
conclusion_name=case_query.attribute_name)
|
634
|
+
self.start_rule = SingleClassRule.from_case_query(case_query)
|
568
635
|
|
569
|
-
def classify(self, case: Case, modify_case: bool = False
|
636
|
+
def classify(self, case: Case, modify_case: bool = False,
|
637
|
+
case_query: Optional[CaseQuery] = None) -> Optional[Any]:
|
570
638
|
"""
|
571
639
|
Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
|
572
640
|
|
573
641
|
:param case: The case to classify.
|
574
642
|
:param modify_case: Whether to modify the original case attributes with the conclusion or not.
|
643
|
+
:param case_query: The case query containing the case and the target category to compare the case with.
|
575
644
|
"""
|
576
645
|
pred = self.evaluate(case)
|
577
|
-
|
646
|
+
conclusion = pred.conclusion(case) if pred is not None else None
|
647
|
+
if pred is not None and pred.fired and case_query is not None:
|
648
|
+
if pred.corner_case_metadata is None and conclusion is not None\
|
649
|
+
and type(conclusion) in case_query.core_attribute_type:
|
650
|
+
pred.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
651
|
+
return conclusion if pred is not None and pred.fired else self.default_conclusion
|
578
652
|
|
579
653
|
def evaluate(self, case: Case) -> SingleClassRule:
|
580
654
|
"""
|
@@ -590,22 +664,24 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
590
664
|
f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
|
591
665
|
|
592
666
|
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
|
593
|
-
defs_file: Optional[str] = None):
|
667
|
+
defs_file: Optional[str] = None, cases_file: Optional[str] = None):
|
594
668
|
"""
|
595
669
|
Write the rules as source code to a file.
|
596
670
|
"""
|
597
671
|
if rule.conditions:
|
672
|
+
rule.write_corner_case_as_source_code(cases_file)
|
598
673
|
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
599
674
|
file.write(if_clause)
|
600
675
|
if rule.refinement:
|
601
676
|
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
|
602
|
-
defs_file=defs_file)
|
677
|
+
defs_file=defs_file, cases_file=cases_file)
|
603
678
|
|
604
679
|
conclusion_call = rule.write_conclusion_as_source_code(parent_indent, defs_file)
|
605
680
|
file.write(conclusion_call)
|
606
681
|
|
607
682
|
if rule.alternative:
|
608
|
-
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file
|
683
|
+
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file,
|
684
|
+
cases_file=cases_file)
|
609
685
|
|
610
686
|
@property
|
611
687
|
def conclusion_type_hint(self) -> str:
|
@@ -643,23 +719,33 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
643
719
|
"""
|
644
720
|
The conditions of the stopping rule if needed.
|
645
721
|
"""
|
722
|
+
mutually_exclusive: bool = False
|
723
|
+
"""
|
724
|
+
The output of the classification of this rdr allows for more than one true value as conclusion.
|
725
|
+
"""
|
646
726
|
|
647
727
|
def __init__(self, start_rule: Optional[MultiClassTopRule] = None,
|
648
|
-
mode: MCRDRMode = MCRDRMode.StopOnly):
|
728
|
+
mode: MCRDRMode = MCRDRMode.StopOnly, **kwargs):
|
649
729
|
"""
|
650
730
|
:param start_rule: The starting rules for the classifier.
|
651
731
|
:param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
|
652
732
|
"""
|
653
|
-
super(MultiClassRDR, self).__init__(start_rule)
|
733
|
+
super(MultiClassRDR, self).__init__(start_rule, **kwargs)
|
654
734
|
self.mode: MCRDRMode = mode
|
655
735
|
|
656
|
-
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False
|
736
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
|
737
|
+
case_query: Optional[CaseQuery] = None) -> Set[Any]:
|
657
738
|
evaluated_rule = self.start_rule
|
658
739
|
self.conclusions = []
|
659
740
|
while evaluated_rule:
|
660
741
|
next_rule = evaluated_rule(case)
|
661
742
|
if evaluated_rule.fired:
|
662
|
-
|
743
|
+
rule_conclusion = evaluated_rule.conclusion(case)
|
744
|
+
if evaluated_rule.corner_case_metadata is None and case_query is not None:
|
745
|
+
if rule_conclusion is not None and len(make_list(rule_conclusion)) > 0\
|
746
|
+
and any(ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
|
747
|
+
evaluated_rule.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
748
|
+
self.add_conclusion(rule_conclusion)
|
663
749
|
evaluated_rule = next_rule
|
664
750
|
return make_set(self.conclusions)
|
665
751
|
|
@@ -687,7 +773,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
687
773
|
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
|
688
774
|
else:
|
689
775
|
# Rule fired and target is correct or there is no target to compare
|
690
|
-
self.add_conclusion(
|
776
|
+
self.add_conclusion(rule_conclusion)
|
691
777
|
|
692
778
|
if not next_rule:
|
693
779
|
if not make_set(target_value).issubset(make_set(self.conclusions)):
|
@@ -699,16 +785,18 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
699
785
|
return self.conclusions
|
700
786
|
|
701
787
|
def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
|
702
|
-
file, parent_indent: str = "", defs_file: Optional[str] = None
|
788
|
+
file, parent_indent: str = "", defs_file: Optional[str] = None,
|
789
|
+
cases_file: Optional[str] = None):
|
703
790
|
if rule == self.start_rule:
|
704
791
|
file.write(f"{parent_indent}conclusions = set()\n")
|
705
792
|
if rule.conditions:
|
793
|
+
rule.write_corner_case_as_source_code(cases_file)
|
706
794
|
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
707
795
|
file.write(if_clause)
|
708
796
|
conclusion_indent = parent_indent
|
709
797
|
if hasattr(rule, "refinement") and rule.refinement:
|
710
798
|
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
|
711
|
-
defs_file=defs_file)
|
799
|
+
defs_file=defs_file, cases_file=cases_file)
|
712
800
|
conclusion_indent = parent_indent + " " * 4
|
713
801
|
file.write(f"{conclusion_indent}else:\n")
|
714
802
|
|
@@ -716,7 +804,8 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
716
804
|
file.write(conclusion_call)
|
717
805
|
|
718
806
|
if rule.alternative:
|
719
|
-
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file
|
807
|
+
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file,
|
808
|
+
cases_file=cases_file)
|
720
809
|
|
721
810
|
@property
|
722
811
|
def conclusion_type_hint(self) -> str:
|
@@ -742,8 +831,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
742
831
|
"""
|
743
832
|
if not self.start_rule:
|
744
833
|
conditions = expert.ask_for_conditions(case_query)
|
745
|
-
self.start_rule = MultiClassTopRule(
|
746
|
-
conclusion_name=case_query.attribute_name)
|
834
|
+
self.start_rule = MultiClassTopRule.from_case_query(case_query)
|
747
835
|
|
748
836
|
@property
|
749
837
|
def last_top_rule(self) -> Optional[MultiClassTopRule]:
|
@@ -764,7 +852,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
764
852
|
if is_conflicting(rule_conclusion, case_query.target_value):
|
765
853
|
self.stop_conclusion(case_query, expert, evaluated_rule)
|
766
854
|
else:
|
767
|
-
self.add_conclusion(
|
855
|
+
self.add_conclusion(rule_conclusion)
|
768
856
|
|
769
857
|
def stop_conclusion(self, case_query: CaseQuery,
|
770
858
|
expert: Expert, evaluated_rule: MultiClassTopRule):
|
@@ -776,12 +864,13 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
776
864
|
:param evaluated_rule: The evaluated rule to ask the expert about.
|
777
865
|
"""
|
778
866
|
conditions = expert.ask_for_conditions(case_query, evaluated_rule)
|
779
|
-
evaluated_rule.fit_rule(case_query
|
867
|
+
evaluated_rule.fit_rule(case_query)
|
780
868
|
if self.mode == MCRDRMode.StopPlusRule:
|
781
869
|
self.stop_rule_conditions = conditions
|
782
870
|
if self.mode == MCRDRMode.StopPlusRuleCombined:
|
783
871
|
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
784
|
-
|
872
|
+
case_query.conditions = new_top_rule_conditions
|
873
|
+
self.add_top_rule(case_query)
|
785
874
|
|
786
875
|
def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
|
787
876
|
"""
|
@@ -793,19 +882,19 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
793
882
|
if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
|
794
883
|
conditions = self.stop_rule_conditions
|
795
884
|
self.stop_rule_conditions = None
|
885
|
+
case_query.conditions = conditions
|
796
886
|
else:
|
797
887
|
conditions = expert.ask_for_conditions(case_query)
|
798
|
-
self.add_top_rule(
|
888
|
+
self.add_top_rule(case_query)
|
799
889
|
|
800
|
-
def add_conclusion(self,
|
890
|
+
def add_conclusion(self, rule_conclusion: List[Any]) -> None:
|
801
891
|
"""
|
802
892
|
Add the conclusion of the evaluated rule to the list of conclusions.
|
803
893
|
|
804
|
-
:param
|
805
|
-
|
894
|
+
:param rule_conclusion: The conclusion of the evaluated rule, which can be a single conclusion
|
895
|
+
or a set of conclusions.
|
806
896
|
"""
|
807
897
|
conclusion_types = [type(c) for c in self.conclusions]
|
808
|
-
rule_conclusion = evaluated_rule.conclusion(case)
|
809
898
|
if type(rule_conclusion) not in conclusion_types:
|
810
899
|
self.conclusions.extend(make_list(rule_conclusion))
|
811
900
|
else:
|
@@ -818,15 +907,13 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
818
907
|
self.conclusions.remove(c)
|
819
908
|
self.conclusions.extend(make_list(combined_conclusion))
|
820
909
|
|
821
|
-
def add_top_rule(self,
|
910
|
+
def add_top_rule(self, case_query: CaseQuery):
|
822
911
|
"""
|
823
912
|
Add a top rule to the classifier, which is a rule that is always checked and is part of the start_rules list.
|
824
913
|
|
825
|
-
:param
|
826
|
-
:param conclusion: The conclusion of the rule.
|
827
|
-
:param corner_case: The corner case of the rule.
|
914
|
+
:param case_query: The case query to add the top rule for.
|
828
915
|
"""
|
829
|
-
self.start_rule.alternative = MultiClassTopRule(
|
916
|
+
self.start_rule.alternative = MultiClassTopRule.from_case_query(case_query)
|
830
917
|
|
831
918
|
@staticmethod
|
832
919
|
def start_rule_type() -> Type[Rule]:
|
@@ -887,59 +974,19 @@ class GeneralRDR(RippleDownRules):
|
|
887
974
|
def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
|
888
975
|
return [rdr.start_rule for rdr in self.start_rules_dict.values()]
|
889
976
|
|
890
|
-
def classify(self, case: Any, modify_case: bool = False
|
977
|
+
def classify(self, case: Any, modify_case: bool = False,
|
978
|
+
case_query: Optional[CaseQuery] = None) -> Optional[Dict[str, Any]]:
|
891
979
|
"""
|
892
980
|
Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
|
893
981
|
the classification until no more categories can be added.
|
894
982
|
|
895
983
|
:param case: The case to classify.
|
896
984
|
:param modify_case: Whether to modify the original case or create a copy and modify it.
|
985
|
+
:param case_query: The case query containing the case and the target category to compare the case with.
|
897
986
|
:return: The categories that the case belongs to.
|
898
987
|
"""
|
899
|
-
return
|
900
|
-
|
901
|
-
@staticmethod
|
902
|
-
def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
|
903
|
-
case: Any, modify_original_case: bool = False) -> Dict[str, Any]:
|
904
|
-
"""
|
905
|
-
Classify a case by going through all classifiers and adding the categories that are classified,
|
906
|
-
and then restarting the classification until no more categories can be added.
|
907
|
-
|
908
|
-
:param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
|
909
|
-
:param case: The case to classify.
|
910
|
-
:param modify_original_case: Whether to modify the original case or create a copy and modify it.
|
911
|
-
:return: The categories that the case belongs to.
|
912
|
-
"""
|
913
|
-
conclusions = {}
|
914
|
-
case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
|
915
|
-
case_cp = copy_case(case) if not modify_original_case else case
|
916
|
-
while True:
|
917
|
-
new_conclusions = {}
|
918
|
-
for attribute_name, rdr in classifiers_dict.items():
|
919
|
-
pred_atts = rdr.classify(case_cp)
|
920
|
-
if pred_atts is None:
|
921
|
-
continue
|
922
|
-
if rdr.type_ is SingleClassRDR:
|
923
|
-
if attribute_name not in conclusions or \
|
924
|
-
(attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
|
925
|
-
conclusions[attribute_name] = pred_atts
|
926
|
-
new_conclusions[attribute_name] = pred_atts
|
927
|
-
else:
|
928
|
-
pred_atts = make_set(pred_atts)
|
929
|
-
if attribute_name in conclusions:
|
930
|
-
pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
|
931
|
-
if len(pred_atts) > 0:
|
932
|
-
new_conclusions[attribute_name] = pred_atts
|
933
|
-
if attribute_name not in conclusions:
|
934
|
-
conclusions[attribute_name] = set()
|
935
|
-
conclusions[attribute_name].update(pred_atts)
|
936
|
-
if attribute_name in new_conclusions:
|
937
|
-
mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
|
938
|
-
case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive)
|
939
|
-
update_case(case_query, new_conclusions)
|
940
|
-
if len(new_conclusions) == 0:
|
941
|
-
break
|
942
|
-
return conclusions
|
988
|
+
return general_rdr_classify(self.start_rules_dict, case, modify_original_case=modify_case,
|
989
|
+
case_query=case_query)
|
943
990
|
|
944
991
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
945
992
|
-> Dict[str, Any]:
|
@@ -1026,7 +1073,7 @@ class GeneralRDR(RippleDownRules):
|
|
1026
1073
|
"""
|
1027
1074
|
for rdr in self.start_rules_dict.values():
|
1028
1075
|
rdr._write_to_python(model_dir)
|
1029
|
-
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
1076
|
+
func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
|
1030
1077
|
with open(model_dir + f"/{self.generated_python_file_name}.py", "w") as f:
|
1031
1078
|
f.write(self._get_imports() + "\n\n")
|
1032
1079
|
f.write("classifiers_dict = dict()\n")
|
@@ -1036,7 +1083,7 @@ class GeneralRDR(RippleDownRules):
|
|
1036
1083
|
f.write(func_def)
|
1037
1084
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
1038
1085
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
1039
|
-
f.write(f"{' ' * 4}return
|
1086
|
+
f.write(f"{' ' * 4}return general_rdr_classify(classifiers_dict, case, **kwargs)\n")
|
1040
1087
|
|
1041
1088
|
@property
|
1042
1089
|
def _default_generated_python_file_name(self) -> Optional[str]:
|
@@ -1061,7 +1108,7 @@ class GeneralRDR(RippleDownRules):
|
|
1061
1108
|
# add type hints
|
1062
1109
|
imports += f"from typing_extensions import Dict, Any\n"
|
1063
1110
|
# import rdr type
|
1064
|
-
imports += f"from ripple_down_rules.
|
1111
|
+
imports += f"from ripple_down_rules.helpers import general_rdr_classify\n"
|
1065
1112
|
# add case type
|
1066
1113
|
imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
|
1067
1114
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|