@llumiverse/drivers 0.9.2 → 0.11.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. package/README.md +51 -7
  2. package/lib/cjs/bedrock/index.js +32 -8
  3. package/lib/cjs/bedrock/index.js.map +1 -1
  4. package/lib/cjs/bedrock/s3.js.map +1 -1
  5. package/lib/cjs/huggingface_ie.js +1 -7
  6. package/lib/cjs/huggingface_ie.js.map +1 -1
  7. package/lib/cjs/mistral/index.js +31 -29
  8. package/lib/cjs/mistral/index.js.map +1 -1
  9. package/lib/cjs/openai.js +4 -7
  10. package/lib/cjs/openai.js.map +1 -1
  11. package/lib/cjs/replicate.js +3 -5
  12. package/lib/cjs/replicate.js.map +1 -1
  13. package/lib/cjs/test/index.js.map +1 -1
  14. package/lib/cjs/togetherai/index.js +5 -12
  15. package/lib/cjs/togetherai/index.js.map +1 -1
  16. package/lib/cjs/vertexai/embeddings/embeddings-text.js +23 -0
  17. package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -0
  18. package/lib/cjs/vertexai/index.js +5 -8
  19. package/lib/cjs/vertexai/index.js.map +1 -1
  20. package/lib/cjs/vertexai/models/codey-chat.js +2 -2
  21. package/lib/cjs/vertexai/models/codey-chat.js.map +1 -1
  22. package/lib/cjs/vertexai/models/codey-text.js +2 -2
  23. package/lib/cjs/vertexai/models/codey-text.js.map +1 -1
  24. package/lib/cjs/vertexai/models/gemini.js +2 -2
  25. package/lib/cjs/vertexai/models/gemini.js.map +1 -1
  26. package/lib/cjs/vertexai/models/palm-model-base.js.map +1 -1
  27. package/lib/cjs/vertexai/models/palm2-chat.js +2 -2
  28. package/lib/cjs/vertexai/models/palm2-chat.js.map +1 -1
  29. package/lib/cjs/vertexai/models/palm2-text.js +2 -2
  30. package/lib/cjs/vertexai/models/palm2-text.js.map +1 -1
  31. package/lib/esm/bedrock/index.js +33 -9
  32. package/lib/esm/bedrock/index.js.map +1 -1
  33. package/lib/esm/bedrock/s3.js.map +1 -1
  34. package/lib/esm/huggingface_ie.js +2 -8
  35. package/lib/esm/huggingface_ie.js.map +1 -1
  36. package/lib/esm/mistral/index.js +32 -30
  37. package/lib/esm/mistral/index.js.map +1 -1
  38. package/lib/esm/openai.js +5 -8
  39. package/lib/esm/openai.js.map +1 -1
  40. package/lib/esm/replicate.js +4 -6
  41. package/lib/esm/replicate.js.map +1 -1
  42. package/lib/esm/src/bedrock/index.js +375 -0
  43. package/lib/esm/src/bedrock/index.js.map +1 -0
  44. package/lib/esm/src/bedrock/s3.js +53 -0
  45. package/lib/esm/src/bedrock/s3.js.map +1 -0
  46. package/lib/esm/src/huggingface_ie.js +173 -0
  47. package/lib/esm/src/huggingface_ie.js.map +1 -0
  48. package/lib/esm/src/index.js +9 -0
  49. package/lib/esm/src/index.js.map +1 -0
  50. package/lib/esm/src/mistral/index.js +145 -0
  51. package/lib/esm/src/mistral/index.js.map +1 -0
  52. package/lib/esm/src/mistral/types.js +80 -0
  53. package/lib/esm/src/mistral/types.js.map +1 -0
  54. package/lib/esm/src/openai.js +195 -0
  55. package/lib/esm/src/openai.js.map +1 -0
  56. package/lib/esm/src/replicate.js +281 -0
  57. package/lib/esm/src/replicate.js.map +1 -0
  58. package/lib/esm/src/test/TestErrorCompletionStream.js +16 -0
  59. package/lib/esm/src/test/TestErrorCompletionStream.js.map +1 -0
  60. package/lib/esm/src/test/TestValidationErrorCompletionStream.js +20 -0
  61. package/lib/esm/src/test/TestValidationErrorCompletionStream.js.map +1 -0
  62. package/lib/esm/src/test/index.js +91 -0
  63. package/lib/esm/src/test/index.js.map +1 -0
  64. package/lib/esm/src/test/utils.js +25 -0
  65. package/lib/esm/src/test/utils.js.map +1 -0
  66. package/lib/esm/src/togetherai/index.js +89 -0
  67. package/lib/esm/src/togetherai/index.js.map +1 -0
  68. package/lib/esm/src/togetherai/interfaces.js +2 -0
  69. package/lib/esm/src/togetherai/interfaces.js.map +1 -0
  70. package/lib/esm/src/vertexai/debug.js +6 -0
  71. package/lib/esm/src/vertexai/debug.js.map +1 -0
  72. package/lib/esm/src/vertexai/embeddings/embeddings-text.js +19 -0
  73. package/lib/esm/src/vertexai/embeddings/embeddings-text.js.map +1 -0
  74. package/lib/esm/src/vertexai/index.js +73 -0
  75. package/lib/esm/src/vertexai/index.js.map +1 -0
  76. package/lib/esm/src/vertexai/models/codey-chat.js +61 -0
  77. package/lib/esm/src/vertexai/models/codey-chat.js.map +1 -0
  78. package/lib/esm/src/vertexai/models/codey-text.js +31 -0
  79. package/lib/esm/src/vertexai/models/codey-text.js.map +1 -0
  80. package/lib/esm/src/vertexai/models/gemini.js +136 -0
  81. package/lib/esm/src/vertexai/models/gemini.js.map +1 -0
  82. package/lib/esm/src/vertexai/models/palm-model-base.js +53 -0
  83. package/lib/esm/src/vertexai/models/palm-model-base.js.map +1 -0
  84. package/lib/esm/src/vertexai/models/palm2-chat.js +61 -0
  85. package/lib/esm/src/vertexai/models/palm2-chat.js.map +1 -0
  86. package/lib/esm/src/vertexai/models/palm2-text.js +31 -0
  87. package/lib/esm/src/vertexai/models/palm2-text.js.map +1 -0
  88. package/lib/esm/src/vertexai/models.js +87 -0
  89. package/lib/esm/src/vertexai/models.js.map +1 -0
  90. package/{src/vertexai/utils/prompts.ts → lib/esm/src/vertexai/utils/prompts.js} +10 -29
  91. package/lib/esm/src/vertexai/utils/prompts.js.map +1 -0
  92. package/lib/esm/src/vertexai/utils/tensor.js +82 -0
  93. package/lib/esm/src/vertexai/utils/tensor.js.map +1 -0
  94. package/lib/esm/test/index.js.map +1 -1
  95. package/lib/esm/togetherai/index.js +6 -13
  96. package/lib/esm/togetherai/index.js.map +1 -1
  97. package/lib/esm/tsconfig.tsbuildinfo +1 -0
  98. package/lib/esm/vertexai/embeddings/embeddings-text.js +19 -0
  99. package/lib/esm/vertexai/embeddings/embeddings-text.js.map +1 -0
  100. package/lib/esm/vertexai/index.js +6 -9
  101. package/lib/esm/vertexai/index.js.map +1 -1
  102. package/lib/esm/vertexai/models/codey-chat.js +1 -1
  103. package/lib/esm/vertexai/models/codey-chat.js.map +1 -1
  104. package/lib/esm/vertexai/models/codey-text.js +2 -2
  105. package/lib/esm/vertexai/models/codey-text.js.map +1 -1
  106. package/lib/esm/vertexai/models/gemini.js +2 -2
  107. package/lib/esm/vertexai/models/gemini.js.map +1 -1
  108. package/lib/esm/vertexai/models/palm-model-base.js.map +1 -1
  109. package/lib/esm/vertexai/models/palm2-chat.js +1 -1
  110. package/lib/esm/vertexai/models/palm2-chat.js.map +1 -1
  111. package/lib/esm/vertexai/models/palm2-text.js +2 -2
  112. package/lib/esm/vertexai/models/palm2-text.js.map +1 -1
  113. package/lib/types/bedrock/index.d.ts +12 -9
  114. package/lib/types/bedrock/index.d.ts.map +1 -1
  115. package/lib/types/bedrock/s3.d.ts +2 -5
  116. package/lib/types/bedrock/s3.d.ts.map +1 -1
  117. package/lib/types/huggingface_ie.d.ts +5 -10
  118. package/lib/types/huggingface_ie.d.ts.map +1 -1
  119. package/lib/types/mistral/index.d.ts +7 -15
  120. package/lib/types/mistral/index.d.ts.map +1 -1
  121. package/lib/types/openai.d.ts +2 -7
  122. package/lib/types/openai.d.ts.map +1 -1
  123. package/lib/types/replicate.d.ts +2 -6
  124. package/lib/types/replicate.d.ts.map +1 -1
  125. package/lib/types/src/bedrock/index.d.ts +94 -0
  126. package/lib/types/src/bedrock/s3.d.ts +16 -0
  127. package/lib/types/src/huggingface_ie.d.ts +30 -0
  128. package/lib/types/src/index.d.ts +8 -0
  129. package/lib/types/src/mistral/index.d.ts +23 -0
  130. package/lib/types/src/mistral/types.d.ts +130 -0
  131. package/lib/types/src/openai.d.ts +30 -0
  132. package/lib/types/src/replicate.d.ts +47 -0
  133. package/lib/types/src/test/TestErrorCompletionStream.d.ts +8 -0
  134. package/lib/types/src/test/TestValidationErrorCompletionStream.d.ts +8 -0
  135. package/lib/types/src/test/index.d.ts +23 -0
  136. package/lib/types/src/test/utils.d.ts +4 -0
  137. package/lib/types/src/togetherai/index.d.ts +21 -0
  138. package/lib/types/src/togetherai/interfaces.d.ts +80 -0
  139. package/lib/types/src/vertexai/debug.d.ts +1 -0
  140. package/lib/types/src/vertexai/embeddings/embeddings-text.d.ts +9 -0
  141. package/lib/types/src/vertexai/index.d.ts +21 -0
  142. package/lib/types/src/vertexai/models/codey-chat.d.ts +50 -0
  143. package/lib/types/src/vertexai/models/codey-text.d.ts +38 -0
  144. package/lib/types/src/vertexai/models/gemini.d.ts +10 -0
  145. package/lib/types/src/vertexai/models/palm-model-base.d.ts +60 -0
  146. package/lib/types/src/vertexai/models/palm2-chat.d.ts +60 -0
  147. package/lib/types/src/vertexai/models/palm2-text.d.ts +38 -0
  148. package/lib/types/src/vertexai/models.d.ts +13 -0
  149. package/lib/types/src/vertexai/utils/prompts.d.ts +19 -0
  150. package/lib/types/src/vertexai/utils/tensor.d.ts +5 -0
  151. package/lib/types/test/index.d.ts +2 -5
  152. package/lib/types/test/index.d.ts.map +1 -1
  153. package/lib/types/togetherai/index.d.ts +2 -7
  154. package/lib/types/togetherai/index.d.ts.map +1 -1
  155. package/lib/types/vertexai/embeddings/embeddings-text.d.ts +10 -0
  156. package/lib/types/vertexai/embeddings/embeddings-text.d.ts.map +1 -0
  157. package/lib/types/vertexai/index.d.ts +3 -7
  158. package/lib/types/vertexai/index.d.ts.map +1 -1
  159. package/lib/types/vertexai/models/codey-chat.d.ts.map +1 -1
  160. package/lib/types/vertexai/models/codey-text.d.ts.map +1 -1
  161. package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
  162. package/lib/types/vertexai/models/palm-model-base.d.ts +15 -1
  163. package/lib/types/vertexai/models/palm-model-base.d.ts.map +1 -1
  164. package/lib/types/vertexai/models/palm2-chat.d.ts.map +1 -1
  165. package/lib/types/vertexai/models/palm2-text.d.ts.map +1 -1
  166. package/package.json +15 -16
  167. package/src/bedrock/index.ts +37 -12
  168. package/src/bedrock/s3.ts +2 -3
  169. package/src/huggingface_ie.ts +3 -10
  170. package/src/mistral/index.ts +36 -43
  171. package/src/openai.ts +7 -11
  172. package/src/replicate.ts +4 -6
  173. package/src/test/index.ts +2 -2
  174. package/src/togetherai/index.ts +6 -13
  175. package/src/vertexai/embeddings/embeddings-text.ts +52 -0
  176. package/src/vertexai/index.ts +9 -10
  177. package/src/vertexai/models/codey-chat.ts +1 -1
  178. package/src/vertexai/models/codey-text.ts +2 -2
  179. package/src/vertexai/models/gemini.ts +4 -4
  180. package/src/vertexai/models/palm-model-base.ts +17 -1
  181. package/src/vertexai/models/palm2-chat.ts +1 -1
  182. package/src/vertexai/models/palm2-text.ts +2 -2
@@ -1,5 +1,6 @@
1
- import { AIModel, AbstractDriver, Completion, DriverOptions, ExecutionOptions, PromptFormats, PromptSegment } from "@llumiverse/core";
1
+ import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core";
2
2
  import { transformSSEStream } from "@llumiverse/core/async";
3
+ import { OpenAITextMessage, formatOpenAILikePrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
3
4
  import { FetchClient } from "api-fetch-client";
4
5
  import { CompletionRequestParams, ListModelsResponse, ResponseFormat } from "./types.js";
5
6
 
@@ -13,10 +14,9 @@ interface MistralAIDriverOptions extends DriverOptions {
13
14
  endpoint_url?: string;
14
15
  }
15
16
 
16
- export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, LLMMessage[]> {
17
+ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, OpenAITextMessage[]> {
17
18
  provider: string;
18
19
  apiKey: string;
19
- defaultFormat: PromptFormats;
20
20
  //client: MistralClient;
21
21
  client: FetchClient;
22
22
  endpointUrl?: string;
@@ -24,7 +24,6 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, LLMM
24
24
  constructor(options: MistralAIDriverOptions) {
25
25
  super(options);
26
26
  this.provider = "MistralAI";
27
- this.defaultFormat = PromptFormats.genericTextLLM;
28
27
  this.apiKey = options.apiKey;
29
28
  //this.client = new MistralClient(options.apiKey, options.endpointUrl);
30
29
  this.client = new FetchClient(options.endpoint_url || ENDPOINT).withHeaders({
@@ -34,47 +33,41 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, LLMM
34
33
 
35
34
  getResponseFormat = (_options: ExecutionOptions): ResponseFormat | undefined => {
36
35
 
36
+ // const responseFormatJson: ResponseFormat = {
37
+ // type: "json_object",
38
+ // } as ResponseFormat
37
39
 
38
- /*const responseFormatJson: ResponseFormat = {
39
- type: "json_object",
40
- } as ResponseFormat
40
+ // const responseFormatText: ResponseFormat = {
41
+ // type: "text",
42
+ // } as ResponseFormat;
41
43
 
42
- const responseFormatText: ResponseFormat = {
43
- type: "text",
44
- } as ResponseFormat;
45
- */
46
44
 
47
- //return _options.resultSchema ? responseFormatJson : responseFormatText;
45
+ // return _options.resultSchema ? responseFormatJson : responseFormatText;
48
46
 
49
47
  //TODO remove this when Mistral properly supports the parameters - it makes an error for now
48
+ // some models like mixtral mistrall tiny or medium are throwing an error when using the response_format parameter
50
49
  return undefined
51
50
  }
52
51
 
53
- createPrompt(segments: PromptSegment[], opts: ExecutionOptions): LLMMessage[] {
54
- // use same format as OpenAI as that's what MistralAI uses
55
- const prompts = super.createPrompt(segments, { ...opts, format: PromptFormats.openai })
56
-
52
+ protected formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAITextMessage[] {
53
+ const messages = formatOpenAILikePrompt(segments);
57
54
  //Add JSON instruction is schema is provided
58
55
  if (opts.resultSchema) {
59
- const content = "The user is explicitely instructing that the result should be a JSON object.\nThe schema is as follows: \n" + JSON.stringify(opts.resultSchema);
60
- prompts.push({
56
+ messages.push({
61
57
  role: "user",
62
- content: content
58
+ content: "IMPORTANT: " + getJSONSafetyNotice(opts.resultSchema)
63
59
  });
64
60
  }
65
-
66
- return prompts;
67
-
61
+ return messages;
68
62
  }
69
63
 
70
- async requestCompletion(messages: LLMMessage[], options: ExecutionOptions): Promise<Completion<any>> {
71
-
64
+ async requestCompletion(messages: OpenAITextMessage[], options: ExecutionOptions): Promise<Completion<any>> {
72
65
  const res = await this.client.post('/v1/chat/completions', {
73
66
  payload: _makeChatCompletionRequest({
74
67
  model: options.model,
75
68
  messages: messages,
76
- maxTokens: options.max_tokens ?? 1024,
77
- temperature: options.temperature ?? 0.7,
69
+ maxTokens: options.max_tokens,
70
+ temperature: options.temperature,
78
71
  responseFormat: this.getResponseFormat(options),
79
72
  })
80
73
  })
@@ -91,14 +84,13 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, LLMM
91
84
  };
92
85
  }
93
86
 
94
- async requestCompletionStream(messages: LLMMessage[], options: ExecutionOptions): Promise<AsyncIterable<string>> {
95
-
87
+ async requestCompletionStream(messages: OpenAITextMessage[], options: ExecutionOptions): Promise<AsyncIterable<string>> {
96
88
  const stream = await this.client.post('/v1/chat/completions', {
97
89
  payload: _makeChatCompletionRequest({
98
90
  model: options.model,
99
91
  messages: messages,
100
- maxTokens: options.max_tokens ?? 1024,
101
- temperature: options.temperature ?? 0.7,
92
+ maxTokens: options.max_tokens,
93
+ temperature: options.temperature,
102
94
  responseFormat: this.getResponseFormat(options),
103
95
  stream: true
104
96
  }),
@@ -121,32 +113,33 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, LLMM
121
113
  name: m.id,
122
114
  description: undefined,
123
115
  provider: m.owned_by,
124
- formats: [PromptFormats.genericTextLLM],
125
116
  }
126
117
  });
127
118
 
128
119
  return aimodels;
129
120
  }
130
121
 
131
- listTrainableModels(): Promise<AIModel<string>[]> {
132
- throw new Error("Method not implemented.");
133
- }
134
122
  validateConnection(): Promise<boolean> {
135
123
  throw new Error("Method not implemented.");
136
124
  }
137
- //@ts-ignore
138
- generateEmbeddings(content: string, model?: string | undefined): Promise<{ embeddings: number[]; model: string; }> {
139
- throw new Error("Method not implemented.");
140
- }
141
125
 
142
- }
126
+ async generateEmbeddings({ content, model = "mistral-embed" }: EmbeddingsOptions): Promise<EmbeddingsResult> {
127
+ const r = await this.client.post('/v1/embeddings', {
128
+ payload: {
129
+ model,
130
+ input: [content],
131
+ encoding_format: "float"
132
+ },
133
+ });
134
+ return {
135
+ values: r.data[0].embedding,
136
+ model,
137
+ token_count: r.usage.total_tokens
138
+ }
139
+ }
143
140
 
144
- interface LLMMessage {
145
- role: string;
146
- content: string;
147
141
  }
148
142
 
149
-
150
143
  /**
151
144
  * Creates a chat completion request
152
145
  * @param {*} model
package/src/openai.ts CHANGED
@@ -5,16 +5,17 @@ import {
5
5
  Completion,
6
6
  DataSource,
7
7
  DriverOptions,
8
+ EmbeddingsOptions,
9
+ EmbeddingsResult,
8
10
  ExecutionOptions,
9
11
  ExecutionTokenUsage,
10
12
  ModelType,
11
- PromptFormats,
12
- PromptSegment,
13
13
  TrainingJob,
14
14
  TrainingJobStatus,
15
15
  TrainingOptions,
16
- TrainingPromptOptions
16
+ TrainingPromptOptions,
17
17
  } from "@llumiverse/core";
18
+ import { formatOpenAILikePrompt } from "@llumiverse/core/formatters";
18
19
  import { asyncMap } from "@llumiverse/core/async";
19
20
  import OpenAI from "openai";
20
21
  import { Stream } from "openai/streaming";
@@ -39,18 +40,13 @@ export class OpenAIDriver extends AbstractDriver<
39
40
  generatedContentTypes: string[] = ["text/plain"];
40
41
  service: OpenAI;
41
42
  provider = BuiltinProviders.openai;
42
- defaultFormat = PromptFormats.openai;
43
43
 
44
44
  constructor(opts: OpenAIDriverOptions) {
45
45
  super(opts);
46
46
  this.service = new OpenAI({
47
47
  apiKey: opts.apiKey,
48
48
  });
49
- }
50
-
51
- createPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
52
- // openai only supports opanai format - force the format
53
- return super.createPrompt(segments, { ...opts, format: PromptFormats.openai })
49
+ this.formatPrompt = formatOpenAILikePrompt;
54
50
  }
55
51
 
56
52
  extractDataFromResponse(
@@ -211,7 +207,7 @@ export class OpenAIDriver extends AbstractDriver<
211
207
  }
212
208
 
213
209
 
214
- async generateEmbeddings(content: string, model: string = "text-embedding-ada-002"): Promise<{ embeddings: number[], model: string; }> {
210
+ async generateEmbeddings({ content, model = "text-embedding-ada-002" }: EmbeddingsOptions): Promise<EmbeddingsResult> {
215
211
  const res = await this.service.embeddings.create({
216
212
  input: content,
217
213
  model: model,
@@ -223,7 +219,7 @@ export class OpenAIDriver extends AbstractDriver<
223
219
  throw new Error("No embedding found");
224
220
  }
225
221
 
226
- return { embeddings, model };
222
+ return { values: embeddings, model } as EmbeddingsResult;
227
223
  }
228
224
 
229
225
  }
package/src/replicate.ts CHANGED
@@ -5,9 +5,9 @@ import {
5
5
  Completion,
6
6
  DataSource,
7
7
  DriverOptions,
8
+ EmbeddingsResult,
8
9
  ExecutionOptions,
9
10
  ModelSearchPayload,
10
- PromptFormats,
11
11
  TrainingJob,
12
12
  TrainingJobStatus,
13
13
  TrainingOptions
@@ -36,7 +36,6 @@ export interface ReplicateDriverOptions extends DriverOptions {
36
36
  export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
37
37
  provider = BuiltinProviders.replicate;
38
38
  service: Replicate;
39
- defaultFormat = PromptFormats.genericTextLLM;
40
39
 
41
40
  static parseModelId(modelId: string) {
42
41
  const [owner, modelPart] = modelId.split("/");
@@ -73,7 +72,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
73
72
  const predictionData = {
74
73
  input: {
75
74
  prompt: prompt,
76
- max_new_tokens: options.max_tokens || 1024,
75
+ max_new_tokens: options.max_tokens,
77
76
  temperature: options.temperature,
78
77
  },
79
78
  version: model.version,
@@ -113,7 +112,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
113
112
  const predictionData = {
114
113
  input: {
115
114
  prompt: prompt,
116
- max_new_tokens: options.max_tokens || 1024,
115
+ max_new_tokens: options.max_tokens,
117
116
  temperature: options.temperature,
118
117
  },
119
118
  version: model.version,
@@ -282,8 +281,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
282
281
  return models;
283
282
  }
284
283
 
285
- generateEmbeddings(content: string, model?: string): Promise<{ embeddings: number[], model: string; }> {
286
- this.logger?.debug(`[Replicate] Generating embeddings for ${content} on ${model}`);
284
+ async generateEmbeddings(): Promise<EmbeddingsResult> {
287
285
  throw new Error("Method not implemented.");
288
286
  }
289
287
 
package/src/test/index.ts CHANGED
@@ -1,4 +1,4 @@
1
- import { AIModel, AIModelStatus, CompletionStream, Driver, ExecutionOptions, ExecutionResponse, ModelType, PromptOptions, PromptSegment, TrainingJob } from "@llumiverse/core";
1
+ import { AIModel, AIModelStatus, CompletionStream, Driver, EmbeddingsResult, ExecutionOptions, ExecutionResponse, ModelType, PromptOptions, PromptSegment, TrainingJob } from "@llumiverse/core";
2
2
  import { TestErrorCompletionStream } from "./TestErrorCompletionStream.js";
3
3
  import { TestValidationErrorCompletionStream } from "./TestValidationErrorCompletionStream.js";
4
4
  import { createValidationErrorCompletion, sleep, throwError } from "./utils.js";
@@ -83,7 +83,7 @@ export class TestDriver implements Driver<PromptSegment[]> {
83
83
  validateConnection(): Promise<boolean> {
84
84
  throw new Error("Method not implemented.");
85
85
  }
86
- generateEmbeddings(): Promise<{ embeddings: number[]; model: string; }> {
86
+ generateEmbeddings(): Promise<EmbeddingsResult> {
87
87
  throw new Error("Method not implemented.");
88
88
  }
89
89
 
@@ -1,4 +1,4 @@
1
- import { AIModel, AbstractDriver, Completion, DriverOptions, ExecutionOptions, PromptFormats } from "@llumiverse/core";
1
+ import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsResult, ExecutionOptions } from "@llumiverse/core";
2
2
  import { transformSSEStream } from "@llumiverse/core/async";
3
3
  import { FetchClient } from "api-fetch-client";
4
4
  import { TogetherModelInfo } from "./interfaces.js";
@@ -10,13 +10,11 @@ interface TogetherAIDriverOptions extends DriverOptions {
10
10
  export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, string> {
11
11
  provider: string;
12
12
  apiKey: string;
13
- defaultFormat: PromptFormats;
14
13
  fetchClient: FetchClient;
15
14
 
16
15
  constructor(options: TogetherAIDriverOptions) {
17
16
  super(options);
18
17
  this.provider = "togetherai";
19
- this.defaultFormat = PromptFormats.genericTextLLM;
20
18
  this.apiKey = options.apiKey;
21
19
  this.fetchClient = new FetchClient('https://api.together.xyz').withHeaders({
22
20
  authorization: `Bearer ${this.apiKey}`
@@ -37,8 +35,8 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
37
35
  model: options.model,
38
36
  prompt: prompt,
39
37
  response_format: this.getResponseFormat(options),
40
- max_tokens: options.max_tokens ?? 1024,
41
- temperature: options.temperature ?? 0.7,
38
+ max_tokens: options.max_tokens,
39
+ temperature: options.temperature,
42
40
  stop: [
43
41
  "</s>",
44
42
  "[/INST]"
@@ -64,8 +62,8 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
64
62
  payload: {
65
63
  model: options.model,
66
64
  prompt: prompt,
67
- max_tokens: options.max_tokens ?? 1024,
68
- temperature: options.temperature ?? 0.7,
65
+ max_tokens: options.max_tokens,
66
+ temperature: options.temperature,
69
67
  response_format: this.getResponseFormat(options),
70
68
  stream: true,
71
69
  stop: [
@@ -93,7 +91,6 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
93
91
  name: m.display_name,
94
92
  description: m.description,
95
93
  provider: this.provider,
96
- formats: [PromptFormats.genericTextLLM],
97
94
  }
98
95
  });
99
96
 
@@ -101,14 +98,10 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
101
98
 
102
99
  }
103
100
 
104
- listTrainableModels(): Promise<AIModel<string>[]> {
105
- throw new Error("Method not implemented.");
106
- }
107
101
  validateConnection(): Promise<boolean> {
108
102
  throw new Error("Method not implemented.");
109
103
  }
110
- //@ts-ignore
111
- generateEmbeddings(content: string, model?: string | undefined): Promise<{ embeddings: number[]; model: string; }> {
104
+ generateEmbeddings(): Promise<EmbeddingsResult> {
112
105
  throw new Error("Method not implemented.");
113
106
  }
114
107
 
@@ -0,0 +1,52 @@
1
+
2
+ import { EmbeddingsResult } from '@llumiverse/core';
3
+ import { VertexAIDriver } from '../index.js';
4
+
5
+ export interface TextEmbeddingsOptions {
6
+ model?: string;
7
+ task_type?: "RETRIEVAL_QUERY" | "RETRIEVAL_DOCUMENT" | "SEMANTIC_SIMILARITY" | "CLASSIFICATION" | "CLUSTERING",
8
+ title?: string, // the title for the embedding
9
+ content: string // the text to generate embeddings for
10
+ }
11
+
12
+ interface EmbedingsForTextPrompt {
13
+ instances: TextEmbeddingsOptions[]
14
+ }
15
+
16
+ interface TextEmbeddingsResult {
17
+ predictions: [
18
+ {
19
+ embeddings: TextEmbeddings
20
+ }
21
+ ]
22
+ }
23
+
24
+ interface TextEmbeddings {
25
+ statistics: {
26
+ truncated: boolean,
27
+ token_count: number
28
+ },
29
+ values: [number]
30
+ }
31
+
32
+ export async function getEmbeddingsForText(driver: VertexAIDriver, options: TextEmbeddingsOptions): Promise<EmbeddingsResult> {
33
+ const prompt = {
34
+ instances: [{
35
+ task_type: options.task_type,
36
+ title: options.title,
37
+ content: options.content
38
+ }]
39
+ } as EmbedingsForTextPrompt;
40
+
41
+ const model = options.model || "textembedding-gecko@latest";
42
+
43
+ const result = await driver.fetchClient.post(`/publishers/google/models/${model}:predict`, {
44
+ payload: prompt
45
+ }) as TextEmbeddingsResult;
46
+
47
+ return {
48
+ ...result.predictions[0].embeddings,
49
+ model,
50
+ token_count: result.predictions[0].embeddings.statistics?.token_count
51
+ };
52
+ }
@@ -1,9 +1,9 @@
1
- //import { v1 } from "@google-cloud/aiplatform";
2
1
  import { GenerateContentRequest, VertexAI } from "@google-cloud/vertexai";
3
- import { AIModel, AbstractDriver, BuiltinProviders, Completion, DriverOptions, ExecutionOptions, ModelSearchPayload, PromptFormats, PromptOptions, PromptSegment } from "@llumiverse/core";
2
+ import { AIModel, AbstractDriver, BuiltinProviders, Completion, DriverOptions, EmbeddingsResult, ExecutionOptions, ModelSearchPayload, PromptOptions, PromptSegment } from "@llumiverse/core";
4
3
  import { FetchClient } from "api-fetch-client";
4
+ import { TextEmbeddingsOptions, getEmbeddingsForText } from "./embeddings/embeddings-text.js";
5
5
  import { BuiltinModels, getModelDefinition } from "./models.js";
6
- //import { GoogleAuth } from "google-auth-library";
6
+
7
7
 
8
8
  export interface VertexAIDriverOptions extends DriverOptions {
9
9
  project: string;
@@ -12,7 +12,6 @@ export interface VertexAIDriverOptions extends DriverOptions {
12
12
 
13
13
  export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, GenerateContentRequest> {
14
14
  provider = BuiltinProviders.vertexai;
15
- defaultFormat = PromptFormats.genericTextLLM;
16
15
 
17
16
  //aiplatform: v1.ModelServiceClient;
18
17
  vertexai: VertexAI;
@@ -31,7 +30,8 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Genera
31
30
  region: this.options.region,
32
31
  project: this.options.project,
33
32
  }).withAuthCallback(async () => {
34
- const token = await this.vertexai.preview.token;
33
+ //@ts-ignore
34
+ const token = await this.vertexai.preview.googleAuth.getAccessToken();
35
35
  return `Bearer ${token}`;
36
36
  });
37
37
  // this.aiplatform = new v1.ModelServiceClient({
@@ -72,14 +72,13 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Genera
72
72
 
73
73
  return []; //TODO
74
74
  }
75
- listTrainableModels(): Promise<AIModel<string>[]> {
76
- throw new Error("Method not implemented.");
77
- }
75
+
78
76
  validateConnection(): Promise<boolean> {
79
77
  throw new Error("Method not implemented.");
80
78
  }
81
- generateEmbeddings(_content: string, _model?: string | undefined): Promise<{ embeddings: number[]; model: string; }> {
82
- throw new Error("Method not implemented.");
79
+
80
+ async generateEmbeddings(options: TextEmbeddingsOptions): Promise<EmbeddingsResult> {
81
+ return getEmbeddingsForText(this, options);
83
82
  }
84
83
 
85
84
  }
@@ -1,6 +1,6 @@
1
1
  import { AIModel, ModelType, PromptOptions, PromptRole, PromptSegment } from "@llumiverse/core";
2
+ import { getJSONSafetyNotice } from "@llumiverse/core/formatters";
2
3
  import { VertexAIDriver } from "../index.js";
3
- import { getJSONSafetyNotice } from "../utils/prompts.js";
4
4
  import { AbstractPalmModelDefinition, NonStreamingPromptBase, PalmResponseMetadata, StreamingPromptBase } from "./palm-model-base.js";
5
5
 
6
6
  export interface CodeyChatMessage {
@@ -1,6 +1,6 @@
1
1
  import { AIModel, ModelType, PromptOptions, PromptSegment } from "@llumiverse/core";
2
+ import { formatTextPrompt } from "@llumiverse/core/formatters";
2
3
  import { VertexAIDriver } from "../index.js";
3
- import { getPromptAsText } from "../utils/prompts.js";
4
4
  import { AbstractPalmModelDefinition, NonStreamingPromptBase, PalmResponseMetadata, StreamingPromptBase } from "./palm-model-base.js";
5
5
 
6
6
 
@@ -50,7 +50,7 @@ export class CodeyTextDefinition extends AbstractPalmModelDefinition<CodeyTextPr
50
50
  createNonStreamingPrompt(_driver: VertexAIDriver, segments: PromptSegment[], opts: PromptOptions): CodeyTextPrompt {
51
51
  return {
52
52
  instances: [{
53
- prefix: getPromptAsText(segments, opts)
53
+ prefix: formatTextPrompt(segments, opts.resultSchema)
54
54
  }],
55
55
  parameters: {
56
56
  // put defauilts here
@@ -1,10 +1,10 @@
1
- import { Content, GenerateContentRequest, GenerativeModel, HarmBlockThreshold, HarmCategory, TextPart } from "@google-cloud/vertexai";
1
+ import { Content, GenerateContentRequest, HarmBlockThreshold, HarmCategory, TextPart } from "@google-cloud/vertexai";
2
2
  import { AIModel, Completion, ExecutionOptions, ExecutionTokenUsage, ModelType, PromptOptions, PromptRole, PromptSegment } from "@llumiverse/core";
3
3
  import { asyncMap } from "@llumiverse/core/async";
4
4
  import { VertexAIDriver } from "../index.js";
5
5
  import { ModelDefinition } from "../models.js";
6
6
 
7
- function getGenerativeModel(driver: VertexAIDriver, options: ExecutionOptions): GenerativeModel {
7
+ function getGenerativeModel(driver: VertexAIDriver, options: ExecutionOptions) {
8
8
  return driver.vertexai.preview.getGenerativeModel({
9
9
  model: options.model,
10
10
  //TODO pass in the options
@@ -106,8 +106,8 @@ export class GeminiModelDefinition implements ModelDefinition<GenerateContentReq
106
106
  const response = await r.response;
107
107
  const usage = response.usageMetadata;
108
108
  const token_usage: ExecutionTokenUsage = {
109
- prompt: usage?.prompt_token_count,
110
- result: usage?.candidates_token_count,
109
+ prompt: usage?.promptTokenCount,
110
+ result: usage?.candidatesTokenCount,
111
111
  total: usage?.totalTokenCount,
112
112
  }
113
113
 
@@ -2,9 +2,25 @@ import { AIModel, Completion, ExecutionOptions, PromptOptions, PromptSegment } f
2
2
  import { transformSSEStream } from "@llumiverse/core/async";
3
3
  import { VertexAIDriver } from "../index.js";
4
4
  import { ModelDefinition } from "../models.js";
5
- import { PromptParamatersBase } from "../utils/prompts.js";
6
5
  import { generateStreamingPrompt } from "../utils/tensor.js";
7
6
 
7
+
8
+ export interface PromptParamatersBase {
9
+ temperature?: number,
10
+ maxOutputTokens?: number,
11
+ topK?: number,
12
+ topP?: number,
13
+ groundingConfig?: string,
14
+ stopSequences?: string[],
15
+ candidateCount?: number,
16
+ logprobs?: number,
17
+ presencePenalty?: number,
18
+ frequencyPenalty?: number,
19
+ logitBias?: Record<string, number>,
20
+ seed?: number,
21
+ echo?: boolean
22
+ }
23
+
8
24
  export interface NonStreamingPromptBase<InstanceType = any> {
9
25
  instances: InstanceType[];
10
26
  parameters: PromptParamatersBase;
@@ -1,6 +1,6 @@
1
1
  import { AIModel, ModelType, PromptOptions, PromptRole, PromptSegment } from "@llumiverse/core";
2
+ import { getJSONSafetyNotice } from "@llumiverse/core/formatters";
2
3
  import { VertexAIDriver } from "../index.js";
3
- import { getJSONSafetyNotice } from "../utils/prompts.js";
4
4
  import { AbstractPalmModelDefinition, NonStreamingPromptBase, PalmResponseMetadata, StreamingPromptBase } from "./palm-model-base.js";
5
5
 
6
6
  export interface Palm2ChatMessage {
@@ -1,6 +1,6 @@
1
1
  import { AIModel, ModelType, PromptOptions, PromptSegment } from "@llumiverse/core";
2
+ import { formatTextPrompt } from "@llumiverse/core/formatters";
2
3
  import { VertexAIDriver } from "../index.js";
3
- import { getPromptAsText } from "../utils/prompts.js";
4
4
  import { AbstractPalmModelDefinition, NonStreamingPromptBase, PalmResponseMetadata, StreamingPromptBase } from "./palm-model-base.js";
5
5
 
6
6
  export type Palm2TextPrompt = NonStreamingPromptBase<{
@@ -50,7 +50,7 @@ export class Palm2TextDefinition extends AbstractPalmModelDefinition<Palm2TextPr
50
50
  createNonStreamingPrompt(_driver: VertexAIDriver, segments: PromptSegment[], opts: PromptOptions): Palm2TextPrompt {
51
51
  return {
52
52
  instances: [{
53
- prompt: getPromptAsText(segments, opts)
53
+ prompt: formatTextPrompt(segments, opts.resultSchema)
54
54
  }],
55
55
  parameters: {
56
56
  // put defauilts here