@jackchen_me/open-multi-agent 0.2.0 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. package/.github/workflows/ci.yml +1 -1
  2. package/CLAUDE.md +11 -3
  3. package/README.md +87 -20
  4. package/README_zh.md +85 -25
  5. package/dist/agent/agent.d.ts +15 -1
  6. package/dist/agent/agent.d.ts.map +1 -1
  7. package/dist/agent/agent.js +144 -10
  8. package/dist/agent/agent.js.map +1 -1
  9. package/dist/agent/loop-detector.d.ts +39 -0
  10. package/dist/agent/loop-detector.d.ts.map +1 -0
  11. package/dist/agent/loop-detector.js +122 -0
  12. package/dist/agent/loop-detector.js.map +1 -0
  13. package/dist/agent/pool.d.ts +2 -1
  14. package/dist/agent/pool.d.ts.map +1 -1
  15. package/dist/agent/pool.js +4 -2
  16. package/dist/agent/pool.js.map +1 -1
  17. package/dist/agent/runner.d.ts +23 -1
  18. package/dist/agent/runner.d.ts.map +1 -1
  19. package/dist/agent/runner.js +113 -12
  20. package/dist/agent/runner.js.map +1 -1
  21. package/dist/index.d.ts +3 -1
  22. package/dist/index.d.ts.map +1 -1
  23. package/dist/index.js +2 -0
  24. package/dist/index.js.map +1 -1
  25. package/dist/llm/adapter.d.ts +4 -1
  26. package/dist/llm/adapter.d.ts.map +1 -1
  27. package/dist/llm/adapter.js +11 -0
  28. package/dist/llm/adapter.js.map +1 -1
  29. package/dist/llm/copilot.d.ts.map +1 -1
  30. package/dist/llm/copilot.js +2 -1
  31. package/dist/llm/copilot.js.map +1 -1
  32. package/dist/llm/gemini.d.ts +65 -0
  33. package/dist/llm/gemini.d.ts.map +1 -0
  34. package/dist/llm/gemini.js +317 -0
  35. package/dist/llm/gemini.js.map +1 -0
  36. package/dist/llm/grok.d.ts +21 -0
  37. package/dist/llm/grok.d.ts.map +1 -0
  38. package/dist/llm/grok.js +24 -0
  39. package/dist/llm/grok.js.map +1 -0
  40. package/dist/llm/openai-common.d.ts +8 -1
  41. package/dist/llm/openai-common.d.ts.map +1 -1
  42. package/dist/llm/openai-common.js +35 -2
  43. package/dist/llm/openai-common.js.map +1 -1
  44. package/dist/llm/openai.d.ts +1 -1
  45. package/dist/llm/openai.d.ts.map +1 -1
  46. package/dist/llm/openai.js +20 -2
  47. package/dist/llm/openai.js.map +1 -1
  48. package/dist/orchestrator/orchestrator.d.ts.map +1 -1
  49. package/dist/orchestrator/orchestrator.js +89 -9
  50. package/dist/orchestrator/orchestrator.js.map +1 -1
  51. package/dist/task/queue.d.ts +31 -2
  52. package/dist/task/queue.d.ts.map +1 -1
  53. package/dist/task/queue.js +69 -2
  54. package/dist/task/queue.js.map +1 -1
  55. package/dist/tool/text-tool-extractor.d.ts +32 -0
  56. package/dist/tool/text-tool-extractor.d.ts.map +1 -0
  57. package/dist/tool/text-tool-extractor.js +187 -0
  58. package/dist/tool/text-tool-extractor.js.map +1 -0
  59. package/dist/types.d.ts +139 -7
  60. package/dist/types.d.ts.map +1 -1
  61. package/dist/utils/trace.d.ts +12 -0
  62. package/dist/utils/trace.d.ts.map +1 -0
  63. package/dist/utils/trace.js +30 -0
  64. package/dist/utils/trace.js.map +1 -0
  65. package/examples/06-local-model.ts +1 -0
  66. package/examples/08-gemma4-local.ts +76 -87
  67. package/examples/09-structured-output.ts +73 -0
  68. package/examples/10-task-retry.ts +132 -0
  69. package/examples/11-trace-observability.ts +133 -0
  70. package/examples/12-grok.ts +154 -0
  71. package/examples/13-gemini.ts +48 -0
  72. package/package.json +11 -1
  73. package/src/agent/agent.ts +159 -10
  74. package/src/agent/loop-detector.ts +137 -0
  75. package/src/agent/pool.ts +9 -2
  76. package/src/agent/runner.ts +148 -19
  77. package/src/index.ts +15 -0
  78. package/src/llm/adapter.ts +12 -1
  79. package/src/llm/copilot.ts +2 -1
  80. package/src/llm/gemini.ts +378 -0
  81. package/src/llm/grok.ts +29 -0
  82. package/src/llm/openai-common.ts +41 -2
  83. package/src/llm/openai.ts +23 -3
  84. package/src/orchestrator/orchestrator.ts +105 -11
  85. package/src/task/queue.ts +73 -3
  86. package/src/tool/text-tool-extractor.ts +219 -0
  87. package/src/types.ts +157 -6
  88. package/src/utils/trace.ts +34 -0
  89. package/tests/agent-hooks.test.ts +473 -0
  90. package/tests/agent-pool.test.ts +212 -0
  91. package/tests/approval.test.ts +464 -0
  92. package/tests/built-in-tools.test.ts +393 -0
  93. package/tests/gemini-adapter.test.ts +97 -0
  94. package/tests/grok-adapter.test.ts +74 -0
  95. package/tests/llm-adapters.test.ts +357 -0
  96. package/tests/loop-detection.test.ts +456 -0
  97. package/tests/openai-fallback.test.ts +159 -0
  98. package/tests/orchestrator.test.ts +281 -0
  99. package/tests/scheduler.test.ts +221 -0
  100. package/tests/team-messaging.test.ts +329 -0
  101. package/tests/text-tool-extractor.test.ts +170 -0
  102. package/tests/trace.test.ts +453 -0
  103. package/vitest.config.ts +9 -0
  104. package/examples/09-gemma4-auto-orchestration.ts +0 -162
@@ -52,8 +52,10 @@ import type {
52
52
  TeamRunResult,
53
53
  TokenUsage,
54
54
  } from '../types.js'
55
+ import type { RunOptions } from '../agent/runner.js'
55
56
  import { Agent } from '../agent/agent.js'
56
57
  import { AgentPool } from '../agent/pool.js'
58
+ import { emitTrace, generateRunId } from '../utils/trace.js'
57
59
  import { ToolRegistry } from '../tool/framework.js'
58
60
  import { ToolExecutor } from '../tool/executor.js'
59
61
  import { registerBuiltInTools } from '../tool/built-in/index.js'
@@ -128,9 +130,10 @@ export async function executeWithRetry(
128
130
  onRetry?: (data: { attempt: number; maxAttempts: number; error: string; nextDelayMs: number }) => void,
129
131
  delayFn: (ms: number) => Promise<void> = sleep,
130
132
  ): Promise<AgentRunResult> {
131
- const maxAttempts = Math.max(0, task.maxRetries ?? 0) + 1
132
- const baseDelay = Math.max(0, task.retryDelayMs ?? 1000)
133
- const backoff = Math.max(1, task.retryBackoff ?? 2)
133
+ const rawRetries = Number.isFinite(task.maxRetries) ? task.maxRetries! : 0
134
+ const maxAttempts = Math.max(0, rawRetries) + 1
135
+ const baseDelay = Math.max(0, Number.isFinite(task.retryDelayMs) ? task.retryDelayMs! : 1000)
136
+ const backoff = Math.max(1, Number.isFinite(task.retryBackoff) ? task.retryBackoff! : 2)
134
137
 
135
138
  let lastError: string = ''
136
139
  // Accumulate token usage across all attempts so billing/observability
@@ -259,6 +262,8 @@ interface RunContext {
259
262
  readonly scheduler: Scheduler
260
263
  readonly agentResults: Map<string, AgentRunResult>
261
264
  readonly config: OrchestratorConfig
265
+ /** Trace run ID, present when `onTrace` is configured. */
266
+ readonly runId?: string
262
267
  }
263
268
 
264
269
  /**
@@ -278,6 +283,17 @@ async function executeQueue(
278
283
  ): Promise<void> {
279
284
  const { team, pool, scheduler, config } = ctx
280
285
 
286
+ // Relay queue-level skip events to the orchestrator's onProgress callback.
287
+ const unsubSkipped = config.onProgress
288
+ ? queue.on('task:skipped', (task) => {
289
+ config.onProgress!({
290
+ type: 'task_skipped',
291
+ task: task.id,
292
+ data: task,
293
+ } satisfies OrchestratorEvent)
294
+ })
295
+ : undefined
296
+
281
297
  while (true) {
282
298
  // Re-run auto-assignment each iteration so tasks that were unblocked since
283
299
  // the last round (and thus have no assignee yet) get assigned before dispatch.
@@ -289,6 +305,11 @@ async function executeQueue(
289
305
  break
290
306
  }
291
307
 
308
+ // Track tasks that complete successfully in this round for the approval gate.
309
+ // Safe to push from concurrent promises: JS is single-threaded, so
310
+ // Array.push calls from resolved microtasks never interleave.
311
+ const completedThisRound: Task[] = []
312
+
292
313
  // Dispatch all currently-pending tasks as a parallel batch.
293
314
  const dispatchPromises = pending.map(async (task): Promise<void> => {
294
315
  // Mark in-progress
@@ -337,10 +358,19 @@ async function executeQueue(
337
358
  // Build the prompt: inject shared memory context + task description
338
359
  const prompt = await buildTaskPrompt(task, team)
339
360
 
361
+ // Build trace context for this task's agent run
362
+ const traceOptions: Partial<RunOptions> | undefined = config.onTrace
363
+ ? { onTrace: config.onTrace, runId: ctx.runId ?? '', taskId: task.id, traceAgent: assignee }
364
+ : undefined
365
+
366
+ const taskStartMs = config.onTrace ? Date.now() : 0
367
+ let retryCount = 0
368
+
340
369
  const result = await executeWithRetry(
341
- () => pool.run(assignee, prompt),
370
+ () => pool.run(assignee, prompt, traceOptions),
342
371
  task,
343
372
  (retryData) => {
373
+ retryCount++
344
374
  config.onProgress?.({
345
375
  type: 'task_retry',
346
376
  task: task.id,
@@ -350,6 +380,23 @@ async function executeQueue(
350
380
  },
351
381
  )
352
382
 
383
+ // Emit task trace
384
+ if (config.onTrace) {
385
+ const taskEndMs = Date.now()
386
+ emitTrace(config.onTrace, {
387
+ type: 'task',
388
+ runId: ctx.runId ?? '',
389
+ taskId: task.id,
390
+ taskTitle: task.title,
391
+ agent: assignee,
392
+ success: result.success,
393
+ retries: retryCount,
394
+ startMs: taskStartMs,
395
+ endMs: taskEndMs,
396
+ durationMs: taskEndMs - taskStartMs,
397
+ })
398
+ }
399
+
353
400
  ctx.agentResults.set(`${assignee}:${task.id}`, result)
354
401
 
355
402
  if (result.success) {
@@ -359,7 +406,8 @@ async function executeQueue(
359
406
  await sharedMem.write(assignee, `task:${task.id}:result`, result.output)
360
407
  }
361
408
 
362
- queue.complete(task.id, result.output)
409
+ const completedTask = queue.complete(task.id, result.output)
410
+ completedThisRound.push(completedTask)
363
411
 
364
412
  config.onProgress?.({
365
413
  type: 'task_complete',
@@ -387,7 +435,32 @@ async function executeQueue(
387
435
 
388
436
  // Wait for the entire parallel batch before checking for newly-unblocked tasks.
389
437
  await Promise.all(dispatchPromises)
438
+
439
+ // --- Approval gate ---
440
+ // After the batch completes, check if the caller wants to approve
441
+ // the next round before it starts.
442
+ if (config.onApproval && completedThisRound.length > 0) {
443
+ scheduler.autoAssign(queue, team.getAgents())
444
+ const nextPending = queue.getByStatus('pending')
445
+
446
+ if (nextPending.length > 0) {
447
+ let approved: boolean
448
+ try {
449
+ approved = await config.onApproval(completedThisRound, nextPending)
450
+ } catch (err) {
451
+ const reason = `Skipped: approval callback error — ${err instanceof Error ? err.message : String(err)}`
452
+ queue.skipRemaining(reason)
453
+ break
454
+ }
455
+ if (!approved) {
456
+ queue.skipRemaining('Skipped: approval rejected.')
457
+ break
458
+ }
459
+ }
460
+ }
390
461
  }
462
+
463
+ unsubSkipped?.()
391
464
  }
392
465
 
393
466
  /**
@@ -440,8 +513,8 @@ async function buildTaskPrompt(task: Task, team: Team): Promise<string> {
440
513
  */
441
514
  export class OpenMultiAgent {
442
515
  private readonly config: Required<
443
- Omit<OrchestratorConfig, 'onProgress' | 'defaultBaseURL' | 'defaultApiKey'>
444
- > & Pick<OrchestratorConfig, 'onProgress' | 'defaultBaseURL' | 'defaultApiKey'>
516
+ Omit<OrchestratorConfig, 'onApproval' | 'onProgress' | 'onTrace' | 'defaultBaseURL' | 'defaultApiKey'>
517
+ > & Pick<OrchestratorConfig, 'onApproval' | 'onProgress' | 'onTrace' | 'defaultBaseURL' | 'defaultApiKey'>
445
518
 
446
519
  private readonly teams: Map<string, Team> = new Map()
447
520
  private completedTaskCount = 0
@@ -461,7 +534,9 @@ export class OpenMultiAgent {
461
534
  defaultProvider: config.defaultProvider ?? 'anthropic',
462
535
  defaultBaseURL: config.defaultBaseURL,
463
536
  defaultApiKey: config.defaultApiKey,
537
+ onApproval: config.onApproval,
464
538
  onProgress: config.onProgress,
539
+ onTrace: config.onTrace,
465
540
  }
466
541
  }
467
542
 
@@ -519,7 +594,11 @@ export class OpenMultiAgent {
519
594
  data: { prompt },
520
595
  })
521
596
 
522
- const result = await agent.run(prompt)
597
+ const traceOptions: Partial<RunOptions> | undefined = this.config.onTrace
598
+ ? { onTrace: this.config.onTrace, runId: generateRunId(), traceAgent: config.name }
599
+ : undefined
600
+
601
+ const result = await agent.run(prompt, traceOptions)
523
602
 
524
603
  this.config.onProgress?.({
525
604
  type: 'agent_complete',
@@ -577,6 +656,7 @@ export class OpenMultiAgent {
577
656
 
578
657
  const decompositionPrompt = this.buildDecompositionPrompt(goal, agentConfigs)
579
658
  const coordinatorAgent = buildAgent(coordinatorConfig)
659
+ const runId = this.config.onTrace ? generateRunId() : undefined
580
660
 
581
661
  this.config.onProgress?.({
582
662
  type: 'agent_start',
@@ -584,7 +664,10 @@ export class OpenMultiAgent {
584
664
  data: { phase: 'decomposition', goal },
585
665
  })
586
666
 
587
- const decompositionResult = await coordinatorAgent.run(decompositionPrompt)
667
+ const decompTraceOptions: Partial<RunOptions> | undefined = this.config.onTrace
668
+ ? { onTrace: this.config.onTrace, runId: runId ?? '', traceAgent: 'coordinator' }
669
+ : undefined
670
+ const decompositionResult = await coordinatorAgent.run(decompositionPrompt, decompTraceOptions)
588
671
  const agentResults = new Map<string, AgentRunResult>()
589
672
  agentResults.set('coordinator:decompose', decompositionResult)
590
673
 
@@ -628,6 +711,7 @@ export class OpenMultiAgent {
628
711
  scheduler,
629
712
  agentResults,
630
713
  config: this.config,
714
+ runId,
631
715
  }
632
716
 
633
717
  await executeQueue(queue, ctx)
@@ -636,7 +720,10 @@ export class OpenMultiAgent {
636
720
  // Step 5: Coordinator synthesises final result
637
721
  // ------------------------------------------------------------------
638
722
  const synthesisPrompt = await this.buildSynthesisPrompt(goal, queue.list(), team)
639
- const synthesisResult = await coordinatorAgent.run(synthesisPrompt)
723
+ const synthTraceOptions: Partial<RunOptions> | undefined = this.config.onTrace
724
+ ? { onTrace: this.config.onTrace, runId: runId ?? '', traceAgent: 'coordinator' }
725
+ : undefined
726
+ const synthesisResult = await coordinatorAgent.run(synthesisPrompt, synthTraceOptions)
640
727
  agentResults.set('coordinator', synthesisResult)
641
728
 
642
729
  this.config.onProgress?.({
@@ -706,6 +793,7 @@ export class OpenMultiAgent {
706
793
  scheduler,
707
794
  agentResults,
708
795
  config: this.config,
796
+ runId: this.config.onTrace ? generateRunId() : undefined,
709
797
  }
710
798
 
711
799
  await executeQueue(queue, ctx)
@@ -809,6 +897,7 @@ export class OpenMultiAgent {
809
897
  ): Promise<string> {
810
898
  const completedTasks = tasks.filter((t) => t.status === 'completed')
811
899
  const failedTasks = tasks.filter((t) => t.status === 'failed')
900
+ const skippedTasks = tasks.filter((t) => t.status === 'skipped')
812
901
 
813
902
  const resultSections = completedTasks.map((t) => {
814
903
  const assignee = t.assignee ?? 'unknown'
@@ -819,6 +908,10 @@ export class OpenMultiAgent {
819
908
  (t) => `### ${t.title} (FAILED)\nError: ${t.result ?? 'unknown error'}`,
820
909
  )
821
910
 
911
+ const skippedSections = skippedTasks.map(
912
+ (t) => `### ${t.title} (SKIPPED)\nReason: ${t.result ?? 'approval rejected'}`,
913
+ )
914
+
822
915
  // Also include shared memory summary for additional context
823
916
  let memorySummary = ''
824
917
  const sharedMem = team.getSharedMemoryInstance()
@@ -833,11 +926,12 @@ export class OpenMultiAgent {
833
926
  `## Task Results`,
834
927
  ...resultSections,
835
928
  ...(failureSections.length > 0 ? ['', '## Failed Tasks', ...failureSections] : []),
929
+ ...(skippedSections.length > 0 ? ['', '## Skipped Tasks', ...skippedSections] : []),
836
930
  ...(memorySummary ? ['', memorySummary] : []),
837
931
  '',
838
932
  '## Your Task',
839
933
  'Synthesise the above results into a comprehensive final answer that addresses the original goal.',
840
- 'If some tasks failed, note any gaps in the result.',
934
+ 'If some tasks failed or were skipped, note any gaps in the result.',
841
935
  ].join('\n')
842
936
  }
843
937
 
package/src/task/queue.ts CHANGED
@@ -18,6 +18,7 @@ export type TaskQueueEvent =
18
18
  | 'task:ready'
19
19
  | 'task:complete'
20
20
  | 'task:failed'
21
+ | 'task:skipped'
21
22
  | 'all:complete'
22
23
 
23
24
  /** Handler for `'task:ready' | 'task:complete' | 'task:failed'` events. */
@@ -156,6 +157,51 @@ export class TaskQueue {
156
157
  return failed
157
158
  }
158
159
 
160
+ /**
161
+ * Marks `taskId` as `'skipped'` and records `reason` in the `result` field.
162
+ *
163
+ * Fires `'task:skipped'` for the skipped task and cascades to every
164
+ * downstream task that transitively depended on it — even if the dependent
165
+ * has other dependencies that are still pending or completed. A skipped
166
+ * upstream is treated as permanently unsatisfiable, mirroring `fail()`.
167
+ *
168
+ * @throws {Error} when `taskId` is not found.
169
+ */
170
+ skip(taskId: string, reason: string): Task {
171
+ const skipped = this.update(taskId, { status: 'skipped', result: reason })
172
+ this.emit('task:skipped', skipped)
173
+ this.cascadeSkip(taskId)
174
+ if (this.isComplete()) {
175
+ this.emitAllComplete()
176
+ }
177
+ return skipped
178
+ }
179
+
180
+ /**
181
+ * Marks all non-terminal tasks as `'skipped'`.
182
+ *
183
+ * Used when an approval gate rejects continuation — every pending, blocked,
184
+ * or in-progress task is skipped with the given reason.
185
+ *
186
+ * **Important:** Call only when no tasks are actively executing. The
187
+ * orchestrator invokes this after `await Promise.all()`, so no tasks are
188
+ * in-flight. Calling while agents are running may mark an in-progress task
189
+ * as skipped while its agent continues executing.
190
+ */
191
+ skipRemaining(reason = 'Skipped: approval rejected.'): void {
192
+ // Snapshot first — update() mutates the live map, which is unsafe to
193
+ // iterate over during modification.
194
+ const snapshot = Array.from(this.tasks.values())
195
+ for (const task of snapshot) {
196
+ if (task.status === 'completed' || task.status === 'failed' || task.status === 'skipped') continue
197
+ const skipped = this.update(task.id, { status: 'skipped', result: reason })
198
+ this.emit('task:skipped', skipped)
199
+ }
200
+ if (this.isComplete()) {
201
+ this.emitAllComplete()
202
+ }
203
+ }
204
+
159
205
  /**
160
206
  * Recursively marks all tasks that (transitively) depend on `failedTaskId`
161
207
  * as `'failed'` with an informative message, firing `'task:failed'` for each.
@@ -178,6 +224,24 @@ export class TaskQueue {
178
224
  }
179
225
  }
180
226
 
227
+ /**
228
+ * Recursively marks all tasks that (transitively) depend on `skippedTaskId`
229
+ * as `'skipped'`, firing `'task:skipped'` for each.
230
+ */
231
+ private cascadeSkip(skippedTaskId: string): void {
232
+ for (const task of this.tasks.values()) {
233
+ if (task.status !== 'blocked' && task.status !== 'pending') continue
234
+ if (!task.dependsOn?.includes(skippedTaskId)) continue
235
+
236
+ const cascaded = this.update(task.id, {
237
+ status: 'skipped',
238
+ result: `Skipped: dependency "${skippedTaskId}" was skipped.`,
239
+ })
240
+ this.emit('task:skipped', cascaded)
241
+ this.cascadeSkip(task.id)
242
+ }
243
+ }
244
+
181
245
  // ---------------------------------------------------------------------------
182
246
  // Queries
183
247
  // ---------------------------------------------------------------------------
@@ -227,11 +291,11 @@ export class TaskQueue {
227
291
 
228
292
  /**
229
293
  * Returns `true` when every task in the queue has reached a terminal state
230
- * (`'completed'` or `'failed'`), **or** the queue is empty.
294
+ * (`'completed'`, `'failed'`, or `'skipped'`), **or** the queue is empty.
231
295
  */
232
296
  isComplete(): boolean {
233
297
  for (const task of this.tasks.values()) {
234
- if (task.status !== 'completed' && task.status !== 'failed') return false
298
+ if (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'skipped') return false
235
299
  }
236
300
  return true
237
301
  }
@@ -249,12 +313,14 @@ export class TaskQueue {
249
313
  total: number
250
314
  completed: number
251
315
  failed: number
316
+ skipped: number
252
317
  inProgress: number
253
318
  pending: number
254
319
  blocked: number
255
320
  } {
256
321
  let completed = 0
257
322
  let failed = 0
323
+ let skipped = 0
258
324
  let inProgress = 0
259
325
  let pending = 0
260
326
  let blocked = 0
@@ -267,6 +333,9 @@ export class TaskQueue {
267
333
  case 'failed':
268
334
  failed++
269
335
  break
336
+ case 'skipped':
337
+ skipped++
338
+ break
270
339
  case 'in_progress':
271
340
  inProgress++
272
341
  break
@@ -283,6 +352,7 @@ export class TaskQueue {
283
352
  total: this.tasks.size,
284
353
  completed,
285
354
  failed,
355
+ skipped,
286
356
  inProgress,
287
357
  pending,
288
358
  blocked,
@@ -370,7 +440,7 @@ export class TaskQueue {
370
440
  }
371
441
  }
372
442
 
373
- private emit(event: 'task:ready' | 'task:complete' | 'task:failed', task: Task): void {
443
+ private emit(event: 'task:ready' | 'task:complete' | 'task:failed' | 'task:skipped', task: Task): void {
374
444
  const map = this.listeners.get(event)
375
445
  if (!map) return
376
446
  for (const handler of map.values()) {
@@ -0,0 +1,219 @@
1
+ /**
2
+ * @fileoverview Fallback tool-call extractor for local models.
3
+ *
4
+ * When a local model (Ollama, vLLM, LM Studio) returns tool calls as plain
5
+ * text instead of using the OpenAI `tool_calls` wire format, this module
6
+ * attempts to extract them from the text output.
7
+ *
8
+ * Common scenarios:
9
+ * - Ollama thinking-model bug: tool call JSON ends up inside unclosed `<think>` tags
10
+ * - Model outputs raw JSON tool calls without the server parsing them
11
+ * - Model wraps tool calls in markdown code fences
12
+ * - Hermes-format `<tool_call>` tags
13
+ *
14
+ * This is a **safety net**, not the primary path. Native `tool_calls` from
15
+ * the server are always preferred.
16
+ */
17
+
18
+ import type { ToolUseBlock } from '../types.js'
19
+
20
+ // ---------------------------------------------------------------------------
21
+ // ID generation
22
+ // ---------------------------------------------------------------------------
23
+
24
+ let callCounter = 0
25
+
26
+ /** Generate a unique tool-call ID for extracted calls. */
27
+ function generateToolCallId(): string {
28
+ return `extracted_call_${Date.now()}_${++callCounter}`
29
+ }
30
+
31
+ // ---------------------------------------------------------------------------
32
+ // Internal parsers
33
+ // ---------------------------------------------------------------------------
34
+
35
+ /**
36
+ * Try to parse a single JSON object as a tool call.
37
+ *
38
+ * Accepted shapes:
39
+ * ```json
40
+ * { "name": "bash", "arguments": { "command": "ls" } }
41
+ * { "name": "bash", "parameters": { "command": "ls" } }
42
+ * { "function": { "name": "bash", "arguments": { "command": "ls" } } }
43
+ * ```
44
+ */
45
+ function parseToolCallJSON(
46
+ json: unknown,
47
+ knownToolNames: ReadonlySet<string>,
48
+ ): ToolUseBlock | null {
49
+ if (json === null || typeof json !== 'object' || Array.isArray(json)) {
50
+ return null
51
+ }
52
+
53
+ const obj = json as Record<string, unknown>
54
+
55
+ // Shape: { function: { name, arguments } }
56
+ if (typeof obj['function'] === 'object' && obj['function'] !== null) {
57
+ const fn = obj['function'] as Record<string, unknown>
58
+ return parseFlat(fn, knownToolNames)
59
+ }
60
+
61
+ // Shape: { name, arguments|parameters }
62
+ return parseFlat(obj, knownToolNames)
63
+ }
64
+
65
+ function parseFlat(
66
+ obj: Record<string, unknown>,
67
+ knownToolNames: ReadonlySet<string>,
68
+ ): ToolUseBlock | null {
69
+ const name = obj['name']
70
+ if (typeof name !== 'string' || name.length === 0) return null
71
+
72
+ // Whitelist check — don't treat arbitrary JSON as a tool call
73
+ if (knownToolNames.size > 0 && !knownToolNames.has(name)) return null
74
+
75
+ let input: Record<string, unknown> = {}
76
+ const args = obj['arguments'] ?? obj['parameters'] ?? obj['input']
77
+ if (args !== null && args !== undefined) {
78
+ if (typeof args === 'string') {
79
+ try {
80
+ const parsed = JSON.parse(args)
81
+ if (typeof parsed === 'object' && parsed !== null && !Array.isArray(parsed)) {
82
+ input = parsed as Record<string, unknown>
83
+ }
84
+ } catch {
85
+ // Malformed — use empty input
86
+ }
87
+ } else if (typeof args === 'object' && !Array.isArray(args)) {
88
+ input = args as Record<string, unknown>
89
+ }
90
+ }
91
+
92
+ return {
93
+ type: 'tool_use',
94
+ id: generateToolCallId(),
95
+ name,
96
+ input,
97
+ }
98
+ }
99
+
100
+ // ---------------------------------------------------------------------------
101
+ // JSON extraction from text
102
+ // ---------------------------------------------------------------------------
103
+
104
+ /**
105
+ * Find all top-level JSON objects in a string by tracking brace depth.
106
+ * Returns the parsed objects (not sub-objects).
107
+ */
108
+ function extractJSONObjects(text: string): unknown[] {
109
+ const results: unknown[] = []
110
+ let depth = 0
111
+ let start = -1
112
+ let inString = false
113
+ let escape = false
114
+
115
+ for (let i = 0; i < text.length; i++) {
116
+ const ch = text[i]!
117
+
118
+ if (escape) {
119
+ escape = false
120
+ continue
121
+ }
122
+
123
+ if (ch === '\\' && inString) {
124
+ escape = true
125
+ continue
126
+ }
127
+
128
+ if (ch === '"') {
129
+ inString = !inString
130
+ continue
131
+ }
132
+
133
+ if (inString) continue
134
+
135
+ if (ch === '{') {
136
+ if (depth === 0) start = i
137
+ depth++
138
+ } else if (ch === '}') {
139
+ depth--
140
+ if (depth === 0 && start !== -1) {
141
+ const candidate = text.slice(start, i + 1)
142
+ try {
143
+ results.push(JSON.parse(candidate))
144
+ } catch {
145
+ // Not valid JSON — skip
146
+ }
147
+ start = -1
148
+ }
149
+ }
150
+ }
151
+
152
+ return results
153
+ }
154
+
155
+ // ---------------------------------------------------------------------------
156
+ // Hermes format: <tool_call>...</tool_call>
157
+ // ---------------------------------------------------------------------------
158
+
159
+ function extractHermesToolCalls(
160
+ text: string,
161
+ knownToolNames: ReadonlySet<string>,
162
+ ): ToolUseBlock[] {
163
+ const results: ToolUseBlock[] = []
164
+
165
+ for (const match of text.matchAll(/<tool_call>\s*([\s\S]*?)\s*<\/tool_call>/g)) {
166
+ const inner = match[1]!.trim()
167
+ try {
168
+ const parsed: unknown = JSON.parse(inner)
169
+ const block = parseToolCallJSON(parsed, knownToolNames)
170
+ if (block !== null) results.push(block)
171
+ } catch {
172
+ // Malformed hermes content — skip
173
+ }
174
+ }
175
+
176
+ return results
177
+ }
178
+
179
+ // ---------------------------------------------------------------------------
180
+ // Public API
181
+ // ---------------------------------------------------------------------------
182
+
183
+ /**
184
+ * Attempt to extract tool calls from a model's text output.
185
+ *
186
+ * Tries multiple strategies in order:
187
+ * 1. Hermes `<tool_call>` tags
188
+ * 2. JSON objects in text (bare or inside code fences)
189
+ *
190
+ * @param text - The model's text output.
191
+ * @param knownToolNames - Whitelist of registered tool names. When non-empty,
192
+ * only JSON objects whose `name` matches a known tool
193
+ * are treated as tool calls.
194
+ * @returns Extracted {@link ToolUseBlock}s, or an empty array if none found.
195
+ */
196
+ export function extractToolCallsFromText(
197
+ text: string,
198
+ knownToolNames: string[],
199
+ ): ToolUseBlock[] {
200
+ if (text.length === 0) return []
201
+
202
+ const nameSet = new Set(knownToolNames)
203
+
204
+ // Strategy 1: Hermes format
205
+ const hermesResults = extractHermesToolCalls(text, nameSet)
206
+ if (hermesResults.length > 0) return hermesResults
207
+
208
+ // Strategy 2: Strip code fences, then extract JSON objects
209
+ const stripped = text.replace(/```(?:json)?\s*\n?([\s\S]*?)\n?\s*```/g, '$1')
210
+ const jsonObjects = extractJSONObjects(stripped)
211
+
212
+ const results: ToolUseBlock[] = []
213
+ for (const obj of jsonObjects) {
214
+ const block = parseToolCallJSON(obj, nameSet)
215
+ if (block !== null) results.push(block)
216
+ }
217
+
218
+ return results
219
+ }