ripple-down-rules 0.5.63__py3-none-any.whl → 0.5.71__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ripple_down_rules/rdr.py CHANGED
@@ -4,6 +4,8 @@ import copyreg
4
4
  import importlib
5
5
  import os
6
6
 
7
+ from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
8
+
7
9
  from . import logger
8
10
  import sys
9
11
  from abc import ABC, abstractmethod
@@ -28,7 +30,7 @@ from .datastructures.case import Case, CaseAttribute, create_case
28
30
  from .datastructures.dataclasses import CaseQuery
29
31
  from .datastructures.enums import MCRDRMode
30
32
  from .experts import Expert, Human
31
- from .helpers import is_matching
33
+ from .helpers import is_matching, general_rdr_classify
32
34
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
33
35
  try:
34
36
  from .user_interface.gui import RDRCaseViewer
@@ -36,8 +38,8 @@ except ImportError as e:
36
38
  RDRCaseViewer = None
37
39
  from .utils import draw_tree, make_set, copy_case, \
38
40
  SubclassJSONSerializer, make_list, get_type_from_string, \
39
- is_conflicting, update_case, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
40
- is_iterable, str_to_snake_case
41
+ is_conflicting, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
42
+ is_iterable, str_to_snake_case, get_import_path_from_path
41
43
 
42
44
 
43
45
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -76,16 +78,18 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
76
78
  """
77
79
  The name of the model. If None, the model name will be the generated python file name.
78
80
  """
81
+ mutually_exclusive: Optional[bool] = None
82
+ """
83
+ Whether the output of the classification of this rdr allows only one possible conclusion or not.
84
+ """
79
85
 
80
86
  def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None,
81
- save_dir: Optional[str] = None, ask_always: bool = False, model_name: Optional[str] = None):
87
+ save_dir: Optional[str] = None, model_name: Optional[str] = None):
82
88
  """
83
89
  :param start_rule: The starting rule for the classifier.
84
90
  :param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
85
91
  :param save_dir: The directory to save the classifier to.
86
- :param ask_always: Whether to always ask the expert (True) or only ask when classification fails (False).
87
92
  """
88
- self.ask_always: bool = ask_always
89
93
  self.model_name: Optional[str] = model_name
90
94
  self.save_dir = save_dir
91
95
  self.start_rule = start_rule
@@ -110,7 +114,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
110
114
  if not os.path.exists(save_dir + '/__init__.py'):
111
115
  os.makedirs(save_dir, exist_ok=True)
112
116
  with open(save_dir + '/__init__.py', 'w') as f:
113
- f.write("# This is an empty __init__.py file to make the directory a package.\n")
117
+ f.write("from . import *\n")
114
118
  if model_name is not None:
115
119
  self.model_name = model_name
116
120
  elif self.model_name is None:
@@ -134,7 +138,11 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
134
138
  model_dir = os.path.join(load_dir, model_name)
135
139
  json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
136
140
  rdr = cls.from_json_file(json_file)
137
- rdr.update_from_python(model_dir)
141
+ try:
142
+ rdr.update_from_python(model_dir)
143
+ except (FileNotFoundError, ValueError) as e:
144
+ logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
145
+ f"Make sure the file exists and is valid.")
138
146
  rdr.save_dir = load_dir
139
147
  rdr.model_name = model_name
140
148
  return rdr
@@ -213,18 +221,24 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
213
221
  return self.classify(case)
214
222
 
215
223
  @abstractmethod
216
- def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) \
224
+ def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
225
+ case_query: Optional[CaseQuery] = None) \
217
226
  -> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
218
227
  """
219
228
  Classify a case.
220
229
 
221
230
  :param case: The case to classify.
222
231
  :param modify_case: Whether to modify the original case attributes with the conclusion or not.
232
+ :param case_query: The case query containing the case to classify and the target category to compare the case with.
223
233
  :return: The category that the case belongs to.
224
234
  """
225
235
  pass
226
236
 
227
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
237
+ def fit_case(self, case_query: CaseQuery,
238
+ expert: Optional[Expert] = None,
239
+ update_existing_rules: bool = True,
240
+ scenario: Optional[Callable] = None,
241
+ **kwargs) \
228
242
  -> Union[CallableExpression, Dict[str, CallableExpression]]:
229
243
  """
230
244
  Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
@@ -232,6 +246,9 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
232
246
 
233
247
  :param case_query: The query containing the case to classify and the target category to compare the case with.
234
248
  :param expert: The expert to ask for differentiating features as new rule conditions.
249
+ :param update_existing_rules: Whether to update the existing same conclusion type rules that already gave
250
+ some conclusions with the type required by the case query.
251
+ :param scenario: The scenario at which the case was created, this is used to recreate the case if needed.
235
252
  :return: The category that the case belongs to.
236
253
  """
237
254
  if case_query is None:
@@ -240,19 +257,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
240
257
  self.name = case_query.attribute_name if self.name is None else self.name
241
258
  self.case_type = case_query.case_type if self.case_type is None else self.case_type
242
259
  self.case_name = case_query.case_name if self.case_name is None else self.case_name
260
+ case_query.scenario = scenario if case_query.scenario is None else case_query.scenario
243
261
 
244
262
  expert = expert or Human(viewer=self.viewer,
245
263
  answers_save_path=self.save_dir + '/expert_answers'
246
264
  if self.save_dir else None)
247
-
248
265
  if case_query.target is None:
249
266
  case_query_cp = copy(case_query)
250
- conclusions = self.classify(case_query_cp.case, modify_case=True)
251
- if (self.ask_always or conclusions is None
252
- or is_iterable(conclusions) and len(conclusions) == 0
253
- or (isinstance(conclusions, dict) and (case_query_cp.attribute_name not in conclusions
254
- or not any(type(c) in case_query_cp.core_attribute_type
255
- for c in make_list(conclusions[case_query_cp.attribute_name]))))):
267
+ conclusions = self.classify(case_query_cp.case, modify_case=True, case_query=case_query_cp)
268
+ if self.should_i_ask_the_expert_for_a_target(conclusions, case_query_cp, update_existing_rules):
256
269
  expert.ask_for_conclusion(case_query_cp)
257
270
  case_query.target = case_query_cp.target
258
271
  if case_query.target is None:
@@ -268,6 +281,34 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
268
281
 
269
282
  return fit_case_result
270
283
 
284
+ @staticmethod
285
+ def should_i_ask_the_expert_for_a_target(conclusions: Union[Any, Dict[str, Any]],
286
+ case_query: CaseQuery,
287
+ update_existing: bool) -> bool:
288
+ """
289
+ Determine if the rdr should ask the expert for the target of a given case query.
290
+
291
+ :param conclusions: The conclusions of the case.
292
+ :param case_query: The query containing the case to classify.
293
+ :param update_existing: Whether to update rules that gave the required type of conclusions.
294
+ :return: True if the rdr should ask the expert, False otherwise.
295
+ """
296
+ if conclusions is None:
297
+ return True
298
+ elif is_iterable(conclusions) and len(conclusions) == 0:
299
+ return True
300
+ elif isinstance(conclusions, dict):
301
+ if case_query.attribute_name not in conclusions:
302
+ return True
303
+ conclusions = conclusions[case_query.attribute_name]
304
+ conclusion_types = map(type, make_list(conclusions))
305
+ if not any(ct in case_query.core_attribute_type for ct in conclusion_types):
306
+ return True
307
+ elif update_existing:
308
+ return True
309
+ else:
310
+ return False
311
+
271
312
  @abstractmethod
272
313
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
273
314
  -> Union[CallableExpression, Dict[str, CallableExpression]]:
@@ -358,7 +399,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
358
399
  :return: The module that contains the rdr classifier function.
359
400
  """
360
401
  # remove from imports if exists first
361
- name = f"{package_name.strip('./').replace('/', '.')}.{self.generated_python_file_name}"
402
+ package_name = get_import_path_from_path(package_name)
403
+ name = f"{package_name}.{self.generated_python_file_name}" if package_name else self.generated_python_file_name
362
404
  try:
363
405
  module = importlib.import_module(name)
364
406
  del sys.modules[name]
@@ -380,6 +422,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
380
422
  conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys() if not isinstance(rules_dict[rid], MultiClassStopRule)]
381
423
  all_func_names = condition_func_names + conclusion_func_names
382
424
  filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
425
+ cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
426
+ cases_import_path = get_import_path_from_path(model_dir)
427
+ cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path\
428
+ else self.generated_python_cases_file_name
383
429
  functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
384
430
  # get the scope from the imports in the file
385
431
  scope = extract_imports(filepath)
@@ -387,13 +433,15 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
387
433
  if rule.conditions is not None:
388
434
  rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
389
435
  rule.conditions.scope = scope
436
+ if os.path.exists(cases_path):
437
+ rule.corner_case_metadata = importlib.import_module(cases_import_path).__dict__.get(f"corner_case_{rule.uid}", None)
390
438
  if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
391
439
  rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
392
440
  rule.conclusion.scope = scope
393
441
 
394
442
  @abstractmethod
395
443
  def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
396
- defs_file: Optional[str] = None):
444
+ defs_file: Optional[str] = None, cases_file: Optional[str] = None):
397
445
  """
398
446
  Write the rules as source code to a file.
399
447
 
@@ -401,6 +449,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
401
449
  :param file: The file to write the source code to.
402
450
  :param parent_indent: The indentation of the parent rule.
403
451
  :param defs_file: The file to write the definitions to.
452
+ :param cases_file: The file to write the cases to.
404
453
  """
405
454
  pass
406
455
 
@@ -413,25 +462,28 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
413
462
  os.makedirs(model_dir, exist_ok=True)
414
463
  if not os.path.exists(model_dir + '/__init__.py'):
415
464
  with open(model_dir + '/__init__.py', 'w') as f:
416
- f.write("# This is an empty __init__.py file to make the directory a package.\n")
417
- func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
465
+ f.write("from . import *\n")
466
+ func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
418
467
  file_name = model_dir + f"/{self.generated_python_file_name}.py"
419
468
  defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
469
+ cases_file_name = model_dir + f"/{self.generated_python_cases_file_name}.py"
420
470
  imports, defs_imports = self._get_imports()
421
471
  # clear the files first
422
472
  with open(defs_file_name, "w") as f:
423
473
  f.write(defs_imports + "\n\n")
474
+ with open(cases_file_name, "w") as cases_f:
475
+ cases_f.write("# This file contains the corner cases for the rules.\n")
424
476
  with open(file_name, "w") as f:
425
477
  imports += f"from .{self.generated_python_defs_file_name} import *\n"
426
- imports += f"from ripple_down_rules.rdr import {self.__class__.__name__}\n"
427
478
  f.write(imports + "\n\n")
428
479
  f.write(f"attribute_name = '{self.attribute_name}'\n")
429
480
  f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
430
- f.write(f"type_ = {self.__class__.__name__}\n")
481
+ f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
431
482
  f.write(f"\n\n{func_def}")
432
483
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
433
484
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
434
- self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name)
485
+ self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name,
486
+ cases_file=cases_file_name)
435
487
 
436
488
  @property
437
489
  @abstractmethod
@@ -480,6 +532,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
480
532
  def generated_python_defs_file_name(self) -> str:
481
533
  return f"{self.generated_python_file_name}_defs"
482
534
 
535
+ @property
536
+ def generated_python_cases_file_name(self) -> str:
537
+ return f"{self.generated_python_file_name}_cases"
538
+
483
539
 
484
540
  @property
485
541
  def conclusion_type(self) -> Tuple[Type]:
@@ -533,6 +589,11 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
533
589
 
534
590
  class SingleClassRDR(RDRWithCodeWriter):
535
591
 
592
+ mutually_exclusive: bool = True
593
+ """
594
+ The output of the classification of this rdr negates all other possible outputs, there can only be one true value.
595
+ """
596
+
536
597
  def __init__(self, default_conclusion: Optional[Any] = None, **kwargs):
537
598
  """
538
599
  :param start_rule: The starting rule for the classifier.
@@ -557,7 +618,7 @@ class SingleClassRDR(RDRWithCodeWriter):
557
618
  pred = self.evaluate(case_query.case)
558
619
  if pred.conclusion(case_query.case) != case_query.target_value:
559
620
  expert.ask_for_conditions(case_query, pred)
560
- pred.fit_rule(case_query.case, case_query.target, conditions=case_query.conditions)
621
+ pred.fit_rule(case_query)
561
622
 
562
623
  return self.classify(case_query.case)
563
624
 
@@ -570,18 +631,24 @@ class SingleClassRDR(RDRWithCodeWriter):
570
631
  """
571
632
  if not self.start_rule:
572
633
  expert.ask_for_conditions(case_query)
573
- self.start_rule = SingleClassRule(case_query.conditions, case_query.target, corner_case=case_query.case,
574
- conclusion_name=case_query.attribute_name)
634
+ self.start_rule = SingleClassRule.from_case_query(case_query)
575
635
 
576
- def classify(self, case: Case, modify_case: bool = False) -> Optional[Any]:
636
+ def classify(self, case: Case, modify_case: bool = False,
637
+ case_query: Optional[CaseQuery] = None) -> Optional[Any]:
577
638
  """
578
639
  Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
579
640
 
580
641
  :param case: The case to classify.
581
642
  :param modify_case: Whether to modify the original case attributes with the conclusion or not.
643
+ :param case_query: The case query containing the case and the target category to compare the case with.
582
644
  """
583
645
  pred = self.evaluate(case)
584
- return pred.conclusion(case) if pred is not None and pred.fired else self.default_conclusion
646
+ conclusion = pred.conclusion(case) if pred is not None else None
647
+ if pred is not None and pred.fired and case_query is not None:
648
+ if pred.corner_case_metadata is None and conclusion is not None\
649
+ and type(conclusion) in case_query.core_attribute_type:
650
+ pred.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
651
+ return conclusion if pred is not None and pred.fired else self.default_conclusion
585
652
 
586
653
  def evaluate(self, case: Case) -> SingleClassRule:
587
654
  """
@@ -597,22 +664,24 @@ class SingleClassRDR(RDRWithCodeWriter):
597
664
  f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
598
665
 
599
666
  def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
600
- defs_file: Optional[str] = None):
667
+ defs_file: Optional[str] = None, cases_file: Optional[str] = None):
601
668
  """
602
669
  Write the rules as source code to a file.
603
670
  """
604
671
  if rule.conditions:
672
+ rule.write_corner_case_as_source_code(cases_file)
605
673
  if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
606
674
  file.write(if_clause)
607
675
  if rule.refinement:
608
676
  self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
609
- defs_file=defs_file)
677
+ defs_file=defs_file, cases_file=cases_file)
610
678
 
611
679
  conclusion_call = rule.write_conclusion_as_source_code(parent_indent, defs_file)
612
680
  file.write(conclusion_call)
613
681
 
614
682
  if rule.alternative:
615
- self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
683
+ self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file,
684
+ cases_file=cases_file)
616
685
 
617
686
  @property
618
687
  def conclusion_type_hint(self) -> str:
@@ -650,6 +719,10 @@ class MultiClassRDR(RDRWithCodeWriter):
650
719
  """
651
720
  The conditions of the stopping rule if needed.
652
721
  """
722
+ mutually_exclusive: bool = False
723
+ """
724
+ The output of the classification of this rdr allows for more than one true value as conclusion.
725
+ """
653
726
 
654
727
  def __init__(self, start_rule: Optional[MultiClassTopRule] = None,
655
728
  mode: MCRDRMode = MCRDRMode.StopOnly, **kwargs):
@@ -660,13 +733,19 @@ class MultiClassRDR(RDRWithCodeWriter):
660
733
  super(MultiClassRDR, self).__init__(start_rule, **kwargs)
661
734
  self.mode: MCRDRMode = mode
662
735
 
663
- def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) -> Set[Any]:
736
+ def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
737
+ case_query: Optional[CaseQuery] = None) -> Set[Any]:
664
738
  evaluated_rule = self.start_rule
665
739
  self.conclusions = []
666
740
  while evaluated_rule:
667
741
  next_rule = evaluated_rule(case)
668
742
  if evaluated_rule.fired:
669
- self.add_conclusion(evaluated_rule, case)
743
+ rule_conclusion = evaluated_rule.conclusion(case)
744
+ if evaluated_rule.corner_case_metadata is None and case_query is not None:
745
+ if rule_conclusion is not None and len(make_list(rule_conclusion)) > 0\
746
+ and any(ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
747
+ evaluated_rule.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
748
+ self.add_conclusion(rule_conclusion)
670
749
  evaluated_rule = next_rule
671
750
  return make_set(self.conclusions)
672
751
 
@@ -694,7 +773,7 @@ class MultiClassRDR(RDRWithCodeWriter):
694
773
  self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
695
774
  else:
696
775
  # Rule fired and target is correct or there is no target to compare
697
- self.add_conclusion(evaluated_rule, case_query.case)
776
+ self.add_conclusion(rule_conclusion)
698
777
 
699
778
  if not next_rule:
700
779
  if not make_set(target_value).issubset(make_set(self.conclusions)):
@@ -706,16 +785,18 @@ class MultiClassRDR(RDRWithCodeWriter):
706
785
  return self.conclusions
707
786
 
708
787
  def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
709
- file, parent_indent: str = "", defs_file: Optional[str] = None):
788
+ file, parent_indent: str = "", defs_file: Optional[str] = None,
789
+ cases_file: Optional[str] = None):
710
790
  if rule == self.start_rule:
711
791
  file.write(f"{parent_indent}conclusions = set()\n")
712
792
  if rule.conditions:
793
+ rule.write_corner_case_as_source_code(cases_file)
713
794
  if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
714
795
  file.write(if_clause)
715
796
  conclusion_indent = parent_indent
716
797
  if hasattr(rule, "refinement") and rule.refinement:
717
798
  self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
718
- defs_file=defs_file)
799
+ defs_file=defs_file, cases_file=cases_file)
719
800
  conclusion_indent = parent_indent + " " * 4
720
801
  file.write(f"{conclusion_indent}else:\n")
721
802
 
@@ -723,7 +804,8 @@ class MultiClassRDR(RDRWithCodeWriter):
723
804
  file.write(conclusion_call)
724
805
 
725
806
  if rule.alternative:
726
- self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
807
+ self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file,
808
+ cases_file=cases_file)
727
809
 
728
810
  @property
729
811
  def conclusion_type_hint(self) -> str:
@@ -749,8 +831,7 @@ class MultiClassRDR(RDRWithCodeWriter):
749
831
  """
750
832
  if not self.start_rule:
751
833
  conditions = expert.ask_for_conditions(case_query)
752
- self.start_rule = MultiClassTopRule(conditions, case_query.target, corner_case=case_query.case,
753
- conclusion_name=case_query.attribute_name)
834
+ self.start_rule = MultiClassTopRule.from_case_query(case_query)
754
835
 
755
836
  @property
756
837
  def last_top_rule(self) -> Optional[MultiClassTopRule]:
@@ -771,7 +852,7 @@ class MultiClassRDR(RDRWithCodeWriter):
771
852
  if is_conflicting(rule_conclusion, case_query.target_value):
772
853
  self.stop_conclusion(case_query, expert, evaluated_rule)
773
854
  else:
774
- self.add_conclusion(evaluated_rule, case_query.case)
855
+ self.add_conclusion(rule_conclusion)
775
856
 
776
857
  def stop_conclusion(self, case_query: CaseQuery,
777
858
  expert: Expert, evaluated_rule: MultiClassTopRule):
@@ -783,12 +864,13 @@ class MultiClassRDR(RDRWithCodeWriter):
783
864
  :param evaluated_rule: The evaluated rule to ask the expert about.
784
865
  """
785
866
  conditions = expert.ask_for_conditions(case_query, evaluated_rule)
786
- evaluated_rule.fit_rule(case_query.case, case_query.target, conditions=conditions)
867
+ evaluated_rule.fit_rule(case_query)
787
868
  if self.mode == MCRDRMode.StopPlusRule:
788
869
  self.stop_rule_conditions = conditions
789
870
  if self.mode == MCRDRMode.StopPlusRuleCombined:
790
871
  new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
791
- self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
872
+ case_query.conditions = new_top_rule_conditions
873
+ self.add_top_rule(case_query)
792
874
 
793
875
  def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
794
876
  """
@@ -800,19 +882,19 @@ class MultiClassRDR(RDRWithCodeWriter):
800
882
  if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
801
883
  conditions = self.stop_rule_conditions
802
884
  self.stop_rule_conditions = None
885
+ case_query.conditions = conditions
803
886
  else:
804
887
  conditions = expert.ask_for_conditions(case_query)
805
- self.add_top_rule(conditions, case_query.target, case_query.case)
888
+ self.add_top_rule(case_query)
806
889
 
807
- def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
890
+ def add_conclusion(self, rule_conclusion: List[Any]) -> None:
808
891
  """
809
892
  Add the conclusion of the evaluated rule to the list of conclusions.
810
893
 
811
- :param evaluated_rule: The evaluated rule to add the conclusion of.
812
- :param case: The case to add the conclusion for.
894
+ :param rule_conclusion: The conclusion of the evaluated rule, which can be a single conclusion
895
+ or a set of conclusions.
813
896
  """
814
897
  conclusion_types = [type(c) for c in self.conclusions]
815
- rule_conclusion = evaluated_rule.conclusion(case)
816
898
  if type(rule_conclusion) not in conclusion_types:
817
899
  self.conclusions.extend(make_list(rule_conclusion))
818
900
  else:
@@ -825,15 +907,13 @@ class MultiClassRDR(RDRWithCodeWriter):
825
907
  self.conclusions.remove(c)
826
908
  self.conclusions.extend(make_list(combined_conclusion))
827
909
 
828
- def add_top_rule(self, conditions: CallableExpression, conclusion: Any, corner_case: Union[Case, SQLTable]):
910
+ def add_top_rule(self, case_query: CaseQuery):
829
911
  """
830
912
  Add a top rule to the classifier, which is a rule that is always checked and is part of the start_rules list.
831
913
 
832
- :param conditions: The conditions of the rule.
833
- :param conclusion: The conclusion of the rule.
834
- :param corner_case: The corner case of the rule.
914
+ :param case_query: The case query to add the top rule for.
835
915
  """
836
- self.start_rule.alternative = MultiClassTopRule(conditions, conclusion, corner_case=corner_case)
916
+ self.start_rule.alternative = MultiClassTopRule.from_case_query(case_query)
837
917
 
838
918
  @staticmethod
839
919
  def start_rule_type() -> Type[Rule]:
@@ -894,59 +974,19 @@ class GeneralRDR(RippleDownRules):
894
974
  def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
895
975
  return [rdr.start_rule for rdr in self.start_rules_dict.values()]
896
976
 
897
- def classify(self, case: Any, modify_case: bool = False) -> Optional[Dict[str, Any]]:
977
+ def classify(self, case: Any, modify_case: bool = False,
978
+ case_query: Optional[CaseQuery] = None) -> Optional[Dict[str, Any]]:
898
979
  """
899
980
  Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
900
981
  the classification until no more categories can be added.
901
982
 
902
983
  :param case: The case to classify.
903
984
  :param modify_case: Whether to modify the original case or create a copy and modify it.
985
+ :param case_query: The case query containing the case and the target category to compare the case with.
904
986
  :return: The categories that the case belongs to.
905
987
  """
906
- return self._classify(self.start_rules_dict, case, modify_original_case=modify_case)
907
-
908
- @staticmethod
909
- def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
910
- case: Any, modify_original_case: bool = False) -> Dict[str, Any]:
911
- """
912
- Classify a case by going through all classifiers and adding the categories that are classified,
913
- and then restarting the classification until no more categories can be added.
914
-
915
- :param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
916
- :param case: The case to classify.
917
- :param modify_original_case: Whether to modify the original case or create a copy and modify it.
918
- :return: The categories that the case belongs to.
919
- """
920
- conclusions = {}
921
- case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
922
- case_cp = copy_case(case) if not modify_original_case else case
923
- while True:
924
- new_conclusions = {}
925
- for attribute_name, rdr in classifiers_dict.items():
926
- pred_atts = rdr.classify(case_cp)
927
- if pred_atts is None:
928
- continue
929
- if rdr.type_ is SingleClassRDR:
930
- if attribute_name not in conclusions or \
931
- (attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
932
- conclusions[attribute_name] = pred_atts
933
- new_conclusions[attribute_name] = pred_atts
934
- else:
935
- pred_atts = make_set(pred_atts)
936
- if attribute_name in conclusions:
937
- pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
938
- if len(pred_atts) > 0:
939
- new_conclusions[attribute_name] = pred_atts
940
- if attribute_name not in conclusions:
941
- conclusions[attribute_name] = set()
942
- conclusions[attribute_name].update(pred_atts)
943
- if attribute_name in new_conclusions:
944
- mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
945
- case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive)
946
- update_case(case_query, new_conclusions)
947
- if len(new_conclusions) == 0:
948
- break
949
- return conclusions
988
+ return general_rdr_classify(self.start_rules_dict, case, modify_original_case=modify_case,
989
+ case_query=case_query)
950
990
 
951
991
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
952
992
  -> Dict[str, Any]:
@@ -1033,7 +1073,7 @@ class GeneralRDR(RippleDownRules):
1033
1073
  """
1034
1074
  for rdr in self.start_rules_dict.values():
1035
1075
  rdr._write_to_python(model_dir)
1036
- func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
1076
+ func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
1037
1077
  with open(model_dir + f"/{self.generated_python_file_name}.py", "w") as f:
1038
1078
  f.write(self._get_imports() + "\n\n")
1039
1079
  f.write("classifiers_dict = dict()\n")
@@ -1043,7 +1083,7 @@ class GeneralRDR(RippleDownRules):
1043
1083
  f.write(func_def)
1044
1084
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
1045
1085
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
1046
- f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
1086
+ f.write(f"{' ' * 4}return general_rdr_classify(classifiers_dict, case, **kwargs)\n")
1047
1087
 
1048
1088
  @property
1049
1089
  def _default_generated_python_file_name(self) -> Optional[str]:
@@ -1068,7 +1108,7 @@ class GeneralRDR(RippleDownRules):
1068
1108
  # add type hints
1069
1109
  imports += f"from typing_extensions import Dict, Any\n"
1070
1110
  # import rdr type
1071
- imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
1111
+ imports += f"from ripple_down_rules.helpers import general_rdr_classify\n"
1072
1112
  # add case type
1073
1113
  imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
1074
1114
  imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"