pyopenapi-gen 0.8.3__py3-none-any.whl → 0.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. pyopenapi_gen/cli.py +5 -22
  2. pyopenapi_gen/context/import_collector.py +8 -8
  3. pyopenapi_gen/core/loader/operations/parser.py +1 -1
  4. pyopenapi_gen/core/parsing/context.py +2 -1
  5. pyopenapi_gen/core/parsing/cycle_helpers.py +1 -1
  6. pyopenapi_gen/core/parsing/keywords/properties_parser.py +4 -4
  7. pyopenapi_gen/core/parsing/schema_parser.py +4 -4
  8. pyopenapi_gen/core/parsing/transformers/inline_enum_extractor.py +1 -1
  9. pyopenapi_gen/core/postprocess_manager.py +39 -13
  10. pyopenapi_gen/core/schemas.py +101 -16
  11. pyopenapi_gen/core/utils.py +8 -3
  12. pyopenapi_gen/core/writers/python_construct_renderer.py +57 -9
  13. pyopenapi_gen/emitters/endpoints_emitter.py +1 -1
  14. pyopenapi_gen/helpers/endpoint_utils.py +4 -22
  15. pyopenapi_gen/helpers/type_cleaner.py +1 -1
  16. pyopenapi_gen/helpers/type_resolution/composition_resolver.py +1 -1
  17. pyopenapi_gen/helpers/type_resolution/finalizer.py +1 -1
  18. pyopenapi_gen/types/contracts/types.py +0 -1
  19. pyopenapi_gen/types/resolvers/response_resolver.py +5 -33
  20. pyopenapi_gen/types/resolvers/schema_resolver.py +2 -2
  21. pyopenapi_gen/types/services/type_service.py +0 -18
  22. pyopenapi_gen/types/strategies/__init__.py +5 -0
  23. pyopenapi_gen/types/strategies/response_strategy.py +187 -0
  24. pyopenapi_gen/visit/endpoint/endpoint_visitor.py +1 -20
  25. pyopenapi_gen/visit/endpoint/generators/docstring_generator.py +5 -3
  26. pyopenapi_gen/visit/endpoint/generators/endpoint_method_generator.py +12 -6
  27. pyopenapi_gen/visit/endpoint/generators/response_handler_generator.py +352 -343
  28. pyopenapi_gen/visit/endpoint/generators/signature_generator.py +7 -4
  29. pyopenapi_gen/visit/endpoint/processors/import_analyzer.py +4 -2
  30. pyopenapi_gen/visit/endpoint/processors/parameter_processor.py +1 -1
  31. pyopenapi_gen/visit/model/dataclass_generator.py +32 -1
  32. pyopenapi_gen-0.8.6.dist-info/METADATA +383 -0
  33. {pyopenapi_gen-0.8.3.dist-info → pyopenapi_gen-0.8.6.dist-info}/RECORD +36 -34
  34. pyopenapi_gen-0.8.3.dist-info/METADATA +0 -224
  35. {pyopenapi_gen-0.8.3.dist-info → pyopenapi_gen-0.8.6.dist-info}/WHEEL +0 -0
  36. {pyopenapi_gen-0.8.3.dist-info → pyopenapi_gen-0.8.6.dist-info}/entry_points.txt +0 -0
  37. {pyopenapi_gen-0.8.3.dist-info → pyopenapi_gen-0.8.6.dist-info}/licenses/LICENSE +0 -0
@@ -5,23 +5,21 @@ Helper class for generating response handling logic for an endpoint method.
5
5
  from __future__ import annotations
6
6
 
7
7
  import logging
8
- import re # For parsing Union types, etc.
9
8
  from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
10
9
 
11
10
  from pyopenapi_gen.core.writers.code_writer import CodeWriter
12
11
  from pyopenapi_gen.helpers.endpoint_utils import (
13
12
  _get_primary_response,
14
- get_return_type_unified,
15
- get_type_for_specific_response, # Added new helper
16
13
  )
17
14
  from pyopenapi_gen.types.services.type_service import UnifiedTypeService
15
+ from pyopenapi_gen.types.strategies.response_strategy import ResponseStrategy
18
16
 
19
17
  if TYPE_CHECKING:
20
- from pyopenapi_gen import IROperation, IRResponse
18
+ from pyopenapi_gen import IROperation, IRResponse, IRSchema
21
19
  from pyopenapi_gen.context.render_context import RenderContext
22
20
  else:
23
21
  # For runtime, we need to import for TypedDict
24
- from pyopenapi_gen import IRResponse
22
+ from pyopenapi_gen import IRResponse, IRSchema
25
23
 
26
24
  logger = logging.getLogger(__name__)
27
25
 
@@ -32,7 +30,6 @@ class StatusCase(TypedDict):
32
30
  status_code: int
33
31
  type: str # 'primary_success', 'success', or 'error'
34
32
  return_type: str
35
- needs_unwrap: bool
36
33
  response_ir: IRResponse
37
34
 
38
35
 
@@ -41,7 +38,6 @@ class DefaultCase(TypedDict):
41
38
 
42
39
  response_ir: IRResponse
43
40
  return_type: str
44
- needs_unwrap: bool
45
41
 
46
42
 
47
43
  class EndpointResponseHandlerGenerator:
@@ -50,12 +46,189 @@ class EndpointResponseHandlerGenerator:
50
46
  def __init__(self, schemas: Optional[Dict[str, Any]] = None) -> None:
51
47
  self.schemas: Dict[str, Any] = schemas or {}
52
48
 
49
+ def _is_type_alias_to_array(self, type_name: str) -> bool:
50
+ """
51
+ Check if a type name corresponds to a type alias that resolves to a List/array type.
52
+
53
+ This helps distinguish between:
54
+ - Type aliases: AgentHistoryListResponse = List[AgentHistory] (should use array deserialization)
55
+ - Dataclasses: class AgentHistoryListResponse(BaseSchema): ... (should use .from_dict())
56
+
57
+ Args:
58
+ type_name: The Python type name (e.g., "AgentHistoryListResponse")
59
+
60
+ Returns:
61
+ True if this is a type alias that resolves to List[SomeType]
62
+ """
63
+ # Extract base type name without generics
64
+ base_type = type_name
65
+ if "[" in base_type:
66
+ base_type = base_type[: base_type.find("[")]
67
+
68
+ # Look up the schema for this type name
69
+ if base_type in self.schemas:
70
+ schema = self.schemas[base_type]
71
+ # Check if it's a type alias using ModelVisitor's logic:
72
+ # - Has a name
73
+ # - No properties (not an object with fields)
74
+ # - Not an enum
75
+ # - Type is not "object" (which would be a dataclass)
76
+ # - Type is "array" (indicating it's an array type alias)
77
+ is_type_alias = bool(
78
+ getattr(schema, "name", None)
79
+ and not getattr(schema, "properties", None)
80
+ and not getattr(schema, "enum", None)
81
+ and getattr(schema, "type", None) != "object"
82
+ )
83
+ is_array_type = getattr(schema, "type", None) == "array"
84
+ return is_type_alias and is_array_type
85
+
86
+ return False
87
+
88
+ def _is_type_alias_to_primitive(self, type_name: str) -> bool:
89
+ """
90
+ Check if a type name corresponds to a type alias that resolves to a primitive type.
91
+
92
+ This helps distinguish between:
93
+ - Type aliases: StringAlias = str (should use cast())
94
+ - Dataclasses: class MyModel(BaseSchema): ... (should use .from_dict())
95
+
96
+ Args:
97
+ type_name: The Python type name (e.g., "StringAlias")
98
+
99
+ Returns:
100
+ True if this is a type alias that resolves to a primitive type (str, int, float, bool)
101
+ """
102
+ # Extract base type name without generics
103
+ base_type = type_name
104
+ if "[" in base_type:
105
+ base_type = base_type[: base_type.find("[")]
106
+
107
+ # Look up the schema for this type name
108
+ if base_type in self.schemas:
109
+ schema = self.schemas[base_type]
110
+ # Check if it's a type alias using ModelVisitor's logic:
111
+ # - Has a name
112
+ # - No properties (not an object with fields)
113
+ # - Not an enum
114
+ # - Type is not "object" (which would be a dataclass)
115
+ # - Type is a primitive (string, integer, number, boolean)
116
+ is_type_alias = bool(
117
+ getattr(schema, "name", None)
118
+ and not getattr(schema, "properties", None)
119
+ and not getattr(schema, "enum", None)
120
+ and getattr(schema, "type", None) != "object"
121
+ )
122
+ is_primitive_type = getattr(schema, "type", None) in ("string", "integer", "number", "boolean")
123
+ return is_type_alias and is_primitive_type
124
+
125
+ return False
126
+
127
+ def _should_use_base_schema(self, type_name: str) -> bool:
128
+ """
129
+ Determine if a type should use BaseSchema deserialization.
130
+
131
+ Args:
132
+ type_name: The Python type name (e.g., "User", "List[User]", "Optional[User]")
133
+
134
+ Returns:
135
+ True if the type should use BaseSchema .from_dict() deserialization
136
+ """
137
+ # Extract the base type name from complex types
138
+ base_type = type_name
139
+
140
+ # Handle List[Type], Optional[Type], etc.
141
+ if "[" in base_type and "]" in base_type:
142
+ # Extract the inner type from List[Type], Optional[Type], etc.
143
+ start_bracket = base_type.find("[")
144
+ end_bracket = base_type.rfind("]")
145
+ inner_type = base_type[start_bracket + 1 : end_bracket]
146
+
147
+ # For Union types like Optional[User] -> Union[User, None], take the first type
148
+ if ", " in inner_type:
149
+ inner_type = inner_type.split(", ")[0]
150
+
151
+ base_type = inner_type.strip()
152
+
153
+ # Skip primitive types and built-ins (both uppercase and lowercase)
154
+ if base_type in {
155
+ "str",
156
+ "int",
157
+ "float",
158
+ "bool",
159
+ "bytes",
160
+ "None",
161
+ "Any",
162
+ "Dict",
163
+ "List",
164
+ "dict",
165
+ "list",
166
+ "tuple",
167
+ }:
168
+ return False
169
+
170
+ # Skip typing constructs (both uppercase and lowercase)
171
+ if base_type.startswith(("Dict[", "List[", "Optional[", "Union[", "Tuple[", "dict[", "list[", "tuple[")):
172
+ return False
173
+
174
+ # Check if this is a type alias (array or non-array) - these should NOT use BaseSchema
175
+ if self._is_type_alias_to_array(type_name) or self._is_type_alias_to_primitive(type_name):
176
+ return False
177
+
178
+ # All custom model types now inherit from BaseSchema for automatic field mapping
179
+ # Check if it's a model type (contains a dot indicating it's from models package)
180
+ # or if it's a simple class name that's likely a generated model (starts with uppercase)
181
+ return "." in base_type or (
182
+ base_type[0].isupper()
183
+ and base_type not in {"Dict", "List", "Optional", "Union", "Tuple", "dict", "list", "tuple"}
184
+ )
185
+
186
+ def _get_base_schema_deserialization_code(self, return_type: str, data_expr: str) -> str:
187
+ """
188
+ Generate BaseSchema deserialization code for a given type.
189
+
190
+ Args:
191
+ return_type: The return type (e.g., "User", "List[User]", "list[User]")
192
+ data_expr: The expression containing the raw data to deserialize
193
+
194
+ Returns:
195
+ Code string for deserializing the data using BaseSchema .from_dict()
196
+ """
197
+ if return_type.startswith("List[") or return_type.startswith("list["):
198
+ # Handle List[Model] or list[Model] types
199
+ if return_type.startswith("List["):
200
+ item_type = return_type[5:-1] # Remove 'List[' and ']'
201
+ else: # starts with "list["
202
+ item_type = return_type[5:-1] # Remove 'list[' and ']'
203
+ return f"[{item_type}.from_dict(item) for item in {data_expr}]"
204
+ elif return_type.startswith("Optional["):
205
+ # Handle Optional[Model] types
206
+ inner_type = return_type[9:-1] # Remove 'Optional[' and ']'
207
+ # Check if inner type is also a list
208
+ if inner_type.startswith("List[") or inner_type.startswith("list["):
209
+ list_code = self._get_base_schema_deserialization_code(inner_type, data_expr)
210
+ return f"{list_code} if {data_expr} is not None else None"
211
+ else:
212
+ return f"{inner_type}.from_dict({data_expr}) if {data_expr} is not None else None"
213
+ else:
214
+ # Handle simple Model types only - this should not be called for list types
215
+ if "[" in return_type and "]" in return_type:
216
+ # This is a complex type that we missed - should not happen
217
+ raise ValueError(f"Unsupported complex type for BaseSchema deserialization: {return_type}")
218
+
219
+ # Safety check: catch the specific issue we're debugging
220
+ if return_type.startswith("list[") or return_type.startswith("List["):
221
+ raise ValueError(
222
+ f"CRITICAL BUG: List type {return_type} reached simple type handler! This should never happen."
223
+ )
224
+
225
+ return f"{return_type}.from_dict({data_expr})"
226
+
53
227
  def _get_extraction_code(
54
228
  self,
55
229
  return_type: str,
56
230
  context: RenderContext,
57
231
  op: IROperation,
58
- needs_unwrap: bool,
59
232
  response_ir: Optional[IRResponse] = None,
60
233
  ) -> str:
61
234
  """Determines the code snippet to extract/transform the response body."""
@@ -79,7 +252,7 @@ class EndpointResponseHandlerGenerator:
79
252
  model_type = return_type[13:-1] # Remove 'AsyncIterator[' and ']'
80
253
  if response_ir and "text/event-stream" in response_ir.content:
81
254
  context.add_import(f"{context.core_package_name}.streaming_helpers", "iter_sse_events_text")
82
- return "sse_json_stream_marker" # Special marker for SSE
255
+ return "sse_json_stream_marker" # Special marker for SSE streaming
83
256
 
84
257
  # Default to bytes streaming for other types
85
258
  context.add_import(f"{context.core_package_name}.streaming_helpers", "iter_bytes")
@@ -100,32 +273,19 @@ class EndpointResponseHandlerGenerator:
100
273
  elif return_type == "None":
101
274
  return "None" # This will be handled by generate_response_handling directly
102
275
  else: # Includes schema-defined models, List[], Dict[], Optional[]
103
- context.add_import("typing", "cast")
104
276
  context.add_typing_imports_for_type(return_type) # Ensure model itself is imported
105
277
 
106
- if needs_unwrap:
107
- # Special handling for List unwrapping - ensure we have the correct imports
108
- if return_type.startswith("List["):
109
- # Extract the item type from List[ItemType]
110
- item_type = return_type[5:-1] # Remove 'List[' and ']'
111
- context.add_import("typing", "List")
112
- if "." in item_type:
113
- # Ensure we have the proper import for the item type
114
- context.add_typing_imports_for_type(item_type)
115
- # Handle unwrapping of List directly
116
- return (
117
- f"raw_data = response.json().get('data')\n"
118
- f"if raw_data is None:\n"
119
- f" raise ValueError(\"Expected 'data' key in response but found None\")\n"
120
- f"return cast({return_type}, raw_data)"
121
- )
122
- # Standard unwrapping for single object
123
- return (
124
- f"raw_data = response.json().get('data')\n"
125
- f"if raw_data is None:\n"
126
- f" raise ValueError(\"Expected 'data' key in response but found None\")\n"
127
- f"return cast({return_type}, raw_data)"
128
- )
278
+ # Check if we should use BaseSchema deserialization instead of cast()
279
+ use_base_schema = self._should_use_base_schema(return_type)
280
+
281
+ if not use_base_schema:
282
+ # Fallback to cast() for non-BaseSchema types
283
+ context.add_import("typing", "cast")
284
+
285
+ # Direct deserialization using schemas as-is (no unwrapping)
286
+ if use_base_schema:
287
+ deserialization_code = self._get_base_schema_deserialization_code(return_type, "response.json()")
288
+ return deserialization_code
129
289
  else:
130
290
  return f"cast({return_type}, response.json())"
131
291
 
@@ -134,364 +294,213 @@ class EndpointResponseHandlerGenerator:
134
294
  writer: CodeWriter,
135
295
  op: IROperation,
136
296
  context: RenderContext,
297
+ strategy: ResponseStrategy,
137
298
  ) -> None:
138
- """Writes the response parsing and return logic to the CodeWriter, including status code dispatch."""
299
+ """Writes the response parsing and return logic to the CodeWriter, using the unified response strategy."""
139
300
  writer.write_line("# Check response status code and handle accordingly")
140
301
 
141
- # Sort responses: specific 2xx, then default (if configured for success), then errors
142
- # This simplified sorting might need adjustment based on how 'default' is treated
143
- # For now, we'll explicitly find the primary success path first.
302
+ # Generate the match statement for status codes
303
+ writer.write_line("match response.status_code:")
304
+ writer.indent()
144
305
 
306
+ # Handle the primary success response first
145
307
  primary_success_ir = _get_primary_response(op)
146
-
147
- is_primary_actually_success = False
148
- if primary_success_ir: # Explicit check for None to help linter
149
- is_2xx = primary_success_ir.status_code.startswith("2")
150
- is_default_with_content = primary_success_ir.status_code == "default" and bool(
151
- primary_success_ir.content
152
- ) # Ensure this part is boolean
153
- is_primary_actually_success = is_2xx or is_default_with_content
154
-
155
- # Determine if the primary success response will be handled by the first dedicated block
156
- # This first block only handles numeric (2xx) success codes.
157
- is_primary_handled_by_first_block = (
308
+ processed_primary_success = False
309
+ if (
158
310
  primary_success_ir
159
- and is_primary_actually_success
160
- and primary_success_ir.status_code.isdigit() # Key change: first block only for numeric codes
161
- and primary_success_ir.status_code.startswith("2") # Ensure it's 2xx
162
- )
163
-
164
- other_responses = sorted(
165
- [
166
- r for r in op.responses if not (r == primary_success_ir and is_primary_handled_by_first_block)
167
- ], # If primary is handled by first block, exclude it from others
168
- key=lambda r: (
169
- not r.status_code.startswith("2"), # False for 2xx (comes first)
170
- r.status_code != "default", # False for default (comes after 2xx, before errors)
171
- r.status_code, # Then sort by status_code string
172
- ),
173
- )
174
-
175
- # Collect all status codes and their handlers for the match statement
176
- status_cases: list[StatusCase] = []
177
-
178
- # 1. Handle primary success response IF IT IS TRULY A SUCCESS RESPONSE AND NUMERIC (2xx)
179
- if is_primary_handled_by_first_block:
180
- assert primary_success_ir is not None # Add assertion to help linter
181
- # No try-except needed here as isdigit() and startswith("2") already checked
311
+ and primary_success_ir.status_code.isdigit()
312
+ and primary_success_ir.status_code.startswith("2")
313
+ ):
182
314
  status_code_val = int(primary_success_ir.status_code)
315
+ writer.write_line(f"case {status_code_val}:")
316
+ writer.indent()
183
317
 
184
- # This is the return_type for the *entire operation*, based on its primary success response
185
- # First try the fallback method for backward compatibility
186
- return_type_for_op = get_return_type_unified(op, context, self.schemas)
187
- needs_unwrap_for_op = False # Default to False
188
-
189
- # If we have proper schemas, try to get unwrapping information from unified service
190
- if self.schemas and hasattr(list(self.schemas.values())[0] if self.schemas else None, "type"):
191
- try:
192
- type_service = UnifiedTypeService(self.schemas)
193
- return_type_for_op, needs_unwrap_for_op = type_service.resolve_operation_response_with_unwrap_info(
194
- op, context
195
- )
196
- except Exception:
197
- # Fall back to the original approach if there's an issue
198
- needs_unwrap_for_op = False
199
-
200
- status_cases.append(
201
- StatusCase(
202
- status_code=status_code_val,
203
- type="primary_success",
204
- return_type=return_type_for_op,
205
- needs_unwrap=needs_unwrap_for_op,
206
- response_ir=primary_success_ir,
207
- )
208
- )
209
-
210
- # 2. Handle other specific responses (other 2xx, then default, then errors)
211
- default_case: Optional[DefaultCase] = None
212
- for resp_ir in other_responses:
213
- # Determine if this response IR defines a success type different from the primary
214
- # This is complex. For now, if it's 2xx, we'll try to parse it.
215
- # If it's an error, we raise.
216
-
217
- current_return_type_str: str = "None" # Default for e.g. 204 or error cases
218
- current_needs_unwrap: bool = False
318
+ if strategy.return_type == "None":
319
+ writer.write_line("return None")
320
+ else:
321
+ self._write_strategy_based_return(writer, strategy, context)
219
322
 
220
- if resp_ir.status_code.startswith("2"):
221
- if not resp_ir.content: # e.g. 204
222
- current_return_type_str = "None"
223
- else:
224
- # We need a way to get the type for *this specific* resp_ir if its schema differs
225
- # from the primary operation return type.
226
- # Call the new helper for this specific response
227
- current_return_type_str = get_type_for_specific_response(
228
- operation_path=getattr(op, "path", ""),
229
- resp_ir=resp_ir,
230
- all_schemas=self.schemas,
231
- ctx=context,
232
- return_unwrap_data_property=True,
233
- )
234
- current_needs_unwrap = (
235
- "data" in current_return_type_str.lower() or "item" in current_return_type_str.lower()
236
- )
237
-
238
- if resp_ir.status_code == "default":
239
- # Determine type for default response if it has content
240
- default_return_type_str = "None"
241
- default_needs_unwrap = False
242
- if resp_ir.content:
243
- # If 'default' is primary success, get_return_type_unified(op,...) might give its type.
244
- # We use the operation's global/primary return type if default has content.
245
- op_global_return_type = get_return_type_unified(op, context, self.schemas)
246
- op_global_needs_unwrap = False # Unified service handles unwrapping internally
247
- # Only use this if the global type is not 'None', otherwise keep default_return_type_str as 'None'.
248
- if op_global_return_type != "None":
249
- default_return_type_str = op_global_return_type
250
- default_needs_unwrap = op_global_needs_unwrap
251
-
252
- default_case = DefaultCase(
253
- response_ir=resp_ir, return_type=default_return_type_str, needs_unwrap=default_needs_unwrap
254
- )
255
- continue # Handle default separately
323
+ writer.dedent()
324
+ processed_primary_success = True
256
325
 
257
- try:
326
+ # Handle other responses (exclude primary only if it was actually processed)
327
+ other_responses = [r for r in op.responses if not (processed_primary_success and r == primary_success_ir)]
328
+ for resp_ir in other_responses:
329
+ if resp_ir.status_code.isdigit():
258
330
  status_code_val = int(resp_ir.status_code)
259
- case_type = "success" if resp_ir.status_code.startswith("2") else "error"
260
-
261
- status_cases.append(
262
- StatusCase(
263
- status_code=status_code_val,
264
- type=case_type,
265
- return_type=current_return_type_str,
266
- needs_unwrap=current_needs_unwrap,
267
- response_ir=resp_ir,
268
- )
269
- )
270
- except ValueError:
271
- logger.warning(f"Skipping non-integer status code in other_responses: {resp_ir.status_code}")
272
-
273
- # Generate the match statement
274
- if status_cases or default_case:
275
- writer.write_line("match response.status_code:")
276
- writer.indent()
277
-
278
- # Generate cases for specific status codes
279
- for case in status_cases:
280
- writer.write_line(f"case {case['status_code']}:")
331
+ writer.write_line(f"case {status_code_val}:")
281
332
  writer.indent()
282
333
 
283
- if case["type"] == "primary_success":
284
- # If get_return_type determined a specific type (not "None"),
285
- # we should attempt to parse the response accordingly. This handles cases
286
- # where the type was inferred even if the spec lacked explicit content for the 2xx.
287
- # If get_return_type says "None" (e.g., for a 204 or truly no content), then return None.
288
- if case["return_type"] == "None":
289
- writer.write_line("return None")
290
- else:
291
- self._write_parsed_return(
292
- writer, op, context, case["return_type"], case["needs_unwrap"], case["response_ir"]
293
- )
294
- elif case["type"] == "success":
295
- # Other 2xx success
296
- if case["return_type"] == "None" or not case["response_ir"].content:
334
+ if resp_ir.status_code.startswith("2"):
335
+ # Other 2xx success responses - resolve each response individually
336
+ if not resp_ir.content:
297
337
  writer.write_line("return None")
298
338
  else:
299
- self._write_parsed_return(
300
- writer, op, context, case["return_type"], case["needs_unwrap"], case["response_ir"]
301
- )
302
- elif case["type"] == "error":
303
- # Error codes (3xx, 4xx, 5xx)
304
- error_class_name = f"Error{case['status_code']}"
305
- context.add_import(
306
- f"{context.core_package_name}", error_class_name
307
- ) # Import from top-level core package
339
+ # Resolve the specific return type for this response
340
+ resp_schema = self._get_response_schema(resp_ir)
341
+ if resp_schema:
342
+ type_service = UnifiedTypeService(self.schemas)
343
+ response_type = type_service.resolve_schema_type(resp_schema, context)
344
+ if self._should_use_base_schema(response_type):
345
+ deserialization_code = self._get_base_schema_deserialization_code(
346
+ response_type, "response.json()"
347
+ )
348
+ writer.write_line(f"return {deserialization_code}")
349
+ context.add_typing_imports_for_type(response_type)
350
+ else:
351
+ context.add_import("typing", "cast")
352
+ writer.write_line(f"return cast({response_type}, response.json())")
353
+ else:
354
+ writer.write_line("return None")
355
+ else:
356
+ # Error responses
357
+ error_class_name = f"Error{status_code_val}"
358
+ context.add_import(f"{context.core_package_name}", error_class_name)
308
359
  writer.write_line(f"raise {error_class_name}(response=response)")
309
360
 
310
361
  writer.dedent()
311
362
 
312
- # Handle default case if it exists
313
- if default_case:
314
- # Default response case - catch all remaining status codes
315
- if default_case["response_ir"].content and default_case["return_type"] != "None":
316
- # Default case with content (success)
317
- writer.write_line("case _ if response.status_code >= 0: # Default response catch-all")
318
- writer.indent()
319
- self._write_parsed_return(
320
- writer,
321
- op,
322
- context,
323
- default_case["return_type"],
324
- default_case["needs_unwrap"],
325
- default_case["response_ir"],
326
- )
327
- writer.dedent()
328
- else:
329
- # Default case without content (error)
330
- writer.write_line("case _: # Default error response")
331
- writer.indent()
332
- context.add_import(f"{context.core_package_name}.exceptions", "HTTPError")
333
- default_description = default_case["response_ir"].description or "Unknown default error"
334
- writer.write_line(
335
- f"raise HTTPError(response=response, "
336
- f'message="Default error: {default_description}", '
337
- f"status_code=response.status_code)"
338
- )
339
- writer.dedent()
363
+ # Handle default case
364
+ default_response = next((r for r in op.responses if r.status_code == "default"), None)
365
+ if default_response:
366
+ writer.write_line("case _: # Default response")
367
+ writer.indent()
368
+ if default_response.content and strategy.return_type != "None":
369
+ self._write_strategy_based_return(writer, strategy, context)
340
370
  else:
341
- # Final catch-all for unhandled status codes
342
- writer.write_line("case _:")
343
- writer.indent()
344
371
  context.add_import(f"{context.core_package_name}.exceptions", "HTTPError")
345
372
  writer.write_line(
346
- "raise HTTPError("
347
- "response=response, "
348
- 'message="Unhandled status code", '
349
- "status_code=response.status_code)"
373
+ 'raise HTTPError(response=response, message="Default error", status_code=response.status_code)'
350
374
  )
351
- writer.dedent()
352
-
353
- writer.dedent() # End of match statement
375
+ writer.dedent()
354
376
  else:
355
- # Fallback if no responses are defined
356
- writer.write_line("match response.status_code:")
357
- writer.indent()
377
+ # Final catch-all
358
378
  writer.write_line("case _:")
359
379
  writer.indent()
360
380
  context.add_import(f"{context.core_package_name}.exceptions", "HTTPError")
361
381
  writer.write_line(
362
- f'raise HTTPError(response=response, message="Unhandled status code", status_code=response.status_code)'
382
+ 'raise HTTPError(response=response, message="Unhandled status code", status_code=response.status_code)'
363
383
  )
364
384
  writer.dedent()
365
- writer.dedent()
385
+
386
+ writer.dedent() # End of match statement
366
387
 
367
388
  # All code paths should be covered by the match statement above
368
- # But add an explicit assertion for mypy's satisfaction
369
389
  writer.write_line("# All paths above should return or raise - this should never execute")
370
390
  context.add_import("typing", "NoReturn")
371
391
  writer.write_line("assert False, 'Unexpected code path' # pragma: no cover")
372
392
  writer.write_line("") # Add a blank line for readability
373
393
 
374
- def _write_parsed_return(
394
+ def _write_strategy_based_return(
375
395
  self,
376
396
  writer: CodeWriter,
377
- op: IROperation,
397
+ strategy: ResponseStrategy,
378
398
  context: RenderContext,
379
- return_type: str,
380
- needs_unwrap: bool,
381
- response_ir: Optional[IRResponse] = None,
382
399
  ) -> None:
383
- """Helper to write the actual return statement with parsing/extraction logic."""
384
-
385
- # This section largely reuses the logic from the original generate_response_handling
386
- # adapted to be callable for a specific return_type and response context.
387
-
388
- is_op_with_inferred_type = return_type != "None" and not any(
389
- r.content for r in op.responses if r.status_code.startswith("2")
390
- ) # This might need adjustment if called for a specific non-primary response.
391
-
392
- if return_type.startswith("Union["):
393
- context.add_import("typing", "Union")
394
- context.add_import("typing", "cast")
395
- # Corrected regex to parse "Union[TypeA, TypeB]"
396
- match = re.match(r"Union\[([A-Za-z0-9_]+),\s*([A-Za-z0-9_]+)\]", return_type)
397
- if match:
398
- type1_str = match.group(1).strip()
399
- type2_str = match.group(2).strip()
400
- context.add_typing_imports_for_type(type1_str)
401
- context.add_typing_imports_for_type(type2_str)
402
- writer.write_line("try:")
403
- writer.indent()
404
- # Pass response_ir to _get_extraction_code if available
405
- extraction_code_type1 = self._get_extraction_code(type1_str, context, op, needs_unwrap, response_ir)
406
- if "\n" in extraction_code_type1: # Multi-line extraction
407
- lines = extraction_code_type1.split("\n")
408
- for line in lines[:-1]: # all but 'return ...'
409
- writer.write_line(line)
410
- writer.write_line(lines[-1].replace("return ", "return_value = "))
411
- writer.write_line("return return_value")
412
- else:
413
- writer.write_line(f"return {extraction_code_type1}")
414
-
415
- writer.dedent()
416
- writer.write_line("except Exception: # Attempt to parse as the second type")
400
+ """Write the return statement based on the response strategy.
401
+
402
+ This method implements the strategy pattern for response handling,
403
+ ensuring consistent behavior between signature and implementation.
404
+ """
405
+ if strategy.is_streaming:
406
+ # Handle streaming responses
407
+ if "AsyncIterator[bytes]" in strategy.return_type:
408
+ context.add_import(f"{context.core_package_name}.streaming_helpers", "iter_bytes")
409
+ writer.write_line("async for chunk in iter_bytes(response):")
417
410
  writer.indent()
418
- extraction_code_type2 = self._get_extraction_code(type2_str, context, op, needs_unwrap, response_ir)
419
- if "\n" in extraction_code_type2: # Multi-line extraction
420
- lines = extraction_code_type2.split("\n")
421
- for line in lines[:-1]:
422
- writer.write_line(line)
423
- writer.write_line(lines[-1].replace("return ", "return_value = "))
424
- writer.write_line("return return_value")
425
- else:
426
- writer.write_line(f"return {extraction_code_type2}")
411
+ writer.write_line("yield chunk")
427
412
  writer.dedent()
413
+ writer.write_line("return # Explicit return for async generator")
428
414
  else:
429
- logger.warning(
430
- f"Could not parse Union components with regex: {return_type}. Falling back to cast(Any, ...)"
431
- )
432
- context.add_import("typing", "Any")
433
- writer.write_line(f"return cast(Any, response.json())")
434
-
435
- elif return_type == "None": # Explicit None, e.g. for 204 or when specific response has no content
436
- writer.write_line("return None")
437
- elif is_op_with_inferred_type: # This condition may need re-evaluation in this context
438
- context.add_typing_imports_for_type(return_type)
439
- context.add_import("typing", "cast")
440
- writer.write_line(f"return cast({return_type}, response.json())")
441
- else:
442
- context.add_typing_imports_for_type(return_type)
443
- extraction_code_str = self._get_extraction_code(return_type, context, op, needs_unwrap, response_ir)
444
-
445
- if extraction_code_str == "sse_json_stream_marker": # SSE handling
415
+ # Handle other streaming types
446
416
  context.add_plain_import("json")
447
417
  context.add_import(f"{context.core_package_name}.streaming_helpers", "iter_sse_events_text")
448
- # The actual yield loop must be outside, this function is about the *return value* for one branch.
449
- # This indicates that SSE streaming might need to be handled more holistically.
450
- # For now, if we hit this, it means get_return_type decided on AsyncIterator for an SSE.
451
- # The method signature is already async iterator.
452
- # The dispatcher should yield from the iter_sse_events_text.
453
- # This implies that the `if response.status_code == ...:` block itself needs to be `async for ... yield`
454
- # This refactoring is getting deeper.
455
- # Quick fix: if it's sse_json_stream_marker, we write the loop here.
456
- writer.write_line(f"async for chunk in iter_sse_events_text(response):")
418
+ writer.write_line("async for chunk in iter_sse_events_text(response):")
457
419
  writer.indent()
458
- writer.write_line("yield json.loads(chunk)") # Assuming item_type for SSE is JSON decodable
420
+ writer.write_line("yield json.loads(chunk)")
459
421
  writer.dedent()
460
- writer.write_line(
461
- "return # Explicit return for async generator"
462
- ) # Ensure function ends if it's a generator path
463
- elif extraction_code_str == "iter_bytes(response)" or (
464
- return_type.startswith("AsyncIterator[") and "Iterator" in return_type
465
- ):
466
- # Handle streaming responses - either binary (bytes) or event-stream (Dict[str, Any])
467
- context.add_import(f"{context.core_package_name}.streaming_helpers", "iter_bytes")
468
- if return_type == "AsyncIterator[bytes]":
469
- # Binary streaming
470
- writer.write_line(f"async for chunk in iter_bytes(response):")
471
- writer.indent()
472
- writer.write_line("yield chunk")
473
- writer.dedent()
474
- elif "Dict[str, Any]" in return_type or "dict" in return_type.lower():
475
- # Event-stream or JSON streaming
476
- context.add_plain_import("json")
477
- context.add_import(f"{context.core_package_name}.streaming_helpers", "iter_sse_events_text")
478
- writer.write_line(f"async for chunk in iter_sse_events_text(response):")
479
- writer.indent()
480
- writer.write_line("yield json.loads(chunk)")
481
- writer.dedent()
422
+ writer.write_line("return # Explicit return for async generator")
423
+ return
424
+
425
+ # Handle responses using the schema as-is from the OpenAPI spec (no unwrapping)
426
+ if strategy.return_type.startswith("Union["):
427
+ # Special handling for Union types with try/except fallback
428
+ self._write_union_response_handling(writer, context, strategy.return_type, "response.json()")
429
+ elif self._should_use_base_schema(strategy.return_type):
430
+ deserialization_code = self._get_base_schema_deserialization_code(strategy.return_type, "response.json()")
431
+ writer.write_line(f"return {deserialization_code}")
432
+ else:
433
+ context.add_import("typing", "cast")
434
+ writer.write_line(f"return cast({strategy.return_type}, response.json())")
435
+
436
+ def _get_response_schema(self, response_ir: IRResponse) -> Optional[IRSchema]:
437
+ """Extract the schema from a response IR."""
438
+ if not response_ir.content:
439
+ return None
440
+
441
+ # Prefer application/json, then first available content type
442
+ content_types = list(response_ir.content.keys())
443
+ preferred_content_type = next((ct for ct in content_types if ct == "application/json"), None)
444
+ if not preferred_content_type:
445
+ preferred_content_type = content_types[0] if content_types else None
446
+
447
+ if preferred_content_type:
448
+ return response_ir.content.get(preferred_content_type)
449
+
450
+ return None
451
+
452
+ def _write_union_response_handling(
453
+ self, writer: CodeWriter, context: RenderContext, return_type: str, data_expr: str
454
+ ) -> None:
455
+ """Write try/except logic for Union types."""
456
+ # Parse Union[TypeA, TypeB] to extract the types
457
+ if not return_type.startswith("Union[") or not return_type.endswith("]"):
458
+ raise ValueError(f"Invalid Union type format: {return_type}")
459
+
460
+ union_content = return_type[6:-1] # Remove 'Union[' and ']'
461
+ types = [t.strip() for t in union_content.split(",")]
462
+
463
+ if len(types) < 2:
464
+ raise ValueError(f"Union type must have at least 2 types: {return_type}")
465
+
466
+ # Add Union import
467
+ context.add_import("typing", "Union")
468
+
469
+ # Generate try/except blocks for each type
470
+ first_type = types[0]
471
+ remaining_types = types[1:]
472
+
473
+ # Try the first type
474
+ writer.write_line("try:")
475
+ writer.indent()
476
+ if self._should_use_base_schema(first_type):
477
+ context.add_typing_imports_for_type(first_type)
478
+ deserialization_code = self._get_base_schema_deserialization_code(first_type, data_expr)
479
+ writer.write_line(f"return {deserialization_code}")
480
+ else:
481
+ context.add_import("typing", "cast")
482
+ writer.write_line(f"return cast({first_type}, {data_expr})")
483
+ writer.dedent()
484
+
485
+ # Add except blocks for remaining types
486
+ for i, type_name in enumerate(remaining_types):
487
+ is_last = i == len(remaining_types) - 1
488
+ if is_last:
489
+ writer.write_line("except Exception: # Attempt to parse as the final type")
490
+ else:
491
+ writer.write_line("except Exception: # Attempt to parse as the next type")
492
+ writer.indent()
493
+ if self._should_use_base_schema(type_name):
494
+ context.add_typing_imports_for_type(type_name)
495
+ deserialization_code = self._get_base_schema_deserialization_code(type_name, data_expr)
496
+ if is_last:
497
+ writer.write_line(f"return {deserialization_code}")
482
498
  else:
483
- # Other streaming type
484
- writer.write_line(f"async for chunk in iter_bytes(response):")
499
+ writer.write_line("try:")
485
500
  writer.indent()
486
- writer.write_line("yield chunk")
501
+ writer.write_line(f"return {deserialization_code}")
487
502
  writer.dedent()
488
- writer.write_line("return # Explicit return for async generator")
489
-
490
- elif "\n" in extraction_code_str: # Multi-line extraction code (e.g. data unwrap)
491
- # The _get_extraction_code for unwrap already includes "return cast(...)"
492
- for line in extraction_code_str.split("\n"):
493
- writer.write_line(line)
494
- else: # Single line extraction code
495
- if return_type != "None": # Should already be handled, but as safety
496
- writer.write_line(f"return {extraction_code_str}")
497
- # writer.write_line("") # Blank line might be added by the caller of this helper
503
+ else:
504
+ context.add_import("typing", "cast")
505
+ writer.write_line(f"return cast({type_name}, {data_expr})")
506
+ writer.dedent()