ripple-down-rules 0.0.12__tar.gz → 0.0.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/PKG-INFO +1 -1
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/pyproject.toml +1 -1
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datastructures/case.py +9 -1
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/rdr.py +33 -13
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/rules.py +22 -6
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/utils.py +20 -3
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
- ripple_down_rules-0.0.14/test/test_json_serialization.py +51 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/test/test_rdr.py +5 -24
- ripple_down_rules-0.0.12/test/test_json_serialization.py +0 -43
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/LICENSE +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/README.md +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/setup.cfg +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/__init__.py +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datasets.py +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datastructures/__init__.py +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datastructures/callable_expression.py +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datastructures/dataclasses.py +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datastructures/enums.py +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/experts.py +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/failures.py +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/prompt.py +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules.egg-info/SOURCES.txt +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/test/test_rdr_alchemy.py +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/test/test_relational_rdr.py +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/test/test_relational_rdr_alchemy.py +0 -0
- {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.14
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
|
|
6
6
|
|
7
7
|
[project]
|
8
8
|
name = "ripple_down_rules"
|
9
|
-
version = "0.0.
|
9
|
+
version = "0.0.14"
|
10
10
|
description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
|
11
11
|
readme = "README.md"
|
12
12
|
authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
|
{ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datastructures/case.py
RENAMED
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from collections import UserDict
|
4
|
+
from copy import copy, deepcopy
|
4
5
|
from dataclasses import dataclass
|
5
6
|
from enum import Enum
|
6
7
|
|
@@ -9,7 +10,8 @@ from sqlalchemy import MetaData
|
|
9
10
|
from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColumn, registry
|
10
11
|
from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING
|
11
12
|
|
12
|
-
from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer
|
13
|
+
from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer, \
|
14
|
+
get_full_class_name, get_type_from_string
|
13
15
|
|
14
16
|
if TYPE_CHECKING:
|
15
17
|
from ripple_down_rules.rules import Rule
|
@@ -76,11 +78,17 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
76
78
|
def _to_json(self) -> Dict[str, Any]:
|
77
79
|
serializable = {k: v for k, v in self.items() if not k.startswith("_")}
|
78
80
|
serializable["_id"] = self._id
|
81
|
+
for k, v in serializable.items():
|
82
|
+
if isinstance(v, set):
|
83
|
+
serializable[k] = {'_type': get_full_class_name(set), 'value': list(v)}
|
79
84
|
return {k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}
|
80
85
|
|
81
86
|
@classmethod
|
82
87
|
def _from_json(cls, data: Dict[str, Any]) -> Case:
|
83
88
|
id_ = data.pop("_id")
|
89
|
+
for k, v in data.items():
|
90
|
+
if isinstance(v, dict) and "_type" in v:
|
91
|
+
data[k] = SubclassJSONSerializer.from_json(v)
|
84
92
|
return cls(_id=id_, **data)
|
85
93
|
|
86
94
|
|
@@ -14,10 +14,10 @@ from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute,
|
|
14
14
|
from .experts import Expert, Human
|
15
15
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
16
16
|
from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
|
17
|
-
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list
|
17
|
+
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_full_class_name, get_type_from_string
|
18
18
|
|
19
19
|
|
20
|
-
class RippleDownRules(ABC):
|
20
|
+
class RippleDownRules(SubclassJSONSerializer, ABC):
|
21
21
|
"""
|
22
22
|
The abstract base class for the ripple down rules classifiers.
|
23
23
|
"""
|
@@ -175,10 +175,7 @@ class RippleDownRules(ABC):
|
|
175
175
|
return conclusion_type in case
|
176
176
|
|
177
177
|
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
class RDRWithCodeWriter(RDR, ABC):
|
178
|
+
class RDRWithCodeWriter(RippleDownRules, ABC):
|
182
179
|
|
183
180
|
@abstractmethod
|
184
181
|
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = ""):
|
@@ -260,7 +257,7 @@ class RDRWithCodeWriter(RDR, ABC):
|
|
260
257
|
return type(self.start_rule.conclusion)
|
261
258
|
|
262
259
|
|
263
|
-
class SingleClassRDR(RDRWithCodeWriter
|
260
|
+
class SingleClassRDR(RDRWithCodeWriter):
|
264
261
|
|
265
262
|
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
266
263
|
-> Union[CaseAttribute, CallableExpression]:
|
@@ -353,17 +350,15 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
353
350
|
The conditions of the stopping rule if needed.
|
354
351
|
"""
|
355
352
|
|
356
|
-
def __init__(self,
|
353
|
+
def __init__(self, start_rule: Optional[Rule] = None,
|
357
354
|
mode: MCRDRMode = MCRDRMode.StopOnly, session: Optional[Session] = None):
|
358
355
|
"""
|
359
|
-
:param
|
360
|
-
and are always checked, in contrast to the refinement and alternative rules which are only checked if the
|
361
|
-
starting rules fire or not.
|
356
|
+
:param start_rule: The starting rules for the classifier.
|
362
357
|
:param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
|
363
358
|
:param session: The sqlalchemy orm session.
|
364
359
|
"""
|
365
|
-
|
366
|
-
super(MultiClassRDR, self).__init__(
|
360
|
+
start_rule = MultiClassTopRule() if not start_rule else start_rule
|
361
|
+
super(MultiClassRDR, self).__init__(start_rule, session=session)
|
367
362
|
self.mode: MCRDRMode = mode
|
368
363
|
|
369
364
|
def classify(self, case: Union[Case, SQLTable]) -> List[Any]:
|
@@ -614,6 +609,17 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
614
609
|
"""
|
615
610
|
self.start_rule.alternative = MultiClassTopRule(conditions, conclusion, corner_case=corner_case)
|
616
611
|
|
612
|
+
def _to_json(self) -> Dict[str, Any]:
|
613
|
+
return {"start_rule": self.start_rule.to_json()}
|
614
|
+
|
615
|
+
@classmethod
|
616
|
+
def _from_json(cls, data: Dict[str, Any]) -> Self:
|
617
|
+
"""
|
618
|
+
Create an instance of the class from a json
|
619
|
+
"""
|
620
|
+
start_rule = SingleClassRule.from_json(data["start_rule"])
|
621
|
+
return cls(start_rule)
|
622
|
+
|
617
623
|
|
618
624
|
class GeneralRDR(RippleDownRules):
|
619
625
|
"""
|
@@ -795,3 +801,17 @@ class GeneralRDR(RippleDownRules):
|
|
795
801
|
Get all the types of categories that the GRDR can classify.
|
796
802
|
"""
|
797
803
|
return list(self.start_rules_dict.keys())
|
804
|
+
|
805
|
+
def _to_json(self) -> Dict[str, Any]:
|
806
|
+
return {"start_rules": {get_full_class_name(t): rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
|
807
|
+
|
808
|
+
@classmethod
|
809
|
+
def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
|
810
|
+
"""
|
811
|
+
Create an instance of the class from a json
|
812
|
+
"""
|
813
|
+
start_rules_dict = {}
|
814
|
+
for k, v in data["start_rules"].items():
|
815
|
+
k = get_type_from_string(k)
|
816
|
+
start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
|
817
|
+
return cls(start_rules_dict)
|
@@ -8,7 +8,7 @@ from typing_extensions import List, Optional, Self, Union, Dict, Any
|
|
8
8
|
|
9
9
|
from .datastructures import CallableExpression, Case, SQLTable
|
10
10
|
from .datastructures.enums import RDREdge, Stop
|
11
|
-
from .utils import SubclassJSONSerializer
|
11
|
+
from .utils import SubclassJSONSerializer, is_iterable, get_full_class_name
|
12
12
|
|
13
13
|
|
14
14
|
class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
@@ -102,8 +102,17 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
102
102
|
pass
|
103
103
|
|
104
104
|
def _to_json(self) -> Dict[str, Any]:
|
105
|
+
def conclusion_to_json(conclusion):
|
106
|
+
if is_iterable(conclusion):
|
107
|
+
conclusions = {'_type': get_full_class_name(type(conclusion)), 'value': []}
|
108
|
+
for c in conclusion:
|
109
|
+
conclusions['value'].append(conclusion_to_json(c))
|
110
|
+
else:
|
111
|
+
conclusions = conclusion.to_json()
|
112
|
+
return conclusions
|
113
|
+
|
105
114
|
json_serialization = {"conditions": self.conditions.to_json(),
|
106
|
-
"conclusion": self.conclusion
|
115
|
+
"conclusion": conclusion_to_json(self.conclusion),
|
107
116
|
"parent": self.parent.json_serialization if self.parent else None,
|
108
117
|
"corner_case": self.corner_case.to_json() if self.corner_case else None,
|
109
118
|
"weight": self.weight}
|
@@ -265,14 +274,17 @@ class MultiClassStopRule(Rule, HasAlternativeRule):
|
|
265
274
|
|
266
275
|
def _to_json(self) -> Dict[str, Any]:
|
267
276
|
self.json_serialization = {**Rule._to_json(self),
|
268
|
-
"top_rule": self.top_rule.to_json(),
|
269
277
|
"alternative": self.alternative.to_json() if self.alternative is not None else None}
|
270
278
|
return self.json_serialization
|
271
279
|
|
272
280
|
@classmethod
|
273
281
|
def _from_json(cls, data: Dict[str, Any]) -> MultiClassStopRule:
|
274
|
-
loaded_rule =
|
275
|
-
|
282
|
+
loaded_rule = super(MultiClassStopRule, cls)._from_json(data)
|
283
|
+
# The following is done to prevent re-initialization of the top rule,
|
284
|
+
# so the top rule that is already initialized is passed in the data instead of its json serialization.
|
285
|
+
loaded_rule.top_rule = data['top_rule']
|
286
|
+
if data['alternative'] is not None:
|
287
|
+
data['alternative']['top_rule'] = data['top_rule']
|
276
288
|
loaded_rule.alternative = MultiClassStopRule.from_json(data["alternative"])
|
277
289
|
return loaded_rule
|
278
290
|
|
@@ -312,7 +324,11 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
|
312
324
|
|
313
325
|
@classmethod
|
314
326
|
def _from_json(cls, data: Dict[str, Any]) -> MultiClassTopRule:
|
315
|
-
loaded_rule =
|
327
|
+
loaded_rule = super(MultiClassTopRule, cls)._from_json(data)
|
328
|
+
# The following is done to prevent re-initialization of the top rule,
|
329
|
+
# so the top rule that is already initialized is passed in the data instead of its json serialization.
|
330
|
+
if data['refinement'] is not None:
|
331
|
+
data['refinement']['top_rule'] = loaded_rule
|
316
332
|
loaded_rule.refinement = MultiClassStopRule.from_json(data["refinement"])
|
317
333
|
loaded_rule.alternative = MultiClassTopRule.from_json(data["alternative"])
|
318
334
|
return loaded_rule
|
@@ -25,6 +25,16 @@ if TYPE_CHECKING:
|
|
25
25
|
matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
|
26
26
|
|
27
27
|
|
28
|
+
def flatten_list(a: List):
|
29
|
+
a_flattened = []
|
30
|
+
for c in a:
|
31
|
+
if is_iterable(c):
|
32
|
+
a_flattened.extend(list(c))
|
33
|
+
else:
|
34
|
+
a_flattened.append(c)
|
35
|
+
return a_flattened
|
36
|
+
|
37
|
+
|
28
38
|
def make_list(value: Any) -> List:
|
29
39
|
"""
|
30
40
|
Make a list from a value.
|
@@ -97,15 +107,13 @@ class SubclassJSONSerializer:
|
|
97
107
|
def to_json(self) -> Dict[str, Any]:
|
98
108
|
return {"_type": get_full_class_name(self.__class__), **self._to_json()}
|
99
109
|
|
100
|
-
@abstractmethod
|
101
110
|
def _to_json(self) -> Dict[str, Any]:
|
102
111
|
"""
|
103
112
|
Create a json dict from the object.
|
104
113
|
"""
|
105
|
-
|
114
|
+
raise NotImplementedError()
|
106
115
|
|
107
116
|
@classmethod
|
108
|
-
@abstractmethod
|
109
117
|
def _from_json(cls, data: Dict[str, Any]) -> Self:
|
110
118
|
"""
|
111
119
|
Create a variable from a json dict.
|
@@ -140,7 +148,16 @@ class SubclassJSONSerializer:
|
|
140
148
|
"""
|
141
149
|
if data is None:
|
142
150
|
return None
|
151
|
+
if not isinstance(data, dict) or ('_type' not in data):
|
152
|
+
return data
|
153
|
+
# check if type module is builtins
|
154
|
+
data_type = get_type_from_string(data["_type"])
|
155
|
+
if data_type.__module__ == 'builtins':
|
156
|
+
if is_iterable(data['value']) and not isinstance(data['value'], dict):
|
157
|
+
return data_type([cls.from_json(d) for d in data['value']])
|
158
|
+
return data_type(data["value"])
|
143
159
|
if get_full_class_name(cls) == data["_type"]:
|
160
|
+
data.pop("_type")
|
144
161
|
return cls._from_json(data)
|
145
162
|
for subclass in recursive_subclasses(SubclassJSONSerializer):
|
146
163
|
if get_full_class_name(subclass) == data["_type"]:
|
{ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules.egg-info/PKG-INFO
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.14
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -0,0 +1,51 @@
|
|
1
|
+
import json
|
2
|
+
from unittest import TestCase
|
3
|
+
|
4
|
+
from typing_extensions import List
|
5
|
+
|
6
|
+
from ripple_down_rules.datasets import load_zoo_dataset
|
7
|
+
from ripple_down_rules.datastructures import CaseQuery, Case
|
8
|
+
from ripple_down_rules.experts import Human
|
9
|
+
from ripple_down_rules.rdr import SingleClassRDR, MultiClassRDR, GeneralRDR
|
10
|
+
from ripple_down_rules.utils import make_set, flatten_list
|
11
|
+
from test_helpers.helpers import get_fit_mcrdr, get_fit_scrdr, get_fit_grdr
|
12
|
+
|
13
|
+
|
14
|
+
class TestJSONSerialization(TestCase):
|
15
|
+
all_cases: List[Case]
|
16
|
+
targets: List[str]
|
17
|
+
cache_dir: str = "./test_results"
|
18
|
+
expert_answers_dir: str = "./test_expert_answers"
|
19
|
+
|
20
|
+
@classmethod
|
21
|
+
def setUpClass(cls):
|
22
|
+
cls.all_cases, cls.targets = load_zoo_dataset(cls.cache_dir + "/zoo_dataset.pkl")
|
23
|
+
|
24
|
+
def test_scrdr_json_serialization(self):
|
25
|
+
scrdr = get_fit_scrdr(self.all_cases, self.targets)
|
26
|
+
filename = f"{self.cache_dir}/scrdr.json"
|
27
|
+
scrdr.save(filename)
|
28
|
+
scrdr = SingleClassRDR.load(filename)
|
29
|
+
for case, target in zip(self.all_cases, self.targets):
|
30
|
+
cat = scrdr.classify(case)
|
31
|
+
self.assertEqual(cat, target)
|
32
|
+
|
33
|
+
def test_mcrdr_json_serialization(self):
|
34
|
+
mcrdr = get_fit_mcrdr(self.all_cases, self.targets)
|
35
|
+
filename = f"{self.cache_dir}/mcrdr.json"
|
36
|
+
mcrdr.save(filename)
|
37
|
+
mcrdr = MultiClassRDR.load(filename)
|
38
|
+
for case, target in zip(self.all_cases, self.targets):
|
39
|
+
cat = mcrdr.classify(case)
|
40
|
+
self.assertEqual(make_set(cat), make_set(target))
|
41
|
+
|
42
|
+
def test_grdr_json_serialization(self):
|
43
|
+
grdr, all_targets = get_fit_grdr(self.all_cases, self.targets)
|
44
|
+
filename = f"{self.cache_dir}/grdr.json"
|
45
|
+
grdr.save(filename)
|
46
|
+
grdr = GeneralRDR.load(filename)
|
47
|
+
for case, case_targets in zip(self.all_cases[:len(all_targets)], all_targets):
|
48
|
+
cat = grdr.classify(case)
|
49
|
+
cat = flatten_list(cat)
|
50
|
+
case_targets = flatten_list(case_targets)
|
51
|
+
self.assertEqual(make_set(cat), make_set(case_targets))
|
@@ -11,6 +11,7 @@ from ripple_down_rules.datastructures import Case, MCRDRMode, \
|
|
11
11
|
from ripple_down_rules.experts import Human
|
12
12
|
from ripple_down_rules.rdr import SingleClassRDR, MultiClassRDR, GeneralRDR
|
13
13
|
from ripple_down_rules.utils import render_tree, get_all_subclasses, make_set
|
14
|
+
from test_helpers.helpers import get_fit_scrdr, get_fit_mcrdr
|
14
15
|
|
15
16
|
|
16
17
|
class TestRDR(TestCase):
|
@@ -71,7 +72,7 @@ class TestRDR(TestCase):
|
|
71
72
|
expert.save_answers(file)
|
72
73
|
|
73
74
|
def test_write_scrdr_to_python_file(self):
|
74
|
-
scrdr = self.
|
75
|
+
scrdr = get_fit_scrdr(self.all_cases,self.targets)
|
75
76
|
scrdr.write_to_python_file(self.generated_rdrs_dir)
|
76
77
|
classify_species_scrdr = scrdr.get_rdr_classifier_from_python_file(self.generated_rdrs_dir)
|
77
78
|
for case, target in zip(self.all_cases, self.targets):
|
@@ -79,7 +80,7 @@ class TestRDR(TestCase):
|
|
79
80
|
self.assertEqual(cat, target)
|
80
81
|
|
81
82
|
def test_write_mcrdr_to_python_file(self):
|
82
|
-
mcrdr = self.
|
83
|
+
mcrdr = get_fit_mcrdr(self.all_cases, self.targets)
|
83
84
|
mcrdr.write_to_python_file(self.generated_rdrs_dir)
|
84
85
|
classify_species_mcrdr = mcrdr.get_rdr_classifier_from_python_file(self.generated_rdrs_dir)
|
85
86
|
for case, target in zip(self.all_cases, self.targets):
|
@@ -258,7 +259,7 @@ class TestRDR(TestCase):
|
|
258
259
|
if use_loaded_answers:
|
259
260
|
expert.load_answers(filename)
|
260
261
|
|
261
|
-
fit_scrdr = self.
|
262
|
+
fit_scrdr = get_fit_scrdr(self.all_cases, self.targets, draw_tree=False)
|
262
263
|
|
263
264
|
grdr = GeneralRDR({type(fit_scrdr.start_rule.conclusion): fit_scrdr})
|
264
265
|
|
@@ -309,7 +310,7 @@ class TestRDR(TestCase):
|
|
309
310
|
if use_loaded_answers:
|
310
311
|
expert.load_answers(filename)
|
311
312
|
|
312
|
-
fit_scrdr = self.
|
313
|
+
fit_scrdr = get_fit_scrdr(self.all_cases, self.targets, draw_tree=False)
|
313
314
|
|
314
315
|
grdr = GeneralRDR({type(fit_scrdr.start_rule.conclusion): fit_scrdr})
|
315
316
|
|
@@ -326,23 +327,3 @@ class TestRDR(TestCase):
|
|
326
327
|
cwd = os.getcwd()
|
327
328
|
file = os.path.join(cwd, filename)
|
328
329
|
expert.save_answers(file)
|
329
|
-
|
330
|
-
def get_fit_scrdr(self, draw_tree=False) -> SingleClassRDR:
|
331
|
-
filename = self.expert_answers_dir + "/scrdr_expert_answers_fit"
|
332
|
-
expert = Human(use_loaded_answers=True)
|
333
|
-
expert.load_answers(filename)
|
334
|
-
|
335
|
-
scrdr = SingleClassRDR()
|
336
|
-
case_queries = [CaseQuery(case, target=target) for case, target in zip(self.all_cases, self.targets)]
|
337
|
-
scrdr.fit(case_queries, expert=expert,
|
338
|
-
animate_tree=draw_tree)
|
339
|
-
return scrdr
|
340
|
-
|
341
|
-
def get_fit_mcrdr(self, draw_tree: bool = False):
|
342
|
-
filename = self.expert_answers_dir + "/mcrdr_expert_answers_stop_only_fit"
|
343
|
-
expert = Human(use_loaded_answers=True)
|
344
|
-
expert.load_answers(filename)
|
345
|
-
mcrdr = MultiClassRDR()
|
346
|
-
case_queries = [CaseQuery(case, target=target) for case, target in zip(self.all_cases, self.targets)]
|
347
|
-
mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
|
348
|
-
return mcrdr
|
@@ -1,43 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
from unittest import TestCase
|
3
|
-
|
4
|
-
from typing_extensions import List
|
5
|
-
|
6
|
-
from ripple_down_rules.datasets import load_zoo_dataset
|
7
|
-
from ripple_down_rules.datastructures import CaseQuery, Case
|
8
|
-
from ripple_down_rules.experts import Human
|
9
|
-
from ripple_down_rules.rdr import SingleClassRDR
|
10
|
-
|
11
|
-
|
12
|
-
class TestJSONSerialization(TestCase):
|
13
|
-
all_cases: List[Case]
|
14
|
-
targets: List[str]
|
15
|
-
cache_dir: str = "./test_results"
|
16
|
-
expert_answers_dir: str = "./test_expert_answers"
|
17
|
-
|
18
|
-
@classmethod
|
19
|
-
def setUpClass(cls):
|
20
|
-
cls.all_cases, cls.targets = load_zoo_dataset(cls.cache_dir + "/zoo_dataset.pkl")
|
21
|
-
|
22
|
-
def test_scrdr_json_serialization(self):
|
23
|
-
scrdr = self.get_fit_scrdr()
|
24
|
-
filename = f"{self.cache_dir}/scrdr.json"
|
25
|
-
scrdr.save(filename)
|
26
|
-
scrdr = SingleClassRDR.load(filename)
|
27
|
-
for case, target in zip(self.all_cases, self.targets):
|
28
|
-
cat = scrdr.classify(case)
|
29
|
-
self.assertEqual(cat, target)
|
30
|
-
|
31
|
-
def get_fit_scrdr(self, draw_tree=False) -> SingleClassRDR:
|
32
|
-
filename = self.expert_answers_dir + "/scrdr_expert_answers_fit"
|
33
|
-
expert = Human(use_loaded_answers=True)
|
34
|
-
expert.load_answers(filename)
|
35
|
-
|
36
|
-
scrdr = SingleClassRDR()
|
37
|
-
case_queries = [CaseQuery(case, target=target) for case, target in zip(self.all_cases, self.targets)]
|
38
|
-
scrdr.fit(case_queries, expert=expert,
|
39
|
-
animate_tree=draw_tree)
|
40
|
-
for case, target in zip(self.all_cases, self.targets):
|
41
|
-
cat = scrdr.classify(case)
|
42
|
-
self.assertEqual(cat, target)
|
43
|
-
return scrdr
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datastructures/enums.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules.egg-info/SOURCES.txt
RENAMED
File without changes
|
File without changes
|
{ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules.egg-info/top_level.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|