@botpress/zai 2.2.0 → 2.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,350 @@
1
+ import { z } from "@bpinternal/zui";
2
+ import pLimit from "p-limit";
3
+ import { ZaiContext } from "../context";
4
+ import { Response } from "../response";
5
+ import { getTokenizer } from "../tokenizer";
6
+ import { fastHash, stringify } from "../utils";
7
+ import { Zai } from "../zai";
8
+ import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from "./constants";
9
+ const RATING_VALUES = {
10
+ very_bad: 1,
11
+ bad: 2,
12
+ average: 3,
13
+ good: 4,
14
+ very_good: 5
15
+ };
16
+ const _Options = z.object({
17
+ tokensPerItem: z.number().min(1).max(1e5).optional().describe("The maximum number of tokens per item").default(250),
18
+ maxItemsPerChunk: z.number().min(1).max(100).optional().describe("The maximum number of items to rate per chunk").default(50)
19
+ });
20
+ const END = "\u25A0END\u25A0";
21
+ const rate = async (input, instructions, _options, ctx) => {
22
+ ctx.controller.signal.throwIfAborted();
23
+ const options = _Options.parse(_options ?? {});
24
+ const tokenizer = await getTokenizer();
25
+ const model = await ctx.getModel();
26
+ if (input.length === 0) {
27
+ return [];
28
+ }
29
+ const taskId = ctx.taskId;
30
+ const taskType = "zai.rate";
31
+ const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER;
32
+ const isStringInstructions = typeof instructions === "string";
33
+ const criteriaKeys = isStringInstructions ? [] : Object.keys(instructions);
34
+ const generateCriteriaPrompt = isStringInstructions ? `Generate 3-5 evaluation criteria for: "${instructions}"
35
+
36
+ For each criterion, provide 5 labels (very_bad, bad, average, good, very_good) with brief descriptions.
37
+
38
+ Output format (JSON):
39
+ {
40
+ "criterion1_name": {
41
+ "very_bad": "description",
42
+ "bad": "description",
43
+ "average": "description",
44
+ "good": "description",
45
+ "very_good": "description"
46
+ },
47
+ "criterion2_name": { ... }
48
+ }
49
+
50
+ Keep criterion names short (1-2 words, lowercase, use underscores).
51
+ Keep descriptions brief (5-10 words each).` : `For these evaluation criteria, provide 5 labels (very_bad, bad, average, good, very_good) with brief descriptions for each:
52
+
53
+ ${criteriaKeys.map((key) => `- ${key}: ${instructions[key]}`).join("\n")}
54
+
55
+ Output format (JSON):
56
+ {
57
+ "${criteriaKeys[0]}": {
58
+ "very_bad": "description",
59
+ "bad": "description",
60
+ "average": "description",
61
+ "good": "description",
62
+ "very_good": "description"
63
+ }
64
+ ${criteriaKeys.length > 1 ? "..." : ""}
65
+ }
66
+
67
+ Keep descriptions brief (5-10 words each).`;
68
+ const { extracted: evaluationCriteria } = await ctx.generateContent({
69
+ systemPrompt: `You are creating evaluation criteria for rating items on a 1-5 scale.
70
+ Each criterion must have exactly 5 labels: very_bad (1), bad (2), average (3), good (4), very_good (5).
71
+ Output valid JSON only.`,
72
+ messages: [
73
+ {
74
+ type: "text",
75
+ role: "user",
76
+ content: generateCriteriaPrompt
77
+ }
78
+ ],
79
+ transform: (text) => {
80
+ const jsonMatch = text.match(/```(?:json)?\s*(\{[\s\S]*?\})\s*```/) || text.match(/(\{[\s\S]*\})/);
81
+ if (!jsonMatch) {
82
+ throw new Error("Failed to parse evaluation criteria JSON");
83
+ }
84
+ return JSON.parse(jsonMatch[1]);
85
+ }
86
+ });
87
+ const finalCriteriaKeys = Object.keys(evaluationCriteria);
88
+ if (finalCriteriaKeys.length === 0) {
89
+ throw new Error("No evaluation criteria generated");
90
+ }
91
+ const TOKENS_CRITERIA_MAX = Math.floor(TOKENS_TOTAL_MAX * 0.3);
92
+ const TOKENS_ITEMS_MAX = TOKENS_TOTAL_MAX - TOKENS_CRITERIA_MAX;
93
+ let chunks = [];
94
+ let currentChunk = [];
95
+ let currentChunkTokens = 0;
96
+ for (const element of input) {
97
+ const elementAsString = tokenizer.truncate(stringify(element, false), options.tokensPerItem);
98
+ const elementTokens = tokenizer.count(elementAsString);
99
+ if (currentChunkTokens + elementTokens > TOKENS_ITEMS_MAX || currentChunk.length >= options.maxItemsPerChunk) {
100
+ if (currentChunk.length > 0) {
101
+ chunks.push(currentChunk);
102
+ }
103
+ currentChunk = [];
104
+ currentChunkTokens = 0;
105
+ }
106
+ currentChunk.push(element);
107
+ currentChunkTokens += elementTokens;
108
+ }
109
+ if (currentChunk.length > 0) {
110
+ chunks.push(currentChunk);
111
+ }
112
+ chunks = chunks.filter((x) => x.length > 0);
113
+ const rateChunk = async (chunk) => {
114
+ ctx.controller.signal.throwIfAborted();
115
+ const chunkInputStr = JSON.stringify(chunk);
116
+ const examples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
117
+ input: chunkInputStr.slice(0, 1e3),
118
+ // Limit search string length
119
+ taskType,
120
+ taskId
121
+ }) : [];
122
+ const key = fastHash(
123
+ stringify({
124
+ taskId,
125
+ taskType,
126
+ input: chunkInputStr,
127
+ instructions: stringify(instructions)
128
+ })
129
+ );
130
+ const exactMatch = examples.find((x) => x.key === key);
131
+ if (exactMatch && exactMatch.output) {
132
+ return {
133
+ ratings: exactMatch.output,
134
+ meta: { cost: { input: 0, output: 0 }, latency: 0, tokens: { input: 0, output: 0 } }
135
+ };
136
+ }
137
+ const formatCriteria = () => {
138
+ return finalCriteriaKeys.map((key2) => {
139
+ const labels = evaluationCriteria[key2];
140
+ return `**${key2}**:
141
+ - very_bad (1): ${labels?.very_bad}
142
+ - bad (2): ${labels?.bad}
143
+ - average (3): ${labels?.average}
144
+ - good (4): ${labels?.good}
145
+ - very_good (5): ${labels?.very_good}`;
146
+ }).join("\n\n");
147
+ };
148
+ const formatItems = (items) => {
149
+ return items.map((item, idx) => {
150
+ const itemStr = tokenizer.truncate(stringify(item, false), options.tokensPerItem);
151
+ return `\u25A0${idx}: ${itemStr}\u25A0`;
152
+ }).join("\n");
153
+ };
154
+ const exampleMessages = [];
155
+ for (const example of examples.slice(0, 5)) {
156
+ try {
157
+ const exampleInput = JSON.parse(example.input);
158
+ exampleMessages.push({
159
+ type: "text",
160
+ role: "user",
161
+ content: `Expert Example - Items to rate:
162
+ ${formatItems(Array.isArray(exampleInput) ? exampleInput : [exampleInput])}
163
+
164
+ Rate each item on all criteria.`
165
+ });
166
+ const exampleOutput = example.output;
167
+ if (Array.isArray(exampleOutput) && exampleOutput.length > 0) {
168
+ const formattedRatings = exampleOutput.map((rating, idx) => {
169
+ const pairs = finalCriteriaKeys.map((key2) => {
170
+ const value = rating[key2];
171
+ if (typeof value === "number") {
172
+ const labelMap = {
173
+ 1: "very_bad",
174
+ 2: "bad",
175
+ 3: "average",
176
+ 4: "good",
177
+ 5: "very_good"
178
+ };
179
+ return `${key2}=${labelMap[value] || "average"}`;
180
+ }
181
+ return null;
182
+ }).filter(Boolean).join(";");
183
+ return `\u25A0${idx}:${pairs}\u25A0`;
184
+ }).join("\n");
185
+ exampleMessages.push({
186
+ type: "text",
187
+ role: "assistant",
188
+ content: `${formattedRatings}
189
+ ${END}`
190
+ });
191
+ if (example.explanation) {
192
+ exampleMessages.push({
193
+ type: "text",
194
+ role: "assistant",
195
+ content: `Reasoning: ${example.explanation}`
196
+ });
197
+ }
198
+ }
199
+ } catch {
200
+ }
201
+ }
202
+ const { extracted, meta } = await ctx.generateContent({
203
+ systemPrompt: `You are rating items based on evaluation criteria.
204
+
205
+ Evaluation Criteria:
206
+ ${formatCriteria()}
207
+
208
+ For each item, rate it on EACH criterion using one of these labels:
209
+ very_bad, bad, average, good, very_good
210
+
211
+ Output format:
212
+ \u25A00:criterion1=label;criterion2=label;criterion3=label\u25A0
213
+ \u25A01:criterion1=label;criterion2=label;criterion3=label\u25A0
214
+ ${END}
215
+
216
+ IMPORTANT:
217
+ - Rate every item (\u25A00 to \u25A0${chunk.length - 1})
218
+ - Use exact criterion names: ${finalCriteriaKeys.join(", ")}
219
+ - Use exact label names: very_bad, bad, average, good, very_good
220
+ - Use semicolons (;) between criteria
221
+ - Use equals (=) between criterion and label`,
222
+ stopSequences: [END],
223
+ messages: [
224
+ ...exampleMessages,
225
+ {
226
+ type: "text",
227
+ role: "user",
228
+ content: `Items to rate (\u25A00 to \u25A0${chunk.length - 1}):
229
+ ${formatItems(chunk)}
230
+
231
+ Rate each item on all criteria.
232
+ Output format: \u25A0index:criterion1=label;criterion2=label\u25A0
233
+ ${END}`
234
+ }
235
+ ],
236
+ transform: (text) => {
237
+ const results = [];
238
+ const regex = /■(\d+):([^■]+)■/g;
239
+ let match;
240
+ while ((match = regex.exec(text)) !== null) {
241
+ const idx = parseInt(match[1] ?? "", 10);
242
+ const ratingsStr = match[2] ?? "";
243
+ if (isNaN(idx) || idx < 0 || idx >= chunk.length) {
244
+ continue;
245
+ }
246
+ const itemRatings = {};
247
+ let total = 0;
248
+ const pairs = ratingsStr.split(";").filter((x) => x.trim().length > 0);
249
+ for (const pair of pairs) {
250
+ const [criterion, label] = pair.split("=").map((x) => x.trim());
251
+ if (!criterion || !label) continue;
252
+ const labelLower = label.toLowerCase().replace(/\s+/g, "_");
253
+ const ratingValue = RATING_VALUES[labelLower] ?? 3;
254
+ itemRatings[criterion] = ratingValue;
255
+ total += ratingValue;
256
+ }
257
+ itemRatings.total = total;
258
+ results[idx] = itemRatings;
259
+ }
260
+ for (let i = 0; i < chunk.length; i++) {
261
+ if (!results[i]) {
262
+ const defaultRatings = {};
263
+ let total = 0;
264
+ for (const key2 of finalCriteriaKeys) {
265
+ defaultRatings[key2] = 3;
266
+ total += 3;
267
+ }
268
+ defaultRatings.total = total;
269
+ results[i] = defaultRatings;
270
+ }
271
+ }
272
+ return results;
273
+ }
274
+ });
275
+ return { ratings: extracted, meta };
276
+ };
277
+ const limit = pLimit(10);
278
+ const chunkPromises = chunks.map((chunk) => limit(() => rateChunk(chunk)));
279
+ const ratedChunks = await Promise.all(chunkPromises);
280
+ const allRatings = ratedChunks.flatMap((result) => result.ratings);
281
+ const totalMeta = ratedChunks.reduce(
282
+ (acc, result) => ({
283
+ cost: {
284
+ input: acc.cost.input + result.meta.cost.input,
285
+ output: acc.cost.output + result.meta.cost.output
286
+ },
287
+ latency: Math.max(acc.latency, result.meta.latency),
288
+ // Use max latency
289
+ tokens: {
290
+ input: acc.tokens.input + result.meta.tokens.input,
291
+ output: acc.tokens.output + result.meta.tokens.output
292
+ }
293
+ }),
294
+ {
295
+ cost: { input: 0, output: 0 },
296
+ latency: 0,
297
+ tokens: { input: 0, output: 0 }
298
+ }
299
+ );
300
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
301
+ const key = fastHash(
302
+ stringify({
303
+ taskId,
304
+ taskType,
305
+ input: JSON.stringify(input),
306
+ instructions: stringify(instructions)
307
+ })
308
+ );
309
+ await ctx.adapter.saveExample({
310
+ key,
311
+ taskType,
312
+ taskId,
313
+ input: JSON.stringify(input),
314
+ output: allRatings,
315
+ instructions: typeof instructions === "string" ? instructions : JSON.stringify(instructions),
316
+ metadata: {
317
+ cost: {
318
+ input: totalMeta.cost.input,
319
+ output: totalMeta.cost.output
320
+ },
321
+ latency: totalMeta.latency,
322
+ model: ctx.modelId,
323
+ tokens: {
324
+ input: totalMeta.tokens.input,
325
+ output: totalMeta.tokens.output
326
+ }
327
+ }
328
+ });
329
+ }
330
+ return allRatings;
331
+ };
332
+ Zai.prototype.rate = function(input, instructions, _options) {
333
+ const context = new ZaiContext({
334
+ client: this.client,
335
+ modelId: this.Model,
336
+ taskId: this.taskId,
337
+ taskType: "zai.rate",
338
+ adapter: this.adapter
339
+ });
340
+ return new Response(
341
+ context,
342
+ rate(input, instructions, _options, context),
343
+ (result) => {
344
+ if (typeof instructions === "string") {
345
+ return result.map((r) => r.total);
346
+ }
347
+ return result;
348
+ }
349
+ );
350
+ };