ripple-down-rules 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +4 -3
- ripple_down_rules/datastructures/callable_expression.py +2 -3
- ripple_down_rules/datastructures/dataclasses.py +2 -3
- ripple_down_rules/datastructures/enums.py +11 -2
- ripple_down_rules/datastructures/table.py +16 -209
- ripple_down_rules/experts.py +2 -28
- ripple_down_rules/rdr.py +37 -25
- ripple_down_rules/rules.py +2 -3
- ripple_down_rules/utils.py +27 -15
- {ripple_down_rules-0.0.5.dist-info → ripple_down_rules-0.0.7.dist-info}/METADATA +1 -1
- ripple_down_rules-0.0.7.dist-info/RECORD +18 -0
- ripple_down_rules/datastructures/generated/__init__.py +0 -0
- ripple_down_rules/datastructures/generated/column/__init__.py +0 -0
- ripple_down_rules/datastructures/generated/row/__init__.py +0 -0
- ripple_down_rules-0.0.5.dist-info/RECORD +0 -21
- {ripple_down_rules-0.0.5.dist-info → ripple_down_rules-0.0.7.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.0.5.dist-info → ripple_down_rules-0.0.7.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.0.5.dist-info → ripple_down_rules-0.0.7.dist-info}/top_level.txt +0 -0
ripple_down_rules/datasets.py
CHANGED
@@ -81,7 +81,8 @@ def load_zoo_dataset(cache_file: Optional[str] = None) -> Tuple[List[Case], List
|
|
81
81
|
|
82
82
|
category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
|
83
83
|
category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
|
84
|
-
targets = [getattr(SpeciesCol, category_id_to_name[i]) for i in y.values.flatten()]
|
84
|
+
# targets = [getattr(SpeciesCol, category_id_to_name[i]) for i in y.values.flatten()]
|
85
|
+
targets = [Species.from_str(category_id_to_name[i]) for i in y.values.flatten()]
|
85
86
|
return all_cases, targets
|
86
87
|
|
87
88
|
|
@@ -104,8 +105,8 @@ class Habitat(Category):
|
|
104
105
|
air = "air"
|
105
106
|
|
106
107
|
|
107
|
-
SpeciesCol = Column.create_from_enum(Species, mutually_exclusive=True)
|
108
|
-
HabitatCol = Column.create_from_enum(Habitat, mutually_exclusive=False)
|
108
|
+
# SpeciesCol = Column.create_from_enum(Species, mutually_exclusive=True)
|
109
|
+
# HabitatCol = Column.create_from_enum(Habitat, mutually_exclusive=False)
|
109
110
|
|
110
111
|
|
111
112
|
class Base(sqlalchemy.orm.DeclarativeBase):
|
@@ -162,9 +162,8 @@ class CallableExpression(SubclassJSONSerializer):
|
|
162
162
|
prev_e = e
|
163
163
|
return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
|
164
164
|
|
165
|
-
def
|
166
|
-
return {
|
167
|
-
"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type)}
|
165
|
+
def _to_json(self) -> Dict[str, Any]:
|
166
|
+
return {"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type)}
|
168
167
|
|
169
168
|
@classmethod
|
170
169
|
def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
|
@@ -1,13 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from copy import copy, deepcopy
|
4
3
|
from dataclasses import dataclass
|
5
4
|
|
6
5
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
7
|
-
from typing_extensions import Any, Optional, Type
|
6
|
+
from typing_extensions import Any, Optional, Type
|
8
7
|
|
9
8
|
from .table import create_row, Case
|
10
|
-
from ..utils import get_attribute_name,
|
9
|
+
from ..utils import get_attribute_name, copy_case
|
11
10
|
|
12
11
|
|
13
12
|
@dataclass
|
@@ -2,10 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from enum import auto, Enum
|
4
4
|
|
5
|
-
from typing_extensions import List
|
5
|
+
from typing_extensions import List, Dict, Any
|
6
6
|
|
7
|
+
from ripple_down_rules.utils import SubclassJSONSerializer
|
7
8
|
|
8
|
-
|
9
|
+
|
10
|
+
class Category(str, SubclassJSONSerializer, Enum):
|
9
11
|
|
10
12
|
@classmethod
|
11
13
|
def from_str(cls, value: str) -> Category:
|
@@ -19,6 +21,13 @@ class Category(str, Enum):
|
|
19
21
|
def as_dict(self):
|
20
22
|
return {self.__class__.__name__.lower(): self.value}
|
21
23
|
|
24
|
+
def _to_json(self) -> Dict[str, Any]:
|
25
|
+
return self.as_dict
|
26
|
+
|
27
|
+
@classmethod
|
28
|
+
def _from_json(cls, data: Dict[str, Any]) -> Category:
|
29
|
+
return cls.from_str(data[cls.__name__.lower()])
|
30
|
+
|
22
31
|
|
23
32
|
class Stop(Category):
|
24
33
|
"""
|
@@ -21,165 +21,7 @@ if TYPE_CHECKING:
|
|
21
21
|
from .callable_expression import CallableExpression
|
22
22
|
|
23
23
|
|
24
|
-
class
|
25
|
-
"""
|
26
|
-
A custom set class that is used to add other attributes to the set. This is similar to a table where the set is the
|
27
|
-
table, the attributes are the columns, and the values are the rows.
|
28
|
-
"""
|
29
|
-
_value_range: set
|
30
|
-
"""
|
31
|
-
The range of the attribute, this can be a set of possible values or a range of numeric values (int, float).
|
32
|
-
"""
|
33
|
-
_registry: Dict[(str, type), Type[SubClassFactory]] = {}
|
34
|
-
"""
|
35
|
-
A dictionary of all dynamically created subclasses of this class.
|
36
|
-
"""
|
37
|
-
_generated_classes_dir: str = os.path.dirname(os.path.abspath(__file__)) + "/generated"
|
38
|
-
|
39
|
-
@classmethod
|
40
|
-
def create(cls, name: str, range_: set, class_attributes: Optional[Dict[str, Any]] = None,
|
41
|
-
default_values: bool = True,
|
42
|
-
attributes_type_hints: Optional[Dict[str, Type]] = None) -> Type[SubClassFactory]:
|
43
|
-
"""
|
44
|
-
Create a new subclass.
|
45
|
-
|
46
|
-
:param name: The name of the subclass.
|
47
|
-
:param range_: The range of the subclass values.
|
48
|
-
:param class_attributes: The attributes of the new subclass.
|
49
|
-
:param default_values: Boolean indicating whether to add default values to the subclass attributes or not.
|
50
|
-
:param attributes_type_hints: The type hints of the subclass attributes.
|
51
|
-
:return: The new subclass.
|
52
|
-
"""
|
53
|
-
existing_class = cls._get_and_update_subclass(name, range_)
|
54
|
-
if existing_class:
|
55
|
-
return existing_class
|
56
|
-
|
57
|
-
new_attribute_type = cls._create_class_in_new_python_file_and_import_it(name, range_, default_values,
|
58
|
-
class_attributes, attributes_type_hints)
|
59
|
-
|
60
|
-
cls.register(new_attribute_type)
|
61
|
-
|
62
|
-
return new_attribute_type
|
63
|
-
|
64
|
-
@classmethod
|
65
|
-
def _create_class_in_new_python_file_and_import_it(cls, name: str, range_: set, default_values: bool = True,
|
66
|
-
class_attributes: Optional[Dict[str, Any]] = None,
|
67
|
-
attributes_type_hints: Optional[Dict[str, Type]] = None)\
|
68
|
-
-> Type[SubClassFactory]:
|
69
|
-
def get_type_import(value_type: Any) -> Tuple[str, str]:
|
70
|
-
if value_type is type(None):
|
71
|
-
return "from types import NoneType\n", "NoneType"
|
72
|
-
elif value_type.__module__ != "builtins":
|
73
|
-
value_type_alias = f"{value_type.__name__}_"
|
74
|
-
return f"from {value_type.__module__} import {value_type.__name__} as {value_type_alias}\n", value_type_alias
|
75
|
-
else:
|
76
|
-
return "", value_type.__name__
|
77
|
-
attributes_type_hints = attributes_type_hints or {}
|
78
|
-
parent_class_alias = cls.__name__ + "_"
|
79
|
-
imports = f"from {cls.__module__} import {cls.__name__} as {parent_class_alias}\n"
|
80
|
-
class_code = f"class {name}({parent_class_alias}):\n"
|
81
|
-
class_attributes = copy(class_attributes) if class_attributes else {}
|
82
|
-
class_attributes.update({"_value_range": range_})
|
83
|
-
for key, value in class_attributes.items():
|
84
|
-
if value is not None:
|
85
|
-
new_import, value_type_name = get_type_import(type(value))
|
86
|
-
elif key in attributes_type_hints:
|
87
|
-
new_import, value_type_name = get_type_import(attributes_type_hints[key])
|
88
|
-
else:
|
89
|
-
new_import, value_type_name = "from typing_extensions import Any", "Any"
|
90
|
-
imports += new_import
|
91
|
-
if isinstance(value, set):
|
92
|
-
value_names = []
|
93
|
-
for v in value:
|
94
|
-
if isinstance(v, type):
|
95
|
-
new_import, v_name = get_type_import(v)
|
96
|
-
imports += new_import
|
97
|
-
else:
|
98
|
-
v_name = str(v)
|
99
|
-
value_names.append(v_name)
|
100
|
-
value_str = ", ".join(value_names)
|
101
|
-
new_value = "{" + value_str + "}"
|
102
|
-
elif isinstance(value, type):
|
103
|
-
new_import, value_name = get_type_import(value)
|
104
|
-
new_value = value_name
|
105
|
-
value_type_name = value_name
|
106
|
-
else:
|
107
|
-
new_value = value
|
108
|
-
if default_values or key == "_value_range":
|
109
|
-
class_code += f" {key}: {value_type_name} = {new_value}\n"
|
110
|
-
else:
|
111
|
-
class_code += f" {key}: {value_type_name}\n"
|
112
|
-
imports += "\n\n"
|
113
|
-
if issubclass(cls, Row):
|
114
|
-
folder_name = "row"
|
115
|
-
elif issubclass(cls, Column):
|
116
|
-
folder_name = "column"
|
117
|
-
else:
|
118
|
-
raise ValueError(f"Unknown class {cls}.")
|
119
|
-
# write the code to a file
|
120
|
-
with open(f"{cls._generated_classes_dir}/{folder_name}/{name.lower()}.py", "w") as f:
|
121
|
-
f.write(imports + class_code)
|
122
|
-
|
123
|
-
# import the class from the file
|
124
|
-
import_path = ".".join(cls.__module__.split(".")[:-1] + ["generated", folder_name, name.lower()])
|
125
|
-
time.sleep(0.3)
|
126
|
-
return __import__(import_path, fromlist=[name.lower()]).__dict__[name]
|
127
|
-
|
128
|
-
@classmethod
|
129
|
-
def _get_and_update_subclass(cls, name: str, range_: set) -> Optional[Type[SubClassFactory]]:
|
130
|
-
"""
|
131
|
-
Get a subclass of the attribute class and update its range if necessary.
|
132
|
-
|
133
|
-
:param name: The name of the column.
|
134
|
-
:param range_: The range of the column values.
|
135
|
-
"""
|
136
|
-
key = (name.lower(), cls)
|
137
|
-
if key in cls._registry:
|
138
|
-
if not cls._registry[key].is_within_range(range_):
|
139
|
-
if isinstance(cls._registry[key]._value_range, set):
|
140
|
-
cls._registry[key]._value_range.update(range_)
|
141
|
-
else:
|
142
|
-
raise ValueError(f"Range of {key} is different from {cls._registry[key]._value_range}.")
|
143
|
-
return cls._registry[key]
|
144
|
-
|
145
|
-
@classmethod
|
146
|
-
def register(cls, subclass: Type[SubClassFactory]):
|
147
|
-
"""
|
148
|
-
Register a subclass of the attribute class, this is used to be able to dynamically create Attribute subclasses.
|
149
|
-
|
150
|
-
:param subclass: The subclass to register.
|
151
|
-
"""
|
152
|
-
if not issubclass(subclass, SubClassFactory):
|
153
|
-
raise ValueError(f"{subclass} is not a subclass of CustomSet.")
|
154
|
-
if subclass not in cls._registry:
|
155
|
-
cls._registry[(subclass.__name__.lower(), cls)] = subclass
|
156
|
-
else:
|
157
|
-
raise ValueError(f"{subclass} is already registered.")
|
158
|
-
|
159
|
-
@classmethod
|
160
|
-
def is_within_range(cls, value: Any) -> bool:
|
161
|
-
"""
|
162
|
-
Check if a value is within the range of the custom set.
|
163
|
-
|
164
|
-
:param value: The value to check.
|
165
|
-
:return: Boolean indicating whether the value is within the range or not.
|
166
|
-
"""
|
167
|
-
if hasattr(value, "__iter__") and not isinstance(value, str):
|
168
|
-
if all(isinstance(val_range, type) and isinstance(v, val_range)
|
169
|
-
for v in value for val_range in cls._value_range):
|
170
|
-
return True
|
171
|
-
else:
|
172
|
-
return set(value).issubset(cls._value_range)
|
173
|
-
elif isinstance(value, str):
|
174
|
-
return value.lower() in cls._value_range
|
175
|
-
else:
|
176
|
-
return value in cls._value_range
|
177
|
-
|
178
|
-
def __instancecheck__(self, instance):
|
179
|
-
return isinstance(instance, (SubClassFactory, *self._value_range))
|
180
|
-
|
181
|
-
|
182
|
-
class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
|
24
|
+
class Row(UserDict, SubclassJSONSerializer):
|
183
25
|
"""
|
184
26
|
A collection of attributes that represents a set of constraints on a case. This is a dictionary where the keys are
|
185
27
|
the names of the attributes and the values are the attributes. All are stored in lower case.
|
@@ -193,7 +35,7 @@ class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
|
|
193
35
|
:param kwargs: The attributes of the row.
|
194
36
|
"""
|
195
37
|
super().__init__(kwargs)
|
196
|
-
self.
|
38
|
+
self.id_ = id_ if id_ else id(self)
|
197
39
|
|
198
40
|
@classmethod
|
199
41
|
def from_obj(cls, obj: Any, obj_name: Optional[str] = None, max_recursion_idx: int = 3) -> Row:
|
@@ -219,10 +61,10 @@ class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
|
|
219
61
|
value.update(make_set(self[name]))
|
220
62
|
super().__setitem__(name, value)
|
221
63
|
else:
|
222
|
-
super().__setitem__(name, make_set(self[name]))
|
64
|
+
super().__setitem__(name, make_set([self[name], value]))
|
223
65
|
else:
|
224
|
-
setattr(self, name, value)
|
225
66
|
super().__setitem__(name, value)
|
67
|
+
setattr(self, name, self[name])
|
226
68
|
|
227
69
|
def __contains__(self, item):
|
228
70
|
if isinstance(item, (type, Enum)):
|
@@ -241,16 +83,12 @@ class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
|
|
241
83
|
return super().__eq__(other)
|
242
84
|
|
243
85
|
def __hash__(self):
|
244
|
-
return
|
86
|
+
return self.id_
|
245
87
|
|
246
|
-
def
|
247
|
-
return isinstance(instance, (dict, UserDict, Row)) or super().__instancecheck__(instance)
|
248
|
-
|
249
|
-
def to_json(self) -> Dict[str, Any]:
|
88
|
+
def _to_json(self) -> Dict[str, Any]:
|
250
89
|
serializable = {k: v for k, v in self.items() if not k.startswith("_")}
|
251
|
-
serializable["_id"] = self.
|
252
|
-
return {
|
253
|
-
**{k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}}
|
90
|
+
serializable["_id"] = self.id_
|
91
|
+
return {k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}
|
254
92
|
|
255
93
|
@classmethod
|
256
94
|
def _from_json(cls, data: Dict[str, Any]) -> Row:
|
@@ -280,16 +118,15 @@ class ColumnValue(SubclassJSONSerializer):
|
|
280
118
|
def __hash__(self):
|
281
119
|
return self.id
|
282
120
|
|
283
|
-
def
|
284
|
-
return {
|
285
|
-
"id": self.id, "value": self.value}
|
121
|
+
def _to_json(self) -> Dict[str, Any]:
|
122
|
+
return {"id": self.id, "value": self.value}
|
286
123
|
|
287
124
|
@classmethod
|
288
125
|
def _from_json(cls, data: Dict[str, Any]) -> ColumnValue:
|
289
126
|
return cls(id=data["id"], value=data["value"])
|
290
127
|
|
291
128
|
|
292
|
-
class Column(set,
|
129
|
+
class Column(set, SubclassJSONSerializer):
|
293
130
|
nullable: bool = True
|
294
131
|
"""
|
295
132
|
A boolean indicating whether the column can be None or not.
|
@@ -321,21 +158,6 @@ class Column(set, SubClassFactory, SubclassJSONSerializer):
|
|
321
158
|
values = {ColumnValue(id(values), v) for v in values}
|
322
159
|
return values
|
323
160
|
|
324
|
-
@classmethod
|
325
|
-
def create(cls, name: str, range_: set,
|
326
|
-
nullable: bool = True, mutually_exclusive: bool = False) -> Type[SubClassFactory]:
|
327
|
-
return super().create(name, range_, {"nullable": nullable, "mutually_exclusive": mutually_exclusive})
|
328
|
-
|
329
|
-
@classmethod
|
330
|
-
def create_from_enum(cls, category: Type[Enum], nullable: bool = True,
|
331
|
-
mutually_exclusive: bool = False) -> Type[SubClassFactory]:
|
332
|
-
new_cls = cls.create(category.__name__.lower(), {category}, nullable=nullable,
|
333
|
-
mutually_exclusive=mutually_exclusive)
|
334
|
-
for value in category:
|
335
|
-
value_column = cls.create(category.__name__.lower(), {value}, mutually_exclusive=mutually_exclusive)(value)
|
336
|
-
setattr(new_cls, value.name, value_column)
|
337
|
-
return new_cls
|
338
|
-
|
339
161
|
@classmethod
|
340
162
|
def from_obj(cls, values: Set[Any], row_obj: Optional[Any] = None) -> Column:
|
341
163
|
id_ = id(row_obj) if row_obj else id(values)
|
@@ -373,12 +195,9 @@ class Column(set, SubClassFactory, SubclassJSONSerializer):
|
|
373
195
|
return "None"
|
374
196
|
return str({v for v in self}) if len(self) > 1 else str(next(iter(self)))
|
375
197
|
|
376
|
-
def
|
377
|
-
return
|
378
|
-
|
379
|
-
def to_json(self) -> Dict[str, Any]:
|
380
|
-
return {**SubclassJSONSerializer.to_json(self),
|
381
|
-
**{id_: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for id_, v in self.id_value_map.items()}}
|
198
|
+
def _to_json(self) -> Dict[str, Any]:
|
199
|
+
return {id_: v.to_json() if isinstance(v, SubclassJSONSerializer) else v
|
200
|
+
for id_, v in self.id_value_map.items()}
|
382
201
|
|
383
202
|
@classmethod
|
384
203
|
def _from_json(cls, data: Dict[str, Any]) -> Column:
|
@@ -397,8 +216,7 @@ def create_rows_from_dataframe(df: DataFrame, name: Optional[str] = None) -> Lis
|
|
397
216
|
col_names = list(df.columns)
|
398
217
|
for row_id, row in df.iterrows():
|
399
218
|
row = {col_name: row[col_name].item() for col_name in col_names}
|
400
|
-
|
401
|
-
rows.append(row_cls(id_=row_id, **row))
|
219
|
+
rows.append(Row(id_=row_id, **row))
|
402
220
|
return rows
|
403
221
|
|
404
222
|
|
@@ -420,17 +238,12 @@ def create_row(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
|
|
420
238
|
or (obj.__class__ in [MetaData, registry])):
|
421
239
|
return Row(id_=id(obj), **{obj_name or obj.__class__.__name__: make_set(obj) if parent_is_iterable else obj})
|
422
240
|
row = Row(id_=id(obj))
|
423
|
-
attributes_type_hints = {}
|
424
241
|
for attr in dir(obj):
|
425
242
|
if attr.startswith("_") or callable(getattr(obj, attr)):
|
426
243
|
continue
|
427
244
|
attr_value = getattr(obj, attr)
|
428
245
|
row = create_or_update_row_from_attribute(attr_value, attr, obj, attr, recursion_idx,
|
429
246
|
max_recursion_idx, parent_is_iterable, row)
|
430
|
-
attributes_type_hints[attr] = get_value_type_from_type_hint(attr, obj)
|
431
|
-
row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False,
|
432
|
-
attributes_type_hints=attributes_type_hints)
|
433
|
-
row = row_cls(id_=id(obj), **row)
|
434
247
|
return row
|
435
248
|
|
436
249
|
|
@@ -462,9 +275,6 @@ def create_or_update_row_from_attribute(attr_value: Any, name: str, obj: Any, ob
|
|
462
275
|
row[obj_name] = column
|
463
276
|
else:
|
464
277
|
row[obj_name] = make_set(attr_value) if parent_is_iterable else attr_value
|
465
|
-
if row.__class__.__name__ == "Row":
|
466
|
-
row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False)
|
467
|
-
row = row_cls(id_=id(obj), **row)
|
468
278
|
return row
|
469
279
|
|
470
280
|
|
@@ -489,15 +299,12 @@ def create_column_and_row_from_iterable_attribute(attr_value: Any, name: str, ob
|
|
489
299
|
if not range_:
|
490
300
|
raise ValueError(f"Could not determine the range of {name} in {obj}.")
|
491
301
|
attr_row = Row(id_=id(attr_value))
|
492
|
-
column = Column.
|
493
|
-
attributes_type_hints = {}
|
302
|
+
column = Column.from_obj(values, row_obj=obj)
|
494
303
|
for idx, val in enumerate(values):
|
495
304
|
sub_attr_row = create_row(val, recursion_idx=recursion_idx,
|
496
305
|
max_recursion_idx=max_recursion_idx,
|
497
306
|
obj_name=obj_name, parent_is_iterable=True)
|
498
307
|
attr_row.update(sub_attr_row)
|
499
|
-
# attr_row_cls = Row.create(name or list(range_)[0].__name__, range_, attr_row, default_values=False)
|
500
|
-
# attr_row = attr_row_cls(id_=id(attr_value), **attr_row)
|
501
308
|
for sub_attr, val in attr_row.items():
|
502
309
|
setattr(column, sub_attr, val)
|
503
310
|
return column, attr_row
|
ripple_down_rules/experts.py
CHANGED
@@ -9,7 +9,7 @@ from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type,
|
|
9
9
|
from .datastructures import (Case, PromptFor, CallableExpression, Column, CaseQuery)
|
10
10
|
from .datastructures.table import show_current_and_corner_cases
|
11
11
|
from .prompt import prompt_user_for_expression, prompt_user_about_case
|
12
|
-
from .utils import get_all_subclasses
|
12
|
+
from .utils import get_all_subclasses, is_iterable
|
13
13
|
|
14
14
|
if TYPE_CHECKING:
|
15
15
|
from .rdr import Rule
|
@@ -125,7 +125,7 @@ class Human(Expert):
|
|
125
125
|
self.all_expert_answers = json.load(f)
|
126
126
|
|
127
127
|
def ask_for_conditions(self, case: Case,
|
128
|
-
targets: Union[List[Column], List[
|
128
|
+
targets: Union[List[Column], List[SQLColumn]],
|
129
129
|
last_evaluated_rule: Optional[Rule] = None) \
|
130
130
|
-> CallableExpression:
|
131
131
|
if not self.use_loaded_answers:
|
@@ -198,19 +198,6 @@ class Human(Expert):
|
|
198
198
|
self.all_expert_answers.append(expert_input)
|
199
199
|
return expression
|
200
200
|
|
201
|
-
def create_category_instance(self, cat_name: str, cat_value: Union[str, int, float, set]) -> Column:
|
202
|
-
"""
|
203
|
-
Create a new category instance.
|
204
|
-
|
205
|
-
:param cat_name: The name of the category.
|
206
|
-
:param cat_value: The value of the category.
|
207
|
-
:return: A new instance of the category.
|
208
|
-
"""
|
209
|
-
category_type = self.get_category_type(cat_name)
|
210
|
-
if not category_type:
|
211
|
-
category_type = self.create_new_category_type(cat_name)
|
212
|
-
return category_type(cat_value)
|
213
|
-
|
214
201
|
def get_category_type(self, cat_name: str) -> Optional[Type[Column]]:
|
215
202
|
"""
|
216
203
|
Get the category type from the known categories.
|
@@ -226,19 +213,6 @@ class Human(Expert):
|
|
226
213
|
category_type = self.known_categories[cat_name]
|
227
214
|
return category_type
|
228
215
|
|
229
|
-
def create_new_category_type(self, cat_name: str) -> Type[Column]:
|
230
|
-
"""
|
231
|
-
Create a new category type.
|
232
|
-
|
233
|
-
:param cat_name: The name of the category.
|
234
|
-
:return: A new category type.
|
235
|
-
"""
|
236
|
-
if self.ask_if_category_is_mutually_exclusive(cat_name):
|
237
|
-
category_type: Type[Column] = Column.create(cat_name, set(), mutually_exclusive=True)
|
238
|
-
else:
|
239
|
-
category_type: Type[Column] = Column.create(cat_name, set())
|
240
|
-
return category_type
|
241
|
-
|
242
216
|
def ask_if_category_is_mutually_exclusive(self, category_name: str) -> bool:
|
243
217
|
"""
|
244
218
|
Ask the expert if the new category can have multiple values.
|
ripple_down_rules/rdr.py
CHANGED
@@ -12,7 +12,7 @@ from .datastructures import Case, MCRDRMode, CallableExpression, Column, CaseQue
|
|
12
12
|
from .experts import Expert, Human
|
13
13
|
from .rules import Rule, SingleClassRule, MultiClassTopRule
|
14
14
|
from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
|
15
|
-
get_hint_for_attribute, SubclassJSONSerializer
|
15
|
+
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list
|
16
16
|
|
17
17
|
|
18
18
|
class RippleDownRules(ABC):
|
@@ -124,10 +124,12 @@ class RippleDownRules(ABC):
|
|
124
124
|
:param target: The target category.
|
125
125
|
:return: The precision and recall of the classifier.
|
126
126
|
"""
|
127
|
-
pred_cat = pred_cat if
|
128
|
-
target = target if
|
127
|
+
pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
|
128
|
+
target = target if is_iterable(target) else [target]
|
129
129
|
recall = [not yi or (yi in pred_cat) for yi in target]
|
130
130
|
target_types = [type(yi) for yi in target]
|
131
|
+
if len(pred_cat) > 1:
|
132
|
+
print(pred_cat)
|
131
133
|
precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
|
132
134
|
return precision, recall
|
133
135
|
|
@@ -265,8 +267,8 @@ class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
|
|
265
267
|
if rule.alternative:
|
266
268
|
self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
|
267
269
|
|
268
|
-
def
|
269
|
-
return {
|
270
|
+
def _to_json(self) -> Dict[str, Any]:
|
271
|
+
return {"start_rule": self.start_rule.to_json()}
|
270
272
|
|
271
273
|
@classmethod
|
272
274
|
def _from_json(cls, data: Dict[str, Any]) -> Self:
|
@@ -358,7 +360,7 @@ class MultiClassRDR(RippleDownRules):
|
|
358
360
|
self.add_conclusion(evaluated_rule)
|
359
361
|
|
360
362
|
if not next_rule:
|
361
|
-
if
|
363
|
+
if not make_set(target).intersection(make_set(self.conclusions)):
|
362
364
|
# Nothing fired and there is a target that should have been in the conclusions
|
363
365
|
self.add_rule_for_case(case, target, expert)
|
364
366
|
# Have to check all rules again to make sure only this new rule fires
|
@@ -509,16 +511,16 @@ class MultiClassRDR(RippleDownRules):
|
|
509
511
|
"""
|
510
512
|
conclusion_types = [type(c) for c in self.conclusions]
|
511
513
|
if type(evaluated_rule.conclusion) not in conclusion_types:
|
512
|
-
self.conclusions.
|
514
|
+
self.conclusions.extend(make_list(evaluated_rule.conclusion))
|
513
515
|
else:
|
514
516
|
same_type_conclusions = [c for c in self.conclusions if type(c) == type(evaluated_rule.conclusion)]
|
515
517
|
combined_conclusion = evaluated_rule.conclusion if isinstance(evaluated_rule.conclusion, set) \
|
516
518
|
else {evaluated_rule.conclusion}
|
517
|
-
combined_conclusion =
|
519
|
+
combined_conclusion = copy(combined_conclusion)
|
518
520
|
for c in same_type_conclusions:
|
519
521
|
combined_conclusion.update(c if isinstance(c, set) else make_set(c))
|
520
522
|
self.conclusions.remove(c)
|
521
|
-
self.conclusions.extend(combined_conclusion)
|
523
|
+
self.conclusions.extend(make_list(combined_conclusion))
|
522
524
|
|
523
525
|
def add_top_rule(self, conditions: CallableExpression, conclusion: Any, corner_case: Union[Case, SQLTable]):
|
524
526
|
"""
|
@@ -587,11 +589,11 @@ class GeneralRDR(RippleDownRules):
|
|
587
589
|
continue
|
588
590
|
pred_atts = rdr.classify(case_cp)
|
589
591
|
if pred_atts:
|
590
|
-
pred_atts =
|
592
|
+
pred_atts = make_list(pred_atts)
|
591
593
|
pred_atts = [p for p in pred_atts if p not in conclusions]
|
592
594
|
added_attributes = True
|
593
595
|
conclusions.extend(pred_atts)
|
594
|
-
|
596
|
+
GeneralRDR.update_case(case_cp, pred_atts)
|
595
597
|
if not added_attributes:
|
596
598
|
break
|
597
599
|
return conclusions
|
@@ -624,21 +626,27 @@ class GeneralRDR(RippleDownRules):
|
|
624
626
|
if not target:
|
625
627
|
target = expert.ask_for_conclusion(case_query)
|
626
628
|
case_query_cp = CaseQuery(case_cp, attribute_name=case_query.attribute_name, target=target)
|
627
|
-
if
|
629
|
+
if is_iterable(target) and not isinstance(target, Column):
|
630
|
+
target_type = type(make_list(target)[0])
|
631
|
+
assert all([type(t) is target_type for t in target]), ("All targets of a case query must be of the same"
|
632
|
+
" type")
|
633
|
+
else:
|
634
|
+
target_type = type(target)
|
635
|
+
if target_type not in self.start_rules_dict:
|
628
636
|
conclusions = self.classify(case)
|
629
|
-
self.
|
637
|
+
self.update_case(case_cp, conclusions)
|
630
638
|
new_rdr = self.initialize_new_rdr_for_attribute(target, case_cp)
|
631
639
|
new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
|
632
|
-
self.start_rules_dict[
|
633
|
-
self.
|
634
|
-
elif not self.case_has_conclusion(case_cp,
|
640
|
+
self.start_rules_dict[target_type] = new_rdr
|
641
|
+
self.update_case(case_cp, new_conclusions, target_type)
|
642
|
+
elif not self.case_has_conclusion(case_cp, target_type):
|
635
643
|
for rdr_type, rdr in self.start_rules_dict.items():
|
636
|
-
if
|
644
|
+
if target_type is not rdr_type:
|
637
645
|
conclusions = rdr.classify(case_cp)
|
638
646
|
else:
|
639
|
-
conclusions = self.start_rules_dict[
|
647
|
+
conclusions = self.start_rules_dict[target_type].fit_case(case_query_cp,
|
640
648
|
expert, **kwargs)
|
641
|
-
self.
|
649
|
+
self.update_case(case_cp, conclusions, rdr_type)
|
642
650
|
|
643
651
|
return self.classify(case)
|
644
652
|
|
@@ -653,12 +661,14 @@ class GeneralRDR(RippleDownRules):
|
|
653
661
|
return MultiClassRDR()
|
654
662
|
else:
|
655
663
|
return SingleClassRDR()
|
656
|
-
|
664
|
+
elif isinstance(attribute, Column):
|
657
665
|
return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
|
666
|
+
else:
|
667
|
+
return MultiClassRDR() if is_iterable(attribute) else SingleClassRDR()
|
658
668
|
|
659
669
|
@staticmethod
|
660
|
-
def
|
661
|
-
|
670
|
+
def update_case(case: Union[Case, SQLTable],
|
671
|
+
conclusions: List[Any], attribute_type: Optional[Any] = None):
|
662
672
|
"""
|
663
673
|
Update the case with the conclusions.
|
664
674
|
|
@@ -668,7 +678,7 @@ class GeneralRDR(RippleDownRules):
|
|
668
678
|
"""
|
669
679
|
if not conclusions:
|
670
680
|
return
|
671
|
-
conclusions = [conclusions] if not isinstance(conclusions, list) else conclusions
|
681
|
+
conclusions = [conclusions] if not isinstance(conclusions, list) else list(conclusions)
|
672
682
|
if len(conclusions) == 0:
|
673
683
|
return
|
674
684
|
if isinstance(case, SQLTable):
|
@@ -677,7 +687,8 @@ class GeneralRDR(RippleDownRules):
|
|
677
687
|
hint, origin, args = get_hint_for_attribute(attr_name, case)
|
678
688
|
if isinstance(attribute, set) or origin == set:
|
679
689
|
attribute = set() if attribute is None else attribute
|
680
|
-
|
690
|
+
for c in conclusions:
|
691
|
+
attribute.update(make_set(c))
|
681
692
|
elif isinstance(attribute, list) or origin == list:
|
682
693
|
attribute = [] if attribute is None else attribute
|
683
694
|
attribute.extend(conclusions)
|
@@ -686,7 +697,8 @@ class GeneralRDR(RippleDownRules):
|
|
686
697
|
else:
|
687
698
|
raise ValueError(f"Cannot add multiple conclusions to attribute {attr_name}")
|
688
699
|
else:
|
689
|
-
|
700
|
+
for c in make_set(conclusions):
|
701
|
+
case.update(c.as_dict)
|
690
702
|
|
691
703
|
@property
|
692
704
|
def names_of_all_types(self) -> List[str]:
|
ripple_down_rules/rules.py
CHANGED
@@ -195,9 +195,8 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
|
195
195
|
if_clause = "elif" if self.weight == RDREdge.Alternative.value else "if"
|
196
196
|
return f"{parent_indent}{if_clause} {self.conditions.parsed_user_input}:\n"
|
197
197
|
|
198
|
-
def
|
199
|
-
self.json_serialization = {
|
200
|
-
"conditions": self.conditions.to_json(),
|
198
|
+
def _to_json(self) -> Dict[str, Any]:
|
199
|
+
self.json_serialization = {"conditions": self.conditions.to_json(),
|
201
200
|
"conclusion": self.conclusion.to_json(),
|
202
201
|
"parent": self.parent.json_serialization if self.parent else None,
|
203
202
|
"corner_case": self.corner_case.to_json() if self.corner_case else None,
|
ripple_down_rules/utils.py
CHANGED
@@ -24,6 +24,24 @@ if TYPE_CHECKING:
|
|
24
24
|
matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
|
25
25
|
|
26
26
|
|
27
|
+
def make_list(value: Any) -> List:
|
28
|
+
"""
|
29
|
+
Make a list from a value.
|
30
|
+
|
31
|
+
:param value: The value to make a list from.
|
32
|
+
"""
|
33
|
+
return list(value) if is_iterable(value) else [value]
|
34
|
+
|
35
|
+
|
36
|
+
def is_iterable(obj: Any) -> bool:
|
37
|
+
"""
|
38
|
+
Check if an object is iterable.
|
39
|
+
|
40
|
+
:param obj: The object to check.
|
41
|
+
"""
|
42
|
+
return hasattr(obj, "__iter__") and not isinstance(obj, (str, type))
|
43
|
+
|
44
|
+
|
27
45
|
def get_type_from_string(type_path: str):
|
28
46
|
"""
|
29
47
|
Get a type from a string describing its path using the format "module_path.ClassName".
|
@@ -64,7 +82,14 @@ class SubclassJSONSerializer:
|
|
64
82
|
"""
|
65
83
|
|
66
84
|
def to_json(self) -> Dict[str, Any]:
|
67
|
-
return {"_type": get_full_class_name(self.__class__)}
|
85
|
+
return {"_type": get_full_class_name(self.__class__), **self._to_json()}
|
86
|
+
|
87
|
+
@abstractmethod
|
88
|
+
def _to_json(self) -> Dict[str, Any]:
|
89
|
+
"""
|
90
|
+
Create a json dict from the object.
|
91
|
+
"""
|
92
|
+
pass
|
68
93
|
|
69
94
|
@classmethod
|
70
95
|
@abstractmethod
|
@@ -350,20 +375,7 @@ def make_set(value: Any) -> Set:
|
|
350
375
|
|
351
376
|
:param value: The value to make a set from.
|
352
377
|
"""
|
353
|
-
|
354
|
-
return set(value)
|
355
|
-
return {value}
|
356
|
-
|
357
|
-
|
358
|
-
def make_list(value: Any) -> List:
|
359
|
-
"""
|
360
|
-
Make a list from a value.
|
361
|
-
|
362
|
-
:param value: The value to make a list from.
|
363
|
-
"""
|
364
|
-
if hasattr(value, "__iter__") and not isinstance(value, (str, type)):
|
365
|
-
return list(value)
|
366
|
-
return [value]
|
378
|
+
return set(value) if is_iterable(value) else {value}
|
367
379
|
|
368
380
|
|
369
381
|
def make_value_or_raise_error(value: Any) -> Any:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.7
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -0,0 +1,18 @@
|
|
1
|
+
ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
ripple_down_rules/datasets.py,sha256=p7WmLiAKcYtzLq87RJhzU_BHCdM-zKK71yldjyYNpUE,4602
|
3
|
+
ripple_down_rules/experts.py,sha256=ajH46NXM6IREtWr2a-ZBxdKl4CFzSx5S7yktnyXgwnc,11089
|
4
|
+
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
|
+
ripple_down_rules/prompt.py,sha256=lmREZRyleBTHrVtcf2j_48oc0v3VlxXYGhl6w1mk8qI,4208
|
6
|
+
ripple_down_rules/rdr.py,sha256=dS_OqaGtP1PAKEe0h-dMy7JI9ycNuH_EIbHlh-Nz4Sg,34189
|
7
|
+
ripple_down_rules/rules.py,sha256=R9wIyalXBc6fGogzy92OxsyKAtJdQwdso_Vs1evWZGU,10274
|
8
|
+
ripple_down_rules/utils.py,sha256=VjpdJ5W-xK-TQttfpNIHRKpYK91vYjtQ22kgccQgVTQ,17183
|
9
|
+
ripple_down_rules/datastructures/__init__.py,sha256=wY9WqXavuE3wQ1YP65cs_SZyr7CEMB9tol-4oxgK9CM,104
|
10
|
+
ripple_down_rules/datastructures/callable_expression.py,sha256=lF8ag3-8oYuETAUohtRfviOpypSMQR-JEJGgR0Zhv3E,9055
|
11
|
+
ripple_down_rules/datastructures/dataclasses.py,sha256=FW3MMhGXTPa0XwIyLzGalrPwiltNeUqbGIawCAKDHGk,2448
|
12
|
+
ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
|
13
|
+
ripple_down_rules/datastructures/table.py,sha256=G2ShfPmyiwfDXS-JY_jFD6nOx7OLuMr-GP--OCDGKYc,13733
|
14
|
+
ripple_down_rules-0.0.7.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
15
|
+
ripple_down_rules-0.0.7.dist-info/METADATA,sha256=fBb_rMvpmfAueLPCXvNs3IXLNX-COunhRnEq7iYkV5Y,42518
|
16
|
+
ripple_down_rules-0.0.7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
17
|
+
ripple_down_rules-0.0.7.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
18
|
+
ripple_down_rules-0.0.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
@@ -1,21 +0,0 @@
|
|
1
|
-
ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
ripple_down_rules/datasets.py,sha256=QRB-1BdFTcUNuhgYEuXYx6qQOYlDu03_iLKDBqrcVrQ,4511
|
3
|
-
ripple_down_rules/experts.py,sha256=DMTC-E2g1Fs43oyr30MGeGi5-VKBb3RojzzPa8DvCSA,12093
|
4
|
-
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
|
-
ripple_down_rules/prompt.py,sha256=lmREZRyleBTHrVtcf2j_48oc0v3VlxXYGhl6w1mk8qI,4208
|
6
|
-
ripple_down_rules/rdr.py,sha256=nRhrHdQYmvXgfTFLyO_xxwe7GDYuUhaOqZlQtqSUrnE,33751
|
7
|
-
ripple_down_rules/rules.py,sha256=TLptqvA6I3QlQaVBTYchbgvXm17XWwFJoTmoN0diHm8,10348
|
8
|
-
ripple_down_rules/utils.py,sha256=WlUXTf-B45SPEKpDBVb9QPQWS54250MGAY5xl9MIhR4,16944
|
9
|
-
ripple_down_rules/datastructures/__init__.py,sha256=wY9WqXavuE3wQ1YP65cs_SZyr7CEMB9tol-4oxgK9CM,104
|
10
|
-
ripple_down_rules/datastructures/callable_expression.py,sha256=yEZ6OWzSiWsRtEz_5UquA0inmodFTOWqXWV_a2gg1cg,9110
|
11
|
-
ripple_down_rules/datastructures/dataclasses.py,sha256=z_9B7Nj_MIf2Iyrs5VeUhXhYxwaqnuKVjgwxhZZTygY,2525
|
12
|
-
ripple_down_rules/datastructures/enums.py,sha256=6Mh55_8QRuXyYZXtonWr01VBgLP-jYp91K_8hIgh8u8,4244
|
13
|
-
ripple_down_rules/datastructures/table.py,sha256=Tq7savAaFoB7n7x6-u2Vz6NfvjLBzPxsqblR4cHujzM,23161
|
14
|
-
ripple_down_rules/datastructures/generated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
|
-
ripple_down_rules/datastructures/generated/column/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
|
-
ripple_down_rules/datastructures/generated/row/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
-
ripple_down_rules-0.0.5.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
18
|
-
ripple_down_rules-0.0.5.dist-info/METADATA,sha256=qwUDSDi3YlBsg1RlYOEtFn5YQ3x7yNizCAKm72_8YPU,42518
|
19
|
-
ripple_down_rules-0.0.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
20
|
-
ripple_down_rules-0.0.5.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
21
|
-
ripple_down_rules-0.0.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|