aster-cli 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aster_cli/codegen.py ADDED
@@ -0,0 +1,828 @@
1
+ """
2
+ aster_cli.codegen -- Generate typed client libraries from Aster manifests.
3
+
4
+ Implements the algorithm described in docs/_internal/aster-client-generation.md.
5
+ Currently supports Python output only.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+ import re
12
+ import textwrap
13
+ from datetime import datetime, timezone
14
+ from typing import Any
15
+
16
+
17
+ # ── Type mapping ─────────────────────────────────────────────────────────────
18
+
19
+ _PY_TYPE_MAP: dict[str, str] = {
20
+ "str": "str",
21
+ "string": "str",
22
+ "int": "int",
23
+ "int32": "int",
24
+ "int64": "int",
25
+ "float": "float",
26
+ "float32": "float",
27
+ "float64": "float",
28
+ "double": "float",
29
+ "bool": "bool",
30
+ "boolean": "bool",
31
+ "bytes": "bytes",
32
+ "optional": "Optional[str]",
33
+ }
34
+
35
+ _PY_DEFAULTS: dict[str, str] = {
36
+ "str": '""',
37
+ "int": "0",
38
+ "float": "0.0",
39
+ "bool": "False",
40
+ "bytes": 'b""',
41
+ }
42
+
43
+ _OPTIONAL_FIELD_HINTS: dict[str, str] = {
44
+ "bio": "Optional[str]",
45
+ "display_name": "Optional[str]",
46
+ "rate_limit": "Optional[str]",
47
+ "recovery_codes": "Optional[list[str]]",
48
+ "replacement": "Optional[str]",
49
+ "scope_node_id": "Optional[str]",
50
+ "url": "Optional[str]",
51
+ }
52
+
53
+ _KNOWN_WIRE_TYPES: dict[str, tuple[str, list[dict[str, Any]]]] = {
54
+ "DelegationStatement": (
55
+ "aster/DelegationStatement",
56
+ [
57
+ {"name": "authority", "type": "str", "required": False, "default": "consumer"},
58
+ {"name": "mode", "type": "str", "required": False, "default": "open"},
59
+ {"name": "token_ttl", "type": "int", "required": False, "default": 300},
60
+ {"name": "rate_limit", "type": "Optional", "required": False, "default": None},
61
+ {"name": "roles", "type": "list[str]", "required": False, "default": None},
62
+ ],
63
+ ),
64
+ "SigningKeyAttestation": (
65
+ "aster/SigningKeyAttestation",
66
+ [
67
+ {"name": "signing_pubkey", "type": "str", "required": False, "default": ""},
68
+ {"name": "key_id", "type": "str", "required": False, "default": ""},
69
+ {"name": "valid_from", "type": "int", "required": False, "default": 0},
70
+ {"name": "valid_until", "type": "int", "required": False, "default": 0},
71
+ {"name": "root_signature", "type": "str", "required": False, "default": ""},
72
+ ],
73
+ ),
74
+ }
75
+
76
+
77
+ def _to_snake_case(name: str) -> str:
78
+ """Convert CamelCase to snake_case."""
79
+ s1 = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", name)
80
+ return re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
81
+
82
+
83
+ def _py_type_str(type_name: str, known_types: dict[str, str]) -> str:
84
+ """Convert a manifest type name to a Python type annotation string.
85
+
86
+ Args:
87
+ type_name: The type name from the manifest (e.g., "str", "list[DiscoverEntry]").
88
+ known_types: Map of display_name -> Python class name for generated types.
89
+ """
90
+ # Handle PEP 604 unions from manifest (e.g., "str | None") -- convert
91
+ # to Optional[X] for pyfory compatibility
92
+ if " | None" in type_name:
93
+ inner = type_name.replace(" | None", "").strip()
94
+ inner_py = _py_type_str(inner, known_types)
95
+ return f"Optional[{inner_py}]"
96
+
97
+ lower = type_name.lower()
98
+ if lower in _PY_TYPE_MAP:
99
+ return _PY_TYPE_MAP[lower]
100
+
101
+ # list[X]
102
+ m = re.match(r"list\[(.+)\]", type_name, re.IGNORECASE)
103
+ if m:
104
+ inner = m.group(1)
105
+ inner_py = known_types.get(inner, _PY_TYPE_MAP.get(inner.lower(), inner))
106
+ return f"list[{inner_py}]"
107
+
108
+ # dict[K, V]
109
+ m = re.match(r"dict\[(.+),\s*(.+)\]", type_name, re.IGNORECASE)
110
+ if m:
111
+ k = _PY_TYPE_MAP.get(m.group(1).lower(), m.group(1))
112
+ v = known_types.get(m.group(2), _PY_TYPE_MAP.get(m.group(2).lower(), m.group(2)))
113
+ return f"dict[{k}, {v}]"
114
+
115
+ # Optional[X] -- use typing.Optional, not PEP 604 union (pyfory compat)
116
+ m = re.match(r"optional\[(.+)\]", type_name, re.IGNORECASE)
117
+ if m:
118
+ inner = _py_type_str(m.group(1), known_types)
119
+ return f"Optional[{inner}]"
120
+
121
+ # Known generated type
122
+ if type_name in known_types:
123
+ return known_types[type_name]
124
+
125
+ # Unknown complex type -- use Any to avoid NameErrors
126
+ if "." in type_name or type_name[0:1].isupper():
127
+ return "Any"
128
+
129
+ return type_name
130
+
131
+
132
+ def _py_default_str(type_name: str, default: Any) -> str | None:
133
+ """Return a Python default value string, or None if no default."""
134
+ if default is not None:
135
+ # Mutable containers MUST go through dataclasses.field(default_factory=...)
136
+ # -- Python 3.13 raises ValueError on `tags: dict = {}` (and equivalent
137
+ # `[]` for list) at class creation time. Emit a factory instead.
138
+ if isinstance(default, dict):
139
+ return "dataclasses.field(default_factory=dict)"
140
+ if isinstance(default, list):
141
+ return "dataclasses.field(default_factory=list)"
142
+ if isinstance(default, str):
143
+ return repr(default)
144
+ if isinstance(default, bool):
145
+ return "True" if default else "False"
146
+ return str(default)
147
+
148
+ lower = type_name.lower()
149
+ if lower in _PY_DEFAULTS:
150
+ return _PY_DEFAULTS[lower]
151
+
152
+ if lower.startswith("list"):
153
+ return "dataclasses.field(default_factory=list)"
154
+ if lower.startswith("dict"):
155
+ return "dataclasses.field(default_factory=dict)"
156
+ if lower.startswith("optional") or "| none" in lower:
157
+ return "None"
158
+
159
+ return None
160
+
161
+
162
+ # ── V1 schema-aware helpers ──────────────────────────────────────────────────
163
+
164
+ _KIND_TO_PY: dict[str, str] = {
165
+ "string": "str",
166
+ "int": "int",
167
+ "float": "float",
168
+ "bool": "bool",
169
+ "bytes": "bytes",
170
+ }
171
+
172
+
173
+ def _py_type_from_field(f: dict[str, Any], known_types: dict[str, str]) -> str:
174
+ """Derive Python type annotation from a v1 schema field dict.
175
+
176
+ Falls back to the legacy _py_type_str for unversioned fields.
177
+ """
178
+ kind = f.get("kind")
179
+ if kind is None:
180
+ return _field_py_type(f, known_types)
181
+
182
+ nullable = f.get("nullable", False)
183
+ base: str
184
+
185
+ if kind in _KIND_TO_PY:
186
+ base = _KIND_TO_PY[kind]
187
+ elif kind == "list":
188
+ item_kind = f.get("item_kind", "string")
189
+ if item_kind == "ref":
190
+ item_name = f.get("item_ref", "Any")
191
+ item_py = known_types.get(item_name, item_name)
192
+ elif item_kind in _KIND_TO_PY:
193
+ item_py = _KIND_TO_PY[item_kind]
194
+ else:
195
+ item_py = "Any"
196
+ base = f"list[{item_py}]"
197
+ elif kind == "map":
198
+ key_py = _KIND_TO_PY.get(f.get("key_kind", "string"), "str")
199
+ val_kind = f.get("value_kind", "string")
200
+ if val_kind == "ref":
201
+ val_py = known_types.get(f.get("value_ref", "Any"), f.get("value_ref", "Any"))
202
+ elif val_kind in _KIND_TO_PY:
203
+ val_py = _KIND_TO_PY[val_kind]
204
+ else:
205
+ val_py = "Any"
206
+ base = f"dict[{key_py}, {val_py}]"
207
+ elif kind == "ref":
208
+ ref_name = f.get("ref_name", "Any")
209
+ base = known_types.get(ref_name, ref_name)
210
+ elif kind == "enum":
211
+ base = "str"
212
+ else:
213
+ base = "Any"
214
+
215
+ if nullable:
216
+ return f"Optional[{base}]"
217
+ return base
218
+
219
+
220
+ def _py_default_from_field(f: dict[str, Any]) -> str | None:
221
+ """Derive Python default expression from a v1 schema field dict.
222
+
223
+ Falls back to legacy _py_default_str for unversioned fields.
224
+ """
225
+ dk = f.get("default_kind")
226
+ if dk is None:
227
+ return _py_default_str(f.get("type", "str"), f.get("default"))
228
+
229
+ if dk == "value":
230
+ dv = f.get("default_value")
231
+ # Mutable containers MUST go through default_factory -- Python 3.13
232
+ # rejects `tags: dict = {}` at class creation time. See the matching
233
+ # guard in _py_default_str above.
234
+ if isinstance(dv, dict):
235
+ return "dataclasses.field(default_factory=dict)"
236
+ if isinstance(dv, list):
237
+ return "dataclasses.field(default_factory=list)"
238
+ if isinstance(dv, str):
239
+ return repr(dv)
240
+ if isinstance(dv, bool):
241
+ return "True" if dv else "False"
242
+ if dv is not None:
243
+ return str(dv)
244
+ return "None"
245
+ if dk == "empty_list":
246
+ return "dataclasses.field(default_factory=list)"
247
+ if dk == "empty_map":
248
+ return "dataclasses.field(default_factory=dict)"
249
+ if dk == "null":
250
+ return "None"
251
+ if dk == "none":
252
+ return None
253
+ return None
254
+
255
+
256
+ def _field_py_type(field: dict[str, Any], known_types: dict[str, str]) -> str:
257
+ raw_type = field.get("type", "str")
258
+ if str(raw_type).lower() == "optional":
259
+ hinted = _OPTIONAL_FIELD_HINTS.get(field.get("name", ""))
260
+ if hinted:
261
+ return hinted
262
+ elem_type_name = field.get("element_type", "")
263
+ if elem_type_name and str(raw_type).lower().startswith("list"):
264
+ elem_cls = known_types.get(elem_type_name, elem_type_name)
265
+ return f"list[{elem_cls}]"
266
+ return _py_type_str(raw_type, known_types)
267
+
268
+
269
+ # ── Type collection ──────────────────────────────────────────────────────────
270
+
271
+
272
+ class _TypeRecord:
273
+ """Collected type info from manifests."""
274
+
275
+ def __init__(self, wire_tag: str, display_name: str, fields: list[dict[str, Any]]):
276
+ self.wire_tag = wire_tag
277
+ self.display_name = display_name
278
+ self.fields = fields
279
+ self.services: set[str] = set() # services that reference this type
280
+ self.is_request_response = False # direct method param, not just nested
281
+
282
+
283
+ def collect_types(
284
+ manifests: dict[str, dict[str, Any]],
285
+ ) -> dict[str, _TypeRecord]:
286
+ """Walk all manifests and collect every referenced type by wire_tag.
287
+
288
+ Args:
289
+ manifests: Map of service_name -> manifest dict.
290
+
291
+ Returns:
292
+ Map of wire_tag -> _TypeRecord.
293
+ """
294
+ types: dict[str, _TypeRecord] = {}
295
+
296
+ def _ensure_type(wire_tag: str, display_name: str, fields: list[dict], service: str) -> None:
297
+ if not wire_tag and not display_name:
298
+ return
299
+ # Use wire_tag as key if available, otherwise display_name
300
+ key = wire_tag or display_name
301
+ if key not in types:
302
+ types[key] = _TypeRecord(wire_tag, display_name, fields)
303
+ types[key].services.add(service)
304
+
305
+ def _collect_element_types(fields: list[dict], service: str) -> None:
306
+ for f in fields:
307
+ # V1 schema: item_wire_tag / item_ref
308
+ elem_tag = f.get("item_wire_tag", "") or f.get("element_wire_tag", "")
309
+ elem_name = f.get("item_ref", "") or f.get("element_type", "")
310
+ elem_fields = f.get("element_fields", [])
311
+ if elem_tag:
312
+ _ensure_type(elem_tag, elem_name, elem_fields, service)
313
+ # V1 schema: ref fields
314
+ if f.get("kind") == "ref":
315
+ ref_tag = f.get("wire_tag", "")
316
+ ref_name = f.get("ref_name", "")
317
+ if ref_tag:
318
+ _ensure_type(ref_tag, ref_name, [], service)
319
+ field_type = f.get("type", "")
320
+ if isinstance(field_type, str) and field_type in _KNOWN_WIRE_TYPES:
321
+ wire_tag, nested_fields = _KNOWN_WIRE_TYPES[field_type]
322
+ _ensure_type(wire_tag, field_type, nested_fields, service)
323
+
324
+ for svc_name, manifest in manifests.items():
325
+ for method in manifest.get("methods", []):
326
+ # Request type
327
+ req_tag = method.get("request_wire_tag", "")
328
+ req_name = method.get("request_type", "")
329
+ req_fields = method.get("fields", [])
330
+ if req_tag or req_name:
331
+ _ensure_type(req_tag, req_name, req_fields, svc_name)
332
+ key = req_tag or req_name
333
+ types[key].is_request_response = True
334
+ _collect_element_types(req_fields, svc_name)
335
+
336
+ # Response type
337
+ resp_tag = method.get("response_wire_tag", "")
338
+ resp_name = method.get("response_type", "")
339
+ resp_fields = method.get("response_fields", [])
340
+ if resp_tag or resp_name:
341
+ _ensure_type(resp_tag, resp_name, resp_fields, svc_name)
342
+ key = resp_tag or resp_name
343
+ types[key].is_request_response = True
344
+ _collect_element_types(resp_fields, svc_name)
345
+
346
+ return types
347
+
348
+
349
+ def classify_types(
350
+ types: dict[str, _TypeRecord],
351
+ ) -> tuple[dict[str, list[_TypeRecord]], list[_TypeRecord]]:
352
+ """Classify types as service-scoped or shared.
353
+
354
+ Returns:
355
+ (service_types, shared_types) where:
356
+ - service_types: map of service_name -> list of types scoped to that service
357
+ - shared_types: list of types used across multiple services
358
+ """
359
+ service_types: dict[str, list[_TypeRecord]] = {}
360
+ shared_types: list[_TypeRecord] = []
361
+
362
+ for rec in types.values():
363
+ if len(rec.services) == 1:
364
+ svc = next(iter(rec.services))
365
+ service_types.setdefault(svc, []).append(rec)
366
+ else:
367
+ shared_types.append(rec)
368
+
369
+ return service_types, shared_types
370
+
371
+
372
+ # ── Python code generation ───────────────────────────────────────────────────
373
+
374
+
375
+ def _gen_header(source: str, contract_id: str) -> str:
376
+ ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
377
+ return (
378
+ f'"""\nAuto-generated by: aster contract gen-client {source} --lang python\n'
379
+ f"Contract ID: {contract_id}\n"
380
+ f"Generated at: {ts}\n"
381
+ f'DO NOT EDIT -- regenerate with: aster contract gen-client {source} --lang python\n"""\n'
382
+ )
383
+
384
+
385
+ def _gen_type_class(rec: _TypeRecord, known_types: dict[str, str]) -> str:
386
+ """Generate a Python dataclass for a type record."""
387
+ lines = []
388
+ if rec.wire_tag:
389
+ lines.append(f'@wire_type("{rec.wire_tag}")')
390
+ lines.append("@dataclasses.dataclass")
391
+ lines.append(f"class {rec.display_name}:")
392
+
393
+ if not rec.fields:
394
+ lines.append(" pass")
395
+ return "\n".join(lines)
396
+
397
+ for f in rec.fields:
398
+ fname = f["name"]
399
+ ftype_str = _py_type_from_field(f, known_types)
400
+ default = _py_default_from_field(f)
401
+
402
+ if default is not None:
403
+ lines.append(f" {fname}: {ftype_str} = {default}")
404
+ else:
405
+ lines.append(f" {fname}: {ftype_str} = None")
406
+
407
+ return "\n".join(lines)
408
+
409
+
410
+ def _gen_service_client(
411
+ svc_name: str,
412
+ manifest: dict[str, Any],
413
+ known_types: dict[str, str],
414
+ all_types: dict[str, _TypeRecord] | None = None,
415
+ ) -> str:
416
+ """Generate a Python ServiceClient subclass."""
417
+ contract_id = manifest.get("contract_id", "")
418
+ version = manifest.get("version", 1)
419
+ cls_name = f"{svc_name}Client"
420
+
421
+ methods = manifest.get("methods", [])
422
+
423
+ lines = []
424
+ lines.append(f"class {cls_name}(ServiceClient):")
425
+ lines.append(f' """Typed client for {svc_name} v{version}.')
426
+ lines.append(f"")
427
+ lines.append(f" Usage::")
428
+ lines.append(f"")
429
+ lines.append(f" client = AsterClient(address='aster1...')")
430
+ lines.append(f" await client.connect()")
431
+ lines.append(f" svc = await {cls_name}.from_connection(client)")
432
+ lines.append(f' """')
433
+ lines.append("")
434
+ lines.append(f' _service_name = "{svc_name}"')
435
+ lines.append(f" _service_version = {version}")
436
+ lines.append(f' _contract_id = "{contract_id}"')
437
+
438
+ # Collect all type classes referenced by this service's methods
439
+ # (for Fory codec registration in from_connection).
440
+ # Includes direct request/response types AND nested element types.
441
+ type_refs: list[str] = []
442
+ # Only include types that have wire_tags (Fory XLANG requires them)
443
+ tagged_types = {rec.display_name for rec in (all_types or {}).values() if rec.wire_tag}
444
+ def _add_ref(name: str) -> None:
445
+ cls_ref = known_types.get(name, name)
446
+ if cls_ref and cls_ref not in ("None", "Any", "") and cls_ref not in type_refs and cls_ref in tagged_types:
447
+ type_refs.append(cls_ref)
448
+ for method in methods:
449
+ for key in ("request_type", "response_type"):
450
+ _add_ref(method.get(key, ""))
451
+ for field_list_key in ("fields", "response_fields"):
452
+ for f in method.get(field_list_key, []):
453
+ elem_type = f.get("element_type", "")
454
+ if elem_type:
455
+ _add_ref(elem_type)
456
+ field_type = f.get("type", "")
457
+ if isinstance(field_type, str) and field_type in known_types:
458
+ _add_ref(field_type)
459
+ lines.append(f" _wire_types: list[type] = [{', '.join(type_refs)}]")
460
+ lines.append("")
461
+
462
+ # Generate MethodInfo class variables for each method
463
+ for method in methods:
464
+ mname = method["name"]
465
+ pattern = method.get("pattern", "unary")
466
+ req_cls = known_types.get(method.get("request_type", ""), "None")
467
+ resp_cls = known_types.get(method.get("response_type", ""), "None")
468
+ timeout_val = method.get("timeout")
469
+ idempotent = method.get("idempotent", False)
470
+ lines.append(f" _mi_{mname} = MethodInfo(")
471
+ lines.append(f' name="{mname}",')
472
+ lines.append(f' pattern="{pattern}",')
473
+ lines.append(f" request_type={req_cls},")
474
+ lines.append(f" response_type={resp_cls},")
475
+ lines.append(f" timeout={timeout_val},")
476
+ lines.append(f" idempotent={idempotent},")
477
+ lines.append(f" )")
478
+
479
+ # Generate from_connection classmethod
480
+ lines.append("")
481
+ lines.append(f" @classmethod")
482
+ lines.append(f" async def from_connection(cls, aster_client: AsterClient) -> {cls_name}:")
483
+ lines.append(f' """Create a {cls_name} from a connected AsterClient."""')
484
+ lines.append(f" from aster.codec import ForyCodec")
485
+ lines.append(f" from aster.service import ServiceInfo")
486
+ lines.append(f" from aster.transport.iroh import IrohTransport")
487
+ lines.append(f" summary = None")
488
+ lines.append(f" for s in aster_client._services:")
489
+ lines.append(f' if s.name == "{svc_name}":')
490
+ lines.append(f" summary = s")
491
+ lines.append(f" break")
492
+ lines.append(f" if summary is None:")
493
+ lines.append(f' raise RuntimeError("{svc_name} not found on this connection")')
494
+ lines.append(f' rpc_addr = summary.channels.get("rpc", "")')
495
+ lines.append(f" if not rpc_addr:")
496
+ lines.append(f' raise RuntimeError("{svc_name} has no rpc channel")')
497
+ lines.append(f" conn = await aster_client._rpc_conn_for(rpc_addr)")
498
+ lines.append(f" modes = list(getattr(summary, 'serialization_modes', None) or [])")
499
+ lines.append(f" if modes and 'xlang' not in modes and 'json' in modes:")
500
+ lines.append(f" from aster.json_codec import JsonProxyCodec")
501
+ lines.append(f" codec = JsonProxyCodec()")
502
+ lines.append(f" else:")
503
+ lines.append(f" codec = ForyCodec(types=cls._wire_types)")
504
+ lines.append(f" transport = IrohTransport(conn, codec=codec)")
505
+
506
+ # Build methods dict from the class-level MethodInfo objects
507
+ mi_names = [m["name"] for m in methods]
508
+ mi_dict = ", ".join(f'"{n}": cls._mi_{n}' for n in mi_names)
509
+ lines.append(f" info = ServiceInfo(")
510
+ lines.append(f' name="{svc_name}",')
511
+ lines.append(f" version={version},")
512
+ lines.append(f" methods={{{mi_dict}}},")
513
+ lines.append(f" )")
514
+ lines.append(f" return cls(transport, info, codec)")
515
+
516
+ # Generate method stubs
517
+ for method in methods:
518
+ mname = method["name"]
519
+ pattern = method.get("pattern", "unary")
520
+ req_cls = known_types.get(method.get("request_type", ""), "Any")
521
+ resp_cls = known_types.get(method.get("response_type", ""), "Any")
522
+
523
+ lines.append("")
524
+
525
+ if pattern == "unary":
526
+ lines.append(f" async def {mname}(")
527
+ lines.append(f" self, request: {req_cls}, *, timeout: float | None = None")
528
+ lines.append(f" ) -> {resp_cls}:")
529
+ lines.append(f" return await self._call_unary(")
530
+ lines.append(f" method_info=self._mi_{mname},")
531
+ lines.append(f" request=request,")
532
+ lines.append(f" timeout=timeout,")
533
+ lines.append(f" )")
534
+
535
+ elif pattern == "server_stream":
536
+ lines.append(f" def {mname}(")
537
+ lines.append(f" self, request: {req_cls}, *, timeout: float | None = None")
538
+ lines.append(f" ) -> AsyncIterator[{resp_cls}]:")
539
+ lines.append(f" return self._call_server_stream(")
540
+ lines.append(f" method_info=self._mi_{mname},")
541
+ lines.append(f" request=request,")
542
+ lines.append(f" timeout=timeout,")
543
+ lines.append(f" )")
544
+
545
+ elif pattern == "client_stream":
546
+ lines.append(f" async def {mname}(")
547
+ lines.append(f" self, requests: AsyncIterator[{req_cls}], *, timeout: float | None = None")
548
+ lines.append(f" ) -> {resp_cls}:")
549
+ lines.append(f" return await self._call_client_stream(")
550
+ lines.append(f" method_info=self._mi_{mname},")
551
+ lines.append(f" requests=requests,")
552
+ lines.append(f" timeout=timeout,")
553
+ lines.append(f" )")
554
+
555
+ elif pattern == "bidi_stream":
556
+ lines.append(f" def {mname}(")
557
+ lines.append(f" self, *, timeout: float | None = None")
558
+ lines.append(f" ) -> BidiChannel:")
559
+ lines.append(f" return self._call_bidi_stream(")
560
+ lines.append(f" method_info=self._mi_{mname},")
561
+ lines.append(f" timeout=timeout,")
562
+ lines.append(f" )")
563
+
564
+ return "\n".join(lines)
565
+
566
+
567
+ # ── Main generation entry point ──────────────────────────────────────────────
568
+
569
+
570
+ def generate_python_clients(
571
+ manifests: dict[str, dict[str, Any]],
572
+ out_dir: str,
573
+ namespace: str,
574
+ source: str = "",
575
+ ) -> list[str]:
576
+ """Generate Python client files from manifests.
577
+
578
+ Args:
579
+ manifests: Map of service_name -> manifest dict (from ContractManifest).
580
+ out_dir: Root output directory.
581
+ namespace: Handle or endpoint_id prefix for the package namespace.
582
+ source: Source description for the header comment.
583
+
584
+ Returns:
585
+ List of generated file paths.
586
+ """
587
+ generated: list[str] = []
588
+
589
+ # Step 2-3: Collect and classify types
590
+ all_types = collect_types(manifests)
591
+ service_types, shared_types = classify_types(all_types)
592
+
593
+ # Build known_types map: display_name -> display_name (identity for now)
594
+ known_types: dict[str, str] = {}
595
+ for rec in all_types.values():
596
+ known_types[rec.display_name] = rec.display_name
597
+
598
+ # Create directory structure
599
+ ns_dir = os.path.join(out_dir, _to_snake_case(namespace))
600
+ types_dir = os.path.join(ns_dir, "types")
601
+ services_dir = os.path.join(ns_dir, "services")
602
+ os.makedirs(types_dir, exist_ok=True)
603
+ os.makedirs(services_dir, exist_ok=True)
604
+
605
+ # Contract ID for header (use first manifest's)
606
+ first_manifest = next(iter(manifests.values()), {})
607
+ contract_id = first_manifest.get("contract_id", "")
608
+
609
+ # Common imports for type files
610
+ _type_imports = "\nimport dataclasses\nfrom typing import Any, Optional\n\nfrom aster.codec import wire_type\n"
611
+
612
+ # Step 4a: Generate shared type files
613
+ for rec in sorted(shared_types, key=lambda r: r.display_name):
614
+ fname = _to_snake_case(rec.display_name) + ".py"
615
+ fpath = os.path.join(types_dir, fname)
616
+ content = _gen_header(source, contract_id)
617
+ content += _type_imports + "\n\n"
618
+ content += _gen_type_class(rec, known_types) + "\n"
619
+ _write_file(fpath, content)
620
+ generated.append(fpath)
621
+
622
+ # Step 4b: Generate service-scoped type files
623
+ for svc_name, manifest in sorted(manifests.items()):
624
+ svc_contract_id = manifest.get("contract_id", contract_id)
625
+ fname = _to_snake_case(svc_name) + "_v" + str(manifest.get("version", 1)) + ".py"
626
+ fpath = os.path.join(types_dir, fname)
627
+
628
+ svc_recs = service_types.get(svc_name, [])
629
+ if not svc_recs:
630
+ continue
631
+
632
+ content = _gen_header(source, svc_contract_id)
633
+ content += _type_imports
634
+
635
+ # Import shared types referenced by this service's types
636
+ shared_imports = _collect_shared_imports(svc_recs, shared_types, known_types)
637
+ for imp_name, imp_file in sorted(shared_imports):
638
+ content += f"from .{imp_file} import {imp_name}\n"
639
+
640
+ for rec in sorted(svc_recs, key=lambda r: r.display_name):
641
+ content += "\n\n" + _gen_type_class(rec, known_types)
642
+
643
+ content += "\n"
644
+ _write_file(fpath, content)
645
+ generated.append(fpath)
646
+
647
+ # Step 5: Generate service client files
648
+ for svc_name, manifest in sorted(manifests.items()):
649
+ svc_contract_id = manifest.get("contract_id", contract_id)
650
+ version = manifest.get("version", 1)
651
+ fname = _to_snake_case(svc_name) + "_v" + str(version) + ".py"
652
+ fpath = os.path.join(services_dir, fname)
653
+
654
+ content = _gen_header(source, svc_contract_id)
655
+ content += "\nfrom __future__ import annotations\n\n"
656
+ content += "from collections.abc import AsyncIterator\n"
657
+ content += "from typing import TYPE_CHECKING\n\n"
658
+ content += "from aster.client import ServiceClient\n"
659
+ content += "from aster.service import MethodInfo\n"
660
+ content += "from aster.transport.base import BidiChannel\n\n"
661
+ content += "if TYPE_CHECKING:\n"
662
+ content += " from aster.runtime import AsterClient\n"
663
+
664
+ # Import all types used by this service's methods
665
+ type_imports = _collect_service_type_imports(svc_name, manifest, all_types, shared_types)
666
+ for imp_name, imp_module in sorted(type_imports):
667
+ content += f"from ..types.{imp_module} import {imp_name}\n"
668
+
669
+ content += "\n\n"
670
+ content += _gen_service_client(svc_name, manifest, known_types, all_types)
671
+ content += "\n"
672
+
673
+ _write_file(fpath, content)
674
+ generated.append(fpath)
675
+
676
+ # Step 6: Generate __init__.py files
677
+ _write_init(ns_dir, manifests, namespace, generated)
678
+ _write_init(types_dir, {}, "", [])
679
+ _write_init(services_dir, {}, "", [])
680
+
681
+ return generated
682
+
683
+
684
+ def _collect_shared_imports(
685
+ svc_recs: list[_TypeRecord],
686
+ shared_types: list[_TypeRecord],
687
+ known_types: dict[str, str],
688
+ ) -> list[tuple[str, str]]:
689
+ """Find shared types that this service's types reference."""
690
+ shared_names = {r.display_name for r in shared_types}
691
+ shared_tags = {r.wire_tag: r for r in shared_types}
692
+ imports: list[tuple[str, str]] = [] # (class_name, module_name)
693
+ seen: set[str] = set()
694
+
695
+ for rec in svc_recs:
696
+ for f in rec.fields:
697
+ # Check element types
698
+ elem_tag = f.get("element_wire_tag", "")
699
+ if elem_tag and elem_tag in shared_tags:
700
+ name = shared_tags[elem_tag].display_name
701
+ if name not in seen:
702
+ imports.append((name, _to_snake_case(name)))
703
+ seen.add(name)
704
+ # Check field type names
705
+ type_name = f.get("type", "")
706
+ if type_name in shared_names and type_name not in seen:
707
+ imports.append((type_name, _to_snake_case(type_name)))
708
+ seen.add(type_name)
709
+ m = re.match(r"list\[(.+)\]", type_name, re.IGNORECASE)
710
+ if m:
711
+ inner = m.group(1)
712
+ if inner in shared_names and inner not in seen:
713
+ imports.append((inner, _to_snake_case(inner)))
714
+ seen.add(inner)
715
+
716
+ return imports
717
+
718
+
719
+ def _collect_service_type_imports(
720
+ svc_name: str,
721
+ manifest: dict[str, Any],
722
+ all_types: dict[str, _TypeRecord],
723
+ shared_types: list[_TypeRecord],
724
+ ) -> list[tuple[str, str]]:
725
+ """Collect all type imports needed by a service client file."""
726
+ imports: list[tuple[str, str]] = [] # (class_name, module_name)
727
+ seen: set[str] = set()
728
+ version = manifest.get("version", 1)
729
+ svc_types_module = _to_snake_case(svc_name) + f"_v{version}"
730
+ shared_names = {r.display_name: r for r in shared_types}
731
+ known_display_names = {r.display_name for r in all_types.values()}
732
+
733
+ def _add_import(display_name: str) -> None:
734
+ if not display_name or display_name in seen or display_name in ("None", "Any"):
735
+ return
736
+ # Skip generic-wrapper forms like ``AsyncIterator[CommandResult]``
737
+ # that occasionally leak through from PEP 563 string annotations.
738
+ # Importing them produces ``from x import AsyncIterator[Y]`` which
739
+ # is invalid Python. The manifest publisher unwraps these too,
740
+ # but we guard here as a belt-and-braces defence.
741
+ if not display_name.isidentifier():
742
+ return
743
+ seen.add(display_name)
744
+ if display_name in shared_names:
745
+ imports.append((display_name, _to_snake_case(display_name)))
746
+ else:
747
+ imports.append((display_name, svc_types_module))
748
+
749
+ for method in manifest.get("methods", []):
750
+ # Direct request/response types
751
+ for name_key in ("request_type", "response_type"):
752
+ _add_import(method.get(name_key, ""))
753
+ # Element types from list fields
754
+ for field_list_key in ("fields", "response_fields"):
755
+ for f in method.get(field_list_key, []):
756
+ _add_import(f.get("element_type", ""))
757
+ field_type = f.get("type", "")
758
+ if isinstance(field_type, str):
759
+ cleaned = field_type.removeprefix("Optional[").removesuffix("]")
760
+ if cleaned in known_display_names:
761
+ _add_import(cleaned)
762
+
763
+ return imports
764
+
765
+
766
+ def _write_init(dir_path: str, manifests: dict, namespace: str, generated: list[str]) -> None:
767
+ """Write an __init__.py file."""
768
+ fpath = os.path.join(dir_path, "__init__.py")
769
+ _write_file(fpath, "")
770
+
771
+
772
+ def _write_file(path: str, content: str) -> None:
773
+ """Write content to a file, creating parent dirs."""
774
+ os.makedirs(os.path.dirname(path), exist_ok=True)
775
+ with open(path, "w") as f:
776
+ f.write(content)
777
+
778
+
779
+ # ── Usage snippet ────────────────────────────────────────────────────────────
780
+
781
+
782
+ def format_usage_snippet(
783
+ out_dir: str,
784
+ namespace: str,
785
+ manifests: dict[str, dict[str, Any]],
786
+ address: str = "",
787
+ ) -> str:
788
+ """Format a usage snippet to print after generation."""
789
+ ns_snake = _to_snake_case(namespace)
790
+ lines = [f"\nGenerated clients -> {out_dir}/{ns_snake}/\n"]
791
+ lines.append("Usage:\n")
792
+ lines.append(" from aster import AsterClient")
793
+
794
+ # Pick first service as example
795
+ svc_name = next(iter(manifests), "MyService")
796
+ version = manifests.get(svc_name, {}).get("version", 1)
797
+ svc_snake = _to_snake_case(svc_name)
798
+
799
+ lines.append(
800
+ f" from {ns_snake}.services.{svc_snake}_v{version} import {svc_name}Client"
801
+ )
802
+
803
+ # Find a unary method for the example
804
+ example_method = None
805
+ example_req = None
806
+ for m in manifests.get(svc_name, {}).get("methods", []):
807
+ if m.get("pattern", "unary") == "unary":
808
+ example_method = m["name"]
809
+ example_req = m.get("request_type", "Request")
810
+ break
811
+
812
+ if example_req:
813
+ lines.append(
814
+ f" from {ns_snake}.types.{svc_snake}_v{version} import {example_req}"
815
+ )
816
+
817
+ lines.append("")
818
+ if address:
819
+ lines.append(f' client = AsterClient(address="{address}")')
820
+ else:
821
+ lines.append(' client = AsterClient(address="aster1...")')
822
+ lines.append(" await client.connect()")
823
+ lines.append(f" svc = await {svc_name}Client.from_connection(client)")
824
+
825
+ if example_method and example_req:
826
+ lines.append(f" result = await svc.{example_method}({example_req}())")
827
+
828
+ return "\n".join(lines)