@ai-sdk/cohere 3.0.7 → 3.0.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,169 @@
1
+ import {
2
+ EmbeddingModelV3,
3
+ LanguageModelV3,
4
+ NoSuchModelError,
5
+ RerankingModelV3,
6
+ ProviderV3,
7
+ } from '@ai-sdk/provider';
8
+
9
+ import {
10
+ FetchFunction,
11
+ generateId,
12
+ loadApiKey,
13
+ withoutTrailingSlash,
14
+ withUserAgentSuffix,
15
+ } from '@ai-sdk/provider-utils';
16
+ import { CohereChatLanguageModel } from './cohere-chat-language-model';
17
+ import { CohereChatModelId } from './cohere-chat-options';
18
+ import { CohereEmbeddingModel } from './cohere-embedding-model';
19
+ import { CohereRerankingModelId } from './reranking/cohere-reranking-options';
20
+ import { CohereRerankingModel } from './reranking/cohere-reranking-model';
21
+ import { CohereEmbeddingModelId } from './cohere-embedding-options';
22
+ import { VERSION } from './version';
23
+
24
+ export interface CohereProvider extends ProviderV3 {
25
+ (modelId: CohereChatModelId): LanguageModelV3;
26
+
27
+ /**
28
+ Creates a model for text generation.
29
+ */
30
+ languageModel(modelId: CohereChatModelId): LanguageModelV3;
31
+
32
+ /**
33
+ * Creates a model for text embeddings.
34
+ */
35
+ embedding(modelId: CohereEmbeddingModelId): EmbeddingModelV3;
36
+
37
+ /**
38
+ * Creates a model for text embeddings.
39
+ */
40
+ embeddingModel(modelId: CohereEmbeddingModelId): EmbeddingModelV3;
41
+
42
+ /**
43
+ * @deprecated Use `embedding` instead.
44
+ */
45
+ textEmbedding(modelId: CohereEmbeddingModelId): EmbeddingModelV3;
46
+
47
+ /**
48
+ * @deprecated Use `embeddingModel` instead.
49
+ */
50
+ textEmbeddingModel(modelId: CohereEmbeddingModelId): EmbeddingModelV3;
51
+
52
+ /**
53
+ * Creates a model for reranking.
54
+ */
55
+ reranking(modelId: CohereRerankingModelId): RerankingModelV3;
56
+
57
+ /**
58
+ * Creates a model for reranking.
59
+ */
60
+ rerankingModel(modelId: CohereRerankingModelId): RerankingModelV3;
61
+ }
62
+
63
+ export interface CohereProviderSettings {
64
+ /**
65
+ Use a different URL prefix for API calls, e.g. to use proxy servers.
66
+ The default prefix is `https://api.cohere.com/v2`.
67
+ */
68
+ baseURL?: string;
69
+
70
+ /**
71
+ API key that is being send using the `Authorization` header.
72
+ It defaults to the `COHERE_API_KEY` environment variable.
73
+ */
74
+ apiKey?: string;
75
+
76
+ /**
77
+ Custom headers to include in the requests.
78
+ */
79
+ headers?: Record<string, string>;
80
+
81
+ /**
82
+ Custom fetch implementation. You can use it as a middleware to intercept requests,
83
+ or to provide a custom fetch implementation for e.g. testing.
84
+ */
85
+ fetch?: FetchFunction;
86
+
87
+ /**
88
+ Optional function to generate a unique ID for each request.
89
+ */
90
+ generateId?: () => string;
91
+ }
92
+
93
+ /**
94
+ Create a Cohere AI provider instance.
95
+ */
96
+ export function createCohere(
97
+ options: CohereProviderSettings = {},
98
+ ): CohereProvider {
99
+ const baseURL =
100
+ withoutTrailingSlash(options.baseURL) ?? 'https://api.cohere.com/v2';
101
+
102
+ const getHeaders = () =>
103
+ withUserAgentSuffix(
104
+ {
105
+ Authorization: `Bearer ${loadApiKey({
106
+ apiKey: options.apiKey,
107
+ environmentVariableName: 'COHERE_API_KEY',
108
+ description: 'Cohere',
109
+ })}`,
110
+ ...options.headers,
111
+ },
112
+ `ai-sdk/cohere/${VERSION}`,
113
+ );
114
+
115
+ const createChatModel = (modelId: CohereChatModelId) =>
116
+ new CohereChatLanguageModel(modelId, {
117
+ provider: 'cohere.chat',
118
+ baseURL,
119
+ headers: getHeaders,
120
+ fetch: options.fetch,
121
+ generateId: options.generateId ?? generateId,
122
+ });
123
+
124
+ const createEmbeddingModel = (modelId: CohereEmbeddingModelId) =>
125
+ new CohereEmbeddingModel(modelId, {
126
+ provider: 'cohere.textEmbedding',
127
+ baseURL,
128
+ headers: getHeaders,
129
+ fetch: options.fetch,
130
+ });
131
+
132
+ const createRerankingModel = (modelId: CohereRerankingModelId) =>
133
+ new CohereRerankingModel(modelId, {
134
+ provider: 'cohere.reranking',
135
+ baseURL,
136
+ headers: getHeaders,
137
+ fetch: options.fetch,
138
+ });
139
+
140
+ const provider = function (modelId: CohereChatModelId) {
141
+ if (new.target) {
142
+ throw new Error(
143
+ 'The Cohere model function cannot be called with the new keyword.',
144
+ );
145
+ }
146
+
147
+ return createChatModel(modelId);
148
+ };
149
+
150
+ provider.specificationVersion = 'v3' as const;
151
+ provider.languageModel = createChatModel;
152
+ provider.embedding = createEmbeddingModel;
153
+ provider.embeddingModel = createEmbeddingModel;
154
+ provider.textEmbedding = createEmbeddingModel;
155
+ provider.textEmbeddingModel = createEmbeddingModel;
156
+ provider.reranking = createRerankingModel;
157
+ provider.rerankingModel = createRerankingModel;
158
+
159
+ provider.imageModel = (modelId: string) => {
160
+ throw new NoSuchModelError({ modelId, modelType: 'imageModel' });
161
+ };
162
+
163
+ return provider;
164
+ }
165
+
166
+ /**
167
+ Default Cohere provider instance.
168
+ */
169
+ export const cohere = createCohere();
@@ -0,0 +1,45 @@
1
+ import { LanguageModelV3Usage } from '@ai-sdk/provider';
2
+
3
+ export type CohereUsageTokens = {
4
+ input_tokens: number;
5
+ output_tokens: number;
6
+ };
7
+
8
+ export function convertCohereUsage(
9
+ tokens: CohereUsageTokens | undefined | null,
10
+ ): LanguageModelV3Usage {
11
+ if (tokens == null) {
12
+ return {
13
+ inputTokens: {
14
+ total: undefined,
15
+ noCache: undefined,
16
+ cacheRead: undefined,
17
+ cacheWrite: undefined,
18
+ },
19
+ outputTokens: {
20
+ total: undefined,
21
+ text: undefined,
22
+ reasoning: undefined,
23
+ },
24
+ raw: undefined,
25
+ };
26
+ }
27
+
28
+ const inputTokens = tokens.input_tokens;
29
+ const outputTokens = tokens.output_tokens;
30
+
31
+ return {
32
+ inputTokens: {
33
+ total: inputTokens,
34
+ noCache: inputTokens,
35
+ cacheRead: undefined,
36
+ cacheWrite: undefined,
37
+ },
38
+ outputTokens: {
39
+ total: outputTokens,
40
+ text: outputTokens,
41
+ reasoning: undefined,
42
+ },
43
+ raw: tokens,
44
+ };
45
+ }
@@ -0,0 +1,175 @@
1
+ import { convertToCohereChatPrompt } from './convert-to-cohere-chat-prompt';
2
+ import { describe, it, expect } from 'vitest';
3
+
4
+ describe('convert to cohere chat prompt', () => {
5
+ describe('file processing', () => {
6
+ it('should extract documents from file parts', () => {
7
+ const result = convertToCohereChatPrompt([
8
+ {
9
+ role: 'user',
10
+ content: [
11
+ { type: 'text', text: 'Analyze this file: ' },
12
+ {
13
+ type: 'file',
14
+ data: Buffer.from('This is file content'),
15
+ mediaType: 'text/plain',
16
+ filename: 'test.txt',
17
+ },
18
+ ],
19
+ },
20
+ ]);
21
+
22
+ expect(result).toEqual({
23
+ messages: [
24
+ {
25
+ role: 'user',
26
+ content: 'Analyze this file: ',
27
+ },
28
+ ],
29
+ documents: [
30
+ {
31
+ data: {
32
+ text: 'This is file content',
33
+ title: 'test.txt',
34
+ },
35
+ },
36
+ ],
37
+ warnings: [],
38
+ });
39
+ });
40
+
41
+ it('should throw error for unsupported media types', () => {
42
+ expect(() => {
43
+ convertToCohereChatPrompt([
44
+ {
45
+ role: 'user',
46
+ content: [
47
+ {
48
+ type: 'file',
49
+ data: Buffer.from('PDF content'),
50
+ mediaType: 'application/pdf',
51
+ filename: 'test.pdf',
52
+ },
53
+ ],
54
+ },
55
+ ]);
56
+ }).toThrow("Media type 'application/pdf' is not supported");
57
+ });
58
+ });
59
+
60
+ describe('tool messages', () => {
61
+ it('should convert a tool call into a cohere chatbot message', async () => {
62
+ const result = convertToCohereChatPrompt([
63
+ {
64
+ role: 'assistant',
65
+ content: [
66
+ {
67
+ type: 'text',
68
+ text: 'Calling a tool',
69
+ },
70
+ {
71
+ type: 'tool-call',
72
+ toolName: 'tool-1',
73
+ toolCallId: 'tool-call-1',
74
+ input: { test: 'This is a tool message' },
75
+ },
76
+ ],
77
+ },
78
+ ]);
79
+
80
+ expect(result).toEqual({
81
+ messages: [
82
+ {
83
+ content: undefined,
84
+ role: 'assistant',
85
+ tool_calls: [
86
+ {
87
+ id: 'tool-call-1',
88
+ type: 'function',
89
+ function: {
90
+ name: 'tool-1',
91
+ arguments: JSON.stringify({ test: 'This is a tool message' }),
92
+ },
93
+ },
94
+ ],
95
+ },
96
+ ],
97
+ documents: [],
98
+ warnings: [],
99
+ });
100
+ });
101
+
102
+ it('should convert a single tool result into a cohere tool message', async () => {
103
+ const result = convertToCohereChatPrompt([
104
+ {
105
+ role: 'tool',
106
+ content: [
107
+ {
108
+ type: 'tool-result',
109
+ toolName: 'tool-1',
110
+ toolCallId: 'tool-call-1',
111
+ output: {
112
+ type: 'json',
113
+ value: { test: 'This is a tool message' },
114
+ },
115
+ },
116
+ ],
117
+ },
118
+ ]);
119
+
120
+ expect(result).toEqual({
121
+ messages: [
122
+ {
123
+ role: 'tool',
124
+ content: JSON.stringify({ test: 'This is a tool message' }),
125
+ tool_call_id: 'tool-call-1',
126
+ },
127
+ ],
128
+ documents: [],
129
+ warnings: [],
130
+ });
131
+ });
132
+
133
+ it('should convert multiple tool results into a cohere tool message', async () => {
134
+ const result = convertToCohereChatPrompt([
135
+ {
136
+ role: 'tool',
137
+ content: [
138
+ {
139
+ type: 'tool-result',
140
+ toolName: 'tool-1',
141
+ toolCallId: 'tool-call-1',
142
+ output: {
143
+ type: 'json',
144
+ value: { test: 'This is a tool message' },
145
+ },
146
+ },
147
+ {
148
+ type: 'tool-result',
149
+ toolName: 'tool-2',
150
+ toolCallId: 'tool-call-2',
151
+ output: { type: 'json', value: { something: 'else' } },
152
+ },
153
+ ],
154
+ },
155
+ ]);
156
+
157
+ expect(result).toEqual({
158
+ messages: [
159
+ {
160
+ role: 'tool',
161
+ content: JSON.stringify({ test: 'This is a tool message' }),
162
+ tool_call_id: 'tool-call-1',
163
+ },
164
+ {
165
+ role: 'tool',
166
+ content: JSON.stringify({ something: 'else' }),
167
+ tool_call_id: 'tool-call-2',
168
+ },
169
+ ],
170
+ documents: [],
171
+ warnings: [],
172
+ });
173
+ });
174
+ });
175
+ });
@@ -0,0 +1,156 @@
1
+ import {
2
+ SharedV3Warning,
3
+ LanguageModelV3Prompt,
4
+ UnsupportedFunctionalityError,
5
+ } from '@ai-sdk/provider';
6
+ import { CohereAssistantMessage, CohereChatPrompt } from './cohere-chat-prompt';
7
+
8
+ export function convertToCohereChatPrompt(prompt: LanguageModelV3Prompt): {
9
+ messages: CohereChatPrompt;
10
+ documents: Array<{
11
+ data: { text: string; title?: string };
12
+ }>;
13
+ warnings: SharedV3Warning[];
14
+ } {
15
+ const messages: CohereChatPrompt = [];
16
+ const documents: Array<{ data: { text: string; title?: string } }> = [];
17
+ const warnings: SharedV3Warning[] = [];
18
+
19
+ for (const { role, content } of prompt) {
20
+ switch (role) {
21
+ case 'system': {
22
+ messages.push({ role: 'system', content });
23
+ break;
24
+ }
25
+
26
+ case 'user': {
27
+ messages.push({
28
+ role: 'user',
29
+ content: content
30
+ .map(part => {
31
+ switch (part.type) {
32
+ case 'text': {
33
+ return part.text;
34
+ }
35
+ case 'file': {
36
+ // Extract documents for RAG
37
+ let textContent: string;
38
+
39
+ if (typeof part.data === 'string') {
40
+ // Base64 or text data
41
+ textContent = part.data;
42
+ } else if (part.data instanceof Uint8Array) {
43
+ // Check if the media type is supported for text extraction
44
+ if (
45
+ !(
46
+ part.mediaType?.startsWith('text/') ||
47
+ part.mediaType === 'application/json'
48
+ )
49
+ ) {
50
+ throw new UnsupportedFunctionalityError({
51
+ functionality: `document media type: ${part.mediaType}`,
52
+ message: `Media type '${part.mediaType}' is not supported. Supported media types are: text/* and application/json.`,
53
+ });
54
+ }
55
+ textContent = new TextDecoder().decode(part.data);
56
+ } else {
57
+ throw new UnsupportedFunctionalityError({
58
+ functionality: 'File URL data',
59
+ message:
60
+ 'URLs should be downloaded by the AI SDK and not reach this point. This indicates a configuration issue.',
61
+ });
62
+ }
63
+
64
+ documents.push({
65
+ data: {
66
+ text: textContent,
67
+ title: part.filename,
68
+ },
69
+ });
70
+
71
+ // Files are handled separately via the documents parameter
72
+ // Return empty string to not include file content in message text
73
+ return '';
74
+ }
75
+ }
76
+ })
77
+ .join(''),
78
+ });
79
+ break;
80
+ }
81
+
82
+ case 'assistant': {
83
+ let text = '';
84
+ const toolCalls: CohereAssistantMessage['tool_calls'] = [];
85
+
86
+ for (const part of content) {
87
+ switch (part.type) {
88
+ case 'text': {
89
+ text += part.text;
90
+ break;
91
+ }
92
+ case 'tool-call': {
93
+ toolCalls.push({
94
+ id: part.toolCallId,
95
+ type: 'function' as const,
96
+ function: {
97
+ name: part.toolName,
98
+ arguments: JSON.stringify(part.input),
99
+ },
100
+ });
101
+ break;
102
+ }
103
+ }
104
+ }
105
+
106
+ messages.push({
107
+ role: 'assistant',
108
+ content: toolCalls.length > 0 ? undefined : text,
109
+ tool_calls: toolCalls.length > 0 ? toolCalls : undefined,
110
+ tool_plan: undefined,
111
+ });
112
+
113
+ break;
114
+ }
115
+ case 'tool': {
116
+ messages.push(
117
+ ...content
118
+ .filter(toolResult => toolResult.type !== 'tool-approval-response')
119
+ .map(toolResult => {
120
+ const output = toolResult.output;
121
+
122
+ let contentValue: string;
123
+ switch (output.type) {
124
+ case 'text':
125
+ case 'error-text':
126
+ contentValue = output.value;
127
+ break;
128
+ case 'execution-denied':
129
+ contentValue = output.reason ?? 'Tool execution denied.';
130
+ break;
131
+ case 'content':
132
+ case 'json':
133
+ case 'error-json':
134
+ contentValue = JSON.stringify(output.value);
135
+ break;
136
+ }
137
+
138
+ return {
139
+ role: 'tool' as const,
140
+ content: contentValue,
141
+ tool_call_id: toolResult.toolCallId,
142
+ };
143
+ }),
144
+ );
145
+
146
+ break;
147
+ }
148
+ default: {
149
+ const _exhaustiveCheck: never = role;
150
+ throw new Error(`Unsupported role: ${_exhaustiveCheck}`);
151
+ }
152
+ }
153
+ }
154
+
155
+ return { messages, documents, warnings };
156
+ }
package/src/index.ts ADDED
@@ -0,0 +1,5 @@
1
+ export type { CohereChatModelOptions } from './cohere-chat-options';
2
+ export { cohere, createCohere } from './cohere-provider';
3
+ export type { CohereProvider, CohereProviderSettings } from './cohere-provider';
4
+ export type { CohereRerankingOptions } from './reranking/cohere-reranking-options';
5
+ export { VERSION } from './version';
@@ -0,0 +1,23 @@
1
+ import { LanguageModelV3FinishReason } from '@ai-sdk/provider';
2
+
3
+ export function mapCohereFinishReason(
4
+ finishReason: string | null | undefined,
5
+ ): LanguageModelV3FinishReason['unified'] {
6
+ switch (finishReason) {
7
+ case 'COMPLETE':
8
+ case 'STOP_SEQUENCE':
9
+ return 'stop';
10
+
11
+ case 'MAX_TOKENS':
12
+ return 'length';
13
+
14
+ case 'ERROR':
15
+ return 'error';
16
+
17
+ case 'TOOL_CALL':
18
+ return 'tool-calls';
19
+
20
+ default:
21
+ return 'other';
22
+ }
23
+ }
@@ -0,0 +1,21 @@
1
+ {
2
+ "id": "b44fe75b-e3d3-489a-b61e-1a1aede3ef72",
3
+ "results": [
4
+ {
5
+ "index": 1,
6
+ "relevance_score": 0.10183054
7
+ },
8
+ {
9
+ "index": 0,
10
+ "relevance_score": 0.03762639
11
+ }
12
+ ],
13
+ "meta": {
14
+ "api_version": {
15
+ "version": "2"
16
+ },
17
+ "billed_units": {
18
+ "search_units": 1
19
+ }
20
+ }
21
+ }
@@ -0,0 +1,27 @@
1
+ import { lazySchema, zodSchema } from '@ai-sdk/provider-utils';
2
+ import { z } from 'zod/v4';
3
+
4
+ // https://docs.cohere.com/v2/reference/rerank
5
+ export type CohereRerankingInput = {
6
+ model: string;
7
+ query: string;
8
+ documents: string[];
9
+ top_n: number | undefined;
10
+ max_tokens_per_doc: number | undefined;
11
+ priority: number | undefined;
12
+ };
13
+
14
+ export const cohereRerankingResponseSchema = lazySchema(() =>
15
+ zodSchema(
16
+ z.object({
17
+ id: z.string().nullish(),
18
+ results: z.array(
19
+ z.object({
20
+ index: z.number(),
21
+ relevance_score: z.number(),
22
+ }),
23
+ ),
24
+ meta: z.any(),
25
+ }),
26
+ ),
27
+ );