ripple-down-rules 0.1.21__py3-none-any.whl → 0.1.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,8 @@ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationshi
9
9
  from typing_extensions import Tuple, List, Set, Optional
10
10
  from ucimlrepo import fetch_ucirepo
11
11
 
12
- from .datastructures import Case, create_cases_from_dataframe, Category
12
+ from .datastructures.case import Case, create_cases_from_dataframe
13
+ from .datastructures.enums import Category
13
14
 
14
15
 
15
16
  def load_cached_dataset(cache_file):
@@ -1,4 +1,4 @@
1
- from .enums import *
2
- from .dataclasses import *
3
- from .callable_expression import *
4
- from .case import *
1
+ # from .enums import *
2
+ # from .dataclasses import *
3
+ # from .callable_expression import *
4
+ # from .case import *
@@ -4,11 +4,10 @@ import ast
4
4
  import logging
5
5
  from _ast import AST
6
6
 
7
- from sqlalchemy.orm import Session
8
7
  from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
9
8
 
10
9
  from .case import create_case, Case
11
- from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string
10
+ from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string, conclusion_to_json, is_iterable
12
11
 
13
12
 
14
13
  class VariableVisitor(ast.NodeVisitor):
@@ -65,7 +64,8 @@ class VariableVisitor(ast.NodeVisitor):
65
64
 
66
65
  def get_used_scope(code_str, scope):
67
66
  # Parse the code into an AST
68
- tree = ast.parse(code_str, mode='eval')
67
+ mode = 'exec' if code_str.startswith('def') else 'eval'
68
+ tree = ast.parse(code_str, mode=mode)
69
69
 
70
70
  # Walk the AST to collect used variable names
71
71
  class NameCollector(ast.NodeVisitor):
@@ -88,91 +88,56 @@ class CallableExpression(SubclassJSONSerializer):
88
88
  """
89
89
  A callable that is constructed from a string statement written by an expert.
90
90
  """
91
- conclusion_type: Type
92
- """
93
- The type of the output of the callable, used for assertion.
94
- """
95
- expression_tree: AST
96
- """
97
- The AST tree parsed from the user input.
98
- """
99
- user_input: str
100
- """
101
- The input given by the expert.
102
- """
103
- session: Optional[Session]
104
- """
105
- The sqlalchemy orm session.
106
- """
107
- visitor: VariableVisitor
108
- """
109
- A visitor to extract all variables and comparisons from a python expression represented as an AST tree.
110
- """
111
- code: Any
112
- """
113
- The code that was compiled from the expression tree
114
- """
115
- compares_column_offset: List[int]
116
- """
117
- The start and end indices of each comparison in the string of user input.
118
- """
119
91
 
120
- def __init__(self, user_input: str, conclusion_type: Optional[Type] = None, expression_tree: Optional[AST] = None,
121
- session: Optional[Session] = None, scope: Optional[Dict[str, Any]] = None):
92
+ def __init__(self, user_input: Optional[str] = None, conclusion_type: Optional[Tuple[Type]] = None,
93
+ expression_tree: Optional[AST] = None,
94
+ scope: Optional[Dict[str, Any]] = None, conclusion: Optional[Any] = None):
122
95
  """
123
96
  Create a callable expression.
124
97
 
125
98
  :param user_input: The input given by the expert.
126
99
  :param conclusion_type: The type of the output of the callable.
127
100
  :param expression_tree: The AST tree parsed from the user input.
128
- :param session: The sqlalchemy orm session.
101
+ :param scope: The scope to use for the callable expression.
102
+ :param conclusion: The conclusion to use for the callable expression.
129
103
  """
130
- self.session = session
104
+ if user_input is None and conclusion is None:
105
+ raise ValueError("Either user_input or conclusion must be provided.")
106
+ self.conclusion: Optional[Any] = conclusion
131
107
  self.user_input: str = user_input
132
- self.parsed_user_input = self.parse_user_input(user_input, session)
108
+ if conclusion_type is not None:
109
+ if is_iterable(conclusion_type):
110
+ conclusion_type = tuple(conclusion_type)
111
+ else:
112
+ conclusion_type = (conclusion_type,)
133
113
  self.conclusion_type = conclusion_type
134
114
  self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
135
- self.scope = get_used_scope(self.parsed_user_input, self.scope)
136
- self.update_expression(self.parsed_user_input, expression_tree)
137
-
138
- def get_used_scope_in_user_input(self) -> Set[str]:
139
- """
140
- Get the used scope in the user input.
141
- :return: The used scope in the user input.
142
- """
143
- return self.visitor.variables.union(self.visitor.attributes.keys())
144
-
145
- @staticmethod
146
- def parse_user_input(user_input: str, session: Optional[Session] = None) -> str:
147
- if ',' in user_input:
148
- user_input = user_input.split(',')
149
- user_input = [f"({u.strip()})" for u in user_input]
150
- user_input = ' & '.join(user_input) if session else ' and '.join(user_input)
151
- elif session:
152
- user_input = user_input.replace(" and ", " & ")
153
- user_input = user_input.replace(" or ", " | ")
154
- return user_input
155
-
156
- def update_expression(self, user_input: str, expression_tree: Optional[AST] = None):
157
- if not expression_tree:
158
- expression_tree = parse_string_to_expression(user_input)
159
- self.expression_tree: AST = expression_tree
160
- self.visitor = VariableVisitor()
161
- self.visitor.visit(expression_tree)
162
- self.expression_tree = parse_string_to_expression(self.parsed_user_input)
163
- self.compares_column_offset = [(c[0].col_offset, c[2].end_col_offset) for c in self.visitor.compares]
164
- self.code = compile_expression_to_code(self.expression_tree)
115
+ if conclusion is None:
116
+ self.scope = get_used_scope(self.user_input, self.scope)
117
+ self.expression_tree: AST = expression_tree if expression_tree else parse_string_to_expression(self.user_input)
118
+ self.code = compile_expression_to_code(self.expression_tree)
119
+ self.visitor = VariableVisitor()
120
+ self.visitor.visit(self.expression_tree)
165
121
 
166
122
  def __call__(self, case: Any, **kwargs) -> Any:
167
123
  try:
168
- if not isinstance(case, Case):
169
- case = create_case(case, max_recursion_idx=3)
170
- scope = {'case': case, **self.scope}
171
- output = eval(self.code, scope)
172
- if self.conclusion_type is not None:
173
- assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
174
- f" got {type(output)}")
175
- return output
124
+ if self.user_input is not None:
125
+ if not isinstance(case, Case):
126
+ case = create_case(case, max_recursion_idx=3)
127
+ scope = {'case': case, **self.scope}
128
+ output = eval(self.code, scope)
129
+ if output is None:
130
+ output = scope['_get_value'](case)
131
+ if self.conclusion_type is not None:
132
+ if is_iterable(output) and not isinstance(output, self.conclusion_type):
133
+ assert isinstance(list(output)[0], self.conclusion_type), (f"Expected output type {self.conclusion_type},"
134
+ f" got {type(output)}")
135
+ else:
136
+ assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
137
+ f" got {type(output)}")
138
+ return output
139
+ else:
140
+ return self.conclusion
176
141
  except Exception as e:
177
142
  raise ValueError(f"Error during evaluation: {e}")
178
143
 
@@ -181,35 +146,57 @@ class CallableExpression(SubclassJSONSerializer):
181
146
  Combine this callable expression with another callable expression using the 'and' operator.
182
147
  """
183
148
  new_user_input = f"({self.user_input}) and ({other.user_input})"
184
- return CallableExpression(new_user_input, conclusion_type=self.conclusion_type, session=self.session)
149
+ return CallableExpression(new_user_input, conclusion_type=self.conclusion_type)
150
+
151
+ def __eq__(self, other):
152
+ """
153
+ Check if two callable expressions are equal.
154
+ """
155
+ if not isinstance(other, CallableExpression):
156
+ return False
157
+ return self.user_input == other.user_input and self.conclusion == other.conclusion
158
+
159
+ def __hash__(self):
160
+ """
161
+ Hash the callable expression.
162
+ """
163
+ conclusion_hash = self.conclusion if not isinstance(self.conclusion, set) else frozenset(self.conclusion)
164
+ return hash((self.user_input, conclusion_hash))
185
165
 
186
166
  def __str__(self):
187
167
  """
188
168
  Return the user string where each compare is written in a line using compare column offset start and end.
189
169
  """
190
- user_input = self.parsed_user_input
170
+ if self.user_input is None:
171
+ return str(self.conclusion)
191
172
  binary_ops = sorted(self.visitor.binary_ops, key=lambda x: x.end_col_offset)
192
173
  binary_ops_indices = [b.end_col_offset for b in binary_ops]
193
174
  all_binary_ops = []
194
175
  prev_e = 0
195
176
  for i, e in enumerate(binary_ops_indices):
196
177
  if i == 0:
197
- all_binary_ops.append(user_input[:e])
178
+ all_binary_ops.append(self.user_input[:e])
198
179
  else:
199
- all_binary_ops.append(user_input[prev_e:e])
180
+ all_binary_ops.append(self.user_input[prev_e:e])
200
181
  prev_e = e
201
- return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
182
+ return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else self.user_input
202
183
 
203
184
  def _to_json(self) -> Dict[str, Any]:
204
- return {"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type),
185
+ return {"user_input": self.user_input,
186
+ "conclusion_type": [get_full_class_name(t) for t in self.conclusion_type]
187
+ if self.conclusion_type is not None else None,
205
188
  "scope": {k: get_full_class_name(v) for k, v in self.scope.items()
206
- if hasattr(v, '__module__') and hasattr(v, '__name__')}
189
+ if hasattr(v, '__module__') and hasattr(v, '__name__')},
190
+ "conclusion": conclusion_to_json(self.conclusion),
207
191
  }
208
192
 
209
193
  @classmethod
210
194
  def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
211
- return cls(user_input=data["user_input"], conclusion_type=get_type_from_string(data["conclusion_type"]),
212
- scope={k: get_type_from_string(v) for k, v in data["scope"].items()})
195
+ return cls(user_input=data["user_input"],
196
+ conclusion_type=tuple(get_type_from_string(t) for t in data["conclusion_type"])
197
+ if data["conclusion_type"] else None,
198
+ scope={k: get_type_from_string(v) for k, v in data["scope"].items()},
199
+ conclusion=SubclassJSONSerializer.from_json(data["conclusion"]))
213
200
 
214
201
 
215
202
  def compile_expression_to_code(expression_tree: AST) -> Any:
@@ -219,51 +206,8 @@ def compile_expression_to_code(expression_tree: AST) -> Any:
219
206
  :param expression_tree: The parsed expression tree.
220
207
  :return: The code that was compiled from the expression tree.
221
208
  """
222
- return compile(expression_tree, filename="<string>", mode="eval")
223
-
224
-
225
- def assert_context_contains_needed_information(case: Any, context: Dict[str, Any],
226
- visitor: VariableVisitor) -> Tuple[Set[str], Set[str]]:
227
- """
228
- Asserts that the variables mentioned in the expression visited by visitor are all in the given context.
229
-
230
- :param case: The case to check the context for.
231
- :param context: The context to check.
232
- :param visitor: The visitor that visited the expression.
233
- :return: The found variables and attributes.
234
- """
235
- found_variables = set()
236
- for key in visitor.variables:
237
- if key not in context:
238
- raise ValueError(f"Variable {key} not found in the case {case}")
239
- found_variables.add(key)
240
-
241
- found_attributes = get_attributes_str(visitor)
242
- for attr in found_attributes:
243
- if attr not in context:
244
- raise ValueError(f"Attribute {attr} not found in the case {case}")
245
- return found_variables, found_attributes
246
-
247
-
248
- def get_attributes_str(visitor: VariableVisitor) -> Set[str]:
249
- """
250
- Get the string representation of the attributes in the given visitor.
251
-
252
- :param visitor: The visitor that visited the expression.
253
- :return: The string representation of the attributes.
254
- """
255
- found_attributes = set()
256
- for key, ast_attr in visitor.attributes.items():
257
- str_attr = ""
258
- while isinstance(key, ast.Attribute):
259
- if len(str_attr) > 0:
260
- str_attr = f"{key.attr}.{str_attr}"
261
- else:
262
- str_attr = key.attr
263
- key = key.value
264
- str_attr = f"{key.id}.{str_attr}" if len(str_attr) > 0 else f"{key.id}.{ast_attr.attr}"
265
- found_attributes.add(str_attr)
266
- return found_attributes
209
+ mode = 'exec' if isinstance(expression_tree, ast.Module) else 'eval'
210
+ return compile(expression_tree, filename="<string>", mode=mode)
267
211
 
268
212
 
269
213
  def parse_string_to_expression(expression_str: str) -> AST:
@@ -273,6 +217,7 @@ def parse_string_to_expression(expression_str: str) -> AST:
273
217
  :param expression_str: The string which will be parsed.
274
218
  :return: The parsed expression.
275
219
  """
276
- tree = ast.parse(expression_str, mode='eval')
220
+ mode = 'exec' if expression_str.startswith('def') else 'eval'
221
+ tree = ast.parse(expression_str, mode=mode)
277
222
  logging.debug(f"AST parsed successfully: {ast.dump(tree)}")
278
223
  return tree
@@ -24,7 +24,8 @@ class Case(UserDict, SubclassJSONSerializer):
24
24
  the names of the attributes and the values are the attributes. All are stored in lower case.
25
25
  """
26
26
 
27
- def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None, _name: Optional[str] = None, **kwargs):
27
+ def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None,
28
+ _name: Optional[str] = None, original_object: Optional[Any] = None, **kwargs):
28
29
  """
29
30
  Create a new row.
30
31
 
@@ -34,6 +35,7 @@ class Case(UserDict, SubclassJSONSerializer):
34
35
  :param kwargs: The attributes of the row.
35
36
  """
36
37
  super().__init__(kwargs)
38
+ self._original_object = original_object
37
39
  self._obj_type: Type = _obj_type
38
40
  self._id: Hashable = _id if _id is not None else id(self)
39
41
  self._name: str = _name if _name is not None else self._obj_type.__name__
@@ -63,7 +65,7 @@ class Case(UserDict, SubclassJSONSerializer):
63
65
  new_list.extend(make_list(value))
64
66
  super().__setitem__(name, new_list)
65
67
  else:
66
- super().__setitem__(name, [self[name], value])
68
+ super().__setitem__(name, self[name])
67
69
  else:
68
70
  super().__setitem__(name, value)
69
71
  setattr(self, name, self[name])
@@ -221,9 +223,9 @@ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
221
223
  return obj
222
224
  if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
223
225
  or (obj.__class__ in [MetaData, registry])):
224
- return Case(type(obj), _id=id(obj), _name=obj_name,
226
+ return Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj,
225
227
  **{obj_name or obj.__class__.__name__: make_list(obj) if parent_is_iterable else obj})
226
- case = Case(type(obj), _id=id(obj), _name=obj_name)
228
+ case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
227
229
  for attr in dir(obj):
228
230
  if attr.startswith("_") or callable(getattr(obj, attr)):
229
231
  continue
@@ -251,7 +253,7 @@ def create_or_update_case_from_attribute(attr_value: Any, name: str, obj: Any, o
251
253
  :return: The updated/created case.
252
254
  """
253
255
  if case is None:
254
- case = Case(type(obj), _id=id(obj), _name=obj_name)
256
+ case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
255
257
  if isinstance(attr_value, (dict, UserDict)):
256
258
  case.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
257
259
  if hasattr(attr_value, "__iter__") and not isinstance(attr_value, str):
@@ -280,7 +282,7 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
280
282
  """
281
283
  values = list(attr_value.values()) if isinstance(attr_value, (dict, UserDict)) else attr_value
282
284
  _type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
283
- attr_case = Case(_type, _id=id(attr_value), _name=name)
285
+ attr_case = Case(_type, _id=id(attr_value), _name=name, original_object=attr_value)
284
286
  case_attr = CaseAttribute(values)
285
287
  for idx, val in enumerate(values):
286
288
  sub_attr_case = create_case(val, recursion_idx=recursion_idx,
@@ -1,16 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from dataclasses import dataclass
4
+ from dataclasses import dataclass, field
5
5
 
6
6
  from sqlalchemy.orm import DeclarativeBase as SQLTable
7
- from typing_extensions import Any, Optional, Type, List, Tuple, Set, Dict, TYPE_CHECKING
7
+ from typing_extensions import Any, Optional, Dict, Type, Tuple, Union
8
8
 
9
+ from .callable_expression import CallableExpression
9
10
  from .case import create_case, Case
10
- from ..utils import get_attribute_name, copy_case, get_hint_for_attribute, typing_to_python_type
11
+ from ..utils import copy_case, make_list, make_set
11
12
 
12
- if TYPE_CHECKING:
13
- from . import CallableExpression
14
13
 
15
14
  @dataclass
16
15
  class CaseQuery:
@@ -19,7 +18,7 @@ class CaseQuery:
19
18
  not provided, it will be inferred from the attribute itself or from the attribute type or from the target value,
20
19
  depending on what is provided.
21
20
  """
22
- case: Any
21
+ original_case: Any
23
22
  """
24
23
  The case that the attribute belongs to.
25
24
  """
@@ -27,64 +26,80 @@ class CaseQuery:
27
26
  """
28
27
  The name of the attribute.
29
28
  """
30
- target: Optional[Any] = None
29
+ _attribute_types: Tuple[Type]
31
30
  """
32
- The target value of the attribute.
31
+ The type(s) of the attribute.
33
32
  """
34
- mutually_exclusive: bool = False
33
+ mutually_exclusive: bool
35
34
  """
36
35
  Whether the attribute can only take one value (i.e. True) or multiple values (i.e. False).
37
36
  """
38
- conditions: Optional[CallableExpression] = None
37
+ _target: Optional[CallableExpression] = None
39
38
  """
40
- The conditions that must be satisfied for the target value to be valid.
39
+ The target expression of the attribute.
41
40
  """
42
- prediction: Optional[CallableExpression] = None
41
+ default_value: Optional[Any] = None
43
42
  """
44
- The predicted value of the attribute.
43
+ The default value of the attribute. This is used when the target value is not provided.
45
44
  """
46
- scope: Optional[Dict[str, Any]] = None
45
+ scope: Optional[Dict[str, Any]] = field(default_factory=lambda: inspect.currentframe().f_back.f_back.f_globals)
47
46
  """
48
47
  The global scope of the case query. This is used to evaluate the conditions and prediction, and is what is available
49
48
  to the user when they are prompted for input. If it is not provided, it will be set to the global scope of the
50
49
  caller.
51
50
  """
51
+ _case: Optional[Union[Case, SQLTable]] = None
52
+ """
53
+ The created case from the original case that the attribute belongs to.
54
+ """
55
+ _target_value: Optional[Any] = None
56
+ """
57
+ The target value of the case query. (This is the result of the target expression evaluation on the case.)
58
+ """
59
+ conditions: Optional[CallableExpression] = None
60
+ """
61
+ The conditions that must be satisfied for the target value to be valid.
62
+ """
52
63
 
53
- def __init__(self, case: Any, attribute_name: str,
54
- target: Optional[Any] = None,
55
- mutually_exclusive: bool = False,
56
- conditions: Optional[CallableExpression] = None,
57
- prediction: Optional[CallableExpression] = None,
58
- scope: Optional[Dict[str, Any]] = None,):
59
- self.original_case = case
60
- self.case = self._get_case()
61
-
62
- self.attribute_name = attribute_name
63
- self.target = target
64
- self.attribute_type = self._get_attribute_type()
65
- self.mutually_exclusive = mutually_exclusive
66
- self.conditions = conditions
67
- self.prediction = prediction
68
- self.scope = scope if scope is not None else inspect.currentframe().f_back.f_globals
69
-
70
- def _get_case(self) -> Any:
71
- if not isinstance(self.original_case, (Case, SQLTable)):
72
- return create_case(self.original_case, max_recursion_idx=3)
64
+ @property
65
+ def case(self) -> Any:
66
+ """
67
+ :return: The case that the attribute belongs to.
68
+ """
69
+ if self._case is not None:
70
+ return self._case
71
+ elif not isinstance(self.original_case, (Case, SQLTable)):
72
+ self._case = create_case(self.original_case, max_recursion_idx=3)
73
73
  else:
74
- return self.original_case
74
+ self._case = self.original_case
75
+ return self._case
75
76
 
76
- def _get_attribute_type(self) -> Type:
77
+ @case.setter
78
+ def case(self, value: Any):
79
+ """
80
+ Set the case that the attribute belongs to.
81
+ """
82
+ if not isinstance(value, (Case, SQLTable)):
83
+ raise ValueError("The case must be a Case or SQLTable object.")
84
+ self._case = value
85
+
86
+ @property
87
+ def attribute_type(self) -> Tuple[Type]:
77
88
  """
78
89
  :return: The type of the attribute.
79
90
  """
80
- if self.target is not None:
81
- return type(self.target)
82
- elif hasattr(self.original_case, self.attribute_name):
83
- hint, origin, args = get_hint_for_attribute(self.attribute_name, self.original_case)
84
- if origin is not None:
85
- return typing_to_python_type(origin)
86
- elif hint is not None:
87
- return typing_to_python_type(hint)
91
+ if not self.mutually_exclusive and (set not in make_list(self._attribute_types)):
92
+ self._attribute_types = tuple(set(make_list(self._attribute_types) + [set, list]))
93
+ elif not isinstance(self._attribute_types, tuple):
94
+ self._attribute_types = tuple(make_list(self._attribute_types))
95
+ return self._attribute_types
96
+
97
+ @attribute_type.setter
98
+ def attribute_type(self, value: Type):
99
+ """
100
+ Set the type of the attribute.
101
+ """
102
+ self._attribute_types = tuple(make_list(value))
88
103
 
89
104
  @property
90
105
  def name(self):
@@ -100,16 +115,55 @@ class CaseQuery:
100
115
  """
101
116
  return self.case._name if isinstance(self.case, Case) else self.case.__class__.__name__
102
117
 
118
+ @property
119
+ def target(self) -> Optional[CallableExpression]:
120
+ """
121
+ :return: The target expression of the attribute.
122
+ """
123
+ if self._target is not None and not isinstance(self._target, CallableExpression):
124
+ self._target = CallableExpression(conclusion=self._target, conclusion_type=self.attribute_type,
125
+ scope=self.scope)
126
+ return self._target
127
+
128
+ @target.setter
129
+ def target(self, value: Optional[CallableExpression]):
130
+ """
131
+ Set the target expression of the attribute.
132
+ """
133
+ if value is not None and not isinstance(value, (CallableExpression, str)):
134
+ raise ValueError("The target must be a CallableExpression or a string.")
135
+ self._target = value
136
+ self._update_target_value()
137
+
138
+ @property
139
+ def target_value(self) -> Any:
140
+ """
141
+ :return: The target value of the case query.
142
+ """
143
+ if self._target_value is None:
144
+ self._update_target_value()
145
+ return self._target_value
146
+
147
+ def _update_target_value(self):
148
+ """
149
+ Update the target value of the case query.
150
+ """
151
+ if isinstance(self.target, CallableExpression):
152
+ self._target_value = self.target(self.case)
153
+ else:
154
+ self._target_value = self.target
155
+
103
156
  def __str__(self):
104
157
  header = f"CaseQuery: {self.name}"
105
158
  target = f"Target: {self.name} |= {self.target if self.target is not None else '?'}"
106
- prediction = f"Prediction: {self.name} |= {self.prediction if self.prediction is not None else '?'}"
107
159
  conditions = f"Conditions: {self.conditions if self.conditions is not None else '?'}"
108
- return "\n".join([header, target, prediction, conditions])
160
+ return "\n".join([header, target, conditions])
109
161
 
110
162
  def __repr__(self):
111
163
  return self.__str__()
112
164
 
113
165
  def __copy__(self):
114
- return CaseQuery(copy_case(self.case), self.attribute_name, self.target, self.mutually_exclusive,
115
- self.conditions, self.prediction, self.scope)
166
+ return CaseQuery(self.original_case, self.attribute_name, self.attribute_type,
167
+ self.mutually_exclusive, _target=self.target, default_value=self.default_value,
168
+ scope=self.scope, _case=copy_case(self.case), _target_value=self.target_value,
169
+ conditions=self.conditions)
@@ -53,7 +53,7 @@ class ExpressionParser(Enum):
53
53
 
54
54
  class PromptFor(Enum):
55
55
  """
56
- The reason of the prompt. (e.g. get conditions, or conclusions).
56
+ The reason of the prompt. (e.g. get conditions, conclusions, or affirmation).
57
57
  """
58
58
  Conditions: str = "conditions"
59
59
  """
@@ -63,6 +63,10 @@ class PromptFor(Enum):
63
63
  """
64
64
  Prompt for rule conclusion about a case.
65
65
  """
66
+ Affirmation: str = "affirmation"
67
+ """
68
+ Prompt for rule conclusion affirmation about a case.
69
+ """
66
70
 
67
71
  def __str__(self):
68
72
  return self.name