ripple-down-rules 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,148 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pickle
5
+
6
+ import sqlalchemy
7
+ from sqlalchemy import ForeignKey
8
+ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationship
9
+ from typing_extensions import Tuple, List, Set
10
+ from ucimlrepo import fetch_ucirepo
11
+
12
+ from .datastructures import Case, create_rows_from_dataframe, Category, Column
13
+
14
+
15
+ def load_cached_dataset(cache_file):
16
+ """Loads the dataset from cache if it exists."""
17
+ dataset = {}
18
+ for key in ["features", "targets", "ids"]:
19
+ part_file = cache_file.replace(".pkl", f"_{key}.pkl")
20
+ if not os.path.exists(part_file):
21
+ return None
22
+ with open(part_file, "rb") as f:
23
+ dataset[key] = pickle.load(f)
24
+ return dataset
25
+
26
+
27
+ def save_dataset_to_cache(dataset, cache_file):
28
+ """Saves only essential parts of the dataset to cache."""
29
+ dataset_to_cache = {
30
+ "features": dataset.data.features,
31
+ "targets": dataset.data.targets,
32
+ "ids": dataset.data.ids,
33
+ }
34
+
35
+ for key, value in dataset_to_cache.items():
36
+ with open(cache_file.replace(".pkl", f"_{key}.pkl"), "wb") as f:
37
+ pickle.dump(dataset_to_cache[key], f)
38
+ print("Dataset cached successfully.")
39
+
40
+
41
+ def get_dataset(dataset_id, cache_file):
42
+ """Fetches dataset from cache or downloads it if not available."""
43
+ dataset = load_cached_dataset(cache_file)
44
+ if dataset is None:
45
+ print("Downloading dataset...")
46
+ dataset = fetch_ucirepo(id=dataset_id)
47
+
48
+ # Check if dataset is valid before caching
49
+ if dataset is None or not hasattr(dataset, "data"):
50
+ print("Error: Failed to fetch dataset.")
51
+ return None
52
+
53
+ save_dataset_to_cache(dataset, cache_file)
54
+
55
+ return dataset
56
+
57
+
58
+ def load_zoo_dataset(cache_file: str) -> Tuple[List[Case], List[Species]]:
59
+ """
60
+ Load the zoo dataset.
61
+
62
+ :param cache_file: the cache file.
63
+ :return: all cases and targets.
64
+ """
65
+ # fetch dataset
66
+ zoo = get_dataset(111, cache_file)
67
+
68
+ # data (as pandas dataframes)
69
+ X = zoo['features']
70
+ y = zoo['targets']
71
+ # get ids as list of strings
72
+ ids = zoo['ids'].values.flatten()
73
+ all_cases = create_rows_from_dataframe(X, "Animal")
74
+
75
+ category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
76
+ category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
77
+ targets = [getattr(SpeciesCol, category_id_to_name[i]) for i in y.values.flatten()]
78
+ return all_cases, targets
79
+
80
+
81
+ class Species(Category):
82
+ mammal = "mammal"
83
+ bird = "bird"
84
+ reptile = "reptile"
85
+ fish = "fish"
86
+ amphibian = "amphibian"
87
+ insect = "insect"
88
+ molusc = "molusc"
89
+
90
+
91
+ class Habitat(Category):
92
+ """
93
+ A habitat category is a category that represents the habitat of an animal.
94
+ """
95
+ land = "land"
96
+ water = "water"
97
+ air = "air"
98
+
99
+
100
+ SpeciesCol = Column.create_from_enum(Species, mutually_exclusive=True)
101
+ HabitatCol = Column.create_from_enum(Habitat, mutually_exclusive=False)
102
+
103
+
104
+ class Base(sqlalchemy.orm.DeclarativeBase):
105
+ pass
106
+
107
+
108
+ class HabitatTable(MappedAsDataclass, Base):
109
+ __tablename__ = "Habitat"
110
+
111
+ id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
112
+ habitat: Mapped[Habitat]
113
+ animal_id = mapped_column(ForeignKey("Animal.id"), init=False)
114
+
115
+ def __hash__(self):
116
+ return hash(self.habitat)
117
+
118
+ def __str__(self):
119
+ return self.habitat.value
120
+
121
+ def __repr__(self):
122
+ return self.__str__()
123
+
124
+
125
+ class Animal(MappedAsDataclass, Base):
126
+ __tablename__ = "Animal"
127
+
128
+ id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
129
+ name: Mapped[str]
130
+ hair: Mapped[bool]
131
+ feathers: Mapped[bool]
132
+ eggs: Mapped[bool]
133
+ milk: Mapped[bool]
134
+ airborne: Mapped[bool]
135
+ aquatic: Mapped[bool]
136
+ predator: Mapped[bool]
137
+ toothed: Mapped[bool]
138
+ backbone: Mapped[bool]
139
+ breathes: Mapped[bool]
140
+ venomous: Mapped[bool]
141
+ fins: Mapped[bool]
142
+ legs: Mapped[int]
143
+ tail: Mapped[bool]
144
+ domestic: Mapped[bool]
145
+ catsize: Mapped[bool]
146
+ species: Mapped[Species] = mapped_column(nullable=True)
147
+
148
+ habitats: Mapped[Set[HabitatTable]] = relationship(default_factory=set)
@@ -0,0 +1,4 @@
1
+ from .enums import *
2
+ from .dataclasses import *
3
+ from .callable_expression import *
4
+ from .table import *
@@ -0,0 +1,237 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import logging
5
+ from _ast import AST
6
+
7
+ from sqlalchemy.orm import Session
8
+ from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
9
+
10
+ from .table import create_row, Row
11
+ from ..utils import SubclassJSONSerializer, get_full_class_name
12
+
13
+
14
+ class VariableVisitor(ast.NodeVisitor):
15
+ """
16
+ A visitor to extract all variables and comparisons from a python expression represented as an AST tree.
17
+ """
18
+ compares: List[Tuple[Union[ast.Name, ast.Call], ast.cmpop, Union[ast.Name, ast.Call]]]
19
+ variables: Set[str]
20
+ all: List[ast.BoolOp]
21
+
22
+ def __init__(self):
23
+ self.variables = set()
24
+ self.attributes: Dict[ast.Name, ast.Attribute] = {}
25
+ self.compares = list()
26
+ self.binary_ops = list()
27
+ self.all = list()
28
+
29
+ def visit_Attribute(self, node):
30
+ self.all.append(node)
31
+ self.attributes[node.value] = node
32
+ self.generic_visit(node)
33
+
34
+ def visit_BinOp(self, node):
35
+ self.binary_ops.append(node)
36
+ self.all.append(node)
37
+ self.generic_visit(node)
38
+
39
+ def visit_BoolOp(self, node):
40
+ self.all.append(node)
41
+ self.generic_visit(node)
42
+
43
+ def visit_Compare(self, node):
44
+ self.all.append(node)
45
+ self.compares.append([node.left, node.ops[0], node.comparators[0]])
46
+ self.generic_visit(node)
47
+
48
+ def visit_Name(self, node):
49
+ if f"__{node.id}__" not in dir(__builtins__) and node not in self.attributes:
50
+ self.variables.add(node.id)
51
+ self.generic_visit(node)
52
+
53
+
54
+ class CallableExpression(SubclassJSONSerializer):
55
+ """
56
+ A callable that is constructed from a string statement written by an expert.
57
+ """
58
+ conclusion_type: Type
59
+ """
60
+ The type of the output of the callable, used for assertion.
61
+ """
62
+ expression_tree: AST
63
+ """
64
+ The AST tree parsed from the user input.
65
+ """
66
+ user_input: str
67
+ """
68
+ The input given by the expert.
69
+ """
70
+ session: Optional[Session]
71
+ """
72
+ The sqlalchemy orm session.
73
+ """
74
+ visitor: VariableVisitor
75
+ """
76
+ A visitor to extract all variables and comparisons from a python expression represented as an AST tree.
77
+ """
78
+ code: Any
79
+ """
80
+ The code that was compiled from the expression tree
81
+ """
82
+ compares_column_offset: List[int]
83
+ """
84
+ The start and end indices of each comparison in the string of user input.
85
+ """
86
+
87
+ def __init__(self, user_input: str, conclusion_type: Optional[Type] = None, expression_tree: Optional[AST] = None,
88
+ session: Optional[Session] = None):
89
+ """
90
+ Create a callable expression.
91
+
92
+ :param user_input: The input given by the expert.
93
+ :param conclusion_type: The type of the output of the callable.
94
+ :param expression_tree: The AST tree parsed from the user input.
95
+ :param session: The sqlalchemy orm session.
96
+ """
97
+ self.session = session
98
+ self.user_input: str = user_input
99
+ self.parsed_user_input = self.parse_user_input(user_input, session)
100
+ self.conclusion_type = conclusion_type
101
+ self.update_expression(self.parsed_user_input, expression_tree)
102
+
103
+ @staticmethod
104
+ def parse_user_input(user_input: str, session: Optional[Session] = None) -> str:
105
+ if ',' in user_input:
106
+ user_input = user_input.split(',')
107
+ user_input = [f"({u.strip()})" for u in user_input]
108
+ user_input = ' & '.join(user_input) if session else ' and '.join(user_input)
109
+ elif session:
110
+ user_input = user_input.replace(" and ", " & ")
111
+ user_input = user_input.replace(" or ", " | ")
112
+ return user_input
113
+
114
+ def update_expression(self, user_input: str, expression_tree: Optional[AST] = None):
115
+ if not expression_tree:
116
+ expression_tree = parse_string_to_expression(user_input)
117
+ self.expression_tree: AST = expression_tree
118
+ self.visitor = VariableVisitor()
119
+ self.visitor.visit(expression_tree)
120
+ variables_str = self.visitor.variables
121
+ attributes_str = get_attributes_str(self.visitor)
122
+ for v in variables_str | attributes_str:
123
+ if not v.startswith("case."):
124
+ self.parsed_user_input = self.parsed_user_input.replace(v, f"case.{v}")
125
+ self.expression_tree = parse_string_to_expression(self.parsed_user_input)
126
+ self.compares_column_offset = [(c[0].col_offset, c[2].end_col_offset) for c in self.visitor.compares]
127
+ self.code = compile_expression_to_code(self.expression_tree)
128
+
129
+ def __call__(self, case: Any, **kwargs) -> Any:
130
+ try:
131
+ if not isinstance(case, Row):
132
+ case = create_row(case, max_recursion_idx=3)
133
+ output = eval(self.code)
134
+ if self.conclusion_type:
135
+ assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
136
+ f" got {type(output)}")
137
+ return output
138
+ except Exception as e:
139
+ raise ValueError(f"Error during evaluation: {e}")
140
+
141
+ def combine_with(self, other: 'CallableExpression') -> 'CallableExpression':
142
+ """
143
+ Combine this callable expression with another callable expression using the 'and' operator.
144
+ """
145
+ new_user_input = f"({self.user_input}) and ({other.user_input})"
146
+ return CallableExpression(new_user_input, conclusion_type=self.conclusion_type, session=self.session)
147
+
148
+ def __str__(self):
149
+ """
150
+ Return the user string where each compare is written in a line using compare column offset start and end.
151
+ """
152
+ user_input = self.parsed_user_input
153
+ binary_ops = sorted(self.visitor.binary_ops, key=lambda x: x.end_col_offset)
154
+ binary_ops_indices = [b.end_col_offset for b in binary_ops]
155
+ all_binary_ops = []
156
+ prev_e = 0
157
+ for i, e in enumerate(binary_ops_indices):
158
+ if i == 0:
159
+ all_binary_ops.append(user_input[:e])
160
+ else:
161
+ all_binary_ops.append(user_input[prev_e:e])
162
+ prev_e = e
163
+ return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
164
+
165
+ def to_json(self) -> Dict[str, Any]:
166
+ return {**SubclassJSONSerializer.to_json(self),
167
+ "user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type)}
168
+
169
+ @classmethod
170
+ def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
171
+ return cls(user_input=data["user_input"], conclusion_type=data["conclusion_type"])
172
+
173
+
174
+ def compile_expression_to_code(expression_tree: AST) -> Any:
175
+ """
176
+ Compile an expression tree that was parsed from string into code that can be executed using 'eval(code)'
177
+
178
+ :param expression_tree: The parsed expression tree.
179
+ :return: The code that was compiled from the expression tree.
180
+ """
181
+ return compile(expression_tree, filename="<string>", mode="eval")
182
+
183
+
184
+ def assert_context_contains_needed_information(case: Any, context: Dict[str, Any],
185
+ visitor: VariableVisitor) -> Tuple[Set[str], Set[str]]:
186
+ """
187
+ Asserts that the variables mentioned in the expression visited by visitor are all in the given context.
188
+
189
+ :param case: The case to check the context for.
190
+ :param context: The context to check.
191
+ :param visitor: The visitor that visited the expression.
192
+ :return: The found variables and attributes.
193
+ """
194
+ found_variables = set()
195
+ for key in visitor.variables:
196
+ if key not in context:
197
+ raise ValueError(f"Variable {key} not found in the case {case}")
198
+ found_variables.add(key)
199
+
200
+ found_attributes = get_attributes_str(visitor)
201
+ for attr in found_attributes:
202
+ if attr not in context:
203
+ raise ValueError(f"Attribute {attr} not found in the case {case}")
204
+ return found_variables, found_attributes
205
+
206
+
207
+ def get_attributes_str(visitor: VariableVisitor) -> Set[str]:
208
+ """
209
+ Get the string representation of the attributes in the given visitor.
210
+
211
+ :param visitor: The visitor that visited the expression.
212
+ :return: The string representation of the attributes.
213
+ """
214
+ found_attributes = set()
215
+ for key, ast_attr in visitor.attributes.items():
216
+ str_attr = ""
217
+ while isinstance(key, ast.Attribute):
218
+ if len(str_attr) > 0:
219
+ str_attr = f"{key.attr}.{str_attr}"
220
+ else:
221
+ str_attr = key.attr
222
+ key = key.value
223
+ str_attr = f"{key.id}.{str_attr}" if len(str_attr) > 0 else f"{key.id}.{ast_attr.attr}"
224
+ found_attributes.add(str_attr)
225
+ return found_attributes
226
+
227
+
228
+ def parse_string_to_expression(expression_str: str) -> AST:
229
+ """
230
+ Parse a string statement into an AST expression.
231
+
232
+ :param expression_str: The string which will be parsed.
233
+ :return: The parsed expression.
234
+ """
235
+ tree = ast.parse(expression_str, mode='eval')
236
+ logging.debug(f"AST parsed successfully: {ast.dump(tree)}")
237
+ return tree
@@ -0,0 +1,76 @@
1
+ from __future__ import annotations
2
+
3
+ from copy import copy, deepcopy
4
+ from dataclasses import dataclass
5
+
6
+ from sqlalchemy.orm import DeclarativeBase as SQLTable
7
+ from typing_extensions import Any, Optional, Type, Union
8
+
9
+ from .table import create_row, Case
10
+ from ..utils import get_attribute_name, copy_orm_instance_with_relationships, copy_case
11
+
12
+
13
+ @dataclass
14
+ class CaseQuery:
15
+ """
16
+ This is a dataclass that represents an attribute of an object and its target value. If attribute name is
17
+ not provided, it will be inferred from the attribute itself or from the attribute type or from the target value,
18
+ depending on what is provided.
19
+ """
20
+ case: Any
21
+ """
22
+ The case that the attribute belongs to.
23
+ """
24
+ attribute: Optional[Any] = None
25
+ """
26
+ The attribute itself.
27
+ """
28
+ targets: Optional[Any] = None
29
+ """
30
+ The target value of the attribute.
31
+ """
32
+ attribute_name: Optional[str] = None
33
+ """
34
+ The name of the attribute.
35
+ """
36
+ attribute_type: Optional[Type] = None
37
+ """
38
+ The type of the attribute.
39
+ """
40
+ relational_representation: Optional[str] = None
41
+ """
42
+ The representation of the target value in relational form.
43
+ """
44
+
45
+ def __init__(self, case: Any, attribute: Optional[Any] = None, target: Optional[Any] = None,
46
+ attribute_name: Optional[str] = None, attribute_type: Optional[Type] = None,
47
+ relational_representation: Optional[str] = None):
48
+
49
+ if attribute_name is None:
50
+ attribute_name = get_attribute_name(case, attribute, attribute_type, target)
51
+ self.attribute_name = attribute_name
52
+
53
+ if not isinstance(case, (Case, SQLTable)):
54
+ case = create_row(case, max_recursion_idx=3)
55
+ self.case = case
56
+
57
+ self.attribute = getattr(self.case, self.attribute_name) if self.attribute_name else None
58
+ self.attribute_type = type(self.attribute) if self.attribute else None
59
+ self.target = target
60
+ self.relational_representation = relational_representation
61
+
62
+ @property
63
+ def name(self):
64
+ return self.attribute_name if self.attribute_name else self.__class__.__name__
65
+
66
+ def __str__(self):
67
+ if self.relational_representation:
68
+ return f"{self.name} |= {self.relational_representation}"
69
+ else:
70
+ return f"{self.target}"
71
+
72
+ def __repr__(self):
73
+ return self.__str__()
74
+
75
+ def __copy__(self):
76
+ return CaseQuery(copy_case(self.case), attribute_name=self.attribute_name, target=self.target)
@@ -0,0 +1,173 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import auto, Enum
4
+
5
+ from typing_extensions import List
6
+
7
+
8
+ class Category(str, Enum):
9
+
10
+ @classmethod
11
+ def from_str(cls, value: str) -> Category:
12
+ return getattr(cls, value)
13
+
14
+ @classmethod
15
+ def from_strs(cls, values: List[str]) -> List[Category]:
16
+ return [cls.from_str(value) for value in values]
17
+
18
+ @property
19
+ def as_dict(self):
20
+ return {self.__class__.__name__.lower(): self.value}
21
+
22
+
23
+ class Stop(Category):
24
+ """
25
+ A stop category is a special category that represents the stopping of the classification to prevent a wrong
26
+ conclusion from being made.
27
+ """
28
+ stop = "stop"
29
+
30
+
31
+ class ExpressionParser(Enum):
32
+ """
33
+ Parsers for expressions to evaluate and encapsulate the expression into a callable function.
34
+ """
35
+ ASTVisitor: int = auto()
36
+ """
37
+ Generic python Abstract Syntax Tree that detects variables, attributes, binary/boolean expressions , ...etc.
38
+ """
39
+ SQLAlchemy: int = auto()
40
+ """
41
+ Specific for SQLAlchemy expressions on ORM Tables.
42
+ """
43
+
44
+
45
+ class PromptFor(Enum):
46
+ """
47
+ The reason of the prompt. (e.g. get conditions, or conclusions).
48
+ """
49
+ Conditions: str = "conditions"
50
+ """
51
+ Prompt for rule conditions about a case.
52
+ """
53
+ Conclusion: str = "conclusion"
54
+ """
55
+ Prompt for rule conclusion about a case.
56
+ """
57
+
58
+ def __str__(self):
59
+ return self.name
60
+
61
+ def __repr__(self):
62
+ return self.__str__()
63
+
64
+
65
+ class CategoricalValue(Enum):
66
+ """
67
+ A categorical value is a value that is a category.
68
+ """
69
+
70
+ def __eq__(self, other):
71
+ if isinstance(other, CategoricalValue):
72
+ return self.name == other.name
73
+ elif isinstance(other, str):
74
+ return self.name == other
75
+ return self.name == other
76
+
77
+ def __hash__(self):
78
+ return hash(self.name)
79
+
80
+ @classmethod
81
+ def to_list(cls):
82
+ return list(cls._value2member_map_.keys())
83
+
84
+ @classmethod
85
+ def from_str(cls, category: str):
86
+ return cls[category.lower()]
87
+
88
+ @classmethod
89
+ def from_strs(cls, categories: List[str]):
90
+ return [cls.from_str(c) for c in categories]
91
+
92
+ def __str__(self):
93
+ return self.name
94
+
95
+ def __repr__(self):
96
+ return self.__str__()
97
+
98
+
99
+ class RDRMode(Enum):
100
+ Propositional = auto()
101
+ """
102
+ Propositional mode, the mode where the rules are propositional.
103
+ """
104
+ Relational = auto()
105
+ """
106
+ Relational mode, the mode where the rules are relational.
107
+ """
108
+
109
+
110
+ class MCRDRMode(Enum):
111
+ """
112
+ The modes of the MultiClassRDR.
113
+ """
114
+ StopOnly = auto()
115
+ """
116
+ StopOnly mode, stop wrong conclusion from being made and does not add a new rule to make the correct conclusion.
117
+ """
118
+ StopPlusRule = auto()
119
+ """
120
+ StopPlusRule mode, stop wrong conclusion from being made and adds a new rule with same conditions as stopping rule
121
+ to make the correct conclusion.
122
+ """
123
+ StopPlusRuleCombined = auto()
124
+ """
125
+ StopPlusRuleCombined mode, stop wrong conclusion from being made and adds a new rule with combined conditions of
126
+ stopping rule and the rule that should have fired.
127
+ """
128
+
129
+
130
+ class RDREdge(Enum):
131
+ Refinement = "except if"
132
+ """
133
+ Refinement edge, the edge that represents the refinement of an incorrectly fired rule.
134
+ """
135
+ Alternative = "else if"
136
+ """
137
+ Alternative edge, the edge that represents the alternative to the rule that has not fired.
138
+ """
139
+ Next = "next"
140
+ """
141
+ Next edge, the edge that represents the next rule to be evaluated.
142
+ """
143
+
144
+
145
+ class ValueType(Enum):
146
+ Unary = auto()
147
+ """
148
+ Unary value type (eg. null).
149
+ """
150
+ Binary = auto()
151
+ """
152
+ Binary value type (eg. True, False).
153
+ """
154
+ Discrete = auto()
155
+ """
156
+ Discrete value type (eg. 1, 2, 3).
157
+ """
158
+ Continuous = auto()
159
+ """
160
+ Continuous value type (eg. 1.0, 2.5, 3.4).
161
+ """
162
+ Nominal = auto()
163
+ """
164
+ Nominal value type (eg. red, blue, green), categories where the values have no natural order.
165
+ """
166
+ Ordinal = auto()
167
+ """
168
+ Ordinal value type (eg. low, medium, high), categories where the values have a natural order.
169
+ """
170
+ Iterable = auto()
171
+ """
172
+ Iterable value type (eg. [1, 2, 3]).
173
+ """
File without changes