ripple-down-rules 0.1.21__tar.gz → 0.1.61__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/PKG-INFO +5 -4
  2. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/README.md +4 -3
  3. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/pyproject.toml +1 -1
  4. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules/datasets.py +2 -1
  5. ripple_down_rules-0.1.61/src/ripple_down_rules/datastructures/__init__.py +4 -0
  6. ripple_down_rules-0.1.61/src/ripple_down_rules/datastructures/callable_expression.py +223 -0
  7. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules/datastructures/case.py +1 -1
  8. ripple_down_rules-0.1.61/src/ripple_down_rules/datastructures/dataclasses.py +168 -0
  9. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules/experts.py +24 -22
  10. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules/prompt.py +68 -68
  11. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules/rdr.py +270 -164
  12. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules/rules.py +64 -32
  13. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules/utils.py +205 -4
  14. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules.egg-info/PKG-INFO +5 -4
  15. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules.egg-info/SOURCES.txt +1 -0
  16. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/test/test_json_serialization.py +2 -2
  17. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/test/test_on_mutagenic.py +5 -5
  18. ripple_down_rules-0.1.61/test/test_rdr.py +331 -0
  19. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/test/test_rdr_alchemy.py +24 -30
  20. ripple_down_rules-0.1.61/test/test_rdr_world.py +119 -0
  21. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/test/test_relational_rdr.py +7 -7
  22. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/test/test_relational_rdr_alchemy.py +8 -7
  23. ripple_down_rules-0.1.21/src/ripple_down_rules/datastructures/__init__.py +0 -4
  24. ripple_down_rules-0.1.21/src/ripple_down_rules/datastructures/callable_expression.py +0 -278
  25. ripple_down_rules-0.1.21/src/ripple_down_rules/datastructures/dataclasses.py +0 -115
  26. ripple_down_rules-0.1.21/test/test_rdr.py +0 -243
  27. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/LICENSE +0 -0
  28. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/setup.cfg +0 -0
  29. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules/__init__.py +0 -0
  30. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules/datastructures/enums.py +0 -0
  31. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules/failures.py +0 -0
  32. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules/helpers.py +0 -0
  33. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules/rdr_decorators.py +0 -0
  34. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
  35. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
  36. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.61}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.21
3
+ Version: 0.1.61
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -709,7 +709,7 @@ Fit the SCRDR to the data, then classify one of the data cases to check if its c
709
709
  and render the tree to a file:
710
710
 
711
711
  ```Python
712
- from ripple_down_rules.datastructures import CaseQuery
712
+ from ripple_down_rules.datastructures.dataclasses import CaseQuery
713
713
  from ripple_down_rules.rdr import SingleClassRDR
714
714
  from ripple_down_rules.datasets import load_zoo_dataset
715
715
  from ripple_down_rules.utils import render_tree
@@ -719,12 +719,13 @@ all_cases, targets = load_zoo_dataset()
719
719
  scrdr = SingleClassRDR()
720
720
 
721
721
  # Fit the SCRDR to the data
722
- case_queries = [CaseQuery(case, target=target) for case, target in zip(all_cases, targets)]
722
+ case_queries = [CaseQuery(case, 'species', type(target), True, _target=target)
723
+ for case, target in zip(all_cases[:10], targets[:10])]
723
724
  scrdr.fit(case_queries, animate_tree=True)
724
725
 
725
726
  # Render the tree to a file
726
727
  render_tree(scrdr.start_rule, use_dot_exporter=True, filename="scrdr")
727
728
 
728
- cat = scrdr.fit_case(all_cases[50], targets[50])
729
+ cat = scrdr.classify(all_cases[50])
729
730
  assert cat == targets[50]
730
731
  ```
@@ -22,7 +22,7 @@ Fit the SCRDR to the data, then classify one of the data cases to check if its c
22
22
  and render the tree to a file:
23
23
 
24
24
  ```Python
25
- from ripple_down_rules.datastructures import CaseQuery
25
+ from ripple_down_rules.datastructures.dataclasses import CaseQuery
26
26
  from ripple_down_rules.rdr import SingleClassRDR
27
27
  from ripple_down_rules.datasets import load_zoo_dataset
28
28
  from ripple_down_rules.utils import render_tree
@@ -32,12 +32,13 @@ all_cases, targets = load_zoo_dataset()
32
32
  scrdr = SingleClassRDR()
33
33
 
34
34
  # Fit the SCRDR to the data
35
- case_queries = [CaseQuery(case, target=target) for case, target in zip(all_cases, targets)]
35
+ case_queries = [CaseQuery(case, 'species', type(target), True, _target=target)
36
+ for case, target in zip(all_cases[:10], targets[:10])]
36
37
  scrdr.fit(case_queries, animate_tree=True)
37
38
 
38
39
  # Render the tree to a file
39
40
  render_tree(scrdr.start_rule, use_dot_exporter=True, filename="scrdr")
40
41
 
41
- cat = scrdr.fit_case(all_cases[50], targets[50])
42
+ cat = scrdr.classify(all_cases[50])
42
43
  assert cat == targets[50]
43
44
  ```
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
 
7
7
  [project]
8
8
  name = "ripple_down_rules"
9
- version = "0.1.21"
9
+ version = "0.1.61"
10
10
  description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
11
11
  readme = "README.md"
12
12
  authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
@@ -9,7 +9,8 @@ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationshi
9
9
  from typing_extensions import Tuple, List, Set, Optional
10
10
  from ucimlrepo import fetch_ucirepo
11
11
 
12
- from .datastructures import Case, create_cases_from_dataframe, Category
12
+ from .datastructures.case import Case, create_cases_from_dataframe
13
+ from .datastructures.enums import Category
13
14
 
14
15
 
15
16
  def load_cached_dataset(cache_file):
@@ -0,0 +1,4 @@
1
+ # from .enums import *
2
+ # from .dataclasses import *
3
+ # from .callable_expression import *
4
+ # from .case import *
@@ -0,0 +1,223 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import logging
5
+ from _ast import AST
6
+
7
+ from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
8
+
9
+ from .case import create_case, Case
10
+ from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string, conclusion_to_json, is_iterable
11
+
12
+
13
+ class VariableVisitor(ast.NodeVisitor):
14
+ """
15
+ A visitor to extract all variables and comparisons from a python expression represented as an AST tree.
16
+ """
17
+ compares: List[Tuple[Union[ast.Name, ast.Call], ast.cmpop, Union[ast.Name, ast.Call]]]
18
+ variables: Set[str]
19
+ all: List[ast.BoolOp]
20
+
21
+ def __init__(self):
22
+ self.variables = set()
23
+ self.attributes: Dict[ast.Name, ast.Attribute] = {}
24
+ self.types = set()
25
+ self.callables = set()
26
+ self.compares = list()
27
+ self.binary_ops = list()
28
+ self.all = list()
29
+
30
+ def visit_Constant(self, node):
31
+ self.all.append(node)
32
+ self.types.add(node)
33
+ self.generic_visit(node)
34
+
35
+ def visit_Call(self, node):
36
+ self.all.append(node)
37
+ self.callables.add(node)
38
+ self.generic_visit(node)
39
+
40
+ def visit_Attribute(self, node):
41
+ self.all.append(node)
42
+ self.attributes[node.value] = node
43
+ self.generic_visit(node)
44
+
45
+ def visit_BinOp(self, node):
46
+ self.binary_ops.append(node)
47
+ self.all.append(node)
48
+ self.generic_visit(node)
49
+
50
+ def visit_BoolOp(self, node):
51
+ self.all.append(node)
52
+ self.generic_visit(node)
53
+
54
+ def visit_Compare(self, node):
55
+ self.all.append(node)
56
+ self.compares.append([node.left, node.ops[0], node.comparators[0]])
57
+ self.generic_visit(node)
58
+
59
+ def visit_Name(self, node):
60
+ if f"__{node.id}__" not in dir(__builtins__) and node not in self.attributes:
61
+ self.variables.add(node.id)
62
+ self.generic_visit(node)
63
+
64
+
65
+ def get_used_scope(code_str, scope):
66
+ # Parse the code into an AST
67
+ mode = 'exec' if code_str.startswith('def') else 'eval'
68
+ tree = ast.parse(code_str, mode=mode)
69
+
70
+ # Walk the AST to collect used variable names
71
+ class NameCollector(ast.NodeVisitor):
72
+ def __init__(self):
73
+ self.names = set()
74
+
75
+ def visit_Name(self, node):
76
+ if isinstance(node.ctx, ast.Load): # We care only about variables being read
77
+ self.names.add(node.id)
78
+
79
+ collector = NameCollector()
80
+ collector.visit(tree)
81
+
82
+ # Filter the scope to include only used names
83
+ used_scope = {k: scope[k] for k in collector.names if k in scope}
84
+ return used_scope
85
+
86
+
87
+ class CallableExpression(SubclassJSONSerializer):
88
+ """
89
+ A callable that is constructed from a string statement written by an expert.
90
+ """
91
+
92
+ def __init__(self, user_input: Optional[str] = None, conclusion_type: Optional[Tuple[Type]] = None,
93
+ expression_tree: Optional[AST] = None,
94
+ scope: Optional[Dict[str, Any]] = None, conclusion: Optional[Any] = None):
95
+ """
96
+ Create a callable expression.
97
+
98
+ :param user_input: The input given by the expert.
99
+ :param conclusion_type: The type of the output of the callable.
100
+ :param expression_tree: The AST tree parsed from the user input.
101
+ :param scope: The scope to use for the callable expression.
102
+ :param conclusion: The conclusion to use for the callable expression.
103
+ """
104
+ if user_input is None and conclusion is None:
105
+ raise ValueError("Either user_input or conclusion must be provided.")
106
+ self.conclusion: Optional[Any] = conclusion
107
+ self.user_input: str = user_input
108
+ if conclusion_type is not None:
109
+ if is_iterable(conclusion_type):
110
+ conclusion_type = tuple(conclusion_type)
111
+ else:
112
+ conclusion_type = (conclusion_type,)
113
+ self.conclusion_type = conclusion_type
114
+ self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
115
+ if conclusion is None:
116
+ self.scope = get_used_scope(self.user_input, self.scope)
117
+ self.expression_tree: AST = expression_tree if expression_tree else parse_string_to_expression(self.user_input)
118
+ self.code = compile_expression_to_code(self.expression_tree)
119
+ self.visitor = VariableVisitor()
120
+ self.visitor.visit(self.expression_tree)
121
+
122
+ def __call__(self, case: Any, **kwargs) -> Any:
123
+ try:
124
+ if self.user_input is not None:
125
+ if not isinstance(case, Case):
126
+ case = create_case(case, max_recursion_idx=3)
127
+ scope = {'case': case, **self.scope}
128
+ output = eval(self.code, scope)
129
+ if output is None:
130
+ output = scope['_get_value'](case)
131
+ if self.conclusion_type is not None:
132
+ if is_iterable(output) and not isinstance(output, self.conclusion_type):
133
+ assert isinstance(list(output)[0], self.conclusion_type), (f"Expected output type {self.conclusion_type},"
134
+ f" got {type(output)}")
135
+ else:
136
+ assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
137
+ f" got {type(output)}")
138
+ return output
139
+ else:
140
+ return self.conclusion
141
+ except Exception as e:
142
+ raise ValueError(f"Error during evaluation: {e}")
143
+
144
+ def combine_with(self, other: 'CallableExpression') -> 'CallableExpression':
145
+ """
146
+ Combine this callable expression with another callable expression using the 'and' operator.
147
+ """
148
+ new_user_input = f"({self.user_input}) and ({other.user_input})"
149
+ return CallableExpression(new_user_input, conclusion_type=self.conclusion_type)
150
+
151
+ def __eq__(self, other):
152
+ """
153
+ Check if two callable expressions are equal.
154
+ """
155
+ if not isinstance(other, CallableExpression):
156
+ return False
157
+ return self.user_input == other.user_input and self.conclusion == other.conclusion
158
+
159
+ def __hash__(self):
160
+ """
161
+ Hash the callable expression.
162
+ """
163
+ conclusion_hash = self.conclusion if not isinstance(self.conclusion, set) else frozenset(self.conclusion)
164
+ return hash((self.user_input, conclusion_hash))
165
+
166
+ def __str__(self):
167
+ """
168
+ Return the user string where each compare is written in a line using compare column offset start and end.
169
+ """
170
+ if self.user_input is None:
171
+ return str(self.conclusion)
172
+ binary_ops = sorted(self.visitor.binary_ops, key=lambda x: x.end_col_offset)
173
+ binary_ops_indices = [b.end_col_offset for b in binary_ops]
174
+ all_binary_ops = []
175
+ prev_e = 0
176
+ for i, e in enumerate(binary_ops_indices):
177
+ if i == 0:
178
+ all_binary_ops.append(self.user_input[:e])
179
+ else:
180
+ all_binary_ops.append(self.user_input[prev_e:e])
181
+ prev_e = e
182
+ return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else self.user_input
183
+
184
+ def _to_json(self) -> Dict[str, Any]:
185
+ return {"user_input": self.user_input,
186
+ "conclusion_type": [get_full_class_name(t) for t in self.conclusion_type]
187
+ if self.conclusion_type is not None else None,
188
+ "scope": {k: get_full_class_name(v) for k, v in self.scope.items()
189
+ if hasattr(v, '__module__') and hasattr(v, '__name__')},
190
+ "conclusion": conclusion_to_json(self.conclusion),
191
+ }
192
+
193
+ @classmethod
194
+ def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
195
+ return cls(user_input=data["user_input"],
196
+ conclusion_type=tuple(get_type_from_string(t) for t in data["conclusion_type"])
197
+ if data["conclusion_type"] else None,
198
+ scope={k: get_type_from_string(v) for k, v in data["scope"].items()},
199
+ conclusion=SubclassJSONSerializer.from_json(data["conclusion"]))
200
+
201
+
202
+ def compile_expression_to_code(expression_tree: AST) -> Any:
203
+ """
204
+ Compile an expression tree that was parsed from string into code that can be executed using 'eval(code)'
205
+
206
+ :param expression_tree: The parsed expression tree.
207
+ :return: The code that was compiled from the expression tree.
208
+ """
209
+ mode = 'exec' if isinstance(expression_tree, ast.Module) else 'eval'
210
+ return compile(expression_tree, filename="<string>", mode=mode)
211
+
212
+
213
+ def parse_string_to_expression(expression_str: str) -> AST:
214
+ """
215
+ Parse a string statement into an AST expression.
216
+
217
+ :param expression_str: The string which will be parsed.
218
+ :return: The parsed expression.
219
+ """
220
+ mode = 'exec' if expression_str.startswith('def') else 'eval'
221
+ tree = ast.parse(expression_str, mode=mode)
222
+ logging.debug(f"AST parsed successfully: {ast.dump(tree)}")
223
+ return tree
@@ -63,7 +63,7 @@ class Case(UserDict, SubclassJSONSerializer):
63
63
  new_list.extend(make_list(value))
64
64
  super().__setitem__(name, new_list)
65
65
  else:
66
- super().__setitem__(name, [self[name], value])
66
+ super().__setitem__(name, self[name])
67
67
  else:
68
68
  super().__setitem__(name, value)
69
69
  setattr(self, name, self[name])
@@ -0,0 +1,168 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from dataclasses import dataclass, field
5
+
6
+ from sqlalchemy.orm import DeclarativeBase as SQLTable
7
+ from typing_extensions import Any, Optional, Dict, Type, Tuple, Union
8
+
9
+ from .callable_expression import CallableExpression
10
+ from .case import create_case, Case
11
+ from ..utils import copy_case, make_list
12
+
13
+
14
+ @dataclass
15
+ class CaseQuery:
16
+ """
17
+ This is a dataclass that represents an attribute of an object and its target value. If attribute name is
18
+ not provided, it will be inferred from the attribute itself or from the attribute type or from the target value,
19
+ depending on what is provided.
20
+ """
21
+ original_case: Any
22
+ """
23
+ The case that the attribute belongs to.
24
+ """
25
+ attribute_name: str
26
+ """
27
+ The name of the attribute.
28
+ """
29
+ _attribute_types: Tuple[Type]
30
+ """
31
+ The type(s) of the attribute.
32
+ """
33
+ mutually_exclusive: bool
34
+ """
35
+ Whether the attribute can only take one value (i.e. True) or multiple values (i.e. False).
36
+ """
37
+ _target: Optional[CallableExpression] = None
38
+ """
39
+ The target expression of the attribute.
40
+ """
41
+ default_value: Optional[Any] = None
42
+ """
43
+ The default value of the attribute. This is used when the target value is not provided.
44
+ """
45
+ scope: Optional[Dict[str, Any]] = field(default_factory=lambda: inspect.currentframe().f_back.f_back.f_globals)
46
+ """
47
+ The global scope of the case query. This is used to evaluate the conditions and prediction, and is what is available
48
+ to the user when they are prompted for input. If it is not provided, it will be set to the global scope of the
49
+ caller.
50
+ """
51
+ _case: Optional[Union[Case, SQLTable]] = None
52
+ """
53
+ The created case from the original case that the attribute belongs to.
54
+ """
55
+ _target_value: Optional[Any] = None
56
+ """
57
+ The target value of the case query. (This is the result of the target expression evaluation on the case.)
58
+ """
59
+ conditions: Optional[CallableExpression] = None
60
+ """
61
+ The conditions that must be satisfied for the target value to be valid.
62
+ """
63
+
64
+ @property
65
+ def case(self) -> Any:
66
+ """
67
+ :return: The case that the attribute belongs to.
68
+ """
69
+ if self._case is not None:
70
+ return self._case
71
+ elif not isinstance(self.original_case, (Case, SQLTable)):
72
+ self._case = create_case(self.original_case, max_recursion_idx=3)
73
+ else:
74
+ self._case = self.original_case
75
+ return self._case
76
+
77
+ @case.setter
78
+ def case(self, value: Any):
79
+ """
80
+ Set the case that the attribute belongs to.
81
+ """
82
+ if not isinstance(value, (Case, SQLTable)):
83
+ raise ValueError("The case must be a Case or SQLTable object.")
84
+ self._case = value
85
+
86
+ @property
87
+ def attribute_type(self) -> Tuple[Type]:
88
+ """
89
+ :return: The type of the attribute.
90
+ """
91
+ self._attribute_types = tuple(make_list(self._attribute_types))
92
+ if not self.mutually_exclusive and (set not in self._attribute_types):
93
+ self._attribute_types = tuple(list(self._attribute_types) + [set])
94
+ return self._attribute_types
95
+
96
+ @attribute_type.setter
97
+ def attribute_type(self, value: Type):
98
+ """
99
+ Set the type of the attribute.
100
+ """
101
+ self._attribute_types = tuple(make_list(value))
102
+
103
+ @property
104
+ def name(self):
105
+ """
106
+ :return: The name of the case query.
107
+ """
108
+ return f"{self.case_name}.{self.attribute_name}"
109
+
110
+ @property
111
+ def case_name(self) -> str:
112
+ """
113
+ :return: The name of the case.
114
+ """
115
+ return self.case._name if isinstance(self.case, Case) else self.case.__class__.__name__
116
+
117
+ @property
118
+ def target(self) -> Optional[CallableExpression]:
119
+ """
120
+ :return: The target expression of the attribute.
121
+ """
122
+ if self._target is not None and not isinstance(self._target, CallableExpression):
123
+ self._target = CallableExpression(conclusion=self._target, conclusion_type=self.attribute_type,
124
+ scope=self.scope)
125
+ return self._target
126
+
127
+ @target.setter
128
+ def target(self, value: Optional[CallableExpression]):
129
+ """
130
+ Set the target expression of the attribute.
131
+ """
132
+ if value is not None and not isinstance(value, (CallableExpression, str)):
133
+ raise ValueError("The target must be a CallableExpression or a string.")
134
+ self._target = value
135
+ self._update_target_value()
136
+
137
+ @property
138
+ def target_value(self) -> Any:
139
+ """
140
+ :return: The target value of the case query.
141
+ """
142
+ if self._target_value is None:
143
+ self._update_target_value()
144
+ return self._target_value
145
+
146
+ def _update_target_value(self):
147
+ """
148
+ Update the target value of the case query.
149
+ """
150
+ if isinstance(self.target, CallableExpression):
151
+ self._target_value = self.target(self.case)
152
+ else:
153
+ self._target_value = self.target
154
+
155
+ def __str__(self):
156
+ header = f"CaseQuery: {self.name}"
157
+ target = f"Target: {self.name} |= {self.target if self.target is not None else '?'}"
158
+ conditions = f"Conditions: {self.conditions if self.conditions is not None else '?'}"
159
+ return "\n".join([header, target, conditions])
160
+
161
+ def __repr__(self):
162
+ return self.__str__()
163
+
164
+ def __copy__(self):
165
+ return CaseQuery(self.original_case, self.attribute_name, self.attribute_type,
166
+ self.mutually_exclusive, _target=self.target, default_value=self.default_value,
167
+ scope=self.scope, _case=copy_case(self.case), _target_value=self.target_value,
168
+ conditions=self.conditions)
@@ -3,13 +3,15 @@ from __future__ import annotations
3
3
  import json
4
4
  from abc import ABC, abstractmethod
5
5
 
6
- from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColumn, Session
7
- from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type, Union, Any, get_type_hints
6
+ from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Type, Any
8
7
 
9
- from .datastructures import (Case, PromptFor, CallableExpression, CaseAttribute, CaseQuery)
8
+ from .datastructures.case import Case, CaseAttribute
9
+ from .datastructures.callable_expression import CallableExpression
10
+ from .datastructures.enums import PromptFor
11
+ from .datastructures.dataclasses import CaseQuery
10
12
  from .datastructures.case import show_current_and_corner_cases
11
- from .prompt import prompt_user_for_expression, prompt_user_about_case
12
- from .utils import get_all_subclasses, is_iterable
13
+ from .prompt import prompt_user_for_expression
14
+ from .utils import get_all_subclasses, make_list
13
15
 
14
16
  if TYPE_CHECKING:
15
17
  from .rdr import Rule
@@ -89,10 +91,9 @@ class Human(Expert):
89
91
  The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
90
92
  """
91
93
 
92
- def __init__(self, use_loaded_answers: bool = False, session: Optional[Session] = None):
94
+ def __init__(self, use_loaded_answers: bool = False):
93
95
  self.all_expert_answers = []
94
96
  self.use_loaded_answers = use_loaded_answers
95
- self.session = session
96
97
 
97
98
  def save_answers(self, path: str, append: bool = False):
98
99
  """
@@ -142,9 +143,9 @@ class Human(Expert):
142
143
  if self.use_loaded_answers:
143
144
  user_input = self.all_expert_answers.pop(0)
144
145
  if user_input:
145
- condition = CallableExpression(user_input, bool, scope=case_query.scope, session=self.session)
146
+ condition = CallableExpression(user_input, bool, scope=case_query.scope)
146
147
  else:
147
- user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions, session=self.session)
148
+ user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions)
148
149
  if not self.use_loaded_answers:
149
150
  self.all_expert_answers.append(user_input)
150
151
  case_query.conditions = condition
@@ -168,21 +169,22 @@ class Human(Expert):
168
169
  extra_conclusions[category] = self._get_conditions(case, {category.__class__.__name__: category})
169
170
  return extra_conclusions
170
171
 
171
- def ask_for_conclusion(self, case_query: CaseQuery) -> CaseQuery:
172
+ def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
172
173
  """
173
174
  Ask the expert to provide a conclusion for the case.
174
175
 
175
176
  :param case_query: The case query containing the case to find a conclusion for.
176
- :return: The case query updated with the conclusion for the case.
177
+ :return: The conclusion for the case as a callable expression.
177
178
  """
179
+ expression: Optional[CallableExpression] = None
178
180
  if self.use_loaded_answers:
179
181
  expert_input = self.all_expert_answers.pop(0)
180
- expression = CallableExpression(expert_input, case_query.attribute_type, session=self.session,
181
- scope=case_query.scope)
182
+ if expert_input is not None:
183
+ expression = CallableExpression(expert_input, case_query.attribute_type,
184
+ scope=case_query.scope)
182
185
  else:
183
186
  show_current_and_corner_cases(case_query.case)
184
- expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion,
185
- session=self.session)
187
+ expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion)
186
188
  self.all_expert_answers.append(expert_input)
187
189
  case_query.target = expression
188
190
  return expression
@@ -195,7 +197,8 @@ class Human(Expert):
195
197
  :return: The category type.
196
198
  """
197
199
  cat_name = cat_name.lower()
198
- self.known_categories = get_all_subclasses(CaseAttribute) if not self.known_categories else self.known_categories
200
+ self.known_categories = get_all_subclasses(
201
+ CaseAttribute) if not self.known_categories else self.known_categories
199
202
  self.known_categories.update(CaseAttribute.registry)
200
203
  category_type = None
201
204
  if cat_name in self.known_categories:
@@ -211,9 +214,9 @@ class Human(Expert):
211
214
  question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
212
215
  return not self.ask_yes_no_question(question)
213
216
 
214
- def ask_if_conclusion_is_correct(self, x: Case, conclusion: CaseAttribute,
215
- targets: Optional[List[CaseAttribute]] = None,
216
- current_conclusions: Optional[List[CaseAttribute]] = None) -> bool:
217
+ def ask_if_conclusion_is_correct(self, x: Case, conclusion: Any,
218
+ targets: Optional[List[Any]] = None,
219
+ current_conclusions: Optional[List[Any]] = None) -> bool:
217
220
  """
218
221
  Ask the expert if the conclusion is correct.
219
222
 
@@ -225,9 +228,8 @@ class Human(Expert):
225
228
  question = ""
226
229
  if not self.use_loaded_answers:
227
230
  targets = targets or []
228
- targets = targets if isinstance(targets, list) else [targets]
229
- x.conclusions = current_conclusions
230
- x.targets = targets
231
+ x.conclusions = make_list(current_conclusions)
232
+ x.targets = make_list(targets)
231
233
  question = f"Is the conclusion {conclusion} correct for the case (y/n):" \
232
234
  f"\n{str(x)}"
233
235
  return self.ask_yes_no_question(question)