@botpress/zai 2.1.20 → 2.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@botpress/zai",
3
3
  "description": "Zui AI (zai) – An LLM utility library written on top of Zui and the Botpress API",
4
- "version": "2.1.20",
4
+ "version": "2.3.0",
5
5
  "main": "./dist/index.js",
6
6
  "types": "./dist/index.d.ts",
7
7
  "exports": {
package/src/index.ts CHANGED
@@ -7,5 +7,8 @@ import './operations/check'
7
7
  import './operations/filter'
8
8
  import './operations/extract'
9
9
  import './operations/label'
10
+ import './operations/group'
11
+ import './operations/rate'
12
+ import './operations/sort'
10
13
 
11
14
  export { Zai }
@@ -0,0 +1,543 @@
1
+ // eslint-disable consistent-type-definitions
2
+ import { z } from '@bpinternal/zui'
3
+ import { clamp } from 'lodash-es'
4
+ import pLimit from 'p-limit'
5
+ import { ZaiContext } from '../context'
6
+ import { Response } from '../response'
7
+ import { getTokenizer } from '../tokenizer'
8
+ import { fastHash, stringify } from '../utils'
9
+ import { Zai } from '../zai'
10
+ import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
11
+
12
+ export type Group<T> = {
13
+ id: string
14
+ label: string
15
+ elements: T[]
16
+ }
17
+
18
+ type InitialGroup = {
19
+ id: string
20
+ label: string
21
+ elements?: unknown[]
22
+ }
23
+
24
+ const _InitialGroup = z.object({
25
+ id: z.string().min(1).max(100),
26
+ label: z.string().min(1).max(250),
27
+ elements: z.array(z.any()).optional().default([]),
28
+ })
29
+
30
+ export type Options = {
31
+ instructions?: string
32
+ tokensPerElement?: number
33
+ chunkLength?: number
34
+ initialGroups?: Array<InitialGroup>
35
+ }
36
+
37
+ const _Options = z.object({
38
+ instructions: z.string().optional(),
39
+ tokensPerElement: z.number().min(1).max(100_000).optional().default(250),
40
+ chunkLength: z.number().min(100).max(100_000).optional().default(16_000),
41
+ initialGroups: z.array(_InitialGroup).optional().default([]),
42
+ })
43
+
44
+ declare module '@botpress/zai' {
45
+ interface Zai {
46
+ group<T>(input: Array<T>, options?: Options): Response<Array<Group<T>>, Record<string, T[]>>
47
+ }
48
+ }
49
+
50
+ const END = '■END■'
51
+
52
+ // Simplified data structures
53
+ type GroupInfo = {
54
+ id: string
55
+ label: string
56
+ normalizedLabel: string
57
+ }
58
+
59
+ const normalizeLabel = (label: string): string => {
60
+ return label
61
+ .trim()
62
+ .toLowerCase()
63
+ .replace(/^(group|new group|new)\s*[-:]\s*/i, '')
64
+ .replace(/^(group|new group|new)\s+/i, '')
65
+ .trim()
66
+ }
67
+
68
+ const group = async <T>(input: Array<T>, _options: Options | undefined, ctx: ZaiContext): Promise<Array<Group<T>>> => {
69
+ ctx.controller.signal.throwIfAborted()
70
+
71
+ const options = _Options.parse(_options ?? {})
72
+ const tokenizer = await getTokenizer()
73
+ const model = await ctx.getModel()
74
+
75
+ const taskId = ctx.taskId
76
+ const taskType = 'zai.group'
77
+
78
+ if (input.length === 0) {
79
+ return []
80
+ }
81
+
82
+ // Simple data structures
83
+ const groups = new Map<string, GroupInfo>() // groupId -> GroupInfo
84
+ const groupElements = new Map<string, Set<number>>() // groupId -> Set of element indices
85
+ const elementGroups = new Map<number, Set<string>>() // elementIndex -> Set of groupIds seen/assigned
86
+ const labelToGroupId = new Map<string, string>() // normalized label -> groupId
87
+ let groupIdCounter = 0
88
+
89
+ // Initialize with provided groups
90
+ options.initialGroups.forEach((ig) => {
91
+ const normalized = normalizeLabel(ig.label)
92
+ groups.set(ig.id, { id: ig.id, label: ig.label, normalizedLabel: normalized })
93
+ groupElements.set(ig.id, new Set())
94
+ labelToGroupId.set(normalized, ig.id)
95
+ })
96
+
97
+ // Prepare elements
98
+ const elements = input.map((element, idx) => ({
99
+ element,
100
+ index: idx,
101
+ stringified: stringify(element, false),
102
+ }))
103
+
104
+ // Token budget
105
+ const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER
106
+ const TOKENS_INSTRUCTIONS_MAX = options.instructions
107
+ ? clamp(tokenizer.count(options.instructions), 100, TOKENS_TOTAL_MAX * 0.2)
108
+ : 0
109
+ const TOKENS_AVAILABLE = TOKENS_TOTAL_MAX - TOKENS_INSTRUCTIONS_MAX
110
+ const TOKENS_FOR_GROUPS_MAX = Math.floor(TOKENS_AVAILABLE * 0.4)
111
+ const TOKENS_FOR_ELEMENTS_MAX = Math.floor(TOKENS_AVAILABLE * 0.6)
112
+
113
+ // Chunk elements by token budget
114
+ const MAX_ELEMENTS_PER_CHUNK = 50
115
+ const elementChunks: number[][] = [] // Array of element indices
116
+ let currentChunk: number[] = []
117
+ let currentTokens = 0
118
+
119
+ for (const elem of elements) {
120
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement)
121
+ const elemTokens = tokenizer.count(truncated)
122
+
123
+ if (
124
+ (currentTokens + elemTokens > TOKENS_FOR_ELEMENTS_MAX || currentChunk.length >= MAX_ELEMENTS_PER_CHUNK) &&
125
+ currentChunk.length > 0
126
+ ) {
127
+ elementChunks.push(currentChunk)
128
+ currentChunk = []
129
+ currentTokens = 0
130
+ }
131
+
132
+ currentChunk.push(elem.index)
133
+ currentTokens += elemTokens
134
+ }
135
+
136
+ if (currentChunk.length > 0) {
137
+ elementChunks.push(currentChunk)
138
+ }
139
+
140
+ // Helper to chunk groups
141
+ const getGroupChunks = (): string[][] => {
142
+ const allGroupIds = Array.from(groups.keys())
143
+ if (allGroupIds.length === 0) return [[]]
144
+
145
+ const chunks: string[][] = []
146
+ let currentChunk: string[] = []
147
+ let currentTokens = 0
148
+
149
+ for (const groupId of allGroupIds) {
150
+ const group = groups.get(groupId)!
151
+ const groupTokens = tokenizer.count(`${group.label}`) + 10
152
+
153
+ if (currentTokens + groupTokens > TOKENS_FOR_GROUPS_MAX && currentChunk.length > 0) {
154
+ chunks.push(currentChunk)
155
+ currentChunk = []
156
+ currentTokens = 0
157
+ }
158
+
159
+ currentChunk.push(groupId)
160
+ currentTokens += groupTokens
161
+ }
162
+
163
+ if (currentChunk.length > 0) {
164
+ chunks.push(currentChunk)
165
+ }
166
+
167
+ return chunks.length > 0 ? chunks : [[]]
168
+ }
169
+
170
+ // Process elements against groups and get assignments
171
+ const processChunk = async (
172
+ elementIndices: number[],
173
+ groupIds: string[]
174
+ ): Promise<Array<{ elementIndex: number; label: string }>> => {
175
+ // Get examples from adapter for active learning
176
+ const chunkElements = elementIndices.map((idx) => elements[idx].element)
177
+ const chunkInputStr = JSON.stringify(chunkElements)
178
+
179
+ const examples =
180
+ taskId && ctx.adapter
181
+ ? await ctx.adapter.getExamples<string, Array<{ elementIndex: number; label: string }>>({
182
+ input: chunkInputStr.slice(0, 1000), // Limit search string length
183
+ taskType,
184
+ taskId,
185
+ })
186
+ : []
187
+
188
+ // Check for exact match (cache hit)
189
+
190
+ const key = fastHash(
191
+ stringify({
192
+ taskId,
193
+ taskType,
194
+ input: chunkInputStr,
195
+ instructions: options.instructions ?? '',
196
+ groupIds: groupIds.join(','),
197
+ })
198
+ )
199
+
200
+ const exactMatch = examples.find((x) => x.key === key)
201
+ if (exactMatch && exactMatch.output) {
202
+ return exactMatch.output
203
+ }
204
+
205
+ const elementsText = elementIndices
206
+ .map((idx, i) => {
207
+ const elem = elements[idx]
208
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement)
209
+ return `■${i}: ${truncated}■`
210
+ })
211
+ .join('\n')
212
+
213
+ const groupsList = groupIds.map((gid) => groups.get(gid)!.label)
214
+ const groupsText =
215
+ groupsList.length > 0
216
+ ? `**Existing Groups (prefer reusing these):**\n${groupsList.map((l) => `- ${l}`).join('\n')}\n\n`
217
+ : ''
218
+
219
+ // Format examples for few-shot learning
220
+ const exampleMessages: Array<{ type: 'text'; role: 'user' | 'assistant'; content: string }> = []
221
+
222
+ for (const example of examples.slice(0, 5)) {
223
+ try {
224
+ const exampleInput = JSON.parse(example.input)
225
+ const exampleElements = Array.isArray(exampleInput) ? exampleInput : [exampleInput]
226
+
227
+ // User message
228
+ const exampleElementsText = exampleElements
229
+ .map((el, i) => `■${i}: ${stringify(el, false).slice(0, 200)}■`)
230
+ .join('\n')
231
+
232
+ exampleMessages.push({
233
+ type: 'text',
234
+ role: 'user',
235
+ content: `Expert Example - Elements to group:
236
+ ${exampleElementsText}
237
+
238
+ Group each element.`,
239
+ })
240
+
241
+ // Assistant message
242
+ const exampleOutput = example.output
243
+ if (Array.isArray(exampleOutput) && exampleOutput.length > 0) {
244
+ const formattedAssignments = exampleOutput
245
+ .map((assignment) => `■${assignment.elementIndex}:${assignment.label}■`)
246
+ .join('\n')
247
+
248
+ exampleMessages.push({
249
+ type: 'text',
250
+ role: 'assistant',
251
+ content: `${formattedAssignments}\n${END}`,
252
+ })
253
+
254
+ if (example.explanation) {
255
+ exampleMessages.push({
256
+ type: 'text',
257
+ role: 'assistant',
258
+ content: `Reasoning: ${example.explanation}`,
259
+ })
260
+ }
261
+ }
262
+ } catch {
263
+ // Skip malformed examples
264
+ }
265
+ }
266
+
267
+ const systemPrompt = `You are grouping elements into cohesive groups.
268
+
269
+ ${options.instructions ? `**Instructions:** ${options.instructions}\n` : '**Instructions:** Group similar elements together.'}
270
+
271
+ **Important:**
272
+ - Each element gets exactly ONE group label
273
+ - Use EXACT SAME label for similar items (case-sensitive)
274
+ - Create new descriptive labels when needed
275
+
276
+ **Output Format:**
277
+ One line per element:
278
+ ■0:Group Label■
279
+ ■1:Group Label■
280
+ ${END}`.trim()
281
+
282
+ const userPrompt = `${groupsText}**Elements (■0 to ■${elementIndices.length - 1}):**
283
+ ${elementsText}
284
+
285
+ **Task:** For each element, output one line with its group label.
286
+ ${END}`.trim()
287
+
288
+ const { extracted } = await ctx.generateContent({
289
+ systemPrompt,
290
+ stopSequences: [END],
291
+ messages: [...exampleMessages, { type: 'text', role: 'user', content: userPrompt }],
292
+ transform: (text) => {
293
+ const assignments: Array<{ elementIndex: number; label: string }> = []
294
+ const regex = /■(\d+):([^■]+)■/g
295
+ let match: RegExpExecArray | null
296
+
297
+ while ((match = regex.exec(text)) !== null) {
298
+ const idx = parseInt(match[1] ?? '', 10)
299
+ if (isNaN(idx) || idx < 0 || idx >= elementIndices.length) continue
300
+
301
+ const label = (match[2] ?? '').trim()
302
+ if (!label) continue
303
+
304
+ assignments.push({
305
+ elementIndex: elementIndices[idx],
306
+ label: label.slice(0, 250),
307
+ })
308
+ }
309
+
310
+ return assignments
311
+ },
312
+ })
313
+
314
+ return extracted
315
+ }
316
+
317
+ // Phase 1: Process all element chunks against current groups IN PARALLEL
318
+ const elementLimit = pLimit(10) // Separate limiter for element chunks
319
+ const groupLimit = pLimit(10) // Separate limiter for group chunks
320
+
321
+ // Collect all assignments from parallel processing
322
+ const allChunkResults = await Promise.all(
323
+ elementChunks.map((elementChunk) =>
324
+ elementLimit(async () => {
325
+ const groupChunks = getGroupChunks()
326
+
327
+ const allAssignments = await Promise.all(
328
+ groupChunks.map((groupChunk) => groupLimit(() => processChunk(elementChunk, groupChunk)))
329
+ )
330
+
331
+ return allAssignments.flat()
332
+ })
333
+ )
334
+ )
335
+
336
+ // Process all assignments sequentially to avoid race conditions
337
+ for (const assignments of allChunkResults) {
338
+ for (const { elementIndex, label } of assignments) {
339
+ const normalized = normalizeLabel(label)
340
+ let groupId = labelToGroupId.get(normalized)
341
+
342
+ if (!groupId) {
343
+ // Create new group
344
+ groupId = `group_${groupIdCounter++}`
345
+ groups.set(groupId, { id: groupId, label, normalizedLabel: normalized })
346
+ groupElements.set(groupId, new Set())
347
+ labelToGroupId.set(normalized, groupId)
348
+ }
349
+
350
+ // Add element to group
351
+ groupElements.get(groupId)!.add(elementIndex)
352
+
353
+ // Track that element saw this group
354
+ if (!elementGroups.has(elementIndex)) {
355
+ elementGroups.set(elementIndex, new Set())
356
+ }
357
+ elementGroups.get(elementIndex)!.add(groupId)
358
+ }
359
+ }
360
+
361
+ // Phase 2: Ensure all elements saw all groups (coverage guarantee)
362
+ const allGroupIds = Array.from(groups.keys())
363
+
364
+ if (allGroupIds.length > 0) {
365
+ const elementsNeedingReview: number[] = []
366
+
367
+ for (const elem of elements) {
368
+ const seenGroups = elementGroups.get(elem.index) ?? new Set()
369
+ const unseenCount = allGroupIds.filter((gid) => !seenGroups.has(gid)).length
370
+
371
+ if (unseenCount > 0) {
372
+ elementsNeedingReview.push(elem.index)
373
+ }
374
+ }
375
+
376
+ if (elementsNeedingReview.length > 0) {
377
+ // Chunk elements needing review
378
+ const reviewChunks: number[][] = []
379
+ let reviewChunk: number[] = []
380
+ let reviewTokens = 0
381
+
382
+ for (const elemIdx of elementsNeedingReview) {
383
+ const elem = elements[elemIdx]
384
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement)
385
+ const elemTokens = tokenizer.count(truncated)
386
+
387
+ const shouldStartNewChunk =
388
+ (reviewTokens + elemTokens > TOKENS_FOR_ELEMENTS_MAX || reviewChunk.length >= MAX_ELEMENTS_PER_CHUNK) &&
389
+ reviewChunk.length > 0
390
+
391
+ if (shouldStartNewChunk) {
392
+ reviewChunks.push(reviewChunk)
393
+ reviewChunk = []
394
+ reviewTokens = 0
395
+ }
396
+
397
+ reviewChunk.push(elemIdx)
398
+ reviewTokens += elemTokens
399
+ }
400
+
401
+ if (reviewChunk.length > 0) {
402
+ reviewChunks.push(reviewChunk)
403
+ }
404
+
405
+ // Process review chunks IN PARALLEL
406
+ const reviewResults = await Promise.all(
407
+ reviewChunks.map((chunk) =>
408
+ elementLimit(async () => {
409
+ const groupChunks = getGroupChunks()
410
+
411
+ const allAssignments = await Promise.all(
412
+ groupChunks.map((groupChunk) => groupLimit(() => processChunk(chunk, groupChunk)))
413
+ )
414
+
415
+ return allAssignments.flat()
416
+ })
417
+ )
418
+ )
419
+
420
+ // Mark groups as seen and update assignments (sequential to avoid races)
421
+ const updateElementGroupAssignment = (elementIndex: number, label: string) => {
422
+ const normalized = normalizeLabel(label)
423
+ const groupId = labelToGroupId.get(normalized)
424
+ if (!groupId) return
425
+
426
+ // Add to group and mark as seen
427
+ groupElements.get(groupId)!.add(elementIndex)
428
+
429
+ // Initialize element groups if needed
430
+ const elemGroups = elementGroups.get(elementIndex) ?? new Set()
431
+ if (!elementGroups.has(elementIndex)) {
432
+ elementGroups.set(elementIndex, elemGroups)
433
+ }
434
+ elemGroups.add(groupId)
435
+ }
436
+
437
+ for (const assignments of reviewResults) {
438
+ for (const { elementIndex, label } of assignments) {
439
+ updateElementGroupAssignment(elementIndex, label)
440
+ }
441
+ }
442
+ }
443
+ }
444
+
445
+ // Phase 3: Resolve conflicts (elements in multiple groups)
446
+ for (const [elementIndex, groupSet] of elementGroups.entries()) {
447
+ if (groupSet.size > 1) {
448
+ // Element is in multiple groups, keep only the most common assignment
449
+ const groupIds = Array.from(groupSet)
450
+
451
+ // Remove from all groups
452
+ for (const gid of groupIds) {
453
+ groupElements.get(gid)?.delete(elementIndex)
454
+ }
455
+
456
+ // Re-assign to first group (or could use LLM to decide)
457
+ const finalGroupId = groupIds[0]
458
+ groupElements.get(finalGroupId)!.add(elementIndex)
459
+ }
460
+ }
461
+
462
+ // Build final result
463
+ const result: Array<Group<T>> = []
464
+
465
+ for (const [groupId, elementIndices] of groupElements.entries()) {
466
+ if (elementIndices.size > 0) {
467
+ const groupInfo = groups.get(groupId)!
468
+ result.push({
469
+ id: groupInfo.id,
470
+ label: groupInfo.label,
471
+ elements: Array.from(elementIndices).map((idx) => elements[idx].element),
472
+ })
473
+ }
474
+ }
475
+
476
+ // Save example for active learning
477
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
478
+ const key = fastHash(
479
+ stringify({
480
+ taskId,
481
+ taskType,
482
+ input: JSON.stringify(input),
483
+ instructions: options.instructions ?? '',
484
+ })
485
+ )
486
+
487
+ // Build output format for saving
488
+ const outputAssignments: Array<{ elementIndex: number; label: string }> = []
489
+ for (const [groupId, elementIndices] of groupElements.entries()) {
490
+ const groupInfo = groups.get(groupId)!
491
+ for (const idx of elementIndices) {
492
+ outputAssignments.push({
493
+ elementIndex: idx,
494
+ label: groupInfo.label,
495
+ })
496
+ }
497
+ }
498
+
499
+ // Note: We don't have direct access to usage metadata here since it's distributed
500
+ // across many parallel operations. We'll use default values.
501
+ await ctx.adapter.saveExample({
502
+ key,
503
+ taskType,
504
+ taskId,
505
+ input: JSON.stringify(input),
506
+ output: result,
507
+ instructions: options.instructions ?? '',
508
+ metadata: {
509
+ cost: { input: 0, output: 0 },
510
+ latency: 0,
511
+ model: ctx.modelId,
512
+ tokens: { input: 0, output: 0 },
513
+ },
514
+ })
515
+ }
516
+
517
+ return result
518
+ }
519
+
520
+ Zai.prototype.group = function <T>(
521
+ this: Zai,
522
+ input: Array<T>,
523
+ _options?: Options
524
+ ): Response<Array<Group<T>>, Record<string, T[]>> {
525
+ const context = new ZaiContext({
526
+ client: this.client,
527
+ modelId: this.Model,
528
+ taskId: this.taskId,
529
+ taskType: 'zai.group',
530
+ adapter: this.adapter,
531
+ })
532
+
533
+ return new Response<Array<Group<T>>, Record<string, T[]>>(context, group(input, _options, context), (result) => {
534
+ const merged: Record<string, T[]> = {}
535
+ result.forEach((group) => {
536
+ if (!merged[group.label]) {
537
+ merged[group.label] = []
538
+ }
539
+ merged[group.label].push(...group.elements)
540
+ })
541
+ return merged
542
+ })
543
+ }