@botpress/zai 2.0.16 → 2.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,6 +5,9 @@ import JSON5 from 'json5'
5
5
  import { jsonrepair } from 'jsonrepair'
6
6
 
7
7
  import { chunk, isArray } from 'lodash-es'
8
+ import { ZaiContext } from '../context'
9
+ import { Response } from '../response'
10
+ import { getTokenizer } from '../tokenizer'
8
11
  import { fastHash, stringify, takeUntilTokens } from '../utils'
9
12
  import { Zai } from '../zai'
10
13
  import { PROMPT_INPUT_BUFFER } from './constants'
@@ -38,7 +41,7 @@ type AnyObjectOrArray = Record<string, unknown> | Array<unknown>
38
41
  declare module '@botpress/zai' {
39
42
  interface Zai {
40
43
  /** Extracts one or many elements from an arbitrary input */
41
- extract<S extends OfType<any>>(input: unknown, schema: S, options?: Options): Promise<S['_output']>
44
+ extract<S extends OfType<any>>(input: unknown, schema: S, options?: Options): Response<S['_output']>
42
45
  }
43
46
  }
44
47
 
@@ -46,21 +49,22 @@ const START = '■json_start■'
46
49
  const END = '■json_end■'
47
50
  const NO_MORE = '■NO_MORE_ELEMENT■'
48
51
 
49
- Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
50
- this: Zai,
52
+ const extract = async <S extends OfType<AnyObjectOrArray>>(
51
53
  input: unknown,
52
54
  _schema: S,
53
- _options?: Options
54
- ): Promise<S['_output']> {
55
+ _options: Options | undefined,
56
+ ctx: ZaiContext
57
+ ): Promise<S['_output']> => {
58
+ ctx.controller.signal.throwIfAborted()
55
59
  let schema = _schema as any as z.ZodType
56
60
  const options = Options.parse(_options ?? {})
57
- const tokenizer = await this.getTokenizer()
58
- await this.fetchModelDetails()
61
+ const tokenizer = await getTokenizer()
62
+ const model = await ctx.getModel()
59
63
 
60
- const taskId = this.taskId
64
+ const taskId = ctx.taskId
61
65
  const taskType = 'zai.extract'
62
66
 
63
- const PROMPT_COMPONENT = Math.max(this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER, 100)
67
+ const PROMPT_COMPONENT = Math.max(model.input.maxTokens - PROMPT_INPUT_BUFFER, 100)
64
68
 
65
69
  let isArrayOfObjects = false
66
70
  let wrappedValue = false
@@ -99,10 +103,7 @@ Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
99
103
  const schemaTypescript = schema.toTypescriptType({ declaration: false })
100
104
  const schemaLength = tokenizer.count(schemaTypescript)
101
105
 
102
- options.chunkLength = Math.min(
103
- options.chunkLength,
104
- this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength
105
- )
106
+ options.chunkLength = Math.min(options.chunkLength, model.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength)
106
107
 
107
108
  const keys = Object.keys((schema as ZodObject).shape)
108
109
 
@@ -113,18 +114,25 @@ Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
113
114
  const chunks = chunk(tokens, options.chunkLength).map((x) => x.join(''))
114
115
  const all = await Promise.allSettled(
115
116
  chunks.map((chunk) =>
116
- this.extract(chunk, originalSchema, {
117
- ...options,
118
- strict: false, // We don't want to fail on strict mode for sub-chunks
119
- })
117
+ extract(
118
+ chunk,
119
+ originalSchema,
120
+ {
121
+ ...options,
122
+ strict: false, // We don't want to fail on strict mode for sub-chunks
123
+ },
124
+ ctx
125
+ )
120
126
  )
121
127
  ).then((results) =>
122
128
  results.filter((x) => x.status === 'fulfilled').map((x) => (x as PromiseFulfilledResult<S['_output']>).value)
123
129
  )
124
130
 
131
+ ctx.controller.signal.throwIfAborted()
132
+
125
133
  // We run this function recursively until all chunks are merged into a single output
126
134
  const rows = all.map((x, idx) => `<part-${idx + 1}>\n${stringify(x, true)}\n</part-${idx + 1}>`).join('\n')
127
- return this.extract(
135
+ return extract(
128
136
  `
129
137
  The result has been split into ${all.length} parts. Recursively merge the result into the final result.
130
138
  When merging arrays, take unique values.
@@ -136,7 +144,8 @@ ${rows}
136
144
 
137
145
  Merge it back into a final result.`.trim(),
138
146
  originalSchema,
139
- options
147
+ options,
148
+ ctx
140
149
  )
141
150
  }
142
151
 
@@ -179,13 +188,14 @@ Merge it back into a final result.`.trim(),
179
188
  })
180
189
  )
181
190
 
182
- const examples = taskId
183
- ? await this.adapter.getExamples<string, unknown>({
184
- input: inputAsString,
185
- taskType,
186
- taskId,
187
- })
188
- : []
191
+ const examples =
192
+ taskId && ctx.adapter
193
+ ? await ctx.adapter.getExamples<string, unknown>({
194
+ input: inputAsString,
195
+ taskType,
196
+ taskId,
197
+ })
198
+ : []
189
199
 
190
200
  const exactMatch = examples.find((x) => x.key === Key)
191
201
  if (exactMatch) {
@@ -287,7 +297,7 @@ ${END}`.trim()
287
297
  .map(formatExample)
288
298
  .flat()
289
299
 
290
- const { output, meta } = await this.callModel({
300
+ const { meta, extracted } = await ctx.generateContent({
291
301
  systemPrompt: `
292
302
  Extract the following information from the input:
293
303
  ${schemaTypescript}
@@ -304,43 +314,41 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
304
314
  content: formatInput(inputAsString, schemaTypescript, options.instructions ?? ''),
305
315
  },
306
316
  ],
317
+ transform: (text) =>
318
+ (text || '{}')
319
+ ?.split(START)
320
+ .filter((x) => x.trim().length > 0 && x.includes('}'))
321
+ .map((x) => {
322
+ try {
323
+ const json = x.slice(0, x.indexOf(END)).trim()
324
+ const repairedJson = jsonrepair(json)
325
+ const parsedJson = JSON5.parse(repairedJson)
326
+ const safe = schema.safeParse(parsedJson)
327
+
328
+ if (safe.success) {
329
+ return safe.data
330
+ }
331
+
332
+ if (options.strict) {
333
+ throw new JsonParsingError(x, safe.error)
334
+ }
335
+
336
+ return parsedJson
337
+ } catch (error) {
338
+ throw new JsonParsingError(x, error instanceof Error ? error : new Error('Unknown error'))
339
+ }
340
+ })
341
+ .filter((x) => x !== null),
307
342
  })
308
343
 
309
- const answer = (output.choices[0]?.content ?? '{}') as string
310
-
311
- const elements = answer
312
- ?.split(START)
313
- .filter((x) => x.trim().length > 0 && x.includes('}'))
314
- .map((x) => {
315
- try {
316
- const json = x.slice(0, x.indexOf(END)).trim()
317
- const repairedJson = jsonrepair(json)
318
- const parsedJson = JSON5.parse(repairedJson)
319
- const safe = schema.safeParse(parsedJson)
320
-
321
- if (safe.success) {
322
- return safe.data
323
- }
324
-
325
- if (options.strict) {
326
- throw new JsonParsingError(x, safe.error)
327
- }
328
-
329
- return parsedJson
330
- } catch (error) {
331
- throw new JsonParsingError(x, error instanceof Error ? error : new Error('Unknown error'))
332
- }
333
- })
334
- .filter((x) => x !== null)
335
-
336
344
  let final: any
337
345
 
338
346
  if (isArrayOfObjects) {
339
- final = elements
340
- } else if (elements.length === 0) {
347
+ final = extracted
348
+ } else if (extracted.length === 0) {
341
349
  final = options.strict ? schema.parse({}) : {}
342
350
  } else {
343
- final = elements[0]
351
+ final = extracted[0]
344
352
  }
345
353
 
346
354
  if (wrappedValue) {
@@ -351,8 +359,8 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
351
359
  }
352
360
  }
353
361
 
354
- if (taskId) {
355
- await this.adapter.saveExample({
362
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
363
+ await ctx.adapter.saveExample({
356
364
  key: Key,
357
365
  taskId: `zai/${taskId}`,
358
366
  taskType,
@@ -365,7 +373,7 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
365
373
  output: meta.cost.output,
366
374
  },
367
375
  latency: meta.latency,
368
- model: this.Model,
376
+ model: ctx.modelId,
369
377
  tokens: {
370
378
  input: meta.tokens.input,
371
379
  output: meta.tokens.output,
@@ -376,3 +384,20 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
376
384
 
377
385
  return final
378
386
  }
387
+
388
+ Zai.prototype.extract = function <S extends OfType<AnyObjectOrArray>>(
389
+ this: Zai,
390
+ input: unknown,
391
+ schema: S,
392
+ _options?: Options
393
+ ): Response<S['_output']> {
394
+ const context = new ZaiContext({
395
+ client: this.client,
396
+ modelId: this.Model,
397
+ taskId: this.taskId,
398
+ taskType: 'zai.extract',
399
+ adapter: this.adapter,
400
+ })
401
+
402
+ return new Response<S['_output']>(context, extract(input, schema, _options, context), (result) => result)
403
+ }
@@ -2,6 +2,9 @@
2
2
  import { z } from '@bpinternal/zui'
3
3
 
4
4
  import { clamp } from 'lodash-es'
5
+ import { ZaiContext } from '../context'
6
+ import { Response } from '../response'
7
+ import { getTokenizer } from '../tokenizer'
5
8
  import { fastHash, stringify, takeUntilTokens } from '../utils'
6
9
  import { Zai } from '../zai'
7
10
  import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
@@ -39,22 +42,28 @@ const _Options = z.object({
39
42
  declare module '@botpress/zai' {
40
43
  interface Zai {
41
44
  /** Filters elements of an array against a condition */
42
- filter<T>(input: Array<T>, condition: string, options?: Options): Promise<Array<T>>
45
+ filter<T>(input: Array<T>, condition: string, options?: Options): Response<Array<T>>
43
46
  }
44
47
  }
45
48
 
46
49
  const END = '■END■'
47
50
 
48
- Zai.prototype.filter = async function (this: Zai, input, condition, _options) {
51
+ const filter = async <T>(
52
+ input: Array<T>,
53
+ condition: string,
54
+ _options: Options | undefined,
55
+ ctx: ZaiContext
56
+ ): Promise<Array<T>> => {
57
+ ctx.controller.signal.throwIfAborted()
49
58
  const options = _Options.parse(_options ?? {}) as Options
50
- const tokenizer = await this.getTokenizer()
51
- await this.fetchModelDetails()
59
+ const tokenizer = await getTokenizer()
60
+ const model = await ctx.getModel()
52
61
 
53
- const taskId = this.taskId
62
+ const taskId = ctx.taskId
54
63
  const taskType = 'zai.filter'
55
64
 
56
65
  const MAX_ITEMS_PER_CHUNK = 50
57
- const TOKENS_TOTAL_MAX = this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER
66
+ const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER
58
67
  const TOKENS_EXAMPLES_MAX = Math.floor(Math.max(250, TOKENS_TOTAL_MAX * 0.5))
59
68
  const TOKENS_CONDITION_MAX = clamp(TOKENS_TOTAL_MAX * 0.25, 250, tokenizer.count(condition))
60
69
  const TOKENS_INPUT_ARRAY_MAX = TOKENS_TOTAL_MAX - TOKENS_EXAMPLES_MAX - TOKENS_CONDITION_MAX
@@ -145,18 +154,19 @@ ${examples.map((x, idx) => `■${idx}:${!!x.filter ? 'true' : 'false'}:${x.reaso
145
154
  ]
146
155
 
147
156
  const filterChunk = async (chunk: typeof input) => {
148
- const examples = taskId
149
- ? await this.adapter
150
- .getExamples<string, unknown>({
151
- // The Table API can't search for a huge input string
152
- input: JSON.stringify(chunk).slice(0, 1000),
153
- taskType,
154
- taskId,
155
- })
156
- .then((x) =>
157
- x.map((y) => ({ filter: y.output as boolean, input: y.input, reason: y.explanation }) satisfies Example)
158
- )
159
- : []
157
+ const examples =
158
+ taskId && ctx.adapter
159
+ ? await ctx.adapter
160
+ .getExamples<string, unknown>({
161
+ // The Table API can't search for a huge input string
162
+ input: JSON.stringify(chunk).slice(0, 1000),
163
+ taskType,
164
+ taskId,
165
+ })
166
+ .then((x) =>
167
+ x.map((y) => ({ filter: y.output as boolean, input: y.input, reason: y.explanation }) satisfies Example)
168
+ )
169
+ : []
160
170
 
161
171
  const allExamples = takeUntilTokens([...examples, ...(options.examples ?? [])], TOKENS_EXAMPLES_MAX, (el) =>
162
172
  tokenizer.count(stringify(el.input))
@@ -175,7 +185,7 @@ ${examples.map((x, idx) => `■${idx}:${!!x.filter ? 'true' : 'false'}:${x.reaso
175
185
  },
176
186
  ]
177
187
 
178
- const { output, meta } = await this.callModel({
188
+ const { extracted: partial, meta } = await ctx.generateContent({
179
189
  systemPrompt: `
180
190
  You are given a list of items. Your task is to filter out the items that meet the condition below.
181
191
  You need to return the full list of items with the format:
@@ -198,23 +208,23 @@ The condition is: "${condition}"
198
208
  role: 'user',
199
209
  },
200
210
  ],
201
- })
202
-
203
- const answer = output.choices[0]?.content as string
204
- const indices = answer
205
- .trim()
206
- .split('■')
207
- .filter((x) => x.length > 0)
208
- .map((x) => {
209
- const [idx, filter] = x.split(':')
210
- return { idx: parseInt(idx?.trim() ?? ''), filter: filter?.toLowerCase().trim() === 'true' }
211
- })
211
+ transform: (text) => {
212
+ const indices = text
213
+ .trim()
214
+ .split('■')
215
+ .filter((x) => x.length > 0)
216
+ .map((x) => {
217
+ const [idx, filter] = x.split(':')
218
+ return { idx: parseInt(idx?.trim() ?? ''), filter: filter?.toLowerCase().trim() === 'true' }
219
+ })
212
220
 
213
- const partial = chunk.filter((_, idx) => {
214
- return indices.find((x) => x.idx === idx)?.filter ?? false
221
+ return chunk.filter((_, idx) => {
222
+ return indices.find((x) => x.idx === idx && x.filter) ?? false
223
+ })
224
+ },
215
225
  })
216
226
 
217
- if (taskId) {
227
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
218
228
  const key = fastHash(
219
229
  stringify({
220
230
  taskId,
@@ -224,7 +234,7 @@ The condition is: "${condition}"
224
234
  })
225
235
  )
226
236
 
227
- await this.adapter.saveExample({
237
+ await ctx.adapter.saveExample({
228
238
  key,
229
239
  taskType,
230
240
  taskId,
@@ -237,7 +247,7 @@ The condition is: "${condition}"
237
247
  output: meta.cost.output,
238
248
  },
239
249
  latency: meta.latency,
240
- model: this.Model,
250
+ model: ctx.modelId,
241
251
  tokens: {
242
252
  input: meta.tokens.input,
243
253
  output: meta.tokens.output,
@@ -253,3 +263,20 @@ The condition is: "${condition}"
253
263
 
254
264
  return filteredChunks.flat()
255
265
  }
266
+
267
+ Zai.prototype.filter = function <T>(
268
+ this: Zai,
269
+ input: Array<T>,
270
+ condition: string,
271
+ _options?: Options
272
+ ): Response<Array<T>> {
273
+ const context = new ZaiContext({
274
+ client: this.client,
275
+ modelId: this.Model,
276
+ taskId: this.taskId,
277
+ taskType: 'zai.filter',
278
+ adapter: this.adapter,
279
+ })
280
+
281
+ return new Response<Array<T>>(context, filter(input, condition, _options, context), (result) => result)
282
+ }
@@ -1,7 +1,10 @@
1
1
  // eslint-disable consistent-type-definitions
2
2
  import { z } from '@bpinternal/zui'
3
3
 
4
- import { clamp, chunk } from 'lodash-es'
4
+ import { chunk, clamp } from 'lodash-es'
5
+ import { ZaiContext } from '../context'
6
+ import { Response } from '../response'
7
+ import { getTokenizer } from '../tokenizer'
5
8
  import { fastHash, stringify, takeUntilTokens } from '../utils'
6
9
  import { Zai } from '../zai'
7
10
  import { PROMPT_INPUT_BUFFER } from './constants'
@@ -83,13 +86,16 @@ declare module '@botpress/zai' {
83
86
  input: unknown,
84
87
  labels: Labels<T>,
85
88
  options?: Options<T>
86
- ): Promise<{
87
- [K in T]: {
88
- explanation: string
89
- value: boolean
90
- confidence: number
91
- }
92
- }>
89
+ ): Response<
90
+ {
91
+ [K in T]: {
92
+ explanation: string
93
+ value: boolean
94
+ confidence: number
95
+ }
96
+ },
97
+ { [K in T]: boolean }
98
+ >
93
99
  }
94
100
  }
95
101
 
@@ -124,21 +130,28 @@ const getConfidence = (label: Label) => {
124
130
  }
125
131
  }
126
132
 
127
- Zai.prototype.label = async function <T extends string>(
128
- this: Zai,
133
+ const label = async <T extends string>(
129
134
  input: unknown,
130
135
  _labels: Labels<T>,
131
- _options: Options<T> | undefined
132
- ) {
136
+ _options: Options<T> | undefined,
137
+ ctx: ZaiContext
138
+ ): Promise<{
139
+ [K in T]: {
140
+ explanation: string
141
+ value: boolean
142
+ confidence: number
143
+ }
144
+ }> => {
145
+ ctx.controller.signal.throwIfAborted()
133
146
  const options = _Options.parse(_options ?? {}) as unknown as Options<T>
134
147
  const labels = _Labels.parse(_labels) as Labels<T>
135
- const tokenizer = await this.getTokenizer()
136
- await this.fetchModelDetails()
148
+ const tokenizer = await getTokenizer()
149
+ const model = await ctx.getModel()
137
150
 
138
- const taskId = this.taskId
151
+ const taskId = ctx.taskId
139
152
  const taskType = 'zai.label'
140
153
 
141
- const TOTAL_MAX_TOKENS = clamp(options.chunkLength, 1000, this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER)
154
+ const TOTAL_MAX_TOKENS = clamp(options.chunkLength, 1000, model.input.maxTokens - PROMPT_INPUT_BUFFER)
142
155
  const CHUNK_EXAMPLES_MAX_TOKENS = clamp(Math.floor(TOTAL_MAX_TOKENS * 0.5), 250, 10_000)
143
156
  const CHUNK_INPUT_MAX_TOKENS = clamp(
144
157
  TOTAL_MAX_TOKENS - CHUNK_EXAMPLES_MAX_TOKENS,
@@ -151,7 +164,7 @@ Zai.prototype.label = async function <T extends string>(
151
164
  if (tokenizer.count(inputAsString) > CHUNK_INPUT_MAX_TOKENS) {
152
165
  const tokens = tokenizer.split(inputAsString)
153
166
  const chunks = chunk(tokens, CHUNK_INPUT_MAX_TOKENS).map((x) => x.join(''))
154
- const allLabels = await Promise.all(chunks.map((chunk) => this.label(chunk, _labels)))
167
+ const allLabels = await Promise.all(chunks.map((chunk) => label(chunk, _labels, _options, ctx)))
155
168
 
156
169
  // Merge all the labels together (those who are true will remain true)
157
170
  return allLabels.reduce((acc, x) => {
@@ -202,21 +215,22 @@ Zai.prototype.label = async function <T extends string>(
202
215
  }
203
216
  }
204
217
 
205
- const examples = taskId
206
- ? await this.adapter.getExamples<
207
- string,
208
- {
209
- [K in T]: {
210
- explanation: string
211
- label: Label
218
+ const examples =
219
+ taskId && ctx.adapter
220
+ ? await ctx.adapter.getExamples<
221
+ string,
222
+ {
223
+ [K in T]: {
224
+ explanation: string
225
+ label: Label
226
+ }
212
227
  }
213
- }
214
- >({
215
- input: inputAsString,
216
- taskType,
217
- taskId,
218
- })
219
- : []
228
+ >({
229
+ input: inputAsString,
230
+ taskType,
231
+ taskId,
232
+ })
233
+ : []
220
234
 
221
235
  options.examples.forEach((example) => {
222
236
  examples.push({
@@ -285,7 +299,7 @@ ${END}
285
299
  })
286
300
  .join('\n\n')
287
301
 
288
- const { output, meta } = await this.callModel({
302
+ const { extracted, meta } = await ctx.generateContent({
289
303
  stopSequences: [END],
290
304
  systemPrompt: `
291
305
  You need to tag the input with the following labels based on the question asked:
@@ -336,35 +350,33 @@ The Expert Examples are there to help you make your decision. They have been pro
336
350
  For example, you can say: "According to Expert Example #1, ..."`.trim(),
337
351
  },
338
352
  ],
353
+ transform: (text) =>
354
+ Object.keys(labels).reduce((acc, key) => {
355
+ const match = text.match(new RegExp(`■${key}:【(.+)】:(\\w{2,})■`, 'i'))
356
+ if (match) {
357
+ const explanation = match[1].trim()
358
+ const label = parseLabel(match[2])
359
+ acc[key] = {
360
+ explanation,
361
+ label,
362
+ }
363
+ } else {
364
+ acc[key] = {
365
+ explanation: '',
366
+ label: LABELS.AMBIGUOUS,
367
+ }
368
+ }
369
+ return acc
370
+ }, {}) as {
371
+ [K in T]: {
372
+ explanation: string
373
+ label: Label
374
+ }
375
+ },
339
376
  })
340
377
 
341
- const answer = output.choices[0].content as string
342
-
343
- const final = Object.keys(labels).reduce((acc, key) => {
344
- const match = answer.match(new RegExp(`■${key}:【(.+)】:(\\w{2,})■`, 'i'))
345
- if (match) {
346
- const explanation = match[1].trim()
347
- const label = parseLabel(match[2])
348
- acc[key] = {
349
- explanation,
350
- label,
351
- }
352
- } else {
353
- acc[key] = {
354
- explanation: '',
355
- label: LABELS.AMBIGUOUS,
356
- }
357
- }
358
- return acc
359
- }, {}) as {
360
- [K in T]: {
361
- explanation: string
362
- label: Label
363
- }
364
- }
365
-
366
- if (taskId) {
367
- await this.adapter.saveExample({
378
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
379
+ await ctx.adapter.saveExample({
368
380
  key: Key,
369
381
  taskType,
370
382
  taskId,
@@ -375,16 +387,59 @@ For example, you can say: "According to Expert Example #1, ..."`.trim(),
375
387
  output: meta.cost.output,
376
388
  },
377
389
  latency: meta.latency,
378
- model: this.Model,
390
+ model: ctx.modelId,
379
391
  tokens: {
380
392
  input: meta.tokens.input,
381
393
  output: meta.tokens.output,
382
394
  },
383
395
  },
384
396
  input: inputAsString,
385
- output: final,
397
+ output: extracted,
386
398
  })
387
399
  }
388
400
 
389
- return convertToAnswer(final)
401
+ return convertToAnswer(extracted)
402
+ }
403
+
404
+ Zai.prototype.label = function <T extends string>(
405
+ this: Zai,
406
+ input: unknown,
407
+ labels: Labels<T>,
408
+ _options?: Options<T>
409
+ ): Response<
410
+ {
411
+ [K in T]: {
412
+ explanation: string
413
+ value: boolean
414
+ confidence: number
415
+ }
416
+ },
417
+ { [K in T]: boolean }
418
+ > {
419
+ const context = new ZaiContext({
420
+ client: this.client,
421
+ modelId: this.Model,
422
+ taskId: this.taskId,
423
+ taskType: 'zai.label',
424
+ adapter: this.adapter,
425
+ })
426
+
427
+ return new Response<
428
+ {
429
+ [K in T]: {
430
+ explanation: string
431
+ value: boolean
432
+ confidence: number
433
+ }
434
+ },
435
+ { [K in T]: boolean }
436
+ >(context, label(input, labels, _options, context), (result) =>
437
+ Object.keys(result).reduce(
438
+ (acc, key) => {
439
+ acc[key] = result[key].value
440
+ return acc
441
+ },
442
+ {} as { [K in T]: boolean }
443
+ )
444
+ )
390
445
  }