ripple-down-rules 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/datastructures/callable_expression.py +2 -1
- ripple_down_rules/experts.py +11 -6
- ripple_down_rules/rdr.py +4 -2
- ripple_down_rules/utils.py +1 -1
- {ripple_down_rules-0.5.0.dist-info → ripple_down_rules-0.5.2.dist-info}/METADATA +140 -14
- {ripple_down_rules-0.5.0.dist-info → ripple_down_rules-0.5.2.dist-info}/RECORD +10 -11
- ripple_down_rules/datasets.py +0 -222
- {ripple_down_rules-0.5.0.dist-info → ripple_down_rules-0.5.2.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.5.0.dist-info → ripple_down_rules-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.5.0.dist-info → ripple_down_rules-0.5.2.dist-info}/top_level.txt +0 -0
ripple_down_rules/__init__.py
CHANGED
@@ -92,7 +92,8 @@ class CallableExpression(SubclassJSONSerializer):
|
|
92
92
|
"""
|
93
93
|
A callable that is constructed from a string statement written by an expert.
|
94
94
|
"""
|
95
|
-
|
95
|
+
encapsulating_function_name: str = "_get_value"
|
96
|
+
encapsulating_function: str = f"def {encapsulating_function_name}(case):"
|
96
97
|
|
97
98
|
def __init__(self, user_input: Optional[str] = None,
|
98
99
|
conclusion_type: Optional[Tuple[Type]] = None,
|
ripple_down_rules/experts.py
CHANGED
@@ -4,6 +4,7 @@ import ast
|
|
4
4
|
import json
|
5
5
|
import logging
|
6
6
|
import os
|
7
|
+
import uuid
|
7
8
|
from abc import ABC, abstractmethod
|
8
9
|
|
9
10
|
from typing_extensions import Optional, TYPE_CHECKING, List
|
@@ -145,7 +146,8 @@ class Expert(ABC):
|
|
145
146
|
else:
|
146
147
|
imports = ''
|
147
148
|
if func_source is not None:
|
148
|
-
|
149
|
+
uid = uuid.uuid4().hex
|
150
|
+
func_source = encapsulate_user_input(func_source, CallableExpression.encapsulating_function + f'_{uid}')
|
149
151
|
else:
|
150
152
|
func_source = 'pass # No user input provided for this case.\n'
|
151
153
|
f.write(imports + func_source + '\n' + '\n\n\n\'===New Answer===\'\n\n\n')
|
@@ -185,14 +187,17 @@ class Expert(ABC):
|
|
185
187
|
"""
|
186
188
|
file_path = path + '.py'
|
187
189
|
with open(file_path, "r") as f:
|
188
|
-
all_answers = f.read().split('\n\n\n\'===New Answer===\'\n\n\n')
|
189
|
-
|
190
|
+
all_answers = f.read().split('\n\n\n\'===New Answer===\'\n\n\n')[:-1]
|
191
|
+
all_function_sources = list(extract_function_source(file_path, []).values())
|
192
|
+
all_function_sources_names = list(extract_function_source(file_path, []).keys())
|
193
|
+
for i, answer in enumerate(all_answers):
|
190
194
|
answer = answer.strip('\n').strip()
|
191
195
|
if 'def ' not in answer and 'pass' in answer:
|
192
196
|
self.all_expert_answers.append(({}, None))
|
193
197
|
scope = extract_imports(tree=ast.parse(answer))
|
194
|
-
|
195
|
-
|
198
|
+
function_source = all_function_sources[i].replace(all_function_sources_names[i],
|
199
|
+
CallableExpression.encapsulating_function_name)
|
200
|
+
self.all_expert_answers.append((scope, function_source))
|
196
201
|
|
197
202
|
|
198
203
|
class Human(Expert):
|
@@ -212,7 +217,7 @@ class Human(Expert):
|
|
212
217
|
def ask_for_conditions(self, case_query: CaseQuery,
|
213
218
|
last_evaluated_rule: Optional[Rule] = None) \
|
214
219
|
-> CallableExpression:
|
215
|
-
if not self.use_loaded_answers and self.user_prompt.viewer is None:
|
220
|
+
if (not self.use_loaded_answers or len(self.all_expert_answers) == 0) and self.user_prompt.viewer is None:
|
216
221
|
show_current_and_corner_cases(case_query.case, {case_query.attribute_name: case_query.target_value},
|
217
222
|
last_evaluated_rule=last_evaluated_rule)
|
218
223
|
return self._get_conditions(case_query)
|
ripple_down_rules/rdr.py
CHANGED
@@ -191,8 +191,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
191
191
|
match = is_matching(self.classify, case_query, pred_cat)
|
192
192
|
if not match:
|
193
193
|
print(f"Predicted: {pred_cat} but expected: {target}")
|
194
|
-
if animate_tree and self.start_rule.
|
195
|
-
num_rules = self.start_rule.
|
194
|
+
if animate_tree and len(self.start_rule.descendants) > num_rules:
|
195
|
+
num_rules = len(self.start_rule.descendants)
|
196
196
|
self.update_figures()
|
197
197
|
i += 1
|
198
198
|
all_predictions = [1 if is_matching(self.classify, case_query) else 0 for case_query in case_queries
|
@@ -241,6 +241,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
241
241
|
self.case_type = case_query.case_type if self.case_type is None else self.case_type
|
242
242
|
self.case_name = case_query.case_name if self.case_name is None else self.case_name
|
243
243
|
|
244
|
+
expert = expert or Human(answers_save_path=self.save_dir + '/expert_answers' if self.save_dir else None)
|
245
|
+
|
244
246
|
if case_query.target is None:
|
245
247
|
case_query_cp = copy(case_query)
|
246
248
|
conclusions = self.classify(case_query_cp.case, modify_case=True)
|
ripple_down_rules/utils.py
CHANGED
@@ -186,7 +186,7 @@ def extract_function_source(file_path: str,
|
|
186
186
|
func_lines = func_lines[1:]
|
187
187
|
line_numbers.append((node.lineno, node.end_lineno))
|
188
188
|
functions_source[node.name] = dedent("\n".join(func_lines)) if join_lines else func_lines
|
189
|
-
if len(functions_source) >= len(function_names):
|
189
|
+
if (len(functions_source) >= len(function_names)) and (not len(function_names) == 0):
|
190
190
|
break
|
191
191
|
if len(functions_source) < len(function_names):
|
192
192
|
raise ValueError(f"Could not find all functions in {file_path}: {function_names} not found,"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.5.
|
3
|
+
Version: 0.5.2
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -693,6 +693,7 @@ Requires-Dist: colorama
|
|
693
693
|
Requires-Dist: pygments
|
694
694
|
Requires-Dist: sqlalchemy
|
695
695
|
Requires-Dist: pandas
|
696
|
+
Requires-Dist: pyparsing
|
696
697
|
Provides-Extra: viz
|
697
698
|
Requires-Dist: networkx>=3.1; extra == "viz"
|
698
699
|
Requires-Dist: matplotlib>=3.7.5; extra == "viz"
|
@@ -724,36 +725,161 @@ For GUI support, also install:
|
|
724
725
|
sudo apt-get install libxcb-cursor-dev
|
725
726
|
```
|
726
727
|
|
727
|
-
```bash
|
728
|
-
|
729
728
|
## Example Usage
|
730
729
|
|
731
|
-
|
732
|
-
|
730
|
+
### Propositional Example
|
731
|
+
|
732
|
+
By propositional, I mean that each rule conclusion is a propositional logic statement with a constant value.
|
733
|
+
|
734
|
+
For this example, we will use the [UCI Zoo dataset](https://archive.ics.uci.edu/ml/datasets/zoo) to classify animals
|
735
|
+
into their species based on their features. The dataset contains 101 animals with 16 features, and the target is th
|
736
|
+
e species of the animal.
|
737
|
+
|
738
|
+
To install the dataset:
|
739
|
+
```bash
|
740
|
+
pip install ucimlrepo
|
733
741
|
```
|
734
742
|
|
735
743
|
```python
|
744
|
+
from __future__ import annotations
|
736
745
|
from ripple_down_rules.datastructures.dataclasses import CaseQuery
|
737
|
-
from ripple_down_rules.
|
738
|
-
from ripple_down_rules.
|
746
|
+
from ripple_down_rules.datastructures.case import create_cases_from_dataframe
|
747
|
+
from ripple_down_rules.rdr import GeneralRDR
|
739
748
|
from ripple_down_rules.utils import render_tree
|
749
|
+
from ucimlrepo import fetch_ucirepo
|
750
|
+
from enum import Enum
|
740
751
|
|
741
|
-
|
752
|
+
class Species(str, Enum):
|
753
|
+
"""Enum for the species of the animals in the UCI Zoo dataset."""
|
754
|
+
mammal = "mammal"
|
755
|
+
bird = "bird"
|
756
|
+
reptile = "reptile"
|
757
|
+
fish = "fish"
|
758
|
+
amphibian = "amphibian"
|
759
|
+
insect = "insect"
|
760
|
+
molusc = "molusc"
|
761
|
+
|
762
|
+
@classmethod
|
763
|
+
def from_str(cls, value: str) -> Species:
|
764
|
+
return getattr(cls, value)
|
742
765
|
|
743
|
-
|
766
|
+
# fetch dataset
|
767
|
+
zoo = fetch_ucirepo(id=111)
|
744
768
|
|
745
|
-
#
|
769
|
+
# data (as pandas dataframes)
|
770
|
+
X = zoo.data.features
|
771
|
+
y = zoo.data.targets
|
772
|
+
|
773
|
+
# This is a utility that allows each row to be a Case instance,
|
774
|
+
# which simplifies access to column values using dot notation.
|
775
|
+
all_cases = create_cases_from_dataframe(X, name="Animal")
|
776
|
+
|
777
|
+
# The targets are the species of the animals
|
778
|
+
category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
|
779
|
+
category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
|
780
|
+
targets = [Species.from_str(category_id_to_name[i]) for i in y.values.flatten()]
|
781
|
+
|
782
|
+
# Now that we are done with the data preparation, we can create and use the Ripple Down Rules classifier.
|
783
|
+
grdr = GeneralRDR()
|
784
|
+
|
785
|
+
# Fit the GRDR to the data
|
746
786
|
case_queries = [CaseQuery(case, 'species', type(target), True, _target=target)
|
747
787
|
for case, target in zip(all_cases[:10], targets[:10])]
|
748
|
-
|
788
|
+
grdr.fit(case_queries, animate_tree=True)
|
749
789
|
|
750
790
|
# Render the tree to a file
|
751
|
-
render_tree(
|
791
|
+
render_tree(grdr.start_rules[0], use_dot_exporter=True, filename="species_rdr")
|
752
792
|
|
753
|
-
|
793
|
+
# Classify a case
|
794
|
+
cat = grdr.classify(all_cases[50])['species']
|
754
795
|
assert cat == targets[50]
|
755
796
|
```
|
756
797
|
|
798
|
+
### Relational Example
|
799
|
+
|
800
|
+
By relational, I mean that each rule conclusion is not a constant value, but is related to the case being classified,
|
801
|
+
you can understand it better by the next example.
|
802
|
+
|
803
|
+
In this example, we will create a simple robot with parts and use Ripple Down Rules to find the contained objects inside
|
804
|
+
another object, in this case, a robot. You see, the result of such a rule will vary depending on the robot
|
805
|
+
and the parts it has.
|
806
|
+
|
807
|
+
```python
|
808
|
+
from __future__ import annotations
|
809
|
+
|
810
|
+
import os.path
|
811
|
+
from dataclasses import dataclass, field
|
812
|
+
|
813
|
+
from typing_extensions import List, Optional
|
814
|
+
|
815
|
+
from ripple_down_rules.datastructures.dataclasses import CaseQuery
|
816
|
+
from ripple_down_rules.rdr import GeneralRDR
|
817
|
+
|
818
|
+
|
819
|
+
@dataclass(unsafe_hash=True)
|
820
|
+
class PhysicalObject:
|
821
|
+
"""
|
822
|
+
A physical object is an object that can be contained in a container.
|
823
|
+
"""
|
824
|
+
name: str
|
825
|
+
contained_objects: List[PhysicalObject] = field(default_factory=list, hash=False)
|
826
|
+
|
827
|
+
@dataclass(unsafe_hash=True)
|
828
|
+
class Part(PhysicalObject):
|
829
|
+
...
|
830
|
+
|
831
|
+
@dataclass(unsafe_hash=True)
|
832
|
+
class Robot(PhysicalObject):
|
833
|
+
parts: List[Part] = field(default_factory=list, hash=False)
|
834
|
+
|
835
|
+
|
836
|
+
part_a = Part(name="A")
|
837
|
+
part_b = Part(name="B")
|
838
|
+
part_c = Part(name="C")
|
839
|
+
robot = Robot("pr2", parts=[part_a])
|
840
|
+
part_a.contained_objects = [part_b]
|
841
|
+
part_b.contained_objects = [part_c]
|
842
|
+
|
843
|
+
case_query = CaseQuery(robot, "contained_objects", (PhysicalObject,), False)
|
844
|
+
|
845
|
+
load = True # Set to True if you want to load an existing model, False if you want to create a new one.
|
846
|
+
if load and os.path.exists('./part_containment_rdr'):
|
847
|
+
grdr = GeneralRDR.load('./', model_name='part_containment_rdr')
|
848
|
+
grdr.ask_always = False # Set to True if you want to always ask the expert for a target value.
|
849
|
+
else:
|
850
|
+
grdr = GeneralRDR(save_dir='./', model_name='part_containment_rdr')
|
851
|
+
|
852
|
+
grdr.fit_case(case_query)
|
853
|
+
|
854
|
+
print(grdr.classify(robot)['contained_objects'])
|
855
|
+
assert grdr.classify(robot)['contained_objects'] == {part_b}
|
856
|
+
```
|
857
|
+
|
858
|
+
When prompted to write a rule, I wrote the following inside the template function that the Ripple Down Rules created
|
859
|
+
for me, this function takes a `case` object as input:
|
860
|
+
|
861
|
+
```python
|
862
|
+
contained_objects = []
|
863
|
+
for part in case.parts:
|
864
|
+
contained_objects.extend(part.contained_objects)
|
865
|
+
return contained_objects
|
866
|
+
```
|
867
|
+
|
868
|
+
And then when asked for conditions, I wrote the following inside the template function that the Ripple Down Rules
|
869
|
+
created:
|
870
|
+
|
871
|
+
```python
|
872
|
+
return len(case.parts) > 0
|
873
|
+
```
|
874
|
+
|
875
|
+
This means that the rule will only be applied if the robot has parts.
|
876
|
+
|
877
|
+
If you notice, the result only contains part B, while one could say that part C is also contained in the robot, but,
|
878
|
+
the rule we wrote only returns the contained objects of the parts of the robot. To get part C, we would have to
|
879
|
+
add another rule that says that the contained objects of my contained objects are also contained in me, you can
|
880
|
+
try that yourself and see if it works!
|
881
|
+
|
882
|
+
|
757
883
|
## To Cite:
|
758
884
|
|
759
885
|
```bib
|
@@ -761,6 +887,6 @@ assert cat == targets[50]
|
|
761
887
|
author = {Bassiouny, Abdelrhman},
|
762
888
|
title = {Ripple-Down-Rules},
|
763
889
|
url = {https://github.com/AbdelrhmanBassiouny/ripple_down_rules},
|
764
|
-
version = {0.
|
890
|
+
version = {0.5.2},
|
765
891
|
}
|
766
892
|
```
|
@@ -1,15 +1,14 @@
|
|
1
|
-
ripple_down_rules/__init__.py,sha256=
|
2
|
-
ripple_down_rules/
|
3
|
-
ripple_down_rules/experts.py,sha256=9Vc3vx0uhDPy3YlNjwKuWJLl_A-kubRPUU6bMvQhaAg,13237
|
1
|
+
ripple_down_rules/__init__.py,sha256=K8GayszN_Ydn9s_OsfTRq83trUr2_x64cQwQX9gwF-E,99
|
2
|
+
ripple_down_rules/experts.py,sha256=tjCq_T_d2qc_DhyBlxfqoT3oHk6-HmKFZFqGZAdXUb0,13660
|
4
3
|
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
4
|
ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
|
6
|
-
ripple_down_rules/rdr.py,sha256=
|
5
|
+
ripple_down_rules/rdr.py,sha256=FJYuRXgpUYSSK1pYrp2yeXb_ZZ2xjPED31tzxofokL4,48865
|
7
6
|
ripple_down_rules/rdr_decorators.py,sha256=pYCKLgMKgQ6x_252WQtF2t4ZNjWPBxnaWtJ6TpGdcc0,7820
|
8
7
|
ripple_down_rules/rules.py,sha256=TPNVMqW9T-_46BS4WemrspLg5uG8kP6tsPvWWBAzJxg,17515
|
9
8
|
ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
|
10
|
-
ripple_down_rules/utils.py,sha256=
|
9
|
+
ripple_down_rules/utils.py,sha256=cv40XBj-tp11aRYcAPhPtkrYatMvAKk_1d5P7PB1-tw,51123
|
11
10
|
ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
|
12
|
-
ripple_down_rules/datastructures/callable_expression.py,sha256=
|
11
|
+
ripple_down_rules/datastructures/callable_expression.py,sha256=D2KD1RdShzxYZPAERgywZ5ZPE4ar8WmMtXINqvYo_Tc,12497
|
13
12
|
ripple_down_rules/datastructures/case.py,sha256=r8kjL9xP_wk84ThXusspgPMrAoed2bGQmKi54fzhmH8,15258
|
14
13
|
ripple_down_rules/datastructures/dataclasses.py,sha256=PuD-7zWqWT2p4FnGvnihHvZlZKg9A1ctnFgVYf2cs-8,8554
|
15
14
|
ripple_down_rules/datastructures/enums.py,sha256=ce7tqS0otfSTNAOwsnXlhsvIn4iW_Y_N3TNebF3YoZs,5700
|
@@ -19,8 +18,8 @@ ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=24MIFwqnAhC6ofOb
|
|
19
18
|
ripple_down_rules/user_interface/object_diagram.py,sha256=tsB6iuLNEbHxp5lR2WjyejjWbnAX_nHF9xS8jNPOQVk,4548
|
20
19
|
ripple_down_rules/user_interface/prompt.py,sha256=AkkltdDIaioN43lkRKDPKSjJcmdSSGZDMYz7AL7X9lE,8082
|
21
20
|
ripple_down_rules/user_interface/template_file_creator.py,sha256=ycCbddy_BJP8d0Q2Sj21UzamhGtqGZuK_e73VTJqznY,13766
|
22
|
-
ripple_down_rules-0.5.
|
23
|
-
ripple_down_rules-0.5.
|
24
|
-
ripple_down_rules-0.5.
|
25
|
-
ripple_down_rules-0.5.
|
26
|
-
ripple_down_rules-0.5.
|
21
|
+
ripple_down_rules-0.5.2.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
22
|
+
ripple_down_rules-0.5.2.dist-info/METADATA,sha256=O3NmfxnYkTpT9dNAkJ3nOEsG-oau7PMotx37fFSwnqQ,47688
|
23
|
+
ripple_down_rules-0.5.2.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
|
24
|
+
ripple_down_rules-0.5.2.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
25
|
+
ripple_down_rules-0.5.2.dist-info/RECORD,,
|
ripple_down_rules/datasets.py
DELETED
@@ -1,222 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import os
|
4
|
-
import pickle
|
5
|
-
from dataclasses import dataclass, field
|
6
|
-
|
7
|
-
import sqlalchemy
|
8
|
-
from sqlalchemy import ForeignKey
|
9
|
-
from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationship, MappedColumn
|
10
|
-
from typing_extensions import Tuple, List, Set, Optional, Self
|
11
|
-
from ucimlrepo import fetch_ucirepo
|
12
|
-
|
13
|
-
from .datastructures.case import Case, create_cases_from_dataframe
|
14
|
-
from .datastructures.enums import Category
|
15
|
-
from .rdr_decorators import RDRDecorator
|
16
|
-
|
17
|
-
|
18
|
-
def load_cached_dataset(cache_file):
|
19
|
-
"""Loads the dataset from cache if it exists."""
|
20
|
-
dataset = {}
|
21
|
-
if '.pkl' not in cache_file:
|
22
|
-
cache_file += ".pkl"
|
23
|
-
for key in ["features", "targets", "ids"]:
|
24
|
-
part_file = cache_file.replace(".pkl", f"_{key}.pkl")
|
25
|
-
if not os.path.exists(part_file):
|
26
|
-
return None
|
27
|
-
with open(part_file, "rb") as f:
|
28
|
-
dataset[key] = pickle.load(f)
|
29
|
-
return dataset
|
30
|
-
|
31
|
-
|
32
|
-
def save_dataset_to_cache(dataset, cache_file):
|
33
|
-
"""Saves only essential parts of the dataset to cache."""
|
34
|
-
dataset_to_cache = {
|
35
|
-
"features": dataset.data.features,
|
36
|
-
"targets": dataset.data.targets,
|
37
|
-
"ids": dataset.data.ids,
|
38
|
-
}
|
39
|
-
|
40
|
-
for key, value in dataset_to_cache.items():
|
41
|
-
with open(cache_file.replace(".pkl", f"_{key}.pkl"), "wb") as f:
|
42
|
-
pickle.dump(dataset_to_cache[key], f)
|
43
|
-
print("Dataset cached successfully.")
|
44
|
-
|
45
|
-
|
46
|
-
def get_dataset(dataset_id, cache_file: Optional[str] = None):
|
47
|
-
"""Fetches dataset from cache or downloads it if not available."""
|
48
|
-
if cache_file is not None:
|
49
|
-
if not cache_file.endswith(".pkl"):
|
50
|
-
cache_file += ".pkl"
|
51
|
-
dataset = load_cached_dataset(cache_file) if cache_file else None
|
52
|
-
if dataset is None:
|
53
|
-
print("Downloading dataset...")
|
54
|
-
dataset = fetch_ucirepo(id=dataset_id)
|
55
|
-
|
56
|
-
# Check if dataset is valid before caching
|
57
|
-
if dataset is None or not hasattr(dataset, "data"):
|
58
|
-
print("Error: Failed to fetch dataset.")
|
59
|
-
return None
|
60
|
-
|
61
|
-
if cache_file:
|
62
|
-
save_dataset_to_cache(dataset, cache_file)
|
63
|
-
|
64
|
-
dataset = {
|
65
|
-
"features": dataset.data.features,
|
66
|
-
"targets": dataset.data.targets,
|
67
|
-
"ids": dataset.data.ids,
|
68
|
-
}
|
69
|
-
|
70
|
-
return dataset
|
71
|
-
|
72
|
-
|
73
|
-
def load_zoo_dataset(cache_file: Optional[str] = None) -> Tuple[List[Case], List[Species]]:
|
74
|
-
"""
|
75
|
-
Load the zoo dataset.
|
76
|
-
|
77
|
-
:param cache_file: the cache file to store the dataset or load it from.
|
78
|
-
:return: all cases and targets.
|
79
|
-
"""
|
80
|
-
# fetch dataset
|
81
|
-
zoo = get_dataset(111, cache_file)
|
82
|
-
|
83
|
-
# data (as pandas dataframes)
|
84
|
-
X = zoo['features']
|
85
|
-
y = zoo['targets']
|
86
|
-
# get ids as list of strings
|
87
|
-
ids = zoo['ids'].values.flatten()
|
88
|
-
all_cases = create_cases_from_dataframe(X, name="Animal")
|
89
|
-
|
90
|
-
category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
|
91
|
-
category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
|
92
|
-
# targets = [getattr(SpeciesCol, category_id_to_name[i]) for i in y.values.flatten()]
|
93
|
-
targets = [Species.from_str(category_id_to_name[i]) for i in y.values.flatten()]
|
94
|
-
return all_cases, targets
|
95
|
-
|
96
|
-
|
97
|
-
class Species(Category):
|
98
|
-
mammal = "mammal"
|
99
|
-
bird = "bird"
|
100
|
-
reptile = "reptile"
|
101
|
-
fish = "fish"
|
102
|
-
amphibian = "amphibian"
|
103
|
-
insect = "insect"
|
104
|
-
molusc = "molusc"
|
105
|
-
|
106
|
-
|
107
|
-
class Habitat(Category):
|
108
|
-
"""
|
109
|
-
A habitat category is a category that represents the habitat of an animal.
|
110
|
-
"""
|
111
|
-
land = "land"
|
112
|
-
water = "water"
|
113
|
-
air = "air"
|
114
|
-
|
115
|
-
|
116
|
-
class PhysicalObject:
|
117
|
-
"""
|
118
|
-
A physical object is an object that can be contained in a container.
|
119
|
-
"""
|
120
|
-
_rdr_json_dir: str = os.path.join(os.path.dirname(__file__), "../../test/test_results")
|
121
|
-
"""
|
122
|
-
The directory where the RDR serialized JSON files are stored.
|
123
|
-
"""
|
124
|
-
_rdr_python_dir: str = os.path.join(os.path.dirname(__file__), "../../test/test_generated_rdrs")
|
125
|
-
"""
|
126
|
-
The directory where the RDR generated Python files are stored.
|
127
|
-
"""
|
128
|
-
_is_a_robot_rdr: RDRDecorator = RDRDecorator(_rdr_json_dir, (bool,), True,
|
129
|
-
python_dir=_rdr_python_dir)
|
130
|
-
"""
|
131
|
-
The RDR decorator that is used to determine if the object is a robot or not.
|
132
|
-
"""
|
133
|
-
_select_parts_rdr: RDRDecorator = RDRDecorator(_rdr_json_dir, (Self,), False,
|
134
|
-
python_dir=_rdr_python_dir)
|
135
|
-
"""
|
136
|
-
The RDR decorator that is used to determine if the object is a robot or not.
|
137
|
-
"""
|
138
|
-
|
139
|
-
def __init__(self, name: str, contained_objects: Optional[List[PhysicalObject]] = None):
|
140
|
-
self.name: str = name
|
141
|
-
self._contained_objects: List[PhysicalObject] = contained_objects or []
|
142
|
-
|
143
|
-
@property
|
144
|
-
def contained_objects(self) -> List[PhysicalObject]:
|
145
|
-
return self._contained_objects
|
146
|
-
|
147
|
-
@contained_objects.setter
|
148
|
-
def contained_objects(self, value: List[PhysicalObject]):
|
149
|
-
self._contained_objects = value
|
150
|
-
|
151
|
-
@_is_a_robot_rdr.decorator
|
152
|
-
def is_a_robot(self) -> bool:
|
153
|
-
pass
|
154
|
-
|
155
|
-
@_select_parts_rdr.decorator
|
156
|
-
def select_objects_that_are_parts_of_robot(self, objects: List[PhysicalObject], robot: Robot) -> List[PhysicalObject]:
|
157
|
-
pass
|
158
|
-
|
159
|
-
def __str__(self):
|
160
|
-
return self.name
|
161
|
-
|
162
|
-
def __repr__(self):
|
163
|
-
return self.name
|
164
|
-
|
165
|
-
|
166
|
-
class Part(PhysicalObject):
|
167
|
-
...
|
168
|
-
|
169
|
-
|
170
|
-
class Robot(PhysicalObject):
|
171
|
-
|
172
|
-
def __init__(self, name: str, parts: Optional[List[Part]] = None):
|
173
|
-
super().__init__(name)
|
174
|
-
self.parts: List[Part] = parts if parts else []
|
175
|
-
|
176
|
-
|
177
|
-
class Base(sqlalchemy.orm.DeclarativeBase):
|
178
|
-
pass
|
179
|
-
|
180
|
-
|
181
|
-
class HabitatTable(MappedAsDataclass, Base):
|
182
|
-
__tablename__ = "Habitat"
|
183
|
-
|
184
|
-
id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
|
185
|
-
habitat: Mapped[Habitat]
|
186
|
-
animal_id: MappedColumn = mapped_column(ForeignKey("Animal.id"), init=False)
|
187
|
-
|
188
|
-
def __hash__(self):
|
189
|
-
return hash(self.habitat)
|
190
|
-
|
191
|
-
def __str__(self):
|
192
|
-
return f"{HabitatTable.__name__}({Habitat.__name__}.{self.habitat.name})"
|
193
|
-
|
194
|
-
def __repr__(self):
|
195
|
-
return self.__str__()
|
196
|
-
|
197
|
-
|
198
|
-
class MappedAnimal(MappedAsDataclass, Base):
|
199
|
-
__tablename__ = "Animal"
|
200
|
-
|
201
|
-
id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
|
202
|
-
name: Mapped[str]
|
203
|
-
hair: Mapped[bool]
|
204
|
-
feathers: Mapped[bool]
|
205
|
-
eggs: Mapped[bool]
|
206
|
-
milk: Mapped[bool]
|
207
|
-
airborne: Mapped[bool]
|
208
|
-
aquatic: Mapped[bool]
|
209
|
-
predator: Mapped[bool]
|
210
|
-
toothed: Mapped[bool]
|
211
|
-
backbone: Mapped[bool]
|
212
|
-
breathes: Mapped[bool]
|
213
|
-
venomous: Mapped[bool]
|
214
|
-
fins: Mapped[bool]
|
215
|
-
legs: Mapped[int]
|
216
|
-
tail: Mapped[bool]
|
217
|
-
domestic: Mapped[bool]
|
218
|
-
catsize: Mapped[bool]
|
219
|
-
species: Mapped[Species] = mapped_column(nullable=True)
|
220
|
-
|
221
|
-
habitats: Mapped[Set[HabitatTable]] = relationship(default_factory=set)
|
222
|
-
|
File without changes
|
File without changes
|
File without changes
|