ripple-down-rules 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +0 -0
- ripple_down_rules/datasets.py +148 -0
- ripple_down_rules/datastructures/__init__.py +4 -0
- ripple_down_rules/datastructures/callable_expression.py +237 -0
- ripple_down_rules/datastructures/dataclasses.py +76 -0
- ripple_down_rules/datastructures/enums.py +173 -0
- ripple_down_rules/datastructures/generated/__init__.py +0 -0
- ripple_down_rules/datastructures/generated/column/__init__.py +0 -0
- ripple_down_rules/datastructures/generated/row/__init__.py +0 -0
- ripple_down_rules/datastructures/table.py +544 -0
- ripple_down_rules/experts.py +281 -0
- ripple_down_rules/failures.py +10 -0
- ripple_down_rules/prompt.py +101 -0
- ripple_down_rules/rdr.py +687 -0
- ripple_down_rules/rules.py +260 -0
- ripple_down_rules/utils.py +463 -0
- ripple_down_rules-0.0.0.dist-info/METADATA +54 -0
- ripple_down_rules-0.0.0.dist-info/RECORD +20 -0
- ripple_down_rules-0.0.0.dist-info/WHEEL +5 -0
- ripple_down_rules-0.0.0.dist-info/top_level.txt +1 -0
File without changes
|
@@ -0,0 +1,148 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
import pickle
|
5
|
+
|
6
|
+
import sqlalchemy
|
7
|
+
from sqlalchemy import ForeignKey
|
8
|
+
from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationship
|
9
|
+
from typing_extensions import Tuple, List, Set
|
10
|
+
from ucimlrepo import fetch_ucirepo
|
11
|
+
|
12
|
+
from .datastructures import Case, create_rows_from_dataframe, Category, Column
|
13
|
+
|
14
|
+
|
15
|
+
def load_cached_dataset(cache_file):
|
16
|
+
"""Loads the dataset from cache if it exists."""
|
17
|
+
dataset = {}
|
18
|
+
for key in ["features", "targets", "ids"]:
|
19
|
+
part_file = cache_file.replace(".pkl", f"_{key}.pkl")
|
20
|
+
if not os.path.exists(part_file):
|
21
|
+
return None
|
22
|
+
with open(part_file, "rb") as f:
|
23
|
+
dataset[key] = pickle.load(f)
|
24
|
+
return dataset
|
25
|
+
|
26
|
+
|
27
|
+
def save_dataset_to_cache(dataset, cache_file):
|
28
|
+
"""Saves only essential parts of the dataset to cache."""
|
29
|
+
dataset_to_cache = {
|
30
|
+
"features": dataset.data.features,
|
31
|
+
"targets": dataset.data.targets,
|
32
|
+
"ids": dataset.data.ids,
|
33
|
+
}
|
34
|
+
|
35
|
+
for key, value in dataset_to_cache.items():
|
36
|
+
with open(cache_file.replace(".pkl", f"_{key}.pkl"), "wb") as f:
|
37
|
+
pickle.dump(dataset_to_cache[key], f)
|
38
|
+
print("Dataset cached successfully.")
|
39
|
+
|
40
|
+
|
41
|
+
def get_dataset(dataset_id, cache_file):
|
42
|
+
"""Fetches dataset from cache or downloads it if not available."""
|
43
|
+
dataset = load_cached_dataset(cache_file)
|
44
|
+
if dataset is None:
|
45
|
+
print("Downloading dataset...")
|
46
|
+
dataset = fetch_ucirepo(id=dataset_id)
|
47
|
+
|
48
|
+
# Check if dataset is valid before caching
|
49
|
+
if dataset is None or not hasattr(dataset, "data"):
|
50
|
+
print("Error: Failed to fetch dataset.")
|
51
|
+
return None
|
52
|
+
|
53
|
+
save_dataset_to_cache(dataset, cache_file)
|
54
|
+
|
55
|
+
return dataset
|
56
|
+
|
57
|
+
|
58
|
+
def load_zoo_dataset(cache_file: str) -> Tuple[List[Case], List[Species]]:
|
59
|
+
"""
|
60
|
+
Load the zoo dataset.
|
61
|
+
|
62
|
+
:param cache_file: the cache file.
|
63
|
+
:return: all cases and targets.
|
64
|
+
"""
|
65
|
+
# fetch dataset
|
66
|
+
zoo = get_dataset(111, cache_file)
|
67
|
+
|
68
|
+
# data (as pandas dataframes)
|
69
|
+
X = zoo['features']
|
70
|
+
y = zoo['targets']
|
71
|
+
# get ids as list of strings
|
72
|
+
ids = zoo['ids'].values.flatten()
|
73
|
+
all_cases = create_rows_from_dataframe(X, "Animal")
|
74
|
+
|
75
|
+
category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
|
76
|
+
category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
|
77
|
+
targets = [getattr(SpeciesCol, category_id_to_name[i]) for i in y.values.flatten()]
|
78
|
+
return all_cases, targets
|
79
|
+
|
80
|
+
|
81
|
+
class Species(Category):
|
82
|
+
mammal = "mammal"
|
83
|
+
bird = "bird"
|
84
|
+
reptile = "reptile"
|
85
|
+
fish = "fish"
|
86
|
+
amphibian = "amphibian"
|
87
|
+
insect = "insect"
|
88
|
+
molusc = "molusc"
|
89
|
+
|
90
|
+
|
91
|
+
class Habitat(Category):
|
92
|
+
"""
|
93
|
+
A habitat category is a category that represents the habitat of an animal.
|
94
|
+
"""
|
95
|
+
land = "land"
|
96
|
+
water = "water"
|
97
|
+
air = "air"
|
98
|
+
|
99
|
+
|
100
|
+
SpeciesCol = Column.create_from_enum(Species, mutually_exclusive=True)
|
101
|
+
HabitatCol = Column.create_from_enum(Habitat, mutually_exclusive=False)
|
102
|
+
|
103
|
+
|
104
|
+
class Base(sqlalchemy.orm.DeclarativeBase):
|
105
|
+
pass
|
106
|
+
|
107
|
+
|
108
|
+
class HabitatTable(MappedAsDataclass, Base):
|
109
|
+
__tablename__ = "Habitat"
|
110
|
+
|
111
|
+
id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
|
112
|
+
habitat: Mapped[Habitat]
|
113
|
+
animal_id = mapped_column(ForeignKey("Animal.id"), init=False)
|
114
|
+
|
115
|
+
def __hash__(self):
|
116
|
+
return hash(self.habitat)
|
117
|
+
|
118
|
+
def __str__(self):
|
119
|
+
return self.habitat.value
|
120
|
+
|
121
|
+
def __repr__(self):
|
122
|
+
return self.__str__()
|
123
|
+
|
124
|
+
|
125
|
+
class Animal(MappedAsDataclass, Base):
|
126
|
+
__tablename__ = "Animal"
|
127
|
+
|
128
|
+
id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
|
129
|
+
name: Mapped[str]
|
130
|
+
hair: Mapped[bool]
|
131
|
+
feathers: Mapped[bool]
|
132
|
+
eggs: Mapped[bool]
|
133
|
+
milk: Mapped[bool]
|
134
|
+
airborne: Mapped[bool]
|
135
|
+
aquatic: Mapped[bool]
|
136
|
+
predator: Mapped[bool]
|
137
|
+
toothed: Mapped[bool]
|
138
|
+
backbone: Mapped[bool]
|
139
|
+
breathes: Mapped[bool]
|
140
|
+
venomous: Mapped[bool]
|
141
|
+
fins: Mapped[bool]
|
142
|
+
legs: Mapped[int]
|
143
|
+
tail: Mapped[bool]
|
144
|
+
domestic: Mapped[bool]
|
145
|
+
catsize: Mapped[bool]
|
146
|
+
species: Mapped[Species] = mapped_column(nullable=True)
|
147
|
+
|
148
|
+
habitats: Mapped[Set[HabitatTable]] = relationship(default_factory=set)
|
@@ -0,0 +1,237 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import ast
|
4
|
+
import logging
|
5
|
+
from _ast import AST
|
6
|
+
|
7
|
+
from sqlalchemy.orm import Session
|
8
|
+
from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
|
9
|
+
|
10
|
+
from .table import create_row, Row
|
11
|
+
from ..utils import SubclassJSONSerializer, get_full_class_name
|
12
|
+
|
13
|
+
|
14
|
+
class VariableVisitor(ast.NodeVisitor):
|
15
|
+
"""
|
16
|
+
A visitor to extract all variables and comparisons from a python expression represented as an AST tree.
|
17
|
+
"""
|
18
|
+
compares: List[Tuple[Union[ast.Name, ast.Call], ast.cmpop, Union[ast.Name, ast.Call]]]
|
19
|
+
variables: Set[str]
|
20
|
+
all: List[ast.BoolOp]
|
21
|
+
|
22
|
+
def __init__(self):
|
23
|
+
self.variables = set()
|
24
|
+
self.attributes: Dict[ast.Name, ast.Attribute] = {}
|
25
|
+
self.compares = list()
|
26
|
+
self.binary_ops = list()
|
27
|
+
self.all = list()
|
28
|
+
|
29
|
+
def visit_Attribute(self, node):
|
30
|
+
self.all.append(node)
|
31
|
+
self.attributes[node.value] = node
|
32
|
+
self.generic_visit(node)
|
33
|
+
|
34
|
+
def visit_BinOp(self, node):
|
35
|
+
self.binary_ops.append(node)
|
36
|
+
self.all.append(node)
|
37
|
+
self.generic_visit(node)
|
38
|
+
|
39
|
+
def visit_BoolOp(self, node):
|
40
|
+
self.all.append(node)
|
41
|
+
self.generic_visit(node)
|
42
|
+
|
43
|
+
def visit_Compare(self, node):
|
44
|
+
self.all.append(node)
|
45
|
+
self.compares.append([node.left, node.ops[0], node.comparators[0]])
|
46
|
+
self.generic_visit(node)
|
47
|
+
|
48
|
+
def visit_Name(self, node):
|
49
|
+
if f"__{node.id}__" not in dir(__builtins__) and node not in self.attributes:
|
50
|
+
self.variables.add(node.id)
|
51
|
+
self.generic_visit(node)
|
52
|
+
|
53
|
+
|
54
|
+
class CallableExpression(SubclassJSONSerializer):
|
55
|
+
"""
|
56
|
+
A callable that is constructed from a string statement written by an expert.
|
57
|
+
"""
|
58
|
+
conclusion_type: Type
|
59
|
+
"""
|
60
|
+
The type of the output of the callable, used for assertion.
|
61
|
+
"""
|
62
|
+
expression_tree: AST
|
63
|
+
"""
|
64
|
+
The AST tree parsed from the user input.
|
65
|
+
"""
|
66
|
+
user_input: str
|
67
|
+
"""
|
68
|
+
The input given by the expert.
|
69
|
+
"""
|
70
|
+
session: Optional[Session]
|
71
|
+
"""
|
72
|
+
The sqlalchemy orm session.
|
73
|
+
"""
|
74
|
+
visitor: VariableVisitor
|
75
|
+
"""
|
76
|
+
A visitor to extract all variables and comparisons from a python expression represented as an AST tree.
|
77
|
+
"""
|
78
|
+
code: Any
|
79
|
+
"""
|
80
|
+
The code that was compiled from the expression tree
|
81
|
+
"""
|
82
|
+
compares_column_offset: List[int]
|
83
|
+
"""
|
84
|
+
The start and end indices of each comparison in the string of user input.
|
85
|
+
"""
|
86
|
+
|
87
|
+
def __init__(self, user_input: str, conclusion_type: Optional[Type] = None, expression_tree: Optional[AST] = None,
|
88
|
+
session: Optional[Session] = None):
|
89
|
+
"""
|
90
|
+
Create a callable expression.
|
91
|
+
|
92
|
+
:param user_input: The input given by the expert.
|
93
|
+
:param conclusion_type: The type of the output of the callable.
|
94
|
+
:param expression_tree: The AST tree parsed from the user input.
|
95
|
+
:param session: The sqlalchemy orm session.
|
96
|
+
"""
|
97
|
+
self.session = session
|
98
|
+
self.user_input: str = user_input
|
99
|
+
self.parsed_user_input = self.parse_user_input(user_input, session)
|
100
|
+
self.conclusion_type = conclusion_type
|
101
|
+
self.update_expression(self.parsed_user_input, expression_tree)
|
102
|
+
|
103
|
+
@staticmethod
|
104
|
+
def parse_user_input(user_input: str, session: Optional[Session] = None) -> str:
|
105
|
+
if ',' in user_input:
|
106
|
+
user_input = user_input.split(',')
|
107
|
+
user_input = [f"({u.strip()})" for u in user_input]
|
108
|
+
user_input = ' & '.join(user_input) if session else ' and '.join(user_input)
|
109
|
+
elif session:
|
110
|
+
user_input = user_input.replace(" and ", " & ")
|
111
|
+
user_input = user_input.replace(" or ", " | ")
|
112
|
+
return user_input
|
113
|
+
|
114
|
+
def update_expression(self, user_input: str, expression_tree: Optional[AST] = None):
|
115
|
+
if not expression_tree:
|
116
|
+
expression_tree = parse_string_to_expression(user_input)
|
117
|
+
self.expression_tree: AST = expression_tree
|
118
|
+
self.visitor = VariableVisitor()
|
119
|
+
self.visitor.visit(expression_tree)
|
120
|
+
variables_str = self.visitor.variables
|
121
|
+
attributes_str = get_attributes_str(self.visitor)
|
122
|
+
for v in variables_str | attributes_str:
|
123
|
+
if not v.startswith("case."):
|
124
|
+
self.parsed_user_input = self.parsed_user_input.replace(v, f"case.{v}")
|
125
|
+
self.expression_tree = parse_string_to_expression(self.parsed_user_input)
|
126
|
+
self.compares_column_offset = [(c[0].col_offset, c[2].end_col_offset) for c in self.visitor.compares]
|
127
|
+
self.code = compile_expression_to_code(self.expression_tree)
|
128
|
+
|
129
|
+
def __call__(self, case: Any, **kwargs) -> Any:
|
130
|
+
try:
|
131
|
+
if not isinstance(case, Row):
|
132
|
+
case = create_row(case, max_recursion_idx=3)
|
133
|
+
output = eval(self.code)
|
134
|
+
if self.conclusion_type:
|
135
|
+
assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
|
136
|
+
f" got {type(output)}")
|
137
|
+
return output
|
138
|
+
except Exception as e:
|
139
|
+
raise ValueError(f"Error during evaluation: {e}")
|
140
|
+
|
141
|
+
def combine_with(self, other: 'CallableExpression') -> 'CallableExpression':
|
142
|
+
"""
|
143
|
+
Combine this callable expression with another callable expression using the 'and' operator.
|
144
|
+
"""
|
145
|
+
new_user_input = f"({self.user_input}) and ({other.user_input})"
|
146
|
+
return CallableExpression(new_user_input, conclusion_type=self.conclusion_type, session=self.session)
|
147
|
+
|
148
|
+
def __str__(self):
|
149
|
+
"""
|
150
|
+
Return the user string where each compare is written in a line using compare column offset start and end.
|
151
|
+
"""
|
152
|
+
user_input = self.parsed_user_input
|
153
|
+
binary_ops = sorted(self.visitor.binary_ops, key=lambda x: x.end_col_offset)
|
154
|
+
binary_ops_indices = [b.end_col_offset for b in binary_ops]
|
155
|
+
all_binary_ops = []
|
156
|
+
prev_e = 0
|
157
|
+
for i, e in enumerate(binary_ops_indices):
|
158
|
+
if i == 0:
|
159
|
+
all_binary_ops.append(user_input[:e])
|
160
|
+
else:
|
161
|
+
all_binary_ops.append(user_input[prev_e:e])
|
162
|
+
prev_e = e
|
163
|
+
return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
|
164
|
+
|
165
|
+
def to_json(self) -> Dict[str, Any]:
|
166
|
+
return {**SubclassJSONSerializer.to_json(self),
|
167
|
+
"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type)}
|
168
|
+
|
169
|
+
@classmethod
|
170
|
+
def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
|
171
|
+
return cls(user_input=data["user_input"], conclusion_type=data["conclusion_type"])
|
172
|
+
|
173
|
+
|
174
|
+
def compile_expression_to_code(expression_tree: AST) -> Any:
|
175
|
+
"""
|
176
|
+
Compile an expression tree that was parsed from string into code that can be executed using 'eval(code)'
|
177
|
+
|
178
|
+
:param expression_tree: The parsed expression tree.
|
179
|
+
:return: The code that was compiled from the expression tree.
|
180
|
+
"""
|
181
|
+
return compile(expression_tree, filename="<string>", mode="eval")
|
182
|
+
|
183
|
+
|
184
|
+
def assert_context_contains_needed_information(case: Any, context: Dict[str, Any],
|
185
|
+
visitor: VariableVisitor) -> Tuple[Set[str], Set[str]]:
|
186
|
+
"""
|
187
|
+
Asserts that the variables mentioned in the expression visited by visitor are all in the given context.
|
188
|
+
|
189
|
+
:param case: The case to check the context for.
|
190
|
+
:param context: The context to check.
|
191
|
+
:param visitor: The visitor that visited the expression.
|
192
|
+
:return: The found variables and attributes.
|
193
|
+
"""
|
194
|
+
found_variables = set()
|
195
|
+
for key in visitor.variables:
|
196
|
+
if key not in context:
|
197
|
+
raise ValueError(f"Variable {key} not found in the case {case}")
|
198
|
+
found_variables.add(key)
|
199
|
+
|
200
|
+
found_attributes = get_attributes_str(visitor)
|
201
|
+
for attr in found_attributes:
|
202
|
+
if attr not in context:
|
203
|
+
raise ValueError(f"Attribute {attr} not found in the case {case}")
|
204
|
+
return found_variables, found_attributes
|
205
|
+
|
206
|
+
|
207
|
+
def get_attributes_str(visitor: VariableVisitor) -> Set[str]:
|
208
|
+
"""
|
209
|
+
Get the string representation of the attributes in the given visitor.
|
210
|
+
|
211
|
+
:param visitor: The visitor that visited the expression.
|
212
|
+
:return: The string representation of the attributes.
|
213
|
+
"""
|
214
|
+
found_attributes = set()
|
215
|
+
for key, ast_attr in visitor.attributes.items():
|
216
|
+
str_attr = ""
|
217
|
+
while isinstance(key, ast.Attribute):
|
218
|
+
if len(str_attr) > 0:
|
219
|
+
str_attr = f"{key.attr}.{str_attr}"
|
220
|
+
else:
|
221
|
+
str_attr = key.attr
|
222
|
+
key = key.value
|
223
|
+
str_attr = f"{key.id}.{str_attr}" if len(str_attr) > 0 else f"{key.id}.{ast_attr.attr}"
|
224
|
+
found_attributes.add(str_attr)
|
225
|
+
return found_attributes
|
226
|
+
|
227
|
+
|
228
|
+
def parse_string_to_expression(expression_str: str) -> AST:
|
229
|
+
"""
|
230
|
+
Parse a string statement into an AST expression.
|
231
|
+
|
232
|
+
:param expression_str: The string which will be parsed.
|
233
|
+
:return: The parsed expression.
|
234
|
+
"""
|
235
|
+
tree = ast.parse(expression_str, mode='eval')
|
236
|
+
logging.debug(f"AST parsed successfully: {ast.dump(tree)}")
|
237
|
+
return tree
|
@@ -0,0 +1,76 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from copy import copy, deepcopy
|
4
|
+
from dataclasses import dataclass
|
5
|
+
|
6
|
+
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
7
|
+
from typing_extensions import Any, Optional, Type, Union
|
8
|
+
|
9
|
+
from .table import create_row, Case
|
10
|
+
from ..utils import get_attribute_name, copy_orm_instance_with_relationships, copy_case
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class CaseQuery:
|
15
|
+
"""
|
16
|
+
This is a dataclass that represents an attribute of an object and its target value. If attribute name is
|
17
|
+
not provided, it will be inferred from the attribute itself or from the attribute type or from the target value,
|
18
|
+
depending on what is provided.
|
19
|
+
"""
|
20
|
+
case: Any
|
21
|
+
"""
|
22
|
+
The case that the attribute belongs to.
|
23
|
+
"""
|
24
|
+
attribute: Optional[Any] = None
|
25
|
+
"""
|
26
|
+
The attribute itself.
|
27
|
+
"""
|
28
|
+
targets: Optional[Any] = None
|
29
|
+
"""
|
30
|
+
The target value of the attribute.
|
31
|
+
"""
|
32
|
+
attribute_name: Optional[str] = None
|
33
|
+
"""
|
34
|
+
The name of the attribute.
|
35
|
+
"""
|
36
|
+
attribute_type: Optional[Type] = None
|
37
|
+
"""
|
38
|
+
The type of the attribute.
|
39
|
+
"""
|
40
|
+
relational_representation: Optional[str] = None
|
41
|
+
"""
|
42
|
+
The representation of the target value in relational form.
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(self, case: Any, attribute: Optional[Any] = None, target: Optional[Any] = None,
|
46
|
+
attribute_name: Optional[str] = None, attribute_type: Optional[Type] = None,
|
47
|
+
relational_representation: Optional[str] = None):
|
48
|
+
|
49
|
+
if attribute_name is None:
|
50
|
+
attribute_name = get_attribute_name(case, attribute, attribute_type, target)
|
51
|
+
self.attribute_name = attribute_name
|
52
|
+
|
53
|
+
if not isinstance(case, (Case, SQLTable)):
|
54
|
+
case = create_row(case, max_recursion_idx=3)
|
55
|
+
self.case = case
|
56
|
+
|
57
|
+
self.attribute = getattr(self.case, self.attribute_name) if self.attribute_name else None
|
58
|
+
self.attribute_type = type(self.attribute) if self.attribute else None
|
59
|
+
self.target = target
|
60
|
+
self.relational_representation = relational_representation
|
61
|
+
|
62
|
+
@property
|
63
|
+
def name(self):
|
64
|
+
return self.attribute_name if self.attribute_name else self.__class__.__name__
|
65
|
+
|
66
|
+
def __str__(self):
|
67
|
+
if self.relational_representation:
|
68
|
+
return f"{self.name} |= {self.relational_representation}"
|
69
|
+
else:
|
70
|
+
return f"{self.target}"
|
71
|
+
|
72
|
+
def __repr__(self):
|
73
|
+
return self.__str__()
|
74
|
+
|
75
|
+
def __copy__(self):
|
76
|
+
return CaseQuery(copy_case(self.case), attribute_name=self.attribute_name, target=self.target)
|
@@ -0,0 +1,173 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from enum import auto, Enum
|
4
|
+
|
5
|
+
from typing_extensions import List
|
6
|
+
|
7
|
+
|
8
|
+
class Category(str, Enum):
|
9
|
+
|
10
|
+
@classmethod
|
11
|
+
def from_str(cls, value: str) -> Category:
|
12
|
+
return getattr(cls, value)
|
13
|
+
|
14
|
+
@classmethod
|
15
|
+
def from_strs(cls, values: List[str]) -> List[Category]:
|
16
|
+
return [cls.from_str(value) for value in values]
|
17
|
+
|
18
|
+
@property
|
19
|
+
def as_dict(self):
|
20
|
+
return {self.__class__.__name__.lower(): self.value}
|
21
|
+
|
22
|
+
|
23
|
+
class Stop(Category):
|
24
|
+
"""
|
25
|
+
A stop category is a special category that represents the stopping of the classification to prevent a wrong
|
26
|
+
conclusion from being made.
|
27
|
+
"""
|
28
|
+
stop = "stop"
|
29
|
+
|
30
|
+
|
31
|
+
class ExpressionParser(Enum):
|
32
|
+
"""
|
33
|
+
Parsers for expressions to evaluate and encapsulate the expression into a callable function.
|
34
|
+
"""
|
35
|
+
ASTVisitor: int = auto()
|
36
|
+
"""
|
37
|
+
Generic python Abstract Syntax Tree that detects variables, attributes, binary/boolean expressions , ...etc.
|
38
|
+
"""
|
39
|
+
SQLAlchemy: int = auto()
|
40
|
+
"""
|
41
|
+
Specific for SQLAlchemy expressions on ORM Tables.
|
42
|
+
"""
|
43
|
+
|
44
|
+
|
45
|
+
class PromptFor(Enum):
|
46
|
+
"""
|
47
|
+
The reason of the prompt. (e.g. get conditions, or conclusions).
|
48
|
+
"""
|
49
|
+
Conditions: str = "conditions"
|
50
|
+
"""
|
51
|
+
Prompt for rule conditions about a case.
|
52
|
+
"""
|
53
|
+
Conclusion: str = "conclusion"
|
54
|
+
"""
|
55
|
+
Prompt for rule conclusion about a case.
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __str__(self):
|
59
|
+
return self.name
|
60
|
+
|
61
|
+
def __repr__(self):
|
62
|
+
return self.__str__()
|
63
|
+
|
64
|
+
|
65
|
+
class CategoricalValue(Enum):
|
66
|
+
"""
|
67
|
+
A categorical value is a value that is a category.
|
68
|
+
"""
|
69
|
+
|
70
|
+
def __eq__(self, other):
|
71
|
+
if isinstance(other, CategoricalValue):
|
72
|
+
return self.name == other.name
|
73
|
+
elif isinstance(other, str):
|
74
|
+
return self.name == other
|
75
|
+
return self.name == other
|
76
|
+
|
77
|
+
def __hash__(self):
|
78
|
+
return hash(self.name)
|
79
|
+
|
80
|
+
@classmethod
|
81
|
+
def to_list(cls):
|
82
|
+
return list(cls._value2member_map_.keys())
|
83
|
+
|
84
|
+
@classmethod
|
85
|
+
def from_str(cls, category: str):
|
86
|
+
return cls[category.lower()]
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
def from_strs(cls, categories: List[str]):
|
90
|
+
return [cls.from_str(c) for c in categories]
|
91
|
+
|
92
|
+
def __str__(self):
|
93
|
+
return self.name
|
94
|
+
|
95
|
+
def __repr__(self):
|
96
|
+
return self.__str__()
|
97
|
+
|
98
|
+
|
99
|
+
class RDRMode(Enum):
|
100
|
+
Propositional = auto()
|
101
|
+
"""
|
102
|
+
Propositional mode, the mode where the rules are propositional.
|
103
|
+
"""
|
104
|
+
Relational = auto()
|
105
|
+
"""
|
106
|
+
Relational mode, the mode where the rules are relational.
|
107
|
+
"""
|
108
|
+
|
109
|
+
|
110
|
+
class MCRDRMode(Enum):
|
111
|
+
"""
|
112
|
+
The modes of the MultiClassRDR.
|
113
|
+
"""
|
114
|
+
StopOnly = auto()
|
115
|
+
"""
|
116
|
+
StopOnly mode, stop wrong conclusion from being made and does not add a new rule to make the correct conclusion.
|
117
|
+
"""
|
118
|
+
StopPlusRule = auto()
|
119
|
+
"""
|
120
|
+
StopPlusRule mode, stop wrong conclusion from being made and adds a new rule with same conditions as stopping rule
|
121
|
+
to make the correct conclusion.
|
122
|
+
"""
|
123
|
+
StopPlusRuleCombined = auto()
|
124
|
+
"""
|
125
|
+
StopPlusRuleCombined mode, stop wrong conclusion from being made and adds a new rule with combined conditions of
|
126
|
+
stopping rule and the rule that should have fired.
|
127
|
+
"""
|
128
|
+
|
129
|
+
|
130
|
+
class RDREdge(Enum):
|
131
|
+
Refinement = "except if"
|
132
|
+
"""
|
133
|
+
Refinement edge, the edge that represents the refinement of an incorrectly fired rule.
|
134
|
+
"""
|
135
|
+
Alternative = "else if"
|
136
|
+
"""
|
137
|
+
Alternative edge, the edge that represents the alternative to the rule that has not fired.
|
138
|
+
"""
|
139
|
+
Next = "next"
|
140
|
+
"""
|
141
|
+
Next edge, the edge that represents the next rule to be evaluated.
|
142
|
+
"""
|
143
|
+
|
144
|
+
|
145
|
+
class ValueType(Enum):
|
146
|
+
Unary = auto()
|
147
|
+
"""
|
148
|
+
Unary value type (eg. null).
|
149
|
+
"""
|
150
|
+
Binary = auto()
|
151
|
+
"""
|
152
|
+
Binary value type (eg. True, False).
|
153
|
+
"""
|
154
|
+
Discrete = auto()
|
155
|
+
"""
|
156
|
+
Discrete value type (eg. 1, 2, 3).
|
157
|
+
"""
|
158
|
+
Continuous = auto()
|
159
|
+
"""
|
160
|
+
Continuous value type (eg. 1.0, 2.5, 3.4).
|
161
|
+
"""
|
162
|
+
Nominal = auto()
|
163
|
+
"""
|
164
|
+
Nominal value type (eg. red, blue, green), categories where the values have no natural order.
|
165
|
+
"""
|
166
|
+
Ordinal = auto()
|
167
|
+
"""
|
168
|
+
Ordinal value type (eg. low, medium, high), categories where the values have a natural order.
|
169
|
+
"""
|
170
|
+
Iterable = auto()
|
171
|
+
"""
|
172
|
+
Iterable value type (eg. [1, 2, 3]).
|
173
|
+
"""
|
File without changes
|
File without changes
|
File without changes
|