ripple-down-rules 0.6.23__py3-none-any.whl → 0.6.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- __version__ = "0.6.23"
1
+ __version__ = "0.6.25"
2
2
 
3
3
  import logging
4
4
  logger = logging.Logger("rdr")
@@ -95,7 +95,7 @@ class CallableExpression(SubclassJSONSerializer):
95
95
  encapsulating_function_name: str = "_get_value"
96
96
 
97
97
  def __init__(self, user_input: Optional[str] = None,
98
- conclusion_type: Optional[Tuple[Type]] = None,
98
+ conclusion_type: Optional[Tuple[Type, ...]] = None,
99
99
  expression_tree: Optional[AST] = None,
100
100
  scope: Optional[Dict[str, Any]] = None,
101
101
  conclusion: Optional[Any] = None,
@@ -176,7 +176,8 @@ class CallableExpression(SubclassJSONSerializer):
176
176
  new_user_input = (f"{cond1_user_input}\n"
177
177
  f"{cond2_user_input}\n"
178
178
  f"return _cond1(case) and _cond2(case)")
179
- return CallableExpression(new_user_input, conclusion_type=self.conclusion_type)
179
+ return CallableExpression(new_user_input, conclusion_type=self.conclusion_type,
180
+ mutually_exclusive=self.mutually_exclusive)
180
181
 
181
182
  def update_user_input_from_file(self, file_path: str, function_name: str):
182
183
  """
@@ -31,7 +31,7 @@ class CaseQuery:
31
31
  """
32
32
  The name of the attribute.
33
33
  """
34
- _attribute_types: Tuple[Type]
34
+ _attribute_types: Tuple[Type, ...]
35
35
  """
36
36
  The type(s) of the attribute.
37
37
  """
@@ -139,7 +139,7 @@ class CaseQuery:
139
139
  attribute_types_str = f"Union[{', '.join([t.__name__ for t in self.core_attribute_type])}]"
140
140
  else:
141
141
  attribute_types_str = self.core_attribute_type[0].__name__
142
- if all(t in self.attribute_type for t in [list, set]) and len(self.core_attribute_type) > 2:
142
+ if not self.mutually_exclusive:
143
143
  return f"List[{attribute_types_str}]"
144
144
  else:
145
145
  return attribute_types_str
@@ -196,7 +196,10 @@ class RDREdge(Enum):
196
196
  """
197
197
  Next edge, the edge that represents the next rule to be evaluated.
198
198
  """
199
-
199
+ Filter = "filter if"
200
+ """
201
+ Filter edge, the edge that represents the filter condition.
202
+ """
200
203
 
201
204
  class ValueType(Enum):
202
205
  Unary = auto()
@@ -41,14 +41,14 @@ class Expert(ABC):
41
41
  A flag to indicate if the expert should use loaded answers or not.
42
42
  """
43
43
 
44
- def __init__(self, use_loaded_answers: bool = True,
44
+ def __init__(self, use_loaded_answers: bool = False,
45
45
  append: bool = False,
46
46
  answers_save_path: Optional[str] = None):
47
47
  self.all_expert_answers = []
48
48
  self.use_loaded_answers = use_loaded_answers
49
49
  self.append = append
50
50
  self.answers_save_path = answers_save_path
51
- if answers_save_path is not None:
51
+ if answers_save_path is not None and os.path.exists(answers_save_path + '.py'):
52
52
  if use_loaded_answers:
53
53
  self.load_answers(answers_save_path)
54
54
  else:
@@ -2,15 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  import os
4
4
  from types import ModuleType
5
+ from typing import Tuple
5
6
 
6
- from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
7
-
8
- from .datastructures.case import create_case
9
- from .datastructures.dataclasses import CaseQuery
10
7
  from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING, Union
11
8
 
12
- from .utils import get_func_rdr_model_name, copy_case, make_set, update_case
9
+ from .datastructures.case import create_case, Case
10
+ from .datastructures.dataclasses import CaseQuery
13
11
  from .utils import calculate_precision_and_recall
12
+ from .utils import get_func_rdr_model_name, copy_case, make_set, update_case
14
13
 
15
14
  if TYPE_CHECKING:
16
15
  from .rdr import RippleDownRules
@@ -55,12 +54,14 @@ def general_rdr_classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDow
55
54
  if attribute_name in new_conclusions:
56
55
  temp_case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, rdr.mutually_exclusive)
57
56
  update_case(temp_case_query, new_conclusions)
58
- if len(new_conclusions) == 0 or len(classifiers_dict) == 1 and list(classifiers_dict.values())[0].mutually_exclusive:
57
+ if len(new_conclusions) == 0 or len(classifiers_dict) == 1 and list(classifiers_dict.values())[
58
+ 0].mutually_exclusive:
59
59
  break
60
60
  return conclusions
61
61
 
62
62
 
63
- def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
63
+ def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery,
64
+ pred_cat: Optional[Dict[str, Any]] = None) -> bool:
64
65
  """
65
66
  :param classifier: The RDR classifier to check the prediction of.
66
67
  :param case_query: The case query to check.
@@ -95,3 +96,23 @@ def load_or_create_func_rdr_model(func, model_dir: str, rdr_type: Type[RippleDow
95
96
  else:
96
97
  rdr = rdr_type(**rdr_kwargs)
97
98
  return rdr
99
+
100
+
101
+ def get_an_updated_case_copy(case: Case, conclusion: Callable, attribute_name: str, conclusion_type: Tuple[Type, ...],
102
+ mutually_exclusive: bool) -> Case:
103
+ """
104
+ :param case: The case to copy and update.
105
+ :param conclusion: The conclusion to add to the case.
106
+ :param attribute_name: The name of the attribute to update.
107
+ :param conclusion_type: The type of the conclusion to update.
108
+ :param mutually_exclusive: Whether the rule belongs to a mutually exclusive RDR.
109
+ :return: A copy of the case updated with the given conclusion.
110
+ """
111
+ case_cp = copy_case(case)
112
+ temp_case_query = CaseQuery(case_cp, attribute_name, conclusion_type,
113
+ mutually_exclusive=mutually_exclusive)
114
+ output = conclusion(case_cp)
115
+ if not isinstance(output, Dict):
116
+ output = {attribute_name: output}
117
+ update_case(temp_case_query, output)
118
+ return case_cp
ripple_down_rules/rdr.py CHANGED
@@ -1,9 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import importlib
4
+ import json
4
5
  import os
5
6
  from abc import ABC, abstractmethod
6
7
  from copy import copy
8
+ from dataclasses import is_dataclass
7
9
  from types import NoneType
8
10
 
9
11
  from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
@@ -27,15 +29,16 @@ from .datastructures.case import Case, CaseAttribute, create_case
27
29
  from .datastructures.dataclasses import CaseQuery
28
30
  from .datastructures.enums import MCRDRMode
29
31
  from .experts import Expert, Human
30
- from .helpers import is_matching, general_rdr_classify
31
- from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
32
+ from .helpers import is_matching, general_rdr_classify, get_an_updated_case_copy
33
+ from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule, MultiClassRefinementRule, \
34
+ MultiClassFilterRule
32
35
 
33
36
  try:
34
37
  from .user_interface.gui import RDRCaseViewer
35
38
  except ImportError as e:
36
39
  RDRCaseViewer = None
37
40
  from .utils import draw_tree, make_set, SubclassJSONSerializer, make_list, get_type_from_string, \
38
- is_conflicting, extract_function_source, extract_imports, get_full_class_name, \
41
+ is_value_conflicting, extract_function_source, extract_imports, get_full_class_name, \
39
42
  is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types, render_tree
40
43
 
41
44
 
@@ -97,11 +100,12 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
97
100
  self.viewer = RDRCaseViewer.instances[0]
98
101
  logger.error("No viewer was provided, but there is already an existing viewer. "
99
102
  "Using the existing viewer.")
103
+ self.input_node: Optional[Rule] = None
100
104
 
101
105
  @property
102
106
  def viewer(self):
103
107
  return self._viewer
104
-
108
+
105
109
  @viewer.setter
106
110
  def viewer(self, value):
107
111
  self._viewer = value
@@ -110,23 +114,24 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
110
114
 
111
115
  def render_evaluated_rule_tree(self, filename: str, show_full_tree: bool = False) -> None:
112
116
  if show_full_tree:
113
- render_tree(self.start_rule, use_dot_exporter=True, filename=filename)
117
+ start_rule = self.start_rule if self.input_node is None else self.input_node
118
+ render_tree(start_rule, use_dot_exporter=True, filename=filename)
114
119
  else:
115
120
  evaluated_rules = self.get_evaluated_rule_tree()
116
121
  if evaluated_rules is not None and len(evaluated_rules) > 0:
117
122
  render_tree(evaluated_rules[0], use_dot_exporter=True, filename=filename,
118
123
  only_nodes=evaluated_rules)
119
124
 
120
- def get_evaluated_rule_tree(self) -> List[Rule]:
125
+ def get_evaluated_rule_tree(self) -> Optional[List[Rule]]:
121
126
  """
122
127
  Get the evaluated rule tree of the classifier.
123
128
 
124
129
  :return: The evaluated rule tree.
125
130
  """
126
131
  if self.start_rule is None:
127
- return
128
- # raise ValueError("The start rule is not set. Please set the start rule before getting the evaluated rule tree.")
129
- evaluated_rule_tree = [r for r in [self.start_rule] + list(self.start_rule.descendants) if r.evaluated]
132
+ return None
133
+ start_rule = self.start_rule
134
+ evaluated_rule_tree = [r for r in [start_rule] + list(start_rule.descendants) if r.evaluated]
130
135
  return evaluated_rule_tree
131
136
 
132
137
  def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None,
@@ -260,9 +265,38 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
260
265
  def __call__(self, case: Union[Case, SQLTable]) -> Union[CallableExpression, Dict[str, CallableExpression]]:
261
266
  return self.classify(case)
262
267
 
268
+ def classify(self, case: Union[Case, SQLTable], modify_case: bool = False, case_query: Optional[CaseQuery] = None) \
269
+ -> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
270
+ """
271
+ Classify a case using the RDR classifier.
272
+
273
+ :param case: The case to classify.
274
+ :param modify_case: Whether to modify the original case attributes with the conclusion or not.
275
+ :param case_query: The case query containing the case to classify and the target category to compare the case with.
276
+ :return: The category that the case belongs to.
277
+ """
278
+ if self.start_rule is not None:
279
+ for rule in [self.start_rule] + list(self.start_rule.descendants):
280
+ rule.evaluated = False
281
+ rule.fired = False
282
+ if self.start_rule is not None and self.start_rule.parent is None:
283
+ if self.input_node is None:
284
+ self.input_node = type(self.start_rule)(parent=None, uid='0')
285
+ self.input_node.evaluated = False
286
+ self.input_node.fired = False
287
+ self.start_rule.parent = self.input_node
288
+ self.start_rule.weight = ""
289
+ if self.input_node is not None:
290
+ data = case.__dict__ if is_dataclass(case) else case
291
+ if hasattr(case, "items"):
292
+ self.input_node.name = json.dumps({k: str(v) for k, v in data.items()}, indent=4)
293
+ else:
294
+ self.input_node.name = str(data)
295
+ return self._classify(case, modify_case=modify_case, case_query=case_query)
296
+
263
297
  @abstractmethod
264
- def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
265
- case_query: Optional[CaseQuery] = None) \
298
+ def _classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
299
+ case_query: Optional[CaseQuery] = None) \
266
300
  -> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
267
301
  """
268
302
  Classify a case.
@@ -467,25 +501,51 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
467
501
  conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys()
468
502
  if not isinstance(rules_dict[rid], MultiClassStopRule)]
469
503
  all_func_names = condition_func_names + conclusion_func_names
504
+ rule_tree_file_path = f"{model_dir}/{self.generated_python_file_name}.py"
470
505
  filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
471
506
  cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
472
507
  cases_import_path = get_import_path_from_path(model_dir)
473
508
  cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path \
474
509
  else self.generated_python_cases_file_name
475
510
  functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
511
+ python_rule_tree_source = ""
512
+ with open(rule_tree_file_path, "r") as rule_tree_source:
513
+ python_rule_tree_source = rule_tree_source.read()
476
514
  # get the scope from the imports in the file
477
515
  scope = extract_imports(filepath, package_name=package_name)
516
+ rules_not_found = set()
478
517
  for rule in [self.start_rule] + list(self.start_rule.descendants):
479
518
  if rule.conditions is not None:
480
- rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
519
+ conditions_name = rule.generated_conditions_function_name
520
+ if conditions_name not in functions_source or conditions_name not in python_rule_tree_source:
521
+ rules_not_found.add(rule)
522
+ continue
523
+ rule.conditions.user_input = functions_source[conditions_name]
481
524
  rule.conditions.scope = scope
482
525
  if os.path.exists(cases_path):
483
526
  module = importlib.import_module(cases_import_path, package=package_name)
484
527
  importlib.reload(module)
485
528
  rule.corner_case_metadata = module.__dict__.get(f"corner_case_{rule.uid}", None)
486
- if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
487
- rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
529
+ if not isinstance(rule, MultiClassStopRule):
530
+ conclusion_name = rule.generated_conclusion_function_name
531
+ if conclusion_name not in functions_source or conclusion_name not in python_rule_tree_source:
532
+ rules_not_found.add(rule)
533
+ rule.conclusion.user_input = functions_source[conclusion_name]
488
534
  rule.conclusion.scope = scope
535
+ for rule in rules_not_found:
536
+ if isinstance(rule, MultiClassTopRule):
537
+ import pdb; pdb.set_trace()
538
+ rule.parent.set_immediate_alternative(rule.alternative)
539
+ if rule.refinement is not None:
540
+ ref_rules = [ref_rule for ref_rule in [rule.refinement] + list(rule.refinement.descendants)]
541
+ for ref_rule in ref_rules:
542
+ del ref_rule
543
+ else:
544
+ rule.parent.refinement = rule.alternative
545
+ if rule.alternative is not None:
546
+ rule.alternative = None
547
+ rule.parent = None
548
+ del rule
489
549
 
490
550
  @abstractmethod
491
551
  def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
@@ -564,7 +624,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
564
624
  """
565
625
  pass
566
626
 
567
- def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
627
+ def _get_types_to_import(self) -> Tuple[Set[Union[Type, Callable]], Set[Type], Set[Type]]:
568
628
  """
569
629
  :return: The types of the main, defs, and corner cases files of the RDR classifier that will be imported.
570
630
  """
@@ -703,8 +763,8 @@ class SingleClassRDR(RDRWithCodeWriter):
703
763
  expert.ask_for_conditions(case_query)
704
764
  self.start_rule = SingleClassRule.from_case_query(case_query)
705
765
 
706
- def classify(self, case: Case, modify_case: bool = False,
707
- case_query: Optional[CaseQuery] = None) -> Optional[Any]:
766
+ def _classify(self, case: Case, modify_case: bool = False,
767
+ case_query: Optional[CaseQuery] = None) -> Optional[Any]:
708
768
  """
709
769
  Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
710
770
 
@@ -818,8 +878,8 @@ class MultiClassRDR(RDRWithCodeWriter):
818
878
  super(MultiClassRDR, self).__init__(start_rule, **kwargs)
819
879
  self.mode: MCRDRMode = mode
820
880
 
821
- def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
822
- case_query: Optional[CaseQuery] = None) -> Set[Any]:
881
+ def _classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
882
+ case_query: Optional[CaseQuery] = None) -> Set[Any]:
823
883
  evaluated_rule = self.start_rule
824
884
  self.conclusions = []
825
885
  while evaluated_rule:
@@ -897,6 +957,9 @@ class MultiClassRDR(RDRWithCodeWriter):
897
957
  if rule.alternative:
898
958
  self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
899
959
  cases_file=cases_file, package_name=package_name)
960
+ elif isinstance(rule, MultiClassTopRule):
961
+ with open(filename, "a") as file:
962
+ file.write(f"{parent_indent}return conclusions\n")
900
963
 
901
964
  @property
902
965
  def conclusion_type_hint(self) -> str:
@@ -906,8 +969,9 @@ class MultiClassRDR(RDRWithCodeWriter):
906
969
  else:
907
970
  return f"Set[Union[{', '.join(conclusion_types)}]]"
908
971
 
909
- def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
972
+ def _get_types_to_import(self) -> Tuple[Set[Union[Type, Callable]], Set[Type], Set[Type]]:
910
973
  main_types, defs_types, cases_types = super()._get_types_to_import()
974
+ main_types.add(get_an_updated_case_copy)
911
975
  main_types.update({Set, make_set})
912
976
  defs_types.update({List, Set})
913
977
  return main_types, defs_types, cases_types
@@ -939,28 +1003,43 @@ class MultiClassRDR(RDRWithCodeWriter):
939
1003
  Stop a wrong conclusion by adding a stopping rule.
940
1004
  """
941
1005
  rule_conclusion = evaluated_rule.conclusion(case_query.case)
942
- if is_conflicting(rule_conclusion, case_query.target_value):
943
- self.stop_conclusion(case_query, expert, evaluated_rule)
944
- else:
1006
+ stop: bool = False
1007
+ add_filter_rule: bool = False
1008
+ if is_value_conflicting(rule_conclusion, case_query.target_value):
1009
+ if make_set(case_query.target_value).issubset(rule_conclusion):
1010
+ add_filter_rule = True
1011
+ else:
1012
+ stop = True
1013
+ elif make_set(case_query.core_attribute_type).issubset(make_set(evaluated_rule.conclusion.conclusion_type)):
1014
+ if make_set(case_query.target_value).issubset(rule_conclusion):
1015
+ add_filter_rule = True
1016
+
1017
+ if not stop:
945
1018
  self.add_conclusion(rule_conclusion)
1019
+ if stop or add_filter_rule:
1020
+ refinement_type = MultiClassStopRule if stop else MultiClassFilterRule
1021
+ self.stop_or_filter_conclusion(case_query, expert, evaluated_rule, refinement_type=refinement_type)
946
1022
 
947
- def stop_conclusion(self, case_query: CaseQuery,
948
- expert: Expert, evaluated_rule: MultiClassTopRule):
1023
+ def stop_or_filter_conclusion(self, case_query: CaseQuery,
1024
+ expert: Expert, evaluated_rule: MultiClassTopRule,
1025
+ refinement_type: Type[MultiClassRefinementRule] = MultiClassStopRule):
949
1026
  """
950
1027
  Stop a conclusion by adding a stopping rule.
951
1028
 
952
1029
  :param case_query: The case query to stop the conclusion for.
953
1030
  :param expert: The expert to ask for differentiating features as new rule conditions.
954
1031
  :param evaluated_rule: The evaluated rule to ask the expert about.
1032
+ :param refinement_type: The refinement type to use.
955
1033
  """
956
1034
  conditions = expert.ask_for_conditions(case_query, evaluated_rule)
957
- evaluated_rule.fit_rule(case_query)
958
- if self.mode == MCRDRMode.StopPlusRule:
959
- self.stop_rule_conditions = conditions
960
- if self.mode == MCRDRMode.StopPlusRuleCombined:
961
- new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
962
- case_query.conditions = new_top_rule_conditions
963
- self.add_top_rule(case_query)
1035
+ evaluated_rule.fit_rule(case_query, refinement_type=refinement_type)
1036
+ if refinement_type is MultiClassStopRule:
1037
+ if self.mode == MCRDRMode.StopPlusRule:
1038
+ self.stop_rule_conditions = conditions
1039
+ if self.mode == MCRDRMode.StopPlusRuleCombined:
1040
+ new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
1041
+ case_query.conditions = new_top_rule_conditions
1042
+ self.add_top_rule(case_query)
964
1043
 
965
1044
  def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
966
1045
  """
@@ -1064,8 +1143,8 @@ class GeneralRDR(RippleDownRules):
1064
1143
  def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
1065
1144
  return [rdr.start_rule for rdr in self.start_rules_dict.values()]
1066
1145
 
1067
- def classify(self, case: Any, modify_case: bool = False,
1068
- case_query: Optional[CaseQuery] = None) -> Optional[Dict[str, Any]]:
1146
+ def _classify(self, case: Any, modify_case: bool = False,
1147
+ case_query: Optional[CaseQuery] = None) -> Optional[Dict[str, Any]]:
1069
1148
  """
1070
1149
  Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
1071
1150
  the classification until no more categories can be added.
@@ -8,15 +8,16 @@ from functools import wraps
8
8
 
9
9
  from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union, Sequence
10
10
 
11
- from ripple_down_rules.datastructures.case import Case
12
- from ripple_down_rules.datastructures.dataclasses import CaseQuery
13
- from ripple_down_rules.experts import Expert, Human
14
- from ripple_down_rules.rdr import GeneralRDR
11
+ from .datastructures.case import Case
12
+ from .datastructures.dataclasses import CaseQuery
13
+ from .experts import Expert, Human
14
+ from .rdr import GeneralRDR
15
+ from . import logger
15
16
  try:
16
- from ripple_down_rules.user_interface.gui import RDRCaseViewer
17
+ from .user_interface.gui import RDRCaseViewer
17
18
  except ImportError:
18
19
  RDRCaseViewer = None
19
- from ripple_down_rules.utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
20
+ from .utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
20
21
  get_method_class_if_exists, str_to_snake_case
21
22
 
22
23
 
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  import logging
4
4
  import re
5
5
  from abc import ABC, abstractmethod
6
- from pathlib import Path
7
6
  from types import NoneType
8
7
  from uuid import uuid4
9
8
 
@@ -15,7 +14,8 @@ from .datastructures.callable_expression import CallableExpression
15
14
  from .datastructures.case import Case
16
15
  from .datastructures.dataclasses import CaseFactoryMetaData, CaseQuery
17
16
  from .datastructures.enums import RDREdge, Stop
18
- from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_imports_from_types
17
+ from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_type_from_string
18
+ from .helpers import get_an_updated_case_copy
19
19
 
20
20
 
21
21
  class Rule(NodeMixin, SubclassJSONSerializer, ABC):
@@ -23,6 +23,10 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
23
23
  """
24
24
  Whether the rule has fired or not.
25
25
  """
26
+ mutually_exclusive: bool
27
+ """
28
+ Whether the rule is mutually exclusive with other rules.
29
+ """
26
30
 
27
31
  def __init__(self, conditions: Optional[CallableExpression] = None,
28
32
  conclusion: Optional[CallableExpression] = None,
@@ -60,6 +64,14 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
60
64
  self.evaluated: bool = False
61
65
  self._user_defined_name: Optional[str] = None
62
66
 
67
+ def get_an_updated_case_copy(self, case: Case) -> Case:
68
+ """
69
+ :param case: The case to copy and update.
70
+ :return: A copy of the case updated with this rule conclusion.
71
+ """
72
+ return get_an_updated_case_copy(case, self.conclusion, self.conclusion_name, self.conclusion.conclusion_type,
73
+ self.mutually_exclusive)
74
+
63
75
  @property
64
76
  def color(self) -> str:
65
77
  if self.evaluated:
@@ -78,22 +90,27 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
78
90
  if self._user_defined_name is None:
79
91
  if self.conditions and self.conditions.user_input and "def " in self.conditions.user_input:
80
92
  # If the conditions have a user input, use it as the name
81
- self._user_defined_name = self.conditions.user_input.split('(')[0].replace('def ', '').strip()
93
+ func_name = self.conditions.user_input.split('(')[0].replace('def ', '').strip()
94
+ if func_name == self.conditions.encapsulating_function_name:
95
+ self._user_defined_name = str(self.conditions)
96
+ else:
97
+ self._user_defined_name = func_name
82
98
  else:
83
99
  self._user_defined_name = f"Rule_{self.uid}"
84
100
  return self._user_defined_name
85
101
 
86
102
  @classmethod
87
- def from_case_query(cls, case_query: CaseQuery) -> Rule:
103
+ def from_case_query(cls, case_query: CaseQuery, parent: Optional[Rule] = None) -> Rule:
88
104
  """
89
105
  Create a SingleClassRule from a CaseQuery.
90
106
 
91
107
  :param case_query: The CaseQuery to create the rule from.
108
+ :param parent: The parent rule of this rule.
92
109
  :return: A SingleClassRule instance.
93
110
  """
94
111
  corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
95
112
  return cls(conditions=case_query.conditions, conclusion=case_query.target,
96
- corner_case=case_query.case, parent=None,
113
+ corner_case=case_query.case, parent=parent,
97
114
  corner_case_metadata=corner_case_metadata,
98
115
  conclusion_name=case_query.attribute_name)
99
116
 
@@ -116,9 +133,6 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
116
133
  :param x: The case to evaluate the rule on.
117
134
  :return: The rule that fired or the last evaluated rule if no rule fired.
118
135
  """
119
- if self.root is self:
120
- for descendant in self.descendants:
121
- descendant.evaluated = False
122
136
  self.evaluated = True
123
137
  if not self.conditions:
124
138
  raise ValueError("Rule has no conditions")
@@ -178,6 +192,14 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
178
192
  f.write(conclusion_func.strip() + "\n\n\n")
179
193
  return conclusion_func_call
180
194
 
195
+ @property
196
+ def generated_conclusion_function_name(self) -> str:
197
+ return f"conclusion_{self.uid}"
198
+
199
+ @property
200
+ def generated_conditions_function_name(self) -> str:
201
+ return f"conditions_{self.uid}"
202
+
181
203
  def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
182
204
  """
183
205
  Convert the conclusion of a rule to source code.
@@ -190,23 +212,24 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
190
212
  # This means the conclusion is a definition that should be written and then called
191
213
  conclusion_lines = conclusion.split('\n')
192
214
  # use regex to replace the function name
193
- new_function_name = f"def conclusion_{self.uid}"
215
+ new_function_name = f"def {self.generated_conclusion_function_name}"
194
216
  conclusion_lines[0] = re.sub(r"def (\w+)", new_function_name, conclusion_lines[0])
195
217
  # add type hint
196
- if len(self.conclusion.conclusion_type) == 1:
197
- hint = self.conclusion.conclusion_type[0].__name__
218
+ if not self.conclusion.mutually_exclusive:
219
+ type_names = [t.__name__ for t in self.conclusion.conclusion_type if t not in [list, set]]
220
+ if len(type_names) == 1:
221
+ hint = f"List[{type_names[0]}]"
222
+ else:
223
+ hint = f"List[Union[{', '.join(type_names)}]]"
198
224
  else:
199
- if (all(t in self.conclusion.conclusion_type for t in [list, set])
200
- and len(self.conclusion.conclusion_type) > 2):
201
- type_names = [t.__name__ for t in self.conclusion.conclusion_type if t not in [list, set]]
202
- hint = f"List[{', '.join(type_names)}]"
225
+ if NoneType in self.conclusion.conclusion_type:
226
+ type_names = [t.__name__ for t in self.conclusion.conclusion_type if t is not NoneType]
227
+ hint = f"Optional[{', '.join(type_names)}]"
228
+ elif len(self.conclusion.conclusion_type) == 1:
229
+ hint = self.conclusion.conclusion_type[0].__name__
203
230
  else:
204
- if NoneType in self.conclusion.conclusion_type:
205
- type_names = [t.__name__ for t in self.conclusion.conclusion_type if t is not NoneType]
206
- hint = f"Optional[{', '.join(type_names)}]"
207
- else:
208
- type_names = [t.__name__ for t in self.conclusion.conclusion_type]
209
- hint = f"Union[{', '.join(type_names)}]"
231
+ type_names = [t.__name__ for t in self.conclusion.conclusion_type]
232
+ hint = f"Union[{', '.join(type_names)}]"
210
233
  conclusion_lines[0] = conclusion_lines[0].replace("):", f") -> {hint}:")
211
234
  func_call = f"{parent_indent} return {new_function_name.replace('def ', '')}(case)\n"
212
235
  return "\n".join(conclusion_lines).strip(' '), func_call
@@ -228,7 +251,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
228
251
  # This means the conditions are a definition that should be written and then called
229
252
  conditions_lines = self.conditions.user_input.split('\n')
230
253
  # use regex to replace the function name
231
- new_function_name = f"def conditions_{self.uid}"
254
+ new_function_name = f"def {self.generated_conditions_function_name}"
232
255
  conditions_lines[0] = re.sub(r"def (\w+)", new_function_name, conditions_lines[0])
233
256
  # add type hint
234
257
  conditions_lines[0] = conditions_lines[0].replace('):', ') -> bool:')
@@ -310,10 +333,11 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
310
333
  Get the name of the expression, which is the user input of the expression if it exists,
311
334
  otherwise it is the conclusion or conditions of the rule.
312
335
  """
313
- if expression.user_defined_name is not None:
336
+ if expression.user_defined_name is not None and expression.user_defined_name != expression.encapsulating_function_name:
314
337
  return expression.user_defined_name.strip()
315
- elif expression.user_input and "def " in expression.user_input:
316
- return expression.user_input.split('(')[0].replace('def ', '').strip()
338
+ func_name = expression.user_input.split('(')[0].replace('def ', '').strip() if "def " in expression.user_input else None
339
+ if func_name is not None and func_name != expression.encapsulating_function_name:
340
+ return func_name
317
341
  elif expression.user_input:
318
342
  return expression.user_input.strip()
319
343
  else:
@@ -350,6 +374,9 @@ class HasAlternativeRule:
350
374
  def alternative(self) -> Optional[Rule]:
351
375
  return self._alternative
352
376
 
377
+ def set_immediate_alternative(self, alternative: Optional[Rule]):
378
+ self._alternative = alternative
379
+
353
380
  @alternative.setter
354
381
  def alternative(self, new_rule: Rule):
355
382
  """
@@ -385,12 +412,11 @@ class HasRefinementRule:
385
412
  """
386
413
  if new_rule is None:
387
414
  return
388
- new_rule.top_rule = self
389
415
  if self.refinement:
390
416
  self.refinement.alternative = new_rule
391
417
  else:
392
418
  new_rule.parent = self
393
- new_rule.weight = RDREdge.Refinement.value
419
+ new_rule.weight = RDREdge.Refinement.value if not isinstance(new_rule, MultiClassFilterRule) else new_rule.weight
394
420
  self._refinement = new_rule
395
421
 
396
422
 
@@ -399,6 +425,8 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
399
425
  A rule in the SingleClassRDR classifier, it can have a refinement or an alternative rule or both.
400
426
  """
401
427
 
428
+ mutually_exclusive: bool = True
429
+
402
430
  def evaluate_next_rule(self, x: Case) -> SingleClassRule:
403
431
  if self.fired:
404
432
  returned_rule = self.refinement(x) if self.refinement else self
@@ -434,28 +462,15 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
434
462
  return "elif" if self.weight == RDREdge.Alternative.value else "if"
435
463
 
436
464
 
437
- class MultiClassStopRule(Rule, HasAlternativeRule):
465
+ class MultiClassRefinementRule(Rule, HasAlternativeRule, ABC):
438
466
  """
439
- A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
440
- the conclusion of the rule is a Stop category meant to stop the parent conclusion from being made.
467
+ A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule.
441
468
  """
442
469
  top_rule: Optional[MultiClassTopRule] = None
443
470
  """
444
471
  The top rule of the rule, which is the nearest ancestor that fired and this rule is a refinement of.
445
472
  """
446
-
447
- def __init__(self, *args, **kwargs):
448
- super(MultiClassStopRule, self).__init__(*args, **kwargs)
449
- self.conclusion = CallableExpression(conclusion_type=(Stop,), conclusion=Stop.stop)
450
-
451
- def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassStopRule, MultiClassTopRule]]:
452
- if self.fired:
453
- self.top_rule.fired = False
454
- return self.top_rule.alternative
455
- elif self.alternative:
456
- return self.alternative(x)
457
- else:
458
- return self.top_rule.alternative
473
+ mutually_exclusive: bool = False
459
474
 
460
475
  def _to_json(self) -> Dict[str, Any]:
461
476
  self.json_serialization = {**Rule._to_json(self),
@@ -463,27 +478,106 @@ class MultiClassStopRule(Rule, HasAlternativeRule):
463
478
  return self.json_serialization
464
479
 
465
480
  @classmethod
466
- def _from_json(cls, data: Dict[str, Any]) -> MultiClassStopRule:
467
- loaded_rule = super(MultiClassStopRule, cls)._from_json(data)
481
+ def _from_json(cls, data: Dict[str, Any]) -> MultiClassRefinementRule:
482
+ loaded_rule = super(MultiClassRefinementRule, cls)._from_json(data)
468
483
  # The following is done to prevent re-initialization of the top rule,
469
484
  # so the top rule that is already initialized is passed in the data instead of its json serialization.
470
485
  loaded_rule.top_rule = data['top_rule']
471
486
  if data['alternative'] is not None:
472
487
  data['alternative']['top_rule'] = data['top_rule']
473
- loaded_rule.alternative = MultiClassStopRule.from_json(data["alternative"])
488
+ loaded_rule.alternative = SubclassJSONSerializer.from_json(data["alternative"])
474
489
  return loaded_rule
475
490
 
491
+ def _if_statement_source_code_clause(self) -> str:
492
+ return "elif" if self.weight == RDREdge.Alternative.value else "if"
493
+
494
+
495
+ class MultiClassStopRule(MultiClassRefinementRule):
496
+ """
497
+ A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
498
+ the conclusion of the rule is a Stop category meant to stop the parent conclusion from being made.
499
+ """
500
+
501
+ def __init__(self, *args, **kwargs):
502
+ super(MultiClassRefinementRule, self).__init__(*args, **kwargs)
503
+ self.conclusion = CallableExpression(conclusion_type=(Stop,), conclusion=Stop.stop)
504
+
505
+ def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassRefinementRule, MultiClassTopRule]]:
506
+ if self.fired:
507
+ self.top_rule.fired = False
508
+ return self.top_rule.alternative
509
+ elif self.alternative:
510
+ return self.alternative(x)
511
+ else:
512
+ return self.top_rule.alternative
513
+
476
514
  def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[None, str]:
477
515
  return None, f"{parent_indent}{' ' * 4}pass\n"
478
516
 
479
- def _if_statement_source_code_clause(self) -> str:
480
- return "elif" if self.weight == RDREdge.Alternative.value else "if"
517
+
518
+ class MultiClassFilterRule(MultiClassRefinementRule, HasRefinementRule):
519
+ """
520
+ A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
521
+ the conclusion of the rule is a Filter category meant to filter the parent conclusion.
522
+ """
523
+
524
+ def __init__(self, *args, **kwargs):
525
+ super(MultiClassRefinementRule, self).__init__(*args, **kwargs)
526
+ self.weight = RDREdge.Filter.value
527
+
528
+ def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassRefinementRule, MultiClassTopRule]]:
529
+ if self.fired:
530
+ if self.refinement:
531
+ case_cp = x
532
+ if isinstance(self.refinement, MultiClassFilterRule):
533
+ case_cp = self.get_an_updated_case_copy(case_cp)
534
+ return self.refinement(case_cp)
535
+ else:
536
+ return self.top_rule.alternative
537
+ elif self.alternative:
538
+ return self.alternative(x)
539
+ else:
540
+ return self.top_rule.alternative
541
+
542
+ def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[None, str]:
543
+ func, func_call = super().get_conclusion_as_source_code(str(conclusion), parent_indent=parent_indent)
544
+ conclusion_str = func_call.replace("return ", "").strip()
545
+ conclusion_str = conclusion_str.replace("(case)", "(case_cp)")
546
+
547
+ parent_to_filter = self.get_parent_to_filter()
548
+ statement = (
549
+ f"\n{parent_indent} case_cp = get_an_updated_case_copy(case, {parent_to_filter.generated_conclusion_function_name},"
550
+ f" attribute_name, conclusion_type, mutually_exclusive)")
551
+ statement += f"\n{parent_indent} conclusions.update(make_set({conclusion_str}))\n"
552
+ return func, statement
553
+
554
+ def get_parent_to_filter(self, parent: Union[None, MultiClassRefinementRule, MultiClassTopRule] = None) \
555
+ -> Union[MultiClassFilterRule, MultiClassTopRule]:
556
+ parent = self.parent if parent is None else parent
557
+ if isinstance(parent, (MultiClassFilterRule, MultiClassTopRule)) and parent.fired:
558
+ return parent
559
+ else:
560
+ return parent.parent
561
+
562
+ def _to_json(self) -> Dict[str, Any]:
563
+ self.json_serialization = super(MultiClassFilterRule, self)._to_json()
564
+ self.json_serialization['refinement'] = self.refinement.to_json() if self.refinement is not None else None
565
+ return self.json_serialization
566
+
567
+ @classmethod
568
+ def _from_json(cls, data: Dict[str, Any]) -> MultiClassFilterRule:
569
+ loaded_rule = super(MultiClassFilterRule, cls)._from_json(data)
570
+ if data['refinement'] is not None:
571
+ data['refinement']['top_rule'] = data['top_rule']
572
+ loaded_rule.refinement = cls.from_json(data["refinement"]) if data["refinement"] is not None else None
573
+ return loaded_rule
481
574
 
482
575
 
483
576
  class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
484
577
  """
485
578
  A rule in the MultiClassRDR classifier, it can have a refinement and a next rule.
486
579
  """
580
+ mutually_exclusive: bool = False
487
581
 
488
582
  def __init__(self, *args, **kwargs):
489
583
  super(MultiClassTopRule, self).__init__(*args, **kwargs)
@@ -491,16 +585,27 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
491
585
 
492
586
  def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassStopRule, MultiClassTopRule]]:
493
587
  if self.fired and self.refinement:
494
- return self.refinement(x)
588
+ case_cp = x
589
+ if isinstance(self.refinement, MultiClassFilterRule):
590
+ case_cp = self.get_an_updated_case_copy(case_cp)
591
+ return self.refinement(case_cp)
495
592
  elif self.alternative: # Here alternative refers to next rule in MultiClassRDR
496
593
  return self.alternative
594
+ return None
497
595
 
498
- def fit_rule(self, case_query: CaseQuery):
596
+ def fit_rule(self, case_query: CaseQuery, refinement_type: Optional[Type[MultiClassRefinementRule]] = None):
499
597
  if self.fired and case_query.target != self.conclusion:
500
- self.refinement = MultiClassStopRule(case_query.conditions, corner_case=case_query.case, parent=self)
598
+ if refinement_type in [None, MultiClassStopRule]:
599
+ new_rule = MultiClassStopRule(case_query.conditions, corner_case=case_query.case,
600
+ parent=self)
601
+ elif refinement_type is MultiClassFilterRule:
602
+ new_rule = MultiClassFilterRule.from_case_query(case_query, parent=self)
603
+ else:
604
+ raise ValueError(f"Unknown refinement type {refinement_type}")
605
+ new_rule.top_rule = self
606
+ self.refinement = new_rule
501
607
  elif not self.fired:
502
- self.alternative = MultiClassTopRule(case_query.conditions, case_query.target,
503
- corner_case=case_query.case, parent=self)
608
+ self.alternative = MultiClassTopRule.from_case_query(case_query, parent=self)
504
609
 
505
610
  def _to_json(self) -> Dict[str, Any]:
506
611
  self.json_serialization = {**Rule._to_json(self),
@@ -515,7 +620,8 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
515
620
  # so the top rule that is already initialized is passed in the data instead of its json serialization.
516
621
  if data['refinement'] is not None:
517
622
  data['refinement']['top_rule'] = loaded_rule
518
- loaded_rule.refinement = MultiClassStopRule.from_json(data["refinement"])
623
+ data_type = get_type_from_string(data["refinement"]["_type"])
624
+ loaded_rule.refinement = data_type.from_json(data["refinement"])
519
625
  loaded_rule.alternative = MultiClassTopRule.from_json(data["alternative"])
520
626
  return loaded_rule
521
627
 
@@ -524,8 +630,6 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
524
630
  conclusion_str = func_call.replace("return ", "").strip()
525
631
 
526
632
  statement = f"{parent_indent} conclusions.update(make_set({conclusion_str}))\n"
527
- if self.alternative is None:
528
- statement += f"{parent_indent}return conclusions\n"
529
633
  return func, statement
530
634
 
531
635
  def _if_statement_source_code_clause(self) -> str:
@@ -50,6 +50,7 @@ class UserPrompt:
50
50
  :return: A callable expression that takes a case and executes user expression on it.
51
51
  """
52
52
  prev_user_input: Optional[str] = None
53
+ user_input_to_modify: Optional[str] = None
53
54
  callable_expression: Optional[CallableExpression] = None
54
55
  while True:
55
56
  with self.shell_lock:
@@ -69,12 +70,14 @@ class UserPrompt:
69
70
  conclusion_type = bool if prompt_for == PromptFor.Conditions else case_query.attribute_type
70
71
  callable_expression = CallableExpression(user_input, conclusion_type, expression_tree=expression_tree,
71
72
  scope=case_query.scope,
72
- mutually_exclusive=case_query.mutually_exclusive)
73
+ mutually_exclusive=case_query.mutually_exclusive)
73
74
  try:
74
75
  result = callable_expression(case_query.case)
75
- if len(make_list(result)) == 0:
76
+ if len(make_list(result)) == 0 and (user_input_to_modify is not None
77
+ and (prev_user_input != user_input_to_modify)):
78
+ user_input_to_modify = prev_user_input
76
79
  self.print_func(f"{Fore.YELLOW}The given expression gave an empty result for case {case_query.name}."
77
- f" Please modify!{Style.RESET_ALL}")
80
+ f" Please accept or modify!{Style.RESET_ALL}")
78
81
  continue
79
82
  break
80
83
  except Exception as e:
@@ -8,6 +8,7 @@ import json
8
8
  import logging
9
9
  import os
10
10
  import re
11
+ import shutil
11
12
  import sys
12
13
  import threading
13
14
  import uuid
@@ -24,6 +25,7 @@ from types import NoneType
24
25
 
25
26
  import six
26
27
  from sqlalchemy.exc import NoInspectionAvailable
28
+ from src.pycram.ros import logwarn
27
29
 
28
30
  try:
29
31
  import matplotlib
@@ -44,12 +46,12 @@ except ImportError as e:
44
46
 
45
47
  import requests
46
48
  from anytree import Node, RenderTree, PreOrderIter
47
- from anytree.exporter import DotExporter
48
49
  from sqlalchemy import MetaData, inspect
49
50
  from sqlalchemy.orm import Mapped, registry, class_mapper, DeclarativeBase as SQLTable, Session
50
51
  from tabulate import tabulate
51
52
  from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get_type_hints, \
52
53
  get_origin, get_args, Tuple, Optional, List, Union, Self, ForwardRef, Iterable
54
+ from . import logger
53
55
 
54
56
  if TYPE_CHECKING:
55
57
  from .datastructures.case import Case
@@ -180,8 +182,8 @@ def extract_function_source(file_path: str,
180
182
  if (len(functions_source) >= len(function_names)) and (not len(function_names) == 0):
181
183
  break
182
184
  if len(functions_source) < len(function_names):
183
- raise ValueError(f"Could not find all functions in {file_path}: {function_names} not found,"
184
- f"functions not found: {set(function_names) - set(functions_source.keys())}")
185
+ logwarn(f"Could not find all functions in {file_path}: {function_names} not found, "
186
+ f"functions not found: {set(function_names) - set(functions_source.keys())}")
185
187
  if return_line_numbers:
186
188
  return functions_source, line_numbers
187
189
  return functions_source
@@ -285,7 +287,7 @@ def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
285
287
  case_query.case.update(conclusions)
286
288
 
287
289
 
288
- def is_conflicting(conclusion: Any, target: Any) -> bool:
290
+ def is_value_conflicting(conclusion: Any, target: Any) -> bool:
289
291
  """
290
292
  :param conclusion: The conclusion to check.
291
293
  :param target: The target to compare the conclusion with.
@@ -845,10 +847,12 @@ def get_relative_import(target_file_path, imported_module_path: Optional[str] =
845
847
  imported_file_name = Path(imported_module_path).name
846
848
  target_file_name = Path(target_file_path).name
847
849
  if package_name is not None:
848
- target_path = Path(get_path_starting_from_latest_encounter_of(str(target_path), package_name, [target_file_name]))
850
+ target_path = Path(
851
+ get_path_starting_from_latest_encounter_of(str(target_path), package_name, [target_file_name]))
849
852
  imported_path = Path(imported_module_path).resolve()
850
853
  if package_name is not None:
851
- imported_path = Path(get_path_starting_from_latest_encounter_of(str(imported_path), package_name, [imported_file_name]))
854
+ imported_path = Path(
855
+ get_path_starting_from_latest_encounter_of(str(imported_path), package_name, [imported_file_name]))
852
856
 
853
857
  # Compute relative path from target to imported module
854
858
  rel_path = os.path.relpath(imported_path.parent, target_path.parent)
@@ -925,8 +929,8 @@ def get_imports_from_types(type_objs: Iterable[Type],
925
929
  continue
926
930
  if name == "NoneType":
927
931
  module = "types"
928
- if module is None or module == 'builtins' or module.startswith('_')\
929
- or module in sys.builtin_module_names or module in excluded_modules or "<" in module \
932
+ if module is None or module == 'builtins' or module.startswith('_') \
933
+ or module in sys.builtin_module_names or module in excluded_modules or "<" in module \
930
934
  or name in exclueded_names:
931
935
  continue
932
936
  if module == "typing":
@@ -1216,7 +1220,8 @@ class SubclassJSONSerializer:
1216
1220
  return cls._from_json(data)
1217
1221
  for subclass in recursive_subclasses(SubclassJSONSerializer):
1218
1222
  if get_full_class_name(subclass) == data["_type"]:
1219
- subclass_data = deepcopy(data)
1223
+ # subclass_data = deepcopy(data)
1224
+ subclass_data = data
1220
1225
  subclass_data.pop("_type")
1221
1226
  return subclass._from_json(subclass_data)
1222
1227
 
@@ -1408,7 +1413,11 @@ def table_rows_as_str(row_dicts: List[Dict[str, Any]], columns_per_row: int = 20
1408
1413
  row_values = [list(map(lambda v: v[:max_line_sze] + '...' if len(v) > max_line_sze else v, row)) for row in
1409
1414
  row_values]
1410
1415
  row_values = [list(map(lambda v: v.lower() if v in ["True", "False"] else v, row)) for row in row_values]
1411
- table = tabulate(row_values, tablefmt='simple_grid', maxcolwidths=[max_line_sze] * 2)
1416
+ # Step 1: Get terminal size
1417
+ terminal_width = shutil.get_terminal_size((80, 20)).columns
1418
+ # Step 2: Dynamically calculate max width per column (simple approximation)
1419
+ max_col_width = terminal_width // len(row_values[0])
1420
+ table = tabulate(row_values, tablefmt='simple_grid', maxcolwidths=max_col_width) # [max_line_sze] * 2)
1412
1421
  all_table_rows.append(table)
1413
1422
  return "\n".join(all_table_rows)
1414
1423
 
@@ -1628,12 +1637,14 @@ def edge_attr_setter(parent, child):
1628
1637
  """
1629
1638
  Set the edge attributes for the dot exporter.
1630
1639
  """
1631
- if child and hasattr(child, "weight") and child.weight:
1640
+ if child and hasattr(child, "weight") and child.weight is not None:
1632
1641
  return f'style="bold", label=" {child.weight}"'
1633
1642
  return ""
1634
1643
 
1635
1644
 
1636
1645
  _RE_ESC = re.compile(r'["\\]')
1646
+
1647
+
1637
1648
  class FilteredDotExporter(object):
1638
1649
 
1639
1650
  def __init__(self, node, include_nodes=None, graph="digraph", name="tree", options=None,
@@ -1913,7 +1924,7 @@ class FilteredDotExporter(object):
1913
1924
 
1914
1925
 
1915
1926
  def render_tree(root: Node, use_dot_exporter: bool = False,
1916
- filename: str = "scrdr", only_nodes: List[Node] = None):
1927
+ filename: str = "scrdr", only_nodes: List[Node] = None, show_in_console: bool = False):
1917
1928
  """
1918
1929
  Render the tree using the console and optionally export it to a dot file.
1919
1930
 
@@ -1921,12 +1932,16 @@ def render_tree(root: Node, use_dot_exporter: bool = False,
1921
1932
  :param use_dot_exporter: Whether to export the tree to a dot file.
1922
1933
  :param filename: The name of the file to export the tree to.
1923
1934
  :param only_nodes: A list of nodes to include in the dot export.
1935
+ :param show_in_console: Whether to print the tree to the console.
1924
1936
  """
1925
1937
  if not root:
1926
- logging.warning("No rules to render")
1938
+ logger.warning("No rules to render")
1927
1939
  return
1928
- # for pre, _, node in RenderTree(root):
1929
- # print(f"{pre}{node.weight if hasattr(node, 'weight') and node.weight else ''} {node.__str__()}")
1940
+ if show_in_console:
1941
+ for pre, _, node in RenderTree(root):
1942
+ if only_nodes is not None and node not in only_nodes:
1943
+ continue
1944
+ print(f"{pre}{node.weight if hasattr(node, 'weight') and node.weight else ''} {node.__str__()}")
1930
1945
  if use_dot_exporter:
1931
1946
  unique_node_names = get_unique_node_names_func(root)
1932
1947
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.6.23
3
+ Version: 0.6.25
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -0,0 +1,24 @@
1
+ ripple_down_rules/__init__.py,sha256=hhuChAaI4pksG6FgW6-mLaJhifqkw_2Dy21ccjHfIFs,100
2
+ ripple_down_rules/experts.py,sha256=MYK1-vuvU1Stp82YZa8TcwOzvriIiYb0WrPFpWUNnXc,13005
3
+ ripple_down_rules/helpers.py,sha256=X1psHOqrb4_xYN4ssQNB8S9aRKKsqgihAyWJurN0dqk,5499
4
+ ripple_down_rules/rdr.py,sha256=avufEkijvJQCji1S5O1zaDSS7aIoP7Yjefy_4uK1TII,62125
5
+ ripple_down_rules/rdr_decorators.py,sha256=TRhbaB_ZIXN0n8Up19NI43_mMjmTm24qo8axAAOzbxM,11728
6
+ ripple_down_rules/rules.py,sha256=N4dEx-xyqxGZpoEYzRd9P5u97_DcDEVLY_UiNhZ4E7g,28726
7
+ ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
8
+ ripple_down_rules/utils.py,sha256=CD-J33gKHUSGU-hZdVvlGtCvsuID7b1pirv5ejpkEP0,73891
9
+ ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
10
+ ripple_down_rules/datastructures/callable_expression.py,sha256=P3o-z54Jt4rtIczeFWiuHFTNqMzYEOm94OyOP535D6Q,13378
11
+ ripple_down_rules/datastructures/case.py,sha256=PJ7_-AdxYic6BO5z816piFODj6nU5J6Jt1YzTFH-dds,15510
12
+ ripple_down_rules/datastructures/dataclasses.py,sha256=kI3Kv8GiVR8igMgA_BlKN6djUYxC2mLecvyh19pqQQA,10998
13
+ ripple_down_rules/datastructures/enums.py,sha256=R9AkhMKTDErOSZ8J6gEdh2lQ0Bjsxs22eMBtCPrXosI,5804
14
+ ripple_down_rules/user_interface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ ripple_down_rules/user_interface/gui.py,sha256=druufu9cVeVUajPW-RqGW3ZiEbgdgNBQD2CLhadQo18,27486
16
+ ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=yp-F8YRWGhj1PLB33HE6vJkdYWFN5Zn2244S2DUWRTM,6576
17
+ ripple_down_rules/user_interface/object_diagram.py,sha256=FEa2HaYR9QmTE6NsOwBvZ0jqmu3DKyg6mig2VE5ZP4Y,4956
18
+ ripple_down_rules/user_interface/prompt.py,sha256=e5FzVfiIagwKTK3WCKsHvWaWZ4kb8FP8X-SgieTln6E,9156
19
+ ripple_down_rules/user_interface/template_file_creator.py,sha256=kwBbFLyN6Yx2NTIHPSwOoytWgbJDYhgrUOVFw_jkDQ4,13522
20
+ ripple_down_rules-0.6.25.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
21
+ ripple_down_rules-0.6.25.dist-info/METADATA,sha256=OCHGDEQ_jkpOq8iUIZONmvsMGXHPdsfzfuTSZOJUb_k,48294
22
+ ripple_down_rules-0.6.25.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ ripple_down_rules-0.6.25.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
24
+ ripple_down_rules-0.6.25.dist-info/RECORD,,
@@ -1,24 +0,0 @@
1
- ripple_down_rules/__init__.py,sha256=Y8dJbmLSU5jLm9E0jiReo5eIcp03jmLvx6sNfRfi49M,100
2
- ripple_down_rules/experts.py,sha256=4-dMIVeMzFXCLYl_XBG_P7_Xs4sZih9-vZxCIPri6dA,12958
3
- ripple_down_rules/helpers.py,sha256=RUdfiSWMZjGwCxuCy44TcEJf2UNAFlPJusgHzuAs6qI,4583
4
- ripple_down_rules/rdr.py,sha256=sibe0amvb2MuVFRGjADgel6a-K0bp8fpJQRLnRuYW-k,57803
5
- ripple_down_rules/rdr_decorators.py,sha256=0xVM-yRZJ6BrpyMfQNBuFUKT_-me6tbd4UTgG5Exx2g,11809
6
- ripple_down_rules/rules.py,sha256=2d2Fgo5urVvUoIJWdRyaojcFBj_lfhLZaBVrXwb5KLA,23764
7
- ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
8
- ripple_down_rules/utils.py,sha256=VM8vrshOJIqXAY48dghu86KSGpYQoGGDMEiq3GrLXsQ,73319
9
- ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
10
- ripple_down_rules/datastructures/callable_expression.py,sha256=ysK-4JmZ4oSUTJC7zpo_o77g4ONxPDEcIpSWggsnx3c,13320
11
- ripple_down_rules/datastructures/case.py,sha256=PJ7_-AdxYic6BO5z816piFODj6nU5J6Jt1YzTFH-dds,15510
12
- ripple_down_rules/datastructures/dataclasses.py,sha256=D-nrVEW_27njmDGkyiHRnq5lmqEdO8RHKnLb1mdnwrA,11054
13
- ripple_down_rules/datastructures/enums.py,sha256=ce7tqS0otfSTNAOwsnXlhsvIn4iW_Y_N3TNebF3YoZs,5700
14
- ripple_down_rules/user_interface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- ripple_down_rules/user_interface/gui.py,sha256=druufu9cVeVUajPW-RqGW3ZiEbgdgNBQD2CLhadQo18,27486
16
- ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=yp-F8YRWGhj1PLB33HE6vJkdYWFN5Zn2244S2DUWRTM,6576
17
- ripple_down_rules/user_interface/object_diagram.py,sha256=FEa2HaYR9QmTE6NsOwBvZ0jqmu3DKyg6mig2VE5ZP4Y,4956
18
- ripple_down_rules/user_interface/prompt.py,sha256=JceEUGYsd0lIvd-v2y3D3swoo96_C0lxfp3CxM7Vfts,8900
19
- ripple_down_rules/user_interface/template_file_creator.py,sha256=kwBbFLyN6Yx2NTIHPSwOoytWgbJDYhgrUOVFw_jkDQ4,13522
20
- ripple_down_rules-0.6.23.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
21
- ripple_down_rules-0.6.23.dist-info/METADATA,sha256=5B9P3h5EyJEZvyuj7CtFVNeJxQOg8kjG7pQJfoVlAqs,48294
22
- ripple_down_rules-0.6.23.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- ripple_down_rules-0.6.23.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
24
- ripple_down_rules-0.6.23.dist-info/RECORD,,