typespec-rust-emitter 0.2.0 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/emitter.ts CHANGED
@@ -33,34 +33,45 @@ interface RustDeriveInfo {
33
33
  derives: string[];
34
34
  }
35
35
 
36
+ interface RustAttrInfo {
37
+ attrs: string[];
38
+ }
39
+
36
40
  const rustDeriveKey = Symbol("rustDerive");
41
+ const rustAttrKey = Symbol("rustAttr");
37
42
 
38
43
  export function $rustDerive(
39
44
  context: DecoratorContext,
40
45
  target: Type,
41
46
  derive: string,
42
47
  ) {
43
- if (target.kind !== "Model") {
48
+ if (target.kind !== "Model" && target.kind !== "Enum") {
44
49
  context.program.reportDiagnostic({
45
50
  code: "rust-derive-invalid-target",
46
- message: `@rustDerive can only be applied to models`,
51
+ message: `@rustDerive can only be applied to models and enums`,
47
52
  severity: "error",
48
53
  target: context.decoratorTarget,
49
54
  });
50
55
  return;
51
56
  }
52
57
 
53
- const model = target as Model;
54
- const ns = model.namespace ? getNamespaceFullName(model.namespace) : "";
58
+ const ns =
59
+ target.kind === "Model"
60
+ ? target.namespace
61
+ ? getNamespaceFullName(target.namespace)
62
+ : ""
63
+ : target.namespace
64
+ ? getNamespaceFullName(target.namespace)
65
+ : "";
55
66
 
56
67
  if (!ns.startsWith("TypeSpec")) {
57
- const info = (model as any)[rustDeriveKey] as RustDeriveInfo | undefined;
68
+ const info = (target as any)[rustDeriveKey] as RustDeriveInfo | undefined;
58
69
  if (info) {
59
70
  if (!info.derives.includes(derive)) {
60
71
  info.derives.push(derive);
61
72
  }
62
73
  } else {
63
- (model as any)[rustDeriveKey] = { derives: [derive] };
74
+ (target as any)[rustDeriveKey] = { derives: [derive] };
64
75
  }
65
76
  }
66
77
  }
@@ -75,6 +86,52 @@ export function $rustDerives(
75
86
  }
76
87
  }
77
88
 
89
+ export function $rustAttr(
90
+ context: DecoratorContext,
91
+ target: Type,
92
+ attr: string,
93
+ ) {
94
+ if (target.kind !== "Model" && target.kind !== "Enum") {
95
+ context.program.reportDiagnostic({
96
+ code: "rust-attr-invalid-target",
97
+ message: `@rustAttr can only be applied to models and enums`,
98
+ severity: "error",
99
+ target: context.decoratorTarget,
100
+ });
101
+ return;
102
+ }
103
+
104
+ const ns =
105
+ target.kind === "Model"
106
+ ? target.namespace
107
+ ? getNamespaceFullName(target.namespace)
108
+ : ""
109
+ : target.namespace
110
+ ? getNamespaceFullName(target.namespace)
111
+ : "";
112
+
113
+ if (!ns.startsWith("TypeSpec")) {
114
+ const info = (target as any)[rustAttrKey] as RustAttrInfo | undefined;
115
+ if (info) {
116
+ if (!info.attrs.includes(attr)) {
117
+ info.attrs.push(attr);
118
+ }
119
+ } else {
120
+ (target as any)[rustAttrKey] = { attrs: [attr] };
121
+ }
122
+ }
123
+ }
124
+
125
+ export function $rustAttrs(
126
+ context: DecoratorContext,
127
+ target: Type,
128
+ ...attrs: string[]
129
+ ) {
130
+ for (const attr of attrs) {
131
+ $rustAttr(context, target, attr);
132
+ }
133
+ }
134
+
78
135
  type HttpMethod = "GET" | "POST" | "PUT" | "PATCH" | "DELETE" | "HEAD";
79
136
 
80
137
  interface OperationInfo {
@@ -188,22 +245,45 @@ function getOperationParameters(
188
245
 
189
246
  for (const [propName, prop] of model.properties) {
190
247
  const decorators = (prop as any).decorators;
248
+
249
+ // Skip body parameters - they are handled separately
250
+ let isBody = false;
251
+ if (decorators) {
252
+ for (const key of Object.keys(decorators)) {
253
+ const decorator = decorators[key];
254
+ const name = getDecoratorName(decorator);
255
+ if (name === "body" || name === "bodyRoot") {
256
+ isBody = true;
257
+ break;
258
+ }
259
+ }
260
+ }
261
+ if (isBody) continue;
262
+
191
263
  let location: "path" | "query" | "header" | "cookie" = "query";
192
264
  let rustName = toRustIdent(propName);
193
265
 
194
266
  if (decorators) {
195
- if (decorators["$path"]) {
196
- location = "path";
197
- } else if (decorators["$query"]) {
198
- location = "query";
199
- } else if (decorators["$header"]) {
200
- location = "header";
201
- const headerDec = decorators["$header"];
202
- if (headerDec?.value) {
203
- rustName = headerDec.value;
267
+ for (const key of Object.keys(decorators)) {
268
+ const decorator = decorators[key];
269
+ const name = getDecoratorName(decorator);
270
+ if (name === "path") {
271
+ location = "path";
272
+ break;
273
+ } else if (name === "query") {
274
+ location = "query";
275
+ break;
276
+ } else if (name === "header") {
277
+ location = "header";
278
+ const headerVal = getDecoratorArgValue(decorator, 0);
279
+ if (headerVal) {
280
+ rustName = toRustIdent(headerVal);
281
+ }
282
+ break;
283
+ } else if (name === "cookie") {
284
+ location = "cookie";
285
+ break;
204
286
  }
205
- } else if (decorators["$cookie"]) {
206
- location = "cookie";
207
287
  }
208
288
  }
209
289
 
@@ -229,8 +309,14 @@ function getOperationParameters(
229
309
  function getOperationBody(operation: Operation): ModelProperty | undefined {
230
310
  for (const [_propName, prop] of operation.parameters.properties) {
231
311
  const decorators = (prop as any).decorators;
232
- if (decorators?.["$body"] || decorators?.["$bodyRoot"]) {
233
- return prop;
312
+ if (!decorators) continue;
313
+
314
+ for (const key of Object.keys(decorators)) {
315
+ const decorator = decorators[key];
316
+ const name = getDecoratorName(decorator);
317
+ if (name === "body" || name === "bodyRoot") {
318
+ return prop;
319
+ }
234
320
  }
235
321
  }
236
322
  return undefined;
@@ -434,6 +520,8 @@ function generateServerTrait(
434
520
  parts.push(`use super::types::*;
435
521
  use async_trait::async_trait;
436
522
  use axum::{http::StatusCode, Json};
523
+ use axum::extract::Path;
524
+ use axum::Extension;
437
525
  use eyre::Result;
438
526
 
439
527
  #[async_trait]
@@ -455,75 +543,58 @@ pub trait Server: Send + Sync {
455
543
  parts.push(` ${formatDoc(opInfo.doc)}`);
456
544
  }
457
545
 
458
- const requestName = `${nsName}${toPascalCase(opInfo.name)}Request`;
459
546
  const responseName = `${nsName}${toPascalCase(opInfo.name)}Response`;
460
547
  const fnName = toRustIdent(`${nsName}_${opInfo.name}`);
461
548
  const isProtected = hasAuthDecorator(op);
462
549
 
463
- if (isProtected) {
464
- parts.push(
465
- ` async fn ${fnName}(&self, claims: Self::Claims, request: ${requestName}) -> Result<${responseName}>;`,
466
- );
467
- } else {
468
- parts.push(
469
- ` async fn ${fnName}(&self, request: ${requestName}) -> Result<${responseName}>;`,
470
- );
471
- }
472
- }
473
- }
474
-
475
- parts.push("}");
476
-
477
- return parts.join("\n");
478
- }
479
-
480
- function generateRequestStructs(
481
- program: Program,
482
- namespaceGroups: { namespace: Namespace; operations: Operation[] }[],
483
- anonymousEnums: Map<string, AnonymousStringLiteralUnion>,
484
- ): string {
485
- const parts: string[] = [];
550
+ // Build parameter list for the trait method
551
+ const paramParts: string[] = [];
486
552
 
487
- for (const group of namespaceGroups) {
488
- const nsName = toPascalCase(
489
- group.namespace.name.replace(/[^a-zA-Z0-9_]/g, "_"),
490
- );
491
-
492
- for (const op of group.operations) {
493
- const opInfo = emitOperationInfo(program, op, "", anonymousEnums);
494
- if (!opInfo) continue;
495
-
496
- const params = opInfo.parameters;
497
- const requestName = `${nsName}${toPascalCase(opInfo.name)}Request`;
498
-
499
- const fields: string[] = [];
500
-
501
- for (const param of params) {
502
- const rustType = param.optional
503
- ? `Option<${param.rustType}>`
504
- : param.rustType;
505
- fields.push(` #[serde(rename = "${param.name}", flatten)]`);
506
- fields.push(` pub ${param.rustName}: ${rustType},`);
553
+ // Add path parameters
554
+ for (const param of opInfo.parameters) {
555
+ if (param.location === "path") {
556
+ paramParts.push(`${param.rustName}: ${param.rustType}`);
557
+ }
507
558
  }
508
559
 
560
+ // Add body parameter
509
561
  if (opInfo.body) {
510
562
  const bodyType = getRustTypeForProperty(
511
563
  opInfo.body.type,
512
564
  program,
513
565
  anonymousEnums,
514
566
  );
515
- fields.push(` #[serde(rename = "body")]`);
516
- fields.push(` pub body: ${bodyType.type},`);
567
+ paramParts.push(`body: ${bodyType.type}`);
517
568
  }
518
569
 
519
- parts.push(`#[derive(Debug, Clone, serde::Deserialize)]
520
- pub struct ${requestName} {
521
- ${fields.join("\n")}
522
- }
523
- `);
570
+ const paramsStr = paramParts.join(", ");
571
+
572
+ if (isProtected) {
573
+ if (paramsStr) {
574
+ parts.push(
575
+ ` async fn ${fnName}(&self, claims: Self::Claims, ${paramsStr}) -> Result<${responseName}>;`,
576
+ );
577
+ } else {
578
+ parts.push(
579
+ ` async fn ${fnName}(&self, claims: Self::Claims) -> Result<${responseName}>;`,
580
+ );
581
+ }
582
+ } else {
583
+ if (paramsStr) {
584
+ parts.push(
585
+ ` async fn ${fnName}(&self, ${paramsStr}) -> Result<${responseName}>;`,
586
+ );
587
+ } else {
588
+ parts.push(
589
+ ` async fn ${fnName}(&self) -> Result<${responseName}>;`,
590
+ );
591
+ }
592
+ }
524
593
  }
525
594
  }
526
595
 
596
+ parts.push("}");
597
+
527
598
  return parts.join("\n");
528
599
  }
529
600
 
@@ -615,83 +686,57 @@ function generateRouter(
615
686
 
616
687
  const handlerFnName = toRustIdent(`${nsName}_${opInfo.name}`);
617
688
  const traitFnName = handlerFnName;
618
- const requestName = `${nsName}${toPascalCase(opInfo.name)}Request`;
619
689
  const isProtected = hasAuthDecorator(op);
620
690
 
621
691
  const pathParams = opInfo.parameters.filter((p) => p.location === "path");
622
- const queryParams = opInfo.parameters.filter(
623
- (p) => p.location === "query",
624
- );
625
692
  const hasPathParams = pathParams.length > 0;
626
- const hasQueryParams = queryParams.length > 0;
627
693
  const hasBody = !!opInfo.body;
628
694
 
629
- // Build extractor lines and request construction expression
695
+ // Build extractor lines and server method call arguments
696
+ // IMPORTANT: axum requires specific extractor order:
697
+ // State -> Extension -> Path -> Query -> Json -> Body
630
698
  const extractorLines: string[] = [];
631
- let requestExpr = "";
699
+ const serverArgs: string[] = [];
632
700
 
701
+ // State is always first (added in handler template)
702
+
703
+ // Extension (claims) comes after State
704
+ if (isProtected) {
705
+ extractorLines.push(` Extension(claims): Extension<S::Claims>,`);
706
+ serverArgs.push(`claims`);
707
+ }
708
+
709
+ // Path params come after Extension
633
710
  if (hasPathParams) {
634
711
  const pathTypes = pathParams.map((p) => p.rustType).join(", ");
635
712
  const pathFields = pathParams.map((p) => p.rustName).join(", ");
636
713
  if (pathParams.length === 1) {
637
- extractorLines.push(
638
- ` axum::extract::Path(${pathFields}): axum::extract::Path<${pathTypes}>,`,
639
- );
714
+ extractorLines.push(` Path(${pathFields}): Path<${pathTypes}>,`);
640
715
  } else {
641
716
  extractorLines.push(
642
- ` axum::extract::Path((${pathFields})): axum::extract::Path<(${pathTypes})>,`,
717
+ ` Path((${pathFields})): Path<(${pathTypes})>,`,
643
718
  );
644
719
  }
720
+ // Add path params to server method args
721
+ for (const param of pathParams) {
722
+ serverArgs.push(param.rustName);
723
+ }
645
724
  }
646
725
 
647
- if (hasQueryParams) {
648
- extractorLines.push(
649
- ` axum::extract::Query(query): axum::extract::Query<${requestName}>,`,
650
- );
651
- } else if (hasBody) {
652
- extractorLines.push(
653
- ` axum::Json(body): axum::Json<${requestName}Body>,`,
726
+ // Json body comes last
727
+ if (hasBody && opInfo.body) {
728
+ const bodyType = getRustTypeForProperty(
729
+ opInfo.body.type,
730
+ program,
731
+ anonymousEnums,
654
732
  );
733
+ extractorLines.push(` Json(payload): Json<${bodyType.type}>,`);
734
+ serverArgs.push(`payload`);
655
735
  }
656
736
 
657
- if (isProtected) {
658
- extractorLines.push(
659
- ` axum::Extension(claims): axum::Extension<S::Claims>,`,
660
- );
661
- }
662
-
663
- // Build request struct expression
664
- if (hasOnlyPathParams(hasPathParams, hasQueryParams, hasBody)) {
665
- const pathAssignments = pathParams
666
- .map((p) => `${p.rustName},`)
667
- .join(" ");
668
- requestExpr = `${requestName} { ${pathAssignments} }`;
669
- } else if (hasQueryParams) {
670
- if (hasPathParams) {
671
- const pathAssignments = pathParams
672
- .map((p) => `${p.rustName},`)
673
- .join(" ");
674
- requestExpr = `${requestName} { ${pathAssignments} ..query }`;
675
- } else {
676
- requestExpr = "query";
677
- }
678
- } else if (hasBody) {
679
- if (hasPathParams) {
680
- const pathAssignments = pathParams
681
- .map((p) => `${p.rustName},`)
682
- .join(" ");
683
- requestExpr = `${requestName} { ${pathAssignments} body }`;
684
- } else {
685
- requestExpr = `${requestName} { body }`;
686
- }
687
- } else {
688
- requestExpr = `${requestName} {}`;
689
- }
690
-
691
- // Server method call
692
- const serverCall = isProtected
693
- ? `service.${traitFnName}(claims, ${requestExpr}).await`
694
- : `service.${traitFnName}(${requestExpr}).await`;
737
+ // Build server method call
738
+ const serverArgsStr = serverArgs.join(", ");
739
+ const serverCall = `service.${traitFnName}(${serverArgsStr}).await`;
695
740
 
696
741
  // All handlers use <S> generics, Claims is now an associated type
697
742
  let handlerCode = `pub async fn ${handlerFnName}_handler<S>(
@@ -1115,6 +1160,13 @@ function emitModel(
1115
1160
  parts.push('#[error("{code}: {message}")]');
1116
1161
  }
1117
1162
 
1163
+ const customAttrs = (model as any)[rustAttrKey] as RustAttrInfo | undefined;
1164
+ if (customAttrs) {
1165
+ for (const attr of customAttrs.attrs) {
1166
+ parts.push(`#[${attr}]`);
1167
+ }
1168
+ }
1169
+
1118
1170
  if (allProps.size > 0) {
1119
1171
  const fields: string[] = [];
1120
1172
  for (const [propName, prop] of allProps) {
@@ -1181,10 +1233,41 @@ function emitEnum(enumType: Enum): string {
1181
1233
  (m) => m.value === undefined || typeof m.value === "string",
1182
1234
  );
1183
1235
 
1236
+ const baseDerives = [
1237
+ "Debug",
1238
+ "Clone",
1239
+ "Copy",
1240
+ "PartialEq",
1241
+ "Eq",
1242
+ "Hash",
1243
+ "serde::Serialize",
1244
+ "serde::Deserialize",
1245
+ ];
1246
+
1247
+ const customDerives = (enumType as any)[rustDeriveKey] as
1248
+ | RustDeriveInfo
1249
+ | undefined;
1250
+ const allDerives = [...baseDerives];
1251
+ if (customDerives) {
1252
+ allDerives.push(...customDerives.derives);
1253
+ }
1254
+
1255
+ const customAttrs = (enumType as any)[rustAttrKey] as
1256
+ | RustAttrInfo
1257
+ | undefined;
1258
+ const attrLines: string[] = [];
1259
+ if (customAttrs) {
1260
+ for (const attr of customAttrs.attrs) {
1261
+ attrLines.push(`#[${attr}]`);
1262
+ }
1263
+ }
1264
+
1184
1265
  if (isString) {
1185
- parts.push(
1186
- `#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]\npub enum ${name} {`,
1187
- );
1266
+ parts.push(`#[derive(${allDerives.join(", ")})]`);
1267
+ if (attrLines.length > 0) {
1268
+ parts.push(...attrLines);
1269
+ }
1270
+ parts.push(`pub enum ${name} {`);
1188
1271
  for (const member of members) {
1189
1272
  const variantName = toRustVariantName(member.name);
1190
1273
  const serdeValue = member.value ?? member.name;
@@ -1192,9 +1275,11 @@ function emitEnum(enumType: Enum): string {
1192
1275
  parts.push(` ${variantName},`);
1193
1276
  }
1194
1277
  } else {
1195
- parts.push(
1196
- `#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]\npub enum ${name} {`,
1197
- );
1278
+ parts.push(`#[derive(${allDerives.join(", ")})]`);
1279
+ if (attrLines.length > 0) {
1280
+ parts.push(...attrLines);
1281
+ }
1282
+ parts.push(`pub enum ${name} {`);
1198
1283
  for (const member of members) {
1199
1284
  const variantName = toRustVariantName(member.name);
1200
1285
  const enumValue = member.value ?? 0;
@@ -1404,11 +1489,6 @@ export async function $onEmit(
1404
1489
  namespaceGroups,
1405
1490
  anonymousEnums,
1406
1491
  );
1407
- const requestStructs = generateRequestStructs(
1408
- context.program,
1409
- namespaceGroups,
1410
- anonymousEnums,
1411
- );
1412
1492
  const responseEnums = generateResponseEnums(
1413
1493
  context.program,
1414
1494
  namespaceGroups,
@@ -1420,12 +1500,7 @@ export async function $onEmit(
1420
1500
  anonymousEnums,
1421
1501
  );
1422
1502
 
1423
- const serverContent = [
1424
- serverTrait,
1425
- requestStructs,
1426
- responseEnums,
1427
- router,
1428
- ].join("\n");
1503
+ const serverContent = [serverTrait, responseEnums, router].join("\n");
1429
1504
 
1430
1505
  await emitFile(context.program, {
1431
1506
  path: resolvePath(context.emitterOutputDir, "server.rs"),
package/src/index.ts CHANGED
@@ -1,3 +1,3 @@
1
1
  export { $onEmit } from "./emitter.js";
2
- export { $rustDerive, $rustDerives } from "./emitter.js";
2
+ export { $rustDerive, $rustDerives, $rustAttr, $rustAttrs } from "./emitter.js";
3
3
  export { $lib } from "./lib.js";
package/src/lib.tsp CHANGED
@@ -2,5 +2,7 @@ import "../dist/src/emitter.js";
2
2
 
3
3
  using TypeSpec.Reflection;
4
4
 
5
- extern dec rustDerive(target: Model, derive: valueof string);
6
- extern dec rustDerives(target: Model, ...derives: valueof string[]);
5
+ extern dec rustDerive(target: Model | Enum, derive: valueof string);
6
+ extern dec rustDerives(target: Model | Enum, ...derives: valueof string[]);
7
+ extern dec rustAttr(target: Model | Enum, attr: valueof string);
8
+ extern dec rustAttrs(target: Model | Enum, ...attrs: valueof string[]);
@@ -165,4 +165,104 @@ describe("Rust emitter", () => {
165
165
  true,
166
166
  );
167
167
  });
168
+
169
+ it("emits custom rustDerive on enum", async () => {
170
+ const results = await emit(`
171
+ import "typespec-rust-emitter";
172
+
173
+ @rustDerive("strum::Display")
174
+ enum Status {
175
+ active,
176
+ inactive,
177
+ }
178
+ `);
179
+ const output = results["types.rs"];
180
+ strictEqual(
181
+ output.includes(
182
+ "#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, strum::Display)]",
183
+ ),
184
+ true,
185
+ );
186
+ });
187
+
188
+ it("emits multiple rustDerive macros on enum via $rustDerives", async () => {
189
+ const results = await emit(`
190
+ import "typespec-rust-emitter";
191
+
192
+ @rustDerives("strum::Display", "strum::EnumString")
193
+ enum Priority {
194
+ low,
195
+ high,
196
+ }
197
+ `);
198
+ const output = results["types.rs"];
199
+ strictEqual(
200
+ output.includes(
201
+ "#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, strum::Display, strum::EnumString)]",
202
+ ),
203
+ true,
204
+ );
205
+ });
206
+
207
+ it("emits custom rustAttr on model", async () => {
208
+ const results = await emit(`
209
+ import "typespec-rust-emitter";
210
+
211
+ @rustAttr("sqlx(type_name = \\"user\\")")
212
+ model User {
213
+ name: string;
214
+ }
215
+ `);
216
+ const output = results["types.rs"];
217
+ strictEqual(output.includes('#[sqlx(type_name = "user")]'), true);
218
+ });
219
+
220
+ it("emits multiple rustAttrs on model via $rustAttrs", async () => {
221
+ const results = await emit(`
222
+ import "typespec-rust-emitter";
223
+
224
+ @rustAttrs("sqlx(type_name = \\"user\\")", "cfg_attr(feature = \\"test\\", derive(Debug))")
225
+ model Person {
226
+ name: string;
227
+ }
228
+ `);
229
+ const output = results["types.rs"];
230
+ strictEqual(output.includes('#[sqlx(type_name = "user")]'), true);
231
+ strictEqual(
232
+ output.includes('#[cfg_attr(feature = "test", derive(Debug))]'),
233
+ true,
234
+ );
235
+ });
236
+
237
+ it("emits custom rustAttr on enum", async () => {
238
+ const results = await emit(`
239
+ import "typespec-rust-emitter";
240
+
241
+ @rustAttr("sqlx(type_name = \\"study_status\\")")
242
+ enum Status {
243
+ active,
244
+ inactive,
245
+ }
246
+ `);
247
+ const output = results["types.rs"];
248
+ strictEqual(output.includes('#[sqlx(type_name = "study_status")]'), true);
249
+ });
250
+
251
+ it("emits multiple rustAttrs on enum via $rustAttrs", async () => {
252
+ const results = await emit(`
253
+ import "typespec-rust-emitter";
254
+
255
+ @rustAttrs("sqlx(type_name = \\"priority\\")", "cfg_attr(feature = \\"test\\", derive(Debug))")
256
+ enum Priority {
257
+ low,
258
+ high,
259
+ }
260
+ `);
261
+ const output = results["types.rs"];
262
+ strictEqual(output.includes('#[sqlx(type_name = "priority")]'), true);
263
+ strictEqual(
264
+ output.includes('#[cfg_attr(feature = "test", derive(Debug))]'),
265
+ true,
266
+ );
267
+ });
168
268
  });