ripple-down-rules 0.5.5__py3-none-any.whl → 0.5.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ripple_down_rules/rdr.py CHANGED
@@ -1,18 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
- import copyreg
4
3
  import importlib
5
4
  import os
6
-
7
- from . import logger
8
- import sys
9
5
  from abc import ABC, abstractmethod
10
6
  from copy import copy
11
- from io import TextIOWrapper
12
- from types import ModuleType
7
+
8
+ from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
9
+ from . import logger
13
10
 
14
11
  try:
15
12
  from matplotlib import pyplot as plt
13
+
16
14
  Figure = plt.Figure
17
15
  except ImportError as e:
18
16
  logger.debug(f"{e}: matplotlib is not installed")
@@ -28,16 +26,16 @@ from .datastructures.case import Case, CaseAttribute, create_case
28
26
  from .datastructures.dataclasses import CaseQuery
29
27
  from .datastructures.enums import MCRDRMode
30
28
  from .experts import Expert, Human
31
- from .helpers import is_matching
29
+ from .helpers import is_matching, general_rdr_classify
32
30
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
31
+
33
32
  try:
34
33
  from .user_interface.gui import RDRCaseViewer
35
34
  except ImportError as e:
36
35
  RDRCaseViewer = None
37
- from .utils import draw_tree, make_set, copy_case, \
38
- SubclassJSONSerializer, make_list, get_type_from_string, \
39
- is_conflicting, update_case, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
40
- is_iterable, str_to_snake_case
36
+ from .utils import draw_tree, make_set, SubclassJSONSerializer, make_list, get_type_from_string, \
37
+ is_conflicting, extract_function_source, extract_imports, get_full_class_name, \
38
+ is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types
41
39
 
42
40
 
43
41
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -76,16 +74,18 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
76
74
  """
77
75
  The name of the model. If None, the model name will be the generated python file name.
78
76
  """
77
+ mutually_exclusive: Optional[bool] = None
78
+ """
79
+ Whether the output of the classification of this rdr allows only one possible conclusion or not.
80
+ """
79
81
 
80
82
  def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None,
81
- save_dir: Optional[str] = None, ask_always: bool = True, model_name: Optional[str] = None):
83
+ save_dir: Optional[str] = None, model_name: Optional[str] = None):
82
84
  """
83
85
  :param start_rule: The starting rule for the classifier.
84
86
  :param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
85
87
  :param save_dir: The directory to save the classifier to.
86
- :param ask_always: Whether to always ask the expert (True) or only ask when classification fails (False).
87
88
  """
88
- self.ask_always: bool = ask_always
89
89
  self.model_name: Optional[str] = model_name
90
90
  self.save_dir = save_dir
91
91
  self.start_rule = start_rule
@@ -94,13 +94,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
94
94
  if self.viewer is not None:
95
95
  self.viewer.set_save_function(self.save)
96
96
 
97
- def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None) -> str:
97
+ def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None,
98
+ package_name: Optional[str] = None) -> str:
98
99
  """
99
100
  Save the classifier to a file.
100
101
 
101
102
  :param save_dir: The directory to save the classifier to.
102
103
  :param model_name: The name of the model to save. If None, a default name is generated.
103
- :param postfix: The postfix to add to the file name.
104
+ :param package_name: The name of the package that contains the RDR classifier function, this
105
+ is required in case of relative imports in the generated python file.
104
106
  :return: The name of the saved model.
105
107
  """
106
108
  save_dir = save_dir or self.save_dir
@@ -110,7 +112,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
110
112
  if not os.path.exists(save_dir + '/__init__.py'):
111
113
  os.makedirs(save_dir, exist_ok=True)
112
114
  with open(save_dir + '/__init__.py', 'w') as f:
113
- f.write("# This is an empty __init__.py file to make the directory a package.\n")
115
+ f.write("from . import *\n")
114
116
  if model_name is not None:
115
117
  self.model_name = model_name
116
118
  elif self.model_name is None:
@@ -120,31 +122,40 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
120
122
  json_dir = os.path.join(model_dir, self.metadata_folder)
121
123
  os.makedirs(json_dir, exist_ok=True)
122
124
  self.to_json_file(os.path.join(json_dir, self.model_name))
123
- self._write_to_python(model_dir)
125
+ self._write_to_python(model_dir, package_name=package_name)
124
126
  return self.model_name
125
127
 
126
128
  @classmethod
127
- def load(cls, load_dir: str, model_name: str) -> Self:
129
+ def load(cls, load_dir: str, model_name: str,
130
+ package_name: Optional[str] = None) -> Self:
128
131
  """
129
132
  Load the classifier from a file.
130
133
 
131
134
  :param load_dir: The path to the model directory to load the classifier from.
132
135
  :param model_name: The name of the model to load.
136
+ :param package_name: The name of the package that contains the RDR classifier function, this
137
+ is required in case of relative imports in the generated python file.
133
138
  """
134
139
  model_dir = os.path.join(load_dir, model_name)
135
140
  json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
136
141
  rdr = cls.from_json_file(json_file)
137
- rdr.update_from_python(model_dir)
142
+ try:
143
+ rdr.update_from_python(model_dir, package_name=package_name)
144
+ except (FileNotFoundError, ValueError) as e:
145
+ logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
146
+ f"Make sure the file exists and is valid.")
138
147
  rdr.save_dir = load_dir
139
148
  rdr.model_name = model_name
140
149
  return rdr
141
150
 
142
151
  @abstractmethod
143
- def _write_to_python(self, model_dir: str):
152
+ def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
144
153
  """
145
154
  Write the tree of rules as source code to a file.
146
155
 
147
156
  :param model_dir: The path to the directory to write the source code to.
157
+ :param package_name: The name of the package that contains the RDR classifier function, this
158
+ is required in case of relative imports in the generated python file.
148
159
  """
149
160
  pass
150
161
 
@@ -213,18 +224,24 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
213
224
  return self.classify(case)
214
225
 
215
226
  @abstractmethod
216
- def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) \
227
+ def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
228
+ case_query: Optional[CaseQuery] = None) \
217
229
  -> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
218
230
  """
219
231
  Classify a case.
220
232
 
221
233
  :param case: The case to classify.
222
234
  :param modify_case: Whether to modify the original case attributes with the conclusion or not.
235
+ :param case_query: The case query containing the case to classify and the target category to compare the case with.
223
236
  :return: The category that the case belongs to.
224
237
  """
225
238
  pass
226
239
 
227
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
240
+ def fit_case(self, case_query: CaseQuery,
241
+ expert: Optional[Expert] = None,
242
+ update_existing_rules: bool = True,
243
+ scenario: Optional[Callable] = None,
244
+ **kwargs) \
228
245
  -> Union[CallableExpression, Dict[str, CallableExpression]]:
229
246
  """
230
247
  Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
@@ -232,6 +249,9 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
232
249
 
233
250
  :param case_query: The query containing the case to classify and the target category to compare the case with.
234
251
  :param expert: The expert to ask for differentiating features as new rule conditions.
252
+ :param update_existing_rules: Whether to update the existing same conclusion type rules that already gave
253
+ some conclusions with the type required by the case query.
254
+ :param scenario: The scenario at which the case was created, this is used to recreate the case if needed.
235
255
  :return: The category that the case belongs to.
236
256
  """
237
257
  if case_query is None:
@@ -240,13 +260,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
240
260
  self.name = case_query.attribute_name if self.name is None else self.name
241
261
  self.case_type = case_query.case_type if self.case_type is None else self.case_type
242
262
  self.case_name = case_query.case_name if self.case_name is None else self.case_name
263
+ case_query.scenario = scenario if case_query.scenario is None else case_query.scenario
243
264
 
244
- expert = expert or Human(answers_save_path=self.save_dir + '/expert_answers' if self.save_dir else None)
245
-
265
+ expert = expert or Human(viewer=self.viewer,
266
+ answers_save_path=self.save_dir + '/expert_answers'
267
+ if self.save_dir else None)
246
268
  if case_query.target is None:
247
269
  case_query_cp = copy(case_query)
248
- conclusions = self.classify(case_query_cp.case, modify_case=True)
249
- if self.ask_always or conclusions is None or is_iterable(conclusions) and len(conclusions) == 0:
270
+ conclusions = self.classify(case_query_cp.case, modify_case=True, case_query=case_query_cp)
271
+ if self.should_i_ask_the_expert_for_a_target(conclusions, case_query_cp, update_existing_rules):
250
272
  expert.ask_for_conclusion(case_query_cp)
251
273
  case_query.target = case_query_cp.target
252
274
  if case_query.target is None:
@@ -262,6 +284,34 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
262
284
 
263
285
  return fit_case_result
264
286
 
287
+ @staticmethod
288
+ def should_i_ask_the_expert_for_a_target(conclusions: Union[Any, Dict[str, Any]],
289
+ case_query: CaseQuery,
290
+ update_existing: bool) -> bool:
291
+ """
292
+ Determine if the rdr should ask the expert for the target of a given case query.
293
+
294
+ :param conclusions: The conclusions of the case.
295
+ :param case_query: The query containing the case to classify.
296
+ :param update_existing: Whether to update rules that gave the required type of conclusions.
297
+ :return: True if the rdr should ask the expert, False otherwise.
298
+ """
299
+ if conclusions is None:
300
+ return True
301
+ elif is_iterable(conclusions) and len(conclusions) == 0:
302
+ return True
303
+ elif isinstance(conclusions, dict):
304
+ if case_query.attribute_name not in conclusions:
305
+ return True
306
+ conclusions = conclusions[case_query.attribute_name]
307
+ conclusion_types = map(type, make_list(conclusions))
308
+ if not any(ct in case_query.core_attribute_type for ct in conclusion_types):
309
+ return True
310
+ elif update_existing:
311
+ return True
312
+ else:
313
+ return False
314
+
265
315
  @abstractmethod
266
316
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
267
317
  -> Union[CallableExpression, Dict[str, CallableExpression]]:
@@ -326,11 +376,13 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
326
376
  pass
327
377
 
328
378
  @abstractmethod
329
- def update_from_python(self, model_dir: str):
379
+ def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
330
380
  """
331
381
  Update the rules from the generated python file, that might have been modified by the user.
332
382
 
333
383
  :param model_dir: The directory where the generated python file is located.
384
+ :param package_name: The name of the package that contains the RDR classifier function, this
385
+ is required in case of relative imports in the generated python file.
334
386
  """
335
387
  pass
336
388
 
@@ -352,42 +404,53 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
352
404
  :return: The module that contains the rdr classifier function.
353
405
  """
354
406
  # remove from imports if exists first
355
- name = f"{package_name.strip('./').replace('/', '.')}.{self.generated_python_file_name}"
356
- try:
357
- module = importlib.import_module(name)
358
- del sys.modules[name]
359
- except ModuleNotFoundError:
360
- pass
361
- return importlib.import_module(name).classify
407
+ package_name = get_import_path_from_path(package_name)
408
+ name = f"{package_name}.{self.generated_python_file_name}" if package_name else self.generated_python_file_name
409
+ module = importlib.import_module(name)
410
+ importlib.reload(module)
411
+ return module.classify
362
412
 
363
413
 
364
414
  class RDRWithCodeWriter(RippleDownRules, ABC):
365
415
 
366
- def update_from_python(self, model_dir: str):
416
+ def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
367
417
  """
368
418
  Update the rules from the generated python file, that might have been modified by the user.
369
419
 
370
420
  :param model_dir: The directory where the generated python file is located.
421
+ :param package_name: The name of the package that contains the RDR classifier function, this
422
+ is required in case of relative imports in the generated python file.
371
423
  """
372
- rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None}
424
+ rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants)
425
+ if r.conditions is not None}
373
426
  condition_func_names = [f'conditions_{rid}' for rid in rules_dict.keys()]
374
- conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys() if not isinstance(rules_dict[rid], MultiClassStopRule)]
427
+ conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys()
428
+ if not isinstance(rules_dict[rid], MultiClassStopRule)]
375
429
  all_func_names = condition_func_names + conclusion_func_names
376
430
  filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
431
+ cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
432
+ cases_import_path = get_import_path_from_path(model_dir)
433
+ cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path \
434
+ else self.generated_python_cases_file_name
377
435
  functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
378
436
  # get the scope from the imports in the file
379
- scope = extract_imports(filepath)
437
+ scope = extract_imports(filepath, package_name=package_name)
380
438
  for rule in [self.start_rule] + list(self.start_rule.descendants):
381
439
  if rule.conditions is not None:
382
440
  rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
383
441
  rule.conditions.scope = scope
442
+ if os.path.exists(cases_path):
443
+ module = importlib.import_module(cases_import_path, package=package_name)
444
+ importlib.reload(module)
445
+ rule.corner_case_metadata = module.__dict__.get(f"corner_case_{rule.uid}", None)
384
446
  if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
385
447
  rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
386
448
  rule.conclusion.scope = scope
387
449
 
388
450
  @abstractmethod
389
451
  def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
390
- defs_file: Optional[str] = None):
452
+ defs_file: Optional[str] = None, cases_file: Optional[str] = None,
453
+ package_name: Optional[str] = None):
391
454
  """
392
455
  Write the rules as source code to a file.
393
456
 
@@ -395,37 +458,63 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
395
458
  :param file: The file to write the source code to.
396
459
  :param parent_indent: The indentation of the parent rule.
397
460
  :param defs_file: The file to write the definitions to.
461
+ :param cases_file: The file to write the cases to.
462
+ :param package_name: The name of the package that contains the RDR classifier function, this
463
+ is required in case of relative imports in the generated python file.
398
464
  """
399
465
  pass
400
466
 
401
- def _write_to_python(self, model_dir: str):
467
+ def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
402
468
  """
403
469
  Write the tree of rules as source code to a file.
404
470
 
405
471
  :param model_dir: The path to the directory to write the source code to.
472
+ :param package_name: The name of the package that contains the RDR classifier function, this
473
+ is required in case of relative imports in the generated python file.
406
474
  """
475
+ # Make sure the model directory exists and create an __init__.py file if it doesn't exist
407
476
  os.makedirs(model_dir, exist_ok=True)
408
477
  if not os.path.exists(model_dir + '/__init__.py'):
409
478
  with open(model_dir + '/__init__.py', 'w') as f:
410
- f.write("# This is an empty __init__.py file to make the directory a package.\n")
411
- func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
479
+ f.write("from . import *\n")
480
+
481
+ # Set the file names for the generated python files
412
482
  file_name = model_dir + f"/{self.generated_python_file_name}.py"
413
483
  defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
414
- imports, defs_imports = self._get_imports()
415
- # clear the files first
484
+ cases_file_name = model_dir + f"/{self.generated_python_cases_file_name}.py"
485
+
486
+ # Get the required imports for the main file and the defs file
487
+ main_types, defs_types, corner_cases_types = self._get_types_to_import()
488
+ imports = get_imports_from_types(main_types, file_name, package_name)
489
+ defs_imports = get_imports_from_types(defs_types, defs_file_name, package_name)
490
+ corner_cases_imports = get_imports_from_types(corner_cases_types, cases_file_name, package_name)
491
+
492
+ # Add the imports to the defs file
416
493
  with open(defs_file_name, "w") as f:
417
- f.write(defs_imports + "\n\n")
494
+ f.write('\n'.join(defs_imports) + "\n\n\n")
495
+
496
+ # Add the imports to the cases file
497
+ case_factory_import = get_imports_from_types([CaseFactoryMetaData], cases_file_name, package_name)
498
+ corner_cases_imports.extend(case_factory_import)
499
+ with open(cases_file_name, "w") as cases_f:
500
+ cases_f.write("# This file contains the corner cases for the rules.\n")
501
+ cases_f.write('\n'.join(corner_cases_imports) + "\n\n\n")
502
+
503
+ # Add the imports, the attributes, and the function definition to the main file
504
+ func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
418
505
  with open(file_name, "w") as f:
419
- imports += f"from .{self.generated_python_defs_file_name} import *\n"
420
- imports += f"from ripple_down_rules.rdr import {self.__class__.__name__}\n"
421
- f.write(imports + "\n\n")
506
+ imports.append(f"from .{self.generated_python_defs_file_name} import *")
507
+ f.write('\n'.join(imports) + "\n\n\n")
422
508
  f.write(f"attribute_name = '{self.attribute_name}'\n")
423
509
  f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
424
- f.write(f"type_ = {self.__class__.__name__}\n")
510
+ f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
425
511
  f.write(f"\n\n{func_def}")
426
512
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
427
513
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
428
- self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name)
514
+
515
+ # Write the rules as source code to the main file
516
+ self.write_rules_as_source_code_to_file(self.start_rule, file_name, " " * 4, defs_file=defs_file_name,
517
+ cases_file=cases_file_name, package_name=package_name)
429
518
 
430
519
  @property
431
520
  @abstractmethod
@@ -435,31 +524,27 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
435
524
  """
436
525
  pass
437
526
 
438
- def _get_imports(self) -> Tuple[str, str]:
527
+ def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
439
528
  """
440
- :return: The imports for the generated python file of the RDR as a string.
529
+ :return: The types of the main, defs, and corner cases files of the RDR classifier that will be imported.
441
530
  """
442
- defs_imports_list = []
531
+ defs_types = set()
532
+ cases_types = set()
443
533
  for rule in [self.start_rule] + list(self.start_rule.descendants):
444
534
  if not rule.conditions:
445
535
  continue
446
536
  for scope in [rule.conditions.scope, rule.conclusion.scope]:
447
537
  if scope is None:
448
538
  continue
449
- defs_imports_list.extend(get_imports_from_scope(scope))
450
- if self.case_type.__module__ != "builtins":
451
- defs_imports_list.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
452
- defs_imports = "\n".join(set(defs_imports_list)) + "\n"
453
- imports = []
454
- if self.case_type.__module__ != "builtins":
455
- imports.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
456
- for conclusion_type in self.conclusion_type:
457
- if conclusion_type.__module__ != "builtins":
458
- imports.append(f"from {conclusion_type.__module__} import {conclusion_type.__name__}")
459
- imports.append("from ripple_down_rules.datastructures.case import Case, create_case")
460
- imports = set(imports).difference(defs_imports_list)
461
- imports = "\n".join(imports) + "\n"
462
- return imports, defs_imports
539
+ defs_types.update(make_set(scope.values()))
540
+ cases_types.update(rule.get_corner_case_types_to_import())
541
+ defs_types.add(self.case_type)
542
+ main_types = set()
543
+ main_types.add(self.case_type)
544
+ main_types.update(make_set(self.conclusion_type))
545
+ main_types.update({Case, create_case})
546
+ main_types = main_types.difference(defs_types)
547
+ return main_types, defs_types, cases_types
463
548
 
464
549
  @property
465
550
  def _default_generated_python_file_name(self) -> Optional[str]:
@@ -474,6 +559,9 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
474
559
  def generated_python_defs_file_name(self) -> str:
475
560
  return f"{self.generated_python_file_name}_defs"
476
561
 
562
+ @property
563
+ def generated_python_cases_file_name(self) -> str:
564
+ return f"{self.generated_python_file_name}_cases"
477
565
 
478
566
  @property
479
567
  def conclusion_type(self) -> Tuple[Type]:
@@ -493,7 +581,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
493
581
  return self.start_rule.conclusion_name
494
582
 
495
583
  def _to_json(self) -> Dict[str, Any]:
496
- return {"start_rule": self.start_rule.to_json(), "generated_python_file_name": self.generated_python_file_name,
584
+ return {"start_rule": self.start_rule.to_json(),
585
+ "generated_python_file_name": self.generated_python_file_name,
497
586
  "name": self.name,
498
587
  "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
499
588
  "case_name": self.case_name}
@@ -525,6 +614,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
525
614
 
526
615
 
527
616
  class SingleClassRDR(RDRWithCodeWriter):
617
+ mutually_exclusive: bool = True
618
+ """
619
+ The output of the classification of this rdr negates all other possible outputs, there can only be one true value.
620
+ """
528
621
 
529
622
  def __init__(self, default_conclusion: Optional[Any] = None, **kwargs):
530
623
  """
@@ -550,7 +643,7 @@ class SingleClassRDR(RDRWithCodeWriter):
550
643
  pred = self.evaluate(case_query.case)
551
644
  if pred.conclusion(case_query.case) != case_query.target_value:
552
645
  expert.ask_for_conditions(case_query, pred)
553
- pred.fit_rule(case_query.case, case_query.target, conditions=case_query.conditions)
646
+ pred.fit_rule(case_query)
554
647
 
555
648
  return self.classify(case_query.case)
556
649
 
@@ -563,18 +656,24 @@ class SingleClassRDR(RDRWithCodeWriter):
563
656
  """
564
657
  if not self.start_rule:
565
658
  expert.ask_for_conditions(case_query)
566
- self.start_rule = SingleClassRule(case_query.conditions, case_query.target, corner_case=case_query.case,
567
- conclusion_name=case_query.attribute_name)
659
+ self.start_rule = SingleClassRule.from_case_query(case_query)
568
660
 
569
- def classify(self, case: Case, modify_case: bool = False) -> Optional[Any]:
661
+ def classify(self, case: Case, modify_case: bool = False,
662
+ case_query: Optional[CaseQuery] = None) -> Optional[Any]:
570
663
  """
571
664
  Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
572
665
 
573
666
  :param case: The case to classify.
574
667
  :param modify_case: Whether to modify the original case attributes with the conclusion or not.
668
+ :param case_query: The case query containing the case and the target category to compare the case with.
575
669
  """
576
670
  pred = self.evaluate(case)
577
- return pred.conclusion(case) if pred is not None and pred.fired else self.default_conclusion
671
+ conclusion = pred.conclusion(case) if pred is not None else None
672
+ if pred is not None and pred.fired and case_query is not None:
673
+ if pred.corner_case_metadata is None and conclusion is not None \
674
+ and type(conclusion) in case_query.core_attribute_type:
675
+ pred.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
676
+ return conclusion if pred is not None and pred.fired else self.default_conclusion
578
677
 
579
678
  def evaluate(self, case: Case) -> SingleClassRule:
580
679
  """
@@ -583,29 +682,35 @@ class SingleClassRDR(RDRWithCodeWriter):
583
682
  matched_rule = self.start_rule(case) if self.start_rule is not None else None
584
683
  return matched_rule if matched_rule is not None else self.start_rule
585
684
 
586
- def _write_to_python(self, model_dir: str):
587
- super()._write_to_python(model_dir)
685
+ def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
686
+ super()._write_to_python(model_dir, package_name=package_name)
588
687
  if self.default_conclusion is not None:
589
688
  with open(model_dir + f"/{self.generated_python_file_name}.py", "a") as f:
590
689
  f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
591
690
 
592
- def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
593
- defs_file: Optional[str] = None):
691
+ def write_rules_as_source_code_to_file(self, rule: SingleClassRule, filename: str, parent_indent: str = "",
692
+ defs_file: Optional[str] = None, cases_file: Optional[str] = None,
693
+ package_name: Optional[str] = None):
594
694
  """
595
695
  Write the rules as source code to a file.
596
696
  """
597
697
  if rule.conditions:
698
+ rule.write_corner_case_as_source_code(cases_file, package_name=package_name)
598
699
  if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
599
- file.write(if_clause)
700
+ with open(filename, "a") as file:
701
+ file.write(if_clause)
600
702
  if rule.refinement:
601
- self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
602
- defs_file=defs_file)
703
+ self.write_rules_as_source_code_to_file(rule.refinement, filename, parent_indent + " ",
704
+ defs_file=defs_file, cases_file=cases_file,
705
+ package_name=package_name)
603
706
 
604
707
  conclusion_call = rule.write_conclusion_as_source_code(parent_indent, defs_file)
605
- file.write(conclusion_call)
708
+ with open(filename, "a") as file:
709
+ file.write(conclusion_call)
606
710
 
607
711
  if rule.alternative:
608
- self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
712
+ self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
713
+ cases_file=cases_file, package_name=package_name)
609
714
 
610
715
  @property
611
716
  def conclusion_type_hint(self) -> str:
@@ -643,23 +748,34 @@ class MultiClassRDR(RDRWithCodeWriter):
643
748
  """
644
749
  The conditions of the stopping rule if needed.
645
750
  """
751
+ mutually_exclusive: bool = False
752
+ """
753
+ The output of the classification of this rdr allows for more than one true value as conclusion.
754
+ """
646
755
 
647
756
  def __init__(self, start_rule: Optional[MultiClassTopRule] = None,
648
- mode: MCRDRMode = MCRDRMode.StopOnly):
757
+ mode: MCRDRMode = MCRDRMode.StopOnly, **kwargs):
649
758
  """
650
759
  :param start_rule: The starting rules for the classifier.
651
760
  :param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
652
761
  """
653
- super(MultiClassRDR, self).__init__(start_rule)
762
+ super(MultiClassRDR, self).__init__(start_rule, **kwargs)
654
763
  self.mode: MCRDRMode = mode
655
764
 
656
- def classify(self, case: Union[Case, SQLTable], modify_case: bool = False) -> Set[Any]:
765
+ def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
766
+ case_query: Optional[CaseQuery] = None) -> Set[Any]:
657
767
  evaluated_rule = self.start_rule
658
768
  self.conclusions = []
659
769
  while evaluated_rule:
660
770
  next_rule = evaluated_rule(case)
661
771
  if evaluated_rule.fired:
662
- self.add_conclusion(evaluated_rule, case)
772
+ rule_conclusion = evaluated_rule.conclusion(case)
773
+ if evaluated_rule.corner_case_metadata is None and case_query is not None:
774
+ if rule_conclusion is not None and len(make_list(rule_conclusion)) > 0 \
775
+ and any(
776
+ ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
777
+ evaluated_rule.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
778
+ self.add_conclusion(rule_conclusion)
663
779
  evaluated_rule = next_rule
664
780
  return make_set(self.conclusions)
665
781
 
@@ -687,7 +803,7 @@ class MultiClassRDR(RDRWithCodeWriter):
687
803
  self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
688
804
  else:
689
805
  # Rule fired and target is correct or there is no target to compare
690
- self.add_conclusion(evaluated_rule, case_query.case)
806
+ self.add_conclusion(rule_conclusion)
691
807
 
692
808
  if not next_rule:
693
809
  if not make_set(target_value).issubset(make_set(self.conclusions)):
@@ -699,24 +815,32 @@ class MultiClassRDR(RDRWithCodeWriter):
699
815
  return self.conclusions
700
816
 
701
817
  def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
702
- file, parent_indent: str = "", defs_file: Optional[str] = None):
818
+ filename: str, parent_indent: str = "", defs_file: Optional[str] = None,
819
+ cases_file: Optional[str] = None, package_name: Optional[str] = None):
703
820
  if rule == self.start_rule:
704
- file.write(f"{parent_indent}conclusions = set()\n")
821
+ with open(filename, "a") as file:
822
+ file.write(f"{parent_indent}conclusions = set()\n")
705
823
  if rule.conditions:
824
+ rule.write_corner_case_as_source_code(cases_file, package_name=package_name)
706
825
  if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
707
- file.write(if_clause)
826
+ with open(filename, "a") as file:
827
+ file.write(if_clause)
708
828
  conclusion_indent = parent_indent
709
829
  if hasattr(rule, "refinement") and rule.refinement:
710
- self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
711
- defs_file=defs_file)
830
+ self.write_rules_as_source_code_to_file(rule.refinement, filename, parent_indent + " ",
831
+ defs_file=defs_file, cases_file=cases_file,
832
+ package_name=package_name)
712
833
  conclusion_indent = parent_indent + " " * 4
713
- file.write(f"{conclusion_indent}else:\n")
834
+ with open(filename, "a") as file:
835
+ file.write(f"{conclusion_indent}else:\n")
714
836
 
715
837
  conclusion_call = rule.write_conclusion_as_source_code(conclusion_indent, defs_file)
716
- file.write(conclusion_call)
838
+ with open(filename, "a") as file:
839
+ file.write(conclusion_call)
717
840
 
718
841
  if rule.alternative:
719
- self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
842
+ self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
843
+ cases_file=cases_file, package_name=package_name)
720
844
 
721
845
  @property
722
846
  def conclusion_type_hint(self) -> str:
@@ -726,12 +850,11 @@ class MultiClassRDR(RDRWithCodeWriter):
726
850
  else:
727
851
  return f"Set[Union[{', '.join(conclusion_types)}]]"
728
852
 
729
- def _get_imports(self) -> Tuple[str, str]:
730
- imports, defs_imports = super()._get_imports()
731
- imports += f"from typing_extensions import Set, Union\n"
732
- imports += "from ripple_down_rules.utils import make_set\n"
733
- defs_imports += "from typing_extensions import Union\n"
734
- return imports, defs_imports
853
+ def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
854
+ main_types, defs_types, cases_types = super()._get_types_to_import()
855
+ main_types.update({Set, Union, make_set})
856
+ defs_types.add(Union)
857
+ return main_types, defs_types, cases_types
735
858
 
736
859
  def update_start_rule(self, case_query: CaseQuery, expert: Expert):
737
860
  """
@@ -742,8 +865,7 @@ class MultiClassRDR(RDRWithCodeWriter):
742
865
  """
743
866
  if not self.start_rule:
744
867
  conditions = expert.ask_for_conditions(case_query)
745
- self.start_rule = MultiClassTopRule(conditions, case_query.target, corner_case=case_query.case,
746
- conclusion_name=case_query.attribute_name)
868
+ self.start_rule = MultiClassTopRule.from_case_query(case_query)
747
869
 
748
870
  @property
749
871
  def last_top_rule(self) -> Optional[MultiClassTopRule]:
@@ -764,7 +886,7 @@ class MultiClassRDR(RDRWithCodeWriter):
764
886
  if is_conflicting(rule_conclusion, case_query.target_value):
765
887
  self.stop_conclusion(case_query, expert, evaluated_rule)
766
888
  else:
767
- self.add_conclusion(evaluated_rule, case_query.case)
889
+ self.add_conclusion(rule_conclusion)
768
890
 
769
891
  def stop_conclusion(self, case_query: CaseQuery,
770
892
  expert: Expert, evaluated_rule: MultiClassTopRule):
@@ -776,12 +898,13 @@ class MultiClassRDR(RDRWithCodeWriter):
776
898
  :param evaluated_rule: The evaluated rule to ask the expert about.
777
899
  """
778
900
  conditions = expert.ask_for_conditions(case_query, evaluated_rule)
779
- evaluated_rule.fit_rule(case_query.case, case_query.target, conditions=conditions)
901
+ evaluated_rule.fit_rule(case_query)
780
902
  if self.mode == MCRDRMode.StopPlusRule:
781
903
  self.stop_rule_conditions = conditions
782
904
  if self.mode == MCRDRMode.StopPlusRuleCombined:
783
905
  new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
784
- self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
906
+ case_query.conditions = new_top_rule_conditions
907
+ self.add_top_rule(case_query)
785
908
 
786
909
  def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
787
910
  """
@@ -793,19 +916,19 @@ class MultiClassRDR(RDRWithCodeWriter):
793
916
  if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
794
917
  conditions = self.stop_rule_conditions
795
918
  self.stop_rule_conditions = None
919
+ case_query.conditions = conditions
796
920
  else:
797
921
  conditions = expert.ask_for_conditions(case_query)
798
- self.add_top_rule(conditions, case_query.target, case_query.case)
922
+ self.add_top_rule(case_query)
799
923
 
800
- def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
924
+ def add_conclusion(self, rule_conclusion: List[Any]) -> None:
801
925
  """
802
926
  Add the conclusion of the evaluated rule to the list of conclusions.
803
927
 
804
- :param evaluated_rule: The evaluated rule to add the conclusion of.
805
- :param case: The case to add the conclusion for.
928
+ :param rule_conclusion: The conclusion of the evaluated rule, which can be a single conclusion
929
+ or a set of conclusions.
806
930
  """
807
931
  conclusion_types = [type(c) for c in self.conclusions]
808
- rule_conclusion = evaluated_rule.conclusion(case)
809
932
  if type(rule_conclusion) not in conclusion_types:
810
933
  self.conclusions.extend(make_list(rule_conclusion))
811
934
  else:
@@ -818,15 +941,13 @@ class MultiClassRDR(RDRWithCodeWriter):
818
941
  self.conclusions.remove(c)
819
942
  self.conclusions.extend(make_list(combined_conclusion))
820
943
 
821
- def add_top_rule(self, conditions: CallableExpression, conclusion: Any, corner_case: Union[Case, SQLTable]):
944
+ def add_top_rule(self, case_query: CaseQuery):
822
945
  """
823
946
  Add a top rule to the classifier, which is a rule that is always checked and is part of the start_rules list.
824
947
 
825
- :param conditions: The conditions of the rule.
826
- :param conclusion: The conclusion of the rule.
827
- :param corner_case: The corner case of the rule.
948
+ :param case_query: The case query to add the top rule for.
828
949
  """
829
- self.start_rule.alternative = MultiClassTopRule(conditions, conclusion, corner_case=corner_case)
950
+ self.start_rule.alternative = MultiClassTopRule.from_case_query(case_query)
830
951
 
831
952
  @staticmethod
832
953
  def start_rule_type() -> Type[Rule]:
@@ -887,59 +1008,19 @@ class GeneralRDR(RippleDownRules):
887
1008
  def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
888
1009
  return [rdr.start_rule for rdr in self.start_rules_dict.values()]
889
1010
 
890
- def classify(self, case: Any, modify_case: bool = False) -> Optional[Dict[str, Any]]:
1011
+ def classify(self, case: Any, modify_case: bool = False,
1012
+ case_query: Optional[CaseQuery] = None) -> Optional[Dict[str, Any]]:
891
1013
  """
892
1014
  Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
893
1015
  the classification until no more categories can be added.
894
1016
 
895
1017
  :param case: The case to classify.
896
1018
  :param modify_case: Whether to modify the original case or create a copy and modify it.
1019
+ :param case_query: The case query containing the case and the target category to compare the case with.
897
1020
  :return: The categories that the case belongs to.
898
1021
  """
899
- return self._classify(self.start_rules_dict, case, modify_original_case=modify_case)
900
-
901
- @staticmethod
902
- def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
903
- case: Any, modify_original_case: bool = False) -> Dict[str, Any]:
904
- """
905
- Classify a case by going through all classifiers and adding the categories that are classified,
906
- and then restarting the classification until no more categories can be added.
907
-
908
- :param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
909
- :param case: The case to classify.
910
- :param modify_original_case: Whether to modify the original case or create a copy and modify it.
911
- :return: The categories that the case belongs to.
912
- """
913
- conclusions = {}
914
- case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
915
- case_cp = copy_case(case) if not modify_original_case else case
916
- while True:
917
- new_conclusions = {}
918
- for attribute_name, rdr in classifiers_dict.items():
919
- pred_atts = rdr.classify(case_cp)
920
- if pred_atts is None:
921
- continue
922
- if rdr.type_ is SingleClassRDR:
923
- if attribute_name not in conclusions or \
924
- (attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
925
- conclusions[attribute_name] = pred_atts
926
- new_conclusions[attribute_name] = pred_atts
927
- else:
928
- pred_atts = make_set(pred_atts)
929
- if attribute_name in conclusions:
930
- pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
931
- if len(pred_atts) > 0:
932
- new_conclusions[attribute_name] = pred_atts
933
- if attribute_name not in conclusions:
934
- conclusions[attribute_name] = set()
935
- conclusions[attribute_name].update(pred_atts)
936
- if attribute_name in new_conclusions:
937
- mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
938
- case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive)
939
- update_case(case_query, new_conclusions)
940
- if len(new_conclusions) == 0:
941
- break
942
- return conclusions
1022
+ return general_rdr_classify(self.start_rules_dict, case, modify_original_case=modify_case,
1023
+ case_query=case_query)
943
1024
 
944
1025
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
945
1026
  -> Dict[str, Any]:
@@ -985,7 +1066,7 @@ class GeneralRDR(RippleDownRules):
985
1066
 
986
1067
  def _to_json(self) -> Dict[str, Any]:
987
1068
  return {"start_rules": {name: rdr.to_json() for name, rdr in self.start_rules_dict.items()}
988
- , "generated_python_file_name": self.generated_python_file_name,
1069
+ , "generated_python_file_name": self.generated_python_file_name,
989
1070
  "name": self.name,
990
1071
  "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
991
1072
  "case_name": self.case_name}
@@ -1009,26 +1090,30 @@ class GeneralRDR(RippleDownRules):
1009
1090
  new_rdr.case_name = data["case_name"]
1010
1091
  return new_rdr
1011
1092
 
1012
- def update_from_python(self, model_dir: str) -> None:
1093
+ def update_from_python(self, model_dir: str, package_name: Optional[str] = None) -> None:
1013
1094
  """
1014
1095
  Update the rules from the generated python file, that might have been modified by the user.
1015
1096
 
1016
1097
  :param model_dir: The directory where the model is stored.
1098
+ :param package_name: The name of the package that contains the RDR classifier function, this
1099
+ is required in case of relative imports in the generated python file.
1017
1100
  """
1018
1101
  for rdr in self.start_rules_dict.values():
1019
- rdr.update_from_python(model_dir)
1102
+ rdr.update_from_python(model_dir, package_name=package_name)
1020
1103
 
1021
- def _write_to_python(self, model_dir: str) -> None:
1104
+ def _write_to_python(self, model_dir: str, package_name: Optional[str] = None) -> None:
1022
1105
  """
1023
1106
  Write the tree of rules as source code to a file.
1024
1107
 
1025
1108
  :param model_dir: The directory where the model is stored.
1109
+ :param relative_imports: Whether to use relative imports in the generated python file.
1026
1110
  """
1027
1111
  for rdr in self.start_rules_dict.values():
1028
- rdr._write_to_python(model_dir)
1029
- func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
1030
- with open(model_dir + f"/{self.generated_python_file_name}.py", "w") as f:
1031
- f.write(self._get_imports() + "\n\n")
1112
+ rdr._write_to_python(model_dir, package_name=package_name)
1113
+ func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
1114
+ file_path = model_dir + f"/{self.generated_python_file_name}.py"
1115
+ with open(file_path, "w") as f:
1116
+ f.write(self._get_imports(file_path=file_path, package_name=package_name) + "\n\n")
1032
1117
  f.write("classifiers_dict = dict()\n")
1033
1118
  for rdr_key, rdr in self.start_rules_dict.items():
1034
1119
  f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
@@ -1036,7 +1121,7 @@ class GeneralRDR(RippleDownRules):
1036
1121
  f.write(func_def)
1037
1122
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
1038
1123
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
1039
- f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
1124
+ f.write(f"{' ' * 4}return general_rdr_classify(classifiers_dict, case, **kwargs)\n")
1040
1125
 
1041
1126
  @property
1042
1127
  def _default_generated_python_file_name(self) -> Optional[str]:
@@ -1051,25 +1136,29 @@ class GeneralRDR(RippleDownRules):
1051
1136
  def conclusion_type_hint(self) -> str:
1052
1137
  return f"Dict[str, Any]"
1053
1138
 
1054
- def _get_imports(self) -> str:
1139
+ def _get_imports(self, file_path: Optional[str] = None, package_name: Optional[str] = None) -> str:
1055
1140
  """
1056
1141
  Get the imports needed for the generated python file.
1057
1142
 
1143
+ :param file_path: The path to the file where the imports will be written, if None, the imports will be absolute.
1144
+ :param package_name: The name of the package that contains the RDR classifier function, this
1145
+ is required in case of relative imports in the generated python file.
1058
1146
  :return: The imports needed for the generated python file.
1059
1147
  """
1060
- imports = ""
1148
+ all_types = set()
1061
1149
  # add type hints
1062
- imports += f"from typing_extensions import Dict, Any\n"
1150
+ all_types.update({Dict, Any})
1063
1151
  # import rdr type
1064
- imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
1152
+ all_types.add(general_rdr_classify)
1065
1153
  # add case type
1066
- imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
1067
- imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
1154
+ all_types.update({Case, create_case, self.case_type})
1155
+ # get the imports from the types
1156
+ imports = get_imports_from_types(all_types, target_file_path=file_path, package_name=package_name)
1068
1157
  # add rdr python generated functions.
1069
1158
  for rdr_key, rdr in self.start_rules_dict.items():
1070
- imports += (f"from ."
1071
- f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
1072
- return imports
1159
+ imports.append(
1160
+ f"from . import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}")
1161
+ return '\n'.join(imports)
1073
1162
 
1074
1163
  @staticmethod
1075
1164
  def rdr_key_to_function_name(rdr_key: str) -> str: