@mastra/claude 0.1.0 → 0.2.0-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +43 -0
- package/dist/docs/SKILL.md +2 -2
- package/dist/docs/assets/SOURCE_MAP.json +1 -1
- package/dist/docs/references/docs-agents-sdk-agents.md +168 -4
- package/dist/index.cjs +156 -21
- package/dist/index.cjs.map +1 -1
- package/dist/index.d.ts +41 -20
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +156 -21
- package/dist/index.js.map +1 -1
- package/dist/utils.d.ts +13 -4
- package/dist/utils.d.ts.map +1 -1
- package/package.json +7 -6
package/dist/index.js
CHANGED
|
@@ -6,6 +6,7 @@ import { RequestContext } from '@mastra/core/request-context';
|
|
|
6
6
|
import { ChunkFrom, MastraModelOutput } from '@mastra/core/stream';
|
|
7
7
|
import { MessageList } from '@mastra/core/agent/message-list';
|
|
8
8
|
import { getOrCreateSpan, EntityType, SpanType, executeWithContext } from '@mastra/core/observability';
|
|
9
|
+
import { toStandardSchema, standardSchemaToJSONSchema } from '@mastra/core/schema';
|
|
9
10
|
|
|
10
11
|
// src/index.ts
|
|
11
12
|
function createNoopModel({ modelId, provider }) {
|
|
@@ -33,7 +34,8 @@ function createCompletedMastraStream({
|
|
|
33
34
|
modelId,
|
|
34
35
|
usage,
|
|
35
36
|
providerMetadata,
|
|
36
|
-
costContext
|
|
37
|
+
costContext,
|
|
38
|
+
object
|
|
37
39
|
}) {
|
|
38
40
|
return new ReadableStream({
|
|
39
41
|
start(controller) {
|
|
@@ -58,7 +60,8 @@ function createCompletedMastraStream({
|
|
|
58
60
|
modelId,
|
|
59
61
|
usage,
|
|
60
62
|
providerMetadata,
|
|
61
|
-
costContext
|
|
63
|
+
costContext,
|
|
64
|
+
object
|
|
62
65
|
});
|
|
63
66
|
controller.close();
|
|
64
67
|
}
|
|
@@ -70,11 +73,12 @@ function createMastraOutput({
|
|
|
70
73
|
modelId,
|
|
71
74
|
provider,
|
|
72
75
|
stream,
|
|
76
|
+
responseText = "",
|
|
73
77
|
options
|
|
74
78
|
}) {
|
|
75
79
|
const messageList = new MessageList();
|
|
76
80
|
messageList.add(messages, "input");
|
|
77
|
-
messageList.add([{ role: "assistant", content:
|
|
81
|
+
messageList.add([{ role: "assistant", content: responseText }], "response");
|
|
78
82
|
return new MastraModelOutput({
|
|
79
83
|
model: {
|
|
80
84
|
modelId,
|
|
@@ -106,7 +110,8 @@ function toFullOutput({
|
|
|
106
110
|
modelId: result.response.modelId,
|
|
107
111
|
usage: toLanguageModelUsage(result.usage),
|
|
108
112
|
providerMetadata: result.providerMetadata,
|
|
109
|
-
costContext: result.costContext
|
|
113
|
+
costContext: result.costContext,
|
|
114
|
+
object: result.object
|
|
110
115
|
});
|
|
111
116
|
return createMastraOutput({
|
|
112
117
|
messages,
|
|
@@ -114,6 +119,7 @@ function toFullOutput({
|
|
|
114
119
|
modelId: result.response.modelId,
|
|
115
120
|
provider,
|
|
116
121
|
stream,
|
|
122
|
+
responseText: text,
|
|
117
123
|
options
|
|
118
124
|
}).getFullOutput();
|
|
119
125
|
}
|
|
@@ -484,7 +490,8 @@ function enqueueFinishChunks(controller, {
|
|
|
484
490
|
modelId,
|
|
485
491
|
usage,
|
|
486
492
|
providerMetadata,
|
|
487
|
-
costContext
|
|
493
|
+
costContext,
|
|
494
|
+
object
|
|
488
495
|
}) {
|
|
489
496
|
const timestamp = /* @__PURE__ */ new Date();
|
|
490
497
|
const response = {
|
|
@@ -508,6 +515,14 @@ function enqueueFinishChunks(controller, {
|
|
|
508
515
|
providerMetadata
|
|
509
516
|
}
|
|
510
517
|
});
|
|
518
|
+
if (object !== void 0) {
|
|
519
|
+
controller.enqueue({
|
|
520
|
+
type: "object-result",
|
|
521
|
+
runId,
|
|
522
|
+
from: ChunkFrom.AGENT,
|
|
523
|
+
object
|
|
524
|
+
});
|
|
525
|
+
}
|
|
511
526
|
controller.enqueue({
|
|
512
527
|
type: "step-finish",
|
|
513
528
|
runId,
|
|
@@ -612,6 +627,48 @@ function promptToText(prompt) {
|
|
|
612
627
|
}
|
|
613
628
|
return "";
|
|
614
629
|
}
|
|
630
|
+
function getStructuredOutputSchema(structuredOutput) {
|
|
631
|
+
if (!structuredOutput?.schema) {
|
|
632
|
+
return void 0;
|
|
633
|
+
}
|
|
634
|
+
return standardSchemaToJSONSchema(toStandardSchema(structuredOutput.schema));
|
|
635
|
+
}
|
|
636
|
+
async function getStructuredOutputFromValue(value, structuredOutput) {
|
|
637
|
+
if (!structuredOutput?.schema) {
|
|
638
|
+
return void 0;
|
|
639
|
+
}
|
|
640
|
+
let parsed;
|
|
641
|
+
if (typeof value === "string") {
|
|
642
|
+
try {
|
|
643
|
+
parsed = JSON.parse(value);
|
|
644
|
+
} catch (error) {
|
|
645
|
+
return handleStructuredOutputError(
|
|
646
|
+
new Error("Structured output must be valid JSON.", { cause: error }),
|
|
647
|
+
structuredOutput
|
|
648
|
+
);
|
|
649
|
+
}
|
|
650
|
+
} else {
|
|
651
|
+
parsed = value;
|
|
652
|
+
}
|
|
653
|
+
const schema = toStandardSchema(structuredOutput.schema);
|
|
654
|
+
const result = await schema["~standard"].validate(parsed);
|
|
655
|
+
if (!result.issues) {
|
|
656
|
+
return result.value;
|
|
657
|
+
}
|
|
658
|
+
const message = result.issues.map((issue) => `- ${issue.path?.join(".") || "root"}: ${issue.message}`).join("\n");
|
|
659
|
+
return handleStructuredOutputError(new Error(`Structured output validation failed:
|
|
660
|
+
${message}`), structuredOutput);
|
|
661
|
+
}
|
|
662
|
+
function handleStructuredOutputError(error, structuredOutput) {
|
|
663
|
+
if (structuredOutput.errorStrategy === "fallback") {
|
|
664
|
+
return structuredOutput.fallbackValue;
|
|
665
|
+
}
|
|
666
|
+
if (structuredOutput.errorStrategy === "warn") {
|
|
667
|
+
structuredOutput.logger?.warn(error.message);
|
|
668
|
+
return void 0;
|
|
669
|
+
}
|
|
670
|
+
throw error;
|
|
671
|
+
}
|
|
615
672
|
function sumDefined(...values) {
|
|
616
673
|
const defined = values.filter((value) => typeof value === "number");
|
|
617
674
|
if (defined.length === 0) {
|
|
@@ -672,9 +729,7 @@ var ClaudeSDKAgent = class extends Agent {
|
|
|
672
729
|
});
|
|
673
730
|
let result;
|
|
674
731
|
try {
|
|
675
|
-
result = await telemetry.execute(
|
|
676
|
-
() => runClaudeGenerate(prompt, this.options, telemetry, options?.abortSignal ?? options?.signal)
|
|
677
|
-
);
|
|
732
|
+
result = await telemetry.execute(() => runClaudeGenerate(prompt, this.options, telemetry, options));
|
|
678
733
|
telemetry.endGenerate(result);
|
|
679
734
|
} catch (error) {
|
|
680
735
|
telemetry.fail(error);
|
|
@@ -685,7 +740,7 @@ var ClaudeSDKAgent = class extends Agent {
|
|
|
685
740
|
runId,
|
|
686
741
|
provider: PROVIDER,
|
|
687
742
|
result,
|
|
688
|
-
options: telemetry.outputOptions()
|
|
743
|
+
options: { ...telemetry.outputOptions(), structuredOutput: options?.structuredOutput }
|
|
689
744
|
});
|
|
690
745
|
}
|
|
691
746
|
async stream(messages, options) {
|
|
@@ -718,26 +773,82 @@ var ClaudeSDKAgent = class extends Agent {
|
|
|
718
773
|
runId,
|
|
719
774
|
modelId,
|
|
720
775
|
provider: PROVIDER,
|
|
721
|
-
stream: telemetry.wrapStream(
|
|
722
|
-
|
|
723
|
-
),
|
|
724
|
-
options: telemetry.outputOptions()
|
|
776
|
+
stream: telemetry.wrapStream(runClaudeAsMastraStream(prompt, this.options, runId, telemetry, options)),
|
|
777
|
+
options: { ...telemetry.outputOptions(), structuredOutput: options?.structuredOutput }
|
|
725
778
|
});
|
|
726
779
|
}
|
|
780
|
+
async resumeGenerate(resumeData, options) {
|
|
781
|
+
const data = validateClaudeResumeData(resumeData);
|
|
782
|
+
return this.generate(data.message, createClaudeResumeRunOptions(data, options));
|
|
783
|
+
}
|
|
784
|
+
async resumeStream(resumeData, options) {
|
|
785
|
+
const data = validateClaudeResumeData(resumeData);
|
|
786
|
+
return this.stream(data.message, createClaudeResumeRunOptions(data, options));
|
|
787
|
+
}
|
|
727
788
|
};
|
|
728
|
-
|
|
789
|
+
function validateClaudeResumeData(resumeData) {
|
|
790
|
+
if (!isRecord(resumeData) || !("message" in resumeData)) {
|
|
791
|
+
throw new Error("ClaudeSDKAgent resumeData must include a message.");
|
|
792
|
+
}
|
|
793
|
+
const hasSessionId = "sessionId" in resumeData;
|
|
794
|
+
const hasContinue = "continue" in resumeData;
|
|
795
|
+
if (hasSessionId && hasContinue) {
|
|
796
|
+
throw new Error("ClaudeSDKAgent resumeData must include either sessionId or continue: true, not both.");
|
|
797
|
+
}
|
|
798
|
+
if (hasSessionId) {
|
|
799
|
+
if (typeof resumeData.sessionId !== "string") {
|
|
800
|
+
throw new Error("ClaudeSDKAgent resumeData.sessionId must be a string.");
|
|
801
|
+
}
|
|
802
|
+
return resumeData;
|
|
803
|
+
}
|
|
804
|
+
if (hasContinue) {
|
|
805
|
+
if (resumeData.continue !== true) {
|
|
806
|
+
throw new Error("ClaudeSDKAgent resumeData.continue must be true when provided.");
|
|
807
|
+
}
|
|
808
|
+
return resumeData;
|
|
809
|
+
}
|
|
810
|
+
throw new Error("ClaudeSDKAgent resumeData must include sessionId or continue: true.");
|
|
811
|
+
}
|
|
812
|
+
function createClaudeResumeRunOptions(resumeData, options) {
|
|
813
|
+
const sdkOptions = { ...options?.sdkOptions };
|
|
814
|
+
if ("sessionId" in resumeData && typeof resumeData.sessionId === "string") {
|
|
815
|
+
sdkOptions.resume = resumeData.sessionId;
|
|
816
|
+
if (resumeData.forkSession !== void 0) {
|
|
817
|
+
sdkOptions.forkSession = resumeData.forkSession;
|
|
818
|
+
}
|
|
819
|
+
if (resumeData.resumeSessionAt !== void 0) {
|
|
820
|
+
sdkOptions.resumeSessionAt = resumeData.resumeSessionAt;
|
|
821
|
+
}
|
|
822
|
+
} else {
|
|
823
|
+
sdkOptions.continue = true;
|
|
824
|
+
}
|
|
825
|
+
return {
|
|
826
|
+
...options,
|
|
827
|
+
sdkOptions
|
|
828
|
+
};
|
|
829
|
+
}
|
|
830
|
+
async function runClaudeGenerate(prompt, options, telemetry, runOptions) {
|
|
729
831
|
let text = "";
|
|
832
|
+
let structuredOutputValue;
|
|
730
833
|
const usage = createClaudeUsageCollector();
|
|
731
|
-
for await (const message of observeClaudeMessages(
|
|
834
|
+
for await (const message of observeClaudeMessages(
|
|
835
|
+
runClaude(prompt, options, runOptions?.abortSignal ?? runOptions?.signal, runOptions),
|
|
836
|
+
telemetry
|
|
837
|
+
)) {
|
|
732
838
|
usage.record(message);
|
|
733
839
|
if (message.type === "result") {
|
|
734
840
|
if (message.subtype !== "success") {
|
|
735
841
|
throw new Error(message.errors.join("\n") || `Claude Agent SDK failed with ${message.subtype}`);
|
|
736
842
|
}
|
|
737
843
|
text = message.result;
|
|
844
|
+
structuredOutputValue = getClaudeStructuredOutput(message);
|
|
738
845
|
}
|
|
739
846
|
}
|
|
740
847
|
const totals = usage.totals();
|
|
848
|
+
const object = await getStructuredOutputFromValue(
|
|
849
|
+
structuredOutputValue === void 0 ? text : structuredOutputValue,
|
|
850
|
+
runOptions?.structuredOutput
|
|
851
|
+
);
|
|
741
852
|
return {
|
|
742
853
|
content: [{ type: "text", text }],
|
|
743
854
|
finishReason: { unified: "stop", raw: "stop" },
|
|
@@ -748,10 +859,11 @@ async function runClaudeGenerate(prompt, options, telemetry, signal) {
|
|
|
748
859
|
timestamp: /* @__PURE__ */ new Date()
|
|
749
860
|
},
|
|
750
861
|
providerMetadata: getClaudeProviderMetadata(options, totals),
|
|
751
|
-
costContext: getClaudeCostContext(options, totals)
|
|
862
|
+
costContext: getClaudeCostContext(options, totals),
|
|
863
|
+
object
|
|
752
864
|
};
|
|
753
865
|
}
|
|
754
|
-
function runClaudeAsMastraStream(prompt, options, runId, telemetry,
|
|
866
|
+
function runClaudeAsMastraStream(prompt, options, runId, telemetry, runOptions) {
|
|
755
867
|
return new ReadableStream({
|
|
756
868
|
start: async (controller) => {
|
|
757
869
|
const textId = randomUUID();
|
|
@@ -759,6 +871,7 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
|
|
|
759
871
|
const modelId = getModelId(options);
|
|
760
872
|
const usage = createClaudeUsageCollector();
|
|
761
873
|
let text = "";
|
|
874
|
+
let structuredOutputValue;
|
|
762
875
|
let sawDelta = false;
|
|
763
876
|
try {
|
|
764
877
|
enqueueStartChunks(controller, {
|
|
@@ -769,7 +882,10 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
|
|
|
769
882
|
modelId,
|
|
770
883
|
providerMetadata: getClaudeProviderMetadata(options, usage.totals())
|
|
771
884
|
});
|
|
772
|
-
for await (const message of observeClaudeMessages(
|
|
885
|
+
for await (const message of observeClaudeMessages(
|
|
886
|
+
runClaude(prompt, options, runOptions?.abortSignal ?? runOptions?.signal, runOptions),
|
|
887
|
+
telemetry
|
|
888
|
+
)) {
|
|
773
889
|
usage.record(message);
|
|
774
890
|
const delta = getTextDelta(message);
|
|
775
891
|
if (delta) {
|
|
@@ -785,6 +901,7 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
|
|
|
785
901
|
text += message.result;
|
|
786
902
|
enqueueTextDelta(controller, runId, textId, message.result);
|
|
787
903
|
}
|
|
904
|
+
structuredOutputValue = getClaudeStructuredOutput(message);
|
|
788
905
|
}
|
|
789
906
|
}
|
|
790
907
|
const totals = usage.totals();
|
|
@@ -798,7 +915,11 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
|
|
|
798
915
|
modelId,
|
|
799
916
|
usage: usage.toLanguageModelUsage(),
|
|
800
917
|
providerMetadata,
|
|
801
|
-
costContext: getClaudeCostContext(options, totals)
|
|
918
|
+
costContext: getClaudeCostContext(options, totals),
|
|
919
|
+
object: await getStructuredOutputFromValue(
|
|
920
|
+
structuredOutputValue === void 0 ? text : structuredOutputValue,
|
|
921
|
+
runOptions?.structuredOutput
|
|
922
|
+
)
|
|
802
923
|
});
|
|
803
924
|
controller.close();
|
|
804
925
|
} catch (error) {
|
|
@@ -813,11 +934,19 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
|
|
|
813
934
|
}
|
|
814
935
|
});
|
|
815
936
|
}
|
|
816
|
-
function runClaude(prompt, options, signal) {
|
|
937
|
+
function runClaude(prompt, options, signal, runOptions) {
|
|
817
938
|
const abortController = createAbortController(signal);
|
|
818
939
|
const queryOptions = {
|
|
819
|
-
...options.sdkOptions
|
|
940
|
+
...options.sdkOptions,
|
|
941
|
+
...runOptions?.sdkOptions
|
|
820
942
|
};
|
|
943
|
+
const outputSchema = getStructuredOutputSchema(runOptions?.structuredOutput);
|
|
944
|
+
if (outputSchema) {
|
|
945
|
+
queryOptions.outputFormat = {
|
|
946
|
+
type: "json_schema",
|
|
947
|
+
schema: outputSchema
|
|
948
|
+
};
|
|
949
|
+
}
|
|
821
950
|
if (abortController) {
|
|
822
951
|
queryOptions.abortController = abortController;
|
|
823
952
|
}
|
|
@@ -826,6 +955,12 @@ function runClaude(prompt, options, signal) {
|
|
|
826
955
|
options: queryOptions
|
|
827
956
|
});
|
|
828
957
|
}
|
|
958
|
+
function getClaudeStructuredOutput(message) {
|
|
959
|
+
if (message.type !== "result") {
|
|
960
|
+
return void 0;
|
|
961
|
+
}
|
|
962
|
+
return message.structured_output;
|
|
963
|
+
}
|
|
829
964
|
async function* observeClaudeMessages(messages, telemetry) {
|
|
830
965
|
for await (const message of messages) {
|
|
831
966
|
recordClaudeToolTelemetry(message, telemetry);
|