typesync 0.0.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- typesync/__init__.py +13 -0
- typesync/cli.py +126 -0
- typesync/codegen/__init__.py +7 -0
- typesync/codegen/extractor.py +232 -0
- typesync/codegen/inference.py +239 -0
- typesync/codegen/writer.py +195 -0
- typesync/ts_types.py +125 -0
- typesync/type_translators/__init__.py +12 -0
- typesync/type_translators/abstract.py +21 -0
- typesync/type_translators/base_translator.py +149 -0
- typesync/type_translators/flask_translator.py +25 -0
- typesync/type_translators/type_node.py +84 -0
- typesync/utils.py +66 -0
- typesync-0.0.1a1.dist-info/METADATA +94 -0
- typesync-0.0.1a1.dist-info/RECORD +18 -0
- typesync-0.0.1a1.dist-info/WHEEL +5 -0
- typesync-0.0.1a1.dist-info/licenses/LICENSE +7 -0
- typesync-0.0.1a1.dist-info/top_level.txt +1 -0
typesync/__init__.py
ADDED
typesync/cli.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
from flask import current_app
|
|
5
|
+
from flask.cli import AppGroup
|
|
6
|
+
from werkzeug.routing.rules import Rule
|
|
7
|
+
|
|
8
|
+
from .codegen import CodeWriter, FlaskRouteTypeExtractor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
cli = AppGroup("typesync")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@cli.command(help="Generate Typescript types based on Flask routes.")
|
|
15
|
+
@click.argument("out_dir", type=click.Path(file_okay=False, resolve_path=True))
|
|
16
|
+
@click.option("--endpoint", "-E", help="The base endpoint.", default="")
|
|
17
|
+
@click.option("--samefile", "-S", help="Write types and apis to the same file")
|
|
18
|
+
@click.option(
|
|
19
|
+
"--inference",
|
|
20
|
+
"-i",
|
|
21
|
+
is_flag=True,
|
|
22
|
+
help="Whether to use inference when type annotations cannot be resolved",
|
|
23
|
+
)
|
|
24
|
+
@click.option(
|
|
25
|
+
"--inference-can-eval",
|
|
26
|
+
is_flag=True,
|
|
27
|
+
help=(
|
|
28
|
+
"Whether eval() can be called during inference. WARNING: this will"
|
|
29
|
+
" execute arbitrary code."
|
|
30
|
+
),
|
|
31
|
+
)
|
|
32
|
+
@click.option(
|
|
33
|
+
"--types-file",
|
|
34
|
+
help="Name of output file containing type definitions (defaults to 'types.ts')",
|
|
35
|
+
default="types.ts",
|
|
36
|
+
)
|
|
37
|
+
@click.option(
|
|
38
|
+
"--apis-file",
|
|
39
|
+
help="Name of output file containing API functions (defaults to 'apis.ts')",
|
|
40
|
+
default="apis.ts",
|
|
41
|
+
)
|
|
42
|
+
@click.option(
|
|
43
|
+
"--return-type-format",
|
|
44
|
+
default="{pc}ReturnType",
|
|
45
|
+
help=(
|
|
46
|
+
"Format string used to generate return type names from the route name. "
|
|
47
|
+
"Available placeholders are: "
|
|
48
|
+
"{d} (default route name), "
|
|
49
|
+
"{cc} (camelCase), "
|
|
50
|
+
"{pc} (PascalCase), "
|
|
51
|
+
"{uc} (UPPERCASE), "
|
|
52
|
+
"{lc} (lowercase), "
|
|
53
|
+
"{sc} (snake_case). "
|
|
54
|
+
"Defaults to: '{pc}ReturnType'"
|
|
55
|
+
),
|
|
56
|
+
)
|
|
57
|
+
@click.option(
|
|
58
|
+
"--args-type-format",
|
|
59
|
+
default="{pc}ArgsType",
|
|
60
|
+
help=(
|
|
61
|
+
"Format string used to generate argument type names from the route name. "
|
|
62
|
+
"Available placeholders are: "
|
|
63
|
+
"{d} (default route name), "
|
|
64
|
+
"{cc} (camelCase), "
|
|
65
|
+
"{pc} (PascalCase), "
|
|
66
|
+
"{uc} (UPPERCASE), "
|
|
67
|
+
"{lc} (lowercase), "
|
|
68
|
+
"{sc} (snake_case). "
|
|
69
|
+
"Defaults to: '{pc}ArgsType'"
|
|
70
|
+
),
|
|
71
|
+
)
|
|
72
|
+
@click.option(
|
|
73
|
+
"--function-name-format",
|
|
74
|
+
default="{m_lc}{r_pc}",
|
|
75
|
+
help=(
|
|
76
|
+
"Format string used to generate function names from the route and HTTP method. "
|
|
77
|
+
"Available placeholders are: "
|
|
78
|
+
"{r_d} or {m_d} (default route name or HTTP method), "
|
|
79
|
+
"{r_cc} or {m_cc} (camelCase), "
|
|
80
|
+
"{r_pc} or {m_pc} (PascalCase), "
|
|
81
|
+
"{r_uc} or {m_uc} (UPPERCASE), "
|
|
82
|
+
"{r_lc} or {m_lc} (lowercase), "
|
|
83
|
+
"{r_sc} or {m_sc} (snake_case). "
|
|
84
|
+
"Defaults to: '{m_lc}{r_pc}'"
|
|
85
|
+
),
|
|
86
|
+
)
|
|
87
|
+
def generate(
|
|
88
|
+
out_dir: str,
|
|
89
|
+
endpoint: str,
|
|
90
|
+
inference: bool,
|
|
91
|
+
inference_can_eval: bool,
|
|
92
|
+
types_file: str,
|
|
93
|
+
apis_file: str,
|
|
94
|
+
return_type_format: str,
|
|
95
|
+
args_type_format: str,
|
|
96
|
+
function_name_format: str,
|
|
97
|
+
samefile: str | None = None,
|
|
98
|
+
):
|
|
99
|
+
rules: list[Rule] = list(current_app.url_map.iter_rules())
|
|
100
|
+
|
|
101
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
102
|
+
|
|
103
|
+
with (
|
|
104
|
+
open(os.path.join(out_dir, types_file), "w") as types_f,
|
|
105
|
+
open(os.path.join(out_dir, apis_file), "w") as api_f,
|
|
106
|
+
):
|
|
107
|
+
code_writer = CodeWriter(
|
|
108
|
+
types_f,
|
|
109
|
+
api_f,
|
|
110
|
+
types_file,
|
|
111
|
+
return_type_format,
|
|
112
|
+
args_type_format,
|
|
113
|
+
function_name_format,
|
|
114
|
+
endpoint,
|
|
115
|
+
)
|
|
116
|
+
result = code_writer.write(
|
|
117
|
+
FlaskRouteTypeExtractor(
|
|
118
|
+
current_app,
|
|
119
|
+
rule,
|
|
120
|
+
inference_enabled=inference,
|
|
121
|
+
inference_can_eval=inference_can_eval,
|
|
122
|
+
)
|
|
123
|
+
for rule in rules
|
|
124
|
+
)
|
|
125
|
+
if not result:
|
|
126
|
+
click.secho("Errors occurred during file generation", fg="red")
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
from flask import Flask
|
|
5
|
+
from werkzeug.routing import (
|
|
6
|
+
FloatConverter,
|
|
7
|
+
BaseConverter,
|
|
8
|
+
IntegerConverter,
|
|
9
|
+
UUIDConverter,
|
|
10
|
+
PathConverter,
|
|
11
|
+
UnicodeConverter,
|
|
12
|
+
)
|
|
13
|
+
from werkzeug.routing.rules import Rule
|
|
14
|
+
|
|
15
|
+
from .inference import infer_return_type
|
|
16
|
+
from typesync.ts_types import TSType, TSSimpleType, TSObject
|
|
17
|
+
from typesync.type_translators import TypeNode, to_type_node
|
|
18
|
+
|
|
19
|
+
if typing.TYPE_CHECKING:
|
|
20
|
+
from typesync.type_translators import Translator
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
type Type = typing.Any
|
|
24
|
+
type TypeTreeTuple = tuple[TypeTree, ...]
|
|
25
|
+
type TypeTreeDict = dict[str, TypeTree]
|
|
26
|
+
type TypeTree = Type | tuple[Type, TypeTreeTuple | TypeTreeDict]
|
|
27
|
+
type GenericParamValues = dict[typing.TypeVar, TypeTree]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Logger(typing.Protocol):
|
|
31
|
+
def info(self, text: str) -> None: ...
|
|
32
|
+
def warning(self, text: str) -> None: ...
|
|
33
|
+
def error(self, text: str) -> None: ...
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ClickLogger:
|
|
37
|
+
@staticmethod
|
|
38
|
+
def info(text: str) -> None:
|
|
39
|
+
click.echo(f"Info: {text}")
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def warning(text: str) -> None:
|
|
43
|
+
click.secho(f"Warning: {text}", fg="yellow")
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def error(text: str) -> None:
|
|
47
|
+
click.secho(f"Error: {text}", fg="red")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_type_hints(tp: typing.Any) -> dict[str, typing.Any]:
|
|
51
|
+
return getattr(tp, "__annotations__", {})
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class FlaskRouteTypeExtractor:
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
app: Flask,
|
|
58
|
+
rule: Rule,
|
|
59
|
+
translators: tuple[type["Translator"]] | None = None,
|
|
60
|
+
inference_enabled: bool = False,
|
|
61
|
+
inference_can_eval: bool = False,
|
|
62
|
+
logger: Logger | None = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
self.app = app
|
|
65
|
+
self.rule = rule
|
|
66
|
+
self.inference_enabled = inference_enabled
|
|
67
|
+
self.inference_can_eval = inference_can_eval
|
|
68
|
+
self.logger = ClickLogger() if logger is None else logger
|
|
69
|
+
self.translators = (
|
|
70
|
+
self._load_default_translators() if translators is None else translators
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def _load_default_translators() -> tuple[type["Translator"], ...]:
|
|
75
|
+
from typesync.type_translators import BaseTranslator, FlaskTranslator # noqa: PLC0415
|
|
76
|
+
|
|
77
|
+
return (BaseTranslator, FlaskTranslator)
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def rule_name(self) -> str:
|
|
81
|
+
return self.rule.endpoint.replace(".", "_")
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def rule_url(self) -> str:
|
|
85
|
+
return "".join(
|
|
86
|
+
[(f"<{content}>" if arg else content) for arg, content in self.rule._trace][
|
|
87
|
+
1:
|
|
88
|
+
]
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def parse_args_type(self) -> TSType | None:
|
|
92
|
+
try:
|
|
93
|
+
used_converters: dict[str, BaseConverter] = {
|
|
94
|
+
arg: converter
|
|
95
|
+
for arg, converter in (
|
|
96
|
+
(arg, self.rule._converters.get(arg, None))
|
|
97
|
+
for arg in self.rule.arguments
|
|
98
|
+
)
|
|
99
|
+
if converter is not None
|
|
100
|
+
}
|
|
101
|
+
types: list[tuple[str, TSType]] = [
|
|
102
|
+
(arg, self._get_converter_type(arg, converter))
|
|
103
|
+
for arg, converter in used_converters.items()
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
return (
|
|
107
|
+
TSObject([t[0] for t in types], [t[1] for t in types])
|
|
108
|
+
if len(types) > 0
|
|
109
|
+
else TSSimpleType("undefined")
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
except Exception as e:
|
|
113
|
+
self.logger.error(f"couldn't parse argument types ({e})")
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
def translate_type(self, type_: Type) -> tuple[TSType, str | None]:
|
|
117
|
+
warning = None
|
|
118
|
+
|
|
119
|
+
def translate(
|
|
120
|
+
node: TypeNode, generics: dict[typing.TypeVar, TSType] | None
|
|
121
|
+
) -> TSType:
|
|
122
|
+
nonlocal warning
|
|
123
|
+
for translator in translators:
|
|
124
|
+
r = translator.translate(node, generics)
|
|
125
|
+
if r is not None:
|
|
126
|
+
return r
|
|
127
|
+
|
|
128
|
+
warning = (
|
|
129
|
+
f"can't translate '{getattr(node.origin, '__name__', node.origin)}'"
|
|
130
|
+
" to a TypeScript equivalent, defaulting to 'any'"
|
|
131
|
+
)
|
|
132
|
+
return TSSimpleType("any")
|
|
133
|
+
|
|
134
|
+
translators = [Translator(translate) for Translator in self.translators]
|
|
135
|
+
node = to_type_node(type_)
|
|
136
|
+
return translate(node, {}), warning
|
|
137
|
+
|
|
138
|
+
def parse_return_type(self) -> TSType | None:
|
|
139
|
+
try:
|
|
140
|
+
function = self.app.view_functions[self.rule.endpoint]
|
|
141
|
+
annotations = get_type_hints(function)
|
|
142
|
+
if annotations is not None and "return" in annotations:
|
|
143
|
+
return_annotations = annotations["return"]
|
|
144
|
+
route_return_annotations = self._get_route_annotations(
|
|
145
|
+
return_annotations
|
|
146
|
+
)
|
|
147
|
+
return_type, warning = self.translate_type(route_return_annotations)
|
|
148
|
+
else:
|
|
149
|
+
return_type = warning = None
|
|
150
|
+
|
|
151
|
+
if self.inference_enabled and (return_type is None or warning is not None):
|
|
152
|
+
route_return_type = infer_return_type(
|
|
153
|
+
function, self.logger, self.inference_can_eval
|
|
154
|
+
)
|
|
155
|
+
if route_return_type is not None:
|
|
156
|
+
return_type, warning = self.translate_type(
|
|
157
|
+
self._get_route_annotations(route_return_type)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if warning is not None:
|
|
161
|
+
self.logger.warning(warning)
|
|
162
|
+
|
|
163
|
+
return return_type or TSSimpleType("any")
|
|
164
|
+
|
|
165
|
+
except Exception as e:
|
|
166
|
+
self.logger.error(
|
|
167
|
+
f"couldn't parse return type of '{self.rule.endpoint}' ({e})"
|
|
168
|
+
)
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
else:
|
|
172
|
+
return TSSimpleType("any") if return_type is None else return_type
|
|
173
|
+
|
|
174
|
+
def parse_json_body(self) -> TSType | None:
|
|
175
|
+
try:
|
|
176
|
+
function = self.app.view_functions[self.rule.endpoint]
|
|
177
|
+
json_key = getattr(function, "_typesync", None)
|
|
178
|
+
if json_key is None:
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
annotations = get_type_hints(function)
|
|
182
|
+
if json_key not in annotations:
|
|
183
|
+
self.logger.error(
|
|
184
|
+
f"'{self.rule.endpoint}' expected to receive JSON body as keyword "
|
|
185
|
+
f"argument '{json_key}'"
|
|
186
|
+
)
|
|
187
|
+
return None
|
|
188
|
+
|
|
189
|
+
json_body_annotations = annotations[json_key]
|
|
190
|
+
json_body_type, warning = self.translate_type(json_body_annotations)
|
|
191
|
+
if warning is not None:
|
|
192
|
+
self.logger.warning(warning)
|
|
193
|
+
|
|
194
|
+
except Exception as e:
|
|
195
|
+
self.logger.error(
|
|
196
|
+
f"couldn't parse JSON body type of '{self.rule.endpoint}' ({e})"
|
|
197
|
+
)
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
else:
|
|
201
|
+
return json_body_type
|
|
202
|
+
|
|
203
|
+
def _get_route_annotations_from_tuple(self, tp: typing.Any) -> Type:
|
|
204
|
+
args = typing.get_args(tp)
|
|
205
|
+
return args[0]
|
|
206
|
+
|
|
207
|
+
def _get_route_annotations(self, tp: typing.Any) -> Type:
|
|
208
|
+
origin = typing.get_origin(tp) or tp
|
|
209
|
+
if origin is tuple:
|
|
210
|
+
return self._get_route_annotations_from_tuple(tp)
|
|
211
|
+
return tp
|
|
212
|
+
|
|
213
|
+
def _get_converter_type(self, arg: str, converter: BaseConverter) -> TSType:
|
|
214
|
+
if isinstance(converter, (FloatConverter, IntegerConverter)):
|
|
215
|
+
return TSSimpleType("number")
|
|
216
|
+
if isinstance(converter, (UUIDConverter, PathConverter, UnicodeConverter)):
|
|
217
|
+
return TSSimpleType("string")
|
|
218
|
+
|
|
219
|
+
# Custom converter, check to_python() annotations
|
|
220
|
+
annotations = get_type_hints(converter.to_python)
|
|
221
|
+
if annotations is None or "return" not in annotations:
|
|
222
|
+
self.logger.warning(
|
|
223
|
+
f"route '{self.rule.endpoint}', argument '{arg}': using non-standard"
|
|
224
|
+
"converter without type annotations, defaulting to 'string'",
|
|
225
|
+
)
|
|
226
|
+
return TSSimpleType("string")
|
|
227
|
+
|
|
228
|
+
return_annotations = annotations["return"]
|
|
229
|
+
return_type, warning = self.translate_type(return_annotations)
|
|
230
|
+
if warning is not None:
|
|
231
|
+
self.logger.warning(warning)
|
|
232
|
+
return return_type
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import builtins
|
|
3
|
+
import inspect
|
|
4
|
+
import itertools
|
|
5
|
+
import textwrap
|
|
6
|
+
import types
|
|
7
|
+
import typing
|
|
8
|
+
|
|
9
|
+
if typing.TYPE_CHECKING:
|
|
10
|
+
from .extractor import Logger
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ASTVisitor(ast.NodeVisitor):
|
|
14
|
+
def __init__(
|
|
15
|
+
self, function: typing.Callable, logger: "Logger", can_eval: bool = False
|
|
16
|
+
) -> None:
|
|
17
|
+
self.function = function
|
|
18
|
+
self.logger = logger
|
|
19
|
+
self.can_eval = can_eval
|
|
20
|
+
self.locals: dict[str, typing.Any] = {}
|
|
21
|
+
self.returns: list = []
|
|
22
|
+
|
|
23
|
+
def get_variable(self, name: ast.Name) -> typing.Any:
|
|
24
|
+
local_var = self.locals.get(name.id, None)
|
|
25
|
+
if local_var is not None:
|
|
26
|
+
return local_var
|
|
27
|
+
global_var = self.function.__globals__.get(name.id, None)
|
|
28
|
+
if global_var is not None:
|
|
29
|
+
return global_var
|
|
30
|
+
builtin = getattr(builtins, name.id, None)
|
|
31
|
+
if isinstance(builtin, type):
|
|
32
|
+
return builtin
|
|
33
|
+
return None
|
|
34
|
+
|
|
35
|
+
def get_constant(self, constant: ast.Constant) -> typing.Any:
|
|
36
|
+
try:
|
|
37
|
+
constant_value = ast.literal_eval(constant)
|
|
38
|
+
return type(constant_value)
|
|
39
|
+
except Exception:
|
|
40
|
+
return None
|
|
41
|
+
|
|
42
|
+
def get_type_if_all_equal(self, expressions: list[ast.expr]) -> typing.Any:
|
|
43
|
+
if len(expressions) == 0:
|
|
44
|
+
return None
|
|
45
|
+
if len(expressions) == 1:
|
|
46
|
+
return self.get_value(expressions[0])
|
|
47
|
+
|
|
48
|
+
el_types: dict[ast.expr, typing.Any] = {}
|
|
49
|
+
for el1, el2 in itertools.pairwise(expressions):
|
|
50
|
+
el1_type = el_types[el1] if el1 in el_types else self.get_value(el1)
|
|
51
|
+
el_types.setdefault(el1, el1_type)
|
|
52
|
+
|
|
53
|
+
el2_type = el_types[el2] if el2 in el_types else self.get_value(el2)
|
|
54
|
+
el_types.setdefault(el2, el2_type)
|
|
55
|
+
if el1_type != el2_type:
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
return el_types[el1]
|
|
59
|
+
|
|
60
|
+
def get_list(self, list_: ast.List) -> typing.Any:
|
|
61
|
+
list_type = self.get_type_if_all_equal(list_.elts)
|
|
62
|
+
return list if list_type is None else list[list_type]
|
|
63
|
+
|
|
64
|
+
def get_tuple(self, tuple_: ast.Tuple) -> typing.Any:
|
|
65
|
+
types = tuple(self.get_value(el) for el in tuple_.elts)
|
|
66
|
+
if len([t for t in types if t is None]) != 0:
|
|
67
|
+
return tuple
|
|
68
|
+
return tuple[types]
|
|
69
|
+
|
|
70
|
+
def get_dict(self, dict_: ast.Dict) -> typing.Any:
|
|
71
|
+
keys_type = self.get_type_if_all_equal(dict_.keys)
|
|
72
|
+
values_type = self.get_type_if_all_equal(dict_.values)
|
|
73
|
+
|
|
74
|
+
if keys_type is None and values_type is None:
|
|
75
|
+
return dict
|
|
76
|
+
|
|
77
|
+
return dict[
|
|
78
|
+
typing.Any if keys_type is None else keys_type,
|
|
79
|
+
typing.Any if values_type is None else values_type,
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
def get_value(self, expr: ast.expr) -> typing.Any:
|
|
83
|
+
match expr:
|
|
84
|
+
case ast.Name():
|
|
85
|
+
return self.get_variable(expr)
|
|
86
|
+
case ast.Constant():
|
|
87
|
+
return self.get_constant(expr)
|
|
88
|
+
case ast.List():
|
|
89
|
+
return self.get_list(expr)
|
|
90
|
+
case ast.Tuple():
|
|
91
|
+
return self.get_tuple(expr)
|
|
92
|
+
case ast.Dict():
|
|
93
|
+
return self.get_dict(expr)
|
|
94
|
+
case ast.Call():
|
|
95
|
+
return self.infer_call_type(expr)
|
|
96
|
+
case _:
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
def from_func_call(self, func: ast.Name) -> typing.Any:
|
|
100
|
+
called_function = self.get_variable(func)
|
|
101
|
+
if called_function is None:
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
if isinstance(called_function, type):
|
|
105
|
+
# This is a class
|
|
106
|
+
return called_function
|
|
107
|
+
|
|
108
|
+
origin = typing.get_origin(called_function) or called_function
|
|
109
|
+
if isinstance(origin, type):
|
|
110
|
+
# This is of the form type[T], so we should return T
|
|
111
|
+
args = typing.get_args(called_function)
|
|
112
|
+
if len(args) != 1:
|
|
113
|
+
return None
|
|
114
|
+
return args[0]
|
|
115
|
+
|
|
116
|
+
annotations = getattr(called_function, "__annotations__", {})
|
|
117
|
+
if "return" not in annotations:
|
|
118
|
+
return infer_return_type(called_function, self.logger, self.can_eval)
|
|
119
|
+
|
|
120
|
+
return annotations["return"]
|
|
121
|
+
|
|
122
|
+
def from_method_call(self, method: ast.Attribute) -> typing.Any:
|
|
123
|
+
value = self.get_value(method.value)
|
|
124
|
+
if value is None:
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
func = getattr(value, method.attr, None)
|
|
128
|
+
if func is None or not callable(func):
|
|
129
|
+
return None
|
|
130
|
+
annotations = getattr(func, "__annotations__", {})
|
|
131
|
+
if "return" not in annotations:
|
|
132
|
+
return infer_return_type(func, self.logger, self.can_eval)
|
|
133
|
+
return annotations["return"]
|
|
134
|
+
|
|
135
|
+
def infer_call_type(self, call: ast.Call) -> typing.Any:
|
|
136
|
+
callable_ = call.func
|
|
137
|
+
|
|
138
|
+
match callable_:
|
|
139
|
+
case ast.Name():
|
|
140
|
+
return self.from_func_call(callable_)
|
|
141
|
+
case ast.Attribute():
|
|
142
|
+
return self.from_method_call(callable_)
|
|
143
|
+
case _:
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
def visit_Return(self, node: ast.Return) -> None:
|
|
147
|
+
if node.value is not None:
|
|
148
|
+
self.returns.append(self.get_value(node.value))
|
|
149
|
+
self.generic_visit(node)
|
|
150
|
+
|
|
151
|
+
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
|
|
152
|
+
annotation = self.get_value(node.annotation)
|
|
153
|
+
if annotation is None and self.can_eval:
|
|
154
|
+
annotation_string = ast.unparse(node.annotation)
|
|
155
|
+
annotation_code = compile(
|
|
156
|
+
annotation_string,
|
|
157
|
+
"<annotation>",
|
|
158
|
+
"eval",
|
|
159
|
+
0,
|
|
160
|
+
)
|
|
161
|
+
try:
|
|
162
|
+
annotation = eval(annotation_code, self.function.__globals__, {}) # noqa: S307
|
|
163
|
+
except Exception as e:
|
|
164
|
+
self.logger.warning(
|
|
165
|
+
f"failed to parse annotation {annotation_string!r}: {e!s}"
|
|
166
|
+
)
|
|
167
|
+
annotation = None
|
|
168
|
+
if annotation is not None:
|
|
169
|
+
self.locals[node.target.id] = annotation
|
|
170
|
+
self.generic_visit(node)
|
|
171
|
+
|
|
172
|
+
def visit_Assign(self, node: ast.Assign) -> None:
|
|
173
|
+
for target in node.targets:
|
|
174
|
+
if not isinstance(target, ast.Name):
|
|
175
|
+
continue
|
|
176
|
+
self.locals[target.id] = self.get_value(node.value)
|
|
177
|
+
self.generic_visit(node)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def smart_type(type_: typing.Any) -> typing.Any:
|
|
181
|
+
if isinstance(type_, type):
|
|
182
|
+
return type[type_]
|
|
183
|
+
return type_
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def define_types_from_closure(function: typing.Callable, visitor: ASTVisitor) -> None:
|
|
187
|
+
closure: tuple[types.CellType, ...] | None = getattr(function, "__closure__", None)
|
|
188
|
+
code: types.CodeType | None = getattr(function, "__code__", None)
|
|
189
|
+
if code is None or closure is None:
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
visitor.locals.update(
|
|
193
|
+
zip(
|
|
194
|
+
code.co_freevars,
|
|
195
|
+
(smart_type(cell.cell_contents) for cell in closure),
|
|
196
|
+
strict=True,
|
|
197
|
+
)
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def define_types_from_signature(function: typing.Callable, visitor: ASTVisitor) -> None:
|
|
202
|
+
try:
|
|
203
|
+
signature = inspect.signature(function)
|
|
204
|
+
visitor.locals.update(
|
|
205
|
+
{param: value.annotation for param, value in signature.parameters.items()}
|
|
206
|
+
)
|
|
207
|
+
except ValueError:
|
|
208
|
+
pass
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def infer_return_type(
|
|
212
|
+
function: typing.Callable, logger: "Logger", can_eval: bool
|
|
213
|
+
) -> typing.Any:
|
|
214
|
+
try:
|
|
215
|
+
source = inspect.getsource(function)
|
|
216
|
+
except (TypeError, OSError):
|
|
217
|
+
return None
|
|
218
|
+
source = textwrap.dedent(source)
|
|
219
|
+
statements = ast.parse(source).body
|
|
220
|
+
if len(statements) != 1:
|
|
221
|
+
# TODO: Handle this case
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
body = statements[0]
|
|
225
|
+
if not isinstance(body, ast.FunctionDef):
|
|
226
|
+
return None
|
|
227
|
+
|
|
228
|
+
visitor = ASTVisitor(function, logger, can_eval=can_eval)
|
|
229
|
+
define_types_from_closure(function, visitor)
|
|
230
|
+
define_types_from_signature(function, visitor)
|
|
231
|
+
visitor.visit(body)
|
|
232
|
+
if len(visitor.returns) == 0:
|
|
233
|
+
return type(None)
|
|
234
|
+
|
|
235
|
+
for rt1, rt2 in itertools.pairwise(visitor.returns):
|
|
236
|
+
if rt1 != rt2:
|
|
237
|
+
return typing.Any
|
|
238
|
+
|
|
239
|
+
return visitor.returns[0]
|