@botpress/zai 2.1.20 → 2.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CLAUDE.md +696 -0
- package/README.md +79 -2
- package/dist/index.d.ts +85 -14
- package/dist/index.js +3 -0
- package/dist/operations/group.js +369 -0
- package/dist/operations/rate.js +350 -0
- package/dist/operations/sort.js +450 -0
- package/e2e/data/cache.jsonl +289 -0
- package/package.json +1 -1
- package/src/index.ts +3 -0
- package/src/operations/group.ts +543 -0
- package/src/operations/rate.ts +518 -0
- package/src/operations/sort.ts +618 -0
|
@@ -0,0 +1,618 @@
|
|
|
1
|
+
// eslint-disable consistent-type-definitions
|
|
2
|
+
import { z } from '@bpinternal/zui'
|
|
3
|
+
import pLimit from 'p-limit'
|
|
4
|
+
import { ZaiContext } from '../context'
|
|
5
|
+
import { Response } from '../response'
|
|
6
|
+
import { getTokenizer } from '../tokenizer'
|
|
7
|
+
import { fastHash, stringify } from '../utils'
|
|
8
|
+
import { Zai } from '../zai'
|
|
9
|
+
import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
|
|
10
|
+
|
|
11
|
+
export type Options = {
|
|
12
|
+
/** The maximum number of tokens per item */
|
|
13
|
+
tokensPerItem?: number
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
const _Options = z.object({
|
|
17
|
+
tokensPerItem: z
|
|
18
|
+
.number()
|
|
19
|
+
.min(1)
|
|
20
|
+
.max(100_000)
|
|
21
|
+
.optional()
|
|
22
|
+
.describe('The maximum number of tokens per item')
|
|
23
|
+
.default(250),
|
|
24
|
+
})
|
|
25
|
+
|
|
26
|
+
// Evaluation criteria generated by LLM
|
|
27
|
+
type SortingCriteria = Record<
|
|
28
|
+
string,
|
|
29
|
+
{
|
|
30
|
+
description: string
|
|
31
|
+
labels: string[] // Ordered array of labels from FIRST to LAST, e.g., ['very_slow', 'slow', 'medium', 'fast', 'very_fast']
|
|
32
|
+
}
|
|
33
|
+
>
|
|
34
|
+
|
|
35
|
+
declare module '@botpress/zai' {
|
|
36
|
+
interface Zai {
|
|
37
|
+
/**
|
|
38
|
+
* Sorts an array of items based on provided instructions.
|
|
39
|
+
* Returns the sorted array directly when awaited.
|
|
40
|
+
* Use .result() to get detailed scoring information including why each item got its position.
|
|
41
|
+
*
|
|
42
|
+
* @example
|
|
43
|
+
* // Simple usage
|
|
44
|
+
* const sorted = await zai.sort(items, 'from least expensive to most expensive')
|
|
45
|
+
*
|
|
46
|
+
* @example
|
|
47
|
+
* // Get detailed results
|
|
48
|
+
* const { output: sorted, usage } = await zai.sort(items, 'by priority').result()
|
|
49
|
+
*/
|
|
50
|
+
sort<T>(input: Array<T>, instructions: string, options?: Options): Response<Array<T>, Array<T>>
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
const END = '■END■'
|
|
55
|
+
|
|
56
|
+
const sort = async <T>(
|
|
57
|
+
input: Array<T>,
|
|
58
|
+
instructions: string,
|
|
59
|
+
_options: Options | undefined,
|
|
60
|
+
ctx: ZaiContext
|
|
61
|
+
): Promise<Array<T>> => {
|
|
62
|
+
ctx.controller.signal.throwIfAborted()
|
|
63
|
+
|
|
64
|
+
const options = _Options.parse(_options ?? {})
|
|
65
|
+
const tokenizer = await getTokenizer()
|
|
66
|
+
const model = await ctx.getModel()
|
|
67
|
+
|
|
68
|
+
const taskId = ctx.taskId
|
|
69
|
+
const taskType = 'zai.sort'
|
|
70
|
+
|
|
71
|
+
// Handle empty or single element arrays
|
|
72
|
+
if (input.length === 0) {
|
|
73
|
+
return []
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
if (input.length === 1) {
|
|
77
|
+
return input
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER
|
|
81
|
+
|
|
82
|
+
// Phase 1: Generate sorting criteria from instructions + sample items
|
|
83
|
+
const sampleSize = Math.min(5, input.length)
|
|
84
|
+
const sampleItems = input.slice(0, sampleSize)
|
|
85
|
+
const sampleItemsText = sampleItems.map((item, idx) => `■${idx}: ${stringify(item, false)}`).join('\n')
|
|
86
|
+
|
|
87
|
+
const generateCriteriaPrompt = `Analyze this sorting instruction: "${instructions}"
|
|
88
|
+
|
|
89
|
+
Sample items to be sorted:
|
|
90
|
+
${sampleItemsText}
|
|
91
|
+
|
|
92
|
+
Create 1-3 sorting criteria with ordered label arrays (3-10 labels each).
|
|
93
|
+
|
|
94
|
+
**CRITICAL RULES**:
|
|
95
|
+
1. Labels are single words, lowercase, no spaces, use underscores
|
|
96
|
+
2. Labels are ordered from FIRST to LAST in sorted result
|
|
97
|
+
3. If instruction says "from X to Y": first label represents X, last label represents Y
|
|
98
|
+
4. If instruction says "prioritize" or "highest/lowest priority":
|
|
99
|
+
- First label = HIGHEST priority (top of todo list)
|
|
100
|
+
- Last label = LOWEST priority (bottom of todo list)
|
|
101
|
+
|
|
102
|
+
Examples:
|
|
103
|
+
|
|
104
|
+
"from slowest to fastest" → first=slowest, last=fastest
|
|
105
|
+
■speed■
|
|
106
|
+
very_slow;slow;medium;fast;very_fast
|
|
107
|
+
■END■
|
|
108
|
+
|
|
109
|
+
"from most dangerous to least dangerous" → first=most dangerous, last=least dangerous
|
|
110
|
+
■danger■
|
|
111
|
+
extremely_dangerous;very_dangerous;dangerous;moderate;slightly_dangerous;harmless
|
|
112
|
+
■END■
|
|
113
|
+
|
|
114
|
+
"from least urgent (spam) to most urgent (bills)" → first=spam, last=bills
|
|
115
|
+
■urgency■
|
|
116
|
+
spam;promotional;normal;important;urgent;critical
|
|
117
|
+
■END■
|
|
118
|
+
|
|
119
|
+
"prioritize: highest priority=open old tickets; lowest priority=closed" → first=high priority, last=low priority
|
|
120
|
+
■status■
|
|
121
|
+
open_old;open_recent;closed
|
|
122
|
+
■age■
|
|
123
|
+
oldest;old;recent;new
|
|
124
|
+
■END■
|
|
125
|
+
|
|
126
|
+
Output format:
|
|
127
|
+
■criterion_name■
|
|
128
|
+
label1;label2;label3;label4
|
|
129
|
+
■END■
|
|
130
|
+
|
|
131
|
+
Use 3-10 labels per criterion. Labels should be intuitive and match the domain.
|
|
132
|
+
Keep criterion names short (1-2 words, lowercase, underscores).
|
|
133
|
+
`
|
|
134
|
+
|
|
135
|
+
const { extracted: sortingCriteria } = await ctx.generateContent({
|
|
136
|
+
systemPrompt: `You are creating sorting criteria with ordered label arrays.
|
|
137
|
+
|
|
138
|
+
CRITICAL: Output ordered labels from FIRST to LAST position in sorted result.
|
|
139
|
+
- Labels are single words, lowercase, underscores only
|
|
140
|
+
- 3-10 labels per criterion
|
|
141
|
+
- Order matters: first label = appears first, last label = appears last`,
|
|
142
|
+
messages: [
|
|
143
|
+
{
|
|
144
|
+
type: 'text',
|
|
145
|
+
role: 'user',
|
|
146
|
+
content: generateCriteriaPrompt,
|
|
147
|
+
},
|
|
148
|
+
],
|
|
149
|
+
transform: (text) => {
|
|
150
|
+
const criteria: SortingCriteria = {}
|
|
151
|
+
const criterionRegex = /■([^■]+)■\s*([^\n■]+)/g
|
|
152
|
+
let match: RegExpExecArray | null
|
|
153
|
+
|
|
154
|
+
while ((match = criterionRegex.exec(text)) !== null) {
|
|
155
|
+
const name = (match[1] ?? '').trim().toLowerCase()
|
|
156
|
+
const labelsStr = (match[2] ?? '').trim()
|
|
157
|
+
|
|
158
|
+
if (!name || name === 'end') continue
|
|
159
|
+
|
|
160
|
+
// Parse semicolon-separated labels
|
|
161
|
+
const labels = labelsStr
|
|
162
|
+
.split(';')
|
|
163
|
+
.map((l) => l.trim().toLowerCase().replace(/\s+/g, '_'))
|
|
164
|
+
.filter((l) => l.length > 0 && l.length < 50)
|
|
165
|
+
|
|
166
|
+
if (labels.length >= 3 && labels.length <= 10) {
|
|
167
|
+
criteria[name] = {
|
|
168
|
+
description: `${labels.length} ordered labels`,
|
|
169
|
+
labels,
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
if (Object.keys(criteria).length === 0) {
|
|
175
|
+
throw new Error(`Failed to parse sorting criteria. LLM output: ${text.slice(0, 500)}`)
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
return criteria
|
|
179
|
+
},
|
|
180
|
+
})
|
|
181
|
+
|
|
182
|
+
const criteriaKeys = Object.keys(sortingCriteria)
|
|
183
|
+
if (criteriaKeys.length === 0) {
|
|
184
|
+
throw new Error('No sorting criteria generated')
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
// Phase 2: Chunk items and score them in parallel
|
|
188
|
+
const TOKENS_CRITERIA_MAX = Math.floor(TOKENS_TOTAL_MAX * 0.2)
|
|
189
|
+
const TOKENS_ITEMS_MAX = TOKENS_TOTAL_MAX - TOKENS_CRITERIA_MAX
|
|
190
|
+
|
|
191
|
+
const MAX_ITEMS_PER_CHUNK = 50
|
|
192
|
+
|
|
193
|
+
// Prepare elements with indices
|
|
194
|
+
const elements = input.map((element, idx) => ({
|
|
195
|
+
element,
|
|
196
|
+
index: idx,
|
|
197
|
+
stringified: stringify(element, false),
|
|
198
|
+
}))
|
|
199
|
+
|
|
200
|
+
// Chunk elements
|
|
201
|
+
const chunks: Array<typeof elements> = []
|
|
202
|
+
let currentChunk: typeof elements = []
|
|
203
|
+
let currentTokens = 0
|
|
204
|
+
|
|
205
|
+
for (const elem of elements) {
|
|
206
|
+
const truncated = tokenizer.truncate(elem.stringified, options.tokensPerItem)
|
|
207
|
+
const elemTokens = tokenizer.count(truncated)
|
|
208
|
+
|
|
209
|
+
if (
|
|
210
|
+
(currentTokens + elemTokens > TOKENS_ITEMS_MAX || currentChunk.length >= MAX_ITEMS_PER_CHUNK) &&
|
|
211
|
+
currentChunk.length > 0
|
|
212
|
+
) {
|
|
213
|
+
chunks.push(currentChunk)
|
|
214
|
+
currentChunk = []
|
|
215
|
+
currentTokens = 0
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
currentChunk.push(elem)
|
|
219
|
+
currentTokens += elemTokens
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
if (currentChunk.length > 0) {
|
|
223
|
+
chunks.push(currentChunk)
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
// Phase 3: Score each chunk
|
|
227
|
+
type ItemScore = {
|
|
228
|
+
elementIndex: number
|
|
229
|
+
scores: Record<string, number>
|
|
230
|
+
totalScore: number
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
const scoreChunk = async (chunk: typeof elements): Promise<ItemScore[]> => {
|
|
234
|
+
ctx.controller.signal.throwIfAborted()
|
|
235
|
+
|
|
236
|
+
const chunkSize = chunk.length
|
|
237
|
+
const chunkInputStr = JSON.stringify(chunk.map((c) => c.element))
|
|
238
|
+
|
|
239
|
+
// Get examples from adapter for active learning
|
|
240
|
+
const examples =
|
|
241
|
+
taskId && ctx.adapter
|
|
242
|
+
? await ctx.adapter.getExamples<string, ItemScore[]>({
|
|
243
|
+
input: chunkInputStr.slice(0, 1000),
|
|
244
|
+
taskType,
|
|
245
|
+
taskId,
|
|
246
|
+
})
|
|
247
|
+
: []
|
|
248
|
+
|
|
249
|
+
// Check for exact match (cache hit)
|
|
250
|
+
const key = fastHash(
|
|
251
|
+
stringify({
|
|
252
|
+
taskId,
|
|
253
|
+
taskType,
|
|
254
|
+
input: chunkInputStr,
|
|
255
|
+
instructions,
|
|
256
|
+
})
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
const exactMatch = examples.find((x) => x.key === key)
|
|
260
|
+
if (exactMatch && exactMatch.output) {
|
|
261
|
+
return exactMatch.output
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
const elementsText = chunk
|
|
265
|
+
.map((elem, i) => {
|
|
266
|
+
const truncated = tokenizer.truncate(elem.stringified, options.tokensPerItem)
|
|
267
|
+
return `■${i}: ${truncated}■`
|
|
268
|
+
})
|
|
269
|
+
.join('\n')
|
|
270
|
+
|
|
271
|
+
const criteriaText = criteriaKeys
|
|
272
|
+
.map((key) => {
|
|
273
|
+
const criterion = sortingCriteria[key]
|
|
274
|
+
const labelsText = criterion.labels.join(';')
|
|
275
|
+
return `**${key}**: ${labelsText}`
|
|
276
|
+
})
|
|
277
|
+
.join('\n')
|
|
278
|
+
|
|
279
|
+
// Format examples for few-shot learning
|
|
280
|
+
const exampleMessages: Array<{ type: 'text'; role: 'user' | 'assistant'; content: string }> = []
|
|
281
|
+
|
|
282
|
+
for (const example of examples.slice(0, 3)) {
|
|
283
|
+
try {
|
|
284
|
+
const exampleInput = JSON.parse(example.input)
|
|
285
|
+
const exampleItems = Array.isArray(exampleInput) ? exampleInput : [exampleInput]
|
|
286
|
+
|
|
287
|
+
exampleMessages.push({
|
|
288
|
+
type: 'text',
|
|
289
|
+
role: 'user',
|
|
290
|
+
content: `Expert Example - Items to score:\n${exampleItems.map((el, i) => `■${i}: ${stringify(el, false).slice(0, 200)}■`).join('\n')}\n\nScore each item.`,
|
|
291
|
+
})
|
|
292
|
+
|
|
293
|
+
const exampleOutput = example.output
|
|
294
|
+
if (Array.isArray(exampleOutput) && exampleOutput.length > 0) {
|
|
295
|
+
const formattedScores = exampleOutput
|
|
296
|
+
.map((score) => {
|
|
297
|
+
const pairs = criteriaKeys.map((key) => `${key}=${score.scores[key] ?? 0}`).join(';')
|
|
298
|
+
return `■${score.elementIndex}:${pairs}■`
|
|
299
|
+
})
|
|
300
|
+
.join('\n')
|
|
301
|
+
|
|
302
|
+
exampleMessages.push({
|
|
303
|
+
type: 'text',
|
|
304
|
+
role: 'assistant',
|
|
305
|
+
content: `${formattedScores}\n${END}`,
|
|
306
|
+
})
|
|
307
|
+
|
|
308
|
+
if (example.explanation) {
|
|
309
|
+
exampleMessages.push({
|
|
310
|
+
type: 'text',
|
|
311
|
+
role: 'assistant',
|
|
312
|
+
content: `Reasoning: ${example.explanation}`,
|
|
313
|
+
})
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
} catch {
|
|
317
|
+
// Skip malformed examples
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
const { extracted } = await ctx.generateContent({
|
|
322
|
+
systemPrompt: `You are ranking items for sorting using ordered label arrays.
|
|
323
|
+
|
|
324
|
+
${criteriaText}
|
|
325
|
+
|
|
326
|
+
Instructions: "${instructions}"
|
|
327
|
+
|
|
328
|
+
SCORING RULES:
|
|
329
|
+
- For each item and each criterion, assign ONE label from the ordered list
|
|
330
|
+
- Labels are ordered: first label = appears FIRST in sorted result, last label = appears LAST
|
|
331
|
+
- Choose the label that best describes each item
|
|
332
|
+
|
|
333
|
+
Output format:
|
|
334
|
+
■0:criterion1=label;criterion2=label■
|
|
335
|
+
■1:criterion1=label;criterion2=label■
|
|
336
|
+
${END}
|
|
337
|
+
|
|
338
|
+
IMPORTANT:
|
|
339
|
+
- Rank every item (■0 to ■${chunkSize - 1})
|
|
340
|
+
- Use exact criterion names: ${criteriaKeys.join(', ')}
|
|
341
|
+
- Use exact labels from the lists above (lowercase, underscores)
|
|
342
|
+
- Use semicolons (;) between criteria
|
|
343
|
+
- Use equals (=) between criterion and label`,
|
|
344
|
+
stopSequences: [END],
|
|
345
|
+
messages: [
|
|
346
|
+
...exampleMessages,
|
|
347
|
+
{
|
|
348
|
+
type: 'text',
|
|
349
|
+
role: 'user',
|
|
350
|
+
content: `Items to rank (■0 to ■${chunkSize - 1}):\n${elementsText}\n\nRank each item using the labeled scales.\nOutput format: ■index:criterion1=label;criterion2=label■\n${END}`,
|
|
351
|
+
},
|
|
352
|
+
],
|
|
353
|
+
transform: (text) => {
|
|
354
|
+
const results: ItemScore[] = []
|
|
355
|
+
const regex = /■(\d+):([^■]+)■/g
|
|
356
|
+
let match: RegExpExecArray | null
|
|
357
|
+
|
|
358
|
+
while ((match = regex.exec(text)) !== null) {
|
|
359
|
+
const idx = parseInt(match[1] ?? '', 10)
|
|
360
|
+
const labelsStr = match[2] ?? ''
|
|
361
|
+
|
|
362
|
+
if (isNaN(idx) || idx < 0 || idx >= chunkSize) continue
|
|
363
|
+
|
|
364
|
+
const scores: Record<string, number> = {}
|
|
365
|
+
let total = 0
|
|
366
|
+
|
|
367
|
+
const pairs = labelsStr.split(';').filter((x) => x.trim().length > 0)
|
|
368
|
+
for (const pair of pairs) {
|
|
369
|
+
const [criterion, labelStr] = pair.split('=').map((x) => x.trim().toLowerCase().replace(/\s+/g, '_'))
|
|
370
|
+
if (!criterion || !labelStr) continue
|
|
371
|
+
|
|
372
|
+
// Find the label index in the ordered array (index = score)
|
|
373
|
+
const labels = sortingCriteria[criterion]?.labels ?? []
|
|
374
|
+
const labelIndex = labels.findIndex((l) => l === labelStr)
|
|
375
|
+
|
|
376
|
+
if (labelIndex >= 0) {
|
|
377
|
+
// Use index as score (0 = first, higher = later)
|
|
378
|
+
scores[criterion] = labelIndex
|
|
379
|
+
total += labelIndex
|
|
380
|
+
} else {
|
|
381
|
+
// If label not found, use middle value
|
|
382
|
+
const middleIndex = labels.length > 0 ? Math.floor(labels.length / 2) : 5
|
|
383
|
+
scores[criterion] = middleIndex
|
|
384
|
+
total += middleIndex
|
|
385
|
+
}
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
results[idx] = {
|
|
389
|
+
elementIndex: chunk[idx].index,
|
|
390
|
+
scores,
|
|
391
|
+
totalScore: total,
|
|
392
|
+
}
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
// Fill in missing results with middle scores
|
|
396
|
+
for (let i = 0; i < chunkSize; i++) {
|
|
397
|
+
if (!results[i]) {
|
|
398
|
+
const scores: Record<string, number> = {}
|
|
399
|
+
let total = 0
|
|
400
|
+
|
|
401
|
+
for (const key of criteriaKeys) {
|
|
402
|
+
const labels = sortingCriteria[key]?.labels ?? []
|
|
403
|
+
const middleIndex = labels.length > 0 ? Math.floor(labels.length / 2) : 5
|
|
404
|
+
scores[key] = middleIndex
|
|
405
|
+
total += middleIndex
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
results[i] = {
|
|
409
|
+
elementIndex: chunk[i].index,
|
|
410
|
+
scores,
|
|
411
|
+
totalScore: total,
|
|
412
|
+
}
|
|
413
|
+
}
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
return results
|
|
417
|
+
},
|
|
418
|
+
})
|
|
419
|
+
|
|
420
|
+
return extracted
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
// Process all chunks in parallel
|
|
424
|
+
const limit = pLimit(10)
|
|
425
|
+
const chunkPromises = chunks.map((chunk) => limit(() => scoreChunk(chunk)))
|
|
426
|
+
const allScores = await Promise.all(chunkPromises)
|
|
427
|
+
|
|
428
|
+
// Phase 4: Merge scores from all chunks
|
|
429
|
+
// Build a map of elementIndex -> accumulated scores
|
|
430
|
+
const scoreMap = new Map<number, { scores: Record<string, number>; totalScore: number; tieBreakOrder?: number }>()
|
|
431
|
+
|
|
432
|
+
for (const chunkScores of allScores) {
|
|
433
|
+
for (const itemScore of chunkScores) {
|
|
434
|
+
const existing = scoreMap.get(itemScore.elementIndex)
|
|
435
|
+
|
|
436
|
+
if (existing) {
|
|
437
|
+
// Average the scores
|
|
438
|
+
for (const key of criteriaKeys) {
|
|
439
|
+
existing.scores[key] = (existing.scores[key] + (itemScore.scores[key] ?? 0)) / 2
|
|
440
|
+
}
|
|
441
|
+
existing.totalScore = (existing.totalScore + itemScore.totalScore) / 2
|
|
442
|
+
} else {
|
|
443
|
+
scoreMap.set(itemScore.elementIndex, {
|
|
444
|
+
scores: { ...itemScore.scores },
|
|
445
|
+
totalScore: itemScore.totalScore,
|
|
446
|
+
})
|
|
447
|
+
}
|
|
448
|
+
}
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
// Verify we have scores for all elements
|
|
452
|
+
if (scoreMap.size !== input.length) {
|
|
453
|
+
throw new Error(`Score map size mismatch: expected ${input.length}, got ${scoreMap.size}`)
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
// Phase 5: Identify ties
|
|
457
|
+
const scoreGroups = new Map<number, number[]>()
|
|
458
|
+
for (const [index, scoreData] of scoreMap.entries()) {
|
|
459
|
+
const roundedScore = Math.round(scoreData.totalScore * 100) // Round to avoid floating point issues
|
|
460
|
+
const group = scoreGroups.get(roundedScore) ?? []
|
|
461
|
+
group.push(index)
|
|
462
|
+
scoreGroups.set(roundedScore, group)
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
// Find groups with more than one item (ties)
|
|
466
|
+
const tiedGroups = Array.from(scoreGroups.values()).filter((group) => group.length > 1)
|
|
467
|
+
|
|
468
|
+
// Phase 6: Tie-breaking - process all tied groups IN PARALLEL
|
|
469
|
+
if (tiedGroups.length > 0) {
|
|
470
|
+
const tieBreakLimit = pLimit(10)
|
|
471
|
+
|
|
472
|
+
await Promise.all(
|
|
473
|
+
tiedGroups.map((tiedIndices) =>
|
|
474
|
+
tieBreakLimit(async () => {
|
|
475
|
+
if (tiedIndices.length <= 1) return
|
|
476
|
+
|
|
477
|
+
const tiedElements = tiedIndices.map((idx) => elements[idx])
|
|
478
|
+
|
|
479
|
+
// Re-score just these items for tie-breaking
|
|
480
|
+
const tieBreakText = tiedElements
|
|
481
|
+
.map((elem, i) => {
|
|
482
|
+
const truncated = tokenizer.truncate(elem.stringified, options.tokensPerItem)
|
|
483
|
+
return `■${i}: ${truncated}■`
|
|
484
|
+
})
|
|
485
|
+
.join('\n')
|
|
486
|
+
|
|
487
|
+
const { extracted: tieBreakOrder } = await ctx.generateContent({
|
|
488
|
+
systemPrompt: `You are breaking a tie between items with identical total scores.
|
|
489
|
+
|
|
490
|
+
Instructions: ${instructions}
|
|
491
|
+
|
|
492
|
+
Criteria:
|
|
493
|
+
${criteriaKeys
|
|
494
|
+
.map((key) => {
|
|
495
|
+
const labels = sortingCriteria[key].labels.join(';')
|
|
496
|
+
return `- ${key}: ${labels}`
|
|
497
|
+
})
|
|
498
|
+
.join('\n')}
|
|
499
|
+
|
|
500
|
+
Order these ${tiedElements.length} items from FIRST to LAST based on the instructions.
|
|
501
|
+
Earlier labels in each criterion should come FIRST.
|
|
502
|
+
|
|
503
|
+
Output format:
|
|
504
|
+
■original_index■
|
|
505
|
+
■original_index■
|
|
506
|
+
${END}
|
|
507
|
+
|
|
508
|
+
Output the indices in the order they should appear (first item at top).`,
|
|
509
|
+
stopSequences: [END],
|
|
510
|
+
messages: [
|
|
511
|
+
{
|
|
512
|
+
type: 'text',
|
|
513
|
+
role: 'user',
|
|
514
|
+
content: `Items with identical scores (need tie-breaking):\n${tieBreakText}\n\nOrder them from first to last.\nOutput format: ■index■ (one per line)\n${END}`,
|
|
515
|
+
},
|
|
516
|
+
],
|
|
517
|
+
transform: (text) => {
|
|
518
|
+
const order: number[] = []
|
|
519
|
+
const regex = /■(\d+)■/g
|
|
520
|
+
let match: RegExpExecArray | null
|
|
521
|
+
|
|
522
|
+
while ((match = regex.exec(text)) !== null) {
|
|
523
|
+
const idx = parseInt(match[1] ?? '', 10)
|
|
524
|
+
if (!isNaN(idx) && idx >= 0 && idx < tiedElements.length) {
|
|
525
|
+
order.push(idx)
|
|
526
|
+
}
|
|
527
|
+
}
|
|
528
|
+
|
|
529
|
+
// If not all items were ordered, append missing ones
|
|
530
|
+
for (let i = 0; i < tiedElements.length; i++) {
|
|
531
|
+
if (!order.includes(i)) {
|
|
532
|
+
order.push(i)
|
|
533
|
+
}
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
return order
|
|
537
|
+
},
|
|
538
|
+
})
|
|
539
|
+
|
|
540
|
+
// Update scoreMap with tie-break order
|
|
541
|
+
for (let i = 0; i < tieBreakOrder.length; i++) {
|
|
542
|
+
const elementIndex = tiedElements[tieBreakOrder[i]].index
|
|
543
|
+
const scoreData = scoreMap.get(elementIndex)
|
|
544
|
+
if (scoreData) {
|
|
545
|
+
scoreData.tieBreakOrder = i
|
|
546
|
+
}
|
|
547
|
+
}
|
|
548
|
+
})
|
|
549
|
+
)
|
|
550
|
+
)
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
// Phase 7: Sort by total score, then by tie-break order
|
|
554
|
+
const sorted = Array.from(scoreMap.entries())
|
|
555
|
+
.sort((a, b) => {
|
|
556
|
+
// First sort by total score
|
|
557
|
+
const scoreDiff = a[1].totalScore - b[1].totalScore
|
|
558
|
+
if (scoreDiff !== 0) return scoreDiff
|
|
559
|
+
|
|
560
|
+
// If scores are equal, use tie-break order
|
|
561
|
+
const orderA = a[1].tieBreakOrder ?? 0
|
|
562
|
+
const orderB = b[1].tieBreakOrder ?? 0
|
|
563
|
+
return orderA - orderB
|
|
564
|
+
})
|
|
565
|
+
.map(([index]) => elements[index].element)
|
|
566
|
+
|
|
567
|
+
const result = sorted
|
|
568
|
+
|
|
569
|
+
// Save example for active learning
|
|
570
|
+
if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
|
|
571
|
+
const key = fastHash(
|
|
572
|
+
stringify({
|
|
573
|
+
taskId,
|
|
574
|
+
taskType,
|
|
575
|
+
input: JSON.stringify(input),
|
|
576
|
+
instructions,
|
|
577
|
+
})
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
await ctx.adapter.saveExample({
|
|
581
|
+
key,
|
|
582
|
+
taskType,
|
|
583
|
+
taskId,
|
|
584
|
+
input: JSON.stringify(input),
|
|
585
|
+
output: result,
|
|
586
|
+
instructions,
|
|
587
|
+
metadata: {
|
|
588
|
+
cost: { input: 0, output: 0 },
|
|
589
|
+
latency: 0,
|
|
590
|
+
model: ctx.modelId,
|
|
591
|
+
tokens: { input: 0, output: 0 },
|
|
592
|
+
},
|
|
593
|
+
})
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
return result
|
|
597
|
+
}
|
|
598
|
+
|
|
599
|
+
Zai.prototype.sort = function <T>(
|
|
600
|
+
this: Zai,
|
|
601
|
+
input: Array<T>,
|
|
602
|
+
instructions: string,
|
|
603
|
+
_options?: Options
|
|
604
|
+
): Response<Array<T>, Array<T>> {
|
|
605
|
+
const context = new ZaiContext({
|
|
606
|
+
client: this.client,
|
|
607
|
+
modelId: this.Model,
|
|
608
|
+
taskId: this.taskId,
|
|
609
|
+
taskType: 'zai.sort',
|
|
610
|
+
adapter: this.adapter,
|
|
611
|
+
})
|
|
612
|
+
|
|
613
|
+
return new Response<Array<T>, Array<T>>(
|
|
614
|
+
context,
|
|
615
|
+
sort(input, instructions, _options, context),
|
|
616
|
+
(result) => result // Simplified form is just the sorted array
|
|
617
|
+
)
|
|
618
|
+
}
|