@botpress/zai 2.1.19 → 2.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CLAUDE.md +696 -0
- package/README.md +28 -2
- package/dist/index.d.ts +39 -18
- package/dist/index.js +1 -0
- package/dist/operations/errors.js +112 -8
- package/dist/operations/extract.js +20 -12
- package/dist/operations/filter.js +3 -1
- package/dist/operations/group.js +278 -0
- package/dist/operations/label.js +3 -1
- package/dist/operations/summarize.js +3 -1
- package/e2e/data/cache.jsonl +219 -0
- package/package.json +4 -3
- package/src/index.ts +1 -0
- package/src/operations/errors.ts +96 -1
- package/src/operations/extract.ts +21 -11
- package/src/operations/filter.ts +3 -1
- package/src/operations/group.ts +421 -0
- package/src/operations/label.ts +3 -1
- package/src/operations/summarize.ts +3 -2
- package/src/zai.ts +7 -9
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
import { z } from "@bpinternal/zui";
|
|
2
|
+
import { clamp } from "lodash-es";
|
|
3
|
+
import pLimit from "p-limit";
|
|
4
|
+
import { ZaiContext } from "../context";
|
|
5
|
+
import { Response } from "../response";
|
|
6
|
+
import { getTokenizer } from "../tokenizer";
|
|
7
|
+
import { stringify } from "../utils";
|
|
8
|
+
import { Zai } from "../zai";
|
|
9
|
+
import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from "./constants";
|
|
10
|
+
const _InitialGroup = z.object({
|
|
11
|
+
id: z.string().min(1).max(100),
|
|
12
|
+
label: z.string().min(1).max(250),
|
|
13
|
+
elements: z.array(z.any()).optional().default([])
|
|
14
|
+
});
|
|
15
|
+
const _Options = z.object({
|
|
16
|
+
instructions: z.string().optional(),
|
|
17
|
+
tokensPerElement: z.number().min(1).max(1e5).optional().default(250),
|
|
18
|
+
chunkLength: z.number().min(100).max(1e5).optional().default(16e3),
|
|
19
|
+
initialGroups: z.array(_InitialGroup).optional().default([])
|
|
20
|
+
});
|
|
21
|
+
const END = "\u25A0END\u25A0";
|
|
22
|
+
const normalizeLabel = (label) => {
|
|
23
|
+
return label.trim().toLowerCase().replace(/^(group|new group|new)\s*[-:]\s*/i, "").replace(/^(group|new group|new)\s+/i, "").trim();
|
|
24
|
+
};
|
|
25
|
+
const group = async (input, _options, ctx) => {
|
|
26
|
+
ctx.controller.signal.throwIfAborted();
|
|
27
|
+
const options = _Options.parse(_options ?? {});
|
|
28
|
+
const tokenizer = await getTokenizer();
|
|
29
|
+
const model = await ctx.getModel();
|
|
30
|
+
if (input.length === 0) {
|
|
31
|
+
return [];
|
|
32
|
+
}
|
|
33
|
+
const groups = /* @__PURE__ */ new Map();
|
|
34
|
+
const groupElements = /* @__PURE__ */ new Map();
|
|
35
|
+
const elementGroups = /* @__PURE__ */ new Map();
|
|
36
|
+
const labelToGroupId = /* @__PURE__ */ new Map();
|
|
37
|
+
let groupIdCounter = 0;
|
|
38
|
+
options.initialGroups.forEach((ig) => {
|
|
39
|
+
const normalized = normalizeLabel(ig.label);
|
|
40
|
+
groups.set(ig.id, { id: ig.id, label: ig.label, normalizedLabel: normalized });
|
|
41
|
+
groupElements.set(ig.id, /* @__PURE__ */ new Set());
|
|
42
|
+
labelToGroupId.set(normalized, ig.id);
|
|
43
|
+
});
|
|
44
|
+
const elements = input.map((element, idx) => ({
|
|
45
|
+
element,
|
|
46
|
+
index: idx,
|
|
47
|
+
stringified: stringify(element, false)
|
|
48
|
+
}));
|
|
49
|
+
const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER;
|
|
50
|
+
const TOKENS_INSTRUCTIONS_MAX = options.instructions ? clamp(tokenizer.count(options.instructions), 100, TOKENS_TOTAL_MAX * 0.2) : 0;
|
|
51
|
+
const TOKENS_AVAILABLE = TOKENS_TOTAL_MAX - TOKENS_INSTRUCTIONS_MAX;
|
|
52
|
+
const TOKENS_FOR_GROUPS_MAX = Math.floor(TOKENS_AVAILABLE * 0.4);
|
|
53
|
+
const TOKENS_FOR_ELEMENTS_MAX = Math.floor(TOKENS_AVAILABLE * 0.6);
|
|
54
|
+
const MAX_ELEMENTS_PER_CHUNK = 50;
|
|
55
|
+
const elementChunks = [];
|
|
56
|
+
let currentChunk = [];
|
|
57
|
+
let currentTokens = 0;
|
|
58
|
+
for (const elem of elements) {
|
|
59
|
+
const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement);
|
|
60
|
+
const elemTokens = tokenizer.count(truncated);
|
|
61
|
+
if ((currentTokens + elemTokens > TOKENS_FOR_ELEMENTS_MAX || currentChunk.length >= MAX_ELEMENTS_PER_CHUNK) && currentChunk.length > 0) {
|
|
62
|
+
elementChunks.push(currentChunk);
|
|
63
|
+
currentChunk = [];
|
|
64
|
+
currentTokens = 0;
|
|
65
|
+
}
|
|
66
|
+
currentChunk.push(elem.index);
|
|
67
|
+
currentTokens += elemTokens;
|
|
68
|
+
}
|
|
69
|
+
if (currentChunk.length > 0) {
|
|
70
|
+
elementChunks.push(currentChunk);
|
|
71
|
+
}
|
|
72
|
+
const getGroupChunks = () => {
|
|
73
|
+
const allGroupIds2 = Array.from(groups.keys());
|
|
74
|
+
if (allGroupIds2.length === 0) return [[]];
|
|
75
|
+
const chunks = [];
|
|
76
|
+
let currentChunk2 = [];
|
|
77
|
+
let currentTokens2 = 0;
|
|
78
|
+
for (const groupId of allGroupIds2) {
|
|
79
|
+
const group2 = groups.get(groupId);
|
|
80
|
+
const groupTokens = tokenizer.count(`${group2.label}`) + 10;
|
|
81
|
+
if (currentTokens2 + groupTokens > TOKENS_FOR_GROUPS_MAX && currentChunk2.length > 0) {
|
|
82
|
+
chunks.push(currentChunk2);
|
|
83
|
+
currentChunk2 = [];
|
|
84
|
+
currentTokens2 = 0;
|
|
85
|
+
}
|
|
86
|
+
currentChunk2.push(groupId);
|
|
87
|
+
currentTokens2 += groupTokens;
|
|
88
|
+
}
|
|
89
|
+
if (currentChunk2.length > 0) {
|
|
90
|
+
chunks.push(currentChunk2);
|
|
91
|
+
}
|
|
92
|
+
return chunks.length > 0 ? chunks : [[]];
|
|
93
|
+
};
|
|
94
|
+
const processChunk = async (elementIndices, groupIds) => {
|
|
95
|
+
const elementsText = elementIndices.map((idx, i) => {
|
|
96
|
+
const elem = elements[idx];
|
|
97
|
+
const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement);
|
|
98
|
+
return `\u25A0${i}: ${truncated}\u25A0`;
|
|
99
|
+
}).join("\n");
|
|
100
|
+
const groupsList = groupIds.map((gid) => groups.get(gid).label);
|
|
101
|
+
const groupsText = groupsList.length > 0 ? `**Existing Groups (prefer reusing these):**
|
|
102
|
+
${groupsList.map((l) => `- ${l}`).join("\n")}
|
|
103
|
+
|
|
104
|
+
` : "";
|
|
105
|
+
const systemPrompt = `You are grouping elements into cohesive groups.
|
|
106
|
+
|
|
107
|
+
${options.instructions ? `**Instructions:** ${options.instructions}
|
|
108
|
+
` : "**Instructions:** Group similar elements together."}
|
|
109
|
+
|
|
110
|
+
**Important:**
|
|
111
|
+
- Each element gets exactly ONE group label
|
|
112
|
+
- Use EXACT SAME label for similar items (case-sensitive)
|
|
113
|
+
- Create new descriptive labels when needed
|
|
114
|
+
|
|
115
|
+
**Output Format:**
|
|
116
|
+
One line per element:
|
|
117
|
+
\u25A00:Group Label\u25A0
|
|
118
|
+
\u25A01:Group Label\u25A0
|
|
119
|
+
${END}`.trim();
|
|
120
|
+
const userPrompt = `${groupsText}**Elements (\u25A00 to \u25A0${elementIndices.length - 1}):**
|
|
121
|
+
${elementsText}
|
|
122
|
+
|
|
123
|
+
**Task:** For each element, output one line with its group label.
|
|
124
|
+
${END}`.trim();
|
|
125
|
+
const { extracted } = await ctx.generateContent({
|
|
126
|
+
systemPrompt,
|
|
127
|
+
stopSequences: [END],
|
|
128
|
+
messages: [{ type: "text", role: "user", content: userPrompt }],
|
|
129
|
+
transform: (text) => {
|
|
130
|
+
const assignments = [];
|
|
131
|
+
const regex = /■(\d+):([^■]+)■/g;
|
|
132
|
+
let match;
|
|
133
|
+
while ((match = regex.exec(text)) !== null) {
|
|
134
|
+
const idx = parseInt(match[1] ?? "", 10);
|
|
135
|
+
if (isNaN(idx) || idx < 0 || idx >= elementIndices.length) continue;
|
|
136
|
+
const label = (match[2] ?? "").trim();
|
|
137
|
+
if (!label) continue;
|
|
138
|
+
assignments.push({
|
|
139
|
+
elementIndex: elementIndices[idx],
|
|
140
|
+
label: label.slice(0, 250)
|
|
141
|
+
});
|
|
142
|
+
}
|
|
143
|
+
return assignments;
|
|
144
|
+
}
|
|
145
|
+
});
|
|
146
|
+
return extracted;
|
|
147
|
+
};
|
|
148
|
+
const elementLimit = pLimit(10);
|
|
149
|
+
const groupLimit = pLimit(10);
|
|
150
|
+
const allChunkResults = await Promise.all(
|
|
151
|
+
elementChunks.map(
|
|
152
|
+
(elementChunk) => elementLimit(async () => {
|
|
153
|
+
const groupChunks = getGroupChunks();
|
|
154
|
+
const allAssignments = await Promise.all(
|
|
155
|
+
groupChunks.map((groupChunk) => groupLimit(() => processChunk(elementChunk, groupChunk)))
|
|
156
|
+
);
|
|
157
|
+
return allAssignments.flat();
|
|
158
|
+
})
|
|
159
|
+
)
|
|
160
|
+
);
|
|
161
|
+
for (const assignments of allChunkResults) {
|
|
162
|
+
for (const { elementIndex, label } of assignments) {
|
|
163
|
+
const normalized = normalizeLabel(label);
|
|
164
|
+
let groupId = labelToGroupId.get(normalized);
|
|
165
|
+
if (!groupId) {
|
|
166
|
+
groupId = `group_${groupIdCounter++}`;
|
|
167
|
+
groups.set(groupId, { id: groupId, label, normalizedLabel: normalized });
|
|
168
|
+
groupElements.set(groupId, /* @__PURE__ */ new Set());
|
|
169
|
+
labelToGroupId.set(normalized, groupId);
|
|
170
|
+
}
|
|
171
|
+
groupElements.get(groupId).add(elementIndex);
|
|
172
|
+
if (!elementGroups.has(elementIndex)) {
|
|
173
|
+
elementGroups.set(elementIndex, /* @__PURE__ */ new Set());
|
|
174
|
+
}
|
|
175
|
+
elementGroups.get(elementIndex).add(groupId);
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
const allGroupIds = Array.from(groups.keys());
|
|
179
|
+
if (allGroupIds.length > 0) {
|
|
180
|
+
const elementsNeedingReview = [];
|
|
181
|
+
for (const elem of elements) {
|
|
182
|
+
const seenGroups = elementGroups.get(elem.index) ?? /* @__PURE__ */ new Set();
|
|
183
|
+
const unseenCount = allGroupIds.filter((gid) => !seenGroups.has(gid)).length;
|
|
184
|
+
if (unseenCount > 0) {
|
|
185
|
+
elementsNeedingReview.push(elem.index);
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
if (elementsNeedingReview.length > 0) {
|
|
189
|
+
const reviewChunks = [];
|
|
190
|
+
let reviewChunk = [];
|
|
191
|
+
let reviewTokens = 0;
|
|
192
|
+
for (const elemIdx of elementsNeedingReview) {
|
|
193
|
+
const elem = elements[elemIdx];
|
|
194
|
+
const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement);
|
|
195
|
+
const elemTokens = tokenizer.count(truncated);
|
|
196
|
+
const shouldStartNewChunk = (reviewTokens + elemTokens > TOKENS_FOR_ELEMENTS_MAX || reviewChunk.length >= MAX_ELEMENTS_PER_CHUNK) && reviewChunk.length > 0;
|
|
197
|
+
if (shouldStartNewChunk) {
|
|
198
|
+
reviewChunks.push(reviewChunk);
|
|
199
|
+
reviewChunk = [];
|
|
200
|
+
reviewTokens = 0;
|
|
201
|
+
}
|
|
202
|
+
reviewChunk.push(elemIdx);
|
|
203
|
+
reviewTokens += elemTokens;
|
|
204
|
+
}
|
|
205
|
+
if (reviewChunk.length > 0) {
|
|
206
|
+
reviewChunks.push(reviewChunk);
|
|
207
|
+
}
|
|
208
|
+
const reviewResults = await Promise.all(
|
|
209
|
+
reviewChunks.map(
|
|
210
|
+
(chunk) => elementLimit(async () => {
|
|
211
|
+
const groupChunks = getGroupChunks();
|
|
212
|
+
const allAssignments = await Promise.all(
|
|
213
|
+
groupChunks.map((groupChunk) => groupLimit(() => processChunk(chunk, groupChunk)))
|
|
214
|
+
);
|
|
215
|
+
return allAssignments.flat();
|
|
216
|
+
})
|
|
217
|
+
)
|
|
218
|
+
);
|
|
219
|
+
const updateElementGroupAssignment = (elementIndex, label) => {
|
|
220
|
+
const normalized = normalizeLabel(label);
|
|
221
|
+
const groupId = labelToGroupId.get(normalized);
|
|
222
|
+
if (!groupId) return;
|
|
223
|
+
groupElements.get(groupId).add(elementIndex);
|
|
224
|
+
const elemGroups = elementGroups.get(elementIndex) ?? /* @__PURE__ */ new Set();
|
|
225
|
+
if (!elementGroups.has(elementIndex)) {
|
|
226
|
+
elementGroups.set(elementIndex, elemGroups);
|
|
227
|
+
}
|
|
228
|
+
elemGroups.add(groupId);
|
|
229
|
+
};
|
|
230
|
+
for (const assignments of reviewResults) {
|
|
231
|
+
for (const { elementIndex, label } of assignments) {
|
|
232
|
+
updateElementGroupAssignment(elementIndex, label);
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
}
|
|
237
|
+
for (const [elementIndex, groupSet] of elementGroups.entries()) {
|
|
238
|
+
if (groupSet.size > 1) {
|
|
239
|
+
const groupIds = Array.from(groupSet);
|
|
240
|
+
for (const gid of groupIds) {
|
|
241
|
+
groupElements.get(gid)?.delete(elementIndex);
|
|
242
|
+
}
|
|
243
|
+
const finalGroupId = groupIds[0];
|
|
244
|
+
groupElements.get(finalGroupId).add(elementIndex);
|
|
245
|
+
}
|
|
246
|
+
}
|
|
247
|
+
const result = [];
|
|
248
|
+
for (const [groupId, elementIndices] of groupElements.entries()) {
|
|
249
|
+
if (elementIndices.size > 0) {
|
|
250
|
+
const groupInfo = groups.get(groupId);
|
|
251
|
+
result.push({
|
|
252
|
+
id: groupInfo.id,
|
|
253
|
+
label: groupInfo.label,
|
|
254
|
+
elements: Array.from(elementIndices).map((idx) => elements[idx].element)
|
|
255
|
+
});
|
|
256
|
+
}
|
|
257
|
+
}
|
|
258
|
+
return result;
|
|
259
|
+
};
|
|
260
|
+
Zai.prototype.group = function(input, _options) {
|
|
261
|
+
const context = new ZaiContext({
|
|
262
|
+
client: this.client,
|
|
263
|
+
modelId: this.Model,
|
|
264
|
+
taskId: this.taskId,
|
|
265
|
+
taskType: "zai.group",
|
|
266
|
+
adapter: this.adapter
|
|
267
|
+
});
|
|
268
|
+
return new Response(context, group(input, _options, context), (result) => {
|
|
269
|
+
const merged = {};
|
|
270
|
+
result.forEach((group2) => {
|
|
271
|
+
if (!merged[group2.label]) {
|
|
272
|
+
merged[group2.label] = [];
|
|
273
|
+
}
|
|
274
|
+
merged[group2.label].push(...group2.elements);
|
|
275
|
+
});
|
|
276
|
+
return merged;
|
|
277
|
+
});
|
|
278
|
+
};
|
package/dist/operations/label.js
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import { z } from "@bpinternal/zui";
|
|
2
2
|
import { chunk, clamp } from "lodash-es";
|
|
3
|
+
import pLimit from "p-limit";
|
|
3
4
|
import { ZaiContext } from "../context";
|
|
4
5
|
import { Response } from "../response";
|
|
5
6
|
import { getTokenizer } from "../tokenizer";
|
|
@@ -87,9 +88,10 @@ const label = async (input, _labels, _options, ctx) => {
|
|
|
87
88
|
);
|
|
88
89
|
const inputAsString = stringify(input);
|
|
89
90
|
if (tokenizer.count(inputAsString) > CHUNK_INPUT_MAX_TOKENS) {
|
|
91
|
+
const limit = pLimit(10);
|
|
90
92
|
const tokens = tokenizer.split(inputAsString);
|
|
91
93
|
const chunks = chunk(tokens, CHUNK_INPUT_MAX_TOKENS).map((x) => x.join(""));
|
|
92
|
-
const allLabels = await Promise.all(chunks.map((chunk2) => label(chunk2, _labels, _options, ctx)));
|
|
94
|
+
const allLabels = await Promise.all(chunks.map((chunk2) => limit(() => label(chunk2, _labels, _options, ctx))));
|
|
93
95
|
return allLabels.reduce((acc, x) => {
|
|
94
96
|
Object.keys(x).forEach((key) => {
|
|
95
97
|
if (acc[key]?.value === true) {
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import { z } from "@bpinternal/zui";
|
|
2
2
|
import { chunk } from "lodash-es";
|
|
3
|
+
import pLimit from "p-limit";
|
|
3
4
|
import { ZaiContext } from "../context";
|
|
4
5
|
import { Response } from "../response";
|
|
5
6
|
import { getTokenizer } from "../tokenizer";
|
|
@@ -54,8 +55,9 @@ ${newText}
|
|
|
54
55
|
const useMergeSort = parts >= Math.pow(2, N);
|
|
55
56
|
const chunkSize = Math.ceil(tokens.length / (parts * N));
|
|
56
57
|
if (useMergeSort) {
|
|
58
|
+
const limit = pLimit(10);
|
|
57
59
|
const chunks = chunk(tokens, chunkSize).map((x) => x.join(""));
|
|
58
|
-
const allSummaries = (await Promise.allSettled(chunks.map((chunk2) => summarize(chunk2, options, ctx)))).filter((x) => x.status === "fulfilled").map((x) => x.value);
|
|
60
|
+
const allSummaries = (await Promise.allSettled(chunks.map((chunk2) => limit(() => summarize(chunk2, options, ctx))))).filter((x) => x.status === "fulfilled").map((x) => x.value);
|
|
59
61
|
return summarize(allSummaries.join("\n\n============\n\n"), options, ctx);
|
|
60
62
|
}
|
|
61
63
|
const summaries = [];
|