@proteinjs/conversation 2.1.3 → 2.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +11 -0
- package/LICENSE +21 -0
- package/dist/index.d.ts +8 -1
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +8 -0
- package/dist/index.js.map +1 -1
- package/dist/src/Conversation.d.ts +43 -1
- package/dist/src/Conversation.d.ts.map +1 -1
- package/dist/src/Conversation.js +255 -5
- package/dist/src/Conversation.js.map +1 -1
- package/dist/src/OpenAi.d.ts +29 -0
- package/dist/src/OpenAi.d.ts.map +1 -1
- package/dist/src/OpenAi.js +77 -30
- package/dist/src/OpenAi.js.map +1 -1
- package/dist/src/fs/conversation_fs/ConversationFsModule.d.ts +1 -0
- package/dist/src/fs/conversation_fs/ConversationFsModule.d.ts.map +1 -1
- package/dist/src/fs/conversation_fs/ConversationFsModule.js +6 -2
- package/dist/src/fs/conversation_fs/ConversationFsModule.js.map +1 -1
- package/dist/src/fs/conversation_fs/FsFunctions.d.ts +36 -3
- package/dist/src/fs/conversation_fs/FsFunctions.d.ts.map +1 -1
- package/dist/src/fs/conversation_fs/FsFunctions.js +142 -20
- package/dist/src/fs/conversation_fs/FsFunctions.js.map +1 -1
- package/dist/src/fs/keyword_to_files_index/KeywordToFilesIndexModule.d.ts +4 -1
- package/dist/src/fs/keyword_to_files_index/KeywordToFilesIndexModule.d.ts.map +1 -1
- package/dist/src/fs/keyword_to_files_index/KeywordToFilesIndexModule.js +13 -9
- package/dist/src/fs/keyword_to_files_index/KeywordToFilesIndexModule.js.map +1 -1
- package/index.ts +10 -1
- package/package.json +6 -4
- package/src/Conversation.ts +311 -5
- package/src/OpenAi.ts +123 -13
- package/src/fs/conversation_fs/ConversationFsModule.ts +8 -2
- package/src/fs/conversation_fs/FsFunctions.ts +97 -17
- package/src/fs/keyword_to_files_index/KeywordToFilesIndexModule.ts +14 -9
package/src/Conversation.ts
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { ChatCompletionMessageParam } from 'openai/resources/chat';
|
|
2
|
-
import { DEFAULT_MODEL, OpenAi } from './OpenAi';
|
|
2
|
+
import { DEFAULT_MODEL, OpenAi, ToolInvocationProgressEvent } from './OpenAi';
|
|
3
3
|
import { MessageHistory } from './history/MessageHistory';
|
|
4
4
|
import { Function } from './Function';
|
|
5
5
|
import { Logger, LogLevel } from '@proteinjs/logger';
|
|
@@ -9,6 +9,8 @@ import { ConversationModule } from './ConversationModule';
|
|
|
9
9
|
import { TiktokenModel, encoding_for_model } from 'tiktoken';
|
|
10
10
|
import { searchLibrariesFunctionName } from './fs/package/PackageFunctions';
|
|
11
11
|
import { UsageData } from './UsageData';
|
|
12
|
+
import type { ModelMessage, LanguageModel } from 'ai';
|
|
13
|
+
import { generateObject as aiGenerateObject, jsonSchema } from 'ai';
|
|
12
14
|
|
|
13
15
|
export type ConversationParams = {
|
|
14
16
|
name: string;
|
|
@@ -21,8 +23,36 @@ export type ConversationParams = {
|
|
|
21
23
|
};
|
|
22
24
|
};
|
|
23
25
|
|
|
26
|
+
/** Object-only generation (no tool calls in this run). */
|
|
27
|
+
export type GenerateObjectParams<S> = {
|
|
28
|
+
/** Same input contract as generateResponse */
|
|
29
|
+
messages: (string | ChatCompletionMessageParam)[];
|
|
30
|
+
|
|
31
|
+
/** A ready AI SDK model, e.g., openai('gpt-5') / openai('gpt-4o') */
|
|
32
|
+
model: LanguageModel;
|
|
33
|
+
|
|
34
|
+
/** Zod schema or JSON Schema */
|
|
35
|
+
schema: S;
|
|
36
|
+
|
|
37
|
+
/** Sampling & limits */
|
|
38
|
+
temperature?: number;
|
|
39
|
+
topP?: number;
|
|
40
|
+
maxTokens?: number;
|
|
41
|
+
|
|
42
|
+
/** Usage callback */
|
|
43
|
+
onUsageData?: (usageData: UsageData) => Promise<void>;
|
|
44
|
+
|
|
45
|
+
/** Append final JSON to history as assistant text; default true */
|
|
46
|
+
recordInHistory?: boolean;
|
|
47
|
+
};
|
|
48
|
+
|
|
49
|
+
export type GenerateObjectOutcome<T> = {
|
|
50
|
+
object: T; // validated final object
|
|
51
|
+
usageData: UsageData;
|
|
52
|
+
};
|
|
53
|
+
|
|
24
54
|
export class Conversation {
|
|
25
|
-
private tokenLimit =
|
|
55
|
+
private tokenLimit = 50000;
|
|
26
56
|
private history;
|
|
27
57
|
private systemMessages: ChatCompletionMessageParam[] = [];
|
|
28
58
|
private functions: Function[] = [];
|
|
@@ -42,7 +72,7 @@ export class Conversation {
|
|
|
42
72
|
});
|
|
43
73
|
this.logger = new Logger({ name: params.name, logLevel: params.logLevel });
|
|
44
74
|
|
|
45
|
-
if (
|
|
75
|
+
if (params?.limits?.enforceLimits) {
|
|
46
76
|
this.addFunctions('Conversation', [summarizeConversationHistoryFunction(this)]);
|
|
47
77
|
}
|
|
48
78
|
|
|
@@ -137,7 +167,7 @@ export class Conversation {
|
|
|
137
167
|
}
|
|
138
168
|
|
|
139
169
|
private async enforceTokenLimit(messages: (string | ChatCompletionMessageParam)[], model?: TiktokenModel) {
|
|
140
|
-
if (this.params.limits?.enforceLimits
|
|
170
|
+
if (!this.params.limits?.enforceLimits) {
|
|
141
171
|
return;
|
|
142
172
|
}
|
|
143
173
|
|
|
@@ -234,9 +264,12 @@ export class Conversation {
|
|
|
234
264
|
async generateResponse({
|
|
235
265
|
messages,
|
|
236
266
|
model,
|
|
267
|
+
...rest
|
|
237
268
|
}: {
|
|
238
269
|
messages: (string | ChatCompletionMessageParam)[];
|
|
239
270
|
model?: TiktokenModel;
|
|
271
|
+
onUsageData?: (usageData: UsageData) => Promise<void>;
|
|
272
|
+
onToolInvocation?: (evt: ToolInvocationProgressEvent) => void;
|
|
240
273
|
}) {
|
|
241
274
|
await this.ensureModulesProcessed();
|
|
242
275
|
await this.enforceTokenLimit(messages, model);
|
|
@@ -245,7 +278,7 @@ export class Conversation {
|
|
|
245
278
|
functions: this.functions,
|
|
246
279
|
messageModerators: this.messageModerators,
|
|
247
280
|
logLevel: this.params.logLevel,
|
|
248
|
-
}).generateResponse({ messages, model });
|
|
281
|
+
}).generateResponse({ messages, model, ...rest });
|
|
249
282
|
}
|
|
250
283
|
|
|
251
284
|
async generateStreamingResponse({
|
|
@@ -257,6 +290,7 @@ export class Conversation {
|
|
|
257
290
|
model?: TiktokenModel;
|
|
258
291
|
abortSignal?: AbortSignal;
|
|
259
292
|
onUsageData?: (usageData: UsageData) => Promise<void>;
|
|
293
|
+
onToolInvocation?: (evt: ToolInvocationProgressEvent) => void;
|
|
260
294
|
}) {
|
|
261
295
|
await this.ensureModulesProcessed();
|
|
262
296
|
await this.enforceTokenLimit(messages, model);
|
|
@@ -268,6 +302,278 @@ export class Conversation {
|
|
|
268
302
|
}).generateStreamingResponse({ messages, model, ...rest });
|
|
269
303
|
}
|
|
270
304
|
|
|
305
|
+
/**
|
|
306
|
+
* Generate a validated JSON object (no tools in this run).
|
|
307
|
+
* Uses AI SDK `generateObject` which leverages provider-native structured outputs when available.
|
|
308
|
+
*/
|
|
309
|
+
async generateObject<T>({
|
|
310
|
+
messages,
|
|
311
|
+
model,
|
|
312
|
+
schema,
|
|
313
|
+
temperature,
|
|
314
|
+
topP,
|
|
315
|
+
maxTokens,
|
|
316
|
+
onUsageData,
|
|
317
|
+
recordInHistory = true,
|
|
318
|
+
}: GenerateObjectParams<unknown>): Promise<GenerateObjectOutcome<T>> {
|
|
319
|
+
await this.ensureModulesProcessed();
|
|
320
|
+
|
|
321
|
+
const combined: ModelMessage[] = [
|
|
322
|
+
...this.toModelMessages(this.history.getMessages()),
|
|
323
|
+
...this.toModelMessages(messages),
|
|
324
|
+
];
|
|
325
|
+
|
|
326
|
+
// Schema normalization (Zod OR JSON Schema supported)
|
|
327
|
+
const isZod =
|
|
328
|
+
schema &&
|
|
329
|
+
(typeof (schema as any).safeParse === 'function' ||
|
|
330
|
+
(!!(schema as any)._def && typeof (schema as any)._def.typeName === 'string'));
|
|
331
|
+
const normalizedSchema = isZod ? (schema as any) : jsonSchema(this.strictifyJsonSchema(schema as any));
|
|
332
|
+
|
|
333
|
+
const result = await aiGenerateObject({
|
|
334
|
+
model,
|
|
335
|
+
messages: combined,
|
|
336
|
+
schema: normalizedSchema,
|
|
337
|
+
providerOptions: {
|
|
338
|
+
openai: {
|
|
339
|
+
strictJsonSchema: true,
|
|
340
|
+
},
|
|
341
|
+
},
|
|
342
|
+
maxOutputTokens: maxTokens,
|
|
343
|
+
temperature,
|
|
344
|
+
topP,
|
|
345
|
+
experimental_repairText: async ({ text }: any) => {
|
|
346
|
+
const cleaned = String(text ?? '')
|
|
347
|
+
.trim()
|
|
348
|
+
.replace(/^```(?:json)?/i, '')
|
|
349
|
+
.replace(/```$/, '');
|
|
350
|
+
try {
|
|
351
|
+
JSON.parse(cleaned);
|
|
352
|
+
return cleaned;
|
|
353
|
+
} catch {
|
|
354
|
+
return null;
|
|
355
|
+
}
|
|
356
|
+
},
|
|
357
|
+
} as any);
|
|
358
|
+
|
|
359
|
+
// Record user messages to history (parity with other methods)
|
|
360
|
+
const chatCompletions: ChatCompletionMessageParam[] = messages.map((m) =>
|
|
361
|
+
typeof m === 'string' ? ({ role: 'user', content: m } as ChatCompletionMessageParam) : m
|
|
362
|
+
);
|
|
363
|
+
this.addMessagesToHistory(chatCompletions);
|
|
364
|
+
|
|
365
|
+
// Optionally persist the final JSON in history
|
|
366
|
+
if (recordInHistory) {
|
|
367
|
+
try {
|
|
368
|
+
const toRecord = typeof result?.object === 'object' ? JSON.stringify(result.object) : '';
|
|
369
|
+
if (toRecord) {
|
|
370
|
+
this.addAssistantMessagesToHistory([toRecord]);
|
|
371
|
+
}
|
|
372
|
+
} catch {
|
|
373
|
+
/* ignore */
|
|
374
|
+
}
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
const usageData = this.processUsageData({
|
|
378
|
+
result,
|
|
379
|
+
model,
|
|
380
|
+
});
|
|
381
|
+
|
|
382
|
+
if (onUsageData) {
|
|
383
|
+
await onUsageData(usageData);
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
return {
|
|
387
|
+
object: (result?.object ?? ({} as any)) as T,
|
|
388
|
+
usageData,
|
|
389
|
+
};
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
/** Convert (string | ChatCompletionMessageParam)[] -> AI SDK ModelMessage[] */
|
|
393
|
+
private toModelMessages(input: (string | ChatCompletionMessageParam)[]): ModelMessage[] {
|
|
394
|
+
return input.map((m) => {
|
|
395
|
+
if (typeof m === 'string') {
|
|
396
|
+
return { role: 'user', content: m };
|
|
397
|
+
}
|
|
398
|
+
const text = Array.isArray(m.content)
|
|
399
|
+
? m.content.map((p: any) => (typeof p === 'string' ? p : p?.text ?? '')).join('\n')
|
|
400
|
+
: (m.content as string | undefined) ?? '';
|
|
401
|
+
const role = m.role === 'system' || m.role === 'user' || m.role === 'assistant' ? m.role : 'user';
|
|
402
|
+
return { role, content: text };
|
|
403
|
+
});
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
/**
|
|
407
|
+
* Strictifies a plain JSON Schema for OpenAI Structured Outputs (strict mode):
|
|
408
|
+
* - Ensures every object has `additionalProperties: false`
|
|
409
|
+
* - Ensures every object has a `required` array that includes **all** keys in `properties`
|
|
410
|
+
* - Adds missing `type: "object"` / `type: "array"` where implied by keywords
|
|
411
|
+
*/
|
|
412
|
+
private strictifyJsonSchema(schema: any): any {
|
|
413
|
+
const root = JSON.parse(JSON.stringify(schema));
|
|
414
|
+
|
|
415
|
+
const visit = (node: any) => {
|
|
416
|
+
if (!node || typeof node !== 'object') {
|
|
417
|
+
return;
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
// If keywords imply a type but it's missing, add it (helps downstream validators)
|
|
421
|
+
if (!node.type) {
|
|
422
|
+
if (node.properties || node.additionalProperties || node.patternProperties) {
|
|
423
|
+
node.type = 'object';
|
|
424
|
+
} else if (node.items || node.prefixItems) {
|
|
425
|
+
node.type = 'array';
|
|
426
|
+
}
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
const types = Array.isArray(node.type) ? node.type : node.type ? [node.type] : [];
|
|
430
|
+
|
|
431
|
+
// Objects: enforce strict requirements
|
|
432
|
+
if (types.includes('object')) {
|
|
433
|
+
// 1) additionalProperties: false
|
|
434
|
+
if (node.additionalProperties !== false) {
|
|
435
|
+
node.additionalProperties = false;
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
// 2) required must exist and include every key in properties
|
|
439
|
+
if (node.properties && typeof node.properties === 'object') {
|
|
440
|
+
const propKeys = Object.keys(node.properties);
|
|
441
|
+
const currentReq: string[] = Array.isArray(node.required) ? node.required.slice() : [];
|
|
442
|
+
const union = Array.from(new Set([...currentReq, ...propKeys]));
|
|
443
|
+
node.required = union;
|
|
444
|
+
|
|
445
|
+
// Recurse into each property schema
|
|
446
|
+
for (const k of propKeys) {
|
|
447
|
+
visit(node.properties[k]);
|
|
448
|
+
}
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
// Recurse into patternProperties
|
|
452
|
+
if (node.patternProperties && typeof node.patternProperties === 'object') {
|
|
453
|
+
for (const k of Object.keys(node.patternProperties)) {
|
|
454
|
+
visit(node.patternProperties[k]);
|
|
455
|
+
}
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
// Recurse into $defs / definitions
|
|
459
|
+
for (const defsKey of ['$defs', 'definitions']) {
|
|
460
|
+
if (node[defsKey] && typeof node[defsKey] === 'object') {
|
|
461
|
+
for (const key of Object.keys(node[defsKey])) {
|
|
462
|
+
visit(node[defsKey][key]);
|
|
463
|
+
}
|
|
464
|
+
}
|
|
465
|
+
}
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
// Arrays: recurse into items/prefixItems
|
|
469
|
+
if (types.includes('array')) {
|
|
470
|
+
if (node.items) {
|
|
471
|
+
if (Array.isArray(node.items)) {
|
|
472
|
+
node.items.forEach(visit);
|
|
473
|
+
} else {
|
|
474
|
+
visit(node.items);
|
|
475
|
+
}
|
|
476
|
+
}
|
|
477
|
+
if (Array.isArray(node.prefixItems)) {
|
|
478
|
+
node.prefixItems.forEach(visit);
|
|
479
|
+
}
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
// Combinators
|
|
483
|
+
for (const k of ['oneOf', 'anyOf', 'allOf']) {
|
|
484
|
+
if (Array.isArray(node[k])) {
|
|
485
|
+
node[k].forEach(visit);
|
|
486
|
+
}
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
// Negation
|
|
490
|
+
if (node.not) {
|
|
491
|
+
visit(node.not);
|
|
492
|
+
}
|
|
493
|
+
};
|
|
494
|
+
|
|
495
|
+
visit(root);
|
|
496
|
+
return root;
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
// ---- Usage + provider metadata normalization ----
|
|
500
|
+
|
|
501
|
+
private processUsageData(args: {
|
|
502
|
+
result: any;
|
|
503
|
+
model?: LanguageModel;
|
|
504
|
+
toolCounts?: Map<string, number>;
|
|
505
|
+
toolLedgerLen?: number;
|
|
506
|
+
}): UsageData {
|
|
507
|
+
const { result, model, toolCounts, toolLedgerLen } = args;
|
|
508
|
+
|
|
509
|
+
// Try several shapes used by AI SDK / providers
|
|
510
|
+
const u: any = result?.usage ?? result?.response?.usage ?? result?.response?.metadata?.usage;
|
|
511
|
+
|
|
512
|
+
// Provider-specific extras (OpenAI Responses variants)
|
|
513
|
+
const { cachedInputTokens } = this.extractOpenAiUsageDetails?.(result) ?? {};
|
|
514
|
+
|
|
515
|
+
const input = Number.isFinite(u?.inputTokens) ? Number(u.inputTokens) : 0;
|
|
516
|
+
const output = Number.isFinite(u?.outputTokens) ? Number(u.outputTokens) : 0;
|
|
517
|
+
const total = Number.isFinite(u?.totalTokens) ? Number(u.totalTokens) : input + output;
|
|
518
|
+
const cached = Number.isFinite(cachedInputTokens) ? Number(cachedInputTokens) : 0;
|
|
519
|
+
|
|
520
|
+
// Resolve model id for pricing/telemetry
|
|
521
|
+
const modelId: any =
|
|
522
|
+
(model as any)?.modelId ??
|
|
523
|
+
result?.response?.providerMetadata?.openai?.model ??
|
|
524
|
+
result?.providerMetadata?.openai?.model ??
|
|
525
|
+
result?.response?.model ??
|
|
526
|
+
undefined;
|
|
527
|
+
|
|
528
|
+
const tokenUsage = {
|
|
529
|
+
promptTokens: input,
|
|
530
|
+
cachedPromptTokens: cached,
|
|
531
|
+
completionTokens: output,
|
|
532
|
+
totalTokens: total,
|
|
533
|
+
};
|
|
534
|
+
|
|
535
|
+
const callsPerTool = toolCounts ? Object.fromEntries(toolCounts) : {};
|
|
536
|
+
const totalToolCalls =
|
|
537
|
+
typeof toolLedgerLen === 'number' ? toolLedgerLen : Object.values(callsPerTool).reduce((a, b) => a + (b || 0), 0);
|
|
538
|
+
|
|
539
|
+
return {
|
|
540
|
+
model: modelId,
|
|
541
|
+
initialRequestTokenUsage: { ...tokenUsage },
|
|
542
|
+
totalTokenUsage: { ...tokenUsage },
|
|
543
|
+
totalRequestsToAssistant: 1,
|
|
544
|
+
totalToolCalls,
|
|
545
|
+
callsPerTool,
|
|
546
|
+
};
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
// Pull OpenAI-specific cached/extra usage from provider metadata or raw usage.
|
|
550
|
+
// Safe across providers; returns undefined if not available.
|
|
551
|
+
private extractOpenAiUsageDetails(result: any): {
|
|
552
|
+
cachedInputTokens?: number;
|
|
553
|
+
reasoningTokens?: number;
|
|
554
|
+
} {
|
|
555
|
+
try {
|
|
556
|
+
const md = result?.providerMetadata?.openai ?? result?.response?.providerMetadata?.openai;
|
|
557
|
+
const usage = md?.usage ?? result?.response?.usage ?? result?.usage;
|
|
558
|
+
|
|
559
|
+
// OpenAI Responses API has used different shapes over time; try both:
|
|
560
|
+
const cachedInputTokens =
|
|
561
|
+
usage?.input_tokens_details?.cached_tokens ??
|
|
562
|
+
usage?.prompt_tokens_details?.cached_tokens ??
|
|
563
|
+
usage?.cached_input_tokens;
|
|
564
|
+
|
|
565
|
+
// Reasoning tokens (when available on reasoning models)
|
|
566
|
+
const reasoningTokens = usage?.output_tokens_details?.reasoning_tokens ?? usage?.reasoning_tokens;
|
|
567
|
+
|
|
568
|
+
return {
|
|
569
|
+
cachedInputTokens: typeof cachedInputTokens === 'number' ? cachedInputTokens : undefined,
|
|
570
|
+
reasoningTokens: typeof reasoningTokens === 'number' ? reasoningTokens : undefined,
|
|
571
|
+
};
|
|
572
|
+
} catch {
|
|
573
|
+
return {};
|
|
574
|
+
}
|
|
575
|
+
}
|
|
576
|
+
|
|
271
577
|
async generateCode({ description, model }: { description: string[]; model?: TiktokenModel }) {
|
|
272
578
|
this.logger.debug({ message: `Generating code`, obj: { description } });
|
|
273
579
|
await this.ensureModulesProcessed();
|
package/src/OpenAi.ts
CHANGED
|
@@ -22,14 +22,44 @@ function delay(ms: number) {
|
|
|
22
22
|
return new Promise((resolve) => setTimeout(resolve, ms));
|
|
23
23
|
}
|
|
24
24
|
|
|
25
|
+
/** Structured capture of each tool call during a single generateResponse loop. */
|
|
26
|
+
export type ToolInvocationResult = {
|
|
27
|
+
id: string; // tool_call_id from the model
|
|
28
|
+
name: string; // function name invoked
|
|
29
|
+
startedAt: Date;
|
|
30
|
+
finishedAt: Date;
|
|
31
|
+
input: unknown; // parsed JSON args (or raw string if parse failed)
|
|
32
|
+
ok: boolean;
|
|
33
|
+
data?: unknown; // tool return value (JSON-serializable)
|
|
34
|
+
error?: { message: string; stack?: string };
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
/** Realtime progress hook for tool calls. */
|
|
38
|
+
export type ToolInvocationProgressEvent =
|
|
39
|
+
| {
|
|
40
|
+
type: 'started';
|
|
41
|
+
id: string;
|
|
42
|
+
name: string;
|
|
43
|
+
startedAt: Date;
|
|
44
|
+
input: unknown;
|
|
45
|
+
}
|
|
46
|
+
| {
|
|
47
|
+
type: 'finished';
|
|
48
|
+
result: ToolInvocationResult;
|
|
49
|
+
};
|
|
50
|
+
|
|
25
51
|
export type GenerateResponseParams = {
|
|
26
52
|
messages: (string | ChatCompletionMessageParam)[];
|
|
27
53
|
model?: TiktokenModel;
|
|
54
|
+
/** Optional realtime hook for tool-call lifecycle (started/finished). */
|
|
55
|
+
onToolInvocation?: (evt: ToolInvocationProgressEvent) => void;
|
|
28
56
|
};
|
|
29
57
|
|
|
30
58
|
export type GenerateResponseReturn = {
|
|
31
59
|
message: string;
|
|
32
60
|
usagedata: UsageData;
|
|
61
|
+
/** Structured ledger of tool calls executed while producing this message. */
|
|
62
|
+
toolInvocations: ToolInvocationResult[];
|
|
33
63
|
};
|
|
34
64
|
|
|
35
65
|
export type GenerateStreamingResponseParams = GenerateResponseParams & {
|
|
@@ -42,6 +72,8 @@ type GenerateResponseHelperParams = GenerateStreamingResponseParams & {
|
|
|
42
72
|
stream: boolean;
|
|
43
73
|
currentFunctionCalls?: number;
|
|
44
74
|
usageDataAccumulator?: UsageDataAccumulator;
|
|
75
|
+
/** Accumulated across recursive tool loops. */
|
|
76
|
+
toolInvocations?: ToolInvocationResult[];
|
|
45
77
|
};
|
|
46
78
|
|
|
47
79
|
export type OpenAiParams = {
|
|
@@ -53,7 +85,7 @@ export type OpenAiParams = {
|
|
|
53
85
|
logLevel?: LogLevel;
|
|
54
86
|
};
|
|
55
87
|
|
|
56
|
-
export const DEFAULT_MODEL: TiktokenModel = 'gpt-
|
|
88
|
+
export const DEFAULT_MODEL: TiktokenModel = 'gpt-4o';
|
|
57
89
|
export const DEFAULT_MAX_FUNCTION_CALLS = 50;
|
|
58
90
|
|
|
59
91
|
export class OpenAi {
|
|
@@ -77,7 +109,7 @@ export class OpenAi {
|
|
|
77
109
|
this.functions = functions;
|
|
78
110
|
this.messageModerators = messageModerators;
|
|
79
111
|
this.maxFunctionCalls = maxFunctionCalls;
|
|
80
|
-
this.logLevel = logLevel;
|
|
112
|
+
this.logLevel = (process.env.PROTEINJS_CONVERSATION_OPENAI_LOG_LEVEL as LogLevel | undefined) ?? logLevel;
|
|
81
113
|
}
|
|
82
114
|
|
|
83
115
|
async generateResponse({ model, ...rest }: GenerateResponseParams): Promise<GenerateResponseReturn> {
|
|
@@ -85,11 +117,17 @@ export class OpenAi {
|
|
|
85
117
|
model: model ?? this.model,
|
|
86
118
|
stream: false,
|
|
87
119
|
...rest,
|
|
120
|
+
toolInvocations: [],
|
|
88
121
|
})) as GenerateResponseReturn;
|
|
89
122
|
}
|
|
90
123
|
|
|
91
124
|
async generateStreamingResponse({ model, ...rest }: GenerateStreamingResponseParams): Promise<Readable> {
|
|
92
|
-
return (await this.generateResponseHelper({
|
|
125
|
+
return (await this.generateResponseHelper({
|
|
126
|
+
model: model ?? this.model,
|
|
127
|
+
stream: true,
|
|
128
|
+
...rest,
|
|
129
|
+
toolInvocations: [],
|
|
130
|
+
})) as Readable;
|
|
93
131
|
}
|
|
94
132
|
|
|
95
133
|
private async generateResponseHelper({
|
|
@@ -98,8 +136,10 @@ export class OpenAi {
|
|
|
98
136
|
stream,
|
|
99
137
|
abortSignal,
|
|
100
138
|
onUsageData,
|
|
139
|
+
onToolInvocation,
|
|
101
140
|
usageDataAccumulator,
|
|
102
141
|
currentFunctionCalls = 0,
|
|
142
|
+
toolInvocations = [],
|
|
103
143
|
}: GenerateResponseHelperParams): Promise<GenerateResponseReturn | Readable> {
|
|
104
144
|
const logger = new Logger({ name: 'OpenAi.generateResponseHelper', logLevel: this.logLevel });
|
|
105
145
|
this.updateMessageHistory(messages);
|
|
@@ -130,7 +170,9 @@ export class OpenAi {
|
|
|
130
170
|
currentFunctionCalls,
|
|
131
171
|
resolvedUsageDataAccumulator,
|
|
132
172
|
abortSignal,
|
|
133
|
-
onUsageData
|
|
173
|
+
onUsageData,
|
|
174
|
+
toolInvocations,
|
|
175
|
+
onToolInvocation
|
|
134
176
|
)) as (toolCalls: ChatCompletionMessageToolCall[], currentFunctionCalls: number) => Promise<Readable>;
|
|
135
177
|
const streamProcessor = new OpenAiStreamProcessor(
|
|
136
178
|
inputStream,
|
|
@@ -152,7 +194,9 @@ export class OpenAi {
|
|
|
152
194
|
currentFunctionCalls,
|
|
153
195
|
resolvedUsageDataAccumulator,
|
|
154
196
|
abortSignal,
|
|
155
|
-
onUsageData
|
|
197
|
+
onUsageData,
|
|
198
|
+
toolInvocations,
|
|
199
|
+
onToolInvocation
|
|
156
200
|
);
|
|
157
201
|
}
|
|
158
202
|
|
|
@@ -162,7 +206,7 @@ export class OpenAi {
|
|
|
162
206
|
}
|
|
163
207
|
|
|
164
208
|
this.history.push([responseMessage]);
|
|
165
|
-
return { message: responseText, usagedata: resolvedUsageDataAccumulator.usageData };
|
|
209
|
+
return { message: responseText, usagedata: resolvedUsageDataAccumulator.usageData, toolInvocations };
|
|
166
210
|
};
|
|
167
211
|
|
|
168
212
|
// Only wrap in context if this is the first call
|
|
@@ -208,7 +252,6 @@ export class OpenAi {
|
|
|
208
252
|
const response = await openaiApi.chat.completions.create(
|
|
209
253
|
{
|
|
210
254
|
model,
|
|
211
|
-
temperature: 0,
|
|
212
255
|
messages: this.history.getMessages(),
|
|
213
256
|
...(this.functions &&
|
|
214
257
|
this.functions.length > 0 && {
|
|
@@ -310,7 +353,9 @@ export class OpenAi {
|
|
|
310
353
|
currentFunctionCalls: number,
|
|
311
354
|
usageDataAccumulator: UsageDataAccumulator,
|
|
312
355
|
abortSignal?: AbortSignal,
|
|
313
|
-
onUsageData?: (usageData: UsageData) => Promise<void
|
|
356
|
+
onUsageData?: (usageData: UsageData) => Promise<void>,
|
|
357
|
+
toolInvocations: ToolInvocationResult[] = [],
|
|
358
|
+
onToolInvocation?: (evt: ToolInvocationProgressEvent) => void
|
|
314
359
|
): Promise<GenerateResponseReturn | Readable> {
|
|
315
360
|
if (currentFunctionCalls >= this.maxFunctionCalls) {
|
|
316
361
|
throw new Error(`Max function calls (${this.maxFunctionCalls}) reached. Stopping execution.`);
|
|
@@ -327,7 +372,7 @@ export class OpenAi {
|
|
|
327
372
|
this.history.push([toolCallMessage]);
|
|
328
373
|
|
|
329
374
|
// Call the tools and get the responses
|
|
330
|
-
const toolMessageParams = await this.callTools(toolCalls, usageDataAccumulator);
|
|
375
|
+
const toolMessageParams = await this.callTools(toolCalls, usageDataAccumulator, toolInvocations, onToolInvocation);
|
|
331
376
|
|
|
332
377
|
// Add the tool responses to the history
|
|
333
378
|
this.history.push(toolMessageParams);
|
|
@@ -339,18 +384,31 @@ export class OpenAi {
|
|
|
339
384
|
stream,
|
|
340
385
|
abortSignal,
|
|
341
386
|
onUsageData,
|
|
387
|
+
onToolInvocation,
|
|
342
388
|
usageDataAccumulator,
|
|
343
389
|
currentFunctionCalls: currentFunctionCalls + toolCalls.length,
|
|
390
|
+
toolInvocations,
|
|
344
391
|
});
|
|
345
392
|
}
|
|
346
393
|
|
|
347
394
|
private async callTools(
|
|
348
395
|
toolCalls: ChatCompletionMessageToolCall[],
|
|
349
|
-
usageDataAccumulator: UsageDataAccumulator
|
|
396
|
+
usageDataAccumulator: UsageDataAccumulator,
|
|
397
|
+
toolInvocations: ToolInvocationResult[],
|
|
398
|
+
onToolInvocation?: (evt: ToolInvocationProgressEvent) => void
|
|
350
399
|
): Promise<ChatCompletionMessageParam[]> {
|
|
351
400
|
const toolMessageParams: ChatCompletionMessageParam[] = (
|
|
352
401
|
await Promise.all(
|
|
353
|
-
toolCalls.map(
|
|
402
|
+
toolCalls.map(
|
|
403
|
+
async (toolCall) =>
|
|
404
|
+
await this.callFunction(
|
|
405
|
+
toolCall.function,
|
|
406
|
+
toolCall.id,
|
|
407
|
+
usageDataAccumulator,
|
|
408
|
+
toolInvocations,
|
|
409
|
+
onToolInvocation
|
|
410
|
+
)
|
|
411
|
+
)
|
|
354
412
|
)
|
|
355
413
|
).reduce((acc, val) => acc.concat(val), []);
|
|
356
414
|
|
|
@@ -360,7 +418,9 @@ export class OpenAi {
|
|
|
360
418
|
private async callFunction(
|
|
361
419
|
functionCall: ChatCompletionMessageToolCall.Function,
|
|
362
420
|
toolCallId: string,
|
|
363
|
-
usageDataAccumulator: UsageDataAccumulator
|
|
421
|
+
usageDataAccumulator: UsageDataAccumulator,
|
|
422
|
+
toolInvocations: ToolInvocationResult[],
|
|
423
|
+
onToolInvocation?: (evt: ToolInvocationProgressEvent) => void
|
|
364
424
|
): Promise<ChatCompletionMessageParam[]> {
|
|
365
425
|
const logger = new Logger({ name: 'OpenAi.callFunction', logLevel: this.logLevel });
|
|
366
426
|
if (!this.functions) {
|
|
@@ -370,7 +430,7 @@ export class OpenAi {
|
|
|
370
430
|
}
|
|
371
431
|
|
|
372
432
|
functionCall.name = functionCall.name.split('.').pop() as string;
|
|
373
|
-
const f = this.functions.find((
|
|
433
|
+
const f = this.functions.find((fx) => fx.definition.name === functionCall.name);
|
|
374
434
|
if (!f) {
|
|
375
435
|
const errorMessage = `Assistant attempted to call nonexistent function`;
|
|
376
436
|
logger.error({ message: errorMessage, obj: { functionName: functionCall.name } });
|
|
@@ -390,7 +450,33 @@ export class OpenAi {
|
|
|
390
450
|
obj: { toolCallId, functionName: f.definition.name, args: parsedArguments },
|
|
391
451
|
});
|
|
392
452
|
usageDataAccumulator.recordToolCall(f.definition.name);
|
|
453
|
+
|
|
454
|
+
const startedAt = new Date();
|
|
455
|
+
|
|
456
|
+
onToolInvocation?.({
|
|
457
|
+
type: 'started',
|
|
458
|
+
id: toolCallId,
|
|
459
|
+
name: f.definition.name,
|
|
460
|
+
startedAt,
|
|
461
|
+
input: parsedArguments,
|
|
462
|
+
});
|
|
463
|
+
|
|
393
464
|
const returnObject = await f.call(parsedArguments);
|
|
465
|
+
const finishedAt = new Date();
|
|
466
|
+
|
|
467
|
+
// Record success
|
|
468
|
+
const rec: ToolInvocationResult = {
|
|
469
|
+
id: toolCallId,
|
|
470
|
+
name: f.definition.name,
|
|
471
|
+
startedAt,
|
|
472
|
+
finishedAt,
|
|
473
|
+
input: parsedArguments,
|
|
474
|
+
ok: true,
|
|
475
|
+
data: returnObject,
|
|
476
|
+
};
|
|
477
|
+
toolInvocations.push(rec);
|
|
478
|
+
|
|
479
|
+
onToolInvocation?.({ type: 'finished', result: rec });
|
|
394
480
|
|
|
395
481
|
const returnObjectCompletionParams: ChatCompletionMessageParam[] = [];
|
|
396
482
|
if (isInstanceOf(returnObject, ChatCompletionMessageParamFactory)) {
|
|
@@ -433,11 +519,35 @@ export class OpenAi {
|
|
|
433
519
|
|
|
434
520
|
return returnObjectCompletionParams;
|
|
435
521
|
} catch (error: any) {
|
|
522
|
+
const now = new Date();
|
|
523
|
+
const attemptedArgs = (() => {
|
|
524
|
+
try {
|
|
525
|
+
return JSON.parse(functionCall.arguments);
|
|
526
|
+
} catch {
|
|
527
|
+
return functionCall.arguments;
|
|
528
|
+
}
|
|
529
|
+
})();
|
|
530
|
+
|
|
531
|
+
// Record failure
|
|
532
|
+
const rec: ToolInvocationResult = {
|
|
533
|
+
id: toolCallId,
|
|
534
|
+
name: functionCall.name,
|
|
535
|
+
startedAt: now,
|
|
536
|
+
finishedAt: now,
|
|
537
|
+
input: attemptedArgs,
|
|
538
|
+
ok: false,
|
|
539
|
+
error: { message: String(error?.message ?? error), stack: (error as any)?.stack },
|
|
540
|
+
};
|
|
541
|
+
toolInvocations.push(rec);
|
|
542
|
+
|
|
543
|
+
onToolInvocation?.({ type: 'finished', result: rec });
|
|
544
|
+
|
|
436
545
|
logger.error({
|
|
437
546
|
message: `An error occurred while executing function`,
|
|
438
547
|
error,
|
|
439
548
|
obj: { toolCallId, functionName: f.definition.name },
|
|
440
549
|
});
|
|
550
|
+
|
|
441
551
|
throw error;
|
|
442
552
|
}
|
|
443
553
|
}
|
|
@@ -6,7 +6,8 @@ import { ConversationFsModerator } from './ConversationFsModerator';
|
|
|
6
6
|
import {
|
|
7
7
|
fsFunctions,
|
|
8
8
|
getRecentlyAccessedFilePathsFunction,
|
|
9
|
-
|
|
9
|
+
grepFunction,
|
|
10
|
+
grepFunctionName,
|
|
10
11
|
readFilesFunction,
|
|
11
12
|
readFilesFunctionName,
|
|
12
13
|
writeFilesFunction,
|
|
@@ -36,7 +37,7 @@ export class ConversationFsModule implements ConversationModule {
|
|
|
36
37
|
`When reading/writing a file in a specified package, join the package directory with the relative path to form the file path`,
|
|
37
38
|
`When searching for source files, do not look in the dist or node_modules directories`,
|
|
38
39
|
`If you don't know a file path, don't try to guess it, use the ${searchFilesFunctionName} function to find it`,
|
|
39
|
-
`When searching for something (ie. a file to work with/in), unless more context is specified, use the ${
|
|
40
|
+
`When searching for something (ie. a file to work with/in), unless more context is specified, use the ${grepFunctionName} function first, then fall back to functions: ${searchPackagesFunctionName}, ${searchFilesFunctionName}`,
|
|
40
41
|
`After finding a file to work with, assume the user's following question pertains to that file and use ${readFilesFunctionName} to read the file if needed`,
|
|
41
42
|
];
|
|
42
43
|
}
|
|
@@ -46,6 +47,7 @@ export class ConversationFsModule implements ConversationModule {
|
|
|
46
47
|
readFilesFunction(this),
|
|
47
48
|
writeFilesFunction(this),
|
|
48
49
|
getRecentlyAccessedFilePathsFunction(this),
|
|
50
|
+
grepFunction(this),
|
|
49
51
|
...fsFunctions,
|
|
50
52
|
];
|
|
51
53
|
}
|
|
@@ -61,6 +63,10 @@ export class ConversationFsModule implements ConversationModule {
|
|
|
61
63
|
getRecentlyAccessedFilePaths() {
|
|
62
64
|
return this.recentlyAccessedFilePaths;
|
|
63
65
|
}
|
|
66
|
+
|
|
67
|
+
getRepoPath(): string {
|
|
68
|
+
return this.repoPath;
|
|
69
|
+
}
|
|
64
70
|
}
|
|
65
71
|
|
|
66
72
|
export class ConversationFsModuleFactory implements ConversationModuleFactory {
|