ripple-down-rules 0.2.4__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/PKG-INFO +8 -1
  2. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/README.md +7 -0
  3. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/pyproject.toml +1 -1
  4. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datasets.py +71 -6
  5. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datastructures/callable_expression.py +13 -5
  6. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datastructures/case.py +33 -5
  7. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datastructures/dataclasses.py +30 -8
  8. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datastructures/enums.py +44 -1
  9. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/experts.py +16 -8
  10. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/rdr.py +13 -7
  11. ripple_down_rules-0.4.0/src/ripple_down_rules/rdr_decorators.py +139 -0
  12. ripple_down_rules-0.4.0/src/ripple_down_rules/user_interface/__init__.py +0 -0
  13. ripple_down_rules-0.4.0/src/ripple_down_rules/user_interface/gui.py +630 -0
  14. ripple_down_rules-0.4.0/src/ripple_down_rules/user_interface/ipython_custom_shell.py +146 -0
  15. ripple_down_rules-0.4.0/src/ripple_down_rules/user_interface/object_diagram.py +109 -0
  16. ripple_down_rules-0.4.0/src/ripple_down_rules/user_interface/prompt.py +159 -0
  17. ripple_down_rules-0.4.0/src/ripple_down_rules/user_interface/template_file_creator.py +293 -0
  18. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/utils.py +163 -19
  19. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules.egg-info/PKG-INFO +8 -1
  20. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules.egg-info/SOURCES.txt +8 -1
  21. ripple_down_rules-0.4.0/test/test_object_diagram.py +43 -0
  22. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_rdr.py +22 -6
  23. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_rdr_alchemy.py +6 -6
  24. ripple_down_rules-0.4.0/test/test_rdr_decorators.py +27 -0
  25. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_rdr_world.py +9 -1
  26. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_relational_rdr.py +6 -39
  27. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_relational_rdr_alchemy.py +15 -16
  28. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_sql_model.py +4 -4
  29. ripple_down_rules-0.2.4/src/ripple_down_rules/prompt.py +0 -404
  30. ripple_down_rules-0.2.4/src/ripple_down_rules/rdr_decorators.py +0 -55
  31. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/LICENSE +0 -0
  32. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/setup.cfg +0 -0
  33. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/__init__.py +0 -0
  34. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datastructures/__init__.py +0 -0
  35. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/failures.py +0 -0
  36. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/helpers.py +0 -0
  37. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/rules.py +0 -0
  38. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
  39. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
  40. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_json_serialization.py +0 -0
  41. {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_on_mutagenic.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.2.4
3
+ Version: 0.4.0
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -702,6 +702,13 @@ SCRDR, MCRDR, and GRDR implementation were inspired from the book:
702
702
  sudo apt-get install graphviz graphviz-dev
703
703
  pip install ripple_down_rules
704
704
  ```
705
+ For GUI support, also install:
706
+
707
+ ```bash
708
+ sudo apt-get install libxcb-cursor-dev
709
+ ```
710
+
711
+ ```bash
705
712
 
706
713
  ## Example Usage
707
714
 
@@ -15,6 +15,13 @@ SCRDR, MCRDR, and GRDR implementation were inspired from the book:
15
15
  sudo apt-get install graphviz graphviz-dev
16
16
  pip install ripple_down_rules
17
17
  ```
18
+ For GUI support, also install:
19
+
20
+ ```bash
21
+ sudo apt-get install libxcb-cursor-dev
22
+ ```
23
+
24
+ ```bash
18
25
 
19
26
  ## Example Usage
20
27
 
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
 
7
7
  [project]
8
8
  name = "ripple_down_rules"
9
- version = "0.2.4"
9
+ version = "0.4.0"
10
10
  description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
11
11
  readme = "README.md"
12
12
  authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
@@ -2,20 +2,24 @@ from __future__ import annotations
2
2
 
3
3
  import os
4
4
  import pickle
5
+ from dataclasses import dataclass, field
5
6
 
6
7
  import sqlalchemy
7
8
  from sqlalchemy import ForeignKey
8
- from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationship
9
- from typing_extensions import Tuple, List, Set, Optional
9
+ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationship, MappedColumn
10
+ from typing_extensions import Tuple, List, Set, Optional, Self
10
11
  from ucimlrepo import fetch_ucirepo
11
12
 
12
13
  from .datastructures.case import Case, create_cases_from_dataframe
13
14
  from .datastructures.enums import Category
15
+ from .rdr_decorators import RDRDecorator
14
16
 
15
17
 
16
18
  def load_cached_dataset(cache_file):
17
19
  """Loads the dataset from cache if it exists."""
18
20
  dataset = {}
21
+ if '.pkl' not in cache_file:
22
+ cache_file += ".pkl"
19
23
  for key in ["features", "targets", "ids"]:
20
24
  part_file = cache_file.replace(".pkl", f"_{key}.pkl")
21
25
  if not os.path.exists(part_file):
@@ -41,6 +45,9 @@ def save_dataset_to_cache(dataset, cache_file):
41
45
 
42
46
  def get_dataset(dataset_id, cache_file: Optional[str] = None):
43
47
  """Fetches dataset from cache or downloads it if not available."""
48
+ if cache_file is not None:
49
+ if not cache_file.endswith(".pkl"):
50
+ cache_file += ".pkl"
44
51
  dataset = load_cached_dataset(cache_file) if cache_file else None
45
52
  if dataset is None:
46
53
  print("Downloading dataset...")
@@ -106,8 +113,65 @@ class Habitat(Category):
106
113
  air = "air"
107
114
 
108
115
 
109
- # SpeciesCol = Column.create_from_enum(Species, mutually_exclusive=True)
110
- # HabitatCol = Column.create_from_enum(Habitat, mutually_exclusive=False)
116
+ class PhysicalObject:
117
+ """
118
+ A physical object is an object that can be contained in a container.
119
+ """
120
+ _rdr_json_dir: str = os.path.join(os.path.dirname(__file__), "../../test/test_results")
121
+ """
122
+ The directory where the RDR serialized JSON files are stored.
123
+ """
124
+ _rdr_python_dir: str = os.path.join(os.path.dirname(__file__), "../../test/test_generated_rdrs")
125
+ """
126
+ The directory where the RDR generated Python files are stored.
127
+ """
128
+ _is_a_robot_rdr: RDRDecorator = RDRDecorator(_rdr_json_dir, (bool,), True,
129
+ python_dir=_rdr_python_dir)
130
+ """
131
+ The RDR decorator that is used to determine if the object is a robot or not.
132
+ """
133
+ _select_parts_rdr: RDRDecorator = RDRDecorator(_rdr_json_dir, (Self,), False,
134
+ python_dir=_rdr_python_dir)
135
+ """
136
+ The RDR decorator that is used to determine if the object is a robot or not.
137
+ """
138
+
139
+ def __init__(self, name: str, contained_objects: Optional[List[PhysicalObject]] = None):
140
+ self.name: str = name
141
+ self._contained_objects: List[PhysicalObject] = contained_objects or []
142
+
143
+ @property
144
+ def contained_objects(self) -> List[PhysicalObject]:
145
+ return self._contained_objects
146
+
147
+ @contained_objects.setter
148
+ def contained_objects(self, value: List[PhysicalObject]):
149
+ self._contained_objects = value
150
+
151
+ @_is_a_robot_rdr.decorator
152
+ def is_a_robot(self) -> bool:
153
+ pass
154
+
155
+ @_select_parts_rdr.decorator
156
+ def select_objects_that_are_parts_of_robot(self, objects: List[PhysicalObject], robot: Robot) -> List[PhysicalObject]:
157
+ pass
158
+
159
+ def __str__(self):
160
+ return self.name
161
+
162
+ def __repr__(self):
163
+ return self.name
164
+
165
+
166
+ class Part(PhysicalObject):
167
+ ...
168
+
169
+
170
+ class Robot(PhysicalObject):
171
+
172
+ def __init__(self, name: str, parts: Optional[List[Part]] = None):
173
+ super().__init__(name)
174
+ self.parts: List[Part] = parts if parts else []
111
175
 
112
176
 
113
177
  class Base(sqlalchemy.orm.DeclarativeBase):
@@ -119,7 +183,7 @@ class HabitatTable(MappedAsDataclass, Base):
119
183
 
120
184
  id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
121
185
  habitat: Mapped[Habitat]
122
- animal_id = mapped_column(ForeignKey("Animal.id"), init=False)
186
+ animal_id: MappedColumn = mapped_column(ForeignKey("Animal.id"), init=False)
123
187
 
124
188
  def __hash__(self):
125
189
  return hash(self.habitat)
@@ -131,7 +195,7 @@ class HabitatTable(MappedAsDataclass, Base):
131
195
  return self.__str__()
132
196
 
133
197
 
134
- class Animal(MappedAsDataclass, Base):
198
+ class MappedAnimal(MappedAsDataclass, Base):
135
199
  __tablename__ = "Animal"
136
200
 
137
201
  id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
@@ -155,3 +219,4 @@ class Animal(MappedAsDataclass, Base):
155
219
  species: Mapped[Species] = mapped_column(nullable=True)
156
220
 
157
221
  habitats: Mapped[Set[HabitatTable]] = relationship(default_factory=set)
222
+
@@ -93,9 +93,12 @@ class CallableExpression(SubclassJSONSerializer):
93
93
  """
94
94
  encapsulating_function: str = "def _get_value(case):"
95
95
 
96
- def __init__(self, user_input: Optional[str] = None, conclusion_type: Optional[Tuple[Type]] = None,
96
+ def __init__(self, user_input: Optional[str] = None,
97
+ conclusion_type: Optional[Tuple[Type]] = None,
97
98
  expression_tree: Optional[AST] = None,
98
- scope: Optional[Dict[str, Any]] = None, conclusion: Optional[Any] = None):
99
+ scope: Optional[Dict[str, Any]] = None,
100
+ conclusion: Optional[Any] = None,
101
+ mutually_exclusive: bool = True):
99
102
  """
100
103
  Create a callable expression.
101
104
 
@@ -104,6 +107,8 @@ class CallableExpression(SubclassJSONSerializer):
104
107
  :param expression_tree: The AST tree parsed from the user input.
105
108
  :param scope: The scope to use for the callable expression.
106
109
  :param conclusion: The conclusion to use for the callable expression.
110
+ :param mutually_exclusive: If True, the conclusion is mutually exclusive, i.e. the callable expression can only
111
+ return one conclusion. If False, the callable expression can return multiple conclusions.
107
112
  """
108
113
  if user_input is None and conclusion is None:
109
114
  raise ValueError("Either user_input or conclusion must be provided.")
@@ -123,6 +128,7 @@ class CallableExpression(SubclassJSONSerializer):
123
128
  self.code = compile_expression_to_code(self.expression_tree)
124
129
  self.visitor = VariableVisitor()
125
130
  self.visitor.visit(self.expression_tree)
131
+ self.mutually_exclusive: bool = mutually_exclusive
126
132
 
127
133
  def __call__(self, case: Any, **kwargs) -> Any:
128
134
  try:
@@ -134,8 +140,8 @@ class CallableExpression(SubclassJSONSerializer):
134
140
  if output is None:
135
141
  output = scope['_get_value'](case)
136
142
  if self.conclusion_type is not None:
137
- if not any([issubclass(ct, (list, set)) for ct in self.conclusion_type]) and is_iterable(output):
138
- raise ValueError(f"Expected output to be {self.conclusion_type}, but got {type(output)}")
143
+ if self.mutually_exclusive and issubclass(type(output), (list, set)):
144
+ raise ValueError(f"Mutually exclusive types cannot be lists or sets, got {type(output)}")
139
145
  output_types = {type(o) for o in make_list(output)}
140
146
  output_types.add(type(output))
141
147
  if not are_results_subclass_of_types(output_types, self.conclusion_type):
@@ -229,6 +235,7 @@ class CallableExpression(SubclassJSONSerializer):
229
235
  "scope": {k: get_full_class_name(v) for k, v in self.scope.items()
230
236
  if hasattr(v, '__module__') and hasattr(v, '__name__')},
231
237
  "conclusion": conclusion_to_json(self.conclusion),
238
+ "mutually_exclusive": self.mutually_exclusive,
232
239
  }
233
240
 
234
241
  @classmethod
@@ -237,7 +244,8 @@ class CallableExpression(SubclassJSONSerializer):
237
244
  conclusion_type=tuple(get_type_from_string(t) for t in data["conclusion_type"])
238
245
  if data["conclusion_type"] else None,
239
246
  scope={k: get_type_from_string(v) for k, v in data["scope"].items()},
240
- conclusion=SubclassJSONSerializer.from_json(data["conclusion"]))
247
+ conclusion=SubclassJSONSerializer.from_json(data["conclusion"]),
248
+ mutually_exclusive=data["mutually_exclusive"])
241
249
 
242
250
 
243
251
  def compile_expression_to_code(expression_tree: AST) -> Any:
@@ -11,7 +11,7 @@ from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColum
11
11
  from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING
12
12
 
13
13
  from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer, \
14
- get_full_class_name, get_type_from_string, make_list, is_iterable, serialize_dataclass, dataclass_to_dict
14
+ get_full_class_name, get_type_from_string, make_list, is_iterable, serialize_dataclass, dataclass_to_dict, copy_case
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from ripple_down_rules.rules import Rule
@@ -65,7 +65,7 @@ class Case(UserDict, SubclassJSONSerializer):
65
65
  new_list.extend(make_list(value))
66
66
  super().__setitem__(name, new_list)
67
67
  else:
68
- super().__setitem__(name, self[name])
68
+ super().__setitem__(name, value)
69
69
  else:
70
70
  super().__setitem__(name, value)
71
71
  setattr(self, name, self[name])
@@ -102,6 +102,29 @@ class Case(UserDict, SubclassJSONSerializer):
102
102
  data[k] = SubclassJSONSerializer.from_json(v)
103
103
  return cls(_obj_type=obj_type, _id=id_, _name=name, **data)
104
104
 
105
+ def __deepcopy__(self, memo: Dict[Hashable, Any]) -> Case:
106
+ """
107
+ Create a deep copy of the case.
108
+
109
+ :param memo: A dictionary to keep track of objects that have already been copied.
110
+ :return: A deep copy of the case.
111
+ """
112
+ new_case = Case(self._obj_type, _id=self._id, _name=self._name, original_object=self._original_object)
113
+ for k, v in self.items():
114
+ new_case[k] = deepcopy(v)
115
+ return new_case
116
+
117
+ def __copy__(self) -> Case:
118
+ """
119
+ Create a shallow copy of the case.
120
+
121
+ :return: A shallow copy of the case.
122
+ """
123
+ new_case = Case(self._obj_type, _id=self._id, _name=self._name, original_object=self._original_object)
124
+ for k, v in self.items():
125
+ new_case[k] = copy(v)
126
+ return new_case
127
+
105
128
 
106
129
  @dataclass
107
130
  class CaseAttributeValue(SubclassJSONSerializer):
@@ -220,11 +243,16 @@ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
220
243
  return create_cases_from_dataframe(obj, obj_name)
221
244
  if isinstance(obj, Case) or (is_dataclass(obj) and not isinstance(obj, SQLTable)):
222
245
  return obj
223
- if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
246
+ if ((recursion_idx > max_recursion_idx)
247
+ or (obj.__class__.__module__ == "builtins" and not isinstance(obj, (list, set, dict)))
224
248
  or (obj.__class__ in [MetaData, registry])):
225
249
  return Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj,
226
250
  **{obj_name or obj.__class__.__name__: make_list(obj) if parent_is_iterable else obj})
227
251
  case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
252
+ if isinstance(obj, dict):
253
+ for k, v in obj.items():
254
+ case = create_or_update_case_from_attribute(v, k, obj, obj_name, recursion_idx,
255
+ max_recursion_idx, parent_is_iterable, case)
228
256
  for attr in dir(obj):
229
257
  if attr.startswith("_") or callable(getattr(obj, attr)):
230
258
  continue
@@ -322,9 +350,9 @@ def show_current_and_corner_cases(case: Any, targets: Optional[Dict[str, Any]] =
322
350
  if last_evaluated_rule and last_evaluated_rule.fired:
323
351
  corner_row_dict = dataclass_to_dict(last_evaluated_rule.corner_case)
324
352
  else:
325
- case_dict = case
353
+ case_dict = copy_case(case)
326
354
  if last_evaluated_rule and last_evaluated_rule.fired:
327
- corner_row_dict = corner_case
355
+ corner_row_dict = copy_case(corner_case)
328
356
 
329
357
  if corner_row_dict:
330
358
  corner_conclusion = last_evaluated_rule.conclusion(case)
@@ -1,15 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
+ import typing
4
5
  from dataclasses import dataclass, field
5
6
 
6
7
  import typing_extensions
7
8
  from sqlalchemy.orm import DeclarativeBase as SQLTable
8
- from typing_extensions import Any, Optional, Dict, Type, Tuple, Union, List
9
+ from typing_extensions import Any, Optional, Dict, Type, Tuple, Union, List, get_origin, Set
9
10
 
10
11
  from .callable_expression import CallableExpression
11
12
  from .case import create_case, Case
12
- from ..utils import copy_case, make_list, make_set
13
+ from ..utils import copy_case, make_list, make_set, get_origin_and_args_from_type_hint, get_value_type_from_type_hint, \
14
+ typing_to_python_type
13
15
 
14
16
 
15
17
  @dataclass
@@ -61,6 +63,15 @@ class CaseQuery:
61
63
  """
62
64
  The conditions that must be satisfied for the target value to be valid.
63
65
  """
66
+ is_function: bool = False
67
+ """
68
+ Whether the case is a dict representing the arguments of an actual function or not,
69
+ most likely means it came from RDRDecorator, the the rdr takes function arguments and outputs the function output.
70
+ """
71
+ function_args_type_hints: Optional[Dict[str, Type]] = None
72
+ """
73
+ The type hints of the function arguments. This is used to recreate the function signature.
74
+ """
64
75
 
65
76
  @property
66
77
  def case_type(self) -> Type:
@@ -117,10 +128,20 @@ class CaseQuery:
117
128
  """
118
129
  :return: The type of the attribute.
119
130
  """
120
- if not self.mutually_exclusive and (list not in make_list(self._attribute_types)):
121
- self._attribute_types = tuple(set(make_list(self._attribute_types) + [set, list]))
122
- elif not isinstance(self._attribute_types, tuple):
123
- self._attribute_types = tuple(make_list(self._attribute_types))
131
+ if not isinstance(self._attribute_types, tuple):
132
+ self._attribute_types = tuple(make_set(self._attribute_types))
133
+ origin, args = get_origin_and_args_from_type_hint(self._attribute_types)
134
+ if origin is not None:
135
+ att_types = make_set(origin)
136
+ if origin in (list, set, tuple, List, Set, Union, Tuple):
137
+ att_types.update(make_set(args))
138
+ elif origin in (dict, Dict):
139
+ # ignore the key type
140
+ if args and len(args) > 1:
141
+ att_types.update(make_set(args[1]))
142
+ self._attribute_types = tuple(att_types)
143
+ if not self.mutually_exclusive and (list not in self._attribute_types):
144
+ self._attribute_types = tuple(make_list(self._attribute_types) + [set, list])
124
145
  return self._attribute_types
125
146
 
126
147
  @attribute_type.setter
@@ -151,7 +172,7 @@ class CaseQuery:
151
172
  """
152
173
  if (self._target is not None) and (not isinstance(self._target, CallableExpression)):
153
174
  self._target = CallableExpression(conclusion=self._target, conclusion_type=self.attribute_type,
154
- scope=self.scope)
175
+ scope=self.scope, mutually_exclusive=self.mutually_exclusive)
155
176
  return self._target
156
177
 
157
178
  @target.setter
@@ -195,4 +216,5 @@ class CaseQuery:
195
216
  return CaseQuery(self.original_case, self.attribute_name, self.attribute_type,
196
217
  self.mutually_exclusive, _target=self.target, default_value=self.default_value,
197
218
  scope=self.scope, _case=copy_case(self.case), _target_value=self.target_value,
198
- conditions=self.conditions)
219
+ conditions=self.conditions, is_function=self.is_function,
220
+ function_args_type_hints=self.function_args_type_hints)
@@ -2,11 +2,54 @@ from __future__ import annotations
2
2
 
3
3
  from enum import auto, Enum
4
4
 
5
- from typing_extensions import List, Dict, Any
5
+ from typing_extensions import List, Dict, Any, Type
6
6
 
7
7
  from ripple_down_rules.utils import SubclassJSONSerializer
8
8
 
9
9
 
10
+ class InteractionMode(Enum):
11
+ """
12
+ The interaction mode of the RDR.
13
+ """
14
+ IPythonOnly = auto()
15
+ """
16
+ IPythonOnly mode, the mode where the user uses only an Ipython shell to interact with the RDR.
17
+ """
18
+ GUI = auto()
19
+ """
20
+ GUI mode, the mode where the user uses a GUI to interact with the RDR.
21
+ """
22
+
23
+
24
+ class Editor(str, Enum):
25
+ """
26
+ The editor that is used to edit the rules.
27
+ """
28
+ Pycharm = "pycharm"
29
+ """
30
+ PyCharm editor.
31
+ """
32
+ Code = "code"
33
+ """
34
+ Visual Studio Code editor.
35
+ """
36
+ CodeServer = "code-server"
37
+ """
38
+ Visual Studio Code server editor.
39
+ """
40
+ @classmethod
41
+ def from_str(cls, editor: str) -> Editor:
42
+ """
43
+ Convert a string value to an Editor enum.
44
+
45
+ :param editor: The string that represents the editor name.
46
+ :return: The Editor enum.
47
+ """
48
+ if editor not in cls._value2member_map_:
49
+ raise ValueError(f"Editor {editor} is not supported.")
50
+ return cls._value2member_map_[editor]
51
+
52
+
10
53
  class Category(str, SubclassJSONSerializer, Enum):
11
54
 
12
55
  @classmethod
@@ -3,15 +3,14 @@ from __future__ import annotations
3
3
  import json
4
4
  from abc import ABC, abstractmethod
5
5
 
6
- from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Type, Any
6
+ from typing_extensions import Optional, TYPE_CHECKING, List
7
7
 
8
- from .datastructures.case import Case, CaseAttribute
9
8
  from .datastructures.callable_expression import CallableExpression
10
9
  from .datastructures.enums import PromptFor
11
10
  from .datastructures.dataclasses import CaseQuery
12
11
  from .datastructures.case import show_current_and_corner_cases
13
- from .prompt import prompt_user_for_expression, IPythonShell
14
- from .utils import get_all_subclasses, make_list
12
+ from .user_interface.gui import RDRCaseViewer
13
+ from .user_interface.prompt import UserPrompt
15
14
 
16
15
  if TYPE_CHECKING:
17
16
  from .rdr import Rule
@@ -66,12 +65,20 @@ class Human(Expert):
66
65
  The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
67
66
  """
68
67
 
68
+ def __init__(self, use_loaded_answers: bool = False, append: bool = False, viewer: Optional[RDRCaseViewer] = None):
69
+ """
70
+ Initialize the Human expert.
71
+
72
+ :param viewer: The RDRCaseViewer instance to use for prompting the user.
73
+ """
74
+ super().__init__(use_loaded_answers=use_loaded_answers, append=append)
75
+ self.user_prompt = UserPrompt(viewer)
76
+
69
77
  def save_answers(self, path: str):
70
78
  """
71
79
  Save the expert answers to a file.
72
80
 
73
81
  :param path: The path to save the answers to.
74
- :param append: A flag to indicate if the answers should be appended to the file or not.
75
82
  """
76
83
  if self.append:
77
84
  # read the file and append the new answers
@@ -118,7 +125,7 @@ class Human(Expert):
118
125
  if user_input:
119
126
  condition = CallableExpression(user_input, bool, scope=case_query.scope)
120
127
  else:
121
- user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions)
128
+ user_input, condition = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conditions)
122
129
  if not self.use_loaded_answers:
123
130
  self.all_expert_answers.append(user_input)
124
131
  case_query.conditions = condition
@@ -138,10 +145,11 @@ class Human(Expert):
138
145
  expert_input = self.all_expert_answers.pop(0)
139
146
  if expert_input is not None:
140
147
  expression = CallableExpression(expert_input, case_query.attribute_type,
141
- scope=case_query.scope)
148
+ scope=case_query.scope,
149
+ mutually_exclusive=case_query.mutually_exclusive)
142
150
  else:
143
151
  show_current_and_corner_cases(case_query.case)
144
- expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion)
152
+ expert_input, expression = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conclusion)
145
153
  self.all_expert_answers.append(expert_input)
146
154
  case_query.target = expression
147
155
  return expression
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import copyreg
3
4
  import importlib
4
5
  import sys
5
6
  from abc import ABC, abstractmethod
@@ -96,11 +97,12 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
96
97
  plt.ioff()
97
98
  plt.show()
98
99
 
99
- def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
100
+ def __call__(self, case: Union[Case, SQLTable]) -> Union[CallableExpression, Dict[str, CallableExpression]]:
100
101
  return self.classify(case)
101
102
 
102
103
  @abstractmethod
103
- def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) -> Optional[CaseAttribute]:
104
+ def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) \
105
+ -> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
104
106
  """
105
107
  Classify a case.
106
108
 
@@ -111,7 +113,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
111
113
  pass
112
114
 
113
115
  def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
114
- -> Union[CaseAttribute, CallableExpression]:
116
+ -> Union[CallableExpression, Dict[str, CallableExpression]]:
115
117
  """
116
118
  Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
117
119
  incorrect by comparing the case with the target category.
@@ -136,7 +138,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
136
138
 
137
139
  @abstractmethod
138
140
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
139
- -> Union[CaseAttribute, CallableExpression]:
141
+ -> Union[CallableExpression, Dict[str, CallableExpression]]:
140
142
  """
141
143
  Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
142
144
  comparing the case with the target category.
@@ -881,10 +883,12 @@ class GeneralRDR(RippleDownRules):
881
883
  f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
882
884
 
883
885
  @property
884
- def case_type(self) -> Type:
886
+ def case_type(self) -> Optional[Type]:
885
887
  """
886
888
  :return: The type of the case (input) to the RDR classifier.
887
889
  """
890
+ if self.start_rule is None or self.start_rule.corner_case is None:
891
+ return None
888
892
  if isinstance(self.start_rule.corner_case, Case):
889
893
  return self.start_rule.corner_case._obj_type
890
894
  else:
@@ -898,10 +902,12 @@ class GeneralRDR(RippleDownRules):
898
902
  return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
899
903
 
900
904
  @property
901
- def _default_generated_python_file_name(self) -> str:
905
+ def _default_generated_python_file_name(self) -> Optional[str]:
902
906
  """
903
907
  :return: The default generated python file name.
904
908
  """
909
+ if self.start_rule is None or self.start_rule.corner_case is None:
910
+ return None
905
911
  if isinstance(self.start_rule.corner_case, Case):
906
912
  name = self.start_rule.corner_case._name
907
913
  else:
@@ -929,7 +935,7 @@ class GeneralRDR(RippleDownRules):
929
935
  imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
930
936
  # add rdr python generated functions.
931
937
  for rdr_key, rdr in self.start_rules_dict.items():
932
- imports += (f"from {file_path.strip('./')}"
938
+ imports += (f"from ."
933
939
  f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
934
940
  return imports
935
941