ripple-down-rules 0.6.31__py3-none-any.whl → 0.6.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/datastructures/callable_expression.py +4 -1
- ripple_down_rules/rdr.py +501 -59
- ripple_down_rules/rules.py +4 -0
- ripple_down_rules/utils.py +36 -0
- {ripple_down_rules-0.6.31.dist-info → ripple_down_rules-0.6.41.dist-info}/METADATA +1 -1
- {ripple_down_rules-0.6.31.dist-info → ripple_down_rules-0.6.41.dist-info}/RECORD +10 -10
- {ripple_down_rules-0.6.31.dist-info → ripple_down_rules-0.6.41.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.6.31.dist-info → ripple_down_rules-0.6.41.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.6.31.dist-info → ripple_down_rules-0.6.41.dist-info}/top_level.txt +0 -0
ripple_down_rules/__init__.py
CHANGED
@@ -129,6 +129,9 @@ class CallableExpression(SubclassJSONSerializer):
|
|
129
129
|
else:
|
130
130
|
conclusion_type = (conclusion_type,)
|
131
131
|
self.conclusion_type = conclusion_type
|
132
|
+
self.expected_types: Set[Type] = set(conclusion_type) if conclusion_type is not None else set()
|
133
|
+
if not mutually_exclusive:
|
134
|
+
self.expected_types.update({list, set})
|
132
135
|
self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
|
133
136
|
self.scope = get_used_scope(self.user_input, self.scope)
|
134
137
|
self.expression_tree: AST = expression_tree if expression_tree else parse_string_to_expression(self.user_input)
|
@@ -158,7 +161,7 @@ class CallableExpression(SubclassJSONSerializer):
|
|
158
161
|
raise ValueError(f"Mutually exclusive types cannot be lists or sets, got {type(output)}")
|
159
162
|
output_types = {type(o) for o in make_list(output)}
|
160
163
|
output_types.add(type(output))
|
161
|
-
if not are_results_subclass_of_types(output_types, self.
|
164
|
+
if not are_results_subclass_of_types(output_types, self.expected_types):
|
162
165
|
raise ValueError(f"Not all result types {output_types} are subclasses of expected types"
|
163
166
|
f" {self.conclusion_type}")
|
164
167
|
return output
|
ripple_down_rules/rdr.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import ast
|
3
4
|
import importlib
|
4
5
|
import json
|
5
6
|
import os
|
6
7
|
from abc import ABC, abstractmethod
|
7
8
|
from copy import copy
|
8
9
|
from dataclasses import is_dataclass
|
9
|
-
from
|
10
|
+
from io import TextIOWrapper
|
11
|
+
from pathlib import Path
|
12
|
+
from types import NoneType, ModuleType
|
10
13
|
|
11
14
|
from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
|
12
15
|
from . import logger
|
@@ -26,7 +29,7 @@ from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tupl
|
|
26
29
|
from .datastructures.callable_expression import CallableExpression
|
27
30
|
from .datastructures.case import Case, CaseAttribute, create_case
|
28
31
|
from .datastructures.dataclasses import CaseQuery
|
29
|
-
from .datastructures.enums import MCRDRMode
|
32
|
+
from .datastructures.enums import MCRDRMode, RDREdge
|
30
33
|
from .experts import Expert, Human
|
31
34
|
from .helpers import is_matching, general_rdr_classify, get_an_updated_case_copy
|
32
35
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule, MultiClassRefinementRule, \
|
@@ -38,7 +41,8 @@ except ImportError as e:
|
|
38
41
|
RDRCaseViewer = None
|
39
42
|
from .utils import draw_tree, make_set, SubclassJSONSerializer, make_list, get_type_from_string, \
|
40
43
|
is_value_conflicting, extract_function_source, extract_imports, get_full_class_name, \
|
41
|
-
is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types, render_tree
|
44
|
+
is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types, render_tree, \
|
45
|
+
get_types_to_import_from_func_type_hints, get_function_return_type, get_file_that_ends_with
|
42
46
|
|
43
47
|
|
44
48
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -96,6 +100,30 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
96
100
|
if RDRCaseViewer and any(RDRCaseViewer.instances) else None
|
97
101
|
self.input_node: Optional[Rule] = None
|
98
102
|
|
103
|
+
def write_rdr_metadata_to_pyton_file(self, file: TextIOWrapper):
|
104
|
+
"""
|
105
|
+
Write the metadata of the RDR classifier to a python file.
|
106
|
+
|
107
|
+
:param file: The file to write the metadata to.
|
108
|
+
"""
|
109
|
+
file.write(f"name = \'{self.name}\'\n")
|
110
|
+
file.write(f"case_type = {self.case_type.__name__ if self.case_type is not None else None}\n")
|
111
|
+
file.write(f"case_name = \'{self.case_name}\'\n")
|
112
|
+
|
113
|
+
def update_rdr_metadata_from_python(self, module: ModuleType):
|
114
|
+
"""
|
115
|
+
Update the RDR metadata from the module that contains the RDR classifier function.
|
116
|
+
|
117
|
+
:param module: The module that contains the RDR classifier function.
|
118
|
+
"""
|
119
|
+
try:
|
120
|
+
self.name = module.name if hasattr(module, "name") else self.start_rule.conclusion_name
|
121
|
+
self.case_type = module.case_type
|
122
|
+
self.case_name = module.case_name if hasattr(module, "case_name") else f"{self.case_type.__name__}.{self.name}"
|
123
|
+
except AttributeError as e:
|
124
|
+
logger.warning(f"Could not update the RDR metadata from the module {module.__name__}. "
|
125
|
+
f"Make sure the module has the required attributes: {e}")
|
126
|
+
|
99
127
|
def render_evaluated_rule_tree(self, filename: str, show_full_tree: bool = False) -> None:
|
100
128
|
if show_full_tree:
|
101
129
|
start_rule = self.start_rule if self.input_node is None else self.input_node
|
@@ -180,20 +208,42 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
180
208
|
:param package_name: The name of the package that contains the RDR classifier function, this
|
181
209
|
is required in case of relative imports in the generated python file.
|
182
210
|
"""
|
211
|
+
rdr: Optional[RippleDownRules] = None
|
183
212
|
model_dir = os.path.join(load_dir, model_name)
|
184
213
|
json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
|
185
|
-
|
186
|
-
|
187
|
-
rdr.model_name = model_name
|
214
|
+
if os.path.exists(json_file + ".json"):
|
215
|
+
rdr = cls.from_json_file(json_file)
|
188
216
|
try:
|
189
|
-
rdr
|
217
|
+
if rdr is None:
|
218
|
+
acronym = cls.get_acronym().lower()
|
219
|
+
python_file_name = get_file_that_ends_with(model_dir, f"_{acronym}.py")
|
220
|
+
python_file_path = os.path.join(model_dir, python_file_name)
|
221
|
+
rdr = cls.from_python(model_dir, parent_package_name=package_name, python_file_path=python_file_path)
|
222
|
+
else:
|
223
|
+
rdr.update_from_python(model_dir, package_name=package_name)
|
190
224
|
rdr.to_json_file(json_file)
|
191
|
-
except (FileNotFoundError, ValueError, SyntaxError) as e:
|
225
|
+
except (FileNotFoundError, ValueError, SyntaxError, ModuleNotFoundError) as e:
|
192
226
|
logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
|
193
227
|
f"Make sure the file exists and is valid.")
|
194
228
|
rdr.save(save_dir=load_dir, model_name=model_name, package_name=package_name)
|
229
|
+
rdr.save_dir = load_dir
|
230
|
+
rdr.model_name = model_name
|
195
231
|
return rdr
|
196
232
|
|
233
|
+
@classmethod
|
234
|
+
@abstractmethod
|
235
|
+
def from_python(cls, model_dir: str, python_file_path: Optional[str] = None,
|
236
|
+
parent_package_name: Optional[str] = None) -> Self:
|
237
|
+
"""
|
238
|
+
Load the classifier from a python file.
|
239
|
+
|
240
|
+
:param model_dir: The path to the directory where the generated python file is located.
|
241
|
+
:param python_file_path: The path to the python file to load the classifier from.
|
242
|
+
:param parent_package_name: The name of the package that contains the RDR classifier function, this
|
243
|
+
is required in case of relative imports in the generated python file.
|
244
|
+
"""
|
245
|
+
pass
|
246
|
+
|
197
247
|
@abstractmethod
|
198
248
|
def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
|
199
249
|
"""
|
@@ -424,6 +474,58 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
424
474
|
def type_(self):
|
425
475
|
return self.__class__
|
426
476
|
|
477
|
+
@classmethod
|
478
|
+
def get_json_file_path(cls, model_path: str) -> str:
|
479
|
+
"""
|
480
|
+
Get the path to the saved json file.
|
481
|
+
|
482
|
+
:param model_path : The path to the model directory.
|
483
|
+
:return: The path to the saved model.
|
484
|
+
"""
|
485
|
+
model_name = cls.get_model_name_from_model_path(model_path)
|
486
|
+
return os.path.join(model_path, cls.metadata_folder, f"{model_name}.json")
|
487
|
+
|
488
|
+
@classmethod
|
489
|
+
def get_generated_cases_file_path(cls, model_path: str) -> str:
|
490
|
+
"""
|
491
|
+
Get the path to the python file that contains the RDR classifier cases.
|
492
|
+
|
493
|
+
:param model_path : The path to the model directory.
|
494
|
+
:return: The path to the generated python file.
|
495
|
+
"""
|
496
|
+
return cls.get_generated_python_file_path(model_path).replace(".py", "_cases.py")
|
497
|
+
|
498
|
+
@classmethod
|
499
|
+
def get_generated_defs_file_path(cls, model_path: str) -> str:
|
500
|
+
"""
|
501
|
+
Get the path to the python file that contains the RDR classifier function definitions.
|
502
|
+
|
503
|
+
:param model_path : The path to the model directory.
|
504
|
+
:return: The path to the generated python file.
|
505
|
+
"""
|
506
|
+
return cls.get_generated_python_file_path(model_path).replace(".py", "_defs.py")
|
507
|
+
|
508
|
+
@classmethod
|
509
|
+
def get_generated_python_file_path(cls, model_path: str) -> str:
|
510
|
+
"""
|
511
|
+
Get the path to the python file that contains the RDR classifier function.
|
512
|
+
|
513
|
+
:param model_path : The path to the model directory.
|
514
|
+
:return: The path to the generated python file.
|
515
|
+
"""
|
516
|
+
model_name = cls.get_model_name_from_model_path(model_path)
|
517
|
+
return os.path.join(model_path, f"{model_name}.py")
|
518
|
+
|
519
|
+
@classmethod
|
520
|
+
def get_model_name_from_model_path(cls, model_path: str) -> str:
|
521
|
+
"""
|
522
|
+
Get the model name from the model path.
|
523
|
+
|
524
|
+
:param model_path: The path to the model directory.
|
525
|
+
:return: The name of the model.
|
526
|
+
"""
|
527
|
+
return Path(model_path).name
|
528
|
+
|
427
529
|
@property
|
428
530
|
def generated_python_file_name(self) -> str:
|
429
531
|
if self._generated_python_file_name is None:
|
@@ -482,66 +584,345 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
482
584
|
return module.classify
|
483
585
|
|
484
586
|
|
587
|
+
class TreeBuilder(ast.NodeVisitor, ABC):
|
588
|
+
"""Parses an AST of nested if-elif statements and reconstructs the tree."""
|
589
|
+
|
590
|
+
def __init__(self):
|
591
|
+
self.root: Optional[Rule] = None
|
592
|
+
self.current_parent: Optional[Rule] = None
|
593
|
+
self.current_edge: Optional[RDREdge] = None
|
594
|
+
self.default_conclusion: Optional[str] = None
|
595
|
+
|
596
|
+
def visit_FunctionDef(self, node):
|
597
|
+
"""Finds the main function and starts parsing its body."""
|
598
|
+
for stmt in node.body:
|
599
|
+
self.visit(stmt)
|
600
|
+
|
601
|
+
def visit_If(self, node):
|
602
|
+
"""Handles if-elif blocks and creates nodes."""
|
603
|
+
condition = self.get_condition_name(node.test)
|
604
|
+
if condition is None:
|
605
|
+
return
|
606
|
+
rule_uid = condition.split("conditions_")[1]
|
607
|
+
|
608
|
+
new_rule_type = self.get_new_rule_type(node)
|
609
|
+
new_node = new_rule_type(conditions=condition, parent=self.current_parent, uid=rule_uid)
|
610
|
+
if self.current_parent is not None:
|
611
|
+
self.update_current_parent(new_node)
|
612
|
+
|
613
|
+
if self.current_parent is None and self.root is None:
|
614
|
+
self.root = new_node
|
615
|
+
|
616
|
+
self.current_parent = new_node
|
617
|
+
|
618
|
+
# Parse the body of the if statement
|
619
|
+
for stmt in node.body:
|
620
|
+
self.current_edge = self.get_refinement_edge(node)
|
621
|
+
self.current_parent = new_node
|
622
|
+
self.visit(stmt)
|
623
|
+
|
624
|
+
# Parse elif/else
|
625
|
+
for stmt in node.orelse:
|
626
|
+
self.current_edge = self.get_alternative_edge(node)
|
627
|
+
self.current_parent = new_node
|
628
|
+
if isinstance(stmt, ast.If): # elif case
|
629
|
+
self.visit_If(stmt)
|
630
|
+
else: # else case (return)
|
631
|
+
self.process_else_statement(stmt)
|
632
|
+
self.current_parent = new_node
|
633
|
+
self.current_edge = None
|
634
|
+
|
635
|
+
@abstractmethod
|
636
|
+
def process_else_statement(self, stmt: ast.AST):
|
637
|
+
"""
|
638
|
+
Process the else statement in the if-elif-else block.
|
639
|
+
|
640
|
+
:param stmt: The else statement to process.
|
641
|
+
"""
|
642
|
+
pass
|
643
|
+
|
644
|
+
@abstractmethod
|
645
|
+
def get_refinement_edge(self, node: ast.AST) -> RDREdge:
|
646
|
+
"""
|
647
|
+
:param node: The current AST node to determine the edge type from.
|
648
|
+
:return: The refinement edge type.
|
649
|
+
"""
|
650
|
+
pass
|
651
|
+
|
652
|
+
@abstractmethod
|
653
|
+
def get_alternative_edge(self, node: ast.AST) -> RDREdge:
|
654
|
+
"""
|
655
|
+
:param node: The current AST node to determine the alternative edge type from.
|
656
|
+
:return: The alternative edge type.
|
657
|
+
"""
|
658
|
+
pass
|
659
|
+
|
660
|
+
@abstractmethod
|
661
|
+
def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
|
662
|
+
"""
|
663
|
+
Get the new rule type to create.
|
664
|
+
:param node: The current AST node to determine the rule type from.
|
665
|
+
:return: The new rule type.
|
666
|
+
"""
|
667
|
+
pass
|
668
|
+
|
669
|
+
@abstractmethod
|
670
|
+
def update_current_parent(self, new_node: Rule):
|
671
|
+
"""
|
672
|
+
Update the current parent rule with the new node.
|
673
|
+
:param new_node: The new node to set as the current parent.
|
674
|
+
"""
|
675
|
+
pass
|
676
|
+
|
677
|
+
def visit_Return(self, node):
|
678
|
+
"""Handles return statements as leaf nodes."""
|
679
|
+
if isinstance(node.value, ast.Call):
|
680
|
+
return_value = node.value.func.id
|
681
|
+
else:
|
682
|
+
return_value = ast.literal_eval(node.value)
|
683
|
+
if self.current_parent is None:
|
684
|
+
self.default_conclusion = return_value
|
685
|
+
else:
|
686
|
+
self.current_parent.conclusion = return_value
|
687
|
+
|
688
|
+
def get_condition_name(self, node):
|
689
|
+
"""Extracts the condition function name from an AST expression."""
|
690
|
+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
691
|
+
return node.func.id
|
692
|
+
return None
|
693
|
+
|
694
|
+
class SingleClassTreeBuilder(TreeBuilder):
|
695
|
+
"""Parses an AST of generated SingleClassRDR classifier and reconstructs the rdr tree."""
|
696
|
+
|
697
|
+
def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
|
698
|
+
return SingleClassRule
|
699
|
+
|
700
|
+
def get_refinement_edge(self, node: ast.AST) -> RDREdge:
|
701
|
+
return RDREdge.Refinement
|
702
|
+
|
703
|
+
def get_alternative_edge(self, node: ast.AST) -> RDREdge:
|
704
|
+
return RDREdge.Alternative
|
705
|
+
|
706
|
+
def update_current_parent(self, new_node: Rule):
|
707
|
+
if self.current_edge == RDREdge.Alternative:
|
708
|
+
self.current_parent.alternative = new_node
|
709
|
+
elif self.current_edge == RDREdge.Refinement:
|
710
|
+
self.current_parent.refinement = new_node
|
711
|
+
|
712
|
+
def process_else_statement(self, stmt: ast.AST):
|
713
|
+
"""Handles the else statement in the if-elif-else block."""
|
714
|
+
if isinstance(stmt, ast.Return):
|
715
|
+
self.current_parent = None
|
716
|
+
self.visit_Return(stmt)
|
717
|
+
else:
|
718
|
+
raise ValueError(f"Unexpected statement in else block: {stmt}")
|
719
|
+
|
720
|
+
|
721
|
+
class MultiClassTreeBuilder(TreeBuilder):
|
722
|
+
"""Parses an AST of generated MultiClassRDR classifier and reconstructs the rdr tree."""
|
723
|
+
|
724
|
+
def visit_If(self, stmt: ast.If):
|
725
|
+
super().visit_If(stmt)
|
726
|
+
if isinstance(self.current_parent, (MultiClassTopRule, MultiClassFilterRule)):
|
727
|
+
self.current_parent.conclusion = self.current_parent.conditions.replace("conditions_", "conclusion_")
|
728
|
+
|
729
|
+
def visit_Return(self, node):
|
730
|
+
pass
|
731
|
+
|
732
|
+
def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
|
733
|
+
if self.current_edge == RDREdge.Refinement:
|
734
|
+
return MultiClassStopRule
|
735
|
+
elif self.current_edge == RDREdge.Filter:
|
736
|
+
return MultiClassFilterRule
|
737
|
+
elif self.current_edge in [RDREdge.Next, None]:
|
738
|
+
return MultiClassTopRule
|
739
|
+
elif self.current_edge == RDREdge.Alternative:
|
740
|
+
return self.get_refinement_rule_type(node)
|
741
|
+
else:
|
742
|
+
raise ValueError(f"Unknown edge type: {self.current_edge}")
|
743
|
+
|
744
|
+
def get_alternative_edge(self, node: ast.AST) -> RDREdge:
|
745
|
+
if isinstance(self.current_parent, MultiClassTopRule):
|
746
|
+
return RDREdge.Next
|
747
|
+
else:
|
748
|
+
return RDREdge.Alternative
|
749
|
+
|
750
|
+
def get_refinement_edge(self, node: ast.AST) -> RDREdge:
|
751
|
+
rule_type = self.get_refinement_rule_type(node)
|
752
|
+
return self.get_refinement_edge_from_refinement_rule(rule_type)
|
753
|
+
|
754
|
+
def get_refinement_edge_from_refinement_rule(self, rule_type: Type[Rule]) -> RDREdge:
|
755
|
+
"""
|
756
|
+
:param rule_type: The type of the rule to determine the refinement edge from.
|
757
|
+
:return: The refinement edge type based on the rule type.
|
758
|
+
"""
|
759
|
+
if isinstance(self.current_parent, MultiClassRefinementRule):
|
760
|
+
return RDREdge.Alternative
|
761
|
+
if rule_type == MultiClassStopRule:
|
762
|
+
return RDREdge.Refinement
|
763
|
+
else:
|
764
|
+
return RDREdge.Filter
|
765
|
+
|
766
|
+
def get_refinement_rule_type(self, node: ast.AST) -> Type[Rule]:
|
767
|
+
"""
|
768
|
+
:param node: The current AST node to determine the rule type from.
|
769
|
+
:return: The rule type based on the node body.
|
770
|
+
"""
|
771
|
+
for stmt in node.body:
|
772
|
+
if len(node.body) == 1 and isinstance(stmt, ast.Pass):
|
773
|
+
return MultiClassStopRule
|
774
|
+
elif isinstance(stmt, ast.If):
|
775
|
+
return self.get_refinement_rule_type(stmt)
|
776
|
+
else:
|
777
|
+
return MultiClassFilterRule
|
778
|
+
raise ValueError(f"Could not determine the refinement rule type from the node: {node} as it has an empty body.")
|
779
|
+
|
780
|
+
def update_current_parent(self, new_node: Rule):
|
781
|
+
if isinstance(new_node, MultiClassRefinementRule):
|
782
|
+
if isinstance(self.current_parent, MultiClassTopRule):
|
783
|
+
new_node.top_rule = self.current_parent
|
784
|
+
elif hasattr(self.current_parent, "top_rule"):
|
785
|
+
new_node.top_rule = self.current_parent.top_rule
|
786
|
+
else:
|
787
|
+
raise ValueError(f"Could not set the top rule for the refinement rule: {new_node}")
|
788
|
+
if self.current_edge in [RDREdge.Alternative, RDREdge.Next, None]:
|
789
|
+
self.current_parent.alternative = new_node
|
790
|
+
elif self.current_edge in [RDREdge.Refinement, RDREdge.Filter]:
|
791
|
+
self.current_parent.refinement = new_node
|
792
|
+
|
793
|
+
def process_else_statement(self, stmt: ast.AST):
|
794
|
+
"""Handles the else statement in the if-elif-else block."""
|
795
|
+
pass
|
796
|
+
|
485
797
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
486
798
|
|
487
|
-
|
799
|
+
@classmethod
|
800
|
+
def from_python(cls, model_path: str,
|
801
|
+
python_file_path: Optional[str] = None,
|
802
|
+
parent_package_name: Optional[str] = None) -> Self:
|
803
|
+
"""
|
804
|
+
Load the RDR classifier from a generated python file.
|
805
|
+
|
806
|
+
:param model_path: The directory where the generated python file is located.
|
807
|
+
:param python_file_path: The path to the generated python file that contains the RDR classifier function.
|
808
|
+
:param parent_package_name: The name of the package that contains the RDR classifier function, this
|
809
|
+
is required in case of relative imports in the generated python file.
|
810
|
+
:return: An instance of the RDR classifier.
|
811
|
+
"""
|
812
|
+
rule_tree_root = cls.read_rule_tree_from_python(model_path, python_file_path=python_file_path)
|
813
|
+
rdr = cls(start_rule=rule_tree_root)
|
814
|
+
rdr.update_from_python(model_path, package_name=parent_package_name, python_file_path=python_file_path)
|
815
|
+
return rdr
|
816
|
+
|
817
|
+
@classmethod
|
818
|
+
def read_rule_tree_from_python(cls, model_path: str, python_file_path: Optional[str] = None) -> Rule:
|
819
|
+
"""
|
820
|
+
:param model_path: The path to the generated python file that contains the RDR classifier function.
|
821
|
+
:param python_file_path: The path to the generated python file that contains the RDR classifier function.
|
822
|
+
"""
|
823
|
+
if python_file_path is None:
|
824
|
+
python_file_path = cls.get_generated_python_file_path(model_path)
|
825
|
+
with open(python_file_path, "r") as f:
|
826
|
+
source_code = f.read()
|
827
|
+
|
828
|
+
tree = ast.parse(source_code)
|
829
|
+
builder = cls.get_tree_builder_class()()
|
830
|
+
|
831
|
+
# Find and process the function
|
832
|
+
for node in tree.body:
|
833
|
+
if isinstance(node, ast.FunctionDef) and node.name == "classify":
|
834
|
+
builder.visit_FunctionDef(node)
|
835
|
+
|
836
|
+
return builder.root
|
837
|
+
|
838
|
+
@classmethod
|
839
|
+
@abstractmethod
|
840
|
+
def get_tree_builder_class(cls) -> Type[TreeBuilder]:
|
841
|
+
"""
|
842
|
+
:return: The class that builds the rule tree from the generated python file.
|
843
|
+
This should be either SingleClassTreeBuilder or MultiClassTreeBuilder.
|
844
|
+
"""
|
845
|
+
pass
|
846
|
+
|
847
|
+
@property
|
848
|
+
def all_rules(self) -> List[Rule]:
|
849
|
+
"""
|
850
|
+
Get all rules in the classifier.
|
851
|
+
|
852
|
+
:return: A list of all rules in the classifier.
|
853
|
+
"""
|
854
|
+
if self.start_rule is None:
|
855
|
+
return []
|
856
|
+
return [r for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None]
|
857
|
+
|
858
|
+
def update_from_python(self, model_dir: str, package_name: Optional[str] = None,
|
859
|
+
python_file_path: Optional[str] = None):
|
488
860
|
"""
|
489
861
|
Update the rules from the generated python file, that might have been modified by the user.
|
490
862
|
|
491
863
|
:param model_dir: The directory where the generated python file is located.
|
492
864
|
:param package_name: The name of the package that contains the RDR classifier function, this
|
493
865
|
is required in case of relative imports in the generated python file.
|
866
|
+
:param python_file_path: The path to the generated python file that contains the RDR classifier function.
|
494
867
|
"""
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
if not isinstance(rules_dict[rid], MultiClassStopRule)]
|
868
|
+
all_rules = self.all_rules
|
869
|
+
condition_func_names = [rule.generated_conditions_function_name for rule in all_rules]
|
870
|
+
conclusion_func_names = [rule.generated_conclusion_function_name for rule in all_rules
|
871
|
+
if not isinstance(rule, MultiClassStopRule)]
|
500
872
|
all_func_names = condition_func_names + conclusion_func_names
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
873
|
+
|
874
|
+
if python_file_path is None:
|
875
|
+
main_file_name = get_file_that_ends_with(model_dir, f"_{self.get_acronym().lower()}.py")
|
876
|
+
main_file_path = os.path.join(model_dir, main_file_name)
|
877
|
+
else:
|
878
|
+
main_file_path = python_file_path
|
879
|
+
if not os.path.exists(main_file_path):
|
880
|
+
raise ModuleNotFoundError(main_file_path)
|
881
|
+
self.generated_python_file_name = Path(main_file_path).name.replace(".py", "")
|
882
|
+
|
883
|
+
defs_file_path = main_file_path.replace(".py", "_defs.py")
|
884
|
+
defs_file_name = Path(defs_file_path).name.replace(".py", "")
|
885
|
+
|
886
|
+
cases_path = main_file_path.replace(".py", "_cases.py")
|
887
|
+
cases_file_name = Path(cases_path).name.replace(".py", "")
|
888
|
+
model_import_path = get_import_path_from_path(model_dir)
|
889
|
+
cases_import_path = f"{model_import_path}.{cases_file_name}" if model_import_path \
|
890
|
+
else cases_file_name
|
891
|
+
if os.path.exists(cases_path):
|
892
|
+
cases_module = importlib.import_module(cases_import_path, package=package_name)
|
893
|
+
importlib.reload(cases_module)
|
894
|
+
else:
|
895
|
+
cases_module = None
|
896
|
+
|
897
|
+
defs_import_path = f"{model_import_path}.{defs_file_name}" if model_import_path \
|
898
|
+
else defs_file_name
|
899
|
+
defs_module = importlib.import_module(defs_import_path, package=package_name)
|
900
|
+
importlib.reload(defs_module)
|
901
|
+
|
902
|
+
main_file_name = Path(main_file_path).name.replace(".py", "")
|
903
|
+
main_import_path = f"{model_import_path}.{main_file_name}" if model_import_path \
|
904
|
+
else main_file_name
|
905
|
+
main_module = importlib.import_module(main_import_path, package=package_name)
|
906
|
+
importlib.reload(main_module)
|
907
|
+
|
908
|
+
self.start_rule.conclusion_name = main_module.attribute_name
|
909
|
+
self.update_rdr_metadata_from_python(main_module)
|
910
|
+
functions_source = extract_function_source(defs_file_path, all_func_names, include_signature=False)
|
911
|
+
scope = extract_imports(defs_file_path, package_name=package_name)
|
912
|
+
for rule in all_rules:
|
515
913
|
if rule.conditions is not None:
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
continue
|
520
|
-
rule.conditions.user_input = functions_source[conditions_name]
|
521
|
-
rule.conditions.scope = scope
|
914
|
+
conditions_wrapper_func_name = rule.generated_conditions_function_name
|
915
|
+
user_input = functions_source[conditions_wrapper_func_name]
|
916
|
+
rule.conditions = CallableExpression(user_input, (bool,), scope=scope)
|
522
917
|
if os.path.exists(cases_path):
|
523
|
-
|
524
|
-
importlib.reload(module)
|
525
|
-
rule.corner_case_metadata = module.__dict__.get(f"corner_case_{rule.uid}", None)
|
918
|
+
rule.corner_case_metadata = cases_module.__dict__.get(rule.generated_corner_case_object_name, None)
|
526
919
|
if not isinstance(rule, MultiClassStopRule):
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
rule.conclusion
|
532
|
-
|
533
|
-
if isinstance(rule, MultiClassTopRule):
|
534
|
-
rule.parent.set_immediate_alternative(rule.alternative)
|
535
|
-
if rule.refinement is not None:
|
536
|
-
ref_rules = [ref_rule for ref_rule in [rule.refinement] + list(rule.refinement.descendants)]
|
537
|
-
for ref_rule in ref_rules:
|
538
|
-
del ref_rule
|
539
|
-
else:
|
540
|
-
rule.parent.refinement = rule.alternative
|
541
|
-
if rule.alternative is not None:
|
542
|
-
rule.alternative = None
|
543
|
-
rule.parent = None
|
544
|
-
del rule
|
920
|
+
conclusion_wrapper_func_name = rule.generated_conclusion_function_name
|
921
|
+
user_input = functions_source[conclusion_wrapper_func_name]
|
922
|
+
conclusion_func = defs_module.__dict__.get(rule.generated_conclusion_function_name)
|
923
|
+
conclusion_type = get_function_return_type(conclusion_func)
|
924
|
+
rule.conclusion = CallableExpression(user_input, conclusion_type, scope=scope,
|
925
|
+
mutually_exclusive=self.mutually_exclusive)
|
545
926
|
|
546
927
|
@abstractmethod
|
547
928
|
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
|
@@ -604,6 +985,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
604
985
|
f.write(f"attribute_name = '{self.attribute_name}'\n")
|
605
986
|
f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
|
606
987
|
f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
|
988
|
+
self.write_rdr_metadata_to_pyton_file(f)
|
607
989
|
f.write(f"\n\n{func_def}")
|
608
990
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
609
991
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
@@ -669,9 +1051,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
669
1051
|
:return: The type of the conclusion of the RDR classifier.
|
670
1052
|
"""
|
671
1053
|
all_types = []
|
672
|
-
|
673
|
-
|
674
|
-
all_types.extend(list(rule.conclusion.conclusion_type))
|
1054
|
+
for rule in self.all_rules:
|
1055
|
+
all_types.extend(list(rule.conclusion.conclusion_type))
|
675
1056
|
return tuple(set(all_types))
|
676
1057
|
|
677
1058
|
@property
|
@@ -728,6 +1109,10 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
728
1109
|
super(SingleClassRDR, self).__init__(**kwargs)
|
729
1110
|
self.default_conclusion: Optional[Any] = default_conclusion
|
730
1111
|
|
1112
|
+
@classmethod
|
1113
|
+
def get_tree_builder_class(cls) -> Type[TreeBuilder]:
|
1114
|
+
return SingleClassTreeBuilder
|
1115
|
+
|
731
1116
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
732
1117
|
-> Union[CaseAttribute, CallableExpression, None]:
|
733
1118
|
"""
|
@@ -879,6 +1264,10 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
879
1264
|
super(MultiClassRDR, self).__init__(start_rule, **kwargs)
|
880
1265
|
self.mode: MCRDRMode = mode
|
881
1266
|
|
1267
|
+
@classmethod
|
1268
|
+
def get_tree_builder_class(cls) -> Type[TreeBuilder]:
|
1269
|
+
return MultiClassTreeBuilder
|
1270
|
+
|
882
1271
|
def _classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
|
883
1272
|
case_query: Optional[CaseQuery] = None) -> Set[Any]:
|
884
1273
|
evaluated_rule = self.start_rule
|
@@ -1128,6 +1517,58 @@ class GeneralRDR(RippleDownRules):
|
|
1128
1517
|
super(GeneralRDR, self).__init__(**kwargs)
|
1129
1518
|
self.all_figs: List[Figure] = [sr.fig for sr in self.start_rules_dict.values()]
|
1130
1519
|
|
1520
|
+
@classmethod
|
1521
|
+
def from_python(cls, model_dir: str, python_file_path: Optional[str] = None,
|
1522
|
+
parent_package_name: Optional[str] = None) -> Self:
|
1523
|
+
"""
|
1524
|
+
Create an instance of the class from a python file.
|
1525
|
+
|
1526
|
+
:param model_dir: The path to the directory containing the python file.
|
1527
|
+
:param python_file_path: The path to the python file, if not provided, it will be generated from the model_dir.
|
1528
|
+
:param parent_package_name: The name of the package that contains the RDR classifier function, this
|
1529
|
+
is required in case of relative imports in the generated python file.
|
1530
|
+
:return: An instance of the class.
|
1531
|
+
"""
|
1532
|
+
if python_file_path is None:
|
1533
|
+
file_name = get_file_that_ends_with(model_dir, f"_{cls.get_acronym().lower()}.py",)
|
1534
|
+
main_python_file_path = os.path.join(model_dir, file_name)
|
1535
|
+
else:
|
1536
|
+
main_python_file_path = python_file_path
|
1537
|
+
main_python_file_name = Path(main_python_file_path).name.replace('.py', '')
|
1538
|
+
main_module_import_path = get_import_path_from_path(model_dir)
|
1539
|
+
main_module_import_path = f"{main_module_import_path}.{main_python_file_name}" \
|
1540
|
+
if main_module_import_path else main_python_file_name
|
1541
|
+
main_module = importlib.import_module(main_module_import_path)
|
1542
|
+
importlib.reload(main_module)
|
1543
|
+
classifiers_dict = main_module.classifiers_dict
|
1544
|
+
start_rules_dict = {}
|
1545
|
+
for rdr_name, rdr_module in classifiers_dict.items():
|
1546
|
+
rdr_module_name = rdr_module.__name__
|
1547
|
+
rdr_acronym = rdr_module_name.split('_')[-1]
|
1548
|
+
rdr_type = cls.get_rdr_type_from_acronym(rdr_acronym)
|
1549
|
+
rdr_model_path = main_python_file_path.replace('_rdr.py', f'_{rdr_name}_{rdr_acronym}.py')
|
1550
|
+
rdr = rdr_type.from_python(model_dir, python_file_path=rdr_model_path, parent_package_name=parent_package_name)
|
1551
|
+
start_rules_dict[rdr_name] = rdr
|
1552
|
+
grdr = cls(category_rdr_map=start_rules_dict)
|
1553
|
+
grdr.update_rdr_metadata_from_python(main_module)
|
1554
|
+
return grdr
|
1555
|
+
|
1556
|
+
@classmethod
|
1557
|
+
def get_rdr_type_from_acronym(cls, acronym: str) -> Type[Union[SingleClassRDR, MultiClassRDR]]:
|
1558
|
+
"""
|
1559
|
+
Get the type of the ripple down rules classifier from the acronym.
|
1560
|
+
|
1561
|
+
:param acronym: The acronym of the ripple down rules classifier.
|
1562
|
+
:return: The type of the ripple down rules classifier.
|
1563
|
+
"""
|
1564
|
+
acronym = acronym.lower()
|
1565
|
+
if acronym == "scrdr":
|
1566
|
+
return SingleClassRDR
|
1567
|
+
elif acronym == "mcrdr":
|
1568
|
+
return MultiClassRDR
|
1569
|
+
else:
|
1570
|
+
raise ValueError(f"Unknown RDR type acronym: {acronym}")
|
1571
|
+
|
1131
1572
|
def add_rdr(self, rdr: Union[SingleClassRDR, MultiClassRDR], case_query: Optional[CaseQuery] = None):
|
1132
1573
|
"""
|
1133
1574
|
Add a ripple down rules classifier to the map of classifiers.
|
@@ -1249,7 +1690,7 @@ class GeneralRDR(RippleDownRules):
|
|
1249
1690
|
Write the tree of rules as source code to a file.
|
1250
1691
|
|
1251
1692
|
:param model_dir: The directory where the model is stored.
|
1252
|
-
:param
|
1693
|
+
:param package_name: The name of the package that contains the RDR classifier function.
|
1253
1694
|
"""
|
1254
1695
|
for rdr in self.start_rules_dict.values():
|
1255
1696
|
rdr._write_to_python(model_dir, package_name=package_name)
|
@@ -1257,6 +1698,7 @@ class GeneralRDR(RippleDownRules):
|
|
1257
1698
|
file_path = model_dir + f"/{self.generated_python_file_name}.py"
|
1258
1699
|
with open(file_path, "w") as f:
|
1259
1700
|
f.write(self._get_imports(file_path=file_path, package_name=package_name) + "\n\n")
|
1701
|
+
self.write_rdr_metadata_to_pyton_file(f)
|
1260
1702
|
f.write("classifiers_dict = dict()\n")
|
1261
1703
|
for rdr_key, rdr in self.start_rules_dict.items():
|
1262
1704
|
f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
|
ripple_down_rules/rules.py
CHANGED
@@ -214,6 +214,10 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
214
214
|
def generated_conditions_function_name(self) -> str:
|
215
215
|
return f"conditions_{self.uid}"
|
216
216
|
|
217
|
+
@property
|
218
|
+
def generated_corner_case_object_name(self) -> str:
|
219
|
+
return f"corner_case_{self.uid}"
|
220
|
+
|
217
221
|
def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
|
218
222
|
"""
|
219
223
|
Convert the conclusion of a rule to source code.
|
ripple_down_rules/utils.py
CHANGED
@@ -715,6 +715,42 @@ origin_type_to_hint = {
|
|
715
715
|
}
|
716
716
|
|
717
717
|
|
718
|
+
def get_file_that_ends_with(directory_path: str, suffix: str) -> Optional[str]:
|
719
|
+
"""
|
720
|
+
Get the file that ends with the given suffix in the model directory.
|
721
|
+
|
722
|
+
:param directory_path: The path to the directory where the file is located.
|
723
|
+
:param suffix: The suffix to search for.
|
724
|
+
:return: The path to the file that ends with the given suffix, or None if not found.
|
725
|
+
"""
|
726
|
+
files = [f for f in os.listdir(directory_path) if f.endswith(suffix)]
|
727
|
+
if files:
|
728
|
+
return files[0]
|
729
|
+
return None
|
730
|
+
|
731
|
+
def get_function_return_type(func: Callable) -> Union[Type, None, Tuple[Type, ...]]:
|
732
|
+
"""
|
733
|
+
Get the return type of a function.
|
734
|
+
|
735
|
+
:param func: The function to get the return type for.
|
736
|
+
:return: The return type of the function, or None if not specified.
|
737
|
+
"""
|
738
|
+
sig = inspect.signature(func)
|
739
|
+
if sig.return_annotation == inspect.Signature.empty:
|
740
|
+
return None
|
741
|
+
type_hint = sig.return_annotation
|
742
|
+
origin = get_origin(type_hint)
|
743
|
+
args = get_args(type_hint)
|
744
|
+
if origin not in [list, set, None, Union]:
|
745
|
+
raise TypeError(f"{origin} is not a handled return type for function {func.__name__}")
|
746
|
+
if origin is None:
|
747
|
+
return typing_to_python_type(type_hint)
|
748
|
+
if args is None or len(args) == 0:
|
749
|
+
return typing_to_python_type(type_hint)
|
750
|
+
return args
|
751
|
+
|
752
|
+
|
753
|
+
|
718
754
|
def extract_types(tp, seen: Set = None) -> Set[type]:
|
719
755
|
"""Recursively extract all base types from a type hint."""
|
720
756
|
if seen is None:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.6.
|
3
|
+
Version: 0.6.41
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -1,13 +1,13 @@
|
|
1
|
-
ripple_down_rules/__init__.py,sha256=
|
1
|
+
ripple_down_rules/__init__.py,sha256=nYtH-We95fPZmgC7xF4K09TKVWFy7gOrO7Hv4zidPiI,99
|
2
2
|
ripple_down_rules/experts.py,sha256=KXwWCmDrCffu9HW3yNewqWc1e5rnPI5Rnc981w_5M7U,17896
|
3
3
|
ripple_down_rules/helpers.py,sha256=X1psHOqrb4_xYN4ssQNB8S9aRKKsqgihAyWJurN0dqk,5499
|
4
|
-
ripple_down_rules/rdr.py,sha256=
|
4
|
+
ripple_down_rules/rdr.py,sha256=s9y3ImomYOw3WKHggljfIaNFRxid3qtM0h-yso9oOdk,81327
|
5
5
|
ripple_down_rules/rdr_decorators.py,sha256=xoBGsIJMkJYUdsrsEaPZqoAsGuXkuVZAKCoP-xD2Iv8,11668
|
6
|
-
ripple_down_rules/rules.py,sha256=
|
6
|
+
ripple_down_rules/rules.py,sha256=32apFTxtWXKRQ2MJDnqc1URjRJDnNBe_t5A_gGfKNd0,29349
|
7
7
|
ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
|
8
|
-
ripple_down_rules/utils.py,sha256=
|
8
|
+
ripple_down_rules/utils.py,sha256=SuZKBnGZvfDqlry7PY_N8AzLBOhEeuev57AVS-fJy78,77992
|
9
9
|
ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
|
10
|
-
ripple_down_rules/datastructures/callable_expression.py,sha256=
|
10
|
+
ripple_down_rules/datastructures/callable_expression.py,sha256=rzMrpD5oztaCRlt3hQ2B_xZ09cSuJNkYOCePndfQJRA,13684
|
11
11
|
ripple_down_rules/datastructures/case.py,sha256=dfLnrjsHIVF2bgbz-4ID7OdQvw68V71btCeTK372P-g,15667
|
12
12
|
ripple_down_rules/datastructures/dataclasses.py,sha256=3vX52WrAHgVyw0LUSgSBOVFaQNTSxU8hQpdr7cW-tSg,13278
|
13
13
|
ripple_down_rules/datastructures/enums.py,sha256=CvcROl8fE7A6uTbMfs2lLpyxwS_ZFtFcQlBDDKFfoHc,6059
|
@@ -17,8 +17,8 @@ ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=RLdPqPxx-a0Sh74U
|
|
17
17
|
ripple_down_rules/user_interface/object_diagram.py,sha256=FEa2HaYR9QmTE6NsOwBvZ0jqmu3DKyg6mig2VE5ZP4Y,4956
|
18
18
|
ripple_down_rules/user_interface/prompt.py,sha256=WPbw_8_-8SpF2ISyRZRuFwPKBEuGC4HaX3lbCPFHhh8,10314
|
19
19
|
ripple_down_rules/user_interface/template_file_creator.py,sha256=uSbosZS15MOR3Nv7M3MrFuoiKXyP4cBId-EK3I6stHM,13660
|
20
|
-
ripple_down_rules-0.6.
|
21
|
-
ripple_down_rules-0.6.
|
22
|
-
ripple_down_rules-0.6.
|
23
|
-
ripple_down_rules-0.6.
|
24
|
-
ripple_down_rules-0.6.
|
20
|
+
ripple_down_rules-0.6.41.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
21
|
+
ripple_down_rules-0.6.41.dist-info/METADATA,sha256=DkE-Aey_rdGNqR0UzZM4uT7BMaCcSB0vobs2I1sYYRo,48294
|
22
|
+
ripple_down_rules-0.6.41.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
23
|
+
ripple_down_rules-0.6.41.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
24
|
+
ripple_down_rules-0.6.41.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|