@codama/renderers-rust 1.0.0 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  import { CodamaError, CODAMA_ERROR__RENDERERS__UNSUPPORTED_NODE, logWarn, logError, CODAMA_ERROR__UNEXPECTED_NODE_KIND } from '@codama/errors';
2
- import { REGISTERED_TYPE_NODE_KINDS, pascalCase, isNode, resolveNestedTypeNode, remainderCountNode, arrayTypeNode, numberTypeNode, isScalarEnum, snakeCase, isNodeFilter, VALUE_NODES, structTypeNodeFromInstructionArgumentNodes, getAllInstructionsWithSubs, getAllPrograms, getAllAccounts, getAllDefinedTypes, camelCase, kebabCase, titleCase, fixedCountNode, prefixedCountNode, numberValueNode, arrayValueNode, bytesValueNode } from '@codama/nodes';
2
+ import { REGISTERED_TYPE_NODE_KINDS, pascalCase, isNode, resolveNestedTypeNode, remainderCountNode, arrayTypeNode, numberTypeNode, snakeCase, definedTypeNode, isNodeFilter, VALUE_NODES, structTypeNodeFromInstructionArgumentNodes, getAllInstructionsWithSubs, getAllPrograms, getAllAccounts, getAllDefinedTypes, camelCase, kebabCase, titleCase, assertIsNode, isScalarEnum, fixedCountNode, prefixedCountNode, numberValueNode, arrayValueNode, bytesValueNode } from '@codama/nodes';
3
3
  import { RenderMap, deleteDirectory, writeRenderMapVisitor } from '@codama/renderers-core';
4
4
  import { pipe, mergeVisitor, extendVisitor, visit, LinkableDictionary, staticVisitor, recordLinkablesVisitor, rootNodeVisitor } from '@codama/visitors-core';
5
5
  import { getBase64Encoder, getBase58Encoder, getBase16Encoder, getUtf8Encoder } from '@solana/codecs-strings';
@@ -149,10 +149,109 @@ var render = (template, context, options) => {
149
149
  env.addFilter("rustDocblock", rustDocblock);
150
150
  return env.render(template, context);
151
151
  };
152
+ var DEFAULT_TRAIT_OPTIONS = {
153
+ baseDefaults: [
154
+ "borsh::BorshSerialize",
155
+ "borsh::BorshDeserialize",
156
+ "serde::Serialize",
157
+ "serde::Deserialize",
158
+ "Clone",
159
+ "Debug",
160
+ "Eq",
161
+ "PartialEq"
162
+ ],
163
+ dataEnumDefaults: [],
164
+ featureFlags: { serde: ["serde::Serialize", "serde::Deserialize"] },
165
+ overrides: {},
166
+ scalarEnumDefaults: ["Copy", "PartialOrd", "Hash", "num_derive::FromPrimitive"],
167
+ structDefaults: [],
168
+ useFullyQualifiedName: false
169
+ };
170
+ function getTraitsFromNodeFactory(options = {}) {
171
+ return (node) => getTraitsFromNode(node, options);
172
+ }
173
+ function getTraitsFromNode(node, userOptions = {}) {
174
+ assertIsNode(node, ["accountNode", "definedTypeNode"]);
175
+ const options = { ...DEFAULT_TRAIT_OPTIONS, ...userOptions };
176
+ const nodeType = getNodeType(node);
177
+ if (nodeType === "alias") {
178
+ return { imports: new ImportMap(), render: "" };
179
+ }
180
+ const sanitizedOverrides = Object.fromEntries(
181
+ Object.entries(options.overrides).map(([key, value]) => [camelCase(key), value])
182
+ );
183
+ const nodeOverrides = sanitizedOverrides[node.name];
184
+ const allTraits = nodeOverrides === void 0 ? getDefaultTraits(nodeType, options) : nodeOverrides;
185
+ const partitionedTraits = partitionTraitsInFeatures(allTraits, options.featureFlags);
186
+ let unfeaturedTraits = partitionedTraits[0];
187
+ const featuredTraits = partitionedTraits[1];
188
+ const imports = new ImportMap();
189
+ if (!options.useFullyQualifiedName) {
190
+ unfeaturedTraits = extractFullyQualifiedNames(unfeaturedTraits, imports);
191
+ }
192
+ const traitLines = [
193
+ ...unfeaturedTraits.length > 0 ? [`#[derive(${unfeaturedTraits.join(", ")})]
194
+ `] : [],
195
+ ...Object.entries(featuredTraits).map(([feature, traits]) => {
196
+ return `#[cfg_attr(feature = "${feature}", derive(${traits.join(", ")}))]
197
+ `;
198
+ })
199
+ ];
200
+ return { imports, render: traitLines.join("") };
201
+ }
202
+ function getNodeType(node) {
203
+ if (isNode(node, "accountNode")) return "struct";
204
+ if (isNode(node.type, "structTypeNode")) return "struct";
205
+ if (isNode(node.type, "enumTypeNode")) {
206
+ return isScalarEnum(node.type) ? "scalarEnum" : "dataEnum";
207
+ }
208
+ return "alias";
209
+ }
210
+ function getDefaultTraits(nodeType, options) {
211
+ switch (nodeType) {
212
+ case "dataEnum":
213
+ return [...options.baseDefaults, ...options.dataEnumDefaults];
214
+ case "scalarEnum":
215
+ return [...options.baseDefaults, ...options.scalarEnumDefaults];
216
+ case "struct":
217
+ return [...options.baseDefaults, ...options.structDefaults];
218
+ }
219
+ }
220
+ function partitionTraitsInFeatures(traits, featureFlags) {
221
+ const reverseFeatureFlags = Object.entries(featureFlags).reduce(
222
+ (acc, [feature, traits2]) => {
223
+ for (const trait of traits2) {
224
+ if (!acc[trait]) acc[trait] = feature;
225
+ }
226
+ return acc;
227
+ },
228
+ {}
229
+ );
230
+ const unfeaturedTraits = [];
231
+ const featuredTraits = {};
232
+ for (const trait of traits) {
233
+ const feature = reverseFeatureFlags[trait];
234
+ if (feature === void 0) {
235
+ unfeaturedTraits.push(trait);
236
+ } else {
237
+ if (!featuredTraits[feature]) featuredTraits[feature] = [];
238
+ featuredTraits[feature].push(trait);
239
+ }
240
+ }
241
+ return [unfeaturedTraits, featuredTraits];
242
+ }
243
+ function extractFullyQualifiedNames(traits, imports) {
244
+ return traits.map((trait) => {
245
+ const index = trait.lastIndexOf("::");
246
+ if (index === -1) return trait;
247
+ imports.add(trait);
248
+ return trait.slice(index + 2);
249
+ });
250
+ }
152
251
 
153
252
  // src/getTypeManifestVisitor.ts
154
253
  function getTypeManifestVisitor(options) {
155
- const { getImportFrom } = options;
254
+ const { getImportFrom, getTraitsFromNode: getTraitsFromNode2 } = options;
156
255
  let parentName = options.parentName ?? null;
157
256
  let nestedStruct = options.nestedStruct ?? false;
158
257
  let inlineStruct = false;
@@ -170,13 +269,12 @@ function getTypeManifestVisitor(options) {
170
269
  visitAccount(account, { self }) {
171
270
  parentName = pascalCase(account.name);
172
271
  const manifest = visit(account.data, self);
173
- manifest.imports.add(["borsh::BorshSerialize", "borsh::BorshDeserialize"]);
272
+ const traits = getTraitsFromNode2(account);
273
+ manifest.imports.mergeWith(traits.imports);
174
274
  parentName = null;
175
275
  return {
176
276
  ...manifest,
177
- type: `#[derive(BorshSerialize, BorshDeserialize, Clone, Debug, Eq, PartialEq)]
178
- #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
179
- ${manifest.type}`
277
+ type: traits.render + manifest.type
180
278
  };
181
279
  },
182
280
  visitArrayType(arrayType, { self }) {
@@ -249,35 +347,11 @@ ${manifest.type}`
249
347
  visitDefinedType(definedType, { self }) {
250
348
  parentName = pascalCase(definedType.name);
251
349
  const manifest = visit(definedType.type, self);
350
+ const traits = getTraitsFromNode2(definedType);
351
+ manifest.imports.mergeWith(traits.imports);
252
352
  parentName = null;
253
- const traits = ["BorshSerialize", "BorshDeserialize", "Clone", "Debug", "Eq", "PartialEq"];
254
- if (isNode(definedType.type, "enumTypeNode") && isScalarEnum(definedType.type)) {
255
- traits.push("Copy", "PartialOrd", "Hash", "FromPrimitive");
256
- manifest.imports.add(["num_derive::FromPrimitive"]);
257
- }
258
- const nestedStructs = manifest.nestedStructs.map(
259
- (struct) => `#[derive(${traits.join(", ")})]
260
- #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
261
- ${struct}`
262
- );
263
- if (!isNode(definedType.type, ["enumTypeNode", "structTypeNode"])) {
264
- if (nestedStructs.length > 0) {
265
- manifest.imports.add(["borsh::BorshSerialize", "borsh::BorshDeserialize"]);
266
- }
267
- return {
268
- ...manifest,
269
- nestedStructs,
270
- type: `pub type ${pascalCase(definedType.name)} = ${manifest.type}`
271
- };
272
- }
273
- manifest.imports.add(["borsh::BorshSerialize", "borsh::BorshDeserialize"]);
274
- return {
275
- ...manifest,
276
- nestedStructs,
277
- type: `#[derive(${traits.join(", ")})]
278
- #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
279
- ${manifest.type}`
280
- };
353
+ const renderedType = isNode(definedType.type, ["enumTypeNode", "structTypeNode"]) ? manifest.type : `pub type ${pascalCase(definedType.name)} = ${manifest.type};`;
354
+ return { ...manifest, type: `${traits.render}${renderedType}` };
281
355
  },
282
356
  visitDefinedTypeLink(node) {
283
357
  const pascalCaseDefinedType = pascalCase(node.name);
@@ -493,11 +567,15 @@ ${variantNames}
493
567
  const fieldTypes = fields.map((field) => field.type).join("\n");
494
568
  const mergedManifest = mergeManifests(fields);
495
569
  if (nestedStruct) {
570
+ const nestedTraits = getTraitsFromNode2(
571
+ definedTypeNode({ name: originalParentName, type: structType })
572
+ );
573
+ mergedManifest.imports.mergeWith(nestedTraits.imports);
496
574
  return {
497
575
  ...mergedManifest,
498
576
  nestedStructs: [
499
577
  ...mergedManifest.nestedStructs,
500
- `pub struct ${pascalCase(originalParentName)} {
578
+ `${nestedTraits.render}pub struct ${pascalCase(originalParentName)} {
501
579
  ${fieldTypes}
502
580
  }`
503
581
  ],
@@ -677,7 +755,8 @@ function getRenderMapVisitor(options = {}) {
677
755
  const renderParentInstructions = options.renderParentInstructions ?? false;
678
756
  const dependencyMap = options.dependencyMap ?? {};
679
757
  const getImportFrom = getImportFromFactory(options.linkOverrides ?? {});
680
- const typeManifestVisitor = getTypeManifestVisitor({ getImportFrom });
758
+ const getTraitsFromNode2 = getTraitsFromNodeFactory(options.traitOptions);
759
+ const typeManifestVisitor = getTypeManifestVisitor({ getImportFrom, getTraitsFromNode: getTraitsFromNode2 });
681
760
  return pipe(
682
761
  staticVisitor(
683
762
  () => new RenderMap(),
@@ -752,6 +831,7 @@ function getRenderMapVisitor(options = {}) {
752
831
  node.arguments.forEach((argument) => {
753
832
  const argumentVisitor = getTypeManifestVisitor({
754
833
  getImportFrom,
834
+ getTraitsFromNode: getTraitsFromNode2,
755
835
  nestedStruct: true,
756
836
  parentName: `${pascalCase(node.name)}InstructionData`
757
837
  });
@@ -783,6 +863,7 @@ function getRenderMapVisitor(options = {}) {
783
863
  const struct = structTypeNodeFromInstructionArgumentNodes(node.arguments);
784
864
  const structVisitor = getTypeManifestVisitor({
785
865
  getImportFrom,
866
+ getTraitsFromNode: getTraitsFromNode2,
786
867
  parentName: `${pascalCase(node.name)}InstructionData`
787
868
  });
788
869
  const typeManifest = visit(struct, structVisitor);