ripple-down-rules 0.5.5__py3-none-any.whl → 0.5.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,16 +6,18 @@ of the RDRs.
6
6
  import os.path
7
7
  from functools import wraps
8
8
 
9
- from pyparsing.tools.cvt_pyparsing_pep8_names import camel_to_snake
10
- from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union
9
+ from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union, Sequence
11
10
 
12
- from ripple_down_rules.datastructures.case import create_case, Case
11
+ from ripple_down_rules.datastructures.case import Case
13
12
  from ripple_down_rules.datastructures.dataclasses import CaseQuery
14
- from ripple_down_rules.datastructures.enums import Category
15
13
  from ripple_down_rules.experts import Expert, Human
16
- from ripple_down_rules.rdr import GeneralRDR, RippleDownRules
14
+ from ripple_down_rules.rdr import GeneralRDR
15
+ try:
16
+ from ripple_down_rules.user_interface.gui import RDRCaseViewer
17
+ except ImportError:
18
+ RDRCaseViewer = None
17
19
  from ripple_down_rules.utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
18
- get_method_class_if_exists, get_method_name, str_to_snake_case
20
+ get_method_class_if_exists, str_to_snake_case
19
21
 
20
22
 
21
23
  class RDRDecorator:
@@ -24,15 +26,16 @@ class RDRDecorator:
24
26
  def __init__(self, models_dir: str,
25
27
  output_type: Tuple[Type],
26
28
  mutual_exclusive: bool,
27
- python_dir: Optional[str] = None,
28
29
  output_name: str = "output_",
29
30
  fit: bool = True,
30
- expert: Optional[Expert] = None):
31
+ expert: Optional[Expert] = None,
32
+ ask_always: bool = False,
33
+ update_existing_rules: bool = True,
34
+ viewer: Optional[RDRCaseViewer] = None):
31
35
  """
32
36
  :param models_dir: The directory to save/load the RDR models.
33
37
  :param output_type: The type of the output. This is used to create the RDR model.
34
38
  :param mutual_exclusive: If True, the output types are mutually exclusive.
35
- :param python_dir: The directory to save the RDR model as a python file.
36
39
  If None, the RDR model will not be saved as a python file.
37
40
  :param output_name: The name of the output. This is used to create the RDR model.
38
41
  :param fit: If True, the function will be in fit mode. This means that the RDR will prompt the user for the
@@ -40,6 +43,9 @@ class RDRDecorator:
40
43
  classification mode. This means that the RDR will classify the function's output based on the RDR model.
41
44
  :param expert: The expert that will be used to prompt the user for the correct output. If None, a Human
42
45
  expert will be used.
46
+ :param ask_always: If True, the function will ask the user for a target if it doesn't exist.
47
+ :param update_existing_rules: If True, the function will update the existing RDR rules
48
+ even if they gave an output.
43
49
  :return: A decorator to use a GeneralRDR as a classifier that monitors and modifies the function's output.
44
50
  """
45
51
  self.rdr_models_dir = models_dir
@@ -47,10 +53,12 @@ class RDRDecorator:
47
53
  self.output_type = output_type
48
54
  self.parsed_output_type: List[Type] = []
49
55
  self.mutual_exclusive = mutual_exclusive
50
- self.rdr_python_path: Optional[str] = python_dir
51
56
  self.output_name = output_name
52
57
  self.fit: bool = fit
53
58
  self.expert: Optional[Expert] = expert
59
+ self.ask_always = ask_always
60
+ self.update_existing_rules = update_existing_rules
61
+ self.viewer = viewer
54
62
  self.load()
55
63
 
56
64
  def decorator(self, func: Callable) -> Callable:
@@ -62,61 +70,77 @@ class RDRDecorator:
62
70
  self.parsed_output_type = self.parse_output_type(func, self.output_type, *args)
63
71
  if self.model_name is None:
64
72
  self.initialize_rdr_model_name_and_load(func)
73
+ if self.expert is None:
74
+ self.expert = Human(viewer=self.viewer,
75
+ answers_save_path=self.rdr_models_dir + f'/expert_answers')
76
+
77
+ func_output = {self.output_name: func(*args, **kwargs)}
65
78
 
66
79
  if self.fit:
67
- expert_answers_path = os.path.join(self.rdr_models_dir, self.model_name, "expert_answers")
68
- self.expert = self.expert or Human(answers_save_path=expert_answers_path)
69
- case_query = self.create_case_query_from_method(func, self.parsed_output_type,
70
- self.mutual_exclusive, self.output_name,
80
+ case_query = self.create_case_query_from_method(func, func_output,
81
+ self.parsed_output_type,
82
+ self.mutual_exclusive,
71
83
  *args, **kwargs)
72
- output = self.rdr.fit_case(case_query, expert=self.expert)
84
+ output = self.rdr.fit_case(case_query, expert=self.expert,
85
+ ask_always_for_target=self.ask_always,
86
+ update_existing_rules=self.update_existing_rules,
87
+ viewer=self.viewer)
88
+ else:
89
+ case, case_dict = self.create_case_from_method(func, func_output, *args, **kwargs)
90
+ output = self.rdr.classify(case)
91
+
92
+ if self.output_name in output:
73
93
  return output[self.output_name]
74
94
  else:
75
- case, case_dict = self.create_case_from_method(func, self.output_name, *args, **kwargs)
76
- return self.rdr.classify(case)[self.output_name]
95
+ return func_output[self.output_name]
77
96
 
78
97
  return wrapper
79
98
 
80
99
  @staticmethod
81
- def create_case_query_from_method(func: Callable, output_type, mutual_exclusive: bool,
82
- output_name: str = 'output_', *args, **kwargs) -> CaseQuery:
100
+ def create_case_query_from_method(func: Callable,
101
+ func_output: Dict[str, Any],
102
+ output_type: Sequence[Type],
103
+ mutual_exclusive: bool,
104
+ *args, **kwargs) -> CaseQuery:
83
105
  """
84
106
  Create a CaseQuery from the function and its arguments.
85
107
 
86
108
  :param func: The function to create a case from.
87
- :param output_type: The type of the output.
109
+ :param func_output: The output of the function as a dictionary, where the key is the output name.
110
+ :param output_type: The type of the output as a sequence of types.
88
111
  :param mutual_exclusive: If True, the output types are mutually exclusive.
89
- :param output_name: The name of the output in the case. Defaults to 'output_'.
90
112
  :param args: The positional arguments of the function.
91
113
  :param kwargs: The keyword arguments of the function.
92
114
  :return: A CaseQuery object representing the case.
93
115
  """
94
116
  output_type = make_set(output_type)
95
- case, case_dict = RDRDecorator.create_case_from_method(func, output_name, *args, **kwargs)
117
+ case, case_dict = RDRDecorator.create_case_from_method(func, func_output, *args, **kwargs)
96
118
  scope = func.__globals__
97
119
  scope.update(case_dict)
98
120
  func_args_type_hints = get_type_hints(func)
121
+ output_name = list(func_output.keys())[0]
99
122
  func_args_type_hints.update({output_name: Union[tuple(output_type)]})
100
123
  return CaseQuery(case, output_name, Union[tuple(output_type)],
101
124
  mutual_exclusive, scope=scope,
102
125
  is_function=True, function_args_type_hints=func_args_type_hints)
103
126
 
104
127
  @staticmethod
105
- def create_case_from_method(func: Callable, output_name: str = "output_", *args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
128
+ def create_case_from_method(func: Callable,
129
+ func_output: Dict[str, Any],
130
+ *args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
106
131
  """
107
132
  Create a Case from the function and its arguments.
108
133
 
109
134
  :param func: The function to create a case from.
110
- :param output_name: The name of the output in the case. Defaults to 'output_'.
135
+ :param func_output: A dictionary containing the output of the function, where the key is the output name.
111
136
  :param args: The positional arguments of the function.
112
137
  :param kwargs: The keyword arguments of the function.
113
138
  :return: A Case object representing the case.
114
139
  """
115
140
  case_dict = get_method_args_as_dict(func, *args, **kwargs)
116
- func_output = func(*args, **kwargs)
117
- case_dict.update({output_name: func_output})
141
+ case_dict.update(func_output)
118
142
  case_name = get_func_rdr_model_name(func)
119
- return create_case(case_dict, obj_name=case_name, max_recursion_idx=3), case_dict
143
+ return Case(dict, id(case_dict), case_name, case_dict, **case_dict), case_dict
120
144
 
121
145
  def initialize_rdr_model_name_and_load(self, func: Callable) -> None:
122
146
  model_file_name = get_func_rdr_model_name(func, include_file_name=True)
@@ -148,10 +172,15 @@ class RDRDecorator:
148
172
  """
149
173
  Load the RDR model from the specified directory.
150
174
  """
151
- if self.model_name is not None and os.path.exists(os.path.join(self.rdr_models_dir, self.model_name)):
152
- self.rdr = GeneralRDR.load(self.rdr_models_dir, self.model_name)
153
- else:
154
- self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name)
175
+ self.rdr = None
176
+ if self.model_name is not None:
177
+ model_path = os.path.join(self.rdr_models_dir, self.model_name + f"/rdr_metadata/{self.model_name}.json")
178
+ if os.path.exists(os.path.join(self.rdr_models_dir, model_path)):
179
+ self.rdr = GeneralRDR.load(self.rdr_models_dir, self.model_name)
180
+ self.rdr.set_viewer(self.viewer)
181
+ if self.rdr is None:
182
+ self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name,
183
+ viewer=self.viewer)
155
184
 
156
185
  def update_from_python(self):
157
186
  """
@@ -3,16 +3,18 @@ from __future__ import annotations
3
3
  import logging
4
4
  import re
5
5
  from abc import ABC, abstractmethod
6
+ from pathlib import Path
6
7
  from uuid import uuid4
7
8
 
8
9
  from anytree import NodeMixin
9
10
  from sqlalchemy.orm import DeclarativeBase as SQLTable
10
- from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
11
+ from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple, Type, Set
11
12
 
12
13
  from .datastructures.callable_expression import CallableExpression
13
14
  from .datastructures.case import Case
15
+ from .datastructures.dataclasses import CaseFactoryMetaData, CaseQuery
14
16
  from .datastructures.enums import RDREdge, Stop
15
- from .utils import SubclassJSONSerializer, conclusion_to_json
17
+ from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_imports_from_types
16
18
 
17
19
 
18
20
  class Rule(NodeMixin, SubclassJSONSerializer, ABC):
@@ -27,7 +29,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
27
29
  corner_case: Optional[Union[Case, SQLTable]] = None,
28
30
  weight: Optional[str] = None,
29
31
  conclusion_name: Optional[str] = None,
30
- uid: Optional[str] = None):
32
+ uid: Optional[str] = None,
33
+ corner_case_metadata: Optional[CaseFactoryMetaData] = None):
31
34
  """
32
35
  A rule in the ripple down rules classifier.
33
36
 
@@ -38,10 +41,13 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
38
41
  :param weight: The weight of the rule, which is the type of edge connecting the rule to its parent.
39
42
  :param conclusion_name: The name of the conclusion of the rule.
40
43
  :param uid: The unique id of the rule.
44
+ :param corner_case_metadata: Metadata about the corner case, such as the factory that created it or the
45
+ scenario it is based on.
41
46
  """
42
47
  super(Rule, self).__init__()
43
48
  self.conclusion = conclusion
44
49
  self.corner_case = corner_case
50
+ self.corner_case_metadata: Optional[CaseFactoryMetaData] = corner_case_metadata
45
51
  self.parent = parent
46
52
  self.weight: Optional[str] = weight
47
53
  self.conditions = conditions if conditions else None
@@ -51,6 +57,20 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
51
57
  # generate a unique id for the rule using uuid4
52
58
  self.uid: str = uid if uid else str(uuid4().int)
53
59
 
60
+ @classmethod
61
+ def from_case_query(cls, case_query: CaseQuery) -> Rule:
62
+ """
63
+ Create a SingleClassRule from a CaseQuery.
64
+
65
+ :param case_query: The CaseQuery to create the rule from.
66
+ :return: A SingleClassRule instance.
67
+ """
68
+ corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
69
+ return cls(conditions=case_query.conditions, conclusion=case_query.target,
70
+ corner_case=case_query.case, parent=None,
71
+ corner_case_metadata=corner_case_metadata,
72
+ conclusion_name=case_query.attribute_name)
73
+
54
74
  def _post_detach(self, parent):
55
75
  """
56
76
  Called after this node is detached from the tree, useful when drawing the tree.
@@ -82,6 +102,33 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
82
102
  """
83
103
  pass
84
104
 
105
+ def write_corner_case_as_source_code(self, cases_file: str, package_name: Optional[str] = None) -> None:
106
+ """
107
+ Write the source code representation of the corner case of the rule to a file.
108
+
109
+ :param cases_file: The file to write the corner case to.
110
+ :param package_name: The package name to use for relative imports.
111
+ """
112
+ if self.corner_case_metadata is None:
113
+ return
114
+ with open(cases_file, 'a') as f:
115
+ f.write(f"corner_case_{self.uid} = {self.corner_case_metadata}" + "\n\n\n")
116
+
117
+ def get_corner_case_types_to_import(self) -> Set[Type]:
118
+ """
119
+ Get the types that need to be imported for the corner case of the rule.
120
+ """
121
+ if self.corner_case_metadata is None:
122
+ return
123
+ types_to_import = set()
124
+ if self.corner_case_metadata.factory_method is not None:
125
+ types_to_import.add(self.corner_case_metadata.factory_method)
126
+ if self.corner_case_metadata.scenario is not None:
127
+ types_to_import.add(self.corner_case_metadata.scenario)
128
+ if self.corner_case_metadata.case_conf is not None:
129
+ types_to_import.add(self.corner_case_metadata.case_conf)
130
+ return types_to_import
131
+
85
132
  def write_conclusion_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
86
133
  """
87
134
  Get the source code representation of the conclusion of the rule.
@@ -150,11 +197,16 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
150
197
  pass
151
198
 
152
199
  def _to_json(self) -> Dict[str, Any]:
153
- json_serialization = {"conditions": self.conditions.to_json(),
200
+ try:
201
+ corner_case = SubclassJSONSerializer.to_json_static(self.corner_case) if self.corner_case else None
202
+ except Exception as e:
203
+ logging.debug("Failed to serialize corner case to json, setting it to None. Error: %s", e)
204
+ corner_case = None
205
+ json_serialization = {"_type": get_full_class_name(type(self)),
206
+ "conditions": self.conditions.to_json(),
154
207
  "conclusion": conclusion_to_json(self.conclusion),
155
208
  "parent": self.parent.json_serialization if self.parent else None,
156
- "corner_case": SubclassJSONSerializer.to_json_static(self.corner_case)
157
- if self.corner_case else None,
209
+ "corner_case": corner_case,
158
210
  "conclusion_name": self.conclusion_name,
159
211
  "weight": self.weight,
160
212
  "uid": self.uid}
@@ -277,9 +329,12 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
277
329
  returned_rule = self.alternative(x) if self.alternative else self
278
330
  return returned_rule if returned_rule.fired else self
279
331
 
280
- def fit_rule(self, x: Case, target: CallableExpression, conditions: CallableExpression):
281
- new_rule = SingleClassRule(conditions, target,
282
- corner_case=x, parent=self)
332
+ def fit_rule(self, case_query: CaseQuery):
333
+ corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
334
+ new_rule = SingleClassRule(case_query.conditions, case_query.target,
335
+ corner_case=case_query.case, parent=self,
336
+ corner_case_metadata=corner_case_metadata,
337
+ )
283
338
  if self.fired:
284
339
  self.refinement = new_rule
285
340
  else:
@@ -363,11 +418,12 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
363
418
  elif self.alternative: # Here alternative refers to next rule in MultiClassRDR
364
419
  return self.alternative
365
420
 
366
- def fit_rule(self, x: Case, target: CallableExpression, conditions: CallableExpression):
367
- if self.fired and target != self.conclusion:
368
- self.refinement = MultiClassStopRule(conditions, corner_case=x, parent=self)
421
+ def fit_rule(self, case_query: CaseQuery):
422
+ if self.fired and case_query.target != self.conclusion:
423
+ self.refinement = MultiClassStopRule(case_query.conditions, corner_case=case_query.case, parent=self)
369
424
  elif not self.fired:
370
- self.alternative = MultiClassTopRule(conditions, target, corner_case=x, parent=self)
425
+ self.alternative = MultiClassTopRule(case_query.conditions, case_query.target,
426
+ corner_case=case_query.case, parent=self)
371
427
 
372
428
  def _to_json(self) -> Dict[str, Any]:
373
429
  self.json_serialization = {**Rule._to_json(self),
@@ -281,12 +281,14 @@ class RDRCaseViewer(QMainWindow):
281
281
  main_obj: Optional[Dict[str, Any]] = None
282
282
  user_input: Optional[str] = None
283
283
  attributes_widget: Optional[QWidget] = None
284
- save_function: Optional[Callable[str], None] = None
284
+ save_function: Optional[Callable[str, str], None] = None
285
285
 
286
-
287
- def __init__(self, parent=None, save_file: Optional[str] = None):
286
+ def __init__(self, parent=None,
287
+ save_dir: Optional[str] = None,
288
+ save_model_name: Optional[str] = None):
288
289
  super().__init__(parent)
289
- self.save_file = save_file
290
+ self.save_dir = save_dir
291
+ self.save_model_name = save_model_name
290
292
 
291
293
  self.setWindowTitle("RDR Case Viewer")
292
294
 
@@ -323,17 +325,17 @@ class RDRCaseViewer(QMainWindow):
323
325
 
324
326
  # Add both to main layout
325
327
  main_layout.addWidget(self.attributes_widget, stretch=1)
326
- main_layout.addWidget(middle_widget, stretch=2)
328
+ main_layout.addWidget(middle_widget, stretch=1)
327
329
  main_layout.addWidget(self.obj_diagram_viewer, stretch=2)
328
330
 
329
- def set_save_function(self, save_function: Callable[[str], None]) -> None:
331
+ def set_save_function(self, save_function: Callable[[str, str], None]) -> None:
330
332
  """
331
333
  Set the function to save the file.
332
334
 
333
335
  :param save_function: The function to save the file.
334
336
  """
335
337
  self.save_function = save_function
336
- self.save_btn.clicked.connect(lambda: self.save_function(self.save_file))
338
+ self.save_btn.clicked.connect(lambda: self.save_function(self.save_dir, self.save_model_name))
337
339
 
338
340
  def print(self, msg):
339
341
  """
@@ -489,6 +491,7 @@ class RDRCaseViewer(QMainWindow):
489
491
  self.code_lines, self.template_file_creator.func_name,
490
492
  self.template_file_creator.function_signature,
491
493
  self.template_file_creator.func_doc, self.case_query)
494
+ self.case_query.scope.update(updates)
492
495
  self.template_file_creator = None
493
496
 
494
497
  def update_attribute_layout(self, obj, name: str):
@@ -134,7 +134,7 @@ class IPythonShell:
134
134
  """
135
135
  Update the user input from the code lines captured in the shell.
136
136
  """
137
- if len(self.shell.all_lines) == 1 and self.shell.all_lines[0].replace('return', '').strip() == '':
137
+ if self.shell.all_lines[0].replace('return', '').strip() == '':
138
138
  self.user_input = None
139
139
  else:
140
140
  self.all_code_lines = extract_dependencies(self.shell.all_lines)
@@ -1,5 +1,9 @@
1
1
  import logging
2
2
 
3
+ from ripple_down_rules.datastructures.case import Case
4
+ from ripple_down_rules.datastructures.dataclasses import CaseQuery
5
+ from ripple_down_rules.utils import SubclassJSONSerializer
6
+
3
7
  try:
4
8
  import graphviz
5
9
  except ImportError:
@@ -77,7 +81,11 @@ def generate_object_graph(obj, name='root', seen=None, graph=None, current_depth
77
81
  for attr in dir(obj):
78
82
  if attr.startswith('_'):
79
83
  continue
80
- if attr == 'scope':
84
+ if isinstance(obj, CaseQuery) and attr == 'scope':
85
+ continue
86
+ if isinstance(obj, Case) and attr in ['data']:
87
+ continue
88
+ if isinstance(obj, SubclassJSONSerializer) and attr == 'data_class_refs':
81
89
  continue
82
90
  value = getattr(obj, attr)
83
91
  if callable(value):
@@ -8,13 +8,13 @@ from functools import cached_property
8
8
  from textwrap import indent, dedent
9
9
 
10
10
  from colorama import Fore, Style
11
- from typing_extensions import Optional, Type, List, Callable, Tuple, Dict
11
+ from typing_extensions import Optional, Type, List, Callable, Tuple, Dict, Any, Union
12
12
 
13
13
  from ..datastructures.case import Case
14
14
  from ..datastructures.dataclasses import CaseQuery
15
15
  from ..datastructures.enums import Editor, PromptFor
16
- from ..utils import str_to_snake_case, get_imports_from_scope, make_list, typing_hint_to_str, \
17
- get_imports_from_types, extract_function_source, extract_imports
16
+ from ..utils import str_to_snake_case, get_imports_from_scope, make_list, stringify_hint, \
17
+ get_imports_from_types, extract_function_source, extract_imports, get_types_to_import_from_type_hints
18
18
 
19
19
 
20
20
  def detect_available_editor() -> Optional[Editor]:
@@ -84,6 +84,7 @@ class TemplateFileCreator:
84
84
  self.func_doc: str = self.get_func_doc()
85
85
  self.function_signature: str = self.get_function_signature()
86
86
  self.editor: Optional[Editor] = detect_available_editor()
87
+ self.editor_cmd: Optional[str] = os.environ.get("RDR_EDITOR_CMD")
87
88
  self.workspace: str = os.environ.get("RDR_EDITOR_WORKSPACE", os.path.dirname(self.case_query.scope['__file__']))
88
89
  self.temp_file_path: str = os.path.join(self.workspace, "edit_code_here.py")
89
90
 
@@ -98,7 +99,7 @@ class TemplateFileCreator:
98
99
  return make_list(output_type) if output_type is not None else None
99
100
 
100
101
  def edit(self):
101
- if self.editor is None:
102
+ if self.editor is None and self.editor_cmd is None:
102
103
  self.print_func(
103
104
  f"{Fore.RED}ERROR:: No editor found. Please install PyCharm, VSCode or code-server.{Style.RESET_ALL}")
104
105
  return
@@ -112,7 +113,11 @@ class TemplateFileCreator:
112
113
  """
113
114
  Open the file in the available editor.
114
115
  """
115
- if self.editor == Editor.Pycharm:
116
+ if self.editor_cmd is not None:
117
+ subprocess.Popen([self.editor_cmd, self.temp_file_path],
118
+ stdout=subprocess.DEVNULL,
119
+ stderr=subprocess.DEVNULL)
120
+ elif self.editor == Editor.Pycharm:
116
121
  subprocess.Popen(["pycharm", "--line", str(self.user_edit_line), self.temp_file_path],
117
122
  stdout=subprocess.DEVNULL,
118
123
  stderr=subprocess.DEVNULL)
@@ -172,7 +177,7 @@ class TemplateFileCreator:
172
177
  for k, v in self.case_query.case.items():
173
178
  if (self.case_query.function_args_type_hints is not None
174
179
  and k in self.case_query.function_args_type_hints):
175
- func_args[k] = typing_hint_to_str(self.case_query.function_args_type_hints[k])[0]
180
+ func_args[k] = stringify_hint(self.case_query.function_args_type_hints[k])
176
181
  else:
177
182
  func_args[k] = type(v).__name__ if not isinstance(v, type) else f"Type[{v.__name__}]"
178
183
  func_args = ', '.join([f"{k}: {v}" if str(v) not in ["NoneType", "None"] else str(k)
@@ -193,7 +198,7 @@ class TemplateFileCreator:
193
198
  with open(self.temp_file_path, 'w+') as f:
194
199
  f.write(code)
195
200
 
196
- def get_imports(self):
201
+ def get_imports(self) -> str:
197
202
  """
198
203
  :return: A string containing the imports for the function.
199
204
  """
@@ -202,30 +207,25 @@ class TemplateFileCreator:
202
207
  for k, v in self.case_query.case.items():
203
208
  if (self.case_query.function_args_type_hints is not None
204
209
  and k in self.case_query.function_args_type_hints):
205
- hint_list = typing_hint_to_str(self.case_query.function_args_type_hints[k])[1]
206
- for hint in hint_list:
207
- hint_split = hint.split('.')
208
- if len(hint_split) > 1:
209
- case_type_imports.append(f"from {'.'.join(hint_split[:-1])} import {hint_split[-1]}")
210
+ types_to_import = get_types_to_import_from_type_hints([self.case_query.function_args_type_hints[k]])
211
+ case_type_imports.extend(list(types_to_import))
210
212
  else:
211
- if isinstance(v, type):
212
- case_type_imports.append(f"from {v.__module__} import {v.__name__}")
213
- elif hasattr(v, "__module__") and not v.__module__.startswith("__"):
214
- case_type_imports.append(f"\nfrom {type(v).__module__} import {type(v).__name__}")
213
+ case_type_imports.append(v)
215
214
  else:
216
- case_type_imports.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
215
+ case_type_imports.append(self.case_type)
217
216
  if self.output_type is None:
218
- output_type_imports = [f"from typing_extensions import Any"]
217
+ output_type_imports = [Any]
219
218
  else:
220
- output_type_imports = get_imports_from_types(self.output_type)
219
+ output_type_imports = self.output_type
221
220
  if len(self.output_type) > 1:
222
- output_type_imports.append("from typing_extensions import Union")
221
+ output_type_imports.append(Union)
223
222
  if list in self.output_type:
224
- output_type_imports.append("from typing_extensions import List")
225
- imports = get_imports_from_scope(self.case_query.scope)
226
- imports = [i for i in imports if ("get_ipython" not in i)]
227
- imports.extend(case_type_imports)
228
- imports.extend([oti for oti in output_type_imports if oti not in imports])
223
+ output_type_imports.append(List)
224
+ import_types = list(self.case_query.scope.values())
225
+ # imports = [i for i in imports if ("get_ipython" not in i)]
226
+ import_types.extend(case_type_imports)
227
+ import_types.extend(output_type_imports)
228
+ imports = get_imports_from_types(import_types)
229
229
  imports = set(imports)
230
230
  return '\n'.join(imports)
231
231