@botpress/zai 2.0.14 → 2.0.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/e2e/utils.ts CHANGED
@@ -1,4 +1,5 @@
1
1
  import { Client } from '@botpress/client'
2
+ import { Cognitive } from '@botpress/cognitive'
2
3
  import { type TextTokenizer, getWasmTokenizer } from '@bpinternal/thicktoken'
3
4
  import fs from 'node:fs'
4
5
  import path from 'node:path'
@@ -21,8 +22,8 @@ export const getCachedClient = () => {
21
22
  return getCachedCognitiveClient()
22
23
  }
23
24
 
24
- export const getZai = () => {
25
- const client = getCachedClient()
25
+ export const getZai = (cognitive?: Cognitive) => {
26
+ const client = cognitive || getCachedClient()
26
27
  return new Zai({ client })
27
28
  }
28
29
 
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@botpress/zai",
3
3
  "description": "Zui AI (zai) – An LLM utility library written on top of Zui and the Botpress API",
4
- "version": "2.0.14",
4
+ "version": "2.0.16",
5
5
  "main": "./dist/index.js",
6
6
  "types": "./dist/index.d.ts",
7
7
  "exports": {
@@ -15,6 +15,8 @@ export type Options = {
15
15
  instructions?: string
16
16
  /** The maximum number of tokens per chunk */
17
17
  chunkLength?: number
18
+ /** Whether to strictly follow the schema or not */
19
+ strict?: boolean
18
20
  }
19
21
 
20
22
  const Options = z.object({
@@ -26,6 +28,7 @@ const Options = z.object({
26
28
  .optional()
27
29
  .describe('The maximum number of tokens per chunk')
28
30
  .default(16_000),
31
+ strict: z.boolean().optional().default(true).describe('Whether to strictly follow the schema or not'),
29
32
  })
30
33
 
31
34
  type __Z<T extends any = any> = { _output: T }
@@ -35,7 +38,7 @@ type AnyObjectOrArray = Record<string, unknown> | Array<unknown>
35
38
  declare module '@botpress/zai' {
36
39
  interface Zai {
37
40
  /** Extracts one or many elements from an arbitrary input */
38
- extract<S extends OfType<AnyObjectOrArray>>(input: unknown, schema: S, options?: Options): Promise<S['_output']>
41
+ extract<S extends OfType<any>>(input: unknown, schema: S, options?: Options): Promise<S['_output']>
39
42
  }
40
43
  }
41
44
 
@@ -60,26 +63,37 @@ Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
60
63
  const PROMPT_COMPONENT = Math.max(this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER, 100)
61
64
 
62
65
  let isArrayOfObjects = false
66
+ let wrappedValue = false
63
67
  const originalSchema = schema
64
68
 
65
69
  const baseType = (schema.naked ? schema.naked() : schema)?.constructor?.name ?? 'unknown'
66
70
 
67
- if (baseType === 'ZodObject') {
68
- // Do nothing
69
- } else if (baseType === 'ZodArray') {
71
+ if (baseType === 'ZodArray') {
72
+ isArrayOfObjects = true
70
73
  let elementType = (schema as any).element
71
74
  if (elementType.naked) {
72
75
  elementType = elementType.naked()
73
76
  }
74
77
 
75
78
  if (elementType?.constructor?.name === 'ZodObject') {
76
- isArrayOfObjects = true
77
79
  schema = elementType
78
80
  } else {
79
- throw new Error('Schema must be a ZodObject or a ZodArray<ZodObject>')
81
+ wrappedValue = true
82
+ schema = z.object({
83
+ value: elementType,
84
+ })
80
85
  }
81
- } else {
82
- throw new Error('Schema must be either a ZuiObject or a ZuiArray<ZuiObject>')
86
+ } else if (baseType !== 'ZodObject') {
87
+ wrappedValue = true
88
+ schema = z.object({
89
+ value: originalSchema,
90
+ })
91
+ }
92
+
93
+ if (!options.strict) {
94
+ try {
95
+ schema = (schema as ZodObject).partial()
96
+ } catch {}
83
97
  }
84
98
 
85
99
  const schemaTypescript = schema.toTypescriptType({ declaration: false })
@@ -92,20 +106,38 @@ Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
92
106
 
93
107
  const keys = Object.keys((schema as ZodObject).shape)
94
108
 
95
- let inputAsString = stringify(input)
109
+ const inputAsString = stringify(input)
96
110
 
97
111
  if (tokenizer.count(inputAsString) > options.chunkLength) {
98
- // If we want to extract an array of objects, we will run this function recursively
99
- if (isArrayOfObjects) {
100
- const tokens = tokenizer.split(inputAsString)
101
- const chunks = chunk(tokens, options.chunkLength).map((x) => x.join(''))
102
- const all = await Promise.all(chunks.map((chunk) => this.extract(chunk, originalSchema)))
112
+ const tokens = tokenizer.split(inputAsString)
113
+ const chunks = chunk(tokens, options.chunkLength).map((x) => x.join(''))
114
+ const all = await Promise.allSettled(
115
+ chunks.map((chunk) =>
116
+ this.extract(chunk, originalSchema, {
117
+ ...options,
118
+ strict: false, // We don't want to fail on strict mode for sub-chunks
119
+ })
120
+ )
121
+ ).then((results) =>
122
+ results.filter((x) => x.status === 'fulfilled').map((x) => (x as PromiseFulfilledResult<S['_output']>).value)
123
+ )
103
124
 
104
- return all.flat() as any as S['_output']
105
- } else {
106
- // Truncate the input to fit the model's input size
107
- inputAsString = tokenizer.truncate(stringify(input), options.chunkLength)
108
- }
125
+ // We run this function recursively until all chunks are merged into a single output
126
+ const rows = all.map((x, idx) => `<part-${idx + 1}>\n${stringify(x, true)}\n</part-${idx + 1}>`).join('\n')
127
+ return this.extract(
128
+ `
129
+ The result has been split into ${all.length} parts. Recursively merge the result into the final result.
130
+ When merging arrays, take unique values.
131
+ When merging conflictual (but defined) information, take the most reasonable and frequent value.
132
+ Non-defined values are OK and normal. Don't delete fields because of null values. Focus on defined values.
133
+
134
+ Here's the data:
135
+ ${rows}
136
+
137
+ Merge it back into a final result.`.trim(),
138
+ originalSchema,
139
+ options
140
+ )
109
141
  }
110
142
 
111
143
  const instructions: string[] = []
@@ -131,6 +163,10 @@ Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
131
163
  instructions.push(`The element must be a JSON object with exactly the format: ${START}${shape}${END}`)
132
164
  }
133
165
 
166
+ if (!options.strict) {
167
+ instructions.push('You may ignore any fields that are not present in the input. All keys are optional.')
168
+ }
169
+
134
170
  // All tokens remaining after the input and condition are accounted can be used for examples
135
171
  const EXAMPLES_TOKENS = PROMPT_COMPONENT - tokenizer.count(inputAsString) - tokenizer.count(instructions.join('\n'))
136
172
 
@@ -270,18 +306,27 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
270
306
  ],
271
307
  })
272
308
 
273
- const answer = output.choices[0]?.content as string
309
+ const answer = (output.choices[0]?.content ?? '{}') as string
274
310
 
275
311
  const elements = answer
276
- .split(START)
277
- .filter((x) => x.trim().length > 0)
312
+ ?.split(START)
313
+ .filter((x) => x.trim().length > 0 && x.includes('}'))
278
314
  .map((x) => {
279
315
  try {
280
316
  const json = x.slice(0, x.indexOf(END)).trim()
281
317
  const repairedJson = jsonrepair(json)
282
318
  const parsedJson = JSON5.parse(repairedJson)
319
+ const safe = schema.safeParse(parsedJson)
320
+
321
+ if (safe.success) {
322
+ return safe.data
323
+ }
324
+
325
+ if (options.strict) {
326
+ throw new JsonParsingError(x, safe.error)
327
+ }
283
328
 
284
- return schema.parse(parsedJson)
329
+ return parsedJson
285
330
  } catch (error) {
286
331
  throw new JsonParsingError(x, error instanceof Error ? error : new Error('Unknown error'))
287
332
  }
@@ -293,11 +338,19 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
293
338
  if (isArrayOfObjects) {
294
339
  final = elements
295
340
  } else if (elements.length === 0) {
296
- final = schema.parse({})
341
+ final = options.strict ? schema.parse({}) : {}
297
342
  } else {
298
343
  final = elements[0]
299
344
  }
300
345
 
346
+ if (wrappedValue) {
347
+ if (Array.isArray(final)) {
348
+ final = final.map((x) => ('value' in x ? x.value : x))
349
+ } else {
350
+ final = 'value' in final ? final.value : final
351
+ }
352
+ }
353
+
301
354
  if (taskId) {
302
355
  await this.adapter.saveExample({
303
356
  key: Key,