@llumiverse/drivers 0.9.2 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. package/README.md +50 -0
  2. package/lib/cjs/bedrock/index.js +4 -3
  3. package/lib/cjs/bedrock/index.js.map +1 -1
  4. package/lib/cjs/huggingface_ie.js +1 -6
  5. package/lib/cjs/huggingface_ie.js.map +1 -1
  6. package/lib/cjs/mistral/index.js +5 -9
  7. package/lib/cjs/mistral/index.js.map +1 -1
  8. package/lib/cjs/openai.js +2 -2
  9. package/lib/cjs/openai.js.map +1 -1
  10. package/lib/cjs/replicate.js +3 -4
  11. package/lib/cjs/replicate.js.map +1 -1
  12. package/lib/cjs/test/index.js.map +1 -1
  13. package/lib/cjs/togetherai/index.js +5 -9
  14. package/lib/cjs/togetherai/index.js.map +1 -1
  15. package/lib/cjs/vertexai/embeddings/embeddings-text.js +23 -0
  16. package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -0
  17. package/lib/cjs/vertexai/index.js +3 -5
  18. package/lib/cjs/vertexai/index.js.map +1 -1
  19. package/lib/esm/bedrock/index.js +4 -3
  20. package/lib/esm/bedrock/index.js.map +1 -1
  21. package/lib/esm/huggingface_ie.js +1 -6
  22. package/lib/esm/huggingface_ie.js.map +1 -1
  23. package/lib/esm/mistral/index.js +5 -9
  24. package/lib/esm/mistral/index.js.map +1 -1
  25. package/lib/esm/openai.js +2 -2
  26. package/lib/esm/openai.js.map +1 -1
  27. package/lib/esm/replicate.js +3 -4
  28. package/lib/esm/replicate.js.map +1 -1
  29. package/lib/esm/test/index.js.map +1 -1
  30. package/lib/esm/togetherai/index.js +5 -9
  31. package/lib/esm/togetherai/index.js.map +1 -1
  32. package/lib/esm/vertexai/embeddings/embeddings-text.js +19 -0
  33. package/lib/esm/vertexai/embeddings/embeddings-text.js.map +1 -0
  34. package/lib/esm/vertexai/index.js +3 -5
  35. package/lib/esm/vertexai/index.js.map +1 -1
  36. package/lib/types/bedrock/index.d.ts +2 -5
  37. package/lib/types/bedrock/index.d.ts.map +1 -1
  38. package/lib/types/huggingface_ie.d.ts +2 -6
  39. package/lib/types/huggingface_ie.d.ts.map +1 -1
  40. package/lib/types/mistral/index.d.ts +2 -6
  41. package/lib/types/mistral/index.d.ts.map +1 -1
  42. package/lib/types/openai.d.ts +2 -5
  43. package/lib/types/openai.d.ts.map +1 -1
  44. package/lib/types/replicate.d.ts +2 -5
  45. package/lib/types/replicate.d.ts.map +1 -1
  46. package/lib/types/test/index.d.ts +2 -5
  47. package/lib/types/test/index.d.ts.map +1 -1
  48. package/lib/types/togetherai/index.d.ts +2 -6
  49. package/lib/types/togetherai/index.d.ts.map +1 -1
  50. package/lib/types/vertexai/embeddings/embeddings-text.d.ts +10 -0
  51. package/lib/types/vertexai/embeddings/embeddings-text.d.ts.map +1 -0
  52. package/lib/types/vertexai/index.d.ts +3 -6
  53. package/lib/types/vertexai/index.d.ts.map +1 -1
  54. package/package.json +2 -2
  55. package/src/bedrock/index.ts +5 -4
  56. package/src/huggingface_ie.ts +2 -7
  57. package/src/mistral/index.ts +6 -10
  58. package/src/openai.ts +4 -2
  59. package/src/replicate.ts +4 -4
  60. package/src/test/index.ts +2 -2
  61. package/src/togetherai/index.ts +6 -10
  62. package/src/vertexai/embeddings/embeddings-text.ts +52 -0
  63. package/src/vertexai/index.ts +6 -6
@@ -1 +1 @@
1
- {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../../src/vertexai/index.ts"],"names":[],"mappings":"AACA,OAAO,EAAE,sBAAsB,EAAE,QAAQ,EAAE,MAAM,wBAAwB,CAAC;AAC1E,OAAO,EAAE,OAAO,EAAE,cAAc,EAAE,gBAAgB,EAAE,UAAU,EAAE,aAAa,EAAE,gBAAgB,EAAE,kBAAkB,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,MAAM,kBAAkB,CAAC;AAC3L,OAAO,EAAE,WAAW,EAAE,MAAM,kBAAkB,CAAC;AAI/C,MAAM,WAAW,qBAAsB,SAAQ,aAAa;IACxD,OAAO,EAAE,MAAM,CAAC;IAChB,MAAM,EAAE,MAAM,CAAC;CAClB;AAED,qBAAa,cAAe,SAAQ,cAAc,CAAC,qBAAqB,EAAE,sBAAsB,CAAC;IAC7F,QAAQ,mBAA6B;IACrC,aAAa,gBAAgC;IAG7C,QAAQ,EAAE,QAAQ,CAAC;IACnB,WAAW,EAAE,WAAW,CAAC;gBAGrB,OAAO,EAAE,qBAAqB;IAqBlC,SAAS,CAAC,SAAS,CAAC,OAAO,EAAE,gBAAgB,GAAG,OAAO,CAAC,OAAO,CAAC;IAIzD,YAAY,CAAC,QAAQ,EAAE,aAAa,EAAE,EAAE,OAAO,EAAE,aAAa,GAAG,sBAAsB;IAIxF,iBAAiB,CAAC,MAAM,EAAE,sBAAsB,EAAE,OAAO,EAAE,gBAAgB,GAAG,OAAO,CAAC,UAAU,CAAC,GAAG,CAAC,CAAC;IAGtG,uBAAuB,CAAC,MAAM,EAAE,sBAAsB,EAAE,OAAO,EAAE,gBAAgB,GAAG,OAAO,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;IAIlH,UAAU,CAAC,OAAO,CAAC,EAAE,kBAAkB,GAAG,OAAO,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;IAiB1E,mBAAmB,IAAI,OAAO,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;IAGjD,kBAAkB,IAAI,OAAO,CAAC,OAAO,CAAC;IAGtC,kBAAkB,CAAC,QAAQ,EAAE,MAAM,EAAE,MAAM,CAAC,EAAE,MAAM,GAAG,SAAS,GAAG,OAAO,CAAC;QAAE,UAAU,EAAE,MAAM,EAAE,CAAC;QAAC,KAAK,EAAE,MAAM,CAAC;KAAE,CAAC;CAIvH"}
1
+ {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../../src/vertexai/index.ts"],"names":[],"mappings":"AACA,OAAO,EAAE,sBAAsB,EAAE,QAAQ,EAAE,MAAM,wBAAwB,CAAC;AAC1E,OAAO,EAAE,OAAO,EAAE,cAAc,EAAE,gBAAgB,EAAE,UAAU,EAAE,aAAa,EAAE,gBAAgB,EAAE,gBAAgB,EAAE,kBAAkB,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,MAAM,kBAAkB,CAAC;AAC7M,OAAO,EAAE,WAAW,EAAE,MAAM,kBAAkB,CAAC;AAC/C,OAAO,EAAE,qBAAqB,EAAwB,MAAM,iCAAiC,CAAC;AAI9F,MAAM,WAAW,qBAAsB,SAAQ,aAAa;IACxD,OAAO,EAAE,MAAM,CAAC;IAChB,MAAM,EAAE,MAAM,CAAC;CAClB;AAED,qBAAa,cAAe,SAAQ,cAAc,CAAC,qBAAqB,EAAE,sBAAsB,CAAC;IAC7F,QAAQ,mBAA6B;IACrC,aAAa,gBAAgC;IAG7C,QAAQ,EAAE,QAAQ,CAAC;IACnB,WAAW,EAAE,WAAW,CAAC;gBAGrB,OAAO,EAAE,qBAAqB;IAqBlC,SAAS,CAAC,SAAS,CAAC,OAAO,EAAE,gBAAgB,GAAG,OAAO,CAAC,OAAO,CAAC;IAIzD,YAAY,CAAC,QAAQ,EAAE,aAAa,EAAE,EAAE,OAAO,EAAE,aAAa,GAAG,sBAAsB;IAIxF,iBAAiB,CAAC,MAAM,EAAE,sBAAsB,EAAE,OAAO,EAAE,gBAAgB,GAAG,OAAO,CAAC,UAAU,CAAC,GAAG,CAAC,CAAC;IAGtG,uBAAuB,CAAC,MAAM,EAAE,sBAAsB,EAAE,OAAO,EAAE,gBAAgB,GAAG,OAAO,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;IAIlH,UAAU,CAAC,OAAO,CAAC,EAAE,kBAAkB,GAAG,OAAO,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;IAkB1E,kBAAkB,IAAI,OAAO,CAAC,OAAO,CAAC;IAIhC,kBAAkB,CAAC,OAAO,EAAE,qBAAqB,GAAG,OAAO,CAAC,gBAAgB,CAAC;CAItF"}
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@llumiverse/drivers",
3
- "version": "0.9.2",
3
+ "version": "0.10.0",
4
4
  "type": "module",
5
5
  "description": "LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.",
6
6
  "files": [
@@ -63,7 +63,7 @@
63
63
  "@google-cloud/aiplatform": "^3.10.0",
64
64
  "@google-cloud/vertexai": "^0.2.1",
65
65
  "@huggingface/inference": "^2.6.4",
66
- "@llumiverse/core": "^0.9.2",
66
+ "@llumiverse/core": "^0.10.0",
67
67
  "api-fetch-client": "^0.8.6",
68
68
  "eventsource": "^2.0.2",
69
69
  "google-auth-library": "^9.6.1",
@@ -1,7 +1,7 @@
1
1
  import { Bedrock, CreateModelCustomizationJobCommand, FoundationModelSummary, GetModelCustomizationJobCommand, GetModelCustomizationJobCommandOutput, ModelCustomizationJobStatus, StopModelCustomizationJobCommand } from "@aws-sdk/client-bedrock";
2
2
  import { BedrockRuntime, InvokeModelCommandOutput, ResponseStream } from "@aws-sdk/client-bedrock-runtime";
3
3
  import { S3Client } from "@aws-sdk/client-s3";
4
- import { AIModel, AbstractDriver, BuiltinProviders, Completion, DataSource, DriverOptions, ExecutionOptions, PromptFormats, PromptFormatters, PromptOptions, PromptSegment, TrainingJob, TrainingJobStatus, TrainingOptions } from "@llumiverse/core";
4
+ import { AIModel, AbstractDriver, BuiltinProviders, Completion, DataSource, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptFormats, PromptFormatters, PromptOptions, PromptSegment, TrainingJob, TrainingJobStatus, TrainingOptions } from "@llumiverse/core";
5
5
  import { transformAsyncIterator } from "@llumiverse/core/async";
6
6
  import { AwsCredentialIdentity, Provider } from "@smithy/types";
7
7
  import mnemonist from "mnemonist";
@@ -203,7 +203,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
203
203
  anthropic_version: "bedrock-2023-05-31",
204
204
  ...(prompt as ClaudeMessagesPrompt),
205
205
  temperature: options.temperature,
206
- max_tokens: options.max_tokens ?? 256,
206
+ max_tokens: options.max_tokens,
207
207
  } as ClaudeRequestPayload;
208
208
  } else if (contains(options.model, "ai21")) {
209
209
  return {
@@ -374,7 +374,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
374
374
  return aimodels;
375
375
  }
376
376
 
377
- async generateEmbeddings(content: string, model: string = "amazon.titan-embed-text-v1"): Promise<{ embeddings: number[], model: string; }> {
377
+ async generateEmbeddings({ content, model = "amazon.titan-embed-text-v1" }: EmbeddingsOptions): Promise<EmbeddingsResult> {
378
378
 
379
379
  this.logger.info("[Bedrock] Generating embeddings with model " + model);
380
380
 
@@ -401,8 +401,9 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
401
401
  }
402
402
 
403
403
  return {
404
- embeddings: result.embedding,
404
+ values: result.embedding,
405
405
  model: model,
406
+ token_count: result.inputTextTokenCount
406
407
  };
407
408
 
408
409
  }
@@ -9,6 +9,7 @@ import {
9
9
  AbstractDriver,
10
10
  BuiltinProviders,
11
11
  DriverOptions,
12
+ EmbeddingsResult,
12
13
  ExecutionOptions,
13
14
  PromptFormats
14
15
  } from "@llumiverse/core";
@@ -106,11 +107,6 @@ export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptio
106
107
 
107
108
  // ============== management API ==============
108
109
 
109
- // Not implemented
110
- async listTrainableModels(): Promise<AIModel<string>[]> {
111
- return [];
112
- }
113
-
114
110
  async listModels(): Promise<AIModel[]> {
115
111
  const res = await this.service.get("/");
116
112
  const hfModels = res.items as HuggingFaceIEModel[];
@@ -136,8 +132,7 @@ export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptio
136
132
  }
137
133
  }
138
134
 
139
- async generateEmbeddings(content: string, model?: string): Promise<{ embeddings: number[], model: string; }> {
140
- this.logger?.debug(`[Huggingface] Generating embeddings for ${content} on ${model}`);
135
+ async generateEmbeddings(): Promise<EmbeddingsResult> {
141
136
  throw new Error("Method not implemented.");
142
137
  }
143
138
 
@@ -1,4 +1,4 @@
1
- import { AIModel, AbstractDriver, Completion, DriverOptions, ExecutionOptions, PromptFormats, PromptSegment } from "@llumiverse/core";
1
+ import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsResult, ExecutionOptions, PromptFormats, PromptSegment } from "@llumiverse/core";
2
2
  import { transformSSEStream } from "@llumiverse/core/async";
3
3
  import { FetchClient } from "api-fetch-client";
4
4
  import { CompletionRequestParams, ListModelsResponse, ResponseFormat } from "./types.js";
@@ -73,8 +73,8 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, LLMM
73
73
  payload: _makeChatCompletionRequest({
74
74
  model: options.model,
75
75
  messages: messages,
76
- maxTokens: options.max_tokens ?? 1024,
77
- temperature: options.temperature ?? 0.7,
76
+ maxTokens: options.max_tokens,
77
+ temperature: options.temperature,
78
78
  responseFormat: this.getResponseFormat(options),
79
79
  })
80
80
  })
@@ -97,8 +97,8 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, LLMM
97
97
  payload: _makeChatCompletionRequest({
98
98
  model: options.model,
99
99
  messages: messages,
100
- maxTokens: options.max_tokens ?? 1024,
101
- temperature: options.temperature ?? 0.7,
100
+ maxTokens: options.max_tokens,
101
+ temperature: options.temperature,
102
102
  responseFormat: this.getResponseFormat(options),
103
103
  stream: true
104
104
  }),
@@ -128,14 +128,10 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, LLMM
128
128
  return aimodels;
129
129
  }
130
130
 
131
- listTrainableModels(): Promise<AIModel<string>[]> {
132
- throw new Error("Method not implemented.");
133
- }
134
131
  validateConnection(): Promise<boolean> {
135
132
  throw new Error("Method not implemented.");
136
133
  }
137
- //@ts-ignore
138
- generateEmbeddings(content: string, model?: string | undefined): Promise<{ embeddings: number[]; model: string; }> {
134
+ generateEmbeddings(): Promise<EmbeddingsResult> {
139
135
  throw new Error("Method not implemented.");
140
136
  }
141
137
 
package/src/openai.ts CHANGED
@@ -5,6 +5,8 @@ import {
5
5
  Completion,
6
6
  DataSource,
7
7
  DriverOptions,
8
+ EmbeddingsOptions,
9
+ EmbeddingsResult,
8
10
  ExecutionOptions,
9
11
  ExecutionTokenUsage,
10
12
  ModelType,
@@ -211,7 +213,7 @@ export class OpenAIDriver extends AbstractDriver<
211
213
  }
212
214
 
213
215
 
214
- async generateEmbeddings(content: string, model: string = "text-embedding-ada-002"): Promise<{ embeddings: number[], model: string; }> {
216
+ async generateEmbeddings({ content, model = "text-embedding-ada-002" }: EmbeddingsOptions): Promise<EmbeddingsResult> {
215
217
  const res = await this.service.embeddings.create({
216
218
  input: content,
217
219
  model: model,
@@ -223,7 +225,7 @@ export class OpenAIDriver extends AbstractDriver<
223
225
  throw new Error("No embedding found");
224
226
  }
225
227
 
226
- return { embeddings, model };
228
+ return { values: embeddings, model } as EmbeddingsResult;
227
229
  }
228
230
 
229
231
  }
package/src/replicate.ts CHANGED
@@ -5,6 +5,7 @@ import {
5
5
  Completion,
6
6
  DataSource,
7
7
  DriverOptions,
8
+ EmbeddingsResult,
8
9
  ExecutionOptions,
9
10
  ModelSearchPayload,
10
11
  PromptFormats,
@@ -73,7 +74,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
73
74
  const predictionData = {
74
75
  input: {
75
76
  prompt: prompt,
76
- max_new_tokens: options.max_tokens || 1024,
77
+ max_new_tokens: options.max_tokens,
77
78
  temperature: options.temperature,
78
79
  },
79
80
  version: model.version,
@@ -113,7 +114,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
113
114
  const predictionData = {
114
115
  input: {
115
116
  prompt: prompt,
116
- max_new_tokens: options.max_tokens || 1024,
117
+ max_new_tokens: options.max_tokens,
117
118
  temperature: options.temperature,
118
119
  },
119
120
  version: model.version,
@@ -282,8 +283,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
282
283
  return models;
283
284
  }
284
285
 
285
- generateEmbeddings(content: string, model?: string): Promise<{ embeddings: number[], model: string; }> {
286
- this.logger?.debug(`[Replicate] Generating embeddings for ${content} on ${model}`);
286
+ async generateEmbeddings(): Promise<EmbeddingsResult> {
287
287
  throw new Error("Method not implemented.");
288
288
  }
289
289
 
package/src/test/index.ts CHANGED
@@ -1,4 +1,4 @@
1
- import { AIModel, AIModelStatus, CompletionStream, Driver, ExecutionOptions, ExecutionResponse, ModelType, PromptOptions, PromptSegment, TrainingJob } from "@llumiverse/core";
1
+ import { AIModel, AIModelStatus, CompletionStream, Driver, EmbeddingsResult, ExecutionOptions, ExecutionResponse, ModelType, PromptOptions, PromptSegment, TrainingJob } from "@llumiverse/core";
2
2
  import { TestErrorCompletionStream } from "./TestErrorCompletionStream.js";
3
3
  import { TestValidationErrorCompletionStream } from "./TestValidationErrorCompletionStream.js";
4
4
  import { createValidationErrorCompletion, sleep, throwError } from "./utils.js";
@@ -83,7 +83,7 @@ export class TestDriver implements Driver<PromptSegment[]> {
83
83
  validateConnection(): Promise<boolean> {
84
84
  throw new Error("Method not implemented.");
85
85
  }
86
- generateEmbeddings(): Promise<{ embeddings: number[]; model: string; }> {
86
+ generateEmbeddings(): Promise<EmbeddingsResult> {
87
87
  throw new Error("Method not implemented.");
88
88
  }
89
89
 
@@ -1,4 +1,4 @@
1
- import { AIModel, AbstractDriver, Completion, DriverOptions, ExecutionOptions, PromptFormats } from "@llumiverse/core";
1
+ import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsResult, ExecutionOptions, PromptFormats } from "@llumiverse/core";
2
2
  import { transformSSEStream } from "@llumiverse/core/async";
3
3
  import { FetchClient } from "api-fetch-client";
4
4
  import { TogetherModelInfo } from "./interfaces.js";
@@ -37,8 +37,8 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
37
37
  model: options.model,
38
38
  prompt: prompt,
39
39
  response_format: this.getResponseFormat(options),
40
- max_tokens: options.max_tokens ?? 1024,
41
- temperature: options.temperature ?? 0.7,
40
+ max_tokens: options.max_tokens,
41
+ temperature: options.temperature,
42
42
  stop: [
43
43
  "</s>",
44
44
  "[/INST]"
@@ -64,8 +64,8 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
64
64
  payload: {
65
65
  model: options.model,
66
66
  prompt: prompt,
67
- max_tokens: options.max_tokens ?? 1024,
68
- temperature: options.temperature ?? 0.7,
67
+ max_tokens: options.max_tokens,
68
+ temperature: options.temperature,
69
69
  response_format: this.getResponseFormat(options),
70
70
  stream: true,
71
71
  stop: [
@@ -101,14 +101,10 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
101
101
 
102
102
  }
103
103
 
104
- listTrainableModels(): Promise<AIModel<string>[]> {
105
- throw new Error("Method not implemented.");
106
- }
107
104
  validateConnection(): Promise<boolean> {
108
105
  throw new Error("Method not implemented.");
109
106
  }
110
- //@ts-ignore
111
- generateEmbeddings(content: string, model?: string | undefined): Promise<{ embeddings: number[]; model: string; }> {
107
+ generateEmbeddings(): Promise<EmbeddingsResult> {
112
108
  throw new Error("Method not implemented.");
113
109
  }
114
110
 
@@ -0,0 +1,52 @@
1
+
2
+ import { EmbeddingsResult } from '@llumiverse/core';
3
+ import { VertexAIDriver } from '../index.js';
4
+
5
+ export interface TextEmbeddingsOptions {
6
+ model?: string;
7
+ task_type?: "RETRIEVAL_QUERY" | "RETRIEVAL_DOCUMENT" | "SEMANTIC_SIMILARITY" | "CLASSIFICATION" | "CLUSTERING",
8
+ title?: string, // the title for the embedding
9
+ content: string // the text to generate embeddings for
10
+ }
11
+
12
+ interface EmbedingsForTextPrompt {
13
+ instances: TextEmbeddingsOptions[]
14
+ }
15
+
16
+ interface TextEmbeddingsResult {
17
+ predictions: [
18
+ {
19
+ embeddings: TextEmbeddings
20
+ }
21
+ ]
22
+ }
23
+
24
+ interface TextEmbeddings {
25
+ statistics: {
26
+ truncated: boolean,
27
+ token_count: number
28
+ },
29
+ values: [number]
30
+ }
31
+
32
+ export async function getEmbeddingsForText(driver: VertexAIDriver, options: TextEmbeddingsOptions): Promise<EmbeddingsResult> {
33
+ const prompt = {
34
+ instances: [{
35
+ task_type: options.task_type,
36
+ title: options.title,
37
+ content: options.content
38
+ }]
39
+ } as EmbedingsForTextPrompt;
40
+
41
+ const model = options.model || "textembedding-gecko@latest";
42
+
43
+ const result = await driver.fetchClient.post(`/publishers/google/models/${model}:predict`, {
44
+ payload: prompt
45
+ }) as TextEmbeddingsResult;
46
+
47
+ return {
48
+ ...result.predictions[0].embeddings,
49
+ model,
50
+ token_count: result.predictions[0].embeddings.statistics?.token_count
51
+ };
52
+ }
@@ -1,7 +1,8 @@
1
1
  //import { v1 } from "@google-cloud/aiplatform";
2
2
  import { GenerateContentRequest, VertexAI } from "@google-cloud/vertexai";
3
- import { AIModel, AbstractDriver, BuiltinProviders, Completion, DriverOptions, ExecutionOptions, ModelSearchPayload, PromptFormats, PromptOptions, PromptSegment } from "@llumiverse/core";
3
+ import { AIModel, AbstractDriver, BuiltinProviders, Completion, DriverOptions, EmbeddingsResult, ExecutionOptions, ModelSearchPayload, PromptFormats, PromptOptions, PromptSegment } from "@llumiverse/core";
4
4
  import { FetchClient } from "api-fetch-client";
5
+ import { TextEmbeddingsOptions, getEmbeddingsForText } from "./embeddings/embeddings-text.js";
5
6
  import { BuiltinModels, getModelDefinition } from "./models.js";
6
7
  //import { GoogleAuth } from "google-auth-library";
7
8
 
@@ -72,14 +73,13 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Genera
72
73
 
73
74
  return []; //TODO
74
75
  }
75
- listTrainableModels(): Promise<AIModel<string>[]> {
76
- throw new Error("Method not implemented.");
77
- }
76
+
78
77
  validateConnection(): Promise<boolean> {
79
78
  throw new Error("Method not implemented.");
80
79
  }
81
- generateEmbeddings(_content: string, _model?: string | undefined): Promise<{ embeddings: number[]; model: string; }> {
82
- throw new Error("Method not implemented.");
80
+
81
+ async generateEmbeddings(options: TextEmbeddingsOptions): Promise<EmbeddingsResult> {
82
+ return getEmbeddingsForText(this, options);
83
83
  }
84
84
 
85
85
  }