@botpress/zai 2.1.20 → 2.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,450 @@
1
+ import { z } from "@bpinternal/zui";
2
+ import pLimit from "p-limit";
3
+ import { ZaiContext } from "../context";
4
+ import { Response } from "../response";
5
+ import { getTokenizer } from "../tokenizer";
6
+ import { fastHash, stringify } from "../utils";
7
+ import { Zai } from "../zai";
8
+ import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from "./constants";
9
+ const _Options = z.object({
10
+ tokensPerItem: z.number().min(1).max(1e5).optional().describe("The maximum number of tokens per item").default(250)
11
+ });
12
+ const END = "\u25A0END\u25A0";
13
+ const sort = async (input, instructions, _options, ctx) => {
14
+ ctx.controller.signal.throwIfAborted();
15
+ const options = _Options.parse(_options ?? {});
16
+ const tokenizer = await getTokenizer();
17
+ const model = await ctx.getModel();
18
+ const taskId = ctx.taskId;
19
+ const taskType = "zai.sort";
20
+ if (input.length === 0) {
21
+ return [];
22
+ }
23
+ if (input.length === 1) {
24
+ return input;
25
+ }
26
+ const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER;
27
+ const sampleSize = Math.min(5, input.length);
28
+ const sampleItems = input.slice(0, sampleSize);
29
+ const sampleItemsText = sampleItems.map((item, idx) => `\u25A0${idx}: ${stringify(item, false)}`).join("\n");
30
+ const generateCriteriaPrompt = `Analyze this sorting instruction: "${instructions}"
31
+
32
+ Sample items to be sorted:
33
+ ${sampleItemsText}
34
+
35
+ Create 1-3 sorting criteria with ordered label arrays (3-10 labels each).
36
+
37
+ **CRITICAL RULES**:
38
+ 1. Labels are single words, lowercase, no spaces, use underscores
39
+ 2. Labels are ordered from FIRST to LAST in sorted result
40
+ 3. If instruction says "from X to Y": first label represents X, last label represents Y
41
+ 4. If instruction says "prioritize" or "highest/lowest priority":
42
+ - First label = HIGHEST priority (top of todo list)
43
+ - Last label = LOWEST priority (bottom of todo list)
44
+
45
+ Examples:
46
+
47
+ "from slowest to fastest" \u2192 first=slowest, last=fastest
48
+ \u25A0speed\u25A0
49
+ very_slow;slow;medium;fast;very_fast
50
+ \u25A0END\u25A0
51
+
52
+ "from most dangerous to least dangerous" \u2192 first=most dangerous, last=least dangerous
53
+ \u25A0danger\u25A0
54
+ extremely_dangerous;very_dangerous;dangerous;moderate;slightly_dangerous;harmless
55
+ \u25A0END\u25A0
56
+
57
+ "from least urgent (spam) to most urgent (bills)" \u2192 first=spam, last=bills
58
+ \u25A0urgency\u25A0
59
+ spam;promotional;normal;important;urgent;critical
60
+ \u25A0END\u25A0
61
+
62
+ "prioritize: highest priority=open old tickets; lowest priority=closed" \u2192 first=high priority, last=low priority
63
+ \u25A0status\u25A0
64
+ open_old;open_recent;closed
65
+ \u25A0age\u25A0
66
+ oldest;old;recent;new
67
+ \u25A0END\u25A0
68
+
69
+ Output format:
70
+ \u25A0criterion_name\u25A0
71
+ label1;label2;label3;label4
72
+ \u25A0END\u25A0
73
+
74
+ Use 3-10 labels per criterion. Labels should be intuitive and match the domain.
75
+ Keep criterion names short (1-2 words, lowercase, underscores).
76
+ `;
77
+ const { extracted: sortingCriteria } = await ctx.generateContent({
78
+ systemPrompt: `You are creating sorting criteria with ordered label arrays.
79
+
80
+ CRITICAL: Output ordered labels from FIRST to LAST position in sorted result.
81
+ - Labels are single words, lowercase, underscores only
82
+ - 3-10 labels per criterion
83
+ - Order matters: first label = appears first, last label = appears last`,
84
+ messages: [
85
+ {
86
+ type: "text",
87
+ role: "user",
88
+ content: generateCriteriaPrompt
89
+ }
90
+ ],
91
+ transform: (text) => {
92
+ const criteria = {};
93
+ const criterionRegex = /■([^■]+)■\s*([^\n■]+)/g;
94
+ let match;
95
+ while ((match = criterionRegex.exec(text)) !== null) {
96
+ const name = (match[1] ?? "").trim().toLowerCase();
97
+ const labelsStr = (match[2] ?? "").trim();
98
+ if (!name || name === "end") continue;
99
+ const labels = labelsStr.split(";").map((l) => l.trim().toLowerCase().replace(/\s+/g, "_")).filter((l) => l.length > 0 && l.length < 50);
100
+ if (labels.length >= 3 && labels.length <= 10) {
101
+ criteria[name] = {
102
+ description: `${labels.length} ordered labels`,
103
+ labels
104
+ };
105
+ }
106
+ }
107
+ if (Object.keys(criteria).length === 0) {
108
+ throw new Error(`Failed to parse sorting criteria. LLM output: ${text.slice(0, 500)}`);
109
+ }
110
+ return criteria;
111
+ }
112
+ });
113
+ const criteriaKeys = Object.keys(sortingCriteria);
114
+ if (criteriaKeys.length === 0) {
115
+ throw new Error("No sorting criteria generated");
116
+ }
117
+ const TOKENS_CRITERIA_MAX = Math.floor(TOKENS_TOTAL_MAX * 0.2);
118
+ const TOKENS_ITEMS_MAX = TOKENS_TOTAL_MAX - TOKENS_CRITERIA_MAX;
119
+ const MAX_ITEMS_PER_CHUNK = 50;
120
+ const elements = input.map((element, idx) => ({
121
+ element,
122
+ index: idx,
123
+ stringified: stringify(element, false)
124
+ }));
125
+ const chunks = [];
126
+ let currentChunk = [];
127
+ let currentTokens = 0;
128
+ for (const elem of elements) {
129
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerItem);
130
+ const elemTokens = tokenizer.count(truncated);
131
+ if ((currentTokens + elemTokens > TOKENS_ITEMS_MAX || currentChunk.length >= MAX_ITEMS_PER_CHUNK) && currentChunk.length > 0) {
132
+ chunks.push(currentChunk);
133
+ currentChunk = [];
134
+ currentTokens = 0;
135
+ }
136
+ currentChunk.push(elem);
137
+ currentTokens += elemTokens;
138
+ }
139
+ if (currentChunk.length > 0) {
140
+ chunks.push(currentChunk);
141
+ }
142
+ const scoreChunk = async (chunk) => {
143
+ ctx.controller.signal.throwIfAborted();
144
+ const chunkSize = chunk.length;
145
+ const chunkInputStr = JSON.stringify(chunk.map((c) => c.element));
146
+ const examples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
147
+ input: chunkInputStr.slice(0, 1e3),
148
+ taskType,
149
+ taskId
150
+ }) : [];
151
+ const key = fastHash(
152
+ stringify({
153
+ taskId,
154
+ taskType,
155
+ input: chunkInputStr,
156
+ instructions
157
+ })
158
+ );
159
+ const exactMatch = examples.find((x) => x.key === key);
160
+ if (exactMatch && exactMatch.output) {
161
+ return exactMatch.output;
162
+ }
163
+ const elementsText = chunk.map((elem, i) => {
164
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerItem);
165
+ return `\u25A0${i}: ${truncated}\u25A0`;
166
+ }).join("\n");
167
+ const criteriaText = criteriaKeys.map((key2) => {
168
+ const criterion = sortingCriteria[key2];
169
+ const labelsText = criterion.labels.join(";");
170
+ return `**${key2}**: ${labelsText}`;
171
+ }).join("\n");
172
+ const exampleMessages = [];
173
+ for (const example of examples.slice(0, 3)) {
174
+ try {
175
+ const exampleInput = JSON.parse(example.input);
176
+ const exampleItems = Array.isArray(exampleInput) ? exampleInput : [exampleInput];
177
+ exampleMessages.push({
178
+ type: "text",
179
+ role: "user",
180
+ content: `Expert Example - Items to score:
181
+ ${exampleItems.map((el, i) => `\u25A0${i}: ${stringify(el, false).slice(0, 200)}\u25A0`).join("\n")}
182
+
183
+ Score each item.`
184
+ });
185
+ const exampleOutput = example.output;
186
+ if (Array.isArray(exampleOutput) && exampleOutput.length > 0) {
187
+ const formattedScores = exampleOutput.map((score) => {
188
+ const pairs = criteriaKeys.map((key2) => `${key2}=${score.scores[key2] ?? 0}`).join(";");
189
+ return `\u25A0${score.elementIndex}:${pairs}\u25A0`;
190
+ }).join("\n");
191
+ exampleMessages.push({
192
+ type: "text",
193
+ role: "assistant",
194
+ content: `${formattedScores}
195
+ ${END}`
196
+ });
197
+ if (example.explanation) {
198
+ exampleMessages.push({
199
+ type: "text",
200
+ role: "assistant",
201
+ content: `Reasoning: ${example.explanation}`
202
+ });
203
+ }
204
+ }
205
+ } catch {
206
+ }
207
+ }
208
+ const { extracted } = await ctx.generateContent({
209
+ systemPrompt: `You are ranking items for sorting using ordered label arrays.
210
+
211
+ ${criteriaText}
212
+
213
+ Instructions: "${instructions}"
214
+
215
+ SCORING RULES:
216
+ - For each item and each criterion, assign ONE label from the ordered list
217
+ - Labels are ordered: first label = appears FIRST in sorted result, last label = appears LAST
218
+ - Choose the label that best describes each item
219
+
220
+ Output format:
221
+ \u25A00:criterion1=label;criterion2=label\u25A0
222
+ \u25A01:criterion1=label;criterion2=label\u25A0
223
+ ${END}
224
+
225
+ IMPORTANT:
226
+ - Rank every item (\u25A00 to \u25A0${chunkSize - 1})
227
+ - Use exact criterion names: ${criteriaKeys.join(", ")}
228
+ - Use exact labels from the lists above (lowercase, underscores)
229
+ - Use semicolons (;) between criteria
230
+ - Use equals (=) between criterion and label`,
231
+ stopSequences: [END],
232
+ messages: [
233
+ ...exampleMessages,
234
+ {
235
+ type: "text",
236
+ role: "user",
237
+ content: `Items to rank (\u25A00 to \u25A0${chunkSize - 1}):
238
+ ${elementsText}
239
+
240
+ Rank each item using the labeled scales.
241
+ Output format: \u25A0index:criterion1=label;criterion2=label\u25A0
242
+ ${END}`
243
+ }
244
+ ],
245
+ transform: (text) => {
246
+ const results = [];
247
+ const regex = /■(\d+):([^■]+)■/g;
248
+ let match;
249
+ while ((match = regex.exec(text)) !== null) {
250
+ const idx = parseInt(match[1] ?? "", 10);
251
+ const labelsStr = match[2] ?? "";
252
+ if (isNaN(idx) || idx < 0 || idx >= chunkSize) continue;
253
+ const scores = {};
254
+ let total = 0;
255
+ const pairs = labelsStr.split(";").filter((x) => x.trim().length > 0);
256
+ for (const pair of pairs) {
257
+ const [criterion, labelStr] = pair.split("=").map((x) => x.trim().toLowerCase().replace(/\s+/g, "_"));
258
+ if (!criterion || !labelStr) continue;
259
+ const labels = sortingCriteria[criterion]?.labels ?? [];
260
+ const labelIndex = labels.findIndex((l) => l === labelStr);
261
+ if (labelIndex >= 0) {
262
+ scores[criterion] = labelIndex;
263
+ total += labelIndex;
264
+ } else {
265
+ const middleIndex = labels.length > 0 ? Math.floor(labels.length / 2) : 5;
266
+ scores[criterion] = middleIndex;
267
+ total += middleIndex;
268
+ }
269
+ }
270
+ results[idx] = {
271
+ elementIndex: chunk[idx].index,
272
+ scores,
273
+ totalScore: total
274
+ };
275
+ }
276
+ for (let i = 0; i < chunkSize; i++) {
277
+ if (!results[i]) {
278
+ const scores = {};
279
+ let total = 0;
280
+ for (const key2 of criteriaKeys) {
281
+ const labels = sortingCriteria[key2]?.labels ?? [];
282
+ const middleIndex = labels.length > 0 ? Math.floor(labels.length / 2) : 5;
283
+ scores[key2] = middleIndex;
284
+ total += middleIndex;
285
+ }
286
+ results[i] = {
287
+ elementIndex: chunk[i].index,
288
+ scores,
289
+ totalScore: total
290
+ };
291
+ }
292
+ }
293
+ return results;
294
+ }
295
+ });
296
+ return extracted;
297
+ };
298
+ const limit = pLimit(10);
299
+ const chunkPromises = chunks.map((chunk) => limit(() => scoreChunk(chunk)));
300
+ const allScores = await Promise.all(chunkPromises);
301
+ const scoreMap = /* @__PURE__ */ new Map();
302
+ for (const chunkScores of allScores) {
303
+ for (const itemScore of chunkScores) {
304
+ const existing = scoreMap.get(itemScore.elementIndex);
305
+ if (existing) {
306
+ for (const key of criteriaKeys) {
307
+ existing.scores[key] = (existing.scores[key] + (itemScore.scores[key] ?? 0)) / 2;
308
+ }
309
+ existing.totalScore = (existing.totalScore + itemScore.totalScore) / 2;
310
+ } else {
311
+ scoreMap.set(itemScore.elementIndex, {
312
+ scores: { ...itemScore.scores },
313
+ totalScore: itemScore.totalScore
314
+ });
315
+ }
316
+ }
317
+ }
318
+ if (scoreMap.size !== input.length) {
319
+ throw new Error(`Score map size mismatch: expected ${input.length}, got ${scoreMap.size}`);
320
+ }
321
+ const scoreGroups = /* @__PURE__ */ new Map();
322
+ for (const [index, scoreData] of scoreMap.entries()) {
323
+ const roundedScore = Math.round(scoreData.totalScore * 100);
324
+ const group = scoreGroups.get(roundedScore) ?? [];
325
+ group.push(index);
326
+ scoreGroups.set(roundedScore, group);
327
+ }
328
+ const tiedGroups = Array.from(scoreGroups.values()).filter((group) => group.length > 1);
329
+ if (tiedGroups.length > 0) {
330
+ const tieBreakLimit = pLimit(10);
331
+ await Promise.all(
332
+ tiedGroups.map(
333
+ (tiedIndices) => tieBreakLimit(async () => {
334
+ if (tiedIndices.length <= 1) return;
335
+ const tiedElements = tiedIndices.map((idx) => elements[idx]);
336
+ const tieBreakText = tiedElements.map((elem, i) => {
337
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerItem);
338
+ return `\u25A0${i}: ${truncated}\u25A0`;
339
+ }).join("\n");
340
+ const { extracted: tieBreakOrder } = await ctx.generateContent({
341
+ systemPrompt: `You are breaking a tie between items with identical total scores.
342
+
343
+ Instructions: ${instructions}
344
+
345
+ Criteria:
346
+ ${criteriaKeys.map((key) => {
347
+ const labels = sortingCriteria[key].labels.join(";");
348
+ return `- ${key}: ${labels}`;
349
+ }).join("\n")}
350
+
351
+ Order these ${tiedElements.length} items from FIRST to LAST based on the instructions.
352
+ Earlier labels in each criterion should come FIRST.
353
+
354
+ Output format:
355
+ \u25A0original_index\u25A0
356
+ \u25A0original_index\u25A0
357
+ ${END}
358
+
359
+ Output the indices in the order they should appear (first item at top).`,
360
+ stopSequences: [END],
361
+ messages: [
362
+ {
363
+ type: "text",
364
+ role: "user",
365
+ content: `Items with identical scores (need tie-breaking):
366
+ ${tieBreakText}
367
+
368
+ Order them from first to last.
369
+ Output format: \u25A0index\u25A0 (one per line)
370
+ ${END}`
371
+ }
372
+ ],
373
+ transform: (text) => {
374
+ const order = [];
375
+ const regex = /■(\d+)■/g;
376
+ let match;
377
+ while ((match = regex.exec(text)) !== null) {
378
+ const idx = parseInt(match[1] ?? "", 10);
379
+ if (!isNaN(idx) && idx >= 0 && idx < tiedElements.length) {
380
+ order.push(idx);
381
+ }
382
+ }
383
+ for (let i = 0; i < tiedElements.length; i++) {
384
+ if (!order.includes(i)) {
385
+ order.push(i);
386
+ }
387
+ }
388
+ return order;
389
+ }
390
+ });
391
+ for (let i = 0; i < tieBreakOrder.length; i++) {
392
+ const elementIndex = tiedElements[tieBreakOrder[i]].index;
393
+ const scoreData = scoreMap.get(elementIndex);
394
+ if (scoreData) {
395
+ scoreData.tieBreakOrder = i;
396
+ }
397
+ }
398
+ })
399
+ )
400
+ );
401
+ }
402
+ const sorted = Array.from(scoreMap.entries()).sort((a, b) => {
403
+ const scoreDiff = a[1].totalScore - b[1].totalScore;
404
+ if (scoreDiff !== 0) return scoreDiff;
405
+ const orderA = a[1].tieBreakOrder ?? 0;
406
+ const orderB = b[1].tieBreakOrder ?? 0;
407
+ return orderA - orderB;
408
+ }).map(([index]) => elements[index].element);
409
+ const result = sorted;
410
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
411
+ const key = fastHash(
412
+ stringify({
413
+ taskId,
414
+ taskType,
415
+ input: JSON.stringify(input),
416
+ instructions
417
+ })
418
+ );
419
+ await ctx.adapter.saveExample({
420
+ key,
421
+ taskType,
422
+ taskId,
423
+ input: JSON.stringify(input),
424
+ output: result,
425
+ instructions,
426
+ metadata: {
427
+ cost: { input: 0, output: 0 },
428
+ latency: 0,
429
+ model: ctx.modelId,
430
+ tokens: { input: 0, output: 0 }
431
+ }
432
+ });
433
+ }
434
+ return result;
435
+ };
436
+ Zai.prototype.sort = function(input, instructions, _options) {
437
+ const context = new ZaiContext({
438
+ client: this.client,
439
+ modelId: this.Model,
440
+ taskId: this.taskId,
441
+ taskType: "zai.sort",
442
+ adapter: this.adapter
443
+ });
444
+ return new Response(
445
+ context,
446
+ sort(input, instructions, _options, context),
447
+ (result) => result
448
+ // Simplified form is just the sorted array
449
+ );
450
+ };