@ai-sdk/google-vertex 0.0.40 → 0.0.43

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.mjs CHANGED
@@ -1,7 +1,4 @@
1
1
  // src/google-vertex-provider.ts
2
- import {
3
- NoSuchModelError
4
- } from "@ai-sdk/provider";
5
2
  import { generateId, loadSetting } from "@ai-sdk/provider-utils";
6
3
  import { VertexAI as VertexAI2 } from "@google-cloud/vertexai";
7
4
 
@@ -11,7 +8,7 @@ import {
11
8
  } from "@ai-sdk/provider";
12
9
  import { convertAsyncGeneratorToReadableStream } from "@ai-sdk/provider-utils";
13
10
  import {
14
- FunctionCallingMode
11
+ FunctionCallingMode as FunctionCallingMode2
15
12
  } from "@google-cloud/vertexai";
16
13
 
17
14
  // src/convert-json-schema-to-openapi-schema.ts
@@ -88,7 +85,7 @@ import {
88
85
  } from "@ai-sdk/provider";
89
86
  import { convertUint8ArrayToBase64 } from "@ai-sdk/provider-utils";
90
87
  function convertToGoogleVertexContentRequest(prompt) {
91
- var _a;
88
+ var _a, _b;
92
89
  const systemInstructionParts = [];
93
90
  const contents = [];
94
91
  let systemMessagesAllowed = true;
@@ -113,28 +110,35 @@ function convertToGoogleVertexContentRequest(prompt) {
113
110
  break;
114
111
  }
115
112
  case "image": {
116
- if (part.image instanceof URL) {
117
- throw new UnsupportedFunctionalityError({
118
- functionality: "Image URLs in user messages"
119
- });
120
- }
121
- parts.push({
122
- inlineData: {
123
- mimeType: (_a = part.mimeType) != null ? _a : "image/jpeg",
124
- data: convertUint8ArrayToBase64(part.image)
113
+ parts.push(
114
+ part.image instanceof URL ? {
115
+ fileData: {
116
+ mimeType: (_a = part.mimeType) != null ? _a : "image/jpeg",
117
+ fileUri: part.image.toString()
118
+ }
119
+ } : {
120
+ inlineData: {
121
+ mimeType: (_b = part.mimeType) != null ? _b : "image/jpeg",
122
+ data: convertUint8ArrayToBase64(part.image)
123
+ }
125
124
  }
126
- });
125
+ );
127
126
  break;
128
127
  }
129
128
  case "file": {
130
- if (part.data instanceof URL) {
131
- throw new UnsupportedFunctionalityError({
132
- functionality: "File URLs in user messages"
133
- });
134
- }
135
- parts.push({
136
- inlineData: { mimeType: part.mimeType, data: part.data }
137
- });
129
+ parts.push(
130
+ part.data instanceof URL ? {
131
+ fileData: {
132
+ mimeType: part.mimeType,
133
+ fileUri: part.data.toString()
134
+ }
135
+ } : {
136
+ inlineData: {
137
+ mimeType: part.mimeType,
138
+ data: part.data
139
+ }
140
+ }
141
+ );
138
142
  break;
139
143
  }
140
144
  default: {
@@ -203,6 +207,91 @@ function convertToGoogleVertexContentRequest(prompt) {
203
207
  };
204
208
  }
205
209
 
210
+ // src/google-vertex-prepare-tools.ts
211
+ import {
212
+ FunctionCallingMode
213
+ } from "@google-cloud/vertexai";
214
+ function prepareTools({
215
+ useSearchGrounding,
216
+ mode
217
+ }) {
218
+ var _a, _b;
219
+ const tools = ((_a = mode.tools) == null ? void 0 : _a.length) ? mode.tools : void 0;
220
+ const toolWarnings = [];
221
+ const vertexTools = [];
222
+ if (tools != null) {
223
+ const functionDeclarations = [];
224
+ for (const tool of tools) {
225
+ if (tool.type === "provider-defined") {
226
+ toolWarnings.push({ type: "unsupported-tool", tool });
227
+ } else {
228
+ functionDeclarations.push({
229
+ name: tool.name,
230
+ description: (_b = tool.description) != null ? _b : "",
231
+ parameters: convertJSONSchemaToOpenAPISchema(
232
+ tool.parameters
233
+ )
234
+ });
235
+ }
236
+ }
237
+ vertexTools.push({ functionDeclarations });
238
+ }
239
+ if (useSearchGrounding) {
240
+ vertexTools.push({ googleSearchRetrieval: {} });
241
+ }
242
+ const finalTools = vertexTools.length > 0 ? vertexTools : void 0;
243
+ const toolChoice = mode.toolChoice;
244
+ if (toolChoice == null) {
245
+ return {
246
+ tools: finalTools,
247
+ toolConfig: void 0,
248
+ toolWarnings
249
+ };
250
+ }
251
+ const type = toolChoice.type;
252
+ switch (type) {
253
+ case "auto":
254
+ return {
255
+ tools: finalTools,
256
+ toolConfig: {
257
+ functionCallingConfig: { mode: FunctionCallingMode.AUTO }
258
+ },
259
+ toolWarnings
260
+ };
261
+ case "none":
262
+ return {
263
+ tools: finalTools,
264
+ toolConfig: {
265
+ functionCallingConfig: { mode: FunctionCallingMode.NONE }
266
+ },
267
+ toolWarnings
268
+ };
269
+ case "required":
270
+ return {
271
+ tools: finalTools,
272
+ toolConfig: {
273
+ functionCallingConfig: { mode: FunctionCallingMode.ANY }
274
+ },
275
+ toolWarnings
276
+ };
277
+ case "tool":
278
+ return {
279
+ tools: finalTools,
280
+ toolConfig: {
281
+ functionCallingConfig: {
282
+ mode: FunctionCallingMode.ANY,
283
+ allowedFunctionNames: [toolChoice.toolName]
284
+ }
285
+ },
286
+ toolWarnings
287
+ };
288
+ default: {
289
+ const _exhaustiveCheck = type;
290
+ throw new Error(`Unsupported tool choice type: ${_exhaustiveCheck}`);
291
+ }
292
+ }
293
+ }
294
+
206
295
  // src/map-google-vertex-finish-reason.ts
207
296
  function mapGoogleVertexFinishReason({
208
297
  finishReason,
@@ -295,19 +384,21 @@ var GoogleVertexLanguageModel = class {
295
384
  const type = mode.type;
296
385
  switch (type) {
297
386
  case "regular": {
387
+ const { tools, toolConfig, toolWarnings } = prepareTools({
388
+ mode,
389
+ useSearchGrounding: (_a = this.settings.useSearchGrounding) != null ? _a : false
390
+ });
298
391
  const configuration = {
299
392
  model: this.modelId,
300
393
  generationConfig,
301
- ...prepareToolsAndToolConfig({
302
- mode,
303
- useSearchGrounding: (_a = this.settings.useSearchGrounding) != null ? _a : false
304
- }),
394
+ tools,
395
+ toolConfig,
305
396
  safetySettings: this.settings.safetySettings
306
397
  };
307
398
  return {
308
399
  model: this.config.vertexAI.getGenerativeModel(configuration),
309
400
  contentRequest: convertToGoogleVertexContentRequest(prompt),
310
- warnings
401
+ warnings: [...warnings, ...toolWarnings]
311
402
  };
312
403
  }
313
404
  case "object-json": {
@@ -347,7 +438,7 @@ var GoogleVertexLanguageModel = class {
347
438
  }
348
439
  ],
349
440
  toolConfig: {
350
- functionCallingConfig: { mode: FunctionCallingMode.ANY }
441
+ functionCallingConfig: { mode: FunctionCallingMode2.ANY }
351
442
  },
352
443
  safetySettings: this.settings.safetySettings
353
444
  };
@@ -363,6 +454,9 @@ var GoogleVertexLanguageModel = class {
363
454
  }
364
455
  }
365
456
  }
457
+ supportsUrl(url) {
458
+ return url.protocol === "gs:";
459
+ }
366
460
  async doGenerate(options) {
367
461
  var _a, _b, _c;
368
462
  const { model, contentRequest, warnings } = await this.getArgs(options);
@@ -392,6 +486,11 @@ var GoogleVertexLanguageModel = class {
392
486
  rawPrompt: contentRequest,
393
487
  rawSettings: {}
394
488
  },
489
+ providerMetadata: this.settings.useSearchGrounding ? {
490
+ vertex: {
491
+ groundingMetadata: firstCandidate.groundingMetadata
492
+ }
493
+ } : void 0,
395
494
  warnings
396
495
  };
397
496
  }
@@ -405,6 +504,7 @@ var GoogleVertexLanguageModel = class {
405
504
  };
406
505
  const generateId2 = this.config.generateId;
407
506
  let hasToolCalls = false;
507
+ let providerMetadata;
408
508
  return {
409
509
  stream: convertAsyncGeneratorToReadableStream(stream).pipeThrough(
410
510
  new TransformStream(
@@ -428,6 +528,13 @@ var GoogleVertexLanguageModel = class {
428
528
  hasToolCalls
429
529
  });
430
530
  }
531
+ if (candidate.groundingMetadata != null) {
532
+ providerMetadata = {
533
+ vertex: {
534
+ groundingMetadata: candidate.groundingMetadata
535
+ }
536
+ };
537
+ }
431
538
  const content = candidate.content;
432
539
  const deltaText = getTextFromParts(content.parts);
433
540
  if (deltaText != null) {
@@ -464,7 +571,8 @@ var GoogleVertexLanguageModel = class {
464
571
  controller.enqueue({
465
572
  type: "finish",
466
573
  finishReason,
467
- usage
574
+ usage,
575
+ providerMetadata
468
576
  });
469
577
  }
470
578
  }
@@ -478,76 +586,6 @@ var GoogleVertexLanguageModel = class {
478
586
  };
479
587
  }
480
588
  };
481
- function prepareToolsAndToolConfig({
482
- useSearchGrounding,
483
- mode
484
- }) {
485
- var _a;
486
- const tools = ((_a = mode.tools) == null ? void 0 : _a.length) ? mode.tools : void 0;
487
- const mappedTools = tools == null ? [] : [
488
- {
489
- functionDeclarations: tools.map((tool) => {
490
- var _a2;
491
- return {
492
- name: tool.name,
493
- description: (_a2 = tool.description) != null ? _a2 : "",
494
- parameters: convertJSONSchemaToOpenAPISchema(
495
- tool.parameters
496
- )
497
- };
498
- })
499
- }
500
- ];
501
- if (useSearchGrounding) {
502
- mappedTools.push({ googleSearchRetrieval: {} });
503
- }
504
- const finalTools = mappedTools.length > 0 ? mappedTools : void 0;
505
- const toolChoice = mode.toolChoice;
506
- if (toolChoice == null) {
507
- return {
508
- tools: finalTools,
509
- toolConfig: void 0
510
- };
511
- }
512
- const type = toolChoice.type;
513
- switch (type) {
514
- case "auto":
515
- return {
516
- tools: finalTools,
517
- toolConfig: {
518
- functionCallingConfig: { mode: FunctionCallingMode.AUTO }
519
- }
520
- };
521
- case "none":
522
- return {
523
- tools: finalTools,
524
- toolConfig: {
525
- functionCallingConfig: { mode: FunctionCallingMode.NONE }
526
- }
527
- };
528
- case "required":
529
- return {
530
- tools: finalTools,
531
- toolConfig: {
532
- functionCallingConfig: { mode: FunctionCallingMode.ANY }
533
- }
534
- };
535
- case "tool":
536
- return {
537
- tools: finalTools,
538
- toolConfig: {
539
- functionCallingConfig: {
540
- mode: FunctionCallingMode.ANY,
541
- allowedFunctionNames: [toolChoice.toolName]
542
- }
543
- }
544
- };
545
- default: {
546
- const _exhaustiveCheck = type;
547
- throw new Error(`Unsupported tool choice type: ${_exhaustiveCheck}`);
548
- }
549
- }
550
- }
551
589
  function getToolCallsFromParts({
552
590
  parts,
553
591
  generateId: generateId2
@@ -572,23 +610,126 @@ function getTextFromParts(parts) {
572
610
  return textParts.length === 0 ? void 0 : textParts.map((part) => part.text).join("");
573
611
  }
574
612
 
613
+ // src/google-vertex-embedding-model.ts
614
+ import {
615
+ TooManyEmbeddingValuesForCallError
616
+ } from "@ai-sdk/provider";
617
+ import {
618
+ combineHeaders,
619
+ createJsonResponseHandler,
620
+ postJsonToApi
621
+ } from "@ai-sdk/provider-utils";
622
+ import { z as z2 } from "zod";
623
+
624
+ // src/google-error.ts
625
+ import { createJsonErrorResponseHandler } from "@ai-sdk/provider-utils";
626
+ import { z } from "zod";
627
+ var googleErrorDataSchema = z.object({
628
+ error: z.object({
629
+ code: z.number().nullable(),
630
+ message: z.string(),
631
+ status: z.string()
632
+ })
633
+ });
634
+ var googleFailedResponseHandler = createJsonErrorResponseHandler({
635
+ errorSchema: googleErrorDataSchema,
636
+ errorToMessage: (data) => data.error.message
637
+ });
638
+
639
+ // src/google-vertex-embedding-model.ts
640
+ var GoogleVertexEmbeddingModel = class {
641
+ constructor(modelId, settings, config) {
642
+ this.specificationVersion = "v1";
643
+ this.modelId = modelId;
644
+ this.settings = settings;
645
+ this.config = config;
646
+ }
647
+ get provider() {
648
+ return this.config.provider;
649
+ }
650
+ get maxEmbeddingsPerCall() {
651
+ return 2048;
652
+ }
653
+ get supportsParallelCalls() {
654
+ return true;
655
+ }
656
+ async doEmbed({
657
+ values,
658
+ headers,
659
+ abortSignal
660
+ }) {
661
+ if (values.length > this.maxEmbeddingsPerCall) {
662
+ throw new TooManyEmbeddingValuesForCallError({
663
+ provider: this.provider,
664
+ modelId: this.modelId,
665
+ maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
666
+ values
667
+ });
668
+ }
669
+ const { responseHeaders, value: response } = await postJsonToApi({
670
+ url: `https://${this.config.region}-aiplatform.googleapis.com/v1/projects/${this.config.project}/locations/${this.config.region}/publishers/google/models/${this.modelId}:predict`,
671
+ headers: combineHeaders(
672
+ { Authorization: `Bearer ${await this.config.generateAuthToken()}` },
673
+ headers
674
+ ),
675
+ body: {
676
+ instances: values.map((value) => ({ content: value })),
677
+ parameters: {
678
+ outputDimensionality: this.settings.outputDimensionality
679
+ }
680
+ },
681
+ failedResponseHandler: googleFailedResponseHandler,
682
+ successfulResponseHandler: createJsonResponseHandler(
683
+ googleVertexTextEmbeddingResponseSchema
684
+ ),
685
+ abortSignal
686
+ });
687
+ return {
688
+ embeddings: response.predictions.map(
689
+ (prediction) => prediction.embeddings.values
690
+ ),
691
+ usage: {
692
+ tokens: response.predictions.reduce(
693
+ (tokenCount, prediction) => tokenCount + prediction.embeddings.statistics.token_count,
694
+ 0
695
+ )
696
+ },
697
+ rawResponse: { headers: responseHeaders }
698
+ };
699
+ }
700
+ };
701
+ var googleVertexTextEmbeddingResponseSchema = z2.object({
702
+ predictions: z2.array(
703
+ z2.object({
704
+ embeddings: z2.object({
705
+ values: z2.array(z2.number()),
706
+ statistics: z2.object({
707
+ token_count: z2.number()
708
+ })
709
+ })
710
+ })
711
+ )
712
+ });
713
+
575
714
  // src/google-vertex-provider.ts
576
715
  function createVertex(options = {}) {
716
+ const loadVertexProject = () => loadSetting({
717
+ settingValue: options.project,
718
+ settingName: "project",
719
+ environmentVariableName: "GOOGLE_VERTEX_PROJECT",
720
+ description: "Google Vertex project"
721
+ });
722
+ const loadVertexLocation = () => loadSetting({
723
+ settingValue: options.location,
724
+ settingName: "location",
725
+ environmentVariableName: "GOOGLE_VERTEX_LOCATION",
726
+ description: "Google Vertex location"
727
+ });
577
728
  const createVertexAI = () => {
578
729
  var _a, _b;
579
730
  const config = {
580
- project: loadSetting({
581
- settingValue: options.project,
582
- settingName: "project",
583
- environmentVariableName: "GOOGLE_VERTEX_PROJECT",
584
- description: "Google Vertex project"
585
- }),
586
- location: loadSetting({
587
- settingValue: options.location,
588
- settingName: "location",
589
- environmentVariableName: "GOOGLE_VERTEX_LOCATION",
590
- description: "Google Vertex location"
591
- }),
731
+ project: loadVertexProject(),
732
+ location: loadVertexLocation(),
592
733
  googleAuthOptions: options.googleAuthOptions
593
734
  };
594
735
  return (_b = (_a = options.createVertexAI) == null ? void 0 : _a.call(options, config)) != null ? _b : new VertexAI2(config);
@@ -600,6 +741,15 @@ function createVertex(options = {}) {
600
741
  generateId: (_a = options.generateId) != null ? _a : generateId
601
742
  });
602
743
  };
744
+ const createEmbeddingModel = (modelId, settings = {}) => {
745
+ const vertexAI = createVertexAI();
746
+ return new GoogleVertexEmbeddingModel(modelId, settings, {
747
+ provider: "google.vertex",
748
+ region: loadVertexLocation(),
749
+ project: loadVertexProject(),
750
+ generateAuthToken: () => vertexAI.googleAuth.getAccessToken()
751
+ });
752
+ };
603
753
  const provider = function(modelId, settings) {
604
754
  if (new.target) {
605
755
  throw new Error(
@@ -609,9 +759,7 @@ function createVertex(options = {}) {
609
759
  return createChatModel(modelId, settings);
610
760
  };
611
761
  provider.languageModel = createChatModel;
612
- provider.textEmbeddingModel = (modelId) => {
613
- throw new NoSuchModelError({ modelId, modelType: "textEmbeddingModel" });
614
- };
762
+ provider.textEmbeddingModel = createEmbeddingModel;
615
763
  return provider;
616
764
  }
617
765
  var vertex = createVertex();