@botpress/zai 2.0.16 → 2.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/context.js +131 -0
- package/dist/emitter.js +42 -0
- package/dist/index.d.ts +104 -9
- package/dist/operations/check.js +46 -27
- package/dist/operations/extract.js +63 -46
- package/dist/operations/filter.js +34 -19
- package/dist/operations/label.js +65 -42
- package/dist/operations/rewrite.js +37 -17
- package/dist/operations/summarize.js +32 -13
- package/dist/operations/text.js +28 -8
- package/dist/response.js +82 -0
- package/dist/tokenizer.js +11 -0
- package/e2e/client.ts +43 -29
- package/e2e/data/cache.jsonl +276 -0
- package/package.json +11 -3
- package/src/context.ts +197 -0
- package/src/emitter.ts +49 -0
- package/src/operations/check.ts +99 -49
- package/src/operations/extract.ts +85 -60
- package/src/operations/filter.ts +62 -35
- package/src/operations/label.ts +117 -62
- package/src/operations/rewrite.ts +50 -21
- package/src/operations/summarize.ts +40 -14
- package/src/operations/text.ts +32 -8
- package/src/response.ts +114 -0
- package/src/tokenizer.ts +14 -0
|
@@ -5,6 +5,9 @@ import JSON5 from 'json5'
|
|
|
5
5
|
import { jsonrepair } from 'jsonrepair'
|
|
6
6
|
|
|
7
7
|
import { chunk, isArray } from 'lodash-es'
|
|
8
|
+
import { ZaiContext } from '../context'
|
|
9
|
+
import { Response } from '../response'
|
|
10
|
+
import { getTokenizer } from '../tokenizer'
|
|
8
11
|
import { fastHash, stringify, takeUntilTokens } from '../utils'
|
|
9
12
|
import { Zai } from '../zai'
|
|
10
13
|
import { PROMPT_INPUT_BUFFER } from './constants'
|
|
@@ -38,7 +41,7 @@ type AnyObjectOrArray = Record<string, unknown> | Array<unknown>
|
|
|
38
41
|
declare module '@botpress/zai' {
|
|
39
42
|
interface Zai {
|
|
40
43
|
/** Extracts one or many elements from an arbitrary input */
|
|
41
|
-
extract<S extends OfType<any>>(input: unknown, schema: S, options?: Options):
|
|
44
|
+
extract<S extends OfType<any>>(input: unknown, schema: S, options?: Options): Response<S['_output']>
|
|
42
45
|
}
|
|
43
46
|
}
|
|
44
47
|
|
|
@@ -46,21 +49,22 @@ const START = '■json_start■'
|
|
|
46
49
|
const END = '■json_end■'
|
|
47
50
|
const NO_MORE = '■NO_MORE_ELEMENT■'
|
|
48
51
|
|
|
49
|
-
|
|
50
|
-
this: Zai,
|
|
52
|
+
const extract = async <S extends OfType<AnyObjectOrArray>>(
|
|
51
53
|
input: unknown,
|
|
52
54
|
_schema: S,
|
|
53
|
-
_options
|
|
54
|
-
|
|
55
|
+
_options: Options | undefined,
|
|
56
|
+
ctx: ZaiContext
|
|
57
|
+
): Promise<S['_output']> => {
|
|
58
|
+
ctx.controller.signal.throwIfAborted()
|
|
55
59
|
let schema = _schema as any as z.ZodType
|
|
56
60
|
const options = Options.parse(_options ?? {})
|
|
57
|
-
const tokenizer = await
|
|
58
|
-
await
|
|
61
|
+
const tokenizer = await getTokenizer()
|
|
62
|
+
const model = await ctx.getModel()
|
|
59
63
|
|
|
60
|
-
const taskId =
|
|
64
|
+
const taskId = ctx.taskId
|
|
61
65
|
const taskType = 'zai.extract'
|
|
62
66
|
|
|
63
|
-
const PROMPT_COMPONENT = Math.max(
|
|
67
|
+
const PROMPT_COMPONENT = Math.max(model.input.maxTokens - PROMPT_INPUT_BUFFER, 100)
|
|
64
68
|
|
|
65
69
|
let isArrayOfObjects = false
|
|
66
70
|
let wrappedValue = false
|
|
@@ -99,10 +103,7 @@ Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
|
|
|
99
103
|
const schemaTypescript = schema.toTypescriptType({ declaration: false })
|
|
100
104
|
const schemaLength = tokenizer.count(schemaTypescript)
|
|
101
105
|
|
|
102
|
-
options.chunkLength = Math.min(
|
|
103
|
-
options.chunkLength,
|
|
104
|
-
this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength
|
|
105
|
-
)
|
|
106
|
+
options.chunkLength = Math.min(options.chunkLength, model.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength)
|
|
106
107
|
|
|
107
108
|
const keys = Object.keys((schema as ZodObject).shape)
|
|
108
109
|
|
|
@@ -113,18 +114,25 @@ Zai.prototype.extract = async function <S extends OfType<AnyObjectOrArray>>(
|
|
|
113
114
|
const chunks = chunk(tokens, options.chunkLength).map((x) => x.join(''))
|
|
114
115
|
const all = await Promise.allSettled(
|
|
115
116
|
chunks.map((chunk) =>
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
117
|
+
extract(
|
|
118
|
+
chunk,
|
|
119
|
+
originalSchema,
|
|
120
|
+
{
|
|
121
|
+
...options,
|
|
122
|
+
strict: false, // We don't want to fail on strict mode for sub-chunks
|
|
123
|
+
},
|
|
124
|
+
ctx
|
|
125
|
+
)
|
|
120
126
|
)
|
|
121
127
|
).then((results) =>
|
|
122
128
|
results.filter((x) => x.status === 'fulfilled').map((x) => (x as PromiseFulfilledResult<S['_output']>).value)
|
|
123
129
|
)
|
|
124
130
|
|
|
131
|
+
ctx.controller.signal.throwIfAborted()
|
|
132
|
+
|
|
125
133
|
// We run this function recursively until all chunks are merged into a single output
|
|
126
134
|
const rows = all.map((x, idx) => `<part-${idx + 1}>\n${stringify(x, true)}\n</part-${idx + 1}>`).join('\n')
|
|
127
|
-
return
|
|
135
|
+
return extract(
|
|
128
136
|
`
|
|
129
137
|
The result has been split into ${all.length} parts. Recursively merge the result into the final result.
|
|
130
138
|
When merging arrays, take unique values.
|
|
@@ -136,7 +144,8 @@ ${rows}
|
|
|
136
144
|
|
|
137
145
|
Merge it back into a final result.`.trim(),
|
|
138
146
|
originalSchema,
|
|
139
|
-
options
|
|
147
|
+
options,
|
|
148
|
+
ctx
|
|
140
149
|
)
|
|
141
150
|
}
|
|
142
151
|
|
|
@@ -179,13 +188,14 @@ Merge it back into a final result.`.trim(),
|
|
|
179
188
|
})
|
|
180
189
|
)
|
|
181
190
|
|
|
182
|
-
const examples =
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
191
|
+
const examples =
|
|
192
|
+
taskId && ctx.adapter
|
|
193
|
+
? await ctx.adapter.getExamples<string, unknown>({
|
|
194
|
+
input: inputAsString,
|
|
195
|
+
taskType,
|
|
196
|
+
taskId,
|
|
197
|
+
})
|
|
198
|
+
: []
|
|
189
199
|
|
|
190
200
|
const exactMatch = examples.find((x) => x.key === Key)
|
|
191
201
|
if (exactMatch) {
|
|
@@ -287,7 +297,7 @@ ${END}`.trim()
|
|
|
287
297
|
.map(formatExample)
|
|
288
298
|
.flat()
|
|
289
299
|
|
|
290
|
-
const {
|
|
300
|
+
const { meta, extracted } = await ctx.generateContent({
|
|
291
301
|
systemPrompt: `
|
|
292
302
|
Extract the following information from the input:
|
|
293
303
|
${schemaTypescript}
|
|
@@ -304,43 +314,41 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
|
|
|
304
314
|
content: formatInput(inputAsString, schemaTypescript, options.instructions ?? ''),
|
|
305
315
|
},
|
|
306
316
|
],
|
|
317
|
+
transform: (text) =>
|
|
318
|
+
(text || '{}')
|
|
319
|
+
?.split(START)
|
|
320
|
+
.filter((x) => x.trim().length > 0 && x.includes('}'))
|
|
321
|
+
.map((x) => {
|
|
322
|
+
try {
|
|
323
|
+
const json = x.slice(0, x.indexOf(END)).trim()
|
|
324
|
+
const repairedJson = jsonrepair(json)
|
|
325
|
+
const parsedJson = JSON5.parse(repairedJson)
|
|
326
|
+
const safe = schema.safeParse(parsedJson)
|
|
327
|
+
|
|
328
|
+
if (safe.success) {
|
|
329
|
+
return safe.data
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
if (options.strict) {
|
|
333
|
+
throw new JsonParsingError(x, safe.error)
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
return parsedJson
|
|
337
|
+
} catch (error) {
|
|
338
|
+
throw new JsonParsingError(x, error instanceof Error ? error : new Error('Unknown error'))
|
|
339
|
+
}
|
|
340
|
+
})
|
|
341
|
+
.filter((x) => x !== null),
|
|
307
342
|
})
|
|
308
343
|
|
|
309
|
-
const answer = (output.choices[0]?.content ?? '{}') as string
|
|
310
|
-
|
|
311
|
-
const elements = answer
|
|
312
|
-
?.split(START)
|
|
313
|
-
.filter((x) => x.trim().length > 0 && x.includes('}'))
|
|
314
|
-
.map((x) => {
|
|
315
|
-
try {
|
|
316
|
-
const json = x.slice(0, x.indexOf(END)).trim()
|
|
317
|
-
const repairedJson = jsonrepair(json)
|
|
318
|
-
const parsedJson = JSON5.parse(repairedJson)
|
|
319
|
-
const safe = schema.safeParse(parsedJson)
|
|
320
|
-
|
|
321
|
-
if (safe.success) {
|
|
322
|
-
return safe.data
|
|
323
|
-
}
|
|
324
|
-
|
|
325
|
-
if (options.strict) {
|
|
326
|
-
throw new JsonParsingError(x, safe.error)
|
|
327
|
-
}
|
|
328
|
-
|
|
329
|
-
return parsedJson
|
|
330
|
-
} catch (error) {
|
|
331
|
-
throw new JsonParsingError(x, error instanceof Error ? error : new Error('Unknown error'))
|
|
332
|
-
}
|
|
333
|
-
})
|
|
334
|
-
.filter((x) => x !== null)
|
|
335
|
-
|
|
336
344
|
let final: any
|
|
337
345
|
|
|
338
346
|
if (isArrayOfObjects) {
|
|
339
|
-
final =
|
|
340
|
-
} else if (
|
|
347
|
+
final = extracted
|
|
348
|
+
} else if (extracted.length === 0) {
|
|
341
349
|
final = options.strict ? schema.parse({}) : {}
|
|
342
350
|
} else {
|
|
343
|
-
final =
|
|
351
|
+
final = extracted[0]
|
|
344
352
|
}
|
|
345
353
|
|
|
346
354
|
if (wrappedValue) {
|
|
@@ -351,8 +359,8 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
|
|
|
351
359
|
}
|
|
352
360
|
}
|
|
353
361
|
|
|
354
|
-
if (taskId) {
|
|
355
|
-
await
|
|
362
|
+
if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
|
|
363
|
+
await ctx.adapter.saveExample({
|
|
356
364
|
key: Key,
|
|
357
365
|
taskId: `zai/${taskId}`,
|
|
358
366
|
taskType,
|
|
@@ -365,7 +373,7 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
|
|
|
365
373
|
output: meta.cost.output,
|
|
366
374
|
},
|
|
367
375
|
latency: meta.latency,
|
|
368
|
-
model:
|
|
376
|
+
model: ctx.modelId,
|
|
369
377
|
tokens: {
|
|
370
378
|
input: meta.tokens.input,
|
|
371
379
|
output: meta.tokens.output,
|
|
@@ -376,3 +384,20 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
|
|
|
376
384
|
|
|
377
385
|
return final
|
|
378
386
|
}
|
|
387
|
+
|
|
388
|
+
Zai.prototype.extract = function <S extends OfType<AnyObjectOrArray>>(
|
|
389
|
+
this: Zai,
|
|
390
|
+
input: unknown,
|
|
391
|
+
schema: S,
|
|
392
|
+
_options?: Options
|
|
393
|
+
): Response<S['_output']> {
|
|
394
|
+
const context = new ZaiContext({
|
|
395
|
+
client: this.client,
|
|
396
|
+
modelId: this.Model,
|
|
397
|
+
taskId: this.taskId,
|
|
398
|
+
taskType: 'zai.extract',
|
|
399
|
+
adapter: this.adapter,
|
|
400
|
+
})
|
|
401
|
+
|
|
402
|
+
return new Response<S['_output']>(context, extract(input, schema, _options, context), (result) => result)
|
|
403
|
+
}
|
package/src/operations/filter.ts
CHANGED
|
@@ -2,6 +2,9 @@
|
|
|
2
2
|
import { z } from '@bpinternal/zui'
|
|
3
3
|
|
|
4
4
|
import { clamp } from 'lodash-es'
|
|
5
|
+
import { ZaiContext } from '../context'
|
|
6
|
+
import { Response } from '../response'
|
|
7
|
+
import { getTokenizer } from '../tokenizer'
|
|
5
8
|
import { fastHash, stringify, takeUntilTokens } from '../utils'
|
|
6
9
|
import { Zai } from '../zai'
|
|
7
10
|
import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
|
|
@@ -39,22 +42,28 @@ const _Options = z.object({
|
|
|
39
42
|
declare module '@botpress/zai' {
|
|
40
43
|
interface Zai {
|
|
41
44
|
/** Filters elements of an array against a condition */
|
|
42
|
-
filter<T>(input: Array<T>, condition: string, options?: Options):
|
|
45
|
+
filter<T>(input: Array<T>, condition: string, options?: Options): Response<Array<T>>
|
|
43
46
|
}
|
|
44
47
|
}
|
|
45
48
|
|
|
46
49
|
const END = '■END■'
|
|
47
50
|
|
|
48
|
-
|
|
51
|
+
const filter = async <T>(
|
|
52
|
+
input: Array<T>,
|
|
53
|
+
condition: string,
|
|
54
|
+
_options: Options | undefined,
|
|
55
|
+
ctx: ZaiContext
|
|
56
|
+
): Promise<Array<T>> => {
|
|
57
|
+
ctx.controller.signal.throwIfAborted()
|
|
49
58
|
const options = _Options.parse(_options ?? {}) as Options
|
|
50
|
-
const tokenizer = await
|
|
51
|
-
await
|
|
59
|
+
const tokenizer = await getTokenizer()
|
|
60
|
+
const model = await ctx.getModel()
|
|
52
61
|
|
|
53
|
-
const taskId =
|
|
62
|
+
const taskId = ctx.taskId
|
|
54
63
|
const taskType = 'zai.filter'
|
|
55
64
|
|
|
56
65
|
const MAX_ITEMS_PER_CHUNK = 50
|
|
57
|
-
const TOKENS_TOTAL_MAX =
|
|
66
|
+
const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER
|
|
58
67
|
const TOKENS_EXAMPLES_MAX = Math.floor(Math.max(250, TOKENS_TOTAL_MAX * 0.5))
|
|
59
68
|
const TOKENS_CONDITION_MAX = clamp(TOKENS_TOTAL_MAX * 0.25, 250, tokenizer.count(condition))
|
|
60
69
|
const TOKENS_INPUT_ARRAY_MAX = TOKENS_TOTAL_MAX - TOKENS_EXAMPLES_MAX - TOKENS_CONDITION_MAX
|
|
@@ -145,18 +154,19 @@ ${examples.map((x, idx) => `■${idx}:${!!x.filter ? 'true' : 'false'}:${x.reaso
|
|
|
145
154
|
]
|
|
146
155
|
|
|
147
156
|
const filterChunk = async (chunk: typeof input) => {
|
|
148
|
-
const examples =
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
157
|
+
const examples =
|
|
158
|
+
taskId && ctx.adapter
|
|
159
|
+
? await ctx.adapter
|
|
160
|
+
.getExamples<string, unknown>({
|
|
161
|
+
// The Table API can't search for a huge input string
|
|
162
|
+
input: JSON.stringify(chunk).slice(0, 1000),
|
|
163
|
+
taskType,
|
|
164
|
+
taskId,
|
|
165
|
+
})
|
|
166
|
+
.then((x) =>
|
|
167
|
+
x.map((y) => ({ filter: y.output as boolean, input: y.input, reason: y.explanation }) satisfies Example)
|
|
168
|
+
)
|
|
169
|
+
: []
|
|
160
170
|
|
|
161
171
|
const allExamples = takeUntilTokens([...examples, ...(options.examples ?? [])], TOKENS_EXAMPLES_MAX, (el) =>
|
|
162
172
|
tokenizer.count(stringify(el.input))
|
|
@@ -175,7 +185,7 @@ ${examples.map((x, idx) => `■${idx}:${!!x.filter ? 'true' : 'false'}:${x.reaso
|
|
|
175
185
|
},
|
|
176
186
|
]
|
|
177
187
|
|
|
178
|
-
const {
|
|
188
|
+
const { extracted: partial, meta } = await ctx.generateContent({
|
|
179
189
|
systemPrompt: `
|
|
180
190
|
You are given a list of items. Your task is to filter out the items that meet the condition below.
|
|
181
191
|
You need to return the full list of items with the format:
|
|
@@ -198,23 +208,23 @@ The condition is: "${condition}"
|
|
|
198
208
|
role: 'user',
|
|
199
209
|
},
|
|
200
210
|
],
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
return { idx: parseInt(idx?.trim() ?? ''), filter: filter?.toLowerCase().trim() === 'true' }
|
|
211
|
-
})
|
|
211
|
+
transform: (text) => {
|
|
212
|
+
const indices = text
|
|
213
|
+
.trim()
|
|
214
|
+
.split('■')
|
|
215
|
+
.filter((x) => x.length > 0)
|
|
216
|
+
.map((x) => {
|
|
217
|
+
const [idx, filter] = x.split(':')
|
|
218
|
+
return { idx: parseInt(idx?.trim() ?? ''), filter: filter?.toLowerCase().trim() === 'true' }
|
|
219
|
+
})
|
|
212
220
|
|
|
213
|
-
|
|
214
|
-
|
|
221
|
+
return chunk.filter((_, idx) => {
|
|
222
|
+
return indices.find((x) => x.idx === idx && x.filter) ?? false
|
|
223
|
+
})
|
|
224
|
+
},
|
|
215
225
|
})
|
|
216
226
|
|
|
217
|
-
if (taskId) {
|
|
227
|
+
if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
|
|
218
228
|
const key = fastHash(
|
|
219
229
|
stringify({
|
|
220
230
|
taskId,
|
|
@@ -224,7 +234,7 @@ The condition is: "${condition}"
|
|
|
224
234
|
})
|
|
225
235
|
)
|
|
226
236
|
|
|
227
|
-
await
|
|
237
|
+
await ctx.adapter.saveExample({
|
|
228
238
|
key,
|
|
229
239
|
taskType,
|
|
230
240
|
taskId,
|
|
@@ -237,7 +247,7 @@ The condition is: "${condition}"
|
|
|
237
247
|
output: meta.cost.output,
|
|
238
248
|
},
|
|
239
249
|
latency: meta.latency,
|
|
240
|
-
model:
|
|
250
|
+
model: ctx.modelId,
|
|
241
251
|
tokens: {
|
|
242
252
|
input: meta.tokens.input,
|
|
243
253
|
output: meta.tokens.output,
|
|
@@ -253,3 +263,20 @@ The condition is: "${condition}"
|
|
|
253
263
|
|
|
254
264
|
return filteredChunks.flat()
|
|
255
265
|
}
|
|
266
|
+
|
|
267
|
+
Zai.prototype.filter = function <T>(
|
|
268
|
+
this: Zai,
|
|
269
|
+
input: Array<T>,
|
|
270
|
+
condition: string,
|
|
271
|
+
_options?: Options
|
|
272
|
+
): Response<Array<T>> {
|
|
273
|
+
const context = new ZaiContext({
|
|
274
|
+
client: this.client,
|
|
275
|
+
modelId: this.Model,
|
|
276
|
+
taskId: this.taskId,
|
|
277
|
+
taskType: 'zai.filter',
|
|
278
|
+
adapter: this.adapter,
|
|
279
|
+
})
|
|
280
|
+
|
|
281
|
+
return new Response<Array<T>>(context, filter(input, condition, _options, context), (result) => result)
|
|
282
|
+
}
|
package/src/operations/label.ts
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
// eslint-disable consistent-type-definitions
|
|
2
2
|
import { z } from '@bpinternal/zui'
|
|
3
3
|
|
|
4
|
-
import {
|
|
4
|
+
import { chunk, clamp } from 'lodash-es'
|
|
5
|
+
import { ZaiContext } from '../context'
|
|
6
|
+
import { Response } from '../response'
|
|
7
|
+
import { getTokenizer } from '../tokenizer'
|
|
5
8
|
import { fastHash, stringify, takeUntilTokens } from '../utils'
|
|
6
9
|
import { Zai } from '../zai'
|
|
7
10
|
import { PROMPT_INPUT_BUFFER } from './constants'
|
|
@@ -83,13 +86,16 @@ declare module '@botpress/zai' {
|
|
|
83
86
|
input: unknown,
|
|
84
87
|
labels: Labels<T>,
|
|
85
88
|
options?: Options<T>
|
|
86
|
-
):
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
89
|
+
): Response<
|
|
90
|
+
{
|
|
91
|
+
[K in T]: {
|
|
92
|
+
explanation: string
|
|
93
|
+
value: boolean
|
|
94
|
+
confidence: number
|
|
95
|
+
}
|
|
96
|
+
},
|
|
97
|
+
{ [K in T]: boolean }
|
|
98
|
+
>
|
|
93
99
|
}
|
|
94
100
|
}
|
|
95
101
|
|
|
@@ -124,21 +130,28 @@ const getConfidence = (label: Label) => {
|
|
|
124
130
|
}
|
|
125
131
|
}
|
|
126
132
|
|
|
127
|
-
|
|
128
|
-
this: Zai,
|
|
133
|
+
const label = async <T extends string>(
|
|
129
134
|
input: unknown,
|
|
130
135
|
_labels: Labels<T>,
|
|
131
|
-
_options: Options<T> | undefined
|
|
132
|
-
|
|
136
|
+
_options: Options<T> | undefined,
|
|
137
|
+
ctx: ZaiContext
|
|
138
|
+
): Promise<{
|
|
139
|
+
[K in T]: {
|
|
140
|
+
explanation: string
|
|
141
|
+
value: boolean
|
|
142
|
+
confidence: number
|
|
143
|
+
}
|
|
144
|
+
}> => {
|
|
145
|
+
ctx.controller.signal.throwIfAborted()
|
|
133
146
|
const options = _Options.parse(_options ?? {}) as unknown as Options<T>
|
|
134
147
|
const labels = _Labels.parse(_labels) as Labels<T>
|
|
135
|
-
const tokenizer = await
|
|
136
|
-
await
|
|
148
|
+
const tokenizer = await getTokenizer()
|
|
149
|
+
const model = await ctx.getModel()
|
|
137
150
|
|
|
138
|
-
const taskId =
|
|
151
|
+
const taskId = ctx.taskId
|
|
139
152
|
const taskType = 'zai.label'
|
|
140
153
|
|
|
141
|
-
const TOTAL_MAX_TOKENS = clamp(options.chunkLength, 1000,
|
|
154
|
+
const TOTAL_MAX_TOKENS = clamp(options.chunkLength, 1000, model.input.maxTokens - PROMPT_INPUT_BUFFER)
|
|
142
155
|
const CHUNK_EXAMPLES_MAX_TOKENS = clamp(Math.floor(TOTAL_MAX_TOKENS * 0.5), 250, 10_000)
|
|
143
156
|
const CHUNK_INPUT_MAX_TOKENS = clamp(
|
|
144
157
|
TOTAL_MAX_TOKENS - CHUNK_EXAMPLES_MAX_TOKENS,
|
|
@@ -151,7 +164,7 @@ Zai.prototype.label = async function <T extends string>(
|
|
|
151
164
|
if (tokenizer.count(inputAsString) > CHUNK_INPUT_MAX_TOKENS) {
|
|
152
165
|
const tokens = tokenizer.split(inputAsString)
|
|
153
166
|
const chunks = chunk(tokens, CHUNK_INPUT_MAX_TOKENS).map((x) => x.join(''))
|
|
154
|
-
const allLabels = await Promise.all(chunks.map((chunk) =>
|
|
167
|
+
const allLabels = await Promise.all(chunks.map((chunk) => label(chunk, _labels, _options, ctx)))
|
|
155
168
|
|
|
156
169
|
// Merge all the labels together (those who are true will remain true)
|
|
157
170
|
return allLabels.reduce((acc, x) => {
|
|
@@ -202,21 +215,22 @@ Zai.prototype.label = async function <T extends string>(
|
|
|
202
215
|
}
|
|
203
216
|
}
|
|
204
217
|
|
|
205
|
-
const examples =
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
218
|
+
const examples =
|
|
219
|
+
taskId && ctx.adapter
|
|
220
|
+
? await ctx.adapter.getExamples<
|
|
221
|
+
string,
|
|
222
|
+
{
|
|
223
|
+
[K in T]: {
|
|
224
|
+
explanation: string
|
|
225
|
+
label: Label
|
|
226
|
+
}
|
|
212
227
|
}
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
: []
|
|
228
|
+
>({
|
|
229
|
+
input: inputAsString,
|
|
230
|
+
taskType,
|
|
231
|
+
taskId,
|
|
232
|
+
})
|
|
233
|
+
: []
|
|
220
234
|
|
|
221
235
|
options.examples.forEach((example) => {
|
|
222
236
|
examples.push({
|
|
@@ -285,7 +299,7 @@ ${END}
|
|
|
285
299
|
})
|
|
286
300
|
.join('\n\n')
|
|
287
301
|
|
|
288
|
-
const {
|
|
302
|
+
const { extracted, meta } = await ctx.generateContent({
|
|
289
303
|
stopSequences: [END],
|
|
290
304
|
systemPrompt: `
|
|
291
305
|
You need to tag the input with the following labels based on the question asked:
|
|
@@ -336,35 +350,33 @@ The Expert Examples are there to help you make your decision. They have been pro
|
|
|
336
350
|
For example, you can say: "According to Expert Example #1, ..."`.trim(),
|
|
337
351
|
},
|
|
338
352
|
],
|
|
353
|
+
transform: (text) =>
|
|
354
|
+
Object.keys(labels).reduce((acc, key) => {
|
|
355
|
+
const match = text.match(new RegExp(`■${key}:【(.+)】:(\\w{2,})■`, 'i'))
|
|
356
|
+
if (match) {
|
|
357
|
+
const explanation = match[1].trim()
|
|
358
|
+
const label = parseLabel(match[2])
|
|
359
|
+
acc[key] = {
|
|
360
|
+
explanation,
|
|
361
|
+
label,
|
|
362
|
+
}
|
|
363
|
+
} else {
|
|
364
|
+
acc[key] = {
|
|
365
|
+
explanation: '',
|
|
366
|
+
label: LABELS.AMBIGUOUS,
|
|
367
|
+
}
|
|
368
|
+
}
|
|
369
|
+
return acc
|
|
370
|
+
}, {}) as {
|
|
371
|
+
[K in T]: {
|
|
372
|
+
explanation: string
|
|
373
|
+
label: Label
|
|
374
|
+
}
|
|
375
|
+
},
|
|
339
376
|
})
|
|
340
377
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
const final = Object.keys(labels).reduce((acc, key) => {
|
|
344
|
-
const match = answer.match(new RegExp(`■${key}:【(.+)】:(\\w{2,})■`, 'i'))
|
|
345
|
-
if (match) {
|
|
346
|
-
const explanation = match[1].trim()
|
|
347
|
-
const label = parseLabel(match[2])
|
|
348
|
-
acc[key] = {
|
|
349
|
-
explanation,
|
|
350
|
-
label,
|
|
351
|
-
}
|
|
352
|
-
} else {
|
|
353
|
-
acc[key] = {
|
|
354
|
-
explanation: '',
|
|
355
|
-
label: LABELS.AMBIGUOUS,
|
|
356
|
-
}
|
|
357
|
-
}
|
|
358
|
-
return acc
|
|
359
|
-
}, {}) as {
|
|
360
|
-
[K in T]: {
|
|
361
|
-
explanation: string
|
|
362
|
-
label: Label
|
|
363
|
-
}
|
|
364
|
-
}
|
|
365
|
-
|
|
366
|
-
if (taskId) {
|
|
367
|
-
await this.adapter.saveExample({
|
|
378
|
+
if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
|
|
379
|
+
await ctx.adapter.saveExample({
|
|
368
380
|
key: Key,
|
|
369
381
|
taskType,
|
|
370
382
|
taskId,
|
|
@@ -375,16 +387,59 @@ For example, you can say: "According to Expert Example #1, ..."`.trim(),
|
|
|
375
387
|
output: meta.cost.output,
|
|
376
388
|
},
|
|
377
389
|
latency: meta.latency,
|
|
378
|
-
model:
|
|
390
|
+
model: ctx.modelId,
|
|
379
391
|
tokens: {
|
|
380
392
|
input: meta.tokens.input,
|
|
381
393
|
output: meta.tokens.output,
|
|
382
394
|
},
|
|
383
395
|
},
|
|
384
396
|
input: inputAsString,
|
|
385
|
-
output:
|
|
397
|
+
output: extracted,
|
|
386
398
|
})
|
|
387
399
|
}
|
|
388
400
|
|
|
389
|
-
return convertToAnswer(
|
|
401
|
+
return convertToAnswer(extracted)
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
Zai.prototype.label = function <T extends string>(
|
|
405
|
+
this: Zai,
|
|
406
|
+
input: unknown,
|
|
407
|
+
labels: Labels<T>,
|
|
408
|
+
_options?: Options<T>
|
|
409
|
+
): Response<
|
|
410
|
+
{
|
|
411
|
+
[K in T]: {
|
|
412
|
+
explanation: string
|
|
413
|
+
value: boolean
|
|
414
|
+
confidence: number
|
|
415
|
+
}
|
|
416
|
+
},
|
|
417
|
+
{ [K in T]: boolean }
|
|
418
|
+
> {
|
|
419
|
+
const context = new ZaiContext({
|
|
420
|
+
client: this.client,
|
|
421
|
+
modelId: this.Model,
|
|
422
|
+
taskId: this.taskId,
|
|
423
|
+
taskType: 'zai.label',
|
|
424
|
+
adapter: this.adapter,
|
|
425
|
+
})
|
|
426
|
+
|
|
427
|
+
return new Response<
|
|
428
|
+
{
|
|
429
|
+
[K in T]: {
|
|
430
|
+
explanation: string
|
|
431
|
+
value: boolean
|
|
432
|
+
confidence: number
|
|
433
|
+
}
|
|
434
|
+
},
|
|
435
|
+
{ [K in T]: boolean }
|
|
436
|
+
>(context, label(input, labels, _options, context), (result) =>
|
|
437
|
+
Object.keys(result).reduce(
|
|
438
|
+
(acc, key) => {
|
|
439
|
+
acc[key] = result[key].value
|
|
440
|
+
return acc
|
|
441
|
+
},
|
|
442
|
+
{} as { [K in T]: boolean }
|
|
443
|
+
)
|
|
444
|
+
)
|
|
390
445
|
}
|