@botpress/zai 2.1.19 → 2.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CLAUDE.md +696 -0
- package/README.md +28 -2
- package/dist/index.d.ts +39 -18
- package/dist/index.js +1 -0
- package/dist/operations/errors.js +112 -8
- package/dist/operations/extract.js +20 -12
- package/dist/operations/filter.js +3 -1
- package/dist/operations/group.js +278 -0
- package/dist/operations/label.js +3 -1
- package/dist/operations/summarize.js +3 -1
- package/e2e/data/cache.jsonl +219 -0
- package/package.json +4 -3
- package/src/index.ts +1 -0
- package/src/operations/errors.ts +96 -1
- package/src/operations/extract.ts +21 -11
- package/src/operations/filter.ts +3 -1
- package/src/operations/group.ts +421 -0
- package/src/operations/label.ts +3 -1
- package/src/operations/summarize.ts +3 -2
- package/src/zai.ts +7 -9
package/package.json
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@botpress/zai",
|
|
3
3
|
"description": "Zui AI (zai) – An LLM utility library written on top of Zui and the Botpress API",
|
|
4
|
-
"version": "2.
|
|
4
|
+
"version": "2.2.0",
|
|
5
5
|
"main": "./dist/index.js",
|
|
6
6
|
"types": "./dist/index.d.ts",
|
|
7
7
|
"exports": {
|
|
@@ -35,7 +35,8 @@
|
|
|
35
35
|
"@botpress/cognitive": "0.1.50",
|
|
36
36
|
"json5": "^2.2.3",
|
|
37
37
|
"jsonrepair": "^3.10.0",
|
|
38
|
-
"lodash-es": "^4.17.21"
|
|
38
|
+
"lodash-es": "^4.17.21",
|
|
39
|
+
"p-limit": "^7.2.0"
|
|
39
40
|
},
|
|
40
41
|
"devDependencies": {
|
|
41
42
|
"@botpress/client": "workspace:^",
|
|
@@ -53,7 +54,7 @@
|
|
|
53
54
|
},
|
|
54
55
|
"peerDependencies": {
|
|
55
56
|
"@bpinternal/thicktoken": "^1.0.0",
|
|
56
|
-
"@bpinternal/zui": "1.2.
|
|
57
|
+
"@bpinternal/zui": "^1.2.2"
|
|
57
58
|
},
|
|
58
59
|
"engines": {
|
|
59
60
|
"node": ">=18.0.0"
|
package/src/index.ts
CHANGED
package/src/operations/errors.ts
CHANGED
|
@@ -1,9 +1,104 @@
|
|
|
1
|
+
import { ZodError } from '@bpinternal/zui'
|
|
2
|
+
|
|
1
3
|
export class JsonParsingError extends Error {
|
|
2
4
|
public constructor(
|
|
3
5
|
public json: unknown,
|
|
4
6
|
public error: Error
|
|
5
7
|
) {
|
|
6
|
-
const message =
|
|
8
|
+
const message = JsonParsingError._formatError(json, error)
|
|
7
9
|
super(message)
|
|
8
10
|
}
|
|
11
|
+
|
|
12
|
+
private static _formatError(json: unknown, error: Error): string {
|
|
13
|
+
let errorMessage = 'Error parsing JSON:\n\n'
|
|
14
|
+
errorMessage += `---JSON---\n${json}\n\n`
|
|
15
|
+
|
|
16
|
+
if (error instanceof ZodError) {
|
|
17
|
+
errorMessage += '---Validation Errors---\n\n'
|
|
18
|
+
errorMessage += JsonParsingError._formatZodError(error)
|
|
19
|
+
} else {
|
|
20
|
+
errorMessage += '---Error---\n\n'
|
|
21
|
+
errorMessage += 'The JSON provided is not valid JSON.\n'
|
|
22
|
+
errorMessage += `Details: ${error.message}\n`
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
return errorMessage
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
private static _formatZodError(zodError: ZodError): string {
|
|
29
|
+
const issues = zodError.issues
|
|
30
|
+
if (issues.length === 0) {
|
|
31
|
+
return 'Unknown validation error\n'
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
let message = ''
|
|
35
|
+
for (let i = 0; i < issues.length; i++) {
|
|
36
|
+
const issue = issues[i]
|
|
37
|
+
const path = issue.path.length > 0 ? issue.path.join('.') : 'root'
|
|
38
|
+
|
|
39
|
+
message += `${i + 1}. Field: "${path}"\n`
|
|
40
|
+
|
|
41
|
+
switch (issue.code) {
|
|
42
|
+
case 'invalid_type':
|
|
43
|
+
message += ` Problem: Expected ${issue.expected}, but received ${issue.received}\n`
|
|
44
|
+
message += ` Message: ${issue.message}\n`
|
|
45
|
+
break
|
|
46
|
+
case 'invalid_string':
|
|
47
|
+
if ('validation' in issue) {
|
|
48
|
+
message += ` Problem: Invalid ${issue.validation} format\n`
|
|
49
|
+
}
|
|
50
|
+
message += ` Message: ${issue.message}\n`
|
|
51
|
+
break
|
|
52
|
+
case 'too_small':
|
|
53
|
+
if (issue.type === 'string') {
|
|
54
|
+
if (issue.exact) {
|
|
55
|
+
message += ` Problem: String must be exactly ${issue.minimum} characters\n`
|
|
56
|
+
} else {
|
|
57
|
+
message += ` Problem: String must be at least ${issue.minimum} characters\n`
|
|
58
|
+
}
|
|
59
|
+
} else if (issue.type === 'number') {
|
|
60
|
+
message += ` Problem: Number must be ${issue.inclusive ? 'at least' : 'greater than'} ${issue.minimum}\n`
|
|
61
|
+
} else if (issue.type === 'array') {
|
|
62
|
+
message += ` Problem: Array must contain ${issue.inclusive ? 'at least' : 'more than'} ${issue.minimum} items\n`
|
|
63
|
+
}
|
|
64
|
+
message += ` Message: ${issue.message}\n`
|
|
65
|
+
break
|
|
66
|
+
case 'too_big':
|
|
67
|
+
if (issue.type === 'string') {
|
|
68
|
+
if (issue.exact) {
|
|
69
|
+
message += ` Problem: String must be exactly ${issue.maximum} characters\n`
|
|
70
|
+
} else {
|
|
71
|
+
message += ` Problem: String must be at most ${issue.maximum} characters\n`
|
|
72
|
+
}
|
|
73
|
+
} else if (issue.type === 'number') {
|
|
74
|
+
message += ` Problem: Number must be ${issue.inclusive ? 'at most' : 'less than'} ${issue.maximum}\n`
|
|
75
|
+
} else if (issue.type === 'array') {
|
|
76
|
+
message += ` Problem: Array must contain ${issue.inclusive ? 'at most' : 'fewer than'} ${issue.maximum} items\n`
|
|
77
|
+
}
|
|
78
|
+
message += ` Message: ${issue.message}\n`
|
|
79
|
+
break
|
|
80
|
+
case 'invalid_enum_value':
|
|
81
|
+
message += ` Problem: Invalid value "${issue.received}"\n`
|
|
82
|
+
message += ` Allowed values: ${issue.options.map((o: any) => `"${o}"`).join(', ')}\n`
|
|
83
|
+
message += ` Message: ${issue.message}\n`
|
|
84
|
+
break
|
|
85
|
+
case 'invalid_literal':
|
|
86
|
+
message += ` Problem: Expected the literal value "${issue.expected}", but received "${issue.received}"\n`
|
|
87
|
+
message += ` Message: ${issue.message}\n`
|
|
88
|
+
break
|
|
89
|
+
case 'invalid_union':
|
|
90
|
+
message += " Problem: Value doesn't match any of the expected formats\n"
|
|
91
|
+
message += ` Message: ${issue.message}\n`
|
|
92
|
+
break
|
|
93
|
+
default:
|
|
94
|
+
message += ` Problem: ${issue.message}\n`
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
if (i < issues.length - 1) {
|
|
98
|
+
message += '\n'
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
return message
|
|
103
|
+
}
|
|
9
104
|
}
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
// eslint-disable consistent-type-definitions
|
|
2
|
-
import { z, ZodObject } from '@bpinternal/zui'
|
|
2
|
+
import { z, ZodObject, transforms } from '@bpinternal/zui'
|
|
3
3
|
|
|
4
4
|
import JSON5 from 'json5'
|
|
5
5
|
import { jsonrepair } from 'jsonrepair'
|
|
6
6
|
|
|
7
7
|
import { chunk, isArray } from 'lodash-es'
|
|
8
|
+
import pLimit from 'p-limit'
|
|
8
9
|
import { ZaiContext } from '../context'
|
|
9
10
|
import { Response } from '../response'
|
|
10
11
|
import { getTokenizer } from '../tokenizer'
|
|
@@ -48,6 +49,7 @@ declare module '@botpress/zai' {
|
|
|
48
49
|
const START = '■json_start■'
|
|
49
50
|
const END = '■json_end■'
|
|
50
51
|
const NO_MORE = '■NO_MORE_ELEMENT■'
|
|
52
|
+
const ZERO_ELEMENTS = '■ZERO_ELEMENTS■'
|
|
51
53
|
|
|
52
54
|
const extract = async <S extends OfType<AnyObjectOrArray>>(
|
|
53
55
|
input: unknown,
|
|
@@ -56,7 +58,9 @@ const extract = async <S extends OfType<AnyObjectOrArray>>(
|
|
|
56
58
|
ctx: ZaiContext
|
|
57
59
|
): Promise<S['_output']> => {
|
|
58
60
|
ctx.controller.signal.throwIfAborted()
|
|
59
|
-
|
|
61
|
+
|
|
62
|
+
let schema = transforms.fromJSONSchema(transforms.toJSONSchema(_schema as any as z.ZodType))
|
|
63
|
+
|
|
60
64
|
const options = Options.parse(_options ?? {})
|
|
61
65
|
const tokenizer = await getTokenizer()
|
|
62
66
|
const model = await ctx.getModel()
|
|
@@ -110,18 +114,21 @@ const extract = async <S extends OfType<AnyObjectOrArray>>(
|
|
|
110
114
|
const inputAsString = stringify(input)
|
|
111
115
|
|
|
112
116
|
if (tokenizer.count(inputAsString) > options.chunkLength) {
|
|
117
|
+
const limit = pLimit(10) // Limit to 10 concurrent extraction operations
|
|
113
118
|
const tokens = tokenizer.split(inputAsString)
|
|
114
119
|
const chunks = chunk(tokens, options.chunkLength).map((x) => x.join(''))
|
|
115
120
|
const all = await Promise.allSettled(
|
|
116
121
|
chunks.map((chunk) =>
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
122
|
+
limit(() =>
|
|
123
|
+
extract(
|
|
124
|
+
chunk,
|
|
125
|
+
originalSchema,
|
|
126
|
+
{
|
|
127
|
+
...options,
|
|
128
|
+
strict: false, // We don't want to fail on strict mode for sub-chunks
|
|
129
|
+
},
|
|
130
|
+
ctx
|
|
131
|
+
)
|
|
125
132
|
)
|
|
126
133
|
)
|
|
127
134
|
).then((results) =>
|
|
@@ -162,8 +169,11 @@ Merge it back into a final result.`.trim(),
|
|
|
162
169
|
instructions.push('You may have multiple elements, or zero elements in the input.')
|
|
163
170
|
instructions.push('You must extract each element separately.')
|
|
164
171
|
instructions.push(`Each element must be a JSON object with exactly the format: ${START}${shape}${END}`)
|
|
172
|
+
instructions.push(`If there are no elements to extract, respond with ${ZERO_ELEMENTS}.`)
|
|
165
173
|
instructions.push(`When you are done extracting all elements, type "${NO_MORE}" to finish.`)
|
|
166
|
-
instructions.push(
|
|
174
|
+
instructions.push(
|
|
175
|
+
`For example, if you have zero elements, the output should look like this: ${ZERO_ELEMENTS}${NO_MORE}`
|
|
176
|
+
)
|
|
167
177
|
instructions.push(
|
|
168
178
|
`For example, if you have two elements, the output should look like this: ${START}${abbv}${END}${START}${abbv}${END}${NO_MORE}`
|
|
169
179
|
)
|
package/src/operations/filter.ts
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
import { z } from '@bpinternal/zui'
|
|
3
3
|
|
|
4
4
|
import { clamp } from 'lodash-es'
|
|
5
|
+
import pLimit from 'p-limit'
|
|
5
6
|
import { ZaiContext } from '../context'
|
|
6
7
|
import { Response } from '../response'
|
|
7
8
|
import { getTokenizer } from '../tokenizer'
|
|
@@ -259,7 +260,8 @@ The condition is: "${condition}"
|
|
|
259
260
|
return partial
|
|
260
261
|
}
|
|
261
262
|
|
|
262
|
-
const
|
|
263
|
+
const limit = pLimit(10) // Limit to 10 concurrent filtering operations
|
|
264
|
+
const filteredChunks = await Promise.all(chunks.map((chunk) => limit(() => filterChunk(chunk))))
|
|
263
265
|
|
|
264
266
|
return filteredChunks.flat()
|
|
265
267
|
}
|
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
// eslint-disable consistent-type-definitions
|
|
2
|
+
import { z } from '@bpinternal/zui'
|
|
3
|
+
import { clamp } from 'lodash-es'
|
|
4
|
+
import pLimit from 'p-limit'
|
|
5
|
+
import { ZaiContext } from '../context'
|
|
6
|
+
import { Response } from '../response'
|
|
7
|
+
import { getTokenizer } from '../tokenizer'
|
|
8
|
+
import { stringify } from '../utils'
|
|
9
|
+
import { Zai } from '../zai'
|
|
10
|
+
import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
|
|
11
|
+
|
|
12
|
+
export type Group<T> = {
|
|
13
|
+
id: string
|
|
14
|
+
label: string
|
|
15
|
+
elements: T[]
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
type InitialGroup = {
|
|
19
|
+
id: string
|
|
20
|
+
label: string
|
|
21
|
+
elements?: unknown[]
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
const _InitialGroup = z.object({
|
|
25
|
+
id: z.string().min(1).max(100),
|
|
26
|
+
label: z.string().min(1).max(250),
|
|
27
|
+
elements: z.array(z.any()).optional().default([]),
|
|
28
|
+
})
|
|
29
|
+
|
|
30
|
+
export type Options = {
|
|
31
|
+
instructions?: string
|
|
32
|
+
tokensPerElement?: number
|
|
33
|
+
chunkLength?: number
|
|
34
|
+
initialGroups?: Array<InitialGroup>
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
const _Options = z.object({
|
|
38
|
+
instructions: z.string().optional(),
|
|
39
|
+
tokensPerElement: z.number().min(1).max(100_000).optional().default(250),
|
|
40
|
+
chunkLength: z.number().min(100).max(100_000).optional().default(16_000),
|
|
41
|
+
initialGroups: z.array(_InitialGroup).optional().default([]),
|
|
42
|
+
})
|
|
43
|
+
|
|
44
|
+
declare module '@botpress/zai' {
|
|
45
|
+
interface Zai {
|
|
46
|
+
group<T>(input: Array<T>, options?: Options): Response<Array<Group<T>>, Record<string, T[]>>
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
const END = '■END■'
|
|
51
|
+
|
|
52
|
+
// Simplified data structures
|
|
53
|
+
type GroupInfo = {
|
|
54
|
+
id: string
|
|
55
|
+
label: string
|
|
56
|
+
normalizedLabel: string
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
const normalizeLabel = (label: string): string => {
|
|
60
|
+
return label
|
|
61
|
+
.trim()
|
|
62
|
+
.toLowerCase()
|
|
63
|
+
.replace(/^(group|new group|new)\s*[-:]\s*/i, '')
|
|
64
|
+
.replace(/^(group|new group|new)\s+/i, '')
|
|
65
|
+
.trim()
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
const group = async <T>(input: Array<T>, _options: Options | undefined, ctx: ZaiContext): Promise<Array<Group<T>>> => {
|
|
69
|
+
ctx.controller.signal.throwIfAborted()
|
|
70
|
+
|
|
71
|
+
const options = _Options.parse(_options ?? {})
|
|
72
|
+
const tokenizer = await getTokenizer()
|
|
73
|
+
const model = await ctx.getModel()
|
|
74
|
+
|
|
75
|
+
if (input.length === 0) {
|
|
76
|
+
return []
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
// Simple data structures
|
|
80
|
+
const groups = new Map<string, GroupInfo>() // groupId -> GroupInfo
|
|
81
|
+
const groupElements = new Map<string, Set<number>>() // groupId -> Set of element indices
|
|
82
|
+
const elementGroups = new Map<number, Set<string>>() // elementIndex -> Set of groupIds seen/assigned
|
|
83
|
+
const labelToGroupId = new Map<string, string>() // normalized label -> groupId
|
|
84
|
+
let groupIdCounter = 0
|
|
85
|
+
|
|
86
|
+
// Initialize with provided groups
|
|
87
|
+
options.initialGroups.forEach((ig) => {
|
|
88
|
+
const normalized = normalizeLabel(ig.label)
|
|
89
|
+
groups.set(ig.id, { id: ig.id, label: ig.label, normalizedLabel: normalized })
|
|
90
|
+
groupElements.set(ig.id, new Set())
|
|
91
|
+
labelToGroupId.set(normalized, ig.id)
|
|
92
|
+
})
|
|
93
|
+
|
|
94
|
+
// Prepare elements
|
|
95
|
+
const elements = input.map((element, idx) => ({
|
|
96
|
+
element,
|
|
97
|
+
index: idx,
|
|
98
|
+
stringified: stringify(element, false),
|
|
99
|
+
}))
|
|
100
|
+
|
|
101
|
+
// Token budget
|
|
102
|
+
const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER
|
|
103
|
+
const TOKENS_INSTRUCTIONS_MAX = options.instructions
|
|
104
|
+
? clamp(tokenizer.count(options.instructions), 100, TOKENS_TOTAL_MAX * 0.2)
|
|
105
|
+
: 0
|
|
106
|
+
const TOKENS_AVAILABLE = TOKENS_TOTAL_MAX - TOKENS_INSTRUCTIONS_MAX
|
|
107
|
+
const TOKENS_FOR_GROUPS_MAX = Math.floor(TOKENS_AVAILABLE * 0.4)
|
|
108
|
+
const TOKENS_FOR_ELEMENTS_MAX = Math.floor(TOKENS_AVAILABLE * 0.6)
|
|
109
|
+
|
|
110
|
+
// Chunk elements by token budget
|
|
111
|
+
const MAX_ELEMENTS_PER_CHUNK = 50
|
|
112
|
+
const elementChunks: number[][] = [] // Array of element indices
|
|
113
|
+
let currentChunk: number[] = []
|
|
114
|
+
let currentTokens = 0
|
|
115
|
+
|
|
116
|
+
for (const elem of elements) {
|
|
117
|
+
const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement)
|
|
118
|
+
const elemTokens = tokenizer.count(truncated)
|
|
119
|
+
|
|
120
|
+
if (
|
|
121
|
+
(currentTokens + elemTokens > TOKENS_FOR_ELEMENTS_MAX || currentChunk.length >= MAX_ELEMENTS_PER_CHUNK) &&
|
|
122
|
+
currentChunk.length > 0
|
|
123
|
+
) {
|
|
124
|
+
elementChunks.push(currentChunk)
|
|
125
|
+
currentChunk = []
|
|
126
|
+
currentTokens = 0
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
currentChunk.push(elem.index)
|
|
130
|
+
currentTokens += elemTokens
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
if (currentChunk.length > 0) {
|
|
134
|
+
elementChunks.push(currentChunk)
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
// Helper to chunk groups
|
|
138
|
+
const getGroupChunks = (): string[][] => {
|
|
139
|
+
const allGroupIds = Array.from(groups.keys())
|
|
140
|
+
if (allGroupIds.length === 0) return [[]]
|
|
141
|
+
|
|
142
|
+
const chunks: string[][] = []
|
|
143
|
+
let currentChunk: string[] = []
|
|
144
|
+
let currentTokens = 0
|
|
145
|
+
|
|
146
|
+
for (const groupId of allGroupIds) {
|
|
147
|
+
const group = groups.get(groupId)!
|
|
148
|
+
const groupTokens = tokenizer.count(`${group.label}`) + 10
|
|
149
|
+
|
|
150
|
+
if (currentTokens + groupTokens > TOKENS_FOR_GROUPS_MAX && currentChunk.length > 0) {
|
|
151
|
+
chunks.push(currentChunk)
|
|
152
|
+
currentChunk = []
|
|
153
|
+
currentTokens = 0
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
currentChunk.push(groupId)
|
|
157
|
+
currentTokens += groupTokens
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
if (currentChunk.length > 0) {
|
|
161
|
+
chunks.push(currentChunk)
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
return chunks.length > 0 ? chunks : [[]]
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
// Process elements against groups and get assignments
|
|
168
|
+
const processChunk = async (
|
|
169
|
+
elementIndices: number[],
|
|
170
|
+
groupIds: string[]
|
|
171
|
+
): Promise<Array<{ elementIndex: number; label: string }>> => {
|
|
172
|
+
const elementsText = elementIndices
|
|
173
|
+
.map((idx, i) => {
|
|
174
|
+
const elem = elements[idx]
|
|
175
|
+
const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement)
|
|
176
|
+
return `■${i}: ${truncated}■`
|
|
177
|
+
})
|
|
178
|
+
.join('\n')
|
|
179
|
+
|
|
180
|
+
const groupsList = groupIds.map((gid) => groups.get(gid)!.label)
|
|
181
|
+
const groupsText =
|
|
182
|
+
groupsList.length > 0
|
|
183
|
+
? `**Existing Groups (prefer reusing these):**\n${groupsList.map((l) => `- ${l}`).join('\n')}\n\n`
|
|
184
|
+
: ''
|
|
185
|
+
|
|
186
|
+
const systemPrompt = `You are grouping elements into cohesive groups.
|
|
187
|
+
|
|
188
|
+
${options.instructions ? `**Instructions:** ${options.instructions}\n` : '**Instructions:** Group similar elements together.'}
|
|
189
|
+
|
|
190
|
+
**Important:**
|
|
191
|
+
- Each element gets exactly ONE group label
|
|
192
|
+
- Use EXACT SAME label for similar items (case-sensitive)
|
|
193
|
+
- Create new descriptive labels when needed
|
|
194
|
+
|
|
195
|
+
**Output Format:**
|
|
196
|
+
One line per element:
|
|
197
|
+
■0:Group Label■
|
|
198
|
+
■1:Group Label■
|
|
199
|
+
${END}`.trim()
|
|
200
|
+
|
|
201
|
+
const userPrompt = `${groupsText}**Elements (■0 to ■${elementIndices.length - 1}):**
|
|
202
|
+
${elementsText}
|
|
203
|
+
|
|
204
|
+
**Task:** For each element, output one line with its group label.
|
|
205
|
+
${END}`.trim()
|
|
206
|
+
|
|
207
|
+
const { extracted } = await ctx.generateContent({
|
|
208
|
+
systemPrompt,
|
|
209
|
+
stopSequences: [END],
|
|
210
|
+
messages: [{ type: 'text', role: 'user', content: userPrompt }],
|
|
211
|
+
transform: (text) => {
|
|
212
|
+
const assignments: Array<{ elementIndex: number; label: string }> = []
|
|
213
|
+
const regex = /■(\d+):([^■]+)■/g
|
|
214
|
+
let match: RegExpExecArray | null
|
|
215
|
+
|
|
216
|
+
while ((match = regex.exec(text)) !== null) {
|
|
217
|
+
const idx = parseInt(match[1] ?? '', 10)
|
|
218
|
+
if (isNaN(idx) || idx < 0 || idx >= elementIndices.length) continue
|
|
219
|
+
|
|
220
|
+
const label = (match[2] ?? '').trim()
|
|
221
|
+
if (!label) continue
|
|
222
|
+
|
|
223
|
+
assignments.push({
|
|
224
|
+
elementIndex: elementIndices[idx],
|
|
225
|
+
label: label.slice(0, 250),
|
|
226
|
+
})
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
return assignments
|
|
230
|
+
},
|
|
231
|
+
})
|
|
232
|
+
|
|
233
|
+
return extracted
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
// Phase 1: Process all element chunks against current groups IN PARALLEL
|
|
237
|
+
const elementLimit = pLimit(10) // Separate limiter for element chunks
|
|
238
|
+
const groupLimit = pLimit(10) // Separate limiter for group chunks
|
|
239
|
+
|
|
240
|
+
// Collect all assignments from parallel processing
|
|
241
|
+
const allChunkResults = await Promise.all(
|
|
242
|
+
elementChunks.map((elementChunk) =>
|
|
243
|
+
elementLimit(async () => {
|
|
244
|
+
const groupChunks = getGroupChunks()
|
|
245
|
+
|
|
246
|
+
const allAssignments = await Promise.all(
|
|
247
|
+
groupChunks.map((groupChunk) => groupLimit(() => processChunk(elementChunk, groupChunk)))
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
return allAssignments.flat()
|
|
251
|
+
})
|
|
252
|
+
)
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
// Process all assignments sequentially to avoid race conditions
|
|
256
|
+
for (const assignments of allChunkResults) {
|
|
257
|
+
for (const { elementIndex, label } of assignments) {
|
|
258
|
+
const normalized = normalizeLabel(label)
|
|
259
|
+
let groupId = labelToGroupId.get(normalized)
|
|
260
|
+
|
|
261
|
+
if (!groupId) {
|
|
262
|
+
// Create new group
|
|
263
|
+
groupId = `group_${groupIdCounter++}`
|
|
264
|
+
groups.set(groupId, { id: groupId, label, normalizedLabel: normalized })
|
|
265
|
+
groupElements.set(groupId, new Set())
|
|
266
|
+
labelToGroupId.set(normalized, groupId)
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
// Add element to group
|
|
270
|
+
groupElements.get(groupId)!.add(elementIndex)
|
|
271
|
+
|
|
272
|
+
// Track that element saw this group
|
|
273
|
+
if (!elementGroups.has(elementIndex)) {
|
|
274
|
+
elementGroups.set(elementIndex, new Set())
|
|
275
|
+
}
|
|
276
|
+
elementGroups.get(elementIndex)!.add(groupId)
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
// Phase 2: Ensure all elements saw all groups (coverage guarantee)
|
|
281
|
+
const allGroupIds = Array.from(groups.keys())
|
|
282
|
+
|
|
283
|
+
if (allGroupIds.length > 0) {
|
|
284
|
+
const elementsNeedingReview: number[] = []
|
|
285
|
+
|
|
286
|
+
for (const elem of elements) {
|
|
287
|
+
const seenGroups = elementGroups.get(elem.index) ?? new Set()
|
|
288
|
+
const unseenCount = allGroupIds.filter((gid) => !seenGroups.has(gid)).length
|
|
289
|
+
|
|
290
|
+
if (unseenCount > 0) {
|
|
291
|
+
elementsNeedingReview.push(elem.index)
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
if (elementsNeedingReview.length > 0) {
|
|
296
|
+
// Chunk elements needing review
|
|
297
|
+
const reviewChunks: number[][] = []
|
|
298
|
+
let reviewChunk: number[] = []
|
|
299
|
+
let reviewTokens = 0
|
|
300
|
+
|
|
301
|
+
for (const elemIdx of elementsNeedingReview) {
|
|
302
|
+
const elem = elements[elemIdx]
|
|
303
|
+
const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement)
|
|
304
|
+
const elemTokens = tokenizer.count(truncated)
|
|
305
|
+
|
|
306
|
+
const shouldStartNewChunk =
|
|
307
|
+
(reviewTokens + elemTokens > TOKENS_FOR_ELEMENTS_MAX || reviewChunk.length >= MAX_ELEMENTS_PER_CHUNK) &&
|
|
308
|
+
reviewChunk.length > 0
|
|
309
|
+
|
|
310
|
+
if (shouldStartNewChunk) {
|
|
311
|
+
reviewChunks.push(reviewChunk)
|
|
312
|
+
reviewChunk = []
|
|
313
|
+
reviewTokens = 0
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
reviewChunk.push(elemIdx)
|
|
317
|
+
reviewTokens += elemTokens
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
if (reviewChunk.length > 0) {
|
|
321
|
+
reviewChunks.push(reviewChunk)
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
// Process review chunks IN PARALLEL
|
|
325
|
+
const reviewResults = await Promise.all(
|
|
326
|
+
reviewChunks.map((chunk) =>
|
|
327
|
+
elementLimit(async () => {
|
|
328
|
+
const groupChunks = getGroupChunks()
|
|
329
|
+
|
|
330
|
+
const allAssignments = await Promise.all(
|
|
331
|
+
groupChunks.map((groupChunk) => groupLimit(() => processChunk(chunk, groupChunk)))
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
return allAssignments.flat()
|
|
335
|
+
})
|
|
336
|
+
)
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
// Mark groups as seen and update assignments (sequential to avoid races)
|
|
340
|
+
const updateElementGroupAssignment = (elementIndex: number, label: string) => {
|
|
341
|
+
const normalized = normalizeLabel(label)
|
|
342
|
+
const groupId = labelToGroupId.get(normalized)
|
|
343
|
+
if (!groupId) return
|
|
344
|
+
|
|
345
|
+
// Add to group and mark as seen
|
|
346
|
+
groupElements.get(groupId)!.add(elementIndex)
|
|
347
|
+
|
|
348
|
+
// Initialize element groups if needed
|
|
349
|
+
const elemGroups = elementGroups.get(elementIndex) ?? new Set()
|
|
350
|
+
if (!elementGroups.has(elementIndex)) {
|
|
351
|
+
elementGroups.set(elementIndex, elemGroups)
|
|
352
|
+
}
|
|
353
|
+
elemGroups.add(groupId)
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
for (const assignments of reviewResults) {
|
|
357
|
+
for (const { elementIndex, label } of assignments) {
|
|
358
|
+
updateElementGroupAssignment(elementIndex, label)
|
|
359
|
+
}
|
|
360
|
+
}
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
// Phase 3: Resolve conflicts (elements in multiple groups)
|
|
365
|
+
for (const [elementIndex, groupSet] of elementGroups.entries()) {
|
|
366
|
+
if (groupSet.size > 1) {
|
|
367
|
+
// Element is in multiple groups, keep only the most common assignment
|
|
368
|
+
const groupIds = Array.from(groupSet)
|
|
369
|
+
|
|
370
|
+
// Remove from all groups
|
|
371
|
+
for (const gid of groupIds) {
|
|
372
|
+
groupElements.get(gid)?.delete(elementIndex)
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
// Re-assign to first group (or could use LLM to decide)
|
|
376
|
+
const finalGroupId = groupIds[0]
|
|
377
|
+
groupElements.get(finalGroupId)!.add(elementIndex)
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
// Build final result
|
|
382
|
+
const result: Array<Group<T>> = []
|
|
383
|
+
|
|
384
|
+
for (const [groupId, elementIndices] of groupElements.entries()) {
|
|
385
|
+
if (elementIndices.size > 0) {
|
|
386
|
+
const groupInfo = groups.get(groupId)!
|
|
387
|
+
result.push({
|
|
388
|
+
id: groupInfo.id,
|
|
389
|
+
label: groupInfo.label,
|
|
390
|
+
elements: Array.from(elementIndices).map((idx) => elements[idx].element),
|
|
391
|
+
})
|
|
392
|
+
}
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
return result
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
Zai.prototype.group = function <T>(
|
|
399
|
+
this: Zai,
|
|
400
|
+
input: Array<T>,
|
|
401
|
+
_options?: Options
|
|
402
|
+
): Response<Array<Group<T>>, Record<string, T[]>> {
|
|
403
|
+
const context = new ZaiContext({
|
|
404
|
+
client: this.client,
|
|
405
|
+
modelId: this.Model,
|
|
406
|
+
taskId: this.taskId,
|
|
407
|
+
taskType: 'zai.group',
|
|
408
|
+
adapter: this.adapter,
|
|
409
|
+
})
|
|
410
|
+
|
|
411
|
+
return new Response<Array<Group<T>>, Record<string, T[]>>(context, group(input, _options, context), (result) => {
|
|
412
|
+
const merged: Record<string, T[]> = {}
|
|
413
|
+
result.forEach((group) => {
|
|
414
|
+
if (!merged[group.label]) {
|
|
415
|
+
merged[group.label] = []
|
|
416
|
+
}
|
|
417
|
+
merged[group.label].push(...group.elements)
|
|
418
|
+
})
|
|
419
|
+
return merged
|
|
420
|
+
})
|
|
421
|
+
}
|
package/src/operations/label.ts
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
import { z } from '@bpinternal/zui'
|
|
3
3
|
|
|
4
4
|
import { chunk, clamp } from 'lodash-es'
|
|
5
|
+
import pLimit from 'p-limit'
|
|
5
6
|
import { ZaiContext } from '../context'
|
|
6
7
|
import { Response } from '../response'
|
|
7
8
|
import { getTokenizer } from '../tokenizer'
|
|
@@ -162,9 +163,10 @@ const label = async <T extends string>(
|
|
|
162
163
|
const inputAsString = stringify(input)
|
|
163
164
|
|
|
164
165
|
if (tokenizer.count(inputAsString) > CHUNK_INPUT_MAX_TOKENS) {
|
|
166
|
+
const limit = pLimit(10) // Limit to 10 concurrent labeling operations
|
|
165
167
|
const tokens = tokenizer.split(inputAsString)
|
|
166
168
|
const chunks = chunk(tokens, CHUNK_INPUT_MAX_TOKENS).map((x) => x.join(''))
|
|
167
|
-
const allLabels = await Promise.all(chunks.map((chunk) => label(chunk, _labels, _options, ctx)))
|
|
169
|
+
const allLabels = await Promise.all(chunks.map((chunk) => limit(() => label(chunk, _labels, _options, ctx))))
|
|
168
170
|
|
|
169
171
|
// Merge all the labels together (those who are true will remain true)
|
|
170
172
|
return allLabels.reduce((acc, x) => {
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
import { z } from '@bpinternal/zui'
|
|
3
3
|
|
|
4
4
|
import { chunk } from 'lodash-es'
|
|
5
|
+
import pLimit from 'p-limit'
|
|
5
6
|
import { ZaiContext } from '../context'
|
|
6
7
|
import { Response } from '../response'
|
|
7
8
|
|
|
@@ -115,9 +116,9 @@ ${newText}
|
|
|
115
116
|
const chunkSize = Math.ceil(tokens.length / (parts * N))
|
|
116
117
|
|
|
117
118
|
if (useMergeSort) {
|
|
118
|
-
|
|
119
|
+
const limit = pLimit(10) // Limit to 10 concurrent summarization operations
|
|
119
120
|
const chunks = chunk(tokens, chunkSize).map((x) => x.join(''))
|
|
120
|
-
const allSummaries = (await Promise.allSettled(chunks.map((chunk) => summarize(chunk, options, ctx))))
|
|
121
|
+
const allSummaries = (await Promise.allSettled(chunks.map((chunk) => limit(() => summarize(chunk, options, ctx)))))
|
|
121
122
|
.filter((x) => x.status === 'fulfilled')
|
|
122
123
|
.map((x) => x.value)
|
|
123
124
|
return summarize(allSummaries.join('\n\n============\n\n'), options, ctx)
|