ripple-down-rules 0.1.21__py3-none-any.whl → 0.1.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +2 -1
- ripple_down_rules/datastructures/__init__.py +4 -4
- ripple_down_rules/datastructures/callable_expression.py +74 -129
- ripple_down_rules/datastructures/case.py +8 -6
- ripple_down_rules/datastructures/dataclasses.py +102 -48
- ripple_down_rules/datastructures/enums.py +5 -1
- ripple_down_rules/experts.py +61 -68
- ripple_down_rules/helpers.py +27 -3
- ripple_down_rules/prompt.py +87 -74
- ripple_down_rules/rdr.py +291 -206
- ripple_down_rules/rules.py +64 -32
- ripple_down_rules/utils.py +209 -4
- {ripple_down_rules-0.1.21.dist-info → ripple_down_rules-0.1.62.dist-info}/METADATA +5 -4
- ripple_down_rules-0.1.62.dist-info/RECORD +20 -0
- {ripple_down_rules-0.1.21.dist-info → ripple_down_rules-0.1.62.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.1.21.dist-info/RECORD +0 -20
- {ripple_down_rules-0.1.21.dist-info → ripple_down_rules-0.1.62.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.1.21.dist-info → ripple_down_rules-0.1.62.dist-info}/top_level.txt +0 -0
ripple_down_rules/datasets.py
CHANGED
@@ -9,7 +9,8 @@ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationshi
|
|
9
9
|
from typing_extensions import Tuple, List, Set, Optional
|
10
10
|
from ucimlrepo import fetch_ucirepo
|
11
11
|
|
12
|
-
from .datastructures import Case, create_cases_from_dataframe
|
12
|
+
from .datastructures.case import Case, create_cases_from_dataframe
|
13
|
+
from .datastructures.enums import Category
|
13
14
|
|
14
15
|
|
15
16
|
def load_cached_dataset(cache_file):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from .enums import *
|
2
|
-
from .dataclasses import *
|
3
|
-
from .callable_expression import *
|
4
|
-
from .case import *
|
1
|
+
# from .enums import *
|
2
|
+
# from .dataclasses import *
|
3
|
+
# from .callable_expression import *
|
4
|
+
# from .case import *
|
@@ -4,11 +4,10 @@ import ast
|
|
4
4
|
import logging
|
5
5
|
from _ast import AST
|
6
6
|
|
7
|
-
from sqlalchemy.orm import Session
|
8
7
|
from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
|
9
8
|
|
10
9
|
from .case import create_case, Case
|
11
|
-
from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string
|
10
|
+
from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string, conclusion_to_json, is_iterable
|
12
11
|
|
13
12
|
|
14
13
|
class VariableVisitor(ast.NodeVisitor):
|
@@ -65,7 +64,8 @@ class VariableVisitor(ast.NodeVisitor):
|
|
65
64
|
|
66
65
|
def get_used_scope(code_str, scope):
|
67
66
|
# Parse the code into an AST
|
68
|
-
|
67
|
+
mode = 'exec' if code_str.startswith('def') else 'eval'
|
68
|
+
tree = ast.parse(code_str, mode=mode)
|
69
69
|
|
70
70
|
# Walk the AST to collect used variable names
|
71
71
|
class NameCollector(ast.NodeVisitor):
|
@@ -88,91 +88,56 @@ class CallableExpression(SubclassJSONSerializer):
|
|
88
88
|
"""
|
89
89
|
A callable that is constructed from a string statement written by an expert.
|
90
90
|
"""
|
91
|
-
conclusion_type: Type
|
92
|
-
"""
|
93
|
-
The type of the output of the callable, used for assertion.
|
94
|
-
"""
|
95
|
-
expression_tree: AST
|
96
|
-
"""
|
97
|
-
The AST tree parsed from the user input.
|
98
|
-
"""
|
99
|
-
user_input: str
|
100
|
-
"""
|
101
|
-
The input given by the expert.
|
102
|
-
"""
|
103
|
-
session: Optional[Session]
|
104
|
-
"""
|
105
|
-
The sqlalchemy orm session.
|
106
|
-
"""
|
107
|
-
visitor: VariableVisitor
|
108
|
-
"""
|
109
|
-
A visitor to extract all variables and comparisons from a python expression represented as an AST tree.
|
110
|
-
"""
|
111
|
-
code: Any
|
112
|
-
"""
|
113
|
-
The code that was compiled from the expression tree
|
114
|
-
"""
|
115
|
-
compares_column_offset: List[int]
|
116
|
-
"""
|
117
|
-
The start and end indices of each comparison in the string of user input.
|
118
|
-
"""
|
119
91
|
|
120
|
-
def __init__(self, user_input:
|
121
|
-
|
92
|
+
def __init__(self, user_input: Optional[str] = None, conclusion_type: Optional[Tuple[Type]] = None,
|
93
|
+
expression_tree: Optional[AST] = None,
|
94
|
+
scope: Optional[Dict[str, Any]] = None, conclusion: Optional[Any] = None):
|
122
95
|
"""
|
123
96
|
Create a callable expression.
|
124
97
|
|
125
98
|
:param user_input: The input given by the expert.
|
126
99
|
:param conclusion_type: The type of the output of the callable.
|
127
100
|
:param expression_tree: The AST tree parsed from the user input.
|
128
|
-
:param
|
101
|
+
:param scope: The scope to use for the callable expression.
|
102
|
+
:param conclusion: The conclusion to use for the callable expression.
|
129
103
|
"""
|
130
|
-
|
104
|
+
if user_input is None and conclusion is None:
|
105
|
+
raise ValueError("Either user_input or conclusion must be provided.")
|
106
|
+
self.conclusion: Optional[Any] = conclusion
|
131
107
|
self.user_input: str = user_input
|
132
|
-
|
108
|
+
if conclusion_type is not None:
|
109
|
+
if is_iterable(conclusion_type):
|
110
|
+
conclusion_type = tuple(conclusion_type)
|
111
|
+
else:
|
112
|
+
conclusion_type = (conclusion_type,)
|
133
113
|
self.conclusion_type = conclusion_type
|
134
114
|
self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
:return: The used scope in the user input.
|
142
|
-
"""
|
143
|
-
return self.visitor.variables.union(self.visitor.attributes.keys())
|
144
|
-
|
145
|
-
@staticmethod
|
146
|
-
def parse_user_input(user_input: str, session: Optional[Session] = None) -> str:
|
147
|
-
if ',' in user_input:
|
148
|
-
user_input = user_input.split(',')
|
149
|
-
user_input = [f"({u.strip()})" for u in user_input]
|
150
|
-
user_input = ' & '.join(user_input) if session else ' and '.join(user_input)
|
151
|
-
elif session:
|
152
|
-
user_input = user_input.replace(" and ", " & ")
|
153
|
-
user_input = user_input.replace(" or ", " | ")
|
154
|
-
return user_input
|
155
|
-
|
156
|
-
def update_expression(self, user_input: str, expression_tree: Optional[AST] = None):
|
157
|
-
if not expression_tree:
|
158
|
-
expression_tree = parse_string_to_expression(user_input)
|
159
|
-
self.expression_tree: AST = expression_tree
|
160
|
-
self.visitor = VariableVisitor()
|
161
|
-
self.visitor.visit(expression_tree)
|
162
|
-
self.expression_tree = parse_string_to_expression(self.parsed_user_input)
|
163
|
-
self.compares_column_offset = [(c[0].col_offset, c[2].end_col_offset) for c in self.visitor.compares]
|
164
|
-
self.code = compile_expression_to_code(self.expression_tree)
|
115
|
+
if conclusion is None:
|
116
|
+
self.scope = get_used_scope(self.user_input, self.scope)
|
117
|
+
self.expression_tree: AST = expression_tree if expression_tree else parse_string_to_expression(self.user_input)
|
118
|
+
self.code = compile_expression_to_code(self.expression_tree)
|
119
|
+
self.visitor = VariableVisitor()
|
120
|
+
self.visitor.visit(self.expression_tree)
|
165
121
|
|
166
122
|
def __call__(self, case: Any, **kwargs) -> Any:
|
167
123
|
try:
|
168
|
-
if not
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
124
|
+
if self.user_input is not None:
|
125
|
+
if not isinstance(case, Case):
|
126
|
+
case = create_case(case, max_recursion_idx=3)
|
127
|
+
scope = {'case': case, **self.scope}
|
128
|
+
output = eval(self.code, scope)
|
129
|
+
if output is None:
|
130
|
+
output = scope['_get_value'](case)
|
131
|
+
if self.conclusion_type is not None:
|
132
|
+
if is_iterable(output) and not isinstance(output, self.conclusion_type):
|
133
|
+
assert isinstance(list(output)[0], self.conclusion_type), (f"Expected output type {self.conclusion_type},"
|
134
|
+
f" got {type(output)}")
|
135
|
+
else:
|
136
|
+
assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
|
137
|
+
f" got {type(output)}")
|
138
|
+
return output
|
139
|
+
else:
|
140
|
+
return self.conclusion
|
176
141
|
except Exception as e:
|
177
142
|
raise ValueError(f"Error during evaluation: {e}")
|
178
143
|
|
@@ -181,35 +146,57 @@ class CallableExpression(SubclassJSONSerializer):
|
|
181
146
|
Combine this callable expression with another callable expression using the 'and' operator.
|
182
147
|
"""
|
183
148
|
new_user_input = f"({self.user_input}) and ({other.user_input})"
|
184
|
-
return CallableExpression(new_user_input, conclusion_type=self.conclusion_type
|
149
|
+
return CallableExpression(new_user_input, conclusion_type=self.conclusion_type)
|
150
|
+
|
151
|
+
def __eq__(self, other):
|
152
|
+
"""
|
153
|
+
Check if two callable expressions are equal.
|
154
|
+
"""
|
155
|
+
if not isinstance(other, CallableExpression):
|
156
|
+
return False
|
157
|
+
return self.user_input == other.user_input and self.conclusion == other.conclusion
|
158
|
+
|
159
|
+
def __hash__(self):
|
160
|
+
"""
|
161
|
+
Hash the callable expression.
|
162
|
+
"""
|
163
|
+
conclusion_hash = self.conclusion if not isinstance(self.conclusion, set) else frozenset(self.conclusion)
|
164
|
+
return hash((self.user_input, conclusion_hash))
|
185
165
|
|
186
166
|
def __str__(self):
|
187
167
|
"""
|
188
168
|
Return the user string where each compare is written in a line using compare column offset start and end.
|
189
169
|
"""
|
190
|
-
user_input
|
170
|
+
if self.user_input is None:
|
171
|
+
return str(self.conclusion)
|
191
172
|
binary_ops = sorted(self.visitor.binary_ops, key=lambda x: x.end_col_offset)
|
192
173
|
binary_ops_indices = [b.end_col_offset for b in binary_ops]
|
193
174
|
all_binary_ops = []
|
194
175
|
prev_e = 0
|
195
176
|
for i, e in enumerate(binary_ops_indices):
|
196
177
|
if i == 0:
|
197
|
-
all_binary_ops.append(user_input[:e])
|
178
|
+
all_binary_ops.append(self.user_input[:e])
|
198
179
|
else:
|
199
|
-
all_binary_ops.append(user_input[prev_e:e])
|
180
|
+
all_binary_ops.append(self.user_input[prev_e:e])
|
200
181
|
prev_e = e
|
201
|
-
return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
|
182
|
+
return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else self.user_input
|
202
183
|
|
203
184
|
def _to_json(self) -> Dict[str, Any]:
|
204
|
-
return {"user_input": self.user_input,
|
185
|
+
return {"user_input": self.user_input,
|
186
|
+
"conclusion_type": [get_full_class_name(t) for t in self.conclusion_type]
|
187
|
+
if self.conclusion_type is not None else None,
|
205
188
|
"scope": {k: get_full_class_name(v) for k, v in self.scope.items()
|
206
|
-
if hasattr(v, '__module__') and hasattr(v, '__name__')}
|
189
|
+
if hasattr(v, '__module__') and hasattr(v, '__name__')},
|
190
|
+
"conclusion": conclusion_to_json(self.conclusion),
|
207
191
|
}
|
208
192
|
|
209
193
|
@classmethod
|
210
194
|
def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
|
211
|
-
return cls(user_input=data["user_input"],
|
212
|
-
|
195
|
+
return cls(user_input=data["user_input"],
|
196
|
+
conclusion_type=tuple(get_type_from_string(t) for t in data["conclusion_type"])
|
197
|
+
if data["conclusion_type"] else None,
|
198
|
+
scope={k: get_type_from_string(v) for k, v in data["scope"].items()},
|
199
|
+
conclusion=SubclassJSONSerializer.from_json(data["conclusion"]))
|
213
200
|
|
214
201
|
|
215
202
|
def compile_expression_to_code(expression_tree: AST) -> Any:
|
@@ -219,51 +206,8 @@ def compile_expression_to_code(expression_tree: AST) -> Any:
|
|
219
206
|
:param expression_tree: The parsed expression tree.
|
220
207
|
:return: The code that was compiled from the expression tree.
|
221
208
|
"""
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
def assert_context_contains_needed_information(case: Any, context: Dict[str, Any],
|
226
|
-
visitor: VariableVisitor) -> Tuple[Set[str], Set[str]]:
|
227
|
-
"""
|
228
|
-
Asserts that the variables mentioned in the expression visited by visitor are all in the given context.
|
229
|
-
|
230
|
-
:param case: The case to check the context for.
|
231
|
-
:param context: The context to check.
|
232
|
-
:param visitor: The visitor that visited the expression.
|
233
|
-
:return: The found variables and attributes.
|
234
|
-
"""
|
235
|
-
found_variables = set()
|
236
|
-
for key in visitor.variables:
|
237
|
-
if key not in context:
|
238
|
-
raise ValueError(f"Variable {key} not found in the case {case}")
|
239
|
-
found_variables.add(key)
|
240
|
-
|
241
|
-
found_attributes = get_attributes_str(visitor)
|
242
|
-
for attr in found_attributes:
|
243
|
-
if attr not in context:
|
244
|
-
raise ValueError(f"Attribute {attr} not found in the case {case}")
|
245
|
-
return found_variables, found_attributes
|
246
|
-
|
247
|
-
|
248
|
-
def get_attributes_str(visitor: VariableVisitor) -> Set[str]:
|
249
|
-
"""
|
250
|
-
Get the string representation of the attributes in the given visitor.
|
251
|
-
|
252
|
-
:param visitor: The visitor that visited the expression.
|
253
|
-
:return: The string representation of the attributes.
|
254
|
-
"""
|
255
|
-
found_attributes = set()
|
256
|
-
for key, ast_attr in visitor.attributes.items():
|
257
|
-
str_attr = ""
|
258
|
-
while isinstance(key, ast.Attribute):
|
259
|
-
if len(str_attr) > 0:
|
260
|
-
str_attr = f"{key.attr}.{str_attr}"
|
261
|
-
else:
|
262
|
-
str_attr = key.attr
|
263
|
-
key = key.value
|
264
|
-
str_attr = f"{key.id}.{str_attr}" if len(str_attr) > 0 else f"{key.id}.{ast_attr.attr}"
|
265
|
-
found_attributes.add(str_attr)
|
266
|
-
return found_attributes
|
209
|
+
mode = 'exec' if isinstance(expression_tree, ast.Module) else 'eval'
|
210
|
+
return compile(expression_tree, filename="<string>", mode=mode)
|
267
211
|
|
268
212
|
|
269
213
|
def parse_string_to_expression(expression_str: str) -> AST:
|
@@ -273,6 +217,7 @@ def parse_string_to_expression(expression_str: str) -> AST:
|
|
273
217
|
:param expression_str: The string which will be parsed.
|
274
218
|
:return: The parsed expression.
|
275
219
|
"""
|
276
|
-
|
220
|
+
mode = 'exec' if expression_str.startswith('def') else 'eval'
|
221
|
+
tree = ast.parse(expression_str, mode=mode)
|
277
222
|
logging.debug(f"AST parsed successfully: {ast.dump(tree)}")
|
278
223
|
return tree
|
@@ -24,7 +24,8 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
24
24
|
the names of the attributes and the values are the attributes. All are stored in lower case.
|
25
25
|
"""
|
26
26
|
|
27
|
-
def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None,
|
27
|
+
def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None,
|
28
|
+
_name: Optional[str] = None, original_object: Optional[Any] = None, **kwargs):
|
28
29
|
"""
|
29
30
|
Create a new row.
|
30
31
|
|
@@ -34,6 +35,7 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
34
35
|
:param kwargs: The attributes of the row.
|
35
36
|
"""
|
36
37
|
super().__init__(kwargs)
|
38
|
+
self._original_object = original_object
|
37
39
|
self._obj_type: Type = _obj_type
|
38
40
|
self._id: Hashable = _id if _id is not None else id(self)
|
39
41
|
self._name: str = _name if _name is not None else self._obj_type.__name__
|
@@ -63,7 +65,7 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
63
65
|
new_list.extend(make_list(value))
|
64
66
|
super().__setitem__(name, new_list)
|
65
67
|
else:
|
66
|
-
super().__setitem__(name,
|
68
|
+
super().__setitem__(name, self[name])
|
67
69
|
else:
|
68
70
|
super().__setitem__(name, value)
|
69
71
|
setattr(self, name, self[name])
|
@@ -221,9 +223,9 @@ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
|
|
221
223
|
return obj
|
222
224
|
if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
|
223
225
|
or (obj.__class__ in [MetaData, registry])):
|
224
|
-
return Case(type(obj), _id=id(obj), _name=obj_name,
|
226
|
+
return Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj,
|
225
227
|
**{obj_name or obj.__class__.__name__: make_list(obj) if parent_is_iterable else obj})
|
226
|
-
case = Case(type(obj), _id=id(obj), _name=obj_name)
|
228
|
+
case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
|
227
229
|
for attr in dir(obj):
|
228
230
|
if attr.startswith("_") or callable(getattr(obj, attr)):
|
229
231
|
continue
|
@@ -251,7 +253,7 @@ def create_or_update_case_from_attribute(attr_value: Any, name: str, obj: Any, o
|
|
251
253
|
:return: The updated/created case.
|
252
254
|
"""
|
253
255
|
if case is None:
|
254
|
-
case = Case(type(obj), _id=id(obj), _name=obj_name)
|
256
|
+
case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
|
255
257
|
if isinstance(attr_value, (dict, UserDict)):
|
256
258
|
case.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
|
257
259
|
if hasattr(attr_value, "__iter__") and not isinstance(attr_value, str):
|
@@ -280,7 +282,7 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
|
|
280
282
|
"""
|
281
283
|
values = list(attr_value.values()) if isinstance(attr_value, (dict, UserDict)) else attr_value
|
282
284
|
_type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
|
283
|
-
attr_case = Case(_type, _id=id(attr_value), _name=name)
|
285
|
+
attr_case = Case(_type, _id=id(attr_value), _name=name, original_object=attr_value)
|
284
286
|
case_attr = CaseAttribute(values)
|
285
287
|
for idx, val in enumerate(values):
|
286
288
|
sub_attr_case = create_case(val, recursion_idx=recursion_idx,
|
@@ -1,16 +1,15 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import inspect
|
4
|
-
from dataclasses import dataclass
|
4
|
+
from dataclasses import dataclass, field
|
5
5
|
|
6
6
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
7
|
-
from typing_extensions import Any, Optional,
|
7
|
+
from typing_extensions import Any, Optional, Dict, Type, Tuple, Union
|
8
8
|
|
9
|
+
from .callable_expression import CallableExpression
|
9
10
|
from .case import create_case, Case
|
10
|
-
from ..utils import
|
11
|
+
from ..utils import copy_case, make_list, make_set
|
11
12
|
|
12
|
-
if TYPE_CHECKING:
|
13
|
-
from . import CallableExpression
|
14
13
|
|
15
14
|
@dataclass
|
16
15
|
class CaseQuery:
|
@@ -19,7 +18,7 @@ class CaseQuery:
|
|
19
18
|
not provided, it will be inferred from the attribute itself or from the attribute type or from the target value,
|
20
19
|
depending on what is provided.
|
21
20
|
"""
|
22
|
-
|
21
|
+
original_case: Any
|
23
22
|
"""
|
24
23
|
The case that the attribute belongs to.
|
25
24
|
"""
|
@@ -27,64 +26,80 @@ class CaseQuery:
|
|
27
26
|
"""
|
28
27
|
The name of the attribute.
|
29
28
|
"""
|
30
|
-
|
29
|
+
_attribute_types: Tuple[Type]
|
31
30
|
"""
|
32
|
-
The
|
31
|
+
The type(s) of the attribute.
|
33
32
|
"""
|
34
|
-
mutually_exclusive: bool
|
33
|
+
mutually_exclusive: bool
|
35
34
|
"""
|
36
35
|
Whether the attribute can only take one value (i.e. True) or multiple values (i.e. False).
|
37
36
|
"""
|
38
|
-
|
37
|
+
_target: Optional[CallableExpression] = None
|
39
38
|
"""
|
40
|
-
The
|
39
|
+
The target expression of the attribute.
|
41
40
|
"""
|
42
|
-
|
41
|
+
default_value: Optional[Any] = None
|
43
42
|
"""
|
44
|
-
The
|
43
|
+
The default value of the attribute. This is used when the target value is not provided.
|
45
44
|
"""
|
46
|
-
scope: Optional[Dict[str, Any]] =
|
45
|
+
scope: Optional[Dict[str, Any]] = field(default_factory=lambda: inspect.currentframe().f_back.f_back.f_globals)
|
47
46
|
"""
|
48
47
|
The global scope of the case query. This is used to evaluate the conditions and prediction, and is what is available
|
49
48
|
to the user when they are prompted for input. If it is not provided, it will be set to the global scope of the
|
50
49
|
caller.
|
51
50
|
"""
|
51
|
+
_case: Optional[Union[Case, SQLTable]] = None
|
52
|
+
"""
|
53
|
+
The created case from the original case that the attribute belongs to.
|
54
|
+
"""
|
55
|
+
_target_value: Optional[Any] = None
|
56
|
+
"""
|
57
|
+
The target value of the case query. (This is the result of the target expression evaluation on the case.)
|
58
|
+
"""
|
59
|
+
conditions: Optional[CallableExpression] = None
|
60
|
+
"""
|
61
|
+
The conditions that must be satisfied for the target value to be valid.
|
62
|
+
"""
|
52
63
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
self.attribute_name = attribute_name
|
63
|
-
self.target = target
|
64
|
-
self.attribute_type = self._get_attribute_type()
|
65
|
-
self.mutually_exclusive = mutually_exclusive
|
66
|
-
self.conditions = conditions
|
67
|
-
self.prediction = prediction
|
68
|
-
self.scope = scope if scope is not None else inspect.currentframe().f_back.f_globals
|
69
|
-
|
70
|
-
def _get_case(self) -> Any:
|
71
|
-
if not isinstance(self.original_case, (Case, SQLTable)):
|
72
|
-
return create_case(self.original_case, max_recursion_idx=3)
|
64
|
+
@property
|
65
|
+
def case(self) -> Any:
|
66
|
+
"""
|
67
|
+
:return: The case that the attribute belongs to.
|
68
|
+
"""
|
69
|
+
if self._case is not None:
|
70
|
+
return self._case
|
71
|
+
elif not isinstance(self.original_case, (Case, SQLTable)):
|
72
|
+
self._case = create_case(self.original_case, max_recursion_idx=3)
|
73
73
|
else:
|
74
|
-
|
74
|
+
self._case = self.original_case
|
75
|
+
return self._case
|
75
76
|
|
76
|
-
|
77
|
+
@case.setter
|
78
|
+
def case(self, value: Any):
|
79
|
+
"""
|
80
|
+
Set the case that the attribute belongs to.
|
81
|
+
"""
|
82
|
+
if not isinstance(value, (Case, SQLTable)):
|
83
|
+
raise ValueError("The case must be a Case or SQLTable object.")
|
84
|
+
self._case = value
|
85
|
+
|
86
|
+
@property
|
87
|
+
def attribute_type(self) -> Tuple[Type]:
|
77
88
|
"""
|
78
89
|
:return: The type of the attribute.
|
79
90
|
"""
|
80
|
-
if self.
|
81
|
-
|
82
|
-
elif
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
91
|
+
if not self.mutually_exclusive and (set not in make_list(self._attribute_types)):
|
92
|
+
self._attribute_types = tuple(set(make_list(self._attribute_types) + [set, list]))
|
93
|
+
elif not isinstance(self._attribute_types, tuple):
|
94
|
+
self._attribute_types = tuple(make_list(self._attribute_types))
|
95
|
+
return self._attribute_types
|
96
|
+
|
97
|
+
@attribute_type.setter
|
98
|
+
def attribute_type(self, value: Type):
|
99
|
+
"""
|
100
|
+
Set the type of the attribute.
|
101
|
+
"""
|
102
|
+
self._attribute_types = tuple(make_list(value))
|
88
103
|
|
89
104
|
@property
|
90
105
|
def name(self):
|
@@ -100,16 +115,55 @@ class CaseQuery:
|
|
100
115
|
"""
|
101
116
|
return self.case._name if isinstance(self.case, Case) else self.case.__class__.__name__
|
102
117
|
|
118
|
+
@property
|
119
|
+
def target(self) -> Optional[CallableExpression]:
|
120
|
+
"""
|
121
|
+
:return: The target expression of the attribute.
|
122
|
+
"""
|
123
|
+
if self._target is not None and not isinstance(self._target, CallableExpression):
|
124
|
+
self._target = CallableExpression(conclusion=self._target, conclusion_type=self.attribute_type,
|
125
|
+
scope=self.scope)
|
126
|
+
return self._target
|
127
|
+
|
128
|
+
@target.setter
|
129
|
+
def target(self, value: Optional[CallableExpression]):
|
130
|
+
"""
|
131
|
+
Set the target expression of the attribute.
|
132
|
+
"""
|
133
|
+
if value is not None and not isinstance(value, (CallableExpression, str)):
|
134
|
+
raise ValueError("The target must be a CallableExpression or a string.")
|
135
|
+
self._target = value
|
136
|
+
self._update_target_value()
|
137
|
+
|
138
|
+
@property
|
139
|
+
def target_value(self) -> Any:
|
140
|
+
"""
|
141
|
+
:return: The target value of the case query.
|
142
|
+
"""
|
143
|
+
if self._target_value is None:
|
144
|
+
self._update_target_value()
|
145
|
+
return self._target_value
|
146
|
+
|
147
|
+
def _update_target_value(self):
|
148
|
+
"""
|
149
|
+
Update the target value of the case query.
|
150
|
+
"""
|
151
|
+
if isinstance(self.target, CallableExpression):
|
152
|
+
self._target_value = self.target(self.case)
|
153
|
+
else:
|
154
|
+
self._target_value = self.target
|
155
|
+
|
103
156
|
def __str__(self):
|
104
157
|
header = f"CaseQuery: {self.name}"
|
105
158
|
target = f"Target: {self.name} |= {self.target if self.target is not None else '?'}"
|
106
|
-
prediction = f"Prediction: {self.name} |= {self.prediction if self.prediction is not None else '?'}"
|
107
159
|
conditions = f"Conditions: {self.conditions if self.conditions is not None else '?'}"
|
108
|
-
return "\n".join([header, target,
|
160
|
+
return "\n".join([header, target, conditions])
|
109
161
|
|
110
162
|
def __repr__(self):
|
111
163
|
return self.__str__()
|
112
164
|
|
113
165
|
def __copy__(self):
|
114
|
-
return CaseQuery(
|
115
|
-
self.
|
166
|
+
return CaseQuery(self.original_case, self.attribute_name, self.attribute_type,
|
167
|
+
self.mutually_exclusive, _target=self.target, default_value=self.default_value,
|
168
|
+
scope=self.scope, _case=copy_case(self.case), _target_value=self.target_value,
|
169
|
+
conditions=self.conditions)
|
@@ -53,7 +53,7 @@ class ExpressionParser(Enum):
|
|
53
53
|
|
54
54
|
class PromptFor(Enum):
|
55
55
|
"""
|
56
|
-
The reason of the prompt. (e.g. get conditions, or
|
56
|
+
The reason of the prompt. (e.g. get conditions, conclusions, or affirmation).
|
57
57
|
"""
|
58
58
|
Conditions: str = "conditions"
|
59
59
|
"""
|
@@ -63,6 +63,10 @@ class PromptFor(Enum):
|
|
63
63
|
"""
|
64
64
|
Prompt for rule conclusion about a case.
|
65
65
|
"""
|
66
|
+
Affirmation: str = "affirmation"
|
67
|
+
"""
|
68
|
+
Prompt for rule conclusion affirmation about a case.
|
69
|
+
"""
|
66
70
|
|
67
71
|
def __str__(self):
|
68
72
|
return self.name
|