ripple-down-rules 0.5.75__py3-none-any.whl → 0.5.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- __version__ = "0.5.75"
1
+ __version__ = "0.5.80"
2
2
 
3
3
  import logging
4
4
  logger = logging.Logger("rdr")
ripple_down_rules/rdr.py CHANGED
@@ -1,20 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
- import copyreg
4
3
  import importlib
5
4
  import os
5
+ from abc import ABC, abstractmethod
6
+ from copy import copy
6
7
 
7
8
  from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
8
-
9
9
  from . import logger
10
- import sys
11
- from abc import ABC, abstractmethod
12
- from copy import copy
13
- from io import TextIOWrapper
14
- from types import ModuleType
15
10
 
16
11
  try:
17
12
  from matplotlib import pyplot as plt
13
+
18
14
  Figure = plt.Figure
19
15
  except ImportError as e:
20
16
  logger.debug(f"{e}: matplotlib is not installed")
@@ -32,13 +28,13 @@ from .datastructures.enums import MCRDRMode
32
28
  from .experts import Expert, Human
33
29
  from .helpers import is_matching, general_rdr_classify
34
30
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
31
+
35
32
  try:
36
33
  from .user_interface.gui import RDRCaseViewer
37
34
  except ImportError as e:
38
35
  RDRCaseViewer = None
39
- from .utils import draw_tree, make_set, copy_case, \
40
- SubclassJSONSerializer, make_list, get_type_from_string, \
41
- is_conflicting, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
36
+ from .utils import draw_tree, make_set, SubclassJSONSerializer, make_list, get_type_from_string, \
37
+ is_conflicting, extract_function_source, extract_imports, get_full_class_name, \
42
38
  is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types
43
39
 
44
40
 
@@ -98,13 +94,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
98
94
  if self.viewer is not None:
99
95
  self.viewer.set_save_function(self.save)
100
96
 
101
- def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None) -> str:
97
+ def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None,
98
+ package_name: Optional[str] = None) -> str:
102
99
  """
103
100
  Save the classifier to a file.
104
101
 
105
102
  :param save_dir: The directory to save the classifier to.
106
103
  :param model_name: The name of the model to save. If None, a default name is generated.
107
- :param postfix: The postfix to add to the file name.
104
+ :param package_name: The name of the package that contains the RDR classifier function, this
105
+ is required in case of relative imports in the generated python file.
108
106
  :return: The name of the saved model.
109
107
  """
110
108
  save_dir = save_dir or self.save_dir
@@ -124,22 +122,25 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
124
122
  json_dir = os.path.join(model_dir, self.metadata_folder)
125
123
  os.makedirs(json_dir, exist_ok=True)
126
124
  self.to_json_file(os.path.join(json_dir, self.model_name))
127
- self._write_to_python(model_dir)
125
+ self._write_to_python(model_dir, package_name=package_name)
128
126
  return self.model_name
129
127
 
130
128
  @classmethod
131
- def load(cls, load_dir: str, model_name: str) -> Self:
129
+ def load(cls, load_dir: str, model_name: str,
130
+ package_name: Optional[str] = None) -> Self:
132
131
  """
133
132
  Load the classifier from a file.
134
133
 
135
134
  :param load_dir: The path to the model directory to load the classifier from.
136
135
  :param model_name: The name of the model to load.
136
+ :param package_name: The name of the package that contains the RDR classifier function, this
137
+ is required in case of relative imports in the generated python file.
137
138
  """
138
139
  model_dir = os.path.join(load_dir, model_name)
139
140
  json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
140
141
  rdr = cls.from_json_file(json_file)
141
142
  try:
142
- rdr.update_from_python(model_dir)
143
+ rdr.update_from_python(model_dir, package_name=package_name)
143
144
  except (FileNotFoundError, ValueError) as e:
144
145
  logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
145
146
  f"Make sure the file exists and is valid.")
@@ -148,11 +149,13 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
148
149
  return rdr
149
150
 
150
151
  @abstractmethod
151
- def _write_to_python(self, model_dir: str):
152
+ def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
152
153
  """
153
154
  Write the tree of rules as source code to a file.
154
155
 
155
156
  :param model_dir: The path to the directory to write the source code to.
157
+ :param package_name: The name of the package that contains the RDR classifier function, this
158
+ is required in case of relative imports in the generated python file.
156
159
  """
157
160
  pass
158
161
 
@@ -373,11 +376,13 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
373
376
  pass
374
377
 
375
378
  @abstractmethod
376
- def update_from_python(self, model_dir: str):
379
+ def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
377
380
  """
378
381
  Update the rules from the generated python file, that might have been modified by the user.
379
382
 
380
383
  :param model_dir: The directory where the generated python file is located.
384
+ :param package_name: The name of the package that contains the RDR classifier function, this
385
+ is required in case of relative imports in the generated python file.
381
386
  """
382
387
  pass
383
388
 
@@ -408,30 +413,34 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
408
413
 
409
414
  class RDRWithCodeWriter(RippleDownRules, ABC):
410
415
 
411
- def update_from_python(self, model_dir: str):
416
+ def update_from_python(self, model_dir: str, package_name: Optional[str] = None):
412
417
  """
413
418
  Update the rules from the generated python file, that might have been modified by the user.
414
419
 
415
420
  :param model_dir: The directory where the generated python file is located.
421
+ :param package_name: The name of the package that contains the RDR classifier function, this
422
+ is required in case of relative imports in the generated python file.
416
423
  """
417
- rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None}
424
+ rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants)
425
+ if r.conditions is not None}
418
426
  condition_func_names = [f'conditions_{rid}' for rid in rules_dict.keys()]
419
- conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys() if not isinstance(rules_dict[rid], MultiClassStopRule)]
427
+ conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys()
428
+ if not isinstance(rules_dict[rid], MultiClassStopRule)]
420
429
  all_func_names = condition_func_names + conclusion_func_names
421
430
  filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
422
431
  cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
423
432
  cases_import_path = get_import_path_from_path(model_dir)
424
- cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path\
433
+ cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path \
425
434
  else self.generated_python_cases_file_name
426
435
  functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
427
436
  # get the scope from the imports in the file
428
- scope = extract_imports(filepath)
437
+ scope = extract_imports(filepath, package_name=package_name)
429
438
  for rule in [self.start_rule] + list(self.start_rule.descendants):
430
439
  if rule.conditions is not None:
431
440
  rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
432
441
  rule.conditions.scope = scope
433
442
  if os.path.exists(cases_path):
434
- module = importlib.import_module(cases_import_path)
443
+ module = importlib.import_module(cases_import_path, package=package_name)
435
444
  importlib.reload(module)
436
445
  rule.corner_case_metadata = module.__dict__.get(f"corner_case_{rule.uid}", None)
437
446
  if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
@@ -440,7 +449,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
440
449
 
441
450
  @abstractmethod
442
451
  def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
443
- defs_file: Optional[str] = None, cases_file: Optional[str] = None):
452
+ defs_file: Optional[str] = None, cases_file: Optional[str] = None,
453
+ package_name: Optional[str] = None):
444
454
  """
445
455
  Write the rules as source code to a file.
446
456
 
@@ -449,42 +459,62 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
449
459
  :param parent_indent: The indentation of the parent rule.
450
460
  :param defs_file: The file to write the definitions to.
451
461
  :param cases_file: The file to write the cases to.
462
+ :param package_name: The name of the package that contains the RDR classifier function, this
463
+ is required in case of relative imports in the generated python file.
452
464
  """
453
465
  pass
454
466
 
455
- def _write_to_python(self, model_dir: str):
467
+ def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
456
468
  """
457
469
  Write the tree of rules as source code to a file.
458
470
 
459
471
  :param model_dir: The path to the directory to write the source code to.
472
+ :param package_name: The name of the package that contains the RDR classifier function, this
473
+ is required in case of relative imports in the generated python file.
460
474
  """
475
+ # Make sure the model directory exists and create an __init__.py file if it doesn't exist
461
476
  os.makedirs(model_dir, exist_ok=True)
462
477
  if not os.path.exists(model_dir + '/__init__.py'):
463
478
  with open(model_dir + '/__init__.py', 'w') as f:
464
479
  f.write("from . import *\n")
465
- func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
480
+
481
+ # Set the file names for the generated python files
466
482
  file_name = model_dir + f"/{self.generated_python_file_name}.py"
467
483
  defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
468
484
  cases_file_name = model_dir + f"/{self.generated_python_cases_file_name}.py"
469
- imports, defs_imports = self._get_imports()
470
- # clear the files first
485
+
486
+ # Get the required imports for the main file and the defs file
487
+ main_types, defs_types, corner_cases_types = self._get_types_to_import()
488
+ imports = get_imports_from_types(main_types, file_name, package_name)
489
+ defs_imports = get_imports_from_types(defs_types, defs_file_name, package_name)
490
+ corner_cases_imports = get_imports_from_types(corner_cases_types, cases_file_name, package_name)
491
+
492
+ # Add the imports to the defs file
471
493
  with open(defs_file_name, "w") as f:
472
- f.write(defs_imports + "\n\n")
473
- case_factory_import = get_imports_from_types([CaseFactoryMetaData])
494
+ f.write('\n'.join(defs_imports) + "\n\n\n")
495
+
496
+ # Add the imports to the cases file
497
+ case_factory_import = get_imports_from_types([CaseFactoryMetaData], cases_file_name, package_name)
498
+ corner_cases_imports.extend(case_factory_import)
474
499
  with open(cases_file_name, "w") as cases_f:
475
500
  cases_f.write("# This file contains the corner cases for the rules.\n")
476
- cases_f.write('\n'.join(case_factory_import) + "\n\n\n")
501
+ cases_f.write('\n'.join(corner_cases_imports) + "\n\n\n")
502
+
503
+ # Add the imports, the attributes, and the function definition to the main file
504
+ func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
477
505
  with open(file_name, "w") as f:
478
- imports += f"from .{self.generated_python_defs_file_name} import *\n"
479
- f.write(imports + "\n\n")
506
+ imports.append(f"from .{self.generated_python_defs_file_name} import *")
507
+ f.write('\n'.join(imports) + "\n\n\n")
480
508
  f.write(f"attribute_name = '{self.attribute_name}'\n")
481
509
  f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
482
510
  f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
483
511
  f.write(f"\n\n{func_def}")
484
512
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
485
513
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
514
+
515
+ # Write the rules as source code to the main file
486
516
  self.write_rules_as_source_code_to_file(self.start_rule, file_name, " " * 4, defs_file=defs_file_name,
487
- cases_file=cases_file_name)
517
+ cases_file=cases_file_name, package_name=package_name)
488
518
 
489
519
  @property
490
520
  @abstractmethod
@@ -494,31 +524,27 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
494
524
  """
495
525
  pass
496
526
 
497
- def _get_imports(self) -> Tuple[str, str]:
527
+ def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
498
528
  """
499
- :return: The imports for the generated python file of the RDR as a string.
529
+ :return: The types of the main, defs, and corner cases files of the RDR classifier that will be imported.
500
530
  """
501
- defs_imports_list = []
531
+ defs_types = set()
532
+ cases_types = set()
502
533
  for rule in [self.start_rule] + list(self.start_rule.descendants):
503
534
  if not rule.conditions:
504
535
  continue
505
536
  for scope in [rule.conditions.scope, rule.conclusion.scope]:
506
537
  if scope is None:
507
538
  continue
508
- defs_imports_list.extend(get_imports_from_scope(scope))
509
- if self.case_type.__module__ != "builtins":
510
- defs_imports_list.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
511
- defs_imports = "\n".join(set(defs_imports_list)) + "\n"
512
- imports = []
513
- if self.case_type.__module__ != "builtins":
514
- imports.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
515
- for conclusion_type in self.conclusion_type:
516
- if conclusion_type.__module__ != "builtins":
517
- imports.append(f"from {conclusion_type.__module__} import {conclusion_type.__name__}")
518
- imports.append("from ripple_down_rules.datastructures.case import Case, create_case")
519
- imports = set(imports).difference(defs_imports_list)
520
- imports = "\n".join(imports) + "\n"
521
- return imports, defs_imports
539
+ defs_types.update(make_set(scope.values()))
540
+ cases_types.update(rule.get_corner_case_types_to_import())
541
+ defs_types.add(self.case_type)
542
+ main_types = set()
543
+ main_types.add(self.case_type)
544
+ main_types.update(make_set(self.conclusion_type))
545
+ main_types.update({Case, create_case})
546
+ main_types = main_types.difference(defs_types)
547
+ return main_types, defs_types, cases_types
522
548
 
523
549
  @property
524
550
  def _default_generated_python_file_name(self) -> Optional[str]:
@@ -537,7 +563,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
537
563
  def generated_python_cases_file_name(self) -> str:
538
564
  return f"{self.generated_python_file_name}_cases"
539
565
 
540
-
541
566
  @property
542
567
  def conclusion_type(self) -> Tuple[Type]:
543
568
  """
@@ -589,7 +614,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
589
614
 
590
615
 
591
616
  class SingleClassRDR(RDRWithCodeWriter):
592
-
593
617
  mutually_exclusive: bool = True
594
618
  """
595
619
  The output of the classification of this rdr negates all other possible outputs, there can only be one true value.
@@ -646,7 +670,7 @@ class SingleClassRDR(RDRWithCodeWriter):
646
670
  pred = self.evaluate(case)
647
671
  conclusion = pred.conclusion(case) if pred is not None else None
648
672
  if pred is not None and pred.fired and case_query is not None:
649
- if pred.corner_case_metadata is None and conclusion is not None\
673
+ if pred.corner_case_metadata is None and conclusion is not None \
650
674
  and type(conclusion) in case_query.core_attribute_type:
651
675
  pred.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
652
676
  return conclusion if pred is not None and pred.fired else self.default_conclusion
@@ -658,25 +682,27 @@ class SingleClassRDR(RDRWithCodeWriter):
658
682
  matched_rule = self.start_rule(case) if self.start_rule is not None else None
659
683
  return matched_rule if matched_rule is not None else self.start_rule
660
684
 
661
- def _write_to_python(self, model_dir: str):
662
- super()._write_to_python(model_dir)
685
+ def _write_to_python(self, model_dir: str, package_name: Optional[str] = None):
686
+ super()._write_to_python(model_dir, package_name=package_name)
663
687
  if self.default_conclusion is not None:
664
688
  with open(model_dir + f"/{self.generated_python_file_name}.py", "a") as f:
665
689
  f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
666
690
 
667
691
  def write_rules_as_source_code_to_file(self, rule: SingleClassRule, filename: str, parent_indent: str = "",
668
- defs_file: Optional[str] = None, cases_file: Optional[str] = None):
692
+ defs_file: Optional[str] = None, cases_file: Optional[str] = None,
693
+ package_name: Optional[str] = None):
669
694
  """
670
695
  Write the rules as source code to a file.
671
696
  """
672
697
  if rule.conditions:
673
- rule.write_corner_case_as_source_code(cases_file)
698
+ rule.write_corner_case_as_source_code(cases_file, package_name=package_name)
674
699
  if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
675
700
  with open(filename, "a") as file:
676
701
  file.write(if_clause)
677
702
  if rule.refinement:
678
703
  self.write_rules_as_source_code_to_file(rule.refinement, filename, parent_indent + " ",
679
- defs_file=defs_file, cases_file=cases_file)
704
+ defs_file=defs_file, cases_file=cases_file,
705
+ package_name=package_name)
680
706
 
681
707
  conclusion_call = rule.write_conclusion_as_source_code(parent_indent, defs_file)
682
708
  with open(filename, "a") as file:
@@ -684,7 +710,7 @@ class SingleClassRDR(RDRWithCodeWriter):
684
710
 
685
711
  if rule.alternative:
686
712
  self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
687
- cases_file=cases_file)
713
+ cases_file=cases_file, package_name=package_name)
688
714
 
689
715
  @property
690
716
  def conclusion_type_hint(self) -> str:
@@ -745,8 +771,9 @@ class MultiClassRDR(RDRWithCodeWriter):
745
771
  if evaluated_rule.fired:
746
772
  rule_conclusion = evaluated_rule.conclusion(case)
747
773
  if evaluated_rule.corner_case_metadata is None and case_query is not None:
748
- if rule_conclusion is not None and len(make_list(rule_conclusion)) > 0\
749
- and any(ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
774
+ if rule_conclusion is not None and len(make_list(rule_conclusion)) > 0 \
775
+ and any(
776
+ ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
750
777
  evaluated_rule.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
751
778
  self.add_conclusion(rule_conclusion)
752
779
  evaluated_rule = next_rule
@@ -789,19 +816,20 @@ class MultiClassRDR(RDRWithCodeWriter):
789
816
 
790
817
  def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
791
818
  filename: str, parent_indent: str = "", defs_file: Optional[str] = None,
792
- cases_file: Optional[str] = None):
819
+ cases_file: Optional[str] = None, package_name: Optional[str] = None):
793
820
  if rule == self.start_rule:
794
821
  with open(filename, "a") as file:
795
822
  file.write(f"{parent_indent}conclusions = set()\n")
796
823
  if rule.conditions:
797
- rule.write_corner_case_as_source_code(cases_file)
824
+ rule.write_corner_case_as_source_code(cases_file, package_name=package_name)
798
825
  if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
799
826
  with open(filename, "a") as file:
800
827
  file.write(if_clause)
801
828
  conclusion_indent = parent_indent
802
829
  if hasattr(rule, "refinement") and rule.refinement:
803
830
  self.write_rules_as_source_code_to_file(rule.refinement, filename, parent_indent + " ",
804
- defs_file=defs_file, cases_file=cases_file)
831
+ defs_file=defs_file, cases_file=cases_file,
832
+ package_name=package_name)
805
833
  conclusion_indent = parent_indent + " " * 4
806
834
  with open(filename, "a") as file:
807
835
  file.write(f"{conclusion_indent}else:\n")
@@ -812,7 +840,7 @@ class MultiClassRDR(RDRWithCodeWriter):
812
840
 
813
841
  if rule.alternative:
814
842
  self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
815
- cases_file=cases_file)
843
+ cases_file=cases_file, package_name=package_name)
816
844
 
817
845
  @property
818
846
  def conclusion_type_hint(self) -> str:
@@ -822,12 +850,11 @@ class MultiClassRDR(RDRWithCodeWriter):
822
850
  else:
823
851
  return f"Set[Union[{', '.join(conclusion_types)}]]"
824
852
 
825
- def _get_imports(self) -> Tuple[str, str]:
826
- imports, defs_imports = super()._get_imports()
827
- imports += f"from typing_extensions import Set, Union\n"
828
- imports += "from ripple_down_rules.utils import make_set\n"
829
- defs_imports += "from typing_extensions import Union\n"
830
- return imports, defs_imports
853
+ def _get_types_to_import(self) -> Tuple[Set[Type], Set[Type], Set[Type]]:
854
+ main_types, defs_types, cases_types = super()._get_types_to_import()
855
+ main_types.update({Set, Union, make_set})
856
+ defs_types.add(Union)
857
+ return main_types, defs_types, cases_types
831
858
 
832
859
  def update_start_rule(self, case_query: CaseQuery, expert: Expert):
833
860
  """
@@ -1039,7 +1066,7 @@ class GeneralRDR(RippleDownRules):
1039
1066
 
1040
1067
  def _to_json(self) -> Dict[str, Any]:
1041
1068
  return {"start_rules": {name: rdr.to_json() for name, rdr in self.start_rules_dict.items()}
1042
- , "generated_python_file_name": self.generated_python_file_name,
1069
+ , "generated_python_file_name": self.generated_python_file_name,
1043
1070
  "name": self.name,
1044
1071
  "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
1045
1072
  "case_name": self.case_name}
@@ -1063,26 +1090,30 @@ class GeneralRDR(RippleDownRules):
1063
1090
  new_rdr.case_name = data["case_name"]
1064
1091
  return new_rdr
1065
1092
 
1066
- def update_from_python(self, model_dir: str) -> None:
1093
+ def update_from_python(self, model_dir: str, package_name: Optional[str] = None) -> None:
1067
1094
  """
1068
1095
  Update the rules from the generated python file, that might have been modified by the user.
1069
1096
 
1070
1097
  :param model_dir: The directory where the model is stored.
1098
+ :param package_name: The name of the package that contains the RDR classifier function, this
1099
+ is required in case of relative imports in the generated python file.
1071
1100
  """
1072
1101
  for rdr in self.start_rules_dict.values():
1073
- rdr.update_from_python(model_dir)
1102
+ rdr.update_from_python(model_dir, package_name=package_name)
1074
1103
 
1075
- def _write_to_python(self, model_dir: str) -> None:
1104
+ def _write_to_python(self, model_dir: str, package_name: Optional[str] = None) -> None:
1076
1105
  """
1077
1106
  Write the tree of rules as source code to a file.
1078
1107
 
1079
1108
  :param model_dir: The directory where the model is stored.
1109
+ :param relative_imports: Whether to use relative imports in the generated python file.
1080
1110
  """
1081
1111
  for rdr in self.start_rules_dict.values():
1082
- rdr._write_to_python(model_dir)
1112
+ rdr._write_to_python(model_dir, package_name=package_name)
1083
1113
  func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
1084
- with open(model_dir + f"/{self.generated_python_file_name}.py", "w") as f:
1085
- f.write(self._get_imports() + "\n\n")
1114
+ file_path = model_dir + f"/{self.generated_python_file_name}.py"
1115
+ with open(file_path, "w") as f:
1116
+ f.write(self._get_imports(file_path=file_path, package_name=package_name) + "\n\n")
1086
1117
  f.write("classifiers_dict = dict()\n")
1087
1118
  for rdr_key, rdr in self.start_rules_dict.items():
1088
1119
  f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
@@ -1105,25 +1136,29 @@ class GeneralRDR(RippleDownRules):
1105
1136
  def conclusion_type_hint(self) -> str:
1106
1137
  return f"Dict[str, Any]"
1107
1138
 
1108
- def _get_imports(self) -> str:
1139
+ def _get_imports(self, file_path: Optional[str] = None, package_name: Optional[str] = None) -> str:
1109
1140
  """
1110
1141
  Get the imports needed for the generated python file.
1111
1142
 
1143
+ :param file_path: The path to the file where the imports will be written, if None, the imports will be absolute.
1144
+ :param package_name: The name of the package that contains the RDR classifier function, this
1145
+ is required in case of relative imports in the generated python file.
1112
1146
  :return: The imports needed for the generated python file.
1113
1147
  """
1114
- imports = ""
1148
+ all_types = set()
1115
1149
  # add type hints
1116
- imports += f"from typing_extensions import Dict, Any\n"
1150
+ all_types.update({Dict, Any})
1117
1151
  # import rdr type
1118
- imports += f"from ripple_down_rules.helpers import general_rdr_classify\n"
1152
+ all_types.add(general_rdr_classify)
1119
1153
  # add case type
1120
- imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
1121
- imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
1154
+ all_types.update({Case, create_case, self.case_type})
1155
+ # get the imports from the types
1156
+ imports = get_imports_from_types(all_types, target_file_path=file_path, package_name=package_name)
1122
1157
  # add rdr python generated functions.
1123
1158
  for rdr_key, rdr in self.start_rules_dict.items():
1124
- imports += (f"from ."
1125
- f" import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}\n")
1126
- return imports
1159
+ imports.append(
1160
+ f"from . import {rdr.generated_python_file_name} as {self.rdr_key_to_function_name(rdr_key)}")
1161
+ return '\n'.join(imports)
1127
1162
 
1128
1163
  @staticmethod
1129
1164
  def rdr_key_to_function_name(rdr_key: str) -> str:
@@ -8,7 +8,7 @@ from uuid import uuid4
8
8
 
9
9
  from anytree import NodeMixin
10
10
  from sqlalchemy.orm import DeclarativeBase as SQLTable
11
- from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
11
+ from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple, Type, Set
12
12
 
13
13
  from .datastructures.callable_expression import CallableExpression
14
14
  from .datastructures.case import Case
@@ -102,11 +102,21 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
102
102
  """
103
103
  pass
104
104
 
105
- def write_corner_case_as_source_code(self, cases_file: Path) -> None:
105
+ def write_corner_case_as_source_code(self, cases_file: str, package_name: Optional[str] = None) -> None:
106
106
  """
107
107
  Write the source code representation of the corner case of the rule to a file.
108
108
 
109
- :param cases_file: The file to write the corner case to if it is a definition.
109
+ :param cases_file: The file to write the corner case to.
110
+ :param package_name: The package name to use for relative imports.
111
+ """
112
+ if self.corner_case_metadata is None:
113
+ return
114
+ with open(cases_file, 'a') as f:
115
+ f.write(f"corner_case_{self.uid} = {self.corner_case_metadata}" + "\n\n\n")
116
+
117
+ def get_corner_case_types_to_import(self) -> Set[Type]:
118
+ """
119
+ Get the types that need to be imported for the corner case of the rule.
110
120
  """
111
121
  if self.corner_case_metadata is None:
112
122
  return
@@ -117,10 +127,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
117
127
  types_to_import.add(self.corner_case_metadata.scenario)
118
128
  if self.corner_case_metadata.case_conf is not None:
119
129
  types_to_import.add(self.corner_case_metadata.case_conf)
120
- imports = get_imports_from_types(list(types_to_import))
121
- with open(cases_file, 'a') as f:
122
- f.write("\n".join(imports) + "\n\n\n")
123
- f.write(f"corner_case_{self.uid} = {self.corner_case_metadata}" + "\n\n\n")
130
+ return types_to_import
124
131
 
125
132
  def write_conclusion_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
126
133
  """
@@ -198,7 +198,7 @@ class TemplateFileCreator:
198
198
  with open(self.temp_file_path, 'w+') as f:
199
199
  f.write(code)
200
200
 
201
- def get_imports(self):
201
+ def get_imports(self) -> str:
202
202
  """
203
203
  :return: A string containing the imports for the function.
204
204
  """
@@ -8,6 +8,7 @@ import json
8
8
  import logging
9
9
  import os
10
10
  import re
11
+ import sys
11
12
  import threading
12
13
  import uuid
13
14
  from collections import UserDict, defaultdict
@@ -15,6 +16,7 @@ from copy import deepcopy, copy
15
16
  from dataclasses import is_dataclass, fields
16
17
  from enum import Enum
17
18
  from os.path import dirname
19
+ from pathlib import Path
18
20
  from textwrap import dedent
19
21
  from types import NoneType
20
22
 
@@ -44,7 +46,7 @@ from sqlalchemy import MetaData, inspect
44
46
  from sqlalchemy.orm import Mapped, registry, class_mapper, DeclarativeBase as SQLTable, Session
45
47
  from tabulate import tabulate
46
48
  from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get_type_hints, \
47
- get_origin, get_args, Tuple, Optional, List, Union, Self, ForwardRef
49
+ get_origin, get_args, Tuple, Optional, List, Union, Self, ForwardRef, Sequence, Iterable
48
50
 
49
51
  if TYPE_CHECKING:
50
52
  from .datastructures.case import Case
@@ -122,7 +124,15 @@ def get_imports_from_scope(scope: Dict[str, Any]) -> List[str]:
122
124
  return imports
123
125
 
124
126
 
125
- def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = None) -> Dict[str, Any]:
127
+ def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = None,
128
+ package_name: Optional[str] = None) -> Dict[str, Any]:
129
+ """
130
+ Extract imports from a Python file or an AST tree.
131
+
132
+ :param file_path: The path to the Python file to extract imports from.
133
+ :param tree: An AST tree to extract imports from. If provided, file_path is ignored.
134
+ :param package_name: The name of the package to use for relative imports.
135
+ """
126
136
  if tree is None:
127
137
  if file_path is None:
128
138
  raise ValueError("Either file_path or tree must be provided")
@@ -137,7 +147,7 @@ def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = N
137
147
  module_name = alias.name
138
148
  asname = alias.asname or alias.name
139
149
  try:
140
- scope[asname] = importlib.import_module(module_name)
150
+ scope[asname] = importlib.import_module(module_name, package=package_name)
141
151
  except ImportError as e:
142
152
  print(f"Could not import {module_name}: {e}")
143
153
  elif isinstance(node, ast.ImportFrom):
@@ -146,7 +156,12 @@ def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = N
146
156
  name = alias.name
147
157
  asname = alias.asname or name
148
158
  try:
149
- module = importlib.import_module(module_name)
159
+ if package_name is not None and node.level > 0: # Handle relative imports
160
+ module_rel_path = Path(os.path.join(file_path, *['..'] * node.level, module_name)).resolve()
161
+ idx = str(module_rel_path).rfind(package_name)
162
+ if idx != -1:
163
+ module_name = str(module_rel_path)[idx:].replace(os.path.sep, '.')
164
+ module = importlib.import_module(module_name, package=package_name)
150
165
  scope[asname] = getattr(module, name)
151
166
  except (ImportError, AttributeError) as e:
152
167
  logging.warning(f"Could not import {module_name}: {e} while extracting imports from {file_path}")
@@ -833,37 +848,82 @@ def get_function_representation(func: Callable) -> str:
833
848
  return func_name
834
849
 
835
850
 
836
- def get_imports_from_types(type_objs: List[Type]) -> List[str]:
851
+ def get_relative_import(target_file_path, imported_module_path: Optional[str] = None,
852
+ module: Optional[str] = None) -> str:
853
+ """
854
+ Get a relative import path from the target file to the imported module.
855
+
856
+ :param target_file_path: The file path of the target file.
857
+ :param imported_module_path: The file path of the module being imported.
858
+ :param module: The module name, if available.
859
+ :return: A relative import path as a string.
860
+ """
861
+ # Convert to absolute paths
862
+ if module is not None:
863
+ imported_module_path = sys.modules[module].__file__
864
+ if imported_module_path is None:
865
+ raise ValueError("Either imported_module_path or module must be provided")
866
+ target_path = Path(target_file_path).resolve()
867
+ imported_path = Path(imported_module_path).resolve()
868
+
869
+ # Compute relative path from target to imported module
870
+ rel_path = os.path.relpath(imported_path.parent, target_path.parent)
871
+
872
+ # Convert path to Python import format
873
+ rel_parts = [part.replace('..', '.') for part in Path(rel_path).parts]
874
+ rel_parts = rel_parts if rel_parts else ['']
875
+
876
+ # Join the parts and add the module name
877
+ joined_parts = "".join(rel_parts) + f".{imported_path.stem}"
878
+ joined_parts = f".{joined_parts}" if not joined_parts.startswith(".") else joined_parts
879
+
880
+ return joined_parts
881
+
882
+
883
+ def get_imports_from_types(type_objs: Iterable[Type],
884
+ target_file_path: Optional[str] = None,
885
+ package_name: Optional[str] = None) -> List[str]:
837
886
  """
838
887
  Format import lines from type objects.
839
888
 
840
889
  :param type_objs: A list of type objects to format.
890
+ :param target_file_path: The file path to which the imports should be relative.
891
+ :param package_name: The name of the package to use for relative imports.
841
892
  """
842
893
 
843
894
  module_to_types = defaultdict(list)
895
+ module_to_path = {}
844
896
  other_imports = []
845
897
  for tp in type_objs:
846
898
  try:
847
899
  if isinstance(tp, type) or is_typing_type(tp):
848
900
  module = tp.__module__
901
+ file = getattr(tp, '__file__', None)
849
902
  name = tp.__qualname__
850
903
  elif callable(tp):
851
904
  module, name = get_function_import_data(tp)
905
+ file = get_method_file_name(tp)
852
906
  elif hasattr(type(tp), "__module__"):
853
907
  module = type(tp).__module__
908
+ file = getattr(tp, '__file__', None)
854
909
  name = type(tp).__qualname__
855
910
  else:
856
911
  continue
857
912
  if module is None or module == 'builtins' or module.startswith('_'):
858
913
  continue
859
914
  module_to_types[module].append(name)
915
+ if file:
916
+ module_to_path[module] = file
860
917
  except AttributeError:
861
918
  continue
862
919
 
863
920
  lines = []
864
921
  for module, names in module_to_types.items():
865
922
  joined = ", ".join(sorted(set(names)))
866
- lines.append(f"from {module} import {joined}")
923
+ import_path = module
924
+ if (target_file_path is not None) and (package_name is not None) and (package_name in module):
925
+ import_path = get_relative_import(target_file_path, module=module)
926
+ lines.append(f"from {import_path} import {joined}")
867
927
  if other_imports:
868
928
  lines.extend(other_imports)
869
929
  return sorted(lines)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.5.75
3
+ Version: 0.5.80
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -1,11 +1,11 @@
1
- ripple_down_rules/__init__.py,sha256=6Ze00N3Py1dmFEMGBz3jz63qsUVLz8WzFn7qx3lJnfM,100
1
+ ripple_down_rules/__init__.py,sha256=OCXYavUU_yIcymr1OaTIQ9uNk3ZwDNlg1bX6wK7sZLQ,100
2
2
  ripple_down_rules/experts.py,sha256=bwozulI1rv0uyaMZQqEgapDO-s8wvW0D6Jqxmvu5fik,12610
3
3
  ripple_down_rules/helpers.py,sha256=v4oE7C5PfQUVJfSUs1FfLHEwrJXEHJLn4vJhJMvyCR8,4453
4
- ripple_down_rules/rdr.py,sha256=Mqh7lDjQu6wZUcJiJ57CZ3P0-hM4WfhFuV4s1jZnRv8,51833
4
+ ripple_down_rules/rdr.py,sha256=2t04Qj931dh6UypLBy4RxJ5hFIui0ejFgvm7D5P8b-E,55118
5
5
  ripple_down_rules/rdr_decorators.py,sha256=0sk7izDB53lTKSB9fm33vQahmY_05FyCOWljyQOMB0U,9072
6
- ripple_down_rules/rules.py,sha256=ctf9yREG5l99HPFcYosjppKXTOwplZmzQbm4R1DMVaA,20107
6
+ ripple_down_rules/rules.py,sha256=iVevv6iZ-6L2IPI0ZYbBjxBymXEQMmJGRFhiKUS-NmA,20352
7
7
  ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
8
- ripple_down_rules/utils.py,sha256=iwfpTlsxUqLHWpYqSKwrDnEEa_FYFHYb2LugEVDH_kk,57132
8
+ ripple_down_rules/utils.py,sha256=mObBszTruGrRvD4MgD8tS1AnMtoyKrPl4RCciinhzY4,60132
9
9
  ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
10
10
  ripple_down_rules/datastructures/callable_expression.py,sha256=f3wUPTrLa1INO-1qfgVz87ryrCABronfyq0_JKWoZCs,12800
11
11
  ripple_down_rules/datastructures/case.py,sha256=1zSaXUljaH6z3SgMGzYPoDyjotNam791KpYgvxuMh90,15463
@@ -16,9 +16,9 @@ ripple_down_rules/user_interface/gui.py,sha256=_lgZAUXxxaBUFQJAHjA5TBPp6XEvJ62t-
16
16
  ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=Jrf7NxOdlrwGXH0Xyz3vzQprY-PNx9etfePOTpm2Gu8,6479
17
17
  ripple_down_rules/user_interface/object_diagram.py,sha256=FEa2HaYR9QmTE6NsOwBvZ0jqmu3DKyg6mig2VE5ZP4Y,4956
18
18
  ripple_down_rules/user_interface/prompt.py,sha256=AkkltdDIaioN43lkRKDPKSjJcmdSSGZDMYz7AL7X9lE,8082
19
- ripple_down_rules/user_interface/template_file_creator.py,sha256=VLS9Nxg6gPNa-YYliJ_VNsTvLPlZ003EVkJ2t8zuDgE,13563
20
- ripple_down_rules-0.5.75.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
21
- ripple_down_rules-0.5.75.dist-info/METADATA,sha256=M1N-k7Zp8qzOsGlF8K4n889agx_bLu_rWO9_c-cEViQ,48214
22
- ripple_down_rules-0.5.75.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- ripple_down_rules-0.5.75.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
24
- ripple_down_rules-0.5.75.dist-info/RECORD,,
19
+ ripple_down_rules/user_interface/template_file_creator.py,sha256=xWUt-RHRqrvoMo74o-bMLo8xNxil68wgbOZAMADZp2A,13570
20
+ ripple_down_rules-0.5.80.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
21
+ ripple_down_rules-0.5.80.dist-info/METADATA,sha256=slyVN82WJd9yd4efs415WgtiiCM__5BJeKneQhWl5zY,48214
22
+ ripple_down_rules-0.5.80.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ ripple_down_rules-0.5.80.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
24
+ ripple_down_rules-0.5.80.dist-info/RECORD,,