@llumiverse/drivers 0.12.2 → 0.13.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/lib/cjs/bedrock/index.js +22 -9
- package/lib/cjs/bedrock/index.js.map +1 -1
- package/lib/cjs/huggingface_ie.js.map +1 -1
- package/lib/cjs/index.js +2 -1
- package/lib/cjs/index.js.map +1 -1
- package/lib/cjs/vertexai/models/gemini.js +7 -7
- package/lib/cjs/vertexai/models/gemini.js.map +1 -1
- package/lib/cjs/watsonx/index.js +124 -0
- package/lib/cjs/watsonx/index.js.map +1 -0
- package/lib/cjs/watsonx/interfaces.js +3 -0
- package/lib/cjs/watsonx/interfaces.js.map +1 -0
- package/lib/esm/bedrock/index.js +22 -9
- package/lib/esm/bedrock/index.js.map +1 -1
- package/lib/esm/huggingface_ie.js.map +1 -1
- package/lib/esm/index.js +2 -1
- package/lib/esm/index.js.map +1 -1
- package/lib/esm/vertexai/models/gemini.js +7 -7
- package/lib/esm/vertexai/models/gemini.js.map +1 -1
- package/lib/esm/watsonx/index.js +120 -0
- package/lib/esm/watsonx/index.js.map +1 -0
- package/lib/esm/watsonx/interfaces.js +2 -0
- package/lib/esm/watsonx/interfaces.js.map +1 -0
- package/lib/types/bedrock/index.d.ts.map +1 -1
- package/lib/types/huggingface_ie.d.ts +5 -5
- package/lib/types/index.d.ts +2 -1
- package/lib/types/index.d.ts.map +1 -1
- package/lib/types/watsonx/index.d.ts +27 -0
- package/lib/types/watsonx/index.d.ts.map +1 -0
- package/lib/types/watsonx/interfaces.d.ts +61 -0
- package/lib/types/watsonx/interfaces.d.ts.map +1 -0
- package/package.json +5 -6
- package/src/bedrock/index.ts +24 -9
- package/src/huggingface_ie.ts +1 -1
- package/src/index.ts +2 -1
- package/src/vertexai/models/gemini.ts +7 -7
- package/src/watsonx/index.ts +163 -0
- package/src/watsonx/interfaces.ts +71 -0
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions } from "@llumiverse/core";
|
|
2
|
+
import { transformSSEStream } from "@llumiverse/core/async";
|
|
3
|
+
import { FetchClient } from "api-fetch-client";
|
|
4
|
+
import { GenerateEmbeddingPayload, GenerateEmbeddingResponse, WatsonAuthToken, WatsonxListModelResponse, WatsonxModelSpec, WatsonxTextGenerationPayload, WatsonxTextGenerationResponse } from "./interfaces.js";
|
|
5
|
+
|
|
6
|
+
interface WatsonxDriverOptions extends DriverOptions {
|
|
7
|
+
apiKey: string;
|
|
8
|
+
projectId: string;
|
|
9
|
+
endpointUrl: string;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
const API_VERSION = "2024-03-14"
|
|
13
|
+
|
|
14
|
+
export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string> {
|
|
15
|
+
static PROVIDER = "watsonx";
|
|
16
|
+
provider = WatsonxDriver.PROVIDER;
|
|
17
|
+
apiKey: string;
|
|
18
|
+
endpoint_url: string;
|
|
19
|
+
projectId: string;
|
|
20
|
+
authToken?: WatsonAuthToken;
|
|
21
|
+
fetcher?: FetchClient;
|
|
22
|
+
fetchClient: FetchClient
|
|
23
|
+
|
|
24
|
+
constructor(options: WatsonxDriverOptions) {
|
|
25
|
+
super(options);
|
|
26
|
+
this.apiKey = options.apiKey;
|
|
27
|
+
this.projectId = options.projectId;
|
|
28
|
+
this.endpoint_url = options.endpointUrl;
|
|
29
|
+
this.fetchClient = new FetchClient(this.endpoint_url).withAuthCallback(async () => this.getAuthToken())
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
async requestCompletion(prompt: string, options: ExecutionOptions): Promise<Completion<any>> {
|
|
33
|
+
const payload: WatsonxTextGenerationPayload = {
|
|
34
|
+
model_id: options.model,
|
|
35
|
+
input: prompt,
|
|
36
|
+
parameters: {
|
|
37
|
+
max_new_tokens: options.max_tokens,
|
|
38
|
+
//time_limit: options.time_limit,
|
|
39
|
+
},
|
|
40
|
+
project_id: this.projectId,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
const res = await this.fetchClient.post(`/ml/v1/text/generation?version=${API_VERSION}`, { payload }) as WatsonxTextGenerationResponse;
|
|
44
|
+
|
|
45
|
+
const result = res.results[0];
|
|
46
|
+
|
|
47
|
+
return {
|
|
48
|
+
result: result.generated_text,
|
|
49
|
+
token_usage: {
|
|
50
|
+
prompt: result.input_token_count,
|
|
51
|
+
result: result.generated_token_count,
|
|
52
|
+
total: result.input_token_count + result.generated_token_count,
|
|
53
|
+
},
|
|
54
|
+
finish_reason: result.stop_reason,
|
|
55
|
+
original_response: options.include_original_response ? res : undefined,
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<string>> {
|
|
60
|
+
|
|
61
|
+
const payload: WatsonxTextGenerationPayload = {
|
|
62
|
+
model_id: options.model,
|
|
63
|
+
input: prompt,
|
|
64
|
+
parameters: {
|
|
65
|
+
max_new_tokens: options.max_tokens,
|
|
66
|
+
//time_limit: options.time_limit,
|
|
67
|
+
},
|
|
68
|
+
project_id: this.projectId,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
const stream = await this.fetchClient.post(`/ml/v1/text/generation_stream?version=${API_VERSION}`, {
|
|
72
|
+
payload: payload,
|
|
73
|
+
reader: 'sse'
|
|
74
|
+
})
|
|
75
|
+
|
|
76
|
+
return transformSSEStream(stream, (data: string) => {
|
|
77
|
+
const json = JSON.parse(data) as WatsonxTextGenerationResponse;
|
|
78
|
+
return json.results[0]?.generated_text ?? '';
|
|
79
|
+
});
|
|
80
|
+
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
async listModels(): Promise<AIModel<string>[]> {
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
const res = await this.fetchClient.get(`/ml/v1/foundation_model_specs?version=${API_VERSION}`)
|
|
90
|
+
.catch(err => this.logger.warn("Can't list models on Watsonx: " + err)) as WatsonxListModelResponse;
|
|
91
|
+
|
|
92
|
+
const aimodels = res.resources.map((m: WatsonxModelSpec) => {
|
|
93
|
+
return {
|
|
94
|
+
id: m.model_id,
|
|
95
|
+
name: m.label,
|
|
96
|
+
description: m.short_description,
|
|
97
|
+
provider: this.provider,
|
|
98
|
+
}
|
|
99
|
+
});
|
|
100
|
+
|
|
101
|
+
return aimodels;
|
|
102
|
+
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
async getAuthToken(): Promise<string> {
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
if (this.authToken) {
|
|
109
|
+
const now = Date.now() / 1000;
|
|
110
|
+
if (now < this.authToken.expiration) {
|
|
111
|
+
return this.authToken.access_token;
|
|
112
|
+
} else {
|
|
113
|
+
this.logger.debug("Token expired, refetching", this.authToken, now)
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
const authToken = await fetch('https://iam.cloud.ibm.com/identity/token', {
|
|
118
|
+
method: 'POST',
|
|
119
|
+
headers: {
|
|
120
|
+
'Content-Type': 'application/x-www-form-urlencoded',
|
|
121
|
+
},
|
|
122
|
+
body: `grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=${this.apiKey}`,
|
|
123
|
+
}).then(response => response.json()) as WatsonAuthToken;
|
|
124
|
+
|
|
125
|
+
this.authToken = authToken;
|
|
126
|
+
|
|
127
|
+
return 'Bearer ' + this.authToken.access_token;
|
|
128
|
+
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
async validateConnection(): Promise<boolean> {
|
|
132
|
+
return this.listModels()
|
|
133
|
+
.then(() => true)
|
|
134
|
+
.catch((err) => {
|
|
135
|
+
this.logger.warn("Failed to connect to WatsonX", err);
|
|
136
|
+
return false
|
|
137
|
+
});
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
async generateEmbeddings(options: EmbeddingsOptions): Promise<EmbeddingsResult> {
|
|
141
|
+
|
|
142
|
+
const payload: GenerateEmbeddingPayload = {
|
|
143
|
+
inputs: [options.content],
|
|
144
|
+
model_id: options.model ?? 'ibm/slate-125m-english-rtrvr',
|
|
145
|
+
project_id: this.projectId
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
const res = await this.fetchClient.post(`/ml/v1/text/embeddings?version=${API_VERSION}`, { payload }) as GenerateEmbeddingResponse;
|
|
149
|
+
|
|
150
|
+
return {
|
|
151
|
+
values: res.results[0].embedding,
|
|
152
|
+
model: res.model_id
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
/*interface ListModelsParams extends ModelSearchPayload {
|
|
162
|
+
limit?: number;
|
|
163
|
+
}*/
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
export interface WatsonxTextGenerationPayload {
|
|
4
|
+
model_id: string;
|
|
5
|
+
input: string;
|
|
6
|
+
parameters: {
|
|
7
|
+
max_new_tokens?: number;
|
|
8
|
+
time_limit?: number;
|
|
9
|
+
},
|
|
10
|
+
project_id: string;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
export interface WatsonxTextGenerationResponse {
|
|
14
|
+
model_id: string;
|
|
15
|
+
created_at: string;
|
|
16
|
+
results: {
|
|
17
|
+
generated_text: string;
|
|
18
|
+
generated_token_count: number;
|
|
19
|
+
input_token_count: number;
|
|
20
|
+
stop_reason: string;
|
|
21
|
+
}[]
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
export interface GenerateEmbeddingPayload {
|
|
25
|
+
model_id: string;
|
|
26
|
+
inputs: string[];
|
|
27
|
+
project_id: string;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
export interface GenerateEmbeddingResponse {
|
|
32
|
+
model_id: string;
|
|
33
|
+
created_at: string;
|
|
34
|
+
results: {
|
|
35
|
+
embedding: number[];
|
|
36
|
+
}[]
|
|
37
|
+
input_token_count: number;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
export interface WatsonxModelSpec {
|
|
41
|
+
model_id: string;
|
|
42
|
+
label: string;
|
|
43
|
+
provider: string;
|
|
44
|
+
source: string;
|
|
45
|
+
short_description: string;
|
|
46
|
+
tasks: {
|
|
47
|
+
id: string;
|
|
48
|
+
ratings: {
|
|
49
|
+
quality: number;
|
|
50
|
+
}
|
|
51
|
+
}[];
|
|
52
|
+
min_shot_size: number;
|
|
53
|
+
tier: string;
|
|
54
|
+
number_params: string;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
export interface WatsonxListModelResponse {
|
|
59
|
+
total_count: number;
|
|
60
|
+
limit: number;
|
|
61
|
+
resources: WatsonxModelSpec[];
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
export interface WatsonAuthToken {
|
|
66
|
+
access_token: string
|
|
67
|
+
refresh_token: string
|
|
68
|
+
token_type: string
|
|
69
|
+
expire_in: number
|
|
70
|
+
expiration: number
|
|
71
|
+
}
|