ripple-down-rules 0.4.7__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,15 +5,17 @@ of the RDRs.
5
5
  """
6
6
  import os.path
7
7
  from functools import wraps
8
+
9
+ from pyparsing.tools.cvt_pyparsing_pep8_names import camel_to_snake
8
10
  from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union
9
11
 
10
- from ripple_down_rules.datastructures.case import create_case
12
+ from ripple_down_rules.datastructures.case import create_case, Case
11
13
  from ripple_down_rules.datastructures.dataclasses import CaseQuery
12
14
  from ripple_down_rules.datastructures.enums import Category
13
15
  from ripple_down_rules.experts import Expert, Human
14
16
  from ripple_down_rules.rdr import GeneralRDR, RippleDownRules
15
17
  from ripple_down_rules.utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
16
- get_method_class_if_exists
18
+ get_method_class_if_exists, get_method_name, str_to_snake_case
17
19
 
18
20
 
19
21
  class RDRDecorator:
@@ -41,99 +43,118 @@ class RDRDecorator:
41
43
  :return: A decorator to use a GeneralRDR as a classifier that monitors and modifies the function's output.
42
44
  """
43
45
  self.rdr_models_dir = models_dir
46
+ self.model_name: Optional[str] = None
44
47
  self.output_type = output_type
45
48
  self.parsed_output_type: List[Type] = []
46
49
  self.mutual_exclusive = mutual_exclusive
47
50
  self.rdr_python_path: Optional[str] = python_dir
48
51
  self.output_name = output_name
49
52
  self.fit: bool = fit
50
- self.expert = expert if expert else Human()
51
- self.rdr_model_path: Optional[str] = None
53
+ self.expert: Optional[Expert] = expert
52
54
  self.load()
53
55
 
54
56
  def decorator(self, func: Callable) -> Callable:
55
57
 
56
58
  @wraps(func)
57
59
  def wrapper(*args, **kwargs) -> Optional[Any]:
60
+
58
61
  if len(self.parsed_output_type) == 0:
59
- self.parse_output_type(func, *args)
60
- if self.rdr_model_path is None:
61
- self.initialize_rdr_model_path_and_load(func)
62
- case_dict = get_method_args_as_dict(func, *args, **kwargs)
63
- func_output = func(*args, **kwargs)
64
- case_dict.update({self.output_name: func_output})
65
- case = create_case(case_dict, obj_name=get_func_rdr_model_name(func), max_recursion_idx=3)
62
+ self.parsed_output_type = self.parse_output_type(func, self.output_type, *args)
63
+ if self.model_name is None:
64
+ self.initialize_rdr_model_name_and_load(func)
65
+
66
66
  if self.fit:
67
- scope = func.__globals__
68
- scope.update(case_dict)
69
- func_args_type_hints = get_type_hints(func)
70
- func_args_type_hints.update({self.output_name: Union[tuple(self.parsed_output_type)]})
71
- case_query = CaseQuery(case, self.output_name, Union[tuple(self.parsed_output_type)],
72
- self.mutual_exclusive,
73
- scope=scope, is_function=True, function_args_type_hints=func_args_type_hints)
67
+ expert_answers_path = os.path.join(self.rdr_models_dir, self.model_name, "expert_answers")
68
+ self.expert = self.expert or Human(answers_save_path=expert_answers_path)
69
+ case_query = self.create_case_query_from_method(func, self.parsed_output_type,
70
+ self.mutual_exclusive, self.output_name,
71
+ *args, **kwargs)
74
72
  output = self.rdr.fit_case(case_query, expert=self.expert)
75
73
  return output[self.output_name]
76
74
  else:
75
+ case, case_dict = self.create_case_from_method(func, self.output_name, *args, **kwargs)
77
76
  return self.rdr.classify(case)[self.output_name]
78
77
 
79
78
  return wrapper
80
79
 
81
- def initialize_rdr_model_path_and_load(self, func: Callable) -> None:
80
+ @staticmethod
81
+ def create_case_query_from_method(func: Callable, output_type, mutual_exclusive: bool,
82
+ output_name: str = 'output_', *args, **kwargs) -> CaseQuery:
83
+ """
84
+ Create a CaseQuery from the function and its arguments.
85
+
86
+ :param func: The function to create a case from.
87
+ :param output_type: The type of the output.
88
+ :param mutual_exclusive: If True, the output types are mutually exclusive.
89
+ :param output_name: The name of the output in the case. Defaults to 'output_'.
90
+ :param args: The positional arguments of the function.
91
+ :param kwargs: The keyword arguments of the function.
92
+ :return: A CaseQuery object representing the case.
93
+ """
94
+ output_type = make_set(output_type)
95
+ case, case_dict = RDRDecorator.create_case_from_method(func, output_name, *args, **kwargs)
96
+ scope = func.__globals__
97
+ scope.update(case_dict)
98
+ func_args_type_hints = get_type_hints(func)
99
+ func_args_type_hints.update({output_name: Union[tuple(output_type)]})
100
+ return CaseQuery(case, output_name, Union[tuple(output_type)],
101
+ mutual_exclusive, scope=scope,
102
+ is_function=True, function_args_type_hints=func_args_type_hints)
103
+
104
+ @staticmethod
105
+ def create_case_from_method(func: Callable, output_name: str = "output_", *args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
106
+ """
107
+ Create a Case from the function and its arguments.
108
+
109
+ :param func: The function to create a case from.
110
+ :param output_name: The name of the output in the case. Defaults to 'output_'.
111
+ :param args: The positional arguments of the function.
112
+ :param kwargs: The keyword arguments of the function.
113
+ :return: A Case object representing the case.
114
+ """
115
+ case_dict = get_method_args_as_dict(func, *args, **kwargs)
116
+ func_output = func(*args, **kwargs)
117
+ case_dict.update({output_name: func_output})
118
+ case_name = get_func_rdr_model_name(func)
119
+ return create_case(case_dict, obj_name=case_name, max_recursion_idx=3), case_dict
120
+
121
+ def initialize_rdr_model_name_and_load(self, func: Callable) -> None:
82
122
  model_file_name = get_func_rdr_model_name(func, include_file_name=True)
83
- model_file_name = (''.join(['_' + c.lower() if c.isupper() else c for c in model_file_name]).lstrip('_')
84
- .replace('__', '_') + ".json")
85
- self.rdr_model_path = os.path.join(self.rdr_models_dir, model_file_name)
123
+ self.model_name = str_to_snake_case(model_file_name)
86
124
  self.load()
87
125
 
88
- def parse_output_type(self, func: Callable, *args) -> None:
89
- for ot in make_set(self.output_type):
126
+ @staticmethod
127
+ def parse_output_type(func: Callable, output_type: Any, *args) -> List[Type]:
128
+ parsed_output_type = []
129
+ for ot in make_set(output_type):
90
130
  if ot is Self:
91
131
  func_class = get_method_class_if_exists(func, *args)
92
132
  if func_class is not None:
93
- self.parsed_output_type.append(func_class)
133
+ parsed_output_type.append(func_class)
94
134
  else:
95
135
  raise ValueError(f"The function {func} is not a method of a class,"
96
136
  f" and the output type is {Self}.")
97
137
  else:
98
- self.parsed_output_type.append(ot)
138
+ parsed_output_type.append(ot)
139
+ return parsed_output_type
99
140
 
100
141
  def save(self):
101
142
  """
102
143
  Save the RDR model to the specified directory.
103
144
  """
104
- self.rdr.save(self.rdr_model_path)
105
-
106
- if self.rdr_python_path is not None:
107
- if not os.path.exists(self.rdr_python_path):
108
- os.makedirs(self.rdr_python_path)
109
- if not os.path.exists(os.path.join(self.rdr_python_path, "__init__.py")):
110
- # add __init__.py file to the directory
111
- with open(os.path.join(self.rdr_python_path, "__init__.py"), "w") as f:
112
- f.write("# This is an empty __init__.py file to make the directory a package.")
113
- # write the RDR model to a python file
114
- self.rdr.write_to_python_file(self.rdr_python_path)
145
+ self.rdr.save(self.rdr_models_dir)
115
146
 
116
147
  def load(self):
117
148
  """
118
149
  Load the RDR model from the specified directory.
119
150
  """
120
- if self.rdr_model_path is not None and os.path.exists(self.rdr_model_path):
121
- self.rdr = GeneralRDR.load(self.rdr_model_path)
151
+ if self.model_name is not None and os.path.exists(os.path.join(self.rdr_models_dir, self.model_name)):
152
+ self.rdr = GeneralRDR.load(self.rdr_models_dir, self.model_name)
122
153
  else:
123
- self.rdr = GeneralRDR()
124
-
125
- def write_to_python_file(self, package_dir: str, file_name_postfix: str = ""):
126
- """
127
- Write the RDR model to a python file.
128
-
129
- :param package_dir: The path to the directory to write the python file.
130
- """
131
- self.rdr.write_to_python_file(package_dir, postfix=file_name_postfix)
154
+ self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name)
132
155
 
133
- def update_from_python_file(self, package_dir: str):
156
+ def update_from_python(self):
134
157
  """
135
158
  Update the RDR model from a python file.
136
-
137
- :param package_dir: The directory of the package that contains the generated python file.
138
159
  """
139
- self.rdr.update_from_python_file(package_dir)
160
+ self.rdr.update_from_python(self.rdr_models_dir, self.model_name)
@@ -1,11 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  import re
4
5
  from abc import ABC, abstractmethod
5
6
  from uuid import uuid4
6
7
 
7
8
  from anytree import NodeMixin
8
- from rospy import logwarn, logdebug
9
9
  from sqlalchemy.orm import DeclarativeBase as SQLTable
10
10
  from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
11
11
 
@@ -118,7 +118,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
118
118
  func_call = f"{parent_indent} return {new_function_name.replace('def ', '')}(case)\n"
119
119
  return "\n".join(conclusion_lines).strip(' '), func_call
120
120
  else:
121
- raise ValueError(f"Conclusion is format is not valid, it should be contain a function definition."
121
+ raise ValueError(f"Conclusion format is not valid, it should contain a function definition."
122
122
  f" Instead got:\n{conclusion}\n")
123
123
 
124
124
  def write_condition_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
@@ -129,9 +129,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
129
129
  :param defs_file: The file to write the conditions to if they are a definition.
130
130
  """
131
131
  if_clause = self._if_statement_source_code_clause()
132
- if '\n' not in self.conditions.user_input:
133
- return f"{parent_indent}{if_clause} {self.conditions.user_input}:\n"
134
- elif "def " in self.conditions.user_input:
132
+ if "def " in self.conditions.user_input:
135
133
  if defs_file is None:
136
134
  raise ValueError("Cannot write conditions to source code as definitions python file was not given.")
137
135
  # This means the conditions are a definition that should be written and then called
@@ -143,6 +141,9 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
143
141
  with open(defs_file, 'a') as f:
144
142
  f.write(def_code.strip() + "\n\n\n")
145
143
  return f"\n{parent_indent}{if_clause} {new_function_name.replace('def ', '')}(case):\n"
144
+ else:
145
+ raise ValueError(f"Conditions format is not valid, it should contain a function definition"
146
+ f" Instead got:\n{self.conditions.user_input}\n")
146
147
 
147
148
  @abstractmethod
148
149
  def _if_statement_source_code_clause(self) -> str:
@@ -164,7 +165,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
164
165
  try:
165
166
  corner_case = Case.from_json(data["corner_case"])
166
167
  except Exception as e:
167
- logdebug("Failed to load corner case from json, setting it to None.")
168
+ logging.debug("Failed to load corner case from json, setting it to None.")
168
169
  corner_case = None
169
170
  loaded_rule = cls(conditions=CallableExpression.from_json(data["conditions"]),
170
171
  conclusion=CallableExpression.from_json(data["conclusion"]),
@@ -0,0 +1,27 @@
1
+ #!/bin/bash
2
+ set -e
3
+ if [ -z "$RDR_EDITOR_PORT" ]; then
4
+ echo "RDR_EDITOR_PORT is not set. Using default port 8080."
5
+ RDR_EDITOR_PORT=8080
6
+ fi
7
+ ADDR="0.0.0.0:$RDR_EDITOR_PORT"
8
+ # DATA_DIR="/root/.local/share/code-server"
9
+ echo "🚀 Starting code-server on $ADDR"
10
+ # Activate your Python virtual environment if exists else ignore
11
+ if [ -z "$RDR_VENV_PATH" ]; then
12
+ echo "No virtual environment found. Skipping activation."
13
+ else
14
+ source "$RDR_VENV_PATH/bin/activate"
15
+ # Set the default Python interpreter for VS Code
16
+ export DEFAULT_PYTHON_PATH=$(which python)
17
+ fi
18
+
19
+ # Start code-server.
20
+ echo "🚀 Starting code-server on $ADDR"
21
+ if [ -z "$CODE_SERVER_USER_DATA_DIR" ]; then
22
+ echo "No user data directory found. Using default"
23
+ code-server --bind-addr $ADDR --auth none "$@"
24
+ else
25
+ echo "Using user data directory: $CODE_SERVER_USER_DATA_DIR"
26
+ code-server --bind-addr $ADDR --user-data-dir $CODE_SERVER_USER_DATA_DIR --auth none "$@"
27
+ fi
@@ -1,23 +1,29 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from copy import copy
4
+ import logging
5
5
  from types import MethodType
6
6
 
7
- from PyQt6.QtCore import Qt
8
- from PyQt6.QtGui import QPixmap, QPainter, QPalette
9
- from PyQt6.QtWidgets import (
10
- QWidget, QVBoxLayout, QLabel, QScrollArea,
11
- QSizePolicy, QToolButton, QHBoxLayout, QPushButton, QMainWindow, QGraphicsView, QGraphicsScene, QGraphicsPixmapItem
12
- )
13
- from qtconsole.inprocess import QtInProcessKernelManager
14
- from qtconsole.rich_jupyter_widget import RichJupyterWidget
7
+ try:
8
+ from PyQt6.QtCore import Qt
9
+ from PyQt6.QtGui import QPixmap, QPainter, QPalette
10
+ from PyQt6.QtWidgets import (
11
+ QWidget, QVBoxLayout, QLabel, QScrollArea,
12
+ QSizePolicy, QToolButton, QHBoxLayout, QPushButton, QMainWindow, QGraphicsView, QGraphicsScene, QGraphicsPixmapItem
13
+ )
14
+ from qtconsole.inprocess import QtInProcessKernelManager
15
+ from qtconsole.rich_jupyter_widget import RichJupyterWidget
16
+ except ImportError as e:
17
+ logging.debug("RDRCaseViewer is not available. GUI features will not work. "
18
+ "Make sure you have PyQt6 installed if you want to use the GUI features.")
19
+ raise ImportError("PyQt6 is required for the GUI features. Please install it using 'pip install PyQt6'") from e
20
+
15
21
  from typing_extensions import Optional, Any, List, Dict, Callable
16
22
 
17
23
  from ..datastructures.dataclasses import CaseQuery
18
24
  from ..datastructures.enums import PromptFor
19
25
  from .template_file_creator import TemplateFileCreator
20
- from ..utils import is_iterable, contains_return_statement, encapsulate_user_input
26
+ from ..utils import is_iterable, contains_return_statement, encapsulate_code_lines_into_a_function
21
27
  from .object_diagram import generate_object_graph
22
28
 
23
29
 
@@ -467,15 +473,17 @@ class RDRCaseViewer(QMainWindow):
467
473
  self.close()
468
474
 
469
475
  def _edit(self):
470
- self.template_file_creator = TemplateFileCreator(self.ipython_console.kernel.shell,
471
- self.case_query, self.prompt_for, self.code_to_modify,
476
+ self.template_file_creator = TemplateFileCreator(self.case_query, self.prompt_for, self.code_to_modify,
472
477
  self.print)
473
478
  self.template_file_creator.edit()
474
479
 
475
480
  def _load(self):
476
481
  if not self.template_file_creator:
477
482
  return
478
- self.code_lines = self.template_file_creator.load()
483
+ self.code_lines, updates = self.template_file_creator.load(self.template_file_creator.temp_file_path,
484
+ self.template_file_creator.func_name,
485
+ self.template_file_creator.print_func)
486
+ self.ipython_console.kernel.shell.user_ns.update(updates)
479
487
  if self.code_lines is not None:
480
488
  self.user_input = encapsulate_code_lines_into_a_function(
481
489
  self.code_lines, self.template_file_creator.func_name,
@@ -558,28 +566,6 @@ class RDRCaseViewer(QMainWindow):
558
566
  layout.addWidget(item_label)
559
567
 
560
568
 
561
- def encapsulate_code_lines_into_a_function(code_lines: List[str], function_name: str, function_signature: str,
562
- func_doc: str, case_query: CaseQuery) -> str:
563
- """
564
- Encapsulate the given code lines into a function with the specified name, signature, and docstring.
565
-
566
- :param code_lines: The lines of code to include in the user input.
567
- :param function_name: The name of the function to include in the user input.
568
- :param function_signature: The function signature to include in the user input.
569
- :param func_doc: The function docstring to include in the user input.
570
- :param case_query: The case query object.
571
- """
572
- code = '\n'.join(code_lines)
573
- code = encapsulate_user_input(code, function_signature, func_doc)
574
- if case_query.is_function:
575
- args = "**case"
576
- else:
577
- args = "case"
578
- if f"return {function_name}({args})" not in code:
579
- code = code.strip() + f"\nreturn {function_name}({args})"
580
- return code
581
-
582
-
583
569
  class IPythonConsole(RichJupyterWidget):
584
570
  def __init__(self, namespace=None, parent=None):
585
571
  super(IPythonConsole, self).__init__(parent)
@@ -8,9 +8,8 @@ from traitlets.config import Config
8
8
 
9
9
  from ..datastructures.dataclasses import CaseQuery
10
10
  from ..datastructures.enums import PromptFor
11
- from .gui import encapsulate_code_lines_into_a_function
12
11
  from .template_file_creator import TemplateFileCreator
13
- from ..utils import contains_return_statement, extract_dependencies
12
+ from ..utils import contains_return_statement, extract_dependencies, encapsulate_code_lines_into_a_function
14
13
 
15
14
 
16
15
  @magics_class
@@ -21,7 +20,7 @@ class MyMagics(Magics):
21
20
  prompt_for: Optional[PromptFor] = None,
22
21
  case_query: Optional[CaseQuery] = None):
23
22
  super().__init__(shell)
24
- self.rule_editor = TemplateFileCreator(shell, case_query, prompt_for=prompt_for, code_to_modify=code_to_modify)
23
+ self.rule_editor = TemplateFileCreator(case_query, prompt_for=prompt_for, code_to_modify=code_to_modify)
25
24
  self.all_code_lines: Optional[List[str]] = None
26
25
 
27
26
  @line_magic
@@ -30,7 +29,10 @@ class MyMagics(Magics):
30
29
 
31
30
  @line_magic
32
31
  def load(self, line):
33
- self.all_code_lines = self.rule_editor.load()
32
+ self.all_code_lines, updates = self.rule_editor.load(self.rule_editor.temp_file_path,
33
+ self.rule_editor.func_name,
34
+ self.rule_editor.print_func)
35
+ self.shell.user_ns.update(updates)
34
36
 
35
37
  @line_magic
36
38
  def help(self, line):
@@ -1,4 +1,10 @@
1
- import graphviz
1
+ import logging
2
+
3
+ try:
4
+ import graphviz
5
+ except ImportError:
6
+ graphviz = None
7
+ logging.debug("Graphviz is not installed")
2
8
 
3
9
 
4
10
  def is_simple(obj):
@@ -2,7 +2,13 @@ import ast
2
2
  import logging
3
3
  from _ast import AST
4
4
 
5
- from PyQt6.QtWidgets import QApplication
5
+ try:
6
+ from PyQt6.QtWidgets import QApplication
7
+ from .gui import RDRCaseViewer
8
+ except ImportError:
9
+ QApplication = None
10
+ RDRCaseViewer = None
11
+
6
12
  from colorama import Fore, Style
7
13
  from pygments import highlight
8
14
  from pygments.formatters.terminal import TerminalFormatter
@@ -11,8 +17,7 @@ from typing_extensions import Optional, Tuple
11
17
 
12
18
  from ..datastructures.callable_expression import CallableExpression, parse_string_to_expression
13
19
  from ..datastructures.dataclasses import CaseQuery
14
- from ..datastructures.enums import PromptFor, InteractionMode
15
- from .gui import RDRCaseViewer
20
+ from ..datastructures.enums import PromptFor
16
21
  from .ipython_custom_shell import IPythonShell
17
22
  from ..utils import make_list
18
23
 
@@ -90,7 +95,7 @@ class UserPrompt:
90
95
  prompt_str = f"Give conditions on when can the rule be evaluated for:"
91
96
  case_query.scope.update({'case': case_query.case})
92
97
  shell = None
93
- if QApplication.instance() is None:
98
+ if self.viewer is None:
94
99
  prompt_str = self.construct_prompt_str_for_shell(case_query, prompt_for, prompt_str)
95
100
  shell = IPythonShell(header=prompt_str, prompt_for=prompt_for, case_query=case_query,
96
101
  code_to_modify=code_to_modify)
@@ -8,14 +8,13 @@ from functools import cached_property
8
8
  from textwrap import indent, dedent
9
9
 
10
10
  from colorama import Fore, Style
11
- from ipykernel.inprocess.ipkernel import InProcessInteractiveShell
12
- from typing_extensions import Optional, Type, List, Callable
11
+ from typing_extensions import Optional, Type, List, Callable, Tuple, Dict
13
12
 
14
13
  from ..datastructures.case import Case
15
14
  from ..datastructures.dataclasses import CaseQuery
16
15
  from ..datastructures.enums import Editor, PromptFor
17
16
  from ..utils import str_to_snake_case, get_imports_from_scope, make_list, typing_hint_to_str, \
18
- get_imports_from_types, extract_function_source
17
+ get_imports_from_types, extract_function_source, extract_imports
19
18
 
20
19
 
21
20
  def detect_available_editor() -> Optional[Editor]:
@@ -49,6 +48,9 @@ def start_code_server(workspace):
49
48
  stderr=subprocess.PIPE, text=True)
50
49
 
51
50
 
51
+ FunctionData = Tuple[Optional[List[str]], Optional[Dict[str, Callable]]]
52
+
53
+
52
54
  class TemplateFileCreator:
53
55
  """
54
56
  A class to create a rule template file for a given case and prompt for the user to edit it.
@@ -70,10 +72,9 @@ class TemplateFileCreator:
70
72
  The list of all code lines in the function in the temporary file.
71
73
  """
72
74
 
73
- def __init__(self, shell: InProcessInteractiveShell, case_query: CaseQuery, prompt_for: PromptFor,
74
- code_to_modify: Optional[str] = None, print_func: Optional[Callable[[str], None]] = None):
75
- self.print_func = print_func if print_func else print
76
- self.shell = shell
75
+ def __init__(self, case_query: CaseQuery, prompt_for: PromptFor,
76
+ code_to_modify: Optional[str] = None, print_func: Callable[[str], None] = print):
77
+ self.print_func = print_func
77
78
  self.code_to_modify = code_to_modify
78
79
  self.prompt_for = prompt_for
79
80
  self.case_query = case_query
@@ -228,15 +229,27 @@ class TemplateFileCreator:
228
229
  imports = set(imports)
229
230
  return '\n'.join(imports)
230
231
 
232
+ @staticmethod
233
+ def get_core_attribute_types(case_query: CaseQuery) -> List[Type]:
234
+ """
235
+ Get the core attribute types of the case query.
236
+
237
+ :return: A list of core attribute types.
238
+ """
239
+ attr_types = [t for t in case_query.core_attribute_type if t.__module__ != "builtins" and t is not None
240
+ and t is not type(None)]
241
+ return attr_types
242
+
231
243
  def get_func_doc(self) -> Optional[str]:
232
244
  """
233
245
  :return: A string containing the function docstring.
234
246
  """
247
+ type_data = f" of type {' or '.join(map(lambda c: c.__name__, self.get_core_attribute_types(self.case_query)))}"
235
248
  if self.prompt_for == PromptFor.Conditions:
236
249
  return (f"Get conditions on whether it's possible to conclude a value"
237
- f" for {self.case_query.name}")
250
+ f" for {self.case_query.name} {type_data}.")
238
251
  else:
239
- return f"Get possible value(s) for {self.case_query.name}"
252
+ return f"Get possible value(s) for {self.case_query.name} {type_data}."
240
253
 
241
254
  @staticmethod
242
255
  def get_func_name(prompt_for, case_query) -> Optional[str]:
@@ -245,12 +258,12 @@ class TemplateFileCreator:
245
258
  func_name = f"{prompt_for.value.lower()}_for_"
246
259
  case_name = case_query.name.replace(".", "_")
247
260
  if case_query.is_function:
248
- # convert any CamelCase word into snake_case by adding _ before each capital letter
249
- case_name = case_name.replace(f"_{case_query.attribute_name}", "")
250
- func_name += case_name
251
- attr_types = [t for t in case_query.core_attribute_type if t.__module__ != "builtins" and t is not None
252
- and t is not type(None)]
253
- func_name += f"_of_type_{'_or_'.join(map(lambda c: c.__name__, attr_types))}"
261
+ func_name += case_name.replace(f"_{case_query.attribute_name}", "")
262
+ else:
263
+ func_name += case_name
264
+ attribute_types = TemplateFileCreator.get_core_attribute_types(case_query)
265
+ attribute_type_names = [t.__name__ for t in attribute_types]
266
+ func_name += f"_of_type_{'_or_'.join(attribute_type_names)}"
254
267
  return str_to_snake_case(func_name)
255
268
 
256
269
  @cached_property
@@ -263,33 +276,43 @@ class TemplateFileCreator:
263
276
  case = self.case_query.scope['case']
264
277
  return case._obj_type if isinstance(case, Case) else type(case)
265
278
 
266
- def load(self) -> Optional[List[str]]:
267
- if not self.temp_file_path:
268
- self.print_func(f"{Fore.RED}ERROR:: No file to load. Run %edit first.{Style.RESET_ALL}")
269
- return None
279
+ @staticmethod
280
+ def load(file_path: str, func_name: str, print_func: Callable = print) -> FunctionData:
281
+ """
282
+ Load the function from the given file path.
283
+
284
+ :param file_path: The path to the file to load.
285
+ :param func_name: The name of the function to load.
286
+ :param print_func: The function to use for printing messages.
287
+ :return: A tuple containing the function source code and the function object as a dictionary
288
+ with the function name as the key and the function object as the value.
289
+ """
290
+ if not file_path:
291
+ print_func(f"{Fore.RED}ERROR:: No file to load. Run %edit first.{Style.RESET_ALL}")
292
+ return None, None
270
293
 
271
- with open(self.temp_file_path, 'r') as f:
294
+ with open(file_path, 'r') as f:
272
295
  source = f.read()
273
296
 
274
297
  tree = ast.parse(source)
275
298
  updates = {}
276
299
  for node in tree.body:
277
- if isinstance(node, ast.FunctionDef) and node.name == self.func_name:
300
+ if isinstance(node, ast.FunctionDef) and node.name == func_name:
278
301
  exec_globals = {}
279
- exec(source, self.case_query.scope, exec_globals)
280
- user_function = exec_globals[self.func_name]
281
- updates[self.func_name] = user_function
282
- self.print_func(f"{Fore.BLUE}Loaded `{self.func_name}` function into user namespace.{Style.RESET_ALL}")
302
+ scope = extract_imports(tree=tree)
303
+ exec(source, scope, exec_globals)
304
+ user_function = exec_globals[func_name]
305
+ updates[func_name] = user_function
306
+ print_func(f"{Fore.BLUE}Loaded `{func_name}` function into user namespace.{Style.RESET_ALL}")
283
307
  break
284
308
  if updates:
285
- self.shell.user_ns.update(updates)
286
- self.all_code_lines = extract_function_source(self.temp_file_path,
287
- [self.func_name],
288
- join_lines=False)[self.func_name]
289
- return self.all_code_lines
309
+ all_code_lines = extract_function_source(file_path,
310
+ [func_name],
311
+ join_lines=False)[func_name]
312
+ return all_code_lines, updates
290
313
  else:
291
- self.print_func(f"{Fore.RED}ERROR:: Function `{self.func_name}` not found.{Style.RESET_ALL}")
292
- return None
314
+ print_func(f"{Fore.RED}ERROR:: Function `{func_name}` not found.{Style.RESET_ALL}")
315
+ return None, None
293
316
 
294
317
  def __del__(self):
295
318
  if hasattr(self, 'process') and self.process is not None and self.process.poll() is None: