ripple-down-rules 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ripple_down_rules/rdr.py CHANGED
@@ -4,6 +4,8 @@ import copyreg
4
4
  import importlib
5
5
  import os
6
6
 
7
+ from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
8
+
7
9
  from . import logger
8
10
  import sys
9
11
  from abc import ABC, abstractmethod
@@ -28,7 +30,7 @@ from .datastructures.case import Case, CaseAttribute, create_case
28
30
  from .datastructures.dataclasses import CaseQuery
29
31
  from .datastructures.enums import MCRDRMode
30
32
  from .experts import Expert, Human
31
- from .helpers import is_matching
33
+ from .helpers import is_matching, general_rdr_classify
32
34
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
33
35
  try:
34
36
  from .user_interface.gui import RDRCaseViewer
@@ -36,8 +38,8 @@ except ImportError as e:
36
38
  RDRCaseViewer = None
37
39
  from .utils import draw_tree, make_set, copy_case, \
38
40
  SubclassJSONSerializer, make_list, get_type_from_string, \
39
- is_conflicting, update_case, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
40
- is_iterable, str_to_snake_case
41
+ is_conflicting, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
42
+ is_iterable, str_to_snake_case, get_import_path_from_path
41
43
 
42
44
 
43
45
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -76,16 +78,18 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
76
78
  """
77
79
  The name of the model. If None, the model name will be the generated python file name.
78
80
  """
81
+ mutually_exclusive: Optional[bool] = None
82
+ """
83
+ Whether the output of the classification of this rdr allows only one possible conclusion or not.
84
+ """
79
85
 
80
86
  def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None,
81
- save_dir: Optional[str] = None, ask_always: bool = True, model_name: Optional[str] = None):
87
+ save_dir: Optional[str] = None, model_name: Optional[str] = None):
82
88
  """
83
89
  :param start_rule: The starting rule for the classifier.
84
90
  :param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
85
91
  :param save_dir: The directory to save the classifier to.
86
- :param ask_always: Whether to always ask the expert (True) or only ask when classification fails (False).
87
92
  """
88
- self.ask_always: bool = ask_always
89
93
  self.model_name: Optional[str] = model_name
90
94
  self.save_dir = save_dir
91
95
  self.start_rule = start_rule
@@ -110,7 +114,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
110
114
  if not os.path.exists(save_dir + '/__init__.py'):
111
115
  os.makedirs(save_dir, exist_ok=True)
112
116
  with open(save_dir + '/__init__.py', 'w') as f:
113
- f.write("# This is an empty __init__.py file to make the directory a package.\n")
117
+ f.write("from . import *\n")
114
118
  if model_name is not None:
115
119
  self.model_name = model_name
116
120
  elif self.model_name is None:
@@ -134,7 +138,11 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
134
138
  model_dir = os.path.join(load_dir, model_name)
135
139
  json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
136
140
  rdr = cls.from_json_file(json_file)
137
- rdr.update_from_python(model_dir)
141
+ try:
142
+ rdr.update_from_python(model_dir)
143
+ except (FileNotFoundError, ValueError) as e:
144
+ logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
145
+ f"Make sure the file exists and is valid.")
138
146
  rdr.save_dir = load_dir
139
147
  rdr.model_name = model_name
140
148
  return rdr
@@ -213,18 +221,24 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
213
221
  return self.classify(case)
214
222
 
215
223
  @abstractmethod
216
- def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) \
224
+ def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
225
+ case_query: Optional[CaseQuery] = None) \
217
226
  -> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
218
227
  """
219
228
  Classify a case.
220
229
 
221
230
  :param case: The case to classify.
222
231
  :param modify_case: Whether to modify the original case attributes with the conclusion or not.
232
+ :param case_query: The case query containing the case to classify and the target category to compare the case with.
223
233
  :return: The category that the case belongs to.
224
234
  """
225
235
  pass
226
236
 
227
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
237
+ def fit_case(self, case_query: CaseQuery,
238
+ expert: Optional[Expert] = None,
239
+ update_existing_rules: bool = True,
240
+ scenario: Optional[Callable] = None,
241
+ **kwargs) \
228
242
  -> Union[CallableExpression, Dict[str, CallableExpression]]:
229
243
  """
230
244
  Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
@@ -232,6 +246,9 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
232
246
 
233
247
  :param case_query: The query containing the case to classify and the target category to compare the case with.
234
248
  :param expert: The expert to ask for differentiating features as new rule conditions.
249
+ :param update_existing_rules: Whether to update the existing same conclusion type rules that already gave
250
+ some conclusions with the type required by the case query.
251
+ :param scenario: The scenario at which the case was created, this is used to recreate the case if needed.
235
252
  :return: The category that the case belongs to.
236
253
  """
237
254
  if case_query is None:
@@ -240,13 +257,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
240
257
  self.name = case_query.attribute_name if self.name is None else self.name
241
258
  self.case_type = case_query.case_type if self.case_type is None else self.case_type
242
259
  self.case_name = case_query.case_name if self.case_name is None else self.case_name
260
+ case_query.scenario = scenario if case_query.scenario is None else case_query.scenario
243
261
 
244
- expert = expert or Human(answers_save_path=self.save_dir + '/expert_answers' if self.save_dir else None)
245
-
262
+ expert = expert or Human(viewer=self.viewer,
263
+ answers_save_path=self.save_dir + '/expert_answers'
264
+ if self.save_dir else None)
246
265
  if case_query.target is None:
247
266
  case_query_cp = copy(case_query)
248
- conclusions = self.classify(case_query_cp.case, modify_case=True)
249
- if self.ask_always or conclusions is None or is_iterable(conclusions) and len(conclusions) == 0:
267
+ conclusions = self.classify(case_query_cp.case, modify_case=True, case_query=case_query_cp)
268
+ if self.should_i_ask_the_expert_for_a_target(conclusions, case_query_cp, update_existing_rules):
250
269
  expert.ask_for_conclusion(case_query_cp)
251
270
  case_query.target = case_query_cp.target
252
271
  if case_query.target is None:
@@ -262,6 +281,34 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
262
281
 
263
282
  return fit_case_result
264
283
 
284
+ @staticmethod
285
+ def should_i_ask_the_expert_for_a_target(conclusions: Union[Any, Dict[str, Any]],
286
+ case_query: CaseQuery,
287
+ update_existing: bool) -> bool:
288
+ """
289
+ Determine if the rdr should ask the expert for the target of a given case query.
290
+
291
+ :param conclusions: The conclusions of the case.
292
+ :param case_query: The query containing the case to classify.
293
+ :param update_existing: Whether to update rules that gave the required type of conclusions.
294
+ :return: True if the rdr should ask the expert, False otherwise.
295
+ """
296
+ if conclusions is None:
297
+ return True
298
+ elif is_iterable(conclusions) and len(conclusions) == 0:
299
+ return True
300
+ elif isinstance(conclusions, dict):
301
+ if case_query.attribute_name not in conclusions:
302
+ return True
303
+ conclusions = conclusions[case_query.attribute_name]
304
+ conclusion_types = map(type, make_list(conclusions))
305
+ if not any(ct in case_query.core_attribute_type for ct in conclusion_types):
306
+ return True
307
+ elif update_existing:
308
+ return True
309
+ else:
310
+ return False
311
+
265
312
  @abstractmethod
266
313
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
267
314
  -> Union[CallableExpression, Dict[str, CallableExpression]]:
@@ -352,7 +399,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
352
399
  :return: The module that contains the rdr classifier function.
353
400
  """
354
401
  # remove from imports if exists first
355
- name = f"{package_name.strip('./').replace('/', '.')}.{self.generated_python_file_name}"
402
+ package_name = get_import_path_from_path(package_name)
403
+ name = f"{package_name}.{self.generated_python_file_name}" if package_name else self.generated_python_file_name
356
404
  try:
357
405
  module = importlib.import_module(name)
358
406
  del sys.modules[name]
@@ -374,6 +422,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
374
422
  conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys() if not isinstance(rules_dict[rid], MultiClassStopRule)]
375
423
  all_func_names = condition_func_names + conclusion_func_names
376
424
  filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
425
+ cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
426
+ cases_import_path = get_import_path_from_path(model_dir)
427
+ cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path\
428
+ else self.generated_python_cases_file_name
377
429
  functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
378
430
  # get the scope from the imports in the file
379
431
  scope = extract_imports(filepath)
@@ -381,13 +433,15 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
381
433
  if rule.conditions is not None:
382
434
  rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
383
435
  rule.conditions.scope = scope
436
+ if os.path.exists(cases_path):
437
+ rule.corner_case_metadata = importlib.import_module(cases_import_path).__dict__.get(f"corner_case_{rule.uid}", None)
384
438
  if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
385
439
  rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
386
440
  rule.conclusion.scope = scope
387
441
 
388
442
  @abstractmethod
389
443
  def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
390
- defs_file: Optional[str] = None):
444
+ defs_file: Optional[str] = None, cases_file: Optional[str] = None):
391
445
  """
392
446
  Write the rules as source code to a file.
393
447
 
@@ -395,6 +449,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
395
449
  :param file: The file to write the source code to.
396
450
  :param parent_indent: The indentation of the parent rule.
397
451
  :param defs_file: The file to write the definitions to.
452
+ :param cases_file: The file to write the cases to.
398
453
  """
399
454
  pass
400
455
 
@@ -407,25 +462,28 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
407
462
  os.makedirs(model_dir, exist_ok=True)
408
463
  if not os.path.exists(model_dir + '/__init__.py'):
409
464
  with open(model_dir + '/__init__.py', 'w') as f:
410
- f.write("# This is an empty __init__.py file to make the directory a package.\n")
411
- func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
465
+ f.write("from . import *\n")
466
+ func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
412
467
  file_name = model_dir + f"/{self.generated_python_file_name}.py"
413
468
  defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
469
+ cases_file_name = model_dir + f"/{self.generated_python_cases_file_name}.py"
414
470
  imports, defs_imports = self._get_imports()
415
471
  # clear the files first
416
472
  with open(defs_file_name, "w") as f:
417
473
  f.write(defs_imports + "\n\n")
474
+ with open(cases_file_name, "w") as cases_f:
475
+ cases_f.write("# This file contains the corner cases for the rules.\n")
418
476
  with open(file_name, "w") as f:
419
477
  imports += f"from .{self.generated_python_defs_file_name} import *\n"
420
- imports += f"from ripple_down_rules.rdr import {self.__class__.__name__}\n"
421
478
  f.write(imports + "\n\n")
422
479
  f.write(f"attribute_name = '{self.attribute_name}'\n")
423
480
  f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
424
- f.write(f"type_ = {self.__class__.__name__}\n")
481
+ f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
425
482
  f.write(f"\n\n{func_def}")
426
483
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
427
484
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
428
- self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name)
485
+ self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name,
486
+ cases_file=cases_file_name)
429
487
 
430
488
  @property
431
489
  @abstractmethod
@@ -474,6 +532,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
474
532
  def generated_python_defs_file_name(self) -> str:
475
533
  return f"{self.generated_python_file_name}_defs"
476
534
 
535
+ @property
536
+ def generated_python_cases_file_name(self) -> str:
537
+ return f"{self.generated_python_file_name}_cases"
538
+
477
539
 
478
540
  @property
479
541
  def conclusion_type(self) -> Tuple[Type]:
@@ -493,7 +555,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
493
555
  return self.start_rule.conclusion_name
494
556
 
495
557
  def _to_json(self) -> Dict[str, Any]:
496
- return {"start_rule": self.start_rule.to_json(), "generated_python_file_name": self.generated_python_file_name,
558
+ return {"start_rule": self.start_rule.to_json(),
559
+ "generated_python_file_name": self.generated_python_file_name,
497
560
  "name": self.name,
498
561
  "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
499
562
  "case_name": self.case_name}
@@ -526,6 +589,11 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
526
589
 
527
590
  class SingleClassRDR(RDRWithCodeWriter):
528
591
 
592
+ mutually_exclusive: bool = True
593
+ """
594
+ The output of the classification of this rdr negates all other possible outputs, there can only be one true value.
595
+ """
596
+
529
597
  def __init__(self, default_conclusion: Optional[Any] = None, **kwargs):
530
598
  """
531
599
  :param start_rule: The starting rule for the classifier.
@@ -550,7 +618,7 @@ class SingleClassRDR(RDRWithCodeWriter):
550
618
  pred = self.evaluate(case_query.case)
551
619
  if pred.conclusion(case_query.case) != case_query.target_value:
552
620
  expert.ask_for_conditions(case_query, pred)
553
- pred.fit_rule(case_query.case, case_query.target, conditions=case_query.conditions)
621
+ pred.fit_rule(case_query)
554
622
 
555
623
  return self.classify(case_query.case)
556
624
 
@@ -563,18 +631,24 @@ class SingleClassRDR(RDRWithCodeWriter):
563
631
  """
564
632
  if not self.start_rule:
565
633
  expert.ask_for_conditions(case_query)
566
- self.start_rule = SingleClassRule(case_query.conditions, case_query.target, corner_case=case_query.case,
567
- conclusion_name=case_query.attribute_name)
634
+ self.start_rule = SingleClassRule.from_case_query(case_query)
568
635
 
569
- def classify(self, case: Case, modify_case: bool = False) -> Optional[Any]:
636
+ def classify(self, case: Case, modify_case: bool = False,
637
+ case_query: Optional[CaseQuery] = None) -> Optional[Any]:
570
638
  """
571
639
  Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
572
640
 
573
641
  :param case: The case to classify.
574
642
  :param modify_case: Whether to modify the original case attributes with the conclusion or not.
643
+ :param case_query: The case query containing the case and the target category to compare the case with.
575
644
  """
576
645
  pred = self.evaluate(case)
577
- return pred.conclusion(case) if pred is not None and pred.fired else self.default_conclusion
646
+ conclusion = pred.conclusion(case) if pred is not None else None
647
+ if pred is not None and pred.fired and case_query is not None:
648
+ if pred.corner_case_metadata is None and conclusion is not None\
649
+ and type(conclusion) in case_query.core_attribute_type:
650
+ pred.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
651
+ return conclusion if pred is not None and pred.fired else self.default_conclusion
578
652
 
579
653
  def evaluate(self, case: Case) -> SingleClassRule:
580
654
  """
@@ -590,22 +664,24 @@ class SingleClassRDR(RDRWithCodeWriter):
590
664
  f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
591
665
 
592
666
  def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
593
- defs_file: Optional[str] = None):
667
+ defs_file: Optional[str] = None, cases_file: Optional[str] = None):
594
668
  """
595
669
  Write the rules as source code to a file.
596
670
  """
597
671
  if rule.conditions:
672
+ rule.write_corner_case_as_source_code(cases_file)
598
673
  if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
599
674
  file.write(if_clause)
600
675
  if rule.refinement:
601
676
  self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
602
- defs_file=defs_file)
677
+ defs_file=defs_file, cases_file=cases_file)
603
678
 
604
679
  conclusion_call = rule.write_conclusion_as_source_code(parent_indent, defs_file)
605
680
  file.write(conclusion_call)
606
681
 
607
682
  if rule.alternative:
608
- self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
683
+ self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file,
684
+ cases_file=cases_file)
609
685
 
610
686
  @property
611
687
  def conclusion_type_hint(self) -> str:
@@ -643,23 +719,33 @@ class MultiClassRDR(RDRWithCodeWriter):
643
719
  """
644
720
  The conditions of the stopping rule if needed.
645
721
  """
722
+ mutually_exclusive: bool = False
723
+ """
724
+ The output of the classification of this rdr allows for more than one true value as conclusion.
725
+ """
646
726
 
647
727
  def __init__(self, start_rule: Optional[MultiClassTopRule] = None,
648
- mode: MCRDRMode = MCRDRMode.StopOnly):
728
+ mode: MCRDRMode = MCRDRMode.StopOnly, **kwargs):
649
729
  """
650
730
  :param start_rule: The starting rules for the classifier.
651
731
  :param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
652
732
  """
653
- super(MultiClassRDR, self).__init__(start_rule)
733
+ super(MultiClassRDR, self).__init__(start_rule, **kwargs)
654
734
  self.mode: MCRDRMode = mode
655
735
 
656
- def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) -> Set[Any]:
736
+ def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
737
+ case_query: Optional[CaseQuery] = None) -> Set[Any]:
657
738
  evaluated_rule = self.start_rule
658
739
  self.conclusions = []
659
740
  while evaluated_rule:
660
741
  next_rule = evaluated_rule(case)
661
742
  if evaluated_rule.fired:
662
- self.add_conclusion(evaluated_rule, case)
743
+ rule_conclusion = evaluated_rule.conclusion(case)
744
+ if evaluated_rule.corner_case_metadata is None and case_query is not None:
745
+ if rule_conclusion is not None and len(make_list(rule_conclusion)) > 0\
746
+ and any(ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
747
+ evaluated_rule.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
748
+ self.add_conclusion(rule_conclusion)
663
749
  evaluated_rule = next_rule
664
750
  return make_set(self.conclusions)
665
751
 
@@ -687,7 +773,7 @@ class MultiClassRDR(RDRWithCodeWriter):
687
773
  self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
688
774
  else:
689
775
  # Rule fired and target is correct or there is no target to compare
690
- self.add_conclusion(evaluated_rule, case_query.case)
776
+ self.add_conclusion(rule_conclusion)
691
777
 
692
778
  if not next_rule:
693
779
  if not make_set(target_value).issubset(make_set(self.conclusions)):
@@ -699,16 +785,18 @@ class MultiClassRDR(RDRWithCodeWriter):
699
785
  return self.conclusions
700
786
 
701
787
  def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
702
- file, parent_indent: str = "", defs_file: Optional[str] = None):
788
+ file, parent_indent: str = "", defs_file: Optional[str] = None,
789
+ cases_file: Optional[str] = None):
703
790
  if rule == self.start_rule:
704
791
  file.write(f"{parent_indent}conclusions = set()\n")
705
792
  if rule.conditions:
793
+ rule.write_corner_case_as_source_code(cases_file)
706
794
  if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
707
795
  file.write(if_clause)
708
796
  conclusion_indent = parent_indent
709
797
  if hasattr(rule, "refinement") and rule.refinement:
710
798
  self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
711
- defs_file=defs_file)
799
+ defs_file=defs_file, cases_file=cases_file)
712
800
  conclusion_indent = parent_indent + " " * 4
713
801
  file.write(f"{conclusion_indent}else:\n")
714
802
 
@@ -716,7 +804,8 @@ class MultiClassRDR(RDRWithCodeWriter):
716
804
  file.write(conclusion_call)
717
805
 
718
806
  if rule.alternative:
719
- self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
807
+ self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file,
808
+ cases_file=cases_file)
720
809
 
721
810
  @property
722
811
  def conclusion_type_hint(self) -> str:
@@ -742,8 +831,7 @@ class MultiClassRDR(RDRWithCodeWriter):
742
831
  """
743
832
  if not self.start_rule:
744
833
  conditions = expert.ask_for_conditions(case_query)
745
- self.start_rule = MultiClassTopRule(conditions, case_query.target, corner_case=case_query.case,
746
- conclusion_name=case_query.attribute_name)
834
+ self.start_rule = MultiClassTopRule.from_case_query(case_query)
747
835
 
748
836
  @property
749
837
  def last_top_rule(self) -> Optional[MultiClassTopRule]:
@@ -764,7 +852,7 @@ class MultiClassRDR(RDRWithCodeWriter):
764
852
  if is_conflicting(rule_conclusion, case_query.target_value):
765
853
  self.stop_conclusion(case_query, expert, evaluated_rule)
766
854
  else:
767
- self.add_conclusion(evaluated_rule, case_query.case)
855
+ self.add_conclusion(rule_conclusion)
768
856
 
769
857
  def stop_conclusion(self, case_query: CaseQuery,
770
858
  expert: Expert, evaluated_rule: MultiClassTopRule):
@@ -776,12 +864,13 @@ class MultiClassRDR(RDRWithCodeWriter):
776
864
  :param evaluated_rule: The evaluated rule to ask the expert about.
777
865
  """
778
866
  conditions = expert.ask_for_conditions(case_query, evaluated_rule)
779
- evaluated_rule.fit_rule(case_query.case, case_query.target, conditions=conditions)
867
+ evaluated_rule.fit_rule(case_query)
780
868
  if self.mode == MCRDRMode.StopPlusRule:
781
869
  self.stop_rule_conditions = conditions
782
870
  if self.mode == MCRDRMode.StopPlusRuleCombined:
783
871
  new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
784
- self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
872
+ case_query.conditions = new_top_rule_conditions
873
+ self.add_top_rule(case_query)
785
874
 
786
875
  def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
787
876
  """
@@ -793,19 +882,19 @@ class MultiClassRDR(RDRWithCodeWriter):
793
882
  if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
794
883
  conditions = self.stop_rule_conditions
795
884
  self.stop_rule_conditions = None
885
+ case_query.conditions = conditions
796
886
  else:
797
887
  conditions = expert.ask_for_conditions(case_query)
798
- self.add_top_rule(conditions, case_query.target, case_query.case)
888
+ self.add_top_rule(case_query)
799
889
 
800
- def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
890
+ def add_conclusion(self, rule_conclusion: List[Any]) -> None:
801
891
  """
802
892
  Add the conclusion of the evaluated rule to the list of conclusions.
803
893
 
804
- :param evaluated_rule: The evaluated rule to add the conclusion of.
805
- :param case: The case to add the conclusion for.
894
+ :param rule_conclusion: The conclusion of the evaluated rule, which can be a single conclusion
895
+ or a set of conclusions.
806
896
  """
807
897
  conclusion_types = [type(c) for c in self.conclusions]
808
- rule_conclusion = evaluated_rule.conclusion(case)
809
898
  if type(rule_conclusion) not in conclusion_types:
810
899
  self.conclusions.extend(make_list(rule_conclusion))
811
900
  else:
@@ -818,15 +907,13 @@ class MultiClassRDR(RDRWithCodeWriter):
818
907
  self.conclusions.remove(c)
819
908
  self.conclusions.extend(make_list(combined_conclusion))
820
909
 
821
- def add_top_rule(self, conditions: CallableExpression, conclusion: Any, corner_case: Union[Case, SQLTable]):
910
+ def add_top_rule(self, case_query: CaseQuery):
822
911
  """
823
912
  Add a top rule to the classifier, which is a rule that is always checked and is part of the start_rules list.
824
913
 
825
- :param conditions: The conditions of the rule.
826
- :param conclusion: The conclusion of the rule.
827
- :param corner_case: The corner case of the rule.
914
+ :param case_query: The case query to add the top rule for.
828
915
  """
829
- self.start_rule.alternative = MultiClassTopRule(conditions, conclusion, corner_case=corner_case)
916
+ self.start_rule.alternative = MultiClassTopRule.from_case_query(case_query)
830
917
 
831
918
  @staticmethod
832
919
  def start_rule_type() -> Type[Rule]:
@@ -887,59 +974,19 @@ class GeneralRDR(RippleDownRules):
887
974
  def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
888
975
  return [rdr.start_rule for rdr in self.start_rules_dict.values()]
889
976
 
890
- def classify(self, case: Any, modify_case: bool = False) -> Optional[Dict[str, Any]]:
977
+ def classify(self, case: Any, modify_case: bool = False,
978
+ case_query: Optional[CaseQuery] = None) -> Optional[Dict[str, Any]]:
891
979
  """
892
980
  Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
893
981
  the classification until no more categories can be added.
894
982
 
895
983
  :param case: The case to classify.
896
984
  :param modify_case: Whether to modify the original case or create a copy and modify it.
985
+ :param case_query: The case query containing the case and the target category to compare the case with.
897
986
  :return: The categories that the case belongs to.
898
987
  """
899
- return self._classify(self.start_rules_dict, case, modify_original_case=modify_case)
900
-
901
- @staticmethod
902
- def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
903
- case: Any, modify_original_case: bool = False) -> Dict[str, Any]:
904
- """
905
- Classify a case by going through all classifiers and adding the categories that are classified,
906
- and then restarting the classification until no more categories can be added.
907
-
908
- :param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
909
- :param case: The case to classify.
910
- :param modify_original_case: Whether to modify the original case or create a copy and modify it.
911
- :return: The categories that the case belongs to.
912
- """
913
- conclusions = {}
914
- case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
915
- case_cp = copy_case(case) if not modify_original_case else case
916
- while True:
917
- new_conclusions = {}
918
- for attribute_name, rdr in classifiers_dict.items():
919
- pred_atts = rdr.classify(case_cp)
920
- if pred_atts is None:
921
- continue
922
- if rdr.type_ is SingleClassRDR:
923
- if attribute_name not in conclusions or \
924
- (attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
925
- conclusions[attribute_name] = pred_atts
926
- new_conclusions[attribute_name] = pred_atts
927
- else:
928
- pred_atts = make_set(pred_atts)
929
- if attribute_name in conclusions:
930
- pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
931
- if len(pred_atts) > 0:
932
- new_conclusions[attribute_name] = pred_atts
933
- if attribute_name not in conclusions:
934
- conclusions[attribute_name] = set()
935
- conclusions[attribute_name].update(pred_atts)
936
- if attribute_name in new_conclusions:
937
- mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
938
- case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive)
939
- update_case(case_query, new_conclusions)
940
- if len(new_conclusions) == 0:
941
- break
942
- return conclusions
988
+ return general_rdr_classify(self.start_rules_dict, case, modify_original_case=modify_case,
989
+ case_query=case_query)
943
990
 
944
991
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
945
992
  -> Dict[str, Any]:
@@ -1026,7 +1073,7 @@ class GeneralRDR(RippleDownRules):
1026
1073
  """
1027
1074
  for rdr in self.start_rules_dict.values():
1028
1075
  rdr._write_to_python(model_dir)
1029
- func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
1076
+ func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
1030
1077
  with open(model_dir + f"/{self.generated_python_file_name}.py", "w") as f:
1031
1078
  f.write(self._get_imports() + "\n\n")
1032
1079
  f.write("classifiers_dict = dict()\n")
@@ -1036,7 +1083,7 @@ class GeneralRDR(RippleDownRules):
1036
1083
  f.write(func_def)
1037
1084
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
1038
1085
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
1039
- f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
1086
+ f.write(f"{' ' * 4}return general_rdr_classify(classifiers_dict, case, **kwargs)\n")
1040
1087
 
1041
1088
  @property
1042
1089
  def _default_generated_python_file_name(self) -> Optional[str]:
@@ -1061,7 +1108,7 @@ class GeneralRDR(RippleDownRules):
1061
1108
  # add type hints
1062
1109
  imports += f"from typing_extensions import Dict, Any\n"
1063
1110
  # import rdr type
1064
- imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
1111
+ imports += f"from ripple_down_rules.helpers import general_rdr_classify\n"
1065
1112
  # add case type
1066
1113
  imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
1067
1114
  imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"