ripple-down-rules 0.5.63__py3-none-any.whl → 0.5.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/datastructures/case.py +10 -4
- ripple_down_rules/datastructures/dataclasses.py +62 -3
- ripple_down_rules/helpers.py +55 -9
- ripple_down_rules/rdr.py +141 -101
- ripple_down_rules/rdr_decorators.py +54 -23
- ripple_down_rules/rules.py +63 -13
- ripple_down_rules/user_interface/gui.py +9 -7
- ripple_down_rules/user_interface/ipython_custom_shell.py +1 -1
- ripple_down_rules/user_interface/object_diagram.py +9 -1
- ripple_down_rules/user_interface/template_file_creator.py +17 -22
- ripple_down_rules/utils.py +235 -62
- {ripple_down_rules-0.5.63.dist-info → ripple_down_rules-0.5.71.dist-info}/METADATA +2 -1
- ripple_down_rules-0.5.71.dist-info/RECORD +24 -0
- ripple_down_rules-0.5.63.dist-info/RECORD +0 -24
- {ripple_down_rules-0.5.63.dist-info → ripple_down_rules-0.5.71.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.5.63.dist-info → ripple_down_rules-0.5.71.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.5.63.dist-info → ripple_down_rules-0.5.71.dist-info}/top_level.txt +0 -0
@@ -6,16 +6,18 @@ of the RDRs.
|
|
6
6
|
import os.path
|
7
7
|
from functools import wraps
|
8
8
|
|
9
|
-
from
|
10
|
-
from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union
|
9
|
+
from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union, Sequence
|
11
10
|
|
12
|
-
from ripple_down_rules.datastructures.case import
|
11
|
+
from ripple_down_rules.datastructures.case import Case
|
13
12
|
from ripple_down_rules.datastructures.dataclasses import CaseQuery
|
14
|
-
from ripple_down_rules.datastructures.enums import Category
|
15
13
|
from ripple_down_rules.experts import Expert, Human
|
16
|
-
from ripple_down_rules.rdr import GeneralRDR
|
14
|
+
from ripple_down_rules.rdr import GeneralRDR
|
15
|
+
try:
|
16
|
+
from ripple_down_rules.user_interface.gui import RDRCaseViewer
|
17
|
+
except ImportError:
|
18
|
+
RDRCaseViewer = None
|
17
19
|
from ripple_down_rules.utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
|
18
|
-
get_method_class_if_exists,
|
20
|
+
get_method_class_if_exists, str_to_snake_case
|
19
21
|
|
20
22
|
|
21
23
|
class RDRDecorator:
|
@@ -26,7 +28,10 @@ class RDRDecorator:
|
|
26
28
|
mutual_exclusive: bool,
|
27
29
|
output_name: str = "output_",
|
28
30
|
fit: bool = True,
|
29
|
-
expert: Optional[Expert] = None
|
31
|
+
expert: Optional[Expert] = None,
|
32
|
+
ask_always: bool = False,
|
33
|
+
update_existing_rules: bool = True,
|
34
|
+
viewer: Optional[RDRCaseViewer] = None):
|
30
35
|
"""
|
31
36
|
:param models_dir: The directory to save/load the RDR models.
|
32
37
|
:param output_type: The type of the output. This is used to create the RDR model.
|
@@ -38,6 +43,9 @@ class RDRDecorator:
|
|
38
43
|
classification mode. This means that the RDR will classify the function's output based on the RDR model.
|
39
44
|
:param expert: The expert that will be used to prompt the user for the correct output. If None, a Human
|
40
45
|
expert will be used.
|
46
|
+
:param ask_always: If True, the function will ask the user for a target if it doesn't exist.
|
47
|
+
:param update_existing_rules: If True, the function will update the existing RDR rules
|
48
|
+
even if they gave an output.
|
41
49
|
:return: A decorator to use a GeneralRDR as a classifier that monitors and modifies the function's output.
|
42
50
|
"""
|
43
51
|
self.rdr_models_dir = models_dir
|
@@ -48,6 +56,9 @@ class RDRDecorator:
|
|
48
56
|
self.output_name = output_name
|
49
57
|
self.fit: bool = fit
|
50
58
|
self.expert: Optional[Expert] = expert
|
59
|
+
self.ask_always = ask_always
|
60
|
+
self.update_existing_rules = update_existing_rules
|
61
|
+
self.viewer = viewer
|
51
62
|
self.load()
|
52
63
|
|
53
64
|
def decorator(self, func: Callable) -> Callable:
|
@@ -59,59 +70,77 @@ class RDRDecorator:
|
|
59
70
|
self.parsed_output_type = self.parse_output_type(func, self.output_type, *args)
|
60
71
|
if self.model_name is None:
|
61
72
|
self.initialize_rdr_model_name_and_load(func)
|
73
|
+
if self.expert is None:
|
74
|
+
self.expert = Human(viewer=self.viewer,
|
75
|
+
answers_save_path=self.rdr_models_dir + f'/expert_answers')
|
76
|
+
|
77
|
+
func_output = {self.output_name: func(*args, **kwargs)}
|
62
78
|
|
63
79
|
if self.fit:
|
64
|
-
case_query = self.create_case_query_from_method(func,
|
65
|
-
self.
|
80
|
+
case_query = self.create_case_query_from_method(func, func_output,
|
81
|
+
self.parsed_output_type,
|
82
|
+
self.mutual_exclusive,
|
66
83
|
*args, **kwargs)
|
67
|
-
output = self.rdr.fit_case(case_query, expert=self.expert
|
84
|
+
output = self.rdr.fit_case(case_query, expert=self.expert,
|
85
|
+
ask_always_for_target=self.ask_always,
|
86
|
+
update_existing_rules=self.update_existing_rules,
|
87
|
+
viewer=self.viewer)
|
88
|
+
else:
|
89
|
+
case, case_dict = self.create_case_from_method(func, func_output, *args, **kwargs)
|
90
|
+
output = self.rdr.classify(case)
|
91
|
+
|
92
|
+
if self.output_name in output:
|
68
93
|
return output[self.output_name]
|
69
94
|
else:
|
70
|
-
|
71
|
-
return self.rdr.classify(case)[self.output_name]
|
95
|
+
return func_output[self.output_name]
|
72
96
|
|
73
97
|
return wrapper
|
74
98
|
|
75
99
|
@staticmethod
|
76
|
-
def create_case_query_from_method(func: Callable,
|
77
|
-
|
100
|
+
def create_case_query_from_method(func: Callable,
|
101
|
+
func_output: Dict[str, Any],
|
102
|
+
output_type: Sequence[Type],
|
103
|
+
mutual_exclusive: bool,
|
104
|
+
*args, **kwargs) -> CaseQuery:
|
78
105
|
"""
|
79
106
|
Create a CaseQuery from the function and its arguments.
|
80
107
|
|
81
108
|
:param func: The function to create a case from.
|
82
|
-
:param
|
109
|
+
:param func_output: The output of the function as a dictionary, where the key is the output name.
|
110
|
+
:param output_type: The type of the output as a sequence of types.
|
83
111
|
:param mutual_exclusive: If True, the output types are mutually exclusive.
|
84
|
-
:param output_name: The name of the output in the case. Defaults to 'output_'.
|
85
112
|
:param args: The positional arguments of the function.
|
86
113
|
:param kwargs: The keyword arguments of the function.
|
87
114
|
:return: A CaseQuery object representing the case.
|
88
115
|
"""
|
89
116
|
output_type = make_set(output_type)
|
90
|
-
case, case_dict = RDRDecorator.create_case_from_method(func,
|
117
|
+
case, case_dict = RDRDecorator.create_case_from_method(func, func_output, *args, **kwargs)
|
91
118
|
scope = func.__globals__
|
92
119
|
scope.update(case_dict)
|
93
120
|
func_args_type_hints = get_type_hints(func)
|
121
|
+
output_name = list(func_output.keys())[0]
|
94
122
|
func_args_type_hints.update({output_name: Union[tuple(output_type)]})
|
95
123
|
return CaseQuery(case, output_name, Union[tuple(output_type)],
|
96
124
|
mutual_exclusive, scope=scope,
|
97
125
|
is_function=True, function_args_type_hints=func_args_type_hints)
|
98
126
|
|
99
127
|
@staticmethod
|
100
|
-
def create_case_from_method(func: Callable,
|
128
|
+
def create_case_from_method(func: Callable,
|
129
|
+
func_output: Dict[str, Any],
|
130
|
+
*args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
|
101
131
|
"""
|
102
132
|
Create a Case from the function and its arguments.
|
103
133
|
|
104
134
|
:param func: The function to create a case from.
|
105
|
-
:param
|
135
|
+
:param func_output: A dictionary containing the output of the function, where the key is the output name.
|
106
136
|
:param args: The positional arguments of the function.
|
107
137
|
:param kwargs: The keyword arguments of the function.
|
108
138
|
:return: A Case object representing the case.
|
109
139
|
"""
|
110
140
|
case_dict = get_method_args_as_dict(func, *args, **kwargs)
|
111
|
-
func_output
|
112
|
-
case_dict.update({output_name: func_output})
|
141
|
+
case_dict.update(func_output)
|
113
142
|
case_name = get_func_rdr_model_name(func)
|
114
|
-
return
|
143
|
+
return Case(dict, id(case_dict), case_name, case_dict, **case_dict), case_dict
|
115
144
|
|
116
145
|
def initialize_rdr_model_name_and_load(self, func: Callable) -> None:
|
117
146
|
model_file_name = get_func_rdr_model_name(func, include_file_name=True)
|
@@ -148,8 +177,10 @@ class RDRDecorator:
|
|
148
177
|
model_path = os.path.join(self.rdr_models_dir, self.model_name + f"/rdr_metadata/{self.model_name}.json")
|
149
178
|
if os.path.exists(os.path.join(self.rdr_models_dir, model_path)):
|
150
179
|
self.rdr = GeneralRDR.load(self.rdr_models_dir, self.model_name)
|
180
|
+
self.rdr.set_viewer(self.viewer)
|
151
181
|
if self.rdr is None:
|
152
|
-
self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name
|
182
|
+
self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name,
|
183
|
+
viewer=self.viewer)
|
153
184
|
|
154
185
|
def update_from_python(self):
|
155
186
|
"""
|
ripple_down_rules/rules.py
CHANGED
@@ -3,16 +3,18 @@ from __future__ import annotations
|
|
3
3
|
import logging
|
4
4
|
import re
|
5
5
|
from abc import ABC, abstractmethod
|
6
|
+
from pathlib import Path
|
6
7
|
from uuid import uuid4
|
7
8
|
|
8
9
|
from anytree import NodeMixin
|
9
10
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
10
|
-
from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
|
11
|
+
from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple, Callable
|
11
12
|
|
12
13
|
from .datastructures.callable_expression import CallableExpression
|
13
14
|
from .datastructures.case import Case
|
15
|
+
from .datastructures.dataclasses import CaseFactoryMetaData, CaseConf, CaseQuery
|
14
16
|
from .datastructures.enums import RDREdge, Stop
|
15
|
-
from .utils import SubclassJSONSerializer, conclusion_to_json
|
17
|
+
from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_imports_from_types
|
16
18
|
|
17
19
|
|
18
20
|
class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
@@ -27,7 +29,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
27
29
|
corner_case: Optional[Union[Case, SQLTable]] = None,
|
28
30
|
weight: Optional[str] = None,
|
29
31
|
conclusion_name: Optional[str] = None,
|
30
|
-
uid: Optional[str] = None
|
32
|
+
uid: Optional[str] = None,
|
33
|
+
corner_case_metadata: Optional[CaseFactoryMetaData] = None):
|
31
34
|
"""
|
32
35
|
A rule in the ripple down rules classifier.
|
33
36
|
|
@@ -38,10 +41,13 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
38
41
|
:param weight: The weight of the rule, which is the type of edge connecting the rule to its parent.
|
39
42
|
:param conclusion_name: The name of the conclusion of the rule.
|
40
43
|
:param uid: The unique id of the rule.
|
44
|
+
:param corner_case_metadata: Metadata about the corner case, such as the factory that created it or the
|
45
|
+
scenario it is based on.
|
41
46
|
"""
|
42
47
|
super(Rule, self).__init__()
|
43
48
|
self.conclusion = conclusion
|
44
49
|
self.corner_case = corner_case
|
50
|
+
self.corner_case_metadata: Optional[CaseFactoryMetaData] = corner_case_metadata
|
45
51
|
self.parent = parent
|
46
52
|
self.weight: Optional[str] = weight
|
47
53
|
self.conditions = conditions if conditions else None
|
@@ -51,6 +57,20 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
51
57
|
# generate a unique id for the rule using uuid4
|
52
58
|
self.uid: str = uid if uid else str(uuid4().int)
|
53
59
|
|
60
|
+
@classmethod
|
61
|
+
def from_case_query(cls, case_query: CaseQuery) -> Rule:
|
62
|
+
"""
|
63
|
+
Create a SingleClassRule from a CaseQuery.
|
64
|
+
|
65
|
+
:param case_query: The CaseQuery to create the rule from.
|
66
|
+
:return: A SingleClassRule instance.
|
67
|
+
"""
|
68
|
+
corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
69
|
+
return cls(conditions=case_query.conditions, conclusion=case_query.target,
|
70
|
+
corner_case=case_query.case, parent=None,
|
71
|
+
corner_case_metadata=corner_case_metadata,
|
72
|
+
conclusion_name=case_query.attribute_name)
|
73
|
+
|
54
74
|
def _post_detach(self, parent):
|
55
75
|
"""
|
56
76
|
Called after this node is detached from the tree, useful when drawing the tree.
|
@@ -82,6 +102,27 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
82
102
|
"""
|
83
103
|
pass
|
84
104
|
|
105
|
+
def write_corner_case_as_source_code(self, cases_file: Path) -> None:
|
106
|
+
"""
|
107
|
+
Write the source code representation of the corner case of the rule to a file.
|
108
|
+
|
109
|
+
:param cases_file: The file to write the corner case to if it is a definition.
|
110
|
+
"""
|
111
|
+
if self.corner_case_metadata is None:
|
112
|
+
return
|
113
|
+
types_to_import = set()
|
114
|
+
if self.corner_case_metadata.factory_method is not None:
|
115
|
+
types_to_import.add(self.corner_case_metadata.factory_method)
|
116
|
+
if self.corner_case_metadata.scenario is not None:
|
117
|
+
types_to_import.add(self.corner_case_metadata.scenario)
|
118
|
+
if self.corner_case_metadata.case_conf is not None:
|
119
|
+
types_to_import.add(self.corner_case_metadata.case_conf)
|
120
|
+
types_to_import.add(CaseFactoryMetaData)
|
121
|
+
imports = get_imports_from_types(list(types_to_import))
|
122
|
+
with open(cases_file, 'a') as f:
|
123
|
+
f.write("\n".join(imports) + "\n\n\n")
|
124
|
+
f.write(f"corner_case_{self.uid} = {self.corner_case_metadata}" + "\n\n\n")
|
125
|
+
|
85
126
|
def write_conclusion_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
|
86
127
|
"""
|
87
128
|
Get the source code representation of the conclusion of the rule.
|
@@ -150,11 +191,16 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
150
191
|
pass
|
151
192
|
|
152
193
|
def _to_json(self) -> Dict[str, Any]:
|
153
|
-
|
194
|
+
try:
|
195
|
+
corner_case = SubclassJSONSerializer.to_json_static(self.corner_case) if self.corner_case else None
|
196
|
+
except Exception as e:
|
197
|
+
logging.debug("Failed to serialize corner case to json, setting it to None. Error: %s", e)
|
198
|
+
corner_case = None
|
199
|
+
json_serialization = {"_type": get_full_class_name(type(self)),
|
200
|
+
"conditions": self.conditions.to_json(),
|
154
201
|
"conclusion": conclusion_to_json(self.conclusion),
|
155
202
|
"parent": self.parent.json_serialization if self.parent else None,
|
156
|
-
"corner_case":
|
157
|
-
if self.corner_case else None,
|
203
|
+
"corner_case": corner_case,
|
158
204
|
"conclusion_name": self.conclusion_name,
|
159
205
|
"weight": self.weight,
|
160
206
|
"uid": self.uid}
|
@@ -277,9 +323,12 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
|
277
323
|
returned_rule = self.alternative(x) if self.alternative else self
|
278
324
|
return returned_rule if returned_rule.fired else self
|
279
325
|
|
280
|
-
def fit_rule(self,
|
281
|
-
|
282
|
-
|
326
|
+
def fit_rule(self, case_query: CaseQuery):
|
327
|
+
corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
328
|
+
new_rule = SingleClassRule(case_query.conditions, case_query.target,
|
329
|
+
corner_case=case_query.case, parent=self,
|
330
|
+
corner_case_metadata=corner_case_metadata,
|
331
|
+
)
|
283
332
|
if self.fired:
|
284
333
|
self.refinement = new_rule
|
285
334
|
else:
|
@@ -363,11 +412,12 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
|
363
412
|
elif self.alternative: # Here alternative refers to next rule in MultiClassRDR
|
364
413
|
return self.alternative
|
365
414
|
|
366
|
-
def fit_rule(self,
|
367
|
-
if self.fired and target != self.conclusion:
|
368
|
-
self.refinement = MultiClassStopRule(conditions, corner_case=
|
415
|
+
def fit_rule(self, case_query: CaseQuery):
|
416
|
+
if self.fired and case_query.target != self.conclusion:
|
417
|
+
self.refinement = MultiClassStopRule(case_query.conditions, corner_case=case_query.case, parent=self)
|
369
418
|
elif not self.fired:
|
370
|
-
self.alternative = MultiClassTopRule(conditions, target,
|
419
|
+
self.alternative = MultiClassTopRule(case_query.conditions, case_query.target,
|
420
|
+
corner_case=case_query.case, parent=self)
|
371
421
|
|
372
422
|
def _to_json(self) -> Dict[str, Any]:
|
373
423
|
self.json_serialization = {**Rule._to_json(self),
|
@@ -281,12 +281,14 @@ class RDRCaseViewer(QMainWindow):
|
|
281
281
|
main_obj: Optional[Dict[str, Any]] = None
|
282
282
|
user_input: Optional[str] = None
|
283
283
|
attributes_widget: Optional[QWidget] = None
|
284
|
-
save_function: Optional[Callable[str], None] = None
|
284
|
+
save_function: Optional[Callable[str, str], None] = None
|
285
285
|
|
286
|
-
|
287
|
-
|
286
|
+
def __init__(self, parent=None,
|
287
|
+
save_dir: Optional[str] = None,
|
288
|
+
save_model_name: Optional[str] = None):
|
288
289
|
super().__init__(parent)
|
289
|
-
self.
|
290
|
+
self.save_dir = save_dir
|
291
|
+
self.save_model_name = save_model_name
|
290
292
|
|
291
293
|
self.setWindowTitle("RDR Case Viewer")
|
292
294
|
|
@@ -323,17 +325,17 @@ class RDRCaseViewer(QMainWindow):
|
|
323
325
|
|
324
326
|
# Add both to main layout
|
325
327
|
main_layout.addWidget(self.attributes_widget, stretch=1)
|
326
|
-
main_layout.addWidget(middle_widget, stretch=
|
328
|
+
main_layout.addWidget(middle_widget, stretch=1)
|
327
329
|
main_layout.addWidget(self.obj_diagram_viewer, stretch=2)
|
328
330
|
|
329
|
-
def set_save_function(self, save_function: Callable[[str], None]) -> None:
|
331
|
+
def set_save_function(self, save_function: Callable[[str, str], None]) -> None:
|
330
332
|
"""
|
331
333
|
Set the function to save the file.
|
332
334
|
|
333
335
|
:param save_function: The function to save the file.
|
334
336
|
"""
|
335
337
|
self.save_function = save_function
|
336
|
-
self.save_btn.clicked.connect(lambda: self.save_function(self.
|
338
|
+
self.save_btn.clicked.connect(lambda: self.save_function(self.save_dir, self.save_model_name))
|
337
339
|
|
338
340
|
def print(self, msg):
|
339
341
|
"""
|
@@ -134,7 +134,7 @@ class IPythonShell:
|
|
134
134
|
"""
|
135
135
|
Update the user input from the code lines captured in the shell.
|
136
136
|
"""
|
137
|
-
if
|
137
|
+
if self.shell.all_lines[0].replace('return', '').strip() == '':
|
138
138
|
self.user_input = None
|
139
139
|
else:
|
140
140
|
self.all_code_lines = extract_dependencies(self.shell.all_lines)
|
@@ -1,5 +1,9 @@
|
|
1
1
|
import logging
|
2
2
|
|
3
|
+
from ripple_down_rules.datastructures.case import Case
|
4
|
+
from ripple_down_rules.datastructures.dataclasses import CaseQuery
|
5
|
+
from ripple_down_rules.utils import SubclassJSONSerializer
|
6
|
+
|
3
7
|
try:
|
4
8
|
import graphviz
|
5
9
|
except ImportError:
|
@@ -77,7 +81,11 @@ def generate_object_graph(obj, name='root', seen=None, graph=None, current_depth
|
|
77
81
|
for attr in dir(obj):
|
78
82
|
if attr.startswith('_'):
|
79
83
|
continue
|
80
|
-
if attr == 'scope':
|
84
|
+
if isinstance(obj, CaseQuery) and attr == 'scope':
|
85
|
+
continue
|
86
|
+
if isinstance(obj, Case) and attr in ['data']:
|
87
|
+
continue
|
88
|
+
if isinstance(obj, SubclassJSONSerializer) and attr == 'data_class_refs':
|
81
89
|
continue
|
82
90
|
value = getattr(obj, attr)
|
83
91
|
if callable(value):
|
@@ -8,13 +8,13 @@ from functools import cached_property
|
|
8
8
|
from textwrap import indent, dedent
|
9
9
|
|
10
10
|
from colorama import Fore, Style
|
11
|
-
from typing_extensions import Optional, Type, List, Callable, Tuple, Dict
|
11
|
+
from typing_extensions import Optional, Type, List, Callable, Tuple, Dict, Any, Union
|
12
12
|
|
13
13
|
from ..datastructures.case import Case
|
14
14
|
from ..datastructures.dataclasses import CaseQuery
|
15
15
|
from ..datastructures.enums import Editor, PromptFor
|
16
|
-
from ..utils import str_to_snake_case, get_imports_from_scope, make_list,
|
17
|
-
get_imports_from_types, extract_function_source, extract_imports
|
16
|
+
from ..utils import str_to_snake_case, get_imports_from_scope, make_list, stringify_hint, \
|
17
|
+
get_imports_from_types, extract_function_source, extract_imports, get_types_to_import_from_type_hints
|
18
18
|
|
19
19
|
|
20
20
|
def detect_available_editor() -> Optional[Editor]:
|
@@ -177,7 +177,7 @@ class TemplateFileCreator:
|
|
177
177
|
for k, v in self.case_query.case.items():
|
178
178
|
if (self.case_query.function_args_type_hints is not None
|
179
179
|
and k in self.case_query.function_args_type_hints):
|
180
|
-
func_args[k] =
|
180
|
+
func_args[k] = stringify_hint(self.case_query.function_args_type_hints[k])
|
181
181
|
else:
|
182
182
|
func_args[k] = type(v).__name__ if not isinstance(v, type) else f"Type[{v.__name__}]"
|
183
183
|
func_args = ', '.join([f"{k}: {v}" if str(v) not in ["NoneType", "None"] else str(k)
|
@@ -207,30 +207,25 @@ class TemplateFileCreator:
|
|
207
207
|
for k, v in self.case_query.case.items():
|
208
208
|
if (self.case_query.function_args_type_hints is not None
|
209
209
|
and k in self.case_query.function_args_type_hints):
|
210
|
-
|
211
|
-
|
212
|
-
hint_split = hint.split('.')
|
213
|
-
if len(hint_split) > 1:
|
214
|
-
case_type_imports.append(f"from {'.'.join(hint_split[:-1])} import {hint_split[-1]}")
|
210
|
+
types_to_import = get_types_to_import_from_type_hints([self.case_query.function_args_type_hints[k]])
|
211
|
+
case_type_imports.extend(list(types_to_import))
|
215
212
|
else:
|
216
|
-
|
217
|
-
case_type_imports.append(f"from {v.__module__} import {v.__name__}")
|
218
|
-
elif hasattr(v, "__module__") and not v.__module__.startswith("__"):
|
219
|
-
case_type_imports.append(f"\nfrom {type(v).__module__} import {type(v).__name__}")
|
213
|
+
case_type_imports.append(v)
|
220
214
|
else:
|
221
|
-
case_type_imports.append(
|
215
|
+
case_type_imports.append(self.case_type)
|
222
216
|
if self.output_type is None:
|
223
|
-
output_type_imports = [
|
217
|
+
output_type_imports = [Any]
|
224
218
|
else:
|
225
|
-
output_type_imports =
|
219
|
+
output_type_imports = self.output_type
|
226
220
|
if len(self.output_type) > 1:
|
227
|
-
output_type_imports.append(
|
221
|
+
output_type_imports.append(Union)
|
228
222
|
if list in self.output_type:
|
229
|
-
output_type_imports.append(
|
230
|
-
|
231
|
-
imports = [i for i in imports if ("get_ipython" not in i)]
|
232
|
-
|
233
|
-
|
223
|
+
output_type_imports.append(List)
|
224
|
+
import_types = list(self.case_query.scope.values())
|
225
|
+
# imports = [i for i in imports if ("get_ipython" not in i)]
|
226
|
+
import_types.extend(case_type_imports)
|
227
|
+
import_types.extend(output_type_imports)
|
228
|
+
imports = get_imports_from_types(import_types)
|
234
229
|
imports = set(imports)
|
235
230
|
return '\n'.join(imports)
|
236
231
|
|