ripple-down-rules 0.5.75__py3-none-any.whl → 0.5.81__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- __version__ = "0.5.75"
1
+ __version__ = "0.5.81"
2
2
 
3
3
  import logging
4
4
  logger = logging.Logger("rdr")
ripple_down_rules/rdr.py CHANGED
@@ -1,20 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
- import copyreg
4
3
  import importlib
5
4
  import os
5
+ from abc import ABC, abstractmethod
6
+ from copy import copy
6
7
 
7
8
  from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
8
-
9
9
  from . import logger
10
- import sys
11
- from abc import ABC, abstractmethod
12
- from copy import copy
13
- from io import TextIOWrapper
14
- from types import ModuleType
15
10
 
16
11
  try:
17
12
  from matplotlib import pyplot as plt
13
+
18
14
  Figure = plt.Figure
19
15
  except ImportError as e:
20
16
  logger.debug(f"{e}: matplotlib is not installed")
@@ -32,13 +28,13 @@ from .datastructures.enums import MCRDRMode
32
28
  from .experts import Expert, Human
33
29
  from .helpers import is_matching, general_rdr_classify
34
30
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
31
+
35
32
  try:
36
33
  from .user_interface.gui import RDRCaseViewer
37
34
  except ImportError as e:
38
35
  RDRCaseViewer = None
39
- from .utils import draw_tree, make_set, copy_case, \
40
- SubclassJSONSerializer, make_list, get_type_from_string, \
41
- is_conflicting, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
36
+ from .utils import draw_tree, make_set, SubclassJSONSerializer, make_list, get_type_from_string, \
37
+ is_conflicting, extract_function_source, extract_imports, get_full_class_name, \
42
38
  is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types
43
39
 
44
40
 
@@ -98,13 +94,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
98
94
  if self.viewer is not None:
99
95
  self.viewer.set_save_function(self.save)
100
96
 
101
- def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None) -> str:
97
+ def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None,
98
+ package_name: Optional[str] = None) -> str:
102
99
  """
103
100
  Save the classifier to a file.
104
101
 
105
102
  :param save_dir: The directory to save the classifier to.
106
103
  :param model_name: The name of the model to save. If None, a default name is generated.
107
- :param postfix: The postfix to add to the file name.
104
+ :param package_name: The name of the package that contains the RDR classifier function, this
105
+ is required in case of relative imports in the generated python file.
108
106
  :return: The name of the saved model.
109
107
  """
110
108
  save_dir = save_dir or self.save_dir
@@ -124,22 +122,25 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
124
122
  json_dir = os.path.join(model_dir, self.metadata_folder)
125
123
  os.makedirs(json_dir, exist_ok=True)
126
124
  self.to_json_file(os.path.join(json_dir, self.model_name))
127
- self._write_to_python(model_dir)
125
+ self._write_to_python(model_dir, package_name=package_name)
128
126
  return self.model_name
129
127
 
130
128
  @classmethod
131
- def load(cls, load_dir: str, model_name: str) -> Self:
129
+ def load(cls, load_dir: str, model_name: str,
130
+ package_name: Optional[str] = None) -> Self:
132
131
  """
133
132
  Load the classifier from a file.
134
133
 
135
134
  :param load_dir: The path to the model directory to load the classifier from.
136
135
  :param model_name: The name of the model to load.
136
+ :param package_name: The name of the package that contains the RDR classifier function, this
137
+ is required in case of relative imports in the generated python file.
137
138
  """
138
139
  model_dir = os.path.join(load_dir, model_name)
139
140
  json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
140
141
  rdr = cls.from_json_file(json_file)
141
142
  try:
142
- rdr.update_from_python(model_dir)
143
+ rdr.update_from_python(model_dir, package_name=package_name)
143
144
  except (FileNotFoundError, ValueError) as e:
144
145
  logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
145
146
  f"Make sure the file exists and is valid.")
@@ -148,11 +149,13 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
148
149
  return rdr
149
150
 
150
151
  @abstractmethod
151
- def _write_to_python(self, model_dir: str):
152
+ def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
152
153
  """
153
154
  Write the tree of rules as source code to a file.
154
155
 
155
156
  :param model_dir: The path to the directory to write the source code to.
157
+ :param package_name: The name of the package that contains the RDR classifier function, this
158
+ is required in case of relative imports in the generated python file.
156
159
  """
157
160
  pass
158
161
 
@@ -373,11 +376,13 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
373
376
  pass
374
377
 
375
378
  @abstractmethod
376
- def update_from_python(self, model_dir: str):
379
+ def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
377
380
  """
378
381
  Update the rules from the generated python file, that might have been modified by the user.
379
382
 
380
383
  :param model_dir: The directory where the generated python file is located.
384
+ :param package_name: The name of the package that contains the RDR classifier function, this
385
+ is required in case of relative imports in the generated python file.
381
386
  """
382
387
  pass
383
388
 
@@ -408,30 +413,34 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
408
413
 
409
414
  class RDRWithCodeWriter(RippleDownRules, ABC):
410
415
 
411
- def update_from_python(self, model_dir: str):
416
+ def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
412
417
  """
413
418
  Update the rules from the generated python file, that might have been modified by the user.
414
419
 
415
420
  :param model_dir: The directory where the generated python file is located.
421
+ :param package_name: The name of the package that contains the RDR classifier function, this
422
+ is required in case of relative imports in the generated python file.
416
423
  """
417
- rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None}
424
+ rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants)
425
+ if r.conditions is not None}
418
426
  condition_func_names = [f'conditions_{rid}' for rid in rules_dict.keys()]
419
- conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys() if not isinstance(rules_dict[rid], MultiClassStopRule)]
427
+ conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys()
428
+ if not isinstance(rules_dict[rid], MultiClassStopRule)]
420
429
  all_func_names = condition_func_names + conclusion_func_names
421
430
  filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
422
431
  cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
423
432
  cases_import_path = get_import_path_from_path(model_dir)
424
- cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path\
433
+ cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path \
425
434
  else self.generated_python_cases_file_name
426
435
  functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
427
436
  # get the scope from the imports in the file
428
- scope = extract_imports(filepath)
437
+ scope = extract_imports(filepath, package_name=package_name)
429
438
  for rule in [self.start_rule] + list(self.start_rule.descendants):
430
439
  if rule.conditions is not None:
431
440
  rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
432
441
  rule.conditions.scope = scope
433
442
  if os.path.exists(cases_path):
434
- module = importlib.import_module(cases_import_path)
443
+ module = importlib.import_module(cases_import_path, package=package_name)
435
444
  importlib.reload(module)
436
445
  rule.corner_case_metadata = module.__dict__.get(f"corner_case_{rule.uid}", None)
437
446
  if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
@@ -440,7 +449,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
440
449
 
441
450
  @abstractmethod
442
451
  def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
443
- defs_file: Optional[str] = None, cases_file: Optional[str] = None):
452
+ defs_file: Optional[str] = None, cases_file: Optional[str] = None,
453
+ package_name: Optional[str] = None):
444
454
  """
445
455
  Write the rules as source code to a file.
446
456
 
@@ -449,42 +459,62 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
449
459
  :param parent_indent: The indentation of the parent rule.
450
460
  :param defs_file: The file to write the definitions to.
451
461
  :param cases_file: The file to write the cases to.
462
+ :param package_name: The name of the package that contains the RDR classifier function, this
463
+ is required in case of relative imports in the generated python file.
452
464
  """
453
465
  pass
454
466
 
455
- def _write_to_python(self, model_dir: str):
467
+ def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
456
468
  """
457
469
  Write the tree of rules as source code to a file.
458
470
 
459
471
  :param model_dir: The path to the directory to write the source code to.
472
+ :param package_name: The name of the package that contains the RDR classifier function, this
473
+ is required in case of relative imports in the generated python file.
460
474
  """
475
+ # Make sure the model directory exists and create an __init__.py file if it doesn't exist
461
476
  os.makedirs(model_dir, exist_ok=True)
462
477
  if not os.path.exists(model_dir + '/__init__.py'):
463
478
  with open(model_dir + '/__init__.py', 'w') as f:
464
479
  f.write("from . import *\n")
465
- func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
480
+
481
+ # Set the file names for the generated python files
466
482
  file_name = model_dir + f"/{self.generated_python_file_name}.py"
467
483
  defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
468
484
  cases_file_name = model_dir + f"/{self.generated_python_cases_file_name}.py"
469
- imports, defs_imports = self._get_imports()
470
- # clear the files first
485
+
486
+ # Get the required imports for the main file and the defs file
487
+ main_types, defs_types, corner_cases_types = self._get_types_to_import()
488
+ imports = get_imports_from_types(main_types, file_name, package_name)
489
+ defs_imports = get_imports_from_types(defs_types, defs_file_name, package_name)
490
+ corner_cases_imports = get_imports_from_types(corner_cases_types, cases_file_name, package_name)
491
+
492
+ # Add the imports to the defs file
471
493
  with open(defs_file_name, "w") as f:
472
- f.write(defs_imports + "\n\n")
473
- case_factory_import = get_imports_from_types([CaseFactoryMetaData])
494
+ f.write('\n'.join(defs_imports) + "\n\n\n")
495
+
496
+ # Add the imports to the cases file
497
+ case_factory_import = get_imports_from_types([CaseFactoryMetaData], cases_file_name, package_name)
498
+ corner_cases_imports.extend(case_factory_import)
474
499
  with open(cases_file_name, "w") as cases_f:
475
500
  cases_f.write("# This file contains the corner cases for the rules.\n")
476
- cases_f.write('\n'.join(case_factory_import) + "\n\n\n")
501
+ cases_f.write('\n'.join(corner_cases_imports) + "\n\n\n")
502
+
503
+ # Add the imports, the attributes, and the function definition to the main file
504
+ func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
477
505
  with open(file_name, "w") as f:
478
- imports += f"from .{self.generated_python_defs_file_name} import *\n"
479
- f.write(imports + "\n\n")
506
+ imports.append(f"from .{self.generated_python_defs_file_name} import *")
507
+ f.write('\n'.join(imports) + "\n\n\n")
480
508
  f.write(f"attribute_name = '{self.attribute_name}'\n")
481
509
  f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
482
510
  f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
483
511
  f.write(f"\n\n{func_def}")
484
512
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
485
513
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
514
+
515
+ # Write the rules as source code to the main file
486
516
  self.write_rules_as_source_code_to_file(self.start_rule, file_name, " " * 4, defs_file=defs_file_name,
487
- cases_file=cases_file_name)
517
+ cases_file=cases_file_name, package_name=package_name)
488
518
 
489
519
  @property
490
520
  @abstractmethod
@@ -494,31 +524,29 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
494
524
  """
495
525
  pass
496
526
 
497
- def _get_imports(self) -> Tuple[str, str]:
527
+ def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
498
528
  """
499
- :return: The imports for the generated python file of the RDR as a string.
529
+ :return: The types of the main, defs, and corner cases files of the RDR classifier that will be imported.
500
530
  """
501
- defs_imports_list = []
531
+ defs_types = set()
532
+ cases_types = set()
502
533
  for rule in [self.start_rule] + list(self.start_rule.descendants):
503
534
  if not rule.conditions:
504
535
  continue
505
536
  for scope in [rule.conditions.scope, rule.conclusion.scope]:
506
537
  if scope is None:
507
538
  continue
508
- defs_imports_list.extend(get_imports_from_scope(scope))
509
- if self.case_type.__module__ != "builtins":
510
- defs_imports_list.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
511
- defs_imports = "\n".join(set(defs_imports_list)) + "\n"
512
- imports = []
513
- if self.case_type.__module__ != "builtins":
514
- imports.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
515
- for conclusion_type in self.conclusion_type:
516
- if conclusion_type.__module__ != "builtins":
517
- imports.append(f"from {conclusion_type.__module__} import {conclusion_type.__name__}")
518
- imports.append("from ripple_down_rules.datastructures.case import Case, create_case")
519
- imports = set(imports).difference(defs_imports_list)
520
- imports = "\n".join(imports) + "\n"
521
- return imports, defs_imports
539
+ defs_types.update(make_set(scope.values()))
540
+ corner_case_types = rule.get_corner_case_types_to_import()
541
+ if corner_case_types is not None:
542
+ cases_types.update(corner_case_types)
543
+ defs_types.add(self.case_type)
544
+ main_types = set()
545
+ main_types.add(self.case_type)
546
+ main_types.update(make_set(self.conclusion_type))
547
+ main_types.update({Case, create_case})
548
+ main_types = main_types.difference(defs_types)
549
+ return main_types, defs_types, cases_types
522
550
 
523
551
  @property
524
552
  def _default_generated_python_file_name(self) -> Optional[str]:
@@ -537,7 +565,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
537
565
  def generated_python_cases_file_name(self) -> str:
538
566
  return f"{self.generated_python_file_name}_cases"
539
567
 
540
-
541
568
  @property
542
569
  def conclusion_type(self) -> Tuple[Type]:
543
570
  """
@@ -589,7 +616,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
589
616
 
590
617
 
591
618
  class SingleClassRDR(RDRWithCodeWriter):
592
-
593
619
  mutually_exclusive: bool = True
594
620
  """
595
621
  The output of the classification of this rdr negates all other possible outputs, there can only be one true value.
@@ -646,7 +672,7 @@ class SingleClassRDR(RDRWithCodeWriter):
646
672
  pred = self.evaluate(case)
647
673
  conclusion = pred.conclusion(case) if pred is not None else None
648
674
  if pred is not None and pred.fired and case_query is not None:
649
- if pred.corner_case_metadata is None and conclusion is not None\
675
+ if pred.corner_case_metadata is None and conclusion is not None \
650
676
  and type(conclusion) in case_query.core_attribute_type:
651
677
  pred.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
652
678
  return conclusion if pred is not None and pred.fired else self.default_conclusion
@@ -658,25 +684,27 @@ class SingleClassRDR(RDRWithCodeWriter):
658
684
  matched_rule = self.start_rule(case) if self.start_rule is not None else None
659
685
  return matched_rule if matched_rule is not None else self.start_rule
660
686
 
661
- def _write_to_python(self, model_dir: str):
662
- super()._write_to_python(model_dir)
687
+ def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
688
+ super()._write_to_python(model_dir, package_name=package_name)
663
689
  if self.default_conclusion is not None:
664
690
  with open(model_dir + f"/{self.generated_python_file_name}.py", "a") as f:
665
691
  f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
666
692
 
667
693
  def write_rules_as_source_code_to_file(self, rule: SingleClassRule, filename: str, parent_indent: str = "",
668
- defs_file: Optional[str] = None, cases_file: Optional[str] = None):
694
+ defs_file: Optional[str] = None, cases_file: Optional[str] = None,
695
+ package_name: Optional[str] = None):
669
696
  """
670
697
  Write the rules as source code to a file.
671
698
  """
672
699
  if rule.conditions:
673
- rule.write_corner_case_as_source_code(cases_file)
700
+ rule.write_corner_case_as_source_code(cases_file, package_name=package_name)
674
701
  if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
675
702
  with open(filename, "a") as file:
676
703
  file.write(if_clause)
677
704
  if rule.refinement:
678
705
  self.write_rules_as_source_code_to_file(rule.refinement, filename, parent_indent + " ",
679
- defs_file=defs_file, cases_file=cases_file)
706
+ defs_file=defs_file, cases_file=cases_file,
707
+ package_name=package_name)
680
708
 
681
709
  conclusion_call = rule.write_conclusion_as_source_code(parent_indent, defs_file)
682
710
  with open(filename, "a") as file:
@@ -684,7 +712,7 @@ class SingleClassRDR(RDRWithCodeWriter):
684
712
 
685
713
  if rule.alternative:
686
714
  self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
687
- cases_file=cases_file)
715
+ cases_file=cases_file, package_name=package_name)
688
716
 
689
717
  @property
690
718
  def conclusion_type_hint(self) -> str:
@@ -745,8 +773,9 @@ class MultiClassRDR(RDRWithCodeWriter):
745
773
  if evaluated_rule.fired:
746
774
  rule_conclusion = evaluated_rule.conclusion(case)
747
775
  if evaluated_rule.corner_case_metadata is None and case_query is not None:
748
- if rule_conclusion is not None and len(make_list(rule_conclusion)) > 0\
749
- and any(ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
776
+ if rule_conclusion is not None and len(make_list(rule_conclusion)) > 0 \
777
+ and any(
778
+ ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
750
779
  evaluated_rule.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
751
780
  self.add_conclusion(rule_conclusion)
752
781
  evaluated_rule = next_rule
@@ -789,19 +818,20 @@ class MultiClassRDR(RDRWithCodeWriter):
789
818
 
790
819
  def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
791
820
  filename: str, parent_indent: str = "", defs_file: Optional[str] = None,
792
- cases_file: Optional[str] = None):
821
+ cases_file: Optional[str] = None, package_name: Optional[str] = None):
793
822
  if rule == self.start_rule:
794
823
  with open(filename, "a") as file:
795
824
  file.write(f"{parent_indent}conclusions = set()\n")
796
825
  if rule.conditions:
797
- rule.write_corner_case_as_source_code(cases_file)
826
+ rule.write_corner_case_as_source_code(cases_file, package_name=package_name)
798
827
  if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
799
828
  with open(filename, "a") as file:
800
829
  file.write(if_clause)
801
830
  conclusion_indent = parent_indent
802
831
  if hasattr(rule, "refinement") and rule.refinement:
803
832
  self.write_rules_as_source_code_to_file(rule.refinement, filename, parent_indent + " ",
804
- defs_file=defs_file, cases_file=cases_file)
833
+ defs_file=defs_file, cases_file=cases_file,
834
+ package_name=package_name)
805
835
  conclusion_indent = parent_indent + " " * 4
806
836
  with open(filename, "a") as file:
807
837
  file.write(f"{conclusion_indent}else:\n")
@@ -812,7 +842,7 @@ class MultiClassRDR(RDRWithCodeWriter):
812
842
 
813
843
  if rule.alternative:
814
844
  self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
815
- cases_file=cases_file)
845
+ cases_file=cases_file, package_name=package_name)
816
846
 
817
847
  @property
818
848
  def conclusion_type_hint(self) -> str:
@@ -822,12 +852,11 @@ class MultiClassRDR(RDRWithCodeWriter):
822
852
  else:
823
853
  return f"Set[Union[{', '.join(conclusion_types)}]]"
824
854
 
825
- def _get_imports(self) -> Tuple[str, str]:
826
- imports, defs_imports = super()._get_imports()
827
- imports += f"from typing_extensions import Set, Union\n"
828
- imports += "from ripple_down_rules.utils import make_set\n"
829
- defs_imports += "from typing_extensions import Union\n"
830
- return imports, defs_imports
855
+ def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
856
+ main_types, defs_types, cases_types = super()._get_types_to_import()
857
+ main_types.update({Set, Union, make_set})
858
+ defs_types.add(Union)
859
+ return main_types, defs_types, cases_types
831
860
 
832
861
  def update_start_rule(self, case_query: CaseQuery, expert: Expert):
833
862
  """
@@ -1039,7 +1068,7 @@ class GeneralRDR(RippleDownRules):
1039
1068
 
1040
1069
  def _to_json(self) -> Dict[str, Any]:
1041
1070
  return {"start_rules": {name: rdr.to_json() for name, rdr in self.start_rules_dict.items()}
1042
- , "generated_python_file_name": self.generated_python_file_name,
1071
+ , "generated_python_file_name": self.generated_python_file_name,
1043
1072
  "name": self.name,
1044
1073
  "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
1045
1074
  "case_name": self.case_name}
@@ -1063,26 +1092,30 @@ class GeneralRDR(RippleDownRules):
1063
1092
  new_rdr.case_name = data["case_name"]
1064
1093
  return new_rdr
1065
1094
 
1066
- def update_from_python(self, model_dir: str) -> None:
1095
+ def update_from_python(self, model_dir: str, package_name: Optional[str] = None) -> None:
1067
1096
  """
1068
1097
  Update the rules from the generated python file, that might have been modified by the user.
1069
1098
 
1070
1099
  :param model_dir: The directory where the model is stored.
1100
+ :param package_name: The name of the package that contains the RDR classifier function, this
1101
+ is required in case of relative imports in the generated python file.
1071
1102
  """
1072
1103
  for rdr in self.start_rules_dict.values():
1073
- rdr.update_from_python(model_dir)
1104
+ rdr.update_from_python(model_dir, package_name=package_name)
1074
1105
 
1075
- def _write_to_python(self, model_dir: str) -> None:
1106
+ def _write_to_python(self, model_dir: str, package_name: Optional[str] = None) -> None:
1076
1107
  """
1077
1108
  Write the tree of rules as source code to a file.
1078
1109
 
1079
1110
  :param model_dir: The directory where the model is stored.
1111
+ :param relative_imports: Whether to use relative imports in the generated python file.
1080
1112
  """
1081
1113
  for rdr in self.start_rules_dict.values():
1082
- rdr._write_to_python(model_dir)
1114
+ rdr._write_to_python(model_dir, package_name=package_name)
1083
1115
  func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
1084
- with open(model_dir + f"/{self.generated_python_file_name}.py", "w") as f:
1085
- f.write(self._get_imports() + "\n\n")
1116
+ file_path = model_dir + f"/{self.generated_python_file_name}.py"
1117
+ with open(file_path, "w") as f:
1118
+ f.write(self._get_imports(file_path=file_path, package_name=package_name) + "\n\n")
1086
1119
  f.write("classifiers_dict = dict()\n")
1087
1120
  for rdr_key, rdr in self.start_rules_dict.items():
1088
1121
  f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
@@ -1105,25 +1138,25 @@ class GeneralRDR(RippleDownRules):
1105
1138
  def conclusion_type_hint(self) -> str:
1106
1139
  return f"Dict[str, Any]"
1107
1140
 
1108
- def _get_imports(self) -> str:
1141
+ def _get_imports(self, file_path: Optional[str] = None, package_name: Optional[str] = None) -> str:
1109
1142
  """
1110
1143
  Get the imports needed for the generated python file.
1111
1144
 
1145
+ :param file_path: The path to the file where the imports will be written, if None, the imports will be absolute.
1146
+ :param package_name: The name of the package that contains the RDR classifier function, this
1147
+ is required in case of relative imports in the generated python file.
1112
1148
  :return: The imports needed for the generated python file.
1113
1149
  """
1114
- imports = ""
1115
- # add type hints
1116
- imports += f"from typing_extensions import Dict, Any\n"
1117
- # import rdr type
1118
- imports += f"from ripple_down_rules.helpers import general_rdr_classify\n"
1119
- # add case type
1120
- imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
1121
- imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
1150
+ # get the imports from the types
1151
+ imports = get_imports_from_types([self.case_type], target_file_path=file_path, package_name=package_name)
1122
1152
  # add rdr python generated functions.
1153
+ imports.append("from typing import Dict, Any")
1154
+ imports.append("from ripple_down_rules.datastructures.case import Case, create_case")
1155
+ imports.append("from ripple_down_rules.helpers import general_rdr_classify")
1123
1156
  for rdr_key, rdr in self.start_rules_dict.items():
1124
- imports += (f"from ."
1125
- f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
1126
- return imports
1157
+ imports.append(
1158
+ f"from . import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}")
1159
+ return '\n'.join(imports)
1127
1160
 
1128
1161
  @staticmethod
1129
1162
  def rdr_key_to_function_name(rdr_key: str) -> str:
@@ -8,7 +8,7 @@ from uuid import uuid4
8
8
 
9
9
  from anytree import NodeMixin
10
10
  from sqlalchemy.orm import DeclarativeBase as SQLTable
11
- from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
11
+ from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple, Type, Set
12
12
 
13
13
  from .datastructures.callable_expression import CallableExpression
14
14
  from .datastructures.case import Case
@@ -102,11 +102,21 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
102
102
  """
103
103
  pass
104
104
 
105
- def write_corner_case_as_source_code(self, cases_file: Path) -> None:
105
+ def write_corner_case_as_source_code(self, cases_file: str, package_name: Optional[str] = None) -> None:
106
106
  """
107
107
  Write the source code representation of the corner case of the rule to a file.
108
108
 
109
- :param cases_file: The file to write the corner case to if it is a definition.
109
+ :param cases_file: The file to write the corner case to.
110
+ :param package_name: The package name to use for relative imports.
111
+ """
112
+ if self.corner_case_metadata is None:
113
+ return
114
+ with open(cases_file, 'a') as f:
115
+ f.write(f"corner_case_{self.uid} = {self.corner_case_metadata}" + "\n\n\n")
116
+
117
+ def get_corner_case_types_to_import(self) -> Set[Type]:
118
+ """
119
+ Get the types that need to be imported for the corner case of the rule.
110
120
  """
111
121
  if self.corner_case_metadata is None:
112
122
  return
@@ -117,10 +127,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
117
127
  types_to_import.add(self.corner_case_metadata.scenario)
118
128
  if self.corner_case_metadata.case_conf is not None:
119
129
  types_to_import.add(self.corner_case_metadata.case_conf)
120
- imports = get_imports_from_types(list(types_to_import))
121
- with open(cases_file, 'a') as f:
122
- f.write("\n".join(imports) + "\n\n\n")
123
- f.write(f"corner_case_{self.uid} = {self.corner_case_metadata}" + "\n\n\n")
130
+ return types_to_import
124
131
 
125
132
  def write_conclusion_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
126
133
  """
@@ -198,7 +198,7 @@ class TemplateFileCreator:
198
198
  with open(self.temp_file_path, 'w+') as f:
199
199
  f.write(code)
200
200
 
201
- def get_imports(self):
201
+ def get_imports(self) -> str:
202
202
  """
203
203
  :return: A string containing the imports for the function.
204
204
  """
@@ -8,6 +8,7 @@ import json
8
8
  import logging
9
9
  import os
10
10
  import re
11
+ import sys
11
12
  import threading
12
13
  import uuid
13
14
  from collections import UserDict, defaultdict
@@ -15,6 +16,7 @@ from copy import deepcopy, copy
15
16
  from dataclasses import is_dataclass, fields
16
17
  from enum import Enum
17
18
  from os.path import dirname
19
+ from pathlib import Path
18
20
  from textwrap import dedent
19
21
  from types import NoneType
20
22
 
@@ -44,7 +46,7 @@ from sqlalchemy import MetaData, inspect
44
46
  from sqlalchemy.orm import Mapped, registry, class_mapper, DeclarativeBase as SQLTable, Session
45
47
  from tabulate import tabulate
46
48
  from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get_type_hints, \
47
- get_origin, get_args, Tuple, Optional, List, Union, Self, ForwardRef
49
+ get_origin, get_args, Tuple, Optional, List, Union, Self, ForwardRef, Sequence, Iterable
48
50
 
49
51
  if TYPE_CHECKING:
50
52
  from .datastructures.case import Case
@@ -122,7 +124,15 @@ def get_imports_from_scope(scope: Dict[str, Any]) -> List[str]:
122
124
  return imports
123
125
 
124
126
 
125
- def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = None) -> Dict[str, Any]:
127
+ def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = None,
128
+ package_name: Optional[str] = None) -> Dict[str, Any]:
129
+ """
130
+ Extract imports from a Python file or an AST tree.
131
+
132
+ :param file_path: The path to the Python file to extract imports from.
133
+ :param tree: An AST tree to extract imports from. If provided, file_path is ignored.
134
+ :param package_name: The name of the package to use for relative imports.
135
+ """
126
136
  if tree is None:
127
137
  if file_path is None:
128
138
  raise ValueError("Either file_path or tree must be provided")
@@ -137,7 +147,7 @@ def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = N
137
147
  module_name = alias.name
138
148
  asname = alias.asname or alias.name
139
149
  try:
140
- scope[asname] = importlib.import_module(module_name)
150
+ scope[asname] = importlib.import_module(module_name, package=package_name)
141
151
  except ImportError as e:
142
152
  print(f"Could not import {module_name}: {e}")
143
153
  elif isinstance(node, ast.ImportFrom):
@@ -146,7 +156,12 @@ def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = N
146
156
  name = alias.name
147
157
  asname = alias.asname or name
148
158
  try:
149
- module = importlib.import_module(module_name)
159
+ if package_name is not None and node.level > 0: # Handle relative imports
160
+ module_rel_path = Path(os.path.join(file_path, *['..'] * node.level, module_name)).resolve()
161
+ idx = str(module_rel_path).rfind(package_name)
162
+ if idx != -1:
163
+ module_name = str(module_rel_path)[idx:].replace(os.path.sep, '.')
164
+ module = importlib.import_module(module_name, package=package_name)
150
165
  scope[asname] = getattr(module, name)
151
166
  except (ImportError, AttributeError) as e:
152
167
  logging.warning(f"Could not import {module_name}: {e} while extracting imports from {file_path}")
@@ -833,37 +848,82 @@ def get_function_representation(func: Callable) -> str:
833
848
  return func_name
834
849
 
835
850
 
836
- def get_imports_from_types(type_objs: List[Type]) -> List[str]:
851
+ def get_relative_import(target_file_path, imported_module_path: Optional[str] = None,
852
+ module: Optional[str] = None) -> str:
853
+ """
854
+ Get a relative import path from the target file to the imported module.
855
+
856
+ :param target_file_path: The file path of the target file.
857
+ :param imported_module_path: The file path of the module being imported.
858
+ :param module: The module name, if available.
859
+ :return: A relative import path as a string.
860
+ """
861
+ # Convert to absolute paths
862
+ if module is not None:
863
+ imported_module_path = sys.modules[module].__file__
864
+ if imported_module_path is None:
865
+ raise ValueError("Either imported_module_path or module must be provided")
866
+ target_path = Path(target_file_path).resolve()
867
+ imported_path = Path(imported_module_path).resolve()
868
+
869
+ # Compute relative path from target to imported module
870
+ rel_path = os.path.relpath(imported_path.parent, target_path.parent)
871
+
872
+ # Convert path to Python import format
873
+ rel_parts = [part.replace('..', '.') for part in Path(rel_path).parts]
874
+ rel_parts = rel_parts if rel_parts else ['']
875
+
876
+ # Join the parts and add the module name
877
+ joined_parts = "".join(rel_parts) + f".{imported_path.stem}"
878
+ joined_parts = f".{joined_parts}" if not joined_parts.startswith(".") else joined_parts
879
+
880
+ return joined_parts
881
+
882
+
883
+ def get_imports_from_types(type_objs: Iterable[Type],
884
+ target_file_path: Optional[str] = None,
885
+ package_name: Optional[str] = None) -> List[str]:
837
886
  """
838
887
  Format import lines from type objects.
839
888
 
840
889
  :param type_objs: A list of type objects to format.
890
+ :param target_file_path: The file path to which the imports should be relative.
891
+ :param package_name: The name of the package to use for relative imports.
841
892
  """
842
893
 
843
894
  module_to_types = defaultdict(list)
895
+ module_to_path = {}
844
896
  other_imports = []
845
897
  for tp in type_objs:
846
898
  try:
847
899
  if isinstance(tp, type) or is_typing_type(tp):
848
900
  module = tp.__module__
901
+ file = getattr(tp, '__file__', None)
849
902
  name = tp.__qualname__
850
903
  elif callable(tp):
851
904
  module, name = get_function_import_data(tp)
905
+ file = get_method_file_name(tp)
852
906
  elif hasattr(type(tp), "__module__"):
853
907
  module = type(tp).__module__
908
+ file = getattr(tp, '__file__', None)
854
909
  name = type(tp).__qualname__
855
910
  else:
856
911
  continue
857
912
  if module is None or module == 'builtins' or module.startswith('_'):
858
913
  continue
859
914
  module_to_types[module].append(name)
915
+ if file:
916
+ module_to_path[module] = file
860
917
  except AttributeError:
861
918
  continue
862
919
 
863
920
  lines = []
864
921
  for module, names in module_to_types.items():
865
922
  joined = ", ".join(sorted(set(names)))
866
- lines.append(f"from {module} import {joined}")
923
+ import_path = module
924
+ if (target_file_path is not None) and (package_name is not None) and (package_name in module):
925
+ import_path = get_relative_import(target_file_path, module=module)
926
+ lines.append(f"from {import_path} import {joined}")
867
927
  if other_imports:
868
928
  lines.extend(other_imports)
869
929
  return sorted(lines)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.5.75
3
+ Version: 0.5.81
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -1,11 +1,11 @@
1
- ripple_down_rules/__init__.py,sha256=6Ze00N3Py1dmFEMGBz3jz63qsUVLz8WzFn7qx3lJnfM,100
1
+ ripple_down_rules/__init__.py,sha256=Hf2qy117DtlfEhbMoyreqc3k02Y_tLUwlxP1IJySXWA,100
2
2
  ripple_down_rules/experts.py,sha256=bwozulI1rv0uyaMZQqEgapDO-s8wvW0D6Jqxmvu5fik,12610
3
3
  ripple_down_rules/helpers.py,sha256=v4oE7C5PfQUVJfSUs1FfLHEwrJXEHJLn4vJhJMvyCR8,4453
4
- ripple_down_rules/rdr.py,sha256=Mqh7lDjQu6wZUcJiJ57CZ3P0-hM4WfhFuV4s1jZnRv8,51833
4
+ ripple_down_rules/rdr.py,sha256=2gh_kRmsR58pLqY4OvhjOgHSIYIcST7tI7jjVQ8yhD4,55214
5
5
  ripple_down_rules/rdr_decorators.py,sha256=0sk7izDB53lTKSB9fm33vQahmY_05FyCOWljyQOMB0U,9072
6
- ripple_down_rules/rules.py,sha256=ctf9yREG5l99HPFcYosjppKXTOwplZmzQbm4R1DMVaA,20107
6
+ ripple_down_rules/rules.py,sha256=iVevv6iZ-6L2IPI0ZYbBjxBymXEQMmJGRFhiKUS-NmA,20352
7
7
  ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
8
- ripple_down_rules/utils.py,sha256=iwfpTlsxUqLHWpYqSKwrDnEEa_FYFHYb2LugEVDH_kk,57132
8
+ ripple_down_rules/utils.py,sha256=mObBszTruGrRvD4MgD8tS1AnMtoyKrPl4RCciinhzY4,60132
9
9
  ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
10
10
  ripple_down_rules/datastructures/callable_expression.py,sha256=f3wUPTrLa1INO-1qfgVz87ryrCABronfyq0_JKWoZCs,12800
11
11
  ripple_down_rules/datastructures/case.py,sha256=1zSaXUljaH6z3SgMGzYPoDyjotNam791KpYgvxuMh90,15463
@@ -16,9 +16,9 @@ ripple_down_rules/user_interface/gui.py,sha256=_lgZAUXxxaBUFQJAHjA5TBPp6XEvJ62t-
16
16
  ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=Jrf7NxOdlrwGXH0Xyz3vzQprY-PNx9etfePOTpm2Gu8,6479
17
17
  ripple_down_rules/user_interface/object_diagram.py,sha256=FEa2HaYR9QmTE6NsOwBvZ0jqmu3DKyg6mig2VE5ZP4Y,4956
18
18
  ripple_down_rules/user_interface/prompt.py,sha256=AkkltdDIaioN43lkRKDPKSjJcmdSSGZDMYz7AL7X9lE,8082
19
- ripple_down_rules/user_interface/template_file_creator.py,sha256=VLS9Nxg6gPNa-YYliJ_VNsTvLPlZ003EVkJ2t8zuDgE,13563
20
- ripple_down_rules-0.5.75.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
21
- ripple_down_rules-0.5.75.dist-info/METADATA,sha256=M1N-k7Zp8qzOsGlF8K4n889agx_bLu_rWO9_c-cEViQ,48214
22
- ripple_down_rules-0.5.75.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- ripple_down_rules-0.5.75.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
24
- ripple_down_rules-0.5.75.dist-info/RECORD,,
19
+ ripple_down_rules/user_interface/template_file_creator.py,sha256=xWUt-RHRqrvoMo74o-bMLo8xNxil68wgbOZAMADZp2A,13570
20
+ ripple_down_rules-0.5.81.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
21
+ ripple_down_rules-0.5.81.dist-info/METADATA,sha256=P3JwR4h_dcxZ9w-S7DSflG8BS9NZvW-IrCNF1-H3-dA,48214
22
+ ripple_down_rules-0.5.81.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ ripple_down_rules-0.5.81.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
24
+ ripple_down_rules-0.5.81.dist-info/RECORD,,