@mastra/claude 0.1.1-alpha.0 → 0.2.0-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -6,6 +6,7 @@ import { RequestContext } from '@mastra/core/request-context';
6
6
  import { ChunkFrom, MastraModelOutput } from '@mastra/core/stream';
7
7
  import { MessageList } from '@mastra/core/agent/message-list';
8
8
  import { getOrCreateSpan, EntityType, SpanType, executeWithContext } from '@mastra/core/observability';
9
+ import { toStandardSchema, standardSchemaToJSONSchema } from '@mastra/core/schema';
9
10
 
10
11
  // src/index.ts
11
12
  function createNoopModel({ modelId, provider }) {
@@ -33,7 +34,8 @@ function createCompletedMastraStream({
33
34
  modelId,
34
35
  usage,
35
36
  providerMetadata,
36
- costContext
37
+ costContext,
38
+ object
37
39
  }) {
38
40
  return new ReadableStream({
39
41
  start(controller) {
@@ -58,7 +60,8 @@ function createCompletedMastraStream({
58
60
  modelId,
59
61
  usage,
60
62
  providerMetadata,
61
- costContext
63
+ costContext,
64
+ object
62
65
  });
63
66
  controller.close();
64
67
  }
@@ -107,7 +110,8 @@ function toFullOutput({
107
110
  modelId: result.response.modelId,
108
111
  usage: toLanguageModelUsage(result.usage),
109
112
  providerMetadata: result.providerMetadata,
110
- costContext: result.costContext
113
+ costContext: result.costContext,
114
+ object: result.object
111
115
  });
112
116
  return createMastraOutput({
113
117
  messages,
@@ -486,7 +490,8 @@ function enqueueFinishChunks(controller, {
486
490
  modelId,
487
491
  usage,
488
492
  providerMetadata,
489
- costContext
493
+ costContext,
494
+ object
490
495
  }) {
491
496
  const timestamp = /* @__PURE__ */ new Date();
492
497
  const response = {
@@ -510,6 +515,14 @@ function enqueueFinishChunks(controller, {
510
515
  providerMetadata
511
516
  }
512
517
  });
518
+ if (object !== void 0) {
519
+ controller.enqueue({
520
+ type: "object-result",
521
+ runId,
522
+ from: ChunkFrom.AGENT,
523
+ object
524
+ });
525
+ }
513
526
  controller.enqueue({
514
527
  type: "step-finish",
515
528
  runId,
@@ -614,6 +627,48 @@ function promptToText(prompt) {
614
627
  }
615
628
  return "";
616
629
  }
630
+ function getStructuredOutputSchema(structuredOutput) {
631
+ if (!structuredOutput?.schema) {
632
+ return void 0;
633
+ }
634
+ return standardSchemaToJSONSchema(toStandardSchema(structuredOutput.schema));
635
+ }
636
+ async function getStructuredOutputFromValue(value, structuredOutput) {
637
+ if (!structuredOutput?.schema) {
638
+ return void 0;
639
+ }
640
+ let parsed;
641
+ if (typeof value === "string") {
642
+ try {
643
+ parsed = JSON.parse(value);
644
+ } catch (error) {
645
+ return handleStructuredOutputError(
646
+ new Error("Structured output must be valid JSON.", { cause: error }),
647
+ structuredOutput
648
+ );
649
+ }
650
+ } else {
651
+ parsed = value;
652
+ }
653
+ const schema = toStandardSchema(structuredOutput.schema);
654
+ const result = await schema["~standard"].validate(parsed);
655
+ if (!result.issues) {
656
+ return result.value;
657
+ }
658
+ const message = result.issues.map((issue) => `- ${issue.path?.join(".") || "root"}: ${issue.message}`).join("\n");
659
+ return handleStructuredOutputError(new Error(`Structured output validation failed:
660
+ ${message}`), structuredOutput);
661
+ }
662
+ function handleStructuredOutputError(error, structuredOutput) {
663
+ if (structuredOutput.errorStrategy === "fallback") {
664
+ return structuredOutput.fallbackValue;
665
+ }
666
+ if (structuredOutput.errorStrategy === "warn") {
667
+ structuredOutput.logger?.warn(error.message);
668
+ return void 0;
669
+ }
670
+ throw error;
671
+ }
617
672
  function sumDefined(...values) {
618
673
  const defined = values.filter((value) => typeof value === "number");
619
674
  if (defined.length === 0) {
@@ -674,9 +729,7 @@ var ClaudeSDKAgent = class extends Agent {
674
729
  });
675
730
  let result;
676
731
  try {
677
- result = await telemetry.execute(
678
- () => runClaudeGenerate(prompt, this.options, telemetry, options?.abortSignal ?? options?.signal)
679
- );
732
+ result = await telemetry.execute(() => runClaudeGenerate(prompt, this.options, telemetry, options));
680
733
  telemetry.endGenerate(result);
681
734
  } catch (error) {
682
735
  telemetry.fail(error);
@@ -687,7 +740,7 @@ var ClaudeSDKAgent = class extends Agent {
687
740
  runId,
688
741
  provider: PROVIDER,
689
742
  result,
690
- options: telemetry.outputOptions()
743
+ options: { ...telemetry.outputOptions(), structuredOutput: options?.structuredOutput }
691
744
  });
692
745
  }
693
746
  async stream(messages, options) {
@@ -720,26 +773,82 @@ var ClaudeSDKAgent = class extends Agent {
720
773
  runId,
721
774
  modelId,
722
775
  provider: PROVIDER,
723
- stream: telemetry.wrapStream(
724
- runClaudeAsMastraStream(prompt, this.options, runId, telemetry, options?.abortSignal ?? options?.signal)
725
- ),
726
- options: telemetry.outputOptions()
776
+ stream: telemetry.wrapStream(runClaudeAsMastraStream(prompt, this.options, runId, telemetry, options)),
777
+ options: { ...telemetry.outputOptions(), structuredOutput: options?.structuredOutput }
727
778
  });
728
779
  }
780
+ async resumeGenerate(resumeData, options) {
781
+ const data = validateClaudeResumeData(resumeData);
782
+ return this.generate(data.message, createClaudeResumeRunOptions(data, options));
783
+ }
784
+ async resumeStream(resumeData, options) {
785
+ const data = validateClaudeResumeData(resumeData);
786
+ return this.stream(data.message, createClaudeResumeRunOptions(data, options));
787
+ }
729
788
  };
730
- async function runClaudeGenerate(prompt, options, telemetry, signal) {
789
+ function validateClaudeResumeData(resumeData) {
790
+ if (!isRecord(resumeData) || !("message" in resumeData)) {
791
+ throw new Error("ClaudeSDKAgent resumeData must include a message.");
792
+ }
793
+ const hasSessionId = "sessionId" in resumeData;
794
+ const hasContinue = "continue" in resumeData;
795
+ if (hasSessionId && hasContinue) {
796
+ throw new Error("ClaudeSDKAgent resumeData must include either sessionId or continue: true, not both.");
797
+ }
798
+ if (hasSessionId) {
799
+ if (typeof resumeData.sessionId !== "string") {
800
+ throw new Error("ClaudeSDKAgent resumeData.sessionId must be a string.");
801
+ }
802
+ return resumeData;
803
+ }
804
+ if (hasContinue) {
805
+ if (resumeData.continue !== true) {
806
+ throw new Error("ClaudeSDKAgent resumeData.continue must be true when provided.");
807
+ }
808
+ return resumeData;
809
+ }
810
+ throw new Error("ClaudeSDKAgent resumeData must include sessionId or continue: true.");
811
+ }
812
+ function createClaudeResumeRunOptions(resumeData, options) {
813
+ const sdkOptions = { ...options?.sdkOptions };
814
+ if ("sessionId" in resumeData && typeof resumeData.sessionId === "string") {
815
+ sdkOptions.resume = resumeData.sessionId;
816
+ if (resumeData.forkSession !== void 0) {
817
+ sdkOptions.forkSession = resumeData.forkSession;
818
+ }
819
+ if (resumeData.resumeSessionAt !== void 0) {
820
+ sdkOptions.resumeSessionAt = resumeData.resumeSessionAt;
821
+ }
822
+ } else {
823
+ sdkOptions.continue = true;
824
+ }
825
+ return {
826
+ ...options,
827
+ sdkOptions
828
+ };
829
+ }
830
+ async function runClaudeGenerate(prompt, options, telemetry, runOptions) {
731
831
  let text = "";
832
+ let structuredOutputValue;
732
833
  const usage = createClaudeUsageCollector();
733
- for await (const message of observeClaudeMessages(runClaude(prompt, options, signal), telemetry)) {
834
+ for await (const message of observeClaudeMessages(
835
+ runClaude(prompt, options, runOptions?.abortSignal ?? runOptions?.signal, runOptions),
836
+ telemetry
837
+ )) {
734
838
  usage.record(message);
735
839
  if (message.type === "result") {
736
840
  if (message.subtype !== "success") {
737
841
  throw new Error(message.errors.join("\n") || `Claude Agent SDK failed with ${message.subtype}`);
738
842
  }
739
843
  text = message.result;
844
+ structuredOutputValue = getClaudeStructuredOutput(message);
740
845
  }
741
846
  }
742
847
  const totals = usage.totals();
848
+ const object = await getStructuredOutputFromValue(
849
+ structuredOutputValue === void 0 ? text : structuredOutputValue,
850
+ runOptions?.structuredOutput
851
+ );
743
852
  return {
744
853
  content: [{ type: "text", text }],
745
854
  finishReason: { unified: "stop", raw: "stop" },
@@ -750,10 +859,11 @@ async function runClaudeGenerate(prompt, options, telemetry, signal) {
750
859
  timestamp: /* @__PURE__ */ new Date()
751
860
  },
752
861
  providerMetadata: getClaudeProviderMetadata(options, totals),
753
- costContext: getClaudeCostContext(options, totals)
862
+ costContext: getClaudeCostContext(options, totals),
863
+ object
754
864
  };
755
865
  }
756
- function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
866
+ function runClaudeAsMastraStream(prompt, options, runId, telemetry, runOptions) {
757
867
  return new ReadableStream({
758
868
  start: async (controller) => {
759
869
  const textId = randomUUID();
@@ -761,6 +871,7 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
761
871
  const modelId = getModelId(options);
762
872
  const usage = createClaudeUsageCollector();
763
873
  let text = "";
874
+ let structuredOutputValue;
764
875
  let sawDelta = false;
765
876
  try {
766
877
  enqueueStartChunks(controller, {
@@ -771,7 +882,10 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
771
882
  modelId,
772
883
  providerMetadata: getClaudeProviderMetadata(options, usage.totals())
773
884
  });
774
- for await (const message of observeClaudeMessages(runClaude(prompt, options, signal), telemetry)) {
885
+ for await (const message of observeClaudeMessages(
886
+ runClaude(prompt, options, runOptions?.abortSignal ?? runOptions?.signal, runOptions),
887
+ telemetry
888
+ )) {
775
889
  usage.record(message);
776
890
  const delta = getTextDelta(message);
777
891
  if (delta) {
@@ -787,6 +901,7 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
787
901
  text += message.result;
788
902
  enqueueTextDelta(controller, runId, textId, message.result);
789
903
  }
904
+ structuredOutputValue = getClaudeStructuredOutput(message);
790
905
  }
791
906
  }
792
907
  const totals = usage.totals();
@@ -800,7 +915,11 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
800
915
  modelId,
801
916
  usage: usage.toLanguageModelUsage(),
802
917
  providerMetadata,
803
- costContext: getClaudeCostContext(options, totals)
918
+ costContext: getClaudeCostContext(options, totals),
919
+ object: await getStructuredOutputFromValue(
920
+ structuredOutputValue === void 0 ? text : structuredOutputValue,
921
+ runOptions?.structuredOutput
922
+ )
804
923
  });
805
924
  controller.close();
806
925
  } catch (error) {
@@ -815,11 +934,19 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
815
934
  }
816
935
  });
817
936
  }
818
- function runClaude(prompt, options, signal) {
937
+ function runClaude(prompt, options, signal, runOptions) {
819
938
  const abortController = createAbortController(signal);
820
939
  const queryOptions = {
821
- ...options.sdkOptions
940
+ ...options.sdkOptions,
941
+ ...runOptions?.sdkOptions
822
942
  };
943
+ const outputSchema = getStructuredOutputSchema(runOptions?.structuredOutput);
944
+ if (outputSchema) {
945
+ queryOptions.outputFormat = {
946
+ type: "json_schema",
947
+ schema: outputSchema
948
+ };
949
+ }
823
950
  if (abortController) {
824
951
  queryOptions.abortController = abortController;
825
952
  }
@@ -828,6 +955,12 @@ function runClaude(prompt, options, signal) {
828
955
  options: queryOptions
829
956
  });
830
957
  }
958
+ function getClaudeStructuredOutput(message) {
959
+ if (message.type !== "result") {
960
+ return void 0;
961
+ }
962
+ return message.structured_output;
963
+ }
831
964
  async function* observeClaudeMessages(messages, telemetry) {
832
965
  for await (const message of messages) {
833
966
  recordClaudeToolTelemetry(message, telemetry);