pyopenapi-gen 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyopenapi_gen/cli.py +5 -22
- pyopenapi_gen/context/import_collector.py +8 -8
- pyopenapi_gen/core/loader/operations/parser.py +1 -1
- pyopenapi_gen/core/parsing/context.py +2 -1
- pyopenapi_gen/core/parsing/cycle_helpers.py +1 -1
- pyopenapi_gen/core/parsing/keywords/properties_parser.py +4 -4
- pyopenapi_gen/core/parsing/schema_parser.py +4 -4
- pyopenapi_gen/core/parsing/transformers/inline_enum_extractor.py +1 -1
- pyopenapi_gen/core/postprocess_manager.py +39 -13
- pyopenapi_gen/core/schemas.py +101 -16
- pyopenapi_gen/core/writers/python_construct_renderer.py +57 -9
- pyopenapi_gen/emitters/endpoints_emitter.py +1 -1
- pyopenapi_gen/helpers/endpoint_utils.py +4 -22
- pyopenapi_gen/helpers/type_cleaner.py +1 -1
- pyopenapi_gen/helpers/type_resolution/composition_resolver.py +1 -1
- pyopenapi_gen/helpers/type_resolution/finalizer.py +1 -1
- pyopenapi_gen/types/contracts/types.py +0 -1
- pyopenapi_gen/types/resolvers/response_resolver.py +5 -33
- pyopenapi_gen/types/resolvers/schema_resolver.py +2 -2
- pyopenapi_gen/types/services/type_service.py +0 -18
- pyopenapi_gen/types/strategies/__init__.py +5 -0
- pyopenapi_gen/types/strategies/response_strategy.py +187 -0
- pyopenapi_gen/visit/endpoint/endpoint_visitor.py +1 -20
- pyopenapi_gen/visit/endpoint/generators/docstring_generator.py +5 -3
- pyopenapi_gen/visit/endpoint/generators/endpoint_method_generator.py +12 -6
- pyopenapi_gen/visit/endpoint/generators/response_handler_generator.py +352 -343
- pyopenapi_gen/visit/endpoint/generators/signature_generator.py +7 -4
- pyopenapi_gen/visit/endpoint/processors/import_analyzer.py +4 -2
- pyopenapi_gen/visit/endpoint/processors/parameter_processor.py +1 -1
- pyopenapi_gen/visit/model/dataclass_generator.py +32 -1
- pyopenapi_gen-0.8.5.dist-info/METADATA +383 -0
- {pyopenapi_gen-0.8.3.dist-info → pyopenapi_gen-0.8.5.dist-info}/RECORD +35 -33
- pyopenapi_gen-0.8.3.dist-info/METADATA +0 -224
- {pyopenapi_gen-0.8.3.dist-info → pyopenapi_gen-0.8.5.dist-info}/WHEEL +0 -0
- {pyopenapi_gen-0.8.3.dist-info → pyopenapi_gen-0.8.5.dist-info}/entry_points.txt +0 -0
- {pyopenapi_gen-0.8.3.dist-info → pyopenapi_gen-0.8.5.dist-info}/licenses/LICENSE +0 -0
@@ -5,23 +5,21 @@ Helper class for generating response handling logic for an endpoint method.
|
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
7
|
import logging
|
8
|
-
import re # For parsing Union types, etc.
|
9
8
|
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
10
9
|
|
11
10
|
from pyopenapi_gen.core.writers.code_writer import CodeWriter
|
12
11
|
from pyopenapi_gen.helpers.endpoint_utils import (
|
13
12
|
_get_primary_response,
|
14
|
-
get_return_type_unified,
|
15
|
-
get_type_for_specific_response, # Added new helper
|
16
13
|
)
|
17
14
|
from pyopenapi_gen.types.services.type_service import UnifiedTypeService
|
15
|
+
from pyopenapi_gen.types.strategies.response_strategy import ResponseStrategy
|
18
16
|
|
19
17
|
if TYPE_CHECKING:
|
20
|
-
from pyopenapi_gen import IROperation, IRResponse
|
18
|
+
from pyopenapi_gen import IROperation, IRResponse, IRSchema
|
21
19
|
from pyopenapi_gen.context.render_context import RenderContext
|
22
20
|
else:
|
23
21
|
# For runtime, we need to import for TypedDict
|
24
|
-
from pyopenapi_gen import IRResponse
|
22
|
+
from pyopenapi_gen import IRResponse, IRSchema
|
25
23
|
|
26
24
|
logger = logging.getLogger(__name__)
|
27
25
|
|
@@ -32,7 +30,6 @@ class StatusCase(TypedDict):
|
|
32
30
|
status_code: int
|
33
31
|
type: str # 'primary_success', 'success', or 'error'
|
34
32
|
return_type: str
|
35
|
-
needs_unwrap: bool
|
36
33
|
response_ir: IRResponse
|
37
34
|
|
38
35
|
|
@@ -41,7 +38,6 @@ class DefaultCase(TypedDict):
|
|
41
38
|
|
42
39
|
response_ir: IRResponse
|
43
40
|
return_type: str
|
44
|
-
needs_unwrap: bool
|
45
41
|
|
46
42
|
|
47
43
|
class EndpointResponseHandlerGenerator:
|
@@ -50,12 +46,189 @@ class EndpointResponseHandlerGenerator:
|
|
50
46
|
def __init__(self, schemas: Optional[Dict[str, Any]] = None) -> None:
|
51
47
|
self.schemas: Dict[str, Any] = schemas or {}
|
52
48
|
|
49
|
+
def _is_type_alias_to_array(self, type_name: str) -> bool:
|
50
|
+
"""
|
51
|
+
Check if a type name corresponds to a type alias that resolves to a List/array type.
|
52
|
+
|
53
|
+
This helps distinguish between:
|
54
|
+
- Type aliases: AgentHistoryListResponse = List[AgentHistory] (should use array deserialization)
|
55
|
+
- Dataclasses: class AgentHistoryListResponse(BaseSchema): ... (should use .from_dict())
|
56
|
+
|
57
|
+
Args:
|
58
|
+
type_name: The Python type name (e.g., "AgentHistoryListResponse")
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
True if this is a type alias that resolves to List[SomeType]
|
62
|
+
"""
|
63
|
+
# Extract base type name without generics
|
64
|
+
base_type = type_name
|
65
|
+
if "[" in base_type:
|
66
|
+
base_type = base_type[: base_type.find("[")]
|
67
|
+
|
68
|
+
# Look up the schema for this type name
|
69
|
+
if base_type in self.schemas:
|
70
|
+
schema = self.schemas[base_type]
|
71
|
+
# Check if it's a type alias using ModelVisitor's logic:
|
72
|
+
# - Has a name
|
73
|
+
# - No properties (not an object with fields)
|
74
|
+
# - Not an enum
|
75
|
+
# - Type is not "object" (which would be a dataclass)
|
76
|
+
# - Type is "array" (indicating it's an array type alias)
|
77
|
+
is_type_alias = bool(
|
78
|
+
getattr(schema, "name", None)
|
79
|
+
and not getattr(schema, "properties", None)
|
80
|
+
and not getattr(schema, "enum", None)
|
81
|
+
and getattr(schema, "type", None) != "object"
|
82
|
+
)
|
83
|
+
is_array_type = getattr(schema, "type", None) == "array"
|
84
|
+
return is_type_alias and is_array_type
|
85
|
+
|
86
|
+
return False
|
87
|
+
|
88
|
+
def _is_type_alias_to_primitive(self, type_name: str) -> bool:
|
89
|
+
"""
|
90
|
+
Check if a type name corresponds to a type alias that resolves to a primitive type.
|
91
|
+
|
92
|
+
This helps distinguish between:
|
93
|
+
- Type aliases: StringAlias = str (should use cast())
|
94
|
+
- Dataclasses: class MyModel(BaseSchema): ... (should use .from_dict())
|
95
|
+
|
96
|
+
Args:
|
97
|
+
type_name: The Python type name (e.g., "StringAlias")
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
True if this is a type alias that resolves to a primitive type (str, int, float, bool)
|
101
|
+
"""
|
102
|
+
# Extract base type name without generics
|
103
|
+
base_type = type_name
|
104
|
+
if "[" in base_type:
|
105
|
+
base_type = base_type[: base_type.find("[")]
|
106
|
+
|
107
|
+
# Look up the schema for this type name
|
108
|
+
if base_type in self.schemas:
|
109
|
+
schema = self.schemas[base_type]
|
110
|
+
# Check if it's a type alias using ModelVisitor's logic:
|
111
|
+
# - Has a name
|
112
|
+
# - No properties (not an object with fields)
|
113
|
+
# - Not an enum
|
114
|
+
# - Type is not "object" (which would be a dataclass)
|
115
|
+
# - Type is a primitive (string, integer, number, boolean)
|
116
|
+
is_type_alias = bool(
|
117
|
+
getattr(schema, "name", None)
|
118
|
+
and not getattr(schema, "properties", None)
|
119
|
+
and not getattr(schema, "enum", None)
|
120
|
+
and getattr(schema, "type", None) != "object"
|
121
|
+
)
|
122
|
+
is_primitive_type = getattr(schema, "type", None) in ("string", "integer", "number", "boolean")
|
123
|
+
return is_type_alias and is_primitive_type
|
124
|
+
|
125
|
+
return False
|
126
|
+
|
127
|
+
def _should_use_base_schema(self, type_name: str) -> bool:
|
128
|
+
"""
|
129
|
+
Determine if a type should use BaseSchema deserialization.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
type_name: The Python type name (e.g., "User", "List[User]", "Optional[User]")
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
True if the type should use BaseSchema .from_dict() deserialization
|
136
|
+
"""
|
137
|
+
# Extract the base type name from complex types
|
138
|
+
base_type = type_name
|
139
|
+
|
140
|
+
# Handle List[Type], Optional[Type], etc.
|
141
|
+
if "[" in base_type and "]" in base_type:
|
142
|
+
# Extract the inner type from List[Type], Optional[Type], etc.
|
143
|
+
start_bracket = base_type.find("[")
|
144
|
+
end_bracket = base_type.rfind("]")
|
145
|
+
inner_type = base_type[start_bracket + 1 : end_bracket]
|
146
|
+
|
147
|
+
# For Union types like Optional[User] -> Union[User, None], take the first type
|
148
|
+
if ", " in inner_type:
|
149
|
+
inner_type = inner_type.split(", ")[0]
|
150
|
+
|
151
|
+
base_type = inner_type.strip()
|
152
|
+
|
153
|
+
# Skip primitive types and built-ins (both uppercase and lowercase)
|
154
|
+
if base_type in {
|
155
|
+
"str",
|
156
|
+
"int",
|
157
|
+
"float",
|
158
|
+
"bool",
|
159
|
+
"bytes",
|
160
|
+
"None",
|
161
|
+
"Any",
|
162
|
+
"Dict",
|
163
|
+
"List",
|
164
|
+
"dict",
|
165
|
+
"list",
|
166
|
+
"tuple",
|
167
|
+
}:
|
168
|
+
return False
|
169
|
+
|
170
|
+
# Skip typing constructs (both uppercase and lowercase)
|
171
|
+
if base_type.startswith(("Dict[", "List[", "Optional[", "Union[", "Tuple[", "dict[", "list[", "tuple[")):
|
172
|
+
return False
|
173
|
+
|
174
|
+
# Check if this is a type alias (array or non-array) - these should NOT use BaseSchema
|
175
|
+
if self._is_type_alias_to_array(type_name) or self._is_type_alias_to_primitive(type_name):
|
176
|
+
return False
|
177
|
+
|
178
|
+
# All custom model types now inherit from BaseSchema for automatic field mapping
|
179
|
+
# Check if it's a model type (contains a dot indicating it's from models package)
|
180
|
+
# or if it's a simple class name that's likely a generated model (starts with uppercase)
|
181
|
+
return "." in base_type or (
|
182
|
+
base_type[0].isupper()
|
183
|
+
and base_type not in {"Dict", "List", "Optional", "Union", "Tuple", "dict", "list", "tuple"}
|
184
|
+
)
|
185
|
+
|
186
|
+
def _get_base_schema_deserialization_code(self, return_type: str, data_expr: str) -> str:
|
187
|
+
"""
|
188
|
+
Generate BaseSchema deserialization code for a given type.
|
189
|
+
|
190
|
+
Args:
|
191
|
+
return_type: The return type (e.g., "User", "List[User]", "list[User]")
|
192
|
+
data_expr: The expression containing the raw data to deserialize
|
193
|
+
|
194
|
+
Returns:
|
195
|
+
Code string for deserializing the data using BaseSchema .from_dict()
|
196
|
+
"""
|
197
|
+
if return_type.startswith("List[") or return_type.startswith("list["):
|
198
|
+
# Handle List[Model] or list[Model] types
|
199
|
+
if return_type.startswith("List["):
|
200
|
+
item_type = return_type[5:-1] # Remove 'List[' and ']'
|
201
|
+
else: # starts with "list["
|
202
|
+
item_type = return_type[5:-1] # Remove 'list[' and ']'
|
203
|
+
return f"[{item_type}.from_dict(item) for item in {data_expr}]"
|
204
|
+
elif return_type.startswith("Optional["):
|
205
|
+
# Handle Optional[Model] types
|
206
|
+
inner_type = return_type[9:-1] # Remove 'Optional[' and ']'
|
207
|
+
# Check if inner type is also a list
|
208
|
+
if inner_type.startswith("List[") or inner_type.startswith("list["):
|
209
|
+
list_code = self._get_base_schema_deserialization_code(inner_type, data_expr)
|
210
|
+
return f"{list_code} if {data_expr} is not None else None"
|
211
|
+
else:
|
212
|
+
return f"{inner_type}.from_dict({data_expr}) if {data_expr} is not None else None"
|
213
|
+
else:
|
214
|
+
# Handle simple Model types only - this should not be called for list types
|
215
|
+
if "[" in return_type and "]" in return_type:
|
216
|
+
# This is a complex type that we missed - should not happen
|
217
|
+
raise ValueError(f"Unsupported complex type for BaseSchema deserialization: {return_type}")
|
218
|
+
|
219
|
+
# Safety check: catch the specific issue we're debugging
|
220
|
+
if return_type.startswith("list[") or return_type.startswith("List["):
|
221
|
+
raise ValueError(
|
222
|
+
f"CRITICAL BUG: List type {return_type} reached simple type handler! This should never happen."
|
223
|
+
)
|
224
|
+
|
225
|
+
return f"{return_type}.from_dict({data_expr})"
|
226
|
+
|
53
227
|
def _get_extraction_code(
|
54
228
|
self,
|
55
229
|
return_type: str,
|
56
230
|
context: RenderContext,
|
57
231
|
op: IROperation,
|
58
|
-
needs_unwrap: bool,
|
59
232
|
response_ir: Optional[IRResponse] = None,
|
60
233
|
) -> str:
|
61
234
|
"""Determines the code snippet to extract/transform the response body."""
|
@@ -79,7 +252,7 @@ class EndpointResponseHandlerGenerator:
|
|
79
252
|
model_type = return_type[13:-1] # Remove 'AsyncIterator[' and ']'
|
80
253
|
if response_ir and "text/event-stream" in response_ir.content:
|
81
254
|
context.add_import(f"{context.core_package_name}.streaming_helpers", "iter_sse_events_text")
|
82
|
-
return "sse_json_stream_marker" # Special marker for SSE
|
255
|
+
return "sse_json_stream_marker" # Special marker for SSE streaming
|
83
256
|
|
84
257
|
# Default to bytes streaming for other types
|
85
258
|
context.add_import(f"{context.core_package_name}.streaming_helpers", "iter_bytes")
|
@@ -100,32 +273,19 @@ class EndpointResponseHandlerGenerator:
|
|
100
273
|
elif return_type == "None":
|
101
274
|
return "None" # This will be handled by generate_response_handling directly
|
102
275
|
else: # Includes schema-defined models, List[], Dict[], Optional[]
|
103
|
-
context.add_import("typing", "cast")
|
104
276
|
context.add_typing_imports_for_type(return_type) # Ensure model itself is imported
|
105
277
|
|
106
|
-
if
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
f"raw_data = response.json().get('data')\n"
|
118
|
-
f"if raw_data is None:\n"
|
119
|
-
f" raise ValueError(\"Expected 'data' key in response but found None\")\n"
|
120
|
-
f"return cast({return_type}, raw_data)"
|
121
|
-
)
|
122
|
-
# Standard unwrapping for single object
|
123
|
-
return (
|
124
|
-
f"raw_data = response.json().get('data')\n"
|
125
|
-
f"if raw_data is None:\n"
|
126
|
-
f" raise ValueError(\"Expected 'data' key in response but found None\")\n"
|
127
|
-
f"return cast({return_type}, raw_data)"
|
128
|
-
)
|
278
|
+
# Check if we should use BaseSchema deserialization instead of cast()
|
279
|
+
use_base_schema = self._should_use_base_schema(return_type)
|
280
|
+
|
281
|
+
if not use_base_schema:
|
282
|
+
# Fallback to cast() for non-BaseSchema types
|
283
|
+
context.add_import("typing", "cast")
|
284
|
+
|
285
|
+
# Direct deserialization using schemas as-is (no unwrapping)
|
286
|
+
if use_base_schema:
|
287
|
+
deserialization_code = self._get_base_schema_deserialization_code(return_type, "response.json()")
|
288
|
+
return deserialization_code
|
129
289
|
else:
|
130
290
|
return f"cast({return_type}, response.json())"
|
131
291
|
|
@@ -134,364 +294,213 @@ class EndpointResponseHandlerGenerator:
|
|
134
294
|
writer: CodeWriter,
|
135
295
|
op: IROperation,
|
136
296
|
context: RenderContext,
|
297
|
+
strategy: ResponseStrategy,
|
137
298
|
) -> None:
|
138
|
-
"""Writes the response parsing and return logic to the CodeWriter,
|
299
|
+
"""Writes the response parsing and return logic to the CodeWriter, using the unified response strategy."""
|
139
300
|
writer.write_line("# Check response status code and handle accordingly")
|
140
301
|
|
141
|
-
#
|
142
|
-
|
143
|
-
|
302
|
+
# Generate the match statement for status codes
|
303
|
+
writer.write_line("match response.status_code:")
|
304
|
+
writer.indent()
|
144
305
|
|
306
|
+
# Handle the primary success response first
|
145
307
|
primary_success_ir = _get_primary_response(op)
|
146
|
-
|
147
|
-
|
148
|
-
if primary_success_ir: # Explicit check for None to help linter
|
149
|
-
is_2xx = primary_success_ir.status_code.startswith("2")
|
150
|
-
is_default_with_content = primary_success_ir.status_code == "default" and bool(
|
151
|
-
primary_success_ir.content
|
152
|
-
) # Ensure this part is boolean
|
153
|
-
is_primary_actually_success = is_2xx or is_default_with_content
|
154
|
-
|
155
|
-
# Determine if the primary success response will be handled by the first dedicated block
|
156
|
-
# This first block only handles numeric (2xx) success codes.
|
157
|
-
is_primary_handled_by_first_block = (
|
308
|
+
processed_primary_success = False
|
309
|
+
if (
|
158
310
|
primary_success_ir
|
159
|
-
and
|
160
|
-
and primary_success_ir.status_code.
|
161
|
-
|
162
|
-
)
|
163
|
-
|
164
|
-
other_responses = sorted(
|
165
|
-
[
|
166
|
-
r for r in op.responses if not (r == primary_success_ir and is_primary_handled_by_first_block)
|
167
|
-
], # If primary is handled by first block, exclude it from others
|
168
|
-
key=lambda r: (
|
169
|
-
not r.status_code.startswith("2"), # False for 2xx (comes first)
|
170
|
-
r.status_code != "default", # False for default (comes after 2xx, before errors)
|
171
|
-
r.status_code, # Then sort by status_code string
|
172
|
-
),
|
173
|
-
)
|
174
|
-
|
175
|
-
# Collect all status codes and their handlers for the match statement
|
176
|
-
status_cases: list[StatusCase] = []
|
177
|
-
|
178
|
-
# 1. Handle primary success response IF IT IS TRULY A SUCCESS RESPONSE AND NUMERIC (2xx)
|
179
|
-
if is_primary_handled_by_first_block:
|
180
|
-
assert primary_success_ir is not None # Add assertion to help linter
|
181
|
-
# No try-except needed here as isdigit() and startswith("2") already checked
|
311
|
+
and primary_success_ir.status_code.isdigit()
|
312
|
+
and primary_success_ir.status_code.startswith("2")
|
313
|
+
):
|
182
314
|
status_code_val = int(primary_success_ir.status_code)
|
315
|
+
writer.write_line(f"case {status_code_val}:")
|
316
|
+
writer.indent()
|
183
317
|
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
# If we have proper schemas, try to get unwrapping information from unified service
|
190
|
-
if self.schemas and hasattr(list(self.schemas.values())[0] if self.schemas else None, "type"):
|
191
|
-
try:
|
192
|
-
type_service = UnifiedTypeService(self.schemas)
|
193
|
-
return_type_for_op, needs_unwrap_for_op = type_service.resolve_operation_response_with_unwrap_info(
|
194
|
-
op, context
|
195
|
-
)
|
196
|
-
except Exception:
|
197
|
-
# Fall back to the original approach if there's an issue
|
198
|
-
needs_unwrap_for_op = False
|
199
|
-
|
200
|
-
status_cases.append(
|
201
|
-
StatusCase(
|
202
|
-
status_code=status_code_val,
|
203
|
-
type="primary_success",
|
204
|
-
return_type=return_type_for_op,
|
205
|
-
needs_unwrap=needs_unwrap_for_op,
|
206
|
-
response_ir=primary_success_ir,
|
207
|
-
)
|
208
|
-
)
|
209
|
-
|
210
|
-
# 2. Handle other specific responses (other 2xx, then default, then errors)
|
211
|
-
default_case: Optional[DefaultCase] = None
|
212
|
-
for resp_ir in other_responses:
|
213
|
-
# Determine if this response IR defines a success type different from the primary
|
214
|
-
# This is complex. For now, if it's 2xx, we'll try to parse it.
|
215
|
-
# If it's an error, we raise.
|
216
|
-
|
217
|
-
current_return_type_str: str = "None" # Default for e.g. 204 or error cases
|
218
|
-
current_needs_unwrap: bool = False
|
318
|
+
if strategy.return_type == "None":
|
319
|
+
writer.write_line("return None")
|
320
|
+
else:
|
321
|
+
self._write_strategy_based_return(writer, strategy, context)
|
219
322
|
|
220
|
-
|
221
|
-
|
222
|
-
current_return_type_str = "None"
|
223
|
-
else:
|
224
|
-
# We need a way to get the type for *this specific* resp_ir if its schema differs
|
225
|
-
# from the primary operation return type.
|
226
|
-
# Call the new helper for this specific response
|
227
|
-
current_return_type_str = get_type_for_specific_response(
|
228
|
-
operation_path=getattr(op, "path", ""),
|
229
|
-
resp_ir=resp_ir,
|
230
|
-
all_schemas=self.schemas,
|
231
|
-
ctx=context,
|
232
|
-
return_unwrap_data_property=True,
|
233
|
-
)
|
234
|
-
current_needs_unwrap = (
|
235
|
-
"data" in current_return_type_str.lower() or "item" in current_return_type_str.lower()
|
236
|
-
)
|
237
|
-
|
238
|
-
if resp_ir.status_code == "default":
|
239
|
-
# Determine type for default response if it has content
|
240
|
-
default_return_type_str = "None"
|
241
|
-
default_needs_unwrap = False
|
242
|
-
if resp_ir.content:
|
243
|
-
# If 'default' is primary success, get_return_type_unified(op,...) might give its type.
|
244
|
-
# We use the operation's global/primary return type if default has content.
|
245
|
-
op_global_return_type = get_return_type_unified(op, context, self.schemas)
|
246
|
-
op_global_needs_unwrap = False # Unified service handles unwrapping internally
|
247
|
-
# Only use this if the global type is not 'None', otherwise keep default_return_type_str as 'None'.
|
248
|
-
if op_global_return_type != "None":
|
249
|
-
default_return_type_str = op_global_return_type
|
250
|
-
default_needs_unwrap = op_global_needs_unwrap
|
251
|
-
|
252
|
-
default_case = DefaultCase(
|
253
|
-
response_ir=resp_ir, return_type=default_return_type_str, needs_unwrap=default_needs_unwrap
|
254
|
-
)
|
255
|
-
continue # Handle default separately
|
323
|
+
writer.dedent()
|
324
|
+
processed_primary_success = True
|
256
325
|
|
257
|
-
|
326
|
+
# Handle other responses (exclude primary only if it was actually processed)
|
327
|
+
other_responses = [r for r in op.responses if not (processed_primary_success and r == primary_success_ir)]
|
328
|
+
for resp_ir in other_responses:
|
329
|
+
if resp_ir.status_code.isdigit():
|
258
330
|
status_code_val = int(resp_ir.status_code)
|
259
|
-
|
260
|
-
|
261
|
-
status_cases.append(
|
262
|
-
StatusCase(
|
263
|
-
status_code=status_code_val,
|
264
|
-
type=case_type,
|
265
|
-
return_type=current_return_type_str,
|
266
|
-
needs_unwrap=current_needs_unwrap,
|
267
|
-
response_ir=resp_ir,
|
268
|
-
)
|
269
|
-
)
|
270
|
-
except ValueError:
|
271
|
-
logger.warning(f"Skipping non-integer status code in other_responses: {resp_ir.status_code}")
|
272
|
-
|
273
|
-
# Generate the match statement
|
274
|
-
if status_cases or default_case:
|
275
|
-
writer.write_line("match response.status_code:")
|
276
|
-
writer.indent()
|
277
|
-
|
278
|
-
# Generate cases for specific status codes
|
279
|
-
for case in status_cases:
|
280
|
-
writer.write_line(f"case {case['status_code']}:")
|
331
|
+
writer.write_line(f"case {status_code_val}:")
|
281
332
|
writer.indent()
|
282
333
|
|
283
|
-
if
|
284
|
-
#
|
285
|
-
|
286
|
-
# where the type was inferred even if the spec lacked explicit content for the 2xx.
|
287
|
-
# If get_return_type says "None" (e.g., for a 204 or truly no content), then return None.
|
288
|
-
if case["return_type"] == "None":
|
289
|
-
writer.write_line("return None")
|
290
|
-
else:
|
291
|
-
self._write_parsed_return(
|
292
|
-
writer, op, context, case["return_type"], case["needs_unwrap"], case["response_ir"]
|
293
|
-
)
|
294
|
-
elif case["type"] == "success":
|
295
|
-
# Other 2xx success
|
296
|
-
if case["return_type"] == "None" or not case["response_ir"].content:
|
334
|
+
if resp_ir.status_code.startswith("2"):
|
335
|
+
# Other 2xx success responses - resolve each response individually
|
336
|
+
if not resp_ir.content:
|
297
337
|
writer.write_line("return None")
|
298
338
|
else:
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
339
|
+
# Resolve the specific return type for this response
|
340
|
+
resp_schema = self._get_response_schema(resp_ir)
|
341
|
+
if resp_schema:
|
342
|
+
type_service = UnifiedTypeService(self.schemas)
|
343
|
+
response_type = type_service.resolve_schema_type(resp_schema, context)
|
344
|
+
if self._should_use_base_schema(response_type):
|
345
|
+
deserialization_code = self._get_base_schema_deserialization_code(
|
346
|
+
response_type, "response.json()"
|
347
|
+
)
|
348
|
+
writer.write_line(f"return {deserialization_code}")
|
349
|
+
context.add_typing_imports_for_type(response_type)
|
350
|
+
else:
|
351
|
+
context.add_import("typing", "cast")
|
352
|
+
writer.write_line(f"return cast({response_type}, response.json())")
|
353
|
+
else:
|
354
|
+
writer.write_line("return None")
|
355
|
+
else:
|
356
|
+
# Error responses
|
357
|
+
error_class_name = f"Error{status_code_val}"
|
358
|
+
context.add_import(f"{context.core_package_name}", error_class_name)
|
308
359
|
writer.write_line(f"raise {error_class_name}(response=response)")
|
309
360
|
|
310
361
|
writer.dedent()
|
311
362
|
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
self._write_parsed_return(
|
320
|
-
writer,
|
321
|
-
op,
|
322
|
-
context,
|
323
|
-
default_case["return_type"],
|
324
|
-
default_case["needs_unwrap"],
|
325
|
-
default_case["response_ir"],
|
326
|
-
)
|
327
|
-
writer.dedent()
|
328
|
-
else:
|
329
|
-
# Default case without content (error)
|
330
|
-
writer.write_line("case _: # Default error response")
|
331
|
-
writer.indent()
|
332
|
-
context.add_import(f"{context.core_package_name}.exceptions", "HTTPError")
|
333
|
-
default_description = default_case["response_ir"].description or "Unknown default error"
|
334
|
-
writer.write_line(
|
335
|
-
f"raise HTTPError(response=response, "
|
336
|
-
f'message="Default error: {default_description}", '
|
337
|
-
f"status_code=response.status_code)"
|
338
|
-
)
|
339
|
-
writer.dedent()
|
363
|
+
# Handle default case
|
364
|
+
default_response = next((r for r in op.responses if r.status_code == "default"), None)
|
365
|
+
if default_response:
|
366
|
+
writer.write_line("case _: # Default response")
|
367
|
+
writer.indent()
|
368
|
+
if default_response.content and strategy.return_type != "None":
|
369
|
+
self._write_strategy_based_return(writer, strategy, context)
|
340
370
|
else:
|
341
|
-
# Final catch-all for unhandled status codes
|
342
|
-
writer.write_line("case _:")
|
343
|
-
writer.indent()
|
344
371
|
context.add_import(f"{context.core_package_name}.exceptions", "HTTPError")
|
345
372
|
writer.write_line(
|
346
|
-
|
347
|
-
"response=response, "
|
348
|
-
'message="Unhandled status code", '
|
349
|
-
"status_code=response.status_code)"
|
373
|
+
'raise HTTPError(response=response, message="Default error", status_code=response.status_code)'
|
350
374
|
)
|
351
|
-
|
352
|
-
|
353
|
-
writer.dedent() # End of match statement
|
375
|
+
writer.dedent()
|
354
376
|
else:
|
355
|
-
#
|
356
|
-
writer.write_line("match response.status_code:")
|
357
|
-
writer.indent()
|
377
|
+
# Final catch-all
|
358
378
|
writer.write_line("case _:")
|
359
379
|
writer.indent()
|
360
380
|
context.add_import(f"{context.core_package_name}.exceptions", "HTTPError")
|
361
381
|
writer.write_line(
|
362
|
-
|
382
|
+
'raise HTTPError(response=response, message="Unhandled status code", status_code=response.status_code)'
|
363
383
|
)
|
364
384
|
writer.dedent()
|
365
|
-
|
385
|
+
|
386
|
+
writer.dedent() # End of match statement
|
366
387
|
|
367
388
|
# All code paths should be covered by the match statement above
|
368
|
-
# But add an explicit assertion for mypy's satisfaction
|
369
389
|
writer.write_line("# All paths above should return or raise - this should never execute")
|
370
390
|
context.add_import("typing", "NoReturn")
|
371
391
|
writer.write_line("assert False, 'Unexpected code path' # pragma: no cover")
|
372
392
|
writer.write_line("") # Add a blank line for readability
|
373
393
|
|
374
|
-
def
|
394
|
+
def _write_strategy_based_return(
|
375
395
|
self,
|
376
396
|
writer: CodeWriter,
|
377
|
-
|
397
|
+
strategy: ResponseStrategy,
|
378
398
|
context: RenderContext,
|
379
|
-
return_type: str,
|
380
|
-
needs_unwrap: bool,
|
381
|
-
response_ir: Optional[IRResponse] = None,
|
382
399
|
) -> None:
|
383
|
-
"""
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
context.add_import("typing", "Union")
|
394
|
-
context.add_import("typing", "cast")
|
395
|
-
# Corrected regex to parse "Union[TypeA, TypeB]"
|
396
|
-
match = re.match(r"Union\[([A-Za-z0-9_]+),\s*([A-Za-z0-9_]+)\]", return_type)
|
397
|
-
if match:
|
398
|
-
type1_str = match.group(1).strip()
|
399
|
-
type2_str = match.group(2).strip()
|
400
|
-
context.add_typing_imports_for_type(type1_str)
|
401
|
-
context.add_typing_imports_for_type(type2_str)
|
402
|
-
writer.write_line("try:")
|
403
|
-
writer.indent()
|
404
|
-
# Pass response_ir to _get_extraction_code if available
|
405
|
-
extraction_code_type1 = self._get_extraction_code(type1_str, context, op, needs_unwrap, response_ir)
|
406
|
-
if "\n" in extraction_code_type1: # Multi-line extraction
|
407
|
-
lines = extraction_code_type1.split("\n")
|
408
|
-
for line in lines[:-1]: # all but 'return ...'
|
409
|
-
writer.write_line(line)
|
410
|
-
writer.write_line(lines[-1].replace("return ", "return_value = "))
|
411
|
-
writer.write_line("return return_value")
|
412
|
-
else:
|
413
|
-
writer.write_line(f"return {extraction_code_type1}")
|
414
|
-
|
415
|
-
writer.dedent()
|
416
|
-
writer.write_line("except Exception: # Attempt to parse as the second type")
|
400
|
+
"""Write the return statement based on the response strategy.
|
401
|
+
|
402
|
+
This method implements the strategy pattern for response handling,
|
403
|
+
ensuring consistent behavior between signature and implementation.
|
404
|
+
"""
|
405
|
+
if strategy.is_streaming:
|
406
|
+
# Handle streaming responses
|
407
|
+
if "AsyncIterator[bytes]" in strategy.return_type:
|
408
|
+
context.add_import(f"{context.core_package_name}.streaming_helpers", "iter_bytes")
|
409
|
+
writer.write_line("async for chunk in iter_bytes(response):")
|
417
410
|
writer.indent()
|
418
|
-
|
419
|
-
if "\n" in extraction_code_type2: # Multi-line extraction
|
420
|
-
lines = extraction_code_type2.split("\n")
|
421
|
-
for line in lines[:-1]:
|
422
|
-
writer.write_line(line)
|
423
|
-
writer.write_line(lines[-1].replace("return ", "return_value = "))
|
424
|
-
writer.write_line("return return_value")
|
425
|
-
else:
|
426
|
-
writer.write_line(f"return {extraction_code_type2}")
|
411
|
+
writer.write_line("yield chunk")
|
427
412
|
writer.dedent()
|
413
|
+
writer.write_line("return # Explicit return for async generator")
|
428
414
|
else:
|
429
|
-
|
430
|
-
f"Could not parse Union components with regex: {return_type}. Falling back to cast(Any, ...)"
|
431
|
-
)
|
432
|
-
context.add_import("typing", "Any")
|
433
|
-
writer.write_line(f"return cast(Any, response.json())")
|
434
|
-
|
435
|
-
elif return_type == "None": # Explicit None, e.g. for 204 or when specific response has no content
|
436
|
-
writer.write_line("return None")
|
437
|
-
elif is_op_with_inferred_type: # This condition may need re-evaluation in this context
|
438
|
-
context.add_typing_imports_for_type(return_type)
|
439
|
-
context.add_import("typing", "cast")
|
440
|
-
writer.write_line(f"return cast({return_type}, response.json())")
|
441
|
-
else:
|
442
|
-
context.add_typing_imports_for_type(return_type)
|
443
|
-
extraction_code_str = self._get_extraction_code(return_type, context, op, needs_unwrap, response_ir)
|
444
|
-
|
445
|
-
if extraction_code_str == "sse_json_stream_marker": # SSE handling
|
415
|
+
# Handle other streaming types
|
446
416
|
context.add_plain_import("json")
|
447
417
|
context.add_import(f"{context.core_package_name}.streaming_helpers", "iter_sse_events_text")
|
448
|
-
|
449
|
-
# This indicates that SSE streaming might need to be handled more holistically.
|
450
|
-
# For now, if we hit this, it means get_return_type decided on AsyncIterator for an SSE.
|
451
|
-
# The method signature is already async iterator.
|
452
|
-
# The dispatcher should yield from the iter_sse_events_text.
|
453
|
-
# This implies that the `if response.status_code == ...:` block itself needs to be `async for ... yield`
|
454
|
-
# This refactoring is getting deeper.
|
455
|
-
# Quick fix: if it's sse_json_stream_marker, we write the loop here.
|
456
|
-
writer.write_line(f"async for chunk in iter_sse_events_text(response):")
|
418
|
+
writer.write_line("async for chunk in iter_sse_events_text(response):")
|
457
419
|
writer.indent()
|
458
|
-
writer.write_line("yield json.loads(chunk)")
|
420
|
+
writer.write_line("yield json.loads(chunk)")
|
459
421
|
writer.dedent()
|
460
|
-
writer.write_line(
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
422
|
+
writer.write_line("return # Explicit return for async generator")
|
423
|
+
return
|
424
|
+
|
425
|
+
# Handle responses using the schema as-is from the OpenAPI spec (no unwrapping)
|
426
|
+
if strategy.return_type.startswith("Union["):
|
427
|
+
# Special handling for Union types with try/except fallback
|
428
|
+
self._write_union_response_handling(writer, context, strategy.return_type, "response.json()")
|
429
|
+
elif self._should_use_base_schema(strategy.return_type):
|
430
|
+
deserialization_code = self._get_base_schema_deserialization_code(strategy.return_type, "response.json()")
|
431
|
+
writer.write_line(f"return {deserialization_code}")
|
432
|
+
else:
|
433
|
+
context.add_import("typing", "cast")
|
434
|
+
writer.write_line(f"return cast({strategy.return_type}, response.json())")
|
435
|
+
|
436
|
+
def _get_response_schema(self, response_ir: IRResponse) -> Optional[IRSchema]:
|
437
|
+
"""Extract the schema from a response IR."""
|
438
|
+
if not response_ir.content:
|
439
|
+
return None
|
440
|
+
|
441
|
+
# Prefer application/json, then first available content type
|
442
|
+
content_types = list(response_ir.content.keys())
|
443
|
+
preferred_content_type = next((ct for ct in content_types if ct == "application/json"), None)
|
444
|
+
if not preferred_content_type:
|
445
|
+
preferred_content_type = content_types[0] if content_types else None
|
446
|
+
|
447
|
+
if preferred_content_type:
|
448
|
+
return response_ir.content.get(preferred_content_type)
|
449
|
+
|
450
|
+
return None
|
451
|
+
|
452
|
+
def _write_union_response_handling(
|
453
|
+
self, writer: CodeWriter, context: RenderContext, return_type: str, data_expr: str
|
454
|
+
) -> None:
|
455
|
+
"""Write try/except logic for Union types."""
|
456
|
+
# Parse Union[TypeA, TypeB] to extract the types
|
457
|
+
if not return_type.startswith("Union[") or not return_type.endswith("]"):
|
458
|
+
raise ValueError(f"Invalid Union type format: {return_type}")
|
459
|
+
|
460
|
+
union_content = return_type[6:-1] # Remove 'Union[' and ']'
|
461
|
+
types = [t.strip() for t in union_content.split(",")]
|
462
|
+
|
463
|
+
if len(types) < 2:
|
464
|
+
raise ValueError(f"Union type must have at least 2 types: {return_type}")
|
465
|
+
|
466
|
+
# Add Union import
|
467
|
+
context.add_import("typing", "Union")
|
468
|
+
|
469
|
+
# Generate try/except blocks for each type
|
470
|
+
first_type = types[0]
|
471
|
+
remaining_types = types[1:]
|
472
|
+
|
473
|
+
# Try the first type
|
474
|
+
writer.write_line("try:")
|
475
|
+
writer.indent()
|
476
|
+
if self._should_use_base_schema(first_type):
|
477
|
+
context.add_typing_imports_for_type(first_type)
|
478
|
+
deserialization_code = self._get_base_schema_deserialization_code(first_type, data_expr)
|
479
|
+
writer.write_line(f"return {deserialization_code}")
|
480
|
+
else:
|
481
|
+
context.add_import("typing", "cast")
|
482
|
+
writer.write_line(f"return cast({first_type}, {data_expr})")
|
483
|
+
writer.dedent()
|
484
|
+
|
485
|
+
# Add except blocks for remaining types
|
486
|
+
for i, type_name in enumerate(remaining_types):
|
487
|
+
is_last = i == len(remaining_types) - 1
|
488
|
+
if is_last:
|
489
|
+
writer.write_line("except Exception: # Attempt to parse as the final type")
|
490
|
+
else:
|
491
|
+
writer.write_line("except Exception: # Attempt to parse as the next type")
|
492
|
+
writer.indent()
|
493
|
+
if self._should_use_base_schema(type_name):
|
494
|
+
context.add_typing_imports_for_type(type_name)
|
495
|
+
deserialization_code = self._get_base_schema_deserialization_code(type_name, data_expr)
|
496
|
+
if is_last:
|
497
|
+
writer.write_line(f"return {deserialization_code}")
|
482
498
|
else:
|
483
|
-
|
484
|
-
writer.write_line(f"async for chunk in iter_bytes(response):")
|
499
|
+
writer.write_line("try:")
|
485
500
|
writer.indent()
|
486
|
-
writer.write_line("
|
501
|
+
writer.write_line(f"return {deserialization_code}")
|
487
502
|
writer.dedent()
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
for line in extraction_code_str.split("\n"):
|
493
|
-
writer.write_line(line)
|
494
|
-
else: # Single line extraction code
|
495
|
-
if return_type != "None": # Should already be handled, but as safety
|
496
|
-
writer.write_line(f"return {extraction_code_str}")
|
497
|
-
# writer.write_line("") # Blank line might be added by the caller of this helper
|
503
|
+
else:
|
504
|
+
context.add_import("typing", "cast")
|
505
|
+
writer.write_line(f"return cast({type_name}, {data_expr})")
|
506
|
+
writer.dedent()
|