sambanova 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +117 -0
- package/dist/client.d.ts +18 -0
- package/dist/client.js +127 -0
- package/dist/tests/client.test.d.ts +1 -0
- package/dist/tests/client.test.js +112 -0
- package/dist/types.d.ts +39 -0
- package/dist/types.js +13 -0
- package/dist/utils.d.ts +4 -0
- package/dist/utils.js +21 -0
- package/jest.config.js +13 -0
- package/package.json +36 -0
- package/src/client.ts +190 -0
- package/src/tests/client.test.ts +135 -0
- package/src/types.ts +56 -0
- package/src/utils.ts +25 -0
- package/tsconfig.json +14 -0
package/README.md
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Sambanova JavaScript/TypeScript Client
|
|
2
|
+
|
|
3
|
+
A JavaScript/TypeScript client for the Sambanova AI API. This package provides an easy way to interact with Sambanova's language and vision models.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
npm install sambanova-js
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Quick Start
|
|
12
|
+
|
|
13
|
+
```javascript
|
|
14
|
+
import { SambanovaClient } from 'sambanova-js';
|
|
15
|
+
|
|
16
|
+
// Initialize the client
|
|
17
|
+
const client = new SambanovaClient('YOUR_API_KEY');
|
|
18
|
+
|
|
19
|
+
// Text completion
|
|
20
|
+
async function textExample() {
|
|
21
|
+
const response = await client.chat([
|
|
22
|
+
{ role: 'user', content: 'Hello!' }
|
|
23
|
+
], {
|
|
24
|
+
model: 'Meta-Llama-3.2-3B-Instruct'
|
|
25
|
+
});
|
|
26
|
+
console.log(response.choices[0].message.content);
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
// Vision analysis
|
|
30
|
+
async function visionExample() {
|
|
31
|
+
const response = await client.chat([
|
|
32
|
+
{
|
|
33
|
+
role: 'user',
|
|
34
|
+
content: [
|
|
35
|
+
{ type: 'text', text: 'What is in this image?' },
|
|
36
|
+
{ type: 'image_url', image_url: { url: 'your_image_url_here' }}
|
|
37
|
+
]
|
|
38
|
+
}
|
|
39
|
+
], {
|
|
40
|
+
model: 'Llama-3.2-11B-Vision-Instruct'
|
|
41
|
+
});
|
|
42
|
+
console.log(response.choices[0].message.content);
|
|
43
|
+
}
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
## Features
|
|
47
|
+
|
|
48
|
+
- Support for all Sambanova language and vision models
|
|
49
|
+
- TypeScript support with full type definitions
|
|
50
|
+
- Error handling and automatic retries
|
|
51
|
+
- Streaming support
|
|
52
|
+
- Vision model integration
|
|
53
|
+
|
|
54
|
+
## Supported Models
|
|
55
|
+
|
|
56
|
+
### Vision Models
|
|
57
|
+
- `Llama-3.2-11B-Vision-Instruct`
|
|
58
|
+
- `Llama-3.2-90B-Vision-Instruct`
|
|
59
|
+
|
|
60
|
+
### Language Models
|
|
61
|
+
- `Meta-Llama-3.1-8B-Instruct`
|
|
62
|
+
- `Meta-Llama-3.1-70B-Instruct`
|
|
63
|
+
- `Meta-Llama-3.1-405B-Instruct`
|
|
64
|
+
- `Meta-Llama-3.2-1B-Instruct`
|
|
65
|
+
- `Meta-Llama-3.2-3B-Instruct`
|
|
66
|
+
|
|
67
|
+
## Advanced Usage
|
|
68
|
+
|
|
69
|
+
### Streaming
|
|
70
|
+
|
|
71
|
+
```javascript
|
|
72
|
+
(async () => {
|
|
73
|
+
try {
|
|
74
|
+
for await (const chunk of client.streamChat([
|
|
75
|
+
{ role: 'user', content: 'Tell me a story' }
|
|
76
|
+
])) {
|
|
77
|
+
process.stdout.write(chunk.choices[0].message.content);
|
|
78
|
+
}
|
|
79
|
+
} catch (error) {
|
|
80
|
+
console.error('Stream Chat Error:', error);
|
|
81
|
+
}
|
|
82
|
+
})();
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
### Error Handling
|
|
86
|
+
|
|
87
|
+
```javascript
|
|
88
|
+
import { SambanovaClient, SambanovaError } from 'sambanova-js';
|
|
89
|
+
|
|
90
|
+
try {
|
|
91
|
+
const response = await client.chat([
|
|
92
|
+
{ role: 'user', content: 'Hello' }
|
|
93
|
+
]);
|
|
94
|
+
console.log(response.choices[0].message.content);
|
|
95
|
+
} catch (error) {
|
|
96
|
+
if (error instanceof SambanovaError) {
|
|
97
|
+
console.error(`API Error: ${error.message} (Code: ${error.code})`);
|
|
98
|
+
} else {
|
|
99
|
+
console.error('Unexpected error:', error);
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
## Configuration Options
|
|
105
|
+
|
|
106
|
+
```javascript
|
|
107
|
+
const client = new SambanovaClient('YOUR_API_KEY', {
|
|
108
|
+
baseUrl: 'https://api.sambanova.ai/v1',
|
|
109
|
+
defaultModel: 'Meta-Llama-3.2-3B-Instruct',
|
|
110
|
+
defaultRetryCount: 3,
|
|
111
|
+
defaultRetryDelay: 1000
|
|
112
|
+
});
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
## License
|
|
116
|
+
|
|
117
|
+
[MIT](LICENSE)
|
package/dist/client.d.ts
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import { ModelType, ChatMessage, ChatOptions, APIResponse } from './types';
|
|
2
|
+
export declare class SambanovaClient {
|
|
3
|
+
private readonly apiKey;
|
|
4
|
+
private readonly baseUrl;
|
|
5
|
+
private readonly defaultModel;
|
|
6
|
+
private readonly defaultRetryCount;
|
|
7
|
+
private readonly defaultRetryDelay;
|
|
8
|
+
constructor(apiKey: string, options?: {
|
|
9
|
+
baseUrl?: string;
|
|
10
|
+
defaultModel?: ModelType;
|
|
11
|
+
defaultRetryCount?: number;
|
|
12
|
+
defaultRetryDelay?: number;
|
|
13
|
+
});
|
|
14
|
+
private makeRequest;
|
|
15
|
+
chat(messages: ChatMessage[], options?: ChatOptions): Promise<APIResponse>;
|
|
16
|
+
streamChat(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator<APIResponse, void, unknown>;
|
|
17
|
+
private handleStreamResponse;
|
|
18
|
+
}
|
package/dist/client.js
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.SambanovaClient = void 0;
|
|
4
|
+
const utils_1 = require("./utils");
|
|
5
|
+
const types_1 = require("./types");
|
|
6
|
+
class SambanovaClient {
|
|
7
|
+
constructor(apiKey, options = {}) {
|
|
8
|
+
this.apiKey = apiKey;
|
|
9
|
+
this.baseUrl = options.baseUrl || 'https://api.sambanova.ai/v1';
|
|
10
|
+
this.defaultModel = options.defaultModel || 'Meta-Llama-3.2-3B-Instruct';
|
|
11
|
+
this.defaultRetryCount = options.defaultRetryCount || 3;
|
|
12
|
+
this.defaultRetryDelay = options.defaultRetryDelay || 1000;
|
|
13
|
+
}
|
|
14
|
+
async makeRequest(endpoint, data, retryCount = this.defaultRetryCount, stream = false) {
|
|
15
|
+
let lastError = null;
|
|
16
|
+
for (let attempt = 0; attempt <= retryCount; attempt++) {
|
|
17
|
+
try {
|
|
18
|
+
const response = await fetch(`${this.baseUrl}${endpoint}`, {
|
|
19
|
+
method: 'POST',
|
|
20
|
+
headers: {
|
|
21
|
+
'Content-Type': 'application/json',
|
|
22
|
+
'Authorization': `Bearer ${this.apiKey}`
|
|
23
|
+
},
|
|
24
|
+
body: JSON.stringify(data)
|
|
25
|
+
});
|
|
26
|
+
if (!response.ok) {
|
|
27
|
+
const errorData = await response.json();
|
|
28
|
+
throw new types_1.SambanovaError(errorData.message || 'API request failed', response.status, errorData.code, errorData);
|
|
29
|
+
}
|
|
30
|
+
if (stream) {
|
|
31
|
+
return response;
|
|
32
|
+
}
|
|
33
|
+
return await response.json();
|
|
34
|
+
}
|
|
35
|
+
catch (error) {
|
|
36
|
+
lastError = error;
|
|
37
|
+
if (attempt < retryCount) {
|
|
38
|
+
await (0, utils_1.sleep)(this.defaultRetryDelay * Math.pow(2, attempt));
|
|
39
|
+
continue;
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
throw lastError || new Error('Request failed after retries');
|
|
44
|
+
}
|
|
45
|
+
async chat(messages, options = {}) {
|
|
46
|
+
var _a, _b, _c;
|
|
47
|
+
const model = options.model || this.defaultModel;
|
|
48
|
+
messages.forEach(msg => (0, utils_1.validateMessage)(msg, (0, utils_1.isVisionModel)(model)));
|
|
49
|
+
const payload = {
|
|
50
|
+
model,
|
|
51
|
+
messages,
|
|
52
|
+
temperature: (_a = options.temperature) !== null && _a !== void 0 ? _a : 0.1,
|
|
53
|
+
top_p: (_b = options.top_p) !== null && _b !== void 0 ? _b : 0.1,
|
|
54
|
+
max_tokens: options.max_tokens,
|
|
55
|
+
stream: (_c = options.stream) !== null && _c !== void 0 ? _c : false
|
|
56
|
+
};
|
|
57
|
+
const response = await this.makeRequest('/chat/completions', payload, options.retry_count, payload.stream);
|
|
58
|
+
if (payload.stream && response instanceof Response) {
|
|
59
|
+
throw new Error('Stream response received in chat method. Use streamChat instead.');
|
|
60
|
+
}
|
|
61
|
+
return response;
|
|
62
|
+
}
|
|
63
|
+
async *streamChat(messages, options = {}) {
|
|
64
|
+
var _a, _b;
|
|
65
|
+
const model = options.model || this.defaultModel;
|
|
66
|
+
messages.forEach(msg => (0, utils_1.validateMessage)(msg, (0, utils_1.isVisionModel)(model)));
|
|
67
|
+
const payload = {
|
|
68
|
+
model,
|
|
69
|
+
messages,
|
|
70
|
+
temperature: (_a = options.temperature) !== null && _a !== void 0 ? _a : 0.1,
|
|
71
|
+
top_p: (_b = options.top_p) !== null && _b !== void 0 ? _b : 0.1,
|
|
72
|
+
max_tokens: options.max_tokens,
|
|
73
|
+
stream: true
|
|
74
|
+
};
|
|
75
|
+
const response = await this.makeRequest('/chat/completions', payload, options.retry_count, true);
|
|
76
|
+
// **Modified Check:**
|
|
77
|
+
if (!response ||
|
|
78
|
+
!('body' in response) ||
|
|
79
|
+
typeof response.body.getReader !== 'function') {
|
|
80
|
+
throw new Error('Expected a streaming response');
|
|
81
|
+
}
|
|
82
|
+
yield* this.handleStreamResponse(response);
|
|
83
|
+
}
|
|
84
|
+
async *handleStreamResponse(response) {
|
|
85
|
+
const reader = response.body.getReader();
|
|
86
|
+
const decoder = new TextDecoder();
|
|
87
|
+
let buffer = '';
|
|
88
|
+
while (true) {
|
|
89
|
+
const { done, value } = await reader.read();
|
|
90
|
+
if (done)
|
|
91
|
+
break;
|
|
92
|
+
buffer += decoder.decode(value, { stream: true });
|
|
93
|
+
let lines = buffer.split('\n');
|
|
94
|
+
buffer = lines.pop() || '';
|
|
95
|
+
for (const line of lines) {
|
|
96
|
+
const trimmedLine = line.trim();
|
|
97
|
+
if (trimmedLine.startsWith('data: ')) {
|
|
98
|
+
const dataStr = trimmedLine.slice(6);
|
|
99
|
+
if (dataStr === '[DONE]') {
|
|
100
|
+
return;
|
|
101
|
+
}
|
|
102
|
+
try {
|
|
103
|
+
const data = JSON.parse(dataStr);
|
|
104
|
+
yield data;
|
|
105
|
+
}
|
|
106
|
+
catch (e) {
|
|
107
|
+
console.error('Failed to parse stream data:', e);
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
// Processing any remaining buffer
|
|
113
|
+
if (buffer.startsWith('data: ')) {
|
|
114
|
+
const dataStr = buffer.slice(6);
|
|
115
|
+
if (dataStr !== '[DONE]') {
|
|
116
|
+
try {
|
|
117
|
+
const data = JSON.parse(dataStr);
|
|
118
|
+
yield data;
|
|
119
|
+
}
|
|
120
|
+
catch (e) {
|
|
121
|
+
console.error('Failed to parse stream data:', e);
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
exports.SambanovaClient = SambanovaClient;
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export {};
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
const client_1 = require("../client");
|
|
4
|
+
const types_1 = require("../types");
|
|
5
|
+
describe('SambanovaClient', () => {
|
|
6
|
+
const client = new client_1.SambanovaClient('test-api-key');
|
|
7
|
+
beforeEach(() => {
|
|
8
|
+
global.fetch = jest.fn();
|
|
9
|
+
});
|
|
10
|
+
afterEach(() => {
|
|
11
|
+
jest.resetAllMocks();
|
|
12
|
+
});
|
|
13
|
+
describe('chat', () => {
|
|
14
|
+
it('should successfully make a text chat request', async () => {
|
|
15
|
+
const mockResponse = {
|
|
16
|
+
id: 'test-id',
|
|
17
|
+
choices: [{
|
|
18
|
+
message: { role: 'assistant', content: 'Hello!' },
|
|
19
|
+
finish_reason: 'stop'
|
|
20
|
+
}],
|
|
21
|
+
usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }
|
|
22
|
+
};
|
|
23
|
+
global.fetch.mockResolvedValueOnce({
|
|
24
|
+
ok: true,
|
|
25
|
+
json: () => Promise.resolve(mockResponse)
|
|
26
|
+
});
|
|
27
|
+
const result = await client.chat([
|
|
28
|
+
{ role: 'user', content: 'Hi!' }
|
|
29
|
+
]);
|
|
30
|
+
expect(result).toEqual(mockResponse);
|
|
31
|
+
});
|
|
32
|
+
it('should handle vision model requests correctly', async () => {
|
|
33
|
+
const mockResponse = {
|
|
34
|
+
id: 'test-id',
|
|
35
|
+
choices: [{
|
|
36
|
+
message: { role: 'assistant', content: 'I see a cat in the image.' },
|
|
37
|
+
finish_reason: 'stop'
|
|
38
|
+
}],
|
|
39
|
+
usage: { prompt_tokens: 20, completion_tokens: 8, total_tokens: 28 }
|
|
40
|
+
};
|
|
41
|
+
global.fetch.mockResolvedValueOnce({
|
|
42
|
+
ok: true,
|
|
43
|
+
json: () => Promise.resolve(mockResponse)
|
|
44
|
+
});
|
|
45
|
+
const result = await client.chat([
|
|
46
|
+
{
|
|
47
|
+
role: 'user',
|
|
48
|
+
content: [
|
|
49
|
+
{ type: 'text', text: 'What do you see?' },
|
|
50
|
+
{ type: 'image_url', image_url: { url: 'base64...' } }
|
|
51
|
+
]
|
|
52
|
+
}
|
|
53
|
+
], { model: 'Llama-3.2-11B-Vision-Instruct' });
|
|
54
|
+
expect(result).toEqual(mockResponse);
|
|
55
|
+
});
|
|
56
|
+
it('should retry on failure', async () => {
|
|
57
|
+
global.fetch
|
|
58
|
+
.mockRejectedValueOnce(new Error('Network error'))
|
|
59
|
+
.mockRejectedValueOnce(new Error('Network error'))
|
|
60
|
+
.mockResolvedValueOnce({
|
|
61
|
+
ok: true,
|
|
62
|
+
json: () => Promise.resolve({ id: 'test-id' })
|
|
63
|
+
});
|
|
64
|
+
const result = await client.chat([
|
|
65
|
+
{ role: 'user', content: 'Hi!' }
|
|
66
|
+
], { retry_count: 3 });
|
|
67
|
+
expect(global.fetch).toHaveBeenCalledTimes(3);
|
|
68
|
+
expect(result).toEqual({ id: 'test-id' });
|
|
69
|
+
});
|
|
70
|
+
it('should throw SambanovaError on API error', async () => {
|
|
71
|
+
global.fetch.mockResolvedValueOnce({
|
|
72
|
+
ok: false,
|
|
73
|
+
status: 400,
|
|
74
|
+
json: () => Promise.resolve({
|
|
75
|
+
message: 'Invalid request',
|
|
76
|
+
code: 'INVALID_REQUEST'
|
|
77
|
+
})
|
|
78
|
+
});
|
|
79
|
+
await expect(client.chat([{ role: 'user', content: 'Hi!' }], { retry_count: 0 } // this is 0 to prevent retries
|
|
80
|
+
)).rejects.toThrow(types_1.SambanovaError);
|
|
81
|
+
});
|
|
82
|
+
});
|
|
83
|
+
describe('streamChat', () => {
|
|
84
|
+
it('should handle streaming responses', async () => {
|
|
85
|
+
const mockReader = {
|
|
86
|
+
read: jest.fn()
|
|
87
|
+
.mockResolvedValueOnce({
|
|
88
|
+
done: false,
|
|
89
|
+
value: new TextEncoder().encode('data: {"id":"1"}\n')
|
|
90
|
+
})
|
|
91
|
+
.mockResolvedValueOnce({
|
|
92
|
+
done: false,
|
|
93
|
+
value: new TextEncoder().encode('data: {"id":"2"}\n')
|
|
94
|
+
})
|
|
95
|
+
.mockResolvedValueOnce({ done: true })
|
|
96
|
+
};
|
|
97
|
+
global.fetch.mockResolvedValueOnce({
|
|
98
|
+
ok: true,
|
|
99
|
+
body: { getReader: () => mockReader }
|
|
100
|
+
});
|
|
101
|
+
const responses = [];
|
|
102
|
+
for await (const response of client.streamChat([
|
|
103
|
+
{ role: 'user', content: 'Hi!' }
|
|
104
|
+
])) {
|
|
105
|
+
responses.push(response);
|
|
106
|
+
}
|
|
107
|
+
expect(responses).toHaveLength(2);
|
|
108
|
+
expect(responses[0]).toEqual({ id: '1' });
|
|
109
|
+
expect(responses[1]).toEqual({ id: '2' });
|
|
110
|
+
});
|
|
111
|
+
});
|
|
112
|
+
});
|
package/dist/types.d.ts
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
export type ModelType = 'Llama-3.2-11B-Vision-Instruct' | 'Meta-Llama-3.1-8B-Instruct' | 'Meta-Llama-3.1-70B-Instruct' | 'Meta-Llama-3.1-405B-Instruct' | 'Meta-Llama-3.2-1B-Instruct' | 'Meta-Llama-3.2-3B-Instruct' | 'Llama-3.2-90B-Vision-Instruct';
|
|
2
|
+
export interface ChatMessage {
|
|
3
|
+
role: 'system' | 'user' | 'assistant';
|
|
4
|
+
content: string | MessageContent[];
|
|
5
|
+
}
|
|
6
|
+
export interface MessageContent {
|
|
7
|
+
type: 'text' | 'image_url';
|
|
8
|
+
text?: string;
|
|
9
|
+
image_url?: {
|
|
10
|
+
url: string;
|
|
11
|
+
};
|
|
12
|
+
}
|
|
13
|
+
export interface ChatOptions {
|
|
14
|
+
model?: ModelType;
|
|
15
|
+
temperature?: number;
|
|
16
|
+
top_p?: number;
|
|
17
|
+
max_tokens?: number;
|
|
18
|
+
stream?: boolean;
|
|
19
|
+
retry_count?: number;
|
|
20
|
+
retry_delay?: number;
|
|
21
|
+
}
|
|
22
|
+
export interface APIResponse {
|
|
23
|
+
id: string;
|
|
24
|
+
choices: Array<{
|
|
25
|
+
message: ChatMessage;
|
|
26
|
+
finish_reason: string;
|
|
27
|
+
}>;
|
|
28
|
+
usage: {
|
|
29
|
+
prompt_tokens: number;
|
|
30
|
+
completion_tokens: number;
|
|
31
|
+
total_tokens: number;
|
|
32
|
+
};
|
|
33
|
+
}
|
|
34
|
+
export declare class SambanovaError extends Error {
|
|
35
|
+
status?: number | undefined;
|
|
36
|
+
code?: string | undefined;
|
|
37
|
+
details?: any;
|
|
38
|
+
constructor(message: string, status?: number | undefined, code?: string | undefined, details?: any);
|
|
39
|
+
}
|
package/dist/types.js
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.SambanovaError = void 0;
|
|
4
|
+
class SambanovaError extends Error {
|
|
5
|
+
constructor(message, status, code, details) {
|
|
6
|
+
super(message);
|
|
7
|
+
this.status = status;
|
|
8
|
+
this.code = code;
|
|
9
|
+
this.details = details;
|
|
10
|
+
this.name = 'SambanovaError';
|
|
11
|
+
}
|
|
12
|
+
}
|
|
13
|
+
exports.SambanovaError = SambanovaError;
|
package/dist/utils.d.ts
ADDED
package/dist/utils.js
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.validateMessage = exports.isVisionModel = exports.sleep = void 0;
|
|
4
|
+
const types_1 = require("./types");
|
|
5
|
+
const sleep = (ms) => new Promise(resolve => setTimeout(resolve, ms));
|
|
6
|
+
exports.sleep = sleep;
|
|
7
|
+
const isVisionModel = (model) => {
|
|
8
|
+
return model.toLowerCase().includes('vision');
|
|
9
|
+
};
|
|
10
|
+
exports.isVisionModel = isVisionModel;
|
|
11
|
+
const validateMessage = (message, isVision) => {
|
|
12
|
+
if (Array.isArray(message.content)) {
|
|
13
|
+
if (!isVision) {
|
|
14
|
+
throw new types_1.SambanovaError('Array content is only supported for vision models', 400, 'INVALID_MESSAGE_FORMAT');
|
|
15
|
+
}
|
|
16
|
+
}
|
|
17
|
+
else if (isVision) {
|
|
18
|
+
throw new types_1.SambanovaError('Vision models require array content format', 400, 'INVALID_MESSAGE_FORMAT');
|
|
19
|
+
}
|
|
20
|
+
};
|
|
21
|
+
exports.validateMessage = validateMessage;
|
package/jest.config.js
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
module.exports = {
|
|
2
|
+
preset: 'ts-jest',
|
|
3
|
+
testEnvironment: 'node',
|
|
4
|
+
testMatch: [
|
|
5
|
+
'**/tests/**/*.test.ts',
|
|
6
|
+
'**/__tests__/**/*.test.ts'
|
|
7
|
+
],
|
|
8
|
+
moduleFileExtensions: ['ts', 'js', 'json', 'node'],
|
|
9
|
+
rootDir: './',
|
|
10
|
+
transform: {
|
|
11
|
+
'^.+\\.ts?$': 'ts-jest'
|
|
12
|
+
}
|
|
13
|
+
};
|
package/package.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "sambanova",
|
|
3
|
+
"version": "1.0.0",
|
|
4
|
+
"description": "TypeScript/Javascript client for Sambanova AI API with comprehensive model support",
|
|
5
|
+
"main": "dist/index.js",
|
|
6
|
+
"types": "dist/index.d.ts",
|
|
7
|
+
"scripts": {
|
|
8
|
+
"build": "tsc",
|
|
9
|
+
"test": "jest",
|
|
10
|
+
"lint": "eslint src/**/*.ts",
|
|
11
|
+
"prepare": "npm run build"
|
|
12
|
+
},
|
|
13
|
+
"keywords": [
|
|
14
|
+
"sambanova",
|
|
15
|
+
"ai",
|
|
16
|
+
"llm",
|
|
17
|
+
"vision",
|
|
18
|
+
"api",
|
|
19
|
+
"typescript"
|
|
20
|
+
],
|
|
21
|
+
"author": "Aaditya Srivastava",
|
|
22
|
+
"license": "MIT",
|
|
23
|
+
"dependencies": {
|
|
24
|
+
"node-fetch": "^2.7.0"
|
|
25
|
+
},
|
|
26
|
+
"devDependencies": {
|
|
27
|
+
"@types/jest": "^29.5.14",
|
|
28
|
+
"@types/node": "^16.18.119",
|
|
29
|
+
"@typescript-eslint/eslint-plugin": "^5.62.0",
|
|
30
|
+
"@typescript-eslint/parser": "^5.62.0",
|
|
31
|
+
"eslint": "^8.0.0",
|
|
32
|
+
"jest": "^29.7.0",
|
|
33
|
+
"ts-jest": "^29.2.5",
|
|
34
|
+
"typescript": "^4.9.5"
|
|
35
|
+
}
|
|
36
|
+
}
|
package/src/client.ts
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
import { sleep, isVisionModel, validateMessage } from './utils';
|
|
2
|
+
import {
|
|
3
|
+
ModelType,
|
|
4
|
+
ChatMessage,
|
|
5
|
+
ChatOptions,
|
|
6
|
+
APIResponse,
|
|
7
|
+
SambanovaError
|
|
8
|
+
} from './types';
|
|
9
|
+
|
|
10
|
+
export class SambanovaClient {
|
|
11
|
+
private readonly baseUrl: string;
|
|
12
|
+
private readonly defaultModel: ModelType;
|
|
13
|
+
private readonly defaultRetryCount: number;
|
|
14
|
+
private readonly defaultRetryDelay: number;
|
|
15
|
+
|
|
16
|
+
constructor(
|
|
17
|
+
private readonly apiKey: string,
|
|
18
|
+
options: {
|
|
19
|
+
baseUrl?: string;
|
|
20
|
+
defaultModel?: ModelType;
|
|
21
|
+
defaultRetryCount?: number;
|
|
22
|
+
defaultRetryDelay?: number;
|
|
23
|
+
} = {}
|
|
24
|
+
) {
|
|
25
|
+
this.baseUrl = options.baseUrl || 'https://api.sambanova.ai/v1';
|
|
26
|
+
this.defaultModel = options.defaultModel || 'Meta-Llama-3.2-3B-Instruct';
|
|
27
|
+
this.defaultRetryCount = options.defaultRetryCount || 3;
|
|
28
|
+
this.defaultRetryDelay = options.defaultRetryDelay || 1000;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
private async makeRequest(
|
|
32
|
+
endpoint: string,
|
|
33
|
+
data: any,
|
|
34
|
+
retryCount: number = this.defaultRetryCount,
|
|
35
|
+
stream: boolean = false
|
|
36
|
+
): Promise<APIResponse | Response> {
|
|
37
|
+
let lastError: Error | null = null;
|
|
38
|
+
|
|
39
|
+
for (let attempt = 0; attempt <= retryCount; attempt++) {
|
|
40
|
+
try {
|
|
41
|
+
const response = await fetch(`${this.baseUrl}${endpoint}`, {
|
|
42
|
+
method: 'POST',
|
|
43
|
+
headers: {
|
|
44
|
+
'Content-Type': 'application/json',
|
|
45
|
+
'Authorization': `Bearer ${this.apiKey}`
|
|
46
|
+
},
|
|
47
|
+
body: JSON.stringify(data)
|
|
48
|
+
});
|
|
49
|
+
|
|
50
|
+
if (!response.ok) {
|
|
51
|
+
const errorData = await response.json();
|
|
52
|
+
throw new SambanovaError(
|
|
53
|
+
errorData.message || 'API request failed',
|
|
54
|
+
response.status,
|
|
55
|
+
errorData.code,
|
|
56
|
+
errorData
|
|
57
|
+
);
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
if (stream) {
|
|
61
|
+
return response;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
return await response.json();
|
|
65
|
+
} catch (error) {
|
|
66
|
+
lastError = error as Error;
|
|
67
|
+
if (attempt < retryCount) {
|
|
68
|
+
await sleep(this.defaultRetryDelay * Math.pow(2, attempt));
|
|
69
|
+
continue;
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
throw lastError || new Error('Request failed after retries');
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
async chat(
|
|
78
|
+
messages: ChatMessage[],
|
|
79
|
+
options: ChatOptions = {}
|
|
80
|
+
): Promise<APIResponse> {
|
|
81
|
+
const model = options.model || this.defaultModel;
|
|
82
|
+
|
|
83
|
+
messages.forEach(msg => validateMessage(msg, isVisionModel(model)));
|
|
84
|
+
|
|
85
|
+
const payload = {
|
|
86
|
+
model,
|
|
87
|
+
messages,
|
|
88
|
+
temperature: options.temperature ?? 0.1,
|
|
89
|
+
top_p: options.top_p ?? 0.1,
|
|
90
|
+
max_tokens: options.max_tokens,
|
|
91
|
+
stream: options.stream ?? false
|
|
92
|
+
};
|
|
93
|
+
|
|
94
|
+
const response = await this.makeRequest(
|
|
95
|
+
'/chat/completions',
|
|
96
|
+
payload,
|
|
97
|
+
options.retry_count,
|
|
98
|
+
payload.stream
|
|
99
|
+
);
|
|
100
|
+
|
|
101
|
+
if (payload.stream && response instanceof Response) {
|
|
102
|
+
throw new Error('Stream response received in chat method. Use streamChat instead.');
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
return response as APIResponse;
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
async *streamChat(
|
|
109
|
+
messages: ChatMessage[],
|
|
110
|
+
options: ChatOptions = {}
|
|
111
|
+
): AsyncGenerator<APIResponse, void, unknown> {
|
|
112
|
+
const model = options.model || this.defaultModel;
|
|
113
|
+
|
|
114
|
+
messages.forEach(msg => validateMessage(msg, isVisionModel(model)));
|
|
115
|
+
|
|
116
|
+
const payload = {
|
|
117
|
+
model,
|
|
118
|
+
messages,
|
|
119
|
+
temperature: options.temperature ?? 0.1,
|
|
120
|
+
top_p: options.top_p ?? 0.1,
|
|
121
|
+
max_tokens: options.max_tokens,
|
|
122
|
+
stream: true
|
|
123
|
+
};
|
|
124
|
+
|
|
125
|
+
const response = await this.makeRequest(
|
|
126
|
+
'/chat/completions',
|
|
127
|
+
payload,
|
|
128
|
+
options.retry_count,
|
|
129
|
+
true
|
|
130
|
+
);
|
|
131
|
+
|
|
132
|
+
// **Modified Check:**
|
|
133
|
+
if (
|
|
134
|
+
!response ||
|
|
135
|
+
!('body' in response) ||
|
|
136
|
+
typeof (response as any).body.getReader !== 'function'
|
|
137
|
+
) {
|
|
138
|
+
throw new Error('Expected a streaming response');
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
yield* this.handleStreamResponse(response as Response);
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
private async *handleStreamResponse(
|
|
145
|
+
response: Response
|
|
146
|
+
): AsyncGenerator<APIResponse, void, unknown> {
|
|
147
|
+
const reader = response.body!.getReader();
|
|
148
|
+
const decoder = new TextDecoder();
|
|
149
|
+
let buffer = '';
|
|
150
|
+
|
|
151
|
+
while (true) {
|
|
152
|
+
const { done, value } = await reader.read();
|
|
153
|
+
if (done) break;
|
|
154
|
+
|
|
155
|
+
buffer += decoder.decode(value, { stream: true });
|
|
156
|
+
let lines = buffer.split('\n');
|
|
157
|
+
|
|
158
|
+
buffer = lines.pop() || '';
|
|
159
|
+
|
|
160
|
+
for (const line of lines) {
|
|
161
|
+
const trimmedLine = line.trim();
|
|
162
|
+
if (trimmedLine.startsWith('data: ')) {
|
|
163
|
+
const dataStr = trimmedLine.slice(6);
|
|
164
|
+
if (dataStr === '[DONE]') {
|
|
165
|
+
return;
|
|
166
|
+
}
|
|
167
|
+
try {
|
|
168
|
+
const data = JSON.parse(dataStr);
|
|
169
|
+
yield data;
|
|
170
|
+
} catch (e) {
|
|
171
|
+
console.error('Failed to parse stream data:', e);
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
// Processing any remaining buffer
|
|
178
|
+
if (buffer.startsWith('data: ')) {
|
|
179
|
+
const dataStr = buffer.slice(6);
|
|
180
|
+
if (dataStr !== '[DONE]') {
|
|
181
|
+
try {
|
|
182
|
+
const data = JSON.parse(dataStr);
|
|
183
|
+
yield data;
|
|
184
|
+
} catch (e) {
|
|
185
|
+
console.error('Failed to parse stream data:', e);
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
}
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import { SambanovaClient } from '../client';
|
|
2
|
+
import { SambanovaError } from '../types';
|
|
3
|
+
|
|
4
|
+
describe('SambanovaClient', () => {
|
|
5
|
+
const client = new SambanovaClient('test-api-key');
|
|
6
|
+
|
|
7
|
+
beforeEach(() => {
|
|
8
|
+
global.fetch = jest.fn();
|
|
9
|
+
});
|
|
10
|
+
|
|
11
|
+
afterEach(() => {
|
|
12
|
+
jest.resetAllMocks();
|
|
13
|
+
});
|
|
14
|
+
|
|
15
|
+
describe('chat', () => {
|
|
16
|
+
it('should successfully make a text chat request', async () => {
|
|
17
|
+
const mockResponse = {
|
|
18
|
+
id: 'test-id',
|
|
19
|
+
choices: [{
|
|
20
|
+
message: { role: 'assistant', content: 'Hello!' },
|
|
21
|
+
finish_reason: 'stop'
|
|
22
|
+
}],
|
|
23
|
+
usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }
|
|
24
|
+
};
|
|
25
|
+
|
|
26
|
+
(global.fetch as jest.Mock).mockResolvedValueOnce({
|
|
27
|
+
ok: true,
|
|
28
|
+
json: () => Promise.resolve(mockResponse)
|
|
29
|
+
});
|
|
30
|
+
|
|
31
|
+
const result = await client.chat([
|
|
32
|
+
{ role: 'user', content: 'Hi!' }
|
|
33
|
+
]);
|
|
34
|
+
|
|
35
|
+
expect(result).toEqual(mockResponse);
|
|
36
|
+
});
|
|
37
|
+
|
|
38
|
+
it('should handle vision model requests correctly', async () => {
|
|
39
|
+
const mockResponse = {
|
|
40
|
+
id: 'test-id',
|
|
41
|
+
choices: [{
|
|
42
|
+
message: { role: 'assistant', content: 'I see a cat in the image.' },
|
|
43
|
+
finish_reason: 'stop'
|
|
44
|
+
}],
|
|
45
|
+
usage: { prompt_tokens: 20, completion_tokens: 8, total_tokens: 28 }
|
|
46
|
+
};
|
|
47
|
+
|
|
48
|
+
(global.fetch as jest.Mock).mockResolvedValueOnce({
|
|
49
|
+
ok: true,
|
|
50
|
+
json: () => Promise.resolve(mockResponse)
|
|
51
|
+
});
|
|
52
|
+
|
|
53
|
+
const result = await client.chat([
|
|
54
|
+
{
|
|
55
|
+
role: 'user',
|
|
56
|
+
content: [
|
|
57
|
+
{ type: 'text', text: 'What do you see?' },
|
|
58
|
+
{ type: 'image_url', image_url: { url: 'base64...' } }
|
|
59
|
+
]
|
|
60
|
+
}
|
|
61
|
+
], { model: 'Llama-3.2-11B-Vision-Instruct' });
|
|
62
|
+
|
|
63
|
+
expect(result).toEqual(mockResponse);
|
|
64
|
+
});
|
|
65
|
+
|
|
66
|
+
it('should retry on failure', async () => {
|
|
67
|
+
(global.fetch as jest.Mock)
|
|
68
|
+
.mockRejectedValueOnce(new Error('Network error'))
|
|
69
|
+
.mockRejectedValueOnce(new Error('Network error'))
|
|
70
|
+
.mockResolvedValueOnce({
|
|
71
|
+
ok: true,
|
|
72
|
+
json: () => Promise.resolve({ id: 'test-id' })
|
|
73
|
+
});
|
|
74
|
+
|
|
75
|
+
const result = await client.chat([
|
|
76
|
+
{ role: 'user', content: 'Hi!' }
|
|
77
|
+
], { retry_count: 3 });
|
|
78
|
+
|
|
79
|
+
expect(global.fetch).toHaveBeenCalledTimes(3);
|
|
80
|
+
expect(result).toEqual({ id: 'test-id' });
|
|
81
|
+
});
|
|
82
|
+
|
|
83
|
+
it('should throw SambanovaError on API error', async () => {
|
|
84
|
+
(global.fetch as jest.Mock).mockResolvedValueOnce({
|
|
85
|
+
ok: false,
|
|
86
|
+
status: 400,
|
|
87
|
+
json: () => Promise.resolve({
|
|
88
|
+
message: 'Invalid request',
|
|
89
|
+
code: 'INVALID_REQUEST'
|
|
90
|
+
})
|
|
91
|
+
});
|
|
92
|
+
|
|
93
|
+
await expect(
|
|
94
|
+
client.chat(
|
|
95
|
+
[{ role: 'user', content: 'Hi!' }],
|
|
96
|
+
{ retry_count: 0 } // this is 0 to prevent retries
|
|
97
|
+
)
|
|
98
|
+
).rejects.toThrow(SambanovaError);
|
|
99
|
+
});
|
|
100
|
+
|
|
101
|
+
});
|
|
102
|
+
|
|
103
|
+
describe('streamChat', () => {
|
|
104
|
+
it('should handle streaming responses', async () => {
|
|
105
|
+
const mockReader = {
|
|
106
|
+
read: jest.fn()
|
|
107
|
+
.mockResolvedValueOnce({
|
|
108
|
+
done: false,
|
|
109
|
+
value: new TextEncoder().encode('data: {"id":"1"}\n')
|
|
110
|
+
})
|
|
111
|
+
.mockResolvedValueOnce({
|
|
112
|
+
done: false,
|
|
113
|
+
value: new TextEncoder().encode('data: {"id":"2"}\n')
|
|
114
|
+
})
|
|
115
|
+
.mockResolvedValueOnce({ done: true })
|
|
116
|
+
};
|
|
117
|
+
|
|
118
|
+
(global.fetch as jest.Mock).mockResolvedValueOnce({
|
|
119
|
+
ok: true,
|
|
120
|
+
body: { getReader: () => mockReader }
|
|
121
|
+
});
|
|
122
|
+
|
|
123
|
+
const responses = [];
|
|
124
|
+
for await (const response of client.streamChat([
|
|
125
|
+
{ role: 'user', content: 'Hi!' }
|
|
126
|
+
])) {
|
|
127
|
+
responses.push(response);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
expect(responses).toHaveLength(2);
|
|
131
|
+
expect(responses[0]).toEqual({ id: '1' });
|
|
132
|
+
expect(responses[1]).toEqual({ id: '2' });
|
|
133
|
+
});
|
|
134
|
+
});
|
|
135
|
+
});
|
package/src/types.ts
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
export type ModelType =
|
|
2
|
+
| 'Llama-3.2-11B-Vision-Instruct'
|
|
3
|
+
| 'Meta-Llama-3.1-8B-Instruct'
|
|
4
|
+
| 'Meta-Llama-3.1-70B-Instruct'
|
|
5
|
+
| 'Meta-Llama-3.1-405B-Instruct'
|
|
6
|
+
| 'Meta-Llama-3.2-1B-Instruct'
|
|
7
|
+
| 'Meta-Llama-3.2-3B-Instruct'
|
|
8
|
+
| 'Llama-3.2-90B-Vision-Instruct';
|
|
9
|
+
|
|
10
|
+
export interface ChatMessage {
|
|
11
|
+
role: 'system' | 'user' | 'assistant';
|
|
12
|
+
content: string | MessageContent[];
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
export interface MessageContent {
|
|
16
|
+
type: 'text' | 'image_url';
|
|
17
|
+
text?: string;
|
|
18
|
+
image_url?: {
|
|
19
|
+
url: string;
|
|
20
|
+
};
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
export interface ChatOptions {
|
|
24
|
+
model?: ModelType;
|
|
25
|
+
temperature?: number;
|
|
26
|
+
top_p?: number;
|
|
27
|
+
max_tokens?: number;
|
|
28
|
+
stream?: boolean;
|
|
29
|
+
retry_count?: number;
|
|
30
|
+
retry_delay?: number;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
export interface APIResponse {
|
|
34
|
+
id: string;
|
|
35
|
+
choices: Array<{
|
|
36
|
+
message: ChatMessage;
|
|
37
|
+
finish_reason: string;
|
|
38
|
+
}>;
|
|
39
|
+
usage: {
|
|
40
|
+
prompt_tokens: number;
|
|
41
|
+
completion_tokens: number;
|
|
42
|
+
total_tokens: number;
|
|
43
|
+
};
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
export class SambanovaError extends Error {
|
|
47
|
+
constructor(
|
|
48
|
+
message: string,
|
|
49
|
+
public status?: number,
|
|
50
|
+
public code?: string,
|
|
51
|
+
public details?: any
|
|
52
|
+
) {
|
|
53
|
+
super(message);
|
|
54
|
+
this.name = 'SambanovaError';
|
|
55
|
+
}
|
|
56
|
+
}
|
package/src/utils.ts
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import { ModelType, ChatMessage, SambanovaError } from './types';
|
|
2
|
+
|
|
3
|
+
export const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms));
|
|
4
|
+
|
|
5
|
+
export const isVisionModel = (model: ModelType): boolean => {
|
|
6
|
+
return model.toLowerCase().includes('vision');
|
|
7
|
+
};
|
|
8
|
+
|
|
9
|
+
export const validateMessage = (message: ChatMessage, isVision: boolean) => {
|
|
10
|
+
if (Array.isArray(message.content)) {
|
|
11
|
+
if (!isVision) {
|
|
12
|
+
throw new SambanovaError(
|
|
13
|
+
'Array content is only supported for vision models',
|
|
14
|
+
400,
|
|
15
|
+
'INVALID_MESSAGE_FORMAT'
|
|
16
|
+
);
|
|
17
|
+
}
|
|
18
|
+
} else if (isVision) {
|
|
19
|
+
throw new SambanovaError(
|
|
20
|
+
'Vision models require array content format',
|
|
21
|
+
400,
|
|
22
|
+
'INVALID_MESSAGE_FORMAT'
|
|
23
|
+
);
|
|
24
|
+
}
|
|
25
|
+
};
|
package/tsconfig.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
{
|
|
2
|
+
"compilerOptions": {
|
|
3
|
+
"target": "es2018",
|
|
4
|
+
"module": "commonjs",
|
|
5
|
+
"declaration": true,
|
|
6
|
+
"outDir": "./dist",
|
|
7
|
+
"strict": true,
|
|
8
|
+
"esModuleInterop": true,
|
|
9
|
+
"skipLibCheck": true,
|
|
10
|
+
"forceConsistentCasingInFileNames": true
|
|
11
|
+
},
|
|
12
|
+
"include": ["src"],
|
|
13
|
+
"exclude": ["node_modules", "dist", "**/__tests__/*"]
|
|
14
|
+
}
|