@llumiverse/drivers 0.9.2 → 0.11.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +51 -7
- package/lib/cjs/bedrock/index.js +32 -8
- package/lib/cjs/bedrock/index.js.map +1 -1
- package/lib/cjs/bedrock/s3.js.map +1 -1
- package/lib/cjs/huggingface_ie.js +1 -7
- package/lib/cjs/huggingface_ie.js.map +1 -1
- package/lib/cjs/mistral/index.js +31 -29
- package/lib/cjs/mistral/index.js.map +1 -1
- package/lib/cjs/openai.js +4 -7
- package/lib/cjs/openai.js.map +1 -1
- package/lib/cjs/replicate.js +3 -5
- package/lib/cjs/replicate.js.map +1 -1
- package/lib/cjs/test/index.js.map +1 -1
- package/lib/cjs/togetherai/index.js +5 -12
- package/lib/cjs/togetherai/index.js.map +1 -1
- package/lib/cjs/vertexai/embeddings/embeddings-text.js +23 -0
- package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -0
- package/lib/cjs/vertexai/index.js +5 -8
- package/lib/cjs/vertexai/index.js.map +1 -1
- package/lib/cjs/vertexai/models/codey-chat.js +2 -2
- package/lib/cjs/vertexai/models/codey-chat.js.map +1 -1
- package/lib/cjs/vertexai/models/codey-text.js +2 -2
- package/lib/cjs/vertexai/models/codey-text.js.map +1 -1
- package/lib/cjs/vertexai/models/gemini.js +2 -2
- package/lib/cjs/vertexai/models/gemini.js.map +1 -1
- package/lib/cjs/vertexai/models/palm-model-base.js.map +1 -1
- package/lib/cjs/vertexai/models/palm2-chat.js +2 -2
- package/lib/cjs/vertexai/models/palm2-chat.js.map +1 -1
- package/lib/cjs/vertexai/models/palm2-text.js +2 -2
- package/lib/cjs/vertexai/models/palm2-text.js.map +1 -1
- package/lib/esm/bedrock/index.js +33 -9
- package/lib/esm/bedrock/index.js.map +1 -1
- package/lib/esm/bedrock/s3.js.map +1 -1
- package/lib/esm/huggingface_ie.js +2 -8
- package/lib/esm/huggingface_ie.js.map +1 -1
- package/lib/esm/mistral/index.js +32 -30
- package/lib/esm/mistral/index.js.map +1 -1
- package/lib/esm/openai.js +5 -8
- package/lib/esm/openai.js.map +1 -1
- package/lib/esm/replicate.js +4 -6
- package/lib/esm/replicate.js.map +1 -1
- package/lib/esm/src/bedrock/index.js +375 -0
- package/lib/esm/src/bedrock/index.js.map +1 -0
- package/lib/esm/src/bedrock/s3.js +53 -0
- package/lib/esm/src/bedrock/s3.js.map +1 -0
- package/lib/esm/src/huggingface_ie.js +173 -0
- package/lib/esm/src/huggingface_ie.js.map +1 -0
- package/lib/esm/src/index.js +9 -0
- package/lib/esm/src/index.js.map +1 -0
- package/lib/esm/src/mistral/index.js +145 -0
- package/lib/esm/src/mistral/index.js.map +1 -0
- package/lib/esm/src/mistral/types.js +80 -0
- package/lib/esm/src/mistral/types.js.map +1 -0
- package/lib/esm/src/openai.js +195 -0
- package/lib/esm/src/openai.js.map +1 -0
- package/lib/esm/src/replicate.js +281 -0
- package/lib/esm/src/replicate.js.map +1 -0
- package/lib/esm/src/test/TestErrorCompletionStream.js +16 -0
- package/lib/esm/src/test/TestErrorCompletionStream.js.map +1 -0
- package/lib/esm/src/test/TestValidationErrorCompletionStream.js +20 -0
- package/lib/esm/src/test/TestValidationErrorCompletionStream.js.map +1 -0
- package/lib/esm/src/test/index.js +91 -0
- package/lib/esm/src/test/index.js.map +1 -0
- package/lib/esm/src/test/utils.js +25 -0
- package/lib/esm/src/test/utils.js.map +1 -0
- package/lib/esm/src/togetherai/index.js +89 -0
- package/lib/esm/src/togetherai/index.js.map +1 -0
- package/lib/esm/src/togetherai/interfaces.js +2 -0
- package/lib/esm/src/togetherai/interfaces.js.map +1 -0
- package/lib/esm/src/vertexai/debug.js +6 -0
- package/lib/esm/src/vertexai/debug.js.map +1 -0
- package/lib/esm/src/vertexai/embeddings/embeddings-text.js +19 -0
- package/lib/esm/src/vertexai/embeddings/embeddings-text.js.map +1 -0
- package/lib/esm/src/vertexai/index.js +73 -0
- package/lib/esm/src/vertexai/index.js.map +1 -0
- package/lib/esm/src/vertexai/models/codey-chat.js +61 -0
- package/lib/esm/src/vertexai/models/codey-chat.js.map +1 -0
- package/lib/esm/src/vertexai/models/codey-text.js +31 -0
- package/lib/esm/src/vertexai/models/codey-text.js.map +1 -0
- package/lib/esm/src/vertexai/models/gemini.js +136 -0
- package/lib/esm/src/vertexai/models/gemini.js.map +1 -0
- package/lib/esm/src/vertexai/models/palm-model-base.js +53 -0
- package/lib/esm/src/vertexai/models/palm-model-base.js.map +1 -0
- package/lib/esm/src/vertexai/models/palm2-chat.js +61 -0
- package/lib/esm/src/vertexai/models/palm2-chat.js.map +1 -0
- package/lib/esm/src/vertexai/models/palm2-text.js +31 -0
- package/lib/esm/src/vertexai/models/palm2-text.js.map +1 -0
- package/lib/esm/src/vertexai/models.js +87 -0
- package/lib/esm/src/vertexai/models.js.map +1 -0
- package/{src/vertexai/utils/prompts.ts → lib/esm/src/vertexai/utils/prompts.js} +10 -29
- package/lib/esm/src/vertexai/utils/prompts.js.map +1 -0
- package/lib/esm/src/vertexai/utils/tensor.js +82 -0
- package/lib/esm/src/vertexai/utils/tensor.js.map +1 -0
- package/lib/esm/test/index.js.map +1 -1
- package/lib/esm/togetherai/index.js +6 -13
- package/lib/esm/togetherai/index.js.map +1 -1
- package/lib/esm/tsconfig.tsbuildinfo +1 -0
- package/lib/esm/vertexai/embeddings/embeddings-text.js +19 -0
- package/lib/esm/vertexai/embeddings/embeddings-text.js.map +1 -0
- package/lib/esm/vertexai/index.js +6 -9
- package/lib/esm/vertexai/index.js.map +1 -1
- package/lib/esm/vertexai/models/codey-chat.js +1 -1
- package/lib/esm/vertexai/models/codey-chat.js.map +1 -1
- package/lib/esm/vertexai/models/codey-text.js +2 -2
- package/lib/esm/vertexai/models/codey-text.js.map +1 -1
- package/lib/esm/vertexai/models/gemini.js +2 -2
- package/lib/esm/vertexai/models/gemini.js.map +1 -1
- package/lib/esm/vertexai/models/palm-model-base.js.map +1 -1
- package/lib/esm/vertexai/models/palm2-chat.js +1 -1
- package/lib/esm/vertexai/models/palm2-chat.js.map +1 -1
- package/lib/esm/vertexai/models/palm2-text.js +2 -2
- package/lib/esm/vertexai/models/palm2-text.js.map +1 -1
- package/lib/types/bedrock/index.d.ts +12 -9
- package/lib/types/bedrock/index.d.ts.map +1 -1
- package/lib/types/bedrock/s3.d.ts +2 -5
- package/lib/types/bedrock/s3.d.ts.map +1 -1
- package/lib/types/huggingface_ie.d.ts +5 -10
- package/lib/types/huggingface_ie.d.ts.map +1 -1
- package/lib/types/mistral/index.d.ts +7 -15
- package/lib/types/mistral/index.d.ts.map +1 -1
- package/lib/types/openai.d.ts +2 -7
- package/lib/types/openai.d.ts.map +1 -1
- package/lib/types/replicate.d.ts +2 -6
- package/lib/types/replicate.d.ts.map +1 -1
- package/lib/types/src/bedrock/index.d.ts +94 -0
- package/lib/types/src/bedrock/s3.d.ts +16 -0
- package/lib/types/src/huggingface_ie.d.ts +30 -0
- package/lib/types/src/index.d.ts +8 -0
- package/lib/types/src/mistral/index.d.ts +23 -0
- package/lib/types/src/mistral/types.d.ts +130 -0
- package/lib/types/src/openai.d.ts +30 -0
- package/lib/types/src/replicate.d.ts +47 -0
- package/lib/types/src/test/TestErrorCompletionStream.d.ts +8 -0
- package/lib/types/src/test/TestValidationErrorCompletionStream.d.ts +8 -0
- package/lib/types/src/test/index.d.ts +23 -0
- package/lib/types/src/test/utils.d.ts +4 -0
- package/lib/types/src/togetherai/index.d.ts +21 -0
- package/lib/types/src/togetherai/interfaces.d.ts +80 -0
- package/lib/types/src/vertexai/debug.d.ts +1 -0
- package/lib/types/src/vertexai/embeddings/embeddings-text.d.ts +9 -0
- package/lib/types/src/vertexai/index.d.ts +21 -0
- package/lib/types/src/vertexai/models/codey-chat.d.ts +50 -0
- package/lib/types/src/vertexai/models/codey-text.d.ts +38 -0
- package/lib/types/src/vertexai/models/gemini.d.ts +10 -0
- package/lib/types/src/vertexai/models/palm-model-base.d.ts +60 -0
- package/lib/types/src/vertexai/models/palm2-chat.d.ts +60 -0
- package/lib/types/src/vertexai/models/palm2-text.d.ts +38 -0
- package/lib/types/src/vertexai/models.d.ts +13 -0
- package/lib/types/src/vertexai/utils/prompts.d.ts +19 -0
- package/lib/types/src/vertexai/utils/tensor.d.ts +5 -0
- package/lib/types/test/index.d.ts +2 -5
- package/lib/types/test/index.d.ts.map +1 -1
- package/lib/types/togetherai/index.d.ts +2 -7
- package/lib/types/togetherai/index.d.ts.map +1 -1
- package/lib/types/vertexai/embeddings/embeddings-text.d.ts +10 -0
- package/lib/types/vertexai/embeddings/embeddings-text.d.ts.map +1 -0
- package/lib/types/vertexai/index.d.ts +3 -7
- package/lib/types/vertexai/index.d.ts.map +1 -1
- package/lib/types/vertexai/models/codey-chat.d.ts.map +1 -1
- package/lib/types/vertexai/models/codey-text.d.ts.map +1 -1
- package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
- package/lib/types/vertexai/models/palm-model-base.d.ts +15 -1
- package/lib/types/vertexai/models/palm-model-base.d.ts.map +1 -1
- package/lib/types/vertexai/models/palm2-chat.d.ts.map +1 -1
- package/lib/types/vertexai/models/palm2-text.d.ts.map +1 -1
- package/package.json +15 -16
- package/src/bedrock/index.ts +37 -12
- package/src/bedrock/s3.ts +2 -3
- package/src/huggingface_ie.ts +3 -10
- package/src/mistral/index.ts +36 -43
- package/src/openai.ts +7 -11
- package/src/replicate.ts +4 -6
- package/src/test/index.ts +2 -2
- package/src/togetherai/index.ts +6 -13
- package/src/vertexai/embeddings/embeddings-text.ts +52 -0
- package/src/vertexai/index.ts +9 -10
- package/src/vertexai/models/codey-chat.ts +1 -1
- package/src/vertexai/models/codey-text.ts +2 -2
- package/src/vertexai/models/gemini.ts +4 -4
- package/src/vertexai/models/palm-model-base.ts +17 -1
- package/src/vertexai/models/palm2-chat.ts +1 -1
- package/src/vertexai/models/palm2-text.ts +2 -2
package/src/mistral/index.ts
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
|
-
import { AIModel, AbstractDriver, Completion, DriverOptions,
|
|
1
|
+
import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core";
|
|
2
2
|
import { transformSSEStream } from "@llumiverse/core/async";
|
|
3
|
+
import { OpenAITextMessage, formatOpenAILikePrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
|
|
3
4
|
import { FetchClient } from "api-fetch-client";
|
|
4
5
|
import { CompletionRequestParams, ListModelsResponse, ResponseFormat } from "./types.js";
|
|
5
6
|
|
|
@@ -13,10 +14,9 @@ interface MistralAIDriverOptions extends DriverOptions {
|
|
|
13
14
|
endpoint_url?: string;
|
|
14
15
|
}
|
|
15
16
|
|
|
16
|
-
export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions,
|
|
17
|
+
export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, OpenAITextMessage[]> {
|
|
17
18
|
provider: string;
|
|
18
19
|
apiKey: string;
|
|
19
|
-
defaultFormat: PromptFormats;
|
|
20
20
|
//client: MistralClient;
|
|
21
21
|
client: FetchClient;
|
|
22
22
|
endpointUrl?: string;
|
|
@@ -24,7 +24,6 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, LLMM
|
|
|
24
24
|
constructor(options: MistralAIDriverOptions) {
|
|
25
25
|
super(options);
|
|
26
26
|
this.provider = "MistralAI";
|
|
27
|
-
this.defaultFormat = PromptFormats.genericTextLLM;
|
|
28
27
|
this.apiKey = options.apiKey;
|
|
29
28
|
//this.client = new MistralClient(options.apiKey, options.endpointUrl);
|
|
30
29
|
this.client = new FetchClient(options.endpoint_url || ENDPOINT).withHeaders({
|
|
@@ -34,47 +33,41 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, LLMM
|
|
|
34
33
|
|
|
35
34
|
getResponseFormat = (_options: ExecutionOptions): ResponseFormat | undefined => {
|
|
36
35
|
|
|
36
|
+
// const responseFormatJson: ResponseFormat = {
|
|
37
|
+
// type: "json_object",
|
|
38
|
+
// } as ResponseFormat
|
|
37
39
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
} as ResponseFormat
|
|
40
|
+
// const responseFormatText: ResponseFormat = {
|
|
41
|
+
// type: "text",
|
|
42
|
+
// } as ResponseFormat;
|
|
41
43
|
|
|
42
|
-
const responseFormatText: ResponseFormat = {
|
|
43
|
-
type: "text",
|
|
44
|
-
} as ResponseFormat;
|
|
45
|
-
*/
|
|
46
44
|
|
|
47
|
-
//return _options.resultSchema ? responseFormatJson : responseFormatText;
|
|
45
|
+
// return _options.resultSchema ? responseFormatJson : responseFormatText;
|
|
48
46
|
|
|
49
47
|
//TODO remove this when Mistral properly supports the parameters - it makes an error for now
|
|
48
|
+
// some models like mixtral mistrall tiny or medium are throwing an error when using the response_format parameter
|
|
50
49
|
return undefined
|
|
51
50
|
}
|
|
52
51
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
const prompts = super.createPrompt(segments, { ...opts, format: PromptFormats.openai })
|
|
56
|
-
|
|
52
|
+
protected formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAITextMessage[] {
|
|
53
|
+
const messages = formatOpenAILikePrompt(segments);
|
|
57
54
|
//Add JSON instruction is schema is provided
|
|
58
55
|
if (opts.resultSchema) {
|
|
59
|
-
|
|
60
|
-
prompts.push({
|
|
56
|
+
messages.push({
|
|
61
57
|
role: "user",
|
|
62
|
-
content:
|
|
58
|
+
content: "IMPORTANT: " + getJSONSafetyNotice(opts.resultSchema)
|
|
63
59
|
});
|
|
64
60
|
}
|
|
65
|
-
|
|
66
|
-
return prompts;
|
|
67
|
-
|
|
61
|
+
return messages;
|
|
68
62
|
}
|
|
69
63
|
|
|
70
|
-
async requestCompletion(messages:
|
|
71
|
-
|
|
64
|
+
async requestCompletion(messages: OpenAITextMessage[], options: ExecutionOptions): Promise<Completion<any>> {
|
|
72
65
|
const res = await this.client.post('/v1/chat/completions', {
|
|
73
66
|
payload: _makeChatCompletionRequest({
|
|
74
67
|
model: options.model,
|
|
75
68
|
messages: messages,
|
|
76
|
-
maxTokens: options.max_tokens
|
|
77
|
-
temperature: options.temperature
|
|
69
|
+
maxTokens: options.max_tokens,
|
|
70
|
+
temperature: options.temperature,
|
|
78
71
|
responseFormat: this.getResponseFormat(options),
|
|
79
72
|
})
|
|
80
73
|
})
|
|
@@ -91,14 +84,13 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, LLMM
|
|
|
91
84
|
};
|
|
92
85
|
}
|
|
93
86
|
|
|
94
|
-
async requestCompletionStream(messages:
|
|
95
|
-
|
|
87
|
+
async requestCompletionStream(messages: OpenAITextMessage[], options: ExecutionOptions): Promise<AsyncIterable<string>> {
|
|
96
88
|
const stream = await this.client.post('/v1/chat/completions', {
|
|
97
89
|
payload: _makeChatCompletionRequest({
|
|
98
90
|
model: options.model,
|
|
99
91
|
messages: messages,
|
|
100
|
-
maxTokens: options.max_tokens
|
|
101
|
-
temperature: options.temperature
|
|
92
|
+
maxTokens: options.max_tokens,
|
|
93
|
+
temperature: options.temperature,
|
|
102
94
|
responseFormat: this.getResponseFormat(options),
|
|
103
95
|
stream: true
|
|
104
96
|
}),
|
|
@@ -121,32 +113,33 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, LLMM
|
|
|
121
113
|
name: m.id,
|
|
122
114
|
description: undefined,
|
|
123
115
|
provider: m.owned_by,
|
|
124
|
-
formats: [PromptFormats.genericTextLLM],
|
|
125
116
|
}
|
|
126
117
|
});
|
|
127
118
|
|
|
128
119
|
return aimodels;
|
|
129
120
|
}
|
|
130
121
|
|
|
131
|
-
listTrainableModels(): Promise<AIModel<string>[]> {
|
|
132
|
-
throw new Error("Method not implemented.");
|
|
133
|
-
}
|
|
134
122
|
validateConnection(): Promise<boolean> {
|
|
135
123
|
throw new Error("Method not implemented.");
|
|
136
124
|
}
|
|
137
|
-
//@ts-ignore
|
|
138
|
-
generateEmbeddings(content: string, model?: string | undefined): Promise<{ embeddings: number[]; model: string; }> {
|
|
139
|
-
throw new Error("Method not implemented.");
|
|
140
|
-
}
|
|
141
125
|
|
|
142
|
-
}
|
|
126
|
+
async generateEmbeddings({ content, model = "mistral-embed" }: EmbeddingsOptions): Promise<EmbeddingsResult> {
|
|
127
|
+
const r = await this.client.post('/v1/embeddings', {
|
|
128
|
+
payload: {
|
|
129
|
+
model,
|
|
130
|
+
input: [content],
|
|
131
|
+
encoding_format: "float"
|
|
132
|
+
},
|
|
133
|
+
});
|
|
134
|
+
return {
|
|
135
|
+
values: r.data[0].embedding,
|
|
136
|
+
model,
|
|
137
|
+
token_count: r.usage.total_tokens
|
|
138
|
+
}
|
|
139
|
+
}
|
|
143
140
|
|
|
144
|
-
interface LLMMessage {
|
|
145
|
-
role: string;
|
|
146
|
-
content: string;
|
|
147
141
|
}
|
|
148
142
|
|
|
149
|
-
|
|
150
143
|
/**
|
|
151
144
|
* Creates a chat completion request
|
|
152
145
|
* @param {*} model
|
package/src/openai.ts
CHANGED
|
@@ -5,16 +5,17 @@ import {
|
|
|
5
5
|
Completion,
|
|
6
6
|
DataSource,
|
|
7
7
|
DriverOptions,
|
|
8
|
+
EmbeddingsOptions,
|
|
9
|
+
EmbeddingsResult,
|
|
8
10
|
ExecutionOptions,
|
|
9
11
|
ExecutionTokenUsage,
|
|
10
12
|
ModelType,
|
|
11
|
-
PromptFormats,
|
|
12
|
-
PromptSegment,
|
|
13
13
|
TrainingJob,
|
|
14
14
|
TrainingJobStatus,
|
|
15
15
|
TrainingOptions,
|
|
16
|
-
TrainingPromptOptions
|
|
16
|
+
TrainingPromptOptions,
|
|
17
17
|
} from "@llumiverse/core";
|
|
18
|
+
import { formatOpenAILikePrompt } from "@llumiverse/core/formatters";
|
|
18
19
|
import { asyncMap } from "@llumiverse/core/async";
|
|
19
20
|
import OpenAI from "openai";
|
|
20
21
|
import { Stream } from "openai/streaming";
|
|
@@ -39,18 +40,13 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
39
40
|
generatedContentTypes: string[] = ["text/plain"];
|
|
40
41
|
service: OpenAI;
|
|
41
42
|
provider = BuiltinProviders.openai;
|
|
42
|
-
defaultFormat = PromptFormats.openai;
|
|
43
43
|
|
|
44
44
|
constructor(opts: OpenAIDriverOptions) {
|
|
45
45
|
super(opts);
|
|
46
46
|
this.service = new OpenAI({
|
|
47
47
|
apiKey: opts.apiKey,
|
|
48
48
|
});
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
createPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
|
|
52
|
-
// openai only supports opanai format - force the format
|
|
53
|
-
return super.createPrompt(segments, { ...opts, format: PromptFormats.openai })
|
|
49
|
+
this.formatPrompt = formatOpenAILikePrompt;
|
|
54
50
|
}
|
|
55
51
|
|
|
56
52
|
extractDataFromResponse(
|
|
@@ -211,7 +207,7 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
211
207
|
}
|
|
212
208
|
|
|
213
209
|
|
|
214
|
-
async generateEmbeddings(content
|
|
210
|
+
async generateEmbeddings({ content, model = "text-embedding-ada-002" }: EmbeddingsOptions): Promise<EmbeddingsResult> {
|
|
215
211
|
const res = await this.service.embeddings.create({
|
|
216
212
|
input: content,
|
|
217
213
|
model: model,
|
|
@@ -223,7 +219,7 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
223
219
|
throw new Error("No embedding found");
|
|
224
220
|
}
|
|
225
221
|
|
|
226
|
-
return { embeddings, model };
|
|
222
|
+
return { values: embeddings, model } as EmbeddingsResult;
|
|
227
223
|
}
|
|
228
224
|
|
|
229
225
|
}
|
package/src/replicate.ts
CHANGED
|
@@ -5,9 +5,9 @@ import {
|
|
|
5
5
|
Completion,
|
|
6
6
|
DataSource,
|
|
7
7
|
DriverOptions,
|
|
8
|
+
EmbeddingsResult,
|
|
8
9
|
ExecutionOptions,
|
|
9
10
|
ModelSearchPayload,
|
|
10
|
-
PromptFormats,
|
|
11
11
|
TrainingJob,
|
|
12
12
|
TrainingJobStatus,
|
|
13
13
|
TrainingOptions
|
|
@@ -36,7 +36,6 @@ export interface ReplicateDriverOptions extends DriverOptions {
|
|
|
36
36
|
export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
|
|
37
37
|
provider = BuiltinProviders.replicate;
|
|
38
38
|
service: Replicate;
|
|
39
|
-
defaultFormat = PromptFormats.genericTextLLM;
|
|
40
39
|
|
|
41
40
|
static parseModelId(modelId: string) {
|
|
42
41
|
const [owner, modelPart] = modelId.split("/");
|
|
@@ -73,7 +72,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
|
|
|
73
72
|
const predictionData = {
|
|
74
73
|
input: {
|
|
75
74
|
prompt: prompt,
|
|
76
|
-
max_new_tokens: options.max_tokens
|
|
75
|
+
max_new_tokens: options.max_tokens,
|
|
77
76
|
temperature: options.temperature,
|
|
78
77
|
},
|
|
79
78
|
version: model.version,
|
|
@@ -113,7 +112,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
|
|
|
113
112
|
const predictionData = {
|
|
114
113
|
input: {
|
|
115
114
|
prompt: prompt,
|
|
116
|
-
max_new_tokens: options.max_tokens
|
|
115
|
+
max_new_tokens: options.max_tokens,
|
|
117
116
|
temperature: options.temperature,
|
|
118
117
|
},
|
|
119
118
|
version: model.version,
|
|
@@ -282,8 +281,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
|
|
|
282
281
|
return models;
|
|
283
282
|
}
|
|
284
283
|
|
|
285
|
-
generateEmbeddings(
|
|
286
|
-
this.logger?.debug(`[Replicate] Generating embeddings for ${content} on ${model}`);
|
|
284
|
+
async generateEmbeddings(): Promise<EmbeddingsResult> {
|
|
287
285
|
throw new Error("Method not implemented.");
|
|
288
286
|
}
|
|
289
287
|
|
package/src/test/index.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AIModel, AIModelStatus, CompletionStream, Driver, ExecutionOptions, ExecutionResponse, ModelType, PromptOptions, PromptSegment, TrainingJob } from "@llumiverse/core";
|
|
1
|
+
import { AIModel, AIModelStatus, CompletionStream, Driver, EmbeddingsResult, ExecutionOptions, ExecutionResponse, ModelType, PromptOptions, PromptSegment, TrainingJob } from "@llumiverse/core";
|
|
2
2
|
import { TestErrorCompletionStream } from "./TestErrorCompletionStream.js";
|
|
3
3
|
import { TestValidationErrorCompletionStream } from "./TestValidationErrorCompletionStream.js";
|
|
4
4
|
import { createValidationErrorCompletion, sleep, throwError } from "./utils.js";
|
|
@@ -83,7 +83,7 @@ export class TestDriver implements Driver<PromptSegment[]> {
|
|
|
83
83
|
validateConnection(): Promise<boolean> {
|
|
84
84
|
throw new Error("Method not implemented.");
|
|
85
85
|
}
|
|
86
|
-
generateEmbeddings(): Promise<
|
|
86
|
+
generateEmbeddings(): Promise<EmbeddingsResult> {
|
|
87
87
|
throw new Error("Method not implemented.");
|
|
88
88
|
}
|
|
89
89
|
|
package/src/togetherai/index.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AIModel, AbstractDriver, Completion, DriverOptions,
|
|
1
|
+
import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsResult, ExecutionOptions } from "@llumiverse/core";
|
|
2
2
|
import { transformSSEStream } from "@llumiverse/core/async";
|
|
3
3
|
import { FetchClient } from "api-fetch-client";
|
|
4
4
|
import { TogetherModelInfo } from "./interfaces.js";
|
|
@@ -10,13 +10,11 @@ interface TogetherAIDriverOptions extends DriverOptions {
|
|
|
10
10
|
export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, string> {
|
|
11
11
|
provider: string;
|
|
12
12
|
apiKey: string;
|
|
13
|
-
defaultFormat: PromptFormats;
|
|
14
13
|
fetchClient: FetchClient;
|
|
15
14
|
|
|
16
15
|
constructor(options: TogetherAIDriverOptions) {
|
|
17
16
|
super(options);
|
|
18
17
|
this.provider = "togetherai";
|
|
19
|
-
this.defaultFormat = PromptFormats.genericTextLLM;
|
|
20
18
|
this.apiKey = options.apiKey;
|
|
21
19
|
this.fetchClient = new FetchClient('https://api.together.xyz').withHeaders({
|
|
22
20
|
authorization: `Bearer ${this.apiKey}`
|
|
@@ -37,8 +35,8 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
|
|
|
37
35
|
model: options.model,
|
|
38
36
|
prompt: prompt,
|
|
39
37
|
response_format: this.getResponseFormat(options),
|
|
40
|
-
max_tokens: options.max_tokens
|
|
41
|
-
temperature: options.temperature
|
|
38
|
+
max_tokens: options.max_tokens,
|
|
39
|
+
temperature: options.temperature,
|
|
42
40
|
stop: [
|
|
43
41
|
"</s>",
|
|
44
42
|
"[/INST]"
|
|
@@ -64,8 +62,8 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
|
|
|
64
62
|
payload: {
|
|
65
63
|
model: options.model,
|
|
66
64
|
prompt: prompt,
|
|
67
|
-
max_tokens: options.max_tokens
|
|
68
|
-
temperature: options.temperature
|
|
65
|
+
max_tokens: options.max_tokens,
|
|
66
|
+
temperature: options.temperature,
|
|
69
67
|
response_format: this.getResponseFormat(options),
|
|
70
68
|
stream: true,
|
|
71
69
|
stop: [
|
|
@@ -93,7 +91,6 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
|
|
|
93
91
|
name: m.display_name,
|
|
94
92
|
description: m.description,
|
|
95
93
|
provider: this.provider,
|
|
96
|
-
formats: [PromptFormats.genericTextLLM],
|
|
97
94
|
}
|
|
98
95
|
});
|
|
99
96
|
|
|
@@ -101,14 +98,10 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
|
|
|
101
98
|
|
|
102
99
|
}
|
|
103
100
|
|
|
104
|
-
listTrainableModels(): Promise<AIModel<string>[]> {
|
|
105
|
-
throw new Error("Method not implemented.");
|
|
106
|
-
}
|
|
107
101
|
validateConnection(): Promise<boolean> {
|
|
108
102
|
throw new Error("Method not implemented.");
|
|
109
103
|
}
|
|
110
|
-
|
|
111
|
-
generateEmbeddings(content: string, model?: string | undefined): Promise<{ embeddings: number[]; model: string; }> {
|
|
104
|
+
generateEmbeddings(): Promise<EmbeddingsResult> {
|
|
112
105
|
throw new Error("Method not implemented.");
|
|
113
106
|
}
|
|
114
107
|
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
|
|
2
|
+
import { EmbeddingsResult } from '@llumiverse/core';
|
|
3
|
+
import { VertexAIDriver } from '../index.js';
|
|
4
|
+
|
|
5
|
+
export interface TextEmbeddingsOptions {
|
|
6
|
+
model?: string;
|
|
7
|
+
task_type?: "RETRIEVAL_QUERY" | "RETRIEVAL_DOCUMENT" | "SEMANTIC_SIMILARITY" | "CLASSIFICATION" | "CLUSTERING",
|
|
8
|
+
title?: string, // the title for the embedding
|
|
9
|
+
content: string // the text to generate embeddings for
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
interface EmbedingsForTextPrompt {
|
|
13
|
+
instances: TextEmbeddingsOptions[]
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
interface TextEmbeddingsResult {
|
|
17
|
+
predictions: [
|
|
18
|
+
{
|
|
19
|
+
embeddings: TextEmbeddings
|
|
20
|
+
}
|
|
21
|
+
]
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
interface TextEmbeddings {
|
|
25
|
+
statistics: {
|
|
26
|
+
truncated: boolean,
|
|
27
|
+
token_count: number
|
|
28
|
+
},
|
|
29
|
+
values: [number]
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
export async function getEmbeddingsForText(driver: VertexAIDriver, options: TextEmbeddingsOptions): Promise<EmbeddingsResult> {
|
|
33
|
+
const prompt = {
|
|
34
|
+
instances: [{
|
|
35
|
+
task_type: options.task_type,
|
|
36
|
+
title: options.title,
|
|
37
|
+
content: options.content
|
|
38
|
+
}]
|
|
39
|
+
} as EmbedingsForTextPrompt;
|
|
40
|
+
|
|
41
|
+
const model = options.model || "textembedding-gecko@latest";
|
|
42
|
+
|
|
43
|
+
const result = await driver.fetchClient.post(`/publishers/google/models/${model}:predict`, {
|
|
44
|
+
payload: prompt
|
|
45
|
+
}) as TextEmbeddingsResult;
|
|
46
|
+
|
|
47
|
+
return {
|
|
48
|
+
...result.predictions[0].embeddings,
|
|
49
|
+
model,
|
|
50
|
+
token_count: result.predictions[0].embeddings.statistics?.token_count
|
|
51
|
+
};
|
|
52
|
+
}
|
package/src/vertexai/index.ts
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
//import { v1 } from "@google-cloud/aiplatform";
|
|
2
1
|
import { GenerateContentRequest, VertexAI } from "@google-cloud/vertexai";
|
|
3
|
-
import { AIModel, AbstractDriver, BuiltinProviders, Completion, DriverOptions, ExecutionOptions, ModelSearchPayload,
|
|
2
|
+
import { AIModel, AbstractDriver, BuiltinProviders, Completion, DriverOptions, EmbeddingsResult, ExecutionOptions, ModelSearchPayload, PromptOptions, PromptSegment } from "@llumiverse/core";
|
|
4
3
|
import { FetchClient } from "api-fetch-client";
|
|
4
|
+
import { TextEmbeddingsOptions, getEmbeddingsForText } from "./embeddings/embeddings-text.js";
|
|
5
5
|
import { BuiltinModels, getModelDefinition } from "./models.js";
|
|
6
|
-
|
|
6
|
+
|
|
7
7
|
|
|
8
8
|
export interface VertexAIDriverOptions extends DriverOptions {
|
|
9
9
|
project: string;
|
|
@@ -12,7 +12,6 @@ export interface VertexAIDriverOptions extends DriverOptions {
|
|
|
12
12
|
|
|
13
13
|
export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, GenerateContentRequest> {
|
|
14
14
|
provider = BuiltinProviders.vertexai;
|
|
15
|
-
defaultFormat = PromptFormats.genericTextLLM;
|
|
16
15
|
|
|
17
16
|
//aiplatform: v1.ModelServiceClient;
|
|
18
17
|
vertexai: VertexAI;
|
|
@@ -31,7 +30,8 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Genera
|
|
|
31
30
|
region: this.options.region,
|
|
32
31
|
project: this.options.project,
|
|
33
32
|
}).withAuthCallback(async () => {
|
|
34
|
-
|
|
33
|
+
//@ts-ignore
|
|
34
|
+
const token = await this.vertexai.preview.googleAuth.getAccessToken();
|
|
35
35
|
return `Bearer ${token}`;
|
|
36
36
|
});
|
|
37
37
|
// this.aiplatform = new v1.ModelServiceClient({
|
|
@@ -72,14 +72,13 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Genera
|
|
|
72
72
|
|
|
73
73
|
return []; //TODO
|
|
74
74
|
}
|
|
75
|
-
|
|
76
|
-
throw new Error("Method not implemented.");
|
|
77
|
-
}
|
|
75
|
+
|
|
78
76
|
validateConnection(): Promise<boolean> {
|
|
79
77
|
throw new Error("Method not implemented.");
|
|
80
78
|
}
|
|
81
|
-
|
|
82
|
-
|
|
79
|
+
|
|
80
|
+
async generateEmbeddings(options: TextEmbeddingsOptions): Promise<EmbeddingsResult> {
|
|
81
|
+
return getEmbeddingsForText(this, options);
|
|
83
82
|
}
|
|
84
83
|
|
|
85
84
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { AIModel, ModelType, PromptOptions, PromptRole, PromptSegment } from "@llumiverse/core";
|
|
2
|
+
import { getJSONSafetyNotice } from "@llumiverse/core/formatters";
|
|
2
3
|
import { VertexAIDriver } from "../index.js";
|
|
3
|
-
import { getJSONSafetyNotice } from "../utils/prompts.js";
|
|
4
4
|
import { AbstractPalmModelDefinition, NonStreamingPromptBase, PalmResponseMetadata, StreamingPromptBase } from "./palm-model-base.js";
|
|
5
5
|
|
|
6
6
|
export interface CodeyChatMessage {
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { AIModel, ModelType, PromptOptions, PromptSegment } from "@llumiverse/core";
|
|
2
|
+
import { formatTextPrompt } from "@llumiverse/core/formatters";
|
|
2
3
|
import { VertexAIDriver } from "../index.js";
|
|
3
|
-
import { getPromptAsText } from "../utils/prompts.js";
|
|
4
4
|
import { AbstractPalmModelDefinition, NonStreamingPromptBase, PalmResponseMetadata, StreamingPromptBase } from "./palm-model-base.js";
|
|
5
5
|
|
|
6
6
|
|
|
@@ -50,7 +50,7 @@ export class CodeyTextDefinition extends AbstractPalmModelDefinition<CodeyTextPr
|
|
|
50
50
|
createNonStreamingPrompt(_driver: VertexAIDriver, segments: PromptSegment[], opts: PromptOptions): CodeyTextPrompt {
|
|
51
51
|
return {
|
|
52
52
|
instances: [{
|
|
53
|
-
prefix:
|
|
53
|
+
prefix: formatTextPrompt(segments, opts.resultSchema)
|
|
54
54
|
}],
|
|
55
55
|
parameters: {
|
|
56
56
|
// put defauilts here
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import { Content, GenerateContentRequest,
|
|
1
|
+
import { Content, GenerateContentRequest, HarmBlockThreshold, HarmCategory, TextPart } from "@google-cloud/vertexai";
|
|
2
2
|
import { AIModel, Completion, ExecutionOptions, ExecutionTokenUsage, ModelType, PromptOptions, PromptRole, PromptSegment } from "@llumiverse/core";
|
|
3
3
|
import { asyncMap } from "@llumiverse/core/async";
|
|
4
4
|
import { VertexAIDriver } from "../index.js";
|
|
5
5
|
import { ModelDefinition } from "../models.js";
|
|
6
6
|
|
|
7
|
-
function getGenerativeModel(driver: VertexAIDriver, options: ExecutionOptions)
|
|
7
|
+
function getGenerativeModel(driver: VertexAIDriver, options: ExecutionOptions) {
|
|
8
8
|
return driver.vertexai.preview.getGenerativeModel({
|
|
9
9
|
model: options.model,
|
|
10
10
|
//TODO pass in the options
|
|
@@ -106,8 +106,8 @@ export class GeminiModelDefinition implements ModelDefinition<GenerateContentReq
|
|
|
106
106
|
const response = await r.response;
|
|
107
107
|
const usage = response.usageMetadata;
|
|
108
108
|
const token_usage: ExecutionTokenUsage = {
|
|
109
|
-
prompt: usage?.
|
|
110
|
-
result: usage?.
|
|
109
|
+
prompt: usage?.promptTokenCount,
|
|
110
|
+
result: usage?.candidatesTokenCount,
|
|
111
111
|
total: usage?.totalTokenCount,
|
|
112
112
|
}
|
|
113
113
|
|
|
@@ -2,9 +2,25 @@ import { AIModel, Completion, ExecutionOptions, PromptOptions, PromptSegment } f
|
|
|
2
2
|
import { transformSSEStream } from "@llumiverse/core/async";
|
|
3
3
|
import { VertexAIDriver } from "../index.js";
|
|
4
4
|
import { ModelDefinition } from "../models.js";
|
|
5
|
-
import { PromptParamatersBase } from "../utils/prompts.js";
|
|
6
5
|
import { generateStreamingPrompt } from "../utils/tensor.js";
|
|
7
6
|
|
|
7
|
+
|
|
8
|
+
export interface PromptParamatersBase {
|
|
9
|
+
temperature?: number,
|
|
10
|
+
maxOutputTokens?: number,
|
|
11
|
+
topK?: number,
|
|
12
|
+
topP?: number,
|
|
13
|
+
groundingConfig?: string,
|
|
14
|
+
stopSequences?: string[],
|
|
15
|
+
candidateCount?: number,
|
|
16
|
+
logprobs?: number,
|
|
17
|
+
presencePenalty?: number,
|
|
18
|
+
frequencyPenalty?: number,
|
|
19
|
+
logitBias?: Record<string, number>,
|
|
20
|
+
seed?: number,
|
|
21
|
+
echo?: boolean
|
|
22
|
+
}
|
|
23
|
+
|
|
8
24
|
export interface NonStreamingPromptBase<InstanceType = any> {
|
|
9
25
|
instances: InstanceType[];
|
|
10
26
|
parameters: PromptParamatersBase;
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { AIModel, ModelType, PromptOptions, PromptRole, PromptSegment } from "@llumiverse/core";
|
|
2
|
+
import { getJSONSafetyNotice } from "@llumiverse/core/formatters";
|
|
2
3
|
import { VertexAIDriver } from "../index.js";
|
|
3
|
-
import { getJSONSafetyNotice } from "../utils/prompts.js";
|
|
4
4
|
import { AbstractPalmModelDefinition, NonStreamingPromptBase, PalmResponseMetadata, StreamingPromptBase } from "./palm-model-base.js";
|
|
5
5
|
|
|
6
6
|
export interface Palm2ChatMessage {
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { AIModel, ModelType, PromptOptions, PromptSegment } from "@llumiverse/core";
|
|
2
|
+
import { formatTextPrompt } from "@llumiverse/core/formatters";
|
|
2
3
|
import { VertexAIDriver } from "../index.js";
|
|
3
|
-
import { getPromptAsText } from "../utils/prompts.js";
|
|
4
4
|
import { AbstractPalmModelDefinition, NonStreamingPromptBase, PalmResponseMetadata, StreamingPromptBase } from "./palm-model-base.js";
|
|
5
5
|
|
|
6
6
|
export type Palm2TextPrompt = NonStreamingPromptBase<{
|
|
@@ -50,7 +50,7 @@ export class Palm2TextDefinition extends AbstractPalmModelDefinition<Palm2TextPr
|
|
|
50
50
|
createNonStreamingPrompt(_driver: VertexAIDriver, segments: PromptSegment[], opts: PromptOptions): Palm2TextPrompt {
|
|
51
51
|
return {
|
|
52
52
|
instances: [{
|
|
53
|
-
prompt:
|
|
53
|
+
prompt: formatTextPrompt(segments, opts.resultSchema)
|
|
54
54
|
}],
|
|
55
55
|
parameters: {
|
|
56
56
|
// put defauilts here
|