@botpress/zai 1.0.0-beta.2 → 1.0.0-beta.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +9 -2
- package/dist/src/operations/check.test.js +0 -129
- package/dist/src/operations/extract.test.js +0 -163
- package/dist/src/operations/filter.test.js +0 -151
- package/dist/src/operations/label.test.js +0 -213
- package/dist/src/operations/rewrite.test.js +0 -96
- package/dist/src/operations/summarize.test.js +0 -22
- package/dist/src/operations/text.test.js +0 -51
- package/dist/src/operations/zai-learn.test.js +0 -71
- package/dist/src/operations/zai-retry.test.js +0 -50
- package/src/adapters/adapter.ts +0 -35
- package/src/adapters/botpress-table.ts +0 -213
- package/src/adapters/memory.ts +0 -13
- package/src/env.ts +0 -54
- package/src/index.ts +0 -11
- package/src/models.ts +0 -347
- package/src/operations/__tests/botpress_docs.txt +0 -26040
- package/src/operations/__tests/index.ts +0 -30
- package/src/operations/check.test.ts +0 -155
- package/src/operations/check.ts +0 -187
- package/src/operations/constants.ts +0 -2
- package/src/operations/errors.ts +0 -9
- package/src/operations/extract.test.ts +0 -209
- package/src/operations/extract.ts +0 -291
- package/src/operations/filter.test.ts +0 -182
- package/src/operations/filter.ts +0 -231
- package/src/operations/label.test.ts +0 -239
- package/src/operations/label.ts +0 -332
- package/src/operations/rewrite.test.ts +0 -114
- package/src/operations/rewrite.ts +0 -148
- package/src/operations/summarize.test.ts +0 -25
- package/src/operations/summarize.ts +0 -193
- package/src/operations/text.test.ts +0 -60
- package/src/operations/text.ts +0 -63
- package/src/operations/zai-learn.test.ts +0 -85
- package/src/operations/zai-retry.test.ts +0 -64
- package/src/scripts/update-models.ts +0 -74
- package/src/utils.ts +0 -61
- package/src/zai.ts +0 -185
|
@@ -1,291 +0,0 @@
|
|
|
1
|
-
import { ZodArray, ZodObject, z } from '@bpinternal/zui'
|
|
2
|
-
|
|
3
|
-
import JSON5 from 'json5'
|
|
4
|
-
import { jsonrepair } from 'jsonrepair'
|
|
5
|
-
|
|
6
|
-
import _ from 'lodash'
|
|
7
|
-
import { fastHash, stringify, takeUntilTokens } from '../utils'
|
|
8
|
-
import { Zai } from '../zai'
|
|
9
|
-
import { PROMPT_INPUT_BUFFER } from './constants'
|
|
10
|
-
import { JsonParsingError } from './errors'
|
|
11
|
-
|
|
12
|
-
export type Options = z.input<typeof Options>
|
|
13
|
-
const Options = z.object({
|
|
14
|
-
instructions: z.string().optional().describe('Instructions to guide the user on how to extract the data'),
|
|
15
|
-
chunkLength: z
|
|
16
|
-
.number()
|
|
17
|
-
.min(100)
|
|
18
|
-
.max(100_000)
|
|
19
|
-
.optional()
|
|
20
|
-
.describe('The maximum number of tokens per chunk')
|
|
21
|
-
.default(16_000)
|
|
22
|
-
})
|
|
23
|
-
|
|
24
|
-
declare module '../zai' {
|
|
25
|
-
interface Zai {
|
|
26
|
-
/** Extracts one or many elements from an arbitrary input */
|
|
27
|
-
extract<S extends z.AnyZodObject>(input: unknown, schema: S, options?: Options): Promise<z.infer<S>>
|
|
28
|
-
extract<S extends z.AnyZodObject>(
|
|
29
|
-
input: unknown,
|
|
30
|
-
schema: z.ZodArray<S>,
|
|
31
|
-
options?: Options
|
|
32
|
-
): Promise<Array<z.infer<S>>>
|
|
33
|
-
}
|
|
34
|
-
}
|
|
35
|
-
|
|
36
|
-
const START = '■json_start■'
|
|
37
|
-
const END = '■json_end■'
|
|
38
|
-
const NO_MORE = '■NO_MORE_ELEMENT■'
|
|
39
|
-
|
|
40
|
-
Zai.prototype.extract = async function (this: Zai, input, schema, _options) {
|
|
41
|
-
const options = Options.parse(_options ?? {})
|
|
42
|
-
const tokenizer = await this.getTokenizer()
|
|
43
|
-
|
|
44
|
-
const taskId = this.taskId
|
|
45
|
-
const taskType = 'zai.extract'
|
|
46
|
-
|
|
47
|
-
const PROMPT_COMPONENT = Math.max(this.Model.input.maxTokens - PROMPT_INPUT_BUFFER, 100)
|
|
48
|
-
|
|
49
|
-
let isArrayOfObjects = false
|
|
50
|
-
const originalSchema = schema
|
|
51
|
-
|
|
52
|
-
if (schema instanceof ZodObject) {
|
|
53
|
-
// Do nothing
|
|
54
|
-
} else if (schema instanceof ZodArray) {
|
|
55
|
-
if (schema._def.type instanceof ZodObject) {
|
|
56
|
-
isArrayOfObjects = true
|
|
57
|
-
schema = schema._def.type
|
|
58
|
-
} else {
|
|
59
|
-
throw new Error('Schema must be a ZodObject or a ZodArray<ZodObject>')
|
|
60
|
-
}
|
|
61
|
-
} else {
|
|
62
|
-
throw new Error('Schema must be either a ZuiObject or a ZuiArray<ZuiObject>')
|
|
63
|
-
}
|
|
64
|
-
|
|
65
|
-
const schemaTypescript = schema.toTypescript({ declaration: false })
|
|
66
|
-
const schemaLength = tokenizer.count(schemaTypescript)
|
|
67
|
-
|
|
68
|
-
options.chunkLength = Math.min(options.chunkLength, this.Model.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength)
|
|
69
|
-
|
|
70
|
-
const keys = Object.keys(schema.shape)
|
|
71
|
-
|
|
72
|
-
let inputAsString = stringify(input)
|
|
73
|
-
|
|
74
|
-
if (tokenizer.count(inputAsString) > options.chunkLength) {
|
|
75
|
-
// If we want to extract an array of objects, we will run this function recursively
|
|
76
|
-
if (isArrayOfObjects) {
|
|
77
|
-
const tokens = tokenizer.split(inputAsString)
|
|
78
|
-
const chunks = _.chunk(tokens, options.chunkLength).map((x) => x.join(''))
|
|
79
|
-
const all = await Promise.all(chunks.map((chunk) => this.extract(chunk, originalSchema)))
|
|
80
|
-
|
|
81
|
-
return all.flat()
|
|
82
|
-
} else {
|
|
83
|
-
// Truncate the input to fit the model's input size
|
|
84
|
-
inputAsString = tokenizer.truncate(stringify(input), options.chunkLength)
|
|
85
|
-
}
|
|
86
|
-
}
|
|
87
|
-
|
|
88
|
-
const instructions: string[] = []
|
|
89
|
-
|
|
90
|
-
if (options.instructions) {
|
|
91
|
-
instructions.push(options.instructions)
|
|
92
|
-
}
|
|
93
|
-
|
|
94
|
-
const shape = `{ ${keys.map((key) => `"${key}": ...`).join(', ')} }`
|
|
95
|
-
const abbv = '{ ... }'
|
|
96
|
-
|
|
97
|
-
if (isArrayOfObjects) {
|
|
98
|
-
instructions.push('You may have multiple elements, or zero elements in the input.')
|
|
99
|
-
instructions.push('You must extract each element separately.')
|
|
100
|
-
instructions.push(`Each element must be a JSON object with exactly the format: ${START}${shape}${END}`)
|
|
101
|
-
instructions.push(`When you are done extracting all elements, type "${NO_MORE}" to finish.`)
|
|
102
|
-
instructions.push(`For example, if you have zero elements, the output should look like this: ${NO_MORE}`)
|
|
103
|
-
instructions.push(
|
|
104
|
-
`For example, if you have two elements, the output should look like this: ${START}${abbv}${END}${START}${abbv}${END}${NO_MORE}`
|
|
105
|
-
)
|
|
106
|
-
} else {
|
|
107
|
-
instructions.push('You may have exactly one element in the input.')
|
|
108
|
-
instructions.push(`The element must be a JSON object with exactly the format: ${START}${shape}${END}`)
|
|
109
|
-
}
|
|
110
|
-
|
|
111
|
-
// All tokens remaining after the input and condition are accounted can be used for examples
|
|
112
|
-
const EXAMPLES_TOKENS = PROMPT_COMPONENT - tokenizer.count(inputAsString) - tokenizer.count(instructions.join('\n'))
|
|
113
|
-
|
|
114
|
-
const Key = fastHash(
|
|
115
|
-
JSON.stringify({
|
|
116
|
-
taskType,
|
|
117
|
-
taskId,
|
|
118
|
-
input: inputAsString,
|
|
119
|
-
instructions: options.instructions
|
|
120
|
-
})
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
const examples = taskId
|
|
124
|
-
? await this.adapter.getExamples<string, unknown>({
|
|
125
|
-
input: inputAsString,
|
|
126
|
-
taskType,
|
|
127
|
-
taskId
|
|
128
|
-
})
|
|
129
|
-
: []
|
|
130
|
-
|
|
131
|
-
const exactMatch = examples.find((x) => x.key === Key)
|
|
132
|
-
if (exactMatch) {
|
|
133
|
-
return exactMatch.output
|
|
134
|
-
}
|
|
135
|
-
|
|
136
|
-
const defaultExample = isArrayOfObjects
|
|
137
|
-
? {
|
|
138
|
-
input: `The story goes as follow.
|
|
139
|
-
Once upon a time, there was a person named Alice who was 30 years old.
|
|
140
|
-
Then, there was a person named Bob who was 25 years old.
|
|
141
|
-
The end.`,
|
|
142
|
-
schema: 'Array<{ name: string, age: number }>',
|
|
143
|
-
instructions: 'Extract all people',
|
|
144
|
-
extracted: [
|
|
145
|
-
{
|
|
146
|
-
name: 'Alice',
|
|
147
|
-
age: 30
|
|
148
|
-
},
|
|
149
|
-
{
|
|
150
|
-
name: 'Bob',
|
|
151
|
-
age: 25
|
|
152
|
-
}
|
|
153
|
-
]
|
|
154
|
-
}
|
|
155
|
-
: {
|
|
156
|
-
input: `The story goes as follow.
|
|
157
|
-
Once upon a time, there was a person named Alice who was 30 years old.
|
|
158
|
-
The end.`,
|
|
159
|
-
schema: '{ name: string, age: number }',
|
|
160
|
-
instructions: 'Extract the person',
|
|
161
|
-
extracted: { name: 'Alice', age: 30 }
|
|
162
|
-
}
|
|
163
|
-
|
|
164
|
-
const userExamples = examples.map((e) => ({
|
|
165
|
-
input: e.input,
|
|
166
|
-
extracted: e.output,
|
|
167
|
-
schema: schemaTypescript,
|
|
168
|
-
instructions: options.instructions
|
|
169
|
-
}))
|
|
170
|
-
|
|
171
|
-
let exampleId = 1
|
|
172
|
-
|
|
173
|
-
const formatInput = (input: string, schema: string, instructions?: string) => {
|
|
174
|
-
const header = userExamples.length
|
|
175
|
-
? `Expert Example #${exampleId++}`
|
|
176
|
-
: "Here's an example to help you understand the format:"
|
|
177
|
-
|
|
178
|
-
return `
|
|
179
|
-
${header}
|
|
180
|
-
|
|
181
|
-
<|start_schema|>
|
|
182
|
-
${schema}
|
|
183
|
-
<|end_schema|>
|
|
184
|
-
|
|
185
|
-
<|start_instructions|>
|
|
186
|
-
${instructions ?? 'No specific instructions, just follow the schema above.'}
|
|
187
|
-
<|end_instructions|>
|
|
188
|
-
|
|
189
|
-
<|start_input|>
|
|
190
|
-
${input.trim()}
|
|
191
|
-
<|end_input|>
|
|
192
|
-
`.trim()
|
|
193
|
-
}
|
|
194
|
-
|
|
195
|
-
const formatOutput = (extracted: any) => {
|
|
196
|
-
extracted = _.isArray(extracted) ? extracted : [extracted]
|
|
197
|
-
|
|
198
|
-
return (
|
|
199
|
-
extracted
|
|
200
|
-
.map((x) =>
|
|
201
|
-
`
|
|
202
|
-
${START}
|
|
203
|
-
${JSON.stringify(x, null, 2)}
|
|
204
|
-
${END}`.trim()
|
|
205
|
-
)
|
|
206
|
-
.join('\n') + NO_MORE
|
|
207
|
-
)
|
|
208
|
-
}
|
|
209
|
-
|
|
210
|
-
const formatExample = (example: { input?: any; schema: string; instructions?: string; extracted: any }) => [
|
|
211
|
-
{
|
|
212
|
-
type: 'text' as const,
|
|
213
|
-
content: formatInput(stringify(example.input ?? null), example.schema, example.instructions),
|
|
214
|
-
role: 'user' as const
|
|
215
|
-
},
|
|
216
|
-
{
|
|
217
|
-
type: 'text' as const,
|
|
218
|
-
content: formatOutput(example.extracted),
|
|
219
|
-
role: 'assistant' as const
|
|
220
|
-
}
|
|
221
|
-
]
|
|
222
|
-
|
|
223
|
-
const allExamples = takeUntilTokens(
|
|
224
|
-
userExamples.length ? userExamples : [defaultExample],
|
|
225
|
-
EXAMPLES_TOKENS,
|
|
226
|
-
(el) => tokenizer.count(stringify(el.input)) + tokenizer.count(stringify(el.extracted))
|
|
227
|
-
)
|
|
228
|
-
.map(formatExample)
|
|
229
|
-
.flat()
|
|
230
|
-
|
|
231
|
-
const output = await this.callModel({
|
|
232
|
-
systemPrompt: `
|
|
233
|
-
Extract the following information from the input:
|
|
234
|
-
${schemaTypescript}
|
|
235
|
-
====
|
|
236
|
-
|
|
237
|
-
${instructions.map((x) => `• ${x}`).join('\n')}
|
|
238
|
-
`.trim(),
|
|
239
|
-
stopSequences: [isArrayOfObjects ? NO_MORE : END],
|
|
240
|
-
messages: [
|
|
241
|
-
...allExamples,
|
|
242
|
-
{
|
|
243
|
-
role: 'user',
|
|
244
|
-
type: 'text',
|
|
245
|
-
content: formatInput(inputAsString, schemaTypescript, options.instructions ?? '')
|
|
246
|
-
}
|
|
247
|
-
]
|
|
248
|
-
})
|
|
249
|
-
|
|
250
|
-
const answer = output.choices[0].content as string
|
|
251
|
-
|
|
252
|
-
const elements = answer
|
|
253
|
-
.split(START)
|
|
254
|
-
.filter((x) => x.trim().length > 0)
|
|
255
|
-
.map((x) => {
|
|
256
|
-
try {
|
|
257
|
-
const json = x.slice(0, x.indexOf(END)).trim()
|
|
258
|
-
const repairedJson = jsonrepair(json)
|
|
259
|
-
const parsedJson = JSON5.parse(repairedJson)
|
|
260
|
-
|
|
261
|
-
return schema.parse(parsedJson)
|
|
262
|
-
} catch (error) {
|
|
263
|
-
throw new JsonParsingError(x, error)
|
|
264
|
-
}
|
|
265
|
-
})
|
|
266
|
-
.filter((x) => x !== null)
|
|
267
|
-
|
|
268
|
-
let final: any
|
|
269
|
-
|
|
270
|
-
if (isArrayOfObjects) {
|
|
271
|
-
final = elements
|
|
272
|
-
} else if (elements.length === 0) {
|
|
273
|
-
final = schema.parse({})
|
|
274
|
-
} else {
|
|
275
|
-
final = elements[0]
|
|
276
|
-
}
|
|
277
|
-
|
|
278
|
-
if (taskId) {
|
|
279
|
-
await this.adapter.saveExample({
|
|
280
|
-
key: Key,
|
|
281
|
-
taskId: `zai/${taskId}`,
|
|
282
|
-
taskType,
|
|
283
|
-
instructions: options.instructions ?? 'No specific instructions',
|
|
284
|
-
input: inputAsString,
|
|
285
|
-
output: final,
|
|
286
|
-
metadata: output.metadata
|
|
287
|
-
})
|
|
288
|
-
}
|
|
289
|
-
|
|
290
|
-
return final
|
|
291
|
-
}
|
|
@@ -1,182 +0,0 @@
|
|
|
1
|
-
import { describe, it, expect, beforeEach, afterEach, afterAll, vi } from 'vitest'
|
|
2
|
-
|
|
3
|
-
import { getClient, getZai, metadata } from './__tests'
|
|
4
|
-
import { TableAdapter } from '../adapters/botpress-table'
|
|
5
|
-
import { Client } from '@botpress/client'
|
|
6
|
-
|
|
7
|
-
describe('zai.filter', { timeout: 60_000 }, () => {
|
|
8
|
-
let zai = getZai()
|
|
9
|
-
|
|
10
|
-
beforeEach(async () => {
|
|
11
|
-
zai = getZai()
|
|
12
|
-
})
|
|
13
|
-
|
|
14
|
-
it('basic filter with small items', async () => {
|
|
15
|
-
const value = await zai.filter(
|
|
16
|
-
[
|
|
17
|
-
{ name: 'John', description: 'is a bad person' },
|
|
18
|
-
{ name: 'Alice', description: 'is a good person' },
|
|
19
|
-
{ name: 'Bob', description: 'is a good person' },
|
|
20
|
-
{ name: 'Eve', description: 'is a bad person' },
|
|
21
|
-
{ name: 'Alex', description: 'is a good person' },
|
|
22
|
-
{ name: 'Sara', description: 'donates to charity every month' },
|
|
23
|
-
{ name: 'Tom', description: 'commits crimes and is in jail' }
|
|
24
|
-
],
|
|
25
|
-
'generally good people'
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
const names = value.map((v) => v.name)
|
|
29
|
-
expect(names).toMatchInlineSnapshot(`
|
|
30
|
-
[
|
|
31
|
-
"Alice",
|
|
32
|
-
"Bob",
|
|
33
|
-
"Alex",
|
|
34
|
-
"Sara",
|
|
35
|
-
]
|
|
36
|
-
`)
|
|
37
|
-
})
|
|
38
|
-
|
|
39
|
-
it('filtering huge array chunks it up', async () => {
|
|
40
|
-
const callAction = vi.fn()
|
|
41
|
-
const client = { ...getClient(), callAction } as unknown as Client
|
|
42
|
-
|
|
43
|
-
zai = getZai().with({
|
|
44
|
-
client
|
|
45
|
-
})
|
|
46
|
-
|
|
47
|
-
const hugeArray = Array.from({ length: 100 }, (_, i) => ({
|
|
48
|
-
name: `Person #${i}#`,
|
|
49
|
-
description: 'blah blah '.repeat(50_000)
|
|
50
|
-
}))
|
|
51
|
-
|
|
52
|
-
try {
|
|
53
|
-
await zai.filter(hugeArray, 'generally good people', { tokensPerItem: 100_000 })
|
|
54
|
-
} catch (err) {}
|
|
55
|
-
|
|
56
|
-
expect(callAction.mock.calls.length).toBeGreaterThan(20)
|
|
57
|
-
expect(JSON.stringify(callAction.mock.calls.at(0))).toContain('Person #0#')
|
|
58
|
-
expect(JSON.stringify(callAction.mock.calls.at(0))).not.toContain('Person #99#')
|
|
59
|
-
|
|
60
|
-
expect(JSON.stringify(callAction.mock.calls.at(-1))).not.toContain('Person #0#')
|
|
61
|
-
expect(JSON.stringify(callAction.mock.calls.at(-1))).toContain('Person #99#')
|
|
62
|
-
|
|
63
|
-
callAction.mockReset()
|
|
64
|
-
|
|
65
|
-
try {
|
|
66
|
-
await zai.filter(hugeArray, 'generally good people', { tokensPerItem: 100 })
|
|
67
|
-
} catch (err) {}
|
|
68
|
-
|
|
69
|
-
expect(callAction.mock.calls.length).toBe(2)
|
|
70
|
-
})
|
|
71
|
-
|
|
72
|
-
it('filter with examples', async () => {
|
|
73
|
-
const examples = [
|
|
74
|
-
{
|
|
75
|
-
input: 'Rasa (framework)',
|
|
76
|
-
filter: true,
|
|
77
|
-
reason: 'Rasa is a chatbot framework, so it competes with us (Botpress).'
|
|
78
|
-
},
|
|
79
|
-
{
|
|
80
|
-
input: 'Rasa (coffee company)',
|
|
81
|
-
filter: false,
|
|
82
|
-
reason:
|
|
83
|
-
'Rasa (coffee company) is not in the chatbot or AI agent industry, therefore it does not compete with us (Botpress).'
|
|
84
|
-
},
|
|
85
|
-
{
|
|
86
|
-
input: 'Dialogflow',
|
|
87
|
-
filter: true,
|
|
88
|
-
reason: 'Dialogflow is a chatbot development product, so it competes with us (Botpress).'
|
|
89
|
-
}
|
|
90
|
-
]
|
|
91
|
-
|
|
92
|
-
const value = await zai.filter(
|
|
93
|
-
[{ name: 'Moveworks' }, { name: 'Ada.cx' }, { name: 'Nike' }, { name: 'Voiceflow' }, { name: 'Adidas' }],
|
|
94
|
-
'competes with us',
|
|
95
|
-
{ examples }
|
|
96
|
-
)
|
|
97
|
-
|
|
98
|
-
const names = value.map((v) => v.name)
|
|
99
|
-
expect(names).toMatchInlineSnapshot(`
|
|
100
|
-
[
|
|
101
|
-
"Moveworks",
|
|
102
|
-
"Ada.cx",
|
|
103
|
-
"Voiceflow",
|
|
104
|
-
]
|
|
105
|
-
`)
|
|
106
|
-
})
|
|
107
|
-
})
|
|
108
|
-
|
|
109
|
-
describe('zai.learn.filter', { timeout: 60_000 }, () => {
|
|
110
|
-
const client = getClient()
|
|
111
|
-
let tableName = 'ZaiTestFilterInternalTable'
|
|
112
|
-
let taskId = 'filter'
|
|
113
|
-
let zai = getZai()
|
|
114
|
-
|
|
115
|
-
beforeEach(async () => {
|
|
116
|
-
zai = getZai().with({
|
|
117
|
-
activeLearning: {
|
|
118
|
-
enable: true,
|
|
119
|
-
taskId,
|
|
120
|
-
tableName
|
|
121
|
-
}
|
|
122
|
-
})
|
|
123
|
-
})
|
|
124
|
-
|
|
125
|
-
afterEach(async () => {
|
|
126
|
-
try {
|
|
127
|
-
await client.deleteTableRows({ table: tableName, deleteAllRows: true })
|
|
128
|
-
} catch (err) {}
|
|
129
|
-
})
|
|
130
|
-
|
|
131
|
-
afterAll(async () => {
|
|
132
|
-
try {
|
|
133
|
-
await client.deleteTable({ table: tableName })
|
|
134
|
-
} catch (err) {}
|
|
135
|
-
})
|
|
136
|
-
|
|
137
|
-
it('learns a filtering rule from examples', async () => {
|
|
138
|
-
const adapter = new TableAdapter({
|
|
139
|
-
client,
|
|
140
|
-
tableName
|
|
141
|
-
})
|
|
142
|
-
|
|
143
|
-
await adapter.saveExample({
|
|
144
|
-
key: 't1',
|
|
145
|
-
taskId: `zai/${taskId}`,
|
|
146
|
-
taskType: 'zai.filter',
|
|
147
|
-
instructions: 'competes with us?',
|
|
148
|
-
input: ['Rasa (framework)', 'Rasa (coffee company)'],
|
|
149
|
-
output: ['Rasa (framework)'],
|
|
150
|
-
explanation: `Rasa is a chatbot framework, so it competes with us (Botpress). We should keep it. Rasa (coffee company) is not in the chatbot or AI agent industry, therefore it does not compete with us (Botpress). We should filter it out.`,
|
|
151
|
-
metadata,
|
|
152
|
-
status: 'approved'
|
|
153
|
-
})
|
|
154
|
-
|
|
155
|
-
await adapter.saveExample({
|
|
156
|
-
key: 't2',
|
|
157
|
-
taskId: `zai/${taskId}`,
|
|
158
|
-
taskType: 'zai.filter',
|
|
159
|
-
instructions: 'competes with us?',
|
|
160
|
-
input: ['Voiceflow', 'Dialogflow'],
|
|
161
|
-
output: ['Voiceflow', 'Dialogflow'],
|
|
162
|
-
explanation: `Voiceflow is a chatbot development product, so it competes with us (Botpress). We should keep it. Dialogflow is a chatbot development product, so it competes with us (Botpress). We should keep it.`,
|
|
163
|
-
metadata,
|
|
164
|
-
status: 'approved'
|
|
165
|
-
})
|
|
166
|
-
|
|
167
|
-
const second = await zai
|
|
168
|
-
.learn(taskId)
|
|
169
|
-
.filter(['Nike', 'Ada.cx', 'Adidas', 'Moveworks', 'Lululemon'], 'competes with us? (botpress)')
|
|
170
|
-
|
|
171
|
-
expect(second).toMatchInlineSnapshot(`
|
|
172
|
-
[
|
|
173
|
-
"Ada.cx",
|
|
174
|
-
"Moveworks",
|
|
175
|
-
]
|
|
176
|
-
`)
|
|
177
|
-
|
|
178
|
-
const rows = await client.findTableRows({ table: tableName })
|
|
179
|
-
expect(rows.rows.length).toBe(3)
|
|
180
|
-
expect(rows.rows.at(-1)!.output.value).toEqual(second)
|
|
181
|
-
})
|
|
182
|
-
})
|
package/src/operations/filter.ts
DELETED
|
@@ -1,231 +0,0 @@
|
|
|
1
|
-
import { z } from '@bpinternal/zui'
|
|
2
|
-
|
|
3
|
-
import _ from 'lodash'
|
|
4
|
-
import { fastHash, stringify, takeUntilTokens } from '../utils'
|
|
5
|
-
import { Zai } from '../zai'
|
|
6
|
-
import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
|
|
7
|
-
|
|
8
|
-
type Example = z.input<typeof Example>
|
|
9
|
-
const Example = z.object({
|
|
10
|
-
input: z.any(),
|
|
11
|
-
filter: z.boolean(),
|
|
12
|
-
reason: z.string().optional()
|
|
13
|
-
})
|
|
14
|
-
|
|
15
|
-
export type Options = z.input<typeof Options>
|
|
16
|
-
const Options = z.object({
|
|
17
|
-
tokensPerItem: z
|
|
18
|
-
.number()
|
|
19
|
-
.min(1)
|
|
20
|
-
.max(100_000)
|
|
21
|
-
.optional()
|
|
22
|
-
.describe('The maximum number of tokens per item')
|
|
23
|
-
.default(250),
|
|
24
|
-
examples: z.array(Example).describe('Examples to filter the condition against').default([])
|
|
25
|
-
})
|
|
26
|
-
|
|
27
|
-
declare module '../zai' {
|
|
28
|
-
interface Zai {
|
|
29
|
-
/** Filters elements of an array against a condition */
|
|
30
|
-
filter<T>(input: Array<T>, condition: string, options?: Options): Promise<Array<T>>
|
|
31
|
-
}
|
|
32
|
-
}
|
|
33
|
-
|
|
34
|
-
const END = '■END■'
|
|
35
|
-
|
|
36
|
-
Zai.prototype.filter = async function (this: Zai, input, condition, _options) {
|
|
37
|
-
const options = Options.parse(_options ?? {})
|
|
38
|
-
const tokenizer = await this.getTokenizer()
|
|
39
|
-
|
|
40
|
-
const taskId = this.taskId
|
|
41
|
-
const taskType = 'zai.filter'
|
|
42
|
-
|
|
43
|
-
const MAX_ITEMS_PER_CHUNK = 50
|
|
44
|
-
const TOKENS_TOTAL_MAX = this.Model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER
|
|
45
|
-
const TOKENS_EXAMPLES_MAX = Math.floor(Math.max(250, TOKENS_TOTAL_MAX * 0.5))
|
|
46
|
-
const TOKENS_CONDITION_MAX = _.clamp(TOKENS_TOTAL_MAX * 0.25, 250, tokenizer.count(condition))
|
|
47
|
-
const TOKENS_INPUT_ARRAY_MAX = TOKENS_TOTAL_MAX - TOKENS_EXAMPLES_MAX - TOKENS_CONDITION_MAX
|
|
48
|
-
|
|
49
|
-
condition = tokenizer.truncate(condition, TOKENS_CONDITION_MAX)
|
|
50
|
-
|
|
51
|
-
let chunks: Array<typeof input> = []
|
|
52
|
-
let currentChunk: typeof input = []
|
|
53
|
-
let currentChunkTokens = 0
|
|
54
|
-
|
|
55
|
-
for (const element of input) {
|
|
56
|
-
const elementAsString = tokenizer.truncate(stringify(element, false), options.tokensPerItem)
|
|
57
|
-
const elementTokens = tokenizer.count(elementAsString)
|
|
58
|
-
|
|
59
|
-
if (currentChunkTokens + elementTokens > TOKENS_INPUT_ARRAY_MAX || currentChunk.length >= MAX_ITEMS_PER_CHUNK) {
|
|
60
|
-
chunks.push(currentChunk)
|
|
61
|
-
currentChunk = []
|
|
62
|
-
currentChunkTokens = 0
|
|
63
|
-
}
|
|
64
|
-
|
|
65
|
-
currentChunk.push(element)
|
|
66
|
-
currentChunkTokens += elementTokens
|
|
67
|
-
}
|
|
68
|
-
|
|
69
|
-
if (currentChunk.length > 0) {
|
|
70
|
-
chunks.push(currentChunk)
|
|
71
|
-
}
|
|
72
|
-
|
|
73
|
-
chunks = chunks.filter((x) => x.length > 0)
|
|
74
|
-
|
|
75
|
-
// ■1:true■2:true■3:true
|
|
76
|
-
|
|
77
|
-
const formatInput = (input: Example[], condition: string) => {
|
|
78
|
-
return `
|
|
79
|
-
Condition to check:
|
|
80
|
-
${condition}
|
|
81
|
-
|
|
82
|
-
Items (from ■0 to ■${input.length - 1})
|
|
83
|
-
==============================
|
|
84
|
-
${input.map((x, idx) => `■${idx} = ${stringify(x.input ?? null, false)}`).join('\n')}
|
|
85
|
-
`.trim()
|
|
86
|
-
}
|
|
87
|
-
|
|
88
|
-
const formatExamples = (examples: Example[]) => {
|
|
89
|
-
return `
|
|
90
|
-
${examples.map((x, idx) => `■${idx}:${!!x.filter ? 'true' : 'false'}`).join('')}
|
|
91
|
-
${END}
|
|
92
|
-
====
|
|
93
|
-
Here's the reasoning behind each example:
|
|
94
|
-
${examples.map((x, idx) => `■${idx}:${!!x.filter ? 'true' : 'false'}:${x.reason ?? 'No reason provided'}`).join('\n')}
|
|
95
|
-
`.trim()
|
|
96
|
-
}
|
|
97
|
-
|
|
98
|
-
const genericExamples: Example[] = [
|
|
99
|
-
{
|
|
100
|
-
input: 'apple',
|
|
101
|
-
filter: true,
|
|
102
|
-
reason: 'Apples are fruits'
|
|
103
|
-
},
|
|
104
|
-
{
|
|
105
|
-
input: 'Apple Inc.',
|
|
106
|
-
filter: false,
|
|
107
|
-
reason: 'Apple Inc. is a company, not a fruit'
|
|
108
|
-
},
|
|
109
|
-
{
|
|
110
|
-
input: 'banana',
|
|
111
|
-
filter: true,
|
|
112
|
-
reason: 'Bananas are fruits'
|
|
113
|
-
},
|
|
114
|
-
{
|
|
115
|
-
input: 'potato',
|
|
116
|
-
filter: false,
|
|
117
|
-
reason: 'Potatoes are vegetables'
|
|
118
|
-
}
|
|
119
|
-
]
|
|
120
|
-
|
|
121
|
-
const genericExamplesMessages = [
|
|
122
|
-
{
|
|
123
|
-
type: 'text' as const,
|
|
124
|
-
content: formatInput(genericExamples, 'is a fruit'),
|
|
125
|
-
role: 'user' as const
|
|
126
|
-
},
|
|
127
|
-
{
|
|
128
|
-
type: 'text' as const,
|
|
129
|
-
content: formatExamples(genericExamples),
|
|
130
|
-
role: 'assistant' as const
|
|
131
|
-
}
|
|
132
|
-
]
|
|
133
|
-
|
|
134
|
-
const filterChunk = async (chunk: typeof input) => {
|
|
135
|
-
const examples = taskId
|
|
136
|
-
? await this.adapter
|
|
137
|
-
.getExamples<string, unknown>({
|
|
138
|
-
// The Table API can't search for a huge input string
|
|
139
|
-
input: JSON.stringify(chunk).slice(0, 1000),
|
|
140
|
-
taskType,
|
|
141
|
-
taskId
|
|
142
|
-
})
|
|
143
|
-
.then((x) =>
|
|
144
|
-
x.map((y) => ({ filter: y.output as boolean, input: y.input, reason: y.explanation }) satisfies Example)
|
|
145
|
-
)
|
|
146
|
-
: []
|
|
147
|
-
|
|
148
|
-
const allExamples = takeUntilTokens([...examples, ...(options.examples ?? [])], TOKENS_EXAMPLES_MAX, (el) =>
|
|
149
|
-
tokenizer.count(stringify(el.input))
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
const exampleMessages = [
|
|
153
|
-
{
|
|
154
|
-
type: 'text' as const,
|
|
155
|
-
content: formatInput(allExamples, condition),
|
|
156
|
-
role: 'user' as const
|
|
157
|
-
},
|
|
158
|
-
{
|
|
159
|
-
type: 'text' as const,
|
|
160
|
-
content: formatExamples(allExamples),
|
|
161
|
-
role: 'assistant' as const
|
|
162
|
-
}
|
|
163
|
-
]
|
|
164
|
-
|
|
165
|
-
const output = await this.callModel({
|
|
166
|
-
systemPrompt: `
|
|
167
|
-
You are given a list of items. Your task is to filter out the items that meet the condition below.
|
|
168
|
-
You need to return the full list of items with the format:
|
|
169
|
-
■x:true■y:false■z:true (where x, y, z are the indices of the items in the list)
|
|
170
|
-
You need to start with "■0" and go up to the last index "■${chunk.length - 1}".
|
|
171
|
-
If an item meets the condition, you should return ":true", otherwise ":false".
|
|
172
|
-
|
|
173
|
-
IMPORTANT: Make sure to read the condition and the examples carefully before making your decision.
|
|
174
|
-
The condition is: "${condition}"
|
|
175
|
-
`.trim(),
|
|
176
|
-
stopSequences: [END],
|
|
177
|
-
messages: [
|
|
178
|
-
...(exampleMessages.length ? exampleMessages : genericExamplesMessages),
|
|
179
|
-
{
|
|
180
|
-
type: 'text',
|
|
181
|
-
content: formatInput(
|
|
182
|
-
chunk.map((x) => ({ input: x }) as Example),
|
|
183
|
-
condition
|
|
184
|
-
),
|
|
185
|
-
role: 'user'
|
|
186
|
-
}
|
|
187
|
-
]
|
|
188
|
-
})
|
|
189
|
-
|
|
190
|
-
const answer = output.choices[0].content as string
|
|
191
|
-
const indices = answer
|
|
192
|
-
.trim()
|
|
193
|
-
.split('■')
|
|
194
|
-
.filter((x) => x.length > 0)
|
|
195
|
-
.map((x) => {
|
|
196
|
-
const [idx, filter] = x.split(':')
|
|
197
|
-
return { idx: parseInt(idx.trim()), filter: filter.toLowerCase().trim() === 'true' }
|
|
198
|
-
})
|
|
199
|
-
|
|
200
|
-
const partial = chunk.filter((_, idx) => {
|
|
201
|
-
return indices.find((x) => x.idx === idx)?.filter ?? false
|
|
202
|
-
})
|
|
203
|
-
|
|
204
|
-
if (taskId) {
|
|
205
|
-
const key = fastHash(
|
|
206
|
-
stringify({
|
|
207
|
-
taskId,
|
|
208
|
-
taskType,
|
|
209
|
-
input: JSON.stringify(chunk),
|
|
210
|
-
condition
|
|
211
|
-
})
|
|
212
|
-
)
|
|
213
|
-
|
|
214
|
-
await this.adapter.saveExample({
|
|
215
|
-
key,
|
|
216
|
-
taskType,
|
|
217
|
-
taskId,
|
|
218
|
-
input: JSON.stringify(chunk),
|
|
219
|
-
output: partial,
|
|
220
|
-
instructions: condition,
|
|
221
|
-
metadata: output.metadata
|
|
222
|
-
})
|
|
223
|
-
}
|
|
224
|
-
|
|
225
|
-
return partial
|
|
226
|
-
}
|
|
227
|
-
|
|
228
|
-
const filteredChunks = await Promise.all(chunks.map(filterChunk))
|
|
229
|
-
|
|
230
|
-
return filteredChunks.flat()
|
|
231
|
-
}
|