typespec-rust-emitter 0.3.0 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/emitter.ts CHANGED
@@ -245,22 +245,45 @@ function getOperationParameters(
245
245
 
246
246
  for (const [propName, prop] of model.properties) {
247
247
  const decorators = (prop as any).decorators;
248
+
249
+ // Skip body parameters - they are handled separately
250
+ let isBody = false;
251
+ if (decorators) {
252
+ for (const key of Object.keys(decorators)) {
253
+ const decorator = decorators[key];
254
+ const name = getDecoratorName(decorator);
255
+ if (name === "body" || name === "bodyRoot") {
256
+ isBody = true;
257
+ break;
258
+ }
259
+ }
260
+ }
261
+ if (isBody) continue;
262
+
248
263
  let location: "path" | "query" | "header" | "cookie" = "query";
249
264
  let rustName = toRustIdent(propName);
250
265
 
251
266
  if (decorators) {
252
- if (decorators["$path"]) {
253
- location = "path";
254
- } else if (decorators["$query"]) {
255
- location = "query";
256
- } else if (decorators["$header"]) {
257
- location = "header";
258
- const headerDec = decorators["$header"];
259
- if (headerDec?.value) {
260
- rustName = headerDec.value;
267
+ for (const key of Object.keys(decorators)) {
268
+ const decorator = decorators[key];
269
+ const name = getDecoratorName(decorator);
270
+ if (name === "path") {
271
+ location = "path";
272
+ break;
273
+ } else if (name === "query") {
274
+ location = "query";
275
+ break;
276
+ } else if (name === "header") {
277
+ location = "header";
278
+ const headerVal = getDecoratorArgValue(decorator, 0);
279
+ if (headerVal) {
280
+ rustName = toRustIdent(headerVal);
281
+ }
282
+ break;
283
+ } else if (name === "cookie") {
284
+ location = "cookie";
285
+ break;
261
286
  }
262
- } else if (decorators["$cookie"]) {
263
- location = "cookie";
264
287
  }
265
288
  }
266
289
 
@@ -286,8 +309,14 @@ function getOperationParameters(
286
309
  function getOperationBody(operation: Operation): ModelProperty | undefined {
287
310
  for (const [_propName, prop] of operation.parameters.properties) {
288
311
  const decorators = (prop as any).decorators;
289
- if (decorators?.["$body"] || decorators?.["$bodyRoot"]) {
290
- return prop;
312
+ if (!decorators) continue;
313
+
314
+ for (const key of Object.keys(decorators)) {
315
+ const decorator = decorators[key];
316
+ const name = getDecoratorName(decorator);
317
+ if (name === "body" || name === "bodyRoot") {
318
+ return prop;
319
+ }
291
320
  }
292
321
  }
293
322
  return undefined;
@@ -491,6 +520,8 @@ function generateServerTrait(
491
520
  parts.push(`use super::types::*;
492
521
  use async_trait::async_trait;
493
522
  use axum::{http::StatusCode, Json};
523
+ use axum::extract::Path;
524
+ use axum::Extension;
494
525
  use eyre::Result;
495
526
 
496
527
  #[async_trait]
@@ -512,75 +543,58 @@ pub trait Server: Send + Sync {
512
543
  parts.push(` ${formatDoc(opInfo.doc)}`);
513
544
  }
514
545
 
515
- const requestName = `${nsName}${toPascalCase(opInfo.name)}Request`;
516
546
  const responseName = `${nsName}${toPascalCase(opInfo.name)}Response`;
517
547
  const fnName = toRustIdent(`${nsName}_${opInfo.name}`);
518
548
  const isProtected = hasAuthDecorator(op);
519
549
 
520
- if (isProtected) {
521
- parts.push(
522
- ` async fn ${fnName}(&self, claims: Self::Claims, request: ${requestName}) -> Result<${responseName}>;`,
523
- );
524
- } else {
525
- parts.push(
526
- ` async fn ${fnName}(&self, request: ${requestName}) -> Result<${responseName}>;`,
527
- );
528
- }
529
- }
530
- }
531
-
532
- parts.push("}");
533
-
534
- return parts.join("\n");
535
- }
550
+ // Build parameter list for the trait method
551
+ const paramParts: string[] = [];
536
552
 
537
- function generateRequestStructs(
538
- program: Program,
539
- namespaceGroups: { namespace: Namespace; operations: Operation[] }[],
540
- anonymousEnums: Map<string, AnonymousStringLiteralUnion>,
541
- ): string {
542
- const parts: string[] = [];
543
-
544
- for (const group of namespaceGroups) {
545
- const nsName = toPascalCase(
546
- group.namespace.name.replace(/[^a-zA-Z0-9_]/g, "_"),
547
- );
548
-
549
- for (const op of group.operations) {
550
- const opInfo = emitOperationInfo(program, op, "", anonymousEnums);
551
- if (!opInfo) continue;
552
-
553
- const params = opInfo.parameters;
554
- const requestName = `${nsName}${toPascalCase(opInfo.name)}Request`;
555
-
556
- const fields: string[] = [];
557
-
558
- for (const param of params) {
559
- const rustType = param.optional
560
- ? `Option<${param.rustType}>`
561
- : param.rustType;
562
- fields.push(` #[serde(rename = "${param.name}", flatten)]`);
563
- fields.push(` pub ${param.rustName}: ${rustType},`);
553
+ // Add path parameters
554
+ for (const param of opInfo.parameters) {
555
+ if (param.location === "path") {
556
+ paramParts.push(`${param.rustName}: ${param.rustType}`);
557
+ }
564
558
  }
565
559
 
560
+ // Add body parameter
566
561
  if (opInfo.body) {
567
562
  const bodyType = getRustTypeForProperty(
568
563
  opInfo.body.type,
569
564
  program,
570
565
  anonymousEnums,
571
566
  );
572
- fields.push(` #[serde(rename = "body")]`);
573
- fields.push(` pub body: ${bodyType.type},`);
567
+ paramParts.push(`body: ${bodyType.type}`);
574
568
  }
575
569
 
576
- parts.push(`#[derive(Debug, Clone, serde::Deserialize)]
577
- pub struct ${requestName} {
578
- ${fields.join("\n")}
579
- }
580
- `);
570
+ const paramsStr = paramParts.join(", ");
571
+
572
+ if (isProtected) {
573
+ if (paramsStr) {
574
+ parts.push(
575
+ ` async fn ${fnName}(&self, claims: Self::Claims, ${paramsStr}) -> Result<${responseName}>;`,
576
+ );
577
+ } else {
578
+ parts.push(
579
+ ` async fn ${fnName}(&self, claims: Self::Claims) -> Result<${responseName}>;`,
580
+ );
581
+ }
582
+ } else {
583
+ if (paramsStr) {
584
+ parts.push(
585
+ ` async fn ${fnName}(&self, ${paramsStr}) -> Result<${responseName}>;`,
586
+ );
587
+ } else {
588
+ parts.push(
589
+ ` async fn ${fnName}(&self) -> Result<${responseName}>;`,
590
+ );
591
+ }
592
+ }
581
593
  }
582
594
  }
583
595
 
596
+ parts.push("}");
597
+
584
598
  return parts.join("\n");
585
599
  }
586
600
 
@@ -672,83 +686,57 @@ function generateRouter(
672
686
 
673
687
  const handlerFnName = toRustIdent(`${nsName}_${opInfo.name}`);
674
688
  const traitFnName = handlerFnName;
675
- const requestName = `${nsName}${toPascalCase(opInfo.name)}Request`;
676
689
  const isProtected = hasAuthDecorator(op);
677
690
 
678
691
  const pathParams = opInfo.parameters.filter((p) => p.location === "path");
679
- const queryParams = opInfo.parameters.filter(
680
- (p) => p.location === "query",
681
- );
682
692
  const hasPathParams = pathParams.length > 0;
683
- const hasQueryParams = queryParams.length > 0;
684
693
  const hasBody = !!opInfo.body;
685
694
 
686
- // Build extractor lines and request construction expression
695
+ // Build extractor lines and server method call arguments
696
+ // IMPORTANT: axum requires specific extractor order:
697
+ // State -> Extension -> Path -> Query -> Json -> Body
687
698
  const extractorLines: string[] = [];
688
- let requestExpr = "";
699
+ const serverArgs: string[] = [];
700
+
701
+ // State is always first (added in handler template)
702
+
703
+ // Extension (claims) comes after State
704
+ if (isProtected) {
705
+ extractorLines.push(` Extension(claims): Extension<S::Claims>,`);
706
+ serverArgs.push(`claims`);
707
+ }
689
708
 
709
+ // Path params come after Extension
690
710
  if (hasPathParams) {
691
711
  const pathTypes = pathParams.map((p) => p.rustType).join(", ");
692
712
  const pathFields = pathParams.map((p) => p.rustName).join(", ");
693
713
  if (pathParams.length === 1) {
694
- extractorLines.push(
695
- ` axum::extract::Path(${pathFields}): axum::extract::Path<${pathTypes}>,`,
696
- );
714
+ extractorLines.push(` Path(${pathFields}): Path<${pathTypes}>,`);
697
715
  } else {
698
716
  extractorLines.push(
699
- ` axum::extract::Path((${pathFields})): axum::extract::Path<(${pathTypes})>,`,
717
+ ` Path((${pathFields})): Path<(${pathTypes})>,`,
700
718
  );
701
719
  }
720
+ // Add path params to server method args
721
+ for (const param of pathParams) {
722
+ serverArgs.push(param.rustName);
723
+ }
702
724
  }
703
725
 
704
- if (hasQueryParams) {
705
- extractorLines.push(
706
- ` axum::extract::Query(query): axum::extract::Query<${requestName}>,`,
707
- );
708
- } else if (hasBody) {
709
- extractorLines.push(
710
- ` axum::Json(body): axum::Json<${requestName}Body>,`,
711
- );
712
- }
713
-
714
- if (isProtected) {
715
- extractorLines.push(
716
- ` axum::Extension(claims): axum::Extension<S::Claims>,`,
726
+ // Json body comes last
727
+ if (hasBody && opInfo.body) {
728
+ const bodyType = getRustTypeForProperty(
729
+ opInfo.body.type,
730
+ program,
731
+ anonymousEnums,
717
732
  );
733
+ extractorLines.push(` Json(payload): Json<${bodyType.type}>,`);
734
+ serverArgs.push(`payload`);
718
735
  }
719
736
 
720
- // Build request struct expression
721
- if (hasOnlyPathParams(hasPathParams, hasQueryParams, hasBody)) {
722
- const pathAssignments = pathParams
723
- .map((p) => `${p.rustName},`)
724
- .join(" ");
725
- requestExpr = `${requestName} { ${pathAssignments} }`;
726
- } else if (hasQueryParams) {
727
- if (hasPathParams) {
728
- const pathAssignments = pathParams
729
- .map((p) => `${p.rustName},`)
730
- .join(" ");
731
- requestExpr = `${requestName} { ${pathAssignments} ..query }`;
732
- } else {
733
- requestExpr = "query";
734
- }
735
- } else if (hasBody) {
736
- if (hasPathParams) {
737
- const pathAssignments = pathParams
738
- .map((p) => `${p.rustName},`)
739
- .join(" ");
740
- requestExpr = `${requestName} { ${pathAssignments} body }`;
741
- } else {
742
- requestExpr = `${requestName} { body }`;
743
- }
744
- } else {
745
- requestExpr = `${requestName} {}`;
746
- }
747
-
748
- // Server method call
749
- const serverCall = isProtected
750
- ? `service.${traitFnName}(claims, ${requestExpr}).await`
751
- : `service.${traitFnName}(${requestExpr}).await`;
737
+ // Build server method call
738
+ const serverArgsStr = serverArgs.join(", ");
739
+ const serverCall = `service.${traitFnName}(${serverArgsStr}).await`;
752
740
 
753
741
  // All handlers use <S> generics, Claims is now an associated type
754
742
  let handlerCode = `pub async fn ${handlerFnName}_handler<S>(
@@ -1172,9 +1160,7 @@ function emitModel(
1172
1160
  parts.push('#[error("{code}: {message}")]');
1173
1161
  }
1174
1162
 
1175
- const customAttrs = (model as any)[rustAttrKey] as
1176
- | RustAttrInfo
1177
- | undefined;
1163
+ const customAttrs = (model as any)[rustAttrKey] as RustAttrInfo | undefined;
1178
1164
  if (customAttrs) {
1179
1165
  for (const attr of customAttrs.attrs) {
1180
1166
  parts.push(`#[${attr}]`);
@@ -1277,9 +1263,7 @@ function emitEnum(enumType: Enum): string {
1277
1263
  }
1278
1264
 
1279
1265
  if (isString) {
1280
- parts.push(
1281
- `#[derive(${allDerives.join(", ")})]`,
1282
- );
1266
+ parts.push(`#[derive(${allDerives.join(", ")})]`);
1283
1267
  if (attrLines.length > 0) {
1284
1268
  parts.push(...attrLines);
1285
1269
  }
@@ -1291,9 +1275,7 @@ function emitEnum(enumType: Enum): string {
1291
1275
  parts.push(` ${variantName},`);
1292
1276
  }
1293
1277
  } else {
1294
- parts.push(
1295
- `#[derive(${allDerives.join(", ")})]`,
1296
- );
1278
+ parts.push(`#[derive(${allDerives.join(", ")})]`);
1297
1279
  if (attrLines.length > 0) {
1298
1280
  parts.push(...attrLines);
1299
1281
  }
@@ -1507,11 +1489,6 @@ export async function $onEmit(
1507
1489
  namespaceGroups,
1508
1490
  anonymousEnums,
1509
1491
  );
1510
- const requestStructs = generateRequestStructs(
1511
- context.program,
1512
- namespaceGroups,
1513
- anonymousEnums,
1514
- );
1515
1492
  const responseEnums = generateResponseEnums(
1516
1493
  context.program,
1517
1494
  namespaceGroups,
@@ -1523,12 +1500,7 @@ export async function $onEmit(
1523
1500
  anonymousEnums,
1524
1501
  );
1525
1502
 
1526
- const serverContent = [
1527
- serverTrait,
1528
- requestStructs,
1529
- responseEnums,
1530
- router,
1531
- ].join("\n");
1503
+ const serverContent = [serverTrait, responseEnums, router].join("\n");
1532
1504
 
1533
1505
  await emitFile(context.program, {
1534
1506
  path: resolvePath(context.emitterOutputDir, "server.rs"),