ripple-down-rules 0.1.66__tar.gz → 0.1.68__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/PKG-INFO +1 -1
  2. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/pyproject.toml +1 -1
  3. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/datasets.py +1 -1
  4. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/datastructures/callable_expression.py +17 -9
  5. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/datastructures/case.py +7 -4
  6. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/datastructures/dataclasses.py +1 -1
  7. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/prompt.py +3 -9
  8. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/rdr.py +46 -30
  9. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/rules.py +17 -17
  10. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/utils.py +196 -38
  11. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
  12. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/test/test_rdr_world.py +8 -0
  13. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/test/test_relational_rdr.py +5 -4
  14. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/test/test_relational_rdr_alchemy.py +5 -4
  15. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/LICENSE +0 -0
  16. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/README.md +0 -0
  17. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/setup.cfg +0 -0
  18. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/__init__.py +0 -0
  19. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/datastructures/__init__.py +0 -0
  20. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/datastructures/enums.py +0 -0
  21. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/experts.py +0 -0
  22. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/failures.py +0 -0
  23. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/helpers.py +0 -0
  24. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules/rdr_decorators.py +0 -0
  25. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules.egg-info/SOURCES.txt +0 -0
  26. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
  27. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
  28. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/test/test_json_serialization.py +0 -0
  29. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/test/test_on_mutagenic.py +0 -0
  30. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/test/test_rdr.py +0 -0
  31. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/test/test_rdr_alchemy.py +0 -0
  32. {ripple_down_rules-0.1.66 → ripple_down_rules-0.1.68}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.66
3
+ Version: 0.1.68
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
 
7
7
  [project]
8
8
  name = "ripple_down_rules"
9
- version = "0.1.66"
9
+ version = "0.1.68"
10
10
  description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
11
11
  readme = "README.md"
12
12
  authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
@@ -125,7 +125,7 @@ class HabitatTable(MappedAsDataclass, Base):
125
125
  return hash(self.habitat)
126
126
 
127
127
  def __str__(self):
128
- return self.habitat.value
128
+ return f"{HabitatTable.__name__}({Habitat.__name__}.{self.habitat.name})"
129
129
 
130
130
  def __repr__(self):
131
131
  return self.__str__()
@@ -3,11 +3,13 @@ from __future__ import annotations
3
3
  import ast
4
4
  import logging
5
5
  from _ast import AST
6
+ from enum import Enum
6
7
 
7
8
  from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
8
9
 
9
10
  from .case import create_case, Case
10
- from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string, conclusion_to_json, is_iterable
11
+ from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string, conclusion_to_json, is_iterable, \
12
+ build_user_input_from_conclusion, encapsulate_user_input
11
13
 
12
14
 
13
15
  class VariableVisitor(ast.NodeVisitor):
@@ -88,6 +90,7 @@ class CallableExpression(SubclassJSONSerializer):
88
90
  """
89
91
  A callable that is constructed from a string statement written by an expert.
90
92
  """
93
+ encapsulating_function: str = "def _get_value(case):"
91
94
 
92
95
  def __init__(self, user_input: Optional[str] = None, conclusion_type: Optional[Tuple[Type]] = None,
93
96
  expression_tree: Optional[AST] = None,
@@ -103,8 +106,10 @@ class CallableExpression(SubclassJSONSerializer):
103
106
  """
104
107
  if user_input is None and conclusion is None:
105
108
  raise ValueError("Either user_input or conclusion must be provided.")
109
+ if user_input is None:
110
+ user_input = build_user_input_from_conclusion(conclusion)
106
111
  self.conclusion: Optional[Any] = conclusion
107
- self.user_input: str = user_input
112
+ self.user_input: str = encapsulate_user_input(user_input, self.encapsulating_function)
108
113
  if conclusion_type is not None:
109
114
  if is_iterable(conclusion_type):
110
115
  conclusion_type = tuple(conclusion_type)
@@ -112,12 +117,11 @@ class CallableExpression(SubclassJSONSerializer):
112
117
  conclusion_type = (conclusion_type,)
113
118
  self.conclusion_type = conclusion_type
114
119
  self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
115
- if conclusion is None:
116
- self.scope = get_used_scope(self.user_input, self.scope)
117
- self.expression_tree: AST = expression_tree if expression_tree else parse_string_to_expression(self.user_input)
118
- self.code = compile_expression_to_code(self.expression_tree)
119
- self.visitor = VariableVisitor()
120
- self.visitor.visit(self.expression_tree)
120
+ self.scope = get_used_scope(self.user_input, self.scope)
121
+ self.expression_tree: AST = expression_tree if expression_tree else parse_string_to_expression(self.user_input)
122
+ self.code = compile_expression_to_code(self.expression_tree)
123
+ self.visitor = VariableVisitor()
124
+ self.visitor.visit(self.expression_tree)
121
125
 
122
126
  def __call__(self, case: Any, **kwargs) -> Any:
123
127
  try:
@@ -145,7 +149,11 @@ class CallableExpression(SubclassJSONSerializer):
145
149
  """
146
150
  Combine this callable expression with another callable expression using the 'and' operator.
147
151
  """
148
- new_user_input = f"({self.user_input}) and ({other.user_input})"
152
+ cond1_user_input = self.user_input.replace(self.encapsulating_function, "def _cond1(case):")
153
+ cond2_user_input = other.user_input.replace(self.encapsulating_function, "def _cond2(case):")
154
+ new_user_input = (f"{cond1_user_input}\n"
155
+ f"{cond2_user_input}\n"
156
+ f"return _cond1(case) and _cond2(case)")
149
157
  return CallableExpression(new_user_input, conclusion_type=self.conclusion_type)
150
158
 
151
159
  def __eq__(self, other):
@@ -11,7 +11,7 @@ from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColum
11
11
  from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING
12
12
 
13
13
  from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer, \
14
- get_full_class_name, get_type_from_string, make_list, is_iterable, serialize_dataclass
14
+ get_full_class_name, get_type_from_string, make_list, is_iterable, serialize_dataclass, dataclass_to_dict
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from ripple_down_rules.rules import Rule
@@ -99,8 +99,7 @@ class Case(UserDict, SubclassJSONSerializer):
99
99
  obj_type = get_type_from_string(data.pop("_obj_type"))
100
100
  name = data.pop("_name")
101
101
  for k, v in data.items():
102
- if isinstance(v, dict) and "_type" in v:
103
- data[k] = SubclassJSONSerializer.from_json(v)
102
+ data[k] = SubclassJSONSerializer.from_json(v)
104
103
  return cls(_obj_type=obj_type, _id=id_, _name=name, **data)
105
104
 
106
105
 
@@ -219,7 +218,7 @@ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
219
218
  obj_name = obj_name or obj.__class__.__name__
220
219
  if isinstance(obj, DataFrame):
221
220
  return create_cases_from_dataframe(obj, obj_name)
222
- if isinstance(obj, Case):
221
+ if isinstance(obj, Case) or (is_dataclass(obj) and not isinstance(obj, SQLTable)):
223
222
  return obj
224
223
  if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
225
224
  or (obj.__class__ in [MetaData, registry])):
@@ -318,6 +317,10 @@ def show_current_and_corner_cases(case: Any, targets: Optional[Dict[str, Any]] =
318
317
  case_dict = row_to_dict(case)
319
318
  if last_evaluated_rule and last_evaluated_rule.fired:
320
319
  corner_row_dict = row_to_dict(last_evaluated_rule.corner_case)
320
+ elif is_dataclass(case):
321
+ case_dict = dataclass_to_dict(case)
322
+ if last_evaluated_rule and last_evaluated_rule.fired:
323
+ corner_row_dict = dataclass_to_dict(last_evaluated_rule.corner_case)
321
324
  else:
322
325
  case_dict = case
323
326
  if last_evaluated_rule and last_evaluated_rule.fired:
@@ -120,7 +120,7 @@ class CaseQuery:
120
120
  """
121
121
  :return: The target expression of the attribute.
122
122
  """
123
- if self._target is not None and not isinstance(self._target, CallableExpression):
123
+ if (self._target is not None) and (not isinstance(self._target, CallableExpression)):
124
124
  self._target = CallableExpression(conclusion=self._target, conclusion_type=self.attribute_type,
125
125
  scope=self.scope)
126
126
  return self._target
@@ -80,16 +80,10 @@ class IPythonShell:
80
80
  self.user_input = None
81
81
  else:
82
82
  self.all_code_lines = extract_dependencies(self.shell.all_lines)
83
- if len(self.all_code_lines) == 1:
84
- if self.all_code_lines[0].strip() == '':
85
- self.user_input = None
86
- else:
87
- self.user_input = self.all_code_lines[0].replace('return', '').strip()
83
+ if len(self.all_code_lines) == 1 and self.all_code_lines[0].strip() == '':
84
+ self.user_input = None
88
85
  else:
89
- self.user_input = f"def _get_value(case):\n "
90
- for cl in self.all_code_lines:
91
- sub_code_lines = cl.split('\n')
92
- self.user_input += '\n '.join(sub_code_lines) + '\n '
86
+ self.user_input = '\n'.join(self.all_code_lines)
93
87
 
94
88
 
95
89
  def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, prompt_str: Optional[str] = None)\
@@ -211,17 +211,17 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
211
211
  func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
212
212
  file_name = file_path + f"/{self.generated_python_file_name}.py"
213
213
  defs_file_name = file_path + f"/{self.generated_python_defs_file_name}.py"
214
- imports = self._get_imports()
214
+ imports, defs_imports = self._get_imports()
215
215
  # clear the files first
216
216
  with open(defs_file_name, "w") as f:
217
- f.write(imports + "\n\n")
217
+ f.write(defs_imports + "\n\n")
218
218
  with open(file_name, "w") as f:
219
219
  imports += f"from .{self.generated_python_defs_file_name} import *\n"
220
220
  imports += f"from ripple_down_rules.rdr import {self.__class__.__name__}\n"
221
221
  f.write(imports + "\n\n")
222
- f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n\n")
223
- f.write(f"type_ = {self.__class__.__name__}\n\n")
224
- f.write(func_def)
222
+ f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
223
+ f.write(f"type_ = {self.__class__.__name__}\n")
224
+ f.write(f"\n\n{func_def}")
225
225
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
226
226
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
227
227
  self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name)
@@ -234,17 +234,11 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
234
234
  """
235
235
  pass
236
236
 
237
- def _get_imports(self) -> str:
237
+ def _get_imports(self) -> Tuple[str, str]:
238
238
  """
239
239
  :return: The imports for the generated python file of the RDR as a string.
240
240
  """
241
- imports = ""
242
- if self.case_type.__module__ != "builtins":
243
- imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
244
- for conclusion_type in self.conclusion_type:
245
- if conclusion_type.__module__ != "builtins":
246
- imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
247
- imports += "from ripple_down_rules.datastructures.case import Case, create_case\n"
241
+ defs_imports = ""
248
242
  for rule in [self.start_rule] + list(self.start_rule.descendants):
249
243
  if not rule.conditions:
250
244
  continue
@@ -255,10 +249,21 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
255
249
  if not hasattr(v, "__module__") or not hasattr(v, "__name__"):
256
250
  continue
257
251
  new_imports = f"from {v.__module__} import {v.__name__}\n"
258
- if new_imports in imports:
252
+ if new_imports in defs_imports:
259
253
  continue
260
- imports += new_imports
261
- return imports
254
+ defs_imports += new_imports
255
+ imports = ""
256
+ if self.case_type.__module__ != "builtins":
257
+ new_import = f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
258
+ if new_import not in defs_imports:
259
+ imports += new_import
260
+ for conclusion_type in self.conclusion_type:
261
+ if conclusion_type.__module__ != "builtins":
262
+ new_import = f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
263
+ if new_import not in defs_imports:
264
+ imports += new_import
265
+ imports += "from ripple_down_rules.datastructures.case import Case, create_case\n"
266
+ return imports, defs_imports
262
267
 
263
268
  def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
264
269
  """
@@ -293,7 +298,11 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
293
298
  """
294
299
  :return: The default generated python file name.
295
300
  """
296
- return f"{self.start_rule.corner_case._name.lower()}_{self.attribute_name}_{self.acronym.lower()}"
301
+ if isinstance(self.start_rule.corner_case, Case):
302
+ name = self.start_rule.corner_case._name
303
+ else:
304
+ name = self.start_rule.corner_case.__class__.__name__
305
+ return f"{name.lower()}_{self.attribute_name}_{self.acronym.lower()}"
297
306
 
298
307
  @property
299
308
  def generated_python_defs_file_name(self) -> str:
@@ -539,13 +548,21 @@ class MultiClassRDR(RDRWithCodeWriter):
539
548
 
540
549
  @property
541
550
  def conclusion_type_hint(self) -> str:
542
- return f"Set[Union[{', '.join([ct.__name__ for ct in self.conclusion_type if ct not in [list, set]])}]]"
551
+ conclusion_types = [ct.__name__ for ct in self.conclusion_type if ct not in [list, set]]
552
+ if len(conclusion_types) == 1:
553
+ return f"Set[{conclusion_types[0]}]"
554
+ else:
555
+ return f"Set[Union[{', '.join(conclusion_types)}]]"
543
556
 
544
- def _get_imports(self) -> str:
545
- imports = super()._get_imports()
546
- imports += "from typing_extensions import Set, Union\n"
557
+ def _get_imports(self) -> Tuple[str, str]:
558
+ imports, defs_imports = super()._get_imports()
559
+ conclusion_types = [ct for ct in self.conclusion_type if ct not in [list, set]]
560
+ if len(conclusion_types) == 1:
561
+ imports += f"from typing_extensions import Set\n"
562
+ else:
563
+ imports += "from typing_extensions import Set, Union\n"
547
564
  imports += "from ripple_down_rules.utils import make_set\n"
548
- return imports
565
+ return imports, defs_imports
549
566
 
550
567
  def update_start_rule(self, case_query: CaseQuery, expert: Expert):
551
568
  """
@@ -715,7 +732,7 @@ class GeneralRDR(RippleDownRules):
715
732
 
716
733
  @staticmethod
717
734
  def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
718
- case: Any, modify_original_case: bool = True) -> Dict[str, Any]:
735
+ case: Any, modify_original_case: bool = False) -> Dict[str, Any]:
719
736
  """
720
737
  Classify a case by going through all classifiers and adding the categories that are classified,
721
738
  and then restarting the classification until no more categories can be added.
@@ -867,7 +884,11 @@ class GeneralRDR(RippleDownRules):
867
884
  """
868
885
  :return: The default generated python file name.
869
886
  """
870
- return f"{self.start_rule.corner_case._name.lower()}_rdr"
887
+ if isinstance(self.start_rule.corner_case, Case):
888
+ name = self.start_rule.corner_case._name
889
+ else:
890
+ name = self.start_rule.corner_case.__class__.__name__
891
+ return f"{name}_rdr".lower()
871
892
 
872
893
  @property
873
894
  def conclusion_type_hint(self) -> str:
@@ -882,17 +903,12 @@ class GeneralRDR(RippleDownRules):
882
903
  """
883
904
  imports = ""
884
905
  # add type hints
885
- imports += f"from typing_extensions import Dict, Any, Union, Set\n"
906
+ imports += f"from typing_extensions import Dict, Any\n"
886
907
  # import rdr type
887
908
  imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
888
909
  # add case type
889
910
  imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
890
911
  imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
891
- # add conclusion type imports
892
- for rdr in self.start_rules_dict.values():
893
- for conclusion_type in rdr.conclusion_type:
894
- if conclusion_type.__module__ != "builtins":
895
- imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
896
912
  # add rdr python generated functions.
897
913
  for rdr_key, rdr in self.start_rules_dict.items():
898
914
  imports += (f"from {file_path.strip('./')}"
@@ -93,7 +93,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
93
93
  conclusion_func, conclusion_func_call = self._conclusion_source_code(conclusion, parent_indent=parent_indent)
94
94
  if conclusion_func is not None:
95
95
  with open(defs_file, 'a') as f:
96
- f.write(conclusion_func + "\n\n")
96
+ f.write(conclusion_func.strip() + "\n\n\n")
97
97
  return conclusion_func_call
98
98
 
99
99
  @abstractmethod
@@ -120,7 +120,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
120
120
  conditions_lines[0] = re.sub(r"def (\w+)", new_function_name, conditions_lines[0])
121
121
  def_code = "\n".join(conditions_lines)
122
122
  with open(defs_file, 'a') as f:
123
- f.write(def_code + "\n\n")
123
+ f.write(def_code.strip() + "\n\n\n")
124
124
  return f"\n{parent_indent}{if_clause} {new_function_name.replace('def ', '')}(case):\n"
125
125
 
126
126
  @abstractmethod
@@ -131,7 +131,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
131
131
  json_serialization = {"conditions": self.conditions.to_json(),
132
132
  "conclusion": conclusion_to_json(self.conclusion),
133
133
  "parent": self.parent.json_serialization if self.parent else None,
134
- "corner_case": self.corner_case.to_json() if self.corner_case else None,
134
+ "corner_case": SubclassJSONSerializer.to_json_static(self.corner_case),
135
135
  "conclusion_name": self.conclusion_name,
136
136
  "weight": self.weight}
137
137
  return json_serialization
@@ -270,11 +270,11 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
270
270
 
271
271
  def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
272
272
  conclusion = str(conclusion)
273
- indent = parent_indent + " " * 4
274
- if '\n' not in conclusion:
275
- return None, f"{indent}return {conclusion}\n"
276
- else:
277
- return get_rule_conclusion_as_source_code(self, conclusion, parent_indent=parent_indent)
273
+ # indent = parent_indent + " " * 4
274
+ # if '\n' not in conclusion:
275
+ # return None, f"{indent}return {conclusion}\n"
276
+ # else:
277
+ return get_rule_conclusion_as_source_code(self, conclusion, parent_indent=parent_indent)
278
278
 
279
279
  def _if_statement_source_code_clause(self) -> str:
280
280
  return "elif" if self.weight == RDREdge.Alternative.value else "if"
@@ -367,15 +367,15 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
367
367
  def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[str, str]:
368
368
  conclusion_str = str(conclusion)
369
369
  indent = parent_indent + " " * 4
370
- if '\n' not in conclusion_str:
371
- func = None
372
- if is_iterable(conclusion):
373
- conclusion_str = "{" + ", ".join([str(c) for c in conclusion]) + "}"
374
- else:
375
- conclusion_str = "{" + str(conclusion) + "}"
376
- else:
377
- func, func_call = get_rule_conclusion_as_source_code(self, conclusion_str, parent_indent=parent_indent)
378
- conclusion_str = func_call.replace("return ", "").strip()
370
+ # if '\n' not in conclusion_str:
371
+ # func = None
372
+ # if is_iterable(conclusion):
373
+ # conclusion_str = "{" + ", ".join([str(c) for c in conclusion]) + "}"
374
+ # else:
375
+ # conclusion_str = "{" + str(conclusion) + "}"
376
+ # else:
377
+ func, func_call = get_rule_conclusion_as_source_code(self, conclusion_str, parent_indent=parent_indent)
378
+ conclusion_str = func_call.replace("return ", "").strip()
379
379
 
380
380
  statement = f"{indent}conclusions.update(make_set({conclusion_str}))\n"
381
381
  if self.alternative is None:
@@ -7,9 +7,11 @@ import json
7
7
  import logging
8
8
  import os
9
9
  import re
10
+ import uuid
10
11
  from collections import UserDict
11
12
  from copy import deepcopy
12
13
  from dataclasses import is_dataclass, fields
14
+ from enum import Enum
13
15
  from types import NoneType
14
16
 
15
17
  import matplotlib
@@ -34,6 +36,58 @@ import ast
34
36
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
35
37
 
36
38
 
39
+ def encapsulate_user_input(user_input: str, func_signature: str) -> str:
40
+ """
41
+ Encapsulate the user input string with a function definition.
42
+
43
+ :param user_input: The user input string.
44
+ :param func_signature: The function signature to use for encapsulation.
45
+ :return: The encapsulated user input string.
46
+ """
47
+ if func_signature not in user_input:
48
+ new_user_input = func_signature + "\n "
49
+ if "return " not in user_input:
50
+ if '\n' not in user_input:
51
+ new_user_input += f"return {user_input}"
52
+ else:
53
+ raise ValueError("User input must contain a return statement or be a single line.")
54
+ else:
55
+ for cl in user_input.split('\n'):
56
+ sub_code_lines = cl.split('\n')
57
+ new_user_input += '\n '.join(sub_code_lines) + '\n '
58
+ else:
59
+ new_user_input = user_input
60
+ return new_user_input
61
+
62
+
63
+ def build_user_input_from_conclusion(conclusion: Any) -> str:
64
+ """
65
+ Build a user input string from the conclusion.
66
+
67
+ :param conclusion: The conclusion to use for the callable expression.
68
+ :return: The user input string.
69
+ """
70
+
71
+ # set user_input to the string representation of the conclusion
72
+ if isinstance(conclusion, set):
73
+ user_input = '{' + f"{', '.join([conclusion_to_str(t) for t in conclusion])}" + '}'
74
+ elif isinstance(conclusion, list):
75
+ user_input = '[' + f"{', '.join([conclusion_to_str(t) for t in conclusion])}" + ']'
76
+ elif isinstance(conclusion, tuple):
77
+ user_input = '(' + f"{', '.join([conclusion_to_str(t) for t in conclusion])}" + ')'
78
+ else:
79
+ user_input = conclusion_to_str(conclusion)
80
+
81
+ return user_input
82
+
83
+
84
+ def conclusion_to_str(conclusion_: Any) -> str:
85
+ if isinstance(conclusion_, Enum):
86
+ return type(conclusion_).__name__ + '.' + conclusion_.name
87
+ else:
88
+ return str(conclusion_)
89
+
90
+
37
91
  def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
38
92
  """
39
93
  Update the case with the conclusions.
@@ -269,7 +323,7 @@ def extract_dependencies(code_lines):
269
323
  return required_lines
270
324
 
271
325
 
272
- def serialize_dataclass(obj: Any) -> Union[Dict, Any]:
326
+ def serialize_dataclass(obj: Any, seen=None) -> Any:
273
327
  """
274
328
  Recursively serialize a dataclass to a dictionary. If the dataclass contains any nested dataclasses, they will be
275
329
  serialized as well. If the object is not a dataclass, it will be returned as is.
@@ -277,24 +331,44 @@ def serialize_dataclass(obj: Any) -> Union[Dict, Any]:
277
331
  :param obj: The dataclass to serialize.
278
332
  :return: The serialized dataclass as a dictionary or the object itself if it is not a dataclass.
279
333
  """
280
-
281
- def recursive_convert(obj):
282
- if is_dataclass(obj):
283
- return {
284
- "__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}",
285
- "fields": {f.name: recursive_convert(getattr(obj, f.name)) for f in fields(obj)}
286
- }
287
- elif isinstance(obj, list):
288
- return [recursive_convert(item) for item in obj]
289
- elif isinstance(obj, dict):
290
- return {k: recursive_convert(v) for k, v in obj.items()}
291
- else:
334
+ if seen is None:
335
+ seen = {}
336
+
337
+ obj_id = id(obj)
338
+ if obj_id in seen:
339
+ return {'$ref': seen[obj_id]}
340
+
341
+ if is_dataclass(obj):
342
+ uid = str(uuid.uuid4())
343
+ seen[obj_id] = uid
344
+ result = {
345
+ '$id': uid,
346
+ "__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}",
347
+ 'fields': {}
348
+ }
349
+ for f in fields(obj):
350
+ value = getattr(obj, f.name)
351
+ result['fields'][f.name] = serialize_dataclass(value, seen)
352
+ return result
353
+ elif isinstance(obj, list):
354
+ return [serialize_dataclass(v, seen) for v in obj]
355
+ elif isinstance(obj, dict):
356
+ return {k: serialize_dataclass(v, seen) for k, v in obj.items()}
357
+ else:
358
+ try:
359
+ json.dumps(obj) # Check if the object is JSON serializable
292
360
  return obj
361
+ except TypeError:
362
+ return None
293
363
 
294
- return recursive_convert(obj)
295
364
 
365
+ def deserialize_dataclass(data: Any, refs: Optional[Dict[str, Any]] = None) -> Any:
366
+ refs = {} if refs is None else refs
367
+ preloaded = preload_serialized_objects(data, refs)
368
+ return resolve_refs(preloaded, refs)
296
369
 
297
- def deserialize_dataclass(data: dict) -> Any:
370
+
371
+ def preload_serialized_objects(data: Any, refs: Dict[str, Any] = None) -> Any:
298
372
  """
299
373
  Recursively deserialize a dataclass from a dictionary, if the dictionary contains a key "__dataclass__" (Most likely
300
374
  created by the serialize_dataclass function), it will be treated as a dataclass and deserialized accordingly,
@@ -303,25 +377,81 @@ def deserialize_dataclass(data: dict) -> Any:
303
377
  :param data: The dictionary to deserialize.
304
378
  :return: The deserialized dataclass.
305
379
  """
380
+ if refs is None:
381
+ refs = {}
382
+
383
+ if isinstance(data, dict):
384
+
385
+ if '$ref' in data:
386
+ ref_id = data['$ref']
387
+ if ref_id not in refs:
388
+ return {'$ref': data['$ref']}
389
+ return refs[ref_id]
390
+
391
+ elif '$id' in data and '__dataclass__' in data and 'fields' in data:
392
+ cls_path = data['__dataclass__']
393
+ module_name, class_name = cls_path.rsplit('.', 1)
394
+ cls = getattr(importlib.import_module(module_name), class_name)
395
+
396
+ dummy_instance = cls.__new__(cls) # Don't call __init__ yet
397
+ refs[data['$id']] = dummy_instance
398
+
399
+ for f in fields(cls):
400
+ raw_value = data['fields'].get(f.name)
401
+ value = preload_serialized_objects(raw_value, refs)
402
+ setattr(dummy_instance, f.name, value)
403
+
404
+ return dummy_instance
306
405
 
307
- def recursive_load(obj):
308
- if isinstance(obj, dict) and "__dataclass__" in obj:
309
- module_name, class_name = obj["__dataclass__"].rsplit(".", 1)
310
- module = importlib.import_module(module_name)
311
- cls: Type = getattr(module, class_name)
312
- field_values = {
313
- k: recursive_load(v)
314
- for k, v in obj["fields"].items()
315
- }
316
- return cls(**field_values)
317
- elif isinstance(obj, list):
318
- return [recursive_load(item) for item in obj]
319
- elif isinstance(obj, dict):
320
- return {k: recursive_load(v) for k, v in obj.items()}
321
406
  else:
322
- return obj
407
+ return {k: preload_serialized_objects(v, refs) for k, v in data.items()}
408
+
409
+ elif isinstance(data, list):
410
+ return [preload_serialized_objects(item, refs) for item in data]
411
+ elif isinstance(data, dict):
412
+ return {k: preload_serialized_objects(v, refs) for k, v in data.items()}
413
+
414
+ return data # Primitive
415
+
416
+
417
+ def resolve_refs(obj, refs, seen=None):
418
+ if seen is None:
419
+ seen = {}
420
+
421
+ obj_id = id(obj)
422
+ if obj_id in seen:
423
+ return seen[obj_id]
323
424
 
324
- return recursive_load(data)
425
+ # Resolve if dict with $ref
426
+ if isinstance(obj, dict) and '$ref' in obj:
427
+ ref_id = obj['$ref']
428
+ if ref_id not in refs:
429
+ raise KeyError(f"$ref to unknown ID: {ref_id}")
430
+ return refs[ref_id]
431
+
432
+ elif is_dataclass(obj):
433
+ seen[obj_id] = obj # Mark before diving deeper
434
+ for f in fields(obj):
435
+ val = getattr(obj, f.name)
436
+ resolved = resolve_refs(val, refs, seen)
437
+ setattr(obj, f.name, resolved)
438
+ return obj
439
+
440
+ elif isinstance(obj, list):
441
+ resolved_list = []
442
+ seen[obj_id] = resolved_list
443
+ for item in obj:
444
+ resolved_list.append(resolve_refs(item, refs, seen))
445
+ return resolved_list
446
+
447
+ elif isinstance(obj, dict):
448
+ resolved_dict = {}
449
+ seen[obj_id] = resolved_dict
450
+ for k, v in obj.items():
451
+ resolved_dict[k] = resolve_refs(v, refs, seen)
452
+ return resolved_dict
453
+
454
+ return obj # Primitive
325
455
 
326
456
 
327
457
  def typing_to_python_type(typing_hint: Type) -> Type:
@@ -494,6 +624,7 @@ class SubclassJSONSerializer:
494
624
  Classes that inherit from this class can be serialized and deserialized automatically by calling this classes
495
625
  'from_json' method.
496
626
  """
627
+ data_class_refs = {}
497
628
 
498
629
  def to_json_file(self, filename: str):
499
630
  """
@@ -507,8 +638,14 @@ class SubclassJSONSerializer:
507
638
  json.dump(data, f, indent=4)
508
639
  return data
509
640
 
641
+ @staticmethod
642
+ def to_json_static(obj) -> Dict[str, Any]:
643
+ if is_dataclass(obj):
644
+ return serialize_dataclass(obj)
645
+ return {"_type": get_full_class_name(obj.__class__), **obj._to_json()}
646
+
510
647
  def to_json(self) -> Dict[str, Any]:
511
- return {"_type": get_full_class_name(self.__class__), **self._to_json()}
648
+ return self.to_json_static(self)
512
649
 
513
650
  def _to_json(self) -> Dict[str, Any]:
514
651
  """
@@ -529,7 +666,7 @@ class SubclassJSONSerializer:
529
666
  raise NotImplementedError()
530
667
 
531
668
  @classmethod
532
- def from_json_file(cls, filename: str):
669
+ def from_json_file(cls, filename: str) -> Any:
533
670
  """
534
671
  Create an instance of the subclass from the data in the given json file.
535
672
 
@@ -539,7 +676,9 @@ class SubclassJSONSerializer:
539
676
  filename += ".json"
540
677
  with open(filename, "r") as f:
541
678
  scrdr_json = json.load(f)
542
- return cls.from_json(scrdr_json)
679
+ deserialized_obj = cls.from_json(scrdr_json)
680
+ cls.data_class_refs.clear()
681
+ return deserialized_obj
543
682
 
544
683
  @classmethod
545
684
  def from_json(cls, data: Dict[str, Any]) -> Self:
@@ -551,11 +690,17 @@ class SubclassJSONSerializer:
551
690
  """
552
691
  if data is None:
553
692
  return None
554
- if not isinstance(data, dict) or ('_type' not in data):
693
+ if isinstance(data, list):
694
+ # if the data is a list, deserialize it
695
+ return [cls.from_json(d) for d in data]
696
+ elif isinstance(data, dict):
697
+ if '__dataclass__' in data:
698
+ # if the data is a dataclass, deserialize it
699
+ return deserialize_dataclass(data, cls.data_class_refs)
700
+ elif '_type' not in data:
701
+ return {k: cls.from_json(v) for k, v in data.items()}
702
+ elif not isinstance(data, dict):
555
703
  return data
556
- if '__dataclass__' in data:
557
- # if the data is a dataclass, deserialize it
558
- return deserialize_dataclass(data)
559
704
 
560
705
  # check if type module is builtins
561
706
  data_type = get_type_from_string(data["_type"])
@@ -704,6 +849,19 @@ def row_to_dict(obj):
704
849
  }
705
850
 
706
851
 
852
+ def dataclass_to_dict(obj):
853
+ """
854
+ Convert a dataclass to a dictionary.
855
+
856
+ :param obj: The dataclass to convert.
857
+ :return: The dictionary representation of the dataclass.
858
+ """
859
+ if is_dataclass(obj):
860
+ return {f.name: getattr(obj, f.name) for f in fields(obj) if not f.name.startswith("_")}
861
+ else:
862
+ raise ValueError(f"Object {obj} is not a dataclass.")
863
+
864
+
707
865
  def get_attribute_name(obj: Any, attribute: Optional[Any] = None, attribute_type: Optional[Type] = None,
708
866
  possible_value: Optional[Any] = None) -> Optional[str]:
709
867
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.66
3
+ Version: 0.1.68
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -113,6 +113,14 @@ class TestRDRWorld(TestCase):
113
113
  def test_view_rdr(self):
114
114
  self.get_view_rdr(use_loaded_answers=True, save_answers=False, append=False)
115
115
 
116
+ def test_save_and_load_view_rdr(self):
117
+ view_rdr = self.get_view_rdr(use_loaded_answers=True, save_answers=False, append=False)
118
+ filename = os.path.join(os.getcwd(), "test_results/world_views_rdr")
119
+ view_rdr.save(filename)
120
+ loaded_rdr = GeneralRDR.load(filename)
121
+ self.assertEqual(view_rdr.classify(self.world), loaded_rdr.classify(self.world))
122
+ self.assertEqual(self.world.bodies, loaded_rdr.start_rules[0].corner_case.bodies)
123
+
116
124
  def test_write_view_rdr_to_python_file(self):
117
125
  rdrs_dir = "./test_generated_rdrs"
118
126
  view_rdr = self.get_view_rdr()
@@ -54,6 +54,7 @@ class RelationalRDRTestCase(TestCase):
54
54
  part_d: Part
55
55
  part_e: Part
56
56
  part_f: Part
57
+ target: List[PhysicalObject]
57
58
 
58
59
  @classmethod
59
60
  def setUpClass(cls):
@@ -71,8 +72,8 @@ class RelationalRDRTestCase(TestCase):
71
72
  cls.part_d.contained_objects = [cls.part_e]
72
73
  cls.part_e.contained_objects = [cls.part_f]
73
74
  cls.robot: Robot = robot
74
- cls.case_query = CaseQuery(robot, "contained_objects", (PhysicalObject,), False,
75
- _target=[cls.part_b, cls.part_c, cls.part_d, cls.part_e])
75
+ cls.case_query = CaseQuery(robot, "contained_objects", (PhysicalObject,), False)
76
+ cls.target = [cls.part_b, cls.part_c, cls.part_d, cls.part_e]
76
77
 
77
78
  def test_classify_scrdr(self):
78
79
  use_loaded_answers = True
@@ -86,7 +87,7 @@ class RelationalRDRTestCase(TestCase):
86
87
  cat = scrdr.fit_case(CaseQuery(self.robot, "contained_objects", (PhysicalObject,), False), expert=expert)
87
88
  render_tree(scrdr.start_rule, use_dot_exporter=True,
88
89
  filename=self.test_results_dir + "/relational_scrdr_classify")
89
- assert cat == self.case_query.target(self.case_query.case)
90
+ assert cat == self.target
90
91
 
91
92
  if save_answers:
92
93
  cwd = os.getcwd()
@@ -105,4 +106,4 @@ class RelationalRDRTestCase(TestCase):
105
106
  conclusion = CallableExpression(user_input, list)
106
107
  print(conclusion)
107
108
  print(conclusion(self.robot))
108
- assert conclusion(self.robot) == self.case_query.target(self.case_query.case)
109
+ assert conclusion(self.robot) == self.target
@@ -108,6 +108,7 @@ class RelationalRDRTestCase(TestCase):
108
108
  part_f: PhysicalObject
109
109
  rob_has_parts: List[HasPart]
110
110
  containments: List[ContainsObject]
111
+ target: List[PhysicalObject]
111
112
 
112
113
  @classmethod
113
114
  def setUpClass(cls):
@@ -129,8 +130,8 @@ class RelationalRDRTestCase(TestCase):
129
130
  cls.containments.append(ContainsObject(left=cls.part_d, right=cls.part_e))
130
131
  cls.containments.append(ContainsObject(left=cls.part_e, right=cls.part_f))
131
132
  cls.robot: PhysicalObject = robot
132
- cls.case_query = CaseQuery(robot, robot.contained_objects, (PhysicalObject,), False,
133
- _target=[cls.part_b, cls.part_c, cls.part_d, cls.part_e])
133
+ cls.case_query = CaseQuery(robot, robot.contained_objects, (PhysicalObject,), False)
134
+ cls.target = [cls.part_b, cls.part_c, cls.part_d, cls.part_e]
134
135
 
135
136
  def test_setup(self):
136
137
  assert self.robot.parts == [self.part_a, self.part_b, self.part_c, self.part_d]
@@ -153,7 +154,7 @@ class RelationalRDRTestCase(TestCase):
153
154
  cat = scrdr.fit_case(CaseQuery(self.robot, "contained_objects", (PhysicalObject,), False), expert=expert)
154
155
  render_tree(scrdr.start_rule, use_dot_exporter=True,
155
156
  filename=self.test_results_dir + "/relational_scrdr_classify")
156
- assert cat == self.case_query.target(self.case_query.case)
157
+ assert cat == self.target
157
158
 
158
159
  if save_answers:
159
160
  cwd = os.getcwd()
@@ -172,4 +173,4 @@ class RelationalRDRTestCase(TestCase):
172
173
  conclusion = CallableExpression(user_input, list)
173
174
  print(conclusion)
174
175
  print(conclusion(self.robot))
175
- assert conclusion(self.robot) == self.case_query.target(self.case_query.case)
176
+ assert conclusion(self.robot) == self.target