@botpress/zai 2.2.0 → 2.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,618 @@
1
+ // eslint-disable consistent-type-definitions
2
+ import { z } from '@bpinternal/zui'
3
+ import pLimit from 'p-limit'
4
+ import { ZaiContext } from '../context'
5
+ import { Response } from '../response'
6
+ import { getTokenizer } from '../tokenizer'
7
+ import { fastHash, stringify } from '../utils'
8
+ import { Zai } from '../zai'
9
+ import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
10
+
11
+ export type Options = {
12
+ /** The maximum number of tokens per item */
13
+ tokensPerItem?: number
14
+ }
15
+
16
+ const _Options = z.object({
17
+ tokensPerItem: z
18
+ .number()
19
+ .min(1)
20
+ .max(100_000)
21
+ .optional()
22
+ .describe('The maximum number of tokens per item')
23
+ .default(250),
24
+ })
25
+
26
+ // Evaluation criteria generated by LLM
27
+ type SortingCriteria = Record<
28
+ string,
29
+ {
30
+ description: string
31
+ labels: string[] // Ordered array of labels from FIRST to LAST, e.g., ['very_slow', 'slow', 'medium', 'fast', 'very_fast']
32
+ }
33
+ >
34
+
35
+ declare module '@botpress/zai' {
36
+ interface Zai {
37
+ /**
38
+ * Sorts an array of items based on provided instructions.
39
+ * Returns the sorted array directly when awaited.
40
+ * Use .result() to get detailed scoring information including why each item got its position.
41
+ *
42
+ * @example
43
+ * // Simple usage
44
+ * const sorted = await zai.sort(items, 'from least expensive to most expensive')
45
+ *
46
+ * @example
47
+ * // Get detailed results
48
+ * const { output: sorted, usage } = await zai.sort(items, 'by priority').result()
49
+ */
50
+ sort<T>(input: Array<T>, instructions: string, options?: Options): Response<Array<T>, Array<T>>
51
+ }
52
+ }
53
+
54
+ const END = '■END■'
55
+
56
+ const sort = async <T>(
57
+ input: Array<T>,
58
+ instructions: string,
59
+ _options: Options | undefined,
60
+ ctx: ZaiContext
61
+ ): Promise<Array<T>> => {
62
+ ctx.controller.signal.throwIfAborted()
63
+
64
+ const options = _Options.parse(_options ?? {})
65
+ const tokenizer = await getTokenizer()
66
+ const model = await ctx.getModel()
67
+
68
+ const taskId = ctx.taskId
69
+ const taskType = 'zai.sort'
70
+
71
+ // Handle empty or single element arrays
72
+ if (input.length === 0) {
73
+ return []
74
+ }
75
+
76
+ if (input.length === 1) {
77
+ return input
78
+ }
79
+
80
+ const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER
81
+
82
+ // Phase 1: Generate sorting criteria from instructions + sample items
83
+ const sampleSize = Math.min(5, input.length)
84
+ const sampleItems = input.slice(0, sampleSize)
85
+ const sampleItemsText = sampleItems.map((item, idx) => `■${idx}: ${stringify(item, false)}`).join('\n')
86
+
87
+ const generateCriteriaPrompt = `Analyze this sorting instruction: "${instructions}"
88
+
89
+ Sample items to be sorted:
90
+ ${sampleItemsText}
91
+
92
+ Create 1-3 sorting criteria with ordered label arrays (3-10 labels each).
93
+
94
+ **CRITICAL RULES**:
95
+ 1. Labels are single words, lowercase, no spaces, use underscores
96
+ 2. Labels are ordered from FIRST to LAST in sorted result
97
+ 3. If instruction says "from X to Y": first label represents X, last label represents Y
98
+ 4. If instruction says "prioritize" or "highest/lowest priority":
99
+ - First label = HIGHEST priority (top of todo list)
100
+ - Last label = LOWEST priority (bottom of todo list)
101
+
102
+ Examples:
103
+
104
+ "from slowest to fastest" → first=slowest, last=fastest
105
+ ■speed■
106
+ very_slow;slow;medium;fast;very_fast
107
+ ■END■
108
+
109
+ "from most dangerous to least dangerous" → first=most dangerous, last=least dangerous
110
+ ■danger■
111
+ extremely_dangerous;very_dangerous;dangerous;moderate;slightly_dangerous;harmless
112
+ ■END■
113
+
114
+ "from least urgent (spam) to most urgent (bills)" → first=spam, last=bills
115
+ ■urgency■
116
+ spam;promotional;normal;important;urgent;critical
117
+ ■END■
118
+
119
+ "prioritize: highest priority=open old tickets; lowest priority=closed" → first=high priority, last=low priority
120
+ ■status■
121
+ open_old;open_recent;closed
122
+ ■age■
123
+ oldest;old;recent;new
124
+ ■END■
125
+
126
+ Output format:
127
+ ■criterion_name■
128
+ label1;label2;label3;label4
129
+ ■END■
130
+
131
+ Use 3-10 labels per criterion. Labels should be intuitive and match the domain.
132
+ Keep criterion names short (1-2 words, lowercase, underscores).
133
+ `
134
+
135
+ const { extracted: sortingCriteria } = await ctx.generateContent({
136
+ systemPrompt: `You are creating sorting criteria with ordered label arrays.
137
+
138
+ CRITICAL: Output ordered labels from FIRST to LAST position in sorted result.
139
+ - Labels are single words, lowercase, underscores only
140
+ - 3-10 labels per criterion
141
+ - Order matters: first label = appears first, last label = appears last`,
142
+ messages: [
143
+ {
144
+ type: 'text',
145
+ role: 'user',
146
+ content: generateCriteriaPrompt,
147
+ },
148
+ ],
149
+ transform: (text) => {
150
+ const criteria: SortingCriteria = {}
151
+ const criterionRegex = /■([^■]+)■\s*([^\n■]+)/g
152
+ let match: RegExpExecArray | null
153
+
154
+ while ((match = criterionRegex.exec(text)) !== null) {
155
+ const name = (match[1] ?? '').trim().toLowerCase()
156
+ const labelsStr = (match[2] ?? '').trim()
157
+
158
+ if (!name || name === 'end') continue
159
+
160
+ // Parse semicolon-separated labels
161
+ const labels = labelsStr
162
+ .split(';')
163
+ .map((l) => l.trim().toLowerCase().replace(/\s+/g, '_'))
164
+ .filter((l) => l.length > 0 && l.length < 50)
165
+
166
+ if (labels.length >= 3 && labels.length <= 10) {
167
+ criteria[name] = {
168
+ description: `${labels.length} ordered labels`,
169
+ labels,
170
+ }
171
+ }
172
+ }
173
+
174
+ if (Object.keys(criteria).length === 0) {
175
+ throw new Error(`Failed to parse sorting criteria. LLM output: ${text.slice(0, 500)}`)
176
+ }
177
+
178
+ return criteria
179
+ },
180
+ })
181
+
182
+ const criteriaKeys = Object.keys(sortingCriteria)
183
+ if (criteriaKeys.length === 0) {
184
+ throw new Error('No sorting criteria generated')
185
+ }
186
+
187
+ // Phase 2: Chunk items and score them in parallel
188
+ const TOKENS_CRITERIA_MAX = Math.floor(TOKENS_TOTAL_MAX * 0.2)
189
+ const TOKENS_ITEMS_MAX = TOKENS_TOTAL_MAX - TOKENS_CRITERIA_MAX
190
+
191
+ const MAX_ITEMS_PER_CHUNK = 50
192
+
193
+ // Prepare elements with indices
194
+ const elements = input.map((element, idx) => ({
195
+ element,
196
+ index: idx,
197
+ stringified: stringify(element, false),
198
+ }))
199
+
200
+ // Chunk elements
201
+ const chunks: Array<typeof elements> = []
202
+ let currentChunk: typeof elements = []
203
+ let currentTokens = 0
204
+
205
+ for (const elem of elements) {
206
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerItem)
207
+ const elemTokens = tokenizer.count(truncated)
208
+
209
+ if (
210
+ (currentTokens + elemTokens > TOKENS_ITEMS_MAX || currentChunk.length >= MAX_ITEMS_PER_CHUNK) &&
211
+ currentChunk.length > 0
212
+ ) {
213
+ chunks.push(currentChunk)
214
+ currentChunk = []
215
+ currentTokens = 0
216
+ }
217
+
218
+ currentChunk.push(elem)
219
+ currentTokens += elemTokens
220
+ }
221
+
222
+ if (currentChunk.length > 0) {
223
+ chunks.push(currentChunk)
224
+ }
225
+
226
+ // Phase 3: Score each chunk
227
+ type ItemScore = {
228
+ elementIndex: number
229
+ scores: Record<string, number>
230
+ totalScore: number
231
+ }
232
+
233
+ const scoreChunk = async (chunk: typeof elements): Promise<ItemScore[]> => {
234
+ ctx.controller.signal.throwIfAborted()
235
+
236
+ const chunkSize = chunk.length
237
+ const chunkInputStr = JSON.stringify(chunk.map((c) => c.element))
238
+
239
+ // Get examples from adapter for active learning
240
+ const examples =
241
+ taskId && ctx.adapter
242
+ ? await ctx.adapter.getExamples<string, ItemScore[]>({
243
+ input: chunkInputStr.slice(0, 1000),
244
+ taskType,
245
+ taskId,
246
+ })
247
+ : []
248
+
249
+ // Check for exact match (cache hit)
250
+ const key = fastHash(
251
+ stringify({
252
+ taskId,
253
+ taskType,
254
+ input: chunkInputStr,
255
+ instructions,
256
+ })
257
+ )
258
+
259
+ const exactMatch = examples.find((x) => x.key === key)
260
+ if (exactMatch && exactMatch.output) {
261
+ return exactMatch.output
262
+ }
263
+
264
+ const elementsText = chunk
265
+ .map((elem, i) => {
266
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerItem)
267
+ return `■${i}: ${truncated}■`
268
+ })
269
+ .join('\n')
270
+
271
+ const criteriaText = criteriaKeys
272
+ .map((key) => {
273
+ const criterion = sortingCriteria[key]
274
+ const labelsText = criterion.labels.join(';')
275
+ return `**${key}**: ${labelsText}`
276
+ })
277
+ .join('\n')
278
+
279
+ // Format examples for few-shot learning
280
+ const exampleMessages: Array<{ type: 'text'; role: 'user' | 'assistant'; content: string }> = []
281
+
282
+ for (const example of examples.slice(0, 3)) {
283
+ try {
284
+ const exampleInput = JSON.parse(example.input)
285
+ const exampleItems = Array.isArray(exampleInput) ? exampleInput : [exampleInput]
286
+
287
+ exampleMessages.push({
288
+ type: 'text',
289
+ role: 'user',
290
+ content: `Expert Example - Items to score:\n${exampleItems.map((el, i) => `■${i}: ${stringify(el, false).slice(0, 200)}■`).join('\n')}\n\nScore each item.`,
291
+ })
292
+
293
+ const exampleOutput = example.output
294
+ if (Array.isArray(exampleOutput) && exampleOutput.length > 0) {
295
+ const formattedScores = exampleOutput
296
+ .map((score) => {
297
+ const pairs = criteriaKeys.map((key) => `${key}=${score.scores[key] ?? 0}`).join(';')
298
+ return `■${score.elementIndex}:${pairs}■`
299
+ })
300
+ .join('\n')
301
+
302
+ exampleMessages.push({
303
+ type: 'text',
304
+ role: 'assistant',
305
+ content: `${formattedScores}\n${END}`,
306
+ })
307
+
308
+ if (example.explanation) {
309
+ exampleMessages.push({
310
+ type: 'text',
311
+ role: 'assistant',
312
+ content: `Reasoning: ${example.explanation}`,
313
+ })
314
+ }
315
+ }
316
+ } catch {
317
+ // Skip malformed examples
318
+ }
319
+ }
320
+
321
+ const { extracted } = await ctx.generateContent({
322
+ systemPrompt: `You are ranking items for sorting using ordered label arrays.
323
+
324
+ ${criteriaText}
325
+
326
+ Instructions: "${instructions}"
327
+
328
+ SCORING RULES:
329
+ - For each item and each criterion, assign ONE label from the ordered list
330
+ - Labels are ordered: first label = appears FIRST in sorted result, last label = appears LAST
331
+ - Choose the label that best describes each item
332
+
333
+ Output format:
334
+ ■0:criterion1=label;criterion2=label■
335
+ ■1:criterion1=label;criterion2=label■
336
+ ${END}
337
+
338
+ IMPORTANT:
339
+ - Rank every item (■0 to ■${chunkSize - 1})
340
+ - Use exact criterion names: ${criteriaKeys.join(', ')}
341
+ - Use exact labels from the lists above (lowercase, underscores)
342
+ - Use semicolons (;) between criteria
343
+ - Use equals (=) between criterion and label`,
344
+ stopSequences: [END],
345
+ messages: [
346
+ ...exampleMessages,
347
+ {
348
+ type: 'text',
349
+ role: 'user',
350
+ content: `Items to rank (■0 to ■${chunkSize - 1}):\n${elementsText}\n\nRank each item using the labeled scales.\nOutput format: ■index:criterion1=label;criterion2=label■\n${END}`,
351
+ },
352
+ ],
353
+ transform: (text) => {
354
+ const results: ItemScore[] = []
355
+ const regex = /■(\d+):([^■]+)■/g
356
+ let match: RegExpExecArray | null
357
+
358
+ while ((match = regex.exec(text)) !== null) {
359
+ const idx = parseInt(match[1] ?? '', 10)
360
+ const labelsStr = match[2] ?? ''
361
+
362
+ if (isNaN(idx) || idx < 0 || idx >= chunkSize) continue
363
+
364
+ const scores: Record<string, number> = {}
365
+ let total = 0
366
+
367
+ const pairs = labelsStr.split(';').filter((x) => x.trim().length > 0)
368
+ for (const pair of pairs) {
369
+ const [criterion, labelStr] = pair.split('=').map((x) => x.trim().toLowerCase().replace(/\s+/g, '_'))
370
+ if (!criterion || !labelStr) continue
371
+
372
+ // Find the label index in the ordered array (index = score)
373
+ const labels = sortingCriteria[criterion]?.labels ?? []
374
+ const labelIndex = labels.findIndex((l) => l === labelStr)
375
+
376
+ if (labelIndex >= 0) {
377
+ // Use index as score (0 = first, higher = later)
378
+ scores[criterion] = labelIndex
379
+ total += labelIndex
380
+ } else {
381
+ // If label not found, use middle value
382
+ const middleIndex = labels.length > 0 ? Math.floor(labels.length / 2) : 5
383
+ scores[criterion] = middleIndex
384
+ total += middleIndex
385
+ }
386
+ }
387
+
388
+ results[idx] = {
389
+ elementIndex: chunk[idx].index,
390
+ scores,
391
+ totalScore: total,
392
+ }
393
+ }
394
+
395
+ // Fill in missing results with middle scores
396
+ for (let i = 0; i < chunkSize; i++) {
397
+ if (!results[i]) {
398
+ const scores: Record<string, number> = {}
399
+ let total = 0
400
+
401
+ for (const key of criteriaKeys) {
402
+ const labels = sortingCriteria[key]?.labels ?? []
403
+ const middleIndex = labels.length > 0 ? Math.floor(labels.length / 2) : 5
404
+ scores[key] = middleIndex
405
+ total += middleIndex
406
+ }
407
+
408
+ results[i] = {
409
+ elementIndex: chunk[i].index,
410
+ scores,
411
+ totalScore: total,
412
+ }
413
+ }
414
+ }
415
+
416
+ return results
417
+ },
418
+ })
419
+
420
+ return extracted
421
+ }
422
+
423
+ // Process all chunks in parallel
424
+ const limit = pLimit(10)
425
+ const chunkPromises = chunks.map((chunk) => limit(() => scoreChunk(chunk)))
426
+ const allScores = await Promise.all(chunkPromises)
427
+
428
+ // Phase 4: Merge scores from all chunks
429
+ // Build a map of elementIndex -> accumulated scores
430
+ const scoreMap = new Map<number, { scores: Record<string, number>; totalScore: number; tieBreakOrder?: number }>()
431
+
432
+ for (const chunkScores of allScores) {
433
+ for (const itemScore of chunkScores) {
434
+ const existing = scoreMap.get(itemScore.elementIndex)
435
+
436
+ if (existing) {
437
+ // Average the scores
438
+ for (const key of criteriaKeys) {
439
+ existing.scores[key] = (existing.scores[key] + (itemScore.scores[key] ?? 0)) / 2
440
+ }
441
+ existing.totalScore = (existing.totalScore + itemScore.totalScore) / 2
442
+ } else {
443
+ scoreMap.set(itemScore.elementIndex, {
444
+ scores: { ...itemScore.scores },
445
+ totalScore: itemScore.totalScore,
446
+ })
447
+ }
448
+ }
449
+ }
450
+
451
+ // Verify we have scores for all elements
452
+ if (scoreMap.size !== input.length) {
453
+ throw new Error(`Score map size mismatch: expected ${input.length}, got ${scoreMap.size}`)
454
+ }
455
+
456
+ // Phase 5: Identify ties
457
+ const scoreGroups = new Map<number, number[]>()
458
+ for (const [index, scoreData] of scoreMap.entries()) {
459
+ const roundedScore = Math.round(scoreData.totalScore * 100) // Round to avoid floating point issues
460
+ const group = scoreGroups.get(roundedScore) ?? []
461
+ group.push(index)
462
+ scoreGroups.set(roundedScore, group)
463
+ }
464
+
465
+ // Find groups with more than one item (ties)
466
+ const tiedGroups = Array.from(scoreGroups.values()).filter((group) => group.length > 1)
467
+
468
+ // Phase 6: Tie-breaking - process all tied groups IN PARALLEL
469
+ if (tiedGroups.length > 0) {
470
+ const tieBreakLimit = pLimit(10)
471
+
472
+ await Promise.all(
473
+ tiedGroups.map((tiedIndices) =>
474
+ tieBreakLimit(async () => {
475
+ if (tiedIndices.length <= 1) return
476
+
477
+ const tiedElements = tiedIndices.map((idx) => elements[idx])
478
+
479
+ // Re-score just these items for tie-breaking
480
+ const tieBreakText = tiedElements
481
+ .map((elem, i) => {
482
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerItem)
483
+ return `■${i}: ${truncated}■`
484
+ })
485
+ .join('\n')
486
+
487
+ const { extracted: tieBreakOrder } = await ctx.generateContent({
488
+ systemPrompt: `You are breaking a tie between items with identical total scores.
489
+
490
+ Instructions: ${instructions}
491
+
492
+ Criteria:
493
+ ${criteriaKeys
494
+ .map((key) => {
495
+ const labels = sortingCriteria[key].labels.join(';')
496
+ return `- ${key}: ${labels}`
497
+ })
498
+ .join('\n')}
499
+
500
+ Order these ${tiedElements.length} items from FIRST to LAST based on the instructions.
501
+ Earlier labels in each criterion should come FIRST.
502
+
503
+ Output format:
504
+ ■original_index■
505
+ ■original_index■
506
+ ${END}
507
+
508
+ Output the indices in the order they should appear (first item at top).`,
509
+ stopSequences: [END],
510
+ messages: [
511
+ {
512
+ type: 'text',
513
+ role: 'user',
514
+ content: `Items with identical scores (need tie-breaking):\n${tieBreakText}\n\nOrder them from first to last.\nOutput format: ■index■ (one per line)\n${END}`,
515
+ },
516
+ ],
517
+ transform: (text) => {
518
+ const order: number[] = []
519
+ const regex = /■(\d+)■/g
520
+ let match: RegExpExecArray | null
521
+
522
+ while ((match = regex.exec(text)) !== null) {
523
+ const idx = parseInt(match[1] ?? '', 10)
524
+ if (!isNaN(idx) && idx >= 0 && idx < tiedElements.length) {
525
+ order.push(idx)
526
+ }
527
+ }
528
+
529
+ // If not all items were ordered, append missing ones
530
+ for (let i = 0; i < tiedElements.length; i++) {
531
+ if (!order.includes(i)) {
532
+ order.push(i)
533
+ }
534
+ }
535
+
536
+ return order
537
+ },
538
+ })
539
+
540
+ // Update scoreMap with tie-break order
541
+ for (let i = 0; i < tieBreakOrder.length; i++) {
542
+ const elementIndex = tiedElements[tieBreakOrder[i]].index
543
+ const scoreData = scoreMap.get(elementIndex)
544
+ if (scoreData) {
545
+ scoreData.tieBreakOrder = i
546
+ }
547
+ }
548
+ })
549
+ )
550
+ )
551
+ }
552
+
553
+ // Phase 7: Sort by total score, then by tie-break order
554
+ const sorted = Array.from(scoreMap.entries())
555
+ .sort((a, b) => {
556
+ // First sort by total score
557
+ const scoreDiff = a[1].totalScore - b[1].totalScore
558
+ if (scoreDiff !== 0) return scoreDiff
559
+
560
+ // If scores are equal, use tie-break order
561
+ const orderA = a[1].tieBreakOrder ?? 0
562
+ const orderB = b[1].tieBreakOrder ?? 0
563
+ return orderA - orderB
564
+ })
565
+ .map(([index]) => elements[index].element)
566
+
567
+ const result = sorted
568
+
569
+ // Save example for active learning
570
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
571
+ const key = fastHash(
572
+ stringify({
573
+ taskId,
574
+ taskType,
575
+ input: JSON.stringify(input),
576
+ instructions,
577
+ })
578
+ )
579
+
580
+ await ctx.adapter.saveExample({
581
+ key,
582
+ taskType,
583
+ taskId,
584
+ input: JSON.stringify(input),
585
+ output: result,
586
+ instructions,
587
+ metadata: {
588
+ cost: { input: 0, output: 0 },
589
+ latency: 0,
590
+ model: ctx.modelId,
591
+ tokens: { input: 0, output: 0 },
592
+ },
593
+ })
594
+ }
595
+
596
+ return result
597
+ }
598
+
599
+ Zai.prototype.sort = function <T>(
600
+ this: Zai,
601
+ input: Array<T>,
602
+ instructions: string,
603
+ _options?: Options
604
+ ): Response<Array<T>, Array<T>> {
605
+ const context = new ZaiContext({
606
+ client: this.client,
607
+ modelId: this.Model,
608
+ taskId: this.taskId,
609
+ taskType: 'zai.sort',
610
+ adapter: this.adapter,
611
+ })
612
+
613
+ return new Response<Array<T>, Array<T>>(
614
+ context,
615
+ sort(input, instructions, _options, context),
616
+ (result) => result // Simplified form is just the sorted array
617
+ )
618
+ }