ripple-down-rules 0.1.3__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/PKG-INFO +1 -1
  2. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/pyproject.toml +1 -1
  3. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datasets.py +2 -1
  4. ripple_down_rules-0.1.6/src/ripple_down_rules/datastructures/__init__.py +4 -0
  5. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datastructures/callable_expression.py +68 -128
  6. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datastructures/case.py +1 -1
  7. ripple_down_rules-0.1.6/src/ripple_down_rules/datastructures/dataclasses.py +168 -0
  8. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/experts.py +24 -22
  9. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/prompt.py +44 -50
  10. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/rdr.py +270 -164
  11. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/rules.py +64 -32
  12. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/utils.py +130 -2
  13. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
  14. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules.egg-info/SOURCES.txt +1 -0
  15. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/test/test_json_serialization.py +2 -2
  16. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/test/test_on_mutagenic.py +5 -5
  17. ripple_down_rules-0.1.6/test/test_rdr.py +331 -0
  18. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/test/test_rdr_alchemy.py +24 -30
  19. ripple_down_rules-0.1.6/test/test_rdr_world.py +119 -0
  20. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/test/test_relational_rdr.py +7 -7
  21. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/test/test_relational_rdr_alchemy.py +8 -7
  22. ripple_down_rules-0.1.3/src/ripple_down_rules/datastructures/__init__.py +0 -4
  23. ripple_down_rules-0.1.3/src/ripple_down_rules/datastructures/dataclasses.py +0 -115
  24. ripple_down_rules-0.1.3/test/test_rdr.py +0 -243
  25. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/LICENSE +0 -0
  26. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/README.md +0 -0
  27. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/setup.cfg +0 -0
  28. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/__init__.py +0 -0
  29. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datastructures/enums.py +0 -0
  30. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/failures.py +0 -0
  31. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/helpers.py +0 -0
  32. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules/rdr_decorators.py +0 -0
  33. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
  34. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
  35. {ripple_down_rules-0.1.3 → ripple_down_rules-0.1.6}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.3
3
+ Version: 0.1.6
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
 
7
7
  [project]
8
8
  name = "ripple_down_rules"
9
- version = "0.1.3"
9
+ version = "0.1.6"
10
10
  description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
11
11
  readme = "README.md"
12
12
  authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
@@ -9,7 +9,8 @@ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationshi
9
9
  from typing_extensions import Tuple, List, Set, Optional
10
10
  from ucimlrepo import fetch_ucirepo
11
11
 
12
- from .datastructures import Case, create_cases_from_dataframe, Category
12
+ from .datastructures.case import Case, create_cases_from_dataframe
13
+ from .datastructures.enums import Category
13
14
 
14
15
 
15
16
  def load_cached_dataset(cache_file):
@@ -0,0 +1,4 @@
1
+ # from .enums import *
2
+ # from .dataclasses import *
3
+ # from .callable_expression import *
4
+ # from .case import *
@@ -4,11 +4,10 @@ import ast
4
4
  import logging
5
5
  from _ast import AST
6
6
 
7
- from sqlalchemy.orm import Session
8
7
  from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
9
8
 
10
9
  from .case import create_case, Case
11
- from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string
10
+ from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string, conclusion_to_json, is_iterable
12
11
 
13
12
 
14
13
  class VariableVisitor(ast.NodeVisitor):
@@ -89,93 +88,56 @@ class CallableExpression(SubclassJSONSerializer):
89
88
  """
90
89
  A callable that is constructed from a string statement written by an expert.
91
90
  """
92
- conclusion_type: Type
93
- """
94
- The type of the output of the callable, used for assertion.
95
- """
96
- expression_tree: AST
97
- """
98
- The AST tree parsed from the user input.
99
- """
100
- user_input: str
101
- """
102
- The input given by the expert.
103
- """
104
- session: Optional[Session]
105
- """
106
- The sqlalchemy orm session.
107
- """
108
- visitor: VariableVisitor
109
- """
110
- A visitor to extract all variables and comparisons from a python expression represented as an AST tree.
111
- """
112
- code: Any
113
- """
114
- The code that was compiled from the expression tree
115
- """
116
- compares_column_offset: List[int]
117
- """
118
- The start and end indices of each comparison in the string of user input.
119
- """
120
91
 
121
- def __init__(self, user_input: str, conclusion_type: Optional[Type] = None, expression_tree: Optional[AST] = None,
122
- session: Optional[Session] = None, scope: Optional[Dict[str, Any]] = None):
92
+ def __init__(self, user_input: Optional[str] = None, conclusion_type: Optional[Tuple[Type]] = None,
93
+ expression_tree: Optional[AST] = None,
94
+ scope: Optional[Dict[str, Any]] = None, conclusion: Optional[Any] = None):
123
95
  """
124
96
  Create a callable expression.
125
97
 
126
98
  :param user_input: The input given by the expert.
127
99
  :param conclusion_type: The type of the output of the callable.
128
100
  :param expression_tree: The AST tree parsed from the user input.
129
- :param session: The sqlalchemy orm session.
101
+ :param scope: The scope to use for the callable expression.
102
+ :param conclusion: The conclusion to use for the callable expression.
130
103
  """
131
- self.session = session
104
+ if user_input is None and conclusion is None:
105
+ raise ValueError("Either user_input or conclusion must be provided.")
106
+ self.conclusion: Optional[Any] = conclusion
132
107
  self.user_input: str = user_input
133
- self.parsed_user_input = self.parse_user_input(user_input, session)
108
+ if conclusion_type is not None:
109
+ if is_iterable(conclusion_type):
110
+ conclusion_type = tuple(conclusion_type)
111
+ else:
112
+ conclusion_type = (conclusion_type,)
134
113
  self.conclusion_type = conclusion_type
135
114
  self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
136
- self.scope = get_used_scope(self.parsed_user_input, self.scope)
137
- self.update_expression(self.parsed_user_input, expression_tree)
138
-
139
- def get_used_scope_in_user_input(self) -> Set[str]:
140
- """
141
- Get the used scope in the user input.
142
- :return: The used scope in the user input.
143
- """
144
- return self.visitor.variables.union(self.visitor.attributes.keys())
145
-
146
- @staticmethod
147
- def parse_user_input(user_input: str, session: Optional[Session] = None) -> str:
148
- if ',' in user_input:
149
- user_input = user_input.split(',')
150
- user_input = [f"({u.strip()})" for u in user_input]
151
- user_input = ' & '.join(user_input) if session else ' and '.join(user_input)
152
- elif session:
153
- user_input = user_input.replace(" and ", " & ")
154
- user_input = user_input.replace(" or ", " | ")
155
- return user_input
156
-
157
- def update_expression(self, user_input: str, expression_tree: Optional[AST] = None):
158
- if not expression_tree:
159
- expression_tree = parse_string_to_expression(user_input)
160
- self.expression_tree: AST = expression_tree
161
- self.visitor = VariableVisitor()
162
- self.visitor.visit(expression_tree)
163
- self.expression_tree = parse_string_to_expression(self.parsed_user_input)
164
- self.compares_column_offset = [(c[0].col_offset, c[2].end_col_offset) for c in self.visitor.compares]
165
- self.code = compile_expression_to_code(self.expression_tree)
115
+ if conclusion is None:
116
+ self.scope = get_used_scope(self.user_input, self.scope)
117
+ self.expression_tree: AST = expression_tree if expression_tree else parse_string_to_expression(self.user_input)
118
+ self.code = compile_expression_to_code(self.expression_tree)
119
+ self.visitor = VariableVisitor()
120
+ self.visitor.visit(self.expression_tree)
166
121
 
167
122
  def __call__(self, case: Any, **kwargs) -> Any:
168
123
  try:
169
- if not isinstance(case, Case):
170
- case = create_case(case, max_recursion_idx=3)
171
- scope = {'case': case, **self.scope}
172
- output = eval(self.code, scope)
173
- if output is None:
174
- output = scope['_get_value'](case)
175
- if self.conclusion_type is not None:
176
- assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
177
- f" got {type(output)}")
178
- return output
124
+ if self.user_input is not None:
125
+ if not isinstance(case, Case):
126
+ case = create_case(case, max_recursion_idx=3)
127
+ scope = {'case': case, **self.scope}
128
+ output = eval(self.code, scope)
129
+ if output is None:
130
+ output = scope['_get_value'](case)
131
+ if self.conclusion_type is not None:
132
+ if is_iterable(output) and not isinstance(output, self.conclusion_type):
133
+ assert isinstance(list(output)[0], self.conclusion_type), (f"Expected output type {self.conclusion_type},"
134
+ f" got {type(output)}")
135
+ else:
136
+ assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
137
+ f" got {type(output)}")
138
+ return output
139
+ else:
140
+ return self.conclusion
179
141
  except Exception as e:
180
142
  raise ValueError(f"Error during evaluation: {e}")
181
143
 
@@ -184,35 +146,57 @@ class CallableExpression(SubclassJSONSerializer):
184
146
  Combine this callable expression with another callable expression using the 'and' operator.
185
147
  """
186
148
  new_user_input = f"({self.user_input}) and ({other.user_input})"
187
- return CallableExpression(new_user_input, conclusion_type=self.conclusion_type, session=self.session)
149
+ return CallableExpression(new_user_input, conclusion_type=self.conclusion_type)
150
+
151
+ def __eq__(self, other):
152
+ """
153
+ Check if two callable expressions are equal.
154
+ """
155
+ if not isinstance(other, CallableExpression):
156
+ return False
157
+ return self.user_input == other.user_input and self.conclusion == other.conclusion
158
+
159
+ def __hash__(self):
160
+ """
161
+ Hash the callable expression.
162
+ """
163
+ conclusion_hash = self.conclusion if not isinstance(self.conclusion, set) else frozenset(self.conclusion)
164
+ return hash((self.user_input, conclusion_hash))
188
165
 
189
166
  def __str__(self):
190
167
  """
191
168
  Return the user string where each compare is written in a line using compare column offset start and end.
192
169
  """
193
- user_input = self.parsed_user_input
170
+ if self.user_input is None:
171
+ return str(self.conclusion)
194
172
  binary_ops = sorted(self.visitor.binary_ops, key=lambda x: x.end_col_offset)
195
173
  binary_ops_indices = [b.end_col_offset for b in binary_ops]
196
174
  all_binary_ops = []
197
175
  prev_e = 0
198
176
  for i, e in enumerate(binary_ops_indices):
199
177
  if i == 0:
200
- all_binary_ops.append(user_input[:e])
178
+ all_binary_ops.append(self.user_input[:e])
201
179
  else:
202
- all_binary_ops.append(user_input[prev_e:e])
180
+ all_binary_ops.append(self.user_input[prev_e:e])
203
181
  prev_e = e
204
- return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
182
+ return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else self.user_input
205
183
 
206
184
  def _to_json(self) -> Dict[str, Any]:
207
- return {"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type),
185
+ return {"user_input": self.user_input,
186
+ "conclusion_type": [get_full_class_name(t) for t in self.conclusion_type]
187
+ if self.conclusion_type is not None else None,
208
188
  "scope": {k: get_full_class_name(v) for k, v in self.scope.items()
209
- if hasattr(v, '__module__') and hasattr(v, '__name__')}
189
+ if hasattr(v, '__module__') and hasattr(v, '__name__')},
190
+ "conclusion": conclusion_to_json(self.conclusion),
210
191
  }
211
192
 
212
193
  @classmethod
213
194
  def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
214
- return cls(user_input=data["user_input"], conclusion_type=get_type_from_string(data["conclusion_type"]),
215
- scope={k: get_type_from_string(v) for k, v in data["scope"].items()})
195
+ return cls(user_input=data["user_input"],
196
+ conclusion_type=tuple(get_type_from_string(t) for t in data["conclusion_type"])
197
+ if data["conclusion_type"] else None,
198
+ scope={k: get_type_from_string(v) for k, v in data["scope"].items()},
199
+ conclusion=SubclassJSONSerializer.from_json(data["conclusion"]))
216
200
 
217
201
 
218
202
  def compile_expression_to_code(expression_tree: AST) -> Any:
@@ -226,50 +210,6 @@ def compile_expression_to_code(expression_tree: AST) -> Any:
226
210
  return compile(expression_tree, filename="<string>", mode=mode)
227
211
 
228
212
 
229
- def assert_context_contains_needed_information(case: Any, context: Dict[str, Any],
230
- visitor: VariableVisitor) -> Tuple[Set[str], Set[str]]:
231
- """
232
- Asserts that the variables mentioned in the expression visited by visitor are all in the given context.
233
-
234
- :param case: The case to check the context for.
235
- :param context: The context to check.
236
- :param visitor: The visitor that visited the expression.
237
- :return: The found variables and attributes.
238
- """
239
- found_variables = set()
240
- for key in visitor.variables:
241
- if key not in context:
242
- raise ValueError(f"Variable {key} not found in the case {case}")
243
- found_variables.add(key)
244
-
245
- found_attributes = get_attributes_str(visitor)
246
- for attr in found_attributes:
247
- if attr not in context:
248
- raise ValueError(f"Attribute {attr} not found in the case {case}")
249
- return found_variables, found_attributes
250
-
251
-
252
- def get_attributes_str(visitor: VariableVisitor) -> Set[str]:
253
- """
254
- Get the string representation of the attributes in the given visitor.
255
-
256
- :param visitor: The visitor that visited the expression.
257
- :return: The string representation of the attributes.
258
- """
259
- found_attributes = set()
260
- for key, ast_attr in visitor.attributes.items():
261
- str_attr = ""
262
- while isinstance(key, ast.Attribute):
263
- if len(str_attr) > 0:
264
- str_attr = f"{key.attr}.{str_attr}"
265
- else:
266
- str_attr = key.attr
267
- key = key.value
268
- str_attr = f"{key.id}.{str_attr}" if len(str_attr) > 0 else f"{key.id}.{ast_attr.attr}"
269
- found_attributes.add(str_attr)
270
- return found_attributes
271
-
272
-
273
213
  def parse_string_to_expression(expression_str: str) -> AST:
274
214
  """
275
215
  Parse a string statement into an AST expression.
@@ -63,7 +63,7 @@ class Case(UserDict, SubclassJSONSerializer):
63
63
  new_list.extend(make_list(value))
64
64
  super().__setitem__(name, new_list)
65
65
  else:
66
- super().__setitem__(name, [self[name], value])
66
+ super().__setitem__(name, self[name])
67
67
  else:
68
68
  super().__setitem__(name, value)
69
69
  setattr(self, name, self[name])
@@ -0,0 +1,168 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from dataclasses import dataclass, field
5
+
6
+ from sqlalchemy.orm import DeclarativeBase as SQLTable
7
+ from typing_extensions import Any, Optional, Dict, Type, Tuple, Union
8
+
9
+ from .callable_expression import CallableExpression
10
+ from .case import create_case, Case
11
+ from ..utils import copy_case, make_list
12
+
13
+
14
+ @dataclass
15
+ class CaseQuery:
16
+ """
17
+ This is a dataclass that represents an attribute of an object and its target value. If attribute name is
18
+ not provided, it will be inferred from the attribute itself or from the attribute type or from the target value,
19
+ depending on what is provided.
20
+ """
21
+ original_case: Any
22
+ """
23
+ The case that the attribute belongs to.
24
+ """
25
+ attribute_name: str
26
+ """
27
+ The name of the attribute.
28
+ """
29
+ _attribute_types: Tuple[Type]
30
+ """
31
+ The type(s) of the attribute.
32
+ """
33
+ mutually_exclusive: bool
34
+ """
35
+ Whether the attribute can only take one value (i.e. True) or multiple values (i.e. False).
36
+ """
37
+ _target: Optional[CallableExpression] = None
38
+ """
39
+ The target expression of the attribute.
40
+ """
41
+ default_value: Optional[Any] = None
42
+ """
43
+ The default value of the attribute. This is used when the target value is not provided.
44
+ """
45
+ scope: Optional[Dict[str, Any]] = field(default_factory=lambda: inspect.currentframe().f_back.f_back.f_globals)
46
+ """
47
+ The global scope of the case query. This is used to evaluate the conditions and prediction, and is what is available
48
+ to the user when they are prompted for input. If it is not provided, it will be set to the global scope of the
49
+ caller.
50
+ """
51
+ _case: Optional[Union[Case, SQLTable]] = None
52
+ """
53
+ The created case from the original case that the attribute belongs to.
54
+ """
55
+ _target_value: Optional[Any] = None
56
+ """
57
+ The target value of the case query. (This is the result of the target expression evaluation on the case.)
58
+ """
59
+ conditions: Optional[CallableExpression] = None
60
+ """
61
+ The conditions that must be satisfied for the target value to be valid.
62
+ """
63
+
64
+ @property
65
+ def case(self) -> Any:
66
+ """
67
+ :return: The case that the attribute belongs to.
68
+ """
69
+ if self._case is not None:
70
+ return self._case
71
+ elif not isinstance(self.original_case, (Case, SQLTable)):
72
+ self._case = create_case(self.original_case, max_recursion_idx=3)
73
+ else:
74
+ self._case = self.original_case
75
+ return self._case
76
+
77
+ @case.setter
78
+ def case(self, value: Any):
79
+ """
80
+ Set the case that the attribute belongs to.
81
+ """
82
+ if not isinstance(value, (Case, SQLTable)):
83
+ raise ValueError("The case must be a Case or SQLTable object.")
84
+ self._case = value
85
+
86
+ @property
87
+ def attribute_type(self) -> Tuple[Type]:
88
+ """
89
+ :return: The type of the attribute.
90
+ """
91
+ self._attribute_types = tuple(make_list(self._attribute_types))
92
+ if not self.mutually_exclusive and (set not in self._attribute_types):
93
+ self._attribute_types = tuple(list(self._attribute_types) + [set])
94
+ return self._attribute_types
95
+
96
+ @attribute_type.setter
97
+ def attribute_type(self, value: Type):
98
+ """
99
+ Set the type of the attribute.
100
+ """
101
+ self._attribute_types = tuple(make_list(value))
102
+
103
+ @property
104
+ def name(self):
105
+ """
106
+ :return: The name of the case query.
107
+ """
108
+ return f"{self.case_name}.{self.attribute_name}"
109
+
110
+ @property
111
+ def case_name(self) -> str:
112
+ """
113
+ :return: The name of the case.
114
+ """
115
+ return self.case._name if isinstance(self.case, Case) else self.case.__class__.__name__
116
+
117
+ @property
118
+ def target(self) -> Optional[CallableExpression]:
119
+ """
120
+ :return: The target expression of the attribute.
121
+ """
122
+ if self._target is not None and not isinstance(self._target, CallableExpression):
123
+ self._target = CallableExpression(conclusion=self._target, conclusion_type=self.attribute_type,
124
+ scope=self.scope)
125
+ return self._target
126
+
127
+ @target.setter
128
+ def target(self, value: Optional[CallableExpression]):
129
+ """
130
+ Set the target expression of the attribute.
131
+ """
132
+ if value is not None and not isinstance(value, (CallableExpression, str)):
133
+ raise ValueError("The target must be a CallableExpression or a string.")
134
+ self._target = value
135
+ self._update_target_value()
136
+
137
+ @property
138
+ def target_value(self) -> Any:
139
+ """
140
+ :return: The target value of the case query.
141
+ """
142
+ if self._target_value is None:
143
+ self._update_target_value()
144
+ return self._target_value
145
+
146
+ def _update_target_value(self):
147
+ """
148
+ Update the target value of the case query.
149
+ """
150
+ if isinstance(self.target, CallableExpression):
151
+ self._target_value = self.target(self.case)
152
+ else:
153
+ self._target_value = self.target
154
+
155
+ def __str__(self):
156
+ header = f"CaseQuery: {self.name}"
157
+ target = f"Target: {self.name} |= {self.target if self.target is not None else '?'}"
158
+ conditions = f"Conditions: {self.conditions if self.conditions is not None else '?'}"
159
+ return "\n".join([header, target, conditions])
160
+
161
+ def __repr__(self):
162
+ return self.__str__()
163
+
164
+ def __copy__(self):
165
+ return CaseQuery(self.original_case, self.attribute_name, self.attribute_type,
166
+ self.mutually_exclusive, _target=self.target, default_value=self.default_value,
167
+ scope=self.scope, _case=copy_case(self.case), _target_value=self.target_value,
168
+ conditions=self.conditions)
@@ -3,13 +3,15 @@ from __future__ import annotations
3
3
  import json
4
4
  from abc import ABC, abstractmethod
5
5
 
6
- from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColumn, Session
7
- from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type, Union, Any, get_type_hints
6
+ from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Type, Any
8
7
 
9
- from .datastructures import (Case, PromptFor, CallableExpression, CaseAttribute, CaseQuery)
8
+ from .datastructures.case import Case, CaseAttribute
9
+ from .datastructures.callable_expression import CallableExpression
10
+ from .datastructures.enums import PromptFor
11
+ from .datastructures.dataclasses import CaseQuery
10
12
  from .datastructures.case import show_current_and_corner_cases
11
- from .prompt import prompt_user_for_expression, prompt_user_about_case
12
- from .utils import get_all_subclasses, is_iterable
13
+ from .prompt import prompt_user_for_expression
14
+ from .utils import get_all_subclasses, make_list
13
15
 
14
16
  if TYPE_CHECKING:
15
17
  from .rdr import Rule
@@ -89,10 +91,9 @@ class Human(Expert):
89
91
  The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
90
92
  """
91
93
 
92
- def __init__(self, use_loaded_answers: bool = False, session: Optional[Session] = None):
94
+ def __init__(self, use_loaded_answers: bool = False):
93
95
  self.all_expert_answers = []
94
96
  self.use_loaded_answers = use_loaded_answers
95
- self.session = session
96
97
 
97
98
  def save_answers(self, path: str, append: bool = False):
98
99
  """
@@ -142,9 +143,9 @@ class Human(Expert):
142
143
  if self.use_loaded_answers:
143
144
  user_input = self.all_expert_answers.pop(0)
144
145
  if user_input:
145
- condition = CallableExpression(user_input, bool, scope=case_query.scope, session=self.session)
146
+ condition = CallableExpression(user_input, bool, scope=case_query.scope)
146
147
  else:
147
- user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions, session=self.session)
148
+ user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions)
148
149
  if not self.use_loaded_answers:
149
150
  self.all_expert_answers.append(user_input)
150
151
  case_query.conditions = condition
@@ -168,21 +169,22 @@ class Human(Expert):
168
169
  extra_conclusions[category] = self._get_conditions(case, {category.__class__.__name__: category})
169
170
  return extra_conclusions
170
171
 
171
- def ask_for_conclusion(self, case_query: CaseQuery) -> CaseQuery:
172
+ def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
172
173
  """
173
174
  Ask the expert to provide a conclusion for the case.
174
175
 
175
176
  :param case_query: The case query containing the case to find a conclusion for.
176
- :return: The case query updated with the conclusion for the case.
177
+ :return: The conclusion for the case as a callable expression.
177
178
  """
179
+ expression: Optional[CallableExpression] = None
178
180
  if self.use_loaded_answers:
179
181
  expert_input = self.all_expert_answers.pop(0)
180
- expression = CallableExpression(expert_input, case_query.attribute_type, session=self.session,
181
- scope=case_query.scope)
182
+ if expert_input is not None:
183
+ expression = CallableExpression(expert_input, case_query.attribute_type,
184
+ scope=case_query.scope)
182
185
  else:
183
186
  show_current_and_corner_cases(case_query.case)
184
- expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion,
185
- session=self.session)
187
+ expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion)
186
188
  self.all_expert_answers.append(expert_input)
187
189
  case_query.target = expression
188
190
  return expression
@@ -195,7 +197,8 @@ class Human(Expert):
195
197
  :return: The category type.
196
198
  """
197
199
  cat_name = cat_name.lower()
198
- self.known_categories = get_all_subclasses(CaseAttribute) if not self.known_categories else self.known_categories
200
+ self.known_categories = get_all_subclasses(
201
+ CaseAttribute) if not self.known_categories else self.known_categories
199
202
  self.known_categories.update(CaseAttribute.registry)
200
203
  category_type = None
201
204
  if cat_name in self.known_categories:
@@ -211,9 +214,9 @@ class Human(Expert):
211
214
  question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
212
215
  return not self.ask_yes_no_question(question)
213
216
 
214
- def ask_if_conclusion_is_correct(self, x: Case, conclusion: CaseAttribute,
215
- targets: Optional[List[CaseAttribute]] = None,
216
- current_conclusions: Optional[List[CaseAttribute]] = None) -> bool:
217
+ def ask_if_conclusion_is_correct(self, x: Case, conclusion: Any,
218
+ targets: Optional[List[Any]] = None,
219
+ current_conclusions: Optional[List[Any]] = None) -> bool:
217
220
  """
218
221
  Ask the expert if the conclusion is correct.
219
222
 
@@ -225,9 +228,8 @@ class Human(Expert):
225
228
  question = ""
226
229
  if not self.use_loaded_answers:
227
230
  targets = targets or []
228
- targets = targets if isinstance(targets, list) else [targets]
229
- x.conclusions = current_conclusions
230
- x.targets = targets
231
+ x.conclusions = make_list(current_conclusions)
232
+ x.targets = make_list(targets)
231
233
  question = f"Is the conclusion {conclusion} correct for the case (y/n):" \
232
234
  f"\n{str(x)}"
233
235
  return self.ask_yes_no_question(question)