otterapi 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- README.md +581 -8
- otterapi/__init__.py +73 -0
- otterapi/cli.py +327 -29
- otterapi/codegen/__init__.py +115 -0
- otterapi/codegen/ast_utils.py +134 -5
- otterapi/codegen/client.py +1271 -0
- otterapi/codegen/codegen.py +1736 -0
- otterapi/codegen/dataframes.py +392 -0
- otterapi/codegen/emitter.py +473 -0
- otterapi/codegen/endpoints.py +2597 -343
- otterapi/codegen/pagination.py +1026 -0
- otterapi/codegen/schema.py +593 -0
- otterapi/codegen/splitting.py +1397 -0
- otterapi/codegen/types.py +1345 -0
- otterapi/codegen/utils.py +180 -1
- otterapi/config.py +1017 -24
- otterapi/exceptions.py +231 -0
- otterapi/openapi/__init__.py +46 -0
- otterapi/openapi/v2/__init__.py +86 -0
- otterapi/openapi/v2/spec.json +1607 -0
- otterapi/openapi/v2/v2.py +1776 -0
- otterapi/openapi/v3/__init__.py +131 -0
- otterapi/openapi/v3/spec.json +1651 -0
- otterapi/openapi/v3/v3.py +1557 -0
- otterapi/openapi/v3_1/__init__.py +133 -0
- otterapi/openapi/v3_1/spec.json +1411 -0
- otterapi/openapi/v3_1/v3_1.py +798 -0
- otterapi/openapi/v3_2/__init__.py +133 -0
- otterapi/openapi/v3_2/spec.json +1666 -0
- otterapi/openapi/v3_2/v3_2.py +777 -0
- otterapi/tests/__init__.py +3 -0
- otterapi/tests/fixtures/__init__.py +455 -0
- otterapi/tests/test_ast_utils.py +680 -0
- otterapi/tests/test_codegen.py +610 -0
- otterapi/tests/test_dataframe.py +1038 -0
- otterapi/tests/test_exceptions.py +493 -0
- otterapi/tests/test_openapi_support.py +616 -0
- otterapi/tests/test_openapi_upgrade.py +215 -0
- otterapi/tests/test_pagination.py +1101 -0
- otterapi/tests/test_splitting_config.py +319 -0
- otterapi/tests/test_splitting_integration.py +427 -0
- otterapi/tests/test_splitting_resolver.py +512 -0
- otterapi/tests/test_splitting_tree.py +525 -0
- otterapi-0.0.6.dist-info/METADATA +627 -0
- otterapi-0.0.6.dist-info/RECORD +48 -0
- {otterapi-0.0.5.dist-info → otterapi-0.0.6.dist-info}/WHEEL +1 -1
- otterapi/codegen/generator.py +0 -358
- otterapi/codegen/openapi_processor.py +0 -27
- otterapi/codegen/type_generator.py +0 -559
- otterapi-0.0.5.dist-info/METADATA +0 -54
- otterapi-0.0.5.dist-info/RECORD +0 -16
- {otterapi-0.0.5.dist-info → otterapi-0.0.6.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,1345 @@
|
|
|
1
|
+
"""Type definitions and generation for OtterAPI code generation.
|
|
2
|
+
|
|
3
|
+
This module provides:
|
|
4
|
+
- Type dataclasses for representing generated types, parameters, responses, and endpoints
|
|
5
|
+
- TypeGenerator for creating Pydantic models from OpenAPI schemas
|
|
6
|
+
- TypeRegistry for managing generated types and their dependencies
|
|
7
|
+
- ModelNameCollector for tracking model usage in generated code
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import ast
|
|
11
|
+
import dataclasses
|
|
12
|
+
from collections.abc import Iterable, Iterator
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from typing import Any, Literal
|
|
15
|
+
from uuid import UUID
|
|
16
|
+
|
|
17
|
+
from pydantic import BaseModel, Field, RootModel
|
|
18
|
+
|
|
19
|
+
from otterapi.codegen.ast_utils import _call, _name, _subscript, _union_expr
|
|
20
|
+
from otterapi.codegen.utils import (
|
|
21
|
+
OpenAPIProcessor,
|
|
22
|
+
sanitize_identifier,
|
|
23
|
+
sanitize_parameter_field_name,
|
|
24
|
+
)
|
|
25
|
+
from otterapi.openapi.v3_2 import Reference, Schema, Type as DataType
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
'Type',
|
|
29
|
+
'Parameter',
|
|
30
|
+
'ResponseInfo',
|
|
31
|
+
'RequestBodyInfo',
|
|
32
|
+
'Endpoint',
|
|
33
|
+
'TypeGenerator',
|
|
34
|
+
'TypeInfo',
|
|
35
|
+
'TypeRegistry',
|
|
36
|
+
'ModelNameCollector',
|
|
37
|
+
'collect_used_model_names',
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
_PRIMITIVE_TYPE_MAP = {
|
|
41
|
+
('string', None): str,
|
|
42
|
+
('string', 'date-time'): datetime,
|
|
43
|
+
('string', 'date'): datetime,
|
|
44
|
+
('string', 'uuid'): UUID,
|
|
45
|
+
('integer', None): int,
|
|
46
|
+
('integer', 'int32'): int,
|
|
47
|
+
('integer', 'int64'): int,
|
|
48
|
+
('number', None): float,
|
|
49
|
+
('number', 'float'): float,
|
|
50
|
+
('number', 'double'): float,
|
|
51
|
+
('boolean', None): bool,
|
|
52
|
+
('null', None): None,
|
|
53
|
+
(None, None): None,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclasses.dataclass
|
|
58
|
+
class Type:
|
|
59
|
+
reference: str | None # reference is None if type is 'primitive'
|
|
60
|
+
name: str | None
|
|
61
|
+
type: Literal['primitive', 'root', 'model']
|
|
62
|
+
annotation_ast: ast.expr | ast.stmt | None = dataclasses.field(default=None)
|
|
63
|
+
implementation_ast: ast.expr | ast.stmt | None = dataclasses.field(default=None)
|
|
64
|
+
dependencies: set[str] = dataclasses.field(default_factory=set)
|
|
65
|
+
implementation_imports: dict[str, set[str]] = dataclasses.field(
|
|
66
|
+
default_factory=dict
|
|
67
|
+
)
|
|
68
|
+
annotation_imports: dict[str, set[str]] = dataclasses.field(default_factory=dict)
|
|
69
|
+
|
|
70
|
+
def __hash__(self):
|
|
71
|
+
"""Make Type hashable based on its name (for use in sets/dicts)."""
|
|
72
|
+
# We only hash based on name since we use name as the key in the types dict
|
|
73
|
+
return hash(self.name)
|
|
74
|
+
|
|
75
|
+
def add_dependency(self, type_: 'Type') -> None:
|
|
76
|
+
self.dependencies.add(type_.name)
|
|
77
|
+
for dep in type_.dependencies:
|
|
78
|
+
self.dependencies.add(dep)
|
|
79
|
+
|
|
80
|
+
def add_implementation_import(self, module: str, name: str | Iterable[str]) -> None:
|
|
81
|
+
# Skip builtins - they don't need to be imported
|
|
82
|
+
if module == 'builtins':
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
if isinstance(name, str):
|
|
86
|
+
name = [name]
|
|
87
|
+
|
|
88
|
+
if module not in self.implementation_imports:
|
|
89
|
+
self.implementation_imports[module] = set()
|
|
90
|
+
|
|
91
|
+
for n in name:
|
|
92
|
+
self.implementation_imports[module].add(n)
|
|
93
|
+
|
|
94
|
+
def add_annotation_import(self, module: str, name: str | Iterable[str]) -> None:
|
|
95
|
+
# Skip builtins - they don't need to be imported
|
|
96
|
+
if module == 'builtins':
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
if isinstance(name, str):
|
|
100
|
+
name = [name]
|
|
101
|
+
|
|
102
|
+
if module not in self.annotation_imports:
|
|
103
|
+
self.annotation_imports[module] = set()
|
|
104
|
+
|
|
105
|
+
for n in name:
|
|
106
|
+
self.annotation_imports[module].add(n)
|
|
107
|
+
|
|
108
|
+
def copy_imports_from_sub_types(self, types: Iterable['Type']):
|
|
109
|
+
for t in types:
|
|
110
|
+
for module, names in t.annotation_imports.items():
|
|
111
|
+
self.add_annotation_import(module, names)
|
|
112
|
+
|
|
113
|
+
for module, names in t.implementation_imports.items():
|
|
114
|
+
self.add_implementation_import(module, names)
|
|
115
|
+
|
|
116
|
+
def __eq__(self, other):
|
|
117
|
+
"""Deep comparison of Type objects, including AST nodes."""
|
|
118
|
+
if not isinstance(other, Type):
|
|
119
|
+
return False
|
|
120
|
+
|
|
121
|
+
# Compare simple fields
|
|
122
|
+
if (
|
|
123
|
+
self.reference != other.reference
|
|
124
|
+
or self.name != other.name
|
|
125
|
+
or self.type != other.type
|
|
126
|
+
):
|
|
127
|
+
return False
|
|
128
|
+
|
|
129
|
+
# Compare AST nodes by dumping them to strings
|
|
130
|
+
# Compare annotation AST (can be None)
|
|
131
|
+
if self.annotation_ast is None and other.annotation_ast is None:
|
|
132
|
+
pass # Both None, equal
|
|
133
|
+
elif self.annotation_ast is None or other.annotation_ast is None:
|
|
134
|
+
return False # One is None, other isn't
|
|
135
|
+
else:
|
|
136
|
+
if ast.dump(self.annotation_ast) != ast.dump(other.annotation_ast):
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
# Compare implementation AST (can be None)
|
|
140
|
+
if self.implementation_ast is None and other.implementation_ast is None:
|
|
141
|
+
pass # Both None, equal
|
|
142
|
+
elif self.implementation_ast is None or other.implementation_ast is None:
|
|
143
|
+
return False # One is None, other isn't
|
|
144
|
+
else:
|
|
145
|
+
if ast.dump(self.implementation_ast) != ast.dump(other.implementation_ast):
|
|
146
|
+
return False
|
|
147
|
+
|
|
148
|
+
# Compare imports and dependencies
|
|
149
|
+
if (
|
|
150
|
+
self.dependencies != other.dependencies
|
|
151
|
+
or self.implementation_imports != other.implementation_imports
|
|
152
|
+
or self.annotation_imports != other.annotation_imports
|
|
153
|
+
):
|
|
154
|
+
return False
|
|
155
|
+
|
|
156
|
+
return True
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@dataclasses.dataclass
|
|
160
|
+
class Parameter:
|
|
161
|
+
name: str
|
|
162
|
+
name_sanitized: str
|
|
163
|
+
location: Literal['query', 'path', 'header', 'cookie', 'body']
|
|
164
|
+
required: bool
|
|
165
|
+
type: Type | None = None
|
|
166
|
+
description: str | None = None
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@dataclasses.dataclass
|
|
170
|
+
class ResponseInfo:
|
|
171
|
+
"""Information about a response for a specific status code.
|
|
172
|
+
|
|
173
|
+
Attributes:
|
|
174
|
+
status_code: The HTTP status code for this response.
|
|
175
|
+
content_type: The content type (e.g., 'application/json', 'application/octet-stream').
|
|
176
|
+
type: The Type object for JSON responses, or None for raw responses.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
status_code: int
|
|
180
|
+
content_type: str
|
|
181
|
+
type: Type | None = None
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def is_json(self) -> bool:
|
|
185
|
+
"""Check if this is a JSON response."""
|
|
186
|
+
return self.content_type in (
|
|
187
|
+
'application/json',
|
|
188
|
+
'text/json',
|
|
189
|
+
) or self.content_type.endswith('+json')
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def is_binary(self) -> bool:
|
|
193
|
+
"""Check if this is a binary response (file download)."""
|
|
194
|
+
binary_types = (
|
|
195
|
+
'application/octet-stream',
|
|
196
|
+
'application/pdf',
|
|
197
|
+
'application/zip',
|
|
198
|
+
'application/gzip',
|
|
199
|
+
'application/x-tar',
|
|
200
|
+
'application/x-rar-compressed',
|
|
201
|
+
)
|
|
202
|
+
binary_prefixes = ('image/', 'audio/', 'video/', 'application/vnd.')
|
|
203
|
+
return self.content_type in binary_types or any(
|
|
204
|
+
self.content_type.startswith(p) for p in binary_prefixes
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def is_text(self) -> bool:
|
|
209
|
+
"""Check if this is a plain text response."""
|
|
210
|
+
return self.content_type.startswith('text/') and not self.is_json
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def is_raw(self) -> bool:
|
|
214
|
+
"""Check if this is an unknown content type that should return the raw httpx.Response."""
|
|
215
|
+
return not (self.is_json or self.is_binary or self.is_text)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@dataclasses.dataclass
|
|
219
|
+
class RequestBodyInfo:
|
|
220
|
+
"""Information about a request body including its content type.
|
|
221
|
+
|
|
222
|
+
Attributes:
|
|
223
|
+
content_type: The content type (e.g., 'application/json', 'multipart/form-data').
|
|
224
|
+
type: The Type object for the body schema, or None if no schema.
|
|
225
|
+
required: Whether the request body is required.
|
|
226
|
+
description: Optional description of the request body.
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
content_type: str
|
|
230
|
+
type: Type | None = None
|
|
231
|
+
required: bool = False
|
|
232
|
+
description: str | None = None
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
def is_json(self) -> bool:
|
|
236
|
+
"""Check if this is a JSON request body."""
|
|
237
|
+
return self.content_type in (
|
|
238
|
+
'application/json',
|
|
239
|
+
'text/json',
|
|
240
|
+
) or self.content_type.endswith('+json')
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def is_form(self) -> bool:
|
|
244
|
+
"""Check if this is a form-encoded request body."""
|
|
245
|
+
return self.content_type == 'application/x-www-form-urlencoded'
|
|
246
|
+
|
|
247
|
+
@property
|
|
248
|
+
def is_multipart(self) -> bool:
|
|
249
|
+
"""Check if this is a multipart form data request body."""
|
|
250
|
+
return self.content_type == 'multipart/form-data'
|
|
251
|
+
|
|
252
|
+
@property
|
|
253
|
+
def is_binary(self) -> bool:
|
|
254
|
+
"""Check if this is a binary request body."""
|
|
255
|
+
return self.content_type in ('application/octet-stream',)
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def httpx_param_name(self) -> str:
|
|
259
|
+
"""Get the httpx parameter name for this content type.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
The appropriate httpx parameter name: 'json', 'data', 'files', or 'content'.
|
|
263
|
+
"""
|
|
264
|
+
if self.is_json:
|
|
265
|
+
return 'json'
|
|
266
|
+
elif self.is_form:
|
|
267
|
+
return 'data'
|
|
268
|
+
elif self.is_multipart:
|
|
269
|
+
return 'files'
|
|
270
|
+
elif self.is_binary:
|
|
271
|
+
return 'content'
|
|
272
|
+
else:
|
|
273
|
+
return 'content'
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@dataclasses.dataclass
|
|
277
|
+
class Endpoint:
|
|
278
|
+
"""Represents a generated API endpoint with sync and async functions."""
|
|
279
|
+
|
|
280
|
+
# AST nodes
|
|
281
|
+
sync_ast: ast.FunctionDef
|
|
282
|
+
async_ast: ast.AsyncFunctionDef
|
|
283
|
+
|
|
284
|
+
# Function names
|
|
285
|
+
sync_fn_name: str
|
|
286
|
+
async_fn_name: str
|
|
287
|
+
|
|
288
|
+
# Endpoint metadata
|
|
289
|
+
name: str
|
|
290
|
+
method: str = ''
|
|
291
|
+
path: str = ''
|
|
292
|
+
description: str | None = None
|
|
293
|
+
tags: list[str] | None = None # OpenAPI tags for module splitting
|
|
294
|
+
|
|
295
|
+
# Parameters and body
|
|
296
|
+
parameters: list['Parameter'] | None = None
|
|
297
|
+
request_body: 'RequestBodyInfo | None' = None
|
|
298
|
+
|
|
299
|
+
# Response info
|
|
300
|
+
response_type: 'Type | None' = None
|
|
301
|
+
response_infos: list['ResponseInfo'] | None = None
|
|
302
|
+
|
|
303
|
+
# Imports needed
|
|
304
|
+
imports: dict[str, set[str]] = dataclasses.field(default_factory=dict)
|
|
305
|
+
|
|
306
|
+
@property
|
|
307
|
+
def fn(self) -> ast.FunctionDef:
|
|
308
|
+
"""Alias for sync_ast."""
|
|
309
|
+
return self.sync_ast
|
|
310
|
+
|
|
311
|
+
@property
|
|
312
|
+
def async_fn(self) -> ast.AsyncFunctionDef:
|
|
313
|
+
"""Alias for async_ast."""
|
|
314
|
+
return self.async_ast
|
|
315
|
+
|
|
316
|
+
def add_imports(self, imports: list[dict[str, set[str]]]):
|
|
317
|
+
for imports_ in imports:
|
|
318
|
+
for module, names in imports_.items():
|
|
319
|
+
if module not in self.imports:
|
|
320
|
+
self.imports[module] = set()
|
|
321
|
+
self.imports[module].update(names)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@dataclasses.dataclass
|
|
325
|
+
class TypeGenerator(OpenAPIProcessor):
|
|
326
|
+
types: dict[str, Type] = dataclasses.field(default_factory=dict)
|
|
327
|
+
|
|
328
|
+
def add_type(self, type_: Type, base_name: str | None = None) -> Type:
|
|
329
|
+
"""Add a type to the registry. If a type with the same name but different definition
|
|
330
|
+
already exists, generate a unique name using the base_name prefix.
|
|
331
|
+
Returns the type (potentially with a modified name).
|
|
332
|
+
"""
|
|
333
|
+
# Skip types without names (primitive types, inline types, etc.)
|
|
334
|
+
if not type_.name:
|
|
335
|
+
return type_
|
|
336
|
+
|
|
337
|
+
# If type with same name and same definition exists, just return the existing one
|
|
338
|
+
if type_.name in self.types:
|
|
339
|
+
existing = self.types[type_.name]
|
|
340
|
+
if existing == type_:
|
|
341
|
+
# Same type already registered, return the existing one
|
|
342
|
+
# This avoids creating Detail20, Detail21 when they're identical
|
|
343
|
+
return existing
|
|
344
|
+
else:
|
|
345
|
+
# Different definition with same name - generate a unique name
|
|
346
|
+
if base_name:
|
|
347
|
+
# Use base_name as prefix for endpoint-specific types
|
|
348
|
+
unique_name = f'{base_name}{type_.name}'
|
|
349
|
+
if unique_name not in self.types:
|
|
350
|
+
type_.name = unique_name
|
|
351
|
+
type_.annotation_ast = _name(unique_name)
|
|
352
|
+
# Update the implementation_ast name if it's a ClassDef
|
|
353
|
+
if isinstance(type_.implementation_ast, ast.ClassDef):
|
|
354
|
+
type_.implementation_ast.name = unique_name
|
|
355
|
+
else:
|
|
356
|
+
# Check if even the base_name version is the same
|
|
357
|
+
if (
|
|
358
|
+
unique_name in self.types
|
|
359
|
+
and self.types[unique_name] == type_
|
|
360
|
+
):
|
|
361
|
+
return self.types[unique_name]
|
|
362
|
+
# If even that exists with different def, add a counter
|
|
363
|
+
counter = 1
|
|
364
|
+
while f'{unique_name}{counter}' in self.types:
|
|
365
|
+
candidate = f'{unique_name}{counter}'
|
|
366
|
+
if self.types[candidate] == type_:
|
|
367
|
+
return self.types[candidate]
|
|
368
|
+
counter += 1
|
|
369
|
+
unique_name = f'{unique_name}{counter}'
|
|
370
|
+
type_.name = unique_name
|
|
371
|
+
type_.annotation_ast = _name(unique_name)
|
|
372
|
+
if isinstance(type_.implementation_ast, ast.ClassDef):
|
|
373
|
+
type_.implementation_ast.name = unique_name
|
|
374
|
+
else:
|
|
375
|
+
# No base_name provided, just add a counter
|
|
376
|
+
counter = 1
|
|
377
|
+
original_name = type_.name
|
|
378
|
+
while f'{original_name}{counter}' in self.types:
|
|
379
|
+
candidate = f'{original_name}{counter}'
|
|
380
|
+
if self.types[candidate] == type_:
|
|
381
|
+
# Found identical type with numbered name
|
|
382
|
+
return self.types[candidate]
|
|
383
|
+
counter += 1
|
|
384
|
+
unique_name = f'{original_name}{counter}'
|
|
385
|
+
type_.name = unique_name
|
|
386
|
+
type_.annotation_ast = _name(unique_name)
|
|
387
|
+
if isinstance(type_.implementation_ast, ast.ClassDef):
|
|
388
|
+
type_.implementation_ast.name = unique_name
|
|
389
|
+
|
|
390
|
+
self.types[type_.name] = type_
|
|
391
|
+
return type_
|
|
392
|
+
|
|
393
|
+
def _resolve_reference(self, reference: Reference | Schema) -> tuple[Schema, str]:
|
|
394
|
+
if hasattr(reference, 'ref'):
|
|
395
|
+
if not reference.ref.startswith('#/components/schemas/'):
|
|
396
|
+
raise ValueError(f'Unsupported reference format: {reference.ref}')
|
|
397
|
+
|
|
398
|
+
schema_name = reference.ref.split('/')[-1]
|
|
399
|
+
schemas = self.openapi.components.schemas
|
|
400
|
+
|
|
401
|
+
if schema_name not in schemas:
|
|
402
|
+
raise ValueError(
|
|
403
|
+
f"Referenced schema '{schema_name}' not found in components.schemas"
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
return schemas[schema_name], sanitize_identifier(schema_name)
|
|
407
|
+
return reference, sanitize_identifier(
|
|
408
|
+
reference.title
|
|
409
|
+
) if reference.title else None
|
|
410
|
+
|
|
411
|
+
def _create_enum_type(
|
|
412
|
+
self,
|
|
413
|
+
schema: Schema,
|
|
414
|
+
name: str | None = None,
|
|
415
|
+
base_name: str | None = None,
|
|
416
|
+
field_name: str | None = None,
|
|
417
|
+
) -> Type:
|
|
418
|
+
"""Create an Enum class for schema with enum values.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
schema: The schema containing enum values.
|
|
422
|
+
name: Optional explicit name for the enum.
|
|
423
|
+
base_name: Optional base name prefix (e.g., parent model name).
|
|
424
|
+
field_name: Optional field name this enum is used for (e.g., 'status').
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
A Type representing the generated Enum class.
|
|
428
|
+
"""
|
|
429
|
+
# Determine enum name - prefer schema title, then derive from context
|
|
430
|
+
enum_name = name or (
|
|
431
|
+
sanitize_identifier(schema.title) if schema.title else None
|
|
432
|
+
)
|
|
433
|
+
if not enum_name:
|
|
434
|
+
# Generate name from field_name with base_name context for uniqueness
|
|
435
|
+
if field_name:
|
|
436
|
+
# e.g., Pet + status -> 'PetStatus', Order + status -> 'OrderStatus'
|
|
437
|
+
# Capitalize the field part to ensure proper PascalCase
|
|
438
|
+
field_part = sanitize_identifier(field_name)
|
|
439
|
+
# Ensure first letter is capitalized for PascalCase
|
|
440
|
+
if field_part:
|
|
441
|
+
field_part = field_part[0].upper() + field_part[1:]
|
|
442
|
+
if base_name:
|
|
443
|
+
base_part = sanitize_identifier(base_name)
|
|
444
|
+
enum_name = f'{base_part}{field_part}'
|
|
445
|
+
else:
|
|
446
|
+
enum_name = field_part
|
|
447
|
+
elif base_name:
|
|
448
|
+
enum_name = f'{sanitize_identifier(base_name)}Enum'
|
|
449
|
+
else:
|
|
450
|
+
enum_name = 'AutoEnum'
|
|
451
|
+
|
|
452
|
+
# Create a hashable key from enum values to detect duplicates
|
|
453
|
+
enum_values_key = tuple(sorted(str(v) for v in schema.enum if v is not None))
|
|
454
|
+
|
|
455
|
+
# Check if an identical enum already exists
|
|
456
|
+
for existing_name, existing_type in self.types.items():
|
|
457
|
+
if existing_type.type == 'model' and isinstance(
|
|
458
|
+
existing_type.implementation_ast, ast.ClassDef
|
|
459
|
+
):
|
|
460
|
+
# Check if it's an Enum class with same values
|
|
461
|
+
existing_class = existing_type.implementation_ast
|
|
462
|
+
if any(
|
|
463
|
+
isinstance(base, ast.Name) and base.id == 'Enum'
|
|
464
|
+
for base in existing_class.bases
|
|
465
|
+
):
|
|
466
|
+
# Extract values from existing enum
|
|
467
|
+
existing_values = []
|
|
468
|
+
for node in existing_class.body:
|
|
469
|
+
if isinstance(node, ast.Assign) and node.value:
|
|
470
|
+
if isinstance(node.value, ast.Constant):
|
|
471
|
+
existing_values.append(str(node.value.value))
|
|
472
|
+
if tuple(sorted(existing_values)) == enum_values_key:
|
|
473
|
+
# Reuse existing enum
|
|
474
|
+
return existing_type
|
|
475
|
+
|
|
476
|
+
# Ensure the name is unique
|
|
477
|
+
if enum_name in self.types:
|
|
478
|
+
counter = 1
|
|
479
|
+
original_name = enum_name
|
|
480
|
+
while f'{original_name}{counter}' in self.types:
|
|
481
|
+
counter += 1
|
|
482
|
+
enum_name = f'{original_name}{counter}'
|
|
483
|
+
|
|
484
|
+
# Build enum members: NAME = 'value'
|
|
485
|
+
# For string enums, use the value as the member name (sanitized)
|
|
486
|
+
enum_body = []
|
|
487
|
+
seen_member_names: dict[str, int] = {} # Track seen names to handle duplicates
|
|
488
|
+
for value in schema.enum:
|
|
489
|
+
if value is None:
|
|
490
|
+
continue # Skip None values in enums
|
|
491
|
+
# Create a valid Python identifier for the enum member
|
|
492
|
+
if isinstance(value, str):
|
|
493
|
+
member_name = sanitize_identifier(value).upper()
|
|
494
|
+
# If the sanitized name starts with a digit, prefix with underscore
|
|
495
|
+
if member_name and member_name[0].isdigit():
|
|
496
|
+
member_name = f'_{member_name}'
|
|
497
|
+
else:
|
|
498
|
+
# For numeric enums, create VALUE_X names
|
|
499
|
+
member_name = f'VALUE_{value}'
|
|
500
|
+
|
|
501
|
+
# Handle duplicate member names (e.g., 'mesoderm' and 'Mesoderm' both -> 'MESODERM')
|
|
502
|
+
if member_name in seen_member_names:
|
|
503
|
+
seen_member_names[member_name] += 1
|
|
504
|
+
member_name = f'{member_name}_{seen_member_names[member_name]}'
|
|
505
|
+
else:
|
|
506
|
+
seen_member_names[member_name] = 0
|
|
507
|
+
|
|
508
|
+
enum_body.append(
|
|
509
|
+
ast.Assign(
|
|
510
|
+
targets=[ast.Name(id=member_name, ctx=ast.Store())],
|
|
511
|
+
value=ast.Constant(value=value),
|
|
512
|
+
)
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
# If no valid members, fall back to Literal
|
|
516
|
+
if not enum_body:
|
|
517
|
+
return self._create_literal_type(schema)
|
|
518
|
+
|
|
519
|
+
# Create the Enum class
|
|
520
|
+
# class EnumName(str, Enum): # str mixin for string enums
|
|
521
|
+
# MEMBER = 'value'
|
|
522
|
+
bases = (
|
|
523
|
+
[_name('str'), _name('Enum')]
|
|
524
|
+
if schema.type and schema.type.value == 'string'
|
|
525
|
+
else [_name('Enum')]
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
enum_class = ast.ClassDef(
|
|
529
|
+
name=enum_name,
|
|
530
|
+
bases=bases,
|
|
531
|
+
keywords=[],
|
|
532
|
+
body=enum_body,
|
|
533
|
+
decorator_list=[],
|
|
534
|
+
type_params=[],
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
type_ = Type(
|
|
538
|
+
reference=None,
|
|
539
|
+
name=enum_name,
|
|
540
|
+
annotation_ast=_name(enum_name),
|
|
541
|
+
implementation_ast=enum_class,
|
|
542
|
+
type='model', # Treat as model so it gets included in models.py
|
|
543
|
+
)
|
|
544
|
+
type_.add_implementation_import('enum', 'Enum')
|
|
545
|
+
|
|
546
|
+
# Register the type
|
|
547
|
+
self.types[enum_name] = type_
|
|
548
|
+
|
|
549
|
+
return type_
|
|
550
|
+
|
|
551
|
+
def _create_literal_type(self, schema: Schema) -> Type:
|
|
552
|
+
"""Create a Literal type for enum values (fallback)."""
|
|
553
|
+
literal_values = [ast.Constant(value=v) for v in schema.enum]
|
|
554
|
+
type_ = Type(
|
|
555
|
+
None,
|
|
556
|
+
sanitize_identifier(schema.title) if schema.title else None,
|
|
557
|
+
annotation_ast=_subscript(
|
|
558
|
+
'Literal', ast.Tuple(elts=literal_values, ctx=ast.Load())
|
|
559
|
+
),
|
|
560
|
+
implementation_ast=None,
|
|
561
|
+
type='primitive',
|
|
562
|
+
)
|
|
563
|
+
type_.add_annotation_import('typing', 'Literal')
|
|
564
|
+
return type_
|
|
565
|
+
|
|
566
|
+
def _is_nullable(self, schema: Schema) -> bool:
|
|
567
|
+
"""Check if a schema represents a nullable type.
|
|
568
|
+
|
|
569
|
+
In OpenAPI 3.1+, nullable is expressed via type arrays like ["string", "null"].
|
|
570
|
+
"""
|
|
571
|
+
if isinstance(schema.type, list):
|
|
572
|
+
return any(
|
|
573
|
+
t == DataType.null or (hasattr(t, 'value') and t.value == 'null')
|
|
574
|
+
for t in schema.type
|
|
575
|
+
)
|
|
576
|
+
return False
|
|
577
|
+
|
|
578
|
+
def _get_non_null_type(self, schema: Schema) -> DataType | None:
|
|
579
|
+
"""Extract the non-null type from a potentially nullable schema."""
|
|
580
|
+
if isinstance(schema.type, list):
|
|
581
|
+
for t in schema.type:
|
|
582
|
+
if t != DataType.null and (
|
|
583
|
+
not hasattr(t, 'value') or t.value != 'null'
|
|
584
|
+
):
|
|
585
|
+
return t
|
|
586
|
+
return None
|
|
587
|
+
return schema.type
|
|
588
|
+
|
|
589
|
+
def _make_nullable_type(self, base_type: Type) -> Type:
|
|
590
|
+
"""Wrap a type annotation to make it nullable (T | None)."""
|
|
591
|
+
nullable_ast = _union_expr([base_type.annotation_ast, ast.Constant(value=None)])
|
|
592
|
+
|
|
593
|
+
type_ = Type(
|
|
594
|
+
reference=base_type.reference,
|
|
595
|
+
name=base_type.name,
|
|
596
|
+
annotation_ast=nullable_ast,
|
|
597
|
+
implementation_ast=base_type.implementation_ast,
|
|
598
|
+
type=base_type.type,
|
|
599
|
+
dependencies=base_type.dependencies.copy(),
|
|
600
|
+
implementation_imports=base_type.implementation_imports.copy(),
|
|
601
|
+
annotation_imports=base_type.annotation_imports.copy(),
|
|
602
|
+
)
|
|
603
|
+
return type_
|
|
604
|
+
|
|
605
|
+
def _get_primitive_type_ast(
|
|
606
|
+
self,
|
|
607
|
+
schema: Schema,
|
|
608
|
+
base_name: str | None = None,
|
|
609
|
+
field_name: str | None = None,
|
|
610
|
+
) -> Type:
|
|
611
|
+
# Handle enum types - generate Enum class
|
|
612
|
+
if schema.enum:
|
|
613
|
+
return self._create_enum_type(
|
|
614
|
+
schema, base_name=base_name, field_name=field_name
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
# Check for nullable type (type array with null)
|
|
618
|
+
is_nullable = self._is_nullable(schema)
|
|
619
|
+
actual_type = self._get_non_null_type(schema)
|
|
620
|
+
|
|
621
|
+
# Fix: schema.type is a Type enum, need to use .value for string lookup
|
|
622
|
+
type_value = actual_type.value if actual_type else None
|
|
623
|
+
key = (type_value, schema.format or None)
|
|
624
|
+
mapped = _PRIMITIVE_TYPE_MAP.get(key, Any)
|
|
625
|
+
|
|
626
|
+
type_ = Type(
|
|
627
|
+
None,
|
|
628
|
+
sanitize_identifier(schema.title) if schema.title else None,
|
|
629
|
+
annotation_ast=_name(mapped.__name__ if mapped is not None else 'None'),
|
|
630
|
+
implementation_ast=None,
|
|
631
|
+
type='primitive',
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
if mapped is not None and mapped.__module__ != 'builtins':
|
|
635
|
+
type_.add_annotation_import(mapped.__module__, mapped.__name__)
|
|
636
|
+
|
|
637
|
+
# Wrap in Union with None if nullable
|
|
638
|
+
if is_nullable:
|
|
639
|
+
type_ = self._make_nullable_type(type_)
|
|
640
|
+
|
|
641
|
+
return type_
|
|
642
|
+
|
|
643
|
+
def _create_pydantic_field(
|
|
644
|
+
self,
|
|
645
|
+
field_name: str,
|
|
646
|
+
field_schema: Schema,
|
|
647
|
+
field_type: Type,
|
|
648
|
+
is_required: bool = False,
|
|
649
|
+
is_nullable: bool = False,
|
|
650
|
+
) -> str:
|
|
651
|
+
if hasattr(field_schema, 'ref'):
|
|
652
|
+
field_schema, _ = self._resolve_reference(field_schema)
|
|
653
|
+
|
|
654
|
+
field_keywords = list()
|
|
655
|
+
|
|
656
|
+
sanitized_field_name = sanitize_parameter_field_name(field_name)
|
|
657
|
+
|
|
658
|
+
# Determine the annotation - wrap in Union with None if nullable
|
|
659
|
+
annotation_ast = field_type.annotation_ast
|
|
660
|
+
if is_nullable and not self._type_already_nullable(field_type):
|
|
661
|
+
annotation_ast = _union_expr(
|
|
662
|
+
[field_type.annotation_ast, ast.Constant(value=None)]
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
value = None
|
|
666
|
+
if field_schema.default is not None and isinstance(
|
|
667
|
+
field_schema.default, (str, int, float, bool)
|
|
668
|
+
):
|
|
669
|
+
field_keywords.append(
|
|
670
|
+
ast.keyword(arg='default', value=ast.Constant(field_schema.default))
|
|
671
|
+
)
|
|
672
|
+
elif field_schema.default is None and not is_required:
|
|
673
|
+
# Only add default=None for optional (not required) fields
|
|
674
|
+
# Nullable but required fields should NOT have a default
|
|
675
|
+
field_keywords.append(ast.keyword(arg='default', value=ast.Constant(None)))
|
|
676
|
+
|
|
677
|
+
if sanitized_field_name != field_name:
|
|
678
|
+
field_keywords.append(
|
|
679
|
+
ast.keyword(
|
|
680
|
+
arg='alias',
|
|
681
|
+
value=ast.Constant(field_name), # original name before adding _
|
|
682
|
+
)
|
|
683
|
+
)
|
|
684
|
+
field_name = sanitized_field_name
|
|
685
|
+
|
|
686
|
+
if field_keywords:
|
|
687
|
+
value = _call(
|
|
688
|
+
func=_name(Field.__name__),
|
|
689
|
+
keywords=field_keywords,
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
field_type.add_implementation_import(
|
|
693
|
+
module=Field.__module__, name=Field.__name__
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
return ast.AnnAssign(
|
|
697
|
+
target=_name(field_name),
|
|
698
|
+
annotation=annotation_ast,
|
|
699
|
+
value=value,
|
|
700
|
+
simple=1,
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
def _type_already_nullable(self, type_: Type) -> bool:
|
|
704
|
+
"""Check if a type annotation already includes None."""
|
|
705
|
+
if isinstance(type_.annotation_ast, ast.Subscript):
|
|
706
|
+
# Check if it's Union[..., None]
|
|
707
|
+
if isinstance(type_.annotation_ast.value, ast.Name):
|
|
708
|
+
if type_.annotation_ast.value.id == 'Union':
|
|
709
|
+
if isinstance(type_.annotation_ast.slice, ast.Tuple):
|
|
710
|
+
for elt in type_.annotation_ast.slice.elts:
|
|
711
|
+
if isinstance(elt, ast.Constant) and elt.value is None:
|
|
712
|
+
return True
|
|
713
|
+
return False
|
|
714
|
+
|
|
715
|
+
def _create_pydantic_root_model(
|
|
716
|
+
self,
|
|
717
|
+
schema: Schema,
|
|
718
|
+
item_type: Type | None = None,
|
|
719
|
+
name: str | None = None,
|
|
720
|
+
base_name: str | None = None,
|
|
721
|
+
) -> Type:
|
|
722
|
+
name = (
|
|
723
|
+
name
|
|
724
|
+
or base_name
|
|
725
|
+
or (sanitize_identifier(schema.title) if schema.title else None)
|
|
726
|
+
)
|
|
727
|
+
if not name:
|
|
728
|
+
raise ValueError('Root model must have a name')
|
|
729
|
+
|
|
730
|
+
model = ast.ClassDef(
|
|
731
|
+
name=name,
|
|
732
|
+
bases=[_subscript(RootModel.__name__, item_type.annotation_ast)],
|
|
733
|
+
keywords=[],
|
|
734
|
+
body=[ast.Pass()],
|
|
735
|
+
decorator_list=[],
|
|
736
|
+
type_params=[],
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
type_ = Type(
|
|
740
|
+
reference=None,
|
|
741
|
+
name=name,
|
|
742
|
+
annotation_ast=_name(name),
|
|
743
|
+
implementation_ast=model,
|
|
744
|
+
type='root',
|
|
745
|
+
)
|
|
746
|
+
type_.add_implementation_import(
|
|
747
|
+
module=RootModel.__module__, name=RootModel.__name__
|
|
748
|
+
)
|
|
749
|
+
type_.copy_imports_from_sub_types([item_type] if item_type else [])
|
|
750
|
+
if item_type is not None:
|
|
751
|
+
type_.add_dependency(item_type)
|
|
752
|
+
type_ = self.add_type(type_, base_name=base_name)
|
|
753
|
+
|
|
754
|
+
return type_
|
|
755
|
+
|
|
756
|
+
def _create_pydantic_model(
|
|
757
|
+
self, schema: Schema, name: str | None = None, base_name: str | None = None
|
|
758
|
+
) -> Type:
|
|
759
|
+
base_bases = []
|
|
760
|
+
if schema.allOf:
|
|
761
|
+
for base_schema in schema.allOf:
|
|
762
|
+
base = self._create_object_type(schema=base_schema, base_name=base_name)
|
|
763
|
+
base_bases.append(base)
|
|
764
|
+
|
|
765
|
+
if schema.anyOf or schema.oneOf:
|
|
766
|
+
# Use schema_to_type for each variant to properly handle primitives, objects, etc.
|
|
767
|
+
types_ = [
|
|
768
|
+
self.schema_to_type(t, base_name=base_name)
|
|
769
|
+
for t in (schema.anyOf or schema.oneOf)
|
|
770
|
+
]
|
|
771
|
+
|
|
772
|
+
union_type = Type(
|
|
773
|
+
reference=None,
|
|
774
|
+
name=None, # Union type doesn't need a name, it's used inline
|
|
775
|
+
annotation_ast=_union_expr(types=[t.annotation_ast for t in types_]),
|
|
776
|
+
implementation_ast=None,
|
|
777
|
+
type='primitive',
|
|
778
|
+
)
|
|
779
|
+
union_type.copy_imports_from_sub_types(types_)
|
|
780
|
+
return union_type
|
|
781
|
+
|
|
782
|
+
name = name or (
|
|
783
|
+
sanitize_identifier(schema.title) if schema.title else 'UnnamedModel'
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
bases = [b.name for b in base_bases] or [BaseModel.__name__]
|
|
787
|
+
bases = [_name(base) for base in bases]
|
|
788
|
+
|
|
789
|
+
body = []
|
|
790
|
+
field_types = []
|
|
791
|
+
# Fix: Get the required fields from the parent schema's required array
|
|
792
|
+
required_fields = set(schema.required or [])
|
|
793
|
+
for property_name, property_schema in (schema.properties or {}).items():
|
|
794
|
+
# Resolve reference to check for nullable
|
|
795
|
+
resolved_schema = property_schema
|
|
796
|
+
if hasattr(property_schema, 'ref') and property_schema.ref:
|
|
797
|
+
resolved_schema, _ = self._resolve_reference(property_schema)
|
|
798
|
+
|
|
799
|
+
# Check if field is nullable (type array with null)
|
|
800
|
+
is_nullable = (
|
|
801
|
+
self._is_nullable(resolved_schema) if resolved_schema else False
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
type_ = self.schema_to_type(
|
|
805
|
+
property_schema, base_name=base_name, field_name=property_name
|
|
806
|
+
)
|
|
807
|
+
is_required = property_name in required_fields
|
|
808
|
+
field = self._create_pydantic_field(
|
|
809
|
+
property_name, property_schema, type_, is_required, is_nullable
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
body.append(field)
|
|
813
|
+
field_types.append(type_)
|
|
814
|
+
|
|
815
|
+
# Add deprecation docstring if schema is deprecated
|
|
816
|
+
if schema.deprecated:
|
|
817
|
+
deprecation_doc = ast.Expr(
|
|
818
|
+
value=ast.Constant(
|
|
819
|
+
value=f'{name} is deprecated.\n\n.. deprecated::\n This model is deprecated.'
|
|
820
|
+
)
|
|
821
|
+
)
|
|
822
|
+
body = [deprecation_doc] + body if body else [deprecation_doc]
|
|
823
|
+
|
|
824
|
+
model = ast.ClassDef(
|
|
825
|
+
name=name,
|
|
826
|
+
bases=bases,
|
|
827
|
+
keywords=[],
|
|
828
|
+
body=body or [ast.Pass()],
|
|
829
|
+
decorator_list=[],
|
|
830
|
+
type_params=[],
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
type_ = Type(
|
|
834
|
+
reference=None,
|
|
835
|
+
name=name,
|
|
836
|
+
annotation_ast=_name(name),
|
|
837
|
+
implementation_ast=model,
|
|
838
|
+
dependencies=set(),
|
|
839
|
+
type='model',
|
|
840
|
+
)
|
|
841
|
+
|
|
842
|
+
# Add base class dependencies
|
|
843
|
+
if base_bases:
|
|
844
|
+
for base in base_bases:
|
|
845
|
+
type_.add_dependency(base)
|
|
846
|
+
|
|
847
|
+
# Add field type dependencies
|
|
848
|
+
for field_type in field_types:
|
|
849
|
+
if field_type.name:
|
|
850
|
+
type_.dependencies.add(field_type.name)
|
|
851
|
+
type_.dependencies.update(field_type.dependencies)
|
|
852
|
+
|
|
853
|
+
type_.add_implementation_import(
|
|
854
|
+
module=BaseModel.__module__, name=BaseModel.__name__
|
|
855
|
+
)
|
|
856
|
+
type_.add_implementation_import(module=Field.__module__, name=Field.__name__)
|
|
857
|
+
type_.copy_imports_from_sub_types(field_types)
|
|
858
|
+
|
|
859
|
+
type_ = self.add_type(type_, base_name=base_name)
|
|
860
|
+
return type_
|
|
861
|
+
|
|
862
|
+
def _create_array_type(
|
|
863
|
+
self, schema: Schema, name: str | None = None, base_name: str | None = None
|
|
864
|
+
) -> Type:
|
|
865
|
+
if schema.type != DataType.array:
|
|
866
|
+
raise ValueError('Schema is not an array')
|
|
867
|
+
|
|
868
|
+
if not schema.items:
|
|
869
|
+
type_ = Type(
|
|
870
|
+
None,
|
|
871
|
+
None,
|
|
872
|
+
_subscript(
|
|
873
|
+
list.__name__,
|
|
874
|
+
ast.Name(id=Any.__name__, ctx=ast.Load()),
|
|
875
|
+
),
|
|
876
|
+
'primitive',
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
type_.add_annotation_import(module=list.__module__, name=list.__name__)
|
|
880
|
+
type_.add_annotation_import(module=Any.__module__, name=Any.__name__)
|
|
881
|
+
|
|
882
|
+
return type_
|
|
883
|
+
|
|
884
|
+
item_type = self.schema_to_type(schema.items, base_name=base_name)
|
|
885
|
+
|
|
886
|
+
type_ = Type(
|
|
887
|
+
None,
|
|
888
|
+
None,
|
|
889
|
+
annotation_ast=_subscript(
|
|
890
|
+
list.__name__,
|
|
891
|
+
item_type.annotation_ast,
|
|
892
|
+
),
|
|
893
|
+
implementation_ast=None,
|
|
894
|
+
type='primitive',
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
type_.add_annotation_import(list.__module__, list.__name__)
|
|
898
|
+
type_.copy_imports_from_sub_types([item_type])
|
|
899
|
+
|
|
900
|
+
if item_type:
|
|
901
|
+
type_.add_dependency(item_type)
|
|
902
|
+
|
|
903
|
+
return type_
|
|
904
|
+
|
|
905
|
+
def _create_object_type(
|
|
906
|
+
self,
|
|
907
|
+
schema: Schema | Reference,
|
|
908
|
+
name: str | None = None,
|
|
909
|
+
base_name: str | None = None,
|
|
910
|
+
) -> Type:
|
|
911
|
+
schema, schema_name = self._resolve_reference(schema)
|
|
912
|
+
|
|
913
|
+
# Handle additionalProperties for dict-like types
|
|
914
|
+
if (
|
|
915
|
+
not schema.properties
|
|
916
|
+
and not schema.allOf
|
|
917
|
+
and not schema.anyOf
|
|
918
|
+
and not schema.oneOf
|
|
919
|
+
):
|
|
920
|
+
# Check for additionalProperties to determine value type
|
|
921
|
+
value_type_ast = ast.Name(id=Any.__name__, ctx=ast.Load())
|
|
922
|
+
value_type_imports: dict[str, set[str]] = {Any.__module__: {Any.__name__}}
|
|
923
|
+
|
|
924
|
+
if (
|
|
925
|
+
schema.additionalProperties is not None
|
|
926
|
+
and schema.additionalProperties is not True
|
|
927
|
+
):
|
|
928
|
+
if schema.additionalProperties is False:
|
|
929
|
+
# No additional properties allowed - still generate dict[str, Any]
|
|
930
|
+
pass
|
|
931
|
+
elif isinstance(schema.additionalProperties, (Schema, Reference)):
|
|
932
|
+
# additionalProperties has a schema - use it for value type
|
|
933
|
+
additional_type = self.schema_to_type(
|
|
934
|
+
schema.additionalProperties, base_name=base_name
|
|
935
|
+
)
|
|
936
|
+
value_type_ast = additional_type.annotation_ast
|
|
937
|
+
value_type_imports = additional_type.annotation_imports.copy()
|
|
938
|
+
|
|
939
|
+
type_ = Type(
|
|
940
|
+
None,
|
|
941
|
+
None,
|
|
942
|
+
annotation_ast=_subscript(
|
|
943
|
+
dict.__name__,
|
|
944
|
+
ast.Tuple(
|
|
945
|
+
elts=[
|
|
946
|
+
ast.Name(id=str.__name__, ctx=ast.Load()),
|
|
947
|
+
value_type_ast,
|
|
948
|
+
]
|
|
949
|
+
),
|
|
950
|
+
),
|
|
951
|
+
implementation_ast=None,
|
|
952
|
+
type='primitive',
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
type_.add_annotation_import(dict.__module__, dict.__name__)
|
|
956
|
+
for module, names in value_type_imports.items():
|
|
957
|
+
for name_import in names:
|
|
958
|
+
type_.add_annotation_import(module, name_import)
|
|
959
|
+
|
|
960
|
+
return type_
|
|
961
|
+
|
|
962
|
+
return self._create_pydantic_model(
|
|
963
|
+
schema, schema_name or name, base_name=base_name
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
def schema_to_type(
|
|
967
|
+
self,
|
|
968
|
+
schema: Schema | Reference,
|
|
969
|
+
base_name: str | None = None,
|
|
970
|
+
field_name: str | None = None,
|
|
971
|
+
) -> Type:
|
|
972
|
+
if isinstance(schema, Reference):
|
|
973
|
+
ref_name = schema.ref.split('/')[-1]
|
|
974
|
+
sanitized_ref_name = sanitize_identifier(ref_name)
|
|
975
|
+
if sanitized_ref_name in self.types:
|
|
976
|
+
return self.types[sanitized_ref_name]
|
|
977
|
+
|
|
978
|
+
schema, schema_name = self._resolve_reference(schema)
|
|
979
|
+
|
|
980
|
+
# Use schema_name (from $ref) as base_name for nested types if available
|
|
981
|
+
# This ensures enums inside Pet get names like "PetStatus" not "addPetRequestBodyStatus"
|
|
982
|
+
effective_base_name = schema_name or base_name
|
|
983
|
+
|
|
984
|
+
# TODO: schema.type can be array?
|
|
985
|
+
if schema.type == DataType.array:
|
|
986
|
+
type_ = self._create_array_type(
|
|
987
|
+
schema=schema, name=schema_name, base_name=effective_base_name
|
|
988
|
+
)
|
|
989
|
+
elif schema.type == DataType.object or schema.type is None:
|
|
990
|
+
type_ = self._create_object_type(
|
|
991
|
+
schema, name=schema_name, base_name=effective_base_name
|
|
992
|
+
)
|
|
993
|
+
else:
|
|
994
|
+
type_ = self._get_primitive_type_ast(
|
|
995
|
+
schema, base_name=effective_base_name, field_name=field_name
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
return type_
|
|
999
|
+
|
|
1000
|
+
def get_sorted_types(self) -> list[Type]:
|
|
1001
|
+
"""Returns the types sorted in dependency order using topological sort.
|
|
1002
|
+
Types with no dependencies come first.
|
|
1003
|
+
"""
|
|
1004
|
+
sorted_types: list[Type] = []
|
|
1005
|
+
temp_mark: set[str] = set()
|
|
1006
|
+
perm_mark: set[str] = set()
|
|
1007
|
+
|
|
1008
|
+
def visit(type_: Type):
|
|
1009
|
+
if type_.name in perm_mark:
|
|
1010
|
+
return
|
|
1011
|
+
if type_.name in temp_mark:
|
|
1012
|
+
raise ValueError(f'Cyclic dependency detected for type: {type_.name}')
|
|
1013
|
+
|
|
1014
|
+
temp_mark.add(type_.name)
|
|
1015
|
+
|
|
1016
|
+
for dep_name in type_.dependencies:
|
|
1017
|
+
if dep_name in self.types:
|
|
1018
|
+
visit(self.types[dep_name])
|
|
1019
|
+
|
|
1020
|
+
perm_mark.add(type_.name)
|
|
1021
|
+
temp_mark.remove(type_.name)
|
|
1022
|
+
sorted_types.append(type_)
|
|
1023
|
+
|
|
1024
|
+
for type_ in self.types.values():
|
|
1025
|
+
if type_.name not in perm_mark:
|
|
1026
|
+
visit(type_)
|
|
1027
|
+
|
|
1028
|
+
return list(reversed(sorted_types))
|
|
1029
|
+
|
|
1030
|
+
|
|
1031
|
+
# =============================================================================
|
|
1032
|
+
# Type Registry
|
|
1033
|
+
# =============================================================================
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
@dataclasses.dataclass
|
|
1037
|
+
class TypeInfo:
|
|
1038
|
+
"""Information about a registered type.
|
|
1039
|
+
|
|
1040
|
+
Attributes:
|
|
1041
|
+
name: The Python name for this type.
|
|
1042
|
+
reference: The original OpenAPI reference (e.g., '#/components/schemas/Pet').
|
|
1043
|
+
type_obj: The Type object containing AST and metadata.
|
|
1044
|
+
dependencies: Set of type names this type depends on.
|
|
1045
|
+
is_root_model: Whether this is a Pydantic RootModel.
|
|
1046
|
+
is_generated: Whether the AST has been generated for this type.
|
|
1047
|
+
"""
|
|
1048
|
+
|
|
1049
|
+
name: str
|
|
1050
|
+
reference: str | None
|
|
1051
|
+
type_obj: 'Type'
|
|
1052
|
+
dependencies: set[str] = dataclasses.field(default_factory=set)
|
|
1053
|
+
is_root_model: bool = False
|
|
1054
|
+
is_generated: bool = False
|
|
1055
|
+
|
|
1056
|
+
|
|
1057
|
+
class TypeRegistry:
|
|
1058
|
+
"""Registry for managing generated types during code generation.
|
|
1059
|
+
|
|
1060
|
+
This class provides a centralized location for tracking all types generated
|
|
1061
|
+
from an OpenAPI schema, handling dependencies between types, and ensuring
|
|
1062
|
+
types are generated in the correct order.
|
|
1063
|
+
|
|
1064
|
+
Example:
|
|
1065
|
+
>>> registry = TypeRegistry()
|
|
1066
|
+
>>> registry.register(type_obj, name='Pet', reference='#/components/schemas/Pet')
|
|
1067
|
+
>>> if registry.has_type('Pet'):
|
|
1068
|
+
... pet_type = registry.get_type('Pet')
|
|
1069
|
+
>>> for type_info in registry.get_types_in_dependency_order():
|
|
1070
|
+
... generate_code(type_info)
|
|
1071
|
+
"""
|
|
1072
|
+
|
|
1073
|
+
def __init__(self):
|
|
1074
|
+
"""Initialize an empty type registry."""
|
|
1075
|
+
self._types: dict[str, TypeInfo] = {}
|
|
1076
|
+
self._by_reference: dict[str, str] = {}
|
|
1077
|
+
self._primitive_types: set[str] = {'str', 'int', 'float', 'bool', 'None'}
|
|
1078
|
+
|
|
1079
|
+
def register(
|
|
1080
|
+
self,
|
|
1081
|
+
type_obj: 'Type',
|
|
1082
|
+
name: str,
|
|
1083
|
+
reference: str | None = None,
|
|
1084
|
+
dependencies: set[str] | None = None,
|
|
1085
|
+
is_root_model: bool = False,
|
|
1086
|
+
) -> TypeInfo:
|
|
1087
|
+
"""Register a new type in the registry.
|
|
1088
|
+
|
|
1089
|
+
Args:
|
|
1090
|
+
type_obj: The Type object containing the type information.
|
|
1091
|
+
name: The Python name for this type.
|
|
1092
|
+
reference: The OpenAPI reference string, if applicable.
|
|
1093
|
+
dependencies: Set of type names this type depends on.
|
|
1094
|
+
is_root_model: Whether this is a Pydantic RootModel.
|
|
1095
|
+
|
|
1096
|
+
Returns:
|
|
1097
|
+
The TypeInfo object for the registered type.
|
|
1098
|
+
|
|
1099
|
+
Raises:
|
|
1100
|
+
ValueError: If a type with the same name is already registered.
|
|
1101
|
+
"""
|
|
1102
|
+
if name in self._types:
|
|
1103
|
+
raise ValueError(f"Type '{name}' is already registered")
|
|
1104
|
+
|
|
1105
|
+
type_info = TypeInfo(
|
|
1106
|
+
name=name,
|
|
1107
|
+
reference=reference,
|
|
1108
|
+
type_obj=type_obj,
|
|
1109
|
+
dependencies=dependencies or set(),
|
|
1110
|
+
is_root_model=is_root_model,
|
|
1111
|
+
)
|
|
1112
|
+
|
|
1113
|
+
self._types[name] = type_info
|
|
1114
|
+
|
|
1115
|
+
if reference:
|
|
1116
|
+
self._by_reference[reference] = name
|
|
1117
|
+
|
|
1118
|
+
return type_info
|
|
1119
|
+
|
|
1120
|
+
def register_or_get(
|
|
1121
|
+
self,
|
|
1122
|
+
type_obj: 'Type',
|
|
1123
|
+
name: str,
|
|
1124
|
+
reference: str | None = None,
|
|
1125
|
+
dependencies: set[str] | None = None,
|
|
1126
|
+
is_root_model: bool = False,
|
|
1127
|
+
) -> TypeInfo:
|
|
1128
|
+
"""Register a type if not exists, otherwise return the existing one.
|
|
1129
|
+
|
|
1130
|
+
Args:
|
|
1131
|
+
type_obj: The Type object containing the type information.
|
|
1132
|
+
name: The Python name for this type.
|
|
1133
|
+
reference: The OpenAPI reference string, if applicable.
|
|
1134
|
+
dependencies: Set of type names this type depends on.
|
|
1135
|
+
is_root_model: Whether this is a Pydantic RootModel.
|
|
1136
|
+
|
|
1137
|
+
Returns:
|
|
1138
|
+
The TypeInfo object (existing or newly registered).
|
|
1139
|
+
"""
|
|
1140
|
+
if name in self._types:
|
|
1141
|
+
return self._types[name]
|
|
1142
|
+
return self.register(type_obj, name, reference, dependencies, is_root_model)
|
|
1143
|
+
|
|
1144
|
+
def has_type(self, name: str) -> bool:
|
|
1145
|
+
"""Check if a type is registered."""
|
|
1146
|
+
return name in self._types
|
|
1147
|
+
|
|
1148
|
+
def has_reference(self, reference: str) -> bool:
|
|
1149
|
+
"""Check if a reference has been registered."""
|
|
1150
|
+
return reference in self._by_reference
|
|
1151
|
+
|
|
1152
|
+
def get_type(self, name: str) -> TypeInfo | None:
|
|
1153
|
+
"""Get a registered type by name."""
|
|
1154
|
+
return self._types.get(name)
|
|
1155
|
+
|
|
1156
|
+
def get_type_by_reference(self, reference: str) -> TypeInfo | None:
|
|
1157
|
+
"""Get a registered type by its OpenAPI reference."""
|
|
1158
|
+
name = self._by_reference.get(reference)
|
|
1159
|
+
if name:
|
|
1160
|
+
return self._types.get(name)
|
|
1161
|
+
return None
|
|
1162
|
+
|
|
1163
|
+
def get_name_for_reference(self, reference: str) -> str | None:
|
|
1164
|
+
"""Get the registered name for an OpenAPI reference."""
|
|
1165
|
+
return self._by_reference.get(reference)
|
|
1166
|
+
|
|
1167
|
+
def get_all_types(self) -> dict[str, TypeInfo]:
|
|
1168
|
+
"""Get all registered types."""
|
|
1169
|
+
return dict(self._types)
|
|
1170
|
+
|
|
1171
|
+
def get_type_names(self) -> list[str]:
|
|
1172
|
+
"""Get all registered type names, sorted alphabetically."""
|
|
1173
|
+
return sorted(self._types.keys())
|
|
1174
|
+
|
|
1175
|
+
def add_dependency(self, type_name: str, depends_on: str) -> None:
|
|
1176
|
+
"""Add a dependency relationship between types."""
|
|
1177
|
+
if type_name not in self._types:
|
|
1178
|
+
raise KeyError(f"Type '{type_name}' is not registered")
|
|
1179
|
+
self._types[type_name].dependencies.add(depends_on)
|
|
1180
|
+
|
|
1181
|
+
def get_dependencies(self, type_name: str) -> set[str]:
|
|
1182
|
+
"""Get all dependencies for a type."""
|
|
1183
|
+
if type_name not in self._types:
|
|
1184
|
+
raise KeyError(f"Type '{type_name}' is not registered")
|
|
1185
|
+
return self._types[type_name].dependencies.copy()
|
|
1186
|
+
|
|
1187
|
+
def get_types_in_dependency_order(self) -> list[TypeInfo]:
|
|
1188
|
+
"""Get all types sorted in dependency order.
|
|
1189
|
+
|
|
1190
|
+
Types are sorted so that dependencies come before the types that
|
|
1191
|
+
depend on them.
|
|
1192
|
+
"""
|
|
1193
|
+
result: list[TypeInfo] = []
|
|
1194
|
+
visited: set[str] = set()
|
|
1195
|
+
visiting: set[str] = set()
|
|
1196
|
+
|
|
1197
|
+
def visit(name: str) -> None:
|
|
1198
|
+
if name in visited:
|
|
1199
|
+
return
|
|
1200
|
+
if name in visiting:
|
|
1201
|
+
return
|
|
1202
|
+
if name in self._primitive_types:
|
|
1203
|
+
return
|
|
1204
|
+
if name not in self._types:
|
|
1205
|
+
return
|
|
1206
|
+
|
|
1207
|
+
visiting.add(name)
|
|
1208
|
+
type_info = self._types[name]
|
|
1209
|
+
|
|
1210
|
+
for dep in type_info.dependencies:
|
|
1211
|
+
visit(dep)
|
|
1212
|
+
|
|
1213
|
+
visiting.remove(name)
|
|
1214
|
+
visited.add(name)
|
|
1215
|
+
result.append(type_info)
|
|
1216
|
+
|
|
1217
|
+
for name in sorted(self._types.keys()):
|
|
1218
|
+
visit(name)
|
|
1219
|
+
|
|
1220
|
+
return result
|
|
1221
|
+
|
|
1222
|
+
def mark_generated(self, name: str) -> None:
|
|
1223
|
+
"""Mark a type as having its AST generated."""
|
|
1224
|
+
if name not in self._types:
|
|
1225
|
+
raise KeyError(f"Type '{name}' is not registered")
|
|
1226
|
+
self._types[name].is_generated = True
|
|
1227
|
+
|
|
1228
|
+
def get_ungenerated_types(self) -> list[TypeInfo]:
|
|
1229
|
+
"""Get all types that haven't been generated yet."""
|
|
1230
|
+
return [t for t in self._types.values() if not t.is_generated]
|
|
1231
|
+
|
|
1232
|
+
def clear(self) -> None:
|
|
1233
|
+
"""Clear all registered types."""
|
|
1234
|
+
self._types.clear()
|
|
1235
|
+
self._by_reference.clear()
|
|
1236
|
+
|
|
1237
|
+
def __len__(self) -> int:
|
|
1238
|
+
"""Return the number of registered types."""
|
|
1239
|
+
return len(self._types)
|
|
1240
|
+
|
|
1241
|
+
def __iter__(self) -> Iterator[TypeInfo]:
|
|
1242
|
+
"""Iterate over all registered types."""
|
|
1243
|
+
return iter(self._types.values())
|
|
1244
|
+
|
|
1245
|
+
def __contains__(self, name: str) -> bool:
|
|
1246
|
+
"""Check if a type name is registered."""
|
|
1247
|
+
return name in self._types
|
|
1248
|
+
|
|
1249
|
+
def get_root_models(self) -> list[TypeInfo]:
|
|
1250
|
+
"""Get all registered root models."""
|
|
1251
|
+
return [t for t in self._types.values() if t.is_root_model]
|
|
1252
|
+
|
|
1253
|
+
def get_regular_models(self) -> list[TypeInfo]:
|
|
1254
|
+
"""Get all registered non-root models."""
|
|
1255
|
+
return [t for t in self._types.values() if not t.is_root_model]
|
|
1256
|
+
|
|
1257
|
+
def merge(self, other: 'TypeRegistry') -> None:
|
|
1258
|
+
"""Merge another registry into this one."""
|
|
1259
|
+
for type_info in other:
|
|
1260
|
+
if type_info.name not in self._types:
|
|
1261
|
+
self._types[type_info.name] = type_info
|
|
1262
|
+
if type_info.reference:
|
|
1263
|
+
self._by_reference[type_info.reference] = type_info.name
|
|
1264
|
+
|
|
1265
|
+
|
|
1266
|
+
# =============================================================================
|
|
1267
|
+
# Model Name Collector
|
|
1268
|
+
# =============================================================================
|
|
1269
|
+
|
|
1270
|
+
|
|
1271
|
+
class ModelNameCollector(ast.NodeVisitor):
|
|
1272
|
+
"""AST visitor that collects model names from function definitions.
|
|
1273
|
+
|
|
1274
|
+
This visitor walks AST nodes and identifies Name nodes that match
|
|
1275
|
+
a set of available model names, allowing us to determine which
|
|
1276
|
+
models are actually referenced in generated code.
|
|
1277
|
+
|
|
1278
|
+
Example:
|
|
1279
|
+
>>> available = {'Pet', 'User', 'Order'}
|
|
1280
|
+
>>> collector = ModelNameCollector(available)
|
|
1281
|
+
>>> collector.visit(some_function_ast)
|
|
1282
|
+
>>> print(collector.used_models)
|
|
1283
|
+
{'Pet', 'User'}
|
|
1284
|
+
"""
|
|
1285
|
+
|
|
1286
|
+
def __init__(self, available_models: set[str]):
|
|
1287
|
+
"""Initialize the collector.
|
|
1288
|
+
|
|
1289
|
+
Args:
|
|
1290
|
+
available_models: Set of model names that are available for import.
|
|
1291
|
+
"""
|
|
1292
|
+
self.available_models = available_models
|
|
1293
|
+
self.used_models: set[str] = set()
|
|
1294
|
+
|
|
1295
|
+
def visit_Name(self, node: ast.Name) -> None:
|
|
1296
|
+
"""Visit a Name node and check if it's an available model."""
|
|
1297
|
+
if node.id in self.available_models:
|
|
1298
|
+
self.used_models.add(node.id)
|
|
1299
|
+
self.generic_visit(node)
|
|
1300
|
+
|
|
1301
|
+
@classmethod
|
|
1302
|
+
def collect_from_endpoints(
|
|
1303
|
+
cls,
|
|
1304
|
+
endpoints: list['Endpoint'],
|
|
1305
|
+
available_models: set[str],
|
|
1306
|
+
) -> set[str]:
|
|
1307
|
+
"""Collect model names used across multiple endpoints.
|
|
1308
|
+
|
|
1309
|
+
Args:
|
|
1310
|
+
endpoints: List of Endpoint objects to scan.
|
|
1311
|
+
available_models: Set of model names that are available.
|
|
1312
|
+
|
|
1313
|
+
Returns:
|
|
1314
|
+
Set of model names that are actually used in the endpoints.
|
|
1315
|
+
"""
|
|
1316
|
+
collector = cls(available_models)
|
|
1317
|
+
for endpoint in endpoints:
|
|
1318
|
+
collector.visit(endpoint.sync_ast)
|
|
1319
|
+
collector.visit(endpoint.async_ast)
|
|
1320
|
+
return collector.used_models
|
|
1321
|
+
|
|
1322
|
+
|
|
1323
|
+
def collect_used_model_names(
|
|
1324
|
+
endpoints: list['Endpoint'],
|
|
1325
|
+
typegen_types: dict[str, 'Type'],
|
|
1326
|
+
) -> set[str]:
|
|
1327
|
+
"""Collect model names that are actually used in endpoint signatures.
|
|
1328
|
+
|
|
1329
|
+
Only collects models that have implementations (defined in models.py)
|
|
1330
|
+
and are referenced in endpoint parameters, request bodies, or responses.
|
|
1331
|
+
|
|
1332
|
+
Args:
|
|
1333
|
+
endpoints: List of Endpoint objects to check for model usage.
|
|
1334
|
+
typegen_types: Dictionary mapping type names to Type objects.
|
|
1335
|
+
|
|
1336
|
+
Returns:
|
|
1337
|
+
Set of model names actually used in endpoints.
|
|
1338
|
+
"""
|
|
1339
|
+
available_models = {
|
|
1340
|
+
type_.name
|
|
1341
|
+
for type_ in typegen_types.values()
|
|
1342
|
+
if type_.name and type_.implementation_ast
|
|
1343
|
+
}
|
|
1344
|
+
|
|
1345
|
+
return ModelNameCollector.collect_from_endpoints(endpoints, available_models)
|