otterapi 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- README.md +581 -8
- otterapi/__init__.py +73 -0
- otterapi/cli.py +327 -29
- otterapi/codegen/__init__.py +115 -0
- otterapi/codegen/ast_utils.py +134 -5
- otterapi/codegen/client.py +1271 -0
- otterapi/codegen/codegen.py +1736 -0
- otterapi/codegen/dataframes.py +392 -0
- otterapi/codegen/emitter.py +473 -0
- otterapi/codegen/endpoints.py +2597 -343
- otterapi/codegen/pagination.py +1026 -0
- otterapi/codegen/schema.py +593 -0
- otterapi/codegen/splitting.py +1397 -0
- otterapi/codegen/types.py +1345 -0
- otterapi/codegen/utils.py +180 -1
- otterapi/config.py +1017 -24
- otterapi/exceptions.py +231 -0
- otterapi/openapi/__init__.py +46 -0
- otterapi/openapi/v2/__init__.py +86 -0
- otterapi/openapi/v2/spec.json +1607 -0
- otterapi/openapi/v2/v2.py +1776 -0
- otterapi/openapi/v3/__init__.py +131 -0
- otterapi/openapi/v3/spec.json +1651 -0
- otterapi/openapi/v3/v3.py +1557 -0
- otterapi/openapi/v3_1/__init__.py +133 -0
- otterapi/openapi/v3_1/spec.json +1411 -0
- otterapi/openapi/v3_1/v3_1.py +798 -0
- otterapi/openapi/v3_2/__init__.py +133 -0
- otterapi/openapi/v3_2/spec.json +1666 -0
- otterapi/openapi/v3_2/v3_2.py +777 -0
- otterapi/tests/__init__.py +3 -0
- otterapi/tests/fixtures/__init__.py +455 -0
- otterapi/tests/test_ast_utils.py +680 -0
- otterapi/tests/test_codegen.py +610 -0
- otterapi/tests/test_dataframe.py +1038 -0
- otterapi/tests/test_exceptions.py +493 -0
- otterapi/tests/test_openapi_support.py +616 -0
- otterapi/tests/test_openapi_upgrade.py +215 -0
- otterapi/tests/test_pagination.py +1101 -0
- otterapi/tests/test_splitting_config.py +319 -0
- otterapi/tests/test_splitting_integration.py +427 -0
- otterapi/tests/test_splitting_resolver.py +512 -0
- otterapi/tests/test_splitting_tree.py +525 -0
- otterapi-0.0.6.dist-info/METADATA +627 -0
- otterapi-0.0.6.dist-info/RECORD +48 -0
- {otterapi-0.0.5.dist-info → otterapi-0.0.6.dist-info}/WHEEL +1 -1
- otterapi/codegen/generator.py +0 -358
- otterapi/codegen/openapi_processor.py +0 -27
- otterapi/codegen/type_generator.py +0 -559
- otterapi-0.0.5.dist-info/METADATA +0 -54
- otterapi-0.0.5.dist-info/RECORD +0 -16
- {otterapi-0.0.5.dist-info → otterapi-0.0.6.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,1736 @@
|
|
|
1
|
+
"""Code generation module for OtterAPI.
|
|
2
|
+
|
|
3
|
+
This module provides the main Codegen class that orchestrates the generation
|
|
4
|
+
of Python client code from OpenAPI specifications.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import ast
|
|
8
|
+
import http
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
from urllib.parse import urljoin, urlparse
|
|
12
|
+
|
|
13
|
+
from upath import UPath
|
|
14
|
+
|
|
15
|
+
from otterapi.codegen.ast_utils import (
|
|
16
|
+
ImportCollector,
|
|
17
|
+
_all,
|
|
18
|
+
_assign,
|
|
19
|
+
_call,
|
|
20
|
+
_name,
|
|
21
|
+
_union_expr,
|
|
22
|
+
)
|
|
23
|
+
from otterapi.codegen.client import (
|
|
24
|
+
EndpointInfo,
|
|
25
|
+
generate_api_error_class,
|
|
26
|
+
generate_base_client_class,
|
|
27
|
+
generate_client_stub,
|
|
28
|
+
)
|
|
29
|
+
from otterapi.codegen.dataframes import (
|
|
30
|
+
DataFrameMethodConfig,
|
|
31
|
+
endpoint_returns_list,
|
|
32
|
+
generate_dataframe_module,
|
|
33
|
+
get_dataframe_config_for_endpoint,
|
|
34
|
+
)
|
|
35
|
+
from otterapi.codegen.endpoints import async_request_fn, request_fn
|
|
36
|
+
from otterapi.codegen.pagination import (
|
|
37
|
+
PaginationMethodConfig,
|
|
38
|
+
generate_pagination_module,
|
|
39
|
+
get_pagination_config_for_endpoint,
|
|
40
|
+
)
|
|
41
|
+
from otterapi.codegen.schema import SchemaLoader
|
|
42
|
+
from otterapi.codegen.types import (
|
|
43
|
+
Endpoint,
|
|
44
|
+
Parameter,
|
|
45
|
+
RequestBodyInfo,
|
|
46
|
+
ResponseInfo,
|
|
47
|
+
Type,
|
|
48
|
+
TypeGenerator,
|
|
49
|
+
collect_used_model_names,
|
|
50
|
+
)
|
|
51
|
+
from otterapi.codegen.utils import (
|
|
52
|
+
OpenAPIProcessor,
|
|
53
|
+
sanitize_identifier,
|
|
54
|
+
sanitize_parameter_field_name,
|
|
55
|
+
to_snake_case,
|
|
56
|
+
write_mod,
|
|
57
|
+
)
|
|
58
|
+
from otterapi.config import DocumentConfig
|
|
59
|
+
from otterapi.openapi.v3_2.v3_2 import (
|
|
60
|
+
OpenAPI as OpenAPIv3_2,
|
|
61
|
+
Operation,
|
|
62
|
+
Parameter as OpenAPIParameter,
|
|
63
|
+
Reference,
|
|
64
|
+
RequestBody as OpenAPIRequestBody,
|
|
65
|
+
Response as OpenAPIResponse,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Content types that should be treated as JSON
|
|
69
|
+
JSON_CONTENT_TYPES = {'application/json', 'text/json'}
|
|
70
|
+
|
|
71
|
+
HTTP_METHODS = [method.value.lower() for method in http.HTTPMethod]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class Codegen(OpenAPIProcessor):
|
|
75
|
+
"""Main code generator for creating Python clients from OpenAPI specifications.
|
|
76
|
+
|
|
77
|
+
This class orchestrates the entire code generation process, including:
|
|
78
|
+
- Loading and validating OpenAPI schemas
|
|
79
|
+
- Generating Pydantic models from schema definitions
|
|
80
|
+
- Creating typed endpoint functions for API operations
|
|
81
|
+
- Writing output files with proper imports and structure
|
|
82
|
+
|
|
83
|
+
The generator supports OpenAPI 3.x specifications and produces
|
|
84
|
+
fully typed Python code compatible with httpx for HTTP requests
|
|
85
|
+
and Pydantic for data validation.
|
|
86
|
+
|
|
87
|
+
Attributes:
|
|
88
|
+
config: The DocumentConfig containing source and output settings.
|
|
89
|
+
openapi: The loaded OpenAPI schema (populated after _load_schema).
|
|
90
|
+
typegen: The TypeGenerator for creating Pydantic models.
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
>>> from otterapi.config import DocumentConfig
|
|
94
|
+
>>> from otterapi.codegen.codegen import Codegen
|
|
95
|
+
>>>
|
|
96
|
+
>>> config = DocumentConfig(
|
|
97
|
+
... source="https://api.example.com/openapi.json",
|
|
98
|
+
... output="./client"
|
|
99
|
+
... )
|
|
100
|
+
>>> codegen = Codegen(config)
|
|
101
|
+
>>> codegen.generate()
|
|
102
|
+
# Creates models.py and endpoints.py in ./client/
|
|
103
|
+
|
|
104
|
+
Note:
|
|
105
|
+
The schema is not loaded until generate() is called or
|
|
106
|
+
_load_schema() is explicitly invoked.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self, config: DocumentConfig, schema_loader: SchemaLoader | None = None
|
|
111
|
+
):
|
|
112
|
+
"""Initialize the code generator.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
config: Configuration specifying source schema and output location.
|
|
116
|
+
schema_loader: Optional custom schema loader. If not provided,
|
|
117
|
+
a default SchemaLoader will be created.
|
|
118
|
+
"""
|
|
119
|
+
super().__init__(None)
|
|
120
|
+
self.config = config
|
|
121
|
+
self.openapi: OpenAPIv3_2 | None = None
|
|
122
|
+
self._schema_loader = schema_loader or SchemaLoader()
|
|
123
|
+
|
|
124
|
+
def _load_schema(self) -> None:
|
|
125
|
+
"""Load and parse the OpenAPI schema from the configured source.
|
|
126
|
+
|
|
127
|
+
This method loads the schema from a URL or file path, validates it
|
|
128
|
+
against the OpenAPI specification, and initializes the type generator.
|
|
129
|
+
|
|
130
|
+
After calling this method, self.openapi and self.typegen will be
|
|
131
|
+
populated and ready for code generation.
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
SchemaLoadError: If the schema cannot be loaded from the source.
|
|
135
|
+
SchemaValidationError: If the schema is not valid OpenAPI.
|
|
136
|
+
"""
|
|
137
|
+
self.openapi = self._schema_loader.load(self.config.source)
|
|
138
|
+
self.typegen = TypeGenerator(self.openapi)
|
|
139
|
+
|
|
140
|
+
def _extract_response_info(self, operation: Operation) -> dict[int, ResponseInfo]:
|
|
141
|
+
"""Extract response information including content type from an operation.
|
|
142
|
+
|
|
143
|
+
This method extracts response schemas and content types for each status code.
|
|
144
|
+
When multiple content types are available for a response, it prefers JSON
|
|
145
|
+
content types for better type safety.
|
|
146
|
+
|
|
147
|
+
For non-JSON content types (XML, binary, text, etc.), no response type is
|
|
148
|
+
generated and the endpoint will return the raw httpx.Response object.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
operation: The OpenAPI operation to extract responses from.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Dictionary mapping status codes to ResponseInfo objects.
|
|
155
|
+
"""
|
|
156
|
+
responses: dict[int, ResponseInfo] = {}
|
|
157
|
+
|
|
158
|
+
if not operation.responses:
|
|
159
|
+
return responses
|
|
160
|
+
|
|
161
|
+
for status_code_str, response_or_ref in operation.responses.root.items():
|
|
162
|
+
try:
|
|
163
|
+
status_code = int(status_code_str)
|
|
164
|
+
except ValueError:
|
|
165
|
+
logging.debug(f'Skipping non-numeric status code: {status_code_str}')
|
|
166
|
+
continue
|
|
167
|
+
|
|
168
|
+
# Resolve reference if needed
|
|
169
|
+
response = self._resolve_response_reference(response_or_ref)
|
|
170
|
+
if response is None or not response.content:
|
|
171
|
+
continue
|
|
172
|
+
|
|
173
|
+
selected_content_type, selected_media_type = self._select_content_type(
|
|
174
|
+
response.content
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Only generate typed response for JSON content types
|
|
178
|
+
# For other content types (XML, binary, etc.), return raw httpx.Response
|
|
179
|
+
response_type = None
|
|
180
|
+
is_json_content = (
|
|
181
|
+
selected_content_type in JSON_CONTENT_TYPES
|
|
182
|
+
or selected_content_type.endswith('+json')
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
if is_json_content and selected_media_type.schema_:
|
|
186
|
+
response_type = self.typegen.schema_to_type(
|
|
187
|
+
selected_media_type.schema_,
|
|
188
|
+
base_name=f'{sanitize_identifier(operation.operationId)}Response',
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
responses[status_code] = ResponseInfo(
|
|
192
|
+
status_code=status_code,
|
|
193
|
+
content_type=selected_content_type,
|
|
194
|
+
type=response_type,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return responses
|
|
198
|
+
|
|
199
|
+
def _select_content_type(self, content: dict) -> tuple[str, any]:
|
|
200
|
+
"""Select the best content type from available options.
|
|
201
|
+
|
|
202
|
+
Prefers JSON content types for better type safety.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
content: Dictionary mapping content types to media type objects.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Tuple of (selected_content_type, selected_media_type).
|
|
209
|
+
"""
|
|
210
|
+
for content_type, media_type in content.items():
|
|
211
|
+
if content_type in JSON_CONTENT_TYPES or content_type.endswith('+json'):
|
|
212
|
+
return content_type, media_type
|
|
213
|
+
|
|
214
|
+
return next(iter(content.items()))
|
|
215
|
+
|
|
216
|
+
def _create_response_union(self, types: list[Type]) -> Type:
|
|
217
|
+
"""Create a union type from multiple response types.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
types: List of response types to combine into a union.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
A union Type combining all the input types.
|
|
224
|
+
"""
|
|
225
|
+
union_type = Type(
|
|
226
|
+
None,
|
|
227
|
+
None,
|
|
228
|
+
annotation_ast=_union_expr([t.annotation_ast for t in types]),
|
|
229
|
+
implementation_ast=None,
|
|
230
|
+
type='primitive',
|
|
231
|
+
)
|
|
232
|
+
union_type.copy_imports_from_sub_types(types)
|
|
233
|
+
return union_type
|
|
234
|
+
|
|
235
|
+
def _collect_non_json_types(self, response_list: list[ResponseInfo]) -> list[Type]:
|
|
236
|
+
"""Collect non-JSON response types.
|
|
237
|
+
|
|
238
|
+
For all non-JSON content types (binary, text, XML, etc.), we return
|
|
239
|
+
the raw httpx.Response object. This gives users full control over
|
|
240
|
+
how to handle the response (.content, .text, .json(), etc.).
|
|
241
|
+
"""
|
|
242
|
+
# Check if there are any non-JSON responses
|
|
243
|
+
has_non_json = any(not r.is_json for r in response_list)
|
|
244
|
+
|
|
245
|
+
if has_non_json:
|
|
246
|
+
response_type = Type(
|
|
247
|
+
reference=None,
|
|
248
|
+
name=None,
|
|
249
|
+
type='primitive',
|
|
250
|
+
annotation_ast=_name('Response'),
|
|
251
|
+
)
|
|
252
|
+
response_type.add_annotation_import('httpx', 'Response')
|
|
253
|
+
return [response_type]
|
|
254
|
+
|
|
255
|
+
return []
|
|
256
|
+
|
|
257
|
+
def _get_response_models(
|
|
258
|
+
self, operation: Operation
|
|
259
|
+
) -> tuple[list[ResponseInfo], Type | None]:
|
|
260
|
+
"""Get response models and info from an operation.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
operation: The OpenAPI operation to extract response models from.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
A tuple of (response_infos, response_type) where:
|
|
267
|
+
- response_infos: List of ResponseInfo objects for all status codes
|
|
268
|
+
- response_type: The unified response type (single or union), or None
|
|
269
|
+
"""
|
|
270
|
+
responses = self._extract_response_info(operation)
|
|
271
|
+
|
|
272
|
+
if not responses:
|
|
273
|
+
return [], None
|
|
274
|
+
|
|
275
|
+
response_list = list(responses.values())
|
|
276
|
+
|
|
277
|
+
json_types = [r.type for r in response_list if r.is_json and r.type]
|
|
278
|
+
non_json_types = self._collect_non_json_types(response_list)
|
|
279
|
+
|
|
280
|
+
all_types = json_types + non_json_types
|
|
281
|
+
|
|
282
|
+
if len(all_types) == 0:
|
|
283
|
+
return response_list, None
|
|
284
|
+
elif len(all_types) == 1:
|
|
285
|
+
return response_list, all_types[0]
|
|
286
|
+
else:
|
|
287
|
+
return response_list, self._create_response_union(all_types)
|
|
288
|
+
|
|
289
|
+
def _extract_operation_parameters(
|
|
290
|
+
self, operation: Operation, path_item_parameters: list | None = None
|
|
291
|
+
) -> list[Parameter]:
|
|
292
|
+
"""Extract path, query, header, and cookie parameters from an operation.
|
|
293
|
+
|
|
294
|
+
Merges path-level parameters with operation-level parameters.
|
|
295
|
+
Operation parameters override path-level parameters with the same name and location.
|
|
296
|
+
Handles $ref references to #/components/parameters/.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
operation: The OpenAPI operation to extract parameters from.
|
|
300
|
+
path_item_parameters: Optional path-level parameters to inherit.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
List of Parameter objects for path/query/header/cookie parameters.
|
|
304
|
+
"""
|
|
305
|
+
# Merge path-level and operation-level parameters
|
|
306
|
+
# Operation parameters override path-level parameters with same name+location
|
|
307
|
+
all_params = []
|
|
308
|
+
param_keys_seen = set() # Track (name, location) to handle overrides
|
|
309
|
+
|
|
310
|
+
# First, add operation-level parameters (they take precedence)
|
|
311
|
+
for param_or_ref in operation.parameters or []:
|
|
312
|
+
# Resolve reference if needed
|
|
313
|
+
param = self._resolve_parameter_reference(param_or_ref)
|
|
314
|
+
if param is None:
|
|
315
|
+
continue
|
|
316
|
+
|
|
317
|
+
param_type = None
|
|
318
|
+
if param.schema_:
|
|
319
|
+
param_type = self.typegen.schema_to_type(param.schema_)
|
|
320
|
+
|
|
321
|
+
all_params.append(
|
|
322
|
+
Parameter(
|
|
323
|
+
name=param.name,
|
|
324
|
+
name_sanitized=sanitize_parameter_field_name(param.name),
|
|
325
|
+
location=param.in_,
|
|
326
|
+
required=param.required or False,
|
|
327
|
+
type=param_type,
|
|
328
|
+
description=param.description,
|
|
329
|
+
)
|
|
330
|
+
)
|
|
331
|
+
param_keys_seen.add((param.name, param.in_))
|
|
332
|
+
|
|
333
|
+
# Then, add path-level parameters that weren't overridden
|
|
334
|
+
for param_or_ref in path_item_parameters or []:
|
|
335
|
+
# Resolve reference if needed
|
|
336
|
+
param = self._resolve_parameter_reference(param_or_ref)
|
|
337
|
+
if param is None:
|
|
338
|
+
continue
|
|
339
|
+
|
|
340
|
+
if (param.name, param.in_) not in param_keys_seen:
|
|
341
|
+
param_type = None
|
|
342
|
+
if param.schema_:
|
|
343
|
+
param_type = self.typegen.schema_to_type(param.schema_)
|
|
344
|
+
|
|
345
|
+
all_params.append(
|
|
346
|
+
Parameter(
|
|
347
|
+
name=param.name,
|
|
348
|
+
name_sanitized=sanitize_parameter_field_name(param.name),
|
|
349
|
+
location=param.in_,
|
|
350
|
+
required=param.required or False,
|
|
351
|
+
type=param_type,
|
|
352
|
+
description=param.description,
|
|
353
|
+
)
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
return all_params
|
|
357
|
+
|
|
358
|
+
def _resolve_parameter_reference(
|
|
359
|
+
self, param_or_ref: OpenAPIParameter | Reference
|
|
360
|
+
) -> OpenAPIParameter | None:
|
|
361
|
+
"""Resolve a parameter reference to the actual Parameter object.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
param_or_ref: Either a Parameter object or a Reference to one.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
The resolved Parameter object, or None if reference cannot be resolved.
|
|
368
|
+
"""
|
|
369
|
+
if isinstance(param_or_ref, Reference):
|
|
370
|
+
if not param_or_ref.ref.startswith('#/components/parameters/'):
|
|
371
|
+
logging.warning(
|
|
372
|
+
f'Unsupported parameter reference format: {param_or_ref.ref}'
|
|
373
|
+
)
|
|
374
|
+
return None
|
|
375
|
+
|
|
376
|
+
param_name = param_or_ref.ref.split('/')[-1]
|
|
377
|
+
if (
|
|
378
|
+
not self.openapi.components
|
|
379
|
+
or not self.openapi.components.parameters
|
|
380
|
+
or param_name not in self.openapi.components.parameters
|
|
381
|
+
):
|
|
382
|
+
logging.warning(
|
|
383
|
+
f"Referenced parameter '{param_name}' not found in components.parameters"
|
|
384
|
+
)
|
|
385
|
+
return None
|
|
386
|
+
|
|
387
|
+
resolved = self.openapi.components.parameters[param_name]
|
|
388
|
+
# Handle nested references
|
|
389
|
+
if isinstance(resolved, Reference):
|
|
390
|
+
return self._resolve_parameter_reference(resolved)
|
|
391
|
+
return resolved
|
|
392
|
+
|
|
393
|
+
return param_or_ref
|
|
394
|
+
|
|
395
|
+
def _resolve_response_reference(
|
|
396
|
+
self, response_or_ref: OpenAPIResponse | Reference
|
|
397
|
+
) -> OpenAPIResponse | None:
|
|
398
|
+
"""Resolve a response reference to the actual Response object.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
response_or_ref: Either a Response object or a Reference to one.
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
The resolved Response object, or None if reference cannot be resolved.
|
|
405
|
+
"""
|
|
406
|
+
if isinstance(response_or_ref, Reference):
|
|
407
|
+
if not response_or_ref.ref.startswith('#/components/responses/'):
|
|
408
|
+
logging.warning(
|
|
409
|
+
f'Unsupported response reference format: {response_or_ref.ref}'
|
|
410
|
+
)
|
|
411
|
+
return None
|
|
412
|
+
|
|
413
|
+
response_name = response_or_ref.ref.split('/')[-1]
|
|
414
|
+
if (
|
|
415
|
+
not self.openapi.components
|
|
416
|
+
or not self.openapi.components.responses
|
|
417
|
+
or response_name not in self.openapi.components.responses
|
|
418
|
+
):
|
|
419
|
+
logging.warning(
|
|
420
|
+
f"Referenced response '{response_name}' not found in components.responses"
|
|
421
|
+
)
|
|
422
|
+
return None
|
|
423
|
+
|
|
424
|
+
resolved = self.openapi.components.responses[response_name]
|
|
425
|
+
# Handle nested references
|
|
426
|
+
if isinstance(resolved, Reference):
|
|
427
|
+
return self._resolve_response_reference(resolved)
|
|
428
|
+
return resolved
|
|
429
|
+
|
|
430
|
+
return response_or_ref
|
|
431
|
+
|
|
432
|
+
def _resolve_request_body_reference(
|
|
433
|
+
self, body_or_ref: OpenAPIRequestBody | Reference
|
|
434
|
+
) -> OpenAPIRequestBody | None:
|
|
435
|
+
"""Resolve a request body reference to the actual RequestBody object.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
body_or_ref: Either a RequestBody object or a Reference to one.
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
The resolved RequestBody object, or None if reference cannot be resolved.
|
|
442
|
+
"""
|
|
443
|
+
if isinstance(body_or_ref, Reference):
|
|
444
|
+
if not body_or_ref.ref.startswith('#/components/requestBodies/'):
|
|
445
|
+
logging.warning(
|
|
446
|
+
f'Unsupported request body reference format: {body_or_ref.ref}'
|
|
447
|
+
)
|
|
448
|
+
return None
|
|
449
|
+
|
|
450
|
+
body_name = body_or_ref.ref.split('/')[-1]
|
|
451
|
+
if (
|
|
452
|
+
not self.openapi.components
|
|
453
|
+
or not self.openapi.components.requestBodies
|
|
454
|
+
or body_name not in self.openapi.components.requestBodies
|
|
455
|
+
):
|
|
456
|
+
logging.warning(
|
|
457
|
+
f"Referenced request body '{body_name}' not found in components.requestBodies"
|
|
458
|
+
)
|
|
459
|
+
return None
|
|
460
|
+
|
|
461
|
+
resolved = self.openapi.components.requestBodies[body_name]
|
|
462
|
+
# Handle nested references
|
|
463
|
+
if isinstance(resolved, Reference):
|
|
464
|
+
return self._resolve_request_body_reference(resolved)
|
|
465
|
+
return resolved
|
|
466
|
+
|
|
467
|
+
return body_or_ref
|
|
468
|
+
|
|
469
|
+
def _extract_request_body(self, operation: Operation) -> RequestBodyInfo | None:
|
|
470
|
+
"""Extract request body information from an operation.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
operation: The OpenAPI operation to extract request body from.
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
RequestBodyInfo object with content type and schema, or None if no body exists.
|
|
477
|
+
"""
|
|
478
|
+
if not operation.requestBody:
|
|
479
|
+
return None
|
|
480
|
+
|
|
481
|
+
body = self._resolve_request_body_reference(operation.requestBody)
|
|
482
|
+
if body is None or not body.content:
|
|
483
|
+
return None
|
|
484
|
+
|
|
485
|
+
selected_content_type, selected_media_type = self._select_content_type(
|
|
486
|
+
body.content
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
body_type = None
|
|
490
|
+
if selected_media_type.schema_:
|
|
491
|
+
body_type = self.typegen.schema_to_type(
|
|
492
|
+
selected_media_type.schema_,
|
|
493
|
+
base_name=f'{sanitize_identifier(operation.operationId)}RequestBody',
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
return RequestBodyInfo(
|
|
497
|
+
content_type=selected_content_type,
|
|
498
|
+
type=body_type,
|
|
499
|
+
required=body.required or False,
|
|
500
|
+
description=body.description,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
def _get_param_model(
|
|
504
|
+
self, operation: Operation, path_item_parameters: list | None = None
|
|
505
|
+
) -> tuple[list[Parameter], RequestBodyInfo | None]:
|
|
506
|
+
"""Get all parameters and request body info for an operation.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
operation: The OpenAPI operation to extract parameters from.
|
|
510
|
+
path_item_parameters: Optional path-level parameters to inherit.
|
|
511
|
+
|
|
512
|
+
Returns:
|
|
513
|
+
A tuple of (parameters, request_body_info) where:
|
|
514
|
+
- parameters: List of Parameter objects (path, query, header)
|
|
515
|
+
- request_body_info: RequestBodyInfo object or None
|
|
516
|
+
"""
|
|
517
|
+
params = self._extract_operation_parameters(operation, path_item_parameters)
|
|
518
|
+
body_info = self._extract_request_body(operation)
|
|
519
|
+
|
|
520
|
+
return params, body_info
|
|
521
|
+
|
|
522
|
+
def _generate_endpoint(
|
|
523
|
+
self,
|
|
524
|
+
path: str,
|
|
525
|
+
method: str,
|
|
526
|
+
operation: Operation,
|
|
527
|
+
path_item_parameters: list | None = None,
|
|
528
|
+
) -> Endpoint:
|
|
529
|
+
"""Generate an endpoint with sync and async functions.
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
path: The API path for the endpoint.
|
|
533
|
+
method: The HTTP method (get, post, etc.).
|
|
534
|
+
operation: The OpenAPI operation definition.
|
|
535
|
+
path_item_parameters: Optional list of path-level parameters to inherit.
|
|
536
|
+
|
|
537
|
+
Returns:
|
|
538
|
+
An Endpoint object containing the generated functions and imports.
|
|
539
|
+
"""
|
|
540
|
+
# Convert operationId to snake_case for Pythonic function names
|
|
541
|
+
raw_name = (
|
|
542
|
+
operation.operationId
|
|
543
|
+
or f'{method}_{path.replace("/", "_").replace("{", "").replace("}", "")}'
|
|
544
|
+
)
|
|
545
|
+
fn_name = to_snake_case(raw_name)
|
|
546
|
+
async_fn_name = f'a{fn_name}'
|
|
547
|
+
|
|
548
|
+
parameters, request_body_info = self._get_param_model(
|
|
549
|
+
operation, path_item_parameters
|
|
550
|
+
)
|
|
551
|
+
response_infos, response_model = self._get_response_models(operation)
|
|
552
|
+
|
|
553
|
+
# Build docstring with deprecation warning if needed
|
|
554
|
+
docs = operation.description or ''
|
|
555
|
+
if operation.deprecated:
|
|
556
|
+
deprecation_notice = '.. deprecated::\n This endpoint is deprecated.'
|
|
557
|
+
if docs:
|
|
558
|
+
docs = f'{docs}\n\n{deprecation_notice}'
|
|
559
|
+
else:
|
|
560
|
+
docs = deprecation_notice
|
|
561
|
+
|
|
562
|
+
async_fn, async_imports = async_request_fn(
|
|
563
|
+
name=async_fn_name,
|
|
564
|
+
method=method,
|
|
565
|
+
path=path,
|
|
566
|
+
response_model=response_model,
|
|
567
|
+
docs=docs if docs else None,
|
|
568
|
+
parameters=parameters,
|
|
569
|
+
response_infos=response_infos,
|
|
570
|
+
request_body_info=request_body_info,
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
sync_fn, imports = request_fn(
|
|
574
|
+
name=fn_name,
|
|
575
|
+
method=method,
|
|
576
|
+
path=path,
|
|
577
|
+
response_model=response_model,
|
|
578
|
+
docs=docs if docs else None,
|
|
579
|
+
parameters=parameters,
|
|
580
|
+
response_infos=response_infos,
|
|
581
|
+
request_body_info=request_body_info,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
# Extract tags from operation
|
|
585
|
+
tags = list(operation.tags) if operation.tags else None
|
|
586
|
+
|
|
587
|
+
ep = Endpoint(
|
|
588
|
+
sync_ast=sync_fn,
|
|
589
|
+
sync_fn_name=fn_name,
|
|
590
|
+
async_fn_name=async_fn_name,
|
|
591
|
+
async_ast=async_fn,
|
|
592
|
+
name=fn_name,
|
|
593
|
+
method=method,
|
|
594
|
+
path=path,
|
|
595
|
+
description=operation.description,
|
|
596
|
+
tags=tags,
|
|
597
|
+
parameters=parameters,
|
|
598
|
+
request_body=request_body_info,
|
|
599
|
+
response_type=response_model,
|
|
600
|
+
response_infos=response_infos,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
ep.add_imports([imports, async_imports])
|
|
604
|
+
|
|
605
|
+
for param in parameters:
|
|
606
|
+
if param.type and param.type.annotation_imports:
|
|
607
|
+
ep.add_imports([param.type.annotation_imports])
|
|
608
|
+
|
|
609
|
+
# Add imports for request body type if present
|
|
610
|
+
if request_body_info and request_body_info.type:
|
|
611
|
+
if request_body_info.type.annotation_imports:
|
|
612
|
+
ep.add_imports([request_body_info.type.annotation_imports])
|
|
613
|
+
|
|
614
|
+
# Add imports for response model type if present
|
|
615
|
+
if response_model and response_model.annotation_imports:
|
|
616
|
+
ep.add_imports([response_model.annotation_imports])
|
|
617
|
+
|
|
618
|
+
return ep
|
|
619
|
+
|
|
620
|
+
def _generate_endpoints(self) -> list[Endpoint]:
|
|
621
|
+
"""Generate all endpoints from the OpenAPI paths.
|
|
622
|
+
|
|
623
|
+
Returns:
|
|
624
|
+
List of generated Endpoint objects.
|
|
625
|
+
"""
|
|
626
|
+
endpoints: list[Endpoint] = []
|
|
627
|
+
# Paths is a RootModel, access .root to get the underlying dict
|
|
628
|
+
paths_dict = (
|
|
629
|
+
self.openapi.paths.root
|
|
630
|
+
if hasattr(self.openapi.paths, 'root')
|
|
631
|
+
else self.openapi.paths
|
|
632
|
+
)
|
|
633
|
+
for path, path_item in paths_dict.items():
|
|
634
|
+
# Apply path filtering
|
|
635
|
+
if not self._should_include_path(path):
|
|
636
|
+
continue
|
|
637
|
+
|
|
638
|
+
# Get path-level parameters to pass to each operation
|
|
639
|
+
path_item_parameters = (
|
|
640
|
+
path_item.parameters if hasattr(path_item, 'parameters') else None
|
|
641
|
+
)
|
|
642
|
+
for method in HTTP_METHODS:
|
|
643
|
+
operation = getattr(path_item, method, None)
|
|
644
|
+
if operation:
|
|
645
|
+
ep = self._generate_endpoint(
|
|
646
|
+
path, method, operation, path_item_parameters
|
|
647
|
+
)
|
|
648
|
+
endpoints.append(ep)
|
|
649
|
+
return endpoints
|
|
650
|
+
|
|
651
|
+
def _should_include_path(self, path: str) -> bool:
|
|
652
|
+
"""Check if a path should be included based on include_paths and exclude_paths.
|
|
653
|
+
|
|
654
|
+
Args:
|
|
655
|
+
path: The API path to check.
|
|
656
|
+
|
|
657
|
+
Returns:
|
|
658
|
+
True if the path should be included, False otherwise.
|
|
659
|
+
"""
|
|
660
|
+
import fnmatch
|
|
661
|
+
|
|
662
|
+
# Check include_paths first (if specified, path must match at least one)
|
|
663
|
+
if self.config.include_paths:
|
|
664
|
+
included = any(
|
|
665
|
+
fnmatch.fnmatch(path, pattern) for pattern in self.config.include_paths
|
|
666
|
+
)
|
|
667
|
+
if not included:
|
|
668
|
+
return False
|
|
669
|
+
|
|
670
|
+
# Check exclude_paths (if path matches any, exclude it)
|
|
671
|
+
if self.config.exclude_paths:
|
|
672
|
+
excluded = any(
|
|
673
|
+
fnmatch.fnmatch(path, pattern) for pattern in self.config.exclude_paths
|
|
674
|
+
)
|
|
675
|
+
if excluded:
|
|
676
|
+
return False
|
|
677
|
+
|
|
678
|
+
return True
|
|
679
|
+
|
|
680
|
+
def _is_absolute_url(self, url: str) -> bool:
|
|
681
|
+
"""Check if a URL is absolute (has scheme and netloc).
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
url: The URL to check.
|
|
685
|
+
|
|
686
|
+
Returns:
|
|
687
|
+
True if the URL is absolute, False otherwise.
|
|
688
|
+
"""
|
|
689
|
+
parsed = urlparse(url)
|
|
690
|
+
return bool(parsed.scheme and parsed.netloc)
|
|
691
|
+
|
|
692
|
+
def _resolve_base_url(self) -> str:
|
|
693
|
+
"""Resolve the base URL from config or OpenAPI spec.
|
|
694
|
+
|
|
695
|
+
If the server URL in the spec is relative, attempts to resolve it
|
|
696
|
+
against the source URL (if the spec was loaded from a URL).
|
|
697
|
+
|
|
698
|
+
Returns:
|
|
699
|
+
The base URL to use for API requests.
|
|
700
|
+
|
|
701
|
+
Raises:
|
|
702
|
+
ValueError: If no base URL can be determined, multiple servers are defined,
|
|
703
|
+
or a relative server URL cannot be resolved.
|
|
704
|
+
"""
|
|
705
|
+
# Config base_url takes precedence
|
|
706
|
+
if self.config.base_url:
|
|
707
|
+
return self.config.base_url
|
|
708
|
+
|
|
709
|
+
# If no servers in spec, config must provide base_url
|
|
710
|
+
if not self.openapi.servers:
|
|
711
|
+
raise ValueError(
|
|
712
|
+
'No base url provided. Make sure you specify the base_url in the otterapi config or the OpenAPI document contains a valid servers section'
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
# Only support single server
|
|
716
|
+
if len(self.openapi.servers) > 1:
|
|
717
|
+
raise ValueError(
|
|
718
|
+
'Multiple servers are not supported. Set the base_url in the config.'
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
# TODO: handle server variables
|
|
722
|
+
baseurl = self.openapi.servers[0].url
|
|
723
|
+
|
|
724
|
+
if not baseurl:
|
|
725
|
+
raise ValueError(
|
|
726
|
+
'No base url provided. Make sure you specify the base_url in the otterapi config or the OpenAPI document contains a valid servers section'
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
# Check if the server URL is relative
|
|
730
|
+
if not self._is_absolute_url(baseurl):
|
|
731
|
+
# Try to resolve against the source URL if it's a URL
|
|
732
|
+
source = self.config.source
|
|
733
|
+
if self._is_absolute_url(source):
|
|
734
|
+
# Resolve relative server URL against the source URL
|
|
735
|
+
resolved_url = urljoin(source, baseurl)
|
|
736
|
+
logging.info(
|
|
737
|
+
f"Resolved relative server URL '{baseurl}' to '{resolved_url}' "
|
|
738
|
+
f"using source URL '{source}'"
|
|
739
|
+
)
|
|
740
|
+
return resolved_url
|
|
741
|
+
else:
|
|
742
|
+
# Source is a file path, can't resolve relative URL
|
|
743
|
+
raise ValueError(
|
|
744
|
+
f"Server URL '{baseurl}' is relative and cannot be resolved. "
|
|
745
|
+
f'The OpenAPI spec was loaded from a file, not a URL. '
|
|
746
|
+
f'Please specify an absolute base_url in the otterapi config.'
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
return baseurl
|
|
750
|
+
|
|
751
|
+
def _collect_used_model_names(self, endpoints: list[Endpoint]) -> set[str]:
|
|
752
|
+
"""Collect model names that are actually used in endpoint signatures.
|
|
753
|
+
|
|
754
|
+
Only collects models that have implementations (defined in models.py)
|
|
755
|
+
and are referenced in endpoint parameters, request bodies, or responses.
|
|
756
|
+
|
|
757
|
+
Args:
|
|
758
|
+
endpoints: List of Endpoint objects to check for model usage.
|
|
759
|
+
|
|
760
|
+
Returns:
|
|
761
|
+
Set of model names actually used in endpoints.
|
|
762
|
+
|
|
763
|
+
Note:
|
|
764
|
+
This method delegates to collect_used_model_names() from builders.model_collector.
|
|
765
|
+
"""
|
|
766
|
+
return collect_used_model_names(endpoints, self.typegen.types)
|
|
767
|
+
|
|
768
|
+
def _create_model_import(
|
|
769
|
+
self, models_file: UPath, model_names: set[str]
|
|
770
|
+
) -> ast.ImportFrom:
|
|
771
|
+
"""Create an import statement for models.
|
|
772
|
+
|
|
773
|
+
Args:
|
|
774
|
+
models_file: Path to the models file.
|
|
775
|
+
model_names: Set of model names to import.
|
|
776
|
+
|
|
777
|
+
Returns:
|
|
778
|
+
AST ImportFrom statement for the models.
|
|
779
|
+
"""
|
|
780
|
+
return ast.ImportFrom(
|
|
781
|
+
module=self.config.models_import_path or models_file.stem,
|
|
782
|
+
names=[ast.alias(name=name, asname=None) for name in sorted(model_names)],
|
|
783
|
+
level=1 if not self.config.models_import_path else 0, # relative import
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
def _build_endpoint_file_body(
|
|
787
|
+
self, baseurl: str, endpoints: list[Endpoint]
|
|
788
|
+
) -> tuple[list[ast.stmt], ImportCollector, set[str]]:
|
|
789
|
+
"""Build the body of the endpoints file with standalone functions.
|
|
790
|
+
|
|
791
|
+
Generates standalone functions with full implementations that use the Client.
|
|
792
|
+
|
|
793
|
+
Args:
|
|
794
|
+
baseurl: The base URL for API requests (unused, kept for API compat).
|
|
795
|
+
endpoints: List of Endpoint objects to include.
|
|
796
|
+
|
|
797
|
+
Returns:
|
|
798
|
+
Tuple of (body statements, import collector, endpoint names).
|
|
799
|
+
"""
|
|
800
|
+
from otterapi.codegen.endpoints import (
|
|
801
|
+
build_default_client_code,
|
|
802
|
+
build_standalone_dataframe_fn,
|
|
803
|
+
build_standalone_endpoint_fn,
|
|
804
|
+
build_standalone_paginated_dataframe_fn,
|
|
805
|
+
build_standalone_paginated_fn,
|
|
806
|
+
build_standalone_paginated_iter_fn,
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
body: list[ast.stmt] = []
|
|
810
|
+
import_collector = ImportCollector()
|
|
811
|
+
|
|
812
|
+
# Add default client variable and _get_client() function
|
|
813
|
+
client_stmts, client_imports = build_default_client_code()
|
|
814
|
+
body.extend(client_stmts)
|
|
815
|
+
import_collector.add_imports(client_imports)
|
|
816
|
+
|
|
817
|
+
# Track if we need DataFrame type hints
|
|
818
|
+
has_dataframe_methods = False
|
|
819
|
+
|
|
820
|
+
# Track if we need pagination imports
|
|
821
|
+
has_pagination_methods = False
|
|
822
|
+
|
|
823
|
+
# Add standalone endpoint functions
|
|
824
|
+
endpoint_names = set()
|
|
825
|
+
for endpoint in endpoints:
|
|
826
|
+
# Track whether paginated DataFrame methods were generated for this endpoint
|
|
827
|
+
generated_paginated_df = False
|
|
828
|
+
|
|
829
|
+
# Check if this endpoint has pagination configured
|
|
830
|
+
pag_config = None
|
|
831
|
+
if self.config.pagination.enabled:
|
|
832
|
+
pag_config = self._get_pagination_config(endpoint)
|
|
833
|
+
|
|
834
|
+
# Generate pagination methods if configured, otherwise regular functions
|
|
835
|
+
if pag_config:
|
|
836
|
+
has_pagination_methods = True
|
|
837
|
+
|
|
838
|
+
# Get item type from response type if it's a list
|
|
839
|
+
item_type_ast = self._get_item_type_ast(endpoint)
|
|
840
|
+
|
|
841
|
+
# Build pagination config dict
|
|
842
|
+
pag_dict = {
|
|
843
|
+
'offset_param': pag_config.offset_param,
|
|
844
|
+
'limit_param': pag_config.limit_param,
|
|
845
|
+
'cursor_param': pag_config.cursor_param,
|
|
846
|
+
'page_param': pag_config.page_param,
|
|
847
|
+
'per_page_param': pag_config.per_page_param,
|
|
848
|
+
'data_path': pag_config.data_path,
|
|
849
|
+
'total_path': pag_config.total_path,
|
|
850
|
+
'next_cursor_path': pag_config.next_cursor_path,
|
|
851
|
+
'total_pages_path': pag_config.total_pages_path,
|
|
852
|
+
'default_page_size': pag_config.default_page_size,
|
|
853
|
+
}
|
|
854
|
+
|
|
855
|
+
# Sync paginated function (replaces regular sync function)
|
|
856
|
+
pag_fn, pag_imports = build_standalone_paginated_fn(
|
|
857
|
+
fn_name=endpoint.sync_fn_name,
|
|
858
|
+
method=endpoint.method,
|
|
859
|
+
path=endpoint.path,
|
|
860
|
+
parameters=endpoint.parameters,
|
|
861
|
+
request_body_info=endpoint.request_body,
|
|
862
|
+
response_type=endpoint.response_type,
|
|
863
|
+
pagination_style=pag_config.style,
|
|
864
|
+
pagination_config=pag_dict,
|
|
865
|
+
item_type_ast=item_type_ast,
|
|
866
|
+
docs=endpoint.description,
|
|
867
|
+
is_async=False,
|
|
868
|
+
)
|
|
869
|
+
endpoint_names.add(endpoint.sync_fn_name)
|
|
870
|
+
body.append(pag_fn)
|
|
871
|
+
import_collector.add_imports(pag_imports)
|
|
872
|
+
|
|
873
|
+
# Async paginated function (replaces regular async function)
|
|
874
|
+
async_pag_fn, async_pag_imports = build_standalone_paginated_fn(
|
|
875
|
+
fn_name=endpoint.async_fn_name,
|
|
876
|
+
method=endpoint.method,
|
|
877
|
+
path=endpoint.path,
|
|
878
|
+
parameters=endpoint.parameters,
|
|
879
|
+
request_body_info=endpoint.request_body,
|
|
880
|
+
response_type=endpoint.response_type,
|
|
881
|
+
pagination_style=pag_config.style,
|
|
882
|
+
pagination_config=pag_dict,
|
|
883
|
+
item_type_ast=item_type_ast,
|
|
884
|
+
docs=endpoint.description,
|
|
885
|
+
is_async=True,
|
|
886
|
+
)
|
|
887
|
+
endpoint_names.add(endpoint.async_fn_name)
|
|
888
|
+
body.append(async_pag_fn)
|
|
889
|
+
import_collector.add_imports(async_pag_imports)
|
|
890
|
+
|
|
891
|
+
# Sync iterator function
|
|
892
|
+
iter_fn_name = f'{endpoint.sync_fn_name}_iter'
|
|
893
|
+
iter_fn, iter_imports = build_standalone_paginated_iter_fn(
|
|
894
|
+
fn_name=iter_fn_name,
|
|
895
|
+
method=endpoint.method,
|
|
896
|
+
path=endpoint.path,
|
|
897
|
+
parameters=endpoint.parameters,
|
|
898
|
+
request_body_info=endpoint.request_body,
|
|
899
|
+
response_type=endpoint.response_type,
|
|
900
|
+
pagination_style=pag_config.style,
|
|
901
|
+
pagination_config=pag_dict,
|
|
902
|
+
item_type_ast=item_type_ast,
|
|
903
|
+
docs=endpoint.description,
|
|
904
|
+
is_async=False,
|
|
905
|
+
)
|
|
906
|
+
endpoint_names.add(iter_fn_name)
|
|
907
|
+
body.append(iter_fn)
|
|
908
|
+
import_collector.add_imports(iter_imports)
|
|
909
|
+
|
|
910
|
+
# Async iterator function
|
|
911
|
+
async_iter_fn_name = f'{endpoint.async_fn_name}_iter'
|
|
912
|
+
async_iter_fn, async_iter_imports = build_standalone_paginated_iter_fn(
|
|
913
|
+
fn_name=async_iter_fn_name,
|
|
914
|
+
method=endpoint.method,
|
|
915
|
+
path=endpoint.path,
|
|
916
|
+
parameters=endpoint.parameters,
|
|
917
|
+
request_body_info=endpoint.request_body,
|
|
918
|
+
response_type=endpoint.response_type,
|
|
919
|
+
pagination_style=pag_config.style,
|
|
920
|
+
pagination_config=pag_dict,
|
|
921
|
+
item_type_ast=item_type_ast,
|
|
922
|
+
docs=endpoint.description,
|
|
923
|
+
is_async=True,
|
|
924
|
+
)
|
|
925
|
+
endpoint_names.add(async_iter_fn_name)
|
|
926
|
+
body.append(async_iter_fn)
|
|
927
|
+
import_collector.add_imports(async_iter_imports)
|
|
928
|
+
|
|
929
|
+
# Generate paginated DataFrame methods if dataframe is enabled
|
|
930
|
+
# For paginated endpoints, we know they return lists, so check config directly
|
|
931
|
+
if self.config.dataframe.enabled:
|
|
932
|
+
# Check if endpoint is explicitly disabled
|
|
933
|
+
endpoint_df_config = self.config.dataframe.endpoints.get(
|
|
934
|
+
endpoint.sync_fn_name
|
|
935
|
+
)
|
|
936
|
+
if endpoint_df_config and endpoint_df_config.enabled is False:
|
|
937
|
+
pass # Skip DataFrame generation for this endpoint
|
|
938
|
+
elif self.config.dataframe.pandas:
|
|
939
|
+
generated_paginated_df = True
|
|
940
|
+
has_dataframe_methods = True
|
|
941
|
+
has_pagination_methods = True
|
|
942
|
+
# Sync pandas paginated method
|
|
943
|
+
pandas_fn_name = f'{endpoint.sync_fn_name}_df'
|
|
944
|
+
pandas_fn, pandas_imports = (
|
|
945
|
+
build_standalone_paginated_dataframe_fn(
|
|
946
|
+
fn_name=pandas_fn_name,
|
|
947
|
+
method=endpoint.method,
|
|
948
|
+
path=endpoint.path,
|
|
949
|
+
parameters=endpoint.parameters,
|
|
950
|
+
request_body_info=endpoint.request_body,
|
|
951
|
+
response_type=endpoint.response_type,
|
|
952
|
+
pagination_style=pag_config.style,
|
|
953
|
+
pagination_config=pag_dict,
|
|
954
|
+
library='pandas',
|
|
955
|
+
item_type_ast=item_type_ast,
|
|
956
|
+
docs=endpoint.description,
|
|
957
|
+
is_async=False,
|
|
958
|
+
)
|
|
959
|
+
)
|
|
960
|
+
endpoint_names.add(pandas_fn_name)
|
|
961
|
+
body.append(pandas_fn)
|
|
962
|
+
import_collector.add_imports(pandas_imports)
|
|
963
|
+
|
|
964
|
+
# Async pandas paginated method
|
|
965
|
+
async_pandas_fn_name = f'{endpoint.async_fn_name}_df'
|
|
966
|
+
async_pandas_fn, async_pandas_imports = (
|
|
967
|
+
build_standalone_paginated_dataframe_fn(
|
|
968
|
+
fn_name=async_pandas_fn_name,
|
|
969
|
+
method=endpoint.method,
|
|
970
|
+
path=endpoint.path,
|
|
971
|
+
parameters=endpoint.parameters,
|
|
972
|
+
request_body_info=endpoint.request_body,
|
|
973
|
+
response_type=endpoint.response_type,
|
|
974
|
+
pagination_style=pag_config.style,
|
|
975
|
+
pagination_config=pag_dict,
|
|
976
|
+
library='pandas',
|
|
977
|
+
item_type_ast=item_type_ast,
|
|
978
|
+
docs=endpoint.description,
|
|
979
|
+
is_async=True,
|
|
980
|
+
)
|
|
981
|
+
)
|
|
982
|
+
endpoint_names.add(async_pandas_fn_name)
|
|
983
|
+
body.append(async_pandas_fn)
|
|
984
|
+
import_collector.add_imports(async_pandas_imports)
|
|
985
|
+
|
|
986
|
+
# Check for polars - use elif to skip if endpoint is disabled
|
|
987
|
+
if endpoint_df_config and endpoint_df_config.enabled is False:
|
|
988
|
+
pass # Skip polars DataFrame generation for this endpoint
|
|
989
|
+
elif self.config.dataframe.polars:
|
|
990
|
+
generated_paginated_df = True
|
|
991
|
+
has_dataframe_methods = True
|
|
992
|
+
has_pagination_methods = True
|
|
993
|
+
# Sync polars paginated method
|
|
994
|
+
polars_fn_name = f'{endpoint.sync_fn_name}_pl'
|
|
995
|
+
polars_fn, polars_imports = (
|
|
996
|
+
build_standalone_paginated_dataframe_fn(
|
|
997
|
+
fn_name=polars_fn_name,
|
|
998
|
+
method=endpoint.method,
|
|
999
|
+
path=endpoint.path,
|
|
1000
|
+
parameters=endpoint.parameters,
|
|
1001
|
+
request_body_info=endpoint.request_body,
|
|
1002
|
+
response_type=endpoint.response_type,
|
|
1003
|
+
pagination_style=pag_config.style,
|
|
1004
|
+
pagination_config=pag_dict,
|
|
1005
|
+
library='polars',
|
|
1006
|
+
item_type_ast=item_type_ast,
|
|
1007
|
+
docs=endpoint.description,
|
|
1008
|
+
is_async=False,
|
|
1009
|
+
)
|
|
1010
|
+
)
|
|
1011
|
+
endpoint_names.add(polars_fn_name)
|
|
1012
|
+
body.append(polars_fn)
|
|
1013
|
+
import_collector.add_imports(polars_imports)
|
|
1014
|
+
|
|
1015
|
+
# Async polars paginated method
|
|
1016
|
+
async_polars_fn_name = f'{endpoint.async_fn_name}_pl'
|
|
1017
|
+
async_polars_fn, async_polars_imports = (
|
|
1018
|
+
build_standalone_paginated_dataframe_fn(
|
|
1019
|
+
fn_name=async_polars_fn_name,
|
|
1020
|
+
method=endpoint.method,
|
|
1021
|
+
path=endpoint.path,
|
|
1022
|
+
parameters=endpoint.parameters,
|
|
1023
|
+
request_body_info=endpoint.request_body,
|
|
1024
|
+
response_type=endpoint.response_type,
|
|
1025
|
+
pagination_style=pag_config.style,
|
|
1026
|
+
pagination_config=pag_dict,
|
|
1027
|
+
library='polars',
|
|
1028
|
+
item_type_ast=item_type_ast,
|
|
1029
|
+
docs=endpoint.description,
|
|
1030
|
+
is_async=True,
|
|
1031
|
+
)
|
|
1032
|
+
)
|
|
1033
|
+
endpoint_names.add(async_polars_fn_name)
|
|
1034
|
+
body.append(async_polars_fn)
|
|
1035
|
+
import_collector.add_imports(async_polars_imports)
|
|
1036
|
+
else:
|
|
1037
|
+
# Build regular sync standalone function
|
|
1038
|
+
sync_fn, sync_imports = build_standalone_endpoint_fn(
|
|
1039
|
+
fn_name=endpoint.sync_fn_name,
|
|
1040
|
+
method=endpoint.method,
|
|
1041
|
+
path=endpoint.path,
|
|
1042
|
+
parameters=endpoint.parameters,
|
|
1043
|
+
request_body_info=endpoint.request_body,
|
|
1044
|
+
response_type=endpoint.response_type,
|
|
1045
|
+
response_infos=endpoint.response_infos,
|
|
1046
|
+
docs=endpoint.description,
|
|
1047
|
+
is_async=False,
|
|
1048
|
+
)
|
|
1049
|
+
endpoint_names.add(endpoint.sync_fn_name)
|
|
1050
|
+
body.append(sync_fn)
|
|
1051
|
+
import_collector.add_imports(sync_imports)
|
|
1052
|
+
|
|
1053
|
+
# Build regular async standalone function
|
|
1054
|
+
async_fn, async_imports = build_standalone_endpoint_fn(
|
|
1055
|
+
fn_name=endpoint.async_fn_name,
|
|
1056
|
+
method=endpoint.method,
|
|
1057
|
+
path=endpoint.path,
|
|
1058
|
+
parameters=endpoint.parameters,
|
|
1059
|
+
request_body_info=endpoint.request_body,
|
|
1060
|
+
response_type=endpoint.response_type,
|
|
1061
|
+
response_infos=endpoint.response_infos,
|
|
1062
|
+
docs=endpoint.description,
|
|
1063
|
+
is_async=True,
|
|
1064
|
+
)
|
|
1065
|
+
endpoint_names.add(endpoint.async_fn_name)
|
|
1066
|
+
body.append(async_fn)
|
|
1067
|
+
import_collector.add_imports(async_imports)
|
|
1068
|
+
|
|
1069
|
+
# Note: Pagination methods already handled above if pag_config exists
|
|
1070
|
+
# Skip to next endpoint since pagination methods already generated
|
|
1071
|
+
pass
|
|
1072
|
+
|
|
1073
|
+
# Generate DataFrame methods if configured
|
|
1074
|
+
# Skip if paginated DataFrame methods were already generated for this endpoint
|
|
1075
|
+
if self.config.dataframe.enabled and not generated_paginated_df:
|
|
1076
|
+
df_config = self._get_dataframe_config(endpoint)
|
|
1077
|
+
|
|
1078
|
+
if df_config.generate_pandas:
|
|
1079
|
+
has_dataframe_methods = True
|
|
1080
|
+
# Sync pandas method
|
|
1081
|
+
pandas_fn_name = f'{endpoint.sync_fn_name}_df'
|
|
1082
|
+
pandas_fn, pandas_imports = build_standalone_dataframe_fn(
|
|
1083
|
+
fn_name=pandas_fn_name,
|
|
1084
|
+
method=endpoint.method,
|
|
1085
|
+
path=endpoint.path,
|
|
1086
|
+
parameters=endpoint.parameters,
|
|
1087
|
+
request_body_info=endpoint.request_body,
|
|
1088
|
+
library='pandas',
|
|
1089
|
+
default_path=df_config.path,
|
|
1090
|
+
docs=endpoint.description,
|
|
1091
|
+
is_async=False,
|
|
1092
|
+
)
|
|
1093
|
+
endpoint_names.add(pandas_fn_name)
|
|
1094
|
+
body.append(pandas_fn)
|
|
1095
|
+
import_collector.add_imports(pandas_imports)
|
|
1096
|
+
|
|
1097
|
+
# Async pandas method
|
|
1098
|
+
async_pandas_fn_name = f'{endpoint.async_fn_name}_df'
|
|
1099
|
+
async_pandas_fn, async_pandas_imports = (
|
|
1100
|
+
build_standalone_dataframe_fn(
|
|
1101
|
+
fn_name=async_pandas_fn_name,
|
|
1102
|
+
method=endpoint.method,
|
|
1103
|
+
path=endpoint.path,
|
|
1104
|
+
parameters=endpoint.parameters,
|
|
1105
|
+
request_body_info=endpoint.request_body,
|
|
1106
|
+
library='pandas',
|
|
1107
|
+
default_path=df_config.path,
|
|
1108
|
+
docs=endpoint.description,
|
|
1109
|
+
is_async=True,
|
|
1110
|
+
)
|
|
1111
|
+
)
|
|
1112
|
+
endpoint_names.add(async_pandas_fn_name)
|
|
1113
|
+
body.append(async_pandas_fn)
|
|
1114
|
+
import_collector.add_imports(async_pandas_imports)
|
|
1115
|
+
|
|
1116
|
+
if df_config.generate_polars:
|
|
1117
|
+
has_dataframe_methods = True
|
|
1118
|
+
# Sync polars method
|
|
1119
|
+
polars_fn_name = f'{endpoint.sync_fn_name}_pl'
|
|
1120
|
+
polars_fn, polars_imports = build_standalone_dataframe_fn(
|
|
1121
|
+
fn_name=polars_fn_name,
|
|
1122
|
+
method=endpoint.method,
|
|
1123
|
+
path=endpoint.path,
|
|
1124
|
+
parameters=endpoint.parameters,
|
|
1125
|
+
request_body_info=endpoint.request_body,
|
|
1126
|
+
library='polars',
|
|
1127
|
+
default_path=df_config.path,
|
|
1128
|
+
docs=endpoint.description,
|
|
1129
|
+
is_async=False,
|
|
1130
|
+
)
|
|
1131
|
+
endpoint_names.add(polars_fn_name)
|
|
1132
|
+
body.append(polars_fn)
|
|
1133
|
+
import_collector.add_imports(polars_imports)
|
|
1134
|
+
|
|
1135
|
+
# Async polars method
|
|
1136
|
+
async_polars_fn_name = f'{endpoint.async_fn_name}_pl'
|
|
1137
|
+
async_polars_fn, async_polars_imports = (
|
|
1138
|
+
build_standalone_dataframe_fn(
|
|
1139
|
+
fn_name=async_polars_fn_name,
|
|
1140
|
+
method=endpoint.method,
|
|
1141
|
+
path=endpoint.path,
|
|
1142
|
+
parameters=endpoint.parameters,
|
|
1143
|
+
request_body_info=endpoint.request_body,
|
|
1144
|
+
library='polars',
|
|
1145
|
+
default_path=df_config.path,
|
|
1146
|
+
docs=endpoint.description,
|
|
1147
|
+
is_async=True,
|
|
1148
|
+
)
|
|
1149
|
+
)
|
|
1150
|
+
endpoint_names.add(async_polars_fn_name)
|
|
1151
|
+
body.append(async_polars_fn)
|
|
1152
|
+
import_collector.add_imports(async_polars_imports)
|
|
1153
|
+
|
|
1154
|
+
# Add TYPE_CHECKING block for DataFrame type hints if needed
|
|
1155
|
+
if has_dataframe_methods:
|
|
1156
|
+
import_collector.add_imports({'typing': {'TYPE_CHECKING'}})
|
|
1157
|
+
type_checking_block = ast.If(
|
|
1158
|
+
test=_name('TYPE_CHECKING'),
|
|
1159
|
+
body=[
|
|
1160
|
+
ast.Import(names=[ast.alias(name='pandas', asname='pd')]),
|
|
1161
|
+
ast.Import(names=[ast.alias(name='polars', asname='pl')]),
|
|
1162
|
+
],
|
|
1163
|
+
orelse=[],
|
|
1164
|
+
)
|
|
1165
|
+
body.insert(0, type_checking_block)
|
|
1166
|
+
|
|
1167
|
+
# Add dataframe helper imports
|
|
1168
|
+
dataframe_import = ast.ImportFrom(
|
|
1169
|
+
module='_dataframe',
|
|
1170
|
+
names=[
|
|
1171
|
+
ast.alias(name='to_pandas', asname=None),
|
|
1172
|
+
ast.alias(name='to_polars', asname=None),
|
|
1173
|
+
],
|
|
1174
|
+
level=1,
|
|
1175
|
+
)
|
|
1176
|
+
body.insert(0, dataframe_import)
|
|
1177
|
+
|
|
1178
|
+
# Add pagination imports if needed
|
|
1179
|
+
if has_pagination_methods:
|
|
1180
|
+
import_collector.add_imports(
|
|
1181
|
+
{'collections.abc': {'Iterator', 'AsyncIterator'}}
|
|
1182
|
+
)
|
|
1183
|
+
pagination_import = ast.ImportFrom(
|
|
1184
|
+
module='_pagination',
|
|
1185
|
+
names=[
|
|
1186
|
+
ast.alias(name='paginate_offset', asname=None),
|
|
1187
|
+
ast.alias(name='paginate_offset_async', asname=None),
|
|
1188
|
+
ast.alias(name='paginate_cursor', asname=None),
|
|
1189
|
+
ast.alias(name='paginate_cursor_async', asname=None),
|
|
1190
|
+
ast.alias(name='paginate_page', asname=None),
|
|
1191
|
+
ast.alias(name='paginate_page_async', asname=None),
|
|
1192
|
+
ast.alias(name='iterate_offset', asname=None),
|
|
1193
|
+
ast.alias(name='iterate_offset_async', asname=None),
|
|
1194
|
+
ast.alias(name='iterate_cursor', asname=None),
|
|
1195
|
+
ast.alias(name='iterate_cursor_async', asname=None),
|
|
1196
|
+
ast.alias(name='iterate_page', asname=None),
|
|
1197
|
+
ast.alias(name='iterate_page_async', asname=None),
|
|
1198
|
+
ast.alias(name='extract_path', asname=None),
|
|
1199
|
+
],
|
|
1200
|
+
level=1,
|
|
1201
|
+
)
|
|
1202
|
+
body.insert(0, pagination_import)
|
|
1203
|
+
|
|
1204
|
+
return body, import_collector, endpoint_names
|
|
1205
|
+
|
|
1206
|
+
def _generate_endpoint_file(
|
|
1207
|
+
self, path: UPath, models_file: UPath, endpoints: list[Endpoint]
|
|
1208
|
+
) -> None:
|
|
1209
|
+
"""Generate the endpoints Python file with delegating functions.
|
|
1210
|
+
|
|
1211
|
+
Args:
|
|
1212
|
+
path: Path where the endpoints file should be written.
|
|
1213
|
+
models_file: Path to the models file for import generation.
|
|
1214
|
+
endpoints: List of Endpoint objects to include.
|
|
1215
|
+
"""
|
|
1216
|
+
baseurl = self._resolve_base_url()
|
|
1217
|
+
|
|
1218
|
+
# Build file body and collect imports
|
|
1219
|
+
body, import_collector, endpoint_names = self._build_endpoint_file_body(
|
|
1220
|
+
baseurl, endpoints
|
|
1221
|
+
)
|
|
1222
|
+
|
|
1223
|
+
# Add __all__ export
|
|
1224
|
+
body.insert(0, _all(sorted(endpoint_names)))
|
|
1225
|
+
|
|
1226
|
+
# Add model imports only for models actually used in endpoints
|
|
1227
|
+
model_names = self._collect_used_model_names(endpoints)
|
|
1228
|
+
if model_names:
|
|
1229
|
+
model_import = self._create_model_import(models_file, model_names)
|
|
1230
|
+
body.insert(0, model_import)
|
|
1231
|
+
|
|
1232
|
+
# Add Client import (relative import from same directory)
|
|
1233
|
+
client_import = ast.ImportFrom(
|
|
1234
|
+
module='client',
|
|
1235
|
+
names=[ast.alias(name='Client', asname=None)],
|
|
1236
|
+
level=1,
|
|
1237
|
+
)
|
|
1238
|
+
body.insert(0, client_import)
|
|
1239
|
+
|
|
1240
|
+
# Add all other imports at the beginning
|
|
1241
|
+
for import_stmt in import_collector.to_ast():
|
|
1242
|
+
body.insert(0, import_stmt)
|
|
1243
|
+
|
|
1244
|
+
write_mod(body, path)
|
|
1245
|
+
|
|
1246
|
+
def _generate_models_file(self, path: UPath) -> None:
|
|
1247
|
+
"""Generate the models Python file with Pydantic models.
|
|
1248
|
+
|
|
1249
|
+
Args:
|
|
1250
|
+
path: Path where the models file should be written.
|
|
1251
|
+
"""
|
|
1252
|
+
assert self.typegen is not None
|
|
1253
|
+
|
|
1254
|
+
body: list[ast.stmt] = []
|
|
1255
|
+
import_collector = ImportCollector()
|
|
1256
|
+
all_names = set()
|
|
1257
|
+
|
|
1258
|
+
for type_name, type_ in self.typegen.types.items():
|
|
1259
|
+
if type_.implementation_ast:
|
|
1260
|
+
body.append(type_.implementation_ast)
|
|
1261
|
+
if type_.name:
|
|
1262
|
+
all_names.add(type_.name)
|
|
1263
|
+
|
|
1264
|
+
# Collect imports from implementation and annotations
|
|
1265
|
+
import_collector.add_imports(type_.implementation_imports)
|
|
1266
|
+
import_collector.add_imports(type_.annotation_imports)
|
|
1267
|
+
|
|
1268
|
+
# Add __all__ export
|
|
1269
|
+
body.insert(0, _all(sorted(all_names)))
|
|
1270
|
+
|
|
1271
|
+
# Add all imports at the beginning
|
|
1272
|
+
for import_stmt in import_collector.to_ast():
|
|
1273
|
+
body.insert(0, import_stmt)
|
|
1274
|
+
|
|
1275
|
+
write_mod(body, path)
|
|
1276
|
+
|
|
1277
|
+
def generate(self):
|
|
1278
|
+
self._load_schema()
|
|
1279
|
+
|
|
1280
|
+
assert self.openapi is not None
|
|
1281
|
+
|
|
1282
|
+
if not self.openapi.paths:
|
|
1283
|
+
raise ValueError('OpenAPI spec has no paths to generate endpoints from')
|
|
1284
|
+
|
|
1285
|
+
directory = UPath(self.config.output)
|
|
1286
|
+
directory.mkdir(parents=True, exist_ok=True)
|
|
1287
|
+
|
|
1288
|
+
if not os.access(str(directory), os.W_OK):
|
|
1289
|
+
raise RuntimeError(f'Directory {directory} is not writable')
|
|
1290
|
+
|
|
1291
|
+
generated_files: list[str] = []
|
|
1292
|
+
output_name = self.config.output
|
|
1293
|
+
|
|
1294
|
+
endpoints = self._generate_endpoints()
|
|
1295
|
+
|
|
1296
|
+
models_file = directory / self.config.models_file
|
|
1297
|
+
self._generate_models_file(models_file)
|
|
1298
|
+
generated_files.append(f'{output_name}/{self.config.models_file}')
|
|
1299
|
+
|
|
1300
|
+
base_url = self._resolve_base_url()
|
|
1301
|
+
|
|
1302
|
+
# Generate pagination module if enabled
|
|
1303
|
+
if self.config.pagination.enabled:
|
|
1304
|
+
generate_pagination_module(directory)
|
|
1305
|
+
generated_files.append(f'{output_name}/_pagination.py')
|
|
1306
|
+
|
|
1307
|
+
# Generate client class
|
|
1308
|
+
client_name = self._get_client_class_name()
|
|
1309
|
+
|
|
1310
|
+
# Check if module splitting is enabled
|
|
1311
|
+
if self.config.module_split.enabled:
|
|
1312
|
+
split_files = self._generate_split_endpoints(
|
|
1313
|
+
directory, models_file, endpoints, base_url, client_name
|
|
1314
|
+
)
|
|
1315
|
+
generated_files.extend(split_files)
|
|
1316
|
+
else:
|
|
1317
|
+
# Original single-file generation
|
|
1318
|
+
endpoints_file = directory / self.config.endpoints_file
|
|
1319
|
+
self._generate_endpoint_file(endpoints_file, models_file, endpoints)
|
|
1320
|
+
generated_files.append(f'{output_name}/{self.config.endpoints_file}')
|
|
1321
|
+
|
|
1322
|
+
client_files = self._generate_client_file(
|
|
1323
|
+
directory, endpoints, base_url, client_name
|
|
1324
|
+
)
|
|
1325
|
+
generated_files.extend(client_files)
|
|
1326
|
+
|
|
1327
|
+
# Write __init__.py only if not using module splitting (splitting handles its own __init__.py)
|
|
1328
|
+
if not self.config.module_split.enabled:
|
|
1329
|
+
self._generate_init_file(directory, endpoints, client_name)
|
|
1330
|
+
generated_files.append(f'{output_name}/__init__.py')
|
|
1331
|
+
|
|
1332
|
+
return generated_files
|
|
1333
|
+
|
|
1334
|
+
def _generate_init_file(
|
|
1335
|
+
self,
|
|
1336
|
+
directory: UPath,
|
|
1337
|
+
endpoints: list[Endpoint],
|
|
1338
|
+
client_class_name: str,
|
|
1339
|
+
) -> None:
|
|
1340
|
+
"""Generate __init__.py with all exports for non-split mode.
|
|
1341
|
+
|
|
1342
|
+
Args:
|
|
1343
|
+
directory: Output directory.
|
|
1344
|
+
endpoints: List of Endpoint objects.
|
|
1345
|
+
client_class_name: Name of the client class.
|
|
1346
|
+
"""
|
|
1347
|
+
body: list[ast.stmt] = []
|
|
1348
|
+
all_names: list[str] = []
|
|
1349
|
+
|
|
1350
|
+
# Get endpoint names (including DataFrame methods if configured)
|
|
1351
|
+
endpoint_names = []
|
|
1352
|
+
for endpoint in endpoints:
|
|
1353
|
+
endpoint_names.append(endpoint.sync_fn_name)
|
|
1354
|
+
endpoint_names.append(endpoint.async_fn_name)
|
|
1355
|
+
|
|
1356
|
+
# Check pagination config for this endpoint
|
|
1357
|
+
pag_config = None
|
|
1358
|
+
if self.config.pagination.enabled:
|
|
1359
|
+
pag_config = self._get_pagination_config(endpoint)
|
|
1360
|
+
|
|
1361
|
+
# Add pagination method names if configured
|
|
1362
|
+
if pag_config:
|
|
1363
|
+
endpoint_names.append(f'{endpoint.sync_fn_name}_iter')
|
|
1364
|
+
endpoint_names.append(f'{endpoint.async_fn_name}_iter')
|
|
1365
|
+
|
|
1366
|
+
# Add DataFrame method names if configured
|
|
1367
|
+
if self.config.dataframe.enabled:
|
|
1368
|
+
# For paginated endpoints, DataFrame methods are generated regardless
|
|
1369
|
+
# of whether the original response type is a list
|
|
1370
|
+
is_paginated = pag_config is not None
|
|
1371
|
+
|
|
1372
|
+
if is_paginated:
|
|
1373
|
+
# Check if endpoint is explicitly disabled
|
|
1374
|
+
endpoint_df_config = self.config.dataframe.endpoints.get(
|
|
1375
|
+
endpoint.sync_fn_name
|
|
1376
|
+
)
|
|
1377
|
+
if endpoint_df_config and endpoint_df_config.enabled is False:
|
|
1378
|
+
pass # Skip DataFrame exports for this endpoint
|
|
1379
|
+
else:
|
|
1380
|
+
if self.config.dataframe.pandas:
|
|
1381
|
+
endpoint_names.append(f'{endpoint.sync_fn_name}_df')
|
|
1382
|
+
endpoint_names.append(f'{endpoint.async_fn_name}_df')
|
|
1383
|
+
if self.config.dataframe.polars:
|
|
1384
|
+
endpoint_names.append(f'{endpoint.sync_fn_name}_pl')
|
|
1385
|
+
endpoint_names.append(f'{endpoint.async_fn_name}_pl')
|
|
1386
|
+
else:
|
|
1387
|
+
# For non-paginated endpoints, use the standard config check
|
|
1388
|
+
df_config = self._get_dataframe_config(endpoint)
|
|
1389
|
+
if df_config.generate_pandas:
|
|
1390
|
+
endpoint_names.append(f'{endpoint.sync_fn_name}_df')
|
|
1391
|
+
endpoint_names.append(f'{endpoint.async_fn_name}_df')
|
|
1392
|
+
if df_config.generate_polars:
|
|
1393
|
+
endpoint_names.append(f'{endpoint.sync_fn_name}_pl')
|
|
1394
|
+
endpoint_names.append(f'{endpoint.async_fn_name}_pl')
|
|
1395
|
+
|
|
1396
|
+
# Import endpoints from endpoints.py
|
|
1397
|
+
endpoints_file_stem = self.config.endpoints_file.replace('.py', '')
|
|
1398
|
+
if endpoint_names:
|
|
1399
|
+
body.append(
|
|
1400
|
+
ast.ImportFrom(
|
|
1401
|
+
module=endpoints_file_stem,
|
|
1402
|
+
names=[
|
|
1403
|
+
ast.alias(name=name, asname=None)
|
|
1404
|
+
for name in sorted(endpoint_names)
|
|
1405
|
+
],
|
|
1406
|
+
level=1,
|
|
1407
|
+
)
|
|
1408
|
+
)
|
|
1409
|
+
all_names.extend(endpoint_names)
|
|
1410
|
+
|
|
1411
|
+
# Import Client from client.py
|
|
1412
|
+
body.append(
|
|
1413
|
+
ast.ImportFrom(
|
|
1414
|
+
module='client',
|
|
1415
|
+
names=[ast.alias(name='Client', asname=None)],
|
|
1416
|
+
level=1,
|
|
1417
|
+
)
|
|
1418
|
+
)
|
|
1419
|
+
all_names.append('Client')
|
|
1420
|
+
|
|
1421
|
+
# Import BaseClient from _client.py
|
|
1422
|
+
base_client_name = f'Base{client_class_name}'
|
|
1423
|
+
body.append(
|
|
1424
|
+
ast.ImportFrom(
|
|
1425
|
+
module='_client',
|
|
1426
|
+
names=[ast.alias(name=base_client_name, asname=None)],
|
|
1427
|
+
level=1,
|
|
1428
|
+
)
|
|
1429
|
+
)
|
|
1430
|
+
all_names.append(base_client_name)
|
|
1431
|
+
|
|
1432
|
+
# Also get all model names from typegen
|
|
1433
|
+
all_model_names = {
|
|
1434
|
+
type_.name
|
|
1435
|
+
for type_ in self.typegen.types.values()
|
|
1436
|
+
if type_.name and type_.implementation_ast
|
|
1437
|
+
}
|
|
1438
|
+
if all_model_names:
|
|
1439
|
+
body.append(
|
|
1440
|
+
ast.ImportFrom(
|
|
1441
|
+
module=self.config.models_file.replace('.py', ''),
|
|
1442
|
+
names=[
|
|
1443
|
+
ast.alias(name=name, asname=None)
|
|
1444
|
+
for name in sorted(all_model_names)
|
|
1445
|
+
],
|
|
1446
|
+
level=1,
|
|
1447
|
+
)
|
|
1448
|
+
)
|
|
1449
|
+
all_names.extend(all_model_names)
|
|
1450
|
+
|
|
1451
|
+
# Add __all__ at the beginning
|
|
1452
|
+
body.insert(0, _all(sorted(set(all_names))))
|
|
1453
|
+
|
|
1454
|
+
# Write __init__.py
|
|
1455
|
+
init_path = directory / '__init__.py'
|
|
1456
|
+
write_mod(body, init_path)
|
|
1457
|
+
|
|
1458
|
+
def _generate_split_endpoints(
|
|
1459
|
+
self,
|
|
1460
|
+
directory: UPath,
|
|
1461
|
+
models_file: UPath,
|
|
1462
|
+
endpoints: list[Endpoint],
|
|
1463
|
+
base_url: str,
|
|
1464
|
+
client_class_name: str,
|
|
1465
|
+
) -> list[str]:
|
|
1466
|
+
"""Generate split endpoint modules based on configuration.
|
|
1467
|
+
|
|
1468
|
+
Args:
|
|
1469
|
+
directory: Output directory.
|
|
1470
|
+
models_file: Path to the models file.
|
|
1471
|
+
endpoints: List of Endpoint objects.
|
|
1472
|
+
base_url: The base URL for API requests.
|
|
1473
|
+
client_class_name: Name of the client class (e.g., 'SwaggerPetstoreOpenAPI30Client').
|
|
1474
|
+
|
|
1475
|
+
Returns:
|
|
1476
|
+
List of relative paths to generated files.
|
|
1477
|
+
"""
|
|
1478
|
+
from otterapi.codegen.splitting import (
|
|
1479
|
+
ModuleTreeBuilder,
|
|
1480
|
+
SplitModuleEmitter,
|
|
1481
|
+
)
|
|
1482
|
+
|
|
1483
|
+
# Build the module tree
|
|
1484
|
+
builder = ModuleTreeBuilder(self.config.module_split)
|
|
1485
|
+
tree = builder.build(endpoints)
|
|
1486
|
+
|
|
1487
|
+
# Emit the split modules
|
|
1488
|
+
emitter = SplitModuleEmitter(
|
|
1489
|
+
config=self.config.module_split,
|
|
1490
|
+
output_dir=directory,
|
|
1491
|
+
models_file=models_file,
|
|
1492
|
+
models_import_path=self.config.models_import_path,
|
|
1493
|
+
client_class_name=client_class_name,
|
|
1494
|
+
dataframe_config=self.config.dataframe,
|
|
1495
|
+
pagination_config=self.config.pagination,
|
|
1496
|
+
)
|
|
1497
|
+
|
|
1498
|
+
emitted = emitter.emit(
|
|
1499
|
+
tree=tree,
|
|
1500
|
+
base_url=base_url,
|
|
1501
|
+
typegen_types=self.typegen.types,
|
|
1502
|
+
)
|
|
1503
|
+
|
|
1504
|
+
# Collect generated file paths
|
|
1505
|
+
output_name = self.config.output
|
|
1506
|
+
generated_files = []
|
|
1507
|
+
for module in emitted:
|
|
1508
|
+
rel_path = str(module.path.relative_to(directory))
|
|
1509
|
+
generated_files.append(f'{output_name}/{rel_path}')
|
|
1510
|
+
|
|
1511
|
+
# Add __init__.py files
|
|
1512
|
+
generated_files.append(f'{output_name}/__init__.py')
|
|
1513
|
+
|
|
1514
|
+
return generated_files
|
|
1515
|
+
|
|
1516
|
+
def _get_client_class_name(self) -> str:
|
|
1517
|
+
"""Get the client class name from config or derive from API title."""
|
|
1518
|
+
if self.config.client_class_name:
|
|
1519
|
+
return self.config.client_class_name
|
|
1520
|
+
|
|
1521
|
+
# Derive from API title
|
|
1522
|
+
if self.openapi and self.openapi.info and self.openapi.info.title:
|
|
1523
|
+
title = self.openapi.info.title
|
|
1524
|
+
# Convert to PascalCase and add Client suffix
|
|
1525
|
+
name = sanitize_identifier(title)
|
|
1526
|
+
if not name.endswith('Client'):
|
|
1527
|
+
name = f'{name}Client'
|
|
1528
|
+
return name
|
|
1529
|
+
|
|
1530
|
+
return 'APIClient'
|
|
1531
|
+
|
|
1532
|
+
def _generate_client_file(
|
|
1533
|
+
self,
|
|
1534
|
+
directory: UPath,
|
|
1535
|
+
endpoints: list[Endpoint],
|
|
1536
|
+
base_url: str,
|
|
1537
|
+
client_name: str,
|
|
1538
|
+
) -> list[str]:
|
|
1539
|
+
"""Generate the client class files.
|
|
1540
|
+
|
|
1541
|
+
Generates:
|
|
1542
|
+
- _client.py: Always regenerated, contains BaseClient class
|
|
1543
|
+
- client.py: Created once if missing, user can customize
|
|
1544
|
+
- _dataframe.py: Generated if DataFrame methods are enabled
|
|
1545
|
+
|
|
1546
|
+
Args:
|
|
1547
|
+
directory: Output directory.
|
|
1548
|
+
endpoints: List of Endpoint objects.
|
|
1549
|
+
base_url: Default base URL from spec.
|
|
1550
|
+
client_name: Name for the client class.
|
|
1551
|
+
|
|
1552
|
+
Returns:
|
|
1553
|
+
List of relative paths to generated files.
|
|
1554
|
+
"""
|
|
1555
|
+
base_client_name = f'Base{client_name}'
|
|
1556
|
+
|
|
1557
|
+
# Convert endpoints to EndpointInfo for client generation
|
|
1558
|
+
endpoint_infos = self._endpoints_to_info(endpoints)
|
|
1559
|
+
|
|
1560
|
+
# Check if any endpoint has DataFrame methods
|
|
1561
|
+
has_dataframe_methods = any(
|
|
1562
|
+
ep.dataframe_config.generate_pandas or ep.dataframe_config.generate_polars
|
|
1563
|
+
for ep in endpoint_infos
|
|
1564
|
+
)
|
|
1565
|
+
|
|
1566
|
+
# Also check if pagination + dataframe is enabled
|
|
1567
|
+
# Paginated endpoints always get dataframe methods if dataframe is enabled,
|
|
1568
|
+
# regardless of whether the original endpoint returns a list
|
|
1569
|
+
if (
|
|
1570
|
+
not has_dataframe_methods
|
|
1571
|
+
and self.config.pagination.enabled
|
|
1572
|
+
and self.config.dataframe.enabled
|
|
1573
|
+
and (self.config.dataframe.pandas or self.config.dataframe.polars)
|
|
1574
|
+
):
|
|
1575
|
+
for endpoint in endpoints:
|
|
1576
|
+
pag_config = self._get_pagination_config(endpoint)
|
|
1577
|
+
if pag_config:
|
|
1578
|
+
# Check if endpoint is not explicitly disabled for dataframe
|
|
1579
|
+
endpoint_df_config = self.config.dataframe.endpoints.get(
|
|
1580
|
+
endpoint.sync_fn_name
|
|
1581
|
+
)
|
|
1582
|
+
if not (endpoint_df_config and endpoint_df_config.enabled is False):
|
|
1583
|
+
has_dataframe_methods = True
|
|
1584
|
+
break
|
|
1585
|
+
|
|
1586
|
+
output_name = self.config.output
|
|
1587
|
+
generated_files = []
|
|
1588
|
+
|
|
1589
|
+
# Generate _dataframe.py if needed
|
|
1590
|
+
if has_dataframe_methods:
|
|
1591
|
+
generate_dataframe_module(directory)
|
|
1592
|
+
generated_files.append(f'{output_name}/_dataframe.py')
|
|
1593
|
+
|
|
1594
|
+
# Generate base client class (infrastructure only, no endpoint methods)
|
|
1595
|
+
class_ast, client_imports = generate_base_client_class(
|
|
1596
|
+
class_name=base_client_name,
|
|
1597
|
+
default_base_url=base_url,
|
|
1598
|
+
default_timeout=30.0,
|
|
1599
|
+
)
|
|
1600
|
+
|
|
1601
|
+
# Build the _client.py file
|
|
1602
|
+
body: list[ast.stmt] = []
|
|
1603
|
+
import_collector = ImportCollector()
|
|
1604
|
+
import_collector.add_imports(client_imports)
|
|
1605
|
+
|
|
1606
|
+
# No model imports needed in _client.py anymore - models are imported in module files
|
|
1607
|
+
|
|
1608
|
+
# Add other imports
|
|
1609
|
+
for import_stmt in import_collector.to_ast():
|
|
1610
|
+
body.insert(0, import_stmt)
|
|
1611
|
+
|
|
1612
|
+
# Add TypeVar definition: T = TypeVar('T')
|
|
1613
|
+
typevar_def = _assign(
|
|
1614
|
+
_name('T'),
|
|
1615
|
+
_call(
|
|
1616
|
+
func=_name('TypeVar'),
|
|
1617
|
+
args=[ast.Constant(value='T')],
|
|
1618
|
+
),
|
|
1619
|
+
)
|
|
1620
|
+
body.append(typevar_def)
|
|
1621
|
+
|
|
1622
|
+
# Add __all__ export (include APIError)
|
|
1623
|
+
body.append(_all([base_client_name, 'APIError']))
|
|
1624
|
+
|
|
1625
|
+
# Add APIError class
|
|
1626
|
+
api_error_class = generate_api_error_class()
|
|
1627
|
+
body.append(api_error_class)
|
|
1628
|
+
|
|
1629
|
+
# Add the client class
|
|
1630
|
+
body.append(class_ast)
|
|
1631
|
+
|
|
1632
|
+
# Write _client.py (always regenerated)
|
|
1633
|
+
client_file = directory / '_client.py'
|
|
1634
|
+
write_mod(body, client_file)
|
|
1635
|
+
generated_files.append(f'{output_name}/_client.py')
|
|
1636
|
+
|
|
1637
|
+
# Generate client.py stub (only if it doesn't exist)
|
|
1638
|
+
user_client_file = directory / 'client.py'
|
|
1639
|
+
if not user_client_file.exists():
|
|
1640
|
+
stub_content = generate_client_stub(
|
|
1641
|
+
class_name=client_name,
|
|
1642
|
+
base_class_name=base_client_name,
|
|
1643
|
+
module_name='_client',
|
|
1644
|
+
)
|
|
1645
|
+
user_client_file.write_text(stub_content)
|
|
1646
|
+
generated_files.append(f'{output_name}/client.py')
|
|
1647
|
+
|
|
1648
|
+
return generated_files
|
|
1649
|
+
|
|
1650
|
+
def _endpoints_to_info(self, endpoints: list[Endpoint]) -> list[EndpointInfo]:
|
|
1651
|
+
"""Convert Endpoint objects to EndpointInfo for client generation."""
|
|
1652
|
+
infos = []
|
|
1653
|
+
for ep in endpoints:
|
|
1654
|
+
# Determine DataFrame configuration for this endpoint
|
|
1655
|
+
dataframe_config = self._get_dataframe_config(ep)
|
|
1656
|
+
|
|
1657
|
+
info = EndpointInfo(
|
|
1658
|
+
name=ep.fn.name,
|
|
1659
|
+
async_name=ep.async_fn.name,
|
|
1660
|
+
method=ep.method,
|
|
1661
|
+
path=ep.path,
|
|
1662
|
+
parameters=ep.parameters,
|
|
1663
|
+
request_body=ep.request_body,
|
|
1664
|
+
response_type=ep.response_type,
|
|
1665
|
+
response_infos=ep.response_infos,
|
|
1666
|
+
description=ep.description,
|
|
1667
|
+
dataframe_config=dataframe_config,
|
|
1668
|
+
)
|
|
1669
|
+
infos.append(info)
|
|
1670
|
+
return infos
|
|
1671
|
+
|
|
1672
|
+
def _get_dataframe_config(self, endpoint: Endpoint) -> DataFrameMethodConfig:
|
|
1673
|
+
"""Get the DataFrame method configuration for an endpoint.
|
|
1674
|
+
|
|
1675
|
+
Args:
|
|
1676
|
+
endpoint: The endpoint to check.
|
|
1677
|
+
|
|
1678
|
+
Returns:
|
|
1679
|
+
DataFrameMethodConfig with generation flags and path.
|
|
1680
|
+
|
|
1681
|
+
Note:
|
|
1682
|
+
This method delegates to get_dataframe_config_for_endpoint() from dataframe_utils.
|
|
1683
|
+
"""
|
|
1684
|
+
return get_dataframe_config_for_endpoint(endpoint, self.config.dataframe)
|
|
1685
|
+
|
|
1686
|
+
def _endpoint_returns_list(self, endpoint: Endpoint) -> bool:
|
|
1687
|
+
"""Check if an endpoint returns a list type.
|
|
1688
|
+
|
|
1689
|
+
Args:
|
|
1690
|
+
endpoint: The endpoint to check.
|
|
1691
|
+
|
|
1692
|
+
Returns:
|
|
1693
|
+
True if the endpoint returns a list, False otherwise.
|
|
1694
|
+
|
|
1695
|
+
Note:
|
|
1696
|
+
This method delegates to endpoint_returns_list() from dataframe_utils.
|
|
1697
|
+
"""
|
|
1698
|
+
return endpoint_returns_list(endpoint)
|
|
1699
|
+
|
|
1700
|
+
def _get_pagination_config(
|
|
1701
|
+
self, endpoint: Endpoint
|
|
1702
|
+
) -> PaginationMethodConfig | None:
|
|
1703
|
+
"""Get the pagination method configuration for an endpoint.
|
|
1704
|
+
|
|
1705
|
+
Args:
|
|
1706
|
+
endpoint: The endpoint to check.
|
|
1707
|
+
|
|
1708
|
+
Returns:
|
|
1709
|
+
PaginationMethodConfig if pagination is configured, None otherwise.
|
|
1710
|
+
"""
|
|
1711
|
+
return get_pagination_config_for_endpoint(
|
|
1712
|
+
endpoint.sync_fn_name,
|
|
1713
|
+
self.config.pagination,
|
|
1714
|
+
endpoint.parameters,
|
|
1715
|
+
)
|
|
1716
|
+
|
|
1717
|
+
def _get_item_type_ast(self, endpoint: Endpoint) -> ast.expr | None:
|
|
1718
|
+
"""Extract the item type AST from a list response type.
|
|
1719
|
+
|
|
1720
|
+
For example, if response_type is list[User], returns the AST for User.
|
|
1721
|
+
|
|
1722
|
+
Args:
|
|
1723
|
+
endpoint: The endpoint to check.
|
|
1724
|
+
|
|
1725
|
+
Returns:
|
|
1726
|
+
The AST expression for the item type, or None if not a list type.
|
|
1727
|
+
"""
|
|
1728
|
+
if not endpoint.response_type or not endpoint.response_type.annotation_ast:
|
|
1729
|
+
return None
|
|
1730
|
+
|
|
1731
|
+
ann = endpoint.response_type.annotation_ast
|
|
1732
|
+
if isinstance(ann, ast.Subscript):
|
|
1733
|
+
if isinstance(ann.value, ast.Name) and ann.value.id == 'list':
|
|
1734
|
+
return ann.slice
|
|
1735
|
+
|
|
1736
|
+
return None
|