ripple-down-rules 0.6.0__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ripple_down_rules/rdr.py CHANGED
@@ -1,17 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import ast
3
4
  import importlib
5
+ import json
4
6
  import os
7
+ import sys
5
8
  from abc import ABC, abstractmethod
6
9
  from copy import copy
7
- from types import NoneType
10
+ from dataclasses import is_dataclass
11
+ from io import TextIOWrapper
12
+ from os.path import dirname
13
+ from pathlib import Path
14
+ from types import NoneType, ModuleType
8
15
 
9
16
  from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
10
17
  from . import logger
18
+ from .failures import RDRLoadError
11
19
 
12
20
  try:
13
21
  from matplotlib import pyplot as plt
14
-
15
22
  Figure = plt.Figure
16
23
  except ImportError as e:
17
24
  logger.debug(f"{e}: matplotlib is not installed")
@@ -25,18 +32,22 @@ from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tupl
25
32
  from .datastructures.callable_expression import CallableExpression
26
33
  from .datastructures.case import Case, CaseAttribute, create_case
27
34
  from .datastructures.dataclasses import CaseQuery
28
- from .datastructures.enums import MCRDRMode
35
+ from .datastructures.enums import MCRDRMode, RDREdge
29
36
  from .experts import Expert, Human
30
- from .helpers import is_matching, general_rdr_classify
31
- from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
37
+ from .helpers import is_matching, general_rdr_classify, get_an_updated_case_copy
38
+ from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule, MultiClassRefinementRule, \
39
+ MultiClassFilterRule
32
40
 
33
41
  try:
34
42
  from .user_interface.gui import RDRCaseViewer
35
43
  except ImportError as e:
36
44
  RDRCaseViewer = None
37
45
  from .utils import draw_tree, make_set, SubclassJSONSerializer, make_list, get_type_from_string, \
38
- is_conflicting, extract_function_source, extract_imports, get_full_class_name, \
39
- is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types
46
+ is_value_conflicting, extract_function_source, extract_imports, get_full_class_name, \
47
+ is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types, render_tree, \
48
+ get_types_to_import_from_func_type_hints, get_function_return_type, get_file_that_ends_with, \
49
+ get_and_import_python_module, get_and_import_python_modules_in_a_package, get_type_from_type_hint, \
50
+ are_results_subclass_of_types
40
51
 
41
52
 
42
53
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -80,20 +91,98 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
80
91
  Whether the output of the classification of this rdr allows only one possible conclusion or not.
81
92
  """
82
93
 
83
- def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None,
94
+ def __init__(self, start_rule: Optional[Rule] = None,
84
95
  save_dir: Optional[str] = None, model_name: Optional[str] = None):
85
96
  """
86
97
  :param start_rule: The starting rule for the classifier.
87
- :param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
88
98
  :param save_dir: The directory to save the classifier to.
89
99
  """
90
100
  self.model_name: Optional[str] = model_name
91
- self.save_dir = save_dir
92
- self.start_rule = start_rule
101
+ self.save_dir: Optional[str] = save_dir
102
+ self.start_rule: Optional[Rule] = start_rule
93
103
  self.fig: Optional[Figure] = None
94
- self.viewer: Optional[RDRCaseViewer] = viewer
95
- if self.viewer is not None:
96
- self.viewer.set_save_function(self.save)
104
+ self.viewer: Optional[RDRCaseViewer] = RDRCaseViewer.instances[0]\
105
+ if RDRCaseViewer and any(RDRCaseViewer.instances) else None
106
+ self.input_node: Optional[Rule] = None
107
+ self.update_model()
108
+
109
+ def update_model(self):
110
+ """
111
+ Update the model from the model directory if it exists.
112
+ """
113
+ if self.save_dir is not None and self.model_name is not None:
114
+ model_path = os.path.join(self.save_dir, self.model_name)
115
+ if os.path.exists(model_path):
116
+ try:
117
+ self.update_from_python(model_path, update_rule_tree=True)
118
+ except (FileNotFoundError, ModuleNotFoundError) as e:
119
+ pass
120
+
121
+ def write_rdr_metadata_to_pyton_file(self, file: TextIOWrapper):
122
+ """
123
+ Write the metadata of the RDR classifier to a python file.
124
+
125
+ :param file: The file to write the metadata to.
126
+ """
127
+ file.write(f"name = \'{self.name}\'\n")
128
+ file.write(f"case_type = {self.case_type.__name__ if self.case_type is not None else None}\n")
129
+ file.write(f"case_name = \'{self.case_name}\'\n")
130
+
131
+ def update_rdr_metadata_from_python(self, module: ModuleType):
132
+ """
133
+ Update the RDR metadata from the module that contains the RDR classifier function.
134
+
135
+ :param module: The module that contains the RDR classifier function.
136
+ """
137
+ try:
138
+ self.name = module.name if hasattr(module, "name") else self.start_rule.conclusion_name
139
+ self.case_type = module.case_type
140
+ self.case_name = module.case_name if hasattr(module, "case_name") else f"{self.case_type.__name__}.{self.name}"
141
+ except AttributeError as e:
142
+ logger.warning(f"Could not update the RDR metadata from the module {module.__name__}. "
143
+ f"Make sure the module has the required attributes: {e}")
144
+
145
+ def render_evaluated_rule_tree(self, filename: str, show_full_tree: bool = False) -> None:
146
+ if show_full_tree:
147
+ start_rule = self.start_rule if self.input_node is None else self.input_node
148
+ render_tree(start_rule, use_dot_exporter=True, filename=filename)
149
+ else:
150
+ evaluated_rules = self.get_evaluated_rule_tree()
151
+ if evaluated_rules is not None and len(evaluated_rules) > 0:
152
+ render_tree(evaluated_rules[0], use_dot_exporter=True, filename=filename,
153
+ only_nodes=evaluated_rules)
154
+
155
+ def get_contributing_rules(self) -> Optional[List[Rule]]:
156
+ """
157
+ Get the contributing rules of the classifier.
158
+
159
+ :return: The contributing rules.
160
+ """
161
+ if self.start_rule is None:
162
+ return None
163
+ return [r for r in self.get_fired_rule_tree() if r.contributed]
164
+
165
+ def get_fired_rule_tree(self) -> Optional[List[Rule]]:
166
+ """
167
+ Get the fired rule tree of the classifier.
168
+
169
+ :return: The fired rule tree.
170
+ """
171
+ if self.start_rule is None:
172
+ return None
173
+ return [r for r in self.get_evaluated_rule_tree() if r.fired]
174
+
175
+ def get_evaluated_rule_tree(self) -> Optional[List[Rule]]:
176
+ """
177
+ Get the evaluated rule tree of the classifier.
178
+
179
+ :return: The evaluated rule tree.
180
+ """
181
+ if self.start_rule is None:
182
+ return None
183
+ start_rule = self.start_rule
184
+ evaluated_rule_tree = [r for r in [start_rule] + list(start_rule.descendants) if r.evaluated]
185
+ return evaluated_rule_tree
97
186
 
98
187
  def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None,
99
188
  package_name: Optional[str] = None) -> str:
@@ -137,18 +226,43 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
137
226
  :param package_name: The name of the package that contains the RDR classifier function, this
138
227
  is required in case of relative imports in the generated python file.
139
228
  """
229
+ rdr: Optional[RippleDownRules] = None
140
230
  model_dir = os.path.join(load_dir, model_name)
141
231
  json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
142
- rdr = cls.from_json_file(json_file)
143
- rdr.save_dir = load_dir
144
- rdr.model_name = model_name
232
+ if os.path.exists(json_file + ".json"):
233
+ rdr = cls.from_json_file(json_file)
145
234
  try:
146
- rdr.update_from_python(model_dir, package_name=package_name)
235
+ if rdr is None:
236
+ rdr = cls.from_python(model_dir, parent_package_name=package_name)
237
+ else:
238
+ rdr.update_from_python(model_dir, parent_package_name=package_name)
147
239
  rdr.to_json_file(json_file)
148
- except (FileNotFoundError, ValueError, SyntaxError) as e:
240
+ except (FileNotFoundError, ValueError, SyntaxError, ModuleNotFoundError) as e:
149
241
  logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
150
242
  f"Make sure the file exists and is valid.")
243
+ if rdr is None:
244
+ raise RDRLoadError(f"Could not load the rdr model {model_name} from {model_dir}, error is {e}")
151
245
  rdr.save(save_dir=load_dir, model_name=model_name, package_name=package_name)
246
+ rdr.save_dir = load_dir
247
+ rdr.model_name = model_name
248
+ return rdr
249
+
250
+ @classmethod
251
+ def from_python(cls, model_path: str,
252
+ python_file_path: Optional[str] = None,
253
+ parent_package_name: Optional[str] = None) -> Self:
254
+ """
255
+ Load the RDR classifier from a generated python file.
256
+
257
+ :param model_path: The directory where the generated python file is located.
258
+ :param python_file_path: The path to the generated python file that contains the RDR classifier function.
259
+ :param parent_package_name: The name of the package that contains the RDR classifier function, this
260
+ is required in case of relative imports in the generated python file.
261
+ :return: An instance of the RDR classifier.
262
+ """
263
+ rdr = cls()
264
+ rdr.update_from_python(model_path, parent_package_name=parent_package_name, python_file_path=python_file_path,
265
+ update_rule_tree=True)
152
266
  return rdr
153
267
 
154
268
  @abstractmethod
@@ -162,16 +276,6 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
162
276
  """
163
277
  pass
164
278
 
165
- def set_viewer(self, viewer: RDRCaseViewer):
166
- """
167
- Set the viewer for the classifier.
168
-
169
- :param viewer: The viewer to set.
170
- """
171
- self.viewer = viewer
172
- if self.viewer is not None:
173
- self.viewer.set_save_function(self.save)
174
-
175
279
  def fit(self, case_queries: List[CaseQuery],
176
280
  expert: Optional[Expert] = None,
177
281
  n_iter: int = None,
@@ -196,7 +300,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
196
300
  num_rules: int = 0
197
301
  while not stop_iterating:
198
302
  for case_query in case_queries:
199
- pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
303
+ pred_cat = self.fit_case(case_query, expert=expert, clear_expert_answers=False,
304
+ **kwargs_for_fit_case)
200
305
  if case_query.target is None:
201
306
  continue
202
307
  target = {case_query.attribute_name: case_query.target(case_query.case)}
@@ -226,9 +331,37 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
226
331
  def __call__(self, case: Union[Case, SQLTable]) -> Union[CallableExpression, Dict[str, CallableExpression]]:
227
332
  return self.classify(case)
228
333
 
334
+ def classify(self, case: Union[Case, SQLTable], modify_case: bool = False, case_query: Optional[CaseQuery] = None) \
335
+ -> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
336
+ """
337
+ Classify a case using the RDR classifier.
338
+
339
+ :param case: The case to classify.
340
+ :param modify_case: Whether to modify the original case attributes with the conclusion or not.
341
+ :param case_query: The case query containing the case to classify and the target category to compare the case with.
342
+ :return: The category that the case belongs to.
343
+ """
344
+ if self.start_rule is not None:
345
+ for rule in [self.start_rule] + list(self.start_rule.descendants):
346
+ rule.reset()
347
+ if self.start_rule is not None and self.start_rule.parent is None:
348
+ if self.input_node is None:
349
+ self.input_node = type(self.start_rule)(_parent=None, uid='0')
350
+ self.input_node.evaluated = False
351
+ self.input_node.fired = False
352
+ self.start_rule.parent = self.input_node
353
+ self.start_rule.weight = RDREdge.Empty
354
+ if self.input_node is not None:
355
+ data = case.__dict__ if is_dataclass(case) else case
356
+ if hasattr(case, "items"):
357
+ self.input_node.name = json.dumps({k: str(v) for k, v in data.items()}, indent=4)
358
+ else:
359
+ self.input_node.name = str(data)
360
+ return self._classify(case, modify_case=modify_case, case_query=case_query)
361
+
229
362
  @abstractmethod
230
- def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
231
- case_query: Optional[CaseQuery] = None) \
363
+ def _classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
364
+ case_query: Optional[CaseQuery] = None) \
232
365
  -> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
233
366
  """
234
367
  Classify a case.
@@ -244,6 +377,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
244
377
  expert: Optional[Expert] = None,
245
378
  update_existing_rules: bool = True,
246
379
  scenario: Optional[Callable] = None,
380
+ ask_now: Callable = lambda _: True,
381
+ clear_expert_answers: bool = True,
247
382
  **kwargs) \
248
383
  -> Union[CallableExpression, Dict[str, CallableExpression]]:
249
384
  """
@@ -255,6 +390,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
255
390
  :param update_existing_rules: Whether to update the existing same conclusion type rules that already gave
256
391
  some conclusions with the type required by the case query.
257
392
  :param scenario: The scenario at which the case was created, this is used to recreate the case if needed.
393
+ :param ask_now: Whether to ask the expert for refinements or alternatives.
394
+ :param clear_expert_answers: Whether to clear expert answers after saving the new rule.
258
395
  :return: The category that the case belongs to.
259
396
  """
260
397
  if case_query is None:
@@ -264,14 +401,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
264
401
  self.case_type = case_query.case_type if self.case_type is None else self.case_type
265
402
  self.case_name = case_query.case_name if self.case_name is None else self.case_name
266
403
  case_query.scenario = scenario if case_query.scenario is None else case_query.scenario
404
+ case_query.rdr = self
267
405
 
268
- expert = expert or Human(viewer=self.viewer,
269
- answers_save_path=self.save_dir + '/expert_answers'
406
+ expert = expert or Human(answers_save_path=self.save_dir + '/expert_answers'
270
407
  if self.save_dir else None)
271
408
  if case_query.target is None:
272
409
  case_query_cp = copy(case_query)
273
410
  conclusions = self.classify(case_query_cp.case, modify_case=True, case_query=case_query_cp)
274
- if self.should_i_ask_the_expert_for_a_target(conclusions, case_query_cp, update_existing_rules):
411
+ if (self.should_i_ask_the_expert_for_a_target(conclusions, case_query_cp, update_existing_rules)
412
+ and ask_now(case_query_cp.case)):
275
413
  expert.ask_for_conclusion(case_query_cp)
276
414
  case_query.target = case_query_cp.target
277
415
  if case_query.target is None:
@@ -283,7 +421,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
283
421
 
284
422
  if self.save_dir is not None:
285
423
  self.save()
286
- expert.clear_answers()
424
+ if clear_expert_answers:
425
+ expert.clear_answers()
287
426
 
288
427
  return fit_case_result
289
428
 
@@ -356,6 +495,61 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
356
495
  def type_(self):
357
496
  return self.__class__
358
497
 
498
+ @classmethod
499
+ def get_json_file_path(cls, model_path: str) -> str:
500
+ """
501
+ Get the path to the saved json file.
502
+
503
+ :param model_path : The path to the model directory.
504
+ :return: The path to the saved model.
505
+ """
506
+ model_name = cls.get_model_name_from_model_path(model_path)
507
+ return os.path.join(model_path, cls.metadata_folder, f"{model_name}.json")
508
+
509
+ @classmethod
510
+ def get_generated_cases_file_path(cls, model_path: str) -> str:
511
+ """
512
+ Get the path to the python file that contains the RDR classifier cases.
513
+
514
+ :param model_path : The path to the model directory.
515
+ :return: The path to the generated python file.
516
+ """
517
+ return cls.get_generated_python_file_path(model_path).replace(".py", "_cases.py")
518
+
519
+ @classmethod
520
+ def get_generated_defs_file_path(cls, model_path: str) -> str:
521
+ """
522
+ Get the path to the python file that contains the RDR classifier function definitions.
523
+
524
+ :param model_path : The path to the model directory.
525
+ :return: The path to the generated python file.
526
+ """
527
+ return cls.get_generated_python_file_path(model_path).replace(".py", "_defs.py")
528
+
529
+ @classmethod
530
+ def get_generated_python_file_path(cls, model_path: str) -> str:
531
+ """
532
+ Get the path to the python file that contains the RDR classifier function.
533
+
534
+ :param model_path : The path to the model directory.
535
+ :return: The path to the generated python file.
536
+ """
537
+ model_name = cls.get_model_name_from_model_path(model_path)
538
+ return os.path.join(model_path, f"{model_name}.py")
539
+
540
+ @classmethod
541
+ def get_model_name_from_model_path(cls, model_path: str) -> str:
542
+ """
543
+ Get the model name from the model path.
544
+
545
+ :param model_path: The path to the model directory.
546
+ :return: The name of the model.
547
+ """
548
+ file_name = get_file_that_ends_with(model_path, f"_{cls.get_acronym().lower()}.py")
549
+ if file_name is None:
550
+ raise FileNotFoundError(f"Could not find the python file for the model in the given path: {model_path}.")
551
+ return file_name.replace('.py', '')
552
+
359
553
  @property
360
554
  def generated_python_file_name(self) -> str:
361
555
  if self._generated_python_file_name is None:
@@ -379,13 +573,17 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
379
573
  pass
380
574
 
381
575
  @abstractmethod
382
- def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
576
+ def update_from_python(self, model_dir: str, parent_package_name: Optional[str] = None,
577
+ python_file_path: Optional[str] = None,
578
+ update_rule_tree: bool = False):
383
579
  """
384
580
  Update the rules from the generated python file, that might have been modified by the user.
385
581
 
386
582
  :param model_dir: The directory where the generated python file is located.
387
- :param package_name: The name of the package that contains the RDR classifier function, this
583
+ :param parent_package_name: The name of the package that contains the RDR classifier function, this
388
584
  is required in case of relative imports in the generated python file.
585
+ :param python_file_path: The path to the generated python file that contains the RDR classifier function.
586
+ :param update_rule_tree: Whether to update the rule tree from the python file or not.
389
587
  """
390
588
  pass
391
589
 
@@ -414,41 +612,363 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
414
612
  return module.classify
415
613
 
416
614
 
615
+ class TreeBuilder(ast.NodeVisitor, ABC):
616
+ """Parses an AST of nested if-elif statements and reconstructs the tree."""
617
+
618
+ def __init__(self):
619
+ self.root: Optional[Rule] = None
620
+ self.current_parent: Optional[Rule] = None
621
+ self.current_edge: Optional[RDREdge] = None
622
+ self.default_conclusion: Optional[str] = None
623
+
624
+ def visit_FunctionDef(self, node):
625
+ """Finds the main function and starts parsing its body."""
626
+ for stmt in node.body:
627
+ self.visit(stmt)
628
+
629
+ def visit_If(self, node):
630
+ """Handles if-elif blocks and creates nodes."""
631
+ condition = self.get_condition_name(node.test)
632
+ if condition is None:
633
+ return
634
+ rule_uid = condition.split("conditions_")[1]
635
+
636
+ new_rule_type = self.get_new_rule_type(node)
637
+ new_node = new_rule_type(conditions=condition, _parent=self.current_parent, uid=rule_uid)
638
+ if self.current_parent is not None:
639
+ self.update_current_parent(new_node)
640
+
641
+ if self.current_parent is None and self.root is None:
642
+ self.root = new_node
643
+
644
+ self.current_parent = new_node
645
+
646
+ # Parse the body of the if statement
647
+ for stmt in node.body:
648
+ self.current_edge = self.get_refinement_edge(node)
649
+ self.current_parent = new_node
650
+ self.visit(stmt)
651
+
652
+ # Parse elif/else
653
+ for stmt in node.orelse:
654
+ self.current_edge = self.get_alternative_edge(node)
655
+ self.current_parent = new_node
656
+ if isinstance(stmt, ast.If): # elif case
657
+ self.visit_If(stmt)
658
+ else: # else case (return)
659
+ self.process_else_statement(stmt)
660
+ self.current_parent = new_node
661
+ self.current_edge = None
662
+
663
+ @abstractmethod
664
+ def process_else_statement(self, stmt: ast.AST):
665
+ """
666
+ Process the else statement in the if-elif-else block.
667
+
668
+ :param stmt: The else statement to process.
669
+ """
670
+ pass
671
+
672
+ @abstractmethod
673
+ def get_refinement_edge(self, node: ast.AST) -> RDREdge:
674
+ """
675
+ :param node: The current AST node to determine the edge type from.
676
+ :return: The refinement edge type.
677
+ """
678
+ pass
679
+
680
+ @abstractmethod
681
+ def get_alternative_edge(self, node: ast.AST) -> RDREdge:
682
+ """
683
+ :param node: The current AST node to determine the alternative edge type from.
684
+ :return: The alternative edge type.
685
+ """
686
+ pass
687
+
688
+ @abstractmethod
689
+ def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
690
+ """
691
+ Get the new rule type to create.
692
+ :param node: The current AST node to determine the rule type from.
693
+ :return: The new rule type.
694
+ """
695
+ pass
696
+
697
+ @abstractmethod
698
+ def update_current_parent(self, new_node: Rule):
699
+ """
700
+ Update the current parent rule with the new node.
701
+ :param new_node: The new node to set as the current parent.
702
+ """
703
+ pass
704
+
705
+ def visit_Return(self, node):
706
+ """Handles return statements as leaf nodes."""
707
+ if isinstance(node.value, ast.Call):
708
+ return_value = node.value.func.id
709
+ else:
710
+ return_value = ast.literal_eval(node.value)
711
+ if self.current_parent is None:
712
+ self.default_conclusion = return_value
713
+ else:
714
+ self.current_parent.conclusion = return_value
715
+
716
+ def get_condition_name(self, node):
717
+ """Extracts the condition function name from an AST expression."""
718
+ if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
719
+ return node.func.id
720
+ return None
721
+
722
+ class SingleClassTreeBuilder(TreeBuilder):
723
+ """Parses an AST of generated SingleClassRDR classifier and reconstructs the rdr tree."""
724
+
725
+ def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
726
+ return SingleClassRule
727
+
728
+ def get_refinement_edge(self, node: ast.AST) -> RDREdge:
729
+ return RDREdge.Refinement
730
+
731
+ def get_alternative_edge(self, node: ast.AST) -> RDREdge:
732
+ return RDREdge.Alternative
733
+
734
+ def update_current_parent(self, new_node: Rule):
735
+ if self.current_edge == RDREdge.Alternative:
736
+ self.current_parent.alternative = new_node
737
+ elif self.current_edge == RDREdge.Refinement:
738
+ self.current_parent.refinement = new_node
739
+
740
+ def process_else_statement(self, stmt: ast.AST):
741
+ """Handles the else statement in the if-elif-else block."""
742
+ if isinstance(stmt, ast.Return):
743
+ self.current_parent = None
744
+ self.visit_Return(stmt)
745
+ else:
746
+ raise ValueError(f"Unexpected statement in else block: {stmt}")
747
+
748
+
749
+ class MultiClassTreeBuilder(TreeBuilder):
750
+ """Parses an AST of generated MultiClassRDR classifier and reconstructs the rdr tree."""
751
+
752
+ def visit_If(self, stmt: ast.If):
753
+ super().visit_If(stmt)
754
+ if isinstance(self.current_parent, (MultiClassTopRule, MultiClassFilterRule)):
755
+ self.current_parent.conclusion = self.current_parent.conditions.replace("conditions_", "conclusion_")
756
+
757
+ def visit_Return(self, node):
758
+ pass
759
+
760
+ def get_new_rule_type(self, node: ast.AST) -> Type[Rule]:
761
+ if self.current_edge == RDREdge.Refinement:
762
+ return MultiClassStopRule
763
+ elif self.current_edge == RDREdge.Filter:
764
+ return MultiClassFilterRule
765
+ elif self.current_edge in [RDREdge.Next, None]:
766
+ return MultiClassTopRule
767
+ elif self.current_edge == RDREdge.Alternative:
768
+ return self.get_refinement_rule_type(node)
769
+ else:
770
+ raise ValueError(f"Unknown edge type: {self.current_edge}")
771
+
772
+ def get_alternative_edge(self, node: ast.AST) -> RDREdge:
773
+ if isinstance(self.current_parent, MultiClassTopRule):
774
+ return RDREdge.Next
775
+ else:
776
+ return RDREdge.Alternative
777
+
778
+ def get_refinement_edge(self, node: ast.AST) -> RDREdge:
779
+ rule_type = self.get_refinement_rule_type(node)
780
+ return self.get_refinement_edge_from_refinement_rule(rule_type)
781
+
782
+ def get_refinement_edge_from_refinement_rule(self, rule_type: Type[Rule]) -> RDREdge:
783
+ """
784
+ :param rule_type: The type of the rule to determine the refinement edge from.
785
+ :return: The refinement edge type based on the rule type.
786
+ """
787
+ if isinstance(self.current_parent, MultiClassRefinementRule):
788
+ return RDREdge.Alternative
789
+ if rule_type == MultiClassStopRule:
790
+ return RDREdge.Refinement
791
+ else:
792
+ return RDREdge.Filter
793
+
794
+ def get_refinement_rule_type(self, node: ast.AST) -> Type[Rule]:
795
+ """
796
+ :param node: The current AST node to determine the rule type from.
797
+ :return: The rule type based on the node body.
798
+ """
799
+ for stmt in node.body:
800
+ if len(node.body) == 1 and isinstance(stmt, ast.Pass):
801
+ return MultiClassStopRule
802
+ elif isinstance(stmt, ast.If):
803
+ return self.get_refinement_rule_type(stmt)
804
+ else:
805
+ return MultiClassFilterRule
806
+ raise ValueError(f"Could not determine the refinement rule type from the node: {node} as it has an empty body.")
807
+
808
+ def update_current_parent(self, new_node: Rule):
809
+ if isinstance(new_node, MultiClassRefinementRule):
810
+ if isinstance(self.current_parent, MultiClassTopRule):
811
+ new_node.top_rule = self.current_parent
812
+ elif hasattr(self.current_parent, "top_rule"):
813
+ new_node.top_rule = self.current_parent.top_rule
814
+ else:
815
+ raise ValueError(f"Could not set the top rule for the refinement rule: {new_node}")
816
+ if self.current_edge in [RDREdge.Alternative, RDREdge.Next, None]:
817
+ self.current_parent.alternative = new_node
818
+ elif self.current_edge in [RDREdge.Refinement, RDREdge.Filter]:
819
+ self.current_parent.refinement = new_node
820
+
821
+ def process_else_statement(self, stmt: ast.AST):
822
+ """Handles the else statement in the if-elif-else block."""
823
+ pass
824
+
417
825
  class RDRWithCodeWriter(RippleDownRules, ABC):
418
826
 
419
- def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
827
+ @classmethod
828
+ def read_rule_tree_from_python(cls, model_path: str, python_file_path: Optional[str] = None) -> Rule:
829
+ """
830
+ :param model_path: The path to the generated python file that contains the RDR classifier function.
831
+ :param python_file_path: The path to the generated python file that contains the RDR classifier function.
832
+ """
833
+ if python_file_path is None:
834
+ python_file_path = cls.get_generated_python_file_path(model_path)
835
+ with open(python_file_path, "r") as f:
836
+ source_code = f.read()
837
+
838
+ tree = ast.parse(source_code)
839
+ builder = cls.get_tree_builder_class()()
840
+
841
+ # Find and process the function
842
+ for node in tree.body:
843
+ if isinstance(node, ast.FunctionDef) and node.name == "classify":
844
+ builder.visit_FunctionDef(node)
845
+
846
+ return builder.root
847
+
848
+ @classmethod
849
+ @abstractmethod
850
+ def get_tree_builder_class(cls) -> Type[TreeBuilder]:
851
+ """
852
+ :return: The class that builds the rule tree from the generated python file.
853
+ This should be either SingleClassTreeBuilder or MultiClassTreeBuilder.
854
+ """
855
+ pass
856
+
857
+ @property
858
+ def all_rules(self) -> List[Rule]:
859
+ """
860
+ Get all rules in the classifier.
861
+
862
+ :return: A list of all rules in the classifier.
863
+ """
864
+ if self.start_rule is None:
865
+ return []
866
+ return [r for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None]
867
+
868
+ def update_from_python(self, model_dir: str, parent_package_name: Optional[str] = None,
869
+ python_file_path: Optional[str] = None, update_rule_tree: bool = False):
420
870
  """
421
871
  Update the rules from the generated python file, that might have been modified by the user.
422
872
 
423
873
  :param model_dir: The directory where the generated python file is located.
424
- :param package_name: The name of the package that contains the RDR classifier function, this
874
+ :param parent_package_name: The name of the package that contains the RDR classifier function, this
425
875
  is required in case of relative imports in the generated python file.
426
- """
427
- rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants)
428
- if r.conditions is not None}
429
- condition_func_names = [f'conditions_{rid}' for rid in rules_dict.keys()]
430
- conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys()
431
- if not isinstance(rules_dict[rid], MultiClassStopRule)]
876
+ :param python_file_path: The path to the generated python file that contains the RDR classifier function.
877
+ :param update_rule_tree: Whether to update the rule tree from the python file or not.
878
+ """
879
+ if update_rule_tree:
880
+ self.start_rule = self.read_rule_tree_from_python(model_dir, python_file_path=python_file_path)
881
+ all_rules = self.all_rules
882
+ condition_func_names = [rule.generated_conditions_function_name for rule in all_rules]
883
+ conclusion_func_names = [rule.generated_conclusion_function_name for rule in all_rules
884
+ if not isinstance(rule, MultiClassStopRule)]
432
885
  all_func_names = condition_func_names + conclusion_func_names
433
- filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
434
- cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
435
- cases_import_path = get_import_path_from_path(model_dir)
436
- cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path \
437
- else self.generated_python_cases_file_name
438
- functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
439
- # get the scope from the imports in the file
440
- scope = extract_imports(filepath, package_name=package_name)
441
- for rule in [self.start_rule] + list(self.start_rule.descendants):
886
+
887
+ main_module, defs_module, cases_module = self.get_and_import_model_python_modules(
888
+ model_dir,
889
+ python_file_path=python_file_path,
890
+ parent_package_name=parent_package_name)
891
+ self.generated_python_file_name = Path(main_module.__file__).name.replace(".py", "")
892
+
893
+ self.update_rdr_metadata_from_python(main_module)
894
+
895
+ functions_source = extract_function_source(defs_module.__file__,
896
+ all_func_names, include_signature=True)
897
+ scope = extract_imports(defs_module.__file__, package_name=parent_package_name)
898
+
899
+ cases_source, cases_scope = None, None
900
+ if cases_module:
901
+ with open(cases_module.__file__, "r") as f:
902
+ cases_source = f.read()
903
+ cases_scope = extract_imports(cases_module.__file__, package_name=parent_package_name)
904
+
905
+ with open(main_module.__file__, "r") as f:
906
+ main_source = f.read()
907
+ main_scope = extract_imports(main_module.__file__, package_name=parent_package_name)
908
+ attribute_name_line = [l for l in main_source.split('\n') if "attribute_name = " in l]
909
+ conclusion_name = None
910
+ if len(attribute_name_line) > 0:
911
+ conclusion_name = eval(attribute_name_line[0].split('=')[-1].strip(), main_scope)
912
+
913
+ for rule in all_rules:
442
914
  if rule.conditions is not None:
443
- rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
444
- rule.conditions.scope = scope
445
- if os.path.exists(cases_path):
446
- module = importlib.import_module(cases_import_path, package=package_name)
447
- importlib.reload(module)
448
- rule.corner_case_metadata = module.__dict__.get(f"corner_case_{rule.uid}", None)
449
- if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
450
- rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
451
- rule.conclusion.scope = scope
915
+ conditions_wrapper_func_name = rule.generated_conditions_function_name
916
+ user_input = functions_source[conditions_wrapper_func_name]
917
+ user_input = '\n'.join(user_input.split("\n")[1:]) # Remove the function signature line
918
+ rule.conditions = CallableExpression(user_input, (bool,), scope=scope)
919
+ if cases_module:
920
+ try:
921
+ rule.corner_case_metadata = cases_module.__dict__[rule.generated_corner_case_object_name]
922
+ except KeyError:
923
+ case_def_lines = [l for l in cases_source.split('\n') if rule.generated_corner_case_object_name in l]
924
+ if len(case_def_lines) > 0:
925
+ case_def_line = case_def_lines[0].split('=')[-1].strip()
926
+ rule.corner_case_metadata = eval(case_def_line, cases_scope)
927
+
928
+ if not isinstance(rule, MultiClassStopRule):
929
+ if conclusion_name:
930
+ rule.conclusion_name = conclusion_name
931
+ user_input = functions_source[rule.generated_conclusion_function_name]
932
+ split_user_input = user_input.split("\n")
933
+ user_input = '\n'.join(split_user_input[1:])
934
+ conclusion_func = defs_module.__dict__.get(rule.generated_conclusion_function_name)
935
+ if conclusion_func is None:
936
+ function_signature = split_user_input[0]
937
+ return_type_hint_str = function_signature.split('->')[-1].strip(' :')
938
+ return_type_hint = eval(return_type_hint_str, scope)
939
+ conclusion_type = get_type_from_type_hint(return_type_hint)
940
+ else:
941
+ conclusion_type = get_function_return_type(conclusion_func)
942
+ rule.conclusion = CallableExpression(user_input, conclusion_type, scope=scope,
943
+ mutually_exclusive=self.mutually_exclusive)
944
+
945
+ @classmethod
946
+ def get_and_import_model_python_modules(cls, model_dir: str,
947
+ python_file_path: Optional[str] = None,
948
+ parent_package_name: Optional[str] = None)\
949
+ -> Tuple[ModuleType, ModuleType, ModuleType]:
950
+ """
951
+ Get and import the python modules that contain the RDR classifier function, definitions, and corner cases.
952
+
953
+ :param model_dir: The path to the directory where the generated python files are located.
954
+ :param python_file_path: The path to the generated python file that contains the RDR classifier function.
955
+ :param parent_package_name: The name of the package that contains the RDR classifier function, this
956
+ is required in case of relative imports in the generated python file.
957
+ :return: A tuple containing the main module, defs module, and cases module.
958
+ """
959
+ if python_file_path is None:
960
+ main_file_path = cls.get_generated_python_file_path(model_dir)
961
+ else:
962
+ main_file_path = python_file_path
963
+ if not os.path.exists(main_file_path):
964
+ raise ModuleNotFoundError(main_file_path)
965
+
966
+ defs_file_path = main_file_path.replace(".py", "_defs.py")
967
+ cases_path = main_file_path.replace(".py", "_cases.py")
968
+
969
+ main_module, defs_module, cases_module = get_and_import_python_modules_in_a_package(
970
+ [main_file_path, defs_file_path, cases_path], parent_package_name=parent_package_name)
971
+ return main_module, defs_module, cases_module
452
972
 
453
973
  @abstractmethod
454
974
  def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
@@ -492,6 +1012,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
492
1012
  defs_imports = get_imports_from_types(defs_types, defs_file_name, package_name)
493
1013
  corner_cases_imports = get_imports_from_types(corner_cases_types, cases_file_name, package_name)
494
1014
 
1015
+ defs_imports.append(f"from ripple_down_rules import *")
495
1016
  # Add the imports to the defs file
496
1017
  with open(defs_file_name, "w") as f:
497
1018
  f.write('\n'.join(defs_imports) + "\n\n\n")
@@ -511,6 +1032,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
511
1032
  f.write(f"attribute_name = '{self.attribute_name}'\n")
512
1033
  f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
513
1034
  f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
1035
+ self.write_rdr_metadata_to_pyton_file(f)
514
1036
  f.write(f"\n\n{func_def}")
515
1037
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
516
1038
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
@@ -527,7 +1049,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
527
1049
  """
528
1050
  pass
529
1051
 
530
- def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
1052
+ def _get_types_to_import(self) -> Tuple[Set[Union[Type, Callable]], Set[Type], Set[Type]]:
531
1053
  """
532
1054
  :return: The types of the main, defs, and corner cases files of the RDR classifier that will be imported.
533
1055
  """
@@ -550,7 +1072,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
550
1072
  main_types.update({Union, Optional})
551
1073
  defs_types.add(Union)
552
1074
  main_types.update({Case, create_case})
553
- main_types = main_types.difference(defs_types)
1075
+ # main_types = main_types.difference(defs_types)
554
1076
  return main_types, defs_types, cases_types
555
1077
 
556
1078
  @property
@@ -576,9 +1098,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
576
1098
  :return: The type of the conclusion of the RDR classifier.
577
1099
  """
578
1100
  all_types = []
579
- if self.start_rule is not None:
580
- for rule in [self.start_rule] + list(self.start_rule.descendants):
581
- all_types.extend(list(rule.conclusion.conclusion_type))
1101
+ for rule in self.all_rules:
1102
+ all_types.extend(list(rule.conclusion.conclusion_type))
582
1103
  return tuple(set(all_types))
583
1104
 
584
1105
  @property
@@ -635,6 +1156,10 @@ class SingleClassRDR(RDRWithCodeWriter):
635
1156
  super(SingleClassRDR, self).__init__(**kwargs)
636
1157
  self.default_conclusion: Optional[Any] = default_conclusion
637
1158
 
1159
+ @classmethod
1160
+ def get_tree_builder_class(cls) -> Type[TreeBuilder]:
1161
+ return SingleClassTreeBuilder
1162
+
638
1163
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
639
1164
  -> Union[CaseAttribute, CallableExpression, None]:
640
1165
  """
@@ -649,7 +1174,7 @@ class SingleClassRDR(RDRWithCodeWriter):
649
1174
  self.default_conclusion = case_query.default_value
650
1175
 
651
1176
  pred = self.evaluate(case_query.case)
652
- if pred.conclusion(case_query.case) != case_query.target_value:
1177
+ if (not pred.fired and self.default_conclusion is None) or pred.conclusion(case_query.case) != case_query.target_value:
653
1178
  expert.ask_for_conditions(case_query, pred)
654
1179
  pred.fit_rule(case_query)
655
1180
 
@@ -666,8 +1191,8 @@ class SingleClassRDR(RDRWithCodeWriter):
666
1191
  expert.ask_for_conditions(case_query)
667
1192
  self.start_rule = SingleClassRule.from_case_query(case_query)
668
1193
 
669
- def classify(self, case: Case, modify_case: bool = False,
670
- case_query: Optional[CaseQuery] = None) -> Optional[Any]:
1194
+ def _classify(self, case: Case, modify_case: bool = False,
1195
+ case_query: Optional[CaseQuery] = None) -> Optional[Any]:
671
1196
  """
672
1197
  Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
673
1198
 
@@ -677,6 +1202,11 @@ class SingleClassRDR(RDRWithCodeWriter):
677
1202
  """
678
1203
  pred = self.evaluate(case)
679
1204
  conclusion = pred.conclusion(case) if pred is not None and pred.fired else self.default_conclusion
1205
+ if pred is not None and pred.fired:
1206
+ pred.contributed = True
1207
+ pred.last_conclusion = conclusion
1208
+ if case_query is not None:
1209
+ pred.contributed_to_case_query = True
680
1210
  if pred is not None and pred.fired and case_query is not None:
681
1211
  if pred.corner_case_metadata is None and conclusion is not None \
682
1212
  and type(conclusion) in case_query.core_attribute_type:
@@ -781,8 +1311,12 @@ class MultiClassRDR(RDRWithCodeWriter):
781
1311
  super(MultiClassRDR, self).__init__(start_rule, **kwargs)
782
1312
  self.mode: MCRDRMode = mode
783
1313
 
784
- def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
785
- case_query: Optional[CaseQuery] = None) -> Set[Any]:
1314
+ @classmethod
1315
+ def get_tree_builder_class(cls) -> Type[TreeBuilder]:
1316
+ return MultiClassTreeBuilder
1317
+
1318
+ def _classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
1319
+ case_query: Optional[CaseQuery] = None) -> Set[Any]:
786
1320
  evaluated_rule = self.start_rule
787
1321
  self.conclusions = []
788
1322
  while evaluated_rule:
@@ -794,6 +1328,13 @@ class MultiClassRDR(RDRWithCodeWriter):
794
1328
  and any(
795
1329
  ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
796
1330
  evaluated_rule.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
1331
+ if rule_conclusion is not None and any(make_list(rule_conclusion)):
1332
+ evaluated_rule.contributed = True
1333
+ evaluated_rule.last_conclusion = rule_conclusion
1334
+ if case_query is not None:
1335
+ rule_conclusion_types = set(map(type, make_list(rule_conclusion)))
1336
+ if are_results_subclass_of_types(rule_conclusion_types, case_query.core_attribute_type):
1337
+ evaluated_rule.contributed_to_case_query = True
797
1338
  self.add_conclusion(rule_conclusion)
798
1339
  evaluated_rule = next_rule
799
1340
  return make_set(self.conclusions)
@@ -860,6 +1401,9 @@ class MultiClassRDR(RDRWithCodeWriter):
860
1401
  if rule.alternative:
861
1402
  self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
862
1403
  cases_file=cases_file, package_name=package_name)
1404
+ elif isinstance(rule, MultiClassTopRule):
1405
+ with open(filename, "a") as file:
1406
+ file.write(f"{parent_indent}return conclusions\n")
863
1407
 
864
1408
  @property
865
1409
  def conclusion_type_hint(self) -> str:
@@ -869,8 +1413,9 @@ class MultiClassRDR(RDRWithCodeWriter):
869
1413
  else:
870
1414
  return f"Set[Union[{', '.join(conclusion_types)}]]"
871
1415
 
872
- def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
1416
+ def _get_types_to_import(self) -> Tuple[Set[Union[Type, Callable]], Set[Type], Set[Type]]:
873
1417
  main_types, defs_types, cases_types = super()._get_types_to_import()
1418
+ main_types.add(get_an_updated_case_copy)
874
1419
  main_types.update({Set, make_set})
875
1420
  defs_types.update({List, Set})
876
1421
  return main_types, defs_types, cases_types
@@ -884,7 +1429,7 @@ class MultiClassRDR(RDRWithCodeWriter):
884
1429
  """
885
1430
  if not self.start_rule:
886
1431
  conditions = expert.ask_for_conditions(case_query)
887
- self.start_rule = MultiClassTopRule.from_case_query(case_query)
1432
+ self.start_rule: MultiClassTopRule = MultiClassTopRule.from_case_query(case_query)
888
1433
 
889
1434
  @property
890
1435
  def last_top_rule(self) -> Optional[MultiClassTopRule]:
@@ -902,28 +1447,43 @@ class MultiClassRDR(RDRWithCodeWriter):
902
1447
  Stop a wrong conclusion by adding a stopping rule.
903
1448
  """
904
1449
  rule_conclusion = evaluated_rule.conclusion(case_query.case)
905
- if is_conflicting(rule_conclusion, case_query.target_value):
906
- self.stop_conclusion(case_query, expert, evaluated_rule)
907
- else:
1450
+ stop: bool = False
1451
+ add_filter_rule: bool = False
1452
+ if is_value_conflicting(rule_conclusion, case_query.target_value):
1453
+ if make_set(case_query.target_value).issubset(rule_conclusion):
1454
+ add_filter_rule = True
1455
+ else:
1456
+ stop = True
1457
+ elif make_set(case_query.core_attribute_type).issubset(make_set(evaluated_rule.conclusion.conclusion_type)):
1458
+ if make_set(case_query.target_value).issubset(rule_conclusion):
1459
+ add_filter_rule = True
1460
+
1461
+ if not stop:
908
1462
  self.add_conclusion(rule_conclusion)
1463
+ if stop or add_filter_rule:
1464
+ refinement_type = MultiClassStopRule if stop else MultiClassFilterRule
1465
+ self.stop_or_filter_conclusion(case_query, expert, evaluated_rule, refinement_type=refinement_type)
909
1466
 
910
- def stop_conclusion(self, case_query: CaseQuery,
911
- expert: Expert, evaluated_rule: MultiClassTopRule):
1467
+ def stop_or_filter_conclusion(self, case_query: CaseQuery,
1468
+ expert: Expert, evaluated_rule: MultiClassTopRule,
1469
+ refinement_type: Type[MultiClassRefinementRule] = MultiClassStopRule):
912
1470
  """
913
1471
  Stop a conclusion by adding a stopping rule.
914
1472
 
915
1473
  :param case_query: The case query to stop the conclusion for.
916
1474
  :param expert: The expert to ask for differentiating features as new rule conditions.
917
1475
  :param evaluated_rule: The evaluated rule to ask the expert about.
1476
+ :param refinement_type: The refinement type to use.
918
1477
  """
919
1478
  conditions = expert.ask_for_conditions(case_query, evaluated_rule)
920
- evaluated_rule.fit_rule(case_query)
921
- if self.mode == MCRDRMode.StopPlusRule:
922
- self.stop_rule_conditions = conditions
923
- if self.mode == MCRDRMode.StopPlusRuleCombined:
924
- new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
925
- case_query.conditions = new_top_rule_conditions
926
- self.add_top_rule(case_query)
1479
+ evaluated_rule.fit_rule(case_query, refinement_type=refinement_type)
1480
+ if refinement_type is MultiClassStopRule:
1481
+ if self.mode == MCRDRMode.StopPlusRule:
1482
+ self.stop_rule_conditions = conditions
1483
+ if self.mode == MCRDRMode.StopPlusRuleCombined:
1484
+ new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
1485
+ case_query.conditions = new_top_rule_conditions
1486
+ self.add_top_rule(case_query)
927
1487
 
928
1488
  def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
929
1489
  """
@@ -1004,6 +1564,39 @@ class GeneralRDR(RippleDownRules):
1004
1564
  super(GeneralRDR, self).__init__(**kwargs)
1005
1565
  self.all_figs: List[Figure] = [sr.fig for sr in self.start_rules_dict.values()]
1006
1566
 
1567
+ @classmethod
1568
+ def from_python(cls, model_dir: str, python_file_path: Optional[str] = None,
1569
+ parent_package_name: Optional[str] = None) -> Self:
1570
+ """
1571
+ Create an instance of the class from a python file.
1572
+
1573
+ :param model_dir: The path to the directory containing the python file.
1574
+ :param python_file_path: The path to the python file, if not provided, it will be generated from the model_dir.
1575
+ :param parent_package_name: The name of the package that contains the RDR classifier function, this
1576
+ is required in case of relative imports in the generated python file.
1577
+ :return: An instance of the class.
1578
+ """
1579
+ grdr = cls()
1580
+ grdr.update_from_python(model_dir, parent_package_name=parent_package_name, python_file_path=python_file_path,
1581
+ update_rule_tree=True)
1582
+ return grdr
1583
+
1584
+ @classmethod
1585
+ def get_rdr_type_from_acronym(cls, acronym: str) -> Type[Union[SingleClassRDR, MultiClassRDR]]:
1586
+ """
1587
+ Get the type of the ripple down rules classifier from the acronym.
1588
+
1589
+ :param acronym: The acronym of the ripple down rules classifier.
1590
+ :return: The type of the ripple down rules classifier.
1591
+ """
1592
+ acronym = acronym.lower()
1593
+ if acronym == "scrdr":
1594
+ return SingleClassRDR
1595
+ elif acronym == "mcrdr":
1596
+ return MultiClassRDR
1597
+ else:
1598
+ raise ValueError(f"Unknown RDR type acronym: {acronym}")
1599
+
1007
1600
  def add_rdr(self, rdr: Union[SingleClassRDR, MultiClassRDR], case_query: Optional[CaseQuery] = None):
1008
1601
  """
1009
1602
  Add a ripple down rules classifier to the map of classifiers.
@@ -1027,8 +1620,8 @@ class GeneralRDR(RippleDownRules):
1027
1620
  def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
1028
1621
  return [rdr.start_rule for rdr in self.start_rules_dict.values()]
1029
1622
 
1030
- def classify(self, case: Any, modify_case: bool = False,
1031
- case_query: Optional[CaseQuery] = None) -> Optional[Dict[str, Any]]:
1623
+ def _classify(self, case: Any, modify_case: bool = False,
1624
+ case_query: Optional[CaseQuery] = None) -> Optional[Dict[str, Any]]:
1032
1625
  """
1033
1626
  Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
1034
1627
  the classification until no more categories can be added.
@@ -1109,23 +1702,45 @@ class GeneralRDR(RippleDownRules):
1109
1702
  new_rdr.case_name = data["case_name"]
1110
1703
  return new_rdr
1111
1704
 
1112
- def update_from_python(self, model_dir: str, package_name: Optional[str] = None) -> None:
1705
+ def update_from_python(self, model_dir: str, parent_package_name: Optional[str] = None,
1706
+ python_file_path: Optional[str] = None,
1707
+ update_rule_tree: bool = False) -> None:
1113
1708
  """
1114
1709
  Update the rules from the generated python file, that might have been modified by the user.
1115
1710
 
1116
1711
  :param model_dir: The directory where the model is stored.
1117
- :param package_name: The name of the package that contains the RDR classifier function, this
1712
+ :param parent_package_name: The name of the package that contains the RDR classifier function, this
1118
1713
  is required in case of relative imports in the generated python file.
1119
- """
1120
- for rdr in self.start_rules_dict.values():
1121
- rdr.update_from_python(model_dir, package_name=package_name)
1714
+ :param python_file_path: The path to the python file, if not provided, it will be generated from the model_dir.
1715
+ :param update_rule_tree: Whether to update the rule tree from the python file or not.
1716
+ """
1717
+ if update_rule_tree:
1718
+ if python_file_path is None:
1719
+ main_python_file_path = self.get_generated_python_file_path(model_dir)
1720
+ else:
1721
+ main_python_file_path = python_file_path
1722
+ main_module = get_and_import_python_module(main_python_file_path, parent_package_name=parent_package_name)
1723
+ classifiers_dict = main_module.classifiers_dict
1724
+ self.start_rules_dict = {}
1725
+ for rdr_name, rdr_module in classifiers_dict.items():
1726
+ rdr_acronym = rdr_module.__name__.split('_')[-1]
1727
+ rdr_type = self.get_rdr_type_from_acronym(rdr_acronym)
1728
+ rdr_model_path = main_python_file_path.replace('_rdr.py', f'_{rdr_name}_{rdr_acronym}.py')
1729
+ rdr = rdr_type.from_python(model_dir, python_file_path=rdr_model_path,
1730
+ parent_package_name=parent_package_name)
1731
+ self.start_rules_dict[rdr_name] = rdr
1732
+
1733
+ self.update_rdr_metadata_from_python(main_module)
1734
+ else:
1735
+ for rdr in self.start_rules_dict.values():
1736
+ rdr.update_from_python(model_dir, parent_package_name=parent_package_name)
1122
1737
 
1123
1738
  def _write_to_python(self, model_dir: str, package_name: Optional[str] = None) -> None:
1124
1739
  """
1125
1740
  Write the tree of rules as source code to a file.
1126
1741
 
1127
1742
  :param model_dir: The directory where the model is stored.
1128
- :param relative_imports: Whether to use relative imports in the generated python file.
1743
+ :param package_name: The name of the package that contains the RDR classifier function.
1129
1744
  """
1130
1745
  for rdr in self.start_rules_dict.values():
1131
1746
  rdr._write_to_python(model_dir, package_name=package_name)
@@ -1133,6 +1748,7 @@ class GeneralRDR(RippleDownRules):
1133
1748
  file_path = model_dir + f"/{self.generated_python_file_name}.py"
1134
1749
  with open(file_path, "w") as f:
1135
1750
  f.write(self._get_imports(file_path=file_path, package_name=package_name) + "\n\n")
1751
+ self.write_rdr_metadata_to_pyton_file(f)
1136
1752
  f.write("classifiers_dict = dict()\n")
1137
1753
  for rdr_key, rdr in self.start_rules_dict.items():
1138
1754
  f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")