ripple-down-rules 0.0.15__tar.gz → 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/PKG-INFO +1 -1
  2. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/pyproject.toml +1 -1
  3. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules/datasets.py +2 -2
  4. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules/datastructures/callable_expression.py +52 -10
  5. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules/datastructures/case.py +53 -70
  6. ripple_down_rules-0.1.0/src/ripple_down_rules/datastructures/dataclasses.py +115 -0
  7. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules/experts.py +29 -40
  8. ripple_down_rules-0.1.0/src/ripple_down_rules/helpers.py +27 -0
  9. ripple_down_rules-0.1.0/src/ripple_down_rules/prompt.py +154 -0
  10. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules/rdr.py +214 -192
  11. ripple_down_rules-0.1.0/src/ripple_down_rules/rdr_decorators.py +55 -0
  12. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules/rules.py +7 -2
  13. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules/utils.py +154 -3
  14. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
  15. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules.egg-info/SOURCES.txt +3 -0
  16. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/test/test_json_serialization.py +25 -1
  17. ripple_down_rules-0.1.0/test/test_on_mutagenic.py +200 -0
  18. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/test/test_rdr.py +27 -123
  19. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/test/test_rdr_alchemy.py +18 -16
  20. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/test/test_relational_rdr.py +8 -8
  21. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/test/test_relational_rdr_alchemy.py +19 -19
  22. ripple_down_rules-0.0.15/src/ripple_down_rules/datastructures/dataclasses.py +0 -75
  23. ripple_down_rules-0.0.15/src/ripple_down_rules/prompt.py +0 -101
  24. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/LICENSE +0 -0
  25. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/README.md +0 -0
  26. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/setup.cfg +0 -0
  27. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules/__init__.py +0 -0
  28. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules/datastructures/__init__.py +0 -0
  29. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules/datastructures/enums.py +0 -0
  30. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules/failures.py +0 -0
  31. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
  32. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
  33. {ripple_down_rules-0.0.15 → ripple_down_rules-0.1.0}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.15
3
+ Version: 0.1.0
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
 
7
7
  [project]
8
8
  name = "ripple_down_rules"
9
- version = "0.0.15"
9
+ version = "0.1.0"
10
10
  description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
11
11
  readme = "README.md"
12
12
  authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
@@ -9,7 +9,7 @@ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationshi
9
9
  from typing_extensions import Tuple, List, Set, Optional
10
10
  from ucimlrepo import fetch_ucirepo
11
11
 
12
- from .datastructures import Case, create_cases_from_dataframe, Category, CaseAttribute
12
+ from .datastructures import Case, create_cases_from_dataframe, Category
13
13
 
14
14
 
15
15
  def load_cached_dataset(cache_file):
@@ -77,7 +77,7 @@ def load_zoo_dataset(cache_file: Optional[str] = None) -> Tuple[List[Case], List
77
77
  y = zoo['targets']
78
78
  # get ids as list of strings
79
79
  ids = zoo['ids'].values.flatten()
80
- all_cases = create_cases_from_dataframe(X)
80
+ all_cases = create_cases_from_dataframe(X, name="Animal")
81
81
 
82
82
  category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
83
83
  category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
@@ -22,10 +22,22 @@ class VariableVisitor(ast.NodeVisitor):
22
22
  def __init__(self):
23
23
  self.variables = set()
24
24
  self.attributes: Dict[ast.Name, ast.Attribute] = {}
25
+ self.types = set()
26
+ self.callables = set()
25
27
  self.compares = list()
26
28
  self.binary_ops = list()
27
29
  self.all = list()
28
30
 
31
+ def visit_Constant(self, node):
32
+ self.all.append(node)
33
+ self.types.add(node)
34
+ self.generic_visit(node)
35
+
36
+ def visit_Call(self, node):
37
+ self.all.append(node)
38
+ self.callables.add(node)
39
+ self.generic_visit(node)
40
+
29
41
  def visit_Attribute(self, node):
30
42
  self.all.append(node)
31
43
  self.attributes[node.value] = node
@@ -51,6 +63,27 @@ class VariableVisitor(ast.NodeVisitor):
51
63
  self.generic_visit(node)
52
64
 
53
65
 
66
+ def get_used_scope(code_str, scope):
67
+ # Parse the code into an AST
68
+ tree = ast.parse(code_str, mode='eval')
69
+
70
+ # Walk the AST to collect used variable names
71
+ class NameCollector(ast.NodeVisitor):
72
+ def __init__(self):
73
+ self.names = set()
74
+
75
+ def visit_Name(self, node):
76
+ if isinstance(node.ctx, ast.Load): # We care only about variables being read
77
+ self.names.add(node.id)
78
+
79
+ collector = NameCollector()
80
+ collector.visit(tree)
81
+
82
+ # Filter the scope to include only used names
83
+ used_scope = {k: scope[k] for k in collector.names if k in scope}
84
+ return used_scope
85
+
86
+
54
87
  class CallableExpression(SubclassJSONSerializer):
55
88
  """
56
89
  A callable that is constructed from a string statement written by an expert.
@@ -85,7 +118,7 @@ class CallableExpression(SubclassJSONSerializer):
85
118
  """
86
119
 
87
120
  def __init__(self, user_input: str, conclusion_type: Optional[Type] = None, expression_tree: Optional[AST] = None,
88
- session: Optional[Session] = None):
121
+ session: Optional[Session] = None, scope: Optional[Dict[str, Any]] = None):
89
122
  """
90
123
  Create a callable expression.
91
124
 
@@ -98,8 +131,17 @@ class CallableExpression(SubclassJSONSerializer):
98
131
  self.user_input: str = user_input
99
132
  self.parsed_user_input = self.parse_user_input(user_input, session)
100
133
  self.conclusion_type = conclusion_type
134
+ self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
135
+ self.scope = get_used_scope(self.parsed_user_input, self.scope)
101
136
  self.update_expression(self.parsed_user_input, expression_tree)
102
137
 
138
+ def get_used_scope_in_user_input(self) -> Set[str]:
139
+ """
140
+ Get the used scope in the user input.
141
+ :return: The used scope in the user input.
142
+ """
143
+ return self.visitor.variables.union(self.visitor.attributes.keys())
144
+
103
145
  @staticmethod
104
146
  def parse_user_input(user_input: str, session: Optional[Session] = None) -> str:
105
147
  if ',' in user_input:
@@ -117,11 +159,6 @@ class CallableExpression(SubclassJSONSerializer):
117
159
  self.expression_tree: AST = expression_tree
118
160
  self.visitor = VariableVisitor()
119
161
  self.visitor.visit(expression_tree)
120
- variables_str = self.visitor.variables
121
- attributes_str = get_attributes_str(self.visitor)
122
- for v in variables_str | attributes_str:
123
- if not v.startswith("case."):
124
- self.parsed_user_input = self.parsed_user_input.replace(v, f"case.{v}")
125
162
  self.expression_tree = parse_string_to_expression(self.parsed_user_input)
126
163
  self.compares_column_offset = [(c[0].col_offset, c[2].end_col_offset) for c in self.visitor.compares]
127
164
  self.code = compile_expression_to_code(self.expression_tree)
@@ -130,8 +167,9 @@ class CallableExpression(SubclassJSONSerializer):
130
167
  try:
131
168
  if not isinstance(case, Case):
132
169
  case = create_case(case, max_recursion_idx=3)
133
- output = eval(self.code)
134
- if self.conclusion_type:
170
+ scope = {'case': case, **self.scope}
171
+ output = eval(self.code, scope)
172
+ if self.conclusion_type is not None:
135
173
  assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
136
174
  f" got {type(output)}")
137
175
  return output
@@ -163,11 +201,15 @@ class CallableExpression(SubclassJSONSerializer):
163
201
  return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
164
202
 
165
203
  def _to_json(self) -> Dict[str, Any]:
166
- return {"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type)}
204
+ return {"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type),
205
+ "scope": {k: get_full_class_name(v) for k, v in self.scope.items()
206
+ if hasattr(v, '__module__') and hasattr(v, '__name__')}
207
+ }
167
208
 
168
209
  @classmethod
169
210
  def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
170
- return cls(user_input=data["user_input"], conclusion_type=get_type_from_string(data["conclusion_type"]))
211
+ return cls(user_input=data["user_input"], conclusion_type=get_type_from_string(data["conclusion_type"]),
212
+ scope={k: get_type_from_string(v) for k, v in data["scope"].items()})
171
213
 
172
214
 
173
215
  def compile_expression_to_code(expression_tree: AST) -> Any:
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from collections import UserDict
4
4
  from copy import copy, deepcopy
5
- from dataclasses import dataclass
5
+ from dataclasses import dataclass, is_dataclass
6
6
  from enum import Enum
7
7
 
8
8
  from pandas import DataFrame
@@ -11,7 +11,7 @@ from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColum
11
11
  from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING
12
12
 
13
13
  from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer, \
14
- get_full_class_name, get_type_from_string
14
+ get_full_class_name, get_type_from_string, make_list, is_iterable, serialize_dataclass
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from ripple_down_rules.rules import Rule
@@ -24,16 +24,19 @@ class Case(UserDict, SubclassJSONSerializer):
24
24
  the names of the attributes and the values are the attributes. All are stored in lower case.
25
25
  """
26
26
 
27
- def __init__(self, _id: Optional[Hashable] = None, _type: Optional[Type] = None, **kwargs):
27
+ def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None, _name: Optional[str] = None, **kwargs):
28
28
  """
29
29
  Create a new row.
30
30
 
31
+ :param _obj_type: The type of the object that the row represents.
31
32
  :param _id: The id of the row.
33
+ :param _name: The semantic name that describes the row.
32
34
  :param kwargs: The attributes of the row.
33
35
  """
34
36
  super().__init__(kwargs)
35
- self._id = _id if _id else id(self)
36
- self._type = _type
37
+ self._obj_type: Type = _obj_type
38
+ self._id: Hashable = _id if _id is not None else id(self)
39
+ self._name: str = _name if _name is not None else self._obj_type.__name__
37
40
 
38
41
  @classmethod
39
42
  def from_obj(cls, obj: Any, obj_name: Optional[str] = None, max_recursion_idx: int = 3) -> Case:
@@ -53,13 +56,14 @@ class Case(UserDict, SubclassJSONSerializer):
53
56
  def __setitem__(self, name: str, value: Any):
54
57
  name = name.lower()
55
58
  if name in self:
56
- if isinstance(self[name], set):
57
- self[name].update(make_set(value))
58
- elif isinstance(value, set):
59
- value.update(make_set(self[name]))
60
- super().__setitem__(name, value)
59
+ if isinstance(self[name], list):
60
+ self[name].extend(make_list(value))
61
+ elif isinstance(value, list):
62
+ new_list = make_list(self[name])
63
+ new_list.extend(make_list(value))
64
+ super().__setitem__(name, new_list)
61
65
  else:
62
- super().__setitem__(name, make_set([self[name], value]))
66
+ super().__setitem__(name, [self[name], value])
63
67
  else:
64
68
  super().__setitem__(name, value)
65
69
  setattr(self, name, self[name])
@@ -78,18 +82,24 @@ class Case(UserDict, SubclassJSONSerializer):
78
82
  def _to_json(self) -> Dict[str, Any]:
79
83
  serializable = {k: v for k, v in self.items() if not k.startswith("_")}
80
84
  serializable["_id"] = self._id
85
+ serializable["_obj_type"] = get_full_class_name(self._obj_type)
86
+ serializable["_name"] = self._name
81
87
  for k, v in serializable.items():
82
88
  if isinstance(v, set):
83
- serializable[k] = {'_type': get_full_class_name(set), 'value': list(v)}
89
+ serializable[k] = {'_type': get_full_class_name(set), 'value': serialize_dataclass(list(v))}
90
+ else:
91
+ serializable[k] = serialize_dataclass(v)
84
92
  return {k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}
85
93
 
86
94
  @classmethod
87
95
  def _from_json(cls, data: Dict[str, Any]) -> Case:
88
96
  id_ = data.pop("_id")
97
+ obj_type = get_type_from_string(data.pop("_obj_type"))
98
+ name = data.pop("_name")
89
99
  for k, v in data.items():
90
100
  if isinstance(v, dict) and "_type" in v:
91
101
  data[k] = SubclassJSONSerializer.from_json(v)
92
- return cls(_id=id_, **data)
102
+ return cls(_obj_type=obj_type, _id=id_, _name=name, **data)
93
103
 
94
104
 
95
105
  @dataclass
@@ -122,7 +132,7 @@ class CaseAttributeValue(SubclassJSONSerializer):
122
132
  return cls(id=data["id"], value=data["value"])
123
133
 
124
134
 
125
- class CaseAttribute(set, SubclassJSONSerializer):
135
+ class CaseAttribute(list, SubclassJSONSerializer):
126
136
  nullable: bool = True
127
137
  """
128
138
  A boolean indicating whether the case attribute can be None or not.
@@ -132,33 +142,9 @@ class CaseAttribute(set, SubclassJSONSerializer):
132
142
  A boolean indicating whether the case attribute is mutually exclusive or not. (i.e. can only have one value)
133
143
  """
134
144
 
135
- def __init__(self, values: Set[CaseAttributeValue]):
136
- """
137
- Create a new case attribute.
138
-
139
- :param values: The values of the case attribute.
140
- """
141
- values = self._type_cast_values_to_set_of_case_attribute_values(values)
142
- self.id_value_map: Dict[Hashable, Union[CaseAttributeValue, Set[CaseAttributeValue]]] = {id(v): v for v in values}
143
- super().__init__([v.value for v in values])
144
-
145
- @staticmethod
146
- def _type_cast_values_to_set_of_case_attribute_values(values: Set[Any]) -> Set[CaseAttributeValue]:
147
- """
148
- Type cast values to a set of case attribute values.
149
-
150
- :param values: The values to type cast.
151
- """
152
- values = make_set(values)
153
- if len(values) > 0 and not isinstance(next(iter(values)), CaseAttributeValue):
154
- values = {CaseAttributeValue(id(values), v) for v in values}
155
- return values
156
-
157
145
  @classmethod
158
- def from_obj(cls, values: Set[Any], row_obj: Optional[Any] = None) -> CaseAttribute:
159
- id_ = id(row_obj) if row_obj else id(values)
160
- values = make_set(values)
161
- return cls({CaseAttributeValue(id_, v) for v in values})
146
+ def from_obj(cls, values: List[Any]) -> CaseAttribute:
147
+ return cls(make_list(values))
162
148
 
163
149
  @property
164
150
  def as_dict(self) -> Dict[str, Any]:
@@ -176,42 +162,43 @@ class CaseAttribute(set, SubclassJSONSerializer):
176
162
  :param condition: The condition to filter by.
177
163
  :return: The filtered column.
178
164
  """
179
- return self.__class__({v for v in self if condition(v)})
165
+ return self.__class__([v for v in self if condition(v)])
180
166
 
181
167
  def __eq__(self, other):
182
- if not isinstance(other, set):
183
- return super().__eq__(make_set(other))
168
+ if not isinstance(other, list):
169
+ return super().__eq__(make_list(other))
184
170
  return super().__eq__(other)
185
171
 
186
172
  def __hash__(self):
187
- return hash(tuple(self.id_value_map.values()))
173
+ return hash(id(self))
188
174
 
189
175
  def __str__(self):
190
176
  if len(self) == 0:
191
177
  return "None"
192
- return str({v for v in self}) if len(self) > 1 else str(next(iter(self)))
178
+ return str([v for v in self]) if len(self) > 1 else str(next(iter(self)))
193
179
 
194
180
  def _to_json(self) -> Dict[str, Any]:
195
- return {id_: v.to_json() if isinstance(v, SubclassJSONSerializer) else v
196
- for id_, v in self.id_value_map.items()}
181
+ return {str(i): v.to_json() if isinstance(v, SubclassJSONSerializer) else v
182
+ for i, v in enumerate(self)}
197
183
 
198
184
  @classmethod
199
185
  def _from_json(cls, data: Dict[str, Any]) -> CaseAttribute:
200
- return cls({CaseAttributeValue.from_json(v) for id_, v in data.items()})
186
+ return cls([SubclassJSONSerializer.from_json(v) for _, v in data.items()])
201
187
 
202
188
 
203
- def create_cases_from_dataframe(df: DataFrame) -> List[Case]:
189
+ def create_cases_from_dataframe(df: DataFrame, name: Optional[str] = None) -> List[Case]:
204
190
  """
205
191
  Create cases from a pandas DataFrame.
206
192
 
207
193
  :param df: The DataFrame to create cases from.
194
+ :param name: The semantic name of the DataFrame that describes the DataFrame.
208
195
  :return: The cases of the DataFrame.
209
196
  """
210
197
  cases = []
211
198
  attribute_names = list(df.columns)
212
199
  for row_id, case in df.iterrows():
213
200
  case = {col_name: case[col_name].item() for col_name in attribute_names}
214
- cases.append(Case(_id=row_id, _type=DataFrame, **case))
201
+ cases.append(Case(DataFrame, _id=row_id, _name=name, **case))
215
202
  return cases
216
203
 
217
204
 
@@ -228,19 +215,19 @@ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
228
215
  :return: The case that represents the object.
229
216
  """
230
217
  if isinstance(obj, DataFrame):
231
- return create_cases_from_dataframe(obj)
218
+ return create_cases_from_dataframe(obj, obj_name)
232
219
  if isinstance(obj, Case):
233
220
  return obj
234
221
  if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
235
222
  or (obj.__class__ in [MetaData, registry])):
236
- return Case(_id=id(obj), _type=obj.__class__,
237
- **{obj_name or obj.__class__.__name__: make_set(obj) if parent_is_iterable else obj})
238
- case = Case(_id=id(obj), _type=obj.__class__)
223
+ return Case(type(obj), _id=id(obj), _name=obj_name,
224
+ **{obj_name or obj.__class__.__name__: make_list(obj) if parent_is_iterable else obj})
225
+ case = Case(type(obj), _id=id(obj), _name=obj_name)
239
226
  for attr in dir(obj):
240
227
  if attr.startswith("_") or callable(getattr(obj, attr)):
241
228
  continue
242
229
  attr_value = getattr(obj, attr)
243
- case = create_or_update_case_from_attribute(attr_value, attr, obj, attr, recursion_idx,
230
+ case = create_or_update_case_from_attribute(attr_value, attr, obj, obj_name, recursion_idx,
244
231
  max_recursion_idx, parent_is_iterable, case)
245
232
  return case
246
233
 
@@ -263,16 +250,16 @@ def create_or_update_case_from_attribute(attr_value: Any, name: str, obj: Any, o
263
250
  :return: The updated/created case.
264
251
  """
265
252
  if case is None:
266
- case = Case(_id=id(obj), _type=obj.__class__)
253
+ case = Case(type(obj), _id=id(obj), _name=obj_name)
267
254
  if isinstance(attr_value, (dict, UserDict)):
268
255
  case.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
269
256
  if hasattr(attr_value, "__iter__") and not isinstance(attr_value, str):
270
257
  column = create_case_attribute_from_iterable_attribute(attr_value, name, obj, obj_name,
271
258
  recursion_idx=recursion_idx + 1,
272
259
  max_recursion_idx=max_recursion_idx)
273
- case[obj_name] = column
260
+ case[name] = column
274
261
  else:
275
- case[obj_name] = make_set(attr_value) if parent_is_iterable else attr_value
262
+ case[name] = make_list(attr_value) if parent_is_iterable else attr_value
276
263
  return case
277
264
 
278
265
 
@@ -290,22 +277,22 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
290
277
  :param max_recursion_idx: The maximum recursion index.
291
278
  :return: A case attribute that represents the original iterable attribute.
292
279
  """
293
- values = attr_value.values() if isinstance(attr_value, (dict, UserDict)) else attr_value
280
+ values = list(attr_value.values()) if isinstance(attr_value, (dict, UserDict)) else attr_value
294
281
  _type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
295
- attr_case = Case(_id=id(attr_value), _type=_type)
296
- case_attr = CaseAttribute.from_obj(values, row_obj=obj)
282
+ attr_case = Case(_type, _id=id(attr_value), _name=name)
283
+ case_attr = CaseAttribute(values)
297
284
  for idx, val in enumerate(values):
298
285
  sub_attr_case = create_case(val, recursion_idx=recursion_idx,
299
286
  max_recursion_idx=max_recursion_idx,
300
- obj_name=obj_name, parent_is_iterable=True)
287
+ obj_name=name, parent_is_iterable=True)
301
288
  attr_case.update(sub_attr_case)
302
289
  for sub_attr, val in attr_case.items():
303
290
  setattr(case_attr, sub_attr, val)
304
291
  return case_attr
305
292
 
306
293
 
307
- def show_current_and_corner_cases(case: Any, targets: Optional[Union[List[CaseAttribute], List[SQLColumn]]] = None,
308
- current_conclusions: Optional[Union[List[CaseAttribute], List[SQLColumn]]] = None,
294
+ def show_current_and_corner_cases(case: Any, targets: Optional[Dict[str, Any]] = None,
295
+ current_conclusions: Optional[Dict[str, Any]] = None,
309
296
  last_evaluated_rule: Optional[Rule] = None) -> None:
310
297
  """
311
298
  Show the data of the new case and if last evaluated rule exists also show that of the corner case.
@@ -316,12 +303,8 @@ def show_current_and_corner_cases(case: Any, targets: Optional[Union[List[CaseAt
316
303
  :param last_evaluated_rule: The last evaluated rule in the RDR.
317
304
  """
318
305
  corner_case = None
319
- if targets:
320
- targets = targets if isinstance(targets, list) else [targets]
321
- if current_conclusions:
322
- current_conclusions = current_conclusions if isinstance(current_conclusions, list) else [current_conclusions]
323
- targets = {f"target_{t.__class__.__name__}": t for t in targets} if targets else {}
324
- current_conclusions = {c.__class__.__name__: c for c in current_conclusions} if current_conclusions else {}
306
+ targets = {f"target_{name}": value for name, value in targets.items()} if targets else {}
307
+ current_conclusions = {name: value for name, value in current_conclusions.items} if current_conclusions else {}
325
308
  if last_evaluated_rule:
326
309
  action = "Refinement" if last_evaluated_rule.fired else "Alternative"
327
310
  print(f"{action} needed for rule: {last_evaluated_rule}\n")
@@ -0,0 +1,115 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from dataclasses import dataclass
5
+
6
+ from sqlalchemy.orm import DeclarativeBase as SQLTable
7
+ from typing_extensions import Any, Optional, Type, List, Tuple, Set, Dict, TYPE_CHECKING
8
+
9
+ from .case import create_case, Case
10
+ from ..utils import get_attribute_name, copy_case, get_hint_for_attribute, typing_to_python_type
11
+
12
+ if TYPE_CHECKING:
13
+ from . import CallableExpression
14
+
15
+ @dataclass
16
+ class CaseQuery:
17
+ """
18
+ This is a dataclass that represents an attribute of an object and its target value. If attribute name is
19
+ not provided, it will be inferred from the attribute itself or from the attribute type or from the target value,
20
+ depending on what is provided.
21
+ """
22
+ case: Any
23
+ """
24
+ The case that the attribute belongs to.
25
+ """
26
+ attribute_name: str
27
+ """
28
+ The name of the attribute.
29
+ """
30
+ target: Optional[Any] = None
31
+ """
32
+ The target value of the attribute.
33
+ """
34
+ mutually_exclusive: bool = False
35
+ """
36
+ Whether the attribute can only take one value (i.e. True) or multiple values (i.e. False).
37
+ """
38
+ conditions: Optional[CallableExpression] = None
39
+ """
40
+ The conditions that must be satisfied for the target value to be valid.
41
+ """
42
+ prediction: Optional[CallableExpression] = None
43
+ """
44
+ The predicted value of the attribute.
45
+ """
46
+ scope: Optional[Dict[str, Any]] = None
47
+ """
48
+ The global scope of the case query. This is used to evaluate the conditions and prediction, and is what is available
49
+ to the user when they are prompted for input. If it is not provided, it will be set to the global scope of the
50
+ caller.
51
+ """
52
+
53
+ def __init__(self, case: Any, attribute_name: str,
54
+ target: Optional[Any] = None,
55
+ mutually_exclusive: bool = False,
56
+ conditions: Optional[CallableExpression] = None,
57
+ prediction: Optional[CallableExpression] = None,
58
+ scope: Optional[Dict[str, Any]] = None,):
59
+ self.original_case = case
60
+ self.case = self._get_case()
61
+
62
+ self.attribute_name = attribute_name
63
+ self.target = target
64
+ self.attribute_type = self._get_attribute_type()
65
+ self.mutually_exclusive = mutually_exclusive
66
+ self.conditions = conditions
67
+ self.prediction = prediction
68
+ self.scope = scope if scope is not None else inspect.currentframe().f_back.f_globals
69
+
70
+ def _get_case(self) -> Any:
71
+ if not isinstance(self.original_case, (Case, SQLTable)):
72
+ return create_case(self.original_case, max_recursion_idx=3)
73
+ else:
74
+ return self.original_case
75
+
76
+ def _get_attribute_type(self) -> Type:
77
+ """
78
+ :return: The type of the attribute.
79
+ """
80
+ if self.target is not None:
81
+ return type(self.target)
82
+ elif hasattr(self.original_case, self.attribute_name):
83
+ hint, origin, args = get_hint_for_attribute(self.attribute_name, self.original_case)
84
+ if origin is not None:
85
+ return typing_to_python_type(origin)
86
+ elif hint is not None:
87
+ return typing_to_python_type(hint)
88
+
89
+ @property
90
+ def name(self):
91
+ """
92
+ :return: The name of the case query.
93
+ """
94
+ return f"{self.case_name}.{self.attribute_name}"
95
+
96
+ @property
97
+ def case_name(self) -> str:
98
+ """
99
+ :return: The name of the case.
100
+ """
101
+ return self.case._name if isinstance(self.case, Case) else self.case.__class__.__name__
102
+
103
+ def __str__(self):
104
+ header = f"CaseQuery: {self.name}"
105
+ target = f"Target: {self.name} |= {self.target if self.target is not None else '?'}"
106
+ prediction = f"Prediction: {self.name} |= {self.prediction if self.prediction is not None else '?'}"
107
+ conditions = f"Conditions: {self.conditions if self.conditions is not None else '?'}"
108
+ return "\n".join([header, target, prediction, conditions])
109
+
110
+ def __repr__(self):
111
+ return self.__str__()
112
+
113
+ def __copy__(self):
114
+ return CaseQuery(copy_case(self.case), self.attribute_name, self.target, self.mutually_exclusive,
115
+ self.conditions, self.prediction, self.scope)