typespec-rust-emitter 0.8.0 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/emitter.ts CHANGED
@@ -8,7 +8,6 @@ import {
8
8
  isRecordModelType,
9
9
  getFormat,
10
10
  getPattern,
11
- isErrorModel,
12
11
  Type,
13
12
  Model,
14
13
  ModelProperty,
@@ -39,6 +38,11 @@ interface RustAttrInfo {
39
38
 
40
39
  const rustDeriveKey = Symbol("rustDerive");
41
40
  const rustAttrKey = Symbol("rustAttr");
41
+ const rustImplKey = Symbol("rustImpl");
42
+
43
+ interface RustImplInfo {
44
+ impl: string;
45
+ }
42
46
 
43
47
  export function $rustDerive(
44
48
  context: DecoratorContext,
@@ -136,6 +140,29 @@ export function $rustAttrs(
136
140
  }
137
141
  }
138
142
 
143
+ export function $rustImpl(
144
+ context: DecoratorContext,
145
+ target: Type,
146
+ impl: string,
147
+ ) {
148
+ if (target.kind !== "Model") {
149
+ context.program.reportDiagnostic({
150
+ code: "rust-impl-invalid-target",
151
+ message: `@rustImpl can only be applied to models`,
152
+ severity: "error",
153
+ target: context.decoratorTarget,
154
+ });
155
+ return;
156
+ }
157
+
158
+ const ns = target.namespace ? getNamespaceFullName(target.namespace) : "";
159
+
160
+ if (!ns.startsWith("TypeSpec")) {
161
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
162
+ (target as any)[rustImplKey] = { impl: impl };
163
+ }
164
+ }
165
+
139
166
  type HttpMethod = "GET" | "POST" | "PUT" | "PATCH" | "DELETE" | "HEAD";
140
167
 
141
168
  interface OperationInfo {
@@ -558,13 +585,6 @@ function getHttpStatusCode(statusCode: number): string {
558
585
  return statusCodes[statusCode] || `StatusCode::from_u16(${statusCode})`;
559
586
  }
560
587
 
561
- function getHttpStatusCodeForError(errorName: string): string {
562
- if (errorName.endsWith("NotFoundError")) return "NOT_FOUND";
563
- if (errorName.endsWith("ValidationError")) return "BAD_REQUEST";
564
- if (errorName.endsWith("ConflictError")) return "CONFLICT";
565
- return "INTERNAL_SERVER_ERROR";
566
- }
567
-
568
588
  function generateServerTrait(
569
589
  program: Program,
570
590
  namespaceGroups: { namespace: Namespace; operations: Operation[] }[],
@@ -1154,24 +1174,21 @@ function emitStringLiteralUnion(union: Union): string {
1154
1174
  const variants = Array.from(union.variants.values());
1155
1175
 
1156
1176
  parts.push(
1157
- `#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]\npub enum ${name} {`,
1177
+ `#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, serde::Serialize, serde::Deserialize)]\npub enum ${name} {`,
1158
1178
  );
1159
1179
 
1160
- for (const variant of variants) {
1161
- const literalType = variant.type as StringLiteral;
1180
+ for (let i = 0; i < variants.length; i++) {
1181
+ const literalType = variants[i].type as StringLiteral;
1162
1182
  const variantName = toRustVariantName(literalType.value);
1163
1183
  const serdeValue = literalType.value;
1184
+ if (i === 0) {
1185
+ parts.push(` #[default]`);
1186
+ }
1164
1187
  parts.push(` #[serde(rename = "${serdeValue}")]`);
1165
1188
  parts.push(` ${variantName},`);
1166
1189
  }
1167
1190
 
1168
1191
  parts.push("}");
1169
- const defaultVariant = toRustVariantName(
1170
- variants[0]?.type ? (variants[0].type as StringLiteral).value : "",
1171
- );
1172
- parts.push(
1173
- `\n\nimpl Default for ${name} {\n fn default() -> Self {\n ${name}::${defaultVariant}\n }\n}`,
1174
- );
1175
1192
  return parts.join("\n");
1176
1193
  }
1177
1194
 
@@ -1183,7 +1200,6 @@ function emitModel(
1183
1200
  const parts: string[] = [];
1184
1201
  const name = toPascalCase(model.name);
1185
1202
  const allProps = getAllProperties(model, program);
1186
- const isError = isErrorModel(program, model);
1187
1203
 
1188
1204
  const deriveAttrs = [
1189
1205
  "Debug",
@@ -1200,16 +1216,8 @@ function emitModel(
1200
1216
  deriveAttrs.push(...customDerives.derives);
1201
1217
  }
1202
1218
 
1203
- if (isError) {
1204
- deriveAttrs.push("thiserror::Error");
1205
- }
1206
-
1207
1219
  parts.push(`#[derive(${deriveAttrs.join(", ")})]`);
1208
1220
 
1209
- if (isError && allProps.size > 0) {
1210
- parts.push('#[error("{code}: {message}")]');
1211
- }
1212
-
1213
1221
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
1214
1222
  const customAttrs = (model as any)[rustAttrKey] as RustAttrInfo | undefined;
1215
1223
  if (customAttrs) {
@@ -1253,14 +1261,11 @@ ${fields.join("\n")}
1253
1261
  parts.push("(());");
1254
1262
  }
1255
1263
 
1256
- if (isError && allProps.size > 0) {
1257
- const statusCode = getHttpStatusCodeForError(name);
1264
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
1265
+ const customImpl = (model as any)[rustImplKey] as RustImplInfo | undefined;
1266
+ if (customImpl) {
1258
1267
  parts.push(`
1259
- impl IntoResponse for ${name} {
1260
- fn into_response(self) -> axum::response::Response {
1261
- (StatusCode::${statusCode}, Json(self)).into_response()
1262
- }
1263
- }`);
1268
+ ${customImpl.impl}`);
1264
1269
  }
1265
1270
 
1266
1271
  return parts.join("\n");
@@ -1301,6 +1306,7 @@ function emitEnum(enumType: Enum): string {
1301
1306
  "PartialEq",
1302
1307
  "Eq",
1303
1308
  "Hash",
1309
+ "Default",
1304
1310
  "serde::Serialize",
1305
1311
  "serde::Deserialize",
1306
1312
  ];
@@ -1331,9 +1337,12 @@ function emitEnum(enumType: Enum): string {
1331
1337
  parts.push(...attrLines);
1332
1338
  }
1333
1339
  parts.push(`pub enum ${name} {`);
1334
- for (const member of members) {
1335
- const variantName = toRustVariantName(member.name);
1336
- const serdeValue = member.value ?? member.name;
1340
+ for (let i = 0; i < members.length; i++) {
1341
+ const variantName = toRustVariantName(members[i].name);
1342
+ const serdeValue = members[i].value ?? members[i].name;
1343
+ if (i === 0) {
1344
+ parts.push(` #[default]`);
1345
+ }
1337
1346
  parts.push(` #[serde(rename = "${serdeValue}")]`);
1338
1347
  parts.push(` ${variantName},`);
1339
1348
  }
@@ -1343,17 +1352,16 @@ function emitEnum(enumType: Enum): string {
1343
1352
  parts.push(...attrLines);
1344
1353
  }
1345
1354
  parts.push(`pub enum ${name} {`);
1346
- for (const member of members) {
1347
- const variantName = toRustVariantName(member.name);
1348
- const enumValue = member.value ?? 0;
1355
+ for (let i = 0; i < members.length; i++) {
1356
+ const variantName = toRustVariantName(members[i].name);
1357
+ const enumValue = members[i].value ?? 0;
1358
+ if (i === 0) {
1359
+ parts.push(` #[default]`);
1360
+ }
1349
1361
  parts.push(` ${variantName} = ${enumValue},`);
1350
1362
  }
1351
1363
  }
1352
1364
  parts.push("}");
1353
- const defaultVariant = toRustVariantName(members[0]?.name ?? "");
1354
- parts.push(
1355
- `\n\nimpl Default for ${name} {\n fn default() -> Self {\n ${name}::${defaultVariant}\n }\n}`,
1356
- );
1357
1365
  return parts.join("\n");
1358
1366
  }
1359
1367
 
@@ -1408,12 +1416,9 @@ function emitScalar(
1408
1416
  impls.push(
1409
1417
  `\nimpl TryFrom<String> for ${structName} {\n type Error = String;\n\n fn try_from(value: String) -> Result<Self, Self::Error> {\n let re = regex::Regex::new(r"${pattern}").unwrap();\n if re.is_match(&value) { Ok(Self(value)) } else { Err(format!("Invalid value: {}", value)) }\n }\n}`,
1410
1418
  );
1411
- impls.push(
1412
- `\nimpl Default for ${structName} {\n fn default() -> Self {\n Self(String::new())\n }\n}`,
1413
- );
1414
1419
 
1415
1420
  return {
1416
- typeDef: `#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]\npub struct ${structName}(pub ${rustType});`,
1421
+ typeDef: `#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, serde::Serialize, serde::Deserialize)]\npub struct ${structName}(pub ${rustType});`,
1417
1422
  impls,
1418
1423
  };
1419
1424
  }
@@ -1497,18 +1502,18 @@ export async function $onEmit(
1497
1502
  for (const [enumName, anonEnum] of anonymousEnums) {
1498
1503
  const parts: string[] = [];
1499
1504
  parts.push(
1500
- `#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]\npub enum ${enumName} {`,
1505
+ `#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, serde::Serialize, serde::Deserialize)]\npub enum ${enumName} {`,
1501
1506
  );
1502
- for (const literal of anonEnum.variants) {
1507
+ for (let i = 0; i < anonEnum.variants.length; i++) {
1508
+ const literal = anonEnum.variants[i];
1503
1509
  const variantName = toRustVariantName(literal.value);
1510
+ if (i === 0) {
1511
+ parts.push(` #[default]`);
1512
+ }
1504
1513
  parts.push(` #[serde(rename = "${literal.value}")]`);
1505
1514
  parts.push(` ${variantName},`);
1506
1515
  }
1507
1516
  parts.push("}");
1508
- const defaultVariant = toRustVariantName(anonEnum.variants[0]?.value ?? "");
1509
- parts.push(
1510
- `\n\nimpl Default for ${enumName} {\n fn default() -> Self {\n ${enumName}::${defaultVariant}\n }\n}`,
1511
- );
1512
1517
  content.push(parts.join("\n"));
1513
1518
  content.push("");
1514
1519
  }
@@ -1541,6 +1546,7 @@ export async function $onEmit(
1541
1546
 
1542
1547
  const outputContent =
1543
1548
  "#![allow(unused)]\n\n" +
1549
+ "use std::str::FromStr;\n" +
1544
1550
  "use axum::http::StatusCode;\n" +
1545
1551
  "use axum::response::IntoResponse;\n" +
1546
1552
  "use axum::Json;\n\n" +
package/src/index.ts CHANGED
@@ -1,3 +1,9 @@
1
1
  export { $onEmit } from "./emitter.js";
2
- export { $rustDerive, $rustDerives, $rustAttr, $rustAttrs } from "./emitter.js";
2
+ export {
3
+ $rustDerive,
4
+ $rustDerives,
5
+ $rustAttr,
6
+ $rustAttrs,
7
+ $rustImpl,
8
+ } from "./emitter.js";
3
9
  export { $lib } from "./lib.js";
package/src/lib.tsp CHANGED
@@ -6,3 +6,4 @@ extern dec rustDerive(target: Model | Enum, derive: valueof string);
6
6
  extern dec rustDerives(target: Model | Enum, ...derives: valueof string[]);
7
7
  extern dec rustAttr(target: Model | Enum, attr: valueof string);
8
8
  extern dec rustAttrs(target: Model | Enum, ...attrs: valueof string[]);
9
+ extern dec rustImpl(target: Model, impl: valueof string);
@@ -179,7 +179,7 @@ describe("Rust emitter", () => {
179
179
  const output = results["types.rs"];
180
180
  strictEqual(
181
181
  output.includes(
182
- "#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, strum::Display)]",
182
+ "#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, serde::Serialize, serde::Deserialize, strum::Display)]",
183
183
  ),
184
184
  true,
185
185
  );
@@ -198,7 +198,7 @@ describe("Rust emitter", () => {
198
198
  const output = results["types.rs"];
199
199
  strictEqual(
200
200
  output.includes(
201
- "#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, strum::Display, strum::EnumString)]",
201
+ "#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, serde::Serialize, serde::Deserialize, strum::Display, strum::EnumString)]",
202
202
  ),
203
203
  true,
204
204
  );
@@ -265,4 +265,67 @@ describe("Rust emitter", () => {
265
265
  true,
266
266
  );
267
267
  });
268
+
269
+ it("does not add thiserror::Error by default on @error model", async () => {
270
+ const results = await emit(`
271
+ import "typespec-rust-emitter";
272
+
273
+ @error
274
+ model ApiError {
275
+ code: int32;
276
+ message: string;
277
+ }
278
+ `);
279
+ const output = results["types.rs"];
280
+ strictEqual(output.includes("thiserror::Error"), false);
281
+ strictEqual(output.includes('#[error("{code}: {message}")]'), false);
282
+ });
283
+
284
+ it("allows user to add thiserror::Error and error attribute manually", async () => {
285
+ const results = await emit(`
286
+ import "typespec-rust-emitter";
287
+
288
+ @rustDerive("thiserror::Error")
289
+ @rustAttr("error(\\"{code}: {message}\\")")
290
+ model ApiError {
291
+ code: int32;
292
+ message: string;
293
+ }
294
+ `);
295
+ const output = results["types.rs"];
296
+ strictEqual(output.includes("thiserror::Error"), true);
297
+ strictEqual(output.includes('#[error("{code}: {message}")]'), true);
298
+ });
299
+
300
+ it("allows user to add custom impl blocks", async () => {
301
+ const results = await emit(`
302
+ import "typespec-rust-emitter";
303
+
304
+ @rustImpl("impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { self.message.into_response() } }")
305
+ model ApiError {
306
+ @doc("Human-readable message.")
307
+ message: string;
308
+ }
309
+ `);
310
+ const output = results["types.rs"];
311
+ strictEqual(output.includes("impl IntoResponse for ApiError"), true);
312
+ strictEqual(
313
+ output.includes("fn into_response(self) -> axum::response::Response"),
314
+ true,
315
+ );
316
+ strictEqual(output.includes("self.message.into_response()"), true);
317
+ });
318
+
319
+ it("uses default IntoResponse when no custom impl provided", async () => {
320
+ const results = await emit(`
321
+ import "typespec-rust-emitter";
322
+
323
+ model ApiError {
324
+ code: int32;
325
+ message: string;
326
+ }
327
+ `);
328
+ const output = results["types.rs"];
329
+ strictEqual(output.includes("impl IntoResponse for ApiError"), false);
330
+ });
268
331
  });