ripple-down-rules 0.5.75__py3-none-any.whl → 0.5.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/rdr.py +123 -88
- ripple_down_rules/rules.py +14 -7
- ripple_down_rules/user_interface/template_file_creator.py +1 -1
- ripple_down_rules/utils.py +66 -6
- {ripple_down_rules-0.5.75.dist-info → ripple_down_rules-0.5.80.dist-info}/METADATA +1 -1
- {ripple_down_rules-0.5.75.dist-info → ripple_down_rules-0.5.80.dist-info}/RECORD +10 -10
- {ripple_down_rules-0.5.75.dist-info → ripple_down_rules-0.5.80.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.5.75.dist-info → ripple_down_rules-0.5.80.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.5.75.dist-info → ripple_down_rules-0.5.80.dist-info}/top_level.txt +0 -0
ripple_down_rules/__init__.py
CHANGED
ripple_down_rules/rdr.py
CHANGED
@@ -1,20 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import copyreg
|
4
3
|
import importlib
|
5
4
|
import os
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
from copy import copy
|
6
7
|
|
7
8
|
from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
|
8
|
-
|
9
9
|
from . import logger
|
10
|
-
import sys
|
11
|
-
from abc import ABC, abstractmethod
|
12
|
-
from copy import copy
|
13
|
-
from io import TextIOWrapper
|
14
|
-
from types import ModuleType
|
15
10
|
|
16
11
|
try:
|
17
12
|
from matplotlib import pyplot as plt
|
13
|
+
|
18
14
|
Figure = plt.Figure
|
19
15
|
except ImportError as e:
|
20
16
|
logger.debug(f"{e}: matplotlib is not installed")
|
@@ -32,13 +28,13 @@ from .datastructures.enums import MCRDRMode
|
|
32
28
|
from .experts import Expert, Human
|
33
29
|
from .helpers import is_matching, general_rdr_classify
|
34
30
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
31
|
+
|
35
32
|
try:
|
36
33
|
from .user_interface.gui import RDRCaseViewer
|
37
34
|
except ImportError as e:
|
38
35
|
RDRCaseViewer = None
|
39
|
-
from .utils import draw_tree, make_set,
|
40
|
-
|
41
|
-
is_conflicting, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
|
36
|
+
from .utils import draw_tree, make_set, SubclassJSONSerializer, make_list, get_type_from_string, \
|
37
|
+
is_conflicting, extract_function_source, extract_imports, get_full_class_name, \
|
42
38
|
is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types
|
43
39
|
|
44
40
|
|
@@ -98,13 +94,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
98
94
|
if self.viewer is not None:
|
99
95
|
self.viewer.set_save_function(self.save)
|
100
96
|
|
101
|
-
def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None
|
97
|
+
def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None,
|
98
|
+
package_name: Optional[str] = None) -> str:
|
102
99
|
"""
|
103
100
|
Save the classifier to a file.
|
104
101
|
|
105
102
|
:param save_dir: The directory to save the classifier to.
|
106
103
|
:param model_name: The name of the model to save. If None, a default name is generated.
|
107
|
-
:param
|
104
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
105
|
+
is required in case of relative imports in the generated python file.
|
108
106
|
:return: The name of the saved model.
|
109
107
|
"""
|
110
108
|
save_dir = save_dir or self.save_dir
|
@@ -124,22 +122,25 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
124
122
|
json_dir = os.path.join(model_dir, self.metadata_folder)
|
125
123
|
os.makedirs(json_dir, exist_ok=True)
|
126
124
|
self.to_json_file(os.path.join(json_dir, self.model_name))
|
127
|
-
self._write_to_python(model_dir)
|
125
|
+
self._write_to_python(model_dir, package_name=package_name)
|
128
126
|
return self.model_name
|
129
127
|
|
130
128
|
@classmethod
|
131
|
-
def load(cls, load_dir: str, model_name: str
|
129
|
+
def load(cls, load_dir: str, model_name: str,
|
130
|
+
package_name: Optional[str] = None) -> Self:
|
132
131
|
"""
|
133
132
|
Load the classifier from a file.
|
134
133
|
|
135
134
|
:param load_dir: The path to the model directory to load the classifier from.
|
136
135
|
:param model_name: The name of the model to load.
|
136
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
137
|
+
is required in case of relative imports in the generated python file.
|
137
138
|
"""
|
138
139
|
model_dir = os.path.join(load_dir, model_name)
|
139
140
|
json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
|
140
141
|
rdr = cls.from_json_file(json_file)
|
141
142
|
try:
|
142
|
-
rdr.update_from_python(model_dir)
|
143
|
+
rdr.update_from_python(model_dir, package_name=package_name)
|
143
144
|
except (FileNotFoundError, ValueError) as e:
|
144
145
|
logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
|
145
146
|
f"Make sure the file exists and is valid.")
|
@@ -148,11 +149,13 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
148
149
|
return rdr
|
149
150
|
|
150
151
|
@abstractmethod
|
151
|
-
def _write_to_python(self, model_dir: str):
|
152
|
+
def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
|
152
153
|
"""
|
153
154
|
Write the tree of rules as source code to a file.
|
154
155
|
|
155
156
|
:param model_dir: The path to the directory to write the source code to.
|
157
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
158
|
+
is required in case of relative imports in the generated python file.
|
156
159
|
"""
|
157
160
|
pass
|
158
161
|
|
@@ -373,11 +376,13 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
373
376
|
pass
|
374
377
|
|
375
378
|
@abstractmethod
|
376
|
-
def update_from_python(self, model_dir: str):
|
379
|
+
def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
|
377
380
|
"""
|
378
381
|
Update the rules from the generated python file, that might have been modified by the user.
|
379
382
|
|
380
383
|
:param model_dir: The directory where the generated python file is located.
|
384
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
385
|
+
is required in case of relative imports in the generated python file.
|
381
386
|
"""
|
382
387
|
pass
|
383
388
|
|
@@ -408,30 +413,34 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
408
413
|
|
409
414
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
410
415
|
|
411
|
-
def update_from_python(self, model_dir: str):
|
416
|
+
def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
|
412
417
|
"""
|
413
418
|
Update the rules from the generated python file, that might have been modified by the user.
|
414
419
|
|
415
420
|
:param model_dir: The directory where the generated python file is located.
|
421
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
422
|
+
is required in case of relative imports in the generated python file.
|
416
423
|
"""
|
417
|
-
rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants)
|
424
|
+
rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants)
|
425
|
+
if r.conditions is not None}
|
418
426
|
condition_func_names = [f'conditions_{rid}' for rid in rules_dict.keys()]
|
419
|
-
conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys()
|
427
|
+
conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys()
|
428
|
+
if not isinstance(rules_dict[rid], MultiClassStopRule)]
|
420
429
|
all_func_names = condition_func_names + conclusion_func_names
|
421
430
|
filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
|
422
431
|
cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
|
423
432
|
cases_import_path = get_import_path_from_path(model_dir)
|
424
|
-
cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path\
|
433
|
+
cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path \
|
425
434
|
else self.generated_python_cases_file_name
|
426
435
|
functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
|
427
436
|
# get the scope from the imports in the file
|
428
|
-
scope = extract_imports(filepath)
|
437
|
+
scope = extract_imports(filepath, package_name=package_name)
|
429
438
|
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
430
439
|
if rule.conditions is not None:
|
431
440
|
rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
|
432
441
|
rule.conditions.scope = scope
|
433
442
|
if os.path.exists(cases_path):
|
434
|
-
module = importlib.import_module(cases_import_path)
|
443
|
+
module = importlib.import_module(cases_import_path, package=package_name)
|
435
444
|
importlib.reload(module)
|
436
445
|
rule.corner_case_metadata = module.__dict__.get(f"corner_case_{rule.uid}", None)
|
437
446
|
if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
|
@@ -440,7 +449,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
440
449
|
|
441
450
|
@abstractmethod
|
442
451
|
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
|
443
|
-
defs_file: Optional[str] = None, cases_file: Optional[str] = None
|
452
|
+
defs_file: Optional[str] = None, cases_file: Optional[str] = None,
|
453
|
+
package_name: Optional[str] = None):
|
444
454
|
"""
|
445
455
|
Write the rules as source code to a file.
|
446
456
|
|
@@ -449,42 +459,62 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
449
459
|
:param parent_indent: The indentation of the parent rule.
|
450
460
|
:param defs_file: The file to write the definitions to.
|
451
461
|
:param cases_file: The file to write the cases to.
|
462
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
463
|
+
is required in case of relative imports in the generated python file.
|
452
464
|
"""
|
453
465
|
pass
|
454
466
|
|
455
|
-
def _write_to_python(self, model_dir: str):
|
467
|
+
def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
|
456
468
|
"""
|
457
469
|
Write the tree of rules as source code to a file.
|
458
470
|
|
459
471
|
:param model_dir: The path to the directory to write the source code to.
|
472
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
473
|
+
is required in case of relative imports in the generated python file.
|
460
474
|
"""
|
475
|
+
# Make sure the model directory exists and create an __init__.py file if it doesn't exist
|
461
476
|
os.makedirs(model_dir, exist_ok=True)
|
462
477
|
if not os.path.exists(model_dir + '/__init__.py'):
|
463
478
|
with open(model_dir + '/__init__.py', 'w') as f:
|
464
479
|
f.write("from . import *\n")
|
465
|
-
|
480
|
+
|
481
|
+
# Set the file names for the generated python files
|
466
482
|
file_name = model_dir + f"/{self.generated_python_file_name}.py"
|
467
483
|
defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
|
468
484
|
cases_file_name = model_dir + f"/{self.generated_python_cases_file_name}.py"
|
469
|
-
|
470
|
-
#
|
485
|
+
|
486
|
+
# Get the required imports for the main file and the defs file
|
487
|
+
main_types, defs_types, corner_cases_types = self._get_types_to_import()
|
488
|
+
imports = get_imports_from_types(main_types, file_name, package_name)
|
489
|
+
defs_imports = get_imports_from_types(defs_types, defs_file_name, package_name)
|
490
|
+
corner_cases_imports = get_imports_from_types(corner_cases_types, cases_file_name, package_name)
|
491
|
+
|
492
|
+
# Add the imports to the defs file
|
471
493
|
with open(defs_file_name, "w") as f:
|
472
|
-
f.write(defs_imports + "\n\n")
|
473
|
-
|
494
|
+
f.write('\n'.join(defs_imports) + "\n\n\n")
|
495
|
+
|
496
|
+
# Add the imports to the cases file
|
497
|
+
case_factory_import = get_imports_from_types([CaseFactoryMetaData], cases_file_name, package_name)
|
498
|
+
corner_cases_imports.extend(case_factory_import)
|
474
499
|
with open(cases_file_name, "w") as cases_f:
|
475
500
|
cases_f.write("# This file contains the corner cases for the rules.\n")
|
476
|
-
cases_f.write('\n'.join(
|
501
|
+
cases_f.write('\n'.join(corner_cases_imports) + "\n\n\n")
|
502
|
+
|
503
|
+
# Add the imports, the attributes, and the function definition to the main file
|
504
|
+
func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
|
477
505
|
with open(file_name, "w") as f:
|
478
|
-
imports
|
479
|
-
f.write(imports + "\n\n")
|
506
|
+
imports.append(f"from .{self.generated_python_defs_file_name} import *")
|
507
|
+
f.write('\n'.join(imports) + "\n\n\n")
|
480
508
|
f.write(f"attribute_name = '{self.attribute_name}'\n")
|
481
509
|
f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
|
482
510
|
f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
|
483
511
|
f.write(f"\n\n{func_def}")
|
484
512
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
485
513
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
514
|
+
|
515
|
+
# Write the rules as source code to the main file
|
486
516
|
self.write_rules_as_source_code_to_file(self.start_rule, file_name, " " * 4, defs_file=defs_file_name,
|
487
|
-
cases_file=cases_file_name)
|
517
|
+
cases_file=cases_file_name, package_name=package_name)
|
488
518
|
|
489
519
|
@property
|
490
520
|
@abstractmethod
|
@@ -494,31 +524,27 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
494
524
|
"""
|
495
525
|
pass
|
496
526
|
|
497
|
-
def
|
527
|
+
def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
|
498
528
|
"""
|
499
|
-
:return: The
|
529
|
+
:return: The types of the main, defs, and corner cases files of the RDR classifier that will be imported.
|
500
530
|
"""
|
501
|
-
|
531
|
+
defs_types = set()
|
532
|
+
cases_types = set()
|
502
533
|
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
503
534
|
if not rule.conditions:
|
504
535
|
continue
|
505
536
|
for scope in [rule.conditions.scope, rule.conclusion.scope]:
|
506
537
|
if scope is None:
|
507
538
|
continue
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
imports.append(f"from {conclusion_type.__module__} import {conclusion_type.__name__}")
|
518
|
-
imports.append("from ripple_down_rules.datastructures.case import Case, create_case")
|
519
|
-
imports = set(imports).difference(defs_imports_list)
|
520
|
-
imports = "\n".join(imports) + "\n"
|
521
|
-
return imports, defs_imports
|
539
|
+
defs_types.update(make_set(scope.values()))
|
540
|
+
cases_types.update(rule.get_corner_case_types_to_import())
|
541
|
+
defs_types.add(self.case_type)
|
542
|
+
main_types = set()
|
543
|
+
main_types.add(self.case_type)
|
544
|
+
main_types.update(make_set(self.conclusion_type))
|
545
|
+
main_types.update({Case, create_case})
|
546
|
+
main_types = main_types.difference(defs_types)
|
547
|
+
return main_types, defs_types, cases_types
|
522
548
|
|
523
549
|
@property
|
524
550
|
def _default_generated_python_file_name(self) -> Optional[str]:
|
@@ -537,7 +563,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
537
563
|
def generated_python_cases_file_name(self) -> str:
|
538
564
|
return f"{self.generated_python_file_name}_cases"
|
539
565
|
|
540
|
-
|
541
566
|
@property
|
542
567
|
def conclusion_type(self) -> Tuple[Type]:
|
543
568
|
"""
|
@@ -589,7 +614,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
589
614
|
|
590
615
|
|
591
616
|
class SingleClassRDR(RDRWithCodeWriter):
|
592
|
-
|
593
617
|
mutually_exclusive: bool = True
|
594
618
|
"""
|
595
619
|
The output of the classification of this rdr negates all other possible outputs, there can only be one true value.
|
@@ -646,7 +670,7 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
646
670
|
pred = self.evaluate(case)
|
647
671
|
conclusion = pred.conclusion(case) if pred is not None else None
|
648
672
|
if pred is not None and pred.fired and case_query is not None:
|
649
|
-
if pred.corner_case_metadata is None and conclusion is not None\
|
673
|
+
if pred.corner_case_metadata is None and conclusion is not None \
|
650
674
|
and type(conclusion) in case_query.core_attribute_type:
|
651
675
|
pred.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
652
676
|
return conclusion if pred is not None and pred.fired else self.default_conclusion
|
@@ -658,25 +682,27 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
658
682
|
matched_rule = self.start_rule(case) if self.start_rule is not None else None
|
659
683
|
return matched_rule if matched_rule is not None else self.start_rule
|
660
684
|
|
661
|
-
def _write_to_python(self, model_dir: str):
|
662
|
-
super()._write_to_python(model_dir)
|
685
|
+
def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
|
686
|
+
super()._write_to_python(model_dir, package_name=package_name)
|
663
687
|
if self.default_conclusion is not None:
|
664
688
|
with open(model_dir + f"/{self.generated_python_file_name}.py", "a") as f:
|
665
689
|
f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
|
666
690
|
|
667
691
|
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, filename: str, parent_indent: str = "",
|
668
|
-
defs_file: Optional[str] = None, cases_file: Optional[str] = None
|
692
|
+
defs_file: Optional[str] = None, cases_file: Optional[str] = None,
|
693
|
+
package_name: Optional[str] = None):
|
669
694
|
"""
|
670
695
|
Write the rules as source code to a file.
|
671
696
|
"""
|
672
697
|
if rule.conditions:
|
673
|
-
rule.write_corner_case_as_source_code(cases_file)
|
698
|
+
rule.write_corner_case_as_source_code(cases_file, package_name=package_name)
|
674
699
|
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
675
700
|
with open(filename, "a") as file:
|
676
701
|
file.write(if_clause)
|
677
702
|
if rule.refinement:
|
678
703
|
self.write_rules_as_source_code_to_file(rule.refinement, filename, parent_indent + " ",
|
679
|
-
defs_file=defs_file, cases_file=cases_file
|
704
|
+
defs_file=defs_file, cases_file=cases_file,
|
705
|
+
package_name=package_name)
|
680
706
|
|
681
707
|
conclusion_call = rule.write_conclusion_as_source_code(parent_indent, defs_file)
|
682
708
|
with open(filename, "a") as file:
|
@@ -684,7 +710,7 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
684
710
|
|
685
711
|
if rule.alternative:
|
686
712
|
self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
|
687
|
-
cases_file=cases_file)
|
713
|
+
cases_file=cases_file, package_name=package_name)
|
688
714
|
|
689
715
|
@property
|
690
716
|
def conclusion_type_hint(self) -> str:
|
@@ -745,8 +771,9 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
745
771
|
if evaluated_rule.fired:
|
746
772
|
rule_conclusion = evaluated_rule.conclusion(case)
|
747
773
|
if evaluated_rule.corner_case_metadata is None and case_query is not None:
|
748
|
-
if rule_conclusion is not None and len(make_list(rule_conclusion)) > 0\
|
749
|
-
and any(
|
774
|
+
if rule_conclusion is not None and len(make_list(rule_conclusion)) > 0 \
|
775
|
+
and any(
|
776
|
+
ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
|
750
777
|
evaluated_rule.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
751
778
|
self.add_conclusion(rule_conclusion)
|
752
779
|
evaluated_rule = next_rule
|
@@ -789,19 +816,20 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
789
816
|
|
790
817
|
def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
|
791
818
|
filename: str, parent_indent: str = "", defs_file: Optional[str] = None,
|
792
|
-
cases_file: Optional[str] = None):
|
819
|
+
cases_file: Optional[str] = None, package_name: Optional[str] = None):
|
793
820
|
if rule == self.start_rule:
|
794
821
|
with open(filename, "a") as file:
|
795
822
|
file.write(f"{parent_indent}conclusions = set()\n")
|
796
823
|
if rule.conditions:
|
797
|
-
rule.write_corner_case_as_source_code(cases_file)
|
824
|
+
rule.write_corner_case_as_source_code(cases_file, package_name=package_name)
|
798
825
|
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
799
826
|
with open(filename, "a") as file:
|
800
827
|
file.write(if_clause)
|
801
828
|
conclusion_indent = parent_indent
|
802
829
|
if hasattr(rule, "refinement") and rule.refinement:
|
803
830
|
self.write_rules_as_source_code_to_file(rule.refinement, filename, parent_indent + " ",
|
804
|
-
defs_file=defs_file, cases_file=cases_file
|
831
|
+
defs_file=defs_file, cases_file=cases_file,
|
832
|
+
package_name=package_name)
|
805
833
|
conclusion_indent = parent_indent + " " * 4
|
806
834
|
with open(filename, "a") as file:
|
807
835
|
file.write(f"{conclusion_indent}else:\n")
|
@@ -812,7 +840,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
812
840
|
|
813
841
|
if rule.alternative:
|
814
842
|
self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
|
815
|
-
cases_file=cases_file)
|
843
|
+
cases_file=cases_file, package_name=package_name)
|
816
844
|
|
817
845
|
@property
|
818
846
|
def conclusion_type_hint(self) -> str:
|
@@ -822,12 +850,11 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
822
850
|
else:
|
823
851
|
return f"Set[Union[{', '.join(conclusion_types)}]]"
|
824
852
|
|
825
|
-
def
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
return imports, defs_imports
|
853
|
+
def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
|
854
|
+
main_types, defs_types, cases_types = super()._get_types_to_import()
|
855
|
+
main_types.update({Set, Union, make_set})
|
856
|
+
defs_types.add(Union)
|
857
|
+
return main_types, defs_types, cases_types
|
831
858
|
|
832
859
|
def update_start_rule(self, case_query: CaseQuery, expert: Expert):
|
833
860
|
"""
|
@@ -1039,7 +1066,7 @@ class GeneralRDR(RippleDownRules):
|
|
1039
1066
|
|
1040
1067
|
def _to_json(self) -> Dict[str, Any]:
|
1041
1068
|
return {"start_rules": {name: rdr.to_json() for name, rdr in self.start_rules_dict.items()}
|
1042
|
-
|
1069
|
+
, "generated_python_file_name": self.generated_python_file_name,
|
1043
1070
|
"name": self.name,
|
1044
1071
|
"case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
|
1045
1072
|
"case_name": self.case_name}
|
@@ -1063,26 +1090,30 @@ class GeneralRDR(RippleDownRules):
|
|
1063
1090
|
new_rdr.case_name = data["case_name"]
|
1064
1091
|
return new_rdr
|
1065
1092
|
|
1066
|
-
def update_from_python(self, model_dir: str) -> None:
|
1093
|
+
def update_from_python(self, model_dir: str, package_name: Optional[str] = None) -> None:
|
1067
1094
|
"""
|
1068
1095
|
Update the rules from the generated python file, that might have been modified by the user.
|
1069
1096
|
|
1070
1097
|
:param model_dir: The directory where the model is stored.
|
1098
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
1099
|
+
is required in case of relative imports in the generated python file.
|
1071
1100
|
"""
|
1072
1101
|
for rdr in self.start_rules_dict.values():
|
1073
|
-
rdr.update_from_python(model_dir)
|
1102
|
+
rdr.update_from_python(model_dir, package_name=package_name)
|
1074
1103
|
|
1075
|
-
def _write_to_python(self, model_dir: str) -> None:
|
1104
|
+
def _write_to_python(self, model_dir: str, package_name: Optional[str] = None) -> None:
|
1076
1105
|
"""
|
1077
1106
|
Write the tree of rules as source code to a file.
|
1078
1107
|
|
1079
1108
|
:param model_dir: The directory where the model is stored.
|
1109
|
+
:param relative_imports: Whether to use relative imports in the generated python file.
|
1080
1110
|
"""
|
1081
1111
|
for rdr in self.start_rules_dict.values():
|
1082
|
-
rdr._write_to_python(model_dir)
|
1112
|
+
rdr._write_to_python(model_dir, package_name=package_name)
|
1083
1113
|
func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
|
1084
|
-
|
1085
|
-
|
1114
|
+
file_path = model_dir + f"/{self.generated_python_file_name}.py"
|
1115
|
+
with open(file_path, "w") as f:
|
1116
|
+
f.write(self._get_imports(file_path=file_path, package_name=package_name) + "\n\n")
|
1086
1117
|
f.write("classifiers_dict = dict()\n")
|
1087
1118
|
for rdr_key, rdr in self.start_rules_dict.items():
|
1088
1119
|
f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
|
@@ -1105,25 +1136,29 @@ class GeneralRDR(RippleDownRules):
|
|
1105
1136
|
def conclusion_type_hint(self) -> str:
|
1106
1137
|
return f"Dict[str, Any]"
|
1107
1138
|
|
1108
|
-
def _get_imports(self) -> str:
|
1139
|
+
def _get_imports(self, file_path: Optional[str] = None, package_name: Optional[str] = None) -> str:
|
1109
1140
|
"""
|
1110
1141
|
Get the imports needed for the generated python file.
|
1111
1142
|
|
1143
|
+
:param file_path: The path to the file where the imports will be written, if None, the imports will be absolute.
|
1144
|
+
:param package_name: The name of the package that contains the RDR classifier function, this
|
1145
|
+
is required in case of relative imports in the generated python file.
|
1112
1146
|
:return: The imports needed for the generated python file.
|
1113
1147
|
"""
|
1114
|
-
|
1148
|
+
all_types = set()
|
1115
1149
|
# add type hints
|
1116
|
-
|
1150
|
+
all_types.update({Dict, Any})
|
1117
1151
|
# import rdr type
|
1118
|
-
|
1152
|
+
all_types.add(general_rdr_classify)
|
1119
1153
|
# add case type
|
1120
|
-
|
1121
|
-
imports
|
1154
|
+
all_types.update({Case, create_case, self.case_type})
|
1155
|
+
# get the imports from the types
|
1156
|
+
imports = get_imports_from_types(all_types, target_file_path=file_path, package_name=package_name)
|
1122
1157
|
# add rdr python generated functions.
|
1123
1158
|
for rdr_key, rdr in self.start_rules_dict.items():
|
1124
|
-
imports
|
1125
|
-
|
1126
|
-
return imports
|
1159
|
+
imports.append(
|
1160
|
+
f"from . import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}")
|
1161
|
+
return '\n'.join(imports)
|
1127
1162
|
|
1128
1163
|
@staticmethod
|
1129
1164
|
def rdr_key_to_function_name(rdr_key: str) -> str:
|
ripple_down_rules/rules.py
CHANGED
@@ -8,7 +8,7 @@ from uuid import uuid4
|
|
8
8
|
|
9
9
|
from anytree import NodeMixin
|
10
10
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
11
|
-
from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
|
11
|
+
from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple, Type, Set
|
12
12
|
|
13
13
|
from .datastructures.callable_expression import CallableExpression
|
14
14
|
from .datastructures.case import Case
|
@@ -102,11 +102,21 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
102
102
|
"""
|
103
103
|
pass
|
104
104
|
|
105
|
-
def write_corner_case_as_source_code(self, cases_file:
|
105
|
+
def write_corner_case_as_source_code(self, cases_file: str, package_name: Optional[str] = None) -> None:
|
106
106
|
"""
|
107
107
|
Write the source code representation of the corner case of the rule to a file.
|
108
108
|
|
109
|
-
:param cases_file: The file to write the corner case to
|
109
|
+
:param cases_file: The file to write the corner case to.
|
110
|
+
:param package_name: The package name to use for relative imports.
|
111
|
+
"""
|
112
|
+
if self.corner_case_metadata is None:
|
113
|
+
return
|
114
|
+
with open(cases_file, 'a') as f:
|
115
|
+
f.write(f"corner_case_{self.uid} = {self.corner_case_metadata}" + "\n\n\n")
|
116
|
+
|
117
|
+
def get_corner_case_types_to_import(self) -> Set[Type]:
|
118
|
+
"""
|
119
|
+
Get the types that need to be imported for the corner case of the rule.
|
110
120
|
"""
|
111
121
|
if self.corner_case_metadata is None:
|
112
122
|
return
|
@@ -117,10 +127,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
117
127
|
types_to_import.add(self.corner_case_metadata.scenario)
|
118
128
|
if self.corner_case_metadata.case_conf is not None:
|
119
129
|
types_to_import.add(self.corner_case_metadata.case_conf)
|
120
|
-
|
121
|
-
with open(cases_file, 'a') as f:
|
122
|
-
f.write("\n".join(imports) + "\n\n\n")
|
123
|
-
f.write(f"corner_case_{self.uid} = {self.corner_case_metadata}" + "\n\n\n")
|
130
|
+
return types_to_import
|
124
131
|
|
125
132
|
def write_conclusion_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
|
126
133
|
"""
|
ripple_down_rules/utils.py
CHANGED
@@ -8,6 +8,7 @@ import json
|
|
8
8
|
import logging
|
9
9
|
import os
|
10
10
|
import re
|
11
|
+
import sys
|
11
12
|
import threading
|
12
13
|
import uuid
|
13
14
|
from collections import UserDict, defaultdict
|
@@ -15,6 +16,7 @@ from copy import deepcopy, copy
|
|
15
16
|
from dataclasses import is_dataclass, fields
|
16
17
|
from enum import Enum
|
17
18
|
from os.path import dirname
|
19
|
+
from pathlib import Path
|
18
20
|
from textwrap import dedent
|
19
21
|
from types import NoneType
|
20
22
|
|
@@ -44,7 +46,7 @@ from sqlalchemy import MetaData, inspect
|
|
44
46
|
from sqlalchemy.orm import Mapped, registry, class_mapper, DeclarativeBase as SQLTable, Session
|
45
47
|
from tabulate import tabulate
|
46
48
|
from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get_type_hints, \
|
47
|
-
get_origin, get_args, Tuple, Optional, List, Union, Self, ForwardRef
|
49
|
+
get_origin, get_args, Tuple, Optional, List, Union, Self, ForwardRef, Sequence, Iterable
|
48
50
|
|
49
51
|
if TYPE_CHECKING:
|
50
52
|
from .datastructures.case import Case
|
@@ -122,7 +124,15 @@ def get_imports_from_scope(scope: Dict[str, Any]) -> List[str]:
|
|
122
124
|
return imports
|
123
125
|
|
124
126
|
|
125
|
-
def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = None
|
127
|
+
def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = None,
|
128
|
+
package_name: Optional[str] = None) -> Dict[str, Any]:
|
129
|
+
"""
|
130
|
+
Extract imports from a Python file or an AST tree.
|
131
|
+
|
132
|
+
:param file_path: The path to the Python file to extract imports from.
|
133
|
+
:param tree: An AST tree to extract imports from. If provided, file_path is ignored.
|
134
|
+
:param package_name: The name of the package to use for relative imports.
|
135
|
+
"""
|
126
136
|
if tree is None:
|
127
137
|
if file_path is None:
|
128
138
|
raise ValueError("Either file_path or tree must be provided")
|
@@ -137,7 +147,7 @@ def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = N
|
|
137
147
|
module_name = alias.name
|
138
148
|
asname = alias.asname or alias.name
|
139
149
|
try:
|
140
|
-
scope[asname] = importlib.import_module(module_name)
|
150
|
+
scope[asname] = importlib.import_module(module_name, package=package_name)
|
141
151
|
except ImportError as e:
|
142
152
|
print(f"Could not import {module_name}: {e}")
|
143
153
|
elif isinstance(node, ast.ImportFrom):
|
@@ -146,7 +156,12 @@ def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = N
|
|
146
156
|
name = alias.name
|
147
157
|
asname = alias.asname or name
|
148
158
|
try:
|
149
|
-
|
159
|
+
if package_name is not None and node.level > 0: # Handle relative imports
|
160
|
+
module_rel_path = Path(os.path.join(file_path, *['..'] * node.level, module_name)).resolve()
|
161
|
+
idx = str(module_rel_path).rfind(package_name)
|
162
|
+
if idx != -1:
|
163
|
+
module_name = str(module_rel_path)[idx:].replace(os.path.sep, '.')
|
164
|
+
module = importlib.import_module(module_name, package=package_name)
|
150
165
|
scope[asname] = getattr(module, name)
|
151
166
|
except (ImportError, AttributeError) as e:
|
152
167
|
logging.warning(f"Could not import {module_name}: {e} while extracting imports from {file_path}")
|
@@ -833,37 +848,82 @@ def get_function_representation(func: Callable) -> str:
|
|
833
848
|
return func_name
|
834
849
|
|
835
850
|
|
836
|
-
def
|
851
|
+
def get_relative_import(target_file_path, imported_module_path: Optional[str] = None,
|
852
|
+
module: Optional[str] = None) -> str:
|
853
|
+
"""
|
854
|
+
Get a relative import path from the target file to the imported module.
|
855
|
+
|
856
|
+
:param target_file_path: The file path of the target file.
|
857
|
+
:param imported_module_path: The file path of the module being imported.
|
858
|
+
:param module: The module name, if available.
|
859
|
+
:return: A relative import path as a string.
|
860
|
+
"""
|
861
|
+
# Convert to absolute paths
|
862
|
+
if module is not None:
|
863
|
+
imported_module_path = sys.modules[module].__file__
|
864
|
+
if imported_module_path is None:
|
865
|
+
raise ValueError("Either imported_module_path or module must be provided")
|
866
|
+
target_path = Path(target_file_path).resolve()
|
867
|
+
imported_path = Path(imported_module_path).resolve()
|
868
|
+
|
869
|
+
# Compute relative path from target to imported module
|
870
|
+
rel_path = os.path.relpath(imported_path.parent, target_path.parent)
|
871
|
+
|
872
|
+
# Convert path to Python import format
|
873
|
+
rel_parts = [part.replace('..', '.') for part in Path(rel_path).parts]
|
874
|
+
rel_parts = rel_parts if rel_parts else ['']
|
875
|
+
|
876
|
+
# Join the parts and add the module name
|
877
|
+
joined_parts = "".join(rel_parts) + f".{imported_path.stem}"
|
878
|
+
joined_parts = f".{joined_parts}" if not joined_parts.startswith(".") else joined_parts
|
879
|
+
|
880
|
+
return joined_parts
|
881
|
+
|
882
|
+
|
883
|
+
def get_imports_from_types(type_objs: Iterable[Type],
|
884
|
+
target_file_path: Optional[str] = None,
|
885
|
+
package_name: Optional[str] = None) -> List[str]:
|
837
886
|
"""
|
838
887
|
Format import lines from type objects.
|
839
888
|
|
840
889
|
:param type_objs: A list of type objects to format.
|
890
|
+
:param target_file_path: The file path to which the imports should be relative.
|
891
|
+
:param package_name: The name of the package to use for relative imports.
|
841
892
|
"""
|
842
893
|
|
843
894
|
module_to_types = defaultdict(list)
|
895
|
+
module_to_path = {}
|
844
896
|
other_imports = []
|
845
897
|
for tp in type_objs:
|
846
898
|
try:
|
847
899
|
if isinstance(tp, type) or is_typing_type(tp):
|
848
900
|
module = tp.__module__
|
901
|
+
file = getattr(tp, '__file__', None)
|
849
902
|
name = tp.__qualname__
|
850
903
|
elif callable(tp):
|
851
904
|
module, name = get_function_import_data(tp)
|
905
|
+
file = get_method_file_name(tp)
|
852
906
|
elif hasattr(type(tp), "__module__"):
|
853
907
|
module = type(tp).__module__
|
908
|
+
file = getattr(tp, '__file__', None)
|
854
909
|
name = type(tp).__qualname__
|
855
910
|
else:
|
856
911
|
continue
|
857
912
|
if module is None or module == 'builtins' or module.startswith('_'):
|
858
913
|
continue
|
859
914
|
module_to_types[module].append(name)
|
915
|
+
if file:
|
916
|
+
module_to_path[module] = file
|
860
917
|
except AttributeError:
|
861
918
|
continue
|
862
919
|
|
863
920
|
lines = []
|
864
921
|
for module, names in module_to_types.items():
|
865
922
|
joined = ", ".join(sorted(set(names)))
|
866
|
-
|
923
|
+
import_path = module
|
924
|
+
if (target_file_path is not None) and (package_name is not None) and (package_name in module):
|
925
|
+
import_path = get_relative_import(target_file_path, module=module)
|
926
|
+
lines.append(f"from {import_path} import {joined}")
|
867
927
|
if other_imports:
|
868
928
|
lines.extend(other_imports)
|
869
929
|
return sorted(lines)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.5.
|
3
|
+
Version: 0.5.80
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -1,11 +1,11 @@
|
|
1
|
-
ripple_down_rules/__init__.py,sha256=
|
1
|
+
ripple_down_rules/__init__.py,sha256=OCXYavUU_yIcymr1OaTIQ9uNk3ZwDNlg1bX6wK7sZLQ,100
|
2
2
|
ripple_down_rules/experts.py,sha256=bwozulI1rv0uyaMZQqEgapDO-s8wvW0D6Jqxmvu5fik,12610
|
3
3
|
ripple_down_rules/helpers.py,sha256=v4oE7C5PfQUVJfSUs1FfLHEwrJXEHJLn4vJhJMvyCR8,4453
|
4
|
-
ripple_down_rules/rdr.py,sha256=
|
4
|
+
ripple_down_rules/rdr.py,sha256=2t04Qj931dh6UypLBy4RxJ5hFIui0ejFgvm7D5P8b-E,55118
|
5
5
|
ripple_down_rules/rdr_decorators.py,sha256=0sk7izDB53lTKSB9fm33vQahmY_05FyCOWljyQOMB0U,9072
|
6
|
-
ripple_down_rules/rules.py,sha256=
|
6
|
+
ripple_down_rules/rules.py,sha256=iVevv6iZ-6L2IPI0ZYbBjxBymXEQMmJGRFhiKUS-NmA,20352
|
7
7
|
ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
|
8
|
-
ripple_down_rules/utils.py,sha256=
|
8
|
+
ripple_down_rules/utils.py,sha256=mObBszTruGrRvD4MgD8tS1AnMtoyKrPl4RCciinhzY4,60132
|
9
9
|
ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
|
10
10
|
ripple_down_rules/datastructures/callable_expression.py,sha256=f3wUPTrLa1INO-1qfgVz87ryrCABronfyq0_JKWoZCs,12800
|
11
11
|
ripple_down_rules/datastructures/case.py,sha256=1zSaXUljaH6z3SgMGzYPoDyjotNam791KpYgvxuMh90,15463
|
@@ -16,9 +16,9 @@ ripple_down_rules/user_interface/gui.py,sha256=_lgZAUXxxaBUFQJAHjA5TBPp6XEvJ62t-
|
|
16
16
|
ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=Jrf7NxOdlrwGXH0Xyz3vzQprY-PNx9etfePOTpm2Gu8,6479
|
17
17
|
ripple_down_rules/user_interface/object_diagram.py,sha256=FEa2HaYR9QmTE6NsOwBvZ0jqmu3DKyg6mig2VE5ZP4Y,4956
|
18
18
|
ripple_down_rules/user_interface/prompt.py,sha256=AkkltdDIaioN43lkRKDPKSjJcmdSSGZDMYz7AL7X9lE,8082
|
19
|
-
ripple_down_rules/user_interface/template_file_creator.py,sha256=
|
20
|
-
ripple_down_rules-0.5.
|
21
|
-
ripple_down_rules-0.5.
|
22
|
-
ripple_down_rules-0.5.
|
23
|
-
ripple_down_rules-0.5.
|
24
|
-
ripple_down_rules-0.5.
|
19
|
+
ripple_down_rules/user_interface/template_file_creator.py,sha256=xWUt-RHRqrvoMo74o-bMLo8xNxil68wgbOZAMADZp2A,13570
|
20
|
+
ripple_down_rules-0.5.80.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
21
|
+
ripple_down_rules-0.5.80.dist-info/METADATA,sha256=slyVN82WJd9yd4efs415WgtiiCM__5BJeKneQhWl5zY,48214
|
22
|
+
ripple_down_rules-0.5.80.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
23
|
+
ripple_down_rules-0.5.80.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
24
|
+
ripple_down_rules-0.5.80.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|