promptfoo 0.18.1 → 0.18.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. package/dist/package.json +1 -1
  2. package/dist/src/assertions.d.ts +2 -2
  3. package/dist/src/assertions.d.ts.map +1 -1
  4. package/dist/src/assertions.js +42 -11
  5. package/dist/src/assertions.js.map +1 -1
  6. package/dist/src/cache.d.ts +1 -1
  7. package/dist/src/cache.d.ts.map +1 -1
  8. package/dist/src/cache.js +4 -4
  9. package/dist/src/cache.js.map +1 -1
  10. package/dist/src/evaluator.d.ts.map +1 -1
  11. package/dist/src/evaluator.js +5 -2
  12. package/dist/src/evaluator.js.map +1 -1
  13. package/dist/src/main.js +4 -4
  14. package/dist/src/main.js.map +1 -1
  15. package/dist/src/providers/azureopenai.d.ts +2 -2
  16. package/dist/src/providers/azureopenai.d.ts.map +1 -1
  17. package/dist/src/providers/azureopenai.js +7 -5
  18. package/dist/src/providers/azureopenai.js.map +1 -1
  19. package/dist/src/providers/llama.js +1 -1
  20. package/dist/src/providers/llama.js.map +1 -1
  21. package/dist/src/providers/localai.js +2 -2
  22. package/dist/src/providers/localai.js.map +1 -1
  23. package/dist/src/providers/ollama.d.ts +9 -0
  24. package/dist/src/providers/ollama.d.ts.map +1 -0
  25. package/dist/src/providers/ollama.js +66 -0
  26. package/dist/src/providers/ollama.js.map +1 -0
  27. package/dist/src/providers/openai.d.ts +2 -2
  28. package/dist/src/providers/openai.d.ts.map +1 -1
  29. package/dist/src/providers/openai.js +7 -5
  30. package/dist/src/providers/openai.js.map +1 -1
  31. package/dist/src/providers.d.ts.map +1 -1
  32. package/dist/src/providers.js +11 -5
  33. package/dist/src/providers.js.map +1 -1
  34. package/dist/src/types.d.ts +6 -2
  35. package/dist/src/types.d.ts.map +1 -1
  36. package/dist/src/util.d.ts +2 -0
  37. package/dist/src/util.d.ts.map +1 -1
  38. package/dist/src/util.js +24 -12
  39. package/dist/src/util.js.map +1 -1
  40. package/dist/src/web/client/assets/index-6d2a3573.js +200 -0
  41. package/dist/src/web/client/index.html +1 -1
  42. package/package.json +1 -1
  43. package/src/assertions.ts +45 -11
  44. package/src/cache.ts +3 -2
  45. package/src/evaluator.ts +5 -1
  46. package/src/main.ts +4 -4
  47. package/src/providers/azureopenai.ts +18 -6
  48. package/src/providers/llama.ts +2 -2
  49. package/src/providers/localai.ts +3 -3
  50. package/src/providers/ollama.ts +88 -0
  51. package/src/providers/openai.ts +8 -6
  52. package/src/providers.ts +20 -5
  53. package/src/types.ts +6 -2
  54. package/src/util.ts +25 -17
  55. package/src/web/client/package-lock.json +5726 -0
  56. package/src/web/client/src/EvalOutputPromptDialog.tsx +78 -16
  57. package/src/web/client/src/ResultsTable.tsx +32 -9
  58. package/src/web/client/src/ResultsView.tsx +1 -1
  59. package/src/web/client/src/types.ts +3 -1
  60. package/dist/src/web/client/assets/index-8388d689.js +0 -199
@@ -5,7 +5,7 @@
5
5
  <link rel="icon" type="image/svg+xml" href="favicon.ico" />
6
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
7
  <title>promptfoo web viewer</title>
8
- <script type="module" crossorigin src="/assets/index-8388d689.js"></script>
8
+ <script type="module" crossorigin src="/assets/index-6d2a3573.js"></script>
9
9
  <link rel="stylesheet" href="/assets/index-d2b6a160.css">
10
10
  </head>
11
11
  <body>
package/package.json CHANGED
@@ -2,7 +2,7 @@
2
2
  "name": "promptfoo",
3
3
  "description": "LLM eval & testing toolkit",
4
4
  "author": "Ian Webster",
5
- "version": "0.18.1",
5
+ "version": "0.18.3",
6
6
  "license": "MIT",
7
7
  "type": "commonjs",
8
8
  "main": "dist/src/index.js",
package/src/assertions.ts CHANGED
@@ -1,10 +1,9 @@
1
1
  import rouge from 'rouge';
2
2
  import invariant from 'tiny-invariant';
3
- import nunjucks from 'nunjucks';
4
3
 
5
4
  import telemetry from './telemetry';
6
5
  import { DefaultEmbeddingProvider, DefaultGradingProvider } from './providers/openai';
7
- import { cosineSimilarity, fetchWithRetries } from './util';
6
+ import { cosineSimilarity, fetchWithRetries, getNunjucksEngine } from './util';
8
7
  import { loadApiProvider } from './providers';
9
8
  import { DEFAULT_GRADING_PROMPT } from './prompts';
10
9
 
@@ -18,6 +17,8 @@ import type {
18
17
 
19
18
  const DEFAULT_SEMANTIC_SIMILARITY_THRESHOLD = 0.8;
20
19
 
20
+ const nunjucks = getNunjucksEngine();
21
+
21
22
  function handleRougeScore(
22
23
  baseType: 'rouge-n',
23
24
  assertion: Assertion,
@@ -40,6 +41,7 @@ function handleRougeScore(
40
41
  : `${baseType.toUpperCase()} score ${score} is less than threshold ${
41
42
  assertion.threshold || 0.75
42
43
  }`,
44
+ assertion,
43
45
  };
44
46
  }
45
47
 
@@ -51,17 +53,23 @@ export async function runAssertions(test: AtomicTestCase, output: string): Promi
51
53
  };
52
54
 
53
55
  if (!test.assert || test.assert.length < 1) {
54
- return { pass: true, score: 1, reason: 'No assertions', tokensUsed };
56
+ return { pass: true, score: 1, reason: 'No assertions', tokensUsed, assertion: null };
55
57
  }
56
58
 
57
59
  let totalScore = 0;
58
60
  let totalWeight = 0;
61
+ let allPass = true;
62
+ let failedReason = '';
63
+ const componentResults: GradingResult[] = [];
64
+
59
65
  for (const assertion of test.assert) {
60
66
  const weight = assertion.weight || 1;
61
67
  totalWeight += weight;
62
68
 
63
69
  const result = await runAssertion(assertion, test, output);
64
70
  totalScore += result.score * weight;
71
+ componentResults.push(result);
72
+
65
73
  if (result.tokensUsed) {
66
74
  tokensUsed.total += result.tokensUsed.total;
67
75
  tokensUsed.prompt += result.tokensUsed.prompt;
@@ -69,16 +77,21 @@ export async function runAssertions(test: AtomicTestCase, output: string): Promi
69
77
  }
70
78
 
71
79
  if (!result.pass) {
72
- // Short-circuit assertions
73
- return result;
80
+ allPass = false;
81
+ failedReason = result.reason;
82
+ if (process.env.PROMPTFOO_SHORT_CIRCUIT_TEST_FAILURES) {
83
+ return result;
84
+ }
74
85
  }
75
86
  }
76
87
 
77
88
  return {
78
- pass: true,
89
+ pass: allPass,
79
90
  score: totalScore / totalWeight,
80
- reason: 'All assertions passed',
91
+ reason: allPass ? 'All assertions passed' : failedReason,
81
92
  tokensUsed,
93
+ componentResults,
94
+ assertion: null,
82
95
  };
83
96
  }
84
97
 
@@ -114,6 +127,7 @@ export async function runAssertion(
114
127
  pass,
115
128
  score: pass ? 1 : 0,
116
129
  reason: pass ? 'Assertion passed' : `Expected output "${renderedValue}"`,
130
+ assertion,
117
131
  };
118
132
  }
119
133
 
@@ -128,6 +142,7 @@ export async function runAssertion(
128
142
  pass,
129
143
  score: pass ? 1 : 0,
130
144
  reason: pass ? 'Assertion passed' : 'Expected output to be valid JSON',
145
+ assertion,
131
146
  };
132
147
  }
133
148
 
@@ -144,6 +159,7 @@ export async function runAssertion(
144
159
  reason: pass
145
160
  ? 'Assertion passed'
146
161
  : `Expected output to ${inverse ? 'not ' : ''}contain "${renderedValue}"`,
162
+ assertion,
147
163
  };
148
164
  }
149
165
 
@@ -160,6 +176,7 @@ export async function runAssertion(
160
176
  reason: pass
161
177
  ? 'Assertion passed'
162
178
  : `Expected output to ${inverse ? 'not ' : ''}contain one of "${renderedValue.join(', ')}"`,
179
+ assertion,
163
180
  };
164
181
  }
165
182
 
@@ -176,6 +193,7 @@ export async function runAssertion(
176
193
  reason: pass
177
194
  ? 'Assertion passed'
178
195
  : `Expected output to ${inverse ? 'not ' : ''}contain all of "${renderedValue.join(', ')}"`,
196
+ assertion,
179
197
  };
180
198
  }
181
199
 
@@ -193,6 +211,7 @@ export async function runAssertion(
193
211
  reason: pass
194
212
  ? 'Assertion passed'
195
213
  : `Expected output to ${inverse ? 'not ' : ''}match regex "${renderedValue}"`,
214
+ assertion,
196
215
  };
197
216
  }
198
217
 
@@ -209,6 +228,7 @@ export async function runAssertion(
209
228
  reason: pass
210
229
  ? 'Assertion passed'
211
230
  : `Expected output to ${inverse ? 'not ' : ''}contain "${renderedValue}"`,
231
+ assertion,
212
232
  };
213
233
  }
214
234
 
@@ -225,6 +245,7 @@ export async function runAssertion(
225
245
  reason: pass
226
246
  ? 'Assertion passed'
227
247
  : `Expected output to ${inverse ? 'not ' : ''}start with "${renderedValue}"`,
248
+ assertion,
228
249
  };
229
250
  }
230
251
 
@@ -236,6 +257,7 @@ export async function runAssertion(
236
257
  reason: pass
237
258
  ? 'Assertion passed'
238
259
  : `Expected output to ${inverse ? 'not ' : ''}contain valid JSON`,
260
+ assertion,
239
261
  };
240
262
  }
241
263
 
@@ -265,6 +287,7 @@ export async function runAssertion(
265
287
  score: 0,
266
288
  reason: `Custom function threw error: ${(err as Error).message}
267
289
  ${renderedValue}`,
290
+ assertion,
268
291
  };
269
292
  }
270
293
  return {
@@ -274,6 +297,7 @@ ${renderedValue}`,
274
297
  ? 'Assertion passed'
275
298
  : `Custom function returned ${inverse ? 'true' : 'false'}
276
299
  ${renderedValue}`,
300
+ assertion,
277
301
  };
278
302
  }
279
303
 
@@ -309,6 +333,7 @@ ${renderedValue}`,
309
333
  pass: false,
310
334
  score: 0,
311
335
  reason: `Python code execution failed: ${(err as Error).message}`,
336
+ assertion,
312
337
  };
313
338
  }
314
339
  return {
@@ -318,6 +343,7 @@ ${renderedValue}`,
318
343
  ? 'Assertion passed'
319
344
  : `Python code returned ${pass ? 'true' : 'false'}
320
345
  ${assertion.value}`,
346
+ assertion,
321
347
  };
322
348
  }
323
349
 
@@ -327,7 +353,10 @@ ${assertion.value}`,
327
353
  typeof renderedValue === 'string',
328
354
  '"contains" assertion type must have a string value',
329
355
  );
330
- return matchesSimilarity(renderedValue, output, assertion.threshold || 0.75, inverse);
356
+ return {
357
+ assertion,
358
+ ...(await matchesSimilarity(renderedValue, output, assertion.threshold || 0.75, inverse)),
359
+ };
331
360
  }
332
361
 
333
362
  if (baseType === 'llm-rubric') {
@@ -336,7 +365,10 @@ ${assertion.value}`,
336
365
  typeof renderedValue === 'string',
337
366
  '"contains" assertion type must have a string value',
338
367
  );
339
- return matchesLlmRubric(renderedValue, output, test.options);
368
+ return {
369
+ assertion,
370
+ ...(await matchesLlmRubric(renderedValue, output, test.options)),
371
+ };
340
372
  }
341
373
 
342
374
  if (baseType === 'webhook') {
@@ -378,6 +410,7 @@ ${assertion.value}`,
378
410
  pass: false,
379
411
  score: 0,
380
412
  reason: `Webhook error: ${(err as Error).message}`,
413
+ assertion,
381
414
  };
382
415
  }
383
416
 
@@ -385,6 +418,7 @@ ${assertion.value}`,
385
418
  pass,
386
419
  score,
387
420
  reason: pass ? 'Assertion passed' : `Webhook returned ${inverse ? 'true' : 'false'}`,
421
+ assertion,
388
422
  };
389
423
  }
390
424
 
@@ -422,7 +456,7 @@ export async function matchesSimilarity(
422
456
  output: string,
423
457
  threshold: number,
424
458
  inverse: boolean = false,
425
- ): Promise<GradingResult> {
459
+ ): Promise<Omit<GradingResult, 'assertion'>> {
426
460
  const expectedEmbedding = await DefaultEmbeddingProvider.callEmbeddingApi(expected);
427
461
  const outputEmbedding = await DefaultEmbeddingProvider.callEmbeddingApi(output);
428
462
 
@@ -477,7 +511,7 @@ export async function matchesLlmRubric(
477
511
  expected: string,
478
512
  output: string,
479
513
  options?: GradingConfig,
480
- ): Promise<GradingResult> {
514
+ ): Promise<Omit<GradingResult, 'assertion'>> {
481
515
  if (!options) {
482
516
  throw new Error(
483
517
  'Cannot grade output without grading config. Specify --grader option or grading config.',
package/src/cache.ts CHANGED
@@ -42,10 +42,11 @@ export function getCache() {
42
42
  return cacheInstance;
43
43
  }
44
44
 
45
- export async function fetchJsonWithCache(
45
+ export async function fetchWithCache(
46
46
  url: RequestInfo,
47
47
  options: RequestInit = {},
48
48
  timeout: number,
49
+ format: 'json' | 'text' = 'json',
49
50
  ): Promise<{ data: any; cached: boolean }> {
50
51
  if (!enabled) {
51
52
  const resp = await fetchWithRetries(url, options, timeout);
@@ -75,7 +76,7 @@ export async function fetchJsonWithCache(
75
76
  // Fetch the actual data and store it in the cache
76
77
  const response = await fetchWithRetries(url, options, timeout);
77
78
  try {
78
- const data = await response.json();
79
+ const data = format === 'json' ? await response.json() : await response.text();
79
80
  if (response.ok) {
80
81
  logger.debug(`Storing ${url} response in cache: ${JSON.stringify(data)}`);
81
82
  await cache.set(cacheKey, JSON.stringify(data));
package/src/evaluator.ts CHANGED
@@ -2,13 +2,13 @@ import readline from 'readline';
2
2
 
3
3
  import async from 'async';
4
4
  import chalk from 'chalk';
5
- import nunjucks from 'nunjucks';
6
5
  import invariant from 'tiny-invariant';
7
6
 
8
7
  import logger from './logger';
9
8
  import telemetry from './telemetry';
10
9
  import { runAssertions } from './assertions';
11
10
  import { generatePrompts } from './suggestions';
11
+ import { getNunjucksEngine } from './util';
12
12
 
13
13
  import type { SingleBar } from 'cli-progress';
14
14
  import type {
@@ -39,6 +39,8 @@ interface RunEvalOptions {
39
39
 
40
40
  const DEFAULT_MAX_CONCURRENCY = 4;
41
41
 
42
+ const nunjucks = getNunjucksEngine();
43
+
42
44
  function generateVarCombinations(
43
45
  vars: Record<string, string | string[] | any>,
44
46
  ): Record<string, string | any[]>[] {
@@ -156,6 +158,7 @@ class Evaluator {
156
158
  this.stats.tokenUsage.completion += checkResult.tokensUsed.completion;
157
159
  }
158
160
  ret.response = processedResponse;
161
+ ret.gradingResult = checkResult;
159
162
  } else {
160
163
  ret.success = false;
161
164
  ret.score = 0;
@@ -464,6 +467,7 @@ class Evaluator {
464
467
  prompt: row.prompt.raw,
465
468
  latencyMs: row.latencyMs,
466
469
  tokenUsage: row.response?.tokenUsage,
470
+ gradingResult: row.gradingResult,
467
471
  };
468
472
  },
469
473
  );
package/src/main.ts CHANGED
@@ -286,7 +286,7 @@ async function main() {
286
286
  process.env.PROMPTFOO_DISABLE_SHARING === '1'
287
287
  ? false
288
288
  : fileConfig.sharing ?? defaultConfig.sharing ?? true,
289
- defaultTest: fileConfig.defaultTest,
289
+ defaultTest: fileConfig.defaultTest || defaultConfig.defaultTest,
290
290
  };
291
291
 
292
292
  // Validation
@@ -312,7 +312,7 @@ async function main() {
312
312
  cmdObj.tests ? undefined : basePath,
313
313
  );
314
314
 
315
- //parse testCases for each scenario
315
+ // Parse testCases for each scenario
316
316
  if (fileConfig.scenarios) {
317
317
  for (const scenario of fileConfig.scenarios) {
318
318
  const parsedScenarioTests: TestCase[] = await readTests(
@@ -335,8 +335,8 @@ async function main() {
335
335
  prefix: cmdObj.promptPrefix,
336
336
  suffix: cmdObj.promptSuffix,
337
337
  provider: cmdObj.grader,
338
- // rubricPrompt:
339
- // postprocess
338
+ // rubricPrompt
339
+ ...(config.defaultTest?.options || {}),
340
340
  },
341
341
  ...config.defaultTest,
342
342
  };
@@ -1,5 +1,5 @@
1
1
  import logger from '../logger';
2
- import { fetchJsonWithCache } from '../cache';
2
+ import { fetchWithCache } from '../cache';
3
3
  import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
4
4
 
5
5
  import type { ApiProvider, ProviderEmbeddingResponse, ProviderResponse } from '../types.js';
@@ -61,7 +61,7 @@ export class AzureOpenAiEmbeddingProvider extends AzureOpenAiGenericProvider {
61
61
  let data,
62
62
  cached = false;
63
63
  try {
64
- ({ data, cached } = (await fetchJsonWithCache(
64
+ ({ data, cached } = (await fetchWithCache(
65
65
  `https://${this.apiHost}/openai/deployments/${this.deploymentName}/embeddings?api-version=2023-07-01-preview`,
66
66
  {
67
67
  method: 'POST',
@@ -117,9 +117,15 @@ export class AzureOpenAiEmbeddingProvider extends AzureOpenAiGenericProvider {
117
117
  export class AzureOpenAiCompletionProvider extends AzureOpenAiGenericProvider {
118
118
  options: AzureOpenAiCompletionOptions;
119
119
 
120
- constructor(deploymentName: string, apiKey?: string, context?: AzureOpenAiCompletionOptions) {
120
+ constructor(
121
+ deploymentName: string,
122
+ apiKey?: string,
123
+ context?: AzureOpenAiCompletionOptions,
124
+ id?: string,
125
+ ) {
121
126
  super(deploymentName, apiKey);
122
127
  this.options = context || {};
128
+ this.id = id ? () => id : this.id;
123
129
  }
124
130
 
125
131
  async callApi(prompt: string, options?: AzureOpenAiCompletionOptions): Promise<ProviderResponse> {
@@ -165,7 +171,7 @@ export class AzureOpenAiCompletionProvider extends AzureOpenAiGenericProvider {
165
171
  let data,
166
172
  cached = false;
167
173
  try {
168
- ({ data, cached } = (await fetchJsonWithCache(
174
+ ({ data, cached } = (await fetchWithCache(
169
175
  `https://${this.apiHost}/openai/deployments/${this.deploymentName}/completions?api-version=2023-07-01-preview`,
170
176
  {
171
177
  method: 'POST',
@@ -205,9 +211,15 @@ export class AzureOpenAiCompletionProvider extends AzureOpenAiGenericProvider {
205
211
  export class AzureOpenAiChatCompletionProvider extends AzureOpenAiGenericProvider {
206
212
  options: AzureOpenAiCompletionOptions;
207
213
 
208
- constructor(deploymentName: string, apiKey?: string, context?: AzureOpenAiCompletionOptions) {
214
+ constructor(
215
+ deploymentName: string,
216
+ apiKey?: string,
217
+ context?: AzureOpenAiCompletionOptions,
218
+ id?: string,
219
+ ) {
209
220
  super(deploymentName, apiKey);
210
221
  this.options = context || {};
222
+ this.id = id ? () => id : this.id;
211
223
  }
212
224
 
213
225
  async callApi(prompt: string, options?: AzureOpenAiCompletionOptions): Promise<ProviderResponse> {
@@ -246,7 +258,7 @@ export class AzureOpenAiChatCompletionProvider extends AzureOpenAiGenericProvide
246
258
  let data,
247
259
  cached = false;
248
260
  try {
249
- ({ data, cached } = (await fetchJsonWithCache(
261
+ ({ data, cached } = (await fetchWithCache(
250
262
  `https://${this.apiHost}/openai/deployments/${this.deploymentName}/chat/completions?api-version=2023-07-01-preview`,
251
263
  {
252
264
  method: 'POST',
@@ -1,4 +1,4 @@
1
- import { fetchJsonWithCache } from '../cache';
1
+ import { fetchWithCache } from '../cache';
2
2
  import { REQUEST_TIMEOUT_MS } from './shared';
3
3
 
4
4
  import type { ApiProvider, ProviderResponse } from '../types.js';
@@ -65,7 +65,7 @@ export class LlamaProvider implements ApiProvider {
65
65
 
66
66
  let response;
67
67
  try {
68
- response = await fetchJsonWithCache(
68
+ response = await fetchWithCache(
69
69
  `${process.env.LLAMA_BASE_URL || 'http://localhost:8080'}/completion`,
70
70
  {
71
71
  method: 'POST',
@@ -1,5 +1,5 @@
1
1
  import logger from '../logger';
2
- import { fetchJsonWithCache } from '../cache';
2
+ import { fetchWithCache } from '../cache';
3
3
  import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
4
4
 
5
5
  import type { ApiProvider, ProviderResponse } from '../types.js';
@@ -40,7 +40,7 @@ export class LocalAiChatProvider extends LocalAiGenericProvider {
40
40
  let data,
41
41
  cached = false;
42
42
  try {
43
- ({ data, cached } = (await fetchJsonWithCache(
43
+ ({ data, cached } = (await fetchWithCache(
44
44
  `${this.apiBaseUrl}/chat/completions`,
45
45
  {
46
46
  method: 'POST',
@@ -81,7 +81,7 @@ export class LocalAiCompletionProvider extends LocalAiGenericProvider {
81
81
  let data,
82
82
  cached = false;
83
83
  try {
84
- ({ data, cached } = (await fetchJsonWithCache(
84
+ ({ data, cached } = (await fetchWithCache(
85
85
  `${this.apiBaseUrl}/completions`,
86
86
  {
87
87
  method: 'POST',
@@ -0,0 +1,88 @@
1
+ import logger from '../logger';
2
+ import { fetchWithCache } from '../cache';
3
+
4
+ import type { ApiProvider, ProviderResponse } from '../types.js';
5
+ import { REQUEST_TIMEOUT_MS } from './shared';
6
+
7
+ interface OllamaJsonL {
8
+ model: string;
9
+ created_at: string;
10
+ response?: string;
11
+ done: boolean;
12
+ context?: number[];
13
+ total_duration?: number;
14
+ load_duration?: number;
15
+ sample_count?: number;
16
+ sample_duration?: number;
17
+ prompt_eval_count?: number;
18
+ prompt_eval_duration?: number;
19
+ eval_count?: number;
20
+ eval_duration?: number;
21
+ }
22
+
23
+ export class OllamaProvider implements ApiProvider {
24
+ modelName: string;
25
+
26
+ constructor(modelName: string) {
27
+ this.modelName = modelName;
28
+ }
29
+
30
+ id(): string {
31
+ return `ollama:${this.modelName}`;
32
+ }
33
+
34
+ toString(): string {
35
+ return `[Ollama Provider ${this.modelName}]`;
36
+ }
37
+
38
+ async callApi(prompt: string): Promise<ProviderResponse> {
39
+ const params = {
40
+ model: this.modelName,
41
+ prompt,
42
+ };
43
+
44
+ logger.debug(`Calling Ollama API: ${JSON.stringify(params)}`);
45
+ let response;
46
+ try {
47
+ response = await fetchWithCache(
48
+ `${process.env.OLLAMA_BASE_URL || 'http://localhost:11434'}/api/generate`,
49
+ {
50
+ method: 'POST',
51
+ headers: {
52
+ 'Content-Type': 'application/json',
53
+ },
54
+ body: JSON.stringify(params),
55
+ },
56
+ REQUEST_TIMEOUT_MS,
57
+ 'text',
58
+ );
59
+ } catch (err) {
60
+ return {
61
+ error: `API call error: ${String(err)}`,
62
+ };
63
+ }
64
+ logger.debug(`\tOllama API response: ${response.data}`);
65
+
66
+ try {
67
+ const output = response.data
68
+ .split('\n')
69
+ .map((line: string) => {
70
+ const parsed = JSON.parse(line) as OllamaJsonL;
71
+ if (parsed.response) {
72
+ return parsed.response;
73
+ }
74
+ return null;
75
+ })
76
+ .filter((s: string | null) => s !== null)
77
+ .join('');
78
+
79
+ return {
80
+ output,
81
+ };
82
+ } catch (err) {
83
+ return {
84
+ error: `API response error: ${String(err)}: ${JSON.stringify(response.data)}`,
85
+ };
86
+ }
87
+ }
88
+ }
@@ -1,5 +1,5 @@
1
1
  import logger from '../logger';
2
- import { fetchJsonWithCache } from '../cache';
2
+ import { fetchWithCache } from '../cache';
3
3
  import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
4
4
 
5
5
  import type { ApiProvider, ProviderEmbeddingResponse, ProviderResponse } from '../types.js';
@@ -61,7 +61,7 @@ export class OpenAiEmbeddingProvider extends OpenAiGenericProvider {
61
61
  let data,
62
62
  cached = false;
63
63
  try {
64
- ({ data, cached } = (await fetchJsonWithCache(
64
+ ({ data, cached } = (await fetchWithCache(
65
65
  `https://${this.apiHost}/v1/embeddings`,
66
66
  {
67
67
  method: 'POST',
@@ -125,12 +125,13 @@ export class OpenAiCompletionProvider extends OpenAiGenericProvider {
125
125
 
126
126
  options: OpenAiCompletionOptions;
127
127
 
128
- constructor(modelName: string, apiKey?: string, context?: OpenAiCompletionOptions) {
128
+ constructor(modelName: string, apiKey?: string, context?: OpenAiCompletionOptions, id?: string) {
129
129
  if (!OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelName)) {
130
130
  logger.warn(`Using unknown OpenAI completion model: ${modelName}`);
131
131
  }
132
132
  super(modelName, apiKey);
133
133
  this.options = context || {};
134
+ this.id = id ? () => id : this.id;
134
135
  }
135
136
 
136
137
  async callApi(prompt: string, options?: OpenAiCompletionOptions): Promise<ProviderResponse> {
@@ -176,7 +177,7 @@ export class OpenAiCompletionProvider extends OpenAiGenericProvider {
176
177
  let data,
177
178
  cached = false;
178
179
  try {
179
- ({ data, cached } = (await fetchJsonWithCache(
180
+ ({ data, cached } = (await fetchWithCache(
180
181
  `https://${this.apiHost}/v1/completions`,
181
182
  {
182
183
  method: 'POST',
@@ -229,12 +230,13 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
229
230
 
230
231
  options: OpenAiCompletionOptions;
231
232
 
232
- constructor(modelName: string, apiKey?: string, context?: OpenAiCompletionOptions) {
233
+ constructor(modelName: string, apiKey?: string, context?: OpenAiCompletionOptions, id?: string) {
233
234
  if (!OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS.includes(modelName)) {
234
235
  logger.warn(`Using unknown OpenAI chat model: ${modelName}`);
235
236
  }
236
237
  super(modelName, apiKey);
237
238
  this.options = context || {};
239
+ this.id = id ? () => id : this.id;
238
240
  }
239
241
 
240
242
  async callApi(prompt: string, options?: OpenAiCompletionOptions): Promise<ProviderResponse> {
@@ -273,7 +275,7 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
273
275
  let data,
274
276
  cached = false;
275
277
  try {
276
- ({ data, cached } = (await fetchJsonWithCache(
278
+ ({ data, cached } = (await fetchWithCache(
277
279
  `https://${this.apiHost}/v1/chat/completions`,
278
280
  {
279
281
  method: 'POST',
package/src/providers.ts CHANGED
@@ -5,6 +5,7 @@ import { AnthropicCompletionProvider } from './providers/anthropic';
5
5
  import { ReplicateProvider } from './providers/replicate';
6
6
  import { LocalAiCompletionProvider, LocalAiChatProvider } from './providers/localai';
7
7
  import { LlamaProvider } from './providers/llama';
8
+ import { OllamaProvider } from './providers/ollama';
8
9
  import { ScriptCompletionProvider } from './providers/scriptCompletion';
9
10
  import {
10
11
  AzureOpenAiChatCompletionProvider,
@@ -44,7 +45,8 @@ export async function loadApiProviders(
44
45
  };
45
46
  } else {
46
47
  const id = Object.keys(provider)[0];
47
- const context = { ...provider[id], id };
48
+ const providerObject = provider[id];
49
+ const context = { ...providerObject, id: providerObject.id || id };
48
50
  return loadApiProvider(id, context, basePath);
49
51
  }
50
52
  }),
@@ -84,9 +86,9 @@ export async function loadApiProvider(
84
86
  context?.config,
85
87
  );
86
88
  } else if (OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS.includes(modelType)) {
87
- return new OpenAiChatCompletionProvider(modelType, undefined, context?.config);
89
+ return new OpenAiChatCompletionProvider(modelType, undefined, context?.config, context?.id);
88
90
  } else if (OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelType)) {
89
- return new OpenAiCompletionProvider(modelType, undefined, context?.config);
91
+ return new OpenAiCompletionProvider(modelType, undefined, context?.config, context?.id);
90
92
  } else {
91
93
  throw new Error(
92
94
  `Unknown OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,
@@ -99,9 +101,19 @@ export async function loadApiProvider(
99
101
  const deploymentName = options[2];
100
102
 
101
103
  if (modelType === 'chat') {
102
- return new AzureOpenAiChatCompletionProvider(deploymentName, undefined, context?.config);
104
+ return new AzureOpenAiChatCompletionProvider(
105
+ deploymentName,
106
+ undefined,
107
+ context?.config,
108
+ context?.id,
109
+ );
103
110
  } else if (modelType === 'completion') {
104
- return new AzureOpenAiCompletionProvider(deploymentName, undefined, context?.config);
111
+ return new AzureOpenAiCompletionProvider(
112
+ deploymentName,
113
+ undefined,
114
+ context?.config,
115
+ context?.id,
116
+ );
105
117
  } else {
106
118
  throw new Error(
107
119
  `Unknown Azure OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,
@@ -137,6 +149,9 @@ export async function loadApiProvider(
137
149
  if (providerPath === 'llama' || providerPath.startsWith('llama:')) {
138
150
  const modelName = providerPath.split(':')[1];
139
151
  return new LlamaProvider(modelName, context?.config);
152
+ } else if (providerPath.startsWith('ollama:')) {
153
+ const modelName = providerPath.split(':')[1];
154
+ return new OllamaProvider(modelName);
140
155
  } else if (providerPath?.startsWith('localai:')) {
141
156
  const options = providerPath.split(':');
142
157
  const modelType = options[1];