ripple-down-rules 0.0.10__tar.gz → 0.0.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/PKG-INFO +1 -1
  2. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/pyproject.toml +1 -1
  3. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules/rdr.py +119 -41
  4. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules/rules.py +97 -33
  5. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
  6. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/test/test_rdr.py +28 -2
  7. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/LICENSE +0 -0
  8. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/README.md +0 -0
  9. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/setup.cfg +0 -0
  10. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules/__init__.py +0 -0
  11. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules/datasets.py +0 -0
  12. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules/datastructures/__init__.py +0 -0
  13. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules/datastructures/callable_expression.py +0 -0
  14. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules/datastructures/case.py +0 -0
  15. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules/datastructures/dataclasses.py +0 -0
  16. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules/datastructures/enums.py +0 -0
  17. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules/experts.py +0 -0
  18. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules/failures.py +0 -0
  19. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules/prompt.py +0 -0
  20. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules/utils.py +0 -0
  21. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules.egg-info/SOURCES.txt +0 -0
  22. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
  23. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
  24. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/test/test_json_serialization.py +0 -0
  25. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/test/test_rdr_alchemy.py +0 -0
  26. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/test/test_relational_rdr.py +0 -0
  27. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/test/test_relational_rdr_alchemy.py +0 -0
  28. {ripple_down_rules-0.0.10 → ripple_down_rules-0.0.11}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.10
3
+ Version: 0.0.11
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
 
7
7
  [project]
8
8
  name = "ripple_down_rules"
9
- version = "0.0.10"
9
+ version = "0.0.11"
10
10
  description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
11
11
  readme = "README.md"
12
12
  authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
@@ -1,16 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import importlib
3
4
  from abc import ABC, abstractmethod
4
5
  from copy import copy, deepcopy
6
+ from types import ModuleType
5
7
 
6
8
  from matplotlib import pyplot as plt
7
9
  from ordered_set import OrderedSet
8
10
  from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
9
- from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple
11
+ from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable
10
12
 
11
13
  from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute, CaseQuery
12
14
  from .experts import Expert, Human
13
- from .rules import Rule, SingleClassRule, MultiClassTopRule
15
+ from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
14
16
  from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
15
17
  get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list
16
18
 
@@ -176,7 +178,90 @@ class RippleDownRules(ABC):
176
178
  RDR = RippleDownRules
177
179
 
178
180
 
179
- class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
181
+ class RDRWithCodeWriter(RDR, ABC):
182
+
183
+ @abstractmethod
184
+ def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = ""):
185
+ """
186
+ Write the rules as source code to a file.
187
+
188
+ :param rule: The rule to write as source code.
189
+ :param file: The file to write the source code to.
190
+ :param parent_indent: The indentation of the parent rule.
191
+ """
192
+ pass
193
+
194
+ def write_to_python_file(self, file_path: str):
195
+ """
196
+ Write the tree of rules as source code to a file.
197
+
198
+ :param file_path: The path to the file to write the source code to.
199
+ """
200
+ func_def = f"def classify(case: {self.case_type.__name__}) -> {self._get_conclusion_type_hint()}:\n"
201
+ with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
202
+ f.write(self._get_imports())
203
+ f.write(func_def)
204
+ f.write(f"{' '*4}if not isinstance(case, Case):\n"
205
+ f"{' '*4} case = create_case(case, recursion_idx=3)\n""")
206
+ self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
207
+
208
+ @abstractmethod
209
+ def _get_conclusion_type_hint(self) -> str:
210
+ """
211
+ :return: The type hint of the conclusion of the rdr as a string.
212
+ """
213
+ pass
214
+
215
+ def _get_imports(self) -> str:
216
+ """
217
+ :return: The imports for the generated python file of the RDR as a string.
218
+ """
219
+ imports = ""
220
+ if self.case_type.__module__ != "builtins":
221
+ imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
222
+ if self.conclusion_type.__module__ != "builtins":
223
+ imports += f"from {self.conclusion_type.__module__} import {self.conclusion_type.__name__}\n"
224
+ imports += "from ripple_down_rules.datastructures import Case, create_case\n"
225
+ imports += "\n\n"
226
+ return imports
227
+
228
+ def get_rdr_classifier_from_python_file(self, package_name) -> Callable[[Any], Any]:
229
+ """
230
+ :param package_name: The name of the package that contains the RDR classifier function.
231
+ :return: The module that contains the rdr classifier function.
232
+ """
233
+ return importlib.import_module(f"{package_name.strip('./')}.{self.generated_python_file_name}").classify
234
+
235
+ @property
236
+ def generated_python_file_name(self) -> str:
237
+ return f"{self.conclusion_type.__name__.lower()}_{self.__class__.__name__}"
238
+
239
+ @property
240
+ def python_file_name(self):
241
+ return f"{self.start_rule.conclusion.__name__.lower()}_rdr"
242
+
243
+ @property
244
+ def case_type(self) -> Type:
245
+ """
246
+ :return: The type of the case (input) to the RDR classifier.
247
+ """
248
+ if isinstance(self.start_rule.corner_case, Case):
249
+ return self.start_rule.corner_case._type
250
+ else:
251
+ return type(self.start_rule.corner_case)
252
+
253
+ @property
254
+ def conclusion_type(self) -> Type:
255
+ """
256
+ :return: The type of the conclusion of the RDR classifier.
257
+ """
258
+ if isinstance(self.start_rule.conclusion, CallableExpression):
259
+ return self.start_rule.conclusion.conclusion_type
260
+ else:
261
+ return type(self.start_rule.conclusion)
262
+
263
+
264
+ class SingleClassRDR(RDRWithCodeWriter, SubclassJSONSerializer):
180
265
 
181
266
  def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
182
267
  -> Union[CaseAttribute, CallableExpression]:
@@ -221,43 +306,6 @@ class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
221
306
  matched_rule = self.start_rule(case)
222
307
  return matched_rule if matched_rule else self.start_rule
223
308
 
224
- def write_to_python_file(self, filename: str):
225
- """
226
- Write the tree of rules as source code to a file.
227
- """
228
- case = self.start_rule.corner_case
229
- if isinstance(case, Case):
230
- case_type = case._type
231
- else:
232
- case_type = type(case)
233
- case_module = case_type.__module__
234
- conclusion = self.start_rule.conclusion
235
- if isinstance(conclusion, CallableExpression):
236
- conclusion_types = [conclusion.conclusion_type]
237
- elif isinstance(conclusion, CaseAttribute):
238
- conclusion_types = list(conclusion._value_range)
239
- else:
240
- conclusion_types = [type(conclusion)]
241
- imports = ""
242
- if case_module != "builtins":
243
- imports += f"from {case_module} import {case_type.__name__}\n"
244
- if len(conclusion_types) > 1:
245
- conclusion_name = "Union[" + ", ".join([c.__name__ for c in conclusion_types]) + "]"
246
- else:
247
- conclusion_name = conclusion_types[0].__name__
248
- for conclusion_type in conclusion_types:
249
- if conclusion_type.__module__ != "builtins":
250
- imports += f"from {conclusion_type.__module__} import {conclusion_name}\n"
251
- imports += "from ripple_down_rules.datastructures import Case, create_case\n"
252
- imports += "\n\n"
253
- func_def = f"def classify_{conclusion_name.lower()}(case: {case_type.__name__}) -> {conclusion_name}:\n"
254
- with open(filename, "w") as f:
255
- f.write(imports)
256
- f.write(func_def)
257
- f.write(f"{' '*4}if not isinstance(case, Case):\n"
258
- f"{' '*4} case = create_case(case, recursion_idx=3)\n""")
259
- self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
260
-
261
309
  def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file, parent_indent: str = ""):
262
310
  """
263
311
  Write the rules as source code to a file.
@@ -272,6 +320,9 @@ class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
272
320
  if rule.alternative:
273
321
  self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
274
322
 
323
+ def _get_conclusion_type_hint(self) -> str:
324
+ return self.conclusion_type.__name__
325
+
275
326
  def _to_json(self) -> Dict[str, Any]:
276
327
  return {"start_rule": self.start_rule.to_json()}
277
328
 
@@ -284,7 +335,7 @@ class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
284
335
  return cls(start_rule)
285
336
 
286
337
 
287
- class MultiClassRDR(RippleDownRules):
338
+ class MultiClassRDR(RDRWithCodeWriter):
288
339
  """
289
340
  A multi class ripple down rules classifier, which can draw multiple conclusions for a case.
290
341
  This is done by going through all rules and checking if they fire or not, and adding stopping rules if needed,
@@ -378,6 +429,33 @@ class MultiClassRDR(RippleDownRules):
378
429
  evaluated_rule = next_rule
379
430
  return self.conclusions
380
431
 
432
+ def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
433
+ file, parent_indent: str = ""):
434
+ """
435
+ Write the rules as source code to a file.
436
+ """
437
+ if rule == self.start_rule:
438
+ file.write(f"{parent_indent}conclusions = set()\n")
439
+ if rule.conditions:
440
+ file.write(rule.write_condition_as_source_code(parent_indent))
441
+ conclusion_indent = parent_indent
442
+ if hasattr(rule, "refinement") and rule.refinement:
443
+ self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ")
444
+ conclusion_indent = parent_indent + " "*4
445
+ file.write(f"{conclusion_indent}else:\n")
446
+ file.write(rule.write_conclusion_as_source_code(conclusion_indent))
447
+
448
+ if rule.alternative:
449
+ self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
450
+
451
+ def _get_conclusion_type_hint(self) -> str:
452
+ return f"Set[{self.conclusion_type.__name__}]"
453
+
454
+ def _get_imports(self) -> str:
455
+ imports = super()._get_imports().strip('\n')
456
+ imports += "\nfrom typing_extensions import Set\n\n"
457
+ return imports
458
+
381
459
  def update_start_rule(self, case: Union[Case, SQLTable], target: Any, expert: Expert):
382
460
  """
383
461
  Update the starting rule of the classifier.
@@ -70,6 +70,54 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
70
70
  """
71
71
  pass
72
72
 
73
+ def write_conclusion_as_source_code(self, parent_indent: str = "") -> str:
74
+ """
75
+ Get the source code representation of the conclusion of the rule.
76
+
77
+ :param parent_indent: The indentation of the parent rule.
78
+ """
79
+ if isinstance(self.conclusion, CallableExpression):
80
+ conclusion = self.conclusion.parsed_user_input
81
+ elif isinstance(self.conclusion, Enum):
82
+ conclusion = str(self.conclusion)
83
+ else:
84
+ conclusion = self.conclusion
85
+ return self._conclusion_source_code_clause(conclusion, parent_indent=parent_indent)
86
+
87
+ @abstractmethod
88
+ def _conclusion_source_code_clause(self, conclusion: Any, parent_indent: str = "") -> str:
89
+ pass
90
+
91
+ def write_condition_as_source_code(self, parent_indent: str = "") -> str:
92
+ """
93
+ Get the source code representation of the conditions of the rule.
94
+
95
+ :param parent_indent: The indentation of the parent rule.
96
+ """
97
+ if_clause = self._if_statement_source_code_clause()
98
+ return f"{parent_indent}{if_clause} {self.conditions.parsed_user_input}:\n"
99
+
100
+ @abstractmethod
101
+ def _if_statement_source_code_clause(self) -> str:
102
+ pass
103
+
104
+ def _to_json(self) -> Dict[str, Any]:
105
+ json_serialization = {"conditions": self.conditions.to_json(),
106
+ "conclusion": self.conclusion.to_json(),
107
+ "parent": self.parent.json_serialization if self.parent else None,
108
+ "corner_case": self.corner_case.to_json() if self.corner_case else None,
109
+ "weight": self.weight}
110
+ return json_serialization
111
+
112
+ @classmethod
113
+ def _from_json(cls, data: Dict[str, Any]) -> Rule:
114
+ loaded_rule = cls(conditions=CallableExpression.from_json(data["conditions"]),
115
+ conclusion=CallableExpression.from_json(data["conclusion"]),
116
+ parent=cls.from_json(data["parent"]),
117
+ corner_case=Case.from_json(data["corner_case"]),
118
+ weight=data["weight"])
119
+ return loaded_rule
120
+
73
121
  @property
74
122
  def name(self):
75
123
  """
@@ -172,50 +220,25 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
172
220
  else:
173
221
  self.alternative = new_rule
174
222
 
175
- def write_conclusion_as_source_code(self, parent_indent: str = "") -> str:
176
- """
177
- Get the source code representation of the conclusion of the rule.
178
-
179
- :param parent_indent: The indentation of the parent rule.
180
- """
181
- if isinstance(self.conclusion, CallableExpression):
182
- conclusion = self.conclusion.parsed_user_input
183
- elif isinstance(self.conclusion, Enum):
184
- conclusion = str(self.conclusion)
185
- else:
186
- conclusion = self.conclusion
187
- return f"{parent_indent}{' ' * 4}return {conclusion}\n"
188
-
189
- def write_condition_as_source_code(self, parent_indent: str = "") -> str:
190
- """
191
- Get the source code representation of the conditions of the rule.
192
-
193
- :param parent_indent: The indentation of the parent rule.
194
- """
195
- if_clause = "elif" if self.weight == RDREdge.Alternative.value else "if"
196
- return f"{parent_indent}{if_clause} {self.conditions.parsed_user_input}:\n"
197
-
198
223
  def _to_json(self) -> Dict[str, Any]:
199
- self.json_serialization = {"conditions": self.conditions.to_json(),
200
- "conclusion": self.conclusion.to_json(),
201
- "parent": self.parent.json_serialization if self.parent else None,
202
- "corner_case": self.corner_case.to_json() if self.corner_case else None,
203
- "weight": self.weight,
224
+ self.json_serialization = {**super(SingleClassRule, self)._to_json(),
204
225
  "refinement": self.refinement.to_json() if self.refinement is not None else None,
205
226
  "alternative": self.alternative.to_json() if self.alternative is not None else None}
206
227
  return self.json_serialization
207
228
 
208
229
  @classmethod
209
230
  def _from_json(cls, data: Dict[str, Any]) -> SingleClassRule:
210
- loaded_rule = cls(conditions=CallableExpression.from_json(data["conditions"]),
211
- conclusion=CallableExpression.from_json(data["conclusion"]),
212
- parent=SingleClassRule.from_json(data["parent"]),
213
- corner_case=Case.from_json(data["corner_case"]),
214
- weight=data["weight"])
231
+ loaded_rule = super(SingleClassRule, cls)._from_json(data)
215
232
  loaded_rule.refinement = SingleClassRule.from_json(data["refinement"])
216
233
  loaded_rule.alternative = SingleClassRule.from_json(data["alternative"])
217
234
  return loaded_rule
218
235
 
236
+ def _conclusion_source_code_clause(self, conclusion: Any, parent_indent: str = "") -> str:
237
+ return f"{parent_indent}{' ' * 4}return {conclusion}\n"
238
+
239
+ def _if_statement_source_code_clause(self) -> str:
240
+ return "elif" if self.weight == RDREdge.Alternative.value else "if"
241
+
219
242
 
220
243
  class MultiClassStopRule(Rule, HasAlternativeRule):
221
244
  """
@@ -240,6 +263,25 @@ class MultiClassStopRule(Rule, HasAlternativeRule):
240
263
  else:
241
264
  return self.top_rule.alternative
242
265
 
266
+ def _to_json(self) -> Dict[str, Any]:
267
+ self.json_serialization = {**Rule._to_json(self),
268
+ "top_rule": self.top_rule.to_json(),
269
+ "alternative": self.alternative.to_json() if self.alternative is not None else None}
270
+ return self.json_serialization
271
+
272
+ @classmethod
273
+ def _from_json(cls, data: Dict[str, Any]) -> MultiClassStopRule:
274
+ loaded_rule = Rule._from_json(data)
275
+ loaded_rule.top_rule = MultiClassTopRule.from_json(data["top_rule"])
276
+ loaded_rule.alternative = MultiClassStopRule.from_json(data["alternative"])
277
+ return loaded_rule
278
+
279
+ def _conclusion_source_code_clause(self, conclusion: Any, parent_indent: str = "") -> str:
280
+ return f"{parent_indent}{' ' * 4}pass\n"
281
+
282
+ def _if_statement_source_code_clause(self) -> str:
283
+ return "elif" if self.weight == RDREdge.Alternative.value else "if"
284
+
243
285
 
244
286
  class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
245
287
  """
@@ -261,3 +303,25 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
261
303
  self.refinement = MultiClassStopRule(conditions, corner_case=x, parent=self)
262
304
  elif not self.fired:
263
305
  self.alternative = MultiClassTopRule(conditions, target, corner_case=x, parent=self)
306
+
307
+ def _to_json(self) -> Dict[str, Any]:
308
+ self.json_serialization = {**Rule._to_json(self),
309
+ "refinement": self.refinement.to_json() if self.refinement is not None else None,
310
+ "alternative": self.alternative.to_json() if self.alternative is not None else None}
311
+ return self.json_serialization
312
+
313
+ @classmethod
314
+ def _from_json(cls, data: Dict[str, Any]) -> MultiClassTopRule:
315
+ loaded_rule = Rule._from_json(data)
316
+ loaded_rule.refinement = MultiClassStopRule.from_json(data["refinement"])
317
+ loaded_rule.alternative = MultiClassTopRule.from_json(data["alternative"])
318
+ return loaded_rule
319
+
320
+ def _conclusion_source_code_clause(self, conclusion: Any, parent_indent: str = "") -> str:
321
+ statement = f"{parent_indent}{' ' * 4}conclusions.add({conclusion})\n"
322
+ if self.alternative is None:
323
+ statement += f"{parent_indent}return conclusions\n"
324
+ return statement
325
+
326
+ def _if_statement_source_code_clause(self) -> str:
327
+ return "if"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.10
3
+ Version: 0.0.11
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -1,3 +1,4 @@
1
+ import importlib
1
2
  import os
2
3
  from unittest import TestCase, skip
3
4
 
@@ -9,7 +10,7 @@ from ripple_down_rules.datastructures import Case, MCRDRMode, \
9
10
  Case, CaseAttribute, Category, CaseQuery
10
11
  from ripple_down_rules.experts import Human
11
12
  from ripple_down_rules.rdr import SingleClassRDR, MultiClassRDR, GeneralRDR
12
- from ripple_down_rules.utils import render_tree, get_all_subclasses
13
+ from ripple_down_rules.utils import render_tree, get_all_subclasses, make_set
13
14
 
14
15
 
15
16
  class TestRDR(TestCase):
@@ -58,7 +59,6 @@ class TestRDR(TestCase):
58
59
  case_queries = [CaseQuery(case, target=target) for case, target in zip(self.all_cases, self.targets)]
59
60
  scrdr.fit(case_queries, expert=expert,
60
61
  animate_tree=draw_tree)
61
- scrdr.write_to_python_file(self.generated_rdrs_dir + "/scrdr.py")
62
62
  render_tree(scrdr.start_rule, use_dot_exporter=True,
63
63
  filename=self.test_results_dir + f"/scrdr")
64
64
 
@@ -70,6 +70,23 @@ class TestRDR(TestCase):
70
70
  file = os.path.join(cwd, filename)
71
71
  expert.save_answers(file)
72
72
 
73
+ def test_write_scrdr_to_python_file(self):
74
+ scrdr = self.get_fit_scrdr()
75
+ scrdr.write_to_python_file(self.generated_rdrs_dir)
76
+ classify_species_scrdr = scrdr.get_rdr_classifier_from_python_file(self.generated_rdrs_dir)
77
+ for case, target in zip(self.all_cases, self.targets):
78
+ cat = classify_species_scrdr(case)
79
+ self.assertEqual(cat, target)
80
+
81
+ def test_write_mcrdr_to_python_file(self):
82
+ mcrdr = self.get_fit_mcrdr()
83
+ mcrdr.write_to_python_file(self.generated_rdrs_dir)
84
+ classify_species_mcrdr = mcrdr.get_rdr_classifier_from_python_file(self.generated_rdrs_dir)
85
+ for case in self.all_cases:
86
+ cat_1 = mcrdr.classify(case)
87
+ cat_2 = classify_species_mcrdr(case)
88
+ self.assertEqual(make_set(cat_1), make_set(cat_2))
89
+
73
90
  def test_classify_mcrdr(self):
74
91
  use_loaded_answers = True
75
92
  save_answers = False
@@ -321,3 +338,12 @@ class TestRDR(TestCase):
321
338
  scrdr.fit(case_queries, expert=expert,
322
339
  animate_tree=draw_tree)
323
340
  return scrdr
341
+
342
+ def get_fit_mcrdr(self, draw_tree: bool = False):
343
+ filename = self.expert_answers_dir + "/mcrdr_expert_answers_stop_only_fit"
344
+ expert = Human(use_loaded_answers=True)
345
+ expert.load_answers(filename)
346
+ mcrdr = MultiClassRDR()
347
+ case_queries = [CaseQuery(case, target=target) for case, target in zip(self.all_cases, self.targets)]
348
+ mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree, n_iter=1)
349
+ return mcrdr