@botpress/zai 2.0.16 → 2.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/context.js +131 -0
- package/dist/emitter.js +42 -0
- package/dist/index.d.ts +104 -9
- package/dist/operations/check.js +46 -27
- package/dist/operations/extract.js +63 -46
- package/dist/operations/filter.js +34 -19
- package/dist/operations/label.js +65 -42
- package/dist/operations/rewrite.js +37 -17
- package/dist/operations/summarize.js +32 -13
- package/dist/operations/text.js +28 -8
- package/dist/response.js +82 -0
- package/dist/tokenizer.js +11 -0
- package/e2e/client.ts +43 -29
- package/e2e/data/cache.jsonl +276 -0
- package/package.json +11 -3
- package/src/context.ts +197 -0
- package/src/emitter.ts +49 -0
- package/src/operations/check.ts +99 -49
- package/src/operations/extract.ts +85 -60
- package/src/operations/filter.ts +62 -35
- package/src/operations/label.ts +117 -62
- package/src/operations/rewrite.ts +50 -21
- package/src/operations/summarize.ts +40 -14
- package/src/operations/text.ts +32 -8
- package/src/response.ts +114 -0
- package/src/tokenizer.ts +14 -0
package/dist/context.js
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
import { EventEmitter } from "./emitter";
|
|
2
|
+
export class ZaiContext {
|
|
3
|
+
_startedAt = Date.now();
|
|
4
|
+
_inputCost = 0;
|
|
5
|
+
_outputCost = 0;
|
|
6
|
+
_inputTokens = 0;
|
|
7
|
+
_outputTokens = 0;
|
|
8
|
+
_totalCachedResponses = 0;
|
|
9
|
+
_totalRequests = 0;
|
|
10
|
+
_totalErrors = 0;
|
|
11
|
+
_totalResponses = 0;
|
|
12
|
+
taskId;
|
|
13
|
+
taskType;
|
|
14
|
+
modelId;
|
|
15
|
+
adapter;
|
|
16
|
+
source;
|
|
17
|
+
_eventEmitter;
|
|
18
|
+
controller = new AbortController();
|
|
19
|
+
_client;
|
|
20
|
+
constructor(props) {
|
|
21
|
+
this._client = props.client.clone();
|
|
22
|
+
this.taskId = props.taskId;
|
|
23
|
+
this.modelId = props.modelId;
|
|
24
|
+
this.adapter = props.adapter;
|
|
25
|
+
this.source = props.source;
|
|
26
|
+
this.taskType = props.taskType;
|
|
27
|
+
this._eventEmitter = new EventEmitter();
|
|
28
|
+
this._client.on("request", () => {
|
|
29
|
+
this._totalRequests++;
|
|
30
|
+
this._eventEmitter.emit("update", this.usage);
|
|
31
|
+
});
|
|
32
|
+
this._client.on("response", (_req, res) => {
|
|
33
|
+
this._totalResponses++;
|
|
34
|
+
if (res.meta.cached) {
|
|
35
|
+
this._totalCachedResponses++;
|
|
36
|
+
} else {
|
|
37
|
+
this._inputTokens += res.meta.tokens.input || 0;
|
|
38
|
+
this._outputTokens += res.meta.tokens.output || 0;
|
|
39
|
+
this._inputCost += res.meta.cost.input || 0;
|
|
40
|
+
this._outputCost += res.meta.cost.output || 0;
|
|
41
|
+
}
|
|
42
|
+
this._eventEmitter.emit("update", this.usage);
|
|
43
|
+
});
|
|
44
|
+
this._client.on("error", () => {
|
|
45
|
+
this._totalErrors++;
|
|
46
|
+
this._eventEmitter.emit("update", this.usage);
|
|
47
|
+
});
|
|
48
|
+
}
|
|
49
|
+
async getModel() {
|
|
50
|
+
return this._client.getModelDetails(this.modelId);
|
|
51
|
+
}
|
|
52
|
+
on(type, listener) {
|
|
53
|
+
this._eventEmitter.on(type, listener);
|
|
54
|
+
return this;
|
|
55
|
+
}
|
|
56
|
+
clear() {
|
|
57
|
+
this._eventEmitter.clear();
|
|
58
|
+
}
|
|
59
|
+
async generateContent(props) {
|
|
60
|
+
const maxRetries = Math.max(props.maxRetries ?? 3, 0);
|
|
61
|
+
const transform = props.transform;
|
|
62
|
+
let lastError = null;
|
|
63
|
+
const messages = [...props.messages || []];
|
|
64
|
+
for (let attempt = 0; attempt <= maxRetries; attempt++) {
|
|
65
|
+
try {
|
|
66
|
+
const response = await this._client.generateContent({
|
|
67
|
+
...props,
|
|
68
|
+
messages,
|
|
69
|
+
signal: this.controller.signal,
|
|
70
|
+
model: this.modelId,
|
|
71
|
+
meta: {
|
|
72
|
+
integrationName: props.meta?.integrationName || "zai",
|
|
73
|
+
promptCategory: props.meta?.promptCategory || `zai:${this.taskType}`,
|
|
74
|
+
promptSource: props.meta?.promptSource || `zai:${this.taskType}:${this.taskId ?? "default"}`
|
|
75
|
+
}
|
|
76
|
+
});
|
|
77
|
+
const content = response.output.choices[0]?.content;
|
|
78
|
+
const str = typeof content === "string" ? content : content?.[0]?.text || "";
|
|
79
|
+
let output;
|
|
80
|
+
messages.push({
|
|
81
|
+
role: "assistant",
|
|
82
|
+
content: str || "<Invalid output, no content provided>"
|
|
83
|
+
});
|
|
84
|
+
if (!transform) {
|
|
85
|
+
output = str;
|
|
86
|
+
} else {
|
|
87
|
+
output = transform(str, response.output);
|
|
88
|
+
}
|
|
89
|
+
return { meta: response.meta, output: response.output, text: str, extracted: output };
|
|
90
|
+
} catch (error) {
|
|
91
|
+
lastError = error;
|
|
92
|
+
if (attempt === maxRetries) {
|
|
93
|
+
throw lastError;
|
|
94
|
+
}
|
|
95
|
+
messages.push({
|
|
96
|
+
role: "user",
|
|
97
|
+
content: `ERROR PARSING OUTPUT
|
|
98
|
+
|
|
99
|
+
${lastError.message}.
|
|
100
|
+
|
|
101
|
+
Please return a valid response addressing the error above.`
|
|
102
|
+
});
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
throw lastError;
|
|
106
|
+
}
|
|
107
|
+
get elapsedTime() {
|
|
108
|
+
return Date.now() - this._startedAt;
|
|
109
|
+
}
|
|
110
|
+
get usage() {
|
|
111
|
+
return {
|
|
112
|
+
requests: {
|
|
113
|
+
errors: this._totalErrors,
|
|
114
|
+
requests: this._totalRequests,
|
|
115
|
+
responses: this._totalResponses,
|
|
116
|
+
cached: this._totalCachedResponses,
|
|
117
|
+
percentage: this._totalRequests > 0 ? (this._totalResponses + this._totalErrors) / this._totalRequests : 0
|
|
118
|
+
},
|
|
119
|
+
tokens: {
|
|
120
|
+
input: this._inputTokens,
|
|
121
|
+
output: this._outputTokens,
|
|
122
|
+
total: this._inputTokens + this._outputTokens
|
|
123
|
+
},
|
|
124
|
+
cost: {
|
|
125
|
+
input: this._inputCost,
|
|
126
|
+
output: this._outputCost,
|
|
127
|
+
total: this._inputCost + this._outputCost
|
|
128
|
+
}
|
|
129
|
+
};
|
|
130
|
+
}
|
|
131
|
+
}
|
package/dist/emitter.js
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
export class EventEmitter {
|
|
2
|
+
_listeners = {};
|
|
3
|
+
emit(type, event) {
|
|
4
|
+
const listeners = this._listeners[type];
|
|
5
|
+
if (!listeners) {
|
|
6
|
+
return;
|
|
7
|
+
}
|
|
8
|
+
for (const listener of listeners) {
|
|
9
|
+
listener(event);
|
|
10
|
+
}
|
|
11
|
+
}
|
|
12
|
+
once(type, listener) {
|
|
13
|
+
const wrapped = (event) => {
|
|
14
|
+
this.off(type, wrapped);
|
|
15
|
+
listener(event);
|
|
16
|
+
};
|
|
17
|
+
this.on(type, wrapped);
|
|
18
|
+
}
|
|
19
|
+
on(type, listener) {
|
|
20
|
+
if (!this._listeners[type]) {
|
|
21
|
+
this._listeners[type] = [];
|
|
22
|
+
}
|
|
23
|
+
this._listeners[type].push(listener);
|
|
24
|
+
}
|
|
25
|
+
off(type, listener) {
|
|
26
|
+
const listeners = this._listeners[type];
|
|
27
|
+
if (!listeners) {
|
|
28
|
+
return;
|
|
29
|
+
}
|
|
30
|
+
const index = listeners.indexOf(listener);
|
|
31
|
+
if (index !== -1) {
|
|
32
|
+
listeners.splice(index, 1);
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
clear(type) {
|
|
36
|
+
if (type) {
|
|
37
|
+
delete this._listeners[type];
|
|
38
|
+
} else {
|
|
39
|
+
this._listeners = {};
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
}
|
package/dist/index.d.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { Cognitive, Model, BotpressClientLike } from '@botpress/cognitive';
|
|
1
|
+
import { Cognitive, Model, BotpressClientLike, GenerateContentInput, GenerateContentOutput } from '@botpress/cognitive';
|
|
2
2
|
import { TextTokenizer } from '@bpinternal/thicktoken';
|
|
3
3
|
|
|
4
4
|
type GenerationMetadata = {
|
|
@@ -74,6 +74,99 @@ declare class Zai {
|
|
|
74
74
|
learn(taskId: string): Zai;
|
|
75
75
|
}
|
|
76
76
|
|
|
77
|
+
type Meta = Awaited<ReturnType<Cognitive['generateContent']>>['meta'];
|
|
78
|
+
type GenerateContentProps<T> = Omit<GenerateContentInput, 'model' | 'signal'> & {
|
|
79
|
+
maxRetries?: number;
|
|
80
|
+
transform?: (text: string | undefined, output: GenerateContentOutput) => T;
|
|
81
|
+
};
|
|
82
|
+
type ZaiContextProps = {
|
|
83
|
+
client: Cognitive;
|
|
84
|
+
taskType: string;
|
|
85
|
+
taskId: string;
|
|
86
|
+
modelId: string;
|
|
87
|
+
adapter?: Adapter;
|
|
88
|
+
source?: GenerateContentInput['meta'];
|
|
89
|
+
};
|
|
90
|
+
type Usage = {
|
|
91
|
+
requests: {
|
|
92
|
+
requests: number;
|
|
93
|
+
errors: number;
|
|
94
|
+
responses: number;
|
|
95
|
+
cached: number;
|
|
96
|
+
percentage: number;
|
|
97
|
+
};
|
|
98
|
+
cost: {
|
|
99
|
+
input: number;
|
|
100
|
+
output: number;
|
|
101
|
+
total: number;
|
|
102
|
+
};
|
|
103
|
+
tokens: {
|
|
104
|
+
input: number;
|
|
105
|
+
output: number;
|
|
106
|
+
total: number;
|
|
107
|
+
};
|
|
108
|
+
};
|
|
109
|
+
type ContextEvents = {
|
|
110
|
+
update: Usage;
|
|
111
|
+
};
|
|
112
|
+
declare class ZaiContext {
|
|
113
|
+
private _startedAt;
|
|
114
|
+
private _inputCost;
|
|
115
|
+
private _outputCost;
|
|
116
|
+
private _inputTokens;
|
|
117
|
+
private _outputTokens;
|
|
118
|
+
private _totalCachedResponses;
|
|
119
|
+
private _totalRequests;
|
|
120
|
+
private _totalErrors;
|
|
121
|
+
private _totalResponses;
|
|
122
|
+
taskId: string;
|
|
123
|
+
taskType: string;
|
|
124
|
+
modelId: GenerateContentInput['model'];
|
|
125
|
+
adapter?: Adapter;
|
|
126
|
+
source?: GenerateContentInput['meta'];
|
|
127
|
+
private _eventEmitter;
|
|
128
|
+
controller: AbortController;
|
|
129
|
+
private _client;
|
|
130
|
+
constructor(props: ZaiContextProps);
|
|
131
|
+
getModel(): Promise<Model>;
|
|
132
|
+
on<K extends keyof ContextEvents>(type: K, listener: (event: ContextEvents[K]) => void): this;
|
|
133
|
+
clear(): void;
|
|
134
|
+
generateContent<Out = string>(props: GenerateContentProps<Out>): Promise<{
|
|
135
|
+
meta: Meta;
|
|
136
|
+
output: GenerateContentOutput;
|
|
137
|
+
text: string | undefined;
|
|
138
|
+
extracted: Out;
|
|
139
|
+
}>;
|
|
140
|
+
get elapsedTime(): number;
|
|
141
|
+
get usage(): Usage;
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
type ResponseEvents<TComplete = any> = {
|
|
145
|
+
progress: Usage;
|
|
146
|
+
complete: TComplete;
|
|
147
|
+
error: unknown;
|
|
148
|
+
};
|
|
149
|
+
declare class Response<T = any, S = T> implements PromiseLike<S> {
|
|
150
|
+
private _promise;
|
|
151
|
+
private _eventEmitter;
|
|
152
|
+
private _context;
|
|
153
|
+
private _elasped;
|
|
154
|
+
private _simplify;
|
|
155
|
+
constructor(context: ZaiContext, promise: Promise<T>, simplify: (value: T) => S);
|
|
156
|
+
on<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void): this;
|
|
157
|
+
off<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void): this;
|
|
158
|
+
once<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void): this;
|
|
159
|
+
bindSignal(signal: AbortSignal): this;
|
|
160
|
+
abort(reason?: string | Error): void;
|
|
161
|
+
then<TResult1 = S, TResult2 = never>(onfulfilled?: ((value: S) => TResult1 | PromiseLike<TResult1>) | null, onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null): PromiseLike<TResult1 | TResult2>;
|
|
162
|
+
catch<TResult = never>(onrejected?: ((reason: any) => TResult | PromiseLike<TResult>) | null): PromiseLike<S | TResult>;
|
|
163
|
+
result(): Promise<{
|
|
164
|
+
output: T;
|
|
165
|
+
usage: Usage;
|
|
166
|
+
elapsed: number;
|
|
167
|
+
}>;
|
|
168
|
+
}
|
|
169
|
+
|
|
77
170
|
type Options$6 = {
|
|
78
171
|
/** The maximum number of tokens to generate */
|
|
79
172
|
length?: number;
|
|
@@ -81,7 +174,7 @@ type Options$6 = {
|
|
|
81
174
|
declare module '@botpress/zai' {
|
|
82
175
|
interface Zai {
|
|
83
176
|
/** Generates a text of the desired length according to the prompt */
|
|
84
|
-
text(prompt: string, options?: Options$6):
|
|
177
|
+
text(prompt: string, options?: Options$6): Response<string>;
|
|
85
178
|
}
|
|
86
179
|
}
|
|
87
180
|
|
|
@@ -99,7 +192,7 @@ type Options$5 = {
|
|
|
99
192
|
declare module '@botpress/zai' {
|
|
100
193
|
interface Zai {
|
|
101
194
|
/** Rewrites a string according to match the prompt */
|
|
102
|
-
rewrite(original: string, prompt: string, options?: Options$5):
|
|
195
|
+
rewrite(original: string, prompt: string, options?: Options$5): Response<string>;
|
|
103
196
|
}
|
|
104
197
|
}
|
|
105
198
|
|
|
@@ -123,7 +216,7 @@ type Options$4 = {
|
|
|
123
216
|
declare module '@botpress/zai' {
|
|
124
217
|
interface Zai {
|
|
125
218
|
/** Summarizes a text of any length to a summary of the desired length */
|
|
126
|
-
summarize(original: string, options?: Options$4):
|
|
219
|
+
summarize(original: string, options?: Options$4): Response<string>;
|
|
127
220
|
}
|
|
128
221
|
}
|
|
129
222
|
|
|
@@ -140,12 +233,12 @@ type Options$3 = {
|
|
|
140
233
|
declare module '@botpress/zai' {
|
|
141
234
|
interface Zai {
|
|
142
235
|
/** Checks wether a condition is true or not */
|
|
143
|
-
check(input: unknown, condition: string, options?: Options$3):
|
|
236
|
+
check(input: unknown, condition: string, options?: Options$3): Response<{
|
|
144
237
|
/** Whether the condition is true or not */
|
|
145
238
|
value: boolean;
|
|
146
239
|
/** The explanation of the decision */
|
|
147
240
|
explanation: string;
|
|
148
|
-
}>;
|
|
241
|
+
}, boolean>;
|
|
149
242
|
}
|
|
150
243
|
}
|
|
151
244
|
|
|
@@ -163,7 +256,7 @@ type Options$2 = {
|
|
|
163
256
|
declare module '@botpress/zai' {
|
|
164
257
|
interface Zai {
|
|
165
258
|
/** Filters elements of an array against a condition */
|
|
166
|
-
filter<T>(input: Array<T>, condition: string, options?: Options$2):
|
|
259
|
+
filter<T>(input: Array<T>, condition: string, options?: Options$2): Response<Array<T>>;
|
|
167
260
|
}
|
|
168
261
|
}
|
|
169
262
|
|
|
@@ -182,7 +275,7 @@ type OfType<O, T extends __Z = __Z<O>> = T extends __Z<O> ? T : never;
|
|
|
182
275
|
declare module '@botpress/zai' {
|
|
183
276
|
interface Zai {
|
|
184
277
|
/** Extracts one or many elements from an arbitrary input */
|
|
185
|
-
extract<S extends OfType<any>>(input: unknown, schema: S, options?: Options$1):
|
|
278
|
+
extract<S extends OfType<any>>(input: unknown, schema: S, options?: Options$1): Response<S['_output']>;
|
|
186
279
|
}
|
|
187
280
|
}
|
|
188
281
|
|
|
@@ -213,12 +306,14 @@ type Labels<T extends string> = Record<T, string>;
|
|
|
213
306
|
declare module '@botpress/zai' {
|
|
214
307
|
interface Zai {
|
|
215
308
|
/** Tags the provided input with a list of predefined labels */
|
|
216
|
-
label<T extends string>(input: unknown, labels: Labels<T>, options?: Options<T>):
|
|
309
|
+
label<T extends string>(input: unknown, labels: Labels<T>, options?: Options<T>): Response<{
|
|
217
310
|
[K in T]: {
|
|
218
311
|
explanation: string;
|
|
219
312
|
value: boolean;
|
|
220
313
|
confidence: number;
|
|
221
314
|
};
|
|
315
|
+
}, {
|
|
316
|
+
[K in T]: boolean;
|
|
222
317
|
}>;
|
|
223
318
|
}
|
|
224
319
|
}
|
package/dist/operations/check.js
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
import { z } from "@bpinternal/zui";
|
|
2
|
+
import { ZaiContext } from "../context";
|
|
3
|
+
import { Response } from "../response";
|
|
4
|
+
import { getTokenizer } from "../tokenizer";
|
|
2
5
|
import { fastHash, stringify, takeUntilTokens } from "../utils";
|
|
3
6
|
import { Zai } from "../zai";
|
|
4
7
|
import { PROMPT_INPUT_BUFFER } from "./constants";
|
|
@@ -14,12 +17,12 @@ const _Options = z.object({
|
|
|
14
17
|
const TRUE = "\u25A0TRUE\u25A0";
|
|
15
18
|
const FALSE = "\u25A0FALSE\u25A0";
|
|
16
19
|
const END = "\u25A0END\u25A0";
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
const tokenizer = await
|
|
20
|
-
await
|
|
21
|
-
const PROMPT_COMPONENT = Math.max(
|
|
22
|
-
const taskId =
|
|
20
|
+
const check = async (input, condition, options, ctx) => {
|
|
21
|
+
ctx.controller.signal.throwIfAborted();
|
|
22
|
+
const tokenizer = await getTokenizer();
|
|
23
|
+
const model = await ctx.getModel();
|
|
24
|
+
const PROMPT_COMPONENT = Math.max(model.input.maxTokens - PROMPT_INPUT_BUFFER, 100);
|
|
25
|
+
const taskId = ctx.taskId;
|
|
23
26
|
const taskType = "zai.check";
|
|
24
27
|
const PROMPT_TOKENS = {
|
|
25
28
|
INPUT: Math.floor(0.5 * PROMPT_COMPONENT),
|
|
@@ -36,7 +39,7 @@ Zai.prototype.check = async function(input, condition, _options) {
|
|
|
36
39
|
condition
|
|
37
40
|
})
|
|
38
41
|
);
|
|
39
|
-
const examples = taskId ? await
|
|
42
|
+
const examples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
|
|
40
43
|
input: inputAsString,
|
|
41
44
|
taskType,
|
|
42
45
|
taskId
|
|
@@ -73,10 +76,10 @@ ${input2.trim()}
|
|
|
73
76
|
<|end_input|>
|
|
74
77
|
`.trim();
|
|
75
78
|
};
|
|
76
|
-
const formatOutput = (
|
|
79
|
+
const formatOutput = (answer, justification) => {
|
|
77
80
|
return `
|
|
78
81
|
Analysis: ${justification}
|
|
79
|
-
Final Answer: ${
|
|
82
|
+
Final Answer: ${answer ? TRUE : FALSE}
|
|
80
83
|
${END}
|
|
81
84
|
`.trim();
|
|
82
85
|
};
|
|
@@ -103,7 +106,10 @@ ${END}
|
|
|
103
106
|
- When in doubt, ground your decision on the examples provided by the experts instead of your own intuition.
|
|
104
107
|
- When no example is similar to the input, make sure to provide a clear justification for your decision while inferring the decision-making process from the examples provided by the experts.
|
|
105
108
|
`.trim() : "";
|
|
106
|
-
const {
|
|
109
|
+
const {
|
|
110
|
+
extracted: { finalAnswer, explanation },
|
|
111
|
+
meta
|
|
112
|
+
} = await ctx.generateContent({
|
|
107
113
|
systemPrompt: `
|
|
108
114
|
Check if the following condition is true or false for the given input. Before answering, make sure to read the input and the condition carefully.
|
|
109
115
|
Justify your answer, then answer with either ${TRUE} or ${FALSE} at the very end, then add ${END} to finish the response.
|
|
@@ -123,23 +129,25 @@ ${formatInput(inputAsString, condition)}
|
|
|
123
129
|
In your "Analysis", please refer to the Expert Examples # to justify your decision.`.trim(),
|
|
124
130
|
role: "user"
|
|
125
131
|
}
|
|
126
|
-
]
|
|
132
|
+
],
|
|
133
|
+
transform: (text) => {
|
|
134
|
+
const hasTrue = text.includes(TRUE);
|
|
135
|
+
const hasFalse = text.includes(FALSE);
|
|
136
|
+
if (!hasTrue && !hasFalse) {
|
|
137
|
+
throw new Error(`The model did not return a valid answer. The response was: ${text}`);
|
|
138
|
+
}
|
|
139
|
+
let finalAnswer2;
|
|
140
|
+
const explanation2 = text.replace(TRUE, "").replace(FALSE, "").replace(END, "").replace("Final Answer:", "").replace("Analysis:", "").trim();
|
|
141
|
+
if (hasTrue && hasFalse) {
|
|
142
|
+
finalAnswer2 = text.lastIndexOf(TRUE) > text.lastIndexOf(FALSE);
|
|
143
|
+
} else {
|
|
144
|
+
finalAnswer2 = hasTrue;
|
|
145
|
+
}
|
|
146
|
+
return { finalAnswer: finalAnswer2, explanation: explanation2.trim() };
|
|
147
|
+
}
|
|
127
148
|
});
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
const hasFalse = answer.includes(FALSE);
|
|
131
|
-
if (!hasTrue && !hasFalse) {
|
|
132
|
-
throw new Error(`The model did not return a valid answer. The response was: ${answer}`);
|
|
133
|
-
}
|
|
134
|
-
let finalAnswer;
|
|
135
|
-
const explanation = answer.replace(TRUE, "").replace(FALSE, "").replace(END, "").replace("Final Answer:", "").replace("Analysis:", "").trim();
|
|
136
|
-
if (hasTrue && hasFalse) {
|
|
137
|
-
finalAnswer = answer.lastIndexOf(TRUE) > answer.lastIndexOf(FALSE);
|
|
138
|
-
} else {
|
|
139
|
-
finalAnswer = hasTrue;
|
|
140
|
-
}
|
|
141
|
-
if (taskId) {
|
|
142
|
-
await this.adapter.saveExample({
|
|
149
|
+
if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
|
|
150
|
+
await ctx.adapter.saveExample({
|
|
143
151
|
key: Key,
|
|
144
152
|
taskType,
|
|
145
153
|
taskId,
|
|
@@ -151,7 +159,7 @@ In your "Analysis", please refer to the Expert Examples # to justify your decisi
|
|
|
151
159
|
output: meta.cost.output
|
|
152
160
|
},
|
|
153
161
|
latency: meta.latency,
|
|
154
|
-
model:
|
|
162
|
+
model: ctx.modelId,
|
|
155
163
|
tokens: {
|
|
156
164
|
input: meta.tokens.input,
|
|
157
165
|
output: meta.tokens.output
|
|
@@ -166,3 +174,14 @@ In your "Analysis", please refer to the Expert Examples # to justify your decisi
|
|
|
166
174
|
explanation: explanation.trim()
|
|
167
175
|
};
|
|
168
176
|
};
|
|
177
|
+
Zai.prototype.check = function(input, condition, _options) {
|
|
178
|
+
const options = _Options.parse(_options ?? {});
|
|
179
|
+
const context = new ZaiContext({
|
|
180
|
+
client: this.client,
|
|
181
|
+
modelId: this.Model,
|
|
182
|
+
taskId: this.taskId,
|
|
183
|
+
taskType: "zai.check",
|
|
184
|
+
adapter: this.adapter
|
|
185
|
+
});
|
|
186
|
+
return new Response(context, check(input, condition, options, context), (result) => result.value);
|
|
187
|
+
};
|
|
@@ -2,6 +2,9 @@ import { z } from "@bpinternal/zui";
|
|
|
2
2
|
import JSON5 from "json5";
|
|
3
3
|
import { jsonrepair } from "jsonrepair";
|
|
4
4
|
import { chunk, isArray } from "lodash-es";
|
|
5
|
+
import { ZaiContext } from "../context";
|
|
6
|
+
import { Response } from "../response";
|
|
7
|
+
import { getTokenizer } from "../tokenizer";
|
|
5
8
|
import { fastHash, stringify, takeUntilTokens } from "../utils";
|
|
6
9
|
import { Zai } from "../zai";
|
|
7
10
|
import { PROMPT_INPUT_BUFFER } from "./constants";
|
|
@@ -14,14 +17,15 @@ const Options = z.object({
|
|
|
14
17
|
const START = "\u25A0json_start\u25A0";
|
|
15
18
|
const END = "\u25A0json_end\u25A0";
|
|
16
19
|
const NO_MORE = "\u25A0NO_MORE_ELEMENT\u25A0";
|
|
17
|
-
|
|
20
|
+
const extract = async (input, _schema, _options, ctx) => {
|
|
21
|
+
ctx.controller.signal.throwIfAborted();
|
|
18
22
|
let schema = _schema;
|
|
19
23
|
const options = Options.parse(_options ?? {});
|
|
20
|
-
const tokenizer = await
|
|
21
|
-
await
|
|
22
|
-
const taskId =
|
|
24
|
+
const tokenizer = await getTokenizer();
|
|
25
|
+
const model = await ctx.getModel();
|
|
26
|
+
const taskId = ctx.taskId;
|
|
23
27
|
const taskType = "zai.extract";
|
|
24
|
-
const PROMPT_COMPONENT = Math.max(
|
|
28
|
+
const PROMPT_COMPONENT = Math.max(model.input.maxTokens - PROMPT_INPUT_BUFFER, 100);
|
|
25
29
|
let isArrayOfObjects = false;
|
|
26
30
|
let wrappedValue = false;
|
|
27
31
|
const originalSchema = schema;
|
|
@@ -54,10 +58,7 @@ Zai.prototype.extract = async function(input, _schema, _options) {
|
|
|
54
58
|
}
|
|
55
59
|
const schemaTypescript = schema.toTypescriptType({ declaration: false });
|
|
56
60
|
const schemaLength = tokenizer.count(schemaTypescript);
|
|
57
|
-
options.chunkLength = Math.min(
|
|
58
|
-
options.chunkLength,
|
|
59
|
-
this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength
|
|
60
|
-
);
|
|
61
|
+
options.chunkLength = Math.min(options.chunkLength, model.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength);
|
|
61
62
|
const keys = Object.keys(schema.shape);
|
|
62
63
|
const inputAsString = stringify(input);
|
|
63
64
|
if (tokenizer.count(inputAsString) > options.chunkLength) {
|
|
@@ -65,19 +66,25 @@ Zai.prototype.extract = async function(input, _schema, _options) {
|
|
|
65
66
|
const chunks = chunk(tokens, options.chunkLength).map((x) => x.join(""));
|
|
66
67
|
const all = await Promise.allSettled(
|
|
67
68
|
chunks.map(
|
|
68
|
-
(chunk2) =>
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
69
|
+
(chunk2) => extract(
|
|
70
|
+
chunk2,
|
|
71
|
+
originalSchema,
|
|
72
|
+
{
|
|
73
|
+
...options,
|
|
74
|
+
strict: false
|
|
75
|
+
// We don't want to fail on strict mode for sub-chunks
|
|
76
|
+
},
|
|
77
|
+
ctx
|
|
78
|
+
)
|
|
73
79
|
)
|
|
74
80
|
).then(
|
|
75
81
|
(results) => results.filter((x) => x.status === "fulfilled").map((x) => x.value)
|
|
76
82
|
);
|
|
83
|
+
ctx.controller.signal.throwIfAborted();
|
|
77
84
|
const rows = all.map((x, idx) => `<part-${idx + 1}>
|
|
78
85
|
${stringify(x, true)}
|
|
79
86
|
</part-${idx + 1}>`).join("\n");
|
|
80
|
-
return
|
|
87
|
+
return extract(
|
|
81
88
|
`
|
|
82
89
|
The result has been split into ${all.length} parts. Recursively merge the result into the final result.
|
|
83
90
|
When merging arrays, take unique values.
|
|
@@ -89,7 +96,8 @@ ${rows}
|
|
|
89
96
|
|
|
90
97
|
Merge it back into a final result.`.trim(),
|
|
91
98
|
originalSchema,
|
|
92
|
-
options
|
|
99
|
+
options,
|
|
100
|
+
ctx
|
|
93
101
|
);
|
|
94
102
|
}
|
|
95
103
|
const instructions = [];
|
|
@@ -123,7 +131,7 @@ Merge it back into a final result.`.trim(),
|
|
|
123
131
|
instructions: options.instructions
|
|
124
132
|
})
|
|
125
133
|
);
|
|
126
|
-
const examples = taskId ? await
|
|
134
|
+
const examples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
|
|
127
135
|
input: inputAsString,
|
|
128
136
|
taskType,
|
|
129
137
|
taskId
|
|
@@ -182,9 +190,9 @@ ${input2.trim()}
|
|
|
182
190
|
<|end_input|>
|
|
183
191
|
`.trim();
|
|
184
192
|
};
|
|
185
|
-
const formatOutput = (
|
|
186
|
-
|
|
187
|
-
return
|
|
193
|
+
const formatOutput = (extracted2) => {
|
|
194
|
+
extracted2 = isArray(extracted2) ? extracted2 : [extracted2];
|
|
195
|
+
return extracted2.map(
|
|
188
196
|
(x) => `
|
|
189
197
|
${START}
|
|
190
198
|
${JSON.stringify(x, null, 2)}
|
|
@@ -208,7 +216,7 @@ ${END}`.trim()
|
|
|
208
216
|
EXAMPLES_TOKENS,
|
|
209
217
|
(el) => tokenizer.count(stringify(el.input)) + tokenizer.count(stringify(el.extracted))
|
|
210
218
|
).map(formatExample).flat();
|
|
211
|
-
const {
|
|
219
|
+
const { meta, extracted } = await ctx.generateContent({
|
|
212
220
|
systemPrompt: `
|
|
213
221
|
Extract the following information from the input:
|
|
214
222
|
${schemaTypescript}
|
|
@@ -224,33 +232,32 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
|
|
|
224
232
|
type: "text",
|
|
225
233
|
content: formatInput(inputAsString, schemaTypescript, options.instructions ?? "")
|
|
226
234
|
}
|
|
227
|
-
]
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
235
|
+
],
|
|
236
|
+
transform: (text) => (text || "{}")?.split(START).filter((x) => x.trim().length > 0 && x.includes("}")).map((x) => {
|
|
237
|
+
try {
|
|
238
|
+
const json = x.slice(0, x.indexOf(END)).trim();
|
|
239
|
+
const repairedJson = jsonrepair(json);
|
|
240
|
+
const parsedJson = JSON5.parse(repairedJson);
|
|
241
|
+
const safe = schema.safeParse(parsedJson);
|
|
242
|
+
if (safe.success) {
|
|
243
|
+
return safe.data;
|
|
244
|
+
}
|
|
245
|
+
if (options.strict) {
|
|
246
|
+
throw new JsonParsingError(x, safe.error);
|
|
247
|
+
}
|
|
248
|
+
return parsedJson;
|
|
249
|
+
} catch (error) {
|
|
250
|
+
throw new JsonParsingError(x, error instanceof Error ? error : new Error("Unknown error"));
|
|
241
251
|
}
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
throw new JsonParsingError(x, error instanceof Error ? error : new Error("Unknown error"));
|
|
245
|
-
}
|
|
246
|
-
}).filter((x) => x !== null);
|
|
252
|
+
}).filter((x) => x !== null)
|
|
253
|
+
});
|
|
247
254
|
let final;
|
|
248
255
|
if (isArrayOfObjects) {
|
|
249
|
-
final =
|
|
250
|
-
} else if (
|
|
256
|
+
final = extracted;
|
|
257
|
+
} else if (extracted.length === 0) {
|
|
251
258
|
final = options.strict ? schema.parse({}) : {};
|
|
252
259
|
} else {
|
|
253
|
-
final =
|
|
260
|
+
final = extracted[0];
|
|
254
261
|
}
|
|
255
262
|
if (wrappedValue) {
|
|
256
263
|
if (Array.isArray(final)) {
|
|
@@ -259,8 +266,8 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
|
|
|
259
266
|
final = "value" in final ? final.value : final;
|
|
260
267
|
}
|
|
261
268
|
}
|
|
262
|
-
if (taskId) {
|
|
263
|
-
await
|
|
269
|
+
if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
|
|
270
|
+
await ctx.adapter.saveExample({
|
|
264
271
|
key: Key,
|
|
265
272
|
taskId: `zai/${taskId}`,
|
|
266
273
|
taskType,
|
|
@@ -273,7 +280,7 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
|
|
|
273
280
|
output: meta.cost.output
|
|
274
281
|
},
|
|
275
282
|
latency: meta.latency,
|
|
276
|
-
model:
|
|
283
|
+
model: ctx.modelId,
|
|
277
284
|
tokens: {
|
|
278
285
|
input: meta.tokens.input,
|
|
279
286
|
output: meta.tokens.output
|
|
@@ -283,3 +290,13 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
|
|
|
283
290
|
}
|
|
284
291
|
return final;
|
|
285
292
|
};
|
|
293
|
+
Zai.prototype.extract = function(input, schema, _options) {
|
|
294
|
+
const context = new ZaiContext({
|
|
295
|
+
client: this.client,
|
|
296
|
+
modelId: this.Model,
|
|
297
|
+
taskId: this.taskId,
|
|
298
|
+
taskType: "zai.extract",
|
|
299
|
+
adapter: this.adapter
|
|
300
|
+
});
|
|
301
|
+
return new Response(context, extract(input, schema, _options, context), (result) => result);
|
|
302
|
+
};
|