@codama/renderers-rust 1.0.1 → 1.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  import { CodamaError, CODAMA_ERROR__RENDERERS__UNSUPPORTED_NODE, logWarn, logError, CODAMA_ERROR__UNEXPECTED_NODE_KIND } from '@codama/errors';
2
- import { REGISTERED_TYPE_NODE_KINDS, pascalCase, isNode, resolveNestedTypeNode, remainderCountNode, arrayTypeNode, numberTypeNode, isScalarEnum, snakeCase, isNodeFilter, VALUE_NODES, structTypeNodeFromInstructionArgumentNodes, getAllInstructionsWithSubs, getAllPrograms, getAllAccounts, getAllDefinedTypes, camelCase, kebabCase, titleCase, fixedCountNode, prefixedCountNode, numberValueNode, arrayValueNode, bytesValueNode } from '@codama/nodes';
2
+ import { REGISTERED_TYPE_NODE_KINDS, pascalCase, isNode, resolveNestedTypeNode, remainderCountNode, arrayTypeNode, numberTypeNode, snakeCase, definedTypeNode, isNodeFilter, VALUE_NODES, structTypeNodeFromInstructionArgumentNodes, getAllInstructionsWithSubs, getAllPrograms, getAllAccounts, getAllDefinedTypes, camelCase, kebabCase, titleCase, assertIsNode, isScalarEnum, fixedCountNode, prefixedCountNode, numberValueNode, arrayValueNode, bytesValueNode } from '@codama/nodes';
3
3
  import { RenderMap, deleteDirectory, writeRenderMapVisitor } from '@codama/renderers-core';
4
4
  import { pipe, mergeVisitor, extendVisitor, visit, LinkableDictionary, staticVisitor, recordLinkablesVisitor, rootNodeVisitor } from '@codama/visitors-core';
5
5
  import { getBase64Encoder, getBase58Encoder, getBase16Encoder, getUtf8Encoder } from '@solana/codecs-strings';
@@ -149,10 +149,109 @@ var render = (template, context, options) => {
149
149
  env.addFilter("rustDocblock", rustDocblock);
150
150
  return env.render(template, context);
151
151
  };
152
+ var DEFAULT_TRAIT_OPTIONS = {
153
+ baseDefaults: [
154
+ "borsh::BorshSerialize",
155
+ "borsh::BorshDeserialize",
156
+ "serde::Serialize",
157
+ "serde::Deserialize",
158
+ "Clone",
159
+ "Debug",
160
+ "Eq",
161
+ "PartialEq"
162
+ ],
163
+ dataEnumDefaults: [],
164
+ featureFlags: { serde: ["serde::Serialize", "serde::Deserialize"] },
165
+ overrides: {},
166
+ scalarEnumDefaults: ["Copy", "PartialOrd", "Hash", "num_derive::FromPrimitive"],
167
+ structDefaults: [],
168
+ useFullyQualifiedName: false
169
+ };
170
+ function getTraitsFromNodeFactory(options = {}) {
171
+ return (node) => getTraitsFromNode(node, options);
172
+ }
173
+ function getTraitsFromNode(node, userOptions = {}) {
174
+ assertIsNode(node, ["accountNode", "definedTypeNode"]);
175
+ const options = { ...DEFAULT_TRAIT_OPTIONS, ...userOptions };
176
+ const nodeType = getNodeType(node);
177
+ if (nodeType === "alias") {
178
+ return { imports: new ImportMap(), render: "" };
179
+ }
180
+ const sanitizedOverrides = Object.fromEntries(
181
+ Object.entries(options.overrides).map(([key, value]) => [camelCase(key), value])
182
+ );
183
+ const nodeOverrides = sanitizedOverrides[node.name];
184
+ const allTraits = nodeOverrides === void 0 ? getDefaultTraits(nodeType, options) : nodeOverrides;
185
+ const partitionedTraits = partitionTraitsInFeatures(allTraits, options.featureFlags);
186
+ let unfeaturedTraits = partitionedTraits[0];
187
+ const featuredTraits = partitionedTraits[1];
188
+ const imports = new ImportMap();
189
+ if (!options.useFullyQualifiedName) {
190
+ unfeaturedTraits = extractFullyQualifiedNames(unfeaturedTraits, imports);
191
+ }
192
+ const traitLines = [
193
+ ...unfeaturedTraits.length > 0 ? [`#[derive(${unfeaturedTraits.join(", ")})]
194
+ `] : [],
195
+ ...Object.entries(featuredTraits).map(([feature, traits]) => {
196
+ return `#[cfg_attr(feature = "${feature}", derive(${traits.join(", ")}))]
197
+ `;
198
+ })
199
+ ];
200
+ return { imports, render: traitLines.join("") };
201
+ }
202
+ function getNodeType(node) {
203
+ if (isNode(node, "accountNode")) return "struct";
204
+ if (isNode(node.type, "structTypeNode")) return "struct";
205
+ if (isNode(node.type, "enumTypeNode")) {
206
+ return isScalarEnum(node.type) ? "scalarEnum" : "dataEnum";
207
+ }
208
+ return "alias";
209
+ }
210
+ function getDefaultTraits(nodeType, options) {
211
+ switch (nodeType) {
212
+ case "dataEnum":
213
+ return [...options.baseDefaults, ...options.dataEnumDefaults];
214
+ case "scalarEnum":
215
+ return [...options.baseDefaults, ...options.scalarEnumDefaults];
216
+ case "struct":
217
+ return [...options.baseDefaults, ...options.structDefaults];
218
+ }
219
+ }
220
+ function partitionTraitsInFeatures(traits, featureFlags) {
221
+ const reverseFeatureFlags = Object.entries(featureFlags).reduce(
222
+ (acc, [feature, traits2]) => {
223
+ for (const trait of traits2) {
224
+ if (!acc[trait]) acc[trait] = feature;
225
+ }
226
+ return acc;
227
+ },
228
+ {}
229
+ );
230
+ const unfeaturedTraits = [];
231
+ const featuredTraits = {};
232
+ for (const trait of traits) {
233
+ const feature = reverseFeatureFlags[trait];
234
+ if (feature === void 0) {
235
+ unfeaturedTraits.push(trait);
236
+ } else {
237
+ if (!featuredTraits[feature]) featuredTraits[feature] = [];
238
+ featuredTraits[feature].push(trait);
239
+ }
240
+ }
241
+ return [unfeaturedTraits, featuredTraits];
242
+ }
243
+ function extractFullyQualifiedNames(traits, imports) {
244
+ return traits.map((trait) => {
245
+ const index = trait.lastIndexOf("::");
246
+ if (index === -1) return trait;
247
+ imports.add(trait);
248
+ return trait.slice(index + 2);
249
+ });
250
+ }
152
251
 
153
252
  // src/getTypeManifestVisitor.ts
154
253
  function getTypeManifestVisitor(options) {
155
- const { getImportFrom } = options;
254
+ const { getImportFrom, getTraitsFromNode: getTraitsFromNode2 } = options;
156
255
  let parentName = options.parentName ?? null;
157
256
  let nestedStruct = options.nestedStruct ?? false;
158
257
  let inlineStruct = false;
@@ -170,13 +269,12 @@ function getTypeManifestVisitor(options) {
170
269
  visitAccount(account, { self }) {
171
270
  parentName = pascalCase(account.name);
172
271
  const manifest = visit(account.data, self);
173
- manifest.imports.add(["borsh::BorshSerialize", "borsh::BorshDeserialize"]);
272
+ const traits = getTraitsFromNode2(account);
273
+ manifest.imports.mergeWith(traits.imports);
174
274
  parentName = null;
175
275
  return {
176
276
  ...manifest,
177
- type: `#[derive(BorshSerialize, BorshDeserialize, Clone, Debug, Eq, PartialEq)]
178
- #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
179
- ${manifest.type}`
277
+ type: traits.render + manifest.type
180
278
  };
181
279
  },
182
280
  visitArrayType(arrayType, { self }) {
@@ -249,35 +347,11 @@ ${manifest.type}`
249
347
  visitDefinedType(definedType, { self }) {
250
348
  parentName = pascalCase(definedType.name);
251
349
  const manifest = visit(definedType.type, self);
350
+ const traits = getTraitsFromNode2(definedType);
351
+ manifest.imports.mergeWith(traits.imports);
252
352
  parentName = null;
253
- const traits = ["BorshSerialize", "BorshDeserialize", "Clone", "Debug", "Eq", "PartialEq"];
254
- if (isNode(definedType.type, "enumTypeNode") && isScalarEnum(definedType.type)) {
255
- traits.push("Copy", "PartialOrd", "Hash", "FromPrimitive");
256
- manifest.imports.add(["num_derive::FromPrimitive"]);
257
- }
258
- const nestedStructs = manifest.nestedStructs.map(
259
- (struct) => `#[derive(${traits.join(", ")})]
260
- #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
261
- ${struct}`
262
- );
263
- if (!isNode(definedType.type, ["enumTypeNode", "structTypeNode"])) {
264
- if (nestedStructs.length > 0) {
265
- manifest.imports.add(["borsh::BorshSerialize", "borsh::BorshDeserialize"]);
266
- }
267
- return {
268
- ...manifest,
269
- nestedStructs,
270
- type: `pub type ${pascalCase(definedType.name)} = ${manifest.type};`
271
- };
272
- }
273
- manifest.imports.add(["borsh::BorshSerialize", "borsh::BorshDeserialize"]);
274
- return {
275
- ...manifest,
276
- nestedStructs,
277
- type: `#[derive(${traits.join(", ")})]
278
- #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
279
- ${manifest.type}`
280
- };
353
+ const renderedType = isNode(definedType.type, ["enumTypeNode", "structTypeNode"]) ? manifest.type : `pub type ${pascalCase(definedType.name)} = ${manifest.type};`;
354
+ return { ...manifest, type: `${traits.render}${renderedType}` };
281
355
  },
282
356
  visitDefinedTypeLink(node) {
283
357
  const pascalCaseDefinedType = pascalCase(node.name);
@@ -493,11 +567,15 @@ ${variantNames}
493
567
  const fieldTypes = fields.map((field) => field.type).join("\n");
494
568
  const mergedManifest = mergeManifests(fields);
495
569
  if (nestedStruct) {
570
+ const nestedTraits = getTraitsFromNode2(
571
+ definedTypeNode({ name: originalParentName, type: structType })
572
+ );
573
+ mergedManifest.imports.mergeWith(nestedTraits.imports);
496
574
  return {
497
575
  ...mergedManifest,
498
576
  nestedStructs: [
499
577
  ...mergedManifest.nestedStructs,
500
- `pub struct ${pascalCase(originalParentName)} {
578
+ `${nestedTraits.render}pub struct ${pascalCase(originalParentName)} {
501
579
  ${fieldTypes}
502
580
  }`
503
581
  ],
@@ -677,7 +755,9 @@ function getRenderMapVisitor(options = {}) {
677
755
  const renderParentInstructions = options.renderParentInstructions ?? false;
678
756
  const dependencyMap = options.dependencyMap ?? {};
679
757
  const getImportFrom = getImportFromFactory(options.linkOverrides ?? {});
680
- const typeManifestVisitor = getTypeManifestVisitor({ getImportFrom });
758
+ const getTraitsFromNode2 = getTraitsFromNodeFactory(options.traitOptions);
759
+ const typeManifestVisitor = getTypeManifestVisitor({ getImportFrom, getTraitsFromNode: getTraitsFromNode2 });
760
+ const anchorTraits = options.anchorTraits ?? true;
681
761
  return pipe(
682
762
  staticVisitor(
683
763
  () => new RenderMap(),
@@ -715,6 +795,7 @@ function getRenderMapVisitor(options = {}) {
715
795
  `accounts/${snakeCase(node.name)}.rs`,
716
796
  render("accountsPage.njk", {
717
797
  account: node,
798
+ anchorTraits,
718
799
  constantSeeds,
719
800
  hasVariableSeeds,
720
801
  imports: imports.remove(`generatedAccounts::${pascalCase(node.name)}`).toString(dependencyMap),
@@ -752,6 +833,7 @@ function getRenderMapVisitor(options = {}) {
752
833
  node.arguments.forEach((argument) => {
753
834
  const argumentVisitor = getTypeManifestVisitor({
754
835
  getImportFrom,
836
+ getTraitsFromNode: getTraitsFromNode2,
755
837
  nestedStruct: true,
756
838
  parentName: `${pascalCase(node.name)}InstructionData`
757
839
  });
@@ -783,6 +865,7 @@ function getRenderMapVisitor(options = {}) {
783
865
  const struct = structTypeNodeFromInstructionArgumentNodes(node.arguments);
784
866
  const structVisitor = getTypeManifestVisitor({
785
867
  getImportFrom,
868
+ getTraitsFromNode: getTraitsFromNode2,
786
869
  parentName: `${pascalCase(node.name)}InstructionData`
787
870
  });
788
871
  const typeManifest = visit(struct, structVisitor);