ripple-down-rules 0.2.4__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/PKG-INFO +8 -1
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/README.md +7 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/pyproject.toml +1 -1
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datasets.py +71 -6
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datastructures/callable_expression.py +13 -5
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datastructures/case.py +33 -5
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datastructures/dataclasses.py +30 -8
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datastructures/enums.py +44 -1
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/experts.py +16 -8
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/rdr.py +13 -7
- ripple_down_rules-0.4.0/src/ripple_down_rules/rdr_decorators.py +139 -0
- ripple_down_rules-0.4.0/src/ripple_down_rules/user_interface/__init__.py +0 -0
- ripple_down_rules-0.4.0/src/ripple_down_rules/user_interface/gui.py +630 -0
- ripple_down_rules-0.4.0/src/ripple_down_rules/user_interface/ipython_custom_shell.py +146 -0
- ripple_down_rules-0.4.0/src/ripple_down_rules/user_interface/object_diagram.py +109 -0
- ripple_down_rules-0.4.0/src/ripple_down_rules/user_interface/prompt.py +159 -0
- ripple_down_rules-0.4.0/src/ripple_down_rules/user_interface/template_file_creator.py +293 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/utils.py +163 -19
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules.egg-info/PKG-INFO +8 -1
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules.egg-info/SOURCES.txt +8 -1
- ripple_down_rules-0.4.0/test/test_object_diagram.py +43 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_rdr.py +22 -6
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_rdr_alchemy.py +6 -6
- ripple_down_rules-0.4.0/test/test_rdr_decorators.py +27 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_rdr_world.py +9 -1
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_relational_rdr.py +6 -39
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_relational_rdr_alchemy.py +15 -16
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_sql_model.py +4 -4
- ripple_down_rules-0.2.4/src/ripple_down_rules/prompt.py +0 -404
- ripple_down_rules-0.2.4/src/ripple_down_rules/rdr_decorators.py +0 -55
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/LICENSE +0 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/setup.cfg +0 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/__init__.py +0 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datastructures/__init__.py +0 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/failures.py +0 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/helpers.py +0 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/rules.py +0 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_json_serialization.py +0 -0
- {ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/test/test_on_mutagenic.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.0
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -702,6 +702,13 @@ SCRDR, MCRDR, and GRDR implementation were inspired from the book:
|
|
702
702
|
sudo apt-get install graphviz graphviz-dev
|
703
703
|
pip install ripple_down_rules
|
704
704
|
```
|
705
|
+
For GUI support, also install:
|
706
|
+
|
707
|
+
```bash
|
708
|
+
sudo apt-get install libxcb-cursor-dev
|
709
|
+
```
|
710
|
+
|
711
|
+
```bash
|
705
712
|
|
706
713
|
## Example Usage
|
707
714
|
|
@@ -15,6 +15,13 @@ SCRDR, MCRDR, and GRDR implementation were inspired from the book:
|
|
15
15
|
sudo apt-get install graphviz graphviz-dev
|
16
16
|
pip install ripple_down_rules
|
17
17
|
```
|
18
|
+
For GUI support, also install:
|
19
|
+
|
20
|
+
```bash
|
21
|
+
sudo apt-get install libxcb-cursor-dev
|
22
|
+
```
|
23
|
+
|
24
|
+
```bash
|
18
25
|
|
19
26
|
## Example Usage
|
20
27
|
|
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
|
|
6
6
|
|
7
7
|
[project]
|
8
8
|
name = "ripple_down_rules"
|
9
|
-
version = "0.
|
9
|
+
version = "0.4.0"
|
10
10
|
description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
|
11
11
|
readme = "README.md"
|
12
12
|
authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
|
@@ -2,20 +2,24 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import os
|
4
4
|
import pickle
|
5
|
+
from dataclasses import dataclass, field
|
5
6
|
|
6
7
|
import sqlalchemy
|
7
8
|
from sqlalchemy import ForeignKey
|
8
|
-
from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationship
|
9
|
-
from typing_extensions import Tuple, List, Set, Optional
|
9
|
+
from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationship, MappedColumn
|
10
|
+
from typing_extensions import Tuple, List, Set, Optional, Self
|
10
11
|
from ucimlrepo import fetch_ucirepo
|
11
12
|
|
12
13
|
from .datastructures.case import Case, create_cases_from_dataframe
|
13
14
|
from .datastructures.enums import Category
|
15
|
+
from .rdr_decorators import RDRDecorator
|
14
16
|
|
15
17
|
|
16
18
|
def load_cached_dataset(cache_file):
|
17
19
|
"""Loads the dataset from cache if it exists."""
|
18
20
|
dataset = {}
|
21
|
+
if '.pkl' not in cache_file:
|
22
|
+
cache_file += ".pkl"
|
19
23
|
for key in ["features", "targets", "ids"]:
|
20
24
|
part_file = cache_file.replace(".pkl", f"_{key}.pkl")
|
21
25
|
if not os.path.exists(part_file):
|
@@ -41,6 +45,9 @@ def save_dataset_to_cache(dataset, cache_file):
|
|
41
45
|
|
42
46
|
def get_dataset(dataset_id, cache_file: Optional[str] = None):
|
43
47
|
"""Fetches dataset from cache or downloads it if not available."""
|
48
|
+
if cache_file is not None:
|
49
|
+
if not cache_file.endswith(".pkl"):
|
50
|
+
cache_file += ".pkl"
|
44
51
|
dataset = load_cached_dataset(cache_file) if cache_file else None
|
45
52
|
if dataset is None:
|
46
53
|
print("Downloading dataset...")
|
@@ -106,8 +113,65 @@ class Habitat(Category):
|
|
106
113
|
air = "air"
|
107
114
|
|
108
115
|
|
109
|
-
|
110
|
-
|
116
|
+
class PhysicalObject:
|
117
|
+
"""
|
118
|
+
A physical object is an object that can be contained in a container.
|
119
|
+
"""
|
120
|
+
_rdr_json_dir: str = os.path.join(os.path.dirname(__file__), "../../test/test_results")
|
121
|
+
"""
|
122
|
+
The directory where the RDR serialized JSON files are stored.
|
123
|
+
"""
|
124
|
+
_rdr_python_dir: str = os.path.join(os.path.dirname(__file__), "../../test/test_generated_rdrs")
|
125
|
+
"""
|
126
|
+
The directory where the RDR generated Python files are stored.
|
127
|
+
"""
|
128
|
+
_is_a_robot_rdr: RDRDecorator = RDRDecorator(_rdr_json_dir, (bool,), True,
|
129
|
+
python_dir=_rdr_python_dir)
|
130
|
+
"""
|
131
|
+
The RDR decorator that is used to determine if the object is a robot or not.
|
132
|
+
"""
|
133
|
+
_select_parts_rdr: RDRDecorator = RDRDecorator(_rdr_json_dir, (Self,), False,
|
134
|
+
python_dir=_rdr_python_dir)
|
135
|
+
"""
|
136
|
+
The RDR decorator that is used to determine if the object is a robot or not.
|
137
|
+
"""
|
138
|
+
|
139
|
+
def __init__(self, name: str, contained_objects: Optional[List[PhysicalObject]] = None):
|
140
|
+
self.name: str = name
|
141
|
+
self._contained_objects: List[PhysicalObject] = contained_objects or []
|
142
|
+
|
143
|
+
@property
|
144
|
+
def contained_objects(self) -> List[PhysicalObject]:
|
145
|
+
return self._contained_objects
|
146
|
+
|
147
|
+
@contained_objects.setter
|
148
|
+
def contained_objects(self, value: List[PhysicalObject]):
|
149
|
+
self._contained_objects = value
|
150
|
+
|
151
|
+
@_is_a_robot_rdr.decorator
|
152
|
+
def is_a_robot(self) -> bool:
|
153
|
+
pass
|
154
|
+
|
155
|
+
@_select_parts_rdr.decorator
|
156
|
+
def select_objects_that_are_parts_of_robot(self, objects: List[PhysicalObject], robot: Robot) -> List[PhysicalObject]:
|
157
|
+
pass
|
158
|
+
|
159
|
+
def __str__(self):
|
160
|
+
return self.name
|
161
|
+
|
162
|
+
def __repr__(self):
|
163
|
+
return self.name
|
164
|
+
|
165
|
+
|
166
|
+
class Part(PhysicalObject):
|
167
|
+
...
|
168
|
+
|
169
|
+
|
170
|
+
class Robot(PhysicalObject):
|
171
|
+
|
172
|
+
def __init__(self, name: str, parts: Optional[List[Part]] = None):
|
173
|
+
super().__init__(name)
|
174
|
+
self.parts: List[Part] = parts if parts else []
|
111
175
|
|
112
176
|
|
113
177
|
class Base(sqlalchemy.orm.DeclarativeBase):
|
@@ -119,7 +183,7 @@ class HabitatTable(MappedAsDataclass, Base):
|
|
119
183
|
|
120
184
|
id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
|
121
185
|
habitat: Mapped[Habitat]
|
122
|
-
animal_id = mapped_column(ForeignKey("Animal.id"), init=False)
|
186
|
+
animal_id: MappedColumn = mapped_column(ForeignKey("Animal.id"), init=False)
|
123
187
|
|
124
188
|
def __hash__(self):
|
125
189
|
return hash(self.habitat)
|
@@ -131,7 +195,7 @@ class HabitatTable(MappedAsDataclass, Base):
|
|
131
195
|
return self.__str__()
|
132
196
|
|
133
197
|
|
134
|
-
class
|
198
|
+
class MappedAnimal(MappedAsDataclass, Base):
|
135
199
|
__tablename__ = "Animal"
|
136
200
|
|
137
201
|
id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
|
@@ -155,3 +219,4 @@ class Animal(MappedAsDataclass, Base):
|
|
155
219
|
species: Mapped[Species] = mapped_column(nullable=True)
|
156
220
|
|
157
221
|
habitats: Mapped[Set[HabitatTable]] = relationship(default_factory=set)
|
222
|
+
|
@@ -93,9 +93,12 @@ class CallableExpression(SubclassJSONSerializer):
|
|
93
93
|
"""
|
94
94
|
encapsulating_function: str = "def _get_value(case):"
|
95
95
|
|
96
|
-
def __init__(self, user_input: Optional[str] = None,
|
96
|
+
def __init__(self, user_input: Optional[str] = None,
|
97
|
+
conclusion_type: Optional[Tuple[Type]] = None,
|
97
98
|
expression_tree: Optional[AST] = None,
|
98
|
-
scope: Optional[Dict[str, Any]] = None,
|
99
|
+
scope: Optional[Dict[str, Any]] = None,
|
100
|
+
conclusion: Optional[Any] = None,
|
101
|
+
mutually_exclusive: bool = True):
|
99
102
|
"""
|
100
103
|
Create a callable expression.
|
101
104
|
|
@@ -104,6 +107,8 @@ class CallableExpression(SubclassJSONSerializer):
|
|
104
107
|
:param expression_tree: The AST tree parsed from the user input.
|
105
108
|
:param scope: The scope to use for the callable expression.
|
106
109
|
:param conclusion: The conclusion to use for the callable expression.
|
110
|
+
:param mutually_exclusive: If True, the conclusion is mutually exclusive, i.e. the callable expression can only
|
111
|
+
return one conclusion. If False, the callable expression can return multiple conclusions.
|
107
112
|
"""
|
108
113
|
if user_input is None and conclusion is None:
|
109
114
|
raise ValueError("Either user_input or conclusion must be provided.")
|
@@ -123,6 +128,7 @@ class CallableExpression(SubclassJSONSerializer):
|
|
123
128
|
self.code = compile_expression_to_code(self.expression_tree)
|
124
129
|
self.visitor = VariableVisitor()
|
125
130
|
self.visitor.visit(self.expression_tree)
|
131
|
+
self.mutually_exclusive: bool = mutually_exclusive
|
126
132
|
|
127
133
|
def __call__(self, case: Any, **kwargs) -> Any:
|
128
134
|
try:
|
@@ -134,8 +140,8 @@ class CallableExpression(SubclassJSONSerializer):
|
|
134
140
|
if output is None:
|
135
141
|
output = scope['_get_value'](case)
|
136
142
|
if self.conclusion_type is not None:
|
137
|
-
if
|
138
|
-
raise ValueError(f"
|
143
|
+
if self.mutually_exclusive and issubclass(type(output), (list, set)):
|
144
|
+
raise ValueError(f"Mutually exclusive types cannot be lists or sets, got {type(output)}")
|
139
145
|
output_types = {type(o) for o in make_list(output)}
|
140
146
|
output_types.add(type(output))
|
141
147
|
if not are_results_subclass_of_types(output_types, self.conclusion_type):
|
@@ -229,6 +235,7 @@ class CallableExpression(SubclassJSONSerializer):
|
|
229
235
|
"scope": {k: get_full_class_name(v) for k, v in self.scope.items()
|
230
236
|
if hasattr(v, '__module__') and hasattr(v, '__name__')},
|
231
237
|
"conclusion": conclusion_to_json(self.conclusion),
|
238
|
+
"mutually_exclusive": self.mutually_exclusive,
|
232
239
|
}
|
233
240
|
|
234
241
|
@classmethod
|
@@ -237,7 +244,8 @@ class CallableExpression(SubclassJSONSerializer):
|
|
237
244
|
conclusion_type=tuple(get_type_from_string(t) for t in data["conclusion_type"])
|
238
245
|
if data["conclusion_type"] else None,
|
239
246
|
scope={k: get_type_from_string(v) for k, v in data["scope"].items()},
|
240
|
-
conclusion=SubclassJSONSerializer.from_json(data["conclusion"])
|
247
|
+
conclusion=SubclassJSONSerializer.from_json(data["conclusion"]),
|
248
|
+
mutually_exclusive=data["mutually_exclusive"])
|
241
249
|
|
242
250
|
|
243
251
|
def compile_expression_to_code(expression_tree: AST) -> Any:
|
{ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datastructures/case.py
RENAMED
@@ -11,7 +11,7 @@ from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColum
|
|
11
11
|
from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING
|
12
12
|
|
13
13
|
from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer, \
|
14
|
-
get_full_class_name, get_type_from_string, make_list, is_iterable, serialize_dataclass, dataclass_to_dict
|
14
|
+
get_full_class_name, get_type_from_string, make_list, is_iterable, serialize_dataclass, dataclass_to_dict, copy_case
|
15
15
|
|
16
16
|
if TYPE_CHECKING:
|
17
17
|
from ripple_down_rules.rules import Rule
|
@@ -65,7 +65,7 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
65
65
|
new_list.extend(make_list(value))
|
66
66
|
super().__setitem__(name, new_list)
|
67
67
|
else:
|
68
|
-
super().__setitem__(name,
|
68
|
+
super().__setitem__(name, value)
|
69
69
|
else:
|
70
70
|
super().__setitem__(name, value)
|
71
71
|
setattr(self, name, self[name])
|
@@ -102,6 +102,29 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
102
102
|
data[k] = SubclassJSONSerializer.from_json(v)
|
103
103
|
return cls(_obj_type=obj_type, _id=id_, _name=name, **data)
|
104
104
|
|
105
|
+
def __deepcopy__(self, memo: Dict[Hashable, Any]) -> Case:
|
106
|
+
"""
|
107
|
+
Create a deep copy of the case.
|
108
|
+
|
109
|
+
:param memo: A dictionary to keep track of objects that have already been copied.
|
110
|
+
:return: A deep copy of the case.
|
111
|
+
"""
|
112
|
+
new_case = Case(self._obj_type, _id=self._id, _name=self._name, original_object=self._original_object)
|
113
|
+
for k, v in self.items():
|
114
|
+
new_case[k] = deepcopy(v)
|
115
|
+
return new_case
|
116
|
+
|
117
|
+
def __copy__(self) -> Case:
|
118
|
+
"""
|
119
|
+
Create a shallow copy of the case.
|
120
|
+
|
121
|
+
:return: A shallow copy of the case.
|
122
|
+
"""
|
123
|
+
new_case = Case(self._obj_type, _id=self._id, _name=self._name, original_object=self._original_object)
|
124
|
+
for k, v in self.items():
|
125
|
+
new_case[k] = copy(v)
|
126
|
+
return new_case
|
127
|
+
|
105
128
|
|
106
129
|
@dataclass
|
107
130
|
class CaseAttributeValue(SubclassJSONSerializer):
|
@@ -220,11 +243,16 @@ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
|
|
220
243
|
return create_cases_from_dataframe(obj, obj_name)
|
221
244
|
if isinstance(obj, Case) or (is_dataclass(obj) and not isinstance(obj, SQLTable)):
|
222
245
|
return obj
|
223
|
-
if ((recursion_idx > max_recursion_idx)
|
246
|
+
if ((recursion_idx > max_recursion_idx)
|
247
|
+
or (obj.__class__.__module__ == "builtins" and not isinstance(obj, (list, set, dict)))
|
224
248
|
or (obj.__class__ in [MetaData, registry])):
|
225
249
|
return Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj,
|
226
250
|
**{obj_name or obj.__class__.__name__: make_list(obj) if parent_is_iterable else obj})
|
227
251
|
case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
|
252
|
+
if isinstance(obj, dict):
|
253
|
+
for k, v in obj.items():
|
254
|
+
case = create_or_update_case_from_attribute(v, k, obj, obj_name, recursion_idx,
|
255
|
+
max_recursion_idx, parent_is_iterable, case)
|
228
256
|
for attr in dir(obj):
|
229
257
|
if attr.startswith("_") or callable(getattr(obj, attr)):
|
230
258
|
continue
|
@@ -322,9 +350,9 @@ def show_current_and_corner_cases(case: Any, targets: Optional[Dict[str, Any]] =
|
|
322
350
|
if last_evaluated_rule and last_evaluated_rule.fired:
|
323
351
|
corner_row_dict = dataclass_to_dict(last_evaluated_rule.corner_case)
|
324
352
|
else:
|
325
|
-
case_dict = case
|
353
|
+
case_dict = copy_case(case)
|
326
354
|
if last_evaluated_rule and last_evaluated_rule.fired:
|
327
|
-
corner_row_dict = corner_case
|
355
|
+
corner_row_dict = copy_case(corner_case)
|
328
356
|
|
329
357
|
if corner_row_dict:
|
330
358
|
corner_conclusion = last_evaluated_rule.conclusion(case)
|
@@ -1,15 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import inspect
|
4
|
+
import typing
|
4
5
|
from dataclasses import dataclass, field
|
5
6
|
|
6
7
|
import typing_extensions
|
7
8
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
8
|
-
from typing_extensions import Any, Optional, Dict, Type, Tuple, Union, List
|
9
|
+
from typing_extensions import Any, Optional, Dict, Type, Tuple, Union, List, get_origin, Set
|
9
10
|
|
10
11
|
from .callable_expression import CallableExpression
|
11
12
|
from .case import create_case, Case
|
12
|
-
from ..utils import copy_case, make_list, make_set
|
13
|
+
from ..utils import copy_case, make_list, make_set, get_origin_and_args_from_type_hint, get_value_type_from_type_hint, \
|
14
|
+
typing_to_python_type
|
13
15
|
|
14
16
|
|
15
17
|
@dataclass
|
@@ -61,6 +63,15 @@ class CaseQuery:
|
|
61
63
|
"""
|
62
64
|
The conditions that must be satisfied for the target value to be valid.
|
63
65
|
"""
|
66
|
+
is_function: bool = False
|
67
|
+
"""
|
68
|
+
Whether the case is a dict representing the arguments of an actual function or not,
|
69
|
+
most likely means it came from RDRDecorator, the the rdr takes function arguments and outputs the function output.
|
70
|
+
"""
|
71
|
+
function_args_type_hints: Optional[Dict[str, Type]] = None
|
72
|
+
"""
|
73
|
+
The type hints of the function arguments. This is used to recreate the function signature.
|
74
|
+
"""
|
64
75
|
|
65
76
|
@property
|
66
77
|
def case_type(self) -> Type:
|
@@ -117,10 +128,20 @@ class CaseQuery:
|
|
117
128
|
"""
|
118
129
|
:return: The type of the attribute.
|
119
130
|
"""
|
120
|
-
if not
|
121
|
-
self._attribute_types = tuple(
|
122
|
-
|
123
|
-
|
131
|
+
if not isinstance(self._attribute_types, tuple):
|
132
|
+
self._attribute_types = tuple(make_set(self._attribute_types))
|
133
|
+
origin, args = get_origin_and_args_from_type_hint(self._attribute_types)
|
134
|
+
if origin is not None:
|
135
|
+
att_types = make_set(origin)
|
136
|
+
if origin in (list, set, tuple, List, Set, Union, Tuple):
|
137
|
+
att_types.update(make_set(args))
|
138
|
+
elif origin in (dict, Dict):
|
139
|
+
# ignore the key type
|
140
|
+
if args and len(args) > 1:
|
141
|
+
att_types.update(make_set(args[1]))
|
142
|
+
self._attribute_types = tuple(att_types)
|
143
|
+
if not self.mutually_exclusive and (list not in self._attribute_types):
|
144
|
+
self._attribute_types = tuple(make_list(self._attribute_types) + [set, list])
|
124
145
|
return self._attribute_types
|
125
146
|
|
126
147
|
@attribute_type.setter
|
@@ -151,7 +172,7 @@ class CaseQuery:
|
|
151
172
|
"""
|
152
173
|
if (self._target is not None) and (not isinstance(self._target, CallableExpression)):
|
153
174
|
self._target = CallableExpression(conclusion=self._target, conclusion_type=self.attribute_type,
|
154
|
-
scope=self.scope)
|
175
|
+
scope=self.scope, mutually_exclusive=self.mutually_exclusive)
|
155
176
|
return self._target
|
156
177
|
|
157
178
|
@target.setter
|
@@ -195,4 +216,5 @@ class CaseQuery:
|
|
195
216
|
return CaseQuery(self.original_case, self.attribute_name, self.attribute_type,
|
196
217
|
self.mutually_exclusive, _target=self.target, default_value=self.default_value,
|
197
218
|
scope=self.scope, _case=copy_case(self.case), _target_value=self.target_value,
|
198
|
-
conditions=self.conditions
|
219
|
+
conditions=self.conditions, is_function=self.is_function,
|
220
|
+
function_args_type_hints=self.function_args_type_hints)
|
{ripple_down_rules-0.2.4 → ripple_down_rules-0.4.0}/src/ripple_down_rules/datastructures/enums.py
RENAMED
@@ -2,11 +2,54 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from enum import auto, Enum
|
4
4
|
|
5
|
-
from typing_extensions import List, Dict, Any
|
5
|
+
from typing_extensions import List, Dict, Any, Type
|
6
6
|
|
7
7
|
from ripple_down_rules.utils import SubclassJSONSerializer
|
8
8
|
|
9
9
|
|
10
|
+
class InteractionMode(Enum):
|
11
|
+
"""
|
12
|
+
The interaction mode of the RDR.
|
13
|
+
"""
|
14
|
+
IPythonOnly = auto()
|
15
|
+
"""
|
16
|
+
IPythonOnly mode, the mode where the user uses only an Ipython shell to interact with the RDR.
|
17
|
+
"""
|
18
|
+
GUI = auto()
|
19
|
+
"""
|
20
|
+
GUI mode, the mode where the user uses a GUI to interact with the RDR.
|
21
|
+
"""
|
22
|
+
|
23
|
+
|
24
|
+
class Editor(str, Enum):
|
25
|
+
"""
|
26
|
+
The editor that is used to edit the rules.
|
27
|
+
"""
|
28
|
+
Pycharm = "pycharm"
|
29
|
+
"""
|
30
|
+
PyCharm editor.
|
31
|
+
"""
|
32
|
+
Code = "code"
|
33
|
+
"""
|
34
|
+
Visual Studio Code editor.
|
35
|
+
"""
|
36
|
+
CodeServer = "code-server"
|
37
|
+
"""
|
38
|
+
Visual Studio Code server editor.
|
39
|
+
"""
|
40
|
+
@classmethod
|
41
|
+
def from_str(cls, editor: str) -> Editor:
|
42
|
+
"""
|
43
|
+
Convert a string value to an Editor enum.
|
44
|
+
|
45
|
+
:param editor: The string that represents the editor name.
|
46
|
+
:return: The Editor enum.
|
47
|
+
"""
|
48
|
+
if editor not in cls._value2member_map_:
|
49
|
+
raise ValueError(f"Editor {editor} is not supported.")
|
50
|
+
return cls._value2member_map_[editor]
|
51
|
+
|
52
|
+
|
10
53
|
class Category(str, SubclassJSONSerializer, Enum):
|
11
54
|
|
12
55
|
@classmethod
|
@@ -3,15 +3,14 @@ from __future__ import annotations
|
|
3
3
|
import json
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
|
6
|
-
from typing_extensions import Optional,
|
6
|
+
from typing_extensions import Optional, TYPE_CHECKING, List
|
7
7
|
|
8
|
-
from .datastructures.case import Case, CaseAttribute
|
9
8
|
from .datastructures.callable_expression import CallableExpression
|
10
9
|
from .datastructures.enums import PromptFor
|
11
10
|
from .datastructures.dataclasses import CaseQuery
|
12
11
|
from .datastructures.case import show_current_and_corner_cases
|
13
|
-
from .
|
14
|
-
from .
|
12
|
+
from .user_interface.gui import RDRCaseViewer
|
13
|
+
from .user_interface.prompt import UserPrompt
|
15
14
|
|
16
15
|
if TYPE_CHECKING:
|
17
16
|
from .rdr import Rule
|
@@ -66,12 +65,20 @@ class Human(Expert):
|
|
66
65
|
The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
|
67
66
|
"""
|
68
67
|
|
68
|
+
def __init__(self, use_loaded_answers: bool = False, append: bool = False, viewer: Optional[RDRCaseViewer] = None):
|
69
|
+
"""
|
70
|
+
Initialize the Human expert.
|
71
|
+
|
72
|
+
:param viewer: The RDRCaseViewer instance to use for prompting the user.
|
73
|
+
"""
|
74
|
+
super().__init__(use_loaded_answers=use_loaded_answers, append=append)
|
75
|
+
self.user_prompt = UserPrompt(viewer)
|
76
|
+
|
69
77
|
def save_answers(self, path: str):
|
70
78
|
"""
|
71
79
|
Save the expert answers to a file.
|
72
80
|
|
73
81
|
:param path: The path to save the answers to.
|
74
|
-
:param append: A flag to indicate if the answers should be appended to the file or not.
|
75
82
|
"""
|
76
83
|
if self.append:
|
77
84
|
# read the file and append the new answers
|
@@ -118,7 +125,7 @@ class Human(Expert):
|
|
118
125
|
if user_input:
|
119
126
|
condition = CallableExpression(user_input, bool, scope=case_query.scope)
|
120
127
|
else:
|
121
|
-
user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions)
|
128
|
+
user_input, condition = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conditions)
|
122
129
|
if not self.use_loaded_answers:
|
123
130
|
self.all_expert_answers.append(user_input)
|
124
131
|
case_query.conditions = condition
|
@@ -138,10 +145,11 @@ class Human(Expert):
|
|
138
145
|
expert_input = self.all_expert_answers.pop(0)
|
139
146
|
if expert_input is not None:
|
140
147
|
expression = CallableExpression(expert_input, case_query.attribute_type,
|
141
|
-
scope=case_query.scope
|
148
|
+
scope=case_query.scope,
|
149
|
+
mutually_exclusive=case_query.mutually_exclusive)
|
142
150
|
else:
|
143
151
|
show_current_and_corner_cases(case_query.case)
|
144
|
-
expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion)
|
152
|
+
expert_input, expression = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conclusion)
|
145
153
|
self.all_expert_answers.append(expert_input)
|
146
154
|
case_query.target = expression
|
147
155
|
return expression
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import copyreg
|
3
4
|
import importlib
|
4
5
|
import sys
|
5
6
|
from abc import ABC, abstractmethod
|
@@ -96,11 +97,12 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
96
97
|
plt.ioff()
|
97
98
|
plt.show()
|
98
99
|
|
99
|
-
def __call__(self, case: Union[Case, SQLTable]) ->
|
100
|
+
def __call__(self, case: Union[Case, SQLTable]) -> Union[CallableExpression, Dict[str, CallableExpression]]:
|
100
101
|
return self.classify(case)
|
101
102
|
|
102
103
|
@abstractmethod
|
103
|
-
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False)
|
104
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) \
|
105
|
+
-> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
|
104
106
|
"""
|
105
107
|
Classify a case.
|
106
108
|
|
@@ -111,7 +113,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
111
113
|
pass
|
112
114
|
|
113
115
|
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
114
|
-
-> Union[
|
116
|
+
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
115
117
|
"""
|
116
118
|
Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
|
117
119
|
incorrect by comparing the case with the target category.
|
@@ -136,7 +138,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
136
138
|
|
137
139
|
@abstractmethod
|
138
140
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
139
|
-
-> Union[
|
141
|
+
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
140
142
|
"""
|
141
143
|
Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
|
142
144
|
comparing the case with the target category.
|
@@ -881,10 +883,12 @@ class GeneralRDR(RippleDownRules):
|
|
881
883
|
f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
|
882
884
|
|
883
885
|
@property
|
884
|
-
def case_type(self) -> Type:
|
886
|
+
def case_type(self) -> Optional[Type]:
|
885
887
|
"""
|
886
888
|
:return: The type of the case (input) to the RDR classifier.
|
887
889
|
"""
|
890
|
+
if self.start_rule is None or self.start_rule.corner_case is None:
|
891
|
+
return None
|
888
892
|
if isinstance(self.start_rule.corner_case, Case):
|
889
893
|
return self.start_rule.corner_case._obj_type
|
890
894
|
else:
|
@@ -898,10 +902,12 @@ class GeneralRDR(RippleDownRules):
|
|
898
902
|
return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
|
899
903
|
|
900
904
|
@property
|
901
|
-
def _default_generated_python_file_name(self) -> str:
|
905
|
+
def _default_generated_python_file_name(self) -> Optional[str]:
|
902
906
|
"""
|
903
907
|
:return: The default generated python file name.
|
904
908
|
"""
|
909
|
+
if self.start_rule is None or self.start_rule.corner_case is None:
|
910
|
+
return None
|
905
911
|
if isinstance(self.start_rule.corner_case, Case):
|
906
912
|
name = self.start_rule.corner_case._name
|
907
913
|
else:
|
@@ -929,7 +935,7 @@ class GeneralRDR(RippleDownRules):
|
|
929
935
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|
930
936
|
# add rdr python generated functions.
|
931
937
|
for rdr_key, rdr in self.start_rules_dict.items():
|
932
|
-
imports += (f"from
|
938
|
+
imports += (f"from ."
|
933
939
|
f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
|
934
940
|
return imports
|
935
941
|
|