@botpress/zai 2.0.14 → 2.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +3 -2
- package/dist/operations/extract.js +72 -20
- package/e2e/data/cache.jsonl +157 -0
- package/e2e/utils.ts +3 -2
- package/package.json +1 -1
- package/src/operations/extract.ts +77 -24
package/e2e/utils.ts
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import { Client } from '@botpress/client'
|
|
2
|
+
import { Cognitive } from '@botpress/cognitive'
|
|
2
3
|
import { type TextTokenizer, getWasmTokenizer } from '@bpinternal/thicktoken'
|
|
3
4
|
import fs from 'node:fs'
|
|
4
5
|
import path from 'node:path'
|
|
@@ -21,8 +22,8 @@ export const getCachedClient = () => {
|
|
|
21
22
|
return getCachedCognitiveClient()
|
|
22
23
|
}
|
|
23
24
|
|
|
24
|
-
export const getZai = () => {
|
|
25
|
-
const client = getCachedClient()
|
|
25
|
+
export const getZai = (cognitive?: Cognitive) => {
|
|
26
|
+
const client = cognitive || getCachedClient()
|
|
26
27
|
return new Zai({ client })
|
|
27
28
|
}
|
|
28
29
|
|
package/package.json
CHANGED
|
@@ -15,6 +15,8 @@ export type Options = {
|
|
|
15
15
|
instructions?: string
|
|
16
16
|
/** The maximum number of tokens per chunk */
|
|
17
17
|
chunkLength?: number
|
|
18
|
+
/** Whether to strictly follow the schema or not */
|
|
19
|
+
strict?: boolean
|
|
18
20
|
}
|
|
19
21
|
|
|
20
22
|
const Options = z.object({
|
|
@@ -26,6 +28,7 @@ const Options = z.object({
|
|
|
26
28
|
.optional()
|
|
27
29
|
.describe('The maximum number of tokens per chunk')
|
|
28
30
|
.default(16_000),
|
|
31
|
+
strict: z.boolean().optional().default(true).describe('Whether to strictly follow the schema or not'),
|
|
29
32
|
})
|
|
30
33
|
|
|
31
34
|
type __Z<T extends any = any> = { _output: T }
|
|
@@ -35,7 +38,7 @@ type AnyObjectOrArray = Record<string, unknown> | Array<unknown>
|
|
|
35
38
|
declare module '@botpress/zai' {
|
|
36
39
|
interface Zai {
|
|
37
40
|
/** Extracts one or many elements from an arbitrary input */
|
|
38
|
-
extract<S extends OfType<
|
|
41
|
+
extract<S extends OfType<any>>(input: unknown, schema: S, options?: Options): Promise<S['_output']>
|
|
39
42
|
}
|
|
40
43
|
}
|
|
41
44
|
|
|
@@ -60,26 +63,37 @@ Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
|
|
|
60
63
|
const PROMPT_COMPONENT = Math.max(this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER, 100)
|
|
61
64
|
|
|
62
65
|
let isArrayOfObjects = false
|
|
66
|
+
let wrappedValue = false
|
|
63
67
|
const originalSchema = schema
|
|
64
68
|
|
|
65
69
|
const baseType = (schema.naked ? schema.naked() : schema)?.constructor?.name ?? 'unknown'
|
|
66
70
|
|
|
67
|
-
if (baseType === '
|
|
68
|
-
|
|
69
|
-
} else if (baseType === 'ZodArray') {
|
|
71
|
+
if (baseType === 'ZodArray') {
|
|
72
|
+
isArrayOfObjects = true
|
|
70
73
|
let elementType = (schema as any).element
|
|
71
74
|
if (elementType.naked) {
|
|
72
75
|
elementType = elementType.naked()
|
|
73
76
|
}
|
|
74
77
|
|
|
75
78
|
if (elementType?.constructor?.name === 'ZodObject') {
|
|
76
|
-
isArrayOfObjects = true
|
|
77
79
|
schema = elementType
|
|
78
80
|
} else {
|
|
79
|
-
|
|
81
|
+
wrappedValue = true
|
|
82
|
+
schema = z.object({
|
|
83
|
+
value: elementType,
|
|
84
|
+
})
|
|
80
85
|
}
|
|
81
|
-
} else {
|
|
82
|
-
|
|
86
|
+
} else if (baseType !== 'ZodObject') {
|
|
87
|
+
wrappedValue = true
|
|
88
|
+
schema = z.object({
|
|
89
|
+
value: originalSchema,
|
|
90
|
+
})
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
if (!options.strict) {
|
|
94
|
+
try {
|
|
95
|
+
schema = (schema as ZodObject).partial()
|
|
96
|
+
} catch {}
|
|
83
97
|
}
|
|
84
98
|
|
|
85
99
|
const schemaTypescript = schema.toTypescriptType({ declaration: false })
|
|
@@ -92,20 +106,38 @@ Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
|
|
|
92
106
|
|
|
93
107
|
const keys = Object.keys((schema as ZodObject).shape)
|
|
94
108
|
|
|
95
|
-
|
|
109
|
+
const inputAsString = stringify(input)
|
|
96
110
|
|
|
97
111
|
if (tokenizer.count(inputAsString) > options.chunkLength) {
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
112
|
+
const tokens = tokenizer.split(inputAsString)
|
|
113
|
+
const chunks = chunk(tokens, options.chunkLength).map((x) => x.join(''))
|
|
114
|
+
const all = await Promise.allSettled(
|
|
115
|
+
chunks.map((chunk) =>
|
|
116
|
+
this.extract(chunk, originalSchema, {
|
|
117
|
+
...options,
|
|
118
|
+
strict: false, // We don't want to fail on strict mode for sub-chunks
|
|
119
|
+
})
|
|
120
|
+
)
|
|
121
|
+
).then((results) =>
|
|
122
|
+
results.filter((x) => x.status === 'fulfilled').map((x) => (x as PromiseFulfilledResult<S['_output']>).value)
|
|
123
|
+
)
|
|
103
124
|
|
|
104
|
-
|
|
105
|
-
}
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
125
|
+
// We run this function recursively until all chunks are merged into a single output
|
|
126
|
+
const rows = all.map((x, idx) => `<part-${idx + 1}>\n${stringify(x, true)}\n</part-${idx + 1}>`).join('\n')
|
|
127
|
+
return this.extract(
|
|
128
|
+
`
|
|
129
|
+
The result has been split into ${all.length} parts. Recursively merge the result into the final result.
|
|
130
|
+
When merging arrays, take unique values.
|
|
131
|
+
When merging conflictual (but defined) information, take the most reasonable and frequent value.
|
|
132
|
+
Non-defined values are OK and normal. Don't delete fields because of null values. Focus on defined values.
|
|
133
|
+
|
|
134
|
+
Here's the data:
|
|
135
|
+
${rows}
|
|
136
|
+
|
|
137
|
+
Merge it back into a final result.`.trim(),
|
|
138
|
+
originalSchema,
|
|
139
|
+
options
|
|
140
|
+
)
|
|
109
141
|
}
|
|
110
142
|
|
|
111
143
|
const instructions: string[] = []
|
|
@@ -131,6 +163,10 @@ Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
|
|
|
131
163
|
instructions.push(`The element must be a JSON object with exactly the format: ${START}${shape}${END}`)
|
|
132
164
|
}
|
|
133
165
|
|
|
166
|
+
if (!options.strict) {
|
|
167
|
+
instructions.push('You may ignore any fields that are not present in the input. All keys are optional.')
|
|
168
|
+
}
|
|
169
|
+
|
|
134
170
|
// All tokens remaining after the input and condition are accounted can be used for examples
|
|
135
171
|
const EXAMPLES_TOKENS = PROMPT_COMPONENT - tokenizer.count(inputAsString) - tokenizer.count(instructions.join('\n'))
|
|
136
172
|
|
|
@@ -270,18 +306,27 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
|
|
|
270
306
|
],
|
|
271
307
|
})
|
|
272
308
|
|
|
273
|
-
const answer = output.choices[0]?.content as string
|
|
309
|
+
const answer = (output.choices[0]?.content ?? '{}') as string
|
|
274
310
|
|
|
275
311
|
const elements = answer
|
|
276
|
-
|
|
277
|
-
.filter((x) => x.trim().length > 0)
|
|
312
|
+
?.split(START)
|
|
313
|
+
.filter((x) => x.trim().length > 0 && x.includes('}'))
|
|
278
314
|
.map((x) => {
|
|
279
315
|
try {
|
|
280
316
|
const json = x.slice(0, x.indexOf(END)).trim()
|
|
281
317
|
const repairedJson = jsonrepair(json)
|
|
282
318
|
const parsedJson = JSON5.parse(repairedJson)
|
|
319
|
+
const safe = schema.safeParse(parsedJson)
|
|
320
|
+
|
|
321
|
+
if (safe.success) {
|
|
322
|
+
return safe.data
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
if (options.strict) {
|
|
326
|
+
throw new JsonParsingError(x, safe.error)
|
|
327
|
+
}
|
|
283
328
|
|
|
284
|
-
return
|
|
329
|
+
return parsedJson
|
|
285
330
|
} catch (error) {
|
|
286
331
|
throw new JsonParsingError(x, error instanceof Error ? error : new Error('Unknown error'))
|
|
287
332
|
}
|
|
@@ -293,11 +338,19 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
|
|
|
293
338
|
if (isArrayOfObjects) {
|
|
294
339
|
final = elements
|
|
295
340
|
} else if (elements.length === 0) {
|
|
296
|
-
final = schema.parse({})
|
|
341
|
+
final = options.strict ? schema.parse({}) : {}
|
|
297
342
|
} else {
|
|
298
343
|
final = elements[0]
|
|
299
344
|
}
|
|
300
345
|
|
|
346
|
+
if (wrappedValue) {
|
|
347
|
+
if (Array.isArray(final)) {
|
|
348
|
+
final = final.map((x) => ('value' in x ? x.value : x))
|
|
349
|
+
} else {
|
|
350
|
+
final = 'value' in final ? final.value : final
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
|
|
301
354
|
if (taskId) {
|
|
302
355
|
await this.adapter.saveExample({
|
|
303
356
|
key: Key,
|