otterapi 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. README.md +581 -8
  2. otterapi/__init__.py +73 -0
  3. otterapi/cli.py +327 -29
  4. otterapi/codegen/__init__.py +115 -0
  5. otterapi/codegen/ast_utils.py +134 -5
  6. otterapi/codegen/client.py +1271 -0
  7. otterapi/codegen/codegen.py +1736 -0
  8. otterapi/codegen/dataframes.py +392 -0
  9. otterapi/codegen/emitter.py +473 -0
  10. otterapi/codegen/endpoints.py +2597 -343
  11. otterapi/codegen/pagination.py +1026 -0
  12. otterapi/codegen/schema.py +593 -0
  13. otterapi/codegen/splitting.py +1397 -0
  14. otterapi/codegen/types.py +1345 -0
  15. otterapi/codegen/utils.py +180 -1
  16. otterapi/config.py +1017 -24
  17. otterapi/exceptions.py +231 -0
  18. otterapi/openapi/__init__.py +46 -0
  19. otterapi/openapi/v2/__init__.py +86 -0
  20. otterapi/openapi/v2/spec.json +1607 -0
  21. otterapi/openapi/v2/v2.py +1776 -0
  22. otterapi/openapi/v3/__init__.py +131 -0
  23. otterapi/openapi/v3/spec.json +1651 -0
  24. otterapi/openapi/v3/v3.py +1557 -0
  25. otterapi/openapi/v3_1/__init__.py +133 -0
  26. otterapi/openapi/v3_1/spec.json +1411 -0
  27. otterapi/openapi/v3_1/v3_1.py +798 -0
  28. otterapi/openapi/v3_2/__init__.py +133 -0
  29. otterapi/openapi/v3_2/spec.json +1666 -0
  30. otterapi/openapi/v3_2/v3_2.py +777 -0
  31. otterapi/tests/__init__.py +3 -0
  32. otterapi/tests/fixtures/__init__.py +455 -0
  33. otterapi/tests/test_ast_utils.py +680 -0
  34. otterapi/tests/test_codegen.py +610 -0
  35. otterapi/tests/test_dataframe.py +1038 -0
  36. otterapi/tests/test_exceptions.py +493 -0
  37. otterapi/tests/test_openapi_support.py +616 -0
  38. otterapi/tests/test_openapi_upgrade.py +215 -0
  39. otterapi/tests/test_pagination.py +1101 -0
  40. otterapi/tests/test_splitting_config.py +319 -0
  41. otterapi/tests/test_splitting_integration.py +427 -0
  42. otterapi/tests/test_splitting_resolver.py +512 -0
  43. otterapi/tests/test_splitting_tree.py +525 -0
  44. otterapi-0.0.6.dist-info/METADATA +627 -0
  45. otterapi-0.0.6.dist-info/RECORD +48 -0
  46. {otterapi-0.0.5.dist-info → otterapi-0.0.6.dist-info}/WHEEL +1 -1
  47. otterapi/codegen/generator.py +0 -358
  48. otterapi/codegen/openapi_processor.py +0 -27
  49. otterapi/codegen/type_generator.py +0 -559
  50. otterapi-0.0.5.dist-info/METADATA +0 -54
  51. otterapi-0.0.5.dist-info/RECORD +0 -16
  52. {otterapi-0.0.5.dist-info → otterapi-0.0.6.dist-info}/entry_points.txt +0 -0
otterapi/cli.py CHANGED
@@ -1,13 +1,32 @@
1
+ """Command-line interface for OtterAPI.
2
+
3
+ This module provides the CLI commands for generating Python client code
4
+ from OpenAPI specifications.
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import traceback
10
+ from pathlib import Path
1
11
  from typing import Annotated
2
12
 
3
13
  import typer
14
+ import yaml
4
15
  from rich.console import Console
16
+ from rich.logging import RichHandler
17
+ from rich.panel import Panel
5
18
  from rich.progress import Progress, SpinnerColumn, TextColumn
19
+ from rich.syntax import Syntax
20
+ from rich.table import Table
6
21
 
7
- from otterapi.codegen.generator import Codegen
8
- from otterapi.config import get_config
22
+ from otterapi.codegen.codegen import Codegen
23
+ from otterapi.codegen.schema import SchemaLoader
24
+ from otterapi.config import CodegenConfig, DocumentConfig, get_config
25
+ from otterapi.exceptions import OtterAPIError, SchemaLoadError, SchemaValidationError
9
26
 
10
27
  console = Console()
28
+ error_console = Console(stderr=True)
29
+
11
30
  app = typer.Typer(
12
31
  name='otterapi',
13
32
  help='Generate Python client code from OpenAPI specifications',
@@ -15,6 +34,23 @@ app = typer.Typer(
15
34
  )
16
35
 
17
36
 
37
+ def setup_logging(verbose: bool = False, debug: bool = False) -> None:
38
+ """Configure logging based on verbosity settings."""
39
+ if debug:
40
+ level = logging.DEBUG
41
+ elif verbose:
42
+ level = logging.INFO
43
+ else:
44
+ level = logging.WARNING
45
+
46
+ logging.basicConfig(
47
+ level=level,
48
+ format='%(message)s',
49
+ datefmt='[%X]',
50
+ handlers=[RichHandler(console=error_console, rich_tracebacks=True)],
51
+ )
52
+
53
+
18
54
  @app.command()
19
55
  def generate(
20
56
  config: Annotated[
@@ -23,62 +59,324 @@ def generate(
23
59
  '--config', '-c', help='Path to configuration file (YAML or JSON)'
24
60
  ),
25
61
  ] = None,
62
+ source: Annotated[
63
+ str | None,
64
+ typer.Option(
65
+ '--source', '-s', help='Direct path or URL to OpenAPI specification'
66
+ ),
67
+ ] = None,
68
+ output: Annotated[
69
+ str | None,
70
+ typer.Option('--output', '-o', help='Output directory for generated code'),
71
+ ] = None,
72
+ verbose: Annotated[
73
+ bool, typer.Option('--verbose', '-v', help='Enable verbose output')
74
+ ] = False,
75
+ debug: Annotated[bool, typer.Option('--debug', help='Enable debug output')] = False,
26
76
  ) -> None:
27
- """Generate Python client code from configuration.
77
+ """Generate Python client code from OpenAPI specifications.
28
78
 
29
- If no config file is specified, will look for default config files
30
- in the current directory or use environment variables.
79
+ You can either use a configuration file or specify the source and output
80
+ directly via command-line options.
31
81
 
32
82
  Examples:
33
83
  otterapi generate
34
84
  otterapi generate --config my-config.yaml
35
85
  otterapi generate -c config.json
86
+ otterapi generate --source https://api.example.com/openapi.json --output ./client
87
+ otterapi generate -s ./api.yaml -o ./generated
36
88
  """
37
- config = get_config(config)
89
+ setup_logging(verbose, debug)
38
90
 
39
91
  try:
40
- with Progress(
41
- SpinnerColumn(),
42
- TextColumn('[progress.description]{task.description}'),
43
- console=console,
44
- ) as progress:
45
- for document_config in config.documents:
92
+ # Build configuration from options or file
93
+ if source and output:
94
+ codegen_config = CodegenConfig(
95
+ documents=[DocumentConfig(source=source, output=output)]
96
+ )
97
+ elif source or output:
98
+ error_console.print(
99
+ '[red]Error:[/red] Both --source and --output must be provided together'
100
+ )
101
+ raise typer.Exit(1)
102
+ else:
103
+ try:
104
+ codegen_config = get_config(config)
105
+ except FileNotFoundError as e:
106
+ error_console.print(f'[red]Error:[/red] {e}')
107
+ error_console.print(
108
+ '\n[dim]Hint: Run [bold]otterapi init[/bold] to create a configuration file,[/dim]'
109
+ )
110
+ error_console.print(
111
+ '[dim]or use [bold]--source[/bold] and [bold]--output[/bold] options.[/dim]'
112
+ )
113
+ raise typer.Exit(1)
114
+
115
+ for document_config in codegen_config.documents:
116
+ with Progress(
117
+ SpinnerColumn(),
118
+ TextColumn('[progress.description]{task.description}'),
119
+ console=console,
120
+ ) as progress:
46
121
  task = progress.add_task(
47
- f'Generating code for {document_config.source} in {document_config.output}...',
122
+ f'Generating code for {document_config.source}...',
48
123
  total=None,
49
124
  )
50
125
 
51
126
  codegen = Codegen(document_config)
52
- codegen.generate()
53
127
 
54
- console.print(
55
- f"[green]✓[/green] Successfully generated code in '{document_config.output}'"
128
+ generated_files = codegen.generate()
129
+ progress.update(
130
+ task,
131
+ description=f'Code generation completed for {document_config.source}!',
56
132
  )
133
+
57
134
  console.print('[dim]Generated files:[/dim]')
58
- console.print(
59
- f' - {document_config.output}/{document_config.models_file}'
60
- )
61
- console.print(
62
- f' - {document_config.output}/{document_config.endpoints_file}'
63
- )
135
+ for file_path in generated_files:
136
+ console.print(f' - {file_path}')
64
137
 
65
- progress.update(task, description='Code generation completed!')
138
+ console.print('\n[green]✓[/green] Code generation completed!')
66
139
 
140
+ except OtterAPIError as e:
141
+ error_console.print(f'[red]Error:[/red] {e.message}')
142
+ if debug:
143
+ traceback.print_exc()
144
+ raise typer.Exit(1)
67
145
  except Exception as e:
68
- console.print(f'[red]Error:[/red] {str(e)}')
146
+ error_console.print(f'[red]Error:[/red] {str(e)}')
147
+ if debug:
148
+ traceback.print_exc()
69
149
  raise typer.Exit(1)
70
150
 
71
151
 
152
+ @app.command()
153
+ def init(
154
+ path: Annotated[
155
+ str, typer.Argument(help='Path for the configuration file')
156
+ ] = 'otter.yml',
157
+ force: Annotated[
158
+ bool, typer.Option('--force', '-f', help='Overwrite existing file')
159
+ ] = False,
160
+ ) -> None:
161
+ """Create a new configuration file interactively.
162
+
163
+ This command guides you through creating an OtterAPI configuration file
164
+ with all the necessary settings.
165
+
166
+ Examples:
167
+ otterapi init
168
+ otterapi init otter.yaml
169
+ otterapi init config.json --force
170
+ """
171
+ config_path = Path(path)
172
+
173
+ # Check if file exists
174
+ if config_path.exists() and not force:
175
+ error_console.print(
176
+ f'[red]Error:[/red] File {config_path} already exists. Use --force to overwrite.'
177
+ )
178
+ raise typer.Exit(1)
179
+
180
+ console.print(Panel('[bold]OtterAPI Configuration Setup[/bold]'))
181
+
182
+ # Get source
183
+ source = typer.prompt(
184
+ '\nOpenAPI specification source (URL or file path)',
185
+ default='https://petstore3.swagger.io/api/v3/openapi.json',
186
+ )
187
+
188
+ # Get output directory
189
+ output = typer.prompt('Output directory for generated code', default='./client')
190
+
191
+ # Get models file name
192
+ models_file = typer.prompt('Models file name', default='models.py')
193
+
194
+ # Get endpoints file name
195
+ endpoints_file = typer.prompt('Endpoints file name', default='endpoints.py')
196
+
197
+ # Build config
198
+ config_data = {
199
+ 'documents': [
200
+ {
201
+ 'source': source,
202
+ 'output': output,
203
+ 'models_file': models_file,
204
+ 'endpoints_file': endpoints_file,
205
+ }
206
+ ]
207
+ }
208
+
209
+ # Ask if they want to add more documents
210
+ while typer.confirm('\nAdd another document?', default=False):
211
+ source = typer.prompt('OpenAPI specification source')
212
+ output = typer.prompt('Output directory')
213
+ config_data['documents'].append({'source': source, 'output': output})
214
+
215
+ # Write config file
216
+ config_path.parent.mkdir(parents=True, exist_ok=True)
217
+
218
+ if path.endswith('.json'):
219
+ content = json.dumps(config_data, indent=2)
220
+ else:
221
+ content = yaml.dump(config_data, default_flow_style=False, sort_keys=False)
222
+
223
+ config_path.write_text(content)
224
+
225
+ console.print(f'\n[green]✓[/green] Configuration saved to {config_path}')
226
+ console.print('\n[dim]Preview:[/dim]')
227
+ syntax = Syntax(content, 'yaml' if not path.endswith('.json') else 'json')
228
+ console.print(syntax)
229
+
230
+ console.print(
231
+ f'\n[dim]Run [bold]otterapi generate -c {path}[/bold] to generate code.[/dim]'
232
+ )
233
+
234
+
235
+ @app.command()
236
+ def validate(
237
+ source: Annotated[str, typer.Argument(help='Path or URL to OpenAPI specification')],
238
+ verbose: Annotated[
239
+ bool, typer.Option('--verbose', '-v', help='Show detailed schema information')
240
+ ] = False,
241
+ ) -> None:
242
+ """Validate an OpenAPI specification.
243
+
244
+ This command loads and validates an OpenAPI specification without
245
+ generating any code, reporting any errors or warnings found.
246
+
247
+ Examples:
248
+ otterapi validate ./api.yaml
249
+ otterapi validate https://api.example.com/openapi.json
250
+ otterapi validate ./api.yaml --verbose
251
+ """
252
+ with Progress(
253
+ SpinnerColumn(),
254
+ TextColumn('[progress.description]{task.description}'),
255
+ console=console,
256
+ ) as progress:
257
+ task = progress.add_task(f'Loading {source}...', total=None)
258
+
259
+ try:
260
+ loader = SchemaLoader()
261
+ schema = loader.load(source)
262
+ progress.update(task, description='Validating schema...')
263
+
264
+ # Collect validation info
265
+ warnings: list[str] = []
266
+ info: dict[str, any] = {}
267
+
268
+ if schema.info:
269
+ info['title'] = schema.info.title
270
+ info['version'] = schema.info.version
271
+ if schema.info.description:
272
+ info['description'] = (
273
+ schema.info.description[:200] + '...'
274
+ if len(schema.info.description or '') > 200
275
+ else schema.info.description
276
+ )
277
+
278
+ if schema.paths:
279
+ info['paths'] = len(schema.paths.root)
280
+
281
+ # Count operations
282
+ operations = 0
283
+ for path_item in schema.paths.root.values():
284
+ for method in [
285
+ 'get',
286
+ 'post',
287
+ 'put',
288
+ 'patch',
289
+ 'delete',
290
+ 'head',
291
+ 'options',
292
+ ]:
293
+ if getattr(path_item, method, None):
294
+ operations += 1
295
+ info['operations'] = operations
296
+
297
+ if schema.components:
298
+ if schema.components.schemas:
299
+ info['schemas'] = len(schema.components.schemas)
300
+ if schema.components.securitySchemes:
301
+ info['security_schemes'] = len(schema.components.securitySchemes)
302
+
303
+ # Check for potential issues
304
+ if schema.paths:
305
+ for path, path_item in schema.paths.root.items():
306
+ for method in ['get', 'post', 'put', 'patch', 'delete']:
307
+ operation = getattr(path_item, method, None)
308
+ if operation and not operation.operationId:
309
+ warnings.append(
310
+ f'{method.upper()} {path}: Missing operationId'
311
+ )
312
+
313
+ progress.update(task, description='Validation complete!')
314
+
315
+ except SchemaLoadError as e:
316
+ progress.stop()
317
+ error_console.print(f'[red]✗ Failed to load schema:[/red] {e.message}')
318
+ raise typer.Exit(1)
319
+ except SchemaValidationError as e:
320
+ progress.stop()
321
+ error_console.print(f'[red]✗ Schema validation failed:[/red] {e.message}')
322
+ raise typer.Exit(1)
323
+ except Exception as e:
324
+ progress.stop()
325
+ error_console.print(f'[red]✗ Error:[/red] {str(e)}')
326
+ raise typer.Exit(1)
327
+
328
+ # Print results
329
+ console.print(f'\n[green]✓[/green] Schema is valid: {source}\n')
330
+
331
+ if verbose or info:
332
+ table = Table(title='Schema Information')
333
+ table.add_column('Property', style='cyan')
334
+ table.add_column('Value')
335
+
336
+ for key, value in info.items():
337
+ table.add_row(key.replace('_', ' ').title(), str(value))
338
+
339
+ console.print(table)
340
+
341
+ if warnings:
342
+ console.print(f'\n[yellow]⚠ {len(warnings)} warning(s):[/yellow]')
343
+ for warning in warnings[:10]: # Show first 10
344
+ console.print(f' - {warning}')
345
+ if len(warnings) > 10:
346
+ console.print(f' ... and {len(warnings) - 10} more')
347
+
348
+
72
349
  @app.command()
73
350
  def version() -> None:
74
- """Show the version of otterapi."""
351
+ """Show the version of OtterAPI."""
352
+ try:
353
+ from otterapi._version import version as ver
354
+
355
+ console.print(f'otterapi version: [bold]{ver}[/bold]')
356
+ except ImportError:
357
+ console.print('otterapi version: [dim]unknown (development)[/dim]')
358
+
359
+ # Show dependency versions if verbose
360
+ console.print('\n[dim]Dependencies:[/dim]')
361
+ try:
362
+ import pydantic
363
+
364
+ console.print(f' pydantic: {pydantic.__version__}')
365
+ except ImportError:
366
+ pass
367
+ try:
368
+ import httpx
369
+
370
+ console.print(f' httpx: {httpx.__version__}')
371
+ except ImportError:
372
+ pass
75
373
  try:
76
- from otterapi._version import version
374
+ import typer
77
375
 
78
- console.print(f'otterapi version: {version}')
376
+ console.print(f' typer: {typer.__version__}')
79
377
  except ImportError:
80
- console.print('otterapi version: unknown')
378
+ pass
81
379
 
82
380
 
83
381
  if __name__ == '__main__':
84
- generate()
382
+ app()
@@ -0,0 +1,115 @@
1
+ """Code generation module for OtterAPI.
2
+
3
+ This module provides the core code generation functionality for creating
4
+ Python client code from OpenAPI specifications.
5
+
6
+ Main Components:
7
+ - Codegen: The main orchestrator for code generation
8
+ - TypeGenerator: Generates Pydantic models from OpenAPI schemas
9
+ - SchemaLoader: Loads OpenAPI schemas from URLs or files
10
+ - SchemaResolver: Resolves $ref references in schemas
11
+ - TypeRegistry: Manages generated types and their dependencies
12
+ - CodeEmitter: Handles output of generated code
13
+
14
+ Example:
15
+ >>> from otterapi.codegen import Codegen
16
+ >>> from otterapi.config import DocumentConfig
17
+ >>>
18
+ >>> config = DocumentConfig(
19
+ ... source="./openapi.json",
20
+ ... output="./client"
21
+ ... )
22
+ >>> codegen = Codegen(config)
23
+ >>> codegen.generate()
24
+ """
25
+
26
+ from otterapi.codegen.ast_utils import ImportCollector
27
+ from otterapi.codegen.codegen import Codegen
28
+
29
+ # Re-export from dataframes module
30
+ from otterapi.codegen.dataframes import (
31
+ DataFrameMethodConfig,
32
+ generate_dataframe_module,
33
+ get_dataframe_config_for_endpoint,
34
+ )
35
+ from otterapi.codegen.emitter import CodeEmitter, FileEmitter, StringEmitter
36
+
37
+ # Re-export from endpoints module
38
+ from otterapi.codegen.endpoints import (
39
+ DataFrameLibrary,
40
+ EndpointFunctionConfig,
41
+ EndpointFunctionFactory,
42
+ EndpointMode,
43
+ FunctionSignature,
44
+ FunctionSignatureBuilder,
45
+ ParameterASTBuilder,
46
+ )
47
+ from otterapi.codegen.schema import SchemaLoader, SchemaResolver
48
+
49
+ # Re-export from splitting module
50
+ from otterapi.codegen.splitting import (
51
+ EmittedModule,
52
+ ModuleMapResolver,
53
+ ModuleTree,
54
+ ModuleTreeBuilder,
55
+ ResolvedModule,
56
+ SplitModuleEmitter,
57
+ build_module_tree,
58
+ )
59
+ from otterapi.codegen.types import (
60
+ Endpoint,
61
+ ModelNameCollector,
62
+ Parameter,
63
+ RequestBodyInfo,
64
+ ResponseInfo,
65
+ Type,
66
+ TypeGenerator,
67
+ TypeInfo,
68
+ TypeRegistry,
69
+ collect_used_model_names,
70
+ )
71
+
72
+ __all__ = [
73
+ # Main codegen class
74
+ 'Codegen',
75
+ # Type generation
76
+ 'TypeGenerator',
77
+ 'Type',
78
+ 'TypeRegistry',
79
+ 'TypeInfo',
80
+ 'ModelNameCollector',
81
+ 'collect_used_model_names',
82
+ # Schema handling
83
+ 'SchemaLoader',
84
+ 'SchemaResolver',
85
+ # Endpoint types
86
+ 'Endpoint',
87
+ 'Parameter',
88
+ 'RequestBodyInfo',
89
+ 'ResponseInfo',
90
+ # Code emission
91
+ 'CodeEmitter',
92
+ 'FileEmitter',
93
+ 'StringEmitter',
94
+ 'ImportCollector',
95
+ # Endpoint building
96
+ 'EndpointFunctionConfig',
97
+ 'EndpointFunctionFactory',
98
+ 'EndpointMode',
99
+ 'DataFrameLibrary',
100
+ 'FunctionSignature',
101
+ 'FunctionSignatureBuilder',
102
+ 'ParameterASTBuilder',
103
+ # DataFrame utilities
104
+ 'DataFrameMethodConfig',
105
+ 'generate_dataframe_module',
106
+ 'get_dataframe_config_for_endpoint',
107
+ # Module splitting
108
+ 'ModuleTree',
109
+ 'ModuleTreeBuilder',
110
+ 'ModuleMapResolver',
111
+ 'ResolvedModule',
112
+ 'EmittedModule',
113
+ 'SplitModuleEmitter',
114
+ 'build_module_tree',
115
+ ]
@@ -1,9 +1,33 @@
1
+ """AST utilities and import collection for code generation.
2
+
3
+ This module provides helper functions for building Python AST nodes
4
+ and utilities for collecting and organizing imports during code generation.
5
+ """
6
+
1
7
  import ast
2
8
  import keyword
3
9
  from collections.abc import Iterable
4
10
 
5
11
  PYTHON_KEYWORDS = set(keyword.kwlist)
6
12
 
13
+ __all__ = [
14
+ # AST helpers
15
+ '_name',
16
+ '_attr',
17
+ '_subscript',
18
+ '_union_expr',
19
+ '_optional_expr',
20
+ '_argument',
21
+ '_assign',
22
+ '_import',
23
+ '_call',
24
+ '_func',
25
+ '_async_func',
26
+ '_all',
27
+ # Import collection
28
+ 'ImportCollector',
29
+ ]
30
+
7
31
 
8
32
  def _name(name: str) -> ast.Name:
9
33
  return ast.Name(id=name, ctx=ast.Load())
@@ -11,17 +35,27 @@ def _name(name: str) -> ast.Name:
11
35
 
12
36
  def _attr(value: str | ast.expr, attr: str) -> ast.Attribute:
13
37
  return ast.Attribute(
14
- value=_name(value) if isinstance(value, str) else value, attr=attr
38
+ value=_name(value) if isinstance(value, str) else value,
39
+ attr=attr,
40
+ ctx=ast.Load(),
15
41
  )
16
42
 
17
43
 
18
44
  def _subscript(generic: str, inner: ast.expr) -> ast.Subscript:
19
- return ast.Subscript(value=_name(generic), slice=inner)
45
+ return ast.Subscript(value=_name(generic), slice=inner, ctx=ast.Load())
20
46
 
21
47
 
22
- def _union_expr(types: list[ast.expr]) -> ast.Subscript:
23
- # Union[A, B, C]
24
- return _subscript('Union', ast.Tuple(elts=types))
48
+ def _union_expr(types: list[ast.expr]) -> ast.expr:
49
+ # A | B | C (using pipe operator instead of Union[A, B, C])
50
+ if not types:
51
+ raise ValueError('_union_expr requires at least one type')
52
+ if len(types) == 1:
53
+ return types[0]
54
+ # Build a chain of BinOp with BitOr: A | B | C
55
+ result = types[0]
56
+ for t in types[1:]:
57
+ result = ast.BinOp(left=result, op=ast.BitOr(), right=t)
58
+ return result
25
59
 
26
60
 
27
61
  def _optional_expr(inner: ast.expr) -> ast.Subscript:
@@ -36,6 +70,12 @@ def _argument(name: str, value: ast.expr | None = None) -> ast.arg:
36
70
 
37
71
 
38
72
  def _assign(target: ast.expr, value: ast.expr) -> ast.Assign:
73
+ # Ensure target has Store context
74
+ if isinstance(target, ast.Name):
75
+ target = ast.Name(id=target.id, ctx=ast.Store())
76
+ elif isinstance(target, ast.Attribute):
77
+ # For attributes, only the outermost needs Store context
78
+ target.ctx = ast.Store()
39
79
  return ast.Assign(
40
80
  targets=[target],
41
81
  value=value,
@@ -119,3 +159,92 @@ def _all(names: Iterable[str]) -> ast.Assign:
119
159
  elts=[ast.Constant(value=name) for name in names], ctx=ast.Load()
120
160
  ),
121
161
  )
162
+
163
+
164
+ # =============================================================================
165
+ # Import Collection
166
+ # =============================================================================
167
+
168
+
169
+ class ImportCollector:
170
+ """Collects and manages imports for generated Python code.
171
+
172
+ This class provides a centralized way to collect imports from various
173
+ sources during code generation and convert them to AST import statements.
174
+ It automatically deduplicates imports and sorts them for consistent output.
175
+
176
+ Example:
177
+ >>> collector = ImportCollector()
178
+ >>> collector.add_imports({'typing': {'List', 'Dict'}})
179
+ >>> collector.add_imports({'typing': {'Optional'}})
180
+ >>> imports = collector.to_ast()
181
+ >>> # Returns [ImportFrom(module='typing', names=['Dict', 'List', 'Optional'])]
182
+ """
183
+
184
+ def __init__(self):
185
+ """Initialize an empty import collector."""
186
+ self._imports: dict[str, set[str]] = {}
187
+
188
+ def add_imports(self, imports: dict[str, set[str]]) -> None:
189
+ """Add imports from a dictionary mapping modules to sets of names.
190
+
191
+ Args:
192
+ imports: Dictionary mapping module names to sets of imported names.
193
+ Example: {'typing': {'List', 'Dict'}, 'pydantic': {'BaseModel'}}
194
+ """
195
+ for module, names in imports.items():
196
+ if module not in self._imports:
197
+ self._imports[module] = set()
198
+ self._imports[module].update(names)
199
+
200
+ def add_import(self, module: str, name: str) -> None:
201
+ """Add a single import.
202
+
203
+ Args:
204
+ module: The module to import from (e.g., 'typing', 'pydantic').
205
+ name: The name to import (e.g., 'List', 'BaseModel').
206
+ """
207
+ if module not in self._imports:
208
+ self._imports[module] = set()
209
+ self._imports[module].add(name)
210
+
211
+ def to_ast(self, reverse_sort: bool = True) -> list[ast.ImportFrom]:
212
+ """Convert collected imports to AST ImportFrom statements.
213
+
214
+ Args:
215
+ reverse_sort: If True, sort modules in reverse order (default).
216
+ This is useful for placing standard library imports last.
217
+
218
+ Returns:
219
+ List of ast.ImportFrom statements, sorted by module name.
220
+ Names within each import are also sorted alphabetically.
221
+ """
222
+ import_stmts = []
223
+ for module, names in sorted(self._imports.items(), reverse=reverse_sort):
224
+ import_stmt = ast.ImportFrom(
225
+ module=module,
226
+ names=[ast.alias(name=name, asname=None) for name in sorted(names)],
227
+ level=0,
228
+ )
229
+ import_stmts.append(import_stmt)
230
+ return import_stmts
231
+
232
+ def has_imports(self) -> bool:
233
+ """Check if any imports have been collected.
234
+
235
+ Returns:
236
+ True if imports exist, False otherwise.
237
+ """
238
+ return bool(self._imports)
239
+
240
+ def clear(self) -> None:
241
+ """Clear all collected imports."""
242
+ self._imports.clear()
243
+
244
+ def get_modules(self) -> set[str]:
245
+ """Get the set of all modules that have been imported.
246
+
247
+ Returns:
248
+ Set of module names.
249
+ """
250
+ return set(self._imports.keys())