ripple-down-rules 0.1.3__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/PKG-INFO +1 -1
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/pyproject.toml +1 -1
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datasets.py +2 -1
- ripple_down_rules-0.1.6/src/ripple_down_rules/datastructures/__init__.py +4 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datastructures/callable_expression.py +68 -128
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datastructures/case.py +1 -1
- ripple_down_rules-0.1.6/src/ripple_down_rules/datastructures/dataclasses.py +168 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/experts.py +24 -22
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/prompt.py +44 -50
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/rdr.py +270 -164
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/rules.py +64 -32
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/utils.py +130 -2
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules.egg-info/SOURCES.txt +1 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/test/test_json_serialization.py +2 -2
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/test/test_on_mutagenic.py +5 -5
- ripple_down_rules-0.1.6/test/test_rdr.py +331 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/test/test_rdr_alchemy.py +24 -30
- ripple_down_rules-0.1.6/test/test_rdr_world.py +119 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/test/test_relational_rdr.py +7 -7
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/test/test_relational_rdr_alchemy.py +8 -7
- ripple_down_rules-0.1.3/src/ripple_down_rules/datastructures/__init__.py +0 -4
- ripple_down_rules-0.1.3/src/ripple_down_rules/datastructures/dataclasses.py +0 -115
- ripple_down_rules-0.1.3/test/test_rdr.py +0 -243
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/LICENSE +0 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/README.md +0 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/setup.cfg +0 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/__init__.py +0 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datastructures/enums.py +0 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/failures.py +0 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/helpers.py +0 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/rdr_decorators.py +0 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
- {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
|
|
6
6
|
|
7
7
|
[project]
|
8
8
|
name = "ripple_down_rules"
|
9
|
-
version = "0.1.
|
9
|
+
version = "0.1.6"
|
10
10
|
description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
|
11
11
|
readme = "README.md"
|
12
12
|
authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
|
@@ -9,7 +9,8 @@ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationshi
|
|
9
9
|
from typing_extensions import Tuple, List, Set, Optional
|
10
10
|
from ucimlrepo import fetch_ucirepo
|
11
11
|
|
12
|
-
from .datastructures import Case, create_cases_from_dataframe
|
12
|
+
from .datastructures.case import Case, create_cases_from_dataframe
|
13
|
+
from .datastructures.enums import Category
|
13
14
|
|
14
15
|
|
15
16
|
def load_cached_dataset(cache_file):
|
@@ -4,11 +4,10 @@ import ast
|
|
4
4
|
import logging
|
5
5
|
from _ast import AST
|
6
6
|
|
7
|
-
from sqlalchemy.orm import Session
|
8
7
|
from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
|
9
8
|
|
10
9
|
from .case import create_case, Case
|
11
|
-
from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string
|
10
|
+
from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string, conclusion_to_json, is_iterable
|
12
11
|
|
13
12
|
|
14
13
|
class VariableVisitor(ast.NodeVisitor):
|
@@ -89,93 +88,56 @@ class CallableExpression(SubclassJSONSerializer):
|
|
89
88
|
"""
|
90
89
|
A callable that is constructed from a string statement written by an expert.
|
91
90
|
"""
|
92
|
-
conclusion_type: Type
|
93
|
-
"""
|
94
|
-
The type of the output of the callable, used for assertion.
|
95
|
-
"""
|
96
|
-
expression_tree: AST
|
97
|
-
"""
|
98
|
-
The AST tree parsed from the user input.
|
99
|
-
"""
|
100
|
-
user_input: str
|
101
|
-
"""
|
102
|
-
The input given by the expert.
|
103
|
-
"""
|
104
|
-
session: Optional[Session]
|
105
|
-
"""
|
106
|
-
The sqlalchemy orm session.
|
107
|
-
"""
|
108
|
-
visitor: VariableVisitor
|
109
|
-
"""
|
110
|
-
A visitor to extract all variables and comparisons from a python expression represented as an AST tree.
|
111
|
-
"""
|
112
|
-
code: Any
|
113
|
-
"""
|
114
|
-
The code that was compiled from the expression tree
|
115
|
-
"""
|
116
|
-
compares_column_offset: List[int]
|
117
|
-
"""
|
118
|
-
The start and end indices of each comparison in the string of user input.
|
119
|
-
"""
|
120
91
|
|
121
|
-
def __init__(self, user_input:
|
122
|
-
|
92
|
+
def __init__(self, user_input: Optional[str] = None, conclusion_type: Optional[Tuple[Type]] = None,
|
93
|
+
expression_tree: Optional[AST] = None,
|
94
|
+
scope: Optional[Dict[str, Any]] = None, conclusion: Optional[Any] = None):
|
123
95
|
"""
|
124
96
|
Create a callable expression.
|
125
97
|
|
126
98
|
:param user_input: The input given by the expert.
|
127
99
|
:param conclusion_type: The type of the output of the callable.
|
128
100
|
:param expression_tree: The AST tree parsed from the user input.
|
129
|
-
:param
|
101
|
+
:param scope: The scope to use for the callable expression.
|
102
|
+
:param conclusion: The conclusion to use for the callable expression.
|
130
103
|
"""
|
131
|
-
|
104
|
+
if user_input is None and conclusion is None:
|
105
|
+
raise ValueError("Either user_input or conclusion must be provided.")
|
106
|
+
self.conclusion: Optional[Any] = conclusion
|
132
107
|
self.user_input: str = user_input
|
133
|
-
|
108
|
+
if conclusion_type is not None:
|
109
|
+
if is_iterable(conclusion_type):
|
110
|
+
conclusion_type = tuple(conclusion_type)
|
111
|
+
else:
|
112
|
+
conclusion_type = (conclusion_type,)
|
134
113
|
self.conclusion_type = conclusion_type
|
135
114
|
self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
:return: The used scope in the user input.
|
143
|
-
"""
|
144
|
-
return self.visitor.variables.union(self.visitor.attributes.keys())
|
145
|
-
|
146
|
-
@staticmethod
|
147
|
-
def parse_user_input(user_input: str, session: Optional[Session] = None) -> str:
|
148
|
-
if ',' in user_input:
|
149
|
-
user_input = user_input.split(',')
|
150
|
-
user_input = [f"({u.strip()})" for u in user_input]
|
151
|
-
user_input = ' & '.join(user_input) if session else ' and '.join(user_input)
|
152
|
-
elif session:
|
153
|
-
user_input = user_input.replace(" and ", " & ")
|
154
|
-
user_input = user_input.replace(" or ", " | ")
|
155
|
-
return user_input
|
156
|
-
|
157
|
-
def update_expression(self, user_input: str, expression_tree: Optional[AST] = None):
|
158
|
-
if not expression_tree:
|
159
|
-
expression_tree = parse_string_to_expression(user_input)
|
160
|
-
self.expression_tree: AST = expression_tree
|
161
|
-
self.visitor = VariableVisitor()
|
162
|
-
self.visitor.visit(expression_tree)
|
163
|
-
self.expression_tree = parse_string_to_expression(self.parsed_user_input)
|
164
|
-
self.compares_column_offset = [(c[0].col_offset, c[2].end_col_offset) for c in self.visitor.compares]
|
165
|
-
self.code = compile_expression_to_code(self.expression_tree)
|
115
|
+
if conclusion is None:
|
116
|
+
self.scope = get_used_scope(self.user_input, self.scope)
|
117
|
+
self.expression_tree: AST = expression_tree if expression_tree else parse_string_to_expression(self.user_input)
|
118
|
+
self.code = compile_expression_to_code(self.expression_tree)
|
119
|
+
self.visitor = VariableVisitor()
|
120
|
+
self.visitor.visit(self.expression_tree)
|
166
121
|
|
167
122
|
def __call__(self, case: Any, **kwargs) -> Any:
|
168
123
|
try:
|
169
|
-
if not
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
output
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
124
|
+
if self.user_input is not None:
|
125
|
+
if not isinstance(case, Case):
|
126
|
+
case = create_case(case, max_recursion_idx=3)
|
127
|
+
scope = {'case': case, **self.scope}
|
128
|
+
output = eval(self.code, scope)
|
129
|
+
if output is None:
|
130
|
+
output = scope['_get_value'](case)
|
131
|
+
if self.conclusion_type is not None:
|
132
|
+
if is_iterable(output) and not isinstance(output, self.conclusion_type):
|
133
|
+
assert isinstance(list(output)[0], self.conclusion_type), (f"Expected output type {self.conclusion_type},"
|
134
|
+
f" got {type(output)}")
|
135
|
+
else:
|
136
|
+
assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
|
137
|
+
f" got {type(output)}")
|
138
|
+
return output
|
139
|
+
else:
|
140
|
+
return self.conclusion
|
179
141
|
except Exception as e:
|
180
142
|
raise ValueError(f"Error during evaluation: {e}")
|
181
143
|
|
@@ -184,35 +146,57 @@ class CallableExpression(SubclassJSONSerializer):
|
|
184
146
|
Combine this callable expression with another callable expression using the 'and' operator.
|
185
147
|
"""
|
186
148
|
new_user_input = f"({self.user_input}) and ({other.user_input})"
|
187
|
-
return CallableExpression(new_user_input, conclusion_type=self.conclusion_type
|
149
|
+
return CallableExpression(new_user_input, conclusion_type=self.conclusion_type)
|
150
|
+
|
151
|
+
def __eq__(self, other):
|
152
|
+
"""
|
153
|
+
Check if two callable expressions are equal.
|
154
|
+
"""
|
155
|
+
if not isinstance(other, CallableExpression):
|
156
|
+
return False
|
157
|
+
return self.user_input == other.user_input and self.conclusion == other.conclusion
|
158
|
+
|
159
|
+
def __hash__(self):
|
160
|
+
"""
|
161
|
+
Hash the callable expression.
|
162
|
+
"""
|
163
|
+
conclusion_hash = self.conclusion if not isinstance(self.conclusion, set) else frozenset(self.conclusion)
|
164
|
+
return hash((self.user_input, conclusion_hash))
|
188
165
|
|
189
166
|
def __str__(self):
|
190
167
|
"""
|
191
168
|
Return the user string where each compare is written in a line using compare column offset start and end.
|
192
169
|
"""
|
193
|
-
user_input
|
170
|
+
if self.user_input is None:
|
171
|
+
return str(self.conclusion)
|
194
172
|
binary_ops = sorted(self.visitor.binary_ops, key=lambda x: x.end_col_offset)
|
195
173
|
binary_ops_indices = [b.end_col_offset for b in binary_ops]
|
196
174
|
all_binary_ops = []
|
197
175
|
prev_e = 0
|
198
176
|
for i, e in enumerate(binary_ops_indices):
|
199
177
|
if i == 0:
|
200
|
-
all_binary_ops.append(user_input[:e])
|
178
|
+
all_binary_ops.append(self.user_input[:e])
|
201
179
|
else:
|
202
|
-
all_binary_ops.append(user_input[prev_e:e])
|
180
|
+
all_binary_ops.append(self.user_input[prev_e:e])
|
203
181
|
prev_e = e
|
204
|
-
return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
|
182
|
+
return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else self.user_input
|
205
183
|
|
206
184
|
def _to_json(self) -> Dict[str, Any]:
|
207
|
-
return {"user_input": self.user_input,
|
185
|
+
return {"user_input": self.user_input,
|
186
|
+
"conclusion_type": [get_full_class_name(t) for t in self.conclusion_type]
|
187
|
+
if self.conclusion_type is not None else None,
|
208
188
|
"scope": {k: get_full_class_name(v) for k, v in self.scope.items()
|
209
|
-
if hasattr(v, '__module__') and hasattr(v, '__name__')}
|
189
|
+
if hasattr(v, '__module__') and hasattr(v, '__name__')},
|
190
|
+
"conclusion": conclusion_to_json(self.conclusion),
|
210
191
|
}
|
211
192
|
|
212
193
|
@classmethod
|
213
194
|
def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
|
214
|
-
return cls(user_input=data["user_input"],
|
215
|
-
|
195
|
+
return cls(user_input=data["user_input"],
|
196
|
+
conclusion_type=tuple(get_type_from_string(t) for t in data["conclusion_type"])
|
197
|
+
if data["conclusion_type"] else None,
|
198
|
+
scope={k: get_type_from_string(v) for k, v in data["scope"].items()},
|
199
|
+
conclusion=SubclassJSONSerializer.from_json(data["conclusion"]))
|
216
200
|
|
217
201
|
|
218
202
|
def compile_expression_to_code(expression_tree: AST) -> Any:
|
@@ -226,50 +210,6 @@ def compile_expression_to_code(expression_tree: AST) -> Any:
|
|
226
210
|
return compile(expression_tree, filename="<string>", mode=mode)
|
227
211
|
|
228
212
|
|
229
|
-
def assert_context_contains_needed_information(case: Any, context: Dict[str, Any],
|
230
|
-
visitor: VariableVisitor) -> Tuple[Set[str], Set[str]]:
|
231
|
-
"""
|
232
|
-
Asserts that the variables mentioned in the expression visited by visitor are all in the given context.
|
233
|
-
|
234
|
-
:param case: The case to check the context for.
|
235
|
-
:param context: The context to check.
|
236
|
-
:param visitor: The visitor that visited the expression.
|
237
|
-
:return: The found variables and attributes.
|
238
|
-
"""
|
239
|
-
found_variables = set()
|
240
|
-
for key in visitor.variables:
|
241
|
-
if key not in context:
|
242
|
-
raise ValueError(f"Variable {key} not found in the case {case}")
|
243
|
-
found_variables.add(key)
|
244
|
-
|
245
|
-
found_attributes = get_attributes_str(visitor)
|
246
|
-
for attr in found_attributes:
|
247
|
-
if attr not in context:
|
248
|
-
raise ValueError(f"Attribute {attr} not found in the case {case}")
|
249
|
-
return found_variables, found_attributes
|
250
|
-
|
251
|
-
|
252
|
-
def get_attributes_str(visitor: VariableVisitor) -> Set[str]:
|
253
|
-
"""
|
254
|
-
Get the string representation of the attributes in the given visitor.
|
255
|
-
|
256
|
-
:param visitor: The visitor that visited the expression.
|
257
|
-
:return: The string representation of the attributes.
|
258
|
-
"""
|
259
|
-
found_attributes = set()
|
260
|
-
for key, ast_attr in visitor.attributes.items():
|
261
|
-
str_attr = ""
|
262
|
-
while isinstance(key, ast.Attribute):
|
263
|
-
if len(str_attr) > 0:
|
264
|
-
str_attr = f"{key.attr}.{str_attr}"
|
265
|
-
else:
|
266
|
-
str_attr = key.attr
|
267
|
-
key = key.value
|
268
|
-
str_attr = f"{key.id}.{str_attr}" if len(str_attr) > 0 else f"{key.id}.{ast_attr.attr}"
|
269
|
-
found_attributes.add(str_attr)
|
270
|
-
return found_attributes
|
271
|
-
|
272
|
-
|
273
213
|
def parse_string_to_expression(expression_str: str) -> AST:
|
274
214
|
"""
|
275
215
|
Parse a string statement into an AST expression.
|
{ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datastructures/case.py
RENAMED
@@ -63,7 +63,7 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
63
63
|
new_list.extend(make_list(value))
|
64
64
|
super().__setitem__(name, new_list)
|
65
65
|
else:
|
66
|
-
super().__setitem__(name,
|
66
|
+
super().__setitem__(name, self[name])
|
67
67
|
else:
|
68
68
|
super().__setitem__(name, value)
|
69
69
|
setattr(self, name, self[name])
|
@@ -0,0 +1,168 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import inspect
|
4
|
+
from dataclasses import dataclass, field
|
5
|
+
|
6
|
+
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
7
|
+
from typing_extensions import Any, Optional, Dict, Type, Tuple, Union
|
8
|
+
|
9
|
+
from .callable_expression import CallableExpression
|
10
|
+
from .case import create_case, Case
|
11
|
+
from ..utils import copy_case, make_list
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class CaseQuery:
|
16
|
+
"""
|
17
|
+
This is a dataclass that represents an attribute of an object and its target value. If attribute name is
|
18
|
+
not provided, it will be inferred from the attribute itself or from the attribute type or from the target value,
|
19
|
+
depending on what is provided.
|
20
|
+
"""
|
21
|
+
original_case: Any
|
22
|
+
"""
|
23
|
+
The case that the attribute belongs to.
|
24
|
+
"""
|
25
|
+
attribute_name: str
|
26
|
+
"""
|
27
|
+
The name of the attribute.
|
28
|
+
"""
|
29
|
+
_attribute_types: Tuple[Type]
|
30
|
+
"""
|
31
|
+
The type(s) of the attribute.
|
32
|
+
"""
|
33
|
+
mutually_exclusive: bool
|
34
|
+
"""
|
35
|
+
Whether the attribute can only take one value (i.e. True) or multiple values (i.e. False).
|
36
|
+
"""
|
37
|
+
_target: Optional[CallableExpression] = None
|
38
|
+
"""
|
39
|
+
The target expression of the attribute.
|
40
|
+
"""
|
41
|
+
default_value: Optional[Any] = None
|
42
|
+
"""
|
43
|
+
The default value of the attribute. This is used when the target value is not provided.
|
44
|
+
"""
|
45
|
+
scope: Optional[Dict[str, Any]] = field(default_factory=lambda: inspect.currentframe().f_back.f_back.f_globals)
|
46
|
+
"""
|
47
|
+
The global scope of the case query. This is used to evaluate the conditions and prediction, and is what is available
|
48
|
+
to the user when they are prompted for input. If it is not provided, it will be set to the global scope of the
|
49
|
+
caller.
|
50
|
+
"""
|
51
|
+
_case: Optional[Union[Case, SQLTable]] = None
|
52
|
+
"""
|
53
|
+
The created case from the original case that the attribute belongs to.
|
54
|
+
"""
|
55
|
+
_target_value: Optional[Any] = None
|
56
|
+
"""
|
57
|
+
The target value of the case query. (This is the result of the target expression evaluation on the case.)
|
58
|
+
"""
|
59
|
+
conditions: Optional[CallableExpression] = None
|
60
|
+
"""
|
61
|
+
The conditions that must be satisfied for the target value to be valid.
|
62
|
+
"""
|
63
|
+
|
64
|
+
@property
|
65
|
+
def case(self) -> Any:
|
66
|
+
"""
|
67
|
+
:return: The case that the attribute belongs to.
|
68
|
+
"""
|
69
|
+
if self._case is not None:
|
70
|
+
return self._case
|
71
|
+
elif not isinstance(self.original_case, (Case, SQLTable)):
|
72
|
+
self._case = create_case(self.original_case, max_recursion_idx=3)
|
73
|
+
else:
|
74
|
+
self._case = self.original_case
|
75
|
+
return self._case
|
76
|
+
|
77
|
+
@case.setter
|
78
|
+
def case(self, value: Any):
|
79
|
+
"""
|
80
|
+
Set the case that the attribute belongs to.
|
81
|
+
"""
|
82
|
+
if not isinstance(value, (Case, SQLTable)):
|
83
|
+
raise ValueError("The case must be a Case or SQLTable object.")
|
84
|
+
self._case = value
|
85
|
+
|
86
|
+
@property
|
87
|
+
def attribute_type(self) -> Tuple[Type]:
|
88
|
+
"""
|
89
|
+
:return: The type of the attribute.
|
90
|
+
"""
|
91
|
+
self._attribute_types = tuple(make_list(self._attribute_types))
|
92
|
+
if not self.mutually_exclusive and (set not in self._attribute_types):
|
93
|
+
self._attribute_types = tuple(list(self._attribute_types) + [set])
|
94
|
+
return self._attribute_types
|
95
|
+
|
96
|
+
@attribute_type.setter
|
97
|
+
def attribute_type(self, value: Type):
|
98
|
+
"""
|
99
|
+
Set the type of the attribute.
|
100
|
+
"""
|
101
|
+
self._attribute_types = tuple(make_list(value))
|
102
|
+
|
103
|
+
@property
|
104
|
+
def name(self):
|
105
|
+
"""
|
106
|
+
:return: The name of the case query.
|
107
|
+
"""
|
108
|
+
return f"{self.case_name}.{self.attribute_name}"
|
109
|
+
|
110
|
+
@property
|
111
|
+
def case_name(self) -> str:
|
112
|
+
"""
|
113
|
+
:return: The name of the case.
|
114
|
+
"""
|
115
|
+
return self.case._name if isinstance(self.case, Case) else self.case.__class__.__name__
|
116
|
+
|
117
|
+
@property
|
118
|
+
def target(self) -> Optional[CallableExpression]:
|
119
|
+
"""
|
120
|
+
:return: The target expression of the attribute.
|
121
|
+
"""
|
122
|
+
if self._target is not None and not isinstance(self._target, CallableExpression):
|
123
|
+
self._target = CallableExpression(conclusion=self._target, conclusion_type=self.attribute_type,
|
124
|
+
scope=self.scope)
|
125
|
+
return self._target
|
126
|
+
|
127
|
+
@target.setter
|
128
|
+
def target(self, value: Optional[CallableExpression]):
|
129
|
+
"""
|
130
|
+
Set the target expression of the attribute.
|
131
|
+
"""
|
132
|
+
if value is not None and not isinstance(value, (CallableExpression, str)):
|
133
|
+
raise ValueError("The target must be a CallableExpression or a string.")
|
134
|
+
self._target = value
|
135
|
+
self._update_target_value()
|
136
|
+
|
137
|
+
@property
|
138
|
+
def target_value(self) -> Any:
|
139
|
+
"""
|
140
|
+
:return: The target value of the case query.
|
141
|
+
"""
|
142
|
+
if self._target_value is None:
|
143
|
+
self._update_target_value()
|
144
|
+
return self._target_value
|
145
|
+
|
146
|
+
def _update_target_value(self):
|
147
|
+
"""
|
148
|
+
Update the target value of the case query.
|
149
|
+
"""
|
150
|
+
if isinstance(self.target, CallableExpression):
|
151
|
+
self._target_value = self.target(self.case)
|
152
|
+
else:
|
153
|
+
self._target_value = self.target
|
154
|
+
|
155
|
+
def __str__(self):
|
156
|
+
header = f"CaseQuery: {self.name}"
|
157
|
+
target = f"Target: {self.name} |= {self.target if self.target is not None else '?'}"
|
158
|
+
conditions = f"Conditions: {self.conditions if self.conditions is not None else '?'}"
|
159
|
+
return "\n".join([header, target, conditions])
|
160
|
+
|
161
|
+
def __repr__(self):
|
162
|
+
return self.__str__()
|
163
|
+
|
164
|
+
def __copy__(self):
|
165
|
+
return CaseQuery(self.original_case, self.attribute_name, self.attribute_type,
|
166
|
+
self.mutually_exclusive, _target=self.target, default_value=self.default_value,
|
167
|
+
scope=self.scope, _case=copy_case(self.case), _target_value=self.target_value,
|
168
|
+
conditions=self.conditions)
|
@@ -3,13 +3,15 @@ from __future__ import annotations
|
|
3
3
|
import json
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
|
6
|
-
from
|
7
|
-
from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type, Union, Any, get_type_hints
|
6
|
+
from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Type, Any
|
8
7
|
|
9
|
-
from .datastructures import
|
8
|
+
from .datastructures.case import Case, CaseAttribute
|
9
|
+
from .datastructures.callable_expression import CallableExpression
|
10
|
+
from .datastructures.enums import PromptFor
|
11
|
+
from .datastructures.dataclasses import CaseQuery
|
10
12
|
from .datastructures.case import show_current_and_corner_cases
|
11
|
-
from .prompt import prompt_user_for_expression
|
12
|
-
from .utils import get_all_subclasses,
|
13
|
+
from .prompt import prompt_user_for_expression
|
14
|
+
from .utils import get_all_subclasses, make_list
|
13
15
|
|
14
16
|
if TYPE_CHECKING:
|
15
17
|
from .rdr import Rule
|
@@ -89,10 +91,9 @@ class Human(Expert):
|
|
89
91
|
The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
|
90
92
|
"""
|
91
93
|
|
92
|
-
def __init__(self, use_loaded_answers: bool = False
|
94
|
+
def __init__(self, use_loaded_answers: bool = False):
|
93
95
|
self.all_expert_answers = []
|
94
96
|
self.use_loaded_answers = use_loaded_answers
|
95
|
-
self.session = session
|
96
97
|
|
97
98
|
def save_answers(self, path: str, append: bool = False):
|
98
99
|
"""
|
@@ -142,9 +143,9 @@ class Human(Expert):
|
|
142
143
|
if self.use_loaded_answers:
|
143
144
|
user_input = self.all_expert_answers.pop(0)
|
144
145
|
if user_input:
|
145
|
-
condition = CallableExpression(user_input, bool, scope=case_query.scope
|
146
|
+
condition = CallableExpression(user_input, bool, scope=case_query.scope)
|
146
147
|
else:
|
147
|
-
user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions
|
148
|
+
user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions)
|
148
149
|
if not self.use_loaded_answers:
|
149
150
|
self.all_expert_answers.append(user_input)
|
150
151
|
case_query.conditions = condition
|
@@ -168,21 +169,22 @@ class Human(Expert):
|
|
168
169
|
extra_conclusions[category] = self._get_conditions(case, {category.__class__.__name__: category})
|
169
170
|
return extra_conclusions
|
170
171
|
|
171
|
-
def ask_for_conclusion(self, case_query: CaseQuery) ->
|
172
|
+
def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
|
172
173
|
"""
|
173
174
|
Ask the expert to provide a conclusion for the case.
|
174
175
|
|
175
176
|
:param case_query: The case query containing the case to find a conclusion for.
|
176
|
-
:return: The
|
177
|
+
:return: The conclusion for the case as a callable expression.
|
177
178
|
"""
|
179
|
+
expression: Optional[CallableExpression] = None
|
178
180
|
if self.use_loaded_answers:
|
179
181
|
expert_input = self.all_expert_answers.pop(0)
|
180
|
-
|
181
|
-
|
182
|
+
if expert_input is not None:
|
183
|
+
expression = CallableExpression(expert_input, case_query.attribute_type,
|
184
|
+
scope=case_query.scope)
|
182
185
|
else:
|
183
186
|
show_current_and_corner_cases(case_query.case)
|
184
|
-
expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion
|
185
|
-
session=self.session)
|
187
|
+
expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion)
|
186
188
|
self.all_expert_answers.append(expert_input)
|
187
189
|
case_query.target = expression
|
188
190
|
return expression
|
@@ -195,7 +197,8 @@ class Human(Expert):
|
|
195
197
|
:return: The category type.
|
196
198
|
"""
|
197
199
|
cat_name = cat_name.lower()
|
198
|
-
self.known_categories = get_all_subclasses(
|
200
|
+
self.known_categories = get_all_subclasses(
|
201
|
+
CaseAttribute) if not self.known_categories else self.known_categories
|
199
202
|
self.known_categories.update(CaseAttribute.registry)
|
200
203
|
category_type = None
|
201
204
|
if cat_name in self.known_categories:
|
@@ -211,9 +214,9 @@ class Human(Expert):
|
|
211
214
|
question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
|
212
215
|
return not self.ask_yes_no_question(question)
|
213
216
|
|
214
|
-
def ask_if_conclusion_is_correct(self, x: Case, conclusion:
|
215
|
-
targets: Optional[List[
|
216
|
-
current_conclusions: Optional[List[
|
217
|
+
def ask_if_conclusion_is_correct(self, x: Case, conclusion: Any,
|
218
|
+
targets: Optional[List[Any]] = None,
|
219
|
+
current_conclusions: Optional[List[Any]] = None) -> bool:
|
217
220
|
"""
|
218
221
|
Ask the expert if the conclusion is correct.
|
219
222
|
|
@@ -225,9 +228,8 @@ class Human(Expert):
|
|
225
228
|
question = ""
|
226
229
|
if not self.use_loaded_answers:
|
227
230
|
targets = targets or []
|
228
|
-
|
229
|
-
x.
|
230
|
-
x.targets = targets
|
231
|
+
x.conclusions = make_list(current_conclusions)
|
232
|
+
x.targets = make_list(targets)
|
231
233
|
question = f"Is the conclusion {conclusion} correct for the case (y/n):" \
|
232
234
|
f"\n{str(x)}"
|
233
235
|
return self.ask_yes_no_question(question)
|