ripple-down-rules 0.5.4__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,13 +7,14 @@ import os.path
7
7
  from functools import wraps
8
8
 
9
9
  from pyparsing.tools.cvt_pyparsing_pep8_names import camel_to_snake
10
- from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union
10
+ from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union, Sequence
11
11
 
12
12
  from ripple_down_rules.datastructures.case import create_case, Case
13
13
  from ripple_down_rules.datastructures.dataclasses import CaseQuery
14
14
  from ripple_down_rules.datastructures.enums import Category
15
15
  from ripple_down_rules.experts import Expert, Human
16
16
  from ripple_down_rules.rdr import GeneralRDR, RippleDownRules
17
+ from ripple_down_rules.user_interface.gui import RDRCaseViewer
17
18
  from ripple_down_rules.utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
18
19
  get_method_class_if_exists, get_method_name, str_to_snake_case
19
20
 
@@ -24,15 +25,16 @@ class RDRDecorator:
24
25
  def __init__(self, models_dir: str,
25
26
  output_type: Tuple[Type],
26
27
  mutual_exclusive: bool,
27
- python_dir: Optional[str] = None,
28
28
  output_name: str = "output_",
29
29
  fit: bool = True,
30
- expert: Optional[Expert] = None):
30
+ expert: Optional[Expert] = None,
31
+ ask_always: bool = False,
32
+ update_existing_rules: bool = True,
33
+ viewer: Optional[RDRCaseViewer] = None):
31
34
  """
32
35
  :param models_dir: The directory to save/load the RDR models.
33
36
  :param output_type: The type of the output. This is used to create the RDR model.
34
37
  :param mutual_exclusive: If True, the output types are mutually exclusive.
35
- :param python_dir: The directory to save the RDR model as a python file.
36
38
  If None, the RDR model will not be saved as a python file.
37
39
  :param output_name: The name of the output. This is used to create the RDR model.
38
40
  :param fit: If True, the function will be in fit mode. This means that the RDR will prompt the user for the
@@ -40,6 +42,9 @@ class RDRDecorator:
40
42
  classification mode. This means that the RDR will classify the function's output based on the RDR model.
41
43
  :param expert: The expert that will be used to prompt the user for the correct output. If None, a Human
42
44
  expert will be used.
45
+ :param ask_always: If True, the function will ask the user for a target if it doesn't exist.
46
+ :param update_existing_rules: If True, the function will update the existing RDR rules
47
+ even if they gave an output.
43
48
  :return: A decorator to use a GeneralRDR as a classifier that monitors and modifies the function's output.
44
49
  """
45
50
  self.rdr_models_dir = models_dir
@@ -47,10 +52,12 @@ class RDRDecorator:
47
52
  self.output_type = output_type
48
53
  self.parsed_output_type: List[Type] = []
49
54
  self.mutual_exclusive = mutual_exclusive
50
- self.rdr_python_path: Optional[str] = python_dir
51
55
  self.output_name = output_name
52
56
  self.fit: bool = fit
53
57
  self.expert: Optional[Expert] = expert
58
+ self.ask_always = ask_always
59
+ self.update_existing_rules = update_existing_rules
60
+ self.viewer = viewer
54
61
  self.load()
55
62
 
56
63
  def decorator(self, func: Callable) -> Callable:
@@ -62,61 +69,77 @@ class RDRDecorator:
62
69
  self.parsed_output_type = self.parse_output_type(func, self.output_type, *args)
63
70
  if self.model_name is None:
64
71
  self.initialize_rdr_model_name_and_load(func)
72
+ if self.expert is None:
73
+ self.expert = Human(viewer=self.viewer,
74
+ answers_save_path=self.rdr_models_dir + f'/expert_answers')
75
+
76
+ func_output = {self.output_name: func(*args, **kwargs)}
65
77
 
66
78
  if self.fit:
67
- expert_answers_path = os.path.join(self.rdr_models_dir, self.model_name, "expert_answers")
68
- self.expert = self.expert or Human(answers_save_path=expert_answers_path)
69
- case_query = self.create_case_query_from_method(func, self.parsed_output_type,
70
- self.mutual_exclusive, self.output_name,
79
+ case_query = self.create_case_query_from_method(func, func_output,
80
+ self.parsed_output_type,
81
+ self.mutual_exclusive,
71
82
  *args, **kwargs)
72
- output = self.rdr.fit_case(case_query, expert=self.expert)
83
+ output = self.rdr.fit_case(case_query, expert=self.expert,
84
+ ask_always_for_target=self.ask_always,
85
+ update_existing_rules=self.update_existing_rules,
86
+ viewer=self.viewer)
87
+ else:
88
+ case, case_dict = self.create_case_from_method(func, func_output, *args, **kwargs)
89
+ output = self.rdr.classify(case)
90
+
91
+ if self.output_name in output:
73
92
  return output[self.output_name]
74
93
  else:
75
- case, case_dict = self.create_case_from_method(func, self.output_name, *args, **kwargs)
76
- return self.rdr.classify(case)[self.output_name]
94
+ return func_output[self.output_name]
77
95
 
78
96
  return wrapper
79
97
 
80
98
  @staticmethod
81
- def create_case_query_from_method(func: Callable, output_type, mutual_exclusive: bool,
82
- output_name: str = 'output_', *args, **kwargs) -> CaseQuery:
99
+ def create_case_query_from_method(func: Callable,
100
+ func_output: Dict[str, Any],
101
+ output_type: Sequence[Type],
102
+ mutual_exclusive: bool,
103
+ *args, **kwargs) -> CaseQuery:
83
104
  """
84
105
  Create a CaseQuery from the function and its arguments.
85
106
 
86
107
  :param func: The function to create a case from.
87
- :param output_type: The type of the output.
108
+ :param func_output: The output of the function as a dictionary, where the key is the output name.
109
+ :param output_type: The type of the output as a sequence of types.
88
110
  :param mutual_exclusive: If True, the output types are mutually exclusive.
89
- :param output_name: The name of the output in the case. Defaults to 'output_'.
90
111
  :param args: The positional arguments of the function.
91
112
  :param kwargs: The keyword arguments of the function.
92
113
  :return: A CaseQuery object representing the case.
93
114
  """
94
115
  output_type = make_set(output_type)
95
- case, case_dict = RDRDecorator.create_case_from_method(func, output_name, *args, **kwargs)
116
+ case, case_dict = RDRDecorator.create_case_from_method(func, func_output, *args, **kwargs)
96
117
  scope = func.__globals__
97
118
  scope.update(case_dict)
98
119
  func_args_type_hints = get_type_hints(func)
120
+ output_name = list(func_output.keys())[0]
99
121
  func_args_type_hints.update({output_name: Union[tuple(output_type)]})
100
122
  return CaseQuery(case, output_name, Union[tuple(output_type)],
101
123
  mutual_exclusive, scope=scope,
102
124
  is_function=True, function_args_type_hints=func_args_type_hints)
103
125
 
104
126
  @staticmethod
105
- def create_case_from_method(func: Callable, output_name: str = "output_", *args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
127
+ def create_case_from_method(func: Callable,
128
+ func_output: Dict[str, Any],
129
+ *args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
106
130
  """
107
131
  Create a Case from the function and its arguments.
108
132
 
109
133
  :param func: The function to create a case from.
110
- :param output_name: The name of the output in the case. Defaults to 'output_'.
134
+ :param func_output: A dictionary containing the output of the function, where the key is the output name.
111
135
  :param args: The positional arguments of the function.
112
136
  :param kwargs: The keyword arguments of the function.
113
137
  :return: A Case object representing the case.
114
138
  """
115
139
  case_dict = get_method_args_as_dict(func, *args, **kwargs)
116
- func_output = func(*args, **kwargs)
117
- case_dict.update({output_name: func_output})
140
+ case_dict.update(func_output)
118
141
  case_name = get_func_rdr_model_name(func)
119
- return create_case(case_dict, obj_name=case_name, max_recursion_idx=3), case_dict
142
+ return Case(dict, id(case_dict), case_name, case_dict, **case_dict), case_dict
120
143
 
121
144
  def initialize_rdr_model_name_and_load(self, func: Callable) -> None:
122
145
  model_file_name = get_func_rdr_model_name(func, include_file_name=True)
@@ -148,10 +171,15 @@ class RDRDecorator:
148
171
  """
149
172
  Load the RDR model from the specified directory.
150
173
  """
151
- if self.model_name is not None and os.path.exists(os.path.join(self.rdr_models_dir, self.model_name)):
152
- self.rdr = GeneralRDR.load(self.rdr_models_dir, self.model_name)
153
- else:
154
- self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name)
174
+ self.rdr = None
175
+ if self.model_name is not None:
176
+ model_path = os.path.join(self.rdr_models_dir, self.model_name + f"/rdr_metadata/{self.model_name}.json")
177
+ if os.path.exists(os.path.join(self.rdr_models_dir, model_path)):
178
+ self.rdr = GeneralRDR.load(self.rdr_models_dir, self.model_name)
179
+ self.rdr.set_viewer(self.viewer)
180
+ if self.rdr is None:
181
+ self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name,
182
+ viewer=self.viewer)
155
183
 
156
184
  def update_from_python(self):
157
185
  """
@@ -3,16 +3,18 @@ from __future__ import annotations
3
3
  import logging
4
4
  import re
5
5
  from abc import ABC, abstractmethod
6
+ from pathlib import Path
6
7
  from uuid import uuid4
7
8
 
8
9
  from anytree import NodeMixin
9
10
  from sqlalchemy.orm import DeclarativeBase as SQLTable
10
- from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
11
+ from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple, Callable
11
12
 
12
13
  from .datastructures.callable_expression import CallableExpression
13
14
  from .datastructures.case import Case
15
+ from .datastructures.dataclasses import CaseFactoryMetaData, CaseConf, CaseQuery
14
16
  from .datastructures.enums import RDREdge, Stop
15
- from .utils import SubclassJSONSerializer, conclusion_to_json
17
+ from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_imports_from_types
16
18
 
17
19
 
18
20
  class Rule(NodeMixin, SubclassJSONSerializer, ABC):
@@ -27,7 +29,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
27
29
  corner_case: Optional[Union[Case, SQLTable]] = None,
28
30
  weight: Optional[str] = None,
29
31
  conclusion_name: Optional[str] = None,
30
- uid: Optional[str] = None):
32
+ uid: Optional[str] = None,
33
+ corner_case_metadata: Optional[CaseFactoryMetaData] = None):
31
34
  """
32
35
  A rule in the ripple down rules classifier.
33
36
 
@@ -38,10 +41,13 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
38
41
  :param weight: The weight of the rule, which is the type of edge connecting the rule to its parent.
39
42
  :param conclusion_name: The name of the conclusion of the rule.
40
43
  :param uid: The unique id of the rule.
44
+ :param corner_case_metadata: Metadata about the corner case, such as the factory that created it or the
45
+ scenario it is based on.
41
46
  """
42
47
  super(Rule, self).__init__()
43
48
  self.conclusion = conclusion
44
49
  self.corner_case = corner_case
50
+ self.corner_case_metadata: Optional[CaseFactoryMetaData] = corner_case_metadata
45
51
  self.parent = parent
46
52
  self.weight: Optional[str] = weight
47
53
  self.conditions = conditions if conditions else None
@@ -51,6 +57,20 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
51
57
  # generate a unique id for the rule using uuid4
52
58
  self.uid: str = uid if uid else str(uuid4().int)
53
59
 
60
+ @classmethod
61
+ def from_case_query(cls, case_query: CaseQuery) -> Rule:
62
+ """
63
+ Create a SingleClassRule from a CaseQuery.
64
+
65
+ :param case_query: The CaseQuery to create the rule from.
66
+ :return: A SingleClassRule instance.
67
+ """
68
+ corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
69
+ return cls(conditions=case_query.conditions, conclusion=case_query.target,
70
+ corner_case=case_query.case, parent=None,
71
+ corner_case_metadata=corner_case_metadata,
72
+ conclusion_name=case_query.attribute_name)
73
+
54
74
  def _post_detach(self, parent):
55
75
  """
56
76
  Called after this node is detached from the tree, useful when drawing the tree.
@@ -82,6 +102,27 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
82
102
  """
83
103
  pass
84
104
 
105
+ def write_corner_case_as_source_code(self, cases_file: Path) -> None:
106
+ """
107
+ Write the source code representation of the corner case of the rule to a file.
108
+
109
+ :param cases_file: The file to write the corner case to if it is a definition.
110
+ """
111
+ if self.corner_case_metadata is None:
112
+ return
113
+ types_to_import = set()
114
+ if self.corner_case_metadata.factory_method is not None:
115
+ types_to_import.add(self.corner_case_metadata.factory_method)
116
+ if self.corner_case_metadata.scenario is not None:
117
+ types_to_import.add(self.corner_case_metadata.scenario)
118
+ if self.corner_case_metadata.case_conf is not None:
119
+ types_to_import.add(self.corner_case_metadata.case_conf)
120
+ types_to_import.add(CaseFactoryMetaData)
121
+ imports = get_imports_from_types(list(types_to_import))
122
+ with open(cases_file, 'a') as f:
123
+ f.write("\n".join(imports) + "\n\n\n")
124
+ f.write(f"corner_case_{self.uid} = {self.corner_case_metadata}" + "\n\n\n")
125
+
85
126
  def write_conclusion_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
86
127
  """
87
128
  Get the source code representation of the conclusion of the rule.
@@ -150,11 +191,16 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
150
191
  pass
151
192
 
152
193
  def _to_json(self) -> Dict[str, Any]:
153
- json_serialization = {"conditions": self.conditions.to_json(),
194
+ try:
195
+ corner_case = SubclassJSONSerializer.to_json_static(self.corner_case) if self.corner_case else None
196
+ except Exception as e:
197
+ logging.debug("Failed to serialize corner case to json, setting it to None. Error: %s", e)
198
+ corner_case = None
199
+ json_serialization = {"_type": get_full_class_name(type(self)),
200
+ "conditions": self.conditions.to_json(),
154
201
  "conclusion": conclusion_to_json(self.conclusion),
155
202
  "parent": self.parent.json_serialization if self.parent else None,
156
- "corner_case": SubclassJSONSerializer.to_json_static(self.corner_case)
157
- if self.corner_case else None,
203
+ "corner_case": corner_case,
158
204
  "conclusion_name": self.conclusion_name,
159
205
  "weight": self.weight,
160
206
  "uid": self.uid}
@@ -277,9 +323,12 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
277
323
  returned_rule = self.alternative(x) if self.alternative else self
278
324
  return returned_rule if returned_rule.fired else self
279
325
 
280
- def fit_rule(self, x: Case, target: CallableExpression, conditions: CallableExpression):
281
- new_rule = SingleClassRule(conditions, target,
282
- corner_case=x, parent=self)
326
+ def fit_rule(self, case_query: CaseQuery):
327
+ corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
328
+ new_rule = SingleClassRule(case_query.conditions, case_query.target,
329
+ corner_case=case_query.case, parent=self,
330
+ corner_case_metadata=corner_case_metadata,
331
+ )
283
332
  if self.fired:
284
333
  self.refinement = new_rule
285
334
  else:
@@ -363,11 +412,12 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
363
412
  elif self.alternative: # Here alternative refers to next rule in MultiClassRDR
364
413
  return self.alternative
365
414
 
366
- def fit_rule(self, x: Case, target: CallableExpression, conditions: CallableExpression):
367
- if self.fired and target != self.conclusion:
368
- self.refinement = MultiClassStopRule(conditions, corner_case=x, parent=self)
415
+ def fit_rule(self, case_query: CaseQuery):
416
+ if self.fired and case_query.target != self.conclusion:
417
+ self.refinement = MultiClassStopRule(case_query.conditions, corner_case=case_query.case, parent=self)
369
418
  elif not self.fired:
370
- self.alternative = MultiClassTopRule(conditions, target, corner_case=x, parent=self)
419
+ self.alternative = MultiClassTopRule(case_query.conditions, case_query.target,
420
+ corner_case=case_query.case, parent=self)
371
421
 
372
422
  def _to_json(self) -> Dict[str, Any]:
373
423
  self.json_serialization = {**Rule._to_json(self),
@@ -281,12 +281,14 @@ class RDRCaseViewer(QMainWindow):
281
281
  main_obj: Optional[Dict[str, Any]] = None
282
282
  user_input: Optional[str] = None
283
283
  attributes_widget: Optional[QWidget] = None
284
- save_function: Optional[Callable[str], None] = None
284
+ save_function: Optional[Callable[str, str], None] = None
285
285
 
286
-
287
- def __init__(self, parent=None, save_file: Optional[str] = None):
286
+ def __init__(self, parent=None,
287
+ save_dir: Optional[str] = None,
288
+ save_model_name: Optional[str] = None):
288
289
  super().__init__(parent)
289
- self.save_file = save_file
290
+ self.save_dir = save_dir
291
+ self.save_model_name = save_model_name
290
292
 
291
293
  self.setWindowTitle("RDR Case Viewer")
292
294
 
@@ -323,17 +325,17 @@ class RDRCaseViewer(QMainWindow):
323
325
 
324
326
  # Add both to main layout
325
327
  main_layout.addWidget(self.attributes_widget, stretch=1)
326
- main_layout.addWidget(middle_widget, stretch=2)
328
+ main_layout.addWidget(middle_widget, stretch=1)
327
329
  main_layout.addWidget(self.obj_diagram_viewer, stretch=2)
328
330
 
329
- def set_save_function(self, save_function: Callable[[str], None]) -> None:
331
+ def set_save_function(self, save_function: Callable[[str, str], None]) -> None:
330
332
  """
331
333
  Set the function to save the file.
332
334
 
333
335
  :param save_function: The function to save the file.
334
336
  """
335
337
  self.save_function = save_function
336
- self.save_btn.clicked.connect(lambda: self.save_function(self.save_file))
338
+ self.save_btn.clicked.connect(lambda: self.save_function(self.save_dir, self.save_model_name))
337
339
 
338
340
  def print(self, msg):
339
341
  """
@@ -489,6 +491,7 @@ class RDRCaseViewer(QMainWindow):
489
491
  self.code_lines, self.template_file_creator.func_name,
490
492
  self.template_file_creator.function_signature,
491
493
  self.template_file_creator.func_doc, self.case_query)
494
+ self.case_query.scope.update(updates)
492
495
  self.template_file_creator = None
493
496
 
494
497
  def update_attribute_layout(self, obj, name: str):
@@ -134,7 +134,7 @@ class IPythonShell:
134
134
  """
135
135
  Update the user input from the code lines captured in the shell.
136
136
  """
137
- if len(self.shell.all_lines) == 1 and self.shell.all_lines[0].replace('return', '').strip() == '':
137
+ if self.shell.all_lines[0].replace('return', '').strip() == '':
138
138
  self.user_input = None
139
139
  else:
140
140
  self.all_code_lines = extract_dependencies(self.shell.all_lines)
@@ -1,5 +1,9 @@
1
1
  import logging
2
2
 
3
+ from ripple_down_rules.datastructures.case import Case
4
+ from ripple_down_rules.datastructures.dataclasses import CaseQuery
5
+ from ripple_down_rules.utils import SubclassJSONSerializer
6
+
3
7
  try:
4
8
  import graphviz
5
9
  except ImportError:
@@ -77,7 +81,11 @@ def generate_object_graph(obj, name='root', seen=None, graph=None, current_depth
77
81
  for attr in dir(obj):
78
82
  if attr.startswith('_'):
79
83
  continue
80
- if attr == 'scope':
84
+ if isinstance(obj, CaseQuery) and attr == 'scope':
85
+ continue
86
+ if isinstance(obj, Case) and attr in ['data']:
87
+ continue
88
+ if isinstance(obj, SubclassJSONSerializer) and attr == 'data_class_refs':
81
89
  continue
82
90
  value = getattr(obj, attr)
83
91
  if callable(value):
@@ -8,13 +8,13 @@ from functools import cached_property
8
8
  from textwrap import indent, dedent
9
9
 
10
10
  from colorama import Fore, Style
11
- from typing_extensions import Optional, Type, List, Callable, Tuple, Dict
11
+ from typing_extensions import Optional, Type, List, Callable, Tuple, Dict, Any, Union
12
12
 
13
13
  from ..datastructures.case import Case
14
14
  from ..datastructures.dataclasses import CaseQuery
15
15
  from ..datastructures.enums import Editor, PromptFor
16
- from ..utils import str_to_snake_case, get_imports_from_scope, make_list, typing_hint_to_str, \
17
- get_imports_from_types, extract_function_source, extract_imports
16
+ from ..utils import str_to_snake_case, get_imports_from_scope, make_list, stringify_hint, \
17
+ get_imports_from_types, extract_function_source, extract_imports, get_types_to_import_from_type_hints
18
18
 
19
19
 
20
20
  def detect_available_editor() -> Optional[Editor]:
@@ -84,6 +84,7 @@ class TemplateFileCreator:
84
84
  self.func_doc: str = self.get_func_doc()
85
85
  self.function_signature: str = self.get_function_signature()
86
86
  self.editor: Optional[Editor] = detect_available_editor()
87
+ self.editor_cmd: Optional[str] = os.environ.get("RDR_EDITOR_CMD")
87
88
  self.workspace: str = os.environ.get("RDR_EDITOR_WORKSPACE", os.path.dirname(self.case_query.scope['__file__']))
88
89
  self.temp_file_path: str = os.path.join(self.workspace, "edit_code_here.py")
89
90
 
@@ -98,7 +99,7 @@ class TemplateFileCreator:
98
99
  return make_list(output_type) if output_type is not None else None
99
100
 
100
101
  def edit(self):
101
- if self.editor is None:
102
+ if self.editor is None and self.editor_cmd is None:
102
103
  self.print_func(
103
104
  f"{Fore.RED}ERROR:: No editor found. Please install PyCharm, VSCode or code-server.{Style.RESET_ALL}")
104
105
  return
@@ -112,7 +113,11 @@ class TemplateFileCreator:
112
113
  """
113
114
  Open the file in the available editor.
114
115
  """
115
- if self.editor == Editor.Pycharm:
116
+ if self.editor_cmd is not None:
117
+ subprocess.Popen([self.editor_cmd, self.temp_file_path],
118
+ stdout=subprocess.DEVNULL,
119
+ stderr=subprocess.DEVNULL)
120
+ elif self.editor == Editor.Pycharm:
116
121
  subprocess.Popen(["pycharm", "--line", str(self.user_edit_line), self.temp_file_path],
117
122
  stdout=subprocess.DEVNULL,
118
123
  stderr=subprocess.DEVNULL)
@@ -172,7 +177,7 @@ class TemplateFileCreator:
172
177
  for k, v in self.case_query.case.items():
173
178
  if (self.case_query.function_args_type_hints is not None
174
179
  and k in self.case_query.function_args_type_hints):
175
- func_args[k] = typing_hint_to_str(self.case_query.function_args_type_hints[k])[0]
180
+ func_args[k] = stringify_hint(self.case_query.function_args_type_hints[k])
176
181
  else:
177
182
  func_args[k] = type(v).__name__ if not isinstance(v, type) else f"Type[{v.__name__}]"
178
183
  func_args = ', '.join([f"{k}: {v}" if str(v) not in ["NoneType", "None"] else str(k)
@@ -202,30 +207,25 @@ class TemplateFileCreator:
202
207
  for k, v in self.case_query.case.items():
203
208
  if (self.case_query.function_args_type_hints is not None
204
209
  and k in self.case_query.function_args_type_hints):
205
- hint_list = typing_hint_to_str(self.case_query.function_args_type_hints[k])[1]
206
- for hint in hint_list:
207
- hint_split = hint.split('.')
208
- if len(hint_split) > 1:
209
- case_type_imports.append(f"from {'.'.join(hint_split[:-1])} import {hint_split[-1]}")
210
+ types_to_import = get_types_to_import_from_type_hints([self.case_query.function_args_type_hints[k]])
211
+ case_type_imports.extend(list(types_to_import))
210
212
  else:
211
- if isinstance(v, type):
212
- case_type_imports.append(f"from {v.__module__} import {v.__name__}")
213
- elif hasattr(v, "__module__") and not v.__module__.startswith("__"):
214
- case_type_imports.append(f"\nfrom {type(v).__module__} import {type(v).__name__}")
213
+ case_type_imports.append(v)
215
214
  else:
216
- case_type_imports.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
215
+ case_type_imports.append(self.case_type)
217
216
  if self.output_type is None:
218
- output_type_imports = [f"from typing_extensions import Any"]
217
+ output_type_imports = [Any]
219
218
  else:
220
- output_type_imports = get_imports_from_types(self.output_type)
219
+ output_type_imports = self.output_type
221
220
  if len(self.output_type) > 1:
222
- output_type_imports.append("from typing_extensions import Union")
221
+ output_type_imports.append(Union)
223
222
  if list in self.output_type:
224
- output_type_imports.append("from typing_extensions import List")
225
- imports = get_imports_from_scope(self.case_query.scope)
226
- imports = [i for i in imports if ("get_ipython" not in i)]
227
- imports.extend(case_type_imports)
228
- imports.extend([oti for oti in output_type_imports if oti not in imports])
223
+ output_type_imports.append(List)
224
+ import_types = list(self.case_query.scope.values())
225
+ # imports = [i for i in imports if ("get_ipython" not in i)]
226
+ import_types.extend(case_type_imports)
227
+ import_types.extend(output_type_imports)
228
+ imports = get_imports_from_types(import_types)
229
229
  imports = set(imports)
230
230
  return '\n'.join(imports)
231
231
 
@@ -300,6 +300,7 @@ class TemplateFileCreator:
300
300
  if isinstance(node, ast.FunctionDef) and node.name == func_name:
301
301
  exec_globals = {}
302
302
  scope = extract_imports(tree=tree)
303
+ updates.update(scope)
303
304
  exec(source, scope, exec_globals)
304
305
  user_function = exec_globals[func_name]
305
306
  updates[func_name] = user_function