ripple-down-rules 0.0.15__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationshi
9
9
  from typing_extensions import Tuple, List, Set, Optional
10
10
  from ucimlrepo import fetch_ucirepo
11
11
 
12
- from .datastructures import Case, create_cases_from_dataframe, Category, CaseAttribute
12
+ from .datastructures import Case, create_cases_from_dataframe, Category
13
13
 
14
14
 
15
15
  def load_cached_dataset(cache_file):
@@ -77,7 +77,7 @@ def load_zoo_dataset(cache_file: Optional[str] = None) -> Tuple[List[Case], List
77
77
  y = zoo['targets']
78
78
  # get ids as list of strings
79
79
  ids = zoo['ids'].values.flatten()
80
- all_cases = create_cases_from_dataframe(X)
80
+ all_cases = create_cases_from_dataframe(X, name="Animal")
81
81
 
82
82
  category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
83
83
  category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
@@ -22,10 +22,22 @@ class VariableVisitor(ast.NodeVisitor):
22
22
  def __init__(self):
23
23
  self.variables = set()
24
24
  self.attributes: Dict[ast.Name, ast.Attribute] = {}
25
+ self.types = set()
26
+ self.callables = set()
25
27
  self.compares = list()
26
28
  self.binary_ops = list()
27
29
  self.all = list()
28
30
 
31
+ def visit_Constant(self, node):
32
+ self.all.append(node)
33
+ self.types.add(node)
34
+ self.generic_visit(node)
35
+
36
+ def visit_Call(self, node):
37
+ self.all.append(node)
38
+ self.callables.add(node)
39
+ self.generic_visit(node)
40
+
29
41
  def visit_Attribute(self, node):
30
42
  self.all.append(node)
31
43
  self.attributes[node.value] = node
@@ -51,6 +63,27 @@ class VariableVisitor(ast.NodeVisitor):
51
63
  self.generic_visit(node)
52
64
 
53
65
 
66
+ def get_used_scope(code_str, scope):
67
+ # Parse the code into an AST
68
+ tree = ast.parse(code_str, mode='eval')
69
+
70
+ # Walk the AST to collect used variable names
71
+ class NameCollector(ast.NodeVisitor):
72
+ def __init__(self):
73
+ self.names = set()
74
+
75
+ def visit_Name(self, node):
76
+ if isinstance(node.ctx, ast.Load): # We care only about variables being read
77
+ self.names.add(node.id)
78
+
79
+ collector = NameCollector()
80
+ collector.visit(tree)
81
+
82
+ # Filter the scope to include only used names
83
+ used_scope = {k: scope[k] for k in collector.names if k in scope}
84
+ return used_scope
85
+
86
+
54
87
  class CallableExpression(SubclassJSONSerializer):
55
88
  """
56
89
  A callable that is constructed from a string statement written by an expert.
@@ -85,7 +118,7 @@ class CallableExpression(SubclassJSONSerializer):
85
118
  """
86
119
 
87
120
  def __init__(self, user_input: str, conclusion_type: Optional[Type] = None, expression_tree: Optional[AST] = None,
88
- session: Optional[Session] = None):
121
+ session: Optional[Session] = None, scope: Optional[Dict[str, Any]] = None):
89
122
  """
90
123
  Create a callable expression.
91
124
 
@@ -98,8 +131,17 @@ class CallableExpression(SubclassJSONSerializer):
98
131
  self.user_input: str = user_input
99
132
  self.parsed_user_input = self.parse_user_input(user_input, session)
100
133
  self.conclusion_type = conclusion_type
134
+ self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
135
+ self.scope = get_used_scope(self.parsed_user_input, self.scope)
101
136
  self.update_expression(self.parsed_user_input, expression_tree)
102
137
 
138
+ def get_used_scope_in_user_input(self) -> Set[str]:
139
+ """
140
+ Get the used scope in the user input.
141
+ :return: The used scope in the user input.
142
+ """
143
+ return self.visitor.variables.union(self.visitor.attributes.keys())
144
+
103
145
  @staticmethod
104
146
  def parse_user_input(user_input: str, session: Optional[Session] = None) -> str:
105
147
  if ',' in user_input:
@@ -117,11 +159,6 @@ class CallableExpression(SubclassJSONSerializer):
117
159
  self.expression_tree: AST = expression_tree
118
160
  self.visitor = VariableVisitor()
119
161
  self.visitor.visit(expression_tree)
120
- variables_str = self.visitor.variables
121
- attributes_str = get_attributes_str(self.visitor)
122
- for v in variables_str | attributes_str:
123
- if not v.startswith("case."):
124
- self.parsed_user_input = self.parsed_user_input.replace(v, f"case.{v}")
125
162
  self.expression_tree = parse_string_to_expression(self.parsed_user_input)
126
163
  self.compares_column_offset = [(c[0].col_offset, c[2].end_col_offset) for c in self.visitor.compares]
127
164
  self.code = compile_expression_to_code(self.expression_tree)
@@ -130,8 +167,9 @@ class CallableExpression(SubclassJSONSerializer):
130
167
  try:
131
168
  if not isinstance(case, Case):
132
169
  case = create_case(case, max_recursion_idx=3)
133
- output = eval(self.code)
134
- if self.conclusion_type:
170
+ scope = {'case': case, **self.scope}
171
+ output = eval(self.code, scope)
172
+ if self.conclusion_type is not None:
135
173
  assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
136
174
  f" got {type(output)}")
137
175
  return output
@@ -163,11 +201,15 @@ class CallableExpression(SubclassJSONSerializer):
163
201
  return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
164
202
 
165
203
  def _to_json(self) -> Dict[str, Any]:
166
- return {"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type)}
204
+ return {"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type),
205
+ "scope": {k: get_full_class_name(v) for k, v in self.scope.items()
206
+ if hasattr(v, '__module__') and hasattr(v, '__name__')}
207
+ }
167
208
 
168
209
  @classmethod
169
210
  def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
170
- return cls(user_input=data["user_input"], conclusion_type=get_type_from_string(data["conclusion_type"]))
211
+ return cls(user_input=data["user_input"], conclusion_type=get_type_from_string(data["conclusion_type"]),
212
+ scope={k: get_type_from_string(v) for k, v in data["scope"].items()})
171
213
 
172
214
 
173
215
  def compile_expression_to_code(expression_tree: AST) -> Any:
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from collections import UserDict
4
4
  from copy import copy, deepcopy
5
- from dataclasses import dataclass
5
+ from dataclasses import dataclass, is_dataclass
6
6
  from enum import Enum
7
7
 
8
8
  from pandas import DataFrame
@@ -11,7 +11,7 @@ from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColum
11
11
  from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING
12
12
 
13
13
  from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer, \
14
- get_full_class_name, get_type_from_string
14
+ get_full_class_name, get_type_from_string, make_list, is_iterable, serialize_dataclass
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from ripple_down_rules.rules import Rule
@@ -24,16 +24,19 @@ class Case(UserDict, SubclassJSONSerializer):
24
24
  the names of the attributes and the values are the attributes. All are stored in lower case.
25
25
  """
26
26
 
27
- def __init__(self, _id: Optional[Hashable] = None, _type: Optional[Type] = None, **kwargs):
27
+ def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None, _name: Optional[str] = None, **kwargs):
28
28
  """
29
29
  Create a new row.
30
30
 
31
+ :param _obj_type: The type of the object that the row represents.
31
32
  :param _id: The id of the row.
33
+ :param _name: The semantic name that describes the row.
32
34
  :param kwargs: The attributes of the row.
33
35
  """
34
36
  super().__init__(kwargs)
35
- self._id = _id if _id else id(self)
36
- self._type = _type
37
+ self._obj_type: Type = _obj_type
38
+ self._id: Hashable = _id if _id is not None else id(self)
39
+ self._name: str = _name if _name is not None else self._obj_type.__name__
37
40
 
38
41
  @classmethod
39
42
  def from_obj(cls, obj: Any, obj_name: Optional[str] = None, max_recursion_idx: int = 3) -> Case:
@@ -53,13 +56,14 @@ class Case(UserDict, SubclassJSONSerializer):
53
56
  def __setitem__(self, name: str, value: Any):
54
57
  name = name.lower()
55
58
  if name in self:
56
- if isinstance(self[name], set):
57
- self[name].update(make_set(value))
58
- elif isinstance(value, set):
59
- value.update(make_set(self[name]))
60
- super().__setitem__(name, value)
59
+ if isinstance(self[name], list):
60
+ self[name].extend(make_list(value))
61
+ elif isinstance(value, list):
62
+ new_list = make_list(self[name])
63
+ new_list.extend(make_list(value))
64
+ super().__setitem__(name, new_list)
61
65
  else:
62
- super().__setitem__(name, make_set([self[name], value]))
66
+ super().__setitem__(name, [self[name], value])
63
67
  else:
64
68
  super().__setitem__(name, value)
65
69
  setattr(self, name, self[name])
@@ -78,18 +82,24 @@ class Case(UserDict, SubclassJSONSerializer):
78
82
  def _to_json(self) -> Dict[str, Any]:
79
83
  serializable = {k: v for k, v in self.items() if not k.startswith("_")}
80
84
  serializable["_id"] = self._id
85
+ serializable["_obj_type"] = get_full_class_name(self._obj_type)
86
+ serializable["_name"] = self._name
81
87
  for k, v in serializable.items():
82
88
  if isinstance(v, set):
83
- serializable[k] = {'_type': get_full_class_name(set), 'value': list(v)}
89
+ serializable[k] = {'_type': get_full_class_name(set), 'value': serialize_dataclass(list(v))}
90
+ else:
91
+ serializable[k] = serialize_dataclass(v)
84
92
  return {k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}
85
93
 
86
94
  @classmethod
87
95
  def _from_json(cls, data: Dict[str, Any]) -> Case:
88
96
  id_ = data.pop("_id")
97
+ obj_type = get_type_from_string(data.pop("_obj_type"))
98
+ name = data.pop("_name")
89
99
  for k, v in data.items():
90
100
  if isinstance(v, dict) and "_type" in v:
91
101
  data[k] = SubclassJSONSerializer.from_json(v)
92
- return cls(_id=id_, **data)
102
+ return cls(_obj_type=obj_type, _id=id_, _name=name, **data)
93
103
 
94
104
 
95
105
  @dataclass
@@ -122,7 +132,7 @@ class CaseAttributeValue(SubclassJSONSerializer):
122
132
  return cls(id=data["id"], value=data["value"])
123
133
 
124
134
 
125
- class CaseAttribute(set, SubclassJSONSerializer):
135
+ class CaseAttribute(list, SubclassJSONSerializer):
126
136
  nullable: bool = True
127
137
  """
128
138
  A boolean indicating whether the case attribute can be None or not.
@@ -132,33 +142,9 @@ class CaseAttribute(set, SubclassJSONSerializer):
132
142
  A boolean indicating whether the case attribute is mutually exclusive or not. (i.e. can only have one value)
133
143
  """
134
144
 
135
- def __init__(self, values: Set[CaseAttributeValue]):
136
- """
137
- Create a new case attribute.
138
-
139
- :param values: The values of the case attribute.
140
- """
141
- values = self._type_cast_values_to_set_of_case_attribute_values(values)
142
- self.id_value_map: Dict[Hashable, Union[CaseAttributeValue, Set[CaseAttributeValue]]] = {id(v): v for v in values}
143
- super().__init__([v.value for v in values])
144
-
145
- @staticmethod
146
- def _type_cast_values_to_set_of_case_attribute_values(values: Set[Any]) -> Set[CaseAttributeValue]:
147
- """
148
- Type cast values to a set of case attribute values.
149
-
150
- :param values: The values to type cast.
151
- """
152
- values = make_set(values)
153
- if len(values) > 0 and not isinstance(next(iter(values)), CaseAttributeValue):
154
- values = {CaseAttributeValue(id(values), v) for v in values}
155
- return values
156
-
157
145
  @classmethod
158
- def from_obj(cls, values: Set[Any], row_obj: Optional[Any] = None) -> CaseAttribute:
159
- id_ = id(row_obj) if row_obj else id(values)
160
- values = make_set(values)
161
- return cls({CaseAttributeValue(id_, v) for v in values})
146
+ def from_obj(cls, values: List[Any]) -> CaseAttribute:
147
+ return cls(make_list(values))
162
148
 
163
149
  @property
164
150
  def as_dict(self) -> Dict[str, Any]:
@@ -176,42 +162,43 @@ class CaseAttribute(set, SubclassJSONSerializer):
176
162
  :param condition: The condition to filter by.
177
163
  :return: The filtered column.
178
164
  """
179
- return self.__class__({v for v in self if condition(v)})
165
+ return self.__class__([v for v in self if condition(v)])
180
166
 
181
167
  def __eq__(self, other):
182
- if not isinstance(other, set):
183
- return super().__eq__(make_set(other))
168
+ if not isinstance(other, list):
169
+ return super().__eq__(make_list(other))
184
170
  return super().__eq__(other)
185
171
 
186
172
  def __hash__(self):
187
- return hash(tuple(self.id_value_map.values()))
173
+ return hash(id(self))
188
174
 
189
175
  def __str__(self):
190
176
  if len(self) == 0:
191
177
  return "None"
192
- return str({v for v in self}) if len(self) > 1 else str(next(iter(self)))
178
+ return str([v for v in self]) if len(self) > 1 else str(next(iter(self)))
193
179
 
194
180
  def _to_json(self) -> Dict[str, Any]:
195
- return {id_: v.to_json() if isinstance(v, SubclassJSONSerializer) else v
196
- for id_, v in self.id_value_map.items()}
181
+ return {str(i): v.to_json() if isinstance(v, SubclassJSONSerializer) else v
182
+ for i, v in enumerate(self)}
197
183
 
198
184
  @classmethod
199
185
  def _from_json(cls, data: Dict[str, Any]) -> CaseAttribute:
200
- return cls({CaseAttributeValue.from_json(v) for id_, v in data.items()})
186
+ return cls([SubclassJSONSerializer.from_json(v) for _, v in data.items()])
201
187
 
202
188
 
203
- def create_cases_from_dataframe(df: DataFrame) -> List[Case]:
189
+ def create_cases_from_dataframe(df: DataFrame, name: Optional[str] = None) -> List[Case]:
204
190
  """
205
191
  Create cases from a pandas DataFrame.
206
192
 
207
193
  :param df: The DataFrame to create cases from.
194
+ :param name: The semantic name of the DataFrame that describes the DataFrame.
208
195
  :return: The cases of the DataFrame.
209
196
  """
210
197
  cases = []
211
198
  attribute_names = list(df.columns)
212
199
  for row_id, case in df.iterrows():
213
200
  case = {col_name: case[col_name].item() for col_name in attribute_names}
214
- cases.append(Case(_id=row_id, _type=DataFrame, **case))
201
+ cases.append(Case(DataFrame, _id=row_id, _name=name, **case))
215
202
  return cases
216
203
 
217
204
 
@@ -228,19 +215,19 @@ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
228
215
  :return: The case that represents the object.
229
216
  """
230
217
  if isinstance(obj, DataFrame):
231
- return create_cases_from_dataframe(obj)
218
+ return create_cases_from_dataframe(obj, obj_name)
232
219
  if isinstance(obj, Case):
233
220
  return obj
234
221
  if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
235
222
  or (obj.__class__ in [MetaData, registry])):
236
- return Case(_id=id(obj), _type=obj.__class__,
237
- **{obj_name or obj.__class__.__name__: make_set(obj) if parent_is_iterable else obj})
238
- case = Case(_id=id(obj), _type=obj.__class__)
223
+ return Case(type(obj), _id=id(obj), _name=obj_name,
224
+ **{obj_name or obj.__class__.__name__: make_list(obj) if parent_is_iterable else obj})
225
+ case = Case(type(obj), _id=id(obj), _name=obj_name)
239
226
  for attr in dir(obj):
240
227
  if attr.startswith("_") or callable(getattr(obj, attr)):
241
228
  continue
242
229
  attr_value = getattr(obj, attr)
243
- case = create_or_update_case_from_attribute(attr_value, attr, obj, attr, recursion_idx,
230
+ case = create_or_update_case_from_attribute(attr_value, attr, obj, obj_name, recursion_idx,
244
231
  max_recursion_idx, parent_is_iterable, case)
245
232
  return case
246
233
 
@@ -263,16 +250,16 @@ def create_or_update_case_from_attribute(attr_value: Any, name: str, obj: Any, o
263
250
  :return: The updated/created case.
264
251
  """
265
252
  if case is None:
266
- case = Case(_id=id(obj), _type=obj.__class__)
253
+ case = Case(type(obj), _id=id(obj), _name=obj_name)
267
254
  if isinstance(attr_value, (dict, UserDict)):
268
255
  case.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
269
256
  if hasattr(attr_value, "__iter__") and not isinstance(attr_value, str):
270
257
  column = create_case_attribute_from_iterable_attribute(attr_value, name, obj, obj_name,
271
258
  recursion_idx=recursion_idx + 1,
272
259
  max_recursion_idx=max_recursion_idx)
273
- case[obj_name] = column
260
+ case[name] = column
274
261
  else:
275
- case[obj_name] = make_set(attr_value) if parent_is_iterable else attr_value
262
+ case[name] = make_list(attr_value) if parent_is_iterable else attr_value
276
263
  return case
277
264
 
278
265
 
@@ -290,22 +277,22 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
290
277
  :param max_recursion_idx: The maximum recursion index.
291
278
  :return: A case attribute that represents the original iterable attribute.
292
279
  """
293
- values = attr_value.values() if isinstance(attr_value, (dict, UserDict)) else attr_value
280
+ values = list(attr_value.values()) if isinstance(attr_value, (dict, UserDict)) else attr_value
294
281
  _type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
295
- attr_case = Case(_id=id(attr_value), _type=_type)
296
- case_attr = CaseAttribute.from_obj(values, row_obj=obj)
282
+ attr_case = Case(_type, _id=id(attr_value), _name=name)
283
+ case_attr = CaseAttribute(values)
297
284
  for idx, val in enumerate(values):
298
285
  sub_attr_case = create_case(val, recursion_idx=recursion_idx,
299
286
  max_recursion_idx=max_recursion_idx,
300
- obj_name=obj_name, parent_is_iterable=True)
287
+ obj_name=name, parent_is_iterable=True)
301
288
  attr_case.update(sub_attr_case)
302
289
  for sub_attr, val in attr_case.items():
303
290
  setattr(case_attr, sub_attr, val)
304
291
  return case_attr
305
292
 
306
293
 
307
- def show_current_and_corner_cases(case: Any, targets: Optional[Union[List[CaseAttribute], List[SQLColumn]]] = None,
308
- current_conclusions: Optional[Union[List[CaseAttribute], List[SQLColumn]]] = None,
294
+ def show_current_and_corner_cases(case: Any, targets: Optional[Dict[str, Any]] = None,
295
+ current_conclusions: Optional[Dict[str, Any]] = None,
309
296
  last_evaluated_rule: Optional[Rule] = None) -> None:
310
297
  """
311
298
  Show the data of the new case and if last evaluated rule exists also show that of the corner case.
@@ -316,12 +303,8 @@ def show_current_and_corner_cases(case: Any, targets: Optional[Union[List[CaseAt
316
303
  :param last_evaluated_rule: The last evaluated rule in the RDR.
317
304
  """
318
305
  corner_case = None
319
- if targets:
320
- targets = targets if isinstance(targets, list) else [targets]
321
- if current_conclusions:
322
- current_conclusions = current_conclusions if isinstance(current_conclusions, list) else [current_conclusions]
323
- targets = {f"target_{t.__class__.__name__}": t for t in targets} if targets else {}
324
- current_conclusions = {c.__class__.__name__: c for c in current_conclusions} if current_conclusions else {}
306
+ targets = {f"target_{name}": value for name, value in targets.items()} if targets else {}
307
+ current_conclusions = {name: value for name, value in current_conclusions.items} if current_conclusions else {}
325
308
  if last_evaluated_rule:
326
309
  action = "Refinement" if last_evaluated_rule.fired else "Alternative"
327
310
  print(f"{action} needed for rule: {last_evaluated_rule}\n")
@@ -1,13 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import inspect
3
4
  from dataclasses import dataclass
4
5
 
5
6
  from sqlalchemy.orm import DeclarativeBase as SQLTable
6
- from typing_extensions import Any, Optional, Type
7
+ from typing_extensions import Any, Optional, Type, List, Tuple, Set, Dict, TYPE_CHECKING
7
8
 
8
9
  from .case import create_case, Case
9
- from ..utils import get_attribute_name, copy_case
10
+ from ..utils import get_attribute_name, copy_case, get_hint_for_attribute, typing_to_python_type
10
11
 
12
+ if TYPE_CHECKING:
13
+ from . import CallableExpression
11
14
 
12
15
  @dataclass
13
16
  class CaseQuery:
@@ -20,56 +23,93 @@ class CaseQuery:
20
23
  """
21
24
  The case that the attribute belongs to.
22
25
  """
23
- attribute: Optional[Any] = None
26
+ attribute_name: str
24
27
  """
25
- The attribute itself.
28
+ The name of the attribute.
26
29
  """
27
- targets: Optional[Any] = None
30
+ target: Optional[Any] = None
28
31
  """
29
32
  The target value of the attribute.
30
33
  """
31
- attribute_name: Optional[str] = None
34
+ mutually_exclusive: bool = False
32
35
  """
33
- The name of the attribute.
36
+ Whether the attribute can only take one value (i.e. True) or multiple values (i.e. False).
37
+ """
38
+ conditions: Optional[CallableExpression] = None
34
39
  """
35
- attribute_type: Optional[Type] = None
40
+ The conditions that must be satisfied for the target value to be valid.
36
41
  """
37
- The type of the attribute.
42
+ prediction: Optional[CallableExpression] = None
38
43
  """
39
- relational_representation: Optional[str] = None
44
+ The predicted value of the attribute.
40
45
  """
41
- The representation of the target value in relational form.
46
+ scope: Optional[Dict[str, Any]] = None
47
+ """
48
+ The global scope of the case query. This is used to evaluate the conditions and prediction, and is what is available
49
+ to the user when they are prompted for input. If it is not provided, it will be set to the global scope of the
50
+ caller.
42
51
  """
43
52
 
44
- def __init__(self, case: Any, attribute: Optional[Any] = None, target: Optional[Any] = None,
45
- attribute_name: Optional[str] = None, attribute_type: Optional[Type] = None,
46
- relational_representation: Optional[str] = None):
53
+ def __init__(self, case: Any, attribute_name: str,
54
+ target: Optional[Any] = None,
55
+ mutually_exclusive: bool = False,
56
+ conditions: Optional[CallableExpression] = None,
57
+ prediction: Optional[CallableExpression] = None,
58
+ scope: Optional[Dict[str, Any]] = None,):
59
+ self.original_case = case
60
+ self.case = self._get_case()
47
61
 
48
- if attribute_name is None:
49
- attribute_name = get_attribute_name(case, attribute, attribute_type, target)
50
62
  self.attribute_name = attribute_name
63
+ self.target = target
64
+ self.attribute_type = self._get_attribute_type()
65
+ self.mutually_exclusive = mutually_exclusive
66
+ self.conditions = conditions
67
+ self.prediction = prediction
68
+ self.scope = scope if scope is not None else inspect.currentframe().f_back.f_globals
51
69
 
52
- if not isinstance(case, (Case, SQLTable)):
53
- case = create_case(case, max_recursion_idx=3)
54
- self.case = case
70
+ def _get_case(self) -> Any:
71
+ if not isinstance(self.original_case, (Case, SQLTable)):
72
+ return create_case(self.original_case, max_recursion_idx=3)
73
+ else:
74
+ return self.original_case
55
75
 
56
- self.attribute = getattr(self.case, self.attribute_name) if self.attribute_name else None
57
- self.attribute_type = type(self.attribute) if self.attribute else None
58
- self.target = target
59
- self.relational_representation = relational_representation
76
+ def _get_attribute_type(self) -> Type:
77
+ """
78
+ :return: The type of the attribute.
79
+ """
80
+ if self.target is not None:
81
+ return type(self.target)
82
+ elif hasattr(self.original_case, self.attribute_name):
83
+ hint, origin, args = get_hint_for_attribute(self.attribute_name, self.original_case)
84
+ if origin is not None:
85
+ return typing_to_python_type(origin)
86
+ elif hint is not None:
87
+ return typing_to_python_type(hint)
60
88
 
61
89
  @property
62
90
  def name(self):
63
- return self.attribute_name if self.attribute_name else self.__class__.__name__
91
+ """
92
+ :return: The name of the case query.
93
+ """
94
+ return f"{self.case_name}.{self.attribute_name}"
95
+
96
+ @property
97
+ def case_name(self) -> str:
98
+ """
99
+ :return: The name of the case.
100
+ """
101
+ return self.case._name if isinstance(self.case, Case) else self.case.__class__.__name__
64
102
 
65
103
  def __str__(self):
66
- if self.relational_representation:
67
- return f"{self.name} |= {self.relational_representation}"
68
- else:
69
- return f"{self.target}"
104
+ header = f"CaseQuery: {self.name}"
105
+ target = f"Target: {self.name} |= {self.target if self.target is not None else '?'}"
106
+ prediction = f"Prediction: {self.name} |= {self.prediction if self.prediction is not None else '?'}"
107
+ conditions = f"Conditions: {self.conditions if self.conditions is not None else '?'}"
108
+ return "\n".join([header, target, prediction, conditions])
70
109
 
71
110
  def __repr__(self):
72
111
  return self.__str__()
73
112
 
74
113
  def __copy__(self):
75
- return CaseQuery(copy_case(self.case), attribute_name=self.attribute_name, target=self.target)
114
+ return CaseQuery(copy_case(self.case), self.attribute_name, self.target, self.mutually_exclusive,
115
+ self.conditions, self.prediction, self.scope)