@ai-sdk/google-vertex 4.0.22 → 4.0.24
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +16 -0
- package/dist/anthropic/edge/index.js +1 -1
- package/dist/anthropic/edge/index.mjs +1 -1
- package/dist/edge/index.js +1 -1
- package/dist/edge/index.mjs +1 -1
- package/dist/index.js +1 -1
- package/dist/index.mjs +1 -1
- package/package.json +7 -6
- package/src/__snapshots__/google-vertex-embedding-model.test.ts.snap +39 -0
- package/src/anthropic/edge/google-vertex-anthropic-provider-edge.test.ts +87 -0
- package/src/anthropic/edge/google-vertex-anthropic-provider-edge.ts +41 -0
- package/src/anthropic/edge/index.ts +8 -0
- package/src/anthropic/google-vertex-anthropic-messages-options.ts +15 -0
- package/src/anthropic/google-vertex-anthropic-provider-node.test.ts +73 -0
- package/src/anthropic/google-vertex-anthropic-provider-node.ts +40 -0
- package/src/anthropic/google-vertex-anthropic-provider.test.ts +208 -0
- package/src/anthropic/google-vertex-anthropic-provider.ts +210 -0
- package/src/anthropic/index.ts +8 -0
- package/src/edge/google-vertex-auth-edge.test.ts +308 -0
- package/src/edge/google-vertex-auth-edge.ts +161 -0
- package/src/edge/google-vertex-provider-edge.test.ts +105 -0
- package/src/edge/google-vertex-provider-edge.ts +50 -0
- package/src/edge/index.ts +5 -0
- package/src/google-vertex-auth-google-auth-library.test.ts +59 -0
- package/src/google-vertex-auth-google-auth-library.ts +27 -0
- package/src/google-vertex-config.ts +8 -0
- package/src/google-vertex-embedding-model.test.ts +315 -0
- package/src/google-vertex-embedding-model.ts +135 -0
- package/src/google-vertex-embedding-options.ts +63 -0
- package/src/google-vertex-error.ts +19 -0
- package/src/google-vertex-image-model.test.ts +926 -0
- package/src/google-vertex-image-model.ts +288 -0
- package/src/google-vertex-image-settings.ts +8 -0
- package/src/google-vertex-options.ts +32 -0
- package/src/google-vertex-provider-node.test.ts +88 -0
- package/src/google-vertex-provider-node.ts +49 -0
- package/src/google-vertex-provider.test.ts +318 -0
- package/src/google-vertex-provider.ts +217 -0
- package/src/google-vertex-tools.ts +11 -0
- package/src/index.ts +7 -0
- package/src/version.ts +6 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import { resolve } from '@ai-sdk/provider-utils';
|
|
2
|
+
import { createVertex as createVertexEdge } from './google-vertex-provider-edge';
|
|
3
|
+
import { createVertex as createVertexOriginal } from '../google-vertex-provider';
|
|
4
|
+
import * as edgeAuth from './google-vertex-auth-edge';
|
|
5
|
+
import { describe, beforeEach, afterEach, expect, it, vi } from 'vitest';
|
|
6
|
+
|
|
7
|
+
// Mock the imported modules
|
|
8
|
+
vi.mock('./google-vertex-auth-edge', () => ({
|
|
9
|
+
generateAuthToken: vi.fn().mockResolvedValue('mock-auth-token'),
|
|
10
|
+
}));
|
|
11
|
+
|
|
12
|
+
vi.mock('../google-vertex-provider', () => ({
|
|
13
|
+
createVertex: vi.fn().mockImplementation(options => ({
|
|
14
|
+
...options,
|
|
15
|
+
})),
|
|
16
|
+
}));
|
|
17
|
+
|
|
18
|
+
describe('google-vertex-provider-edge', () => {
|
|
19
|
+
beforeEach(() => {
|
|
20
|
+
vi.clearAllMocks();
|
|
21
|
+
delete process.env.GOOGLE_VERTEX_API_KEY;
|
|
22
|
+
});
|
|
23
|
+
|
|
24
|
+
afterEach(() => {
|
|
25
|
+
delete process.env.GOOGLE_VERTEX_API_KEY;
|
|
26
|
+
});
|
|
27
|
+
|
|
28
|
+
it('default headers function should return auth token', async () => {
|
|
29
|
+
createVertexEdge({ project: 'test-project' });
|
|
30
|
+
|
|
31
|
+
const mockCreateVertex = vi.mocked(createVertexOriginal);
|
|
32
|
+
const passedOptions = mockCreateVertex.mock.calls[0][0];
|
|
33
|
+
|
|
34
|
+
expect(mockCreateVertex).toHaveBeenCalledTimes(1);
|
|
35
|
+
expect(typeof passedOptions?.headers).toBe('function');
|
|
36
|
+
|
|
37
|
+
expect(await resolve(passedOptions?.headers)).toStrictEqual({
|
|
38
|
+
Authorization: 'Bearer mock-auth-token',
|
|
39
|
+
});
|
|
40
|
+
});
|
|
41
|
+
|
|
42
|
+
it('should use custom headers in addition to auth token when provided', async () => {
|
|
43
|
+
createVertexEdge({
|
|
44
|
+
project: 'test-project',
|
|
45
|
+
headers: async () => ({
|
|
46
|
+
'Custom-Header': 'custom-value',
|
|
47
|
+
}),
|
|
48
|
+
});
|
|
49
|
+
|
|
50
|
+
const mockCreateVertex = vi.mocked(createVertexOriginal);
|
|
51
|
+
const passedOptions = mockCreateVertex.mock.calls[0][0];
|
|
52
|
+
|
|
53
|
+
expect(mockCreateVertex).toHaveBeenCalledTimes(1);
|
|
54
|
+
expect(typeof passedOptions?.headers).toBe('function');
|
|
55
|
+
expect(await resolve(passedOptions?.headers)).toEqual({
|
|
56
|
+
Authorization: 'Bearer mock-auth-token',
|
|
57
|
+
'Custom-Header': 'custom-value',
|
|
58
|
+
});
|
|
59
|
+
});
|
|
60
|
+
|
|
61
|
+
it('should use edge auth token generator', async () => {
|
|
62
|
+
createVertexEdge({ project: 'test-project' });
|
|
63
|
+
|
|
64
|
+
const mockCreateVertex = vi.mocked(createVertexOriginal);
|
|
65
|
+
const passedOptions = mockCreateVertex.mock.calls[0][0];
|
|
66
|
+
|
|
67
|
+
// Verify the headers function actually calls generateAuthToken by checking its result
|
|
68
|
+
expect(passedOptions?.headers).toBeDefined();
|
|
69
|
+
await resolve(passedOptions?.headers);
|
|
70
|
+
expect(edgeAuth.generateAuthToken).toHaveBeenCalled();
|
|
71
|
+
});
|
|
72
|
+
|
|
73
|
+
it('passes googleCredentials to generateAuthToken', async () => {
|
|
74
|
+
createVertexEdge({
|
|
75
|
+
project: 'test-project',
|
|
76
|
+
googleCredentials: {
|
|
77
|
+
clientEmail: 'test@example.com',
|
|
78
|
+
privateKey: 'test-key',
|
|
79
|
+
},
|
|
80
|
+
});
|
|
81
|
+
|
|
82
|
+
const mockCreateVertex = vi.mocked(createVertexOriginal);
|
|
83
|
+
const passedOptions = mockCreateVertex.mock.calls[0][0];
|
|
84
|
+
|
|
85
|
+
await resolve(passedOptions?.headers); // call the headers function
|
|
86
|
+
|
|
87
|
+
expect(edgeAuth.generateAuthToken).toHaveBeenCalledWith({
|
|
88
|
+
clientEmail: 'test@example.com',
|
|
89
|
+
privateKey: 'test-key',
|
|
90
|
+
});
|
|
91
|
+
});
|
|
92
|
+
|
|
93
|
+
it('should pass options through to base provider when apiKey is provided', async () => {
|
|
94
|
+
createVertexEdge({
|
|
95
|
+
apiKey: 'test-api-key',
|
|
96
|
+
});
|
|
97
|
+
|
|
98
|
+
const mockCreateVertex = vi.mocked(createVertexOriginal);
|
|
99
|
+
const passedOptions = mockCreateVertex.mock.calls[0][0];
|
|
100
|
+
|
|
101
|
+
expect(passedOptions?.apiKey).toBe('test-api-key');
|
|
102
|
+
expect(passedOptions?.headers).toBeUndefined();
|
|
103
|
+
expect(edgeAuth.generateAuthToken).not.toHaveBeenCalled();
|
|
104
|
+
});
|
|
105
|
+
});
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import { loadOptionalSetting, resolve } from '@ai-sdk/provider-utils';
|
|
2
|
+
import {
|
|
3
|
+
createVertex as createVertexOriginal,
|
|
4
|
+
GoogleVertexProvider,
|
|
5
|
+
GoogleVertexProviderSettings as GoogleVertexProviderSettingsOriginal,
|
|
6
|
+
} from '../google-vertex-provider';
|
|
7
|
+
import {
|
|
8
|
+
generateAuthToken,
|
|
9
|
+
GoogleCredentials,
|
|
10
|
+
} from './google-vertex-auth-edge';
|
|
11
|
+
|
|
12
|
+
export type { GoogleVertexProvider };
|
|
13
|
+
|
|
14
|
+
export interface GoogleVertexProviderSettings
|
|
15
|
+
extends GoogleVertexProviderSettingsOriginal {
|
|
16
|
+
/**
|
|
17
|
+
* Optional. The Google credentials for the Google Cloud service account. If
|
|
18
|
+
* not provided, the Google Vertex provider will use environment variables to
|
|
19
|
+
* load the credentials.
|
|
20
|
+
*/
|
|
21
|
+
googleCredentials?: GoogleCredentials;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
export function createVertex(
|
|
25
|
+
options: GoogleVertexProviderSettings = {},
|
|
26
|
+
): GoogleVertexProvider {
|
|
27
|
+
const apiKey = loadOptionalSetting({
|
|
28
|
+
settingValue: options.apiKey,
|
|
29
|
+
environmentVariableName: 'GOOGLE_VERTEX_API_KEY',
|
|
30
|
+
});
|
|
31
|
+
|
|
32
|
+
if (apiKey) {
|
|
33
|
+
return createVertexOriginal(options);
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
return createVertexOriginal({
|
|
37
|
+
...options,
|
|
38
|
+
headers: async () => ({
|
|
39
|
+
Authorization: `Bearer ${await generateAuthToken(
|
|
40
|
+
options.googleCredentials,
|
|
41
|
+
)}`,
|
|
42
|
+
...(await resolve(options.headers)),
|
|
43
|
+
}),
|
|
44
|
+
});
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
/**
|
|
48
|
+
Default Google Vertex AI provider instance.
|
|
49
|
+
*/
|
|
50
|
+
export const vertex = createVertex();
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
2
|
+
import {
|
|
3
|
+
generateAuthToken,
|
|
4
|
+
_resetAuthInstance,
|
|
5
|
+
} from './google-vertex-auth-google-auth-library';
|
|
6
|
+
import { GoogleAuth } from 'google-auth-library';
|
|
7
|
+
|
|
8
|
+
vi.mock('google-auth-library', () => {
|
|
9
|
+
return {
|
|
10
|
+
GoogleAuth: vi.fn().mockImplementation(() => {
|
|
11
|
+
return {
|
|
12
|
+
getClient: vi.fn().mockResolvedValue({
|
|
13
|
+
getAccessToken: vi.fn().mockResolvedValue({ token: 'mocked-token' }),
|
|
14
|
+
}),
|
|
15
|
+
};
|
|
16
|
+
}),
|
|
17
|
+
};
|
|
18
|
+
});
|
|
19
|
+
|
|
20
|
+
describe('generateAuthToken', () => {
|
|
21
|
+
beforeEach(() => {
|
|
22
|
+
vi.clearAllMocks();
|
|
23
|
+
_resetAuthInstance();
|
|
24
|
+
});
|
|
25
|
+
|
|
26
|
+
it('should generate a valid auth token', async () => {
|
|
27
|
+
const token = await generateAuthToken();
|
|
28
|
+
expect(token).toBe('mocked-token');
|
|
29
|
+
});
|
|
30
|
+
|
|
31
|
+
it('should return null if no token is received', async () => {
|
|
32
|
+
// Reset the mock completely
|
|
33
|
+
vi.mocked(GoogleAuth).mockReset();
|
|
34
|
+
|
|
35
|
+
// Create a new mock implementation
|
|
36
|
+
vi.mocked(GoogleAuth).mockImplementation(
|
|
37
|
+
() =>
|
|
38
|
+
({
|
|
39
|
+
getClient: vi.fn().mockResolvedValue({
|
|
40
|
+
getAccessToken: vi.fn().mockResolvedValue({ token: null }),
|
|
41
|
+
}),
|
|
42
|
+
isGCE: vi.fn(),
|
|
43
|
+
}) as unknown as GoogleAuth,
|
|
44
|
+
);
|
|
45
|
+
|
|
46
|
+
const token = await generateAuthToken();
|
|
47
|
+
expect(token).toBeNull();
|
|
48
|
+
});
|
|
49
|
+
|
|
50
|
+
it('should create new auth instance with provided options', async () => {
|
|
51
|
+
const options = { keyFile: 'test-key.json' };
|
|
52
|
+
await generateAuthToken(options);
|
|
53
|
+
|
|
54
|
+
expect(GoogleAuth).toHaveBeenCalledWith({
|
|
55
|
+
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
|
|
56
|
+
keyFile: 'test-key.json',
|
|
57
|
+
});
|
|
58
|
+
});
|
|
59
|
+
});
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library';
|
|
2
|
+
|
|
3
|
+
let authInstance: GoogleAuth | null = null;
|
|
4
|
+
let authOptions: GoogleAuthOptions | null = null;
|
|
5
|
+
|
|
6
|
+
function getAuth(options: GoogleAuthOptions) {
|
|
7
|
+
if (!authInstance || options !== authOptions) {
|
|
8
|
+
authInstance = new GoogleAuth({
|
|
9
|
+
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
|
|
10
|
+
...options,
|
|
11
|
+
});
|
|
12
|
+
authOptions = options;
|
|
13
|
+
}
|
|
14
|
+
return authInstance;
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
export async function generateAuthToken(options?: GoogleAuthOptions) {
|
|
18
|
+
const auth = getAuth(options || {});
|
|
19
|
+
const client = await auth.getClient();
|
|
20
|
+
const token = await client.getAccessToken();
|
|
21
|
+
return token?.token || null;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
// For testing purposes only
|
|
25
|
+
export function _resetAuthInstance() {
|
|
26
|
+
authInstance = null;
|
|
27
|
+
}
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
import {
|
|
2
|
+
EmbeddingModelV3Embedding,
|
|
3
|
+
TooManyEmbeddingValuesForCallError,
|
|
4
|
+
} from '@ai-sdk/provider';
|
|
5
|
+
import { createTestServer } from '@ai-sdk/test-server/with-vitest';
|
|
6
|
+
import { GoogleVertexEmbeddingModel } from './google-vertex-embedding-model';
|
|
7
|
+
import { describe, it, expect, vi } from 'vitest';
|
|
8
|
+
import { createVertex } from './google-vertex-provider';
|
|
9
|
+
|
|
10
|
+
vi.mock('./version', () => ({
|
|
11
|
+
VERSION: '0.0.0-test',
|
|
12
|
+
}));
|
|
13
|
+
|
|
14
|
+
const dummyEmbeddings = [
|
|
15
|
+
[0.1, 0.2, 0.3],
|
|
16
|
+
[0.4, 0.5, 0.6],
|
|
17
|
+
];
|
|
18
|
+
const testValues = ['test text one', 'test text two'];
|
|
19
|
+
|
|
20
|
+
const DEFAULT_URL =
|
|
21
|
+
'https://us-central1-aiplatform.googleapis.com/v1beta1/projects/test-project/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict';
|
|
22
|
+
|
|
23
|
+
const CUSTOM_URL =
|
|
24
|
+
'https://custom-endpoint.com/models/textembedding-gecko@001:predict';
|
|
25
|
+
|
|
26
|
+
const server = createTestServer({
|
|
27
|
+
[DEFAULT_URL]: {},
|
|
28
|
+
[CUSTOM_URL]: {},
|
|
29
|
+
});
|
|
30
|
+
|
|
31
|
+
describe('GoogleVertexEmbeddingModel', () => {
|
|
32
|
+
const mockModelId = 'textembedding-gecko@001';
|
|
33
|
+
const mockProviderOptions = {
|
|
34
|
+
outputDimensionality: 768,
|
|
35
|
+
taskType: 'SEMANTIC_SIMILARITY',
|
|
36
|
+
title: 'test title',
|
|
37
|
+
autoTruncate: false,
|
|
38
|
+
};
|
|
39
|
+
|
|
40
|
+
const mockConfig = {
|
|
41
|
+
provider: 'google-vertex',
|
|
42
|
+
region: 'us-central1',
|
|
43
|
+
project: 'test-project',
|
|
44
|
+
headers: () => ({}),
|
|
45
|
+
baseURL:
|
|
46
|
+
'https://us-central1-aiplatform.googleapis.com/v1beta1/projects/test-project/locations/us-central1/publishers/google',
|
|
47
|
+
};
|
|
48
|
+
|
|
49
|
+
const model = new GoogleVertexEmbeddingModel(mockModelId, mockConfig);
|
|
50
|
+
|
|
51
|
+
function prepareJsonResponse({
|
|
52
|
+
embeddings = dummyEmbeddings,
|
|
53
|
+
tokenCounts = [1, 1],
|
|
54
|
+
headers,
|
|
55
|
+
}: {
|
|
56
|
+
embeddings?: EmbeddingModelV3Embedding[];
|
|
57
|
+
tokenCounts?: number[];
|
|
58
|
+
headers?: Record<string, string>;
|
|
59
|
+
} = {}) {
|
|
60
|
+
server.urls[DEFAULT_URL].response = {
|
|
61
|
+
type: 'json-value',
|
|
62
|
+
headers,
|
|
63
|
+
body: {
|
|
64
|
+
predictions: embeddings.map((values, i) => ({
|
|
65
|
+
embeddings: {
|
|
66
|
+
values,
|
|
67
|
+
statistics: { token_count: tokenCounts[i] },
|
|
68
|
+
},
|
|
69
|
+
})),
|
|
70
|
+
},
|
|
71
|
+
};
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
it('should extract embeddings', async () => {
|
|
75
|
+
prepareJsonResponse();
|
|
76
|
+
|
|
77
|
+
const { embeddings } = await model.doEmbed({
|
|
78
|
+
values: testValues,
|
|
79
|
+
providerOptions: { google: mockProviderOptions },
|
|
80
|
+
});
|
|
81
|
+
|
|
82
|
+
expect(embeddings).toStrictEqual(dummyEmbeddings);
|
|
83
|
+
});
|
|
84
|
+
|
|
85
|
+
it('should expose the raw response', async () => {
|
|
86
|
+
prepareJsonResponse({
|
|
87
|
+
headers: {
|
|
88
|
+
'test-header': 'test-value',
|
|
89
|
+
},
|
|
90
|
+
});
|
|
91
|
+
|
|
92
|
+
const { response } = await model.doEmbed({
|
|
93
|
+
values: testValues,
|
|
94
|
+
providerOptions: { google: mockProviderOptions },
|
|
95
|
+
});
|
|
96
|
+
|
|
97
|
+
expect(response?.headers).toStrictEqual({
|
|
98
|
+
// default headers:
|
|
99
|
+
'content-length': '159',
|
|
100
|
+
'content-type': 'application/json',
|
|
101
|
+
// custom header
|
|
102
|
+
'test-header': 'test-value',
|
|
103
|
+
});
|
|
104
|
+
expect(response).toMatchSnapshot();
|
|
105
|
+
});
|
|
106
|
+
|
|
107
|
+
it('should extract usage', async () => {
|
|
108
|
+
prepareJsonResponse({
|
|
109
|
+
tokenCounts: [10, 15],
|
|
110
|
+
});
|
|
111
|
+
|
|
112
|
+
const { usage } = await model.doEmbed({
|
|
113
|
+
values: testValues,
|
|
114
|
+
providerOptions: { google: mockProviderOptions },
|
|
115
|
+
});
|
|
116
|
+
|
|
117
|
+
expect(usage).toStrictEqual({ tokens: 25 });
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
it('should pass the model parameters correctly', async () => {
|
|
121
|
+
prepareJsonResponse();
|
|
122
|
+
|
|
123
|
+
await model.doEmbed({
|
|
124
|
+
values: testValues,
|
|
125
|
+
providerOptions: { google: mockProviderOptions },
|
|
126
|
+
});
|
|
127
|
+
|
|
128
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
129
|
+
instances: testValues.map(value => ({
|
|
130
|
+
content: value,
|
|
131
|
+
task_type: mockProviderOptions.taskType,
|
|
132
|
+
title: mockProviderOptions.title,
|
|
133
|
+
})),
|
|
134
|
+
parameters: {
|
|
135
|
+
outputDimensionality: mockProviderOptions.outputDimensionality,
|
|
136
|
+
autoTruncate: mockProviderOptions.autoTruncate,
|
|
137
|
+
},
|
|
138
|
+
});
|
|
139
|
+
});
|
|
140
|
+
|
|
141
|
+
it('should accept vertex as provider options key', async () => {
|
|
142
|
+
prepareJsonResponse();
|
|
143
|
+
|
|
144
|
+
await model.doEmbed({
|
|
145
|
+
values: testValues,
|
|
146
|
+
providerOptions: { vertex: mockProviderOptions },
|
|
147
|
+
});
|
|
148
|
+
|
|
149
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
150
|
+
instances: testValues.map(value => ({
|
|
151
|
+
content: value,
|
|
152
|
+
task_type: mockProviderOptions.taskType,
|
|
153
|
+
title: mockProviderOptions.title,
|
|
154
|
+
})),
|
|
155
|
+
parameters: {
|
|
156
|
+
outputDimensionality: mockProviderOptions.outputDimensionality,
|
|
157
|
+
autoTruncate: mockProviderOptions.autoTruncate,
|
|
158
|
+
},
|
|
159
|
+
});
|
|
160
|
+
});
|
|
161
|
+
|
|
162
|
+
it('should pass the taskType setting in instances', async () => {
|
|
163
|
+
prepareJsonResponse();
|
|
164
|
+
|
|
165
|
+
await model.doEmbed({
|
|
166
|
+
values: testValues,
|
|
167
|
+
providerOptions: { google: { taskType: mockProviderOptions.taskType } },
|
|
168
|
+
});
|
|
169
|
+
|
|
170
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
171
|
+
instances: testValues.map(value => ({
|
|
172
|
+
content: value,
|
|
173
|
+
task_type: mockProviderOptions.taskType,
|
|
174
|
+
})),
|
|
175
|
+
parameters: {},
|
|
176
|
+
});
|
|
177
|
+
});
|
|
178
|
+
|
|
179
|
+
it('should pass the title setting in instances', async () => {
|
|
180
|
+
prepareJsonResponse();
|
|
181
|
+
|
|
182
|
+
await model.doEmbed({
|
|
183
|
+
values: testValues,
|
|
184
|
+
providerOptions: { google: { title: mockProviderOptions.title } },
|
|
185
|
+
});
|
|
186
|
+
|
|
187
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
188
|
+
instances: testValues.map(value => ({
|
|
189
|
+
content: value,
|
|
190
|
+
title: mockProviderOptions.title,
|
|
191
|
+
})),
|
|
192
|
+
parameters: {},
|
|
193
|
+
});
|
|
194
|
+
});
|
|
195
|
+
|
|
196
|
+
// changed test to go through the provider `createVertex`
|
|
197
|
+
it('should pass headers correctly', async () => {
|
|
198
|
+
prepareJsonResponse();
|
|
199
|
+
|
|
200
|
+
const provider = createVertex({
|
|
201
|
+
project: 'test-project',
|
|
202
|
+
location: 'us-central1',
|
|
203
|
+
headers: { 'X-Custom-Header': 'custom-value' },
|
|
204
|
+
});
|
|
205
|
+
|
|
206
|
+
await provider.embeddingModel(mockModelId).doEmbed({
|
|
207
|
+
values: testValues,
|
|
208
|
+
headers: { 'X-Request-Header': 'request-value' },
|
|
209
|
+
providerOptions: { google: mockProviderOptions },
|
|
210
|
+
});
|
|
211
|
+
|
|
212
|
+
expect(server.calls[0].requestHeaders).toStrictEqual({
|
|
213
|
+
'content-type': 'application/json',
|
|
214
|
+
'x-custom-header': 'custom-value',
|
|
215
|
+
'x-request-header': 'request-value',
|
|
216
|
+
});
|
|
217
|
+
expect(server.calls[0].requestUserAgent).toContain(
|
|
218
|
+
`ai-sdk/google-vertex/0.0.0-test`,
|
|
219
|
+
);
|
|
220
|
+
});
|
|
221
|
+
|
|
222
|
+
it('should throw TooManyEmbeddingValuesForCallError when too many values provided', async () => {
|
|
223
|
+
const tooManyValues = Array(2049).fill('test');
|
|
224
|
+
|
|
225
|
+
await expect(
|
|
226
|
+
model.doEmbed({
|
|
227
|
+
values: tooManyValues,
|
|
228
|
+
providerOptions: { google: mockProviderOptions },
|
|
229
|
+
}),
|
|
230
|
+
).rejects.toThrow(TooManyEmbeddingValuesForCallError);
|
|
231
|
+
});
|
|
232
|
+
|
|
233
|
+
it('should use custom baseURL when provided', async () => {
|
|
234
|
+
server.urls[CUSTOM_URL].response = {
|
|
235
|
+
type: 'json-value',
|
|
236
|
+
body: {
|
|
237
|
+
predictions: dummyEmbeddings.map(values => ({
|
|
238
|
+
embeddings: {
|
|
239
|
+
values,
|
|
240
|
+
statistics: { token_count: 1 },
|
|
241
|
+
},
|
|
242
|
+
})),
|
|
243
|
+
},
|
|
244
|
+
};
|
|
245
|
+
|
|
246
|
+
const modelWithCustomUrl = new GoogleVertexEmbeddingModel(
|
|
247
|
+
'textembedding-gecko@001',
|
|
248
|
+
{
|
|
249
|
+
headers: () => ({}),
|
|
250
|
+
baseURL: 'https://custom-endpoint.com',
|
|
251
|
+
provider: 'google-vertex',
|
|
252
|
+
},
|
|
253
|
+
);
|
|
254
|
+
|
|
255
|
+
const response = await modelWithCustomUrl.doEmbed({
|
|
256
|
+
values: testValues,
|
|
257
|
+
providerOptions: {
|
|
258
|
+
google: { outputDimensionality: 768 },
|
|
259
|
+
},
|
|
260
|
+
});
|
|
261
|
+
|
|
262
|
+
expect(response.embeddings).toStrictEqual(dummyEmbeddings);
|
|
263
|
+
|
|
264
|
+
expect(server.calls[0].requestUrl).toBe(
|
|
265
|
+
'https://custom-endpoint.com/models/textembedding-gecko@001:predict',
|
|
266
|
+
);
|
|
267
|
+
});
|
|
268
|
+
|
|
269
|
+
it('should use custom fetch when provided and include proper request content', async () => {
|
|
270
|
+
const customFetch = vi.fn().mockResolvedValue(
|
|
271
|
+
new Response(
|
|
272
|
+
JSON.stringify({
|
|
273
|
+
predictions: dummyEmbeddings.map(values => ({
|
|
274
|
+
embeddings: {
|
|
275
|
+
values,
|
|
276
|
+
statistics: { token_count: 1 },
|
|
277
|
+
},
|
|
278
|
+
})),
|
|
279
|
+
}),
|
|
280
|
+
),
|
|
281
|
+
);
|
|
282
|
+
|
|
283
|
+
const modelWithCustomFetch = new GoogleVertexEmbeddingModel(
|
|
284
|
+
'textembedding-gecko@001',
|
|
285
|
+
|
|
286
|
+
{
|
|
287
|
+
headers: () => ({}),
|
|
288
|
+
baseURL: 'https://custom-endpoint.com',
|
|
289
|
+
provider: 'google-vertex',
|
|
290
|
+
fetch: customFetch,
|
|
291
|
+
},
|
|
292
|
+
);
|
|
293
|
+
|
|
294
|
+
const response = await modelWithCustomFetch.doEmbed({
|
|
295
|
+
values: testValues,
|
|
296
|
+
providerOptions: {
|
|
297
|
+
google: { outputDimensionality: 768 },
|
|
298
|
+
},
|
|
299
|
+
});
|
|
300
|
+
|
|
301
|
+
expect(response.embeddings).toStrictEqual(dummyEmbeddings);
|
|
302
|
+
|
|
303
|
+
expect(customFetch).toHaveBeenCalledWith(CUSTOM_URL, expect.any(Object));
|
|
304
|
+
|
|
305
|
+
const [_, secondArgument] = customFetch.mock.calls[0];
|
|
306
|
+
const requestBody = JSON.parse(secondArgument.body);
|
|
307
|
+
|
|
308
|
+
expect(requestBody).toStrictEqual({
|
|
309
|
+
instances: testValues.map(value => ({ content: value })),
|
|
310
|
+
parameters: {
|
|
311
|
+
outputDimensionality: 768,
|
|
312
|
+
},
|
|
313
|
+
});
|
|
314
|
+
});
|
|
315
|
+
});
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import {
|
|
2
|
+
EmbeddingModelV3,
|
|
3
|
+
TooManyEmbeddingValuesForCallError,
|
|
4
|
+
} from '@ai-sdk/provider';
|
|
5
|
+
import {
|
|
6
|
+
combineHeaders,
|
|
7
|
+
createJsonResponseHandler,
|
|
8
|
+
postJsonToApi,
|
|
9
|
+
resolve,
|
|
10
|
+
parseProviderOptions,
|
|
11
|
+
} from '@ai-sdk/provider-utils';
|
|
12
|
+
import { z } from 'zod/v4';
|
|
13
|
+
import { googleVertexFailedResponseHandler } from './google-vertex-error';
|
|
14
|
+
import {
|
|
15
|
+
GoogleVertexEmbeddingModelId,
|
|
16
|
+
googleVertexEmbeddingProviderOptions,
|
|
17
|
+
} from './google-vertex-embedding-options';
|
|
18
|
+
import { GoogleVertexConfig } from './google-vertex-config';
|
|
19
|
+
|
|
20
|
+
export class GoogleVertexEmbeddingModel implements EmbeddingModelV3 {
|
|
21
|
+
readonly specificationVersion = 'v3';
|
|
22
|
+
readonly modelId: GoogleVertexEmbeddingModelId;
|
|
23
|
+
readonly maxEmbeddingsPerCall = 2048;
|
|
24
|
+
readonly supportsParallelCalls = true;
|
|
25
|
+
|
|
26
|
+
private readonly config: GoogleVertexConfig;
|
|
27
|
+
|
|
28
|
+
get provider(): string {
|
|
29
|
+
return this.config.provider;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
constructor(
|
|
33
|
+
modelId: GoogleVertexEmbeddingModelId,
|
|
34
|
+
config: GoogleVertexConfig,
|
|
35
|
+
) {
|
|
36
|
+
this.modelId = modelId;
|
|
37
|
+
this.config = config;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
async doEmbed({
|
|
41
|
+
values,
|
|
42
|
+
headers,
|
|
43
|
+
abortSignal,
|
|
44
|
+
providerOptions,
|
|
45
|
+
}: Parameters<EmbeddingModelV3['doEmbed']>[0]): Promise<
|
|
46
|
+
Awaited<ReturnType<EmbeddingModelV3['doEmbed']>>
|
|
47
|
+
> {
|
|
48
|
+
let googleOptions = await parseProviderOptions({
|
|
49
|
+
provider: 'vertex',
|
|
50
|
+
providerOptions,
|
|
51
|
+
schema: googleVertexEmbeddingProviderOptions,
|
|
52
|
+
});
|
|
53
|
+
|
|
54
|
+
if (googleOptions == null) {
|
|
55
|
+
googleOptions = await parseProviderOptions({
|
|
56
|
+
provider: 'google',
|
|
57
|
+
providerOptions,
|
|
58
|
+
schema: googleVertexEmbeddingProviderOptions,
|
|
59
|
+
});
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
googleOptions = googleOptions ?? {};
|
|
63
|
+
|
|
64
|
+
if (values.length > this.maxEmbeddingsPerCall) {
|
|
65
|
+
throw new TooManyEmbeddingValuesForCallError({
|
|
66
|
+
provider: this.provider,
|
|
67
|
+
modelId: this.modelId,
|
|
68
|
+
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
|
|
69
|
+
values,
|
|
70
|
+
});
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
const mergedHeaders = combineHeaders(
|
|
74
|
+
await resolve(this.config.headers),
|
|
75
|
+
headers,
|
|
76
|
+
);
|
|
77
|
+
|
|
78
|
+
const url = `${this.config.baseURL}/models/${this.modelId}:predict`;
|
|
79
|
+
const {
|
|
80
|
+
responseHeaders,
|
|
81
|
+
value: response,
|
|
82
|
+
rawValue,
|
|
83
|
+
} = await postJsonToApi({
|
|
84
|
+
url,
|
|
85
|
+
headers: mergedHeaders,
|
|
86
|
+
body: {
|
|
87
|
+
instances: values.map(value => ({
|
|
88
|
+
content: value,
|
|
89
|
+
task_type: googleOptions.taskType,
|
|
90
|
+
title: googleOptions.title,
|
|
91
|
+
})),
|
|
92
|
+
parameters: {
|
|
93
|
+
outputDimensionality: googleOptions.outputDimensionality,
|
|
94
|
+
autoTruncate: googleOptions.autoTruncate,
|
|
95
|
+
},
|
|
96
|
+
},
|
|
97
|
+
failedResponseHandler: googleVertexFailedResponseHandler,
|
|
98
|
+
successfulResponseHandler: createJsonResponseHandler(
|
|
99
|
+
googleVertexTextEmbeddingResponseSchema,
|
|
100
|
+
),
|
|
101
|
+
abortSignal,
|
|
102
|
+
fetch: this.config.fetch,
|
|
103
|
+
});
|
|
104
|
+
|
|
105
|
+
return {
|
|
106
|
+
warnings: [],
|
|
107
|
+
embeddings: response.predictions.map(
|
|
108
|
+
prediction => prediction.embeddings.values,
|
|
109
|
+
),
|
|
110
|
+
usage: {
|
|
111
|
+
tokens: response.predictions.reduce(
|
|
112
|
+
(tokenCount, prediction) =>
|
|
113
|
+
tokenCount + prediction.embeddings.statistics.token_count,
|
|
114
|
+
0,
|
|
115
|
+
),
|
|
116
|
+
},
|
|
117
|
+
response: { headers: responseHeaders, body: rawValue },
|
|
118
|
+
};
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
// minimal version of the schema, focussed on what is needed for the implementation
|
|
123
|
+
// this approach limits breakages when the API changes and increases efficiency
|
|
124
|
+
const googleVertexTextEmbeddingResponseSchema = z.object({
|
|
125
|
+
predictions: z.array(
|
|
126
|
+
z.object({
|
|
127
|
+
embeddings: z.object({
|
|
128
|
+
values: z.array(z.number()),
|
|
129
|
+
statistics: z.object({
|
|
130
|
+
token_count: z.number(),
|
|
131
|
+
}),
|
|
132
|
+
}),
|
|
133
|
+
}),
|
|
134
|
+
),
|
|
135
|
+
});
|