ripple-down-rules 0.0.12__tar.gz → 0.0.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/PKG-INFO +1 -1
  2. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/pyproject.toml +1 -1
  3. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datastructures/case.py +9 -1
  4. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/rdr.py +33 -13
  5. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/rules.py +22 -6
  6. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/utils.py +20 -3
  7. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
  8. ripple_down_rules-0.0.14/test/test_json_serialization.py +51 -0
  9. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/test/test_rdr.py +5 -24
  10. ripple_down_rules-0.0.12/test/test_json_serialization.py +0 -43
  11. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/LICENSE +0 -0
  12. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/README.md +0 -0
  13. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/setup.cfg +0 -0
  14. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/__init__.py +0 -0
  15. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datasets.py +0 -0
  16. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datastructures/__init__.py +0 -0
  17. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datastructures/callable_expression.py +0 -0
  18. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datastructures/dataclasses.py +0 -0
  19. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/datastructures/enums.py +0 -0
  20. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/experts.py +0 -0
  21. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/failures.py +0 -0
  22. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules/prompt.py +0 -0
  23. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules.egg-info/SOURCES.txt +0 -0
  24. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
  25. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
  26. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/test/test_rdr_alchemy.py +0 -0
  27. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/test/test_relational_rdr.py +0 -0
  28. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/test/test_relational_rdr_alchemy.py +0 -0
  29. {ripple_down_rules-0.0.12 → ripple_down_rules-0.0.14}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.12
3
+ Version: 0.0.14
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
 
7
7
  [project]
8
8
  name = "ripple_down_rules"
9
- version = "0.0.12"
9
+ version = "0.0.14"
10
10
  description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
11
11
  readme = "README.md"
12
12
  authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from collections import UserDict
4
+ from copy import copy, deepcopy
4
5
  from dataclasses import dataclass
5
6
  from enum import Enum
6
7
 
@@ -9,7 +10,8 @@ from sqlalchemy import MetaData
9
10
  from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColumn, registry
10
11
  from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING
11
12
 
12
- from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer
13
+ from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer, \
14
+ get_full_class_name, get_type_from_string
13
15
 
14
16
  if TYPE_CHECKING:
15
17
  from ripple_down_rules.rules import Rule
@@ -76,11 +78,17 @@ class Case(UserDict, SubclassJSONSerializer):
76
78
  def _to_json(self) -> Dict[str, Any]:
77
79
  serializable = {k: v for k, v in self.items() if not k.startswith("_")}
78
80
  serializable["_id"] = self._id
81
+ for k, v in serializable.items():
82
+ if isinstance(v, set):
83
+ serializable[k] = {'_type': get_full_class_name(set), 'value': list(v)}
79
84
  return {k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}
80
85
 
81
86
  @classmethod
82
87
  def _from_json(cls, data: Dict[str, Any]) -> Case:
83
88
  id_ = data.pop("_id")
89
+ for k, v in data.items():
90
+ if isinstance(v, dict) and "_type" in v:
91
+ data[k] = SubclassJSONSerializer.from_json(v)
84
92
  return cls(_id=id_, **data)
85
93
 
86
94
 
@@ -14,10 +14,10 @@ from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute,
14
14
  from .experts import Expert, Human
15
15
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
16
16
  from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
17
- get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list
17
+ get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_full_class_name, get_type_from_string
18
18
 
19
19
 
20
- class RippleDownRules(ABC):
20
+ class RippleDownRules(SubclassJSONSerializer, ABC):
21
21
  """
22
22
  The abstract base class for the ripple down rules classifiers.
23
23
  """
@@ -175,10 +175,7 @@ class RippleDownRules(ABC):
175
175
  return conclusion_type in case
176
176
 
177
177
 
178
- RDR = RippleDownRules
179
-
180
-
181
- class RDRWithCodeWriter(RDR, ABC):
178
+ class RDRWithCodeWriter(RippleDownRules, ABC):
182
179
 
183
180
  @abstractmethod
184
181
  def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = ""):
@@ -260,7 +257,7 @@ class RDRWithCodeWriter(RDR, ABC):
260
257
  return type(self.start_rule.conclusion)
261
258
 
262
259
 
263
- class SingleClassRDR(RDRWithCodeWriter, SubclassJSONSerializer):
260
+ class SingleClassRDR(RDRWithCodeWriter):
264
261
 
265
262
  def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
266
263
  -> Union[CaseAttribute, CallableExpression]:
@@ -353,17 +350,15 @@ class MultiClassRDR(RDRWithCodeWriter):
353
350
  The conditions of the stopping rule if needed.
354
351
  """
355
352
 
356
- def __init__(self, start_rules: Optional[List[Rule]] = None,
353
+ def __init__(self, start_rule: Optional[Rule] = None,
357
354
  mode: MCRDRMode = MCRDRMode.StopOnly, session: Optional[Session] = None):
358
355
  """
359
- :param start_rules: The starting rules for the classifier, these are the rules that are at the top of the tree
360
- and are always checked, in contrast to the refinement and alternative rules which are only checked if the
361
- starting rules fire or not.
356
+ :param start_rule: The starting rules for the classifier.
362
357
  :param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
363
358
  :param session: The sqlalchemy orm session.
364
359
  """
365
- self.start_rules = [MultiClassTopRule()] if not start_rules else start_rules
366
- super(MultiClassRDR, self).__init__(self.start_rules[0], session=session)
360
+ start_rule = MultiClassTopRule() if not start_rule else start_rule
361
+ super(MultiClassRDR, self).__init__(start_rule, session=session)
367
362
  self.mode: MCRDRMode = mode
368
363
 
369
364
  def classify(self, case: Union[Case, SQLTable]) -> List[Any]:
@@ -614,6 +609,17 @@ class MultiClassRDR(RDRWithCodeWriter):
614
609
  """
615
610
  self.start_rule.alternative = MultiClassTopRule(conditions, conclusion, corner_case=corner_case)
616
611
 
612
+ def _to_json(self) -> Dict[str, Any]:
613
+ return {"start_rule": self.start_rule.to_json()}
614
+
615
+ @classmethod
616
+ def _from_json(cls, data: Dict[str, Any]) -> Self:
617
+ """
618
+ Create an instance of the class from a json
619
+ """
620
+ start_rule = SingleClassRule.from_json(data["start_rule"])
621
+ return cls(start_rule)
622
+
617
623
 
618
624
  class GeneralRDR(RippleDownRules):
619
625
  """
@@ -795,3 +801,17 @@ class GeneralRDR(RippleDownRules):
795
801
  Get all the types of categories that the GRDR can classify.
796
802
  """
797
803
  return list(self.start_rules_dict.keys())
804
+
805
+ def _to_json(self) -> Dict[str, Any]:
806
+ return {"start_rules": {get_full_class_name(t): rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
807
+
808
+ @classmethod
809
+ def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
810
+ """
811
+ Create an instance of the class from a json
812
+ """
813
+ start_rules_dict = {}
814
+ for k, v in data["start_rules"].items():
815
+ k = get_type_from_string(k)
816
+ start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
817
+ return cls(start_rules_dict)
@@ -8,7 +8,7 @@ from typing_extensions import List, Optional, Self, Union, Dict, Any
8
8
 
9
9
  from .datastructures import CallableExpression, Case, SQLTable
10
10
  from .datastructures.enums import RDREdge, Stop
11
- from .utils import SubclassJSONSerializer
11
+ from .utils import SubclassJSONSerializer, is_iterable, get_full_class_name
12
12
 
13
13
 
14
14
  class Rule(NodeMixin, SubclassJSONSerializer, ABC):
@@ -102,8 +102,17 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
102
102
  pass
103
103
 
104
104
  def _to_json(self) -> Dict[str, Any]:
105
+ def conclusion_to_json(conclusion):
106
+ if is_iterable(conclusion):
107
+ conclusions = {'_type': get_full_class_name(type(conclusion)), 'value': []}
108
+ for c in conclusion:
109
+ conclusions['value'].append(conclusion_to_json(c))
110
+ else:
111
+ conclusions = conclusion.to_json()
112
+ return conclusions
113
+
105
114
  json_serialization = {"conditions": self.conditions.to_json(),
106
- "conclusion": self.conclusion.to_json(),
115
+ "conclusion": conclusion_to_json(self.conclusion),
107
116
  "parent": self.parent.json_serialization if self.parent else None,
108
117
  "corner_case": self.corner_case.to_json() if self.corner_case else None,
109
118
  "weight": self.weight}
@@ -265,14 +274,17 @@ class MultiClassStopRule(Rule, HasAlternativeRule):
265
274
 
266
275
  def _to_json(self) -> Dict[str, Any]:
267
276
  self.json_serialization = {**Rule._to_json(self),
268
- "top_rule": self.top_rule.to_json(),
269
277
  "alternative": self.alternative.to_json() if self.alternative is not None else None}
270
278
  return self.json_serialization
271
279
 
272
280
  @classmethod
273
281
  def _from_json(cls, data: Dict[str, Any]) -> MultiClassStopRule:
274
- loaded_rule = Rule._from_json(data)
275
- loaded_rule.top_rule = MultiClassTopRule.from_json(data["top_rule"])
282
+ loaded_rule = super(MultiClassStopRule, cls)._from_json(data)
283
+ # The following is done to prevent re-initialization of the top rule,
284
+ # so the top rule that is already initialized is passed in the data instead of its json serialization.
285
+ loaded_rule.top_rule = data['top_rule']
286
+ if data['alternative'] is not None:
287
+ data['alternative']['top_rule'] = data['top_rule']
276
288
  loaded_rule.alternative = MultiClassStopRule.from_json(data["alternative"])
277
289
  return loaded_rule
278
290
 
@@ -312,7 +324,11 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
312
324
 
313
325
  @classmethod
314
326
  def _from_json(cls, data: Dict[str, Any]) -> MultiClassTopRule:
315
- loaded_rule = Rule._from_json(data)
327
+ loaded_rule = super(MultiClassTopRule, cls)._from_json(data)
328
+ # The following is done to prevent re-initialization of the top rule,
329
+ # so the top rule that is already initialized is passed in the data instead of its json serialization.
330
+ if data['refinement'] is not None:
331
+ data['refinement']['top_rule'] = loaded_rule
316
332
  loaded_rule.refinement = MultiClassStopRule.from_json(data["refinement"])
317
333
  loaded_rule.alternative = MultiClassTopRule.from_json(data["alternative"])
318
334
  return loaded_rule
@@ -25,6 +25,16 @@ if TYPE_CHECKING:
25
25
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
26
26
 
27
27
 
28
+ def flatten_list(a: List):
29
+ a_flattened = []
30
+ for c in a:
31
+ if is_iterable(c):
32
+ a_flattened.extend(list(c))
33
+ else:
34
+ a_flattened.append(c)
35
+ return a_flattened
36
+
37
+
28
38
  def make_list(value: Any) -> List:
29
39
  """
30
40
  Make a list from a value.
@@ -97,15 +107,13 @@ class SubclassJSONSerializer:
97
107
  def to_json(self) -> Dict[str, Any]:
98
108
  return {"_type": get_full_class_name(self.__class__), **self._to_json()}
99
109
 
100
- @abstractmethod
101
110
  def _to_json(self) -> Dict[str, Any]:
102
111
  """
103
112
  Create a json dict from the object.
104
113
  """
105
- pass
114
+ raise NotImplementedError()
106
115
 
107
116
  @classmethod
108
- @abstractmethod
109
117
  def _from_json(cls, data: Dict[str, Any]) -> Self:
110
118
  """
111
119
  Create a variable from a json dict.
@@ -140,7 +148,16 @@ class SubclassJSONSerializer:
140
148
  """
141
149
  if data is None:
142
150
  return None
151
+ if not isinstance(data, dict) or ('_type' not in data):
152
+ return data
153
+ # check if type module is builtins
154
+ data_type = get_type_from_string(data["_type"])
155
+ if data_type.__module__ == 'builtins':
156
+ if is_iterable(data['value']) and not isinstance(data['value'], dict):
157
+ return data_type([cls.from_json(d) for d in data['value']])
158
+ return data_type(data["value"])
143
159
  if get_full_class_name(cls) == data["_type"]:
160
+ data.pop("_type")
144
161
  return cls._from_json(data)
145
162
  for subclass in recursive_subclasses(SubclassJSONSerializer):
146
163
  if get_full_class_name(subclass) == data["_type"]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.12
3
+ Version: 0.0.14
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -0,0 +1,51 @@
1
+ import json
2
+ from unittest import TestCase
3
+
4
+ from typing_extensions import List
5
+
6
+ from ripple_down_rules.datasets import load_zoo_dataset
7
+ from ripple_down_rules.datastructures import CaseQuery, Case
8
+ from ripple_down_rules.experts import Human
9
+ from ripple_down_rules.rdr import SingleClassRDR, MultiClassRDR, GeneralRDR
10
+ from ripple_down_rules.utils import make_set, flatten_list
11
+ from test_helpers.helpers import get_fit_mcrdr, get_fit_scrdr, get_fit_grdr
12
+
13
+
14
+ class TestJSONSerialization(TestCase):
15
+ all_cases: List[Case]
16
+ targets: List[str]
17
+ cache_dir: str = "./test_results"
18
+ expert_answers_dir: str = "./test_expert_answers"
19
+
20
+ @classmethod
21
+ def setUpClass(cls):
22
+ cls.all_cases, cls.targets = load_zoo_dataset(cls.cache_dir + "/zoo_dataset.pkl")
23
+
24
+ def test_scrdr_json_serialization(self):
25
+ scrdr = get_fit_scrdr(self.all_cases, self.targets)
26
+ filename = f"{self.cache_dir}/scrdr.json"
27
+ scrdr.save(filename)
28
+ scrdr = SingleClassRDR.load(filename)
29
+ for case, target in zip(self.all_cases, self.targets):
30
+ cat = scrdr.classify(case)
31
+ self.assertEqual(cat, target)
32
+
33
+ def test_mcrdr_json_serialization(self):
34
+ mcrdr = get_fit_mcrdr(self.all_cases, self.targets)
35
+ filename = f"{self.cache_dir}/mcrdr.json"
36
+ mcrdr.save(filename)
37
+ mcrdr = MultiClassRDR.load(filename)
38
+ for case, target in zip(self.all_cases, self.targets):
39
+ cat = mcrdr.classify(case)
40
+ self.assertEqual(make_set(cat), make_set(target))
41
+
42
+ def test_grdr_json_serialization(self):
43
+ grdr, all_targets = get_fit_grdr(self.all_cases, self.targets)
44
+ filename = f"{self.cache_dir}/grdr.json"
45
+ grdr.save(filename)
46
+ grdr = GeneralRDR.load(filename)
47
+ for case, case_targets in zip(self.all_cases[:len(all_targets)], all_targets):
48
+ cat = grdr.classify(case)
49
+ cat = flatten_list(cat)
50
+ case_targets = flatten_list(case_targets)
51
+ self.assertEqual(make_set(cat), make_set(case_targets))
@@ -11,6 +11,7 @@ from ripple_down_rules.datastructures import Case, MCRDRMode, \
11
11
  from ripple_down_rules.experts import Human
12
12
  from ripple_down_rules.rdr import SingleClassRDR, MultiClassRDR, GeneralRDR
13
13
  from ripple_down_rules.utils import render_tree, get_all_subclasses, make_set
14
+ from test_helpers.helpers import get_fit_scrdr, get_fit_mcrdr
14
15
 
15
16
 
16
17
  class TestRDR(TestCase):
@@ -71,7 +72,7 @@ class TestRDR(TestCase):
71
72
  expert.save_answers(file)
72
73
 
73
74
  def test_write_scrdr_to_python_file(self):
74
- scrdr = self.get_fit_scrdr()
75
+ scrdr = get_fit_scrdr(self.all_cases,self.targets)
75
76
  scrdr.write_to_python_file(self.generated_rdrs_dir)
76
77
  classify_species_scrdr = scrdr.get_rdr_classifier_from_python_file(self.generated_rdrs_dir)
77
78
  for case, target in zip(self.all_cases, self.targets):
@@ -79,7 +80,7 @@ class TestRDR(TestCase):
79
80
  self.assertEqual(cat, target)
80
81
 
81
82
  def test_write_mcrdr_to_python_file(self):
82
- mcrdr = self.get_fit_mcrdr()
83
+ mcrdr = get_fit_mcrdr(self.all_cases, self.targets)
83
84
  mcrdr.write_to_python_file(self.generated_rdrs_dir)
84
85
  classify_species_mcrdr = mcrdr.get_rdr_classifier_from_python_file(self.generated_rdrs_dir)
85
86
  for case, target in zip(self.all_cases, self.targets):
@@ -258,7 +259,7 @@ class TestRDR(TestCase):
258
259
  if use_loaded_answers:
259
260
  expert.load_answers(filename)
260
261
 
261
- fit_scrdr = self.get_fit_scrdr(draw_tree=False)
262
+ fit_scrdr = get_fit_scrdr(self.all_cases, self.targets, draw_tree=False)
262
263
 
263
264
  grdr = GeneralRDR({type(fit_scrdr.start_rule.conclusion): fit_scrdr})
264
265
 
@@ -309,7 +310,7 @@ class TestRDR(TestCase):
309
310
  if use_loaded_answers:
310
311
  expert.load_answers(filename)
311
312
 
312
- fit_scrdr = self.get_fit_scrdr(draw_tree=False)
313
+ fit_scrdr = get_fit_scrdr(self.all_cases, self.targets, draw_tree=False)
313
314
 
314
315
  grdr = GeneralRDR({type(fit_scrdr.start_rule.conclusion): fit_scrdr})
315
316
 
@@ -326,23 +327,3 @@ class TestRDR(TestCase):
326
327
  cwd = os.getcwd()
327
328
  file = os.path.join(cwd, filename)
328
329
  expert.save_answers(file)
329
-
330
- def get_fit_scrdr(self, draw_tree=False) -> SingleClassRDR:
331
- filename = self.expert_answers_dir + "/scrdr_expert_answers_fit"
332
- expert = Human(use_loaded_answers=True)
333
- expert.load_answers(filename)
334
-
335
- scrdr = SingleClassRDR()
336
- case_queries = [CaseQuery(case, target=target) for case, target in zip(self.all_cases, self.targets)]
337
- scrdr.fit(case_queries, expert=expert,
338
- animate_tree=draw_tree)
339
- return scrdr
340
-
341
- def get_fit_mcrdr(self, draw_tree: bool = False):
342
- filename = self.expert_answers_dir + "/mcrdr_expert_answers_stop_only_fit"
343
- expert = Human(use_loaded_answers=True)
344
- expert.load_answers(filename)
345
- mcrdr = MultiClassRDR()
346
- case_queries = [CaseQuery(case, target=target) for case, target in zip(self.all_cases, self.targets)]
347
- mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
348
- return mcrdr
@@ -1,43 +0,0 @@
1
- import json
2
- from unittest import TestCase
3
-
4
- from typing_extensions import List
5
-
6
- from ripple_down_rules.datasets import load_zoo_dataset
7
- from ripple_down_rules.datastructures import CaseQuery, Case
8
- from ripple_down_rules.experts import Human
9
- from ripple_down_rules.rdr import SingleClassRDR
10
-
11
-
12
- class TestJSONSerialization(TestCase):
13
- all_cases: List[Case]
14
- targets: List[str]
15
- cache_dir: str = "./test_results"
16
- expert_answers_dir: str = "./test_expert_answers"
17
-
18
- @classmethod
19
- def setUpClass(cls):
20
- cls.all_cases, cls.targets = load_zoo_dataset(cls.cache_dir + "/zoo_dataset.pkl")
21
-
22
- def test_scrdr_json_serialization(self):
23
- scrdr = self.get_fit_scrdr()
24
- filename = f"{self.cache_dir}/scrdr.json"
25
- scrdr.save(filename)
26
- scrdr = SingleClassRDR.load(filename)
27
- for case, target in zip(self.all_cases, self.targets):
28
- cat = scrdr.classify(case)
29
- self.assertEqual(cat, target)
30
-
31
- def get_fit_scrdr(self, draw_tree=False) -> SingleClassRDR:
32
- filename = self.expert_answers_dir + "/scrdr_expert_answers_fit"
33
- expert = Human(use_loaded_answers=True)
34
- expert.load_answers(filename)
35
-
36
- scrdr = SingleClassRDR()
37
- case_queries = [CaseQuery(case, target=target) for case, target in zip(self.all_cases, self.targets)]
38
- scrdr.fit(case_queries, expert=expert,
39
- animate_tree=draw_tree)
40
- for case, target in zip(self.all_cases, self.targets):
41
- cat = scrdr.classify(case)
42
- self.assertEqual(cat, target)
43
- return scrdr