ripple-down-rules 0.0.14__tar.gz → 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/PKG-INFO +1 -1
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/pyproject.toml +1 -1
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules/datasets.py +2 -2
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules/datastructures/callable_expression.py +52 -10
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules/datastructures/case.py +53 -70
- ripple_down_rules-0.1.0/src/ripple_down_rules/datastructures/dataclasses.py +115 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules/experts.py +29 -40
- ripple_down_rules-0.1.0/src/ripple_down_rules/helpers.py +27 -0
- ripple_down_rules-0.1.0/src/ripple_down_rules/prompt.py +154 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules/rdr.py +298 -192
- ripple_down_rules-0.1.0/src/ripple_down_rules/rdr_decorators.py +55 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules/rules.py +12 -3
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules/utils.py +154 -3
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules.egg-info/SOURCES.txt +3 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/test/test_json_serialization.py +25 -1
- ripple_down_rules-0.1.0/test/test_on_mutagenic.py +200 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/test/test_rdr.py +35 -121
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/test/test_rdr_alchemy.py +18 -16
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/test/test_relational_rdr.py +8 -8
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/test/test_relational_rdr_alchemy.py +19 -19
- ripple_down_rules-0.0.14/src/ripple_down_rules/datastructures/dataclasses.py +0 -75
- ripple_down_rules-0.0.14/src/ripple_down_rules/prompt.py +0 -101
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/LICENSE +0 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/README.md +0 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/setup.cfg +0 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules/__init__.py +0 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules/datastructures/__init__.py +0 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules/datastructures/enums.py +0 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules/failures.py +0 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
- {ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.0
|
3
|
+
Version: 0.1.0
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
|
|
6
6
|
|
7
7
|
[project]
|
8
8
|
name = "ripple_down_rules"
|
9
|
-
version = "0.0
|
9
|
+
version = "0.1.0"
|
10
10
|
description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
|
11
11
|
readme = "README.md"
|
12
12
|
authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
|
@@ -9,7 +9,7 @@ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationshi
|
|
9
9
|
from typing_extensions import Tuple, List, Set, Optional
|
10
10
|
from ucimlrepo import fetch_ucirepo
|
11
11
|
|
12
|
-
from .datastructures import Case, create_cases_from_dataframe, Category
|
12
|
+
from .datastructures import Case, create_cases_from_dataframe, Category
|
13
13
|
|
14
14
|
|
15
15
|
def load_cached_dataset(cache_file):
|
@@ -77,7 +77,7 @@ def load_zoo_dataset(cache_file: Optional[str] = None) -> Tuple[List[Case], List
|
|
77
77
|
y = zoo['targets']
|
78
78
|
# get ids as list of strings
|
79
79
|
ids = zoo['ids'].values.flatten()
|
80
|
-
all_cases = create_cases_from_dataframe(X)
|
80
|
+
all_cases = create_cases_from_dataframe(X, name="Animal")
|
81
81
|
|
82
82
|
category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
|
83
83
|
category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
|
@@ -22,10 +22,22 @@ class VariableVisitor(ast.NodeVisitor):
|
|
22
22
|
def __init__(self):
|
23
23
|
self.variables = set()
|
24
24
|
self.attributes: Dict[ast.Name, ast.Attribute] = {}
|
25
|
+
self.types = set()
|
26
|
+
self.callables = set()
|
25
27
|
self.compares = list()
|
26
28
|
self.binary_ops = list()
|
27
29
|
self.all = list()
|
28
30
|
|
31
|
+
def visit_Constant(self, node):
|
32
|
+
self.all.append(node)
|
33
|
+
self.types.add(node)
|
34
|
+
self.generic_visit(node)
|
35
|
+
|
36
|
+
def visit_Call(self, node):
|
37
|
+
self.all.append(node)
|
38
|
+
self.callables.add(node)
|
39
|
+
self.generic_visit(node)
|
40
|
+
|
29
41
|
def visit_Attribute(self, node):
|
30
42
|
self.all.append(node)
|
31
43
|
self.attributes[node.value] = node
|
@@ -51,6 +63,27 @@ class VariableVisitor(ast.NodeVisitor):
|
|
51
63
|
self.generic_visit(node)
|
52
64
|
|
53
65
|
|
66
|
+
def get_used_scope(code_str, scope):
|
67
|
+
# Parse the code into an AST
|
68
|
+
tree = ast.parse(code_str, mode='eval')
|
69
|
+
|
70
|
+
# Walk the AST to collect used variable names
|
71
|
+
class NameCollector(ast.NodeVisitor):
|
72
|
+
def __init__(self):
|
73
|
+
self.names = set()
|
74
|
+
|
75
|
+
def visit_Name(self, node):
|
76
|
+
if isinstance(node.ctx, ast.Load): # We care only about variables being read
|
77
|
+
self.names.add(node.id)
|
78
|
+
|
79
|
+
collector = NameCollector()
|
80
|
+
collector.visit(tree)
|
81
|
+
|
82
|
+
# Filter the scope to include only used names
|
83
|
+
used_scope = {k: scope[k] for k in collector.names if k in scope}
|
84
|
+
return used_scope
|
85
|
+
|
86
|
+
|
54
87
|
class CallableExpression(SubclassJSONSerializer):
|
55
88
|
"""
|
56
89
|
A callable that is constructed from a string statement written by an expert.
|
@@ -85,7 +118,7 @@ class CallableExpression(SubclassJSONSerializer):
|
|
85
118
|
"""
|
86
119
|
|
87
120
|
def __init__(self, user_input: str, conclusion_type: Optional[Type] = None, expression_tree: Optional[AST] = None,
|
88
|
-
session: Optional[Session] = None):
|
121
|
+
session: Optional[Session] = None, scope: Optional[Dict[str, Any]] = None):
|
89
122
|
"""
|
90
123
|
Create a callable expression.
|
91
124
|
|
@@ -98,8 +131,17 @@ class CallableExpression(SubclassJSONSerializer):
|
|
98
131
|
self.user_input: str = user_input
|
99
132
|
self.parsed_user_input = self.parse_user_input(user_input, session)
|
100
133
|
self.conclusion_type = conclusion_type
|
134
|
+
self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
|
135
|
+
self.scope = get_used_scope(self.parsed_user_input, self.scope)
|
101
136
|
self.update_expression(self.parsed_user_input, expression_tree)
|
102
137
|
|
138
|
+
def get_used_scope_in_user_input(self) -> Set[str]:
|
139
|
+
"""
|
140
|
+
Get the used scope in the user input.
|
141
|
+
:return: The used scope in the user input.
|
142
|
+
"""
|
143
|
+
return self.visitor.variables.union(self.visitor.attributes.keys())
|
144
|
+
|
103
145
|
@staticmethod
|
104
146
|
def parse_user_input(user_input: str, session: Optional[Session] = None) -> str:
|
105
147
|
if ',' in user_input:
|
@@ -117,11 +159,6 @@ class CallableExpression(SubclassJSONSerializer):
|
|
117
159
|
self.expression_tree: AST = expression_tree
|
118
160
|
self.visitor = VariableVisitor()
|
119
161
|
self.visitor.visit(expression_tree)
|
120
|
-
variables_str = self.visitor.variables
|
121
|
-
attributes_str = get_attributes_str(self.visitor)
|
122
|
-
for v in variables_str | attributes_str:
|
123
|
-
if not v.startswith("case."):
|
124
|
-
self.parsed_user_input = self.parsed_user_input.replace(v, f"case.{v}")
|
125
162
|
self.expression_tree = parse_string_to_expression(self.parsed_user_input)
|
126
163
|
self.compares_column_offset = [(c[0].col_offset, c[2].end_col_offset) for c in self.visitor.compares]
|
127
164
|
self.code = compile_expression_to_code(self.expression_tree)
|
@@ -130,8 +167,9 @@ class CallableExpression(SubclassJSONSerializer):
|
|
130
167
|
try:
|
131
168
|
if not isinstance(case, Case):
|
132
169
|
case = create_case(case, max_recursion_idx=3)
|
133
|
-
|
134
|
-
|
170
|
+
scope = {'case': case, **self.scope}
|
171
|
+
output = eval(self.code, scope)
|
172
|
+
if self.conclusion_type is not None:
|
135
173
|
assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
|
136
174
|
f" got {type(output)}")
|
137
175
|
return output
|
@@ -163,11 +201,15 @@ class CallableExpression(SubclassJSONSerializer):
|
|
163
201
|
return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
|
164
202
|
|
165
203
|
def _to_json(self) -> Dict[str, Any]:
|
166
|
-
return {"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type)
|
204
|
+
return {"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type),
|
205
|
+
"scope": {k: get_full_class_name(v) for k, v in self.scope.items()
|
206
|
+
if hasattr(v, '__module__') and hasattr(v, '__name__')}
|
207
|
+
}
|
167
208
|
|
168
209
|
@classmethod
|
169
210
|
def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
|
170
|
-
return cls(user_input=data["user_input"], conclusion_type=get_type_from_string(data["conclusion_type"])
|
211
|
+
return cls(user_input=data["user_input"], conclusion_type=get_type_from_string(data["conclusion_type"]),
|
212
|
+
scope={k: get_type_from_string(v) for k, v in data["scope"].items()})
|
171
213
|
|
172
214
|
|
173
215
|
def compile_expression_to_code(expression_tree: AST) -> Any:
|
{ripple_down_rules-0.0.14 → ripple_down_rules-0.1.0}/src/ripple_down_rules/datastructures/case.py
RENAMED
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from collections import UserDict
|
4
4
|
from copy import copy, deepcopy
|
5
|
-
from dataclasses import dataclass
|
5
|
+
from dataclasses import dataclass, is_dataclass
|
6
6
|
from enum import Enum
|
7
7
|
|
8
8
|
from pandas import DataFrame
|
@@ -11,7 +11,7 @@ from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColum
|
|
11
11
|
from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING
|
12
12
|
|
13
13
|
from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer, \
|
14
|
-
get_full_class_name, get_type_from_string
|
14
|
+
get_full_class_name, get_type_from_string, make_list, is_iterable, serialize_dataclass
|
15
15
|
|
16
16
|
if TYPE_CHECKING:
|
17
17
|
from ripple_down_rules.rules import Rule
|
@@ -24,16 +24,19 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
24
24
|
the names of the attributes and the values are the attributes. All are stored in lower case.
|
25
25
|
"""
|
26
26
|
|
27
|
-
def __init__(self, _id: Optional[Hashable] = None,
|
27
|
+
def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None, _name: Optional[str] = None, **kwargs):
|
28
28
|
"""
|
29
29
|
Create a new row.
|
30
30
|
|
31
|
+
:param _obj_type: The type of the object that the row represents.
|
31
32
|
:param _id: The id of the row.
|
33
|
+
:param _name: The semantic name that describes the row.
|
32
34
|
:param kwargs: The attributes of the row.
|
33
35
|
"""
|
34
36
|
super().__init__(kwargs)
|
35
|
-
self.
|
36
|
-
self.
|
37
|
+
self._obj_type: Type = _obj_type
|
38
|
+
self._id: Hashable = _id if _id is not None else id(self)
|
39
|
+
self._name: str = _name if _name is not None else self._obj_type.__name__
|
37
40
|
|
38
41
|
@classmethod
|
39
42
|
def from_obj(cls, obj: Any, obj_name: Optional[str] = None, max_recursion_idx: int = 3) -> Case:
|
@@ -53,13 +56,14 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
53
56
|
def __setitem__(self, name: str, value: Any):
|
54
57
|
name = name.lower()
|
55
58
|
if name in self:
|
56
|
-
if isinstance(self[name],
|
57
|
-
self[name].
|
58
|
-
elif isinstance(value,
|
59
|
-
|
60
|
-
|
59
|
+
if isinstance(self[name], list):
|
60
|
+
self[name].extend(make_list(value))
|
61
|
+
elif isinstance(value, list):
|
62
|
+
new_list = make_list(self[name])
|
63
|
+
new_list.extend(make_list(value))
|
64
|
+
super().__setitem__(name, new_list)
|
61
65
|
else:
|
62
|
-
super().__setitem__(name,
|
66
|
+
super().__setitem__(name, [self[name], value])
|
63
67
|
else:
|
64
68
|
super().__setitem__(name, value)
|
65
69
|
setattr(self, name, self[name])
|
@@ -78,18 +82,24 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
78
82
|
def _to_json(self) -> Dict[str, Any]:
|
79
83
|
serializable = {k: v for k, v in self.items() if not k.startswith("_")}
|
80
84
|
serializable["_id"] = self._id
|
85
|
+
serializable["_obj_type"] = get_full_class_name(self._obj_type)
|
86
|
+
serializable["_name"] = self._name
|
81
87
|
for k, v in serializable.items():
|
82
88
|
if isinstance(v, set):
|
83
|
-
serializable[k] = {'_type': get_full_class_name(set), 'value': list(v)}
|
89
|
+
serializable[k] = {'_type': get_full_class_name(set), 'value': serialize_dataclass(list(v))}
|
90
|
+
else:
|
91
|
+
serializable[k] = serialize_dataclass(v)
|
84
92
|
return {k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}
|
85
93
|
|
86
94
|
@classmethod
|
87
95
|
def _from_json(cls, data: Dict[str, Any]) -> Case:
|
88
96
|
id_ = data.pop("_id")
|
97
|
+
obj_type = get_type_from_string(data.pop("_obj_type"))
|
98
|
+
name = data.pop("_name")
|
89
99
|
for k, v in data.items():
|
90
100
|
if isinstance(v, dict) and "_type" in v:
|
91
101
|
data[k] = SubclassJSONSerializer.from_json(v)
|
92
|
-
return cls(_id=id_, **data)
|
102
|
+
return cls(_obj_type=obj_type, _id=id_, _name=name, **data)
|
93
103
|
|
94
104
|
|
95
105
|
@dataclass
|
@@ -122,7 +132,7 @@ class CaseAttributeValue(SubclassJSONSerializer):
|
|
122
132
|
return cls(id=data["id"], value=data["value"])
|
123
133
|
|
124
134
|
|
125
|
-
class CaseAttribute(
|
135
|
+
class CaseAttribute(list, SubclassJSONSerializer):
|
126
136
|
nullable: bool = True
|
127
137
|
"""
|
128
138
|
A boolean indicating whether the case attribute can be None or not.
|
@@ -132,33 +142,9 @@ class CaseAttribute(set, SubclassJSONSerializer):
|
|
132
142
|
A boolean indicating whether the case attribute is mutually exclusive or not. (i.e. can only have one value)
|
133
143
|
"""
|
134
144
|
|
135
|
-
def __init__(self, values: Set[CaseAttributeValue]):
|
136
|
-
"""
|
137
|
-
Create a new case attribute.
|
138
|
-
|
139
|
-
:param values: The values of the case attribute.
|
140
|
-
"""
|
141
|
-
values = self._type_cast_values_to_set_of_case_attribute_values(values)
|
142
|
-
self.id_value_map: Dict[Hashable, Union[CaseAttributeValue, Set[CaseAttributeValue]]] = {id(v): v for v in values}
|
143
|
-
super().__init__([v.value for v in values])
|
144
|
-
|
145
|
-
@staticmethod
|
146
|
-
def _type_cast_values_to_set_of_case_attribute_values(values: Set[Any]) -> Set[CaseAttributeValue]:
|
147
|
-
"""
|
148
|
-
Type cast values to a set of case attribute values.
|
149
|
-
|
150
|
-
:param values: The values to type cast.
|
151
|
-
"""
|
152
|
-
values = make_set(values)
|
153
|
-
if len(values) > 0 and not isinstance(next(iter(values)), CaseAttributeValue):
|
154
|
-
values = {CaseAttributeValue(id(values), v) for v in values}
|
155
|
-
return values
|
156
|
-
|
157
145
|
@classmethod
|
158
|
-
def from_obj(cls, values:
|
159
|
-
|
160
|
-
values = make_set(values)
|
161
|
-
return cls({CaseAttributeValue(id_, v) for v in values})
|
146
|
+
def from_obj(cls, values: List[Any]) -> CaseAttribute:
|
147
|
+
return cls(make_list(values))
|
162
148
|
|
163
149
|
@property
|
164
150
|
def as_dict(self) -> Dict[str, Any]:
|
@@ -176,42 +162,43 @@ class CaseAttribute(set, SubclassJSONSerializer):
|
|
176
162
|
:param condition: The condition to filter by.
|
177
163
|
:return: The filtered column.
|
178
164
|
"""
|
179
|
-
return self.__class__(
|
165
|
+
return self.__class__([v for v in self if condition(v)])
|
180
166
|
|
181
167
|
def __eq__(self, other):
|
182
|
-
if not isinstance(other,
|
183
|
-
return super().__eq__(
|
168
|
+
if not isinstance(other, list):
|
169
|
+
return super().__eq__(make_list(other))
|
184
170
|
return super().__eq__(other)
|
185
171
|
|
186
172
|
def __hash__(self):
|
187
|
-
return hash(
|
173
|
+
return hash(id(self))
|
188
174
|
|
189
175
|
def __str__(self):
|
190
176
|
if len(self) == 0:
|
191
177
|
return "None"
|
192
|
-
return str(
|
178
|
+
return str([v for v in self]) if len(self) > 1 else str(next(iter(self)))
|
193
179
|
|
194
180
|
def _to_json(self) -> Dict[str, Any]:
|
195
|
-
return {
|
196
|
-
for
|
181
|
+
return {str(i): v.to_json() if isinstance(v, SubclassJSONSerializer) else v
|
182
|
+
for i, v in enumerate(self)}
|
197
183
|
|
198
184
|
@classmethod
|
199
185
|
def _from_json(cls, data: Dict[str, Any]) -> CaseAttribute:
|
200
|
-
return cls(
|
186
|
+
return cls([SubclassJSONSerializer.from_json(v) for _, v in data.items()])
|
201
187
|
|
202
188
|
|
203
|
-
def create_cases_from_dataframe(df: DataFrame) -> List[Case]:
|
189
|
+
def create_cases_from_dataframe(df: DataFrame, name: Optional[str] = None) -> List[Case]:
|
204
190
|
"""
|
205
191
|
Create cases from a pandas DataFrame.
|
206
192
|
|
207
193
|
:param df: The DataFrame to create cases from.
|
194
|
+
:param name: The semantic name of the DataFrame that describes the DataFrame.
|
208
195
|
:return: The cases of the DataFrame.
|
209
196
|
"""
|
210
197
|
cases = []
|
211
198
|
attribute_names = list(df.columns)
|
212
199
|
for row_id, case in df.iterrows():
|
213
200
|
case = {col_name: case[col_name].item() for col_name in attribute_names}
|
214
|
-
cases.append(Case(_id=row_id,
|
201
|
+
cases.append(Case(DataFrame, _id=row_id, _name=name, **case))
|
215
202
|
return cases
|
216
203
|
|
217
204
|
|
@@ -228,19 +215,19 @@ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
|
|
228
215
|
:return: The case that represents the object.
|
229
216
|
"""
|
230
217
|
if isinstance(obj, DataFrame):
|
231
|
-
return create_cases_from_dataframe(obj)
|
218
|
+
return create_cases_from_dataframe(obj, obj_name)
|
232
219
|
if isinstance(obj, Case):
|
233
220
|
return obj
|
234
221
|
if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
|
235
222
|
or (obj.__class__ in [MetaData, registry])):
|
236
|
-
return Case(_id=id(obj),
|
237
|
-
**{obj_name or obj.__class__.__name__:
|
238
|
-
case = Case(_id=id(obj),
|
223
|
+
return Case(type(obj), _id=id(obj), _name=obj_name,
|
224
|
+
**{obj_name or obj.__class__.__name__: make_list(obj) if parent_is_iterable else obj})
|
225
|
+
case = Case(type(obj), _id=id(obj), _name=obj_name)
|
239
226
|
for attr in dir(obj):
|
240
227
|
if attr.startswith("_") or callable(getattr(obj, attr)):
|
241
228
|
continue
|
242
229
|
attr_value = getattr(obj, attr)
|
243
|
-
case = create_or_update_case_from_attribute(attr_value, attr, obj,
|
230
|
+
case = create_or_update_case_from_attribute(attr_value, attr, obj, obj_name, recursion_idx,
|
244
231
|
max_recursion_idx, parent_is_iterable, case)
|
245
232
|
return case
|
246
233
|
|
@@ -263,16 +250,16 @@ def create_or_update_case_from_attribute(attr_value: Any, name: str, obj: Any, o
|
|
263
250
|
:return: The updated/created case.
|
264
251
|
"""
|
265
252
|
if case is None:
|
266
|
-
case = Case(_id=id(obj),
|
253
|
+
case = Case(type(obj), _id=id(obj), _name=obj_name)
|
267
254
|
if isinstance(attr_value, (dict, UserDict)):
|
268
255
|
case.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
|
269
256
|
if hasattr(attr_value, "__iter__") and not isinstance(attr_value, str):
|
270
257
|
column = create_case_attribute_from_iterable_attribute(attr_value, name, obj, obj_name,
|
271
258
|
recursion_idx=recursion_idx + 1,
|
272
259
|
max_recursion_idx=max_recursion_idx)
|
273
|
-
case[
|
260
|
+
case[name] = column
|
274
261
|
else:
|
275
|
-
case[
|
262
|
+
case[name] = make_list(attr_value) if parent_is_iterable else attr_value
|
276
263
|
return case
|
277
264
|
|
278
265
|
|
@@ -290,22 +277,22 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
|
|
290
277
|
:param max_recursion_idx: The maximum recursion index.
|
291
278
|
:return: A case attribute that represents the original iterable attribute.
|
292
279
|
"""
|
293
|
-
values = attr_value.values() if isinstance(attr_value, (dict, UserDict)) else attr_value
|
280
|
+
values = list(attr_value.values()) if isinstance(attr_value, (dict, UserDict)) else attr_value
|
294
281
|
_type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
|
295
|
-
attr_case = Case(_id=id(attr_value),
|
296
|
-
case_attr = CaseAttribute
|
282
|
+
attr_case = Case(_type, _id=id(attr_value), _name=name)
|
283
|
+
case_attr = CaseAttribute(values)
|
297
284
|
for idx, val in enumerate(values):
|
298
285
|
sub_attr_case = create_case(val, recursion_idx=recursion_idx,
|
299
286
|
max_recursion_idx=max_recursion_idx,
|
300
|
-
obj_name=
|
287
|
+
obj_name=name, parent_is_iterable=True)
|
301
288
|
attr_case.update(sub_attr_case)
|
302
289
|
for sub_attr, val in attr_case.items():
|
303
290
|
setattr(case_attr, sub_attr, val)
|
304
291
|
return case_attr
|
305
292
|
|
306
293
|
|
307
|
-
def show_current_and_corner_cases(case: Any, targets: Optional[
|
308
|
-
current_conclusions: Optional[
|
294
|
+
def show_current_and_corner_cases(case: Any, targets: Optional[Dict[str, Any]] = None,
|
295
|
+
current_conclusions: Optional[Dict[str, Any]] = None,
|
309
296
|
last_evaluated_rule: Optional[Rule] = None) -> None:
|
310
297
|
"""
|
311
298
|
Show the data of the new case and if last evaluated rule exists also show that of the corner case.
|
@@ -316,12 +303,8 @@ def show_current_and_corner_cases(case: Any, targets: Optional[Union[List[CaseAt
|
|
316
303
|
:param last_evaluated_rule: The last evaluated rule in the RDR.
|
317
304
|
"""
|
318
305
|
corner_case = None
|
319
|
-
if targets
|
320
|
-
|
321
|
-
if current_conclusions:
|
322
|
-
current_conclusions = current_conclusions if isinstance(current_conclusions, list) else [current_conclusions]
|
323
|
-
targets = {f"target_{t.__class__.__name__}": t for t in targets} if targets else {}
|
324
|
-
current_conclusions = {c.__class__.__name__: c for c in current_conclusions} if current_conclusions else {}
|
306
|
+
targets = {f"target_{name}": value for name, value in targets.items()} if targets else {}
|
307
|
+
current_conclusions = {name: value for name, value in current_conclusions.items} if current_conclusions else {}
|
325
308
|
if last_evaluated_rule:
|
326
309
|
action = "Refinement" if last_evaluated_rule.fired else "Alternative"
|
327
310
|
print(f"{action} needed for rule: {last_evaluated_rule}\n")
|
@@ -0,0 +1,115 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import inspect
|
4
|
+
from dataclasses import dataclass
|
5
|
+
|
6
|
+
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
7
|
+
from typing_extensions import Any, Optional, Type, List, Tuple, Set, Dict, TYPE_CHECKING
|
8
|
+
|
9
|
+
from .case import create_case, Case
|
10
|
+
from ..utils import get_attribute_name, copy_case, get_hint_for_attribute, typing_to_python_type
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from . import CallableExpression
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class CaseQuery:
|
17
|
+
"""
|
18
|
+
This is a dataclass that represents an attribute of an object and its target value. If attribute name is
|
19
|
+
not provided, it will be inferred from the attribute itself or from the attribute type or from the target value,
|
20
|
+
depending on what is provided.
|
21
|
+
"""
|
22
|
+
case: Any
|
23
|
+
"""
|
24
|
+
The case that the attribute belongs to.
|
25
|
+
"""
|
26
|
+
attribute_name: str
|
27
|
+
"""
|
28
|
+
The name of the attribute.
|
29
|
+
"""
|
30
|
+
target: Optional[Any] = None
|
31
|
+
"""
|
32
|
+
The target value of the attribute.
|
33
|
+
"""
|
34
|
+
mutually_exclusive: bool = False
|
35
|
+
"""
|
36
|
+
Whether the attribute can only take one value (i.e. True) or multiple values (i.e. False).
|
37
|
+
"""
|
38
|
+
conditions: Optional[CallableExpression] = None
|
39
|
+
"""
|
40
|
+
The conditions that must be satisfied for the target value to be valid.
|
41
|
+
"""
|
42
|
+
prediction: Optional[CallableExpression] = None
|
43
|
+
"""
|
44
|
+
The predicted value of the attribute.
|
45
|
+
"""
|
46
|
+
scope: Optional[Dict[str, Any]] = None
|
47
|
+
"""
|
48
|
+
The global scope of the case query. This is used to evaluate the conditions and prediction, and is what is available
|
49
|
+
to the user when they are prompted for input. If it is not provided, it will be set to the global scope of the
|
50
|
+
caller.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(self, case: Any, attribute_name: str,
|
54
|
+
target: Optional[Any] = None,
|
55
|
+
mutually_exclusive: bool = False,
|
56
|
+
conditions: Optional[CallableExpression] = None,
|
57
|
+
prediction: Optional[CallableExpression] = None,
|
58
|
+
scope: Optional[Dict[str, Any]] = None,):
|
59
|
+
self.original_case = case
|
60
|
+
self.case = self._get_case()
|
61
|
+
|
62
|
+
self.attribute_name = attribute_name
|
63
|
+
self.target = target
|
64
|
+
self.attribute_type = self._get_attribute_type()
|
65
|
+
self.mutually_exclusive = mutually_exclusive
|
66
|
+
self.conditions = conditions
|
67
|
+
self.prediction = prediction
|
68
|
+
self.scope = scope if scope is not None else inspect.currentframe().f_back.f_globals
|
69
|
+
|
70
|
+
def _get_case(self) -> Any:
|
71
|
+
if not isinstance(self.original_case, (Case, SQLTable)):
|
72
|
+
return create_case(self.original_case, max_recursion_idx=3)
|
73
|
+
else:
|
74
|
+
return self.original_case
|
75
|
+
|
76
|
+
def _get_attribute_type(self) -> Type:
|
77
|
+
"""
|
78
|
+
:return: The type of the attribute.
|
79
|
+
"""
|
80
|
+
if self.target is not None:
|
81
|
+
return type(self.target)
|
82
|
+
elif hasattr(self.original_case, self.attribute_name):
|
83
|
+
hint, origin, args = get_hint_for_attribute(self.attribute_name, self.original_case)
|
84
|
+
if origin is not None:
|
85
|
+
return typing_to_python_type(origin)
|
86
|
+
elif hint is not None:
|
87
|
+
return typing_to_python_type(hint)
|
88
|
+
|
89
|
+
@property
|
90
|
+
def name(self):
|
91
|
+
"""
|
92
|
+
:return: The name of the case query.
|
93
|
+
"""
|
94
|
+
return f"{self.case_name}.{self.attribute_name}"
|
95
|
+
|
96
|
+
@property
|
97
|
+
def case_name(self) -> str:
|
98
|
+
"""
|
99
|
+
:return: The name of the case.
|
100
|
+
"""
|
101
|
+
return self.case._name if isinstance(self.case, Case) else self.case.__class__.__name__
|
102
|
+
|
103
|
+
def __str__(self):
|
104
|
+
header = f"CaseQuery: {self.name}"
|
105
|
+
target = f"Target: {self.name} |= {self.target if self.target is not None else '?'}"
|
106
|
+
prediction = f"Prediction: {self.name} |= {self.prediction if self.prediction is not None else '?'}"
|
107
|
+
conditions = f"Conditions: {self.conditions if self.conditions is not None else '?'}"
|
108
|
+
return "\n".join([header, target, prediction, conditions])
|
109
|
+
|
110
|
+
def __repr__(self):
|
111
|
+
return self.__str__()
|
112
|
+
|
113
|
+
def __copy__(self):
|
114
|
+
return CaseQuery(copy_case(self.case), self.attribute_name, self.target, self.mutually_exclusive,
|
115
|
+
self.conditions, self.prediction, self.scope)
|