pyopenapi-gen 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyopenapi-gen might be problematic. Click here for more details.

@@ -42,6 +42,7 @@ class EndpointVisitor(Visitor[IROperation, str]):
42
42
  tag: str,
43
43
  method_codes: list[str],
44
44
  context: RenderContext,
45
+ operations: list[IROperation] | None = None,
45
46
  ) -> str:
46
47
  """
47
48
  Emit the endpoint client class for a tag, aggregating all endpoint methods.
@@ -50,6 +51,152 @@ class EndpointVisitor(Visitor[IROperation, str]):
50
51
  tag: The tag name for the endpoint group.
51
52
  method_codes: List of method code blocks as strings.
52
53
  context: The RenderContext for import tracking.
54
+ operations: List of operations for Protocol generation (optional for backwards compatibility).
55
+ """
56
+ # Generate Protocol if operations provided
57
+ protocol_code = ""
58
+ if operations:
59
+ protocol_code = self.generate_endpoint_protocol(tag, operations, context)
60
+
61
+ # Generate implementation
62
+ impl_code = self._generate_endpoint_implementation(tag, method_codes, context)
63
+
64
+ # Combine Protocol and implementation
65
+ if protocol_code:
66
+ return f"{protocol_code}\n\n\n{impl_code}"
67
+ else:
68
+ return impl_code
69
+
70
+ def generate_endpoint_protocol(self, tag: str, operations: list[IROperation], context: RenderContext) -> str:
71
+ """
72
+ Generate Protocol definition for tag-based endpoint client.
73
+
74
+ Args:
75
+ tag: The tag name for the endpoint group
76
+ operations: List of operations for this tag
77
+ context: Render context for import management
78
+
79
+ Returns:
80
+ Protocol class code as string with all operation method signatures
81
+ """
82
+ # Register Protocol imports
83
+ context.add_import("typing", "Protocol")
84
+ context.add_import("typing", "runtime_checkable")
85
+
86
+ writer = CodeWriter()
87
+ class_name = NameSanitizer.sanitize_class_name(tag) + "Client"
88
+ protocol_name = f"{class_name}Protocol"
89
+
90
+ # Protocol class header
91
+ writer.write_line("@runtime_checkable")
92
+ writer.write_line(f"class {protocol_name}(Protocol):")
93
+ writer.indent()
94
+
95
+ # Docstring
96
+ writer.write_line(f'"""Protocol defining the interface of {class_name} for dependency injection."""')
97
+ writer.write_line("")
98
+
99
+ # Generate method signatures from operations
100
+ # We need to extract complete signatures including multi-line ones and decorators
101
+ # For Protocol, we only include the method signatures with ..., not implementations
102
+ # IMPORTANT: Preserve multi-line formatting for readability
103
+ for op in operations:
104
+ method_generator = EndpointMethodGenerator(schemas=self.schemas)
105
+ full_method_code = method_generator.generate(op, context)
106
+
107
+ # Parse the generated code to extract method signatures
108
+ # We want: @overload stubs (already have ...) and final signature converted to stub
109
+ lines = full_method_code.split("\n")
110
+ i = 0
111
+
112
+ while i < len(lines):
113
+ line = lines[i]
114
+ stripped = line.strip()
115
+
116
+ # Handle @overload decorator
117
+ if stripped.startswith("@overload"):
118
+ # Write decorator
119
+ writer.write_line(stripped)
120
+ i += 1
121
+
122
+ # Now process the signature following the decorator
123
+ # Keep collecting lines until we hit the end of the overload signature
124
+ while i < len(lines):
125
+ sig_line = lines[i]
126
+ sig_stripped = sig_line.strip()
127
+
128
+ # Write each line of the signature
129
+ writer.write_line(sig_stripped)
130
+
131
+ # Check for end of overload signature (ends with `: ...`)
132
+ if sig_stripped.endswith(": ..."):
133
+ writer.write_line("") # Blank line after overload
134
+ i += 1
135
+ break
136
+
137
+ i += 1
138
+ continue
139
+
140
+ # Handle non-overload method signatures (the final implementation signature)
141
+ if stripped.startswith("async def ") and "(" in stripped:
142
+ # This is the start of a method signature
143
+ # We need to collect all lines until we hit the colon
144
+ signature_lines = []
145
+
146
+ # Collect signature lines
147
+ while i < len(lines):
148
+ sig_line = lines[i]
149
+ sig_stripped = sig_line.strip()
150
+
151
+ signature_lines.append(sig_stripped)
152
+
153
+ # Check if this completes the signature (ends with :)
154
+ if sig_stripped.endswith(":") and not sig_stripped.endswith(","):
155
+ # This is the final line of the signature
156
+ # For Protocol, convert to stub format
157
+
158
+ # Check if this is an async generator (returns AsyncIterator)
159
+ # If so, remove 'async' from the first line
160
+ is_async_generator = "AsyncIterator" in sig_stripped
161
+
162
+ # Write all lines except the last
163
+ for idx, sig in enumerate(signature_lines[:-1]):
164
+ # For async generators, remove 'async ' from method definition
165
+ if idx == 0 and is_async_generator and sig.startswith("async def "):
166
+ sig = sig.replace("async def ", "def ", 1)
167
+ writer.write_line(sig)
168
+
169
+ # Write last line with ... instead of :
170
+ last_line = signature_lines[-1]
171
+ if last_line.endswith(":"):
172
+ last_line = last_line[:-1] # Remove trailing :
173
+ writer.write_line(f"{last_line}: ...")
174
+ writer.write_line("") # Blank line after method
175
+
176
+ # For Protocol, we only want the signature stub, not the implementation
177
+ # Skip all remaining lines of this method by jumping to end
178
+ i = len(lines) # This will exit the while loop
179
+ break
180
+
181
+ i += 1
182
+ continue
183
+
184
+ i += 1
185
+
186
+ writer.dedent() # Close class
187
+ return writer.get_code()
188
+
189
+ def _generate_endpoint_implementation(self, tag: str, method_codes: list[str], context: RenderContext) -> str:
190
+ """
191
+ Generate the endpoint client implementation class.
192
+
193
+ Args:
194
+ tag: The tag name for the endpoint group
195
+ method_codes: List of method code blocks as strings
196
+ context: Render context for import management
197
+
198
+ Returns:
199
+ Implementation class code as string
53
200
  """
54
201
  context.add_import("typing", "cast")
55
202
  # Import core transport and streaming helpers
@@ -59,7 +206,10 @@ class EndpointVisitor(Visitor[IROperation, str]):
59
206
  context.add_import("typing", "Optional")
60
207
  writer = CodeWriter()
61
208
  class_name = NameSanitizer.sanitize_class_name(tag) + "Client"
62
- writer.write_line(f"class {class_name}:")
209
+ protocol_name = f"{class_name}Protocol"
210
+
211
+ # Class definition - implements Protocol
212
+ writer.write_line(f"class {class_name}({protocol_name}):")
63
213
  writer.indent()
64
214
  writer.write_line(f'"""Client for {tag} endpoints. Uses HttpTransport for all HTTP and header management."""')
65
215
  writer.write_line("")
@@ -82,3 +232,61 @@ class EndpointVisitor(Visitor[IROperation, str]):
82
232
 
83
233
  writer.dedent() # Dedent to close the class block
84
234
  return writer.get_code()
235
+
236
+ def generate_endpoint_mock_class(self, tag: str, operations: list[IROperation], context: RenderContext) -> str:
237
+ """
238
+ Generate mock implementation class for tag-based endpoint client.
239
+
240
+ Args:
241
+ tag: The tag name for the endpoint group
242
+ operations: List of operations for this tag
243
+ context: Render context for import management
244
+
245
+ Returns:
246
+ Mock class code as string with all operation method stubs
247
+ """
248
+ from .generators.mock_generator import MockGenerator
249
+
250
+ # Import Protocol for type checking
251
+ context.add_import("typing", "TYPE_CHECKING")
252
+
253
+ writer = CodeWriter()
254
+ class_name = NameSanitizer.sanitize_class_name(tag) + "Client"
255
+ protocol_name = f"{class_name}Protocol"
256
+ mock_class_name = f"Mock{class_name}"
257
+
258
+ # TYPE_CHECKING import for Protocol
259
+ writer.write_line("if TYPE_CHECKING:")
260
+ writer.indent()
261
+ writer.write_line(f"from ...endpoints.{NameSanitizer.sanitize_module_name(tag)} import {protocol_name}")
262
+ writer.dedent()
263
+ writer.write_line("")
264
+
265
+ # Class header with docstring
266
+ writer.write_line(f"class {mock_class_name}:")
267
+ writer.indent()
268
+ writer.write_line('"""')
269
+ writer.write_line(f"Mock implementation of {class_name} for testing.")
270
+ writer.write_line("")
271
+ writer.write_line("Provides default implementations that raise NotImplementedError.")
272
+ writer.write_line("Override methods as needed in your tests.")
273
+ writer.write_line("")
274
+ writer.write_line("Example:")
275
+ writer.write_line(f" class Test{class_name}({mock_class_name}):")
276
+ writer.write_line(" async def method_name(self, ...) -> ReturnType:")
277
+ writer.write_line(" return test_data")
278
+ writer.write_line('"""')
279
+ writer.write_line("")
280
+
281
+ # Generate mock methods
282
+ mock_generator = MockGenerator(schemas=self.schemas)
283
+ for i, op in enumerate(operations):
284
+ mock_method_code = mock_generator.generate(op, context)
285
+ writer.write_block(mock_method_code)
286
+
287
+ if i < len(operations) - 1:
288
+ writer.write_line("") # Blank line between methods
289
+ writer.write_line("") # Second blank line for consistency
290
+
291
+ writer.dedent() # Close class
292
+ return writer.get_code()
@@ -0,0 +1,140 @@
1
+ """
2
+ Generator for creating mock method implementations.
3
+
4
+ This module generates mock methods that raise NotImplementedError,
5
+ allowing users to create test doubles by subclassing and overriding
6
+ only the methods they need.
7
+ """
8
+
9
+ from typing import Any
10
+
11
+ from ....context.render_context import RenderContext
12
+ from ....core.utils import NameSanitizer
13
+ from ....core.writers.code_writer import CodeWriter
14
+ from ....ir import IROperation
15
+ from .endpoint_method_generator import EndpointMethodGenerator
16
+
17
+
18
+ class MockGenerator:
19
+ """
20
+ Generates mock method implementations for testing.
21
+
22
+ Mock methods preserve the exact signature of the real implementation
23
+ but raise NotImplementedError with helpful error messages instead
24
+ of performing actual operations.
25
+ """
26
+
27
+ def __init__(self, schemas: dict[str, Any] | None = None) -> None:
28
+ self.schemas = schemas or {}
29
+ self.method_generator = EndpointMethodGenerator(self.schemas)
30
+
31
+ def generate(self, op: IROperation, context: RenderContext) -> str:
32
+ """
33
+ Generate a mock method that raises NotImplementedError.
34
+
35
+ Args:
36
+ op: The operation to generate a mock for
37
+ context: Render context for import tracking
38
+
39
+ Returns:
40
+ Complete mock method code as string
41
+ """
42
+ # Generate the full method using EndpointMethodGenerator
43
+ full_method = self.method_generator.generate(op, context)
44
+
45
+ # Parse and transform it to a mock implementation
46
+ return self._transform_to_mock(full_method, op)
47
+
48
+ def _transform_to_mock(self, full_method_code: str, op: IROperation) -> str:
49
+ """
50
+ Transform a full method implementation into a mock that raises NotImplementedError.
51
+
52
+ Args:
53
+ full_method_code: Complete method code from EndpointMethodGenerator
54
+ op: The operation (for generating error messages)
55
+
56
+ Returns:
57
+ Mock method code with NotImplementedError body
58
+ """
59
+ lines = full_method_code.split("\n")
60
+ writer = CodeWriter()
61
+
62
+ i = 0
63
+ while i < len(lines):
64
+ line = lines[i]
65
+ stripped = line.strip()
66
+
67
+ # Handle @overload decorator - keep it
68
+ if stripped.startswith("@overload"):
69
+ writer.write_line(stripped)
70
+ i += 1
71
+
72
+ # Copy overload signature until we hit `: ...`
73
+ while i < len(lines):
74
+ sig_line = lines[i]
75
+ sig_stripped = sig_line.strip()
76
+ writer.write_line(sig_stripped)
77
+
78
+ if sig_stripped.endswith(": ..."):
79
+ writer.write_line("") # Blank line after overload
80
+ i += 1
81
+ break
82
+
83
+ i += 1
84
+ continue
85
+
86
+ # Handle method definition (async def or def)
87
+ if (stripped.startswith("async def ") or stripped.startswith("def ")) and "(" in stripped:
88
+ # Determine if this is an async generator
89
+ is_async_generator = False
90
+
91
+ # Collect signature lines to check return type
92
+ signature_lines = []
93
+ temp_i = i
94
+ while temp_i < len(lines):
95
+ sig_stripped = lines[temp_i].strip()
96
+ signature_lines.append(sig_stripped)
97
+ if sig_stripped.endswith(":") and not sig_stripped.endswith(","):
98
+ # Check if AsyncIterator in return type
99
+ full_sig = " ".join(signature_lines)
100
+ is_async_generator = "AsyncIterator" in full_sig
101
+ break
102
+ temp_i += 1
103
+
104
+ # Write signature lines
105
+ for sig in signature_lines:
106
+ writer.write_line(sig)
107
+
108
+ # Write mock body
109
+ writer.indent()
110
+
111
+ # Docstring
112
+ writer.write_line('"""')
113
+ writer.write_line("Mock implementation that raises NotImplementedError.")
114
+ writer.write_line("")
115
+ writer.write_line("Override this method in your test subclass to provide")
116
+ writer.write_line("the behavior needed for your test scenario.")
117
+ writer.write_line('"""')
118
+
119
+ # Error message
120
+ method_name = NameSanitizer.sanitize_method_name(op.operation_id)
121
+ tag = op.tags[0] if op.tags else "Client"
122
+ class_name = f"Mock{NameSanitizer.sanitize_class_name(tag)}Client"
123
+ error_msg = (
124
+ f'"{class_name}.{method_name}() not implemented. ' f'Override this method in your test subclass."'
125
+ )
126
+ writer.write_line(f"raise NotImplementedError({error_msg})")
127
+
128
+ # For async generators, add unreachable yield for type checker
129
+ if is_async_generator:
130
+ writer.write_line("yield # pragma: no cover")
131
+
132
+ writer.dedent()
133
+
134
+ # Skip the rest of this method implementation in the original code
135
+ i = len(lines) # Exit the loop
136
+ break
137
+
138
+ i += 1
139
+
140
+ return writer.get_code()
@@ -43,17 +43,22 @@ class EndpointUrlArgsGenerator:
43
43
  # writer.write_line("# No query parameters to write") # Optional: for clarity during debugging
44
44
  return
45
45
 
46
+ # Import DataclassSerializer since we use it for parameter serialization
47
+ context.add_import(f"{context.core_package_name}.utils", "DataclassSerializer")
48
+
46
49
  for i, p in enumerate(query_params_to_write):
47
50
  param_var_name = NameSanitizer.sanitize_method_name(p["name"]) # Ensure name is sanitized
48
51
  original_param_name = p["original_name"]
49
52
  line_end = "," # Always add comma, let formatter handle final one if needed
50
53
 
51
54
  if p.get("required", False):
52
- writer.write_line(f' "{original_param_name}": {param_var_name}{line_end}')
55
+ writer.write_line(
56
+ f' "{original_param_name}": DataclassSerializer.serialize({param_var_name}){line_end}'
57
+ )
53
58
  else:
54
59
  # Using dict unpacking for conditional parameters
55
60
  writer.write_line(
56
- f' **({{"{original_param_name}": {param_var_name}}} '
61
+ f' **({{"{original_param_name}": DataclassSerializer.serialize({param_var_name})}} '
57
62
  f"if {param_var_name} is not None else {{}}){line_end}"
58
63
  )
59
64
 
@@ -66,6 +71,10 @@ class EndpointUrlArgsGenerator:
66
71
  # if ordered_params is the sole source of truth for method params.
67
72
  header_params_to_write = [p for p in ordered_params if p.get("param_in") == "header"]
68
73
 
74
+ # Import DataclassSerializer since we use it for parameter serialization
75
+ if header_params_to_write:
76
+ context.add_import(f"{context.core_package_name}.utils", "DataclassSerializer")
77
+
69
78
  for p_info in header_params_to_write:
70
79
  param_var_name = NameSanitizer.sanitize_method_name(
71
80
  p_info["name"]
@@ -74,13 +83,15 @@ class EndpointUrlArgsGenerator:
74
83
  line_end = ","
75
84
 
76
85
  if p_info.get("required", False):
77
- writer.write_line(f' "{original_header_name}": {param_var_name}{line_end}')
86
+ writer.write_line(
87
+ f' "{original_header_name}": DataclassSerializer.serialize({param_var_name}){line_end}'
88
+ )
78
89
  else:
79
90
  # Conditional inclusion for optional headers
80
91
  # This assumes that if an optional header parameter is None, it should not be sent.
81
92
  # If specific behavior (e.g. empty string) is needed for None, logic would adjust.
82
93
  writer.write_line(
83
- f' **({{"{original_header_name}": {param_var_name}}} '
94
+ f' **({{"{original_header_name}": DataclassSerializer.serialize({param_var_name})}} '
84
95
  f"if {param_var_name} is not None else {{}}){line_end}"
85
96
  )
86
97
 
@@ -95,6 +106,19 @@ class EndpointUrlArgsGenerator:
95
106
  ) -> bool:
96
107
  """Writes URL, query, and header parameters. Returns True if header params were written."""
97
108
  # Main logic from EndpointMethodGenerator._write_url_and_args
109
+
110
+ # Serialize path parameters before URL construction
111
+ # This ensures enums, dates, and other complex types are converted to strings
112
+ # before f-string interpolation in the URL
113
+ path_params = [p for p in ordered_params if p.get("param_in") == "path"]
114
+ if path_params:
115
+ # Import DataclassSerializer since we use it for parameter serialization
116
+ context.add_import(f"{context.core_package_name}.utils", "DataclassSerializer")
117
+ for p in path_params:
118
+ param_var_name = NameSanitizer.sanitize_method_name(p["name"])
119
+ writer.write_line(f"{param_var_name} = DataclassSerializer.serialize({param_var_name})")
120
+ writer.write_line("") # Blank line after path param serialization
121
+
98
122
  url_expr = self._build_url_with_path_vars(op.path)
99
123
  writer.write_line(f"url = {url_expr}")
100
124
  writer.write_line("") # Add a blank line for readability