ripple-down-rules 0.0.15__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +2 -2
- ripple_down_rules/datastructures/callable_expression.py +52 -10
- ripple_down_rules/datastructures/case.py +54 -70
- ripple_down_rules/datastructures/dataclasses.py +69 -29
- ripple_down_rules/experts.py +29 -40
- ripple_down_rules/helpers.py +27 -0
- ripple_down_rules/prompt.py +77 -24
- ripple_down_rules/rdr.py +218 -200
- ripple_down_rules/rdr_decorators.py +55 -0
- ripple_down_rules/rules.py +7 -2
- ripple_down_rules/utils.py +167 -3
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/METADATA +1 -1
- ripple_down_rules-0.1.1.dist-info/RECORD +20 -0
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.0.15.dist-info/RECORD +0 -18
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/top_level.txt +0 -0
ripple_down_rules/datasets.py
CHANGED
@@ -9,7 +9,7 @@ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationshi
|
|
9
9
|
from typing_extensions import Tuple, List, Set, Optional
|
10
10
|
from ucimlrepo import fetch_ucirepo
|
11
11
|
|
12
|
-
from .datastructures import Case, create_cases_from_dataframe, Category
|
12
|
+
from .datastructures import Case, create_cases_from_dataframe, Category
|
13
13
|
|
14
14
|
|
15
15
|
def load_cached_dataset(cache_file):
|
@@ -77,7 +77,7 @@ def load_zoo_dataset(cache_file: Optional[str] = None) -> Tuple[List[Case], List
|
|
77
77
|
y = zoo['targets']
|
78
78
|
# get ids as list of strings
|
79
79
|
ids = zoo['ids'].values.flatten()
|
80
|
-
all_cases = create_cases_from_dataframe(X)
|
80
|
+
all_cases = create_cases_from_dataframe(X, name="Animal")
|
81
81
|
|
82
82
|
category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
|
83
83
|
category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
|
@@ -22,10 +22,22 @@ class VariableVisitor(ast.NodeVisitor):
|
|
22
22
|
def __init__(self):
|
23
23
|
self.variables = set()
|
24
24
|
self.attributes: Dict[ast.Name, ast.Attribute] = {}
|
25
|
+
self.types = set()
|
26
|
+
self.callables = set()
|
25
27
|
self.compares = list()
|
26
28
|
self.binary_ops = list()
|
27
29
|
self.all = list()
|
28
30
|
|
31
|
+
def visit_Constant(self, node):
|
32
|
+
self.all.append(node)
|
33
|
+
self.types.add(node)
|
34
|
+
self.generic_visit(node)
|
35
|
+
|
36
|
+
def visit_Call(self, node):
|
37
|
+
self.all.append(node)
|
38
|
+
self.callables.add(node)
|
39
|
+
self.generic_visit(node)
|
40
|
+
|
29
41
|
def visit_Attribute(self, node):
|
30
42
|
self.all.append(node)
|
31
43
|
self.attributes[node.value] = node
|
@@ -51,6 +63,27 @@ class VariableVisitor(ast.NodeVisitor):
|
|
51
63
|
self.generic_visit(node)
|
52
64
|
|
53
65
|
|
66
|
+
def get_used_scope(code_str, scope):
|
67
|
+
# Parse the code into an AST
|
68
|
+
tree = ast.parse(code_str, mode='eval')
|
69
|
+
|
70
|
+
# Walk the AST to collect used variable names
|
71
|
+
class NameCollector(ast.NodeVisitor):
|
72
|
+
def __init__(self):
|
73
|
+
self.names = set()
|
74
|
+
|
75
|
+
def visit_Name(self, node):
|
76
|
+
if isinstance(node.ctx, ast.Load): # We care only about variables being read
|
77
|
+
self.names.add(node.id)
|
78
|
+
|
79
|
+
collector = NameCollector()
|
80
|
+
collector.visit(tree)
|
81
|
+
|
82
|
+
# Filter the scope to include only used names
|
83
|
+
used_scope = {k: scope[k] for k in collector.names if k in scope}
|
84
|
+
return used_scope
|
85
|
+
|
86
|
+
|
54
87
|
class CallableExpression(SubclassJSONSerializer):
|
55
88
|
"""
|
56
89
|
A callable that is constructed from a string statement written by an expert.
|
@@ -85,7 +118,7 @@ class CallableExpression(SubclassJSONSerializer):
|
|
85
118
|
"""
|
86
119
|
|
87
120
|
def __init__(self, user_input: str, conclusion_type: Optional[Type] = None, expression_tree: Optional[AST] = None,
|
88
|
-
session: Optional[Session] = None):
|
121
|
+
session: Optional[Session] = None, scope: Optional[Dict[str, Any]] = None):
|
89
122
|
"""
|
90
123
|
Create a callable expression.
|
91
124
|
|
@@ -98,8 +131,17 @@ class CallableExpression(SubclassJSONSerializer):
|
|
98
131
|
self.user_input: str = user_input
|
99
132
|
self.parsed_user_input = self.parse_user_input(user_input, session)
|
100
133
|
self.conclusion_type = conclusion_type
|
134
|
+
self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
|
135
|
+
self.scope = get_used_scope(self.parsed_user_input, self.scope)
|
101
136
|
self.update_expression(self.parsed_user_input, expression_tree)
|
102
137
|
|
138
|
+
def get_used_scope_in_user_input(self) -> Set[str]:
|
139
|
+
"""
|
140
|
+
Get the used scope in the user input.
|
141
|
+
:return: The used scope in the user input.
|
142
|
+
"""
|
143
|
+
return self.visitor.variables.union(self.visitor.attributes.keys())
|
144
|
+
|
103
145
|
@staticmethod
|
104
146
|
def parse_user_input(user_input: str, session: Optional[Session] = None) -> str:
|
105
147
|
if ',' in user_input:
|
@@ -117,11 +159,6 @@ class CallableExpression(SubclassJSONSerializer):
|
|
117
159
|
self.expression_tree: AST = expression_tree
|
118
160
|
self.visitor = VariableVisitor()
|
119
161
|
self.visitor.visit(expression_tree)
|
120
|
-
variables_str = self.visitor.variables
|
121
|
-
attributes_str = get_attributes_str(self.visitor)
|
122
|
-
for v in variables_str | attributes_str:
|
123
|
-
if not v.startswith("case."):
|
124
|
-
self.parsed_user_input = self.parsed_user_input.replace(v, f"case.{v}")
|
125
162
|
self.expression_tree = parse_string_to_expression(self.parsed_user_input)
|
126
163
|
self.compares_column_offset = [(c[0].col_offset, c[2].end_col_offset) for c in self.visitor.compares]
|
127
164
|
self.code = compile_expression_to_code(self.expression_tree)
|
@@ -130,8 +167,9 @@ class CallableExpression(SubclassJSONSerializer):
|
|
130
167
|
try:
|
131
168
|
if not isinstance(case, Case):
|
132
169
|
case = create_case(case, max_recursion_idx=3)
|
133
|
-
|
134
|
-
|
170
|
+
scope = {'case': case, **self.scope}
|
171
|
+
output = eval(self.code, scope)
|
172
|
+
if self.conclusion_type is not None:
|
135
173
|
assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
|
136
174
|
f" got {type(output)}")
|
137
175
|
return output
|
@@ -163,11 +201,15 @@ class CallableExpression(SubclassJSONSerializer):
|
|
163
201
|
return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
|
164
202
|
|
165
203
|
def _to_json(self) -> Dict[str, Any]:
|
166
|
-
return {"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type)
|
204
|
+
return {"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type),
|
205
|
+
"scope": {k: get_full_class_name(v) for k, v in self.scope.items()
|
206
|
+
if hasattr(v, '__module__') and hasattr(v, '__name__')}
|
207
|
+
}
|
167
208
|
|
168
209
|
@classmethod
|
169
210
|
def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
|
170
|
-
return cls(user_input=data["user_input"], conclusion_type=get_type_from_string(data["conclusion_type"])
|
211
|
+
return cls(user_input=data["user_input"], conclusion_type=get_type_from_string(data["conclusion_type"]),
|
212
|
+
scope={k: get_type_from_string(v) for k, v in data["scope"].items()})
|
171
213
|
|
172
214
|
|
173
215
|
def compile_expression_to_code(expression_tree: AST) -> Any:
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from collections import UserDict
|
4
4
|
from copy import copy, deepcopy
|
5
|
-
from dataclasses import dataclass
|
5
|
+
from dataclasses import dataclass, is_dataclass
|
6
6
|
from enum import Enum
|
7
7
|
|
8
8
|
from pandas import DataFrame
|
@@ -11,7 +11,7 @@ from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColum
|
|
11
11
|
from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING
|
12
12
|
|
13
13
|
from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer, \
|
14
|
-
get_full_class_name, get_type_from_string
|
14
|
+
get_full_class_name, get_type_from_string, make_list, is_iterable, serialize_dataclass
|
15
15
|
|
16
16
|
if TYPE_CHECKING:
|
17
17
|
from ripple_down_rules.rules import Rule
|
@@ -24,16 +24,19 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
24
24
|
the names of the attributes and the values are the attributes. All are stored in lower case.
|
25
25
|
"""
|
26
26
|
|
27
|
-
def __init__(self, _id: Optional[Hashable] = None,
|
27
|
+
def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None, _name: Optional[str] = None, **kwargs):
|
28
28
|
"""
|
29
29
|
Create a new row.
|
30
30
|
|
31
|
+
:param _obj_type: The type of the object that the row represents.
|
31
32
|
:param _id: The id of the row.
|
33
|
+
:param _name: The semantic name that describes the row.
|
32
34
|
:param kwargs: The attributes of the row.
|
33
35
|
"""
|
34
36
|
super().__init__(kwargs)
|
35
|
-
self.
|
36
|
-
self.
|
37
|
+
self._obj_type: Type = _obj_type
|
38
|
+
self._id: Hashable = _id if _id is not None else id(self)
|
39
|
+
self._name: str = _name if _name is not None else self._obj_type.__name__
|
37
40
|
|
38
41
|
@classmethod
|
39
42
|
def from_obj(cls, obj: Any, obj_name: Optional[str] = None, max_recursion_idx: int = 3) -> Case:
|
@@ -53,13 +56,14 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
53
56
|
def __setitem__(self, name: str, value: Any):
|
54
57
|
name = name.lower()
|
55
58
|
if name in self:
|
56
|
-
if isinstance(self[name],
|
57
|
-
self[name].
|
58
|
-
elif isinstance(value,
|
59
|
-
|
60
|
-
|
59
|
+
if isinstance(self[name], list):
|
60
|
+
self[name].extend(make_list(value))
|
61
|
+
elif isinstance(value, list):
|
62
|
+
new_list = make_list(self[name])
|
63
|
+
new_list.extend(make_list(value))
|
64
|
+
super().__setitem__(name, new_list)
|
61
65
|
else:
|
62
|
-
super().__setitem__(name,
|
66
|
+
super().__setitem__(name, [self[name], value])
|
63
67
|
else:
|
64
68
|
super().__setitem__(name, value)
|
65
69
|
setattr(self, name, self[name])
|
@@ -78,18 +82,24 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
78
82
|
def _to_json(self) -> Dict[str, Any]:
|
79
83
|
serializable = {k: v for k, v in self.items() if not k.startswith("_")}
|
80
84
|
serializable["_id"] = self._id
|
85
|
+
serializable["_obj_type"] = get_full_class_name(self._obj_type)
|
86
|
+
serializable["_name"] = self._name
|
81
87
|
for k, v in serializable.items():
|
82
88
|
if isinstance(v, set):
|
83
|
-
serializable[k] = {'_type': get_full_class_name(set), 'value': list(v)}
|
89
|
+
serializable[k] = {'_type': get_full_class_name(set), 'value': serialize_dataclass(list(v))}
|
90
|
+
else:
|
91
|
+
serializable[k] = serialize_dataclass(v)
|
84
92
|
return {k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}
|
85
93
|
|
86
94
|
@classmethod
|
87
95
|
def _from_json(cls, data: Dict[str, Any]) -> Case:
|
88
96
|
id_ = data.pop("_id")
|
97
|
+
obj_type = get_type_from_string(data.pop("_obj_type"))
|
98
|
+
name = data.pop("_name")
|
89
99
|
for k, v in data.items():
|
90
100
|
if isinstance(v, dict) and "_type" in v:
|
91
101
|
data[k] = SubclassJSONSerializer.from_json(v)
|
92
|
-
return cls(_id=id_, **data)
|
102
|
+
return cls(_obj_type=obj_type, _id=id_, _name=name, **data)
|
93
103
|
|
94
104
|
|
95
105
|
@dataclass
|
@@ -122,7 +132,7 @@ class CaseAttributeValue(SubclassJSONSerializer):
|
|
122
132
|
return cls(id=data["id"], value=data["value"])
|
123
133
|
|
124
134
|
|
125
|
-
class CaseAttribute(
|
135
|
+
class CaseAttribute(list, SubclassJSONSerializer):
|
126
136
|
nullable: bool = True
|
127
137
|
"""
|
128
138
|
A boolean indicating whether the case attribute can be None or not.
|
@@ -132,33 +142,9 @@ class CaseAttribute(set, SubclassJSONSerializer):
|
|
132
142
|
A boolean indicating whether the case attribute is mutually exclusive or not. (i.e. can only have one value)
|
133
143
|
"""
|
134
144
|
|
135
|
-
def __init__(self, values: Set[CaseAttributeValue]):
|
136
|
-
"""
|
137
|
-
Create a new case attribute.
|
138
|
-
|
139
|
-
:param values: The values of the case attribute.
|
140
|
-
"""
|
141
|
-
values = self._type_cast_values_to_set_of_case_attribute_values(values)
|
142
|
-
self.id_value_map: Dict[Hashable, Union[CaseAttributeValue, Set[CaseAttributeValue]]] = {id(v): v for v in values}
|
143
|
-
super().__init__([v.value for v in values])
|
144
|
-
|
145
|
-
@staticmethod
|
146
|
-
def _type_cast_values_to_set_of_case_attribute_values(values: Set[Any]) -> Set[CaseAttributeValue]:
|
147
|
-
"""
|
148
|
-
Type cast values to a set of case attribute values.
|
149
|
-
|
150
|
-
:param values: The values to type cast.
|
151
|
-
"""
|
152
|
-
values = make_set(values)
|
153
|
-
if len(values) > 0 and not isinstance(next(iter(values)), CaseAttributeValue):
|
154
|
-
values = {CaseAttributeValue(id(values), v) for v in values}
|
155
|
-
return values
|
156
|
-
|
157
145
|
@classmethod
|
158
|
-
def from_obj(cls, values:
|
159
|
-
|
160
|
-
values = make_set(values)
|
161
|
-
return cls({CaseAttributeValue(id_, v) for v in values})
|
146
|
+
def from_obj(cls, values: List[Any]) -> CaseAttribute:
|
147
|
+
return cls(make_list(values))
|
162
148
|
|
163
149
|
@property
|
164
150
|
def as_dict(self) -> Dict[str, Any]:
|
@@ -176,42 +162,43 @@ class CaseAttribute(set, SubclassJSONSerializer):
|
|
176
162
|
:param condition: The condition to filter by.
|
177
163
|
:return: The filtered column.
|
178
164
|
"""
|
179
|
-
return self.__class__(
|
165
|
+
return self.__class__([v for v in self if condition(v)])
|
180
166
|
|
181
167
|
def __eq__(self, other):
|
182
|
-
if not isinstance(other,
|
183
|
-
return super().__eq__(
|
168
|
+
if not isinstance(other, list):
|
169
|
+
return super().__eq__(make_list(other))
|
184
170
|
return super().__eq__(other)
|
185
171
|
|
186
172
|
def __hash__(self):
|
187
|
-
return hash(
|
173
|
+
return hash(id(self))
|
188
174
|
|
189
175
|
def __str__(self):
|
190
176
|
if len(self) == 0:
|
191
177
|
return "None"
|
192
|
-
return str(
|
178
|
+
return str([v for v in self]) if len(self) > 1 else str(next(iter(self)))
|
193
179
|
|
194
180
|
def _to_json(self) -> Dict[str, Any]:
|
195
|
-
return {
|
196
|
-
for
|
181
|
+
return {str(i): v.to_json() if isinstance(v, SubclassJSONSerializer) else v
|
182
|
+
for i, v in enumerate(self)}
|
197
183
|
|
198
184
|
@classmethod
|
199
185
|
def _from_json(cls, data: Dict[str, Any]) -> CaseAttribute:
|
200
|
-
return cls(
|
186
|
+
return cls([SubclassJSONSerializer.from_json(v) for _, v in data.items()])
|
201
187
|
|
202
188
|
|
203
|
-
def create_cases_from_dataframe(df: DataFrame) -> List[Case]:
|
189
|
+
def create_cases_from_dataframe(df: DataFrame, name: Optional[str] = None) -> List[Case]:
|
204
190
|
"""
|
205
191
|
Create cases from a pandas DataFrame.
|
206
192
|
|
207
193
|
:param df: The DataFrame to create cases from.
|
194
|
+
:param name: The semantic name of the DataFrame that describes the DataFrame.
|
208
195
|
:return: The cases of the DataFrame.
|
209
196
|
"""
|
210
197
|
cases = []
|
211
198
|
attribute_names = list(df.columns)
|
212
199
|
for row_id, case in df.iterrows():
|
213
200
|
case = {col_name: case[col_name].item() for col_name in attribute_names}
|
214
|
-
cases.append(Case(_id=row_id,
|
201
|
+
cases.append(Case(DataFrame, _id=row_id, _name=name, **case))
|
215
202
|
return cases
|
216
203
|
|
217
204
|
|
@@ -227,20 +214,21 @@ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
|
|
227
214
|
:param parent_is_iterable: Boolean indicating whether the parent object is iterable or not.
|
228
215
|
:return: The case that represents the object.
|
229
216
|
"""
|
217
|
+
obj_name = obj_name or obj.__class__.__name__
|
230
218
|
if isinstance(obj, DataFrame):
|
231
|
-
return create_cases_from_dataframe(obj)
|
219
|
+
return create_cases_from_dataframe(obj, obj_name)
|
232
220
|
if isinstance(obj, Case):
|
233
221
|
return obj
|
234
222
|
if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
|
235
223
|
or (obj.__class__ in [MetaData, registry])):
|
236
|
-
return Case(_id=id(obj),
|
237
|
-
**{obj_name or obj.__class__.__name__:
|
238
|
-
case = Case(_id=id(obj),
|
224
|
+
return Case(type(obj), _id=id(obj), _name=obj_name,
|
225
|
+
**{obj_name or obj.__class__.__name__: make_list(obj) if parent_is_iterable else obj})
|
226
|
+
case = Case(type(obj), _id=id(obj), _name=obj_name)
|
239
227
|
for attr in dir(obj):
|
240
228
|
if attr.startswith("_") or callable(getattr(obj, attr)):
|
241
229
|
continue
|
242
230
|
attr_value = getattr(obj, attr)
|
243
|
-
case = create_or_update_case_from_attribute(attr_value, attr, obj,
|
231
|
+
case = create_or_update_case_from_attribute(attr_value, attr, obj, obj_name, recursion_idx,
|
244
232
|
max_recursion_idx, parent_is_iterable, case)
|
245
233
|
return case
|
246
234
|
|
@@ -263,16 +251,16 @@ def create_or_update_case_from_attribute(attr_value: Any, name: str, obj: Any, o
|
|
263
251
|
:return: The updated/created case.
|
264
252
|
"""
|
265
253
|
if case is None:
|
266
|
-
case = Case(_id=id(obj),
|
254
|
+
case = Case(type(obj), _id=id(obj), _name=obj_name)
|
267
255
|
if isinstance(attr_value, (dict, UserDict)):
|
268
256
|
case.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
|
269
257
|
if hasattr(attr_value, "__iter__") and not isinstance(attr_value, str):
|
270
258
|
column = create_case_attribute_from_iterable_attribute(attr_value, name, obj, obj_name,
|
271
259
|
recursion_idx=recursion_idx + 1,
|
272
260
|
max_recursion_idx=max_recursion_idx)
|
273
|
-
case[
|
261
|
+
case[name] = column
|
274
262
|
else:
|
275
|
-
case[
|
263
|
+
case[name] = make_list(attr_value) if parent_is_iterable else attr_value
|
276
264
|
return case
|
277
265
|
|
278
266
|
|
@@ -290,22 +278,22 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
|
|
290
278
|
:param max_recursion_idx: The maximum recursion index.
|
291
279
|
:return: A case attribute that represents the original iterable attribute.
|
292
280
|
"""
|
293
|
-
values = attr_value.values() if isinstance(attr_value, (dict, UserDict)) else attr_value
|
281
|
+
values = list(attr_value.values()) if isinstance(attr_value, (dict, UserDict)) else attr_value
|
294
282
|
_type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
|
295
|
-
attr_case = Case(_id=id(attr_value),
|
296
|
-
case_attr = CaseAttribute
|
283
|
+
attr_case = Case(_type, _id=id(attr_value), _name=name)
|
284
|
+
case_attr = CaseAttribute(values)
|
297
285
|
for idx, val in enumerate(values):
|
298
286
|
sub_attr_case = create_case(val, recursion_idx=recursion_idx,
|
299
287
|
max_recursion_idx=max_recursion_idx,
|
300
|
-
obj_name=
|
288
|
+
obj_name=name, parent_is_iterable=True)
|
301
289
|
attr_case.update(sub_attr_case)
|
302
290
|
for sub_attr, val in attr_case.items():
|
303
291
|
setattr(case_attr, sub_attr, val)
|
304
292
|
return case_attr
|
305
293
|
|
306
294
|
|
307
|
-
def show_current_and_corner_cases(case: Any, targets: Optional[
|
308
|
-
current_conclusions: Optional[
|
295
|
+
def show_current_and_corner_cases(case: Any, targets: Optional[Dict[str, Any]] = None,
|
296
|
+
current_conclusions: Optional[Dict[str, Any]] = None,
|
309
297
|
last_evaluated_rule: Optional[Rule] = None) -> None:
|
310
298
|
"""
|
311
299
|
Show the data of the new case and if last evaluated rule exists also show that of the corner case.
|
@@ -316,12 +304,8 @@ def show_current_and_corner_cases(case: Any, targets: Optional[Union[List[CaseAt
|
|
316
304
|
:param last_evaluated_rule: The last evaluated rule in the RDR.
|
317
305
|
"""
|
318
306
|
corner_case = None
|
319
|
-
if targets
|
320
|
-
|
321
|
-
if current_conclusions:
|
322
|
-
current_conclusions = current_conclusions if isinstance(current_conclusions, list) else [current_conclusions]
|
323
|
-
targets = {f"target_{t.__class__.__name__}": t for t in targets} if targets else {}
|
324
|
-
current_conclusions = {c.__class__.__name__: c for c in current_conclusions} if current_conclusions else {}
|
307
|
+
targets = {f"target_{name}": value for name, value in targets.items()} if targets else {}
|
308
|
+
current_conclusions = {name: value for name, value in current_conclusions.items} if current_conclusions else {}
|
325
309
|
if last_evaluated_rule:
|
326
310
|
action = "Refinement" if last_evaluated_rule.fired else "Alternative"
|
327
311
|
print(f"{action} needed for rule: {last_evaluated_rule}\n")
|
@@ -1,13 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import inspect
|
3
4
|
from dataclasses import dataclass
|
4
5
|
|
5
6
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
6
|
-
from typing_extensions import Any, Optional, Type
|
7
|
+
from typing_extensions import Any, Optional, Type, List, Tuple, Set, Dict, TYPE_CHECKING
|
7
8
|
|
8
9
|
from .case import create_case, Case
|
9
|
-
from ..utils import get_attribute_name, copy_case
|
10
|
+
from ..utils import get_attribute_name, copy_case, get_hint_for_attribute, typing_to_python_type
|
10
11
|
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from . import CallableExpression
|
11
14
|
|
12
15
|
@dataclass
|
13
16
|
class CaseQuery:
|
@@ -20,56 +23,93 @@ class CaseQuery:
|
|
20
23
|
"""
|
21
24
|
The case that the attribute belongs to.
|
22
25
|
"""
|
23
|
-
|
26
|
+
attribute_name: str
|
24
27
|
"""
|
25
|
-
The attribute
|
28
|
+
The name of the attribute.
|
26
29
|
"""
|
27
|
-
|
30
|
+
target: Optional[Any] = None
|
28
31
|
"""
|
29
32
|
The target value of the attribute.
|
30
33
|
"""
|
31
|
-
|
34
|
+
mutually_exclusive: bool = False
|
32
35
|
"""
|
33
|
-
|
36
|
+
Whether the attribute can only take one value (i.e. True) or multiple values (i.e. False).
|
37
|
+
"""
|
38
|
+
conditions: Optional[CallableExpression] = None
|
34
39
|
"""
|
35
|
-
|
40
|
+
The conditions that must be satisfied for the target value to be valid.
|
36
41
|
"""
|
37
|
-
|
42
|
+
prediction: Optional[CallableExpression] = None
|
38
43
|
"""
|
39
|
-
|
44
|
+
The predicted value of the attribute.
|
40
45
|
"""
|
41
|
-
|
46
|
+
scope: Optional[Dict[str, Any]] = None
|
47
|
+
"""
|
48
|
+
The global scope of the case query. This is used to evaluate the conditions and prediction, and is what is available
|
49
|
+
to the user when they are prompted for input. If it is not provided, it will be set to the global scope of the
|
50
|
+
caller.
|
42
51
|
"""
|
43
52
|
|
44
|
-
def __init__(self, case: Any,
|
45
|
-
|
46
|
-
|
53
|
+
def __init__(self, case: Any, attribute_name: str,
|
54
|
+
target: Optional[Any] = None,
|
55
|
+
mutually_exclusive: bool = False,
|
56
|
+
conditions: Optional[CallableExpression] = None,
|
57
|
+
prediction: Optional[CallableExpression] = None,
|
58
|
+
scope: Optional[Dict[str, Any]] = None,):
|
59
|
+
self.original_case = case
|
60
|
+
self.case = self._get_case()
|
47
61
|
|
48
|
-
if attribute_name is None:
|
49
|
-
attribute_name = get_attribute_name(case, attribute, attribute_type, target)
|
50
62
|
self.attribute_name = attribute_name
|
63
|
+
self.target = target
|
64
|
+
self.attribute_type = self._get_attribute_type()
|
65
|
+
self.mutually_exclusive = mutually_exclusive
|
66
|
+
self.conditions = conditions
|
67
|
+
self.prediction = prediction
|
68
|
+
self.scope = scope if scope is not None else inspect.currentframe().f_back.f_globals
|
51
69
|
|
52
|
-
|
53
|
-
|
54
|
-
|
70
|
+
def _get_case(self) -> Any:
|
71
|
+
if not isinstance(self.original_case, (Case, SQLTable)):
|
72
|
+
return create_case(self.original_case, max_recursion_idx=3)
|
73
|
+
else:
|
74
|
+
return self.original_case
|
55
75
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
76
|
+
def _get_attribute_type(self) -> Type:
|
77
|
+
"""
|
78
|
+
:return: The type of the attribute.
|
79
|
+
"""
|
80
|
+
if self.target is not None:
|
81
|
+
return type(self.target)
|
82
|
+
elif hasattr(self.original_case, self.attribute_name):
|
83
|
+
hint, origin, args = get_hint_for_attribute(self.attribute_name, self.original_case)
|
84
|
+
if origin is not None:
|
85
|
+
return typing_to_python_type(origin)
|
86
|
+
elif hint is not None:
|
87
|
+
return typing_to_python_type(hint)
|
60
88
|
|
61
89
|
@property
|
62
90
|
def name(self):
|
63
|
-
|
91
|
+
"""
|
92
|
+
:return: The name of the case query.
|
93
|
+
"""
|
94
|
+
return f"{self.case_name}.{self.attribute_name}"
|
95
|
+
|
96
|
+
@property
|
97
|
+
def case_name(self) -> str:
|
98
|
+
"""
|
99
|
+
:return: The name of the case.
|
100
|
+
"""
|
101
|
+
return self.case._name if isinstance(self.case, Case) else self.case.__class__.__name__
|
64
102
|
|
65
103
|
def __str__(self):
|
66
|
-
|
67
|
-
|
68
|
-
else
|
69
|
-
|
104
|
+
header = f"CaseQuery: {self.name}"
|
105
|
+
target = f"Target: {self.name} |= {self.target if self.target is not None else '?'}"
|
106
|
+
prediction = f"Prediction: {self.name} |= {self.prediction if self.prediction is not None else '?'}"
|
107
|
+
conditions = f"Conditions: {self.conditions if self.conditions is not None else '?'}"
|
108
|
+
return "\n".join([header, target, prediction, conditions])
|
70
109
|
|
71
110
|
def __repr__(self):
|
72
111
|
return self.__str__()
|
73
112
|
|
74
113
|
def __copy__(self):
|
75
|
-
return CaseQuery(copy_case(self.case),
|
114
|
+
return CaseQuery(copy_case(self.case), self.attribute_name, self.target, self.mutually_exclusive,
|
115
|
+
self.conditions, self.prediction, self.scope)
|