ripple-down-rules 0.1.21__tar.gz → 0.1.62__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/PKG-INFO +5 -4
  2. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/README.md +4 -3
  3. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/pyproject.toml +1 -1
  4. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules/datasets.py +2 -1
  5. ripple_down_rules-0.1.62/src/ripple_down_rules/datastructures/__init__.py +4 -0
  6. ripple_down_rules-0.1.62/src/ripple_down_rules/datastructures/callable_expression.py +223 -0
  7. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules/datastructures/case.py +8 -6
  8. ripple_down_rules-0.1.62/src/ripple_down_rules/datastructures/dataclasses.py +169 -0
  9. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules/datastructures/enums.py +5 -1
  10. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules/experts.py +61 -68
  11. ripple_down_rules-0.1.62/src/ripple_down_rules/helpers.py +51 -0
  12. ripple_down_rules-0.1.62/src/ripple_down_rules/prompt.py +167 -0
  13. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules/rdr.py +291 -206
  14. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules/rules.py +64 -32
  15. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules/utils.py +209 -4
  16. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules.egg-info/PKG-INFO +5 -4
  17. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules.egg-info/SOURCES.txt +1 -0
  18. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/test/test_json_serialization.py +2 -2
  19. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/test/test_on_mutagenic.py +5 -5
  20. ripple_down_rules-0.1.62/test/test_rdr.py +331 -0
  21. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/test/test_rdr_alchemy.py +24 -30
  22. ripple_down_rules-0.1.62/test/test_rdr_world.py +161 -0
  23. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/test/test_relational_rdr.py +7 -7
  24. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/test/test_relational_rdr_alchemy.py +8 -7
  25. ripple_down_rules-0.1.21/src/ripple_down_rules/datastructures/__init__.py +0 -4
  26. ripple_down_rules-0.1.21/src/ripple_down_rules/datastructures/callable_expression.py +0 -278
  27. ripple_down_rules-0.1.21/src/ripple_down_rules/datastructures/dataclasses.py +0 -115
  28. ripple_down_rules-0.1.21/src/ripple_down_rules/helpers.py +0 -27
  29. ripple_down_rules-0.1.21/src/ripple_down_rules/prompt.py +0 -154
  30. ripple_down_rules-0.1.21/test/test_rdr.py +0 -243
  31. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/LICENSE +0 -0
  32. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/setup.cfg +0 -0
  33. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules/__init__.py +0 -0
  34. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules/failures.py +0 -0
  35. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules/rdr_decorators.py +0 -0
  36. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
  37. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
  38. {ripple_down_rules-0.1.21 → ripple_down_rules-0.1.62}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.21
3
+ Version: 0.1.62
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -709,7 +709,7 @@ Fit the SCRDR to the data, then classify one of the data cases to check if its c
709
709
  and render the tree to a file:
710
710
 
711
711
  ```Python
712
- from ripple_down_rules.datastructures import CaseQuery
712
+ from ripple_down_rules.datastructures.dataclasses import CaseQuery
713
713
  from ripple_down_rules.rdr import SingleClassRDR
714
714
  from ripple_down_rules.datasets import load_zoo_dataset
715
715
  from ripple_down_rules.utils import render_tree
@@ -719,12 +719,13 @@ all_cases, targets = load_zoo_dataset()
719
719
  scrdr = SingleClassRDR()
720
720
 
721
721
  # Fit the SCRDR to the data
722
- case_queries = [CaseQuery(case, target=target) for case, target in zip(all_cases, targets)]
722
+ case_queries = [CaseQuery(case, 'species', type(target), True, _target=target)
723
+ for case, target in zip(all_cases[:10], targets[:10])]
723
724
  scrdr.fit(case_queries, animate_tree=True)
724
725
 
725
726
  # Render the tree to a file
726
727
  render_tree(scrdr.start_rule, use_dot_exporter=True, filename="scrdr")
727
728
 
728
- cat = scrdr.fit_case(all_cases[50], targets[50])
729
+ cat = scrdr.classify(all_cases[50])
729
730
  assert cat == targets[50]
730
731
  ```
@@ -22,7 +22,7 @@ Fit the SCRDR to the data, then classify one of the data cases to check if its c
22
22
  and render the tree to a file:
23
23
 
24
24
  ```Python
25
- from ripple_down_rules.datastructures import CaseQuery
25
+ from ripple_down_rules.datastructures.dataclasses import CaseQuery
26
26
  from ripple_down_rules.rdr import SingleClassRDR
27
27
  from ripple_down_rules.datasets import load_zoo_dataset
28
28
  from ripple_down_rules.utils import render_tree
@@ -32,12 +32,13 @@ all_cases, targets = load_zoo_dataset()
32
32
  scrdr = SingleClassRDR()
33
33
 
34
34
  # Fit the SCRDR to the data
35
- case_queries = [CaseQuery(case, target=target) for case, target in zip(all_cases, targets)]
35
+ case_queries = [CaseQuery(case, 'species', type(target), True, _target=target)
36
+ for case, target in zip(all_cases[:10], targets[:10])]
36
37
  scrdr.fit(case_queries, animate_tree=True)
37
38
 
38
39
  # Render the tree to a file
39
40
  render_tree(scrdr.start_rule, use_dot_exporter=True, filename="scrdr")
40
41
 
41
- cat = scrdr.fit_case(all_cases[50], targets[50])
42
+ cat = scrdr.classify(all_cases[50])
42
43
  assert cat == targets[50]
43
44
  ```
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
 
7
7
  [project]
8
8
  name = "ripple_down_rules"
9
- version = "0.1.21"
9
+ version = "0.1.62"
10
10
  description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
11
11
  readme = "README.md"
12
12
  authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
@@ -9,7 +9,8 @@ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationshi
9
9
  from typing_extensions import Tuple, List, Set, Optional
10
10
  from ucimlrepo import fetch_ucirepo
11
11
 
12
- from .datastructures import Case, create_cases_from_dataframe, Category
12
+ from .datastructures.case import Case, create_cases_from_dataframe
13
+ from .datastructures.enums import Category
13
14
 
14
15
 
15
16
  def load_cached_dataset(cache_file):
@@ -0,0 +1,4 @@
1
+ # from .enums import *
2
+ # from .dataclasses import *
3
+ # from .callable_expression import *
4
+ # from .case import *
@@ -0,0 +1,223 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import logging
5
+ from _ast import AST
6
+
7
+ from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
8
+
9
+ from .case import create_case, Case
10
+ from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string, conclusion_to_json, is_iterable
11
+
12
+
13
+ class VariableVisitor(ast.NodeVisitor):
14
+ """
15
+ A visitor to extract all variables and comparisons from a python expression represented as an AST tree.
16
+ """
17
+ compares: List[Tuple[Union[ast.Name, ast.Call], ast.cmpop, Union[ast.Name, ast.Call]]]
18
+ variables: Set[str]
19
+ all: List[ast.BoolOp]
20
+
21
+ def __init__(self):
22
+ self.variables = set()
23
+ self.attributes: Dict[ast.Name, ast.Attribute] = {}
24
+ self.types = set()
25
+ self.callables = set()
26
+ self.compares = list()
27
+ self.binary_ops = list()
28
+ self.all = list()
29
+
30
+ def visit_Constant(self, node):
31
+ self.all.append(node)
32
+ self.types.add(node)
33
+ self.generic_visit(node)
34
+
35
+ def visit_Call(self, node):
36
+ self.all.append(node)
37
+ self.callables.add(node)
38
+ self.generic_visit(node)
39
+
40
+ def visit_Attribute(self, node):
41
+ self.all.append(node)
42
+ self.attributes[node.value] = node
43
+ self.generic_visit(node)
44
+
45
+ def visit_BinOp(self, node):
46
+ self.binary_ops.append(node)
47
+ self.all.append(node)
48
+ self.generic_visit(node)
49
+
50
+ def visit_BoolOp(self, node):
51
+ self.all.append(node)
52
+ self.generic_visit(node)
53
+
54
+ def visit_Compare(self, node):
55
+ self.all.append(node)
56
+ self.compares.append([node.left, node.ops[0], node.comparators[0]])
57
+ self.generic_visit(node)
58
+
59
+ def visit_Name(self, node):
60
+ if f"__{node.id}__" not in dir(__builtins__) and node not in self.attributes:
61
+ self.variables.add(node.id)
62
+ self.generic_visit(node)
63
+
64
+
65
+ def get_used_scope(code_str, scope):
66
+ # Parse the code into an AST
67
+ mode = 'exec' if code_str.startswith('def') else 'eval'
68
+ tree = ast.parse(code_str, mode=mode)
69
+
70
+ # Walk the AST to collect used variable names
71
+ class NameCollector(ast.NodeVisitor):
72
+ def __init__(self):
73
+ self.names = set()
74
+
75
+ def visit_Name(self, node):
76
+ if isinstance(node.ctx, ast.Load): # We care only about variables being read
77
+ self.names.add(node.id)
78
+
79
+ collector = NameCollector()
80
+ collector.visit(tree)
81
+
82
+ # Filter the scope to include only used names
83
+ used_scope = {k: scope[k] for k in collector.names if k in scope}
84
+ return used_scope
85
+
86
+
87
+ class CallableExpression(SubclassJSONSerializer):
88
+ """
89
+ A callable that is constructed from a string statement written by an expert.
90
+ """
91
+
92
+ def __init__(self, user_input: Optional[str] = None, conclusion_type: Optional[Tuple[Type]] = None,
93
+ expression_tree: Optional[AST] = None,
94
+ scope: Optional[Dict[str, Any]] = None, conclusion: Optional[Any] = None):
95
+ """
96
+ Create a callable expression.
97
+
98
+ :param user_input: The input given by the expert.
99
+ :param conclusion_type: The type of the output of the callable.
100
+ :param expression_tree: The AST tree parsed from the user input.
101
+ :param scope: The scope to use for the callable expression.
102
+ :param conclusion: The conclusion to use for the callable expression.
103
+ """
104
+ if user_input is None and conclusion is None:
105
+ raise ValueError("Either user_input or conclusion must be provided.")
106
+ self.conclusion: Optional[Any] = conclusion
107
+ self.user_input: str = user_input
108
+ if conclusion_type is not None:
109
+ if is_iterable(conclusion_type):
110
+ conclusion_type = tuple(conclusion_type)
111
+ else:
112
+ conclusion_type = (conclusion_type,)
113
+ self.conclusion_type = conclusion_type
114
+ self.scope: Optional[Dict[str, Any]] = scope if scope is not None else {}
115
+ if conclusion is None:
116
+ self.scope = get_used_scope(self.user_input, self.scope)
117
+ self.expression_tree: AST = expression_tree if expression_tree else parse_string_to_expression(self.user_input)
118
+ self.code = compile_expression_to_code(self.expression_tree)
119
+ self.visitor = VariableVisitor()
120
+ self.visitor.visit(self.expression_tree)
121
+
122
+ def __call__(self, case: Any, **kwargs) -> Any:
123
+ try:
124
+ if self.user_input is not None:
125
+ if not isinstance(case, Case):
126
+ case = create_case(case, max_recursion_idx=3)
127
+ scope = {'case': case, **self.scope}
128
+ output = eval(self.code, scope)
129
+ if output is None:
130
+ output = scope['_get_value'](case)
131
+ if self.conclusion_type is not None:
132
+ if is_iterable(output) and not isinstance(output, self.conclusion_type):
133
+ assert isinstance(list(output)[0], self.conclusion_type), (f"Expected output type {self.conclusion_type},"
134
+ f" got {type(output)}")
135
+ else:
136
+ assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
137
+ f" got {type(output)}")
138
+ return output
139
+ else:
140
+ return self.conclusion
141
+ except Exception as e:
142
+ raise ValueError(f"Error during evaluation: {e}")
143
+
144
+ def combine_with(self, other: 'CallableExpression') -> 'CallableExpression':
145
+ """
146
+ Combine this callable expression with another callable expression using the 'and' operator.
147
+ """
148
+ new_user_input = f"({self.user_input}) and ({other.user_input})"
149
+ return CallableExpression(new_user_input, conclusion_type=self.conclusion_type)
150
+
151
+ def __eq__(self, other):
152
+ """
153
+ Check if two callable expressions are equal.
154
+ """
155
+ if not isinstance(other, CallableExpression):
156
+ return False
157
+ return self.user_input == other.user_input and self.conclusion == other.conclusion
158
+
159
+ def __hash__(self):
160
+ """
161
+ Hash the callable expression.
162
+ """
163
+ conclusion_hash = self.conclusion if not isinstance(self.conclusion, set) else frozenset(self.conclusion)
164
+ return hash((self.user_input, conclusion_hash))
165
+
166
+ def __str__(self):
167
+ """
168
+ Return the user string where each compare is written in a line using compare column offset start and end.
169
+ """
170
+ if self.user_input is None:
171
+ return str(self.conclusion)
172
+ binary_ops = sorted(self.visitor.binary_ops, key=lambda x: x.end_col_offset)
173
+ binary_ops_indices = [b.end_col_offset for b in binary_ops]
174
+ all_binary_ops = []
175
+ prev_e = 0
176
+ for i, e in enumerate(binary_ops_indices):
177
+ if i == 0:
178
+ all_binary_ops.append(self.user_input[:e])
179
+ else:
180
+ all_binary_ops.append(self.user_input[prev_e:e])
181
+ prev_e = e
182
+ return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else self.user_input
183
+
184
+ def _to_json(self) -> Dict[str, Any]:
185
+ return {"user_input": self.user_input,
186
+ "conclusion_type": [get_full_class_name(t) for t in self.conclusion_type]
187
+ if self.conclusion_type is not None else None,
188
+ "scope": {k: get_full_class_name(v) for k, v in self.scope.items()
189
+ if hasattr(v, '__module__') and hasattr(v, '__name__')},
190
+ "conclusion": conclusion_to_json(self.conclusion),
191
+ }
192
+
193
+ @classmethod
194
+ def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
195
+ return cls(user_input=data["user_input"],
196
+ conclusion_type=tuple(get_type_from_string(t) for t in data["conclusion_type"])
197
+ if data["conclusion_type"] else None,
198
+ scope={k: get_type_from_string(v) for k, v in data["scope"].items()},
199
+ conclusion=SubclassJSONSerializer.from_json(data["conclusion"]))
200
+
201
+
202
+ def compile_expression_to_code(expression_tree: AST) -> Any:
203
+ """
204
+ Compile an expression tree that was parsed from string into code that can be executed using 'eval(code)'
205
+
206
+ :param expression_tree: The parsed expression tree.
207
+ :return: The code that was compiled from the expression tree.
208
+ """
209
+ mode = 'exec' if isinstance(expression_tree, ast.Module) else 'eval'
210
+ return compile(expression_tree, filename="<string>", mode=mode)
211
+
212
+
213
+ def parse_string_to_expression(expression_str: str) -> AST:
214
+ """
215
+ Parse a string statement into an AST expression.
216
+
217
+ :param expression_str: The string which will be parsed.
218
+ :return: The parsed expression.
219
+ """
220
+ mode = 'exec' if expression_str.startswith('def') else 'eval'
221
+ tree = ast.parse(expression_str, mode=mode)
222
+ logging.debug(f"AST parsed successfully: {ast.dump(tree)}")
223
+ return tree
@@ -24,7 +24,8 @@ class Case(UserDict, SubclassJSONSerializer):
24
24
  the names of the attributes and the values are the attributes. All are stored in lower case.
25
25
  """
26
26
 
27
- def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None, _name: Optional[str] = None, **kwargs):
27
+ def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None,
28
+ _name: Optional[str] = None, original_object: Optional[Any] = None, **kwargs):
28
29
  """
29
30
  Create a new row.
30
31
 
@@ -34,6 +35,7 @@ class Case(UserDict, SubclassJSONSerializer):
34
35
  :param kwargs: The attributes of the row.
35
36
  """
36
37
  super().__init__(kwargs)
38
+ self._original_object = original_object
37
39
  self._obj_type: Type = _obj_type
38
40
  self._id: Hashable = _id if _id is not None else id(self)
39
41
  self._name: str = _name if _name is not None else self._obj_type.__name__
@@ -63,7 +65,7 @@ class Case(UserDict, SubclassJSONSerializer):
63
65
  new_list.extend(make_list(value))
64
66
  super().__setitem__(name, new_list)
65
67
  else:
66
- super().__setitem__(name, [self[name], value])
68
+ super().__setitem__(name, self[name])
67
69
  else:
68
70
  super().__setitem__(name, value)
69
71
  setattr(self, name, self[name])
@@ -221,9 +223,9 @@ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
221
223
  return obj
222
224
  if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
223
225
  or (obj.__class__ in [MetaData, registry])):
224
- return Case(type(obj), _id=id(obj), _name=obj_name,
226
+ return Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj,
225
227
  **{obj_name or obj.__class__.__name__: make_list(obj) if parent_is_iterable else obj})
226
- case = Case(type(obj), _id=id(obj), _name=obj_name)
228
+ case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
227
229
  for attr in dir(obj):
228
230
  if attr.startswith("_") or callable(getattr(obj, attr)):
229
231
  continue
@@ -251,7 +253,7 @@ def create_or_update_case_from_attribute(attr_value: Any, name: str, obj: Any, o
251
253
  :return: The updated/created case.
252
254
  """
253
255
  if case is None:
254
- case = Case(type(obj), _id=id(obj), _name=obj_name)
256
+ case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
255
257
  if isinstance(attr_value, (dict, UserDict)):
256
258
  case.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
257
259
  if hasattr(attr_value, "__iter__") and not isinstance(attr_value, str):
@@ -280,7 +282,7 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
280
282
  """
281
283
  values = list(attr_value.values()) if isinstance(attr_value, (dict, UserDict)) else attr_value
282
284
  _type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
283
- attr_case = Case(_type, _id=id(attr_value), _name=name)
285
+ attr_case = Case(_type, _id=id(attr_value), _name=name, original_object=attr_value)
284
286
  case_attr = CaseAttribute(values)
285
287
  for idx, val in enumerate(values):
286
288
  sub_attr_case = create_case(val, recursion_idx=recursion_idx,
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from dataclasses import dataclass, field
5
+
6
+ from sqlalchemy.orm import DeclarativeBase as SQLTable
7
+ from typing_extensions import Any, Optional, Dict, Type, Tuple, Union
8
+
9
+ from .callable_expression import CallableExpression
10
+ from .case import create_case, Case
11
+ from ..utils import copy_case, make_list, make_set
12
+
13
+
14
+ @dataclass
15
+ class CaseQuery:
16
+ """
17
+ This is a dataclass that represents an attribute of an object and its target value. If attribute name is
18
+ not provided, it will be inferred from the attribute itself or from the attribute type or from the target value,
19
+ depending on what is provided.
20
+ """
21
+ original_case: Any
22
+ """
23
+ The case that the attribute belongs to.
24
+ """
25
+ attribute_name: str
26
+ """
27
+ The name of the attribute.
28
+ """
29
+ _attribute_types: Tuple[Type]
30
+ """
31
+ The type(s) of the attribute.
32
+ """
33
+ mutually_exclusive: bool
34
+ """
35
+ Whether the attribute can only take one value (i.e. True) or multiple values (i.e. False).
36
+ """
37
+ _target: Optional[CallableExpression] = None
38
+ """
39
+ The target expression of the attribute.
40
+ """
41
+ default_value: Optional[Any] = None
42
+ """
43
+ The default value of the attribute. This is used when the target value is not provided.
44
+ """
45
+ scope: Optional[Dict[str, Any]] = field(default_factory=lambda: inspect.currentframe().f_back.f_back.f_globals)
46
+ """
47
+ The global scope of the case query. This is used to evaluate the conditions and prediction, and is what is available
48
+ to the user when they are prompted for input. If it is not provided, it will be set to the global scope of the
49
+ caller.
50
+ """
51
+ _case: Optional[Union[Case, SQLTable]] = None
52
+ """
53
+ The created case from the original case that the attribute belongs to.
54
+ """
55
+ _target_value: Optional[Any] = None
56
+ """
57
+ The target value of the case query. (This is the result of the target expression evaluation on the case.)
58
+ """
59
+ conditions: Optional[CallableExpression] = None
60
+ """
61
+ The conditions that must be satisfied for the target value to be valid.
62
+ """
63
+
64
+ @property
65
+ def case(self) -> Any:
66
+ """
67
+ :return: The case that the attribute belongs to.
68
+ """
69
+ if self._case is not None:
70
+ return self._case
71
+ elif not isinstance(self.original_case, (Case, SQLTable)):
72
+ self._case = create_case(self.original_case, max_recursion_idx=3)
73
+ else:
74
+ self._case = self.original_case
75
+ return self._case
76
+
77
+ @case.setter
78
+ def case(self, value: Any):
79
+ """
80
+ Set the case that the attribute belongs to.
81
+ """
82
+ if not isinstance(value, (Case, SQLTable)):
83
+ raise ValueError("The case must be a Case or SQLTable object.")
84
+ self._case = value
85
+
86
+ @property
87
+ def attribute_type(self) -> Tuple[Type]:
88
+ """
89
+ :return: The type of the attribute.
90
+ """
91
+ if not self.mutually_exclusive and (set not in make_list(self._attribute_types)):
92
+ self._attribute_types = tuple(set(make_list(self._attribute_types) + [set, list]))
93
+ elif not isinstance(self._attribute_types, tuple):
94
+ self._attribute_types = tuple(make_list(self._attribute_types))
95
+ return self._attribute_types
96
+
97
+ @attribute_type.setter
98
+ def attribute_type(self, value: Type):
99
+ """
100
+ Set the type of the attribute.
101
+ """
102
+ self._attribute_types = tuple(make_list(value))
103
+
104
+ @property
105
+ def name(self):
106
+ """
107
+ :return: The name of the case query.
108
+ """
109
+ return f"{self.case_name}.{self.attribute_name}"
110
+
111
+ @property
112
+ def case_name(self) -> str:
113
+ """
114
+ :return: The name of the case.
115
+ """
116
+ return self.case._name if isinstance(self.case, Case) else self.case.__class__.__name__
117
+
118
+ @property
119
+ def target(self) -> Optional[CallableExpression]:
120
+ """
121
+ :return: The target expression of the attribute.
122
+ """
123
+ if self._target is not None and not isinstance(self._target, CallableExpression):
124
+ self._target = CallableExpression(conclusion=self._target, conclusion_type=self.attribute_type,
125
+ scope=self.scope)
126
+ return self._target
127
+
128
+ @target.setter
129
+ def target(self, value: Optional[CallableExpression]):
130
+ """
131
+ Set the target expression of the attribute.
132
+ """
133
+ if value is not None and not isinstance(value, (CallableExpression, str)):
134
+ raise ValueError("The target must be a CallableExpression or a string.")
135
+ self._target = value
136
+ self._update_target_value()
137
+
138
+ @property
139
+ def target_value(self) -> Any:
140
+ """
141
+ :return: The target value of the case query.
142
+ """
143
+ if self._target_value is None:
144
+ self._update_target_value()
145
+ return self._target_value
146
+
147
+ def _update_target_value(self):
148
+ """
149
+ Update the target value of the case query.
150
+ """
151
+ if isinstance(self.target, CallableExpression):
152
+ self._target_value = self.target(self.case)
153
+ else:
154
+ self._target_value = self.target
155
+
156
+ def __str__(self):
157
+ header = f"CaseQuery: {self.name}"
158
+ target = f"Target: {self.name} |= {self.target if self.target is not None else '?'}"
159
+ conditions = f"Conditions: {self.conditions if self.conditions is not None else '?'}"
160
+ return "\n".join([header, target, conditions])
161
+
162
+ def __repr__(self):
163
+ return self.__str__()
164
+
165
+ def __copy__(self):
166
+ return CaseQuery(self.original_case, self.attribute_name, self.attribute_type,
167
+ self.mutually_exclusive, _target=self.target, default_value=self.default_value,
168
+ scope=self.scope, _case=copy_case(self.case), _target_value=self.target_value,
169
+ conditions=self.conditions)
@@ -53,7 +53,7 @@ class ExpressionParser(Enum):
53
53
 
54
54
  class PromptFor(Enum):
55
55
  """
56
- The reason of the prompt. (e.g. get conditions, or conclusions).
56
+ The reason of the prompt. (e.g. get conditions, conclusions, or affirmation).
57
57
  """
58
58
  Conditions: str = "conditions"
59
59
  """
@@ -63,6 +63,10 @@ class PromptFor(Enum):
63
63
  """
64
64
  Prompt for rule conclusion about a case.
65
65
  """
66
+ Affirmation: str = "affirmation"
67
+ """
68
+ Prompt for rule conclusion affirmation about a case.
69
+ """
66
70
 
67
71
  def __str__(self):
68
72
  return self.name