otterapi 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. README.md +581 -8
  2. otterapi/__init__.py +73 -0
  3. otterapi/cli.py +327 -29
  4. otterapi/codegen/__init__.py +115 -0
  5. otterapi/codegen/ast_utils.py +134 -5
  6. otterapi/codegen/client.py +1271 -0
  7. otterapi/codegen/codegen.py +1736 -0
  8. otterapi/codegen/dataframes.py +392 -0
  9. otterapi/codegen/emitter.py +473 -0
  10. otterapi/codegen/endpoints.py +2597 -343
  11. otterapi/codegen/pagination.py +1026 -0
  12. otterapi/codegen/schema.py +593 -0
  13. otterapi/codegen/splitting.py +1397 -0
  14. otterapi/codegen/types.py +1345 -0
  15. otterapi/codegen/utils.py +180 -1
  16. otterapi/config.py +1017 -24
  17. otterapi/exceptions.py +231 -0
  18. otterapi/openapi/__init__.py +46 -0
  19. otterapi/openapi/v2/__init__.py +86 -0
  20. otterapi/openapi/v2/spec.json +1607 -0
  21. otterapi/openapi/v2/v2.py +1776 -0
  22. otterapi/openapi/v3/__init__.py +131 -0
  23. otterapi/openapi/v3/spec.json +1651 -0
  24. otterapi/openapi/v3/v3.py +1557 -0
  25. otterapi/openapi/v3_1/__init__.py +133 -0
  26. otterapi/openapi/v3_1/spec.json +1411 -0
  27. otterapi/openapi/v3_1/v3_1.py +798 -0
  28. otterapi/openapi/v3_2/__init__.py +133 -0
  29. otterapi/openapi/v3_2/spec.json +1666 -0
  30. otterapi/openapi/v3_2/v3_2.py +777 -0
  31. otterapi/tests/__init__.py +3 -0
  32. otterapi/tests/fixtures/__init__.py +455 -0
  33. otterapi/tests/test_ast_utils.py +680 -0
  34. otterapi/tests/test_codegen.py +610 -0
  35. otterapi/tests/test_dataframe.py +1038 -0
  36. otterapi/tests/test_exceptions.py +493 -0
  37. otterapi/tests/test_openapi_support.py +616 -0
  38. otterapi/tests/test_openapi_upgrade.py +215 -0
  39. otterapi/tests/test_pagination.py +1101 -0
  40. otterapi/tests/test_splitting_config.py +319 -0
  41. otterapi/tests/test_splitting_integration.py +427 -0
  42. otterapi/tests/test_splitting_resolver.py +512 -0
  43. otterapi/tests/test_splitting_tree.py +525 -0
  44. otterapi-0.0.6.dist-info/METADATA +627 -0
  45. otterapi-0.0.6.dist-info/RECORD +48 -0
  46. {otterapi-0.0.5.dist-info → otterapi-0.0.6.dist-info}/WHEEL +1 -1
  47. otterapi/codegen/generator.py +0 -358
  48. otterapi/codegen/openapi_processor.py +0 -27
  49. otterapi/codegen/type_generator.py +0 -559
  50. otterapi-0.0.5.dist-info/METADATA +0 -54
  51. otterapi-0.0.5.dist-info/RECORD +0 -16
  52. {otterapi-0.0.5.dist-info → otterapi-0.0.6.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,1345 @@
1
+ """Type definitions and generation for OtterAPI code generation.
2
+
3
+ This module provides:
4
+ - Type dataclasses for representing generated types, parameters, responses, and endpoints
5
+ - TypeGenerator for creating Pydantic models from OpenAPI schemas
6
+ - TypeRegistry for managing generated types and their dependencies
7
+ - ModelNameCollector for tracking model usage in generated code
8
+ """
9
+
10
+ import ast
11
+ import dataclasses
12
+ from collections.abc import Iterable, Iterator
13
+ from datetime import datetime
14
+ from typing import Any, Literal
15
+ from uuid import UUID
16
+
17
+ from pydantic import BaseModel, Field, RootModel
18
+
19
+ from otterapi.codegen.ast_utils import _call, _name, _subscript, _union_expr
20
+ from otterapi.codegen.utils import (
21
+ OpenAPIProcessor,
22
+ sanitize_identifier,
23
+ sanitize_parameter_field_name,
24
+ )
25
+ from otterapi.openapi.v3_2 import Reference, Schema, Type as DataType
26
+
27
+ __all__ = [
28
+ 'Type',
29
+ 'Parameter',
30
+ 'ResponseInfo',
31
+ 'RequestBodyInfo',
32
+ 'Endpoint',
33
+ 'TypeGenerator',
34
+ 'TypeInfo',
35
+ 'TypeRegistry',
36
+ 'ModelNameCollector',
37
+ 'collect_used_model_names',
38
+ ]
39
+
40
+ _PRIMITIVE_TYPE_MAP = {
41
+ ('string', None): str,
42
+ ('string', 'date-time'): datetime,
43
+ ('string', 'date'): datetime,
44
+ ('string', 'uuid'): UUID,
45
+ ('integer', None): int,
46
+ ('integer', 'int32'): int,
47
+ ('integer', 'int64'): int,
48
+ ('number', None): float,
49
+ ('number', 'float'): float,
50
+ ('number', 'double'): float,
51
+ ('boolean', None): bool,
52
+ ('null', None): None,
53
+ (None, None): None,
54
+ }
55
+
56
+
57
+ @dataclasses.dataclass
58
+ class Type:
59
+ reference: str | None # reference is None if type is 'primitive'
60
+ name: str | None
61
+ type: Literal['primitive', 'root', 'model']
62
+ annotation_ast: ast.expr | ast.stmt | None = dataclasses.field(default=None)
63
+ implementation_ast: ast.expr | ast.stmt | None = dataclasses.field(default=None)
64
+ dependencies: set[str] = dataclasses.field(default_factory=set)
65
+ implementation_imports: dict[str, set[str]] = dataclasses.field(
66
+ default_factory=dict
67
+ )
68
+ annotation_imports: dict[str, set[str]] = dataclasses.field(default_factory=dict)
69
+
70
+ def __hash__(self):
71
+ """Make Type hashable based on its name (for use in sets/dicts)."""
72
+ # We only hash based on name since we use name as the key in the types dict
73
+ return hash(self.name)
74
+
75
+ def add_dependency(self, type_: 'Type') -> None:
76
+ self.dependencies.add(type_.name)
77
+ for dep in type_.dependencies:
78
+ self.dependencies.add(dep)
79
+
80
+ def add_implementation_import(self, module: str, name: str | Iterable[str]) -> None:
81
+ # Skip builtins - they don't need to be imported
82
+ if module == 'builtins':
83
+ return
84
+
85
+ if isinstance(name, str):
86
+ name = [name]
87
+
88
+ if module not in self.implementation_imports:
89
+ self.implementation_imports[module] = set()
90
+
91
+ for n in name:
92
+ self.implementation_imports[module].add(n)
93
+
94
+ def add_annotation_import(self, module: str, name: str | Iterable[str]) -> None:
95
+ # Skip builtins - they don't need to be imported
96
+ if module == 'builtins':
97
+ return
98
+
99
+ if isinstance(name, str):
100
+ name = [name]
101
+
102
+ if module not in self.annotation_imports:
103
+ self.annotation_imports[module] = set()
104
+
105
+ for n in name:
106
+ self.annotation_imports[module].add(n)
107
+
108
+ def copy_imports_from_sub_types(self, types: Iterable['Type']):
109
+ for t in types:
110
+ for module, names in t.annotation_imports.items():
111
+ self.add_annotation_import(module, names)
112
+
113
+ for module, names in t.implementation_imports.items():
114
+ self.add_implementation_import(module, names)
115
+
116
+ def __eq__(self, other):
117
+ """Deep comparison of Type objects, including AST nodes."""
118
+ if not isinstance(other, Type):
119
+ return False
120
+
121
+ # Compare simple fields
122
+ if (
123
+ self.reference != other.reference
124
+ or self.name != other.name
125
+ or self.type != other.type
126
+ ):
127
+ return False
128
+
129
+ # Compare AST nodes by dumping them to strings
130
+ # Compare annotation AST (can be None)
131
+ if self.annotation_ast is None and other.annotation_ast is None:
132
+ pass # Both None, equal
133
+ elif self.annotation_ast is None or other.annotation_ast is None:
134
+ return False # One is None, other isn't
135
+ else:
136
+ if ast.dump(self.annotation_ast) != ast.dump(other.annotation_ast):
137
+ return False
138
+
139
+ # Compare implementation AST (can be None)
140
+ if self.implementation_ast is None and other.implementation_ast is None:
141
+ pass # Both None, equal
142
+ elif self.implementation_ast is None or other.implementation_ast is None:
143
+ return False # One is None, other isn't
144
+ else:
145
+ if ast.dump(self.implementation_ast) != ast.dump(other.implementation_ast):
146
+ return False
147
+
148
+ # Compare imports and dependencies
149
+ if (
150
+ self.dependencies != other.dependencies
151
+ or self.implementation_imports != other.implementation_imports
152
+ or self.annotation_imports != other.annotation_imports
153
+ ):
154
+ return False
155
+
156
+ return True
157
+
158
+
159
+ @dataclasses.dataclass
160
+ class Parameter:
161
+ name: str
162
+ name_sanitized: str
163
+ location: Literal['query', 'path', 'header', 'cookie', 'body']
164
+ required: bool
165
+ type: Type | None = None
166
+ description: str | None = None
167
+
168
+
169
+ @dataclasses.dataclass
170
+ class ResponseInfo:
171
+ """Information about a response for a specific status code.
172
+
173
+ Attributes:
174
+ status_code: The HTTP status code for this response.
175
+ content_type: The content type (e.g., 'application/json', 'application/octet-stream').
176
+ type: The Type object for JSON responses, or None for raw responses.
177
+ """
178
+
179
+ status_code: int
180
+ content_type: str
181
+ type: Type | None = None
182
+
183
+ @property
184
+ def is_json(self) -> bool:
185
+ """Check if this is a JSON response."""
186
+ return self.content_type in (
187
+ 'application/json',
188
+ 'text/json',
189
+ ) or self.content_type.endswith('+json')
190
+
191
+ @property
192
+ def is_binary(self) -> bool:
193
+ """Check if this is a binary response (file download)."""
194
+ binary_types = (
195
+ 'application/octet-stream',
196
+ 'application/pdf',
197
+ 'application/zip',
198
+ 'application/gzip',
199
+ 'application/x-tar',
200
+ 'application/x-rar-compressed',
201
+ )
202
+ binary_prefixes = ('image/', 'audio/', 'video/', 'application/vnd.')
203
+ return self.content_type in binary_types or any(
204
+ self.content_type.startswith(p) for p in binary_prefixes
205
+ )
206
+
207
+ @property
208
+ def is_text(self) -> bool:
209
+ """Check if this is a plain text response."""
210
+ return self.content_type.startswith('text/') and not self.is_json
211
+
212
+ @property
213
+ def is_raw(self) -> bool:
214
+ """Check if this is an unknown content type that should return the raw httpx.Response."""
215
+ return not (self.is_json or self.is_binary or self.is_text)
216
+
217
+
218
+ @dataclasses.dataclass
219
+ class RequestBodyInfo:
220
+ """Information about a request body including its content type.
221
+
222
+ Attributes:
223
+ content_type: The content type (e.g., 'application/json', 'multipart/form-data').
224
+ type: The Type object for the body schema, or None if no schema.
225
+ required: Whether the request body is required.
226
+ description: Optional description of the request body.
227
+ """
228
+
229
+ content_type: str
230
+ type: Type | None = None
231
+ required: bool = False
232
+ description: str | None = None
233
+
234
+ @property
235
+ def is_json(self) -> bool:
236
+ """Check if this is a JSON request body."""
237
+ return self.content_type in (
238
+ 'application/json',
239
+ 'text/json',
240
+ ) or self.content_type.endswith('+json')
241
+
242
+ @property
243
+ def is_form(self) -> bool:
244
+ """Check if this is a form-encoded request body."""
245
+ return self.content_type == 'application/x-www-form-urlencoded'
246
+
247
+ @property
248
+ def is_multipart(self) -> bool:
249
+ """Check if this is a multipart form data request body."""
250
+ return self.content_type == 'multipart/form-data'
251
+
252
+ @property
253
+ def is_binary(self) -> bool:
254
+ """Check if this is a binary request body."""
255
+ return self.content_type in ('application/octet-stream',)
256
+
257
+ @property
258
+ def httpx_param_name(self) -> str:
259
+ """Get the httpx parameter name for this content type.
260
+
261
+ Returns:
262
+ The appropriate httpx parameter name: 'json', 'data', 'files', or 'content'.
263
+ """
264
+ if self.is_json:
265
+ return 'json'
266
+ elif self.is_form:
267
+ return 'data'
268
+ elif self.is_multipart:
269
+ return 'files'
270
+ elif self.is_binary:
271
+ return 'content'
272
+ else:
273
+ return 'content'
274
+
275
+
276
+ @dataclasses.dataclass
277
+ class Endpoint:
278
+ """Represents a generated API endpoint with sync and async functions."""
279
+
280
+ # AST nodes
281
+ sync_ast: ast.FunctionDef
282
+ async_ast: ast.AsyncFunctionDef
283
+
284
+ # Function names
285
+ sync_fn_name: str
286
+ async_fn_name: str
287
+
288
+ # Endpoint metadata
289
+ name: str
290
+ method: str = ''
291
+ path: str = ''
292
+ description: str | None = None
293
+ tags: list[str] | None = None # OpenAPI tags for module splitting
294
+
295
+ # Parameters and body
296
+ parameters: list['Parameter'] | None = None
297
+ request_body: 'RequestBodyInfo | None' = None
298
+
299
+ # Response info
300
+ response_type: 'Type | None' = None
301
+ response_infos: list['ResponseInfo'] | None = None
302
+
303
+ # Imports needed
304
+ imports: dict[str, set[str]] = dataclasses.field(default_factory=dict)
305
+
306
+ @property
307
+ def fn(self) -> ast.FunctionDef:
308
+ """Alias for sync_ast."""
309
+ return self.sync_ast
310
+
311
+ @property
312
+ def async_fn(self) -> ast.AsyncFunctionDef:
313
+ """Alias for async_ast."""
314
+ return self.async_ast
315
+
316
+ def add_imports(self, imports: list[dict[str, set[str]]]):
317
+ for imports_ in imports:
318
+ for module, names in imports_.items():
319
+ if module not in self.imports:
320
+ self.imports[module] = set()
321
+ self.imports[module].update(names)
322
+
323
+
324
+ @dataclasses.dataclass
325
+ class TypeGenerator(OpenAPIProcessor):
326
+ types: dict[str, Type] = dataclasses.field(default_factory=dict)
327
+
328
+ def add_type(self, type_: Type, base_name: str | None = None) -> Type:
329
+ """Add a type to the registry. If a type with the same name but different definition
330
+ already exists, generate a unique name using the base_name prefix.
331
+ Returns the type (potentially with a modified name).
332
+ """
333
+ # Skip types without names (primitive types, inline types, etc.)
334
+ if not type_.name:
335
+ return type_
336
+
337
+ # If type with same name and same definition exists, just return the existing one
338
+ if type_.name in self.types:
339
+ existing = self.types[type_.name]
340
+ if existing == type_:
341
+ # Same type already registered, return the existing one
342
+ # This avoids creating Detail20, Detail21 when they're identical
343
+ return existing
344
+ else:
345
+ # Different definition with same name - generate a unique name
346
+ if base_name:
347
+ # Use base_name as prefix for endpoint-specific types
348
+ unique_name = f'{base_name}{type_.name}'
349
+ if unique_name not in self.types:
350
+ type_.name = unique_name
351
+ type_.annotation_ast = _name(unique_name)
352
+ # Update the implementation_ast name if it's a ClassDef
353
+ if isinstance(type_.implementation_ast, ast.ClassDef):
354
+ type_.implementation_ast.name = unique_name
355
+ else:
356
+ # Check if even the base_name version is the same
357
+ if (
358
+ unique_name in self.types
359
+ and self.types[unique_name] == type_
360
+ ):
361
+ return self.types[unique_name]
362
+ # If even that exists with different def, add a counter
363
+ counter = 1
364
+ while f'{unique_name}{counter}' in self.types:
365
+ candidate = f'{unique_name}{counter}'
366
+ if self.types[candidate] == type_:
367
+ return self.types[candidate]
368
+ counter += 1
369
+ unique_name = f'{unique_name}{counter}'
370
+ type_.name = unique_name
371
+ type_.annotation_ast = _name(unique_name)
372
+ if isinstance(type_.implementation_ast, ast.ClassDef):
373
+ type_.implementation_ast.name = unique_name
374
+ else:
375
+ # No base_name provided, just add a counter
376
+ counter = 1
377
+ original_name = type_.name
378
+ while f'{original_name}{counter}' in self.types:
379
+ candidate = f'{original_name}{counter}'
380
+ if self.types[candidate] == type_:
381
+ # Found identical type with numbered name
382
+ return self.types[candidate]
383
+ counter += 1
384
+ unique_name = f'{original_name}{counter}'
385
+ type_.name = unique_name
386
+ type_.annotation_ast = _name(unique_name)
387
+ if isinstance(type_.implementation_ast, ast.ClassDef):
388
+ type_.implementation_ast.name = unique_name
389
+
390
+ self.types[type_.name] = type_
391
+ return type_
392
+
393
+ def _resolve_reference(self, reference: Reference | Schema) -> tuple[Schema, str]:
394
+ if hasattr(reference, 'ref'):
395
+ if not reference.ref.startswith('#/components/schemas/'):
396
+ raise ValueError(f'Unsupported reference format: {reference.ref}')
397
+
398
+ schema_name = reference.ref.split('/')[-1]
399
+ schemas = self.openapi.components.schemas
400
+
401
+ if schema_name not in schemas:
402
+ raise ValueError(
403
+ f"Referenced schema '{schema_name}' not found in components.schemas"
404
+ )
405
+
406
+ return schemas[schema_name], sanitize_identifier(schema_name)
407
+ return reference, sanitize_identifier(
408
+ reference.title
409
+ ) if reference.title else None
410
+
411
+ def _create_enum_type(
412
+ self,
413
+ schema: Schema,
414
+ name: str | None = None,
415
+ base_name: str | None = None,
416
+ field_name: str | None = None,
417
+ ) -> Type:
418
+ """Create an Enum class for schema with enum values.
419
+
420
+ Args:
421
+ schema: The schema containing enum values.
422
+ name: Optional explicit name for the enum.
423
+ base_name: Optional base name prefix (e.g., parent model name).
424
+ field_name: Optional field name this enum is used for (e.g., 'status').
425
+
426
+ Returns:
427
+ A Type representing the generated Enum class.
428
+ """
429
+ # Determine enum name - prefer schema title, then derive from context
430
+ enum_name = name or (
431
+ sanitize_identifier(schema.title) if schema.title else None
432
+ )
433
+ if not enum_name:
434
+ # Generate name from field_name with base_name context for uniqueness
435
+ if field_name:
436
+ # e.g., Pet + status -> 'PetStatus', Order + status -> 'OrderStatus'
437
+ # Capitalize the field part to ensure proper PascalCase
438
+ field_part = sanitize_identifier(field_name)
439
+ # Ensure first letter is capitalized for PascalCase
440
+ if field_part:
441
+ field_part = field_part[0].upper() + field_part[1:]
442
+ if base_name:
443
+ base_part = sanitize_identifier(base_name)
444
+ enum_name = f'{base_part}{field_part}'
445
+ else:
446
+ enum_name = field_part
447
+ elif base_name:
448
+ enum_name = f'{sanitize_identifier(base_name)}Enum'
449
+ else:
450
+ enum_name = 'AutoEnum'
451
+
452
+ # Create a hashable key from enum values to detect duplicates
453
+ enum_values_key = tuple(sorted(str(v) for v in schema.enum if v is not None))
454
+
455
+ # Check if an identical enum already exists
456
+ for existing_name, existing_type in self.types.items():
457
+ if existing_type.type == 'model' and isinstance(
458
+ existing_type.implementation_ast, ast.ClassDef
459
+ ):
460
+ # Check if it's an Enum class with same values
461
+ existing_class = existing_type.implementation_ast
462
+ if any(
463
+ isinstance(base, ast.Name) and base.id == 'Enum'
464
+ for base in existing_class.bases
465
+ ):
466
+ # Extract values from existing enum
467
+ existing_values = []
468
+ for node in existing_class.body:
469
+ if isinstance(node, ast.Assign) and node.value:
470
+ if isinstance(node.value, ast.Constant):
471
+ existing_values.append(str(node.value.value))
472
+ if tuple(sorted(existing_values)) == enum_values_key:
473
+ # Reuse existing enum
474
+ return existing_type
475
+
476
+ # Ensure the name is unique
477
+ if enum_name in self.types:
478
+ counter = 1
479
+ original_name = enum_name
480
+ while f'{original_name}{counter}' in self.types:
481
+ counter += 1
482
+ enum_name = f'{original_name}{counter}'
483
+
484
+ # Build enum members: NAME = 'value'
485
+ # For string enums, use the value as the member name (sanitized)
486
+ enum_body = []
487
+ seen_member_names: dict[str, int] = {} # Track seen names to handle duplicates
488
+ for value in schema.enum:
489
+ if value is None:
490
+ continue # Skip None values in enums
491
+ # Create a valid Python identifier for the enum member
492
+ if isinstance(value, str):
493
+ member_name = sanitize_identifier(value).upper()
494
+ # If the sanitized name starts with a digit, prefix with underscore
495
+ if member_name and member_name[0].isdigit():
496
+ member_name = f'_{member_name}'
497
+ else:
498
+ # For numeric enums, create VALUE_X names
499
+ member_name = f'VALUE_{value}'
500
+
501
+ # Handle duplicate member names (e.g., 'mesoderm' and 'Mesoderm' both -> 'MESODERM')
502
+ if member_name in seen_member_names:
503
+ seen_member_names[member_name] += 1
504
+ member_name = f'{member_name}_{seen_member_names[member_name]}'
505
+ else:
506
+ seen_member_names[member_name] = 0
507
+
508
+ enum_body.append(
509
+ ast.Assign(
510
+ targets=[ast.Name(id=member_name, ctx=ast.Store())],
511
+ value=ast.Constant(value=value),
512
+ )
513
+ )
514
+
515
+ # If no valid members, fall back to Literal
516
+ if not enum_body:
517
+ return self._create_literal_type(schema)
518
+
519
+ # Create the Enum class
520
+ # class EnumName(str, Enum): # str mixin for string enums
521
+ # MEMBER = 'value'
522
+ bases = (
523
+ [_name('str'), _name('Enum')]
524
+ if schema.type and schema.type.value == 'string'
525
+ else [_name('Enum')]
526
+ )
527
+
528
+ enum_class = ast.ClassDef(
529
+ name=enum_name,
530
+ bases=bases,
531
+ keywords=[],
532
+ body=enum_body,
533
+ decorator_list=[],
534
+ type_params=[],
535
+ )
536
+
537
+ type_ = Type(
538
+ reference=None,
539
+ name=enum_name,
540
+ annotation_ast=_name(enum_name),
541
+ implementation_ast=enum_class,
542
+ type='model', # Treat as model so it gets included in models.py
543
+ )
544
+ type_.add_implementation_import('enum', 'Enum')
545
+
546
+ # Register the type
547
+ self.types[enum_name] = type_
548
+
549
+ return type_
550
+
551
+ def _create_literal_type(self, schema: Schema) -> Type:
552
+ """Create a Literal type for enum values (fallback)."""
553
+ literal_values = [ast.Constant(value=v) for v in schema.enum]
554
+ type_ = Type(
555
+ None,
556
+ sanitize_identifier(schema.title) if schema.title else None,
557
+ annotation_ast=_subscript(
558
+ 'Literal', ast.Tuple(elts=literal_values, ctx=ast.Load())
559
+ ),
560
+ implementation_ast=None,
561
+ type='primitive',
562
+ )
563
+ type_.add_annotation_import('typing', 'Literal')
564
+ return type_
565
+
566
+ def _is_nullable(self, schema: Schema) -> bool:
567
+ """Check if a schema represents a nullable type.
568
+
569
+ In OpenAPI 3.1+, nullable is expressed via type arrays like ["string", "null"].
570
+ """
571
+ if isinstance(schema.type, list):
572
+ return any(
573
+ t == DataType.null or (hasattr(t, 'value') and t.value == 'null')
574
+ for t in schema.type
575
+ )
576
+ return False
577
+
578
+ def _get_non_null_type(self, schema: Schema) -> DataType | None:
579
+ """Extract the non-null type from a potentially nullable schema."""
580
+ if isinstance(schema.type, list):
581
+ for t in schema.type:
582
+ if t != DataType.null and (
583
+ not hasattr(t, 'value') or t.value != 'null'
584
+ ):
585
+ return t
586
+ return None
587
+ return schema.type
588
+
589
+ def _make_nullable_type(self, base_type: Type) -> Type:
590
+ """Wrap a type annotation to make it nullable (T | None)."""
591
+ nullable_ast = _union_expr([base_type.annotation_ast, ast.Constant(value=None)])
592
+
593
+ type_ = Type(
594
+ reference=base_type.reference,
595
+ name=base_type.name,
596
+ annotation_ast=nullable_ast,
597
+ implementation_ast=base_type.implementation_ast,
598
+ type=base_type.type,
599
+ dependencies=base_type.dependencies.copy(),
600
+ implementation_imports=base_type.implementation_imports.copy(),
601
+ annotation_imports=base_type.annotation_imports.copy(),
602
+ )
603
+ return type_
604
+
605
+ def _get_primitive_type_ast(
606
+ self,
607
+ schema: Schema,
608
+ base_name: str | None = None,
609
+ field_name: str | None = None,
610
+ ) -> Type:
611
+ # Handle enum types - generate Enum class
612
+ if schema.enum:
613
+ return self._create_enum_type(
614
+ schema, base_name=base_name, field_name=field_name
615
+ )
616
+
617
+ # Check for nullable type (type array with null)
618
+ is_nullable = self._is_nullable(schema)
619
+ actual_type = self._get_non_null_type(schema)
620
+
621
+ # Fix: schema.type is a Type enum, need to use .value for string lookup
622
+ type_value = actual_type.value if actual_type else None
623
+ key = (type_value, schema.format or None)
624
+ mapped = _PRIMITIVE_TYPE_MAP.get(key, Any)
625
+
626
+ type_ = Type(
627
+ None,
628
+ sanitize_identifier(schema.title) if schema.title else None,
629
+ annotation_ast=_name(mapped.__name__ if mapped is not None else 'None'),
630
+ implementation_ast=None,
631
+ type='primitive',
632
+ )
633
+
634
+ if mapped is not None and mapped.__module__ != 'builtins':
635
+ type_.add_annotation_import(mapped.__module__, mapped.__name__)
636
+
637
+ # Wrap in Union with None if nullable
638
+ if is_nullable:
639
+ type_ = self._make_nullable_type(type_)
640
+
641
+ return type_
642
+
643
+ def _create_pydantic_field(
644
+ self,
645
+ field_name: str,
646
+ field_schema: Schema,
647
+ field_type: Type,
648
+ is_required: bool = False,
649
+ is_nullable: bool = False,
650
+ ) -> str:
651
+ if hasattr(field_schema, 'ref'):
652
+ field_schema, _ = self._resolve_reference(field_schema)
653
+
654
+ field_keywords = list()
655
+
656
+ sanitized_field_name = sanitize_parameter_field_name(field_name)
657
+
658
+ # Determine the annotation - wrap in Union with None if nullable
659
+ annotation_ast = field_type.annotation_ast
660
+ if is_nullable and not self._type_already_nullable(field_type):
661
+ annotation_ast = _union_expr(
662
+ [field_type.annotation_ast, ast.Constant(value=None)]
663
+ )
664
+
665
+ value = None
666
+ if field_schema.default is not None and isinstance(
667
+ field_schema.default, (str, int, float, bool)
668
+ ):
669
+ field_keywords.append(
670
+ ast.keyword(arg='default', value=ast.Constant(field_schema.default))
671
+ )
672
+ elif field_schema.default is None and not is_required:
673
+ # Only add default=None for optional (not required) fields
674
+ # Nullable but required fields should NOT have a default
675
+ field_keywords.append(ast.keyword(arg='default', value=ast.Constant(None)))
676
+
677
+ if sanitized_field_name != field_name:
678
+ field_keywords.append(
679
+ ast.keyword(
680
+ arg='alias',
681
+ value=ast.Constant(field_name), # original name before adding _
682
+ )
683
+ )
684
+ field_name = sanitized_field_name
685
+
686
+ if field_keywords:
687
+ value = _call(
688
+ func=_name(Field.__name__),
689
+ keywords=field_keywords,
690
+ )
691
+
692
+ field_type.add_implementation_import(
693
+ module=Field.__module__, name=Field.__name__
694
+ )
695
+
696
+ return ast.AnnAssign(
697
+ target=_name(field_name),
698
+ annotation=annotation_ast,
699
+ value=value,
700
+ simple=1,
701
+ )
702
+
703
+ def _type_already_nullable(self, type_: Type) -> bool:
704
+ """Check if a type annotation already includes None."""
705
+ if isinstance(type_.annotation_ast, ast.Subscript):
706
+ # Check if it's Union[..., None]
707
+ if isinstance(type_.annotation_ast.value, ast.Name):
708
+ if type_.annotation_ast.value.id == 'Union':
709
+ if isinstance(type_.annotation_ast.slice, ast.Tuple):
710
+ for elt in type_.annotation_ast.slice.elts:
711
+ if isinstance(elt, ast.Constant) and elt.value is None:
712
+ return True
713
+ return False
714
+
715
+ def _create_pydantic_root_model(
716
+ self,
717
+ schema: Schema,
718
+ item_type: Type | None = None,
719
+ name: str | None = None,
720
+ base_name: str | None = None,
721
+ ) -> Type:
722
+ name = (
723
+ name
724
+ or base_name
725
+ or (sanitize_identifier(schema.title) if schema.title else None)
726
+ )
727
+ if not name:
728
+ raise ValueError('Root model must have a name')
729
+
730
+ model = ast.ClassDef(
731
+ name=name,
732
+ bases=[_subscript(RootModel.__name__, item_type.annotation_ast)],
733
+ keywords=[],
734
+ body=[ast.Pass()],
735
+ decorator_list=[],
736
+ type_params=[],
737
+ )
738
+
739
+ type_ = Type(
740
+ reference=None,
741
+ name=name,
742
+ annotation_ast=_name(name),
743
+ implementation_ast=model,
744
+ type='root',
745
+ )
746
+ type_.add_implementation_import(
747
+ module=RootModel.__module__, name=RootModel.__name__
748
+ )
749
+ type_.copy_imports_from_sub_types([item_type] if item_type else [])
750
+ if item_type is not None:
751
+ type_.add_dependency(item_type)
752
+ type_ = self.add_type(type_, base_name=base_name)
753
+
754
+ return type_
755
+
756
+ def _create_pydantic_model(
757
+ self, schema: Schema, name: str | None = None, base_name: str | None = None
758
+ ) -> Type:
759
+ base_bases = []
760
+ if schema.allOf:
761
+ for base_schema in schema.allOf:
762
+ base = self._create_object_type(schema=base_schema, base_name=base_name)
763
+ base_bases.append(base)
764
+
765
+ if schema.anyOf or schema.oneOf:
766
+ # Use schema_to_type for each variant to properly handle primitives, objects, etc.
767
+ types_ = [
768
+ self.schema_to_type(t, base_name=base_name)
769
+ for t in (schema.anyOf or schema.oneOf)
770
+ ]
771
+
772
+ union_type = Type(
773
+ reference=None,
774
+ name=None, # Union type doesn't need a name, it's used inline
775
+ annotation_ast=_union_expr(types=[t.annotation_ast for t in types_]),
776
+ implementation_ast=None,
777
+ type='primitive',
778
+ )
779
+ union_type.copy_imports_from_sub_types(types_)
780
+ return union_type
781
+
782
+ name = name or (
783
+ sanitize_identifier(schema.title) if schema.title else 'UnnamedModel'
784
+ )
785
+
786
+ bases = [b.name for b in base_bases] or [BaseModel.__name__]
787
+ bases = [_name(base) for base in bases]
788
+
789
+ body = []
790
+ field_types = []
791
+ # Fix: Get the required fields from the parent schema's required array
792
+ required_fields = set(schema.required or [])
793
+ for property_name, property_schema in (schema.properties or {}).items():
794
+ # Resolve reference to check for nullable
795
+ resolved_schema = property_schema
796
+ if hasattr(property_schema, 'ref') and property_schema.ref:
797
+ resolved_schema, _ = self._resolve_reference(property_schema)
798
+
799
+ # Check if field is nullable (type array with null)
800
+ is_nullable = (
801
+ self._is_nullable(resolved_schema) if resolved_schema else False
802
+ )
803
+
804
+ type_ = self.schema_to_type(
805
+ property_schema, base_name=base_name, field_name=property_name
806
+ )
807
+ is_required = property_name in required_fields
808
+ field = self._create_pydantic_field(
809
+ property_name, property_schema, type_, is_required, is_nullable
810
+ )
811
+
812
+ body.append(field)
813
+ field_types.append(type_)
814
+
815
+ # Add deprecation docstring if schema is deprecated
816
+ if schema.deprecated:
817
+ deprecation_doc = ast.Expr(
818
+ value=ast.Constant(
819
+ value=f'{name} is deprecated.\n\n.. deprecated::\n This model is deprecated.'
820
+ )
821
+ )
822
+ body = [deprecation_doc] + body if body else [deprecation_doc]
823
+
824
+ model = ast.ClassDef(
825
+ name=name,
826
+ bases=bases,
827
+ keywords=[],
828
+ body=body or [ast.Pass()],
829
+ decorator_list=[],
830
+ type_params=[],
831
+ )
832
+
833
+ type_ = Type(
834
+ reference=None,
835
+ name=name,
836
+ annotation_ast=_name(name),
837
+ implementation_ast=model,
838
+ dependencies=set(),
839
+ type='model',
840
+ )
841
+
842
+ # Add base class dependencies
843
+ if base_bases:
844
+ for base in base_bases:
845
+ type_.add_dependency(base)
846
+
847
+ # Add field type dependencies
848
+ for field_type in field_types:
849
+ if field_type.name:
850
+ type_.dependencies.add(field_type.name)
851
+ type_.dependencies.update(field_type.dependencies)
852
+
853
+ type_.add_implementation_import(
854
+ module=BaseModel.__module__, name=BaseModel.__name__
855
+ )
856
+ type_.add_implementation_import(module=Field.__module__, name=Field.__name__)
857
+ type_.copy_imports_from_sub_types(field_types)
858
+
859
+ type_ = self.add_type(type_, base_name=base_name)
860
+ return type_
861
+
862
+ def _create_array_type(
863
+ self, schema: Schema, name: str | None = None, base_name: str | None = None
864
+ ) -> Type:
865
+ if schema.type != DataType.array:
866
+ raise ValueError('Schema is not an array')
867
+
868
+ if not schema.items:
869
+ type_ = Type(
870
+ None,
871
+ None,
872
+ _subscript(
873
+ list.__name__,
874
+ ast.Name(id=Any.__name__, ctx=ast.Load()),
875
+ ),
876
+ 'primitive',
877
+ )
878
+
879
+ type_.add_annotation_import(module=list.__module__, name=list.__name__)
880
+ type_.add_annotation_import(module=Any.__module__, name=Any.__name__)
881
+
882
+ return type_
883
+
884
+ item_type = self.schema_to_type(schema.items, base_name=base_name)
885
+
886
+ type_ = Type(
887
+ None,
888
+ None,
889
+ annotation_ast=_subscript(
890
+ list.__name__,
891
+ item_type.annotation_ast,
892
+ ),
893
+ implementation_ast=None,
894
+ type='primitive',
895
+ )
896
+
897
+ type_.add_annotation_import(list.__module__, list.__name__)
898
+ type_.copy_imports_from_sub_types([item_type])
899
+
900
+ if item_type:
901
+ type_.add_dependency(item_type)
902
+
903
+ return type_
904
+
905
+ def _create_object_type(
906
+ self,
907
+ schema: Schema | Reference,
908
+ name: str | None = None,
909
+ base_name: str | None = None,
910
+ ) -> Type:
911
+ schema, schema_name = self._resolve_reference(schema)
912
+
913
+ # Handle additionalProperties for dict-like types
914
+ if (
915
+ not schema.properties
916
+ and not schema.allOf
917
+ and not schema.anyOf
918
+ and not schema.oneOf
919
+ ):
920
+ # Check for additionalProperties to determine value type
921
+ value_type_ast = ast.Name(id=Any.__name__, ctx=ast.Load())
922
+ value_type_imports: dict[str, set[str]] = {Any.__module__: {Any.__name__}}
923
+
924
+ if (
925
+ schema.additionalProperties is not None
926
+ and schema.additionalProperties is not True
927
+ ):
928
+ if schema.additionalProperties is False:
929
+ # No additional properties allowed - still generate dict[str, Any]
930
+ pass
931
+ elif isinstance(schema.additionalProperties, (Schema, Reference)):
932
+ # additionalProperties has a schema - use it for value type
933
+ additional_type = self.schema_to_type(
934
+ schema.additionalProperties, base_name=base_name
935
+ )
936
+ value_type_ast = additional_type.annotation_ast
937
+ value_type_imports = additional_type.annotation_imports.copy()
938
+
939
+ type_ = Type(
940
+ None,
941
+ None,
942
+ annotation_ast=_subscript(
943
+ dict.__name__,
944
+ ast.Tuple(
945
+ elts=[
946
+ ast.Name(id=str.__name__, ctx=ast.Load()),
947
+ value_type_ast,
948
+ ]
949
+ ),
950
+ ),
951
+ implementation_ast=None,
952
+ type='primitive',
953
+ )
954
+
955
+ type_.add_annotation_import(dict.__module__, dict.__name__)
956
+ for module, names in value_type_imports.items():
957
+ for name_import in names:
958
+ type_.add_annotation_import(module, name_import)
959
+
960
+ return type_
961
+
962
+ return self._create_pydantic_model(
963
+ schema, schema_name or name, base_name=base_name
964
+ )
965
+
966
+ def schema_to_type(
967
+ self,
968
+ schema: Schema | Reference,
969
+ base_name: str | None = None,
970
+ field_name: str | None = None,
971
+ ) -> Type:
972
+ if isinstance(schema, Reference):
973
+ ref_name = schema.ref.split('/')[-1]
974
+ sanitized_ref_name = sanitize_identifier(ref_name)
975
+ if sanitized_ref_name in self.types:
976
+ return self.types[sanitized_ref_name]
977
+
978
+ schema, schema_name = self._resolve_reference(schema)
979
+
980
+ # Use schema_name (from $ref) as base_name for nested types if available
981
+ # This ensures enums inside Pet get names like "PetStatus" not "addPetRequestBodyStatus"
982
+ effective_base_name = schema_name or base_name
983
+
984
+ # TODO: schema.type can be array?
985
+ if schema.type == DataType.array:
986
+ type_ = self._create_array_type(
987
+ schema=schema, name=schema_name, base_name=effective_base_name
988
+ )
989
+ elif schema.type == DataType.object or schema.type is None:
990
+ type_ = self._create_object_type(
991
+ schema, name=schema_name, base_name=effective_base_name
992
+ )
993
+ else:
994
+ type_ = self._get_primitive_type_ast(
995
+ schema, base_name=effective_base_name, field_name=field_name
996
+ )
997
+
998
+ return type_
999
+
1000
+ def get_sorted_types(self) -> list[Type]:
1001
+ """Returns the types sorted in dependency order using topological sort.
1002
+ Types with no dependencies come first.
1003
+ """
1004
+ sorted_types: list[Type] = []
1005
+ temp_mark: set[str] = set()
1006
+ perm_mark: set[str] = set()
1007
+
1008
+ def visit(type_: Type):
1009
+ if type_.name in perm_mark:
1010
+ return
1011
+ if type_.name in temp_mark:
1012
+ raise ValueError(f'Cyclic dependency detected for type: {type_.name}')
1013
+
1014
+ temp_mark.add(type_.name)
1015
+
1016
+ for dep_name in type_.dependencies:
1017
+ if dep_name in self.types:
1018
+ visit(self.types[dep_name])
1019
+
1020
+ perm_mark.add(type_.name)
1021
+ temp_mark.remove(type_.name)
1022
+ sorted_types.append(type_)
1023
+
1024
+ for type_ in self.types.values():
1025
+ if type_.name not in perm_mark:
1026
+ visit(type_)
1027
+
1028
+ return list(reversed(sorted_types))
1029
+
1030
+
1031
+ # =============================================================================
1032
+ # Type Registry
1033
+ # =============================================================================
1034
+
1035
+
1036
+ @dataclasses.dataclass
1037
+ class TypeInfo:
1038
+ """Information about a registered type.
1039
+
1040
+ Attributes:
1041
+ name: The Python name for this type.
1042
+ reference: The original OpenAPI reference (e.g., '#/components/schemas/Pet').
1043
+ type_obj: The Type object containing AST and metadata.
1044
+ dependencies: Set of type names this type depends on.
1045
+ is_root_model: Whether this is a Pydantic RootModel.
1046
+ is_generated: Whether the AST has been generated for this type.
1047
+ """
1048
+
1049
+ name: str
1050
+ reference: str | None
1051
+ type_obj: 'Type'
1052
+ dependencies: set[str] = dataclasses.field(default_factory=set)
1053
+ is_root_model: bool = False
1054
+ is_generated: bool = False
1055
+
1056
+
1057
+ class TypeRegistry:
1058
+ """Registry for managing generated types during code generation.
1059
+
1060
+ This class provides a centralized location for tracking all types generated
1061
+ from an OpenAPI schema, handling dependencies between types, and ensuring
1062
+ types are generated in the correct order.
1063
+
1064
+ Example:
1065
+ >>> registry = TypeRegistry()
1066
+ >>> registry.register(type_obj, name='Pet', reference='#/components/schemas/Pet')
1067
+ >>> if registry.has_type('Pet'):
1068
+ ... pet_type = registry.get_type('Pet')
1069
+ >>> for type_info in registry.get_types_in_dependency_order():
1070
+ ... generate_code(type_info)
1071
+ """
1072
+
1073
+ def __init__(self):
1074
+ """Initialize an empty type registry."""
1075
+ self._types: dict[str, TypeInfo] = {}
1076
+ self._by_reference: dict[str, str] = {}
1077
+ self._primitive_types: set[str] = {'str', 'int', 'float', 'bool', 'None'}
1078
+
1079
+ def register(
1080
+ self,
1081
+ type_obj: 'Type',
1082
+ name: str,
1083
+ reference: str | None = None,
1084
+ dependencies: set[str] | None = None,
1085
+ is_root_model: bool = False,
1086
+ ) -> TypeInfo:
1087
+ """Register a new type in the registry.
1088
+
1089
+ Args:
1090
+ type_obj: The Type object containing the type information.
1091
+ name: The Python name for this type.
1092
+ reference: The OpenAPI reference string, if applicable.
1093
+ dependencies: Set of type names this type depends on.
1094
+ is_root_model: Whether this is a Pydantic RootModel.
1095
+
1096
+ Returns:
1097
+ The TypeInfo object for the registered type.
1098
+
1099
+ Raises:
1100
+ ValueError: If a type with the same name is already registered.
1101
+ """
1102
+ if name in self._types:
1103
+ raise ValueError(f"Type '{name}' is already registered")
1104
+
1105
+ type_info = TypeInfo(
1106
+ name=name,
1107
+ reference=reference,
1108
+ type_obj=type_obj,
1109
+ dependencies=dependencies or set(),
1110
+ is_root_model=is_root_model,
1111
+ )
1112
+
1113
+ self._types[name] = type_info
1114
+
1115
+ if reference:
1116
+ self._by_reference[reference] = name
1117
+
1118
+ return type_info
1119
+
1120
+ def register_or_get(
1121
+ self,
1122
+ type_obj: 'Type',
1123
+ name: str,
1124
+ reference: str | None = None,
1125
+ dependencies: set[str] | None = None,
1126
+ is_root_model: bool = False,
1127
+ ) -> TypeInfo:
1128
+ """Register a type if not exists, otherwise return the existing one.
1129
+
1130
+ Args:
1131
+ type_obj: The Type object containing the type information.
1132
+ name: The Python name for this type.
1133
+ reference: The OpenAPI reference string, if applicable.
1134
+ dependencies: Set of type names this type depends on.
1135
+ is_root_model: Whether this is a Pydantic RootModel.
1136
+
1137
+ Returns:
1138
+ The TypeInfo object (existing or newly registered).
1139
+ """
1140
+ if name in self._types:
1141
+ return self._types[name]
1142
+ return self.register(type_obj, name, reference, dependencies, is_root_model)
1143
+
1144
+ def has_type(self, name: str) -> bool:
1145
+ """Check if a type is registered."""
1146
+ return name in self._types
1147
+
1148
+ def has_reference(self, reference: str) -> bool:
1149
+ """Check if a reference has been registered."""
1150
+ return reference in self._by_reference
1151
+
1152
+ def get_type(self, name: str) -> TypeInfo | None:
1153
+ """Get a registered type by name."""
1154
+ return self._types.get(name)
1155
+
1156
+ def get_type_by_reference(self, reference: str) -> TypeInfo | None:
1157
+ """Get a registered type by its OpenAPI reference."""
1158
+ name = self._by_reference.get(reference)
1159
+ if name:
1160
+ return self._types.get(name)
1161
+ return None
1162
+
1163
+ def get_name_for_reference(self, reference: str) -> str | None:
1164
+ """Get the registered name for an OpenAPI reference."""
1165
+ return self._by_reference.get(reference)
1166
+
1167
+ def get_all_types(self) -> dict[str, TypeInfo]:
1168
+ """Get all registered types."""
1169
+ return dict(self._types)
1170
+
1171
+ def get_type_names(self) -> list[str]:
1172
+ """Get all registered type names, sorted alphabetically."""
1173
+ return sorted(self._types.keys())
1174
+
1175
+ def add_dependency(self, type_name: str, depends_on: str) -> None:
1176
+ """Add a dependency relationship between types."""
1177
+ if type_name not in self._types:
1178
+ raise KeyError(f"Type '{type_name}' is not registered")
1179
+ self._types[type_name].dependencies.add(depends_on)
1180
+
1181
+ def get_dependencies(self, type_name: str) -> set[str]:
1182
+ """Get all dependencies for a type."""
1183
+ if type_name not in self._types:
1184
+ raise KeyError(f"Type '{type_name}' is not registered")
1185
+ return self._types[type_name].dependencies.copy()
1186
+
1187
+ def get_types_in_dependency_order(self) -> list[TypeInfo]:
1188
+ """Get all types sorted in dependency order.
1189
+
1190
+ Types are sorted so that dependencies come before the types that
1191
+ depend on them.
1192
+ """
1193
+ result: list[TypeInfo] = []
1194
+ visited: set[str] = set()
1195
+ visiting: set[str] = set()
1196
+
1197
+ def visit(name: str) -> None:
1198
+ if name in visited:
1199
+ return
1200
+ if name in visiting:
1201
+ return
1202
+ if name in self._primitive_types:
1203
+ return
1204
+ if name not in self._types:
1205
+ return
1206
+
1207
+ visiting.add(name)
1208
+ type_info = self._types[name]
1209
+
1210
+ for dep in type_info.dependencies:
1211
+ visit(dep)
1212
+
1213
+ visiting.remove(name)
1214
+ visited.add(name)
1215
+ result.append(type_info)
1216
+
1217
+ for name in sorted(self._types.keys()):
1218
+ visit(name)
1219
+
1220
+ return result
1221
+
1222
+ def mark_generated(self, name: str) -> None:
1223
+ """Mark a type as having its AST generated."""
1224
+ if name not in self._types:
1225
+ raise KeyError(f"Type '{name}' is not registered")
1226
+ self._types[name].is_generated = True
1227
+
1228
+ def get_ungenerated_types(self) -> list[TypeInfo]:
1229
+ """Get all types that haven't been generated yet."""
1230
+ return [t for t in self._types.values() if not t.is_generated]
1231
+
1232
+ def clear(self) -> None:
1233
+ """Clear all registered types."""
1234
+ self._types.clear()
1235
+ self._by_reference.clear()
1236
+
1237
+ def __len__(self) -> int:
1238
+ """Return the number of registered types."""
1239
+ return len(self._types)
1240
+
1241
+ def __iter__(self) -> Iterator[TypeInfo]:
1242
+ """Iterate over all registered types."""
1243
+ return iter(self._types.values())
1244
+
1245
+ def __contains__(self, name: str) -> bool:
1246
+ """Check if a type name is registered."""
1247
+ return name in self._types
1248
+
1249
+ def get_root_models(self) -> list[TypeInfo]:
1250
+ """Get all registered root models."""
1251
+ return [t for t in self._types.values() if t.is_root_model]
1252
+
1253
+ def get_regular_models(self) -> list[TypeInfo]:
1254
+ """Get all registered non-root models."""
1255
+ return [t for t in self._types.values() if not t.is_root_model]
1256
+
1257
+ def merge(self, other: 'TypeRegistry') -> None:
1258
+ """Merge another registry into this one."""
1259
+ for type_info in other:
1260
+ if type_info.name not in self._types:
1261
+ self._types[type_info.name] = type_info
1262
+ if type_info.reference:
1263
+ self._by_reference[type_info.reference] = type_info.name
1264
+
1265
+
1266
+ # =============================================================================
1267
+ # Model Name Collector
1268
+ # =============================================================================
1269
+
1270
+
1271
+ class ModelNameCollector(ast.NodeVisitor):
1272
+ """AST visitor that collects model names from function definitions.
1273
+
1274
+ This visitor walks AST nodes and identifies Name nodes that match
1275
+ a set of available model names, allowing us to determine which
1276
+ models are actually referenced in generated code.
1277
+
1278
+ Example:
1279
+ >>> available = {'Pet', 'User', 'Order'}
1280
+ >>> collector = ModelNameCollector(available)
1281
+ >>> collector.visit(some_function_ast)
1282
+ >>> print(collector.used_models)
1283
+ {'Pet', 'User'}
1284
+ """
1285
+
1286
+ def __init__(self, available_models: set[str]):
1287
+ """Initialize the collector.
1288
+
1289
+ Args:
1290
+ available_models: Set of model names that are available for import.
1291
+ """
1292
+ self.available_models = available_models
1293
+ self.used_models: set[str] = set()
1294
+
1295
+ def visit_Name(self, node: ast.Name) -> None:
1296
+ """Visit a Name node and check if it's an available model."""
1297
+ if node.id in self.available_models:
1298
+ self.used_models.add(node.id)
1299
+ self.generic_visit(node)
1300
+
1301
+ @classmethod
1302
+ def collect_from_endpoints(
1303
+ cls,
1304
+ endpoints: list['Endpoint'],
1305
+ available_models: set[str],
1306
+ ) -> set[str]:
1307
+ """Collect model names used across multiple endpoints.
1308
+
1309
+ Args:
1310
+ endpoints: List of Endpoint objects to scan.
1311
+ available_models: Set of model names that are available.
1312
+
1313
+ Returns:
1314
+ Set of model names that are actually used in the endpoints.
1315
+ """
1316
+ collector = cls(available_models)
1317
+ for endpoint in endpoints:
1318
+ collector.visit(endpoint.sync_ast)
1319
+ collector.visit(endpoint.async_ast)
1320
+ return collector.used_models
1321
+
1322
+
1323
+ def collect_used_model_names(
1324
+ endpoints: list['Endpoint'],
1325
+ typegen_types: dict[str, 'Type'],
1326
+ ) -> set[str]:
1327
+ """Collect model names that are actually used in endpoint signatures.
1328
+
1329
+ Only collects models that have implementations (defined in models.py)
1330
+ and are referenced in endpoint parameters, request bodies, or responses.
1331
+
1332
+ Args:
1333
+ endpoints: List of Endpoint objects to check for model usage.
1334
+ typegen_types: Dictionary mapping type names to Type objects.
1335
+
1336
+ Returns:
1337
+ Set of model names actually used in endpoints.
1338
+ """
1339
+ available_models = {
1340
+ type_.name
1341
+ for type_ in typegen_types.values()
1342
+ if type_.name and type_.implementation_ast
1343
+ }
1344
+
1345
+ return ModelNameCollector.collect_from_endpoints(endpoints, available_models)