ripple-down-rules 0.6.31__py3-none-any.whl → 0.6.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- __version__ = "0.6.31"
1
+ __version__ = "0.6.40"
2
2
 
3
3
  import logging
4
4
  logger = logging.Logger("rdr")
@@ -129,6 +129,9 @@ class CallableExpression(SubclassJSONSerializer):
129
129
  else:
130
130
  conclusion_type = (conclusion_type,)
131
131
  self.conclusion_type = conclusion_type
132
+ self.expected_types: Set[Type] = set(conclusion_type) if conclusion_type is not None else set()
133
+ if not mutually_exclusive:
134
+ self.expected_types.update({list, set})
132
135
  self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
133
136
  self.scope = get_used_scope(self.user_input, self.scope)
134
137
  self.expression_tree: AST = expression_tree if expression_tree else parse_string_to_expression(self.user_input)
@@ -158,7 +161,7 @@ class CallableExpression(SubclassJSONSerializer):
158
161
  raise ValueError(f"Mutually exclusive types cannot be lists or sets, got {type(output)}")
159
162
  output_types = {type(o) for o in make_list(output)}
160
163
  output_types.add(type(output))
161
- if not are_results_subclass_of_types(output_types, self.conclusion_type):
164
+ if not are_results_subclass_of_types(output_types, self.expected_types):
162
165
  raise ValueError(f"Not all result types {output_types} are subclasses of expected types"
163
166
  f" {self.conclusion_type}")
164
167
  return output
ripple_down_rules/rdr.py CHANGED
@@ -1,12 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import ast
3
4
  import importlib
4
5
  import json
5
6
  import os
6
7
  from abc import ABC, abstractmethod
7
8
  from copy import copy
8
9
  from dataclasses import is_dataclass
9
- from types import NoneType
10
+ from io import TextIOWrapper
11
+ from pathlib import Path
12
+ from types import NoneType, ModuleType
10
13
 
11
14
  from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
12
15
  from . import logger
@@ -26,7 +29,7 @@ from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tupl
26
29
  from .datastructures.callable_expression import CallableExpression
27
30
  from .datastructures.case import Case, CaseAttribute, create_case
28
31
  from .datastructures.dataclasses import CaseQuery
29
- from .datastructures.enums import MCRDRMode
32
+ from .datastructures.enums import MCRDRMode, RDREdge
30
33
  from .experts import Expert, Human
31
34
  from .helpers import is_matching, general_rdr_classify, get_an_updated_case_copy
32
35
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule, MultiClassRefinementRule, \
@@ -38,7 +41,8 @@ except ImportError as e:
38
41
  RDRCaseViewer = None
39
42
  from .utils import draw_tree, make_set, SubclassJSONSerializer, make_list, get_type_from_string, \
40
43
  is_value_conflicting, extract_function_source, extract_imports, get_full_class_name, \
41
- is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types, render_tree
44
+ is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types, render_tree, \
45
+ get_types_to_import_from_func_type_hints, get_function_return_type, get_file_that_ends_with
42
46
 
43
47
 
44
48
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -96,6 +100,30 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
96
100
  if RDRCaseViewer and any(RDRCaseViewer.instances) else None
97
101
  self.input_node: Optional[Rule] = None
98
102
 
103
+ def write_rdr_metadata_to_pyton_file(self, file: TextIOWrapper):
104
+ """
105
+ Write the metadata of the RDR classifier to a python file.
106
+
107
+ :param file: The file to write the metadata to.
108
+ """
109
+ file.write(f"name = \'{self.name}\'\n")
110
+ file.write(f"case_type = {self.case_type.__name__ if self.case_type is not None else None}\n")
111
+ file.write(f"case_name = \'{self.case_name}\'\n")
112
+
113
+ def update_rdr_metadata_from_python(self, module: ModuleType):
114
+ """
115
+ Update the RDR metadata from the module that contains the RDR classifier function.
116
+
117
+ :param module: The module that contains the RDR classifier function.
118
+ """
119
+ try:
120
+ self.name = module.name if hasattr(module, "name") else self.start_rule.conclusion_name
121
+ self.case_type = module.case_type
122
+ self.case_name = module.case_name if hasattr(module, "case_name") else f"{self.case_type.__name__}.{self.name}"
123
+ except AttributeError as e:
124
+ logger.warning(f"Could not update the RDR metadata from the module {module.__name__}. "
125
+ f"Make sure the module has the required attributes: {e}")
126
+
99
127
  def render_evaluated_rule_tree(self, filename: str, show_full_tree: bool = False) -> None:
100
128
  if show_full_tree:
101
129
  start_rule = self.start_rule if self.input_node is None else self.input_node
@@ -180,13 +208,21 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
180
208
  :param package_name: The name of the package that contains the RDR classifier function, this
181
209
  is required in case of relative imports in the generated python file.
182
210
  """
211
+ rdr: Optional[RippleDownRules] = None
183
212
  model_dir = os.path.join(load_dir, model_name)
184
213
  json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
185
- rdr = cls.from_json_file(json_file)
186
- rdr.save_dir = load_dir
187
- rdr.model_name = model_name
214
+ if os.path.exists(json_file + ".json"):
215
+ rdr = cls.from_json_file(json_file)
188
216
  try:
189
- rdr.update_from_python(model_dir, package_name=package_name)
217
+ if rdr is None:
218
+ acronym = cls.get_acronym().lower()
219
+ python_file_name = get_file_that_ends_with(model_dir, f"_{acronym}.py")
220
+ python_file_path = os.path.join(model_dir, python_file_name)
221
+ rdr = cls.from_python(model_dir, parent_package_name=package_name, python_file_path=python_file_path)
222
+ else:
223
+ rdr.update_from_python(model_dir, package_name=package_name)
224
+ rdr.save_dir = load_dir
225
+ rdr.model_name = model_name
190
226
  rdr.to_json_file(json_file)
191
227
  except (FileNotFoundError, ValueError, SyntaxError) as e:
192
228
  logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
@@ -194,6 +230,20 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
194
230
  rdr.save(save_dir=load_dir, model_name=model_name, package_name=package_name)
195
231
  return rdr
196
232
 
233
+ @classmethod
234
+ @abstractmethod
235
+ def from_python(cls, model_dir: str, python_file_path: Optional[str] = None,
236
+ parent_package_name: Optional[str] = None) -> Self:
237
+ """
238
+ Load the classifier from a python file.
239
+
240
+ :param model_dir: The path to the directory where the generated python file is located.
241
+ :param python_file_path: The path to the python file to load the classifier from.
242
+ :param parent_package_name: The name of the package that contains the RDR classifier function, this
243
+ is required in case of relative imports in the generated python file.
244
+ """
245
+ pass
246
+
197
247
  @abstractmethod
198
248
  def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
199
249
  """
@@ -424,6 +474,58 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
424
474
  def type_(self):
425
475
  return self.__class__
426
476
 
477
+ @classmethod
478
+ def get_json_file_path(cls, model_path: str) -> str:
479
+ """
480
+ Get the path to the saved json file.
481
+
482
+ :param model_path : The path to the model directory.
483
+ :return: The path to the saved model.
484
+ """
485
+ model_name = cls.get_model_name_from_model_path(model_path)
486
+ return os.path.join(model_path, cls.metadata_folder, f"{model_name}.json")
487
+
488
+ @classmethod
489
+ def get_generated_cases_file_path(cls, model_path: str) -> str:
490
+ """
491
+ Get the path to the python file that contains the RDR classifier cases.
492
+
493
+ :param model_path : The path to the model directory.
494
+ :return: The path to the generated python file.
495
+ """
496
+ return cls.get_generated_python_file_path(model_path).replace(".py", "_cases.py")
497
+
498
+ @classmethod
499
+ def get_generated_defs_file_path(cls, model_path: str) -> str:
500
+ """
501
+ Get the path to the python file that contains the RDR classifier function definitions.
502
+
503
+ :param model_path : The path to the model directory.
504
+ :return: The path to the generated python file.
505
+ """
506
+ return cls.get_generated_python_file_path(model_path).replace(".py", "_defs.py")
507
+
508
+ @classmethod
509
+ def get_generated_python_file_path(cls, model_path: str) -> str:
510
+ """
511
+ Get the path to the python file that contains the RDR classifier function.
512
+
513
+ :param model_path : The path to the model directory.
514
+ :return: The path to the generated python file.
515
+ """
516
+ model_name = cls.get_model_name_from_model_path(model_path)
517
+ return os.path.join(model_path, f"{model_name}.py")
518
+
519
+ @classmethod
520
+ def get_model_name_from_model_path(cls, model_path: str) -> str:
521
+ """
522
+ Get the model name from the model path.
523
+
524
+ :param model_path: The path to the model directory.
525
+ :return: The name of the model.
526
+ """
527
+ return Path(model_path).name
528
+
427
529
  @property
428
530
  def generated_python_file_name(self) -> str:
429
531
  if self._generated_python_file_name is None:
@@ -482,66 +584,343 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
482
584
  return module.classify
483
585
 
484
586
 
587
+ class TreeBuilder(ast.NodeVisitor, ABC):
588
+ """Parses an AST of nested if-elif statements and reconstructs the tree."""
589
+
590
+ def __init__(self):
591
+ self.root: Optional[Rule] = None
592
+ self.current_parent: Optional[Rule] = None
593
+ self.current_edge: Optional[RDREdge] = None
594
+ self.default_conclusion: Optional[str] = None
595
+
596
+ def visit_FunctionDef(self, node):
597
+ """Finds the main function and starts parsing its body."""
598
+ for stmt in node.body:
599
+ self.visit(stmt)
600
+
601
+ def visit_If(self, node):
602
+ """Handles if-elif blocks and creates nodes."""
603
+ condition = self.get_condition_name(node.test)
604
+ if condition is None:
605
+ return
606
+ rule_uid = condition.split("conditions_")[1]
607
+
608
+ new_rule_type = self.get_new_rule_type(node)
609
+ new_node = new_rule_type(conditions=condition, parent=self.current_parent, uid=rule_uid)
610
+ if self.current_parent is not None:
611
+ self.update_current_parent(new_node)
612
+
613
+ if self.current_parent is None and self.root is None:
614
+ self.root = new_node
615
+
616
+ self.current_parent = new_node
617
+
618
+ # Parse the body of the if statement
619
+ for stmt in node.body:
620
+ self.current_edge = self.get_refinement_edge(node)
621
+ self.current_parent = new_node
622
+ self.visit(stmt)
623
+
624
+ # Parse elif/else
625
+ for stmt in node.orelse:
626
+ self.current_edge = self.get_alternative_edge(node)
627
+ self.current_parent = new_node
628
+ if isinstance(stmt, ast.If): # elif case
629
+ self.visit_If(stmt)
630
+ else: # else case (return)
631
+ self.process_else_statement(stmt)
632
+ self.current_parent = new_node
633
+ self.current_edge = None
634
+
635
+ @abstractmethod
636
+ def process_else_statement(self, stmt: ast.AST):
637
+ """
638
+ Process the else statement in the if-elif-else block.
639
+
640
+ :param stmt: The else statement to process.
641
+ """
642
+ pass
643
+
644
+ @abstractmethod
645
+ def get_refinement_edge(self, node: ast.AST) -> RDREdge:
646
+ """
647
+ :param node: The current AST node to determine the edge type from.
648
+ :return: The refinement edge type.
649
+ """
650
+ pass
651
+
652
+ @abstractmethod
653
+ def get_alternative_edge(self, node: ast.AST) -> RDREdge:
654
+ """
655
+ :param node: The current AST node to determine the alternative edge type from.
656
+ :return: The alternative edge type.
657
+ """
658
+ pass
659
+
660
+ @abstractmethod
661
+ def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
662
+ """
663
+ Get the new rule type to create.
664
+ :param node: The current AST node to determine the rule type from.
665
+ :return: The new rule type.
666
+ """
667
+ pass
668
+
669
+ @abstractmethod
670
+ def update_current_parent(self, new_node: Rule):
671
+ """
672
+ Update the current parent rule with the new node.
673
+ :param new_node: The new node to set as the current parent.
674
+ """
675
+ pass
676
+
677
+ def visit_Return(self, node):
678
+ """Handles return statements as leaf nodes."""
679
+ if isinstance(node.value, ast.Call):
680
+ return_value = node.value.func.id
681
+ else:
682
+ return_value = ast.literal_eval(node.value)
683
+ if self.current_parent is None:
684
+ self.default_conclusion = return_value
685
+ else:
686
+ self.current_parent.conclusion = return_value
687
+
688
+ def get_condition_name(self, node):
689
+ """Extracts the condition function name from an AST expression."""
690
+ if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
691
+ return node.func.id
692
+ return None
693
+
694
+ class SingleClassTreeBuilder(TreeBuilder):
695
+ """Parses an AST of generated SingleClassRDR classifier and reconstructs the rdr tree."""
696
+
697
+ def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
698
+ return SingleClassRule
699
+
700
+ def get_refinement_edge(self, node: ast.AST) -> RDREdge:
701
+ return RDREdge.Refinement
702
+
703
+ def get_alternative_edge(self, node: ast.AST) -> RDREdge:
704
+ return RDREdge.Alternative
705
+
706
+ def update_current_parent(self, new_node: Rule):
707
+ if self.current_edge == RDREdge.Alternative:
708
+ self.current_parent.alternative = new_node
709
+ elif self.current_edge == RDREdge.Refinement:
710
+ self.current_parent.refinement = new_node
711
+
712
+ def process_else_statement(self, stmt: ast.AST):
713
+ """Handles the else statement in the if-elif-else block."""
714
+ if isinstance(stmt, ast.Return):
715
+ self.current_parent = None
716
+ self.visit_Return(stmt)
717
+ else:
718
+ raise ValueError(f"Unexpected statement in else block: {stmt}")
719
+
720
+
721
+ class MultiClassTreeBuilder(TreeBuilder):
722
+ """Parses an AST of generated MultiClassRDR classifier and reconstructs the rdr tree."""
723
+
724
+ def visit_If(self, stmt: ast.If):
725
+ super().visit_If(stmt)
726
+ if isinstance(self.current_parent, (MultiClassTopRule, MultiClassFilterRule)):
727
+ self.current_parent.conclusion = self.current_parent.conditions.replace("conditions_", "conclusion_")
728
+
729
+ def visit_Return(self, node):
730
+ pass
731
+
732
+ def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
733
+ if self.current_edge == RDREdge.Refinement:
734
+ return MultiClassStopRule
735
+ elif self.current_edge == RDREdge.Filter:
736
+ return MultiClassFilterRule
737
+ elif self.current_edge in [RDREdge.Next, None]:
738
+ return MultiClassTopRule
739
+ elif self.current_edge == RDREdge.Alternative:
740
+ return self.get_refinement_rule_type(node)
741
+ else:
742
+ raise ValueError(f"Unknown edge type: {self.current_edge}")
743
+
744
+ def get_alternative_edge(self, node: ast.AST) -> RDREdge:
745
+ if isinstance(self.current_parent, MultiClassTopRule):
746
+ return RDREdge.Next
747
+ else:
748
+ return RDREdge.Alternative
749
+
750
+ def get_refinement_edge(self, node: ast.AST) -> RDREdge:
751
+ rule_type = self.get_refinement_rule_type(node)
752
+ return self.get_refinement_edge_from_refinement_rule(rule_type)
753
+
754
+ def get_refinement_edge_from_refinement_rule(self, rule_type: Type[Rule]) -> RDREdge:
755
+ """
756
+ :param rule_type: The type of the rule to determine the refinement edge from.
757
+ :return: The refinement edge type based on the rule type.
758
+ """
759
+ if isinstance(self.current_parent, MultiClassRefinementRule):
760
+ return RDREdge.Alternative
761
+ if rule_type == MultiClassStopRule:
762
+ return RDREdge.Refinement
763
+ else:
764
+ return RDREdge.Filter
765
+
766
+ def get_refinement_rule_type(self, node: ast.AST) -> Type[Rule]:
767
+ """
768
+ :param node: The current AST node to determine the rule type from.
769
+ :return: The rule type based on the node body.
770
+ """
771
+ for stmt in node.body:
772
+ if len(node.body) == 1 and isinstance(stmt, ast.Pass):
773
+ return MultiClassStopRule
774
+ elif isinstance(stmt, ast.If):
775
+ return self.get_refinement_rule_type(stmt)
776
+ else:
777
+ return MultiClassFilterRule
778
+ raise ValueError(f"Could not determine the refinement rule type from the node: {node} as it has an empty body.")
779
+
780
+ def update_current_parent(self, new_node: Rule):
781
+ if isinstance(new_node, MultiClassRefinementRule):
782
+ if isinstance(self.current_parent, MultiClassTopRule):
783
+ new_node.top_rule = self.current_parent
784
+ elif hasattr(self.current_parent, "top_rule"):
785
+ new_node.top_rule = self.current_parent.top_rule
786
+ else:
787
+ raise ValueError(f"Could not set the top rule for the refinement rule: {new_node}")
788
+ if self.current_edge in [RDREdge.Alternative, RDREdge.Next, None]:
789
+ self.current_parent.alternative = new_node
790
+ elif self.current_edge in [RDREdge.Refinement, RDREdge.Filter]:
791
+ self.current_parent.refinement = new_node
792
+
793
+ def process_else_statement(self, stmt: ast.AST):
794
+ """Handles the else statement in the if-elif-else block."""
795
+ pass
796
+
485
797
  class RDRWithCodeWriter(RippleDownRules, ABC):
486
798
 
487
- def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
799
+ @classmethod
800
+ def from_python(cls, model_path: str,
801
+ python_file_path: Optional[str] = None,
802
+ parent_package_name: Optional[str] = None) -> Self:
803
+ """
804
+ Load the RDR classifier from a generated python file.
805
+
806
+ :param model_path: The directory where the generated python file is located.
807
+ :param python_file_path: The path to the generated python file that contains the RDR classifier function.
808
+ :param parent_package_name: The name of the package that contains the RDR classifier function, this
809
+ is required in case of relative imports in the generated python file.
810
+ :return: An instance of the RDR classifier.
811
+ """
812
+ rule_tree_root = cls.read_rule_tree_from_python(model_path, python_file_path=python_file_path)
813
+ rdr = cls(start_rule=rule_tree_root)
814
+ rdr.update_from_python(model_path, package_name=parent_package_name, python_file_path=python_file_path)
815
+ return rdr
816
+
817
+ @classmethod
818
+ def read_rule_tree_from_python(cls, model_path: str, python_file_path: Optional[str] = None) -> Rule:
819
+ """
820
+ :param model_path: The path to the generated python file that contains the RDR classifier function.
821
+ :param python_file_path: The path to the generated python file that contains the RDR classifier function.
822
+ """
823
+ if python_file_path is None:
824
+ python_file_path = cls.get_generated_python_file_path(model_path)
825
+ with open(python_file_path, "r") as f:
826
+ source_code = f.read()
827
+
828
+ tree = ast.parse(source_code)
829
+ builder = cls.get_tree_builder_class()()
830
+
831
+ # Find and process the function
832
+ for node in tree.body:
833
+ if isinstance(node, ast.FunctionDef) and node.name == "classify":
834
+ builder.visit_FunctionDef(node)
835
+
836
+ return builder.root
837
+
838
+ @classmethod
839
+ @abstractmethod
840
+ def get_tree_builder_class(cls) -> Type[TreeBuilder]:
841
+ """
842
+ :return: The class that builds the rule tree from the generated python file.
843
+ This should be either SingleClassTreeBuilder or MultiClassTreeBuilder.
844
+ """
845
+ pass
846
+
847
+ @property
848
+ def all_rules(self) -> List[Rule]:
849
+ """
850
+ Get all rules in the classifier.
851
+
852
+ :return: A list of all rules in the classifier.
853
+ """
854
+ if self.start_rule is None:
855
+ return []
856
+ return [r for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None]
857
+
858
+ def update_from_python(self, model_dir: str, package_name: Optional[str] = None,
859
+ python_file_path: Optional[str] = None):
488
860
  """
489
861
  Update the rules from the generated python file, that might have been modified by the user.
490
862
 
491
863
  :param model_dir: The directory where the generated python file is located.
492
864
  :param package_name: The name of the package that contains the RDR classifier function, this
493
865
  is required in case of relative imports in the generated python file.
866
+ :param python_file_path: The path to the generated python file that contains the RDR classifier function.
494
867
  """
495
- rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants)
496
- if r.conditions is not None}
497
- condition_func_names = [f'conditions_{rid}' for rid in rules_dict.keys()]
498
- conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys()
499
- if not isinstance(rules_dict[rid], MultiClassStopRule)]
868
+ all_rules = self.all_rules
869
+ condition_func_names = [rule.generated_conditions_function_name for rule in all_rules]
870
+ conclusion_func_names = [rule.generated_conclusion_function_name for rule in all_rules
871
+ if not isinstance(rule, MultiClassStopRule)]
500
872
  all_func_names = condition_func_names + conclusion_func_names
501
- rule_tree_file_path = f"{model_dir}/{self.generated_python_file_name}.py"
502
- filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
503
- cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
504
- cases_import_path = get_import_path_from_path(model_dir)
505
- cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path \
506
- else self.generated_python_cases_file_name
507
- functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
508
- python_rule_tree_source = ""
509
- with open(rule_tree_file_path, "r") as rule_tree_source:
510
- python_rule_tree_source = rule_tree_source.read()
511
- # get the scope from the imports in the file
512
- scope = extract_imports(filepath, package_name=package_name)
513
- rules_not_found = set()
514
- for rule in [self.start_rule] + list(self.start_rule.descendants):
873
+
874
+ if python_file_path is None:
875
+ main_file_name = get_file_that_ends_with(model_dir, f"_{self.get_acronym().lower()}.py")
876
+ main_file_path = os.path.join(model_dir, main_file_name)
877
+ else:
878
+ main_file_path = python_file_path
879
+ self.generated_python_file_name = Path(main_file_path).name.replace(".py", "")
880
+
881
+ defs_file_path = main_file_path.replace(".py", "_defs.py")
882
+ defs_file_name = Path(defs_file_path).name.replace(".py", "")
883
+
884
+ cases_path = main_file_path.replace(".py", "_cases.py")
885
+ cases_file_name = Path(cases_path).name.replace(".py", "")
886
+ model_import_path = get_import_path_from_path(model_dir)
887
+ cases_import_path = f"{model_import_path}.{cases_file_name}" if model_import_path \
888
+ else cases_file_name
889
+ if os.path.exists(cases_path):
890
+ cases_module = importlib.import_module(cases_import_path, package=package_name)
891
+ importlib.reload(cases_module)
892
+ else:
893
+ cases_module = None
894
+
895
+ defs_import_path = f"{model_import_path}.{defs_file_name}" if model_import_path \
896
+ else defs_file_name
897
+ defs_module = importlib.import_module(defs_import_path, package=package_name)
898
+ importlib.reload(defs_module)
899
+
900
+ main_file_name = Path(main_file_path).name.replace(".py", "")
901
+ main_import_path = f"{model_import_path}.{main_file_name}" if model_import_path \
902
+ else main_file_name
903
+ main_module = importlib.import_module(main_import_path, package=package_name)
904
+ importlib.reload(main_module)
905
+
906
+ self.start_rule.conclusion_name = main_module.attribute_name
907
+ self.update_rdr_metadata_from_python(main_module)
908
+ functions_source = extract_function_source(defs_file_path, all_func_names, include_signature=False)
909
+ scope = extract_imports(defs_file_path, package_name=package_name)
910
+ for rule in all_rules:
515
911
  if rule.conditions is not None:
516
- conditions_name = rule.generated_conditions_function_name
517
- if conditions_name not in functions_source or conditions_name not in python_rule_tree_source:
518
- rules_not_found.add(rule)
519
- continue
520
- rule.conditions.user_input = functions_source[conditions_name]
521
- rule.conditions.scope = scope
912
+ conditions_wrapper_func_name = rule.generated_conditions_function_name
913
+ user_input = functions_source[conditions_wrapper_func_name]
914
+ rule.conditions = CallableExpression(user_input, (bool,), scope=scope)
522
915
  if os.path.exists(cases_path):
523
- module = importlib.import_module(cases_import_path, package=package_name)
524
- importlib.reload(module)
525
- rule.corner_case_metadata = module.__dict__.get(f"corner_case_{rule.uid}", None)
916
+ rule.corner_case_metadata = cases_module.__dict__.get(rule.generated_corner_case_object_name, None)
526
917
  if not isinstance(rule, MultiClassStopRule):
527
- conclusion_name = rule.generated_conclusion_function_name
528
- if conclusion_name not in functions_source or conclusion_name not in python_rule_tree_source:
529
- rules_not_found.add(rule)
530
- rule.conclusion.user_input = functions_source[conclusion_name]
531
- rule.conclusion.scope = scope
532
- for rule in rules_not_found:
533
- if isinstance(rule, MultiClassTopRule):
534
- rule.parent.set_immediate_alternative(rule.alternative)
535
- if rule.refinement is not None:
536
- ref_rules = [ref_rule for ref_rule in [rule.refinement] + list(rule.refinement.descendants)]
537
- for ref_rule in ref_rules:
538
- del ref_rule
539
- else:
540
- rule.parent.refinement = rule.alternative
541
- if rule.alternative is not None:
542
- rule.alternative = None
543
- rule.parent = None
544
- del rule
918
+ conclusion_wrapper_func_name = rule.generated_conclusion_function_name
919
+ user_input = functions_source[conclusion_wrapper_func_name]
920
+ conclusion_func = defs_module.__dict__.get(rule.generated_conclusion_function_name)
921
+ conclusion_type = get_function_return_type(conclusion_func)
922
+ rule.conclusion = CallableExpression(user_input, conclusion_type, scope=scope,
923
+ mutually_exclusive=self.mutually_exclusive)
545
924
 
546
925
  @abstractmethod
547
926
  def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
@@ -604,6 +983,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
604
983
  f.write(f"attribute_name = '{self.attribute_name}'\n")
605
984
  f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
606
985
  f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
986
+ self.write_rdr_metadata_to_pyton_file(f)
607
987
  f.write(f"\n\n{func_def}")
608
988
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
609
989
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
@@ -669,9 +1049,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
669
1049
  :return: The type of the conclusion of the RDR classifier.
670
1050
  """
671
1051
  all_types = []
672
- if self.start_rule is not None:
673
- for rule in [self.start_rule] + list(self.start_rule.descendants):
674
- all_types.extend(list(rule.conclusion.conclusion_type))
1052
+ for rule in self.all_rules:
1053
+ all_types.extend(list(rule.conclusion.conclusion_type))
675
1054
  return tuple(set(all_types))
676
1055
 
677
1056
  @property
@@ -728,6 +1107,10 @@ class SingleClassRDR(RDRWithCodeWriter):
728
1107
  super(SingleClassRDR, self).__init__(**kwargs)
729
1108
  self.default_conclusion: Optional[Any] = default_conclusion
730
1109
 
1110
+ @classmethod
1111
+ def get_tree_builder_class(cls) -> Type[TreeBuilder]:
1112
+ return SingleClassTreeBuilder
1113
+
731
1114
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
732
1115
  -> Union[CaseAttribute, CallableExpression, None]:
733
1116
  """
@@ -879,6 +1262,10 @@ class MultiClassRDR(RDRWithCodeWriter):
879
1262
  super(MultiClassRDR, self).__init__(start_rule, **kwargs)
880
1263
  self.mode: MCRDRMode = mode
881
1264
 
1265
+ @classmethod
1266
+ def get_tree_builder_class(cls) -> Type[TreeBuilder]:
1267
+ return MultiClassTreeBuilder
1268
+
882
1269
  def _classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
883
1270
  case_query: Optional[CaseQuery] = None) -> Set[Any]:
884
1271
  evaluated_rule = self.start_rule
@@ -1128,6 +1515,58 @@ class GeneralRDR(RippleDownRules):
1128
1515
  super(GeneralRDR, self).__init__(**kwargs)
1129
1516
  self.all_figs: List[Figure] = [sr.fig for sr in self.start_rules_dict.values()]
1130
1517
 
1518
+ @classmethod
1519
+ def from_python(cls, model_dir: str, python_file_path: Optional[str] = None,
1520
+ parent_package_name: Optional[str] = None) -> Self:
1521
+ """
1522
+ Create an instance of the class from a python file.
1523
+
1524
+ :param model_dir: The path to the directory containing the python file.
1525
+ :param python_file_path: The path to the python file, if not provided, it will be generated from the model_dir.
1526
+ :param parent_package_name: The name of the package that contains the RDR classifier function, this
1527
+ is required in case of relative imports in the generated python file.
1528
+ :return: An instance of the class.
1529
+ """
1530
+ if python_file_path is None:
1531
+ file_name = get_file_that_ends_with(model_dir, f"_{cls.get_acronym().lower()}.py",)
1532
+ main_python_file_path = os.path.join(model_dir, file_name)
1533
+ else:
1534
+ main_python_file_path = python_file_path
1535
+ main_python_file_name = Path(main_python_file_path).name.replace('.py', '')
1536
+ main_module_import_path = get_import_path_from_path(model_dir)
1537
+ main_module_import_path = f"{main_module_import_path}.{main_python_file_name}" \
1538
+ if main_module_import_path else main_python_file_name
1539
+ main_module = importlib.import_module(main_module_import_path)
1540
+ importlib.reload(main_module)
1541
+ classifiers_dict = main_module.classifiers_dict
1542
+ start_rules_dict = {}
1543
+ for rdr_name, rdr_module in classifiers_dict.items():
1544
+ rdr_module_name = rdr_module.__name__
1545
+ rdr_acronym = rdr_module_name.split('_')[-1]
1546
+ rdr_type = cls.get_rdr_type_from_acronym(rdr_acronym)
1547
+ rdr_model_path = main_python_file_path.replace('_rdr.py', f'_{rdr_name}_{rdr_acronym}.py')
1548
+ rdr = rdr_type.from_python(model_dir, python_file_path=rdr_model_path, parent_package_name=parent_package_name)
1549
+ start_rules_dict[rdr_name] = rdr
1550
+ grdr = cls(category_rdr_map=start_rules_dict)
1551
+ grdr.update_rdr_metadata_from_python(main_module)
1552
+ return grdr
1553
+
1554
+ @classmethod
1555
+ def get_rdr_type_from_acronym(cls, acronym: str) -> Type[Union[SingleClassRDR, MultiClassRDR]]:
1556
+ """
1557
+ Get the type of the ripple down rules classifier from the acronym.
1558
+
1559
+ :param acronym: The acronym of the ripple down rules classifier.
1560
+ :return: The type of the ripple down rules classifier.
1561
+ """
1562
+ acronym = acronym.lower()
1563
+ if acronym == "scrdr":
1564
+ return SingleClassRDR
1565
+ elif acronym == "mcrdr":
1566
+ return MultiClassRDR
1567
+ else:
1568
+ raise ValueError(f"Unknown RDR type acronym: {acronym}")
1569
+
1131
1570
  def add_rdr(self, rdr: Union[SingleClassRDR, MultiClassRDR], case_query: Optional[CaseQuery] = None):
1132
1571
  """
1133
1572
  Add a ripple down rules classifier to the map of classifiers.
@@ -1249,7 +1688,7 @@ class GeneralRDR(RippleDownRules):
1249
1688
  Write the tree of rules as source code to a file.
1250
1689
 
1251
1690
  :param model_dir: The directory where the model is stored.
1252
- :param relative_imports: Whether to use relative imports in the generated python file.
1691
+ :param package_name: The name of the package that contains the RDR classifier function.
1253
1692
  """
1254
1693
  for rdr in self.start_rules_dict.values():
1255
1694
  rdr._write_to_python(model_dir, package_name=package_name)
@@ -1257,6 +1696,7 @@ class GeneralRDR(RippleDownRules):
1257
1696
  file_path = model_dir + f"/{self.generated_python_file_name}.py"
1258
1697
  with open(file_path, "w") as f:
1259
1698
  f.write(self._get_imports(file_path=file_path, package_name=package_name) + "\n\n")
1699
+ self.write_rdr_metadata_to_pyton_file(f)
1260
1700
  f.write("classifiers_dict = dict()\n")
1261
1701
  for rdr_key, rdr in self.start_rules_dict.items():
1262
1702
  f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
@@ -214,6 +214,10 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
214
214
  def generated_conditions_function_name(self) -> str:
215
215
  return f"conditions_{self.uid}"
216
216
 
217
+ @property
218
+ def generated_corner_case_object_name(self) -> str:
219
+ return f"corner_case_{self.uid}"
220
+
217
221
  def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
218
222
  """
219
223
  Convert the conclusion of a rule to source code.
@@ -715,6 +715,42 @@ origin_type_to_hint = {
715
715
  }
716
716
 
717
717
 
718
+ def get_file_that_ends_with(directory_path: str, suffix: str) -> Optional[str]:
719
+ """
720
+ Get the file that ends with the given suffix in the model directory.
721
+
722
+ :param directory_path: The path to the directory where the file is located.
723
+ :param suffix: The suffix to search for.
724
+ :return: The path to the file that ends with the given suffix, or None if not found.
725
+ """
726
+ files = [f for f in os.listdir(directory_path) if f.endswith(suffix)]
727
+ if files:
728
+ return files[0]
729
+ return None
730
+
731
+ def get_function_return_type(func: Callable) -> Union[Type, None, Tuple[Type, ...]]:
732
+ """
733
+ Get the return type of a function.
734
+
735
+ :param func: The function to get the return type for.
736
+ :return: The return type of the function, or None if not specified.
737
+ """
738
+ sig = inspect.signature(func)
739
+ if sig.return_annotation == inspect.Signature.empty:
740
+ return None
741
+ type_hint = sig.return_annotation
742
+ origin = get_origin(type_hint)
743
+ args = get_args(type_hint)
744
+ if origin not in [list, set, None, Union]:
745
+ raise TypeError(f"{origin} is not a handled return type for function {func.__name__}")
746
+ if origin is None:
747
+ return typing_to_python_type(type_hint)
748
+ if args is None or len(args) == 0:
749
+ return typing_to_python_type(type_hint)
750
+ return args
751
+
752
+
753
+
718
754
  def extract_types(tp, seen: Set = None) -> Set[type]:
719
755
  """Recursively extract all base types from a type hint."""
720
756
  if seen is None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.6.31
3
+ Version: 0.6.40
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -1,13 +1,13 @@
1
- ripple_down_rules/__init__.py,sha256=mH0F3aSoNRROyyjLKJo4D1dALPyoYh-J-mPNT1HudhU,99
1
+ ripple_down_rules/__init__.py,sha256=ZDGcC9yYEOFz8iqWsnSdjTpnmCTdj7BklFWMBfFaylI,99
2
2
  ripple_down_rules/experts.py,sha256=KXwWCmDrCffu9HW3yNewqWc1e5rnPI5Rnc981w_5M7U,17896
3
3
  ripple_down_rules/helpers.py,sha256=X1psHOqrb4_xYN4ssQNB8S9aRKKsqgihAyWJurN0dqk,5499
4
- ripple_down_rules/rdr.py,sha256=KsZbAbOs8U2PL19YOjFqSer8coXkSMDL3ztIrWHmTCA,62833
4
+ ripple_down_rules/rdr.py,sha256=-6q_N1vGaui2plVvvS0V4cW0j9ETTDIYyveWZx7UTuY,81213
5
5
  ripple_down_rules/rdr_decorators.py,sha256=xoBGsIJMkJYUdsrsEaPZqoAsGuXkuVZAKCoP-xD2Iv8,11668
6
- ripple_down_rules/rules.py,sha256=tmGJ5m9Z_d_qoRaWvRjDnl5AAVDgC_qeHGKAKN7WP64,29237
6
+ ripple_down_rules/rules.py,sha256=32apFTxtWXKRQ2MJDnqc1URjRJDnNBe_t5A_gGfKNd0,29349
7
7
  ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
8
- ripple_down_rules/utils.py,sha256=1fiSF4MOaOUrxlMz8sZA_e10258sMWuX5fG9WDawd2o,76674
8
+ ripple_down_rules/utils.py,sha256=SuZKBnGZvfDqlry7PY_N8AzLBOhEeuev57AVS-fJy78,77992
9
9
  ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
10
- ripple_down_rules/datastructures/callable_expression.py,sha256=IrlnufVsKrUDLVkc2owoFQ05oSOby3HiGuNXoFVj4Dw,13494
10
+ ripple_down_rules/datastructures/callable_expression.py,sha256=rzMrpD5oztaCRlt3hQ2B_xZ09cSuJNkYOCePndfQJRA,13684
11
11
  ripple_down_rules/datastructures/case.py,sha256=dfLnrjsHIVF2bgbz-4ID7OdQvw68V71btCeTK372P-g,15667
12
12
  ripple_down_rules/datastructures/dataclasses.py,sha256=3vX52WrAHgVyw0LUSgSBOVFaQNTSxU8hQpdr7cW-tSg,13278
13
13
  ripple_down_rules/datastructures/enums.py,sha256=CvcROl8fE7A6uTbMfs2lLpyxwS_ZFtFcQlBDDKFfoHc,6059
@@ -17,8 +17,8 @@ ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=RLdPqPxx-a0Sh74U
17
17
  ripple_down_rules/user_interface/object_diagram.py,sha256=FEa2HaYR9QmTE6NsOwBvZ0jqmu3DKyg6mig2VE5ZP4Y,4956
18
18
  ripple_down_rules/user_interface/prompt.py,sha256=WPbw_8_-8SpF2ISyRZRuFwPKBEuGC4HaX3lbCPFHhh8,10314
19
19
  ripple_down_rules/user_interface/template_file_creator.py,sha256=uSbosZS15MOR3Nv7M3MrFuoiKXyP4cBId-EK3I6stHM,13660
20
- ripple_down_rules-0.6.31.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
21
- ripple_down_rules-0.6.31.dist-info/METADATA,sha256=HRME3boNiuoTqo98-gzH7aK2-KCfc561e2eDDi09bhw,48294
22
- ripple_down_rules-0.6.31.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- ripple_down_rules-0.6.31.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
24
- ripple_down_rules-0.6.31.dist-info/RECORD,,
20
+ ripple_down_rules-0.6.40.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
21
+ ripple_down_rules-0.6.40.dist-info/METADATA,sha256=W9zg7je9jG0609I2OlUL-vqfhlTaX5rtibb8PLyaOXQ,48294
22
+ ripple_down_rules-0.6.40.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ ripple_down_rules-0.6.40.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
24
+ ripple_down_rules-0.6.40.dist-info/RECORD,,