typespec-rust-emitter 0.10.7 → 0.12.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/emitter.ts CHANGED
@@ -39,6 +39,9 @@ interface RustAttrInfo {
39
39
  const rustDeriveKey = Symbol("rustDerive");
40
40
  const rustAttrKey = Symbol("rustAttr");
41
41
  const rustImplKey = Symbol("rustImpl");
42
+ const rustSelfReceiverKey = Symbol("rustSelfReceiver");
43
+
44
+ type SelfReceiver = "&self" | "&mut self" | "self";
42
45
 
43
46
  interface RustImplInfo {
44
47
  impl: string;
@@ -168,6 +171,42 @@ export function $rustImpl(
168
171
  }
169
172
  }
170
173
 
174
+ export function $rustMut(context: DecoratorContext, target: Type) {
175
+ if (target.kind !== "Operation") {
176
+ context.program.reportDiagnostic({
177
+ code: "rust-mut-invalid-target",
178
+ message: `@rustMut can only be applied to operations`,
179
+ severity: "error",
180
+ target: context.decoratorTarget,
181
+ });
182
+ return;
183
+ }
184
+
185
+ const ns = target.namespace ? getNamespaceFullName(target.namespace) : "";
186
+ if (!ns.startsWith("TypeSpec")) {
187
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
188
+ (target as any)[rustSelfReceiverKey] = "&mut self";
189
+ }
190
+ }
191
+
192
+ export function $rustOwn(context: DecoratorContext, target: Type) {
193
+ if (target.kind !== "Operation") {
194
+ context.program.reportDiagnostic({
195
+ code: "rust-own-invalid-target",
196
+ message: `@rustOwn can only be applied to operations`,
197
+ severity: "error",
198
+ target: context.decoratorTarget,
199
+ });
200
+ return;
201
+ }
202
+
203
+ const ns = target.namespace ? getNamespaceFullName(target.namespace) : "";
204
+ if (!ns.startsWith("TypeSpec")) {
205
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
206
+ (target as any)[rustSelfReceiverKey] = "self";
207
+ }
208
+ }
209
+
171
210
  type HttpMethod = "GET" | "POST" | "PUT" | "PATCH" | "DELETE" | "HEAD";
172
211
 
173
212
  interface OperationInfo {
@@ -298,6 +337,14 @@ function hasAuthDecorator(operation: Operation): boolean {
298
337
  return false;
299
338
  }
300
339
 
340
+ function getSelfReceiver(operation: Operation): SelfReceiver {
341
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
342
+ const receiver = (operation as any)[rustSelfReceiverKey] as
343
+ | SelfReceiver
344
+ | undefined;
345
+ return receiver ?? "&self";
346
+ }
347
+
301
348
  function getOperationParameters(
302
349
  program: Program,
303
350
  operation: Operation,
@@ -316,7 +363,11 @@ function getOperationParameters(
316
363
  for (const key of Object.keys(decorators)) {
317
364
  const decorator = decorators[key];
318
365
  const name = getDecoratorName(decorator);
319
- if (name === "body" || name === "bodyRoot") {
366
+ if (
367
+ name === "body" ||
368
+ name === "bodyRoot" ||
369
+ name === "multipartBody"
370
+ ) {
320
371
  isBody = true;
321
372
  break;
322
373
  }
@@ -379,7 +430,7 @@ function getOperationBody(operation: Operation): ModelProperty | undefined {
379
430
  for (const key of Object.keys(decorators)) {
380
431
  const decorator = decorators[key];
381
432
  const name = getDecoratorName(decorator);
382
- if (name === "body" || name === "bodyRoot") {
433
+ if (name === "body" || name === "bodyRoot" || name === "multipartBody") {
383
434
  return prop;
384
435
  }
385
436
  }
@@ -387,6 +438,23 @@ function getOperationBody(operation: Operation): ModelProperty | undefined {
387
438
  return undefined;
388
439
  }
389
440
 
441
+ function hasMultipartBody(operation: Operation): boolean {
442
+ for (const [_propName, prop] of operation.parameters.properties) {
443
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
444
+ const decorators = (prop as any).decorators;
445
+ if (!decorators) continue;
446
+
447
+ for (const key of Object.keys(decorators)) {
448
+ const decorator = decorators[key];
449
+ const name = getDecoratorName(decorator);
450
+ if (name === "multipartBody") {
451
+ return true;
452
+ }
453
+ }
454
+ }
455
+ return false;
456
+ }
457
+
390
458
  function getOperationResponses(
391
459
  program: Program,
392
460
  operation: Operation,
@@ -439,6 +507,7 @@ function getOperationResponses(
439
507
  });
440
508
  return responses;
441
509
  }
510
+ let foundStatusCode = false;
442
511
  for (const [propName, prop] of model.properties) {
443
512
  if (propName === "body") {
444
513
  const { type: rustType } = getRustTypeForProperty(
@@ -451,8 +520,27 @@ function getOperationResponses(
451
520
  bodyType: rustType,
452
521
  bodyDescription: getDoc(program, prop),
453
522
  });
523
+ } else if (propName === "statusCode") {
524
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
525
+ const typeAny = prop.type as any;
526
+ if (typeAny.value !== undefined) {
527
+ const statusCode = typeAny.value as number;
528
+ responses.push({
529
+ statusCode,
530
+ bodyType: undefined,
531
+ bodyDescription: undefined,
532
+ });
533
+ foundStatusCode = true;
534
+ }
454
535
  }
455
536
  }
537
+ if (!foundStatusCode && responses.length === 0) {
538
+ responses.push({
539
+ statusCode: 200,
540
+ bodyType: undefined,
541
+ bodyDescription: undefined,
542
+ });
543
+ }
456
544
  }
457
545
 
458
546
  return responses;
@@ -691,34 +779,39 @@ pub trait Server: Send + Sync {
691
779
 
692
780
  // Add body parameter
693
781
  if (opInfo.body) {
694
- const bodyType = getRustTypeForProperty(
695
- opInfo.body.type,
696
- program,
697
- anonymousEnums,
698
- );
699
- paramParts.push(`body: ${bodyType.type}`);
782
+ if (hasMultipartBody(op)) {
783
+ paramParts.push(`body: Multipart`);
784
+ } else {
785
+ const bodyType = getRustTypeForProperty(
786
+ opInfo.body.type,
787
+ program,
788
+ anonymousEnums,
789
+ );
790
+ paramParts.push(`body: ${bodyType.type}`);
791
+ }
700
792
  }
701
793
 
702
794
  const paramsStr = paramParts.join(", ");
795
+ const selfReceiver = getSelfReceiver(op);
703
796
 
704
797
  if (isProtected) {
705
798
  if (paramsStr) {
706
799
  parts.push(
707
- ` async fn ${fnName}(&self, claims: Self::Claims, ${paramsStr}) -> Result<${responseName}>;`,
800
+ ` async fn ${fnName}(${selfReceiver}, claims: Self::Claims, ${paramsStr}) -> Result<${responseName}>;`,
708
801
  );
709
802
  } else {
710
803
  parts.push(
711
- ` async fn ${fnName}(&self, claims: Self::Claims) -> Result<${responseName}>;`,
804
+ ` async fn ${fnName}(${selfReceiver}, claims: Self::Claims) -> Result<${responseName}>;`,
712
805
  );
713
806
  }
714
807
  } else {
715
808
  if (paramsStr) {
716
809
  parts.push(
717
- ` async fn ${fnName}(&self, ${paramsStr}) -> Result<${responseName}>;`,
810
+ ` async fn ${fnName}(${selfReceiver}, ${paramsStr}) -> Result<${responseName}>;`,
718
811
  );
719
812
  } else {
720
813
  parts.push(
721
- ` async fn ${fnName}(&self) -> Result<${responseName}>;`,
814
+ ` async fn ${fnName}(${selfReceiver}) -> Result<${responseName}>;`,
722
815
  );
723
816
  }
724
817
  }
@@ -827,6 +920,7 @@ function generateRouter(
827
920
  const handlerFnName = toRustIdent(`${nsName}_${opInfo.name}`);
828
921
  const traitFnName = handlerFnName;
829
922
  const isProtected = hasAuthDecorator(op);
923
+ const selfReceiver = getSelfReceiver(op);
830
924
 
831
925
  const pathParams = opInfo.parameters.filter((p) => p.location === "path");
832
926
  const hasPathParams = pathParams.length > 0;
@@ -835,14 +929,17 @@ function generateRouter(
835
929
  );
836
930
  const hasQueryParams = queryParams.length > 0;
837
931
  const hasBody = !!opInfo.body;
932
+ const isMultipartBody = hasMultipartBody(op);
838
933
 
839
934
  // Build extractor lines and server method call arguments
840
935
  // IMPORTANT: axum requires specific extractor order:
841
- // State -> Extension -> Path -> Query -> Json -> Body
936
+ // State -> Extension -> Path -> Query -> Json/Multipart -> Body
842
937
  const extractorLines: string[] = [];
843
938
  const serverArgs: string[] = [];
844
939
 
845
940
  // State is always first (added in handler template)
941
+ const serviceBinding =
942
+ selfReceiver === "&mut self" ? "mut service" : "service";
846
943
 
847
944
  // Extension (claims) comes after State
848
945
  if (isProtected) {
@@ -885,15 +982,20 @@ function generateRouter(
885
982
  }
886
983
  }
887
984
 
888
- // Json body comes last
985
+ // Body comes last (Json or Multipart based on decorator)
889
986
  if (hasBody && opInfo.body) {
890
- const bodyType = getRustTypeForProperty(
891
- opInfo.body.type,
892
- program,
893
- anonymousEnums,
894
- );
895
- extractorLines.push(` Json(payload): Json<${bodyType.type}>,`);
896
- serverArgs.push(`payload`);
987
+ if (isMultipartBody) {
988
+ extractorLines.push(` multipart: axum::extract::Multipart,`);
989
+ serverArgs.push(`multipart`);
990
+ } else {
991
+ const bodyType = getRustTypeForProperty(
992
+ opInfo.body.type,
993
+ program,
994
+ anonymousEnums,
995
+ );
996
+ extractorLines.push(` Json(payload): Json<${bodyType.type}>,`);
997
+ serverArgs.push(`payload`);
998
+ }
897
999
  }
898
1000
 
899
1001
  // Build server method call
@@ -901,12 +1003,37 @@ function generateRouter(
901
1003
  const serverCall = `service.${traitFnName}(${serverArgsStr}).await`;
902
1004
 
903
1005
  // All handlers use <S> generics, Claims is now an associated type
904
- const handlerCode = `pub async fn ${handlerFnName}_handler<S>(
1006
+ // For &mut self, we need Clone because service is extracted multiple times
1007
+ // For self, we can't use Clone (would need Arc/Mutex or different pattern)
1008
+ const needsClone = selfReceiver !== "self" ? "+ Clone" : "";
1009
+ const handlerCode =
1010
+ selfReceiver === "self"
1011
+ ? `// NOTE: ${handlerFnName} takes self and cannot be used with the router pattern.
1012
+ // It consumes the service, so you need to implement your own handler pattern.
1013
+ pub async fn ${handlerFnName}_handler<S>(
905
1014
  axum::extract::State(service): axum::extract::State<S>,
906
1015
  ${extractorLines.join("\n")}
907
1016
  ) -> impl axum::response::IntoResponse
908
1017
  where
909
- S: Server + Clone + Send + Sync + 'static,
1018
+ S: Server + Send + Sync + 'static,
1019
+ S::Claims: Send + Sync + Clone + 'static,
1020
+ {
1021
+ let result = service.${traitFnName}(${serverArgsStr}).await;
1022
+ match result {
1023
+ Ok(response) => response.into_response(),
1024
+ Err(e) => (
1025
+ axum::http::StatusCode::INTERNAL_SERVER_ERROR,
1026
+ format!("Internal error: {e}"),
1027
+ )
1028
+ .into_response(),
1029
+ }
1030
+ }`
1031
+ : `pub async fn ${handlerFnName}_handler<S>(
1032
+ axum::extract::State(${serviceBinding}): axum::extract::State<S>,
1033
+ ${extractorLines.join("\n")}
1034
+ ) -> impl axum::response::IntoResponse
1035
+ where
1036
+ S: Server${needsClone} + Send + Sync + 'static,
910
1037
  S::Claims: Send + Sync + Clone + 'static,
911
1038
  {
912
1039
  let result = ${serverCall};
@@ -922,6 +1049,11 @@ where
922
1049
 
923
1050
  handlers.push(handlerCode);
924
1051
 
1052
+ // Don't add routes for self methods (they consume the service)
1053
+ if (selfReceiver === "self") {
1054
+ continue;
1055
+ }
1056
+
925
1057
  const routePath = `"${opInfo.path}"`;
926
1058
  let routeStmt = "";
927
1059
  if (isProtected) {
@@ -939,7 +1071,7 @@ where
939
1071
  const routerBody = buildRouterBody(publicRoutes, protectedRoutes);
940
1072
 
941
1073
  const parts: string[] = [];
942
- parts.push(`use axum::extract::Query;
1074
+ parts.push(`use axum::extract::{Query, Multipart};
943
1075
  use axum::routing::{${methodImports}};
944
1076
  use axum::Router;
945
1077
 
package/src/index.ts CHANGED
@@ -5,5 +5,7 @@ export {
5
5
  $rustAttr,
6
6
  $rustAttrs,
7
7
  $rustImpl,
8
+ $rustMut,
9
+ $rustOwn,
8
10
  } from "./emitter.js";
9
11
  export { $lib } from "./lib.js";
package/src/lib.tsp CHANGED
@@ -7,3 +7,5 @@ extern dec rustDerives(target: Model | Enum, ...derives: valueof string[]);
7
7
  extern dec rustAttr(target: Model | Enum | ModelProperty, attr: valueof string);
8
8
  extern dec rustAttrs(target: Model | Enum | ModelProperty, ...attrs: valueof string[]);
9
9
  extern dec rustImpl(target: Model, impl: valueof string);
10
+ extern dec rustMut(target: Operation);
11
+ extern dec rustOwn(target: Operation);
@@ -1,6 +1,6 @@
1
1
  import { strictEqual } from "node:assert";
2
2
  import { describe, it } from "node:test";
3
- import { emit } from "./test-host.js";
3
+ import { emit, emitWithDiagnostics } from "./test-host.js";
4
4
 
5
5
  describe("Rust emitter", () => {
6
6
  it("emits basic model", async () => {
@@ -443,4 +443,146 @@ describe("Rust emitter", () => {
443
443
  true,
444
444
  );
445
445
  });
446
+
447
+ it("uses &self by default in trait methods", async () => {
448
+ const results = await emit(`
449
+ import "@typespec/http";
450
+ import "typespec-rust-emitter";
451
+ using TypeSpec.Http;
452
+
453
+ @route("/test")
454
+ namespace Test {
455
+ @get
456
+ op getItem(): { @statusCode status: 200; @body body: string };
457
+ }
458
+ `);
459
+ const server = results["server.rs"];
460
+ strictEqual(server.includes("async fn test_get_item(&self)"), true);
461
+ });
462
+
463
+ it("uses &mut self with @rustMut decorator", async () => {
464
+ const results = await emit(`
465
+ import "@typespec/http";
466
+ import "typespec-rust-emitter";
467
+ using TypeSpec.Http;
468
+
469
+ @route("/test")
470
+ namespace Test {
471
+ @rustMut
472
+ @post
473
+ op createItem(@body name: string): { @statusCode status: 200; @body body: string };
474
+ }
475
+ `);
476
+ const server = results["server.rs"];
477
+ strictEqual(server.includes("test_create_item(&mut self,"), true);
478
+ });
479
+
480
+ it("uses self with @rustOwn decorator", async () => {
481
+ const results = await emit(`
482
+ import "@typespec/http";
483
+ import "typespec-rust-emitter";
484
+ using TypeSpec.Http;
485
+
486
+ @route("/test")
487
+ namespace Test {
488
+ @rustOwn
489
+ @delete
490
+ op deleteItem(): { @statusCode status: 200; @body body: string };
491
+ }
492
+ `);
493
+ const server = results["server.rs"];
494
+ strictEqual(server.includes("test_delete_item(self)"), true);
495
+ });
496
+
497
+ it("@rustMut works with protected routes", async () => {
498
+ const results = await emit(`
499
+ import "@typespec/http";
500
+ import "typespec-rust-emitter";
501
+ using TypeSpec.Http;
502
+
503
+ @route("/test")
504
+ namespace Test {
505
+ @rustMut
506
+ @post
507
+ op createItem(@body name: string, @header Authorization: string): { @statusCode status: 200; @body body: string };
508
+ }
509
+ `);
510
+ const server = results["server.rs"];
511
+ strictEqual(server.includes("test_create_item(&mut self,"), true);
512
+ });
513
+
514
+ it("@rustOwn works with protected routes", async () => {
515
+ const results = await emit(`
516
+ import "@typespec/http";
517
+ import "typespec-rust-emitter";
518
+ using TypeSpec.Http;
519
+
520
+ @route("/test")
521
+ namespace Test {
522
+ @rustOwn
523
+ @delete
524
+ op deleteItem(@query id: string): { @statusCode status: 200; @body body: string };
525
+ }
526
+ `);
527
+ const server = results["server.rs"];
528
+ strictEqual(server.includes("test_delete_item(self,"), true);
529
+ });
530
+
531
+ it("reports error when @rustMut is applied to non-operation", async () => {
532
+ const [, diagnostics] = await emitWithDiagnostics(`
533
+ import "@typespec/http";
534
+ import "typespec-rust-emitter";
535
+ using TypeSpec.Http;
536
+
537
+ @rustMut
538
+ model Test {
539
+ name: string;
540
+ }
541
+ `);
542
+ const hasError = diagnostics.some(
543
+ (d) => d.code === "decorator-wrong-target",
544
+ );
545
+ strictEqual(hasError, true);
546
+ });
547
+
548
+ it("reports error when @rustOwn is applied to non-operation", async () => {
549
+ const [, diagnostics] = await emitWithDiagnostics(`
550
+ import "@typespec/http";
551
+ import "typespec-rust-emitter";
552
+ using TypeSpec.Http;
553
+
554
+ @rustOwn
555
+ model Test {
556
+ name: string;
557
+ }
558
+ `);
559
+ const hasError = diagnostics.some(
560
+ (d) => d.code === "decorator-wrong-target",
561
+ );
562
+ strictEqual(hasError, true);
563
+ });
564
+
565
+ it("emits multipartBody with Multipart extractor", async () => {
566
+ const results = await emit(`
567
+ import "@typespec/http";
568
+ using TypeSpec.Http;
569
+
570
+ @route("/upload")
571
+ namespace Upload {
572
+ @post
573
+ op uploadFile(
574
+ @path accountId: string,
575
+ @multipartBody body: {
576
+ image: HttpPart<File>,
577
+ },
578
+ ): {
579
+ @statusCode statusCode: 201;
580
+ @body body: string;
581
+ };
582
+ }
583
+ `);
584
+ const server = results["server.rs"];
585
+ strictEqual(server.includes("multipart: axum::extract::Multipart,"), true);
586
+ strictEqual(server.includes("body: Multipart"), true);
587
+ });
446
588
  });