ripple-down-rules 0.1.2__tar.gz → 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/PKG-INFO +1 -1
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/pyproject.toml +1 -1
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules/datasets.py +2 -1
- ripple_down_rules-0.1.5/src/ripple_down_rules/datastructures/__init__.py +4 -0
- ripple_down_rules-0.1.5/src/ripple_down_rules/datastructures/callable_expression.py +223 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules/datastructures/case.py +1 -1
- ripple_down_rules-0.1.5/src/ripple_down_rules/datastructures/dataclasses.py +169 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules/experts.py +24 -22
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules/prompt.py +68 -68
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules/rdr.py +290 -153
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules/rules.py +64 -32
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules/utils.py +166 -4
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules.egg-info/SOURCES.txt +1 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/test/test_json_serialization.py +2 -2
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/test/test_on_mutagenic.py +5 -5
- ripple_down_rules-0.1.5/test/test_rdr.py +331 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/test/test_rdr_alchemy.py +24 -30
- ripple_down_rules-0.1.5/test/test_rdr_world.py +109 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/test/test_relational_rdr.py +7 -7
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/test/test_relational_rdr_alchemy.py +8 -7
- ripple_down_rules-0.1.2/src/ripple_down_rules/datastructures/__init__.py +0 -4
- ripple_down_rules-0.1.2/src/ripple_down_rules/datastructures/callable_expression.py +0 -278
- ripple_down_rules-0.1.2/src/ripple_down_rules/datastructures/dataclasses.py +0 -115
- ripple_down_rules-0.1.2/test/test_rdr.py +0 -243
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/LICENSE +0 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/README.md +0 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/setup.cfg +0 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules/__init__.py +0 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules/datastructures/enums.py +0 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules/failures.py +0 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules/helpers.py +0 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules/rdr_decorators.py +0 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
- {ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.5
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
|
|
6
6
|
|
7
7
|
[project]
|
8
8
|
name = "ripple_down_rules"
|
9
|
-
version = "0.1.
|
9
|
+
version = "0.1.5"
|
10
10
|
description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
|
11
11
|
readme = "README.md"
|
12
12
|
authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
|
@@ -9,7 +9,8 @@ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationshi
|
|
9
9
|
from typing_extensions import Tuple, List, Set, Optional
|
10
10
|
from ucimlrepo import fetch_ucirepo
|
11
11
|
|
12
|
-
from .datastructures import Case, create_cases_from_dataframe
|
12
|
+
from .datastructures.case import Case, create_cases_from_dataframe
|
13
|
+
from .datastructures.enums import Category
|
13
14
|
|
14
15
|
|
15
16
|
def load_cached_dataset(cache_file):
|
@@ -0,0 +1,223 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import ast
|
4
|
+
import logging
|
5
|
+
from _ast import AST
|
6
|
+
|
7
|
+
from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
|
8
|
+
|
9
|
+
from .case import create_case, Case
|
10
|
+
from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string, conclusion_to_json, is_iterable
|
11
|
+
|
12
|
+
|
13
|
+
class VariableVisitor(ast.NodeVisitor):
|
14
|
+
"""
|
15
|
+
A visitor to extract all variables and comparisons from a python expression represented as an AST tree.
|
16
|
+
"""
|
17
|
+
compares: List[Tuple[Union[ast.Name, ast.Call], ast.cmpop, Union[ast.Name, ast.Call]]]
|
18
|
+
variables: Set[str]
|
19
|
+
all: List[ast.BoolOp]
|
20
|
+
|
21
|
+
def __init__(self):
|
22
|
+
self.variables = set()
|
23
|
+
self.attributes: Dict[ast.Name, ast.Attribute] = {}
|
24
|
+
self.types = set()
|
25
|
+
self.callables = set()
|
26
|
+
self.compares = list()
|
27
|
+
self.binary_ops = list()
|
28
|
+
self.all = list()
|
29
|
+
|
30
|
+
def visit_Constant(self, node):
|
31
|
+
self.all.append(node)
|
32
|
+
self.types.add(node)
|
33
|
+
self.generic_visit(node)
|
34
|
+
|
35
|
+
def visit_Call(self, node):
|
36
|
+
self.all.append(node)
|
37
|
+
self.callables.add(node)
|
38
|
+
self.generic_visit(node)
|
39
|
+
|
40
|
+
def visit_Attribute(self, node):
|
41
|
+
self.all.append(node)
|
42
|
+
self.attributes[node.value] = node
|
43
|
+
self.generic_visit(node)
|
44
|
+
|
45
|
+
def visit_BinOp(self, node):
|
46
|
+
self.binary_ops.append(node)
|
47
|
+
self.all.append(node)
|
48
|
+
self.generic_visit(node)
|
49
|
+
|
50
|
+
def visit_BoolOp(self, node):
|
51
|
+
self.all.append(node)
|
52
|
+
self.generic_visit(node)
|
53
|
+
|
54
|
+
def visit_Compare(self, node):
|
55
|
+
self.all.append(node)
|
56
|
+
self.compares.append([node.left, node.ops[0], node.comparators[0]])
|
57
|
+
self.generic_visit(node)
|
58
|
+
|
59
|
+
def visit_Name(self, node):
|
60
|
+
if f"__{node.id}__" not in dir(__builtins__) and node not in self.attributes:
|
61
|
+
self.variables.add(node.id)
|
62
|
+
self.generic_visit(node)
|
63
|
+
|
64
|
+
|
65
|
+
def get_used_scope(code_str, scope):
|
66
|
+
# Parse the code into an AST
|
67
|
+
mode = 'exec' if code_str.startswith('def') else 'eval'
|
68
|
+
tree = ast.parse(code_str, mode=mode)
|
69
|
+
|
70
|
+
# Walk the AST to collect used variable names
|
71
|
+
class NameCollector(ast.NodeVisitor):
|
72
|
+
def __init__(self):
|
73
|
+
self.names = set()
|
74
|
+
|
75
|
+
def visit_Name(self, node):
|
76
|
+
if isinstance(node.ctx, ast.Load): # We care only about variables being read
|
77
|
+
self.names.add(node.id)
|
78
|
+
|
79
|
+
collector = NameCollector()
|
80
|
+
collector.visit(tree)
|
81
|
+
|
82
|
+
# Filter the scope to include only used names
|
83
|
+
used_scope = {k: scope[k] for k in collector.names if k in scope}
|
84
|
+
return used_scope
|
85
|
+
|
86
|
+
|
87
|
+
class CallableExpression(SubclassJSONSerializer):
|
88
|
+
"""
|
89
|
+
A callable that is constructed from a string statement written by an expert.
|
90
|
+
"""
|
91
|
+
|
92
|
+
def __init__(self, user_input: Optional[str] = None, conclusion_type: Optional[Tuple[Type]] = None,
|
93
|
+
expression_tree: Optional[AST] = None,
|
94
|
+
scope: Optional[Dict[str, Any]] = None, conclusion: Optional[Any] = None):
|
95
|
+
"""
|
96
|
+
Create a callable expression.
|
97
|
+
|
98
|
+
:param user_input: The input given by the expert.
|
99
|
+
:param conclusion_type: The type of the output of the callable.
|
100
|
+
:param expression_tree: The AST tree parsed from the user input.
|
101
|
+
:param scope: The scope to use for the callable expression.
|
102
|
+
:param conclusion: The conclusion to use for the callable expression.
|
103
|
+
"""
|
104
|
+
if user_input is None and conclusion is None:
|
105
|
+
raise ValueError("Either user_input or conclusion must be provided.")
|
106
|
+
self.conclusion: Optional[Any] = conclusion
|
107
|
+
self.user_input: str = user_input
|
108
|
+
if conclusion_type is not None:
|
109
|
+
if is_iterable(conclusion_type):
|
110
|
+
conclusion_type = tuple(conclusion_type)
|
111
|
+
else:
|
112
|
+
conclusion_type = (conclusion_type,)
|
113
|
+
self.conclusion_type = conclusion_type
|
114
|
+
self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
|
115
|
+
if conclusion is None:
|
116
|
+
self.scope = get_used_scope(self.user_input, self.scope)
|
117
|
+
self.expression_tree: AST = expression_tree if expression_tree else parse_string_to_expression(self.user_input)
|
118
|
+
self.code = compile_expression_to_code(self.expression_tree)
|
119
|
+
self.visitor = VariableVisitor()
|
120
|
+
self.visitor.visit(self.expression_tree)
|
121
|
+
|
122
|
+
def __call__(self, case: Any, **kwargs) -> Any:
|
123
|
+
try:
|
124
|
+
if self.user_input is not None:
|
125
|
+
if not isinstance(case, Case):
|
126
|
+
case = create_case(case, max_recursion_idx=3)
|
127
|
+
scope = {'case': case, **self.scope}
|
128
|
+
output = eval(self.code, scope)
|
129
|
+
if output is None:
|
130
|
+
output = scope['_get_value'](case)
|
131
|
+
if self.conclusion_type is not None:
|
132
|
+
if is_iterable(output) and not isinstance(output, self.conclusion_type):
|
133
|
+
assert isinstance(list(output)[0], self.conclusion_type), (f"Expected output type {self.conclusion_type},"
|
134
|
+
f" got {type(output)}")
|
135
|
+
else:
|
136
|
+
assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
|
137
|
+
f" got {type(output)}")
|
138
|
+
return output
|
139
|
+
else:
|
140
|
+
return self.conclusion
|
141
|
+
except Exception as e:
|
142
|
+
raise ValueError(f"Error during evaluation: {e}")
|
143
|
+
|
144
|
+
def combine_with(self, other: 'CallableExpression') -> 'CallableExpression':
|
145
|
+
"""
|
146
|
+
Combine this callable expression with another callable expression using the 'and' operator.
|
147
|
+
"""
|
148
|
+
new_user_input = f"({self.user_input}) and ({other.user_input})"
|
149
|
+
return CallableExpression(new_user_input, conclusion_type=self.conclusion_type)
|
150
|
+
|
151
|
+
def __eq__(self, other):
|
152
|
+
"""
|
153
|
+
Check if two callable expressions are equal.
|
154
|
+
"""
|
155
|
+
if not isinstance(other, CallableExpression):
|
156
|
+
return False
|
157
|
+
return self.user_input == other.user_input and self.conclusion == other.conclusion
|
158
|
+
|
159
|
+
def __hash__(self):
|
160
|
+
"""
|
161
|
+
Hash the callable expression.
|
162
|
+
"""
|
163
|
+
conclusion_hash = self.conclusion if not isinstance(self.conclusion, set) else frozenset(self.conclusion)
|
164
|
+
return hash((self.user_input, conclusion_hash))
|
165
|
+
|
166
|
+
def __str__(self):
|
167
|
+
"""
|
168
|
+
Return the user string where each compare is written in a line using compare column offset start and end.
|
169
|
+
"""
|
170
|
+
if self.user_input is None:
|
171
|
+
return str(self.conclusion)
|
172
|
+
binary_ops = sorted(self.visitor.binary_ops, key=lambda x: x.end_col_offset)
|
173
|
+
binary_ops_indices = [b.end_col_offset for b in binary_ops]
|
174
|
+
all_binary_ops = []
|
175
|
+
prev_e = 0
|
176
|
+
for i, e in enumerate(binary_ops_indices):
|
177
|
+
if i == 0:
|
178
|
+
all_binary_ops.append(self.user_input[:e])
|
179
|
+
else:
|
180
|
+
all_binary_ops.append(self.user_input[prev_e:e])
|
181
|
+
prev_e = e
|
182
|
+
return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else self.user_input
|
183
|
+
|
184
|
+
def _to_json(self) -> Dict[str, Any]:
|
185
|
+
return {"user_input": self.user_input,
|
186
|
+
"conclusion_type": [get_full_class_name(t) for t in self.conclusion_type]
|
187
|
+
if self.conclusion_type is not None else None,
|
188
|
+
"scope": {k: get_full_class_name(v) for k, v in self.scope.items()
|
189
|
+
if hasattr(v, '__module__') and hasattr(v, '__name__')},
|
190
|
+
"conclusion": conclusion_to_json(self.conclusion),
|
191
|
+
}
|
192
|
+
|
193
|
+
@classmethod
|
194
|
+
def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
|
195
|
+
return cls(user_input=data["user_input"],
|
196
|
+
conclusion_type=tuple(get_type_from_string(t) for t in data["conclusion_type"])
|
197
|
+
if data["conclusion_type"] else None,
|
198
|
+
scope={k: get_type_from_string(v) for k, v in data["scope"].items()},
|
199
|
+
conclusion=SubclassJSONSerializer.from_json(data["conclusion"]))
|
200
|
+
|
201
|
+
|
202
|
+
def compile_expression_to_code(expression_tree: AST) -> Any:
|
203
|
+
"""
|
204
|
+
Compile an expression tree that was parsed from string into code that can be executed using 'eval(code)'
|
205
|
+
|
206
|
+
:param expression_tree: The parsed expression tree.
|
207
|
+
:return: The code that was compiled from the expression tree.
|
208
|
+
"""
|
209
|
+
mode = 'exec' if isinstance(expression_tree, ast.Module) else 'eval'
|
210
|
+
return compile(expression_tree, filename="<string>", mode=mode)
|
211
|
+
|
212
|
+
|
213
|
+
def parse_string_to_expression(expression_str: str) -> AST:
|
214
|
+
"""
|
215
|
+
Parse a string statement into an AST expression.
|
216
|
+
|
217
|
+
:param expression_str: The string which will be parsed.
|
218
|
+
:return: The parsed expression.
|
219
|
+
"""
|
220
|
+
mode = 'exec' if expression_str.startswith('def') else 'eval'
|
221
|
+
tree = ast.parse(expression_str, mode=mode)
|
222
|
+
logging.debug(f"AST parsed successfully: {ast.dump(tree)}")
|
223
|
+
return tree
|
{ripple_down_rules-0.1.2 → ripple_down_rules-0.1.5}/src/ripple_down_rules/datastructures/case.py
RENAMED
@@ -63,7 +63,7 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
63
63
|
new_list.extend(make_list(value))
|
64
64
|
super().__setitem__(name, new_list)
|
65
65
|
else:
|
66
|
-
super().__setitem__(name,
|
66
|
+
super().__setitem__(name, self[name])
|
67
67
|
else:
|
68
68
|
super().__setitem__(name, value)
|
69
69
|
setattr(self, name, self[name])
|
@@ -0,0 +1,169 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import inspect
|
4
|
+
from dataclasses import dataclass, field
|
5
|
+
|
6
|
+
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
7
|
+
from typing_extensions import Any, Optional, Dict, Type, Tuple, Union
|
8
|
+
|
9
|
+
from .callable_expression import CallableExpression
|
10
|
+
from .case import create_case, Case
|
11
|
+
from ..utils import copy_case
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class CaseQuery:
|
16
|
+
"""
|
17
|
+
This is a dataclass that represents an attribute of an object and its target value. If attribute name is
|
18
|
+
not provided, it will be inferred from the attribute itself or from the attribute type or from the target value,
|
19
|
+
depending on what is provided.
|
20
|
+
"""
|
21
|
+
original_case: Any
|
22
|
+
"""
|
23
|
+
The case that the attribute belongs to.
|
24
|
+
"""
|
25
|
+
attribute_name: str
|
26
|
+
"""
|
27
|
+
The name of the attribute.
|
28
|
+
"""
|
29
|
+
_attribute_types: Tuple[Type]
|
30
|
+
"""
|
31
|
+
The type(s) of the attribute.
|
32
|
+
"""
|
33
|
+
mutually_exclusive: bool
|
34
|
+
"""
|
35
|
+
Whether the attribute can only take one value (i.e. True) or multiple values (i.e. False).
|
36
|
+
"""
|
37
|
+
_target: Optional[CallableExpression] = None
|
38
|
+
"""
|
39
|
+
The target expression of the attribute.
|
40
|
+
"""
|
41
|
+
default_value: Optional[Any] = None
|
42
|
+
"""
|
43
|
+
The default value of the attribute. This is used when the target value is not provided.
|
44
|
+
"""
|
45
|
+
scope: Optional[Dict[str, Any]] = field(default_factory=lambda: inspect.currentframe().f_back.f_back.f_globals)
|
46
|
+
"""
|
47
|
+
The global scope of the case query. This is used to evaluate the conditions and prediction, and is what is available
|
48
|
+
to the user when they are prompted for input. If it is not provided, it will be set to the global scope of the
|
49
|
+
caller.
|
50
|
+
"""
|
51
|
+
_case: Optional[Union[Case, SQLTable]] = None
|
52
|
+
"""
|
53
|
+
The created case from the original case that the attribute belongs to.
|
54
|
+
"""
|
55
|
+
_target_value: Optional[Any] = None
|
56
|
+
"""
|
57
|
+
The target value of the case query. (This is the result of the target expression evaluation on the case.)
|
58
|
+
"""
|
59
|
+
conditions: Optional[CallableExpression] = None
|
60
|
+
"""
|
61
|
+
The conditions that must be satisfied for the target value to be valid.
|
62
|
+
"""
|
63
|
+
|
64
|
+
@property
|
65
|
+
def case(self) -> Any:
|
66
|
+
"""
|
67
|
+
:return: The case that the attribute belongs to.
|
68
|
+
"""
|
69
|
+
if self._case is not None:
|
70
|
+
return self._case
|
71
|
+
elif not isinstance(self.original_case, (Case, SQLTable)):
|
72
|
+
self._case = create_case(self.original_case, max_recursion_idx=3)
|
73
|
+
else:
|
74
|
+
self._case = self.original_case
|
75
|
+
return self._case
|
76
|
+
|
77
|
+
@case.setter
|
78
|
+
def case(self, value: Any):
|
79
|
+
"""
|
80
|
+
Set the case that the attribute belongs to.
|
81
|
+
"""
|
82
|
+
if not isinstance(value, (Case, SQLTable)):
|
83
|
+
raise ValueError("The case must be a Case or SQLTable object.")
|
84
|
+
self._case = value
|
85
|
+
|
86
|
+
@property
|
87
|
+
def attribute_type(self) -> Tuple[Type]:
|
88
|
+
"""
|
89
|
+
:return: The type of the attribute.
|
90
|
+
"""
|
91
|
+
if not self.mutually_exclusive and (set not in self._attribute_types):
|
92
|
+
self._attribute_types = tuple(list(self._attribute_types) + [set])
|
93
|
+
return self._attribute_types
|
94
|
+
|
95
|
+
@attribute_type.setter
|
96
|
+
def attribute_type(self, value: Type):
|
97
|
+
"""
|
98
|
+
Set the type of the attribute.
|
99
|
+
"""
|
100
|
+
if not isinstance(value, tuple):
|
101
|
+
value = (value,)
|
102
|
+
self._attribute_types = value
|
103
|
+
|
104
|
+
@property
|
105
|
+
def name(self):
|
106
|
+
"""
|
107
|
+
:return: The name of the case query.
|
108
|
+
"""
|
109
|
+
return f"{self.case_name}.{self.attribute_name}"
|
110
|
+
|
111
|
+
@property
|
112
|
+
def case_name(self) -> str:
|
113
|
+
"""
|
114
|
+
:return: The name of the case.
|
115
|
+
"""
|
116
|
+
return self.case._name if isinstance(self.case, Case) else self.case.__class__.__name__
|
117
|
+
|
118
|
+
@property
|
119
|
+
def target(self) -> Optional[CallableExpression]:
|
120
|
+
"""
|
121
|
+
:return: The target expression of the attribute.
|
122
|
+
"""
|
123
|
+
if self._target is not None and not isinstance(self._target, CallableExpression):
|
124
|
+
self._target = CallableExpression(conclusion=self._target, conclusion_type=self.attribute_type,
|
125
|
+
scope=self.scope)
|
126
|
+
return self._target
|
127
|
+
|
128
|
+
@target.setter
|
129
|
+
def target(self, value: Optional[CallableExpression]):
|
130
|
+
"""
|
131
|
+
Set the target expression of the attribute.
|
132
|
+
"""
|
133
|
+
if value is not None and not isinstance(value, (CallableExpression, str)):
|
134
|
+
raise ValueError("The target must be a CallableExpression or a string.")
|
135
|
+
self._target = value
|
136
|
+
self._update_target_value()
|
137
|
+
|
138
|
+
@property
|
139
|
+
def target_value(self) -> Any:
|
140
|
+
"""
|
141
|
+
:return: The target value of the case query.
|
142
|
+
"""
|
143
|
+
if self._target_value is None:
|
144
|
+
self._update_target_value()
|
145
|
+
return self._target_value
|
146
|
+
|
147
|
+
def _update_target_value(self):
|
148
|
+
"""
|
149
|
+
Update the target value of the case query.
|
150
|
+
"""
|
151
|
+
if isinstance(self.target, CallableExpression):
|
152
|
+
self._target_value = self.target(self.case)
|
153
|
+
else:
|
154
|
+
self._target_value = self.target
|
155
|
+
|
156
|
+
def __str__(self):
|
157
|
+
header = f"CaseQuery: {self.name}"
|
158
|
+
target = f"Target: {self.name} |= {self.target if self.target is not None else '?'}"
|
159
|
+
conditions = f"Conditions: {self.conditions if self.conditions is not None else '?'}"
|
160
|
+
return "\n".join([header, target, conditions])
|
161
|
+
|
162
|
+
def __repr__(self):
|
163
|
+
return self.__str__()
|
164
|
+
|
165
|
+
def __copy__(self):
|
166
|
+
return CaseQuery(self.original_case, self.attribute_name, self.attribute_type,
|
167
|
+
self.mutually_exclusive, _target=self.target, default_value=self.default_value,
|
168
|
+
scope=self.scope, _case=copy_case(self.case), _target_value=self.target_value,
|
169
|
+
conditions=self.conditions)
|
@@ -3,13 +3,15 @@ from __future__ import annotations
|
|
3
3
|
import json
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
|
6
|
-
from
|
7
|
-
from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type, Union, Any, get_type_hints
|
6
|
+
from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Type, Any
|
8
7
|
|
9
|
-
from .datastructures import
|
8
|
+
from .datastructures.case import Case, CaseAttribute
|
9
|
+
from .datastructures.callable_expression import CallableExpression
|
10
|
+
from .datastructures.enums import PromptFor
|
11
|
+
from .datastructures.dataclasses import CaseQuery
|
10
12
|
from .datastructures.case import show_current_and_corner_cases
|
11
|
-
from .prompt import prompt_user_for_expression
|
12
|
-
from .utils import get_all_subclasses,
|
13
|
+
from .prompt import prompt_user_for_expression
|
14
|
+
from .utils import get_all_subclasses, make_list
|
13
15
|
|
14
16
|
if TYPE_CHECKING:
|
15
17
|
from .rdr import Rule
|
@@ -89,10 +91,9 @@ class Human(Expert):
|
|
89
91
|
The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
|
90
92
|
"""
|
91
93
|
|
92
|
-
def __init__(self, use_loaded_answers: bool = False
|
94
|
+
def __init__(self, use_loaded_answers: bool = False):
|
93
95
|
self.all_expert_answers = []
|
94
96
|
self.use_loaded_answers = use_loaded_answers
|
95
|
-
self.session = session
|
96
97
|
|
97
98
|
def save_answers(self, path: str, append: bool = False):
|
98
99
|
"""
|
@@ -142,9 +143,9 @@ class Human(Expert):
|
|
142
143
|
if self.use_loaded_answers:
|
143
144
|
user_input = self.all_expert_answers.pop(0)
|
144
145
|
if user_input:
|
145
|
-
condition = CallableExpression(user_input, bool, scope=case_query.scope
|
146
|
+
condition = CallableExpression(user_input, bool, scope=case_query.scope)
|
146
147
|
else:
|
147
|
-
user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions
|
148
|
+
user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions)
|
148
149
|
if not self.use_loaded_answers:
|
149
150
|
self.all_expert_answers.append(user_input)
|
150
151
|
case_query.conditions = condition
|
@@ -168,21 +169,22 @@ class Human(Expert):
|
|
168
169
|
extra_conclusions[category] = self._get_conditions(case, {category.__class__.__name__: category})
|
169
170
|
return extra_conclusions
|
170
171
|
|
171
|
-
def ask_for_conclusion(self, case_query: CaseQuery) ->
|
172
|
+
def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
|
172
173
|
"""
|
173
174
|
Ask the expert to provide a conclusion for the case.
|
174
175
|
|
175
176
|
:param case_query: The case query containing the case to find a conclusion for.
|
176
|
-
:return: The
|
177
|
+
:return: The conclusion for the case as a callable expression.
|
177
178
|
"""
|
179
|
+
expression: Optional[CallableExpression] = None
|
178
180
|
if self.use_loaded_answers:
|
179
181
|
expert_input = self.all_expert_answers.pop(0)
|
180
|
-
|
181
|
-
|
182
|
+
if expert_input is not None:
|
183
|
+
expression = CallableExpression(expert_input, case_query.attribute_type,
|
184
|
+
scope=case_query.scope)
|
182
185
|
else:
|
183
186
|
show_current_and_corner_cases(case_query.case)
|
184
|
-
expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion
|
185
|
-
session=self.session)
|
187
|
+
expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion)
|
186
188
|
self.all_expert_answers.append(expert_input)
|
187
189
|
case_query.target = expression
|
188
190
|
return expression
|
@@ -195,7 +197,8 @@ class Human(Expert):
|
|
195
197
|
:return: The category type.
|
196
198
|
"""
|
197
199
|
cat_name = cat_name.lower()
|
198
|
-
self.known_categories = get_all_subclasses(
|
200
|
+
self.known_categories = get_all_subclasses(
|
201
|
+
CaseAttribute) if not self.known_categories else self.known_categories
|
199
202
|
self.known_categories.update(CaseAttribute.registry)
|
200
203
|
category_type = None
|
201
204
|
if cat_name in self.known_categories:
|
@@ -211,9 +214,9 @@ class Human(Expert):
|
|
211
214
|
question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
|
212
215
|
return not self.ask_yes_no_question(question)
|
213
216
|
|
214
|
-
def ask_if_conclusion_is_correct(self, x: Case, conclusion:
|
215
|
-
targets: Optional[List[
|
216
|
-
current_conclusions: Optional[List[
|
217
|
+
def ask_if_conclusion_is_correct(self, x: Case, conclusion: Any,
|
218
|
+
targets: Optional[List[Any]] = None,
|
219
|
+
current_conclusions: Optional[List[Any]] = None) -> bool:
|
217
220
|
"""
|
218
221
|
Ask the expert if the conclusion is correct.
|
219
222
|
|
@@ -225,9 +228,8 @@ class Human(Expert):
|
|
225
228
|
question = ""
|
226
229
|
if not self.use_loaded_answers:
|
227
230
|
targets = targets or []
|
228
|
-
|
229
|
-
x.
|
230
|
-
x.targets = targets
|
231
|
+
x.conclusions = make_list(current_conclusions)
|
232
|
+
x.targets = make_list(targets)
|
231
233
|
question = f"Is the conclusion {conclusion} correct for the case (y/n):" \
|
232
234
|
f"\n{str(x)}"
|
233
235
|
return self.ask_yes_no_question(question)
|