@botpress/zai 2.2.0 → 2.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +53 -2
- package/dist/index.d.ts +65 -16
- package/dist/index.js +2 -0
- package/dist/operations/group.js +93 -2
- package/dist/operations/rate.js +350 -0
- package/dist/operations/sort.js +450 -0
- package/e2e/data/cache.jsonl +142 -0
- package/package.json +1 -1
- package/src/index.ts +2 -0
- package/src/operations/group.ts +124 -2
- package/src/operations/rate.ts +518 -0
- package/src/operations/sort.ts +618 -0
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
import { z } from "@bpinternal/zui";
|
|
2
|
+
import pLimit from "p-limit";
|
|
3
|
+
import { ZaiContext } from "../context";
|
|
4
|
+
import { Response } from "../response";
|
|
5
|
+
import { getTokenizer } from "../tokenizer";
|
|
6
|
+
import { fastHash, stringify } from "../utils";
|
|
7
|
+
import { Zai } from "../zai";
|
|
8
|
+
import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from "./constants";
|
|
9
|
+
const RATING_VALUES = {
|
|
10
|
+
very_bad: 1,
|
|
11
|
+
bad: 2,
|
|
12
|
+
average: 3,
|
|
13
|
+
good: 4,
|
|
14
|
+
very_good: 5
|
|
15
|
+
};
|
|
16
|
+
const _Options = z.object({
|
|
17
|
+
tokensPerItem: z.number().min(1).max(1e5).optional().describe("The maximum number of tokens per item").default(250),
|
|
18
|
+
maxItemsPerChunk: z.number().min(1).max(100).optional().describe("The maximum number of items to rate per chunk").default(50)
|
|
19
|
+
});
|
|
20
|
+
const END = "\u25A0END\u25A0";
|
|
21
|
+
const rate = async (input, instructions, _options, ctx) => {
|
|
22
|
+
ctx.controller.signal.throwIfAborted();
|
|
23
|
+
const options = _Options.parse(_options ?? {});
|
|
24
|
+
const tokenizer = await getTokenizer();
|
|
25
|
+
const model = await ctx.getModel();
|
|
26
|
+
if (input.length === 0) {
|
|
27
|
+
return [];
|
|
28
|
+
}
|
|
29
|
+
const taskId = ctx.taskId;
|
|
30
|
+
const taskType = "zai.rate";
|
|
31
|
+
const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER;
|
|
32
|
+
const isStringInstructions = typeof instructions === "string";
|
|
33
|
+
const criteriaKeys = isStringInstructions ? [] : Object.keys(instructions);
|
|
34
|
+
const generateCriteriaPrompt = isStringInstructions ? `Generate 3-5 evaluation criteria for: "${instructions}"
|
|
35
|
+
|
|
36
|
+
For each criterion, provide 5 labels (very_bad, bad, average, good, very_good) with brief descriptions.
|
|
37
|
+
|
|
38
|
+
Output format (JSON):
|
|
39
|
+
{
|
|
40
|
+
"criterion1_name": {
|
|
41
|
+
"very_bad": "description",
|
|
42
|
+
"bad": "description",
|
|
43
|
+
"average": "description",
|
|
44
|
+
"good": "description",
|
|
45
|
+
"very_good": "description"
|
|
46
|
+
},
|
|
47
|
+
"criterion2_name": { ... }
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
Keep criterion names short (1-2 words, lowercase, use underscores).
|
|
51
|
+
Keep descriptions brief (5-10 words each).` : `For these evaluation criteria, provide 5 labels (very_bad, bad, average, good, very_good) with brief descriptions for each:
|
|
52
|
+
|
|
53
|
+
${criteriaKeys.map((key) => `- ${key}: ${instructions[key]}`).join("\n")}
|
|
54
|
+
|
|
55
|
+
Output format (JSON):
|
|
56
|
+
{
|
|
57
|
+
"${criteriaKeys[0]}": {
|
|
58
|
+
"very_bad": "description",
|
|
59
|
+
"bad": "description",
|
|
60
|
+
"average": "description",
|
|
61
|
+
"good": "description",
|
|
62
|
+
"very_good": "description"
|
|
63
|
+
}
|
|
64
|
+
${criteriaKeys.length > 1 ? "..." : ""}
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
Keep descriptions brief (5-10 words each).`;
|
|
68
|
+
const { extracted: evaluationCriteria } = await ctx.generateContent({
|
|
69
|
+
systemPrompt: `You are creating evaluation criteria for rating items on a 1-5 scale.
|
|
70
|
+
Each criterion must have exactly 5 labels: very_bad (1), bad (2), average (3), good (4), very_good (5).
|
|
71
|
+
Output valid JSON only.`,
|
|
72
|
+
messages: [
|
|
73
|
+
{
|
|
74
|
+
type: "text",
|
|
75
|
+
role: "user",
|
|
76
|
+
content: generateCriteriaPrompt
|
|
77
|
+
}
|
|
78
|
+
],
|
|
79
|
+
transform: (text) => {
|
|
80
|
+
const jsonMatch = text.match(/```(?:json)?\s*(\{[\s\S]*?\})\s*```/) || text.match(/(\{[\s\S]*\})/);
|
|
81
|
+
if (!jsonMatch) {
|
|
82
|
+
throw new Error("Failed to parse evaluation criteria JSON");
|
|
83
|
+
}
|
|
84
|
+
return JSON.parse(jsonMatch[1]);
|
|
85
|
+
}
|
|
86
|
+
});
|
|
87
|
+
const finalCriteriaKeys = Object.keys(evaluationCriteria);
|
|
88
|
+
if (finalCriteriaKeys.length === 0) {
|
|
89
|
+
throw new Error("No evaluation criteria generated");
|
|
90
|
+
}
|
|
91
|
+
const TOKENS_CRITERIA_MAX = Math.floor(TOKENS_TOTAL_MAX * 0.3);
|
|
92
|
+
const TOKENS_ITEMS_MAX = TOKENS_TOTAL_MAX - TOKENS_CRITERIA_MAX;
|
|
93
|
+
let chunks = [];
|
|
94
|
+
let currentChunk = [];
|
|
95
|
+
let currentChunkTokens = 0;
|
|
96
|
+
for (const element of input) {
|
|
97
|
+
const elementAsString = tokenizer.truncate(stringify(element, false), options.tokensPerItem);
|
|
98
|
+
const elementTokens = tokenizer.count(elementAsString);
|
|
99
|
+
if (currentChunkTokens + elementTokens > TOKENS_ITEMS_MAX || currentChunk.length >= options.maxItemsPerChunk) {
|
|
100
|
+
if (currentChunk.length > 0) {
|
|
101
|
+
chunks.push(currentChunk);
|
|
102
|
+
}
|
|
103
|
+
currentChunk = [];
|
|
104
|
+
currentChunkTokens = 0;
|
|
105
|
+
}
|
|
106
|
+
currentChunk.push(element);
|
|
107
|
+
currentChunkTokens += elementTokens;
|
|
108
|
+
}
|
|
109
|
+
if (currentChunk.length > 0) {
|
|
110
|
+
chunks.push(currentChunk);
|
|
111
|
+
}
|
|
112
|
+
chunks = chunks.filter((x) => x.length > 0);
|
|
113
|
+
const rateChunk = async (chunk) => {
|
|
114
|
+
ctx.controller.signal.throwIfAborted();
|
|
115
|
+
const chunkInputStr = JSON.stringify(chunk);
|
|
116
|
+
const examples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
|
|
117
|
+
input: chunkInputStr.slice(0, 1e3),
|
|
118
|
+
// Limit search string length
|
|
119
|
+
taskType,
|
|
120
|
+
taskId
|
|
121
|
+
}) : [];
|
|
122
|
+
const key = fastHash(
|
|
123
|
+
stringify({
|
|
124
|
+
taskId,
|
|
125
|
+
taskType,
|
|
126
|
+
input: chunkInputStr,
|
|
127
|
+
instructions: stringify(instructions)
|
|
128
|
+
})
|
|
129
|
+
);
|
|
130
|
+
const exactMatch = examples.find((x) => x.key === key);
|
|
131
|
+
if (exactMatch && exactMatch.output) {
|
|
132
|
+
return {
|
|
133
|
+
ratings: exactMatch.output,
|
|
134
|
+
meta: { cost: { input: 0, output: 0 }, latency: 0, tokens: { input: 0, output: 0 } }
|
|
135
|
+
};
|
|
136
|
+
}
|
|
137
|
+
const formatCriteria = () => {
|
|
138
|
+
return finalCriteriaKeys.map((key2) => {
|
|
139
|
+
const labels = evaluationCriteria[key2];
|
|
140
|
+
return `**${key2}**:
|
|
141
|
+
- very_bad (1): ${labels?.very_bad}
|
|
142
|
+
- bad (2): ${labels?.bad}
|
|
143
|
+
- average (3): ${labels?.average}
|
|
144
|
+
- good (4): ${labels?.good}
|
|
145
|
+
- very_good (5): ${labels?.very_good}`;
|
|
146
|
+
}).join("\n\n");
|
|
147
|
+
};
|
|
148
|
+
const formatItems = (items) => {
|
|
149
|
+
return items.map((item, idx) => {
|
|
150
|
+
const itemStr = tokenizer.truncate(stringify(item, false), options.tokensPerItem);
|
|
151
|
+
return `\u25A0${idx}: ${itemStr}\u25A0`;
|
|
152
|
+
}).join("\n");
|
|
153
|
+
};
|
|
154
|
+
const exampleMessages = [];
|
|
155
|
+
for (const example of examples.slice(0, 5)) {
|
|
156
|
+
try {
|
|
157
|
+
const exampleInput = JSON.parse(example.input);
|
|
158
|
+
exampleMessages.push({
|
|
159
|
+
type: "text",
|
|
160
|
+
role: "user",
|
|
161
|
+
content: `Expert Example - Items to rate:
|
|
162
|
+
${formatItems(Array.isArray(exampleInput) ? exampleInput : [exampleInput])}
|
|
163
|
+
|
|
164
|
+
Rate each item on all criteria.`
|
|
165
|
+
});
|
|
166
|
+
const exampleOutput = example.output;
|
|
167
|
+
if (Array.isArray(exampleOutput) && exampleOutput.length > 0) {
|
|
168
|
+
const formattedRatings = exampleOutput.map((rating, idx) => {
|
|
169
|
+
const pairs = finalCriteriaKeys.map((key2) => {
|
|
170
|
+
const value = rating[key2];
|
|
171
|
+
if (typeof value === "number") {
|
|
172
|
+
const labelMap = {
|
|
173
|
+
1: "very_bad",
|
|
174
|
+
2: "bad",
|
|
175
|
+
3: "average",
|
|
176
|
+
4: "good",
|
|
177
|
+
5: "very_good"
|
|
178
|
+
};
|
|
179
|
+
return `${key2}=${labelMap[value] || "average"}`;
|
|
180
|
+
}
|
|
181
|
+
return null;
|
|
182
|
+
}).filter(Boolean).join(";");
|
|
183
|
+
return `\u25A0${idx}:${pairs}\u25A0`;
|
|
184
|
+
}).join("\n");
|
|
185
|
+
exampleMessages.push({
|
|
186
|
+
type: "text",
|
|
187
|
+
role: "assistant",
|
|
188
|
+
content: `${formattedRatings}
|
|
189
|
+
${END}`
|
|
190
|
+
});
|
|
191
|
+
if (example.explanation) {
|
|
192
|
+
exampleMessages.push({
|
|
193
|
+
type: "text",
|
|
194
|
+
role: "assistant",
|
|
195
|
+
content: `Reasoning: ${example.explanation}`
|
|
196
|
+
});
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
} catch {
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
const { extracted, meta } = await ctx.generateContent({
|
|
203
|
+
systemPrompt: `You are rating items based on evaluation criteria.
|
|
204
|
+
|
|
205
|
+
Evaluation Criteria:
|
|
206
|
+
${formatCriteria()}
|
|
207
|
+
|
|
208
|
+
For each item, rate it on EACH criterion using one of these labels:
|
|
209
|
+
very_bad, bad, average, good, very_good
|
|
210
|
+
|
|
211
|
+
Output format:
|
|
212
|
+
\u25A00:criterion1=label;criterion2=label;criterion3=label\u25A0
|
|
213
|
+
\u25A01:criterion1=label;criterion2=label;criterion3=label\u25A0
|
|
214
|
+
${END}
|
|
215
|
+
|
|
216
|
+
IMPORTANT:
|
|
217
|
+
- Rate every item (\u25A00 to \u25A0${chunk.length - 1})
|
|
218
|
+
- Use exact criterion names: ${finalCriteriaKeys.join(", ")}
|
|
219
|
+
- Use exact label names: very_bad, bad, average, good, very_good
|
|
220
|
+
- Use semicolons (;) between criteria
|
|
221
|
+
- Use equals (=) between criterion and label`,
|
|
222
|
+
stopSequences: [END],
|
|
223
|
+
messages: [
|
|
224
|
+
...exampleMessages,
|
|
225
|
+
{
|
|
226
|
+
type: "text",
|
|
227
|
+
role: "user",
|
|
228
|
+
content: `Items to rate (\u25A00 to \u25A0${chunk.length - 1}):
|
|
229
|
+
${formatItems(chunk)}
|
|
230
|
+
|
|
231
|
+
Rate each item on all criteria.
|
|
232
|
+
Output format: \u25A0index:criterion1=label;criterion2=label\u25A0
|
|
233
|
+
${END}`
|
|
234
|
+
}
|
|
235
|
+
],
|
|
236
|
+
transform: (text) => {
|
|
237
|
+
const results = [];
|
|
238
|
+
const regex = /■(\d+):([^■]+)■/g;
|
|
239
|
+
let match;
|
|
240
|
+
while ((match = regex.exec(text)) !== null) {
|
|
241
|
+
const idx = parseInt(match[1] ?? "", 10);
|
|
242
|
+
const ratingsStr = match[2] ?? "";
|
|
243
|
+
if (isNaN(idx) || idx < 0 || idx >= chunk.length) {
|
|
244
|
+
continue;
|
|
245
|
+
}
|
|
246
|
+
const itemRatings = {};
|
|
247
|
+
let total = 0;
|
|
248
|
+
const pairs = ratingsStr.split(";").filter((x) => x.trim().length > 0);
|
|
249
|
+
for (const pair of pairs) {
|
|
250
|
+
const [criterion, label] = pair.split("=").map((x) => x.trim());
|
|
251
|
+
if (!criterion || !label) continue;
|
|
252
|
+
const labelLower = label.toLowerCase().replace(/\s+/g, "_");
|
|
253
|
+
const ratingValue = RATING_VALUES[labelLower] ?? 3;
|
|
254
|
+
itemRatings[criterion] = ratingValue;
|
|
255
|
+
total += ratingValue;
|
|
256
|
+
}
|
|
257
|
+
itemRatings.total = total;
|
|
258
|
+
results[idx] = itemRatings;
|
|
259
|
+
}
|
|
260
|
+
for (let i = 0; i < chunk.length; i++) {
|
|
261
|
+
if (!results[i]) {
|
|
262
|
+
const defaultRatings = {};
|
|
263
|
+
let total = 0;
|
|
264
|
+
for (const key2 of finalCriteriaKeys) {
|
|
265
|
+
defaultRatings[key2] = 3;
|
|
266
|
+
total += 3;
|
|
267
|
+
}
|
|
268
|
+
defaultRatings.total = total;
|
|
269
|
+
results[i] = defaultRatings;
|
|
270
|
+
}
|
|
271
|
+
}
|
|
272
|
+
return results;
|
|
273
|
+
}
|
|
274
|
+
});
|
|
275
|
+
return { ratings: extracted, meta };
|
|
276
|
+
};
|
|
277
|
+
const limit = pLimit(10);
|
|
278
|
+
const chunkPromises = chunks.map((chunk) => limit(() => rateChunk(chunk)));
|
|
279
|
+
const ratedChunks = await Promise.all(chunkPromises);
|
|
280
|
+
const allRatings = ratedChunks.flatMap((result) => result.ratings);
|
|
281
|
+
const totalMeta = ratedChunks.reduce(
|
|
282
|
+
(acc, result) => ({
|
|
283
|
+
cost: {
|
|
284
|
+
input: acc.cost.input + result.meta.cost.input,
|
|
285
|
+
output: acc.cost.output + result.meta.cost.output
|
|
286
|
+
},
|
|
287
|
+
latency: Math.max(acc.latency, result.meta.latency),
|
|
288
|
+
// Use max latency
|
|
289
|
+
tokens: {
|
|
290
|
+
input: acc.tokens.input + result.meta.tokens.input,
|
|
291
|
+
output: acc.tokens.output + result.meta.tokens.output
|
|
292
|
+
}
|
|
293
|
+
}),
|
|
294
|
+
{
|
|
295
|
+
cost: { input: 0, output: 0 },
|
|
296
|
+
latency: 0,
|
|
297
|
+
tokens: { input: 0, output: 0 }
|
|
298
|
+
}
|
|
299
|
+
);
|
|
300
|
+
if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
|
|
301
|
+
const key = fastHash(
|
|
302
|
+
stringify({
|
|
303
|
+
taskId,
|
|
304
|
+
taskType,
|
|
305
|
+
input: JSON.stringify(input),
|
|
306
|
+
instructions: stringify(instructions)
|
|
307
|
+
})
|
|
308
|
+
);
|
|
309
|
+
await ctx.adapter.saveExample({
|
|
310
|
+
key,
|
|
311
|
+
taskType,
|
|
312
|
+
taskId,
|
|
313
|
+
input: JSON.stringify(input),
|
|
314
|
+
output: allRatings,
|
|
315
|
+
instructions: typeof instructions === "string" ? instructions : JSON.stringify(instructions),
|
|
316
|
+
metadata: {
|
|
317
|
+
cost: {
|
|
318
|
+
input: totalMeta.cost.input,
|
|
319
|
+
output: totalMeta.cost.output
|
|
320
|
+
},
|
|
321
|
+
latency: totalMeta.latency,
|
|
322
|
+
model: ctx.modelId,
|
|
323
|
+
tokens: {
|
|
324
|
+
input: totalMeta.tokens.input,
|
|
325
|
+
output: totalMeta.tokens.output
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
});
|
|
329
|
+
}
|
|
330
|
+
return allRatings;
|
|
331
|
+
};
|
|
332
|
+
Zai.prototype.rate = function(input, instructions, _options) {
|
|
333
|
+
const context = new ZaiContext({
|
|
334
|
+
client: this.client,
|
|
335
|
+
modelId: this.Model,
|
|
336
|
+
taskId: this.taskId,
|
|
337
|
+
taskType: "zai.rate",
|
|
338
|
+
adapter: this.adapter
|
|
339
|
+
});
|
|
340
|
+
return new Response(
|
|
341
|
+
context,
|
|
342
|
+
rate(input, instructions, _options, context),
|
|
343
|
+
(result) => {
|
|
344
|
+
if (typeof instructions === "string") {
|
|
345
|
+
return result.map((r) => r.total);
|
|
346
|
+
}
|
|
347
|
+
return result;
|
|
348
|
+
}
|
|
349
|
+
);
|
|
350
|
+
};
|