@inference-gateway/sdk 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,229 +2,338 @@
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  const client_1 = require("@/client");
4
4
  const types_1 = require("@/types");
5
+ const web_1 = require("node:stream/web");
6
+ const node_util_1 = require("node:util");
5
7
  describe('InferenceGatewayClient', () => {
6
8
  let client;
7
- const mockBaseUrl = 'http://localhost:8080';
9
+ const mockFetch = jest.fn();
8
10
  beforeEach(() => {
9
- client = new client_1.InferenceGatewayClient(mockBaseUrl);
10
- global.fetch = jest.fn();
11
+ client = new client_1.InferenceGatewayClient({
12
+ baseURL: 'http://localhost:8080/v1',
13
+ fetch: mockFetch,
14
+ });
15
+ });
16
+ afterEach(() => {
17
+ jest.clearAllMocks();
11
18
  });
12
19
  describe('listModels', () => {
13
20
  it('should fetch available models', async () => {
14
- const mockResponse = [
15
- {
16
- provider: types_1.Provider.Ollama,
17
- models: [
18
- {
19
- name: 'llama2',
20
- },
21
- ],
22
- },
23
- ];
24
- global.fetch.mockResolvedValueOnce({
21
+ const mockResponse = {
22
+ object: 'list',
23
+ data: [
24
+ {
25
+ id: 'gpt-4o',
26
+ object: 'model',
27
+ created: 1686935002,
28
+ owned_by: 'openai',
29
+ },
30
+ {
31
+ id: 'llama-3.3-70b-versatile',
32
+ object: 'model',
33
+ created: 1723651281,
34
+ owned_by: 'groq',
35
+ },
36
+ ],
37
+ };
38
+ mockFetch.mockResolvedValueOnce({
25
39
  ok: true,
26
40
  json: () => Promise.resolve(mockResponse),
27
41
  });
28
42
  const result = await client.listModels();
29
43
  expect(result).toEqual(mockResponse);
30
- expect(global.fetch).toHaveBeenCalledWith(`${mockBaseUrl}/llms`, expect.objectContaining({
44
+ expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/v1/models', expect.objectContaining({
45
+ method: 'GET',
31
46
  headers: expect.any(Headers),
32
47
  }));
33
48
  });
34
- });
35
- describe('listModelsByProvider', () => {
36
49
  it('should fetch models for a specific provider', async () => {
37
50
  const mockResponse = {
38
- provider: types_1.Provider.OpenAI,
39
- models: [
51
+ object: 'list',
52
+ data: [
40
53
  {
41
- name: 'gpt-4',
54
+ id: 'gpt-4o',
55
+ object: 'model',
56
+ created: 1686935002,
57
+ owned_by: 'openai',
42
58
  },
43
59
  ],
44
60
  };
45
- global.fetch.mockResolvedValueOnce({
61
+ mockFetch.mockResolvedValueOnce({
46
62
  ok: true,
47
63
  json: () => Promise.resolve(mockResponse),
48
64
  });
49
- const result = await client.listModelsByProvider(types_1.Provider.OpenAI);
65
+ const result = await client.listModels(types_1.Provider.OpenAI);
50
66
  expect(result).toEqual(mockResponse);
51
- expect(global.fetch).toHaveBeenCalledWith(`${mockBaseUrl}/llms/${types_1.Provider.OpenAI}`, expect.objectContaining({
67
+ expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/v1/models?provider=openai', expect.objectContaining({
68
+ method: 'GET',
52
69
  headers: expect.any(Headers),
53
70
  }));
54
71
  });
55
- it('should throw error when provider request fails', async () => {
72
+ it('should throw error when request fails', async () => {
56
73
  const errorMessage = 'Provider not found';
57
- global.fetch.mockResolvedValueOnce({
74
+ mockFetch.mockResolvedValueOnce({
58
75
  ok: false,
59
76
  status: 404,
60
77
  json: () => Promise.resolve({ error: errorMessage }),
61
78
  });
62
- await expect(client.listModelsByProvider(types_1.Provider.OpenAI)).rejects.toThrow(errorMessage);
79
+ await expect(client.listModels(types_1.Provider.OpenAI)).rejects.toThrow(errorMessage);
63
80
  });
64
81
  });
65
- describe('generateContent', () => {
66
- it('should generate content with the specified provider', async () => {
82
+ describe('createChatCompletion', () => {
83
+ it('should create a chat completion', async () => {
67
84
  const mockRequest = {
68
- provider: types_1.Provider.Ollama,
69
- model: 'llama2',
85
+ model: 'gpt-4o',
70
86
  messages: [
71
87
  { role: types_1.MessageRole.System, content: 'You are a helpful assistant' },
72
88
  { role: types_1.MessageRole.User, content: 'Hello' },
73
89
  ],
74
90
  };
75
91
  const mockResponse = {
76
- provider: types_1.Provider.Ollama,
77
- response: {
78
- role: types_1.MessageRole.Assistant,
79
- model: 'llama2',
80
- content: 'Hi there!',
92
+ id: 'chatcmpl-123',
93
+ object: 'chat.completion',
94
+ created: 1677652288,
95
+ model: 'gpt-4o',
96
+ choices: [
97
+ {
98
+ index: 0,
99
+ message: {
100
+ role: types_1.MessageRole.Assistant,
101
+ content: 'Hello! How can I help you today?',
102
+ },
103
+ finish_reason: 'stop',
104
+ },
105
+ ],
106
+ usage: {
107
+ prompt_tokens: 10,
108
+ completion_tokens: 8,
109
+ total_tokens: 18,
81
110
  },
82
111
  };
83
- global.fetch.mockResolvedValueOnce({
112
+ mockFetch.mockResolvedValueOnce({
84
113
  ok: true,
85
114
  json: () => Promise.resolve(mockResponse),
86
115
  });
87
- const result = await client.generateContent(mockRequest);
116
+ const result = await client.createChatCompletion(mockRequest);
88
117
  expect(result).toEqual(mockResponse);
89
- expect(global.fetch).toHaveBeenCalledWith(`${mockBaseUrl}/llms/${mockRequest.provider}/generate`, expect.objectContaining({
118
+ expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/v1/chat/completions', expect.objectContaining({
90
119
  method: 'POST',
91
- body: JSON.stringify({
92
- model: mockRequest.model,
93
- messages: mockRequest.messages,
94
- }),
120
+ body: JSON.stringify(mockRequest),
95
121
  }));
96
122
  });
97
- });
98
- describe('healthCheck', () => {
99
- it('should return true when API is healthy', async () => {
100
- global.fetch.mockResolvedValueOnce({
123
+ it('should create a chat completion with a specific provider', async () => {
124
+ const mockRequest = {
125
+ model: 'claude-3-opus-20240229',
126
+ messages: [{ role: types_1.MessageRole.User, content: 'Hello' }],
127
+ };
128
+ const mockResponse = {
129
+ id: 'chatcmpl-456',
130
+ object: 'chat.completion',
131
+ created: 1677652288,
132
+ model: 'claude-3-opus-20240229',
133
+ choices: [
134
+ {
135
+ index: 0,
136
+ message: {
137
+ role: types_1.MessageRole.Assistant,
138
+ content: 'Hello! How can I assist you today?',
139
+ },
140
+ finish_reason: 'stop',
141
+ },
142
+ ],
143
+ usage: {
144
+ prompt_tokens: 5,
145
+ completion_tokens: 8,
146
+ total_tokens: 13,
147
+ },
148
+ };
149
+ mockFetch.mockResolvedValueOnce({
101
150
  ok: true,
102
- json: () => Promise.resolve({}),
103
- });
104
- const result = await client.healthCheck();
105
- expect(result).toBe(true);
106
- expect(global.fetch).toHaveBeenCalledWith(`${mockBaseUrl}/health`, expect.any(Object));
107
- });
108
- it('should return false when API is unhealthy', async () => {
109
- global.fetch.mockRejectedValueOnce(new Error('API error'));
110
- const result = await client.healthCheck();
111
- expect(result).toBe(false);
112
- });
113
- });
114
- describe('error handling', () => {
115
- it('should throw error when API request fails', async () => {
116
- const errorMessage = 'Bad Request';
117
- global.fetch.mockResolvedValueOnce({
118
- ok: false,
119
- status: 400,
120
- json: () => Promise.resolve({ error: errorMessage }),
151
+ json: () => Promise.resolve(mockResponse),
121
152
  });
122
- await expect(client.listModels()).rejects.toThrow(errorMessage);
153
+ const result = await client.createChatCompletion(mockRequest, types_1.Provider.Anthropic);
154
+ expect(result).toEqual(mockResponse);
155
+ expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/v1/chat/completions?provider=anthropic', expect.objectContaining({
156
+ method: 'POST',
157
+ body: JSON.stringify(mockRequest),
158
+ }));
123
159
  });
124
160
  });
125
- describe('generateContentStream', () => {
126
- it('should handle SSE events correctly', async () => {
161
+ describe('streamChatCompletion', () => {
162
+ it('should handle streaming chat completions', async () => {
127
163
  const mockRequest = {
128
- provider: types_1.Provider.Ollama,
129
- model: 'llama2',
130
- messages: [
131
- { role: types_1.MessageRole.System, content: 'You are a helpful assistant' },
132
- { role: types_1.MessageRole.User, content: 'Hello' },
133
- ],
164
+ model: 'gpt-4o',
165
+ messages: [{ role: types_1.MessageRole.User, content: 'Hello' }],
134
166
  };
135
- const mockStream = new TransformStream();
167
+ const mockStream = new web_1.TransformStream();
136
168
  const writer = mockStream.writable.getWriter();
137
- const encoder = new TextEncoder();
138
- global.fetch.mockResolvedValueOnce({
169
+ const encoder = new node_util_1.TextEncoder();
170
+ mockFetch.mockResolvedValueOnce({
139
171
  ok: true,
140
172
  body: mockStream.readable,
141
173
  });
142
174
  const callbacks = {
143
- onMessageStart: jest.fn(),
144
- onStreamStart: jest.fn(),
145
- onContentStart: jest.fn(),
146
- onContentDelta: jest.fn(),
147
- onContentEnd: jest.fn(),
148
- onMessageEnd: jest.fn(),
149
- onStreamEnd: jest.fn(),
175
+ onOpen: jest.fn(),
176
+ onChunk: jest.fn(),
177
+ onContent: jest.fn(),
178
+ onFinish: jest.fn(),
179
+ onError: jest.fn(),
150
180
  };
151
- const streamPromise = client.generateContentStream(mockRequest, callbacks);
152
- await writer.write(encoder.encode('event: message-start\ndata: {"role": "assistant"}\n\n' +
153
- 'event: stream-start\ndata: {}\n\n' +
154
- 'event: content-start\ndata: {}\n\n' +
155
- 'event: content-delta\ndata: {"content": "Hello"}\n\n' +
156
- 'event: content-delta\ndata: {"content": " there!"}\n\n' +
157
- 'event: content-end\ndata: {}\n\n' +
158
- 'event: message-end\ndata: {}\n\n' +
159
- 'event: stream-end\ndata: {}\n\n'));
181
+ const streamPromise = client.streamChatCompletion(mockRequest, callbacks);
182
+ // Simulate SSE events
183
+ await writer.write(encoder.encode('data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' +
184
+ 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}\n\n' +
185
+ 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}\n\n' +
186
+ 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}\n\n' +
187
+ 'data: [DONE]\n\n'));
160
188
  await writer.close();
161
189
  await streamPromise;
162
- expect(callbacks.onMessageStart).toHaveBeenCalledWith('assistant');
163
- expect(callbacks.onStreamStart).toHaveBeenCalledTimes(1);
164
- expect(callbacks.onContentStart).toHaveBeenCalledTimes(1);
165
- expect(callbacks.onContentDelta).toHaveBeenCalledWith('Hello');
166
- expect(callbacks.onContentDelta).toHaveBeenCalledWith(' there!');
167
- expect(callbacks.onContentEnd).toHaveBeenCalledTimes(1);
168
- expect(callbacks.onMessageEnd).toHaveBeenCalledTimes(1);
169
- expect(callbacks.onStreamEnd).toHaveBeenCalledTimes(1);
170
- expect(global.fetch).toHaveBeenCalledWith(`${mockBaseUrl}/llms/${mockRequest.provider}/generate`, expect.objectContaining({
190
+ expect(callbacks.onOpen).toHaveBeenCalledTimes(1);
191
+ expect(callbacks.onChunk).toHaveBeenCalledTimes(4);
192
+ expect(callbacks.onContent).toHaveBeenCalledWith('Hello');
193
+ expect(callbacks.onContent).toHaveBeenCalledWith('!');
194
+ expect(callbacks.onFinish).toHaveBeenCalledTimes(1);
195
+ expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/v1/chat/completions', expect.objectContaining({
171
196
  method: 'POST',
172
197
  body: JSON.stringify({
173
- model: mockRequest.model,
174
- messages: mockRequest.messages,
198
+ ...mockRequest,
175
199
  stream: true,
176
- ssevents: true,
177
200
  }),
178
201
  }));
179
202
  });
180
- it('should handle errors in the stream response', async () => {
203
+ it('should handle tool calls in streaming chat completions', async () => {
181
204
  const mockRequest = {
182
- provider: types_1.Provider.Ollama,
183
- model: 'llama2',
205
+ model: 'gpt-4o',
206
+ messages: [
207
+ {
208
+ role: types_1.MessageRole.User,
209
+ content: 'What is the weather in San Francisco?',
210
+ },
211
+ ],
212
+ tools: [
213
+ {
214
+ type: 'function',
215
+ function: {
216
+ name: 'get_weather',
217
+ parameters: {
218
+ type: 'object',
219
+ properties: {
220
+ location: {
221
+ type: 'string',
222
+ description: 'The city and state, e.g. San Francisco, CA',
223
+ },
224
+ },
225
+ required: ['location'],
226
+ },
227
+ },
228
+ },
229
+ ],
230
+ };
231
+ const mockStream = new web_1.TransformStream();
232
+ const writer = mockStream.writable.getWriter();
233
+ const encoder = new node_util_1.TextEncoder();
234
+ mockFetch.mockResolvedValueOnce({
235
+ ok: true,
236
+ body: mockStream.readable,
237
+ });
238
+ const callbacks = {
239
+ onOpen: jest.fn(),
240
+ onChunk: jest.fn(),
241
+ onTool: jest.fn(),
242
+ onFinish: jest.fn(),
243
+ };
244
+ const streamPromise = client.streamChatCompletion(mockRequest, callbacks);
245
+ // Simulate SSE events with tool calls
246
+ await writer.write(encoder.encode('data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' +
247
+ 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_123","type":"function","function":{"name":"get_weather"}}]},"finish_reason":null}]}\n\n' +
248
+ 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"location\\""}}]},"finish_reason":null}]}\n\n' +
249
+ 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":":\\"San Francisco, CA\\""}}]},"finish_reason":null}]}\n\n' +
250
+ 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"}"}}]},"finish_reason":null}]}\n\n' +
251
+ 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}\n\n' +
252
+ 'data: [DONE]\n\n'));
253
+ await writer.close();
254
+ await streamPromise;
255
+ expect(callbacks.onOpen).toHaveBeenCalledTimes(1);
256
+ expect(callbacks.onChunk).toHaveBeenCalledTimes(6);
257
+ expect(callbacks.onTool).toHaveBeenCalledTimes(4); // Called for each chunk with tool_calls
258
+ expect(callbacks.onFinish).toHaveBeenCalledTimes(1);
259
+ });
260
+ it('should handle errors in streaming chat completions', async () => {
261
+ const mockRequest = {
262
+ model: 'gpt-4o',
184
263
  messages: [{ role: types_1.MessageRole.User, content: 'Hello' }],
185
264
  };
186
- global.fetch.mockResolvedValueOnce({
265
+ mockFetch.mockResolvedValueOnce({
187
266
  ok: false,
188
267
  status: 400,
189
268
  json: () => Promise.resolve({ error: 'Bad Request' }),
190
269
  });
191
- await expect(client.generateContentStream(mockRequest, {})).rejects.toThrow('Bad Request');
192
- });
193
- it('should handle non-readable response body', async () => {
194
- const mockRequest = {
195
- provider: types_1.Provider.Ollama,
196
- model: 'llama2',
197
- messages: [{ role: types_1.MessageRole.User, content: 'Hello' }],
270
+ const callbacks = {
271
+ onError: jest.fn(),
198
272
  };
199
- global.fetch.mockResolvedValueOnce({
273
+ await expect(client.streamChatCompletion(mockRequest, callbacks)).rejects.toThrow('Bad Request');
274
+ expect(callbacks.onError).toHaveBeenCalledTimes(1);
275
+ });
276
+ });
277
+ describe('proxy', () => {
278
+ it('should proxy requests to a specific provider', async () => {
279
+ const mockResponse = { result: 'success' };
280
+ mockFetch.mockResolvedValueOnce({
200
281
  ok: true,
201
- body: null,
282
+ json: () => Promise.resolve(mockResponse),
283
+ });
284
+ const result = await client.proxy(types_1.Provider.OpenAI, 'embeddings', {
285
+ method: 'POST',
286
+ body: JSON.stringify({
287
+ model: 'text-embedding-ada-002',
288
+ input: 'Hello world',
289
+ }),
202
290
  });
203
- await expect(client.generateContentStream(mockRequest, {})).rejects.toThrow('Response body is not readable');
291
+ expect(result).toEqual(mockResponse);
292
+ expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/v1/proxy/openai/embeddings', expect.objectContaining({
293
+ method: 'POST',
294
+ body: JSON.stringify({
295
+ model: 'text-embedding-ada-002',
296
+ input: 'Hello world',
297
+ }),
298
+ }));
204
299
  });
205
- it('should handle empty events in the stream', async () => {
206
- const mockRequest = {
207
- provider: types_1.Provider.Ollama,
208
- model: 'llama2',
209
- messages: [{ role: types_1.MessageRole.User, content: 'Hello' }],
210
- };
211
- const mockStream = new TransformStream();
212
- const writer = mockStream.writable.getWriter();
213
- const encoder = new TextEncoder();
214
- global.fetch.mockResolvedValueOnce({
300
+ });
301
+ describe('healthCheck', () => {
302
+ it('should return true when API is healthy', async () => {
303
+ mockFetch.mockResolvedValueOnce({
215
304
  ok: true,
216
- body: mockStream.readable,
217
305
  });
218
- const callbacks = {
219
- onContentDelta: jest.fn(),
220
- };
221
- const streamPromise = client.generateContentStream(mockRequest, callbacks);
222
- await writer.write(encoder.encode('\n\n'));
223
- await writer.write(encoder.encode('event: content-delta\ndata: {"content": "Hello"}\n\n'));
224
- await writer.close();
225
- await streamPromise;
226
- expect(callbacks.onContentDelta).toHaveBeenCalledTimes(1);
227
- expect(callbacks.onContentDelta).toHaveBeenCalledWith('Hello');
306
+ const result = await client.healthCheck();
307
+ expect(result).toBe(true);
308
+ expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/health');
309
+ });
310
+ it('should return false when API is unhealthy', async () => {
311
+ mockFetch.mockRejectedValueOnce(new Error('API error'));
312
+ const result = await client.healthCheck();
313
+ expect(result).toBe(false);
314
+ });
315
+ });
316
+ describe('withOptions', () => {
317
+ it('should create a new client with merged options', () => {
318
+ const originalClient = new client_1.InferenceGatewayClient({
319
+ baseURL: 'http://localhost:8080/v1',
320
+ apiKey: 'test-key',
321
+ fetch: mockFetch,
322
+ });
323
+ const newClient = originalClient.withOptions({
324
+ defaultHeaders: { 'X-Custom-Header': 'value' },
325
+ });
326
+ expect(newClient).toBeInstanceOf(client_1.InferenceGatewayClient);
327
+ expect(newClient).not.toBe(originalClient);
328
+ // We can't directly test private properties, but we can test behavior
329
+ mockFetch.mockResolvedValueOnce({
330
+ ok: true,
331
+ json: () => Promise.resolve({}),
332
+ });
333
+ newClient.listModels();
334
+ expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/v1/models', expect.objectContaining({
335
+ headers: expect.any(Headers),
336
+ }));
228
337
  });
229
338
  });
230
339
  });
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@inference-gateway/sdk",
3
- "version": "0.3.0",
3
+ "version": "0.3.2",
4
4
  "description": "An SDK written in Typescript for the [Inference Gateway](https://github.com/inference-gateway/inference-gateway).",
5
5
  "main": "dist/src/index.js",
6
6
  "types": "dist/src/index.d.ts",
@@ -18,7 +18,8 @@
18
18
  "ollama",
19
19
  "cloudflare",
20
20
  "cohere",
21
- "typescript"
21
+ "typescript",
22
+ "deepseek"
22
23
  ],
23
24
  "author": "Eden Reich <eden.reich@gmail.com>",
24
25
  "license": "MIT",