ripple-down-rules 0.6.31__py3-none-any.whl → 0.6.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- __version__ = "0.6.31"
1
+ __version__ = "0.6.41"
2
2
 
3
3
  import logging
4
4
  logger = logging.Logger("rdr")
@@ -129,6 +129,9 @@ class CallableExpression(SubclassJSONSerializer):
129
129
  else:
130
130
  conclusion_type = (conclusion_type,)
131
131
  self.conclusion_type = conclusion_type
132
+ self.expected_types: Set[Type] = set(conclusion_type) if conclusion_type is not None else set()
133
+ if not mutually_exclusive:
134
+ self.expected_types.update({list, set})
132
135
  self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
133
136
  self.scope = get_used_scope(self.user_input, self.scope)
134
137
  self.expression_tree: AST = expression_tree if expression_tree else parse_string_to_expression(self.user_input)
@@ -158,7 +161,7 @@ class CallableExpression(SubclassJSONSerializer):
158
161
  raise ValueError(f"Mutually exclusive types cannot be lists or sets, got {type(output)}")
159
162
  output_types = {type(o) for o in make_list(output)}
160
163
  output_types.add(type(output))
161
- if not are_results_subclass_of_types(output_types, self.conclusion_type):
164
+ if not are_results_subclass_of_types(output_types, self.expected_types):
162
165
  raise ValueError(f"Not all result types {output_types} are subclasses of expected types"
163
166
  f" {self.conclusion_type}")
164
167
  return output
ripple_down_rules/rdr.py CHANGED
@@ -1,12 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import ast
3
4
  import importlib
4
5
  import json
5
6
  import os
6
7
  from abc import ABC, abstractmethod
7
8
  from copy import copy
8
9
  from dataclasses import is_dataclass
9
- from types import NoneType
10
+ from io import TextIOWrapper
11
+ from pathlib import Path
12
+ from types import NoneType, ModuleType
10
13
 
11
14
  from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
12
15
  from . import logger
@@ -26,7 +29,7 @@ from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tupl
26
29
  from .datastructures.callable_expression import CallableExpression
27
30
  from .datastructures.case import Case, CaseAttribute, create_case
28
31
  from .datastructures.dataclasses import CaseQuery
29
- from .datastructures.enums import MCRDRMode
32
+ from .datastructures.enums import MCRDRMode, RDREdge
30
33
  from .experts import Expert, Human
31
34
  from .helpers import is_matching, general_rdr_classify, get_an_updated_case_copy
32
35
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule, MultiClassRefinementRule, \
@@ -38,7 +41,8 @@ except ImportError as e:
38
41
  RDRCaseViewer = None
39
42
  from .utils import draw_tree, make_set, SubclassJSONSerializer, make_list, get_type_from_string, \
40
43
  is_value_conflicting, extract_function_source, extract_imports, get_full_class_name, \
41
- is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types, render_tree
44
+ is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types, render_tree, \
45
+ get_types_to_import_from_func_type_hints, get_function_return_type, get_file_that_ends_with
42
46
 
43
47
 
44
48
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -96,6 +100,30 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
96
100
  if RDRCaseViewer and any(RDRCaseViewer.instances) else None
97
101
  self.input_node: Optional[Rule] = None
98
102
 
103
+ def write_rdr_metadata_to_pyton_file(self, file: TextIOWrapper):
104
+ """
105
+ Write the metadata of the RDR classifier to a python file.
106
+
107
+ :param file: The file to write the metadata to.
108
+ """
109
+ file.write(f"name = \'{self.name}\'\n")
110
+ file.write(f"case_type = {self.case_type.__name__ if self.case_type is not None else None}\n")
111
+ file.write(f"case_name = \'{self.case_name}\'\n")
112
+
113
+ def update_rdr_metadata_from_python(self, module: ModuleType):
114
+ """
115
+ Update the RDR metadata from the module that contains the RDR classifier function.
116
+
117
+ :param module: The module that contains the RDR classifier function.
118
+ """
119
+ try:
120
+ self.name = module.name if hasattr(module, "name") else self.start_rule.conclusion_name
121
+ self.case_type = module.case_type
122
+ self.case_name = module.case_name if hasattr(module, "case_name") else f"{self.case_type.__name__}.{self.name}"
123
+ except AttributeError as e:
124
+ logger.warning(f"Could not update the RDR metadata from the module {module.__name__}. "
125
+ f"Make sure the module has the required attributes: {e}")
126
+
99
127
  def render_evaluated_rule_tree(self, filename: str, show_full_tree: bool = False) -> None:
100
128
  if show_full_tree:
101
129
  start_rule = self.start_rule if self.input_node is None else self.input_node
@@ -180,20 +208,42 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
180
208
  :param package_name: The name of the package that contains the RDR classifier function, this
181
209
  is required in case of relative imports in the generated python file.
182
210
  """
211
+ rdr: Optional[RippleDownRules] = None
183
212
  model_dir = os.path.join(load_dir, model_name)
184
213
  json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
185
- rdr = cls.from_json_file(json_file)
186
- rdr.save_dir = load_dir
187
- rdr.model_name = model_name
214
+ if os.path.exists(json_file + ".json"):
215
+ rdr = cls.from_json_file(json_file)
188
216
  try:
189
- rdr.update_from_python(model_dir, package_name=package_name)
217
+ if rdr is None:
218
+ acronym = cls.get_acronym().lower()
219
+ python_file_name = get_file_that_ends_with(model_dir, f"_{acronym}.py")
220
+ python_file_path = os.path.join(model_dir, python_file_name)
221
+ rdr = cls.from_python(model_dir, parent_package_name=package_name, python_file_path=python_file_path)
222
+ else:
223
+ rdr.update_from_python(model_dir, package_name=package_name)
190
224
  rdr.to_json_file(json_file)
191
- except (FileNotFoundError, ValueError, SyntaxError) as e:
225
+ except (FileNotFoundError, ValueError, SyntaxError, ModuleNotFoundError) as e:
192
226
  logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
193
227
  f"Make sure the file exists and is valid.")
194
228
  rdr.save(save_dir=load_dir, model_name=model_name, package_name=package_name)
229
+ rdr.save_dir = load_dir
230
+ rdr.model_name = model_name
195
231
  return rdr
196
232
 
233
+ @classmethod
234
+ @abstractmethod
235
+ def from_python(cls, model_dir: str, python_file_path: Optional[str] = None,
236
+ parent_package_name: Optional[str] = None) -> Self:
237
+ """
238
+ Load the classifier from a python file.
239
+
240
+ :param model_dir: The path to the directory where the generated python file is located.
241
+ :param python_file_path: The path to the python file to load the classifier from.
242
+ :param parent_package_name: The name of the package that contains the RDR classifier function, this
243
+ is required in case of relative imports in the generated python file.
244
+ """
245
+ pass
246
+
197
247
  @abstractmethod
198
248
  def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
199
249
  """
@@ -424,6 +474,58 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
424
474
  def type_(self):
425
475
  return self.__class__
426
476
 
477
+ @classmethod
478
+ def get_json_file_path(cls, model_path: str) -> str:
479
+ """
480
+ Get the path to the saved json file.
481
+
482
+ :param model_path : The path to the model directory.
483
+ :return: The path to the saved model.
484
+ """
485
+ model_name = cls.get_model_name_from_model_path(model_path)
486
+ return os.path.join(model_path, cls.metadata_folder, f"{model_name}.json")
487
+
488
+ @classmethod
489
+ def get_generated_cases_file_path(cls, model_path: str) -> str:
490
+ """
491
+ Get the path to the python file that contains the RDR classifier cases.
492
+
493
+ :param model_path : The path to the model directory.
494
+ :return: The path to the generated python file.
495
+ """
496
+ return cls.get_generated_python_file_path(model_path).replace(".py", "_cases.py")
497
+
498
+ @classmethod
499
+ def get_generated_defs_file_path(cls, model_path: str) -> str:
500
+ """
501
+ Get the path to the python file that contains the RDR classifier function definitions.
502
+
503
+ :param model_path : The path to the model directory.
504
+ :return: The path to the generated python file.
505
+ """
506
+ return cls.get_generated_python_file_path(model_path).replace(".py", "_defs.py")
507
+
508
+ @classmethod
509
+ def get_generated_python_file_path(cls, model_path: str) -> str:
510
+ """
511
+ Get the path to the python file that contains the RDR classifier function.
512
+
513
+ :param model_path : The path to the model directory.
514
+ :return: The path to the generated python file.
515
+ """
516
+ model_name = cls.get_model_name_from_model_path(model_path)
517
+ return os.path.join(model_path, f"{model_name}.py")
518
+
519
+ @classmethod
520
+ def get_model_name_from_model_path(cls, model_path: str) -> str:
521
+ """
522
+ Get the model name from the model path.
523
+
524
+ :param model_path: The path to the model directory.
525
+ :return: The name of the model.
526
+ """
527
+ return Path(model_path).name
528
+
427
529
  @property
428
530
  def generated_python_file_name(self) -> str:
429
531
  if self._generated_python_file_name is None:
@@ -482,66 +584,345 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
482
584
  return module.classify
483
585
 
484
586
 
587
+ class TreeBuilder(ast.NodeVisitor, ABC):
588
+ """Parses an AST of nested if-elif statements and reconstructs the tree."""
589
+
590
+ def __init__(self):
591
+ self.root: Optional[Rule] = None
592
+ self.current_parent: Optional[Rule] = None
593
+ self.current_edge: Optional[RDREdge] = None
594
+ self.default_conclusion: Optional[str] = None
595
+
596
+ def visit_FunctionDef(self, node):
597
+ """Finds the main function and starts parsing its body."""
598
+ for stmt in node.body:
599
+ self.visit(stmt)
600
+
601
+ def visit_If(self, node):
602
+ """Handles if-elif blocks and creates nodes."""
603
+ condition = self.get_condition_name(node.test)
604
+ if condition is None:
605
+ return
606
+ rule_uid = condition.split("conditions_")[1]
607
+
608
+ new_rule_type = self.get_new_rule_type(node)
609
+ new_node = new_rule_type(conditions=condition, parent=self.current_parent, uid=rule_uid)
610
+ if self.current_parent is not None:
611
+ self.update_current_parent(new_node)
612
+
613
+ if self.current_parent is None and self.root is None:
614
+ self.root = new_node
615
+
616
+ self.current_parent = new_node
617
+
618
+ # Parse the body of the if statement
619
+ for stmt in node.body:
620
+ self.current_edge = self.get_refinement_edge(node)
621
+ self.current_parent = new_node
622
+ self.visit(stmt)
623
+
624
+ # Parse elif/else
625
+ for stmt in node.orelse:
626
+ self.current_edge = self.get_alternative_edge(node)
627
+ self.current_parent = new_node
628
+ if isinstance(stmt, ast.If): # elif case
629
+ self.visit_If(stmt)
630
+ else: # else case (return)
631
+ self.process_else_statement(stmt)
632
+ self.current_parent = new_node
633
+ self.current_edge = None
634
+
635
+ @abstractmethod
636
+ def process_else_statement(self, stmt: ast.AST):
637
+ """
638
+ Process the else statement in the if-elif-else block.
639
+
640
+ :param stmt: The else statement to process.
641
+ """
642
+ pass
643
+
644
+ @abstractmethod
645
+ def get_refinement_edge(self, node: ast.AST) -> RDREdge:
646
+ """
647
+ :param node: The current AST node to determine the edge type from.
648
+ :return: The refinement edge type.
649
+ """
650
+ pass
651
+
652
+ @abstractmethod
653
+ def get_alternative_edge(self, node: ast.AST) -> RDREdge:
654
+ """
655
+ :param node: The current AST node to determine the alternative edge type from.
656
+ :return: The alternative edge type.
657
+ """
658
+ pass
659
+
660
+ @abstractmethod
661
+ def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
662
+ """
663
+ Get the new rule type to create.
664
+ :param node: The current AST node to determine the rule type from.
665
+ :return: The new rule type.
666
+ """
667
+ pass
668
+
669
+ @abstractmethod
670
+ def update_current_parent(self, new_node: Rule):
671
+ """
672
+ Update the current parent rule with the new node.
673
+ :param new_node: The new node to set as the current parent.
674
+ """
675
+ pass
676
+
677
+ def visit_Return(self, node):
678
+ """Handles return statements as leaf nodes."""
679
+ if isinstance(node.value, ast.Call):
680
+ return_value = node.value.func.id
681
+ else:
682
+ return_value = ast.literal_eval(node.value)
683
+ if self.current_parent is None:
684
+ self.default_conclusion = return_value
685
+ else:
686
+ self.current_parent.conclusion = return_value
687
+
688
+ def get_condition_name(self, node):
689
+ """Extracts the condition function name from an AST expression."""
690
+ if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
691
+ return node.func.id
692
+ return None
693
+
694
+ class SingleClassTreeBuilder(TreeBuilder):
695
+ """Parses an AST of generated SingleClassRDR classifier and reconstructs the rdr tree."""
696
+
697
+ def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
698
+ return SingleClassRule
699
+
700
+ def get_refinement_edge(self, node: ast.AST) -> RDREdge:
701
+ return RDREdge.Refinement
702
+
703
+ def get_alternative_edge(self, node: ast.AST) -> RDREdge:
704
+ return RDREdge.Alternative
705
+
706
+ def update_current_parent(self, new_node: Rule):
707
+ if self.current_edge == RDREdge.Alternative:
708
+ self.current_parent.alternative = new_node
709
+ elif self.current_edge == RDREdge.Refinement:
710
+ self.current_parent.refinement = new_node
711
+
712
+ def process_else_statement(self, stmt: ast.AST):
713
+ """Handles the else statement in the if-elif-else block."""
714
+ if isinstance(stmt, ast.Return):
715
+ self.current_parent = None
716
+ self.visit_Return(stmt)
717
+ else:
718
+ raise ValueError(f"Unexpected statement in else block: {stmt}")
719
+
720
+
721
+ class MultiClassTreeBuilder(TreeBuilder):
722
+ """Parses an AST of generated MultiClassRDR classifier and reconstructs the rdr tree."""
723
+
724
+ def visit_If(self, stmt: ast.If):
725
+ super().visit_If(stmt)
726
+ if isinstance(self.current_parent, (MultiClassTopRule, MultiClassFilterRule)):
727
+ self.current_parent.conclusion = self.current_parent.conditions.replace("conditions_", "conclusion_")
728
+
729
+ def visit_Return(self, node):
730
+ pass
731
+
732
+ def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
733
+ if self.current_edge == RDREdge.Refinement:
734
+ return MultiClassStopRule
735
+ elif self.current_edge == RDREdge.Filter:
736
+ return MultiClassFilterRule
737
+ elif self.current_edge in [RDREdge.Next, None]:
738
+ return MultiClassTopRule
739
+ elif self.current_edge == RDREdge.Alternative:
740
+ return self.get_refinement_rule_type(node)
741
+ else:
742
+ raise ValueError(f"Unknown edge type: {self.current_edge}")
743
+
744
+ def get_alternative_edge(self, node: ast.AST) -> RDREdge:
745
+ if isinstance(self.current_parent, MultiClassTopRule):
746
+ return RDREdge.Next
747
+ else:
748
+ return RDREdge.Alternative
749
+
750
+ def get_refinement_edge(self, node: ast.AST) -> RDREdge:
751
+ rule_type = self.get_refinement_rule_type(node)
752
+ return self.get_refinement_edge_from_refinement_rule(rule_type)
753
+
754
+ def get_refinement_edge_from_refinement_rule(self, rule_type: Type[Rule]) -> RDREdge:
755
+ """
756
+ :param rule_type: The type of the rule to determine the refinement edge from.
757
+ :return: The refinement edge type based on the rule type.
758
+ """
759
+ if isinstance(self.current_parent, MultiClassRefinementRule):
760
+ return RDREdge.Alternative
761
+ if rule_type == MultiClassStopRule:
762
+ return RDREdge.Refinement
763
+ else:
764
+ return RDREdge.Filter
765
+
766
+ def get_refinement_rule_type(self, node: ast.AST) -> Type[Rule]:
767
+ """
768
+ :param node: The current AST node to determine the rule type from.
769
+ :return: The rule type based on the node body.
770
+ """
771
+ for stmt in node.body:
772
+ if len(node.body) == 1 and isinstance(stmt, ast.Pass):
773
+ return MultiClassStopRule
774
+ elif isinstance(stmt, ast.If):
775
+ return self.get_refinement_rule_type(stmt)
776
+ else:
777
+ return MultiClassFilterRule
778
+ raise ValueError(f"Could not determine the refinement rule type from the node: {node} as it has an empty body.")
779
+
780
+ def update_current_parent(self, new_node: Rule):
781
+ if isinstance(new_node, MultiClassRefinementRule):
782
+ if isinstance(self.current_parent, MultiClassTopRule):
783
+ new_node.top_rule = self.current_parent
784
+ elif hasattr(self.current_parent, "top_rule"):
785
+ new_node.top_rule = self.current_parent.top_rule
786
+ else:
787
+ raise ValueError(f"Could not set the top rule for the refinement rule: {new_node}")
788
+ if self.current_edge in [RDREdge.Alternative, RDREdge.Next, None]:
789
+ self.current_parent.alternative = new_node
790
+ elif self.current_edge in [RDREdge.Refinement, RDREdge.Filter]:
791
+ self.current_parent.refinement = new_node
792
+
793
+ def process_else_statement(self, stmt: ast.AST):
794
+ """Handles the else statement in the if-elif-else block."""
795
+ pass
796
+
485
797
  class RDRWithCodeWriter(RippleDownRules, ABC):
486
798
 
487
- def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
799
+ @classmethod
800
+ def from_python(cls, model_path: str,
801
+ python_file_path: Optional[str] = None,
802
+ parent_package_name: Optional[str] = None) -> Self:
803
+ """
804
+ Load the RDR classifier from a generated python file.
805
+
806
+ :param model_path: The directory where the generated python file is located.
807
+ :param python_file_path: The path to the generated python file that contains the RDR classifier function.
808
+ :param parent_package_name: The name of the package that contains the RDR classifier function, this
809
+ is required in case of relative imports in the generated python file.
810
+ :return: An instance of the RDR classifier.
811
+ """
812
+ rule_tree_root = cls.read_rule_tree_from_python(model_path, python_file_path=python_file_path)
813
+ rdr = cls(start_rule=rule_tree_root)
814
+ rdr.update_from_python(model_path, package_name=parent_package_name, python_file_path=python_file_path)
815
+ return rdr
816
+
817
+ @classmethod
818
+ def read_rule_tree_from_python(cls, model_path: str, python_file_path: Optional[str] = None) -> Rule:
819
+ """
820
+ :param model_path: The path to the generated python file that contains the RDR classifier function.
821
+ :param python_file_path: The path to the generated python file that contains the RDR classifier function.
822
+ """
823
+ if python_file_path is None:
824
+ python_file_path = cls.get_generated_python_file_path(model_path)
825
+ with open(python_file_path, "r") as f:
826
+ source_code = f.read()
827
+
828
+ tree = ast.parse(source_code)
829
+ builder = cls.get_tree_builder_class()()
830
+
831
+ # Find and process the function
832
+ for node in tree.body:
833
+ if isinstance(node, ast.FunctionDef) and node.name == "classify":
834
+ builder.visit_FunctionDef(node)
835
+
836
+ return builder.root
837
+
838
+ @classmethod
839
+ @abstractmethod
840
+ def get_tree_builder_class(cls) -> Type[TreeBuilder]:
841
+ """
842
+ :return: The class that builds the rule tree from the generated python file.
843
+ This should be either SingleClassTreeBuilder or MultiClassTreeBuilder.
844
+ """
845
+ pass
846
+
847
+ @property
848
+ def all_rules(self) -> List[Rule]:
849
+ """
850
+ Get all rules in the classifier.
851
+
852
+ :return: A list of all rules in the classifier.
853
+ """
854
+ if self.start_rule is None:
855
+ return []
856
+ return [r for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None]
857
+
858
+ def update_from_python(self, model_dir: str, package_name: Optional[str] = None,
859
+ python_file_path: Optional[str] = None):
488
860
  """
489
861
  Update the rules from the generated python file, that might have been modified by the user.
490
862
 
491
863
  :param model_dir: The directory where the generated python file is located.
492
864
  :param package_name: The name of the package that contains the RDR classifier function, this
493
865
  is required in case of relative imports in the generated python file.
866
+ :param python_file_path: The path to the generated python file that contains the RDR classifier function.
494
867
  """
495
- rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants)
496
- if r.conditions is not None}
497
- condition_func_names = [f'conditions_{rid}' for rid in rules_dict.keys()]
498
- conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys()
499
- if not isinstance(rules_dict[rid], MultiClassStopRule)]
868
+ all_rules = self.all_rules
869
+ condition_func_names = [rule.generated_conditions_function_name for rule in all_rules]
870
+ conclusion_func_names = [rule.generated_conclusion_function_name for rule in all_rules
871
+ if not isinstance(rule, MultiClassStopRule)]
500
872
  all_func_names = condition_func_names + conclusion_func_names
501
- rule_tree_file_path = f"{model_dir}/{self.generated_python_file_name}.py"
502
- filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
503
- cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
504
- cases_import_path = get_import_path_from_path(model_dir)
505
- cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path \
506
- else self.generated_python_cases_file_name
507
- functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
508
- python_rule_tree_source = ""
509
- with open(rule_tree_file_path, "r") as rule_tree_source:
510
- python_rule_tree_source = rule_tree_source.read()
511
- # get the scope from the imports in the file
512
- scope = extract_imports(filepath, package_name=package_name)
513
- rules_not_found = set()
514
- for rule in [self.start_rule] + list(self.start_rule.descendants):
873
+
874
+ if python_file_path is None:
875
+ main_file_name = get_file_that_ends_with(model_dir, f"_{self.get_acronym().lower()}.py")
876
+ main_file_path = os.path.join(model_dir, main_file_name)
877
+ else:
878
+ main_file_path = python_file_path
879
+ if not os.path.exists(main_file_path):
880
+ raise ModuleNotFoundError(main_file_path)
881
+ self.generated_python_file_name = Path(main_file_path).name.replace(".py", "")
882
+
883
+ defs_file_path = main_file_path.replace(".py", "_defs.py")
884
+ defs_file_name = Path(defs_file_path).name.replace(".py", "")
885
+
886
+ cases_path = main_file_path.replace(".py", "_cases.py")
887
+ cases_file_name = Path(cases_path).name.replace(".py", "")
888
+ model_import_path = get_import_path_from_path(model_dir)
889
+ cases_import_path = f"{model_import_path}.{cases_file_name}" if model_import_path \
890
+ else cases_file_name
891
+ if os.path.exists(cases_path):
892
+ cases_module = importlib.import_module(cases_import_path, package=package_name)
893
+ importlib.reload(cases_module)
894
+ else:
895
+ cases_module = None
896
+
897
+ defs_import_path = f"{model_import_path}.{defs_file_name}" if model_import_path \
898
+ else defs_file_name
899
+ defs_module = importlib.import_module(defs_import_path, package=package_name)
900
+ importlib.reload(defs_module)
901
+
902
+ main_file_name = Path(main_file_path).name.replace(".py", "")
903
+ main_import_path = f"{model_import_path}.{main_file_name}" if model_import_path \
904
+ else main_file_name
905
+ main_module = importlib.import_module(main_import_path, package=package_name)
906
+ importlib.reload(main_module)
907
+
908
+ self.start_rule.conclusion_name = main_module.attribute_name
909
+ self.update_rdr_metadata_from_python(main_module)
910
+ functions_source = extract_function_source(defs_file_path, all_func_names, include_signature=False)
911
+ scope = extract_imports(defs_file_path, package_name=package_name)
912
+ for rule in all_rules:
515
913
  if rule.conditions is not None:
516
- conditions_name = rule.generated_conditions_function_name
517
- if conditions_name not in functions_source or conditions_name not in python_rule_tree_source:
518
- rules_not_found.add(rule)
519
- continue
520
- rule.conditions.user_input = functions_source[conditions_name]
521
- rule.conditions.scope = scope
914
+ conditions_wrapper_func_name = rule.generated_conditions_function_name
915
+ user_input = functions_source[conditions_wrapper_func_name]
916
+ rule.conditions = CallableExpression(user_input, (bool,), scope=scope)
522
917
  if os.path.exists(cases_path):
523
- module = importlib.import_module(cases_import_path, package=package_name)
524
- importlib.reload(module)
525
- rule.corner_case_metadata = module.__dict__.get(f"corner_case_{rule.uid}", None)
918
+ rule.corner_case_metadata = cases_module.__dict__.get(rule.generated_corner_case_object_name, None)
526
919
  if not isinstance(rule, MultiClassStopRule):
527
- conclusion_name = rule.generated_conclusion_function_name
528
- if conclusion_name not in functions_source or conclusion_name not in python_rule_tree_source:
529
- rules_not_found.add(rule)
530
- rule.conclusion.user_input = functions_source[conclusion_name]
531
- rule.conclusion.scope = scope
532
- for rule in rules_not_found:
533
- if isinstance(rule, MultiClassTopRule):
534
- rule.parent.set_immediate_alternative(rule.alternative)
535
- if rule.refinement is not None:
536
- ref_rules = [ref_rule for ref_rule in [rule.refinement] + list(rule.refinement.descendants)]
537
- for ref_rule in ref_rules:
538
- del ref_rule
539
- else:
540
- rule.parent.refinement = rule.alternative
541
- if rule.alternative is not None:
542
- rule.alternative = None
543
- rule.parent = None
544
- del rule
920
+ conclusion_wrapper_func_name = rule.generated_conclusion_function_name
921
+ user_input = functions_source[conclusion_wrapper_func_name]
922
+ conclusion_func = defs_module.__dict__.get(rule.generated_conclusion_function_name)
923
+ conclusion_type = get_function_return_type(conclusion_func)
924
+ rule.conclusion = CallableExpression(user_input, conclusion_type, scope=scope,
925
+ mutually_exclusive=self.mutually_exclusive)
545
926
 
546
927
  @abstractmethod
547
928
  def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
@@ -604,6 +985,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
604
985
  f.write(f"attribute_name = '{self.attribute_name}'\n")
605
986
  f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
606
987
  f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
988
+ self.write_rdr_metadata_to_pyton_file(f)
607
989
  f.write(f"\n\n{func_def}")
608
990
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
609
991
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
@@ -669,9 +1051,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
669
1051
  :return: The type of the conclusion of the RDR classifier.
670
1052
  """
671
1053
  all_types = []
672
- if self.start_rule is not None:
673
- for rule in [self.start_rule] + list(self.start_rule.descendants):
674
- all_types.extend(list(rule.conclusion.conclusion_type))
1054
+ for rule in self.all_rules:
1055
+ all_types.extend(list(rule.conclusion.conclusion_type))
675
1056
  return tuple(set(all_types))
676
1057
 
677
1058
  @property
@@ -728,6 +1109,10 @@ class SingleClassRDR(RDRWithCodeWriter):
728
1109
  super(SingleClassRDR, self).__init__(**kwargs)
729
1110
  self.default_conclusion: Optional[Any] = default_conclusion
730
1111
 
1112
+ @classmethod
1113
+ def get_tree_builder_class(cls) -> Type[TreeBuilder]:
1114
+ return SingleClassTreeBuilder
1115
+
731
1116
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
732
1117
  -> Union[CaseAttribute, CallableExpression, None]:
733
1118
  """
@@ -879,6 +1264,10 @@ class MultiClassRDR(RDRWithCodeWriter):
879
1264
  super(MultiClassRDR, self).__init__(start_rule, **kwargs)
880
1265
  self.mode: MCRDRMode = mode
881
1266
 
1267
+ @classmethod
1268
+ def get_tree_builder_class(cls) -> Type[TreeBuilder]:
1269
+ return MultiClassTreeBuilder
1270
+
882
1271
  def _classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
883
1272
  case_query: Optional[CaseQuery] = None) -> Set[Any]:
884
1273
  evaluated_rule = self.start_rule
@@ -1128,6 +1517,58 @@ class GeneralRDR(RippleDownRules):
1128
1517
  super(GeneralRDR, self).__init__(**kwargs)
1129
1518
  self.all_figs: List[Figure] = [sr.fig for sr in self.start_rules_dict.values()]
1130
1519
 
1520
+ @classmethod
1521
+ def from_python(cls, model_dir: str, python_file_path: Optional[str] = None,
1522
+ parent_package_name: Optional[str] = None) -> Self:
1523
+ """
1524
+ Create an instance of the class from a python file.
1525
+
1526
+ :param model_dir: The path to the directory containing the python file.
1527
+ :param python_file_path: The path to the python file, if not provided, it will be generated from the model_dir.
1528
+ :param parent_package_name: The name of the package that contains the RDR classifier function, this
1529
+ is required in case of relative imports in the generated python file.
1530
+ :return: An instance of the class.
1531
+ """
1532
+ if python_file_path is None:
1533
+ file_name = get_file_that_ends_with(model_dir, f"_{cls.get_acronym().lower()}.py",)
1534
+ main_python_file_path = os.path.join(model_dir, file_name)
1535
+ else:
1536
+ main_python_file_path = python_file_path
1537
+ main_python_file_name = Path(main_python_file_path).name.replace('.py', '')
1538
+ main_module_import_path = get_import_path_from_path(model_dir)
1539
+ main_module_import_path = f"{main_module_import_path}.{main_python_file_name}" \
1540
+ if main_module_import_path else main_python_file_name
1541
+ main_module = importlib.import_module(main_module_import_path)
1542
+ importlib.reload(main_module)
1543
+ classifiers_dict = main_module.classifiers_dict
1544
+ start_rules_dict = {}
1545
+ for rdr_name, rdr_module in classifiers_dict.items():
1546
+ rdr_module_name = rdr_module.__name__
1547
+ rdr_acronym = rdr_module_name.split('_')[-1]
1548
+ rdr_type = cls.get_rdr_type_from_acronym(rdr_acronym)
1549
+ rdr_model_path = main_python_file_path.replace('_rdr.py', f'_{rdr_name}_{rdr_acronym}.py')
1550
+ rdr = rdr_type.from_python(model_dir, python_file_path=rdr_model_path, parent_package_name=parent_package_name)
1551
+ start_rules_dict[rdr_name] = rdr
1552
+ grdr = cls(category_rdr_map=start_rules_dict)
1553
+ grdr.update_rdr_metadata_from_python(main_module)
1554
+ return grdr
1555
+
1556
+ @classmethod
1557
+ def get_rdr_type_from_acronym(cls, acronym: str) -> Type[Union[SingleClassRDR, MultiClassRDR]]:
1558
+ """
1559
+ Get the type of the ripple down rules classifier from the acronym.
1560
+
1561
+ :param acronym: The acronym of the ripple down rules classifier.
1562
+ :return: The type of the ripple down rules classifier.
1563
+ """
1564
+ acronym = acronym.lower()
1565
+ if acronym == "scrdr":
1566
+ return SingleClassRDR
1567
+ elif acronym == "mcrdr":
1568
+ return MultiClassRDR
1569
+ else:
1570
+ raise ValueError(f"Unknown RDR type acronym: {acronym}")
1571
+
1131
1572
  def add_rdr(self, rdr: Union[SingleClassRDR, MultiClassRDR], case_query: Optional[CaseQuery] = None):
1132
1573
  """
1133
1574
  Add a ripple down rules classifier to the map of classifiers.
@@ -1249,7 +1690,7 @@ class GeneralRDR(RippleDownRules):
1249
1690
  Write the tree of rules as source code to a file.
1250
1691
 
1251
1692
  :param model_dir: The directory where the model is stored.
1252
- :param relative_imports: Whether to use relative imports in the generated python file.
1693
+ :param package_name: The name of the package that contains the RDR classifier function.
1253
1694
  """
1254
1695
  for rdr in self.start_rules_dict.values():
1255
1696
  rdr._write_to_python(model_dir, package_name=package_name)
@@ -1257,6 +1698,7 @@ class GeneralRDR(RippleDownRules):
1257
1698
  file_path = model_dir + f"/{self.generated_python_file_name}.py"
1258
1699
  with open(file_path, "w") as f:
1259
1700
  f.write(self._get_imports(file_path=file_path, package_name=package_name) + "\n\n")
1701
+ self.write_rdr_metadata_to_pyton_file(f)
1260
1702
  f.write("classifiers_dict = dict()\n")
1261
1703
  for rdr_key, rdr in self.start_rules_dict.items():
1262
1704
  f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
@@ -214,6 +214,10 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
214
214
  def generated_conditions_function_name(self) -> str:
215
215
  return f"conditions_{self.uid}"
216
216
 
217
+ @property
218
+ def generated_corner_case_object_name(self) -> str:
219
+ return f"corner_case_{self.uid}"
220
+
217
221
  def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
218
222
  """
219
223
  Convert the conclusion of a rule to source code.
@@ -715,6 +715,42 @@ origin_type_to_hint = {
715
715
  }
716
716
 
717
717
 
718
+ def get_file_that_ends_with(directory_path: str, suffix: str) -> Optional[str]:
719
+ """
720
+ Get the file that ends with the given suffix in the model directory.
721
+
722
+ :param directory_path: The path to the directory where the file is located.
723
+ :param suffix: The suffix to search for.
724
+ :return: The path to the file that ends with the given suffix, or None if not found.
725
+ """
726
+ files = [f for f in os.listdir(directory_path) if f.endswith(suffix)]
727
+ if files:
728
+ return files[0]
729
+ return None
730
+
731
+ def get_function_return_type(func: Callable) -> Union[Type, None, Tuple[Type, ...]]:
732
+ """
733
+ Get the return type of a function.
734
+
735
+ :param func: The function to get the return type for.
736
+ :return: The return type of the function, or None if not specified.
737
+ """
738
+ sig = inspect.signature(func)
739
+ if sig.return_annotation == inspect.Signature.empty:
740
+ return None
741
+ type_hint = sig.return_annotation
742
+ origin = get_origin(type_hint)
743
+ args = get_args(type_hint)
744
+ if origin not in [list, set, None, Union]:
745
+ raise TypeError(f"{origin} is not a handled return type for function {func.__name__}")
746
+ if origin is None:
747
+ return typing_to_python_type(type_hint)
748
+ if args is None or len(args) == 0:
749
+ return typing_to_python_type(type_hint)
750
+ return args
751
+
752
+
753
+
718
754
  def extract_types(tp, seen: Set = None) -> Set[type]:
719
755
  """Recursively extract all base types from a type hint."""
720
756
  if seen is None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.6.31
3
+ Version: 0.6.41
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -1,13 +1,13 @@
1
- ripple_down_rules/__init__.py,sha256=mH0F3aSoNRROyyjLKJo4D1dALPyoYh-J-mPNT1HudhU,99
1
+ ripple_down_rules/__init__.py,sha256=nYtH-We95fPZmgC7xF4K09TKVWFy7gOrO7Hv4zidPiI,99
2
2
  ripple_down_rules/experts.py,sha256=KXwWCmDrCffu9HW3yNewqWc1e5rnPI5Rnc981w_5M7U,17896
3
3
  ripple_down_rules/helpers.py,sha256=X1psHOqrb4_xYN4ssQNB8S9aRKKsqgihAyWJurN0dqk,5499
4
- ripple_down_rules/rdr.py,sha256=KsZbAbOs8U2PL19YOjFqSer8coXkSMDL3ztIrWHmTCA,62833
4
+ ripple_down_rules/rdr.py,sha256=s9y3ImomYOw3WKHggljfIaNFRxid3qtM0h-yso9oOdk,81327
5
5
  ripple_down_rules/rdr_decorators.py,sha256=xoBGsIJMkJYUdsrsEaPZqoAsGuXkuVZAKCoP-xD2Iv8,11668
6
- ripple_down_rules/rules.py,sha256=tmGJ5m9Z_d_qoRaWvRjDnl5AAVDgC_qeHGKAKN7WP64,29237
6
+ ripple_down_rules/rules.py,sha256=32apFTxtWXKRQ2MJDnqc1URjRJDnNBe_t5A_gGfKNd0,29349
7
7
  ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
8
- ripple_down_rules/utils.py,sha256=1fiSF4MOaOUrxlMz8sZA_e10258sMWuX5fG9WDawd2o,76674
8
+ ripple_down_rules/utils.py,sha256=SuZKBnGZvfDqlry7PY_N8AzLBOhEeuev57AVS-fJy78,77992
9
9
  ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
10
- ripple_down_rules/datastructures/callable_expression.py,sha256=IrlnufVsKrUDLVkc2owoFQ05oSOby3HiGuNXoFVj4Dw,13494
10
+ ripple_down_rules/datastructures/callable_expression.py,sha256=rzMrpD5oztaCRlt3hQ2B_xZ09cSuJNkYOCePndfQJRA,13684
11
11
  ripple_down_rules/datastructures/case.py,sha256=dfLnrjsHIVF2bgbz-4ID7OdQvw68V71btCeTK372P-g,15667
12
12
  ripple_down_rules/datastructures/dataclasses.py,sha256=3vX52WrAHgVyw0LUSgSBOVFaQNTSxU8hQpdr7cW-tSg,13278
13
13
  ripple_down_rules/datastructures/enums.py,sha256=CvcROl8fE7A6uTbMfs2lLpyxwS_ZFtFcQlBDDKFfoHc,6059
@@ -17,8 +17,8 @@ ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=RLdPqPxx-a0Sh74U
17
17
  ripple_down_rules/user_interface/object_diagram.py,sha256=FEa2HaYR9QmTE6NsOwBvZ0jqmu3DKyg6mig2VE5ZP4Y,4956
18
18
  ripple_down_rules/user_interface/prompt.py,sha256=WPbw_8_-8SpF2ISyRZRuFwPKBEuGC4HaX3lbCPFHhh8,10314
19
19
  ripple_down_rules/user_interface/template_file_creator.py,sha256=uSbosZS15MOR3Nv7M3MrFuoiKXyP4cBId-EK3I6stHM,13660
20
- ripple_down_rules-0.6.31.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
21
- ripple_down_rules-0.6.31.dist-info/METADATA,sha256=HRME3boNiuoTqo98-gzH7aK2-KCfc561e2eDDi09bhw,48294
22
- ripple_down_rules-0.6.31.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- ripple_down_rules-0.6.31.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
24
- ripple_down_rules-0.6.31.dist-info/RECORD,,
20
+ ripple_down_rules-0.6.41.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
21
+ ripple_down_rules-0.6.41.dist-info/METADATA,sha256=DkE-Aey_rdGNqR0UzZM4uT7BMaCcSB0vobs2I1sYYRo,48294
22
+ ripple_down_rules-0.6.41.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ ripple_down_rules-0.6.41.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
24
+ ripple_down_rules-0.6.41.dist-info/RECORD,,