@botpress/zai 2.0.14 → 2.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +3 -2
- package/dist/operations/extract.js +72 -20
- package/e2e/data/cache.jsonl +157 -0
- package/e2e/utils.ts +3 -2
- package/package.json +1 -1
- package/src/operations/extract.ts +77 -24
package/dist/index.d.ts
CHANGED
|
@@ -172,16 +172,17 @@ type Options$1 = {
|
|
|
172
172
|
instructions?: string;
|
|
173
173
|
/** The maximum number of tokens per chunk */
|
|
174
174
|
chunkLength?: number;
|
|
175
|
+
/** Whether to strictly follow the schema or not */
|
|
176
|
+
strict?: boolean;
|
|
175
177
|
};
|
|
176
178
|
type __Z<T extends any = any> = {
|
|
177
179
|
_output: T;
|
|
178
180
|
};
|
|
179
181
|
type OfType<O, T extends __Z = __Z<O>> = T extends __Z<O> ? T : never;
|
|
180
|
-
type AnyObjectOrArray = Record<string, unknown> | Array<unknown>;
|
|
181
182
|
declare module '@botpress/zai' {
|
|
182
183
|
interface Zai {
|
|
183
184
|
/** Extracts one or many elements from an arbitrary input */
|
|
184
|
-
extract<S extends OfType<
|
|
185
|
+
extract<S extends OfType<any>>(input: unknown, schema: S, options?: Options$1): Promise<S['_output']>;
|
|
185
186
|
}
|
|
186
187
|
}
|
|
187
188
|
|
|
@@ -8,7 +8,8 @@ import { PROMPT_INPUT_BUFFER } from "./constants";
|
|
|
8
8
|
import { JsonParsingError } from "./errors";
|
|
9
9
|
const Options = z.object({
|
|
10
10
|
instructions: z.string().optional().describe("Instructions to guide the user on how to extract the data"),
|
|
11
|
-
chunkLength: z.number().min(100).max(1e5).optional().describe("The maximum number of tokens per chunk").default(16e3)
|
|
11
|
+
chunkLength: z.number().min(100).max(1e5).optional().describe("The maximum number of tokens per chunk").default(16e3),
|
|
12
|
+
strict: z.boolean().optional().default(true).describe("Whether to strictly follow the schema or not")
|
|
12
13
|
});
|
|
13
14
|
const START = "\u25A0json_start\u25A0";
|
|
14
15
|
const END = "\u25A0json_end\u25A0";
|
|
@@ -22,22 +23,34 @@ Zai.prototype.extract = async function(input, _schema, _options) {
|
|
|
22
23
|
const taskType = "zai.extract";
|
|
23
24
|
const PROMPT_COMPONENT = Math.max(this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER, 100);
|
|
24
25
|
let isArrayOfObjects = false;
|
|
26
|
+
let wrappedValue = false;
|
|
25
27
|
const originalSchema = schema;
|
|
26
28
|
const baseType = (schema.naked ? schema.naked() : schema)?.constructor?.name ?? "unknown";
|
|
27
|
-
if (baseType === "
|
|
28
|
-
|
|
29
|
+
if (baseType === "ZodArray") {
|
|
30
|
+
isArrayOfObjects = true;
|
|
29
31
|
let elementType = schema.element;
|
|
30
32
|
if (elementType.naked) {
|
|
31
33
|
elementType = elementType.naked();
|
|
32
34
|
}
|
|
33
35
|
if (elementType?.constructor?.name === "ZodObject") {
|
|
34
|
-
isArrayOfObjects = true;
|
|
35
36
|
schema = elementType;
|
|
36
37
|
} else {
|
|
37
|
-
|
|
38
|
+
wrappedValue = true;
|
|
39
|
+
schema = z.object({
|
|
40
|
+
value: elementType
|
|
41
|
+
});
|
|
42
|
+
}
|
|
43
|
+
} else if (baseType !== "ZodObject") {
|
|
44
|
+
wrappedValue = true;
|
|
45
|
+
schema = z.object({
|
|
46
|
+
value: originalSchema
|
|
47
|
+
});
|
|
48
|
+
}
|
|
49
|
+
if (!options.strict) {
|
|
50
|
+
try {
|
|
51
|
+
schema = schema.partial();
|
|
52
|
+
} catch {
|
|
38
53
|
}
|
|
39
|
-
} else {
|
|
40
|
-
throw new Error("Schema must be either a ZuiObject or a ZuiArray<ZuiObject>");
|
|
41
54
|
}
|
|
42
55
|
const schemaTypescript = schema.toTypescriptType({ declaration: false });
|
|
43
56
|
const schemaLength = tokenizer.count(schemaTypescript);
|
|
@@ -46,16 +59,38 @@ Zai.prototype.extract = async function(input, _schema, _options) {
|
|
|
46
59
|
this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength
|
|
47
60
|
);
|
|
48
61
|
const keys = Object.keys(schema.shape);
|
|
49
|
-
|
|
62
|
+
const inputAsString = stringify(input);
|
|
50
63
|
if (tokenizer.count(inputAsString) > options.chunkLength) {
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
64
|
+
const tokens = tokenizer.split(inputAsString);
|
|
65
|
+
const chunks = chunk(tokens, options.chunkLength).map((x) => x.join(""));
|
|
66
|
+
const all = await Promise.allSettled(
|
|
67
|
+
chunks.map(
|
|
68
|
+
(chunk2) => this.extract(chunk2, originalSchema, {
|
|
69
|
+
...options,
|
|
70
|
+
strict: false
|
|
71
|
+
// We don't want to fail on strict mode for sub-chunks
|
|
72
|
+
})
|
|
73
|
+
)
|
|
74
|
+
).then(
|
|
75
|
+
(results) => results.filter((x) => x.status === "fulfilled").map((x) => x.value)
|
|
76
|
+
);
|
|
77
|
+
const rows = all.map((x, idx) => `<part-${idx + 1}>
|
|
78
|
+
${stringify(x, true)}
|
|
79
|
+
</part-${idx + 1}>`).join("\n");
|
|
80
|
+
return this.extract(
|
|
81
|
+
`
|
|
82
|
+
The result has been split into ${all.length} parts. Recursively merge the result into the final result.
|
|
83
|
+
When merging arrays, take unique values.
|
|
84
|
+
When merging conflictual (but defined) information, take the most reasonable and frequent value.
|
|
85
|
+
Non-defined values are OK and normal. Don't delete fields because of null values. Focus on defined values.
|
|
86
|
+
|
|
87
|
+
Here's the data:
|
|
88
|
+
${rows}
|
|
89
|
+
|
|
90
|
+
Merge it back into a final result.`.trim(),
|
|
91
|
+
originalSchema,
|
|
92
|
+
options
|
|
93
|
+
);
|
|
59
94
|
}
|
|
60
95
|
const instructions = [];
|
|
61
96
|
if (options.instructions) {
|
|
@@ -76,6 +111,9 @@ Zai.prototype.extract = async function(input, _schema, _options) {
|
|
|
76
111
|
instructions.push("You may have exactly one element in the input.");
|
|
77
112
|
instructions.push(`The element must be a JSON object with exactly the format: ${START}${shape}${END}`);
|
|
78
113
|
}
|
|
114
|
+
if (!options.strict) {
|
|
115
|
+
instructions.push("You may ignore any fields that are not present in the input. All keys are optional.");
|
|
116
|
+
}
|
|
79
117
|
const EXAMPLES_TOKENS = PROMPT_COMPONENT - tokenizer.count(inputAsString) - tokenizer.count(instructions.join("\n"));
|
|
80
118
|
const Key = fastHash(
|
|
81
119
|
JSON.stringify({
|
|
@@ -188,13 +226,20 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
|
|
|
188
226
|
}
|
|
189
227
|
]
|
|
190
228
|
});
|
|
191
|
-
const answer = output.choices[0]?.content;
|
|
192
|
-
const elements = answer
|
|
229
|
+
const answer = output.choices[0]?.content ?? "{}";
|
|
230
|
+
const elements = answer?.split(START).filter((x) => x.trim().length > 0 && x.includes("}")).map((x) => {
|
|
193
231
|
try {
|
|
194
232
|
const json = x.slice(0, x.indexOf(END)).trim();
|
|
195
233
|
const repairedJson = jsonrepair(json);
|
|
196
234
|
const parsedJson = JSON5.parse(repairedJson);
|
|
197
|
-
|
|
235
|
+
const safe = schema.safeParse(parsedJson);
|
|
236
|
+
if (safe.success) {
|
|
237
|
+
return safe.data;
|
|
238
|
+
}
|
|
239
|
+
if (options.strict) {
|
|
240
|
+
throw new JsonParsingError(x, safe.error);
|
|
241
|
+
}
|
|
242
|
+
return parsedJson;
|
|
198
243
|
} catch (error) {
|
|
199
244
|
throw new JsonParsingError(x, error instanceof Error ? error : new Error("Unknown error"));
|
|
200
245
|
}
|
|
@@ -203,10 +248,17 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
|
|
|
203
248
|
if (isArrayOfObjects) {
|
|
204
249
|
final = elements;
|
|
205
250
|
} else if (elements.length === 0) {
|
|
206
|
-
final = schema.parse({});
|
|
251
|
+
final = options.strict ? schema.parse({}) : {};
|
|
207
252
|
} else {
|
|
208
253
|
final = elements[0];
|
|
209
254
|
}
|
|
255
|
+
if (wrappedValue) {
|
|
256
|
+
if (Array.isArray(final)) {
|
|
257
|
+
final = final.map((x) => "value" in x ? x.value : x);
|
|
258
|
+
} else {
|
|
259
|
+
final = "value" in final ? final.value : final;
|
|
260
|
+
}
|
|
261
|
+
}
|
|
210
262
|
if (taskId) {
|
|
211
263
|
await this.adapter.saveExample({
|
|
212
264
|
key: Key,
|