ripple-down-rules 0.0.13__tar.gz → 0.0.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/PKG-INFO +1 -1
  2. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/pyproject.toml +1 -1
  3. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules/datastructures/case.py +9 -1
  4. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules/rdr.py +112 -14
  5. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules/rules.py +16 -3
  6. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules/utils.py +20 -3
  7. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
  8. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/test/test_json_serialization.py +14 -3
  9. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/test/test_rdr.py +12 -2
  10. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/LICENSE +0 -0
  11. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/README.md +0 -0
  12. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/setup.cfg +0 -0
  13. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules/__init__.py +0 -0
  14. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules/datasets.py +0 -0
  15. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules/datastructures/__init__.py +0 -0
  16. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules/datastructures/callable_expression.py +0 -0
  17. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules/datastructures/dataclasses.py +0 -0
  18. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules/datastructures/enums.py +0 -0
  19. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules/experts.py +0 -0
  20. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules/failures.py +0 -0
  21. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules/prompt.py +0 -0
  22. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules.egg-info/SOURCES.txt +0 -0
  23. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
  24. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
  25. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/test/test_rdr_alchemy.py +0 -0
  26. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/test/test_relational_rdr.py +0 -0
  27. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/test/test_relational_rdr_alchemy.py +0 -0
  28. {ripple_down_rules-0.0.13 → ripple_down_rules-0.0.15}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.13
3
+ Version: 0.0.15
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
 
7
7
  [project]
8
8
  name = "ripple_down_rules"
9
- version = "0.0.13"
9
+ version = "0.0.15"
10
10
  description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
11
11
  readme = "README.md"
12
12
  authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from collections import UserDict
4
+ from copy import copy, deepcopy
4
5
  from dataclasses import dataclass
5
6
  from enum import Enum
6
7
 
@@ -9,7 +10,8 @@ from sqlalchemy import MetaData
9
10
  from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColumn, registry
10
11
  from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING
11
12
 
12
- from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer
13
+ from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer, \
14
+ get_full_class_name, get_type_from_string
13
15
 
14
16
  if TYPE_CHECKING:
15
17
  from ripple_down_rules.rules import Rule
@@ -76,11 +78,17 @@ class Case(UserDict, SubclassJSONSerializer):
76
78
  def _to_json(self) -> Dict[str, Any]:
77
79
  serializable = {k: v for k, v in self.items() if not k.startswith("_")}
78
80
  serializable["_id"] = self._id
81
+ for k, v in serializable.items():
82
+ if isinstance(v, set):
83
+ serializable[k] = {'_type': get_full_class_name(set), 'value': list(v)}
79
84
  return {k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}
80
85
 
81
86
  @classmethod
82
87
  def _from_json(cls, data: Dict[str, Any]) -> Case:
83
88
  id_ = data.pop("_id")
89
+ for k, v in data.items():
90
+ if isinstance(v, dict) and "_type" in v:
91
+ data[k] = SubclassJSONSerializer.from_json(v)
84
92
  return cls(_id=id_, **data)
85
93
 
86
94
 
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import importlib
4
4
  from abc import ABC, abstractmethod
5
- from copy import copy, deepcopy
5
+ from copy import copy
6
6
  from types import ModuleType
7
7
 
8
8
  from matplotlib import pyplot as plt
@@ -14,7 +14,7 @@ from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute,
14
14
  from .experts import Expert, Human
15
15
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
16
16
  from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
17
- get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list
17
+ get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_full_class_name, get_type_from_string
18
18
 
19
19
 
20
20
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -120,7 +120,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
120
120
  plt.show()
121
121
 
122
122
  @staticmethod
123
- def calculate_precision_and_recall(pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> Tuple[List[bool], List[bool]]:
123
+ def calculate_precision_and_recall(pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> Tuple[
124
+ List[bool], List[bool]]:
124
125
  """
125
126
  :param pred_cat: The predicted category.
126
127
  :param target: The target category.
@@ -194,16 +195,17 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
194
195
 
195
196
  :param file_path: The path to the file to write the source code to.
196
197
  """
197
- func_def = f"def classify(case: {self.case_type.__name__}) -> {self._get_conclusion_type_hint()}:\n"
198
+ func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
198
199
  with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
199
200
  f.write(self._get_imports() + "\n\n")
200
201
  f.write(func_def)
201
- f.write(f"{' '*4}if not isinstance(case, Case):\n"
202
- f"{' '*4} case = create_case(case, recursion_idx=3)\n""")
202
+ f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
203
+ f"{' ' * 4} case = create_case(case, recursion_idx=3)\n""")
203
204
  self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
204
205
 
206
+ @property
205
207
  @abstractmethod
206
- def _get_conclusion_type_hint(self) -> str:
208
+ def conclusion_type_hint(self) -> str:
207
209
  """
208
210
  :return: The type hint of the conclusion of the rdr as a string.
209
211
  """
@@ -254,6 +256,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
254
256
  if isinstance(self.start_rule.conclusion, CallableExpression):
255
257
  return self.start_rule.conclusion.conclusion_type
256
258
  else:
259
+ if isinstance(self.start_rule.conclusion, set):
260
+ return type(list(self.start_rule.conclusion)[0])
257
261
  return type(self.start_rule.conclusion)
258
262
 
259
263
 
@@ -316,7 +320,8 @@ class SingleClassRDR(RDRWithCodeWriter):
316
320
  if rule.alternative:
317
321
  self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
318
322
 
319
- def _get_conclusion_type_hint(self) -> str:
323
+ @property
324
+ def conclusion_type_hint(self) -> str:
320
325
  return self.conclusion_type.__name__
321
326
 
322
327
  def _to_json(self) -> Dict[str, Any]:
@@ -427,6 +432,8 @@ class MultiClassRDR(RDRWithCodeWriter):
427
432
  file, parent_indent: str = ""):
428
433
  """
429
434
  Write the rules as source code to a file.
435
+
436
+ :
430
437
  """
431
438
  if rule == self.start_rule:
432
439
  file.write(f"{parent_indent}conclusions = set()\n")
@@ -435,14 +442,15 @@ class MultiClassRDR(RDRWithCodeWriter):
435
442
  conclusion_indent = parent_indent
436
443
  if hasattr(rule, "refinement") and rule.refinement:
437
444
  self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ")
438
- conclusion_indent = parent_indent + " "*4
445
+ conclusion_indent = parent_indent + " " * 4
439
446
  file.write(f"{conclusion_indent}else:\n")
440
447
  file.write(rule.write_conclusion_as_source_code(conclusion_indent))
441
448
 
442
449
  if rule.alternative:
443
450
  self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
444
451
 
445
- def _get_conclusion_type_hint(self) -> str:
452
+ @property
453
+ def conclusion_type_hint(self) -> str:
446
454
  return f"Set[{self.conclusion_type.__name__}]"
447
455
 
448
456
  def _get_imports(self) -> str:
@@ -654,7 +662,7 @@ class GeneralRDR(RippleDownRules):
654
662
  @start_rule.setter
655
663
  def start_rule(self, value: Union[SingleClassRDR, MultiClassRDR]):
656
664
  if value:
657
- self.start_rules_dict[type(value.start_rule.conclusion)] = value
665
+ self.start_rules_dict[value.conclusion_type] = value
658
666
 
659
667
  @property
660
668
  def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
@@ -665,6 +673,19 @@ class GeneralRDR(RippleDownRules):
665
673
  Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
666
674
  the classification until no more categories can be added.
667
675
 
676
+ :param case: The case to classify.
677
+ :return: The categories that the case belongs to.
678
+ """
679
+ return self._classify(self.start_rules_dict, case)
680
+
681
+ @staticmethod
682
+ def _classify(classifiers_dict: Dict[Type, Union[ModuleType, RippleDownRules]],
683
+ case: Union[Case, SQLTable]) -> Optional[List[Any]]:
684
+ """
685
+ Classify a case by going through all classifiers and adding the categories that are classified,
686
+ and then restarting the classification until no more categories can be added.
687
+
688
+ :param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
668
689
  :param case: The case to classify.
669
690
  :return: The categories that the case belongs to.
670
691
  """
@@ -672,8 +693,8 @@ class GeneralRDR(RippleDownRules):
672
693
  case_cp = copy_case(case)
673
694
  while True:
674
695
  added_attributes = False
675
- for cat_type, rdr in self.start_rules_dict.items():
676
- if self.case_has_conclusion(case_cp, cat_type):
696
+ for cat_type, rdr in classifiers_dict.items():
697
+ if GeneralRDR.case_has_conclusion(case_cp, cat_type):
677
698
  continue
678
699
  pred_atts = rdr.classify(case_cp)
679
700
  if pred_atts:
@@ -733,7 +754,7 @@ class GeneralRDR(RippleDownRules):
733
754
  conclusions = rdr.classify(case_cp)
734
755
  else:
735
756
  conclusions = self.start_rules_dict[target_type].fit_case(case_query_cp,
736
- expert, **kwargs)
757
+ expert, **kwargs)
737
758
  self.update_case(case_cp, conclusions, rdr_type)
738
759
 
739
760
  return self.classify(case)
@@ -801,3 +822,80 @@ class GeneralRDR(RippleDownRules):
801
822
  Get all the types of categories that the GRDR can classify.
802
823
  """
803
824
  return list(self.start_rules_dict.keys())
825
+
826
+ def _to_json(self) -> Dict[str, Any]:
827
+ return {"start_rules": {get_full_class_name(t): rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
828
+
829
+ @classmethod
830
+ def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
831
+ """
832
+ Create an instance of the class from a json
833
+ """
834
+ start_rules_dict = {}
835
+ for k, v in data["start_rules"].items():
836
+ k = get_type_from_string(k)
837
+ start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
838
+ return cls(start_rules_dict)
839
+
840
+ def write_to_python_file(self, file_path: str):
841
+ """
842
+ Write the tree of rules as source code to a file.
843
+
844
+ :param file_path: The path to the file to write the source code to.
845
+ """
846
+ for rdr in self.start_rules_dict.values():
847
+ rdr.write_to_python_file(file_path)
848
+ func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
849
+ with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
850
+ f.write(self._get_imports(file_path) + "\n\n")
851
+ f.write("classifiers_dict = dict()\n")
852
+ for t, rdr in self.start_rules_dict.items():
853
+ f.write(f"classifiers_dict[{t.__name__}] = {t.__name__.lower()}_classifier\n")
854
+ f.write("\n\n")
855
+ f.write(func_def)
856
+ f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
857
+ f"{' ' * 4} case = create_case(case, recursion_idx=3)\n""")
858
+ f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
859
+
860
+ @property
861
+ def case_type(self) -> Type:
862
+ """
863
+ :return: The type of the case (input) to the RDR classifier.
864
+ """
865
+ if isinstance(self.start_rule.corner_case, Case):
866
+ return self.start_rule.corner_case._type
867
+ else:
868
+ return type(self.start_rule.corner_case)
869
+
870
+ def get_rdr_classifier_from_python_file(self, file_path: str):
871
+ """
872
+ :param file_path: The path to the file that contains the RDR classifier function.
873
+ :return: The module that contains the rdr classifier function.
874
+ """
875
+ return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
876
+
877
+ @property
878
+ def generated_python_file_name(self) -> str:
879
+ return f"{self.case_type.__name__.lower()}_grdr"
880
+
881
+ @property
882
+ def conclusion_type_hint(self) -> str:
883
+ return f"List[Union[{', '.join([rdr.conclusion_type_hint for rdr in self.start_rules_dict.values()])}]]"
884
+
885
+ def _get_imports(self, file_path: str) -> str:
886
+ imports = ""
887
+ # add type hints
888
+ imports += f"from typing_extensions import List, Union, Set\n"
889
+ # import rdr type
890
+ imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
891
+ # add case type
892
+ imports += f"from ripple_down_rules.datastructures import Case, create_case\n"
893
+ imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
894
+ # add conclusion type imports
895
+ for conclusion_type in self.start_rules_dict.keys():
896
+ imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
897
+ # add rdr python generated functions.
898
+ for conclusion_type, rdr in self.start_rules_dict.items():
899
+ imports += (f"from {file_path.strip('./')}"
900
+ f" import {rdr.generated_python_file_name} as {conclusion_type.__name__.lower()}_classifier\n")
901
+ return imports
@@ -8,7 +8,7 @@ from typing_extensions import List, Optional, Self, Union, Dict, Any
8
8
 
9
9
  from .datastructures import CallableExpression, Case, SQLTable
10
10
  from .datastructures.enums import RDREdge, Stop
11
- from .utils import SubclassJSONSerializer
11
+ from .utils import SubclassJSONSerializer, is_iterable, get_full_class_name
12
12
 
13
13
 
14
14
  class Rule(NodeMixin, SubclassJSONSerializer, ABC):
@@ -102,8 +102,17 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
102
102
  pass
103
103
 
104
104
  def _to_json(self) -> Dict[str, Any]:
105
+ def conclusion_to_json(conclusion):
106
+ if is_iterable(conclusion):
107
+ conclusions = {'_type': get_full_class_name(type(conclusion)), 'value': []}
108
+ for c in conclusion:
109
+ conclusions['value'].append(conclusion_to_json(c))
110
+ else:
111
+ conclusions = conclusion.to_json()
112
+ return conclusions
113
+
105
114
  json_serialization = {"conditions": self.conditions.to_json(),
106
- "conclusion": self.conclusion.to_json(),
115
+ "conclusion": conclusion_to_json(self.conclusion),
107
116
  "parent": self.parent.json_serialization if self.parent else None,
108
117
  "corner_case": self.corner_case.to_json() if self.corner_case else None,
109
118
  "weight": self.weight}
@@ -325,7 +334,11 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
325
334
  return loaded_rule
326
335
 
327
336
  def _conclusion_source_code_clause(self, conclusion: Any, parent_indent: str = "") -> str:
328
- statement = f"{parent_indent}{' ' * 4}conclusions.add({conclusion})\n"
337
+ if is_iterable(conclusion):
338
+ conclusion_str = "{" + ", ".join([str(c) for c in conclusion]) + "}"
339
+ else:
340
+ conclusion_str = "{" + str(conclusion) + "}"
341
+ statement = f"{parent_indent}{' ' * 4}conclusions.update({conclusion_str})\n"
329
342
  if self.alternative is None:
330
343
  statement += f"{parent_indent}return conclusions\n"
331
344
  return statement
@@ -25,6 +25,16 @@ if TYPE_CHECKING:
25
25
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
26
26
 
27
27
 
28
+ def flatten_list(a: List):
29
+ a_flattened = []
30
+ for c in a:
31
+ if is_iterable(c):
32
+ a_flattened.extend(list(c))
33
+ else:
34
+ a_flattened.append(c)
35
+ return a_flattened
36
+
37
+
28
38
  def make_list(value: Any) -> List:
29
39
  """
30
40
  Make a list from a value.
@@ -97,15 +107,13 @@ class SubclassJSONSerializer:
97
107
  def to_json(self) -> Dict[str, Any]:
98
108
  return {"_type": get_full_class_name(self.__class__), **self._to_json()}
99
109
 
100
- @abstractmethod
101
110
  def _to_json(self) -> Dict[str, Any]:
102
111
  """
103
112
  Create a json dict from the object.
104
113
  """
105
- pass
114
+ raise NotImplementedError()
106
115
 
107
116
  @classmethod
108
- @abstractmethod
109
117
  def _from_json(cls, data: Dict[str, Any]) -> Self:
110
118
  """
111
119
  Create a variable from a json dict.
@@ -140,7 +148,16 @@ class SubclassJSONSerializer:
140
148
  """
141
149
  if data is None:
142
150
  return None
151
+ if not isinstance(data, dict) or ('_type' not in data):
152
+ return data
153
+ # check if type module is builtins
154
+ data_type = get_type_from_string(data["_type"])
155
+ if data_type.__module__ == 'builtins':
156
+ if is_iterable(data['value']) and not isinstance(data['value'], dict):
157
+ return data_type([cls.from_json(d) for d in data['value']])
158
+ return data_type(data["value"])
143
159
  if get_full_class_name(cls) == data["_type"]:
160
+ data.pop("_type")
144
161
  return cls._from_json(data)
145
162
  for subclass in recursive_subclasses(SubclassJSONSerializer):
146
163
  if get_full_class_name(subclass) == data["_type"]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.13
3
+ Version: 0.0.15
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -6,9 +6,9 @@ from typing_extensions import List
6
6
  from ripple_down_rules.datasets import load_zoo_dataset
7
7
  from ripple_down_rules.datastructures import CaseQuery, Case
8
8
  from ripple_down_rules.experts import Human
9
- from ripple_down_rules.rdr import SingleClassRDR, MultiClassRDR
10
- from ripple_down_rules.utils import make_set
11
- from test_helpers.helpers import get_fit_mcrdr, get_fit_scrdr
9
+ from ripple_down_rules.rdr import SingleClassRDR, MultiClassRDR, GeneralRDR
10
+ from ripple_down_rules.utils import make_set, flatten_list
11
+ from test_helpers.helpers import get_fit_mcrdr, get_fit_scrdr, get_fit_grdr
12
12
 
13
13
 
14
14
  class TestJSONSerialization(TestCase):
@@ -38,3 +38,14 @@ class TestJSONSerialization(TestCase):
38
38
  for case, target in zip(self.all_cases, self.targets):
39
39
  cat = mcrdr.classify(case)
40
40
  self.assertEqual(make_set(cat), make_set(target))
41
+
42
+ def test_grdr_json_serialization(self):
43
+ grdr, all_targets = get_fit_grdr(self.all_cases, self.targets)
44
+ filename = f"{self.cache_dir}/grdr.json"
45
+ grdr.save(filename)
46
+ grdr = GeneralRDR.load(filename)
47
+ for case, case_targets in zip(self.all_cases[:len(all_targets)], all_targets):
48
+ cat = grdr.classify(case)
49
+ cat = flatten_list(cat)
50
+ case_targets = flatten_list(case_targets)
51
+ self.assertEqual(make_set(cat), make_set(case_targets))
@@ -10,8 +10,8 @@ from ripple_down_rules.datastructures import Case, MCRDRMode, \
10
10
  Case, CaseAttribute, Category, CaseQuery
11
11
  from ripple_down_rules.experts import Human
12
12
  from ripple_down_rules.rdr import SingleClassRDR, MultiClassRDR, GeneralRDR
13
- from ripple_down_rules.utils import render_tree, get_all_subclasses, make_set
14
- from test_helpers.helpers import get_fit_scrdr, get_fit_mcrdr
13
+ from ripple_down_rules.utils import render_tree, get_all_subclasses, make_set, flatten_list
14
+ from test_helpers.helpers import get_fit_scrdr, get_fit_mcrdr, get_fit_grdr
15
15
 
16
16
 
17
17
  class TestRDR(TestCase):
@@ -87,6 +87,16 @@ class TestRDR(TestCase):
87
87
  cat = classify_species_mcrdr(case)
88
88
  self.assertEqual(make_set(cat), make_set(target))
89
89
 
90
+ def test_write_grdr_to_python_file(self):
91
+ grdr, all_targets = get_fit_grdr(self.all_cases, self.targets)
92
+ grdr.write_to_python_file(self.generated_rdrs_dir)
93
+ classify_animal_grdr = grdr.get_rdr_classifier_from_python_file(self.generated_rdrs_dir)
94
+ for case, case_targets in zip(self.all_cases[:len(all_targets)], all_targets):
95
+ cat = classify_animal_grdr(case)
96
+ cat = flatten_list(cat)
97
+ case_targets = flatten_list(case_targets)
98
+ self.assertEqual(make_set(cat), make_set(case_targets))
99
+
90
100
  def test_classify_mcrdr(self):
91
101
  use_loaded_answers = True
92
102
  save_answers = False