@botpress/zai 2.0.15 → 2.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,6 +5,9 @@ import JSON5 from 'json5'
5
5
  import { jsonrepair } from 'jsonrepair'
6
6
 
7
7
  import { chunk, isArray } from 'lodash-es'
8
+ import { ZaiContext } from '../context'
9
+ import { Response } from '../response'
10
+ import { getTokenizer } from '../tokenizer'
8
11
  import { fastHash, stringify, takeUntilTokens } from '../utils'
9
12
  import { Zai } from '../zai'
10
13
  import { PROMPT_INPUT_BUFFER } from './constants'
@@ -15,6 +18,8 @@ export type Options = {
15
18
  instructions?: string
16
19
  /** The maximum number of tokens per chunk */
17
20
  chunkLength?: number
21
+ /** Whether to strictly follow the schema or not */
22
+ strict?: boolean
18
23
  }
19
24
 
20
25
  const Options = z.object({
@@ -26,6 +31,7 @@ const Options = z.object({
26
31
  .optional()
27
32
  .describe('The maximum number of tokens per chunk')
28
33
  .default(16_000),
34
+ strict: z.boolean().optional().default(true).describe('Whether to strictly follow the schema or not'),
29
35
  })
30
36
 
31
37
  type __Z<T extends any = any> = { _output: T }
@@ -35,7 +41,7 @@ type AnyObjectOrArray = Record<string, unknown> | Array<unknown>
35
41
  declare module '@botpress/zai' {
36
42
  interface Zai {
37
43
  /** Extracts one or many elements from an arbitrary input */
38
- extract<S extends OfType<AnyObjectOrArray>>(input: unknown, schema: S, options?: Options): Promise<S['_output']>
44
+ extract<S extends OfType<any>>(input: unknown, schema: S, options?: Options): Response<S['_output']>
39
45
  }
40
46
  }
41
47
 
@@ -43,52 +49,61 @@ const START = '■json_start■'
43
49
  const END = '■json_end■'
44
50
  const NO_MORE = '■NO_MORE_ELEMENT■'
45
51
 
46
- Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
47
- this: Zai,
52
+ const extract = async <S extends OfType<AnyObjectOrArray>>(
48
53
  input: unknown,
49
54
  _schema: S,
50
- _options?: Options
51
- ): Promise<S['_output']> {
55
+ _options: Options | undefined,
56
+ ctx: ZaiContext
57
+ ): Promise<S['_output']> => {
58
+ ctx.controller.signal.throwIfAborted()
52
59
  let schema = _schema as any as z.ZodType
53
60
  const options = Options.parse(_options ?? {})
54
- const tokenizer = await this.getTokenizer()
55
- await this.fetchModelDetails()
61
+ const tokenizer = await getTokenizer()
62
+ const model = await ctx.getModel()
56
63
 
57
- const taskId = this.taskId
64
+ const taskId = ctx.taskId
58
65
  const taskType = 'zai.extract'
59
66
 
60
- const PROMPT_COMPONENT = Math.max(this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER, 100)
67
+ const PROMPT_COMPONENT = Math.max(model.input.maxTokens - PROMPT_INPUT_BUFFER, 100)
61
68
 
62
69
  let isArrayOfObjects = false
70
+ let wrappedValue = false
63
71
  const originalSchema = schema
64
72
 
65
73
  const baseType = (schema.naked ? schema.naked() : schema)?.constructor?.name ?? 'unknown'
66
74
 
67
- if (baseType === 'ZodObject') {
68
- // Do nothing
69
- } else if (baseType === 'ZodArray') {
75
+ if (baseType === 'ZodArray') {
76
+ isArrayOfObjects = true
70
77
  let elementType = (schema as any).element
71
78
  if (elementType.naked) {
72
79
  elementType = elementType.naked()
73
80
  }
74
81
 
75
82
  if (elementType?.constructor?.name === 'ZodObject') {
76
- isArrayOfObjects = true
77
83
  schema = elementType
78
84
  } else {
79
- throw new Error('Schema must be a ZodObject or a ZodArray<ZodObject>')
85
+ wrappedValue = true
86
+ schema = z.object({
87
+ value: elementType,
88
+ })
80
89
  }
81
- } else {
82
- throw new Error('Schema must be either a ZuiObject or a ZuiArray<ZuiObject>')
90
+ } else if (baseType !== 'ZodObject') {
91
+ wrappedValue = true
92
+ schema = z.object({
93
+ value: originalSchema,
94
+ })
95
+ }
96
+
97
+ if (!options.strict) {
98
+ try {
99
+ schema = (schema as ZodObject).partial()
100
+ } catch {}
83
101
  }
84
102
 
85
103
  const schemaTypescript = schema.toTypescriptType({ declaration: false })
86
104
  const schemaLength = tokenizer.count(schemaTypescript)
87
105
 
88
- options.chunkLength = Math.min(
89
- options.chunkLength,
90
- this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength
91
- )
106
+ options.chunkLength = Math.min(options.chunkLength, model.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength)
92
107
 
93
108
  const keys = Object.keys((schema as ZodObject).shape)
94
109
 
@@ -97,10 +112,41 @@ Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
97
112
  if (tokenizer.count(inputAsString) > options.chunkLength) {
98
113
  const tokens = tokenizer.split(inputAsString)
99
114
  const chunks = chunk(tokens, options.chunkLength).map((x) => x.join(''))
100
- const all = await Promise.all(chunks.map((chunk) => this.extract(chunk, originalSchema)))
115
+ const all = await Promise.allSettled(
116
+ chunks.map((chunk) =>
117
+ extract(
118
+ chunk,
119
+ originalSchema,
120
+ {
121
+ ...options,
122
+ strict: false, // We don't want to fail on strict mode for sub-chunks
123
+ },
124
+ ctx
125
+ )
126
+ )
127
+ ).then((results) =>
128
+ results.filter((x) => x.status === 'fulfilled').map((x) => (x as PromiseFulfilledResult<S['_output']>).value)
129
+ )
130
+
131
+ ctx.controller.signal.throwIfAborted()
101
132
 
102
133
  // We run this function recursively until all chunks are merged into a single output
103
- return this.extract(all, originalSchema, options)
134
+ const rows = all.map((x, idx) => `<part-${idx + 1}>\n${stringify(x, true)}\n</part-${idx + 1}>`).join('\n')
135
+ return extract(
136
+ `
137
+ The result has been split into ${all.length} parts. Recursively merge the result into the final result.
138
+ When merging arrays, take unique values.
139
+ When merging conflictual (but defined) information, take the most reasonable and frequent value.
140
+ Non-defined values are OK and normal. Don't delete fields because of null values. Focus on defined values.
141
+
142
+ Here's the data:
143
+ ${rows}
144
+
145
+ Merge it back into a final result.`.trim(),
146
+ originalSchema,
147
+ options,
148
+ ctx
149
+ )
104
150
  }
105
151
 
106
152
  const instructions: string[] = []
@@ -126,6 +172,10 @@ Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
126
172
  instructions.push(`The element must be a JSON object with exactly the format: ${START}${shape}${END}`)
127
173
  }
128
174
 
175
+ if (!options.strict) {
176
+ instructions.push('You may ignore any fields that are not present in the input. All keys are optional.')
177
+ }
178
+
129
179
  // All tokens remaining after the input and condition are accounted can be used for examples
130
180
  const EXAMPLES_TOKENS = PROMPT_COMPONENT - tokenizer.count(inputAsString) - tokenizer.count(instructions.join('\n'))
131
181
 
@@ -138,13 +188,14 @@ Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
138
188
  })
139
189
  )
140
190
 
141
- const examples = taskId
142
- ? await this.adapter.getExamples<string, unknown>({
143
- input: inputAsString,
144
- taskType,
145
- taskId,
146
- })
147
- : []
191
+ const examples =
192
+ taskId && ctx.adapter
193
+ ? await ctx.adapter.getExamples<string, unknown>({
194
+ input: inputAsString,
195
+ taskType,
196
+ taskId,
197
+ })
198
+ : []
148
199
 
149
200
  const exactMatch = examples.find((x) => x.key === Key)
150
201
  if (exactMatch) {
@@ -246,7 +297,7 @@ ${END}`.trim()
246
297
  .map(formatExample)
247
298
  .flat()
248
299
 
249
- const { output, meta } = await this.callModel({
300
+ const { meta, extracted } = await ctx.generateContent({
250
301
  systemPrompt: `
251
302
  Extract the following information from the input:
252
303
  ${schemaTypescript}
@@ -263,38 +314,53 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
263
314
  content: formatInput(inputAsString, schemaTypescript, options.instructions ?? ''),
264
315
  },
265
316
  ],
317
+ transform: (text) =>
318
+ (text || '{}')
319
+ ?.split(START)
320
+ .filter((x) => x.trim().length > 0 && x.includes('}'))
321
+ .map((x) => {
322
+ try {
323
+ const json = x.slice(0, x.indexOf(END)).trim()
324
+ const repairedJson = jsonrepair(json)
325
+ const parsedJson = JSON5.parse(repairedJson)
326
+ const safe = schema.safeParse(parsedJson)
327
+
328
+ if (safe.success) {
329
+ return safe.data
330
+ }
331
+
332
+ if (options.strict) {
333
+ throw new JsonParsingError(x, safe.error)
334
+ }
335
+
336
+ return parsedJson
337
+ } catch (error) {
338
+ throw new JsonParsingError(x, error instanceof Error ? error : new Error('Unknown error'))
339
+ }
340
+ })
341
+ .filter((x) => x !== null),
266
342
  })
267
343
 
268
- const answer = output.choices[0]?.content as string
269
-
270
- const elements = answer
271
- .split(START)
272
- .filter((x) => x.trim().length > 0)
273
- .map((x) => {
274
- try {
275
- const json = x.slice(0, x.indexOf(END)).trim()
276
- const repairedJson = jsonrepair(json)
277
- const parsedJson = JSON5.parse(repairedJson)
278
-
279
- return schema.parse(parsedJson)
280
- } catch (error) {
281
- throw new JsonParsingError(x, error instanceof Error ? error : new Error('Unknown error'))
282
- }
283
- })
284
- .filter((x) => x !== null)
285
-
286
344
  let final: any
287
345
 
288
346
  if (isArrayOfObjects) {
289
- final = elements
290
- } else if (elements.length === 0) {
291
- final = schema.parse({})
347
+ final = extracted
348
+ } else if (extracted.length === 0) {
349
+ final = options.strict ? schema.parse({}) : {}
292
350
  } else {
293
- final = elements[0]
351
+ final = extracted[0]
352
+ }
353
+
354
+ if (wrappedValue) {
355
+ if (Array.isArray(final)) {
356
+ final = final.map((x) => ('value' in x ? x.value : x))
357
+ } else {
358
+ final = 'value' in final ? final.value : final
359
+ }
294
360
  }
295
361
 
296
- if (taskId) {
297
- await this.adapter.saveExample({
362
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
363
+ await ctx.adapter.saveExample({
298
364
  key: Key,
299
365
  taskId: `zai/${taskId}`,
300
366
  taskType,
@@ -307,7 +373,7 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
307
373
  output: meta.cost.output,
308
374
  },
309
375
  latency: meta.latency,
310
- model: this.Model,
376
+ model: ctx.modelId,
311
377
  tokens: {
312
378
  input: meta.tokens.input,
313
379
  output: meta.tokens.output,
@@ -318,3 +384,20 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
318
384
 
319
385
  return final
320
386
  }
387
+
388
+ Zai.prototype.extract = function <S extends OfType<AnyObjectOrArray>>(
389
+ this: Zai,
390
+ input: unknown,
391
+ schema: S,
392
+ _options?: Options
393
+ ): Response<S['_output']> {
394
+ const context = new ZaiContext({
395
+ client: this.client,
396
+ modelId: this.Model,
397
+ taskId: this.taskId,
398
+ taskType: 'zai.extract',
399
+ adapter: this.adapter,
400
+ })
401
+
402
+ return new Response<S['_output']>(context, extract(input, schema, _options, context), (result) => result)
403
+ }
@@ -2,6 +2,9 @@
2
2
  import { z } from '@bpinternal/zui'
3
3
 
4
4
  import { clamp } from 'lodash-es'
5
+ import { ZaiContext } from '../context'
6
+ import { Response } from '../response'
7
+ import { getTokenizer } from '../tokenizer'
5
8
  import { fastHash, stringify, takeUntilTokens } from '../utils'
6
9
  import { Zai } from '../zai'
7
10
  import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
@@ -39,22 +42,28 @@ const _Options = z.object({
39
42
  declare module '@botpress/zai' {
40
43
  interface Zai {
41
44
  /** Filters elements of an array against a condition */
42
- filter<T>(input: Array<T>, condition: string, options?: Options): Promise<Array<T>>
45
+ filter<T>(input: Array<T>, condition: string, options?: Options): Response<Array<T>>
43
46
  }
44
47
  }
45
48
 
46
49
  const END = '■END■'
47
50
 
48
- Zai.prototype.filter = async function (this: Zai, input, condition, _options) {
51
+ const filter = async <T>(
52
+ input: Array<T>,
53
+ condition: string,
54
+ _options: Options | undefined,
55
+ ctx: ZaiContext
56
+ ): Promise<Array<T>> => {
57
+ ctx.controller.signal.throwIfAborted()
49
58
  const options = _Options.parse(_options ?? {}) as Options
50
- const tokenizer = await this.getTokenizer()
51
- await this.fetchModelDetails()
59
+ const tokenizer = await getTokenizer()
60
+ const model = await ctx.getModel()
52
61
 
53
- const taskId = this.taskId
62
+ const taskId = ctx.taskId
54
63
  const taskType = 'zai.filter'
55
64
 
56
65
  const MAX_ITEMS_PER_CHUNK = 50
57
- const TOKENS_TOTAL_MAX = this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER
66
+ const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER
58
67
  const TOKENS_EXAMPLES_MAX = Math.floor(Math.max(250, TOKENS_TOTAL_MAX * 0.5))
59
68
  const TOKENS_CONDITION_MAX = clamp(TOKENS_TOTAL_MAX * 0.25, 250, tokenizer.count(condition))
60
69
  const TOKENS_INPUT_ARRAY_MAX = TOKENS_TOTAL_MAX - TOKENS_EXAMPLES_MAX - TOKENS_CONDITION_MAX
@@ -145,18 +154,19 @@ ${examples.map((x, idx) => `■${idx}:${!!x.filter ? 'true' : 'false'}:${x.reaso
145
154
  ]
146
155
 
147
156
  const filterChunk = async (chunk: typeof input) => {
148
- const examples = taskId
149
- ? await this.adapter
150
- .getExamples<string, unknown>({
151
- // The Table API can't search for a huge input string
152
- input: JSON.stringify(chunk).slice(0, 1000),
153
- taskType,
154
- taskId,
155
- })
156
- .then((x) =>
157
- x.map((y) => ({ filter: y.output as boolean, input: y.input, reason: y.explanation }) satisfies Example)
158
- )
159
- : []
157
+ const examples =
158
+ taskId && ctx.adapter
159
+ ? await ctx.adapter
160
+ .getExamples<string, unknown>({
161
+ // The Table API can't search for a huge input string
162
+ input: JSON.stringify(chunk).slice(0, 1000),
163
+ taskType,
164
+ taskId,
165
+ })
166
+ .then((x) =>
167
+ x.map((y) => ({ filter: y.output as boolean, input: y.input, reason: y.explanation }) satisfies Example)
168
+ )
169
+ : []
160
170
 
161
171
  const allExamples = takeUntilTokens([...examples, ...(options.examples ?? [])], TOKENS_EXAMPLES_MAX, (el) =>
162
172
  tokenizer.count(stringify(el.input))
@@ -175,7 +185,7 @@ ${examples.map((x, idx) => `■${idx}:${!!x.filter ? 'true' : 'false'}:${x.reaso
175
185
  },
176
186
  ]
177
187
 
178
- const { output, meta } = await this.callModel({
188
+ const { extracted: partial, meta } = await ctx.generateContent({
179
189
  systemPrompt: `
180
190
  You are given a list of items. Your task is to filter out the items that meet the condition below.
181
191
  You need to return the full list of items with the format:
@@ -198,23 +208,23 @@ The condition is: "${condition}"
198
208
  role: 'user',
199
209
  },
200
210
  ],
201
- })
202
-
203
- const answer = output.choices[0]?.content as string
204
- const indices = answer
205
- .trim()
206
- .split('■')
207
- .filter((x) => x.length > 0)
208
- .map((x) => {
209
- const [idx, filter] = x.split(':')
210
- return { idx: parseInt(idx?.trim() ?? ''), filter: filter?.toLowerCase().trim() === 'true' }
211
- })
211
+ transform: (text) => {
212
+ const indices = text
213
+ .trim()
214
+ .split('■')
215
+ .filter((x) => x.length > 0)
216
+ .map((x) => {
217
+ const [idx, filter] = x.split(':')
218
+ return { idx: parseInt(idx?.trim() ?? ''), filter: filter?.toLowerCase().trim() === 'true' }
219
+ })
212
220
 
213
- const partial = chunk.filter((_, idx) => {
214
- return indices.find((x) => x.idx === idx)?.filter ?? false
221
+ return chunk.filter((_, idx) => {
222
+ return indices.find((x) => x.idx === idx && x.filter) ?? false
223
+ })
224
+ },
215
225
  })
216
226
 
217
- if (taskId) {
227
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
218
228
  const key = fastHash(
219
229
  stringify({
220
230
  taskId,
@@ -224,7 +234,7 @@ The condition is: "${condition}"
224
234
  })
225
235
  )
226
236
 
227
- await this.adapter.saveExample({
237
+ await ctx.adapter.saveExample({
228
238
  key,
229
239
  taskType,
230
240
  taskId,
@@ -237,7 +247,7 @@ The condition is: "${condition}"
237
247
  output: meta.cost.output,
238
248
  },
239
249
  latency: meta.latency,
240
- model: this.Model,
250
+ model: ctx.modelId,
241
251
  tokens: {
242
252
  input: meta.tokens.input,
243
253
  output: meta.tokens.output,
@@ -253,3 +263,20 @@ The condition is: "${condition}"
253
263
 
254
264
  return filteredChunks.flat()
255
265
  }
266
+
267
+ Zai.prototype.filter = function <T>(
268
+ this: Zai,
269
+ input: Array<T>,
270
+ condition: string,
271
+ _options?: Options
272
+ ): Response<Array<T>> {
273
+ const context = new ZaiContext({
274
+ client: this.client,
275
+ modelId: this.Model,
276
+ taskId: this.taskId,
277
+ taskType: 'zai.filter',
278
+ adapter: this.adapter,
279
+ })
280
+
281
+ return new Response<Array<T>>(context, filter(input, condition, _options, context), (result) => result)
282
+ }