ripple-down-rules 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +2 -2
- ripple_down_rules/datastructures/__init__.py +1 -1
- ripple_down_rules/datastructures/callable_expression.py +3 -3
- ripple_down_rules/datastructures/{table.py → case.py} +92 -111
- ripple_down_rules/datastructures/dataclasses.py +2 -2
- ripple_down_rules/experts.py +19 -19
- ripple_down_rules/prompt.py +2 -2
- ripple_down_rules/rdr.py +27 -20
- {ripple_down_rules-0.0.8.dist-info → ripple_down_rules-0.0.9.dist-info}/METADATA +1 -1
- ripple_down_rules-0.0.9.dist-info/RECORD +18 -0
- ripple_down_rules-0.0.8.dist-info/RECORD +0 -18
- {ripple_down_rules-0.0.8.dist-info → ripple_down_rules-0.0.9.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.0.8.dist-info → ripple_down_rules-0.0.9.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.0.8.dist-info → ripple_down_rules-0.0.9.dist-info}/top_level.txt +0 -0
ripple_down_rules/datasets.py
CHANGED
@@ -9,7 +9,7 @@ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationshi
|
|
9
9
|
from typing_extensions import Tuple, List, Set, Optional
|
10
10
|
from ucimlrepo import fetch_ucirepo
|
11
11
|
|
12
|
-
from .datastructures import Case,
|
12
|
+
from .datastructures import Case, create_cases_from_dataframe, Category, CaseAttribute
|
13
13
|
|
14
14
|
|
15
15
|
def load_cached_dataset(cache_file):
|
@@ -77,7 +77,7 @@ def load_zoo_dataset(cache_file: Optional[str] = None) -> Tuple[List[Case], List
|
|
77
77
|
y = zoo['targets']
|
78
78
|
# get ids as list of strings
|
79
79
|
ids = zoo['ids'].values.flatten()
|
80
|
-
all_cases =
|
80
|
+
all_cases = create_cases_from_dataframe(X)
|
81
81
|
|
82
82
|
category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
|
83
83
|
category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
|
@@ -7,7 +7,7 @@ from _ast import AST
|
|
7
7
|
from sqlalchemy.orm import Session
|
8
8
|
from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
|
9
9
|
|
10
|
-
from .
|
10
|
+
from .case import create_case, Case
|
11
11
|
from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string
|
12
12
|
|
13
13
|
|
@@ -128,8 +128,8 @@ class CallableExpression(SubclassJSONSerializer):
|
|
128
128
|
|
129
129
|
def __call__(self, case: Any, **kwargs) -> Any:
|
130
130
|
try:
|
131
|
-
if not isinstance(case,
|
132
|
-
case =
|
131
|
+
if not isinstance(case, Case):
|
132
|
+
case = create_case(case, max_recursion_idx=3)
|
133
133
|
output = eval(self.code)
|
134
134
|
if self.conclusion_type:
|
135
135
|
assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
|
@@ -1,44 +1,40 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import os
|
4
|
-
import time
|
5
|
-
from abc import ABC
|
6
3
|
from collections import UserDict
|
7
|
-
from copy import deepcopy, copy
|
8
4
|
from dataclasses import dataclass
|
9
5
|
from enum import Enum
|
10
6
|
|
11
7
|
from pandas import DataFrame
|
12
8
|
from sqlalchemy import MetaData
|
13
9
|
from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColumn, registry
|
14
|
-
from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING
|
10
|
+
from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING
|
15
11
|
|
16
|
-
from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint,
|
17
|
-
SubclassJSONSerializer
|
12
|
+
from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer
|
18
13
|
|
19
14
|
if TYPE_CHECKING:
|
20
15
|
from ripple_down_rules.rules import Rule
|
21
16
|
from .callable_expression import CallableExpression
|
22
17
|
|
23
18
|
|
24
|
-
class
|
19
|
+
class Case(UserDict, SubclassJSONSerializer):
|
25
20
|
"""
|
26
21
|
A collection of attributes that represents a set of constraints on a case. This is a dictionary where the keys are
|
27
22
|
the names of the attributes and the values are the attributes. All are stored in lower case.
|
28
23
|
"""
|
29
24
|
|
30
|
-
def __init__(self,
|
25
|
+
def __init__(self, _id: Optional[Hashable] = None, _type: Optional[Type] = None, **kwargs):
|
31
26
|
"""
|
32
27
|
Create a new row.
|
33
28
|
|
34
|
-
:param
|
29
|
+
:param _id: The id of the row.
|
35
30
|
:param kwargs: The attributes of the row.
|
36
31
|
"""
|
37
32
|
super().__init__(kwargs)
|
38
|
-
self.
|
33
|
+
self._id = _id if _id else id(self)
|
34
|
+
self._type = _type
|
39
35
|
|
40
36
|
@classmethod
|
41
|
-
def from_obj(cls, obj: Any, obj_name: Optional[str] = None, max_recursion_idx: int = 3) ->
|
37
|
+
def from_obj(cls, obj: Any, obj_name: Optional[str] = None, max_recursion_idx: int = 3) -> Case:
|
42
38
|
"""
|
43
39
|
Create a row from an object.
|
44
40
|
|
@@ -47,7 +43,7 @@ class Row(UserDict, SubclassJSONSerializer):
|
|
47
43
|
:param obj_name: The name of the object.
|
48
44
|
:return: The row of the object.
|
49
45
|
"""
|
50
|
-
return
|
46
|
+
return create_case(obj, max_recursion_idx=max_recursion_idx, obj_name=obj_name)
|
51
47
|
|
52
48
|
def __getitem__(self, item: str) -> Any:
|
53
49
|
return super().__getitem__(item.lower())
|
@@ -74,30 +70,22 @@ class Row(UserDict, SubclassJSONSerializer):
|
|
74
70
|
def __delitem__(self, key):
|
75
71
|
super().__delitem__(key.lower())
|
76
72
|
|
77
|
-
def __eq__(self, other):
|
78
|
-
if not isinstance(other, (Row, dict, UserDict)):
|
79
|
-
return False
|
80
|
-
elif isinstance(other, (dict, UserDict)):
|
81
|
-
return super().__eq__(Row(other))
|
82
|
-
else:
|
83
|
-
return super().__eq__(other)
|
84
|
-
|
85
73
|
def __hash__(self):
|
86
|
-
return self.
|
74
|
+
return self._id
|
87
75
|
|
88
76
|
def _to_json(self) -> Dict[str, Any]:
|
89
77
|
serializable = {k: v for k, v in self.items() if not k.startswith("_")}
|
90
|
-
serializable["_id"] = self.
|
78
|
+
serializable["_id"] = self._id
|
91
79
|
return {k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}
|
92
80
|
|
93
81
|
@classmethod
|
94
|
-
def _from_json(cls, data: Dict[str, Any]) ->
|
82
|
+
def _from_json(cls, data: Dict[str, Any]) -> Case:
|
95
83
|
id_ = data.pop("_id")
|
96
|
-
return cls(
|
84
|
+
return cls(_id=id_, **data)
|
97
85
|
|
98
86
|
|
99
87
|
@dataclass
|
100
|
-
class
|
88
|
+
class CaseAttributeValue(SubclassJSONSerializer):
|
101
89
|
"""
|
102
90
|
A column value is a value in a column.
|
103
91
|
"""
|
@@ -111,7 +99,7 @@ class ColumnValue(SubclassJSONSerializer):
|
|
111
99
|
"""
|
112
100
|
|
113
101
|
def __eq__(self, other):
|
114
|
-
if not isinstance(other,
|
102
|
+
if not isinstance(other, CaseAttributeValue):
|
115
103
|
return False
|
116
104
|
return self.value == other.value
|
117
105
|
|
@@ -122,58 +110,58 @@ class ColumnValue(SubclassJSONSerializer):
|
|
122
110
|
return {"id": self.id, "value": self.value}
|
123
111
|
|
124
112
|
@classmethod
|
125
|
-
def _from_json(cls, data: Dict[str, Any]) ->
|
113
|
+
def _from_json(cls, data: Dict[str, Any]) -> CaseAttributeValue:
|
126
114
|
return cls(id=data["id"], value=data["value"])
|
127
115
|
|
128
116
|
|
129
|
-
class
|
117
|
+
class CaseAttribute(set, SubclassJSONSerializer):
|
130
118
|
nullable: bool = True
|
131
119
|
"""
|
132
|
-
A boolean indicating whether the
|
120
|
+
A boolean indicating whether the case attribute can be None or not.
|
133
121
|
"""
|
134
122
|
mutually_exclusive: bool = False
|
135
123
|
"""
|
136
|
-
A boolean indicating whether the
|
124
|
+
A boolean indicating whether the case attribute is mutually exclusive or not. (i.e. can only have one value)
|
137
125
|
"""
|
138
126
|
|
139
|
-
def __init__(self, values: Set[
|
127
|
+
def __init__(self, values: Set[CaseAttributeValue]):
|
140
128
|
"""
|
141
|
-
Create a new
|
129
|
+
Create a new case attribute.
|
142
130
|
|
143
|
-
:param values: The values of the
|
131
|
+
:param values: The values of the case attribute.
|
144
132
|
"""
|
145
|
-
values = self.
|
146
|
-
self.id_value_map: Dict[Hashable, Union[
|
133
|
+
values = self._type_cast_values_to_set_of_case_attribute_values(values)
|
134
|
+
self.id_value_map: Dict[Hashable, Union[CaseAttributeValue, Set[CaseAttributeValue]]] = {id(v): v for v in values}
|
147
135
|
super().__init__([v.value for v in values])
|
148
136
|
|
149
137
|
@staticmethod
|
150
|
-
def
|
138
|
+
def _type_cast_values_to_set_of_case_attribute_values(values: Set[Any]) -> Set[CaseAttributeValue]:
|
151
139
|
"""
|
152
|
-
Type cast values to a set of
|
140
|
+
Type cast values to a set of case attribute values.
|
153
141
|
|
154
142
|
:param values: The values to type cast.
|
155
143
|
"""
|
156
144
|
values = make_set(values)
|
157
|
-
if len(values) > 0 and not isinstance(next(iter(values)),
|
158
|
-
values = {
|
145
|
+
if len(values) > 0 and not isinstance(next(iter(values)), CaseAttributeValue):
|
146
|
+
values = {CaseAttributeValue(id(values), v) for v in values}
|
159
147
|
return values
|
160
148
|
|
161
149
|
@classmethod
|
162
|
-
def from_obj(cls, values: Set[Any], row_obj: Optional[Any] = None) ->
|
150
|
+
def from_obj(cls, values: Set[Any], row_obj: Optional[Any] = None) -> CaseAttribute:
|
163
151
|
id_ = id(row_obj) if row_obj else id(values)
|
164
152
|
values = make_set(values)
|
165
|
-
return cls({
|
153
|
+
return cls({CaseAttributeValue(id_, v) for v in values})
|
166
154
|
|
167
155
|
@property
|
168
156
|
def as_dict(self) -> Dict[str, Any]:
|
169
157
|
"""
|
170
|
-
Get the
|
158
|
+
Get the case attribute as a dictionary.
|
171
159
|
|
172
|
-
:return: The
|
160
|
+
:return: The case attribute as a dictionary.
|
173
161
|
"""
|
174
162
|
return {self.__class__.__name__: self}
|
175
163
|
|
176
|
-
def filter_by(self, condition: CallableExpression) ->
|
164
|
+
def filter_by(self, condition: CallableExpression) -> CaseAttribute:
|
177
165
|
"""
|
178
166
|
Filter the column by a condition.
|
179
167
|
|
@@ -200,118 +188,114 @@ class Column(set, SubclassJSONSerializer):
|
|
200
188
|
for id_, v in self.id_value_map.items()}
|
201
189
|
|
202
190
|
@classmethod
|
203
|
-
def _from_json(cls, data: Dict[str, Any]) ->
|
204
|
-
return cls({
|
191
|
+
def _from_json(cls, data: Dict[str, Any]) -> CaseAttribute:
|
192
|
+
return cls({CaseAttributeValue.from_json(v) for id_, v in data.items()})
|
205
193
|
|
206
194
|
|
207
|
-
def
|
195
|
+
def create_cases_from_dataframe(df: DataFrame) -> List[Case]:
|
208
196
|
"""
|
209
|
-
Create
|
197
|
+
Create cases from a pandas DataFrame.
|
210
198
|
|
211
|
-
:param df: The DataFrame to create
|
212
|
-
:
|
213
|
-
:return: The row of the DataFrame.
|
199
|
+
:param df: The DataFrame to create cases from.
|
200
|
+
:return: The cases of the DataFrame.
|
214
201
|
"""
|
215
|
-
|
216
|
-
|
217
|
-
for row_id,
|
218
|
-
|
219
|
-
|
220
|
-
return
|
202
|
+
cases = []
|
203
|
+
attribute_names = list(df.columns)
|
204
|
+
for row_id, case in df.iterrows():
|
205
|
+
case = {col_name: case[col_name].item() for col_name in attribute_names}
|
206
|
+
cases.append(Case(_id=row_id, _type=DataFrame, **case))
|
207
|
+
return cases
|
221
208
|
|
222
209
|
|
223
|
-
def
|
224
|
-
|
210
|
+
def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
|
211
|
+
obj_name: Optional[str] = None, parent_is_iterable: bool = False) -> Case:
|
225
212
|
"""
|
226
|
-
Create a
|
213
|
+
Create a case from an object.
|
227
214
|
|
228
|
-
:param obj: The object to create a
|
215
|
+
:param obj: The object to create a case from.
|
229
216
|
:param recursion_idx: The current recursion index.
|
230
217
|
:param max_recursion_idx: The maximum recursion index to prevent infinite recursion.
|
231
218
|
:param obj_name: The name of the object.
|
232
219
|
:param parent_is_iterable: Boolean indicating whether the parent object is iterable or not.
|
233
|
-
:return: The
|
220
|
+
:return: The case that represents the object.
|
234
221
|
"""
|
235
|
-
if isinstance(obj,
|
222
|
+
if isinstance(obj, Case):
|
236
223
|
return obj
|
237
224
|
if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
|
238
225
|
or (obj.__class__ in [MetaData, registry])):
|
239
|
-
return
|
240
|
-
|
226
|
+
return Case(_id=id(obj), _type=obj.__class__,
|
227
|
+
**{obj_name or obj.__class__.__name__: make_set(obj) if parent_is_iterable else obj})
|
228
|
+
case = Case(_id=id(obj), _type=obj.__class__)
|
241
229
|
for attr in dir(obj):
|
242
230
|
if attr.startswith("_") or callable(getattr(obj, attr)):
|
243
231
|
continue
|
244
232
|
attr_value = getattr(obj, attr)
|
245
|
-
|
246
|
-
|
247
|
-
return
|
233
|
+
case = create_or_update_case_from_attribute(attr_value, attr, obj, attr, recursion_idx,
|
234
|
+
max_recursion_idx, parent_is_iterable, case)
|
235
|
+
return case
|
248
236
|
|
249
237
|
|
250
|
-
def
|
251
|
-
|
252
|
-
|
253
|
-
|
238
|
+
def create_or_update_case_from_attribute(attr_value: Any, name: str, obj: Any, obj_name: Optional[str] = None,
|
239
|
+
recursion_idx: int = 0, max_recursion_idx: int = 1,
|
240
|
+
parent_is_iterable: bool = False,
|
241
|
+
case: Optional[Case] = None) -> Case:
|
254
242
|
"""
|
255
|
-
|
243
|
+
Create or update a case from an attribute of the object that the case represents.
|
256
244
|
|
257
|
-
:param attr_value: The attribute value
|
245
|
+
:param attr_value: The attribute value.
|
258
246
|
:param name: The name of the attribute.
|
259
247
|
:param obj: The parent object of the attribute.
|
260
248
|
:param obj_name: The parent object name.
|
261
249
|
:param recursion_idx: The recursion index to prevent infinite recursion.
|
262
250
|
:param max_recursion_idx: The maximum recursion index.
|
263
251
|
:param parent_is_iterable: Boolean indicating whether the parent object is iterable or not.
|
264
|
-
:param
|
265
|
-
:return:
|
252
|
+
:param case: The case to update.
|
253
|
+
:return: The updated/created case.
|
266
254
|
"""
|
267
|
-
if
|
268
|
-
|
255
|
+
if case is None:
|
256
|
+
case = Case(_id=id(obj), _type=obj.__class__)
|
269
257
|
if isinstance(attr_value, (dict, UserDict)):
|
270
|
-
|
258
|
+
case.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
|
271
259
|
if hasattr(attr_value, "__iter__") and not isinstance(attr_value, str):
|
272
|
-
column
|
273
|
-
|
274
|
-
|
275
|
-
|
260
|
+
column = create_case_attribute_from_iterable_attribute(attr_value, name, obj, obj_name,
|
261
|
+
recursion_idx=recursion_idx + 1,
|
262
|
+
max_recursion_idx=max_recursion_idx)
|
263
|
+
case[obj_name] = column
|
276
264
|
else:
|
277
|
-
|
278
|
-
return
|
265
|
+
case[obj_name] = make_set(attr_value) if parent_is_iterable else attr_value
|
266
|
+
return case
|
279
267
|
|
280
268
|
|
281
|
-
def
|
269
|
+
def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, obj: Any, obj_name: Optional[str] = None,
|
282
270
|
recursion_idx: int = 0,
|
283
|
-
max_recursion_idx: int = 1) ->
|
271
|
+
max_recursion_idx: int = 1) -> CaseAttribute:
|
284
272
|
"""
|
285
|
-
Get a
|
273
|
+
Get a case attribute from an iterable attribute.
|
286
274
|
|
287
|
-
:param attr_value: The iterable attribute to get the
|
288
|
-
:param name: The name of the
|
275
|
+
:param attr_value: The iterable attribute to get the case from.
|
276
|
+
:param name: The name of the case.
|
289
277
|
:param obj: The parent object of the iterable.
|
290
278
|
:param obj_name: The parent object name.
|
291
279
|
:param recursion_idx: The recursion index to prevent infinite recursion.
|
292
280
|
:param max_recursion_idx: The maximum recursion index.
|
293
|
-
:return: A
|
281
|
+
:return: A case attribute that represents the original iterable attribute.
|
294
282
|
"""
|
295
283
|
values = attr_value.values() if isinstance(attr_value, (dict, UserDict)) else attr_value
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
if not range_:
|
300
|
-
raise ValueError(f"Could not determine the range of {name} in {obj}.")
|
301
|
-
attr_row = Row(id_=id(attr_value))
|
302
|
-
column = Column.from_obj(values, row_obj=obj)
|
284
|
+
_type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
|
285
|
+
attr_case = Case(_id=id(attr_value), _type=_type)
|
286
|
+
case_attr = CaseAttribute.from_obj(values, row_obj=obj)
|
303
287
|
for idx, val in enumerate(values):
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
for sub_attr, val in
|
309
|
-
setattr(
|
310
|
-
return
|
288
|
+
sub_attr_case = create_case(val, recursion_idx=recursion_idx,
|
289
|
+
max_recursion_idx=max_recursion_idx,
|
290
|
+
obj_name=obj_name, parent_is_iterable=True)
|
291
|
+
attr_case.update(sub_attr_case)
|
292
|
+
for sub_attr, val in attr_case.items():
|
293
|
+
setattr(case_attr, sub_attr, val)
|
294
|
+
return case_attr
|
311
295
|
|
312
296
|
|
313
|
-
def show_current_and_corner_cases(case: Any, targets: Optional[Union[List[
|
314
|
-
current_conclusions: Optional[Union[List[
|
297
|
+
def show_current_and_corner_cases(case: Any, targets: Optional[Union[List[CaseAttribute], List[SQLColumn]]] = None,
|
298
|
+
current_conclusions: Optional[Union[List[CaseAttribute], List[SQLColumn]]] = None,
|
315
299
|
last_evaluated_rule: Optional[Rule] = None) -> None:
|
316
300
|
"""
|
317
301
|
Show the data of the new case and if last evaluated rule exists also show that of the corner case.
|
@@ -351,6 +335,3 @@ def show_current_and_corner_cases(case: Any, targets: Optional[Union[List[Column
|
|
351
335
|
case_dict.update(targets)
|
352
336
|
case_dict.update(current_conclusions)
|
353
337
|
print(table_rows_as_str(case_dict))
|
354
|
-
|
355
|
-
|
356
|
-
Case = Row
|
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|
5
5
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
6
6
|
from typing_extensions import Any, Optional, Type
|
7
7
|
|
8
|
-
from .
|
8
|
+
from .case import create_case, Case
|
9
9
|
from ..utils import get_attribute_name, copy_case
|
10
10
|
|
11
11
|
|
@@ -50,7 +50,7 @@ class CaseQuery:
|
|
50
50
|
self.attribute_name = attribute_name
|
51
51
|
|
52
52
|
if not isinstance(case, (Case, SQLTable)):
|
53
|
-
case =
|
53
|
+
case = create_case(case, max_recursion_idx=3)
|
54
54
|
self.case = case
|
55
55
|
|
56
56
|
self.attribute = getattr(self.case, self.attribute_name) if self.attribute_name else None
|
ripple_down_rules/experts.py
CHANGED
@@ -6,8 +6,8 @@ from abc import ABC, abstractmethod
|
|
6
6
|
from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColumn, Session
|
7
7
|
from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type, Union, Any
|
8
8
|
|
9
|
-
from .datastructures import (Case, PromptFor, CallableExpression,
|
10
|
-
from .datastructures.
|
9
|
+
from .datastructures import (Case, PromptFor, CallableExpression, CaseAttribute, CaseQuery)
|
10
|
+
from .datastructures.case import show_current_and_corner_cases
|
11
11
|
from .prompt import prompt_user_for_expression, prompt_user_about_case
|
12
12
|
from .utils import get_all_subclasses, is_iterable
|
13
13
|
|
@@ -30,13 +30,13 @@ class Expert(ABC):
|
|
30
30
|
"""
|
31
31
|
A flag to indicate if the expert should use loaded answers or not.
|
32
32
|
"""
|
33
|
-
known_categories: Optional[Dict[str, Type[
|
33
|
+
known_categories: Optional[Dict[str, Type[CaseAttribute]]] = None
|
34
34
|
"""
|
35
35
|
The known categories (i.e. Column types) to use.
|
36
36
|
"""
|
37
37
|
|
38
38
|
@abstractmethod
|
39
|
-
def ask_for_conditions(self, x: Case, targets: List[
|
39
|
+
def ask_for_conditions(self, x: Case, targets: List[CaseAttribute], last_evaluated_rule: Optional[Rule] = None) \
|
40
40
|
-> CallableExpression:
|
41
41
|
"""
|
42
42
|
Ask the expert to provide the differentiating features between two cases or unique features for a case
|
@@ -50,8 +50,8 @@ class Expert(ABC):
|
|
50
50
|
pass
|
51
51
|
|
52
52
|
@abstractmethod
|
53
|
-
def ask_for_extra_conclusions(self, x: Case, current_conclusions: List[
|
54
|
-
-> Dict[
|
53
|
+
def ask_for_extra_conclusions(self, x: Case, current_conclusions: List[CaseAttribute]) \
|
54
|
+
-> Dict[CaseAttribute, CallableExpression]:
|
55
55
|
"""
|
56
56
|
Ask the expert to provide extra conclusions for a case by providing a pair of category and conditions for
|
57
57
|
that category.
|
@@ -63,9 +63,9 @@ class Expert(ABC):
|
|
63
63
|
pass
|
64
64
|
|
65
65
|
@abstractmethod
|
66
|
-
def ask_if_conclusion_is_correct(self, x: Case, conclusion:
|
67
|
-
targets: Optional[List[
|
68
|
-
current_conclusions: Optional[List[
|
66
|
+
def ask_if_conclusion_is_correct(self, x: Case, conclusion: CaseAttribute,
|
67
|
+
targets: Optional[List[CaseAttribute]] = None,
|
68
|
+
current_conclusions: Optional[List[CaseAttribute]] = None) -> bool:
|
69
69
|
"""
|
70
70
|
Ask the expert if the conclusion is correct.
|
71
71
|
|
@@ -125,14 +125,14 @@ class Human(Expert):
|
|
125
125
|
self.all_expert_answers = json.load(f)
|
126
126
|
|
127
127
|
def ask_for_conditions(self, case: Case,
|
128
|
-
targets: Union[List[
|
128
|
+
targets: Union[List[CaseAttribute], List[SQLColumn]],
|
129
129
|
last_evaluated_rule: Optional[Rule] = None) \
|
130
130
|
-> CallableExpression:
|
131
131
|
if not self.use_loaded_answers:
|
132
132
|
show_current_and_corner_cases(case, targets, last_evaluated_rule=last_evaluated_rule)
|
133
133
|
return self._get_conditions(case, targets)
|
134
134
|
|
135
|
-
def _get_conditions(self, case: Case, targets: List[
|
135
|
+
def _get_conditions(self, case: Case, targets: List[CaseAttribute]) \
|
136
136
|
-> CallableExpression:
|
137
137
|
"""
|
138
138
|
Ask the expert to provide the differentiating features between two cases or unique features for a case
|
@@ -157,8 +157,8 @@ class Human(Expert):
|
|
157
157
|
self.all_expert_answers.append(user_input)
|
158
158
|
return condition
|
159
159
|
|
160
|
-
def ask_for_extra_conclusions(self, case: Case, current_conclusions: List[
|
161
|
-
-> Dict[
|
160
|
+
def ask_for_extra_conclusions(self, case: Case, current_conclusions: List[CaseAttribute]) \
|
161
|
+
-> Dict[CaseAttribute, CallableExpression]:
|
162
162
|
"""
|
163
163
|
Ask the expert to provide extra conclusions for a case by providing a pair of category and conditions for
|
164
164
|
that category.
|
@@ -198,7 +198,7 @@ class Human(Expert):
|
|
198
198
|
self.all_expert_answers.append(expert_input)
|
199
199
|
return expression
|
200
200
|
|
201
|
-
def get_category_type(self, cat_name: str) -> Optional[Type[
|
201
|
+
def get_category_type(self, cat_name: str) -> Optional[Type[CaseAttribute]]:
|
202
202
|
"""
|
203
203
|
Get the category type from the known categories.
|
204
204
|
|
@@ -206,8 +206,8 @@ class Human(Expert):
|
|
206
206
|
:return: The category type.
|
207
207
|
"""
|
208
208
|
cat_name = cat_name.lower()
|
209
|
-
self.known_categories = get_all_subclasses(
|
210
|
-
self.known_categories.update(
|
209
|
+
self.known_categories = get_all_subclasses(CaseAttribute) if not self.known_categories else self.known_categories
|
210
|
+
self.known_categories.update(CaseAttribute.registry)
|
211
211
|
category_type = None
|
212
212
|
if cat_name in self.known_categories:
|
213
213
|
category_type = self.known_categories[cat_name]
|
@@ -222,9 +222,9 @@ class Human(Expert):
|
|
222
222
|
question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
|
223
223
|
return not self.ask_yes_no_question(question)
|
224
224
|
|
225
|
-
def ask_if_conclusion_is_correct(self, x: Case, conclusion:
|
226
|
-
targets: Optional[List[
|
227
|
-
current_conclusions: Optional[List[
|
225
|
+
def ask_if_conclusion_is_correct(self, x: Case, conclusion: CaseAttribute,
|
226
|
+
targets: Optional[List[CaseAttribute]] = None,
|
227
|
+
current_conclusions: Optional[List[CaseAttribute]] = None) -> bool:
|
228
228
|
"""
|
229
229
|
Ask the expert if the conclusion is correct.
|
230
230
|
|
ripple_down_rules/prompt.py
CHANGED
@@ -7,7 +7,7 @@ from prompt_toolkit.completion import WordCompleter
|
|
7
7
|
from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
|
8
8
|
from typing_extensions import Any, List, Optional, Tuple, Dict, Union, Type
|
9
9
|
|
10
|
-
from .datastructures import Case, PromptFor, CallableExpression,
|
10
|
+
from .datastructures import Case, PromptFor, CallableExpression, create_case, parse_string_to_expression
|
11
11
|
|
12
12
|
|
13
13
|
def prompt_user_for_expression(case: Union[Case, SQLTable], prompt_for: PromptFor, target_name: str,
|
@@ -60,7 +60,7 @@ def get_completions(obj: Any) -> List[str]:
|
|
60
60
|
# Define completer with all object attributes and comparison operators
|
61
61
|
completions = ['==', '!=', '>', '<', '>=', '<=', 'in', 'not', 'and', 'or', 'is']
|
62
62
|
completions += ["isinstance(", "issubclass(", "type(", "len(", "hasattr(", "getattr(", "setattr(", "delattr("]
|
63
|
-
completions += list(
|
63
|
+
completions += list(create_case(obj).keys())
|
64
64
|
return completions
|
65
65
|
|
66
66
|
|
ripple_down_rules/rdr.py
CHANGED
@@ -8,7 +8,7 @@ from ordered_set import OrderedSet
|
|
8
8
|
from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
|
9
9
|
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple
|
10
10
|
|
11
|
-
from .datastructures import Case, MCRDRMode, CallableExpression,
|
11
|
+
from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute, CaseQuery
|
12
12
|
from .experts import Expert, Human
|
13
13
|
from .rules import Rule, SingleClassRule, MultiClassTopRule
|
14
14
|
from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
|
@@ -23,7 +23,7 @@ class RippleDownRules(ABC):
|
|
23
23
|
"""
|
24
24
|
The figure to draw the tree on.
|
25
25
|
"""
|
26
|
-
expert_accepted_conclusions: Optional[List[
|
26
|
+
expert_accepted_conclusions: Optional[List[CaseAttribute]] = None
|
27
27
|
"""
|
28
28
|
The conclusions that the expert has accepted, such that they are not asked again.
|
29
29
|
"""
|
@@ -37,11 +37,11 @@ class RippleDownRules(ABC):
|
|
37
37
|
self.session = session
|
38
38
|
self.fig: Optional[plt.Figure] = None
|
39
39
|
|
40
|
-
def __call__(self, case: Union[Case, SQLTable]) ->
|
40
|
+
def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
|
41
41
|
return self.classify(case)
|
42
42
|
|
43
43
|
@abstractmethod
|
44
|
-
def classify(self, case: Union[Case, SQLTable]) -> Optional[
|
44
|
+
def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
|
45
45
|
"""
|
46
46
|
Classify a case.
|
47
47
|
|
@@ -52,7 +52,7 @@ class RippleDownRules(ABC):
|
|
52
52
|
|
53
53
|
@abstractmethod
|
54
54
|
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
55
|
-
-> Union[
|
55
|
+
-> Union[CaseAttribute, CallableExpression]:
|
56
56
|
"""
|
57
57
|
Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
|
58
58
|
comparing the case with the target category.
|
@@ -118,7 +118,7 @@ class RippleDownRules(ABC):
|
|
118
118
|
plt.show()
|
119
119
|
|
120
120
|
@staticmethod
|
121
|
-
def calculate_precision_and_recall(pred_cat: List[
|
121
|
+
def calculate_precision_and_recall(pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> Tuple[List[bool], List[bool]]:
|
122
122
|
"""
|
123
123
|
:param pred_cat: The predicted category.
|
124
124
|
:param target: The target category.
|
@@ -131,7 +131,7 @@ class RippleDownRules(ABC):
|
|
131
131
|
precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
|
132
132
|
return precision, recall
|
133
133
|
|
134
|
-
def is_matching(self, pred_cat: List[
|
134
|
+
def is_matching(self, pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> bool:
|
135
135
|
"""
|
136
136
|
:param pred_cat: The predicted category.
|
137
137
|
:param target: The target category.
|
@@ -179,7 +179,7 @@ RDR = RippleDownRules
|
|
179
179
|
class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
|
180
180
|
|
181
181
|
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
182
|
-
-> Union[
|
182
|
+
-> Union[CaseAttribute, CallableExpression]:
|
183
183
|
"""
|
184
184
|
Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
|
185
185
|
comparing the case with the target category if provided.
|
@@ -207,7 +207,7 @@ class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
|
|
207
207
|
|
208
208
|
return self.classify(case)
|
209
209
|
|
210
|
-
def classify(self, case: Case) -> Optional[
|
210
|
+
def classify(self, case: Case) -> Optional[CaseAttribute]:
|
211
211
|
"""
|
212
212
|
Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
|
213
213
|
"""
|
@@ -225,18 +225,22 @@ class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
|
|
225
225
|
"""
|
226
226
|
Write the tree of rules as source code to a file.
|
227
227
|
"""
|
228
|
-
|
229
|
-
|
228
|
+
case = self.start_rule.corner_case
|
229
|
+
if isinstance(case, Case):
|
230
|
+
case_type = case._type
|
231
|
+
else:
|
232
|
+
case_type = type(case)
|
233
|
+
case_module = case_type.__module__
|
230
234
|
conclusion = self.start_rule.conclusion
|
231
235
|
if isinstance(conclusion, CallableExpression):
|
232
236
|
conclusion_types = [conclusion.conclusion_type]
|
233
|
-
elif isinstance(conclusion,
|
237
|
+
elif isinstance(conclusion, CaseAttribute):
|
234
238
|
conclusion_types = list(conclusion._value_range)
|
235
239
|
else:
|
236
240
|
conclusion_types = [type(conclusion)]
|
237
241
|
imports = ""
|
238
242
|
if case_module != "builtins":
|
239
|
-
imports += f"from {case_module} import {case_type}\n"
|
243
|
+
imports += f"from {case_module} import {case_type.__name__}\n"
|
240
244
|
if len(conclusion_types) > 1:
|
241
245
|
conclusion_name = "Union[" + ", ".join([c.__name__ for c in conclusion_types]) + "]"
|
242
246
|
else:
|
@@ -244,11 +248,14 @@ class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
|
|
244
248
|
for conclusion_type in conclusion_types:
|
245
249
|
if conclusion_type.__module__ != "builtins":
|
246
250
|
imports += f"from {conclusion_type.__module__} import {conclusion_name}\n"
|
251
|
+
imports += "from ripple_down_rules.datastructures import Case, create_case\n"
|
247
252
|
imports += "\n\n"
|
248
|
-
func_def = f"def classify_{conclusion_name.lower()}(case: {case_type}) -> {conclusion_name}:\n"
|
253
|
+
func_def = f"def classify_{conclusion_name.lower()}(case: {case_type.__name__}) -> {conclusion_name}:\n"
|
249
254
|
with open(filename, "w") as f:
|
250
255
|
f.write(imports)
|
251
256
|
f.write(func_def)
|
257
|
+
f.write(f"{' '*4}if not isinstance(case, Case):\n"
|
258
|
+
f"{' '*4} case = create_case(case, recursion_idx=3)\n""")
|
252
259
|
self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
|
253
260
|
|
254
261
|
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file, parent_indent: str = ""):
|
@@ -287,7 +294,7 @@ class MultiClassRDR(RippleDownRules):
|
|
287
294
|
"""
|
288
295
|
The evaluated rules in the classifier for one case.
|
289
296
|
"""
|
290
|
-
conclusions: Optional[List[
|
297
|
+
conclusions: Optional[List[CaseAttribute]] = None
|
291
298
|
"""
|
292
299
|
The conclusions that the case belongs to.
|
293
300
|
"""
|
@@ -320,7 +327,7 @@ class MultiClassRDR(RippleDownRules):
|
|
320
327
|
return self.conclusions
|
321
328
|
|
322
329
|
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None,
|
323
|
-
add_extra_conclusions: bool = False) -> List[Union[
|
330
|
+
add_extra_conclusions: bool = False) -> List[Union[CaseAttribute, CallableExpression]]:
|
324
331
|
"""
|
325
332
|
Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
|
326
333
|
or missing by comparing the case with the target category if provided.
|
@@ -448,7 +455,7 @@ class MultiClassRDR(RippleDownRules):
|
|
448
455
|
:param target: The target category to compare the conclusion with.
|
449
456
|
:return: Whether the conclusion is of the same class as the target category but has a different value.
|
450
457
|
"""
|
451
|
-
return conclusion.__class__ == target.__class__ and target.__class__ !=
|
458
|
+
return conclusion.__class__ == target.__class__ and target.__class__ != CaseAttribute
|
452
459
|
|
453
460
|
def conclusion_is_correct(self, case: Union[Case, SQLTable], target: Any, expert: Expert, evaluated_rule: Rule,
|
454
461
|
add_extra_conclusions: bool) -> bool:
|
@@ -597,7 +604,7 @@ class GeneralRDR(RippleDownRules):
|
|
597
604
|
return conclusions
|
598
605
|
|
599
606
|
def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs) \
|
600
|
-
-> List[Union[
|
607
|
+
-> List[Union[CaseAttribute, CallableExpression]]:
|
601
608
|
"""
|
602
609
|
Fit the GRDR on a case, if the target is a new type of category, a new RDR is created for it,
|
603
610
|
else the existing RDR of that type will be fitted on the case, and then classification is done and all
|
@@ -624,7 +631,7 @@ class GeneralRDR(RippleDownRules):
|
|
624
631
|
if not target:
|
625
632
|
target = expert.ask_for_conclusion(case_query)
|
626
633
|
case_query_cp = CaseQuery(case_cp, attribute_name=case_query.attribute_name, target=target)
|
627
|
-
if is_iterable(target) and not isinstance(target,
|
634
|
+
if is_iterable(target) and not isinstance(target, CaseAttribute):
|
628
635
|
target_type = type(make_list(target)[0])
|
629
636
|
assert all([type(t) is target_type for t in target]), ("All targets of a case query must be of the same"
|
630
637
|
" type")
|
@@ -659,7 +666,7 @@ class GeneralRDR(RippleDownRules):
|
|
659
666
|
return MultiClassRDR()
|
660
667
|
else:
|
661
668
|
return SingleClassRDR()
|
662
|
-
elif isinstance(attribute,
|
669
|
+
elif isinstance(attribute, CaseAttribute):
|
663
670
|
return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
|
664
671
|
else:
|
665
672
|
return MultiClassRDR() if is_iterable(attribute) else SingleClassRDR()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.9
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -0,0 +1,18 @@
|
|
1
|
+
ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
ripple_down_rules/datasets.py,sha256=qYX7IF7ACm0VRbaKfEgQ32j0YbUUyt2GfGU5Lo42CqI,4601
|
3
|
+
ripple_down_rules/experts.py,sha256=wg1uY0ox9dUMR4s1RdGjzpX1_WUqnCa060r1U9lrKYI,11214
|
4
|
+
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
|
+
ripple_down_rules/prompt.py,sha256=QAmxg4ssrGUAlK7lbyKw2nuRczTZColpjc9uMC1ts3I,4210
|
6
|
+
ripple_down_rules/rdr.py,sha256=IZNoZm6w_WPTlLkHEQmjWcDLZnyczgSK0zX5OFYl0Ys,34572
|
7
|
+
ripple_down_rules/rules.py,sha256=R9wIyalXBc6fGogzy92OxsyKAtJdQwdso_Vs1evWZGU,10274
|
8
|
+
ripple_down_rules/utils.py,sha256=pmjIPjF9wq6dmCHCoUKFRsJmSaI0b-hRLn375WHgWgQ,18010
|
9
|
+
ripple_down_rules/datastructures/__init__.py,sha256=zpmiYm4WkwNHaGdTIfacS7llN5d2xyU6U-saH_TpydI,103
|
10
|
+
ripple_down_rules/datastructures/callable_expression.py,sha256=TN6bi4VYjyLlSLTEA3dRo5ENfEdQYc8Fjj5nbnsz-C0,9058
|
11
|
+
ripple_down_rules/datastructures/case.py,sha256=mcBVu1IGNpLGyXb152oLrsv8UbNncIhPtjmSTUgJ7uc,13593
|
12
|
+
ripple_down_rules/datastructures/dataclasses.py,sha256=EVQ1jBKW7K7q7_JNgikHX9fm3EmQQKA74sNjEQ4rXn8,2449
|
13
|
+
ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
|
14
|
+
ripple_down_rules-0.0.9.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
15
|
+
ripple_down_rules-0.0.9.dist-info/METADATA,sha256=7ACb5X17EumZ1W19aIvx12FcjuovaQvmW4BE4Iccz7o,42518
|
16
|
+
ripple_down_rules-0.0.9.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
17
|
+
ripple_down_rules-0.0.9.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
18
|
+
ripple_down_rules-0.0.9.dist-info/RECORD,,
|
@@ -1,18 +0,0 @@
|
|
1
|
-
ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
ripple_down_rules/datasets.py,sha256=p7WmLiAKcYtzLq87RJhzU_BHCdM-zKK71yldjyYNpUE,4602
|
3
|
-
ripple_down_rules/experts.py,sha256=ajH46NXM6IREtWr2a-ZBxdKl4CFzSx5S7yktnyXgwnc,11089
|
4
|
-
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
|
-
ripple_down_rules/prompt.py,sha256=lmREZRyleBTHrVtcf2j_48oc0v3VlxXYGhl6w1mk8qI,4208
|
6
|
-
ripple_down_rules/rdr.py,sha256=Jag1zq4UqCIJGL_DhLcbYu6l7Bvkskay-mnJ5KRF34g,34131
|
7
|
-
ripple_down_rules/rules.py,sha256=R9wIyalXBc6fGogzy92OxsyKAtJdQwdso_Vs1evWZGU,10274
|
8
|
-
ripple_down_rules/utils.py,sha256=pmjIPjF9wq6dmCHCoUKFRsJmSaI0b-hRLn375WHgWgQ,18010
|
9
|
-
ripple_down_rules/datastructures/__init__.py,sha256=wY9WqXavuE3wQ1YP65cs_SZyr7CEMB9tol-4oxgK9CM,104
|
10
|
-
ripple_down_rules/datastructures/callable_expression.py,sha256=lF8ag3-8oYuETAUohtRfviOpypSMQR-JEJGgR0Zhv3E,9055
|
11
|
-
ripple_down_rules/datastructures/dataclasses.py,sha256=FW3MMhGXTPa0XwIyLzGalrPwiltNeUqbGIawCAKDHGk,2448
|
12
|
-
ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
|
13
|
-
ripple_down_rules/datastructures/table.py,sha256=G2ShfPmyiwfDXS-JY_jFD6nOx7OLuMr-GP--OCDGKYc,13733
|
14
|
-
ripple_down_rules-0.0.8.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
15
|
-
ripple_down_rules-0.0.8.dist-info/METADATA,sha256=8uoMHOwdTv7D_NqbYZbvFGw7Awn5rn0xTHf_B43hNsw,42518
|
16
|
-
ripple_down_rules-0.0.8.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
17
|
-
ripple_down_rules-0.0.8.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
18
|
-
ripple_down_rules-0.0.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|