@botpress/zai 2.1.19 → 2.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@botpress/zai",
3
3
  "description": "Zui AI (zai) – An LLM utility library written on top of Zui and the Botpress API",
4
- "version": "2.1.19",
4
+ "version": "2.2.0",
5
5
  "main": "./dist/index.js",
6
6
  "types": "./dist/index.d.ts",
7
7
  "exports": {
@@ -35,7 +35,8 @@
35
35
  "@botpress/cognitive": "0.1.50",
36
36
  "json5": "^2.2.3",
37
37
  "jsonrepair": "^3.10.0",
38
- "lodash-es": "^4.17.21"
38
+ "lodash-es": "^4.17.21",
39
+ "p-limit": "^7.2.0"
39
40
  },
40
41
  "devDependencies": {
41
42
  "@botpress/client": "workspace:^",
@@ -53,7 +54,7 @@
53
54
  },
54
55
  "peerDependencies": {
55
56
  "@bpinternal/thicktoken": "^1.0.0",
56
- "@bpinternal/zui": "1.2.1"
57
+ "@bpinternal/zui": "^1.2.2"
57
58
  },
58
59
  "engines": {
59
60
  "node": ">=18.0.0"
package/src/index.ts CHANGED
@@ -7,5 +7,6 @@ import './operations/check'
7
7
  import './operations/filter'
8
8
  import './operations/extract'
9
9
  import './operations/label'
10
+ import './operations/group'
10
11
 
11
12
  export { Zai }
@@ -1,9 +1,104 @@
1
+ import { ZodError } from '@bpinternal/zui'
2
+
1
3
  export class JsonParsingError extends Error {
2
4
  public constructor(
3
5
  public json: unknown,
4
6
  public error: Error
5
7
  ) {
6
- const message = `Error parsing JSON:\n\n---JSON---\n${json}\n\n---Error---\n\n ${error}`
8
+ const message = JsonParsingError._formatError(json, error)
7
9
  super(message)
8
10
  }
11
+
12
+ private static _formatError(json: unknown, error: Error): string {
13
+ let errorMessage = 'Error parsing JSON:\n\n'
14
+ errorMessage += `---JSON---\n${json}\n\n`
15
+
16
+ if (error instanceof ZodError) {
17
+ errorMessage += '---Validation Errors---\n\n'
18
+ errorMessage += JsonParsingError._formatZodError(error)
19
+ } else {
20
+ errorMessage += '---Error---\n\n'
21
+ errorMessage += 'The JSON provided is not valid JSON.\n'
22
+ errorMessage += `Details: ${error.message}\n`
23
+ }
24
+
25
+ return errorMessage
26
+ }
27
+
28
+ private static _formatZodError(zodError: ZodError): string {
29
+ const issues = zodError.issues
30
+ if (issues.length === 0) {
31
+ return 'Unknown validation error\n'
32
+ }
33
+
34
+ let message = ''
35
+ for (let i = 0; i < issues.length; i++) {
36
+ const issue = issues[i]
37
+ const path = issue.path.length > 0 ? issue.path.join('.') : 'root'
38
+
39
+ message += `${i + 1}. Field: "${path}"\n`
40
+
41
+ switch (issue.code) {
42
+ case 'invalid_type':
43
+ message += ` Problem: Expected ${issue.expected}, but received ${issue.received}\n`
44
+ message += ` Message: ${issue.message}\n`
45
+ break
46
+ case 'invalid_string':
47
+ if ('validation' in issue) {
48
+ message += ` Problem: Invalid ${issue.validation} format\n`
49
+ }
50
+ message += ` Message: ${issue.message}\n`
51
+ break
52
+ case 'too_small':
53
+ if (issue.type === 'string') {
54
+ if (issue.exact) {
55
+ message += ` Problem: String must be exactly ${issue.minimum} characters\n`
56
+ } else {
57
+ message += ` Problem: String must be at least ${issue.minimum} characters\n`
58
+ }
59
+ } else if (issue.type === 'number') {
60
+ message += ` Problem: Number must be ${issue.inclusive ? 'at least' : 'greater than'} ${issue.minimum}\n`
61
+ } else if (issue.type === 'array') {
62
+ message += ` Problem: Array must contain ${issue.inclusive ? 'at least' : 'more than'} ${issue.minimum} items\n`
63
+ }
64
+ message += ` Message: ${issue.message}\n`
65
+ break
66
+ case 'too_big':
67
+ if (issue.type === 'string') {
68
+ if (issue.exact) {
69
+ message += ` Problem: String must be exactly ${issue.maximum} characters\n`
70
+ } else {
71
+ message += ` Problem: String must be at most ${issue.maximum} characters\n`
72
+ }
73
+ } else if (issue.type === 'number') {
74
+ message += ` Problem: Number must be ${issue.inclusive ? 'at most' : 'less than'} ${issue.maximum}\n`
75
+ } else if (issue.type === 'array') {
76
+ message += ` Problem: Array must contain ${issue.inclusive ? 'at most' : 'fewer than'} ${issue.maximum} items\n`
77
+ }
78
+ message += ` Message: ${issue.message}\n`
79
+ break
80
+ case 'invalid_enum_value':
81
+ message += ` Problem: Invalid value "${issue.received}"\n`
82
+ message += ` Allowed values: ${issue.options.map((o: any) => `"${o}"`).join(', ')}\n`
83
+ message += ` Message: ${issue.message}\n`
84
+ break
85
+ case 'invalid_literal':
86
+ message += ` Problem: Expected the literal value "${issue.expected}", but received "${issue.received}"\n`
87
+ message += ` Message: ${issue.message}\n`
88
+ break
89
+ case 'invalid_union':
90
+ message += " Problem: Value doesn't match any of the expected formats\n"
91
+ message += ` Message: ${issue.message}\n`
92
+ break
93
+ default:
94
+ message += ` Problem: ${issue.message}\n`
95
+ }
96
+
97
+ if (i < issues.length - 1) {
98
+ message += '\n'
99
+ }
100
+ }
101
+
102
+ return message
103
+ }
9
104
  }
@@ -1,10 +1,11 @@
1
1
  // eslint-disable consistent-type-definitions
2
- import { z, ZodObject } from '@bpinternal/zui'
2
+ import { z, ZodObject, transforms } from '@bpinternal/zui'
3
3
 
4
4
  import JSON5 from 'json5'
5
5
  import { jsonrepair } from 'jsonrepair'
6
6
 
7
7
  import { chunk, isArray } from 'lodash-es'
8
+ import pLimit from 'p-limit'
8
9
  import { ZaiContext } from '../context'
9
10
  import { Response } from '../response'
10
11
  import { getTokenizer } from '../tokenizer'
@@ -48,6 +49,7 @@ declare module '@botpress/zai' {
48
49
  const START = '■json_start■'
49
50
  const END = '■json_end■'
50
51
  const NO_MORE = '■NO_MORE_ELEMENT■'
52
+ const ZERO_ELEMENTS = '■ZERO_ELEMENTS■'
51
53
 
52
54
  const extract = async <S extends OfType<AnyObjectOrArray>>(
53
55
  input: unknown,
@@ -56,7 +58,9 @@ const extract = async <S extends OfType<AnyObjectOrArray>>(
56
58
  ctx: ZaiContext
57
59
  ): Promise<S['_output']> => {
58
60
  ctx.controller.signal.throwIfAborted()
59
- let schema = _schema as any as z.ZodType
61
+
62
+ let schema = transforms.fromJSONSchema(transforms.toJSONSchema(_schema as any as z.ZodType))
63
+
60
64
  const options = Options.parse(_options ?? {})
61
65
  const tokenizer = await getTokenizer()
62
66
  const model = await ctx.getModel()
@@ -110,18 +114,21 @@ const extract = async <S extends OfType<AnyObjectOrArray>>(
110
114
  const inputAsString = stringify(input)
111
115
 
112
116
  if (tokenizer.count(inputAsString) > options.chunkLength) {
117
+ const limit = pLimit(10) // Limit to 10 concurrent extraction operations
113
118
  const tokens = tokenizer.split(inputAsString)
114
119
  const chunks = chunk(tokens, options.chunkLength).map((x) => x.join(''))
115
120
  const all = await Promise.allSettled(
116
121
  chunks.map((chunk) =>
117
- extract(
118
- chunk,
119
- originalSchema,
120
- {
121
- ...options,
122
- strict: false, // We don't want to fail on strict mode for sub-chunks
123
- },
124
- ctx
122
+ limit(() =>
123
+ extract(
124
+ chunk,
125
+ originalSchema,
126
+ {
127
+ ...options,
128
+ strict: false, // We don't want to fail on strict mode for sub-chunks
129
+ },
130
+ ctx
131
+ )
125
132
  )
126
133
  )
127
134
  ).then((results) =>
@@ -162,8 +169,11 @@ Merge it back into a final result.`.trim(),
162
169
  instructions.push('You may have multiple elements, or zero elements in the input.')
163
170
  instructions.push('You must extract each element separately.')
164
171
  instructions.push(`Each element must be a JSON object with exactly the format: ${START}${shape}${END}`)
172
+ instructions.push(`If there are no elements to extract, respond with ${ZERO_ELEMENTS}.`)
165
173
  instructions.push(`When you are done extracting all elements, type "${NO_MORE}" to finish.`)
166
- instructions.push(`For example, if you have zero elements, the output should look like this: ${NO_MORE}`)
174
+ instructions.push(
175
+ `For example, if you have zero elements, the output should look like this: ${ZERO_ELEMENTS}${NO_MORE}`
176
+ )
167
177
  instructions.push(
168
178
  `For example, if you have two elements, the output should look like this: ${START}${abbv}${END}${START}${abbv}${END}${NO_MORE}`
169
179
  )
@@ -2,6 +2,7 @@
2
2
  import { z } from '@bpinternal/zui'
3
3
 
4
4
  import { clamp } from 'lodash-es'
5
+ import pLimit from 'p-limit'
5
6
  import { ZaiContext } from '../context'
6
7
  import { Response } from '../response'
7
8
  import { getTokenizer } from '../tokenizer'
@@ -259,7 +260,8 @@ The condition is: "${condition}"
259
260
  return partial
260
261
  }
261
262
 
262
- const filteredChunks = await Promise.all(chunks.map(filterChunk))
263
+ const limit = pLimit(10) // Limit to 10 concurrent filtering operations
264
+ const filteredChunks = await Promise.all(chunks.map((chunk) => limit(() => filterChunk(chunk))))
263
265
 
264
266
  return filteredChunks.flat()
265
267
  }
@@ -0,0 +1,421 @@
1
+ // eslint-disable consistent-type-definitions
2
+ import { z } from '@bpinternal/zui'
3
+ import { clamp } from 'lodash-es'
4
+ import pLimit from 'p-limit'
5
+ import { ZaiContext } from '../context'
6
+ import { Response } from '../response'
7
+ import { getTokenizer } from '../tokenizer'
8
+ import { stringify } from '../utils'
9
+ import { Zai } from '../zai'
10
+ import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
11
+
12
+ export type Group<T> = {
13
+ id: string
14
+ label: string
15
+ elements: T[]
16
+ }
17
+
18
+ type InitialGroup = {
19
+ id: string
20
+ label: string
21
+ elements?: unknown[]
22
+ }
23
+
24
+ const _InitialGroup = z.object({
25
+ id: z.string().min(1).max(100),
26
+ label: z.string().min(1).max(250),
27
+ elements: z.array(z.any()).optional().default([]),
28
+ })
29
+
30
+ export type Options = {
31
+ instructions?: string
32
+ tokensPerElement?: number
33
+ chunkLength?: number
34
+ initialGroups?: Array<InitialGroup>
35
+ }
36
+
37
+ const _Options = z.object({
38
+ instructions: z.string().optional(),
39
+ tokensPerElement: z.number().min(1).max(100_000).optional().default(250),
40
+ chunkLength: z.number().min(100).max(100_000).optional().default(16_000),
41
+ initialGroups: z.array(_InitialGroup).optional().default([]),
42
+ })
43
+
44
+ declare module '@botpress/zai' {
45
+ interface Zai {
46
+ group<T>(input: Array<T>, options?: Options): Response<Array<Group<T>>, Record<string, T[]>>
47
+ }
48
+ }
49
+
50
+ const END = '■END■'
51
+
52
+ // Simplified data structures
53
+ type GroupInfo = {
54
+ id: string
55
+ label: string
56
+ normalizedLabel: string
57
+ }
58
+
59
+ const normalizeLabel = (label: string): string => {
60
+ return label
61
+ .trim()
62
+ .toLowerCase()
63
+ .replace(/^(group|new group|new)\s*[-:]\s*/i, '')
64
+ .replace(/^(group|new group|new)\s+/i, '')
65
+ .trim()
66
+ }
67
+
68
+ const group = async <T>(input: Array<T>, _options: Options | undefined, ctx: ZaiContext): Promise<Array<Group<T>>> => {
69
+ ctx.controller.signal.throwIfAborted()
70
+
71
+ const options = _Options.parse(_options ?? {})
72
+ const tokenizer = await getTokenizer()
73
+ const model = await ctx.getModel()
74
+
75
+ if (input.length === 0) {
76
+ return []
77
+ }
78
+
79
+ // Simple data structures
80
+ const groups = new Map<string, GroupInfo>() // groupId -> GroupInfo
81
+ const groupElements = new Map<string, Set<number>>() // groupId -> Set of element indices
82
+ const elementGroups = new Map<number, Set<string>>() // elementIndex -> Set of groupIds seen/assigned
83
+ const labelToGroupId = new Map<string, string>() // normalized label -> groupId
84
+ let groupIdCounter = 0
85
+
86
+ // Initialize with provided groups
87
+ options.initialGroups.forEach((ig) => {
88
+ const normalized = normalizeLabel(ig.label)
89
+ groups.set(ig.id, { id: ig.id, label: ig.label, normalizedLabel: normalized })
90
+ groupElements.set(ig.id, new Set())
91
+ labelToGroupId.set(normalized, ig.id)
92
+ })
93
+
94
+ // Prepare elements
95
+ const elements = input.map((element, idx) => ({
96
+ element,
97
+ index: idx,
98
+ stringified: stringify(element, false),
99
+ }))
100
+
101
+ // Token budget
102
+ const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER
103
+ const TOKENS_INSTRUCTIONS_MAX = options.instructions
104
+ ? clamp(tokenizer.count(options.instructions), 100, TOKENS_TOTAL_MAX * 0.2)
105
+ : 0
106
+ const TOKENS_AVAILABLE = TOKENS_TOTAL_MAX - TOKENS_INSTRUCTIONS_MAX
107
+ const TOKENS_FOR_GROUPS_MAX = Math.floor(TOKENS_AVAILABLE * 0.4)
108
+ const TOKENS_FOR_ELEMENTS_MAX = Math.floor(TOKENS_AVAILABLE * 0.6)
109
+
110
+ // Chunk elements by token budget
111
+ const MAX_ELEMENTS_PER_CHUNK = 50
112
+ const elementChunks: number[][] = [] // Array of element indices
113
+ let currentChunk: number[] = []
114
+ let currentTokens = 0
115
+
116
+ for (const elem of elements) {
117
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement)
118
+ const elemTokens = tokenizer.count(truncated)
119
+
120
+ if (
121
+ (currentTokens + elemTokens > TOKENS_FOR_ELEMENTS_MAX || currentChunk.length >= MAX_ELEMENTS_PER_CHUNK) &&
122
+ currentChunk.length > 0
123
+ ) {
124
+ elementChunks.push(currentChunk)
125
+ currentChunk = []
126
+ currentTokens = 0
127
+ }
128
+
129
+ currentChunk.push(elem.index)
130
+ currentTokens += elemTokens
131
+ }
132
+
133
+ if (currentChunk.length > 0) {
134
+ elementChunks.push(currentChunk)
135
+ }
136
+
137
+ // Helper to chunk groups
138
+ const getGroupChunks = (): string[][] => {
139
+ const allGroupIds = Array.from(groups.keys())
140
+ if (allGroupIds.length === 0) return [[]]
141
+
142
+ const chunks: string[][] = []
143
+ let currentChunk: string[] = []
144
+ let currentTokens = 0
145
+
146
+ for (const groupId of allGroupIds) {
147
+ const group = groups.get(groupId)!
148
+ const groupTokens = tokenizer.count(`${group.label}`) + 10
149
+
150
+ if (currentTokens + groupTokens > TOKENS_FOR_GROUPS_MAX && currentChunk.length > 0) {
151
+ chunks.push(currentChunk)
152
+ currentChunk = []
153
+ currentTokens = 0
154
+ }
155
+
156
+ currentChunk.push(groupId)
157
+ currentTokens += groupTokens
158
+ }
159
+
160
+ if (currentChunk.length > 0) {
161
+ chunks.push(currentChunk)
162
+ }
163
+
164
+ return chunks.length > 0 ? chunks : [[]]
165
+ }
166
+
167
+ // Process elements against groups and get assignments
168
+ const processChunk = async (
169
+ elementIndices: number[],
170
+ groupIds: string[]
171
+ ): Promise<Array<{ elementIndex: number; label: string }>> => {
172
+ const elementsText = elementIndices
173
+ .map((idx, i) => {
174
+ const elem = elements[idx]
175
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement)
176
+ return `■${i}: ${truncated}■`
177
+ })
178
+ .join('\n')
179
+
180
+ const groupsList = groupIds.map((gid) => groups.get(gid)!.label)
181
+ const groupsText =
182
+ groupsList.length > 0
183
+ ? `**Existing Groups (prefer reusing these):**\n${groupsList.map((l) => `- ${l}`).join('\n')}\n\n`
184
+ : ''
185
+
186
+ const systemPrompt = `You are grouping elements into cohesive groups.
187
+
188
+ ${options.instructions ? `**Instructions:** ${options.instructions}\n` : '**Instructions:** Group similar elements together.'}
189
+
190
+ **Important:**
191
+ - Each element gets exactly ONE group label
192
+ - Use EXACT SAME label for similar items (case-sensitive)
193
+ - Create new descriptive labels when needed
194
+
195
+ **Output Format:**
196
+ One line per element:
197
+ ■0:Group Label■
198
+ ■1:Group Label■
199
+ ${END}`.trim()
200
+
201
+ const userPrompt = `${groupsText}**Elements (■0 to ■${elementIndices.length - 1}):**
202
+ ${elementsText}
203
+
204
+ **Task:** For each element, output one line with its group label.
205
+ ${END}`.trim()
206
+
207
+ const { extracted } = await ctx.generateContent({
208
+ systemPrompt,
209
+ stopSequences: [END],
210
+ messages: [{ type: 'text', role: 'user', content: userPrompt }],
211
+ transform: (text) => {
212
+ const assignments: Array<{ elementIndex: number; label: string }> = []
213
+ const regex = /■(\d+):([^■]+)■/g
214
+ let match: RegExpExecArray | null
215
+
216
+ while ((match = regex.exec(text)) !== null) {
217
+ const idx = parseInt(match[1] ?? '', 10)
218
+ if (isNaN(idx) || idx < 0 || idx >= elementIndices.length) continue
219
+
220
+ const label = (match[2] ?? '').trim()
221
+ if (!label) continue
222
+
223
+ assignments.push({
224
+ elementIndex: elementIndices[idx],
225
+ label: label.slice(0, 250),
226
+ })
227
+ }
228
+
229
+ return assignments
230
+ },
231
+ })
232
+
233
+ return extracted
234
+ }
235
+
236
+ // Phase 1: Process all element chunks against current groups IN PARALLEL
237
+ const elementLimit = pLimit(10) // Separate limiter for element chunks
238
+ const groupLimit = pLimit(10) // Separate limiter for group chunks
239
+
240
+ // Collect all assignments from parallel processing
241
+ const allChunkResults = await Promise.all(
242
+ elementChunks.map((elementChunk) =>
243
+ elementLimit(async () => {
244
+ const groupChunks = getGroupChunks()
245
+
246
+ const allAssignments = await Promise.all(
247
+ groupChunks.map((groupChunk) => groupLimit(() => processChunk(elementChunk, groupChunk)))
248
+ )
249
+
250
+ return allAssignments.flat()
251
+ })
252
+ )
253
+ )
254
+
255
+ // Process all assignments sequentially to avoid race conditions
256
+ for (const assignments of allChunkResults) {
257
+ for (const { elementIndex, label } of assignments) {
258
+ const normalized = normalizeLabel(label)
259
+ let groupId = labelToGroupId.get(normalized)
260
+
261
+ if (!groupId) {
262
+ // Create new group
263
+ groupId = `group_${groupIdCounter++}`
264
+ groups.set(groupId, { id: groupId, label, normalizedLabel: normalized })
265
+ groupElements.set(groupId, new Set())
266
+ labelToGroupId.set(normalized, groupId)
267
+ }
268
+
269
+ // Add element to group
270
+ groupElements.get(groupId)!.add(elementIndex)
271
+
272
+ // Track that element saw this group
273
+ if (!elementGroups.has(elementIndex)) {
274
+ elementGroups.set(elementIndex, new Set())
275
+ }
276
+ elementGroups.get(elementIndex)!.add(groupId)
277
+ }
278
+ }
279
+
280
+ // Phase 2: Ensure all elements saw all groups (coverage guarantee)
281
+ const allGroupIds = Array.from(groups.keys())
282
+
283
+ if (allGroupIds.length > 0) {
284
+ const elementsNeedingReview: number[] = []
285
+
286
+ for (const elem of elements) {
287
+ const seenGroups = elementGroups.get(elem.index) ?? new Set()
288
+ const unseenCount = allGroupIds.filter((gid) => !seenGroups.has(gid)).length
289
+
290
+ if (unseenCount > 0) {
291
+ elementsNeedingReview.push(elem.index)
292
+ }
293
+ }
294
+
295
+ if (elementsNeedingReview.length > 0) {
296
+ // Chunk elements needing review
297
+ const reviewChunks: number[][] = []
298
+ let reviewChunk: number[] = []
299
+ let reviewTokens = 0
300
+
301
+ for (const elemIdx of elementsNeedingReview) {
302
+ const elem = elements[elemIdx]
303
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement)
304
+ const elemTokens = tokenizer.count(truncated)
305
+
306
+ const shouldStartNewChunk =
307
+ (reviewTokens + elemTokens > TOKENS_FOR_ELEMENTS_MAX || reviewChunk.length >= MAX_ELEMENTS_PER_CHUNK) &&
308
+ reviewChunk.length > 0
309
+
310
+ if (shouldStartNewChunk) {
311
+ reviewChunks.push(reviewChunk)
312
+ reviewChunk = []
313
+ reviewTokens = 0
314
+ }
315
+
316
+ reviewChunk.push(elemIdx)
317
+ reviewTokens += elemTokens
318
+ }
319
+
320
+ if (reviewChunk.length > 0) {
321
+ reviewChunks.push(reviewChunk)
322
+ }
323
+
324
+ // Process review chunks IN PARALLEL
325
+ const reviewResults = await Promise.all(
326
+ reviewChunks.map((chunk) =>
327
+ elementLimit(async () => {
328
+ const groupChunks = getGroupChunks()
329
+
330
+ const allAssignments = await Promise.all(
331
+ groupChunks.map((groupChunk) => groupLimit(() => processChunk(chunk, groupChunk)))
332
+ )
333
+
334
+ return allAssignments.flat()
335
+ })
336
+ )
337
+ )
338
+
339
+ // Mark groups as seen and update assignments (sequential to avoid races)
340
+ const updateElementGroupAssignment = (elementIndex: number, label: string) => {
341
+ const normalized = normalizeLabel(label)
342
+ const groupId = labelToGroupId.get(normalized)
343
+ if (!groupId) return
344
+
345
+ // Add to group and mark as seen
346
+ groupElements.get(groupId)!.add(elementIndex)
347
+
348
+ // Initialize element groups if needed
349
+ const elemGroups = elementGroups.get(elementIndex) ?? new Set()
350
+ if (!elementGroups.has(elementIndex)) {
351
+ elementGroups.set(elementIndex, elemGroups)
352
+ }
353
+ elemGroups.add(groupId)
354
+ }
355
+
356
+ for (const assignments of reviewResults) {
357
+ for (const { elementIndex, label } of assignments) {
358
+ updateElementGroupAssignment(elementIndex, label)
359
+ }
360
+ }
361
+ }
362
+ }
363
+
364
+ // Phase 3: Resolve conflicts (elements in multiple groups)
365
+ for (const [elementIndex, groupSet] of elementGroups.entries()) {
366
+ if (groupSet.size > 1) {
367
+ // Element is in multiple groups, keep only the most common assignment
368
+ const groupIds = Array.from(groupSet)
369
+
370
+ // Remove from all groups
371
+ for (const gid of groupIds) {
372
+ groupElements.get(gid)?.delete(elementIndex)
373
+ }
374
+
375
+ // Re-assign to first group (or could use LLM to decide)
376
+ const finalGroupId = groupIds[0]
377
+ groupElements.get(finalGroupId)!.add(elementIndex)
378
+ }
379
+ }
380
+
381
+ // Build final result
382
+ const result: Array<Group<T>> = []
383
+
384
+ for (const [groupId, elementIndices] of groupElements.entries()) {
385
+ if (elementIndices.size > 0) {
386
+ const groupInfo = groups.get(groupId)!
387
+ result.push({
388
+ id: groupInfo.id,
389
+ label: groupInfo.label,
390
+ elements: Array.from(elementIndices).map((idx) => elements[idx].element),
391
+ })
392
+ }
393
+ }
394
+
395
+ return result
396
+ }
397
+
398
+ Zai.prototype.group = function <T>(
399
+ this: Zai,
400
+ input: Array<T>,
401
+ _options?: Options
402
+ ): Response<Array<Group<T>>, Record<string, T[]>> {
403
+ const context = new ZaiContext({
404
+ client: this.client,
405
+ modelId: this.Model,
406
+ taskId: this.taskId,
407
+ taskType: 'zai.group',
408
+ adapter: this.adapter,
409
+ })
410
+
411
+ return new Response<Array<Group<T>>, Record<string, T[]>>(context, group(input, _options, context), (result) => {
412
+ const merged: Record<string, T[]> = {}
413
+ result.forEach((group) => {
414
+ if (!merged[group.label]) {
415
+ merged[group.label] = []
416
+ }
417
+ merged[group.label].push(...group.elements)
418
+ })
419
+ return merged
420
+ })
421
+ }
@@ -2,6 +2,7 @@
2
2
  import { z } from '@bpinternal/zui'
3
3
 
4
4
  import { chunk, clamp } from 'lodash-es'
5
+ import pLimit from 'p-limit'
5
6
  import { ZaiContext } from '../context'
6
7
  import { Response } from '../response'
7
8
  import { getTokenizer } from '../tokenizer'
@@ -162,9 +163,10 @@ const label = async <T extends string>(
162
163
  const inputAsString = stringify(input)
163
164
 
164
165
  if (tokenizer.count(inputAsString) > CHUNK_INPUT_MAX_TOKENS) {
166
+ const limit = pLimit(10) // Limit to 10 concurrent labeling operations
165
167
  const tokens = tokenizer.split(inputAsString)
166
168
  const chunks = chunk(tokens, CHUNK_INPUT_MAX_TOKENS).map((x) => x.join(''))
167
- const allLabels = await Promise.all(chunks.map((chunk) => label(chunk, _labels, _options, ctx)))
169
+ const allLabels = await Promise.all(chunks.map((chunk) => limit(() => label(chunk, _labels, _options, ctx))))
168
170
 
169
171
  // Merge all the labels together (those who are true will remain true)
170
172
  return allLabels.reduce((acc, x) => {
@@ -2,6 +2,7 @@
2
2
  import { z } from '@bpinternal/zui'
3
3
 
4
4
  import { chunk } from 'lodash-es'
5
+ import pLimit from 'p-limit'
5
6
  import { ZaiContext } from '../context'
6
7
  import { Response } from '../response'
7
8
 
@@ -115,9 +116,9 @@ ${newText}
115
116
  const chunkSize = Math.ceil(tokens.length / (parts * N))
116
117
 
117
118
  if (useMergeSort) {
118
- // TODO: use pLimit here to not have too many chunks
119
+ const limit = pLimit(10) // Limit to 10 concurrent summarization operations
119
120
  const chunks = chunk(tokens, chunkSize).map((x) => x.join(''))
120
- const allSummaries = (await Promise.allSettled(chunks.map((chunk) => summarize(chunk, options, ctx))))
121
+ const allSummaries = (await Promise.allSettled(chunks.map((chunk) => limit(() => summarize(chunk, options, ctx)))))
121
122
  .filter((x) => x.status === 'fulfilled')
122
123
  .map((x) => x.value)
123
124
  return summarize(allSummaries.join('\n\n============\n\n'), options, ctx)