ripple-down-rules 0.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules-0.0.0/PKG-INFO +54 -0
- ripple_down_rules-0.0.0/README.md +42 -0
- ripple_down_rules-0.0.0/pyproject.toml +23 -0
- ripple_down_rules-0.0.0/setup.cfg +4 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/__init__.py +0 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/datasets.py +148 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/datastructures/__init__.py +4 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/datastructures/callable_expression.py +237 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/datastructures/dataclasses.py +76 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/datastructures/enums.py +173 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/datastructures/generated/__init__.py +0 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/datastructures/generated/column/__init__.py +0 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/datastructures/generated/row/__init__.py +0 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/datastructures/table.py +544 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/experts.py +281 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/failures.py +10 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/prompt.py +101 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/rdr.py +687 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/rules.py +260 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules/utils.py +463 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules.egg-info/PKG-INFO +54 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules.egg-info/SOURCES.txt +29 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules.egg-info/dependency_links.txt +1 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules.egg-info/requires.txt +1 -0
- ripple_down_rules-0.0.0/src/ripple_down_rules.egg-info/top_level.txt +1 -0
- ripple_down_rules-0.0.0/test/test_json_serialization.py +38 -0
- ripple_down_rules-0.0.0/test/test_rdr.py +309 -0
- ripple_down_rules-0.0.0/test/test_rdr_alchemy.py +154 -0
- ripple_down_rules-0.0.0/test/test_relational_rdr.py +108 -0
- ripple_down_rules-0.0.0/test/test_relational_rdr_alchemy.py +174 -0
- ripple_down_rules-0.0.0/test/test_sql_model.py +34 -0
@@ -0,0 +1,54 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: ripple_down_rules
|
3
|
+
Version: 0.0.0
|
4
|
+
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
|
+
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
|
+
Project-URL: Homepage, https://github.com/AbdelrhmanBassiouny/ripple_down_rules
|
7
|
+
Keywords: robotics,knowledge,reasoning,representation
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
9
|
+
Requires-Python: >=3.8
|
10
|
+
Description-Content-Type: text/markdown
|
11
|
+
Requires-Dist: neem_pycram_interface==1.0.167
|
12
|
+
|
13
|
+
# Ripple Down Rules (RDR)
|
14
|
+
|
15
|
+
A python implementation of the various ripple down rules versions, including Single Classification (SCRDR),
|
16
|
+
Multi Classification (MCRDR), and Generalised Ripple Down Rules (GRDR).
|
17
|
+
|
18
|
+
SCRDR, MCRDR, and GRDR are rule-based classifiers that are built incrementally, and can be used to classify
|
19
|
+
data cases. The rules are refined as new data cases are classified.
|
20
|
+
|
21
|
+
SCRDR, MCRDR, and GRDR implementation were inspired from the book:
|
22
|
+
["Ripple Down Rules: An Alternative to Machine Learning"](https://doi.org/10.1201/9781003126157) by Paul Compton, Byeong Ho Kang.
|
23
|
+
|
24
|
+
## Installation
|
25
|
+
|
26
|
+
```bash
|
27
|
+
sudo apt-get install graphviz graphviz-dev
|
28
|
+
pip install ripple_down_rules
|
29
|
+
```
|
30
|
+
|
31
|
+
## Example Usage
|
32
|
+
|
33
|
+
Fit the SCRDR to the data, then classify one of the data cases to check if its correct,
|
34
|
+
and render the tree to a file:
|
35
|
+
|
36
|
+
```Python
|
37
|
+
from ripple_down_rules.rdr import SingleClassRDR
|
38
|
+
from ripple_down_rules.datasets import load_zoo_dataset
|
39
|
+
from ripple_down_rules.utils import render_tree
|
40
|
+
|
41
|
+
all_cases, targets = load_zoo_dataset()
|
42
|
+
|
43
|
+
scrdr = SingleClassRDR()
|
44
|
+
|
45
|
+
# Fit the SCRDR to the data
|
46
|
+
scrdr.fit(all_cases, targets,
|
47
|
+
animate_tree=True, n_iter=10)
|
48
|
+
|
49
|
+
# Render the tree to a file
|
50
|
+
render_tree(scrdr.start_rule, use_dot_exporter=True, filename="scrdr")
|
51
|
+
|
52
|
+
cat = scrdr.fit_case(all_cases[50], targets[50])
|
53
|
+
assert cat == targets[50]
|
54
|
+
```
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# Ripple Down Rules (RDR)
|
2
|
+
|
3
|
+
A python implementation of the various ripple down rules versions, including Single Classification (SCRDR),
|
4
|
+
Multi Classification (MCRDR), and Generalised Ripple Down Rules (GRDR).
|
5
|
+
|
6
|
+
SCRDR, MCRDR, and GRDR are rule-based classifiers that are built incrementally, and can be used to classify
|
7
|
+
data cases. The rules are refined as new data cases are classified.
|
8
|
+
|
9
|
+
SCRDR, MCRDR, and GRDR implementation were inspired from the book:
|
10
|
+
["Ripple Down Rules: An Alternative to Machine Learning"](https://doi.org/10.1201/9781003126157) by Paul Compton, Byeong Ho Kang.
|
11
|
+
|
12
|
+
## Installation
|
13
|
+
|
14
|
+
```bash
|
15
|
+
sudo apt-get install graphviz graphviz-dev
|
16
|
+
pip install ripple_down_rules
|
17
|
+
```
|
18
|
+
|
19
|
+
## Example Usage
|
20
|
+
|
21
|
+
Fit the SCRDR to the data, then classify one of the data cases to check if its correct,
|
22
|
+
and render the tree to a file:
|
23
|
+
|
24
|
+
```Python
|
25
|
+
from ripple_down_rules.rdr import SingleClassRDR
|
26
|
+
from ripple_down_rules.datasets import load_zoo_dataset
|
27
|
+
from ripple_down_rules.utils import render_tree
|
28
|
+
|
29
|
+
all_cases, targets = load_zoo_dataset()
|
30
|
+
|
31
|
+
scrdr = SingleClassRDR()
|
32
|
+
|
33
|
+
# Fit the SCRDR to the data
|
34
|
+
scrdr.fit(all_cases, targets,
|
35
|
+
animate_tree=True, n_iter=10)
|
36
|
+
|
37
|
+
# Render the tree to a file
|
38
|
+
render_tree(scrdr.start_rule, use_dot_exporter=True, filename="scrdr")
|
39
|
+
|
40
|
+
cat = scrdr.fit_case(all_cases[50], targets[50])
|
41
|
+
assert cat == targets[50]
|
42
|
+
```
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# pyproject.toml
|
2
|
+
|
3
|
+
[build-system]
|
4
|
+
requires = ["setuptools>=61.0.0", "wheel"]
|
5
|
+
build-backend = "setuptools.build_meta"
|
6
|
+
|
7
|
+
[project]
|
8
|
+
name = "ripple_down_rules"
|
9
|
+
version = "0.0.0"
|
10
|
+
description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
|
11
|
+
readme = "README.md"
|
12
|
+
authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
|
13
|
+
license = { file = "LICENSE" }
|
14
|
+
classifiers = [
|
15
|
+
"Programming Language :: Python :: 3",
|
16
|
+
]
|
17
|
+
keywords = ["robotics", "knowledge", "reasoning", "representation"]
|
18
|
+
dependencies = ["neem_pycram_interface==1.0.167",
|
19
|
+
]
|
20
|
+
requires-python = ">=3.8"
|
21
|
+
|
22
|
+
[project.urls]
|
23
|
+
Homepage = "https://github.com/AbdelrhmanBassiouny/ripple_down_rules"
|
File without changes
|
@@ -0,0 +1,148 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
import pickle
|
5
|
+
|
6
|
+
import sqlalchemy
|
7
|
+
from sqlalchemy import ForeignKey
|
8
|
+
from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationship
|
9
|
+
from typing_extensions import Tuple, List, Set
|
10
|
+
from ucimlrepo import fetch_ucirepo
|
11
|
+
|
12
|
+
from .datastructures import Case, create_rows_from_dataframe, Category, Column
|
13
|
+
|
14
|
+
|
15
|
+
def load_cached_dataset(cache_file):
|
16
|
+
"""Loads the dataset from cache if it exists."""
|
17
|
+
dataset = {}
|
18
|
+
for key in ["features", "targets", "ids"]:
|
19
|
+
part_file = cache_file.replace(".pkl", f"_{key}.pkl")
|
20
|
+
if not os.path.exists(part_file):
|
21
|
+
return None
|
22
|
+
with open(part_file, "rb") as f:
|
23
|
+
dataset[key] = pickle.load(f)
|
24
|
+
return dataset
|
25
|
+
|
26
|
+
|
27
|
+
def save_dataset_to_cache(dataset, cache_file):
|
28
|
+
"""Saves only essential parts of the dataset to cache."""
|
29
|
+
dataset_to_cache = {
|
30
|
+
"features": dataset.data.features,
|
31
|
+
"targets": dataset.data.targets,
|
32
|
+
"ids": dataset.data.ids,
|
33
|
+
}
|
34
|
+
|
35
|
+
for key, value in dataset_to_cache.items():
|
36
|
+
with open(cache_file.replace(".pkl", f"_{key}.pkl"), "wb") as f:
|
37
|
+
pickle.dump(dataset_to_cache[key], f)
|
38
|
+
print("Dataset cached successfully.")
|
39
|
+
|
40
|
+
|
41
|
+
def get_dataset(dataset_id, cache_file):
|
42
|
+
"""Fetches dataset from cache or downloads it if not available."""
|
43
|
+
dataset = load_cached_dataset(cache_file)
|
44
|
+
if dataset is None:
|
45
|
+
print("Downloading dataset...")
|
46
|
+
dataset = fetch_ucirepo(id=dataset_id)
|
47
|
+
|
48
|
+
# Check if dataset is valid before caching
|
49
|
+
if dataset is None or not hasattr(dataset, "data"):
|
50
|
+
print("Error: Failed to fetch dataset.")
|
51
|
+
return None
|
52
|
+
|
53
|
+
save_dataset_to_cache(dataset, cache_file)
|
54
|
+
|
55
|
+
return dataset
|
56
|
+
|
57
|
+
|
58
|
+
def load_zoo_dataset(cache_file: str) -> Tuple[List[Case], List[Species]]:
|
59
|
+
"""
|
60
|
+
Load the zoo dataset.
|
61
|
+
|
62
|
+
:param cache_file: the cache file.
|
63
|
+
:return: all cases and targets.
|
64
|
+
"""
|
65
|
+
# fetch dataset
|
66
|
+
zoo = get_dataset(111, cache_file)
|
67
|
+
|
68
|
+
# data (as pandas dataframes)
|
69
|
+
X = zoo['features']
|
70
|
+
y = zoo['targets']
|
71
|
+
# get ids as list of strings
|
72
|
+
ids = zoo['ids'].values.flatten()
|
73
|
+
all_cases = create_rows_from_dataframe(X, "Animal")
|
74
|
+
|
75
|
+
category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
|
76
|
+
category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
|
77
|
+
targets = [getattr(SpeciesCol, category_id_to_name[i]) for i in y.values.flatten()]
|
78
|
+
return all_cases, targets
|
79
|
+
|
80
|
+
|
81
|
+
class Species(Category):
|
82
|
+
mammal = "mammal"
|
83
|
+
bird = "bird"
|
84
|
+
reptile = "reptile"
|
85
|
+
fish = "fish"
|
86
|
+
amphibian = "amphibian"
|
87
|
+
insect = "insect"
|
88
|
+
molusc = "molusc"
|
89
|
+
|
90
|
+
|
91
|
+
class Habitat(Category):
|
92
|
+
"""
|
93
|
+
A habitat category is a category that represents the habitat of an animal.
|
94
|
+
"""
|
95
|
+
land = "land"
|
96
|
+
water = "water"
|
97
|
+
air = "air"
|
98
|
+
|
99
|
+
|
100
|
+
SpeciesCol = Column.create_from_enum(Species, mutually_exclusive=True)
|
101
|
+
HabitatCol = Column.create_from_enum(Habitat, mutually_exclusive=False)
|
102
|
+
|
103
|
+
|
104
|
+
class Base(sqlalchemy.orm.DeclarativeBase):
|
105
|
+
pass
|
106
|
+
|
107
|
+
|
108
|
+
class HabitatTable(MappedAsDataclass, Base):
|
109
|
+
__tablename__ = "Habitat"
|
110
|
+
|
111
|
+
id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
|
112
|
+
habitat: Mapped[Habitat]
|
113
|
+
animal_id = mapped_column(ForeignKey("Animal.id"), init=False)
|
114
|
+
|
115
|
+
def __hash__(self):
|
116
|
+
return hash(self.habitat)
|
117
|
+
|
118
|
+
def __str__(self):
|
119
|
+
return self.habitat.value
|
120
|
+
|
121
|
+
def __repr__(self):
|
122
|
+
return self.__str__()
|
123
|
+
|
124
|
+
|
125
|
+
class Animal(MappedAsDataclass, Base):
|
126
|
+
__tablename__ = "Animal"
|
127
|
+
|
128
|
+
id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
|
129
|
+
name: Mapped[str]
|
130
|
+
hair: Mapped[bool]
|
131
|
+
feathers: Mapped[bool]
|
132
|
+
eggs: Mapped[bool]
|
133
|
+
milk: Mapped[bool]
|
134
|
+
airborne: Mapped[bool]
|
135
|
+
aquatic: Mapped[bool]
|
136
|
+
predator: Mapped[bool]
|
137
|
+
toothed: Mapped[bool]
|
138
|
+
backbone: Mapped[bool]
|
139
|
+
breathes: Mapped[bool]
|
140
|
+
venomous: Mapped[bool]
|
141
|
+
fins: Mapped[bool]
|
142
|
+
legs: Mapped[int]
|
143
|
+
tail: Mapped[bool]
|
144
|
+
domestic: Mapped[bool]
|
145
|
+
catsize: Mapped[bool]
|
146
|
+
species: Mapped[Species] = mapped_column(nullable=True)
|
147
|
+
|
148
|
+
habitats: Mapped[Set[HabitatTable]] = relationship(default_factory=set)
|
@@ -0,0 +1,237 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import ast
|
4
|
+
import logging
|
5
|
+
from _ast import AST
|
6
|
+
|
7
|
+
from sqlalchemy.orm import Session
|
8
|
+
from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
|
9
|
+
|
10
|
+
from .table import create_row, Row
|
11
|
+
from ..utils import SubclassJSONSerializer, get_full_class_name
|
12
|
+
|
13
|
+
|
14
|
+
class VariableVisitor(ast.NodeVisitor):
|
15
|
+
"""
|
16
|
+
A visitor to extract all variables and comparisons from a python expression represented as an AST tree.
|
17
|
+
"""
|
18
|
+
compares: List[Tuple[Union[ast.Name, ast.Call], ast.cmpop, Union[ast.Name, ast.Call]]]
|
19
|
+
variables: Set[str]
|
20
|
+
all: List[ast.BoolOp]
|
21
|
+
|
22
|
+
def __init__(self):
|
23
|
+
self.variables = set()
|
24
|
+
self.attributes: Dict[ast.Name, ast.Attribute] = {}
|
25
|
+
self.compares = list()
|
26
|
+
self.binary_ops = list()
|
27
|
+
self.all = list()
|
28
|
+
|
29
|
+
def visit_Attribute(self, node):
|
30
|
+
self.all.append(node)
|
31
|
+
self.attributes[node.value] = node
|
32
|
+
self.generic_visit(node)
|
33
|
+
|
34
|
+
def visit_BinOp(self, node):
|
35
|
+
self.binary_ops.append(node)
|
36
|
+
self.all.append(node)
|
37
|
+
self.generic_visit(node)
|
38
|
+
|
39
|
+
def visit_BoolOp(self, node):
|
40
|
+
self.all.append(node)
|
41
|
+
self.generic_visit(node)
|
42
|
+
|
43
|
+
def visit_Compare(self, node):
|
44
|
+
self.all.append(node)
|
45
|
+
self.compares.append([node.left, node.ops[0], node.comparators[0]])
|
46
|
+
self.generic_visit(node)
|
47
|
+
|
48
|
+
def visit_Name(self, node):
|
49
|
+
if f"__{node.id}__" not in dir(__builtins__) and node not in self.attributes:
|
50
|
+
self.variables.add(node.id)
|
51
|
+
self.generic_visit(node)
|
52
|
+
|
53
|
+
|
54
|
+
class CallableExpression(SubclassJSONSerializer):
|
55
|
+
"""
|
56
|
+
A callable that is constructed from a string statement written by an expert.
|
57
|
+
"""
|
58
|
+
conclusion_type: Type
|
59
|
+
"""
|
60
|
+
The type of the output of the callable, used for assertion.
|
61
|
+
"""
|
62
|
+
expression_tree: AST
|
63
|
+
"""
|
64
|
+
The AST tree parsed from the user input.
|
65
|
+
"""
|
66
|
+
user_input: str
|
67
|
+
"""
|
68
|
+
The input given by the expert.
|
69
|
+
"""
|
70
|
+
session: Optional[Session]
|
71
|
+
"""
|
72
|
+
The sqlalchemy orm session.
|
73
|
+
"""
|
74
|
+
visitor: VariableVisitor
|
75
|
+
"""
|
76
|
+
A visitor to extract all variables and comparisons from a python expression represented as an AST tree.
|
77
|
+
"""
|
78
|
+
code: Any
|
79
|
+
"""
|
80
|
+
The code that was compiled from the expression tree
|
81
|
+
"""
|
82
|
+
compares_column_offset: List[int]
|
83
|
+
"""
|
84
|
+
The start and end indices of each comparison in the string of user input.
|
85
|
+
"""
|
86
|
+
|
87
|
+
def __init__(self, user_input: str, conclusion_type: Optional[Type] = None, expression_tree: Optional[AST] = None,
|
88
|
+
session: Optional[Session] = None):
|
89
|
+
"""
|
90
|
+
Create a callable expression.
|
91
|
+
|
92
|
+
:param user_input: The input given by the expert.
|
93
|
+
:param conclusion_type: The type of the output of the callable.
|
94
|
+
:param expression_tree: The AST tree parsed from the user input.
|
95
|
+
:param session: The sqlalchemy orm session.
|
96
|
+
"""
|
97
|
+
self.session = session
|
98
|
+
self.user_input: str = user_input
|
99
|
+
self.parsed_user_input = self.parse_user_input(user_input, session)
|
100
|
+
self.conclusion_type = conclusion_type
|
101
|
+
self.update_expression(self.parsed_user_input, expression_tree)
|
102
|
+
|
103
|
+
@staticmethod
|
104
|
+
def parse_user_input(user_input: str, session: Optional[Session] = None) -> str:
|
105
|
+
if ',' in user_input:
|
106
|
+
user_input = user_input.split(',')
|
107
|
+
user_input = [f"({u.strip()})" for u in user_input]
|
108
|
+
user_input = ' & '.join(user_input) if session else ' and '.join(user_input)
|
109
|
+
elif session:
|
110
|
+
user_input = user_input.replace(" and ", " & ")
|
111
|
+
user_input = user_input.replace(" or ", " | ")
|
112
|
+
return user_input
|
113
|
+
|
114
|
+
def update_expression(self, user_input: str, expression_tree: Optional[AST] = None):
|
115
|
+
if not expression_tree:
|
116
|
+
expression_tree = parse_string_to_expression(user_input)
|
117
|
+
self.expression_tree: AST = expression_tree
|
118
|
+
self.visitor = VariableVisitor()
|
119
|
+
self.visitor.visit(expression_tree)
|
120
|
+
variables_str = self.visitor.variables
|
121
|
+
attributes_str = get_attributes_str(self.visitor)
|
122
|
+
for v in variables_str | attributes_str:
|
123
|
+
if not v.startswith("case."):
|
124
|
+
self.parsed_user_input = self.parsed_user_input.replace(v, f"case.{v}")
|
125
|
+
self.expression_tree = parse_string_to_expression(self.parsed_user_input)
|
126
|
+
self.compares_column_offset = [(c[0].col_offset, c[2].end_col_offset) for c in self.visitor.compares]
|
127
|
+
self.code = compile_expression_to_code(self.expression_tree)
|
128
|
+
|
129
|
+
def __call__(self, case: Any, **kwargs) -> Any:
|
130
|
+
try:
|
131
|
+
if not isinstance(case, Row):
|
132
|
+
case = create_row(case, max_recursion_idx=3)
|
133
|
+
output = eval(self.code)
|
134
|
+
if self.conclusion_type:
|
135
|
+
assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
|
136
|
+
f" got {type(output)}")
|
137
|
+
return output
|
138
|
+
except Exception as e:
|
139
|
+
raise ValueError(f"Error during evaluation: {e}")
|
140
|
+
|
141
|
+
def combine_with(self, other: 'CallableExpression') -> 'CallableExpression':
|
142
|
+
"""
|
143
|
+
Combine this callable expression with another callable expression using the 'and' operator.
|
144
|
+
"""
|
145
|
+
new_user_input = f"({self.user_input}) and ({other.user_input})"
|
146
|
+
return CallableExpression(new_user_input, conclusion_type=self.conclusion_type, session=self.session)
|
147
|
+
|
148
|
+
def __str__(self):
|
149
|
+
"""
|
150
|
+
Return the user string where each compare is written in a line using compare column offset start and end.
|
151
|
+
"""
|
152
|
+
user_input = self.parsed_user_input
|
153
|
+
binary_ops = sorted(self.visitor.binary_ops, key=lambda x: x.end_col_offset)
|
154
|
+
binary_ops_indices = [b.end_col_offset for b in binary_ops]
|
155
|
+
all_binary_ops = []
|
156
|
+
prev_e = 0
|
157
|
+
for i, e in enumerate(binary_ops_indices):
|
158
|
+
if i == 0:
|
159
|
+
all_binary_ops.append(user_input[:e])
|
160
|
+
else:
|
161
|
+
all_binary_ops.append(user_input[prev_e:e])
|
162
|
+
prev_e = e
|
163
|
+
return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
|
164
|
+
|
165
|
+
def to_json(self) -> Dict[str, Any]:
|
166
|
+
return {**SubclassJSONSerializer.to_json(self),
|
167
|
+
"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type)}
|
168
|
+
|
169
|
+
@classmethod
|
170
|
+
def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
|
171
|
+
return cls(user_input=data["user_input"], conclusion_type=data["conclusion_type"])
|
172
|
+
|
173
|
+
|
174
|
+
def compile_expression_to_code(expression_tree: AST) -> Any:
|
175
|
+
"""
|
176
|
+
Compile an expression tree that was parsed from string into code that can be executed using 'eval(code)'
|
177
|
+
|
178
|
+
:param expression_tree: The parsed expression tree.
|
179
|
+
:return: The code that was compiled from the expression tree.
|
180
|
+
"""
|
181
|
+
return compile(expression_tree, filename="<string>", mode="eval")
|
182
|
+
|
183
|
+
|
184
|
+
def assert_context_contains_needed_information(case: Any, context: Dict[str, Any],
|
185
|
+
visitor: VariableVisitor) -> Tuple[Set[str], Set[str]]:
|
186
|
+
"""
|
187
|
+
Asserts that the variables mentioned in the expression visited by visitor are all in the given context.
|
188
|
+
|
189
|
+
:param case: The case to check the context for.
|
190
|
+
:param context: The context to check.
|
191
|
+
:param visitor: The visitor that visited the expression.
|
192
|
+
:return: The found variables and attributes.
|
193
|
+
"""
|
194
|
+
found_variables = set()
|
195
|
+
for key in visitor.variables:
|
196
|
+
if key not in context:
|
197
|
+
raise ValueError(f"Variable {key} not found in the case {case}")
|
198
|
+
found_variables.add(key)
|
199
|
+
|
200
|
+
found_attributes = get_attributes_str(visitor)
|
201
|
+
for attr in found_attributes:
|
202
|
+
if attr not in context:
|
203
|
+
raise ValueError(f"Attribute {attr} not found in the case {case}")
|
204
|
+
return found_variables, found_attributes
|
205
|
+
|
206
|
+
|
207
|
+
def get_attributes_str(visitor: VariableVisitor) -> Set[str]:
|
208
|
+
"""
|
209
|
+
Get the string representation of the attributes in the given visitor.
|
210
|
+
|
211
|
+
:param visitor: The visitor that visited the expression.
|
212
|
+
:return: The string representation of the attributes.
|
213
|
+
"""
|
214
|
+
found_attributes = set()
|
215
|
+
for key, ast_attr in visitor.attributes.items():
|
216
|
+
str_attr = ""
|
217
|
+
while isinstance(key, ast.Attribute):
|
218
|
+
if len(str_attr) > 0:
|
219
|
+
str_attr = f"{key.attr}.{str_attr}"
|
220
|
+
else:
|
221
|
+
str_attr = key.attr
|
222
|
+
key = key.value
|
223
|
+
str_attr = f"{key.id}.{str_attr}" if len(str_attr) > 0 else f"{key.id}.{ast_attr.attr}"
|
224
|
+
found_attributes.add(str_attr)
|
225
|
+
return found_attributes
|
226
|
+
|
227
|
+
|
228
|
+
def parse_string_to_expression(expression_str: str) -> AST:
|
229
|
+
"""
|
230
|
+
Parse a string statement into an AST expression.
|
231
|
+
|
232
|
+
:param expression_str: The string which will be parsed.
|
233
|
+
:return: The parsed expression.
|
234
|
+
"""
|
235
|
+
tree = ast.parse(expression_str, mode='eval')
|
236
|
+
logging.debug(f"AST parsed successfully: {ast.dump(tree)}")
|
237
|
+
return tree
|
@@ -0,0 +1,76 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from copy import copy, deepcopy
|
4
|
+
from dataclasses import dataclass
|
5
|
+
|
6
|
+
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
7
|
+
from typing_extensions import Any, Optional, Type, Union
|
8
|
+
|
9
|
+
from .table import create_row, Case
|
10
|
+
from ..utils import get_attribute_name, copy_orm_instance_with_relationships, copy_case
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class CaseQuery:
|
15
|
+
"""
|
16
|
+
This is a dataclass that represents an attribute of an object and its target value. If attribute name is
|
17
|
+
not provided, it will be inferred from the attribute itself or from the attribute type or from the target value,
|
18
|
+
depending on what is provided.
|
19
|
+
"""
|
20
|
+
case: Any
|
21
|
+
"""
|
22
|
+
The case that the attribute belongs to.
|
23
|
+
"""
|
24
|
+
attribute: Optional[Any] = None
|
25
|
+
"""
|
26
|
+
The attribute itself.
|
27
|
+
"""
|
28
|
+
targets: Optional[Any] = None
|
29
|
+
"""
|
30
|
+
The target value of the attribute.
|
31
|
+
"""
|
32
|
+
attribute_name: Optional[str] = None
|
33
|
+
"""
|
34
|
+
The name of the attribute.
|
35
|
+
"""
|
36
|
+
attribute_type: Optional[Type] = None
|
37
|
+
"""
|
38
|
+
The type of the attribute.
|
39
|
+
"""
|
40
|
+
relational_representation: Optional[str] = None
|
41
|
+
"""
|
42
|
+
The representation of the target value in relational form.
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(self, case: Any, attribute: Optional[Any] = None, target: Optional[Any] = None,
|
46
|
+
attribute_name: Optional[str] = None, attribute_type: Optional[Type] = None,
|
47
|
+
relational_representation: Optional[str] = None):
|
48
|
+
|
49
|
+
if attribute_name is None:
|
50
|
+
attribute_name = get_attribute_name(case, attribute, attribute_type, target)
|
51
|
+
self.attribute_name = attribute_name
|
52
|
+
|
53
|
+
if not isinstance(case, (Case, SQLTable)):
|
54
|
+
case = create_row(case, max_recursion_idx=3)
|
55
|
+
self.case = case
|
56
|
+
|
57
|
+
self.attribute = getattr(self.case, self.attribute_name) if self.attribute_name else None
|
58
|
+
self.attribute_type = type(self.attribute) if self.attribute else None
|
59
|
+
self.target = target
|
60
|
+
self.relational_representation = relational_representation
|
61
|
+
|
62
|
+
@property
|
63
|
+
def name(self):
|
64
|
+
return self.attribute_name if self.attribute_name else self.__class__.__name__
|
65
|
+
|
66
|
+
def __str__(self):
|
67
|
+
if self.relational_representation:
|
68
|
+
return f"{self.name} |= {self.relational_representation}"
|
69
|
+
else:
|
70
|
+
return f"{self.target}"
|
71
|
+
|
72
|
+
def __repr__(self):
|
73
|
+
return self.__str__()
|
74
|
+
|
75
|
+
def __copy__(self):
|
76
|
+
return CaseQuery(copy_case(self.case), attribute_name=self.attribute_name, target=self.target)
|