ripple-down-rules 0.4.7__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ripple_down_rules/rdr.py CHANGED
@@ -2,13 +2,24 @@ from __future__ import annotations
2
2
 
3
3
  import copyreg
4
4
  import importlib
5
+ import os
6
+
7
+ from . import logger
5
8
  import sys
6
9
  from abc import ABC, abstractmethod
7
10
  from copy import copy
8
11
  from io import TextIOWrapper
9
12
  from types import ModuleType
10
13
 
11
- from matplotlib import pyplot as plt
14
+ try:
15
+ from matplotlib import pyplot as plt
16
+ Figure = plt.Figure
17
+ except ImportError as e:
18
+ logger.debug(f"{e}: matplotlib is not installed")
19
+ matplotlib = None
20
+ Figure = None
21
+ plt = None
22
+
12
23
  from sqlalchemy.orm import DeclarativeBase as SQLTable
13
24
  from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
14
25
 
@@ -19,17 +30,21 @@ from .datastructures.enums import MCRDRMode
19
30
  from .experts import Expert, Human
20
31
  from .helpers import is_matching
21
32
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
22
- from .user_interface.gui import RDRCaseViewer
33
+ try:
34
+ from .user_interface.gui import RDRCaseViewer
35
+ except ImportError as e:
36
+ RDRCaseViewer = None
23
37
  from .utils import draw_tree, make_set, copy_case, \
24
38
  SubclassJSONSerializer, make_list, get_type_from_string, \
25
- is_conflicting, update_case, get_imports_from_scope, extract_function_source, extract_imports
39
+ is_conflicting, update_case, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
40
+ is_iterable, str_to_snake_case
26
41
 
27
42
 
28
43
  class RippleDownRules(SubclassJSONSerializer, ABC):
29
44
  """
30
45
  The abstract base class for the ripple down rules classifiers.
31
46
  """
32
- fig: Optional[plt.Figure] = None
47
+ fig: Optional[Figure] = None
33
48
  """
34
49
  The figure to draw the tree on.
35
50
  """
@@ -45,17 +60,94 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
45
60
  """
46
61
  The name of the classifier.
47
62
  """
63
+ case_type: Optional[Type] = None
64
+ """
65
+ The type of the case (input) to the RDR classifier.
66
+ """
67
+ case_name: Optional[str] = None
68
+ """
69
+ The name of the case type.
70
+ """
71
+ metadata_folder: str = "rdr_metadata"
72
+ """
73
+ The folder to save the metadata of the RDR classifier.
74
+ """
75
+ model_name: Optional[str] = None
76
+ """
77
+ The name of the model. If None, the model name will be the generated python file name.
78
+ """
48
79
 
49
- def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None):
80
+ def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None,
81
+ save_dir: Optional[str] = None, ask_always: bool = True, model_name: Optional[str] = None):
50
82
  """
51
83
  :param start_rule: The starting rule for the classifier.
84
+ :param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
85
+ :param save_dir: The directory to save the classifier to.
86
+ :param ask_always: Whether to always ask the expert (True) or only ask when classification fails (False).
52
87
  """
88
+ self.ask_always: bool = ask_always
89
+ self.model_name: Optional[str] = model_name
90
+ self.save_dir = save_dir
53
91
  self.start_rule = start_rule
54
- self.fig: Optional[plt.Figure] = None
92
+ self.fig: Optional[Figure] = None
55
93
  self.viewer: Optional[RDRCaseViewer] = viewer
56
94
  if self.viewer is not None:
57
95
  self.viewer.set_save_function(self.save)
58
96
 
97
+ def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None) -> str:
98
+ """
99
+ Save the classifier to a file.
100
+
101
+ :param save_dir: The directory to save the classifier to.
102
+ :param model_name: The name of the model to save. If None, a default name is generated.
103
+ :param postfix: The postfix to add to the file name.
104
+ :return: The name of the saved model.
105
+ """
106
+ save_dir = save_dir or self.save_dir
107
+ if save_dir is None:
108
+ raise ValueError("The save directory cannot be None. Please provide a valid directory to save"
109
+ " the classifier.")
110
+ if not os.path.exists(save_dir + '/__init__.py'):
111
+ os.makedirs(save_dir, exist_ok=True)
112
+ with open(save_dir + '/__init__.py', 'w') as f:
113
+ f.write("# This is an empty __init__.py file to make the directory a package.\n")
114
+ if model_name is not None:
115
+ self.model_name = model_name
116
+ elif self.model_name is None:
117
+ self.model_name = self.generated_python_file_name
118
+ model_dir = os.path.join(save_dir, self.model_name)
119
+ os.makedirs(model_dir, exist_ok=True)
120
+ json_dir = os.path.join(model_dir, self.metadata_folder)
121
+ os.makedirs(json_dir, exist_ok=True)
122
+ self.to_json_file(os.path.join(json_dir, self.model_name))
123
+ self._write_to_python(model_dir)
124
+ return self.model_name
125
+
126
+ @classmethod
127
+ def load(cls, load_dir: str, model_name: str) -> Self:
128
+ """
129
+ Load the classifier from a file.
130
+
131
+ :param load_dir: The path to the model directory to load the classifier from.
132
+ :param model_name: The name of the model to load.
133
+ """
134
+ model_dir = os.path.join(load_dir, model_name)
135
+ json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
136
+ rdr = cls.from_json_file(json_file)
137
+ rdr.update_from_python(model_dir)
138
+ rdr.save_dir = load_dir
139
+ rdr.model_name = model_name
140
+ return rdr
141
+
142
+ @abstractmethod
143
+ def _write_to_python(self, model_dir: str):
144
+ """
145
+ Write the tree of rules as source code to a file.
146
+
147
+ :param model_dir: The path to the directory to write the source code to.
148
+ """
149
+ pass
150
+
59
151
  def set_viewer(self, viewer: RDRCaseViewer):
60
152
  """
61
153
  Set the viewer for the classifier.
@@ -82,6 +174,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
82
174
  """
83
175
  targets = []
84
176
  if animate_tree:
177
+ if plt is None:
178
+ raise ImportError("matplotlib is not installed, cannot animate the tree.")
85
179
  plt.ion()
86
180
  i = 0
87
181
  stop_iterating = False
@@ -104,13 +198,13 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
104
198
  all_predictions = [1 if is_matching(self.classify, case_query) else 0 for case_query in case_queries
105
199
  if case_query.target is not None]
106
200
  all_pred = sum(all_predictions)
107
- print(f"Accuracy: {all_pred}/{len(targets)}")
201
+ logger.info(f"Accuracy: {all_pred}/{len(targets)}")
108
202
  all_predicted = targets and all_pred == len(targets)
109
203
  num_iter_reached = n_iter and i >= n_iter
110
204
  stop_iterating = all_predicted or num_iter_reached
111
205
  if stop_iterating:
112
206
  break
113
- print(f"Finished training in {i} iterations")
207
+ logger.info(f"Finished training in {i} iterations")
114
208
  if animate_tree:
115
209
  plt.ioff()
116
210
  plt.show()
@@ -142,18 +236,29 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
142
236
  """
143
237
  if case_query is None:
144
238
  raise ValueError("The case query cannot be None.")
239
+
145
240
  self.name = case_query.attribute_name if self.name is None else self.name
241
+ self.case_type = case_query.case_type if self.case_type is None else self.case_type
242
+ self.case_name = case_query.case_name if self.case_name is None else self.case_name
243
+
146
244
  if case_query.target is None:
147
245
  case_query_cp = copy(case_query)
148
- self.classify(case_query_cp.case, modify_case=True)
149
- expert.ask_for_conclusion(case_query_cp)
150
- case_query.target = case_query_cp.target
246
+ conclusions = self.classify(case_query_cp.case, modify_case=True)
247
+ if self.ask_always or conclusions is None or is_iterable(conclusions) and len(conclusions) == 0:
248
+ expert.ask_for_conclusion(case_query_cp)
249
+ case_query.target = case_query_cp.target
151
250
  if case_query.target is None:
152
251
  return self.classify(case_query.case)
153
252
 
154
253
  self.update_start_rule(case_query, expert)
155
254
 
156
- return self._fit_case(case_query, expert=expert, **kwargs)
255
+ fit_case_result = self._fit_case(case_query, expert=expert, **kwargs)
256
+
257
+ if self.save_dir is not None:
258
+ self.save()
259
+ expert.clear_answers()
260
+
261
+ return fit_case_result
157
262
 
158
263
  @abstractmethod
159
264
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
@@ -219,28 +324,54 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
219
324
  pass
220
325
 
221
326
  @abstractmethod
222
- def update_from_python_file(self, package_dir: str):
327
+ def update_from_python(self, model_dir: str):
223
328
  """
224
329
  Update the rules from the generated python file, that might have been modified by the user.
225
330
 
226
- :param package_dir: The directory of the package that contains the generated python file.
331
+ :param model_dir: The directory where the generated python file is located.
227
332
  """
228
333
  pass
229
334
 
335
+ @classmethod
336
+ def get_acronym(cls) -> str:
337
+ """
338
+ :return: The acronym of the classifier.
339
+ """
340
+ if cls.__name__ == "GeneralRDR":
341
+ return "RDR"
342
+ elif cls.__name__ == "MultiClassRDR":
343
+ return "MCRDR"
344
+ else:
345
+ return "SCRDR"
346
+
347
+ def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
348
+ """
349
+ :param package_name: The name of the package that contains the RDR classifier function.
350
+ :return: The module that contains the rdr classifier function.
351
+ """
352
+ # remove from imports if exists first
353
+ name = f"{package_name.strip('./').replace('/', '.')}.{self.generated_python_file_name}"
354
+ try:
355
+ module = importlib.import_module(name)
356
+ del sys.modules[name]
357
+ except ModuleNotFoundError:
358
+ pass
359
+ return importlib.import_module(name).classify
360
+
230
361
 
231
362
  class RDRWithCodeWriter(RippleDownRules, ABC):
232
363
 
233
- def update_from_python_file(self, package_dir: str):
364
+ def update_from_python(self, model_dir: str):
234
365
  """
235
366
  Update the rules from the generated python file, that might have been modified by the user.
236
367
 
237
- :param package_dir: The directory of the package that contains the generated python file.
368
+ :param model_dir: The directory where the generated python file is located.
238
369
  """
239
- rule_ids = [r.uid for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None]
240
- condition_func_names = [f'conditions_{rid}' for rid in rule_ids]
241
- conclusion_func_names = [f'conclusion_{rid}' for rid in rule_ids]
370
+ rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None}
371
+ condition_func_names = [f'conditions_{rid}' for rid in rules_dict.keys()]
372
+ conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys() if not isinstance(rules_dict[rid], MultiClassStopRule)]
242
373
  all_func_names = condition_func_names + conclusion_func_names
243
- filepath = f"{package_dir}/{self.generated_python_defs_file_name}.py"
374
+ filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
244
375
  functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
245
376
  # get the scope from the imports in the file
246
377
  scope = extract_imports(filepath)
@@ -248,7 +379,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
248
379
  if rule.conditions is not None:
249
380
  rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
250
381
  rule.conditions.scope = scope
251
- if rule.conclusion is not None:
382
+ if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
252
383
  rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
253
384
  rule.conclusion.scope = scope
254
385
 
@@ -265,17 +396,19 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
265
396
  """
266
397
  pass
267
398
 
268
- def write_to_python_file(self, file_path: str, postfix: str = ""):
399
+ def _write_to_python(self, model_dir: str):
269
400
  """
270
401
  Write the tree of rules as source code to a file.
271
402
 
272
- :param file_path: The path to the file to write the source code to.
273
- :param postfix: The postfix to add to the file name.
403
+ :param model_dir: The path to the directory to write the source code to.
274
404
  """
275
- self.generated_python_file_name = self._default_generated_python_file_name + postfix
405
+ os.makedirs(model_dir, exist_ok=True)
406
+ if not os.path.exists(model_dir + '/__init__.py'):
407
+ with open(model_dir + '/__init__.py', 'w') as f:
408
+ f.write("# This is an empty __init__.py file to make the directory a package.\n")
276
409
  func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
277
- file_name = file_path + f"/{self.generated_python_file_name}.py"
278
- defs_file_name = file_path + f"/{self.generated_python_defs_file_name}.py"
410
+ file_name = model_dir + f"/{self.generated_python_file_name}.py"
411
+ defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
279
412
  imports, defs_imports = self._get_imports()
280
413
  # clear the files first
281
414
  with open(defs_file_name, "w") as f:
@@ -326,56 +459,19 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
326
459
  imports = "\n".join(imports) + "\n"
327
460
  return imports, defs_imports
328
461
 
329
- def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
330
- """
331
- :param package_name: The name of the package that contains the RDR classifier function.
332
- :return: The module that contains the rdr classifier function.
333
- """
334
- # remove from imports if exists first
335
- name = f"{package_name.strip('./')}.{self.generated_python_file_name}"
336
- try:
337
- module = importlib.import_module(name)
338
- del sys.modules[name]
339
- except ModuleNotFoundError:
340
- pass
341
- return importlib.import_module(name).classify
342
-
343
462
  @property
344
- def _default_generated_python_file_name(self) -> str:
463
+ def _default_generated_python_file_name(self) -> Optional[str]:
345
464
  """
346
465
  :return: The default generated python file name.
347
466
  """
348
- if isinstance(self.start_rule.corner_case, Case):
349
- name = self.start_rule.corner_case._name
350
- else:
351
- name = self.start_rule.corner_case.__class__.__name__
352
- return f"{name.lower()}_{self.attribute_name}_{self.acronym.lower()}"
467
+ if self.start_rule is None or self.start_rule.conclusion is None:
468
+ return None
469
+ return f"{str_to_snake_case(self.case_name)}_{self.attribute_name}_{self.get_acronym().lower()}"
353
470
 
354
471
  @property
355
472
  def generated_python_defs_file_name(self) -> str:
356
473
  return f"{self.generated_python_file_name}_defs"
357
474
 
358
- @property
359
- def acronym(self) -> str:
360
- """
361
- :return: The acronym of the classifier.
362
- """
363
- if self.__class__.__name__ == "GeneralRDR":
364
- return "GRDR"
365
- elif self.__class__.__name__ == "MultiClassRDR":
366
- return "MCRDR"
367
- else:
368
- return "SCRDR"
369
-
370
- @property
371
- def case_type(self) -> Type:
372
- """
373
- :return: The type of the case (input) to the RDR classifier.
374
- """
375
- if isinstance(self.start_rule.corner_case, Case):
376
- return self.start_rule.corner_case._obj_type
377
- else:
378
- return type(self.start_rule.corner_case)
379
475
 
380
476
  @property
381
477
  def conclusion_type(self) -> Tuple[Type]:
@@ -396,7 +492,9 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
396
492
 
397
493
  def _to_json(self) -> Dict[str, Any]:
398
494
  return {"start_rule": self.start_rule.to_json(), "generated_python_file_name": self.generated_python_file_name,
399
- "name": self.name}
495
+ "name": self.name,
496
+ "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
497
+ "case_name": self.case_name}
400
498
 
401
499
  @classmethod
402
500
  def _from_json(cls, data: Dict[str, Any]) -> Self:
@@ -404,11 +502,15 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
404
502
  Create an instance of the class from a json
405
503
  """
406
504
  start_rule = cls.start_rule_type().from_json(data["start_rule"])
407
- new_rdr = cls(start_rule)
505
+ new_rdr = cls(start_rule=start_rule)
408
506
  if "generated_python_file_name" in data:
409
507
  new_rdr.generated_python_file_name = data["generated_python_file_name"]
410
508
  if "name" in data:
411
509
  new_rdr.name = data["name"]
510
+ if "case_type" in data:
511
+ new_rdr.case_type = get_type_from_string(data["case_type"])
512
+ if "case_name" in data:
513
+ new_rdr.case_name = data["case_name"]
412
514
  return new_rdr
413
515
 
414
516
  @staticmethod
@@ -422,12 +524,12 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
422
524
 
423
525
  class SingleClassRDR(RDRWithCodeWriter):
424
526
 
425
- def __init__(self, start_rule: Optional[SingleClassRule] = None, default_conclusion: Optional[Any] = None):
527
+ def __init__(self, default_conclusion: Optional[Any] = None, **kwargs):
426
528
  """
427
529
  :param start_rule: The starting rule for the classifier.
428
530
  :param default_conclusion: The default conclusion for the classifier if no rules fire.
429
531
  """
430
- super(SingleClassRDR, self).__init__(start_rule)
532
+ super(SingleClassRDR, self).__init__(**kwargs)
431
533
  self.default_conclusion: Optional[Any] = default_conclusion
432
534
 
433
535
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
@@ -479,10 +581,10 @@ class SingleClassRDR(RDRWithCodeWriter):
479
581
  matched_rule = self.start_rule(case) if self.start_rule is not None else None
480
582
  return matched_rule if matched_rule is not None else self.start_rule
481
583
 
482
- def write_to_python_file(self, file_path: str, postfix: str = ""):
483
- super().write_to_python_file(file_path, postfix)
584
+ def _write_to_python(self, model_dir: str):
585
+ super()._write_to_python(model_dir)
484
586
  if self.default_conclusion is not None:
485
- with open(file_path + f"/{self.generated_python_file_name}.py", "a") as f:
587
+ with open(model_dir + f"/{self.generated_python_file_name}.py", "a") as f:
486
588
  f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
487
589
 
488
590
  def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
@@ -758,7 +860,7 @@ class GeneralRDR(RippleDownRules):
758
860
  self.start_rules_dict: Dict[str, Union[SingleClassRDR, MultiClassRDR]] \
759
861
  = category_rdr_map if category_rdr_map else {}
760
862
  super(GeneralRDR, self).__init__(**kwargs)
761
- self.all_figs: List[plt.Figure] = [sr.fig for sr in self.start_rules_dict.values()]
863
+ self.all_figs: List[Figure] = [sr.fig for sr in self.start_rules_dict.values()]
762
864
 
763
865
  def add_rdr(self, rdr: Union[SingleClassRDR, MultiClassRDR], case_query: Optional[CaseQuery] = None):
764
866
  """
@@ -882,7 +984,9 @@ class GeneralRDR(RippleDownRules):
882
984
  def _to_json(self) -> Dict[str, Any]:
883
985
  return {"start_rules": {name: rdr.to_json() for name, rdr in self.start_rules_dict.items()}
884
986
  , "generated_python_file_name": self.generated_python_file_name,
885
- "name": self.name}
987
+ "name": self.name,
988
+ "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
989
+ "case_name": self.case_name}
886
990
 
887
991
  @classmethod
888
992
  def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
@@ -892,35 +996,37 @@ class GeneralRDR(RippleDownRules):
892
996
  start_rules_dict = {}
893
997
  for k, v in data["start_rules"].items():
894
998
  start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
895
- new_rdr = cls(start_rules_dict)
999
+ new_rdr = cls(category_rdr_map=start_rules_dict)
896
1000
  if "generated_python_file_name" in data:
897
1001
  new_rdr.generated_python_file_name = data["generated_python_file_name"]
898
1002
  if "name" in data:
899
1003
  new_rdr.name = data["name"]
1004
+ if "case_type" in data:
1005
+ new_rdr.case_type = get_type_from_string(data["case_type"])
1006
+ if "case_name" in data:
1007
+ new_rdr.case_name = data["case_name"]
900
1008
  return new_rdr
901
1009
 
902
- def update_from_python_file(self, package_dir: str) -> None:
1010
+ def update_from_python(self, model_dir: str) -> None:
903
1011
  """
904
1012
  Update the rules from the generated python file, that might have been modified by the user.
905
1013
 
906
- :param package_dir: The directory of the package that contains the generated python file.
1014
+ :param model_dir: The directory where the model is stored.
907
1015
  """
908
1016
  for rdr in self.start_rules_dict.values():
909
- rdr.update_from_python_file(package_dir)
1017
+ rdr.update_from_python(model_dir)
910
1018
 
911
- def write_to_python_file(self, file_path: str, postfix: str = "") -> None:
1019
+ def _write_to_python(self, model_dir: str) -> None:
912
1020
  """
913
1021
  Write the tree of rules as source code to a file.
914
1022
 
915
- :param file_path: The path to the file to write the source code to.
916
- :param postfix: The postfix to add to the file name.
1023
+ :param model_dir: The directory where the model is stored.
917
1024
  """
918
- self.generated_python_file_name = self._default_generated_python_file_name + postfix
919
1025
  for rdr in self.start_rules_dict.values():
920
- rdr.write_to_python_file(file_path, postfix=f"_of_grdr{postfix}")
1026
+ rdr._write_to_python(model_dir)
921
1027
  func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
922
- with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
923
- f.write(self._get_imports(file_path) + "\n\n")
1028
+ with open(model_dir + f"/{self.generated_python_file_name}.py", "w") as f:
1029
+ f.write(self._get_imports() + "\n\n")
924
1030
  f.write("classifiers_dict = dict()\n")
925
1031
  for rdr_key, rdr in self.start_rules_dict.items():
926
1032
  f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
@@ -930,25 +1036,6 @@ class GeneralRDR(RippleDownRules):
930
1036
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
931
1037
  f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
932
1038
 
933
- @property
934
- def case_type(self) -> Optional[Type]:
935
- """
936
- :return: The type of the case (input) to the RDR classifier.
937
- """
938
- if self.start_rule is None or self.start_rule.corner_case is None:
939
- return None
940
- if isinstance(self.start_rule.corner_case, Case):
941
- return self.start_rule.corner_case._obj_type
942
- else:
943
- return type(self.start_rule.corner_case)
944
-
945
- def get_rdr_classifier_from_python_file(self, file_path: str) -> Callable[[Any], Any]:
946
- """
947
- :param file_path: The path to the file that contains the RDR classifier function.
948
- :return: The module that contains the rdr classifier function.
949
- """
950
- return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
951
-
952
1039
  @property
953
1040
  def _default_generated_python_file_name(self) -> Optional[str]:
954
1041
  """
@@ -956,21 +1043,16 @@ class GeneralRDR(RippleDownRules):
956
1043
  """
957
1044
  if self.start_rule is None or self.start_rule.conclusion is None:
958
1045
  return None
959
- if isinstance(self.start_rule.corner_case, Case):
960
- name = self.start_rule.corner_case._name
961
- else:
962
- name = self.start_rule.corner_case.__class__.__name__
963
- return f"{name}_rdr".lower()
1046
+ return f"{str_to_snake_case(self.case_name)}_rdr".lower()
964
1047
 
965
1048
  @property
966
1049
  def conclusion_type_hint(self) -> str:
967
1050
  return f"Dict[str, Any]"
968
1051
 
969
- def _get_imports(self, file_path: str) -> str:
1052
+ def _get_imports(self) -> str:
970
1053
  """
971
1054
  Get the imports needed for the generated python file.
972
1055
 
973
- :param file_path: The path to the file that contains the RDR classifier function.
974
1056
  :return: The imports needed for the generated python file.
975
1057
  """
976
1058
  imports = ""