@botpress/zai 2.0.16 → 2.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/context.js +131 -0
- package/dist/emitter.js +42 -0
- package/dist/index.d.ts +104 -9
- package/dist/operations/check.js +46 -27
- package/dist/operations/extract.js +63 -46
- package/dist/operations/filter.js +34 -19
- package/dist/operations/label.js +65 -42
- package/dist/operations/rewrite.js +37 -17
- package/dist/operations/summarize.js +32 -13
- package/dist/operations/text.js +28 -8
- package/dist/response.js +82 -0
- package/dist/tokenizer.js +11 -0
- package/e2e/client.ts +43 -29
- package/e2e/data/cache.jsonl +276 -0
- package/package.json +11 -3
- package/src/context.ts +197 -0
- package/src/emitter.ts +49 -0
- package/src/operations/check.ts +99 -49
- package/src/operations/extract.ts +85 -60
- package/src/operations/filter.ts +62 -35
- package/src/operations/label.ts +117 -62
- package/src/operations/rewrite.ts +50 -21
- package/src/operations/summarize.ts +40 -14
- package/src/operations/text.ts +32 -8
- package/src/response.ts +114 -0
- package/src/tokenizer.ts +14 -0
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import { z } from "@bpinternal/zui";
|
|
2
2
|
import { clamp } from "lodash-es";
|
|
3
|
+
import { ZaiContext } from "../context";
|
|
4
|
+
import { Response } from "../response";
|
|
5
|
+
import { getTokenizer } from "../tokenizer";
|
|
3
6
|
import { fastHash, stringify, takeUntilTokens } from "../utils";
|
|
4
7
|
import { Zai } from "../zai";
|
|
5
8
|
import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from "./constants";
|
|
@@ -13,14 +16,15 @@ const _Options = z.object({
|
|
|
13
16
|
examples: z.array(_Example).describe("Examples to filter the condition against").default([])
|
|
14
17
|
});
|
|
15
18
|
const END = "\u25A0END\u25A0";
|
|
16
|
-
|
|
19
|
+
const filter = async (input, condition, _options, ctx) => {
|
|
20
|
+
ctx.controller.signal.throwIfAborted();
|
|
17
21
|
const options = _Options.parse(_options ?? {});
|
|
18
|
-
const tokenizer = await
|
|
19
|
-
await
|
|
20
|
-
const taskId =
|
|
22
|
+
const tokenizer = await getTokenizer();
|
|
23
|
+
const model = await ctx.getModel();
|
|
24
|
+
const taskId = ctx.taskId;
|
|
21
25
|
const taskType = "zai.filter";
|
|
22
26
|
const MAX_ITEMS_PER_CHUNK = 50;
|
|
23
|
-
const TOKENS_TOTAL_MAX =
|
|
27
|
+
const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER;
|
|
24
28
|
const TOKENS_EXAMPLES_MAX = Math.floor(Math.max(250, TOKENS_TOTAL_MAX * 0.5));
|
|
25
29
|
const TOKENS_CONDITION_MAX = clamp(TOKENS_TOTAL_MAX * 0.25, 250, tokenizer.count(condition));
|
|
26
30
|
const TOKENS_INPUT_ARRAY_MAX = TOKENS_TOTAL_MAX - TOKENS_EXAMPLES_MAX - TOKENS_CONDITION_MAX;
|
|
@@ -97,7 +101,7 @@ ${examples.map((x, idx) => `\u25A0${idx}:${!!x.filter ? "true" : "false"}:${x.re
|
|
|
97
101
|
}
|
|
98
102
|
];
|
|
99
103
|
const filterChunk = async (chunk) => {
|
|
100
|
-
const examples = taskId ? await
|
|
104
|
+
const examples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
|
|
101
105
|
// The Table API can't search for a huge input string
|
|
102
106
|
input: JSON.stringify(chunk).slice(0, 1e3),
|
|
103
107
|
taskType,
|
|
@@ -122,7 +126,7 @@ ${examples.map((x, idx) => `\u25A0${idx}:${!!x.filter ? "true" : "false"}:${x.re
|
|
|
122
126
|
role: "assistant"
|
|
123
127
|
}
|
|
124
128
|
];
|
|
125
|
-
const {
|
|
129
|
+
const { extracted: partial, meta } = await ctx.generateContent({
|
|
126
130
|
systemPrompt: `
|
|
127
131
|
You are given a list of items. Your task is to filter out the items that meet the condition below.
|
|
128
132
|
You need to return the full list of items with the format:
|
|
@@ -144,17 +148,18 @@ The condition is: "${condition}"
|
|
|
144
148
|
),
|
|
145
149
|
role: "user"
|
|
146
150
|
}
|
|
147
|
-
]
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
151
|
+
],
|
|
152
|
+
transform: (text) => {
|
|
153
|
+
const indices = text.trim().split("\u25A0").filter((x) => x.length > 0).map((x) => {
|
|
154
|
+
const [idx, filter2] = x.split(":");
|
|
155
|
+
return { idx: parseInt(idx?.trim() ?? ""), filter: filter2?.toLowerCase().trim() === "true" };
|
|
156
|
+
});
|
|
157
|
+
return chunk.filter((_, idx) => {
|
|
158
|
+
return indices.find((x) => x.idx === idx && x.filter) ?? false;
|
|
159
|
+
});
|
|
160
|
+
}
|
|
156
161
|
});
|
|
157
|
-
if (taskId) {
|
|
162
|
+
if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
|
|
158
163
|
const key = fastHash(
|
|
159
164
|
stringify({
|
|
160
165
|
taskId,
|
|
@@ -163,7 +168,7 @@ The condition is: "${condition}"
|
|
|
163
168
|
condition
|
|
164
169
|
})
|
|
165
170
|
);
|
|
166
|
-
await
|
|
171
|
+
await ctx.adapter.saveExample({
|
|
167
172
|
key,
|
|
168
173
|
taskType,
|
|
169
174
|
taskId,
|
|
@@ -176,7 +181,7 @@ The condition is: "${condition}"
|
|
|
176
181
|
output: meta.cost.output
|
|
177
182
|
},
|
|
178
183
|
latency: meta.latency,
|
|
179
|
-
model:
|
|
184
|
+
model: ctx.modelId,
|
|
180
185
|
tokens: {
|
|
181
186
|
input: meta.tokens.input,
|
|
182
187
|
output: meta.tokens.output
|
|
@@ -189,3 +194,13 @@ The condition is: "${condition}"
|
|
|
189
194
|
const filteredChunks = await Promise.all(chunks.map(filterChunk));
|
|
190
195
|
return filteredChunks.flat();
|
|
191
196
|
};
|
|
197
|
+
Zai.prototype.filter = function(input, condition, _options) {
|
|
198
|
+
const context = new ZaiContext({
|
|
199
|
+
client: this.client,
|
|
200
|
+
modelId: this.Model,
|
|
201
|
+
taskId: this.taskId,
|
|
202
|
+
taskType: "zai.filter",
|
|
203
|
+
adapter: this.adapter
|
|
204
|
+
});
|
|
205
|
+
return new Response(context, filter(input, condition, _options, context), (result) => result);
|
|
206
|
+
};
|
package/dist/operations/label.js
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import { z } from "@bpinternal/zui";
|
|
2
|
-
import {
|
|
2
|
+
import { chunk, clamp } from "lodash-es";
|
|
3
|
+
import { ZaiContext } from "../context";
|
|
4
|
+
import { Response } from "../response";
|
|
5
|
+
import { getTokenizer } from "../tokenizer";
|
|
3
6
|
import { fastHash, stringify, takeUntilTokens } from "../utils";
|
|
4
7
|
import { Zai } from "../zai";
|
|
5
8
|
import { PROMPT_INPUT_BUFFER } from "./constants";
|
|
@@ -39,24 +42,24 @@ const _Labels = z.record(z.string().min(1).max(250), z.string()).superRefine((la
|
|
|
39
42
|
}
|
|
40
43
|
return true;
|
|
41
44
|
});
|
|
42
|
-
const parseLabel = (
|
|
43
|
-
|
|
44
|
-
if (
|
|
45
|
+
const parseLabel = (label2) => {
|
|
46
|
+
label2 = label2.toUpperCase().replace(/\s+/g, "_").replace(/_{2,}/g, "_").trim();
|
|
47
|
+
if (label2.includes("ABSOLUTELY") && label2.includes("NOT")) {
|
|
45
48
|
return LABELS.ABSOLUTELY_NOT;
|
|
46
|
-
} else if (
|
|
49
|
+
} else if (label2.includes("NOT")) {
|
|
47
50
|
return LABELS.PROBABLY_NOT;
|
|
48
|
-
} else if (
|
|
51
|
+
} else if (label2.includes("AMBIGUOUS")) {
|
|
49
52
|
return LABELS.AMBIGUOUS;
|
|
50
53
|
}
|
|
51
|
-
if (
|
|
54
|
+
if (label2.includes("YES")) {
|
|
52
55
|
return LABELS.PROBABLY_YES;
|
|
53
|
-
} else if (
|
|
56
|
+
} else if (label2.includes("ABSOLUTELY") && label2.includes("YES")) {
|
|
54
57
|
return LABELS.ABSOLUTELY_YES;
|
|
55
58
|
}
|
|
56
59
|
return LABELS.AMBIGUOUS;
|
|
57
60
|
};
|
|
58
|
-
const getConfidence = (
|
|
59
|
-
switch (
|
|
61
|
+
const getConfidence = (label2) => {
|
|
62
|
+
switch (label2) {
|
|
60
63
|
case LABELS.ABSOLUTELY_NOT:
|
|
61
64
|
case LABELS.ABSOLUTELY_YES:
|
|
62
65
|
return 1;
|
|
@@ -67,14 +70,15 @@ const getConfidence = (label) => {
|
|
|
67
70
|
return 0;
|
|
68
71
|
}
|
|
69
72
|
};
|
|
70
|
-
|
|
73
|
+
const label = async (input, _labels, _options, ctx) => {
|
|
74
|
+
ctx.controller.signal.throwIfAborted();
|
|
71
75
|
const options = _Options.parse(_options ?? {});
|
|
72
76
|
const labels = _Labels.parse(_labels);
|
|
73
|
-
const tokenizer = await
|
|
74
|
-
await
|
|
75
|
-
const taskId =
|
|
77
|
+
const tokenizer = await getTokenizer();
|
|
78
|
+
const model = await ctx.getModel();
|
|
79
|
+
const taskId = ctx.taskId;
|
|
76
80
|
const taskType = "zai.label";
|
|
77
|
-
const TOTAL_MAX_TOKENS = clamp(options.chunkLength, 1e3,
|
|
81
|
+
const TOTAL_MAX_TOKENS = clamp(options.chunkLength, 1e3, model.input.maxTokens - PROMPT_INPUT_BUFFER);
|
|
78
82
|
const CHUNK_EXAMPLES_MAX_TOKENS = clamp(Math.floor(TOTAL_MAX_TOKENS * 0.5), 250, 1e4);
|
|
79
83
|
const CHUNK_INPUT_MAX_TOKENS = clamp(
|
|
80
84
|
TOTAL_MAX_TOKENS - CHUNK_EXAMPLES_MAX_TOKENS,
|
|
@@ -85,7 +89,7 @@ Zai.prototype.label = async function(input, _labels, _options) {
|
|
|
85
89
|
if (tokenizer.count(inputAsString) > CHUNK_INPUT_MAX_TOKENS) {
|
|
86
90
|
const tokens = tokenizer.split(inputAsString);
|
|
87
91
|
const chunks = chunk(tokens, CHUNK_INPUT_MAX_TOKENS).map((x) => x.join(""));
|
|
88
|
-
const allLabels = await Promise.all(chunks.map((chunk2) =>
|
|
92
|
+
const allLabels = await Promise.all(chunks.map((chunk2) => label(chunk2, _labels, _options, ctx)));
|
|
89
93
|
return allLabels.reduce((acc, x) => {
|
|
90
94
|
Object.keys(x).forEach((key) => {
|
|
91
95
|
if (acc[key]?.value === true) {
|
|
@@ -118,7 +122,7 @@ Zai.prototype.label = async function(input, _labels, _options) {
|
|
|
118
122
|
return acc;
|
|
119
123
|
}, {});
|
|
120
124
|
};
|
|
121
|
-
const examples = taskId ? await
|
|
125
|
+
const examples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
|
|
122
126
|
input: inputAsString,
|
|
123
127
|
taskType,
|
|
124
128
|
taskId
|
|
@@ -171,7 +175,7 @@ ${END}
|
|
|
171
175
|
\u25A0${key}:\u3010explanation (where "explanation" is answering the question "${labels[key]}")\u3011:x\u25A0 (where x is ${ALL_LABELS})
|
|
172
176
|
`.trim();
|
|
173
177
|
}).join("\n\n");
|
|
174
|
-
const {
|
|
178
|
+
const { extracted, meta } = await ctx.generateContent({
|
|
175
179
|
stopSequences: [END],
|
|
176
180
|
systemPrompt: `
|
|
177
181
|
You need to tag the input with the following labels based on the question asked:
|
|
@@ -221,28 +225,27 @@ Remember: In your \`explanation\`, please refer to the Expert Examples # (and qu
|
|
|
221
225
|
The Expert Examples are there to help you make your decision. They have been provided by experts in the field and their answers (and reasoning) are considered the ground truth and should be used as a reference to make your decision when applicable.
|
|
222
226
|
For example, you can say: "According to Expert Example #1, ..."`.trim()
|
|
223
227
|
}
|
|
224
|
-
]
|
|
228
|
+
],
|
|
229
|
+
transform: (text) => Object.keys(labels).reduce((acc, key) => {
|
|
230
|
+
const match = text.match(new RegExp(`\u25A0${key}:\u3010(.+)\u3011:(\\w{2,})\u25A0`, "i"));
|
|
231
|
+
if (match) {
|
|
232
|
+
const explanation = match[1].trim();
|
|
233
|
+
const label2 = parseLabel(match[2]);
|
|
234
|
+
acc[key] = {
|
|
235
|
+
explanation,
|
|
236
|
+
label: label2
|
|
237
|
+
};
|
|
238
|
+
} else {
|
|
239
|
+
acc[key] = {
|
|
240
|
+
explanation: "",
|
|
241
|
+
label: LABELS.AMBIGUOUS
|
|
242
|
+
};
|
|
243
|
+
}
|
|
244
|
+
return acc;
|
|
245
|
+
}, {})
|
|
225
246
|
});
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
const match = answer.match(new RegExp(`\u25A0${key}:\u3010(.+)\u3011:(\\w{2,})\u25A0`, "i"));
|
|
229
|
-
if (match) {
|
|
230
|
-
const explanation = match[1].trim();
|
|
231
|
-
const label = parseLabel(match[2]);
|
|
232
|
-
acc[key] = {
|
|
233
|
-
explanation,
|
|
234
|
-
label
|
|
235
|
-
};
|
|
236
|
-
} else {
|
|
237
|
-
acc[key] = {
|
|
238
|
-
explanation: "",
|
|
239
|
-
label: LABELS.AMBIGUOUS
|
|
240
|
-
};
|
|
241
|
-
}
|
|
242
|
-
return acc;
|
|
243
|
-
}, {});
|
|
244
|
-
if (taskId) {
|
|
245
|
-
await this.adapter.saveExample({
|
|
247
|
+
if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
|
|
248
|
+
await ctx.adapter.saveExample({
|
|
246
249
|
key: Key,
|
|
247
250
|
taskType,
|
|
248
251
|
taskId,
|
|
@@ -253,15 +256,35 @@ For example, you can say: "According to Expert Example #1, ..."`.trim()
|
|
|
253
256
|
output: meta.cost.output
|
|
254
257
|
},
|
|
255
258
|
latency: meta.latency,
|
|
256
|
-
model:
|
|
259
|
+
model: ctx.modelId,
|
|
257
260
|
tokens: {
|
|
258
261
|
input: meta.tokens.input,
|
|
259
262
|
output: meta.tokens.output
|
|
260
263
|
}
|
|
261
264
|
},
|
|
262
265
|
input: inputAsString,
|
|
263
|
-
output:
|
|
266
|
+
output: extracted
|
|
264
267
|
});
|
|
265
268
|
}
|
|
266
|
-
return convertToAnswer(
|
|
269
|
+
return convertToAnswer(extracted);
|
|
270
|
+
};
|
|
271
|
+
Zai.prototype.label = function(input, labels, _options) {
|
|
272
|
+
const context = new ZaiContext({
|
|
273
|
+
client: this.client,
|
|
274
|
+
modelId: this.Model,
|
|
275
|
+
taskId: this.taskId,
|
|
276
|
+
taskType: "zai.label",
|
|
277
|
+
adapter: this.adapter
|
|
278
|
+
});
|
|
279
|
+
return new Response(
|
|
280
|
+
context,
|
|
281
|
+
label(input, labels, _options, context),
|
|
282
|
+
(result) => Object.keys(result).reduce(
|
|
283
|
+
(acc, key) => {
|
|
284
|
+
acc[key] = result[key].value;
|
|
285
|
+
return acc;
|
|
286
|
+
},
|
|
287
|
+
{}
|
|
288
|
+
)
|
|
289
|
+
);
|
|
267
290
|
};
|
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
import { z } from "@bpinternal/zui";
|
|
2
|
+
import { ZaiContext } from "../context";
|
|
3
|
+
import { Response } from "../response";
|
|
4
|
+
import { getTokenizer } from "../tokenizer";
|
|
2
5
|
import { fastHash, stringify, takeUntilTokens } from "../utils";
|
|
3
6
|
import { Zai } from "../zai";
|
|
4
7
|
import { PROMPT_INPUT_BUFFER } from "./constants";
|
|
@@ -12,19 +15,20 @@ const Options = z.object({
|
|
|
12
15
|
});
|
|
13
16
|
const START = "\u25A0START\u25A0";
|
|
14
17
|
const END = "\u25A0END\u25A0";
|
|
15
|
-
|
|
18
|
+
const rewrite = async (original, prompt, _options, ctx) => {
|
|
19
|
+
ctx.controller.signal.throwIfAborted();
|
|
16
20
|
const options = Options.parse(_options ?? {});
|
|
17
|
-
const tokenizer = await
|
|
18
|
-
await
|
|
19
|
-
const taskId =
|
|
21
|
+
const tokenizer = await getTokenizer();
|
|
22
|
+
const model = await ctx.getModel();
|
|
23
|
+
const taskId = ctx.taskId;
|
|
20
24
|
const taskType = "zai.rewrite";
|
|
21
|
-
const INPUT_COMPONENT_SIZE = Math.max(100, (
|
|
25
|
+
const INPUT_COMPONENT_SIZE = Math.max(100, (model.input.maxTokens - PROMPT_INPUT_BUFFER) / 2);
|
|
22
26
|
prompt = tokenizer.truncate(prompt, INPUT_COMPONENT_SIZE);
|
|
23
27
|
const inputSize = tokenizer.count(original) + tokenizer.count(prompt);
|
|
24
|
-
const maxInputSize =
|
|
28
|
+
const maxInputSize = model.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER;
|
|
25
29
|
if (inputSize > maxInputSize) {
|
|
26
30
|
throw new Error(
|
|
27
|
-
`The input size is ${inputSize} tokens long, which is more than the maximum of ${maxInputSize} tokens for this model (${
|
|
31
|
+
`The input size is ${inputSize} tokens long, which is more than the maximum of ${maxInputSize} tokens for this model (${model.name} = ${model.input.maxTokens} tokens)`
|
|
28
32
|
);
|
|
29
33
|
}
|
|
30
34
|
const instructions = [];
|
|
@@ -52,17 +56,17 @@ ${END}
|
|
|
52
56
|
prompt
|
|
53
57
|
})
|
|
54
58
|
);
|
|
55
|
-
const formatExample = ({ input, output
|
|
59
|
+
const formatExample = ({ input, output, instructions: instructions2 }) => {
|
|
56
60
|
return [
|
|
57
61
|
{ type: "text", role: "user", content: format(input, instructions2 || prompt) },
|
|
58
|
-
{ type: "text", role: "assistant", content: `${START}${
|
|
62
|
+
{ type: "text", role: "assistant", content: `${START}${output}${END}` }
|
|
59
63
|
];
|
|
60
64
|
};
|
|
61
65
|
const defaultExamples = [
|
|
62
66
|
{ input: "Hello, how are you?", output: "Bonjour, comment \xE7a va?", instructions: "translate to French" },
|
|
63
67
|
{ input: "1\n2\n3", output: "3\n2\n1", instructions: "reverse the order" }
|
|
64
68
|
];
|
|
65
|
-
const tableExamples = taskId ? await
|
|
69
|
+
const tableExamples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
|
|
66
70
|
input: original,
|
|
67
71
|
taskId,
|
|
68
72
|
taskType
|
|
@@ -75,30 +79,36 @@ ${END}
|
|
|
75
79
|
...tableExamples.map((x) => ({ input: x.input, output: x.output })),
|
|
76
80
|
...options.examples
|
|
77
81
|
];
|
|
78
|
-
const REMAINING_TOKENS =
|
|
82
|
+
const REMAINING_TOKENS = model.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER;
|
|
79
83
|
const examples = takeUntilTokens(
|
|
80
84
|
savedExamples.length ? savedExamples : defaultExamples,
|
|
81
85
|
REMAINING_TOKENS,
|
|
82
86
|
(el) => tokenizer.count(stringify(el.input)) + tokenizer.count(stringify(el.output))
|
|
83
87
|
).map(formatExample).flat();
|
|
84
|
-
const {
|
|
88
|
+
const { extracted, meta } = await ctx.generateContent({
|
|
85
89
|
systemPrompt: `
|
|
86
90
|
Rewrite the text between the ${START} and ${END} tags to match the user prompt.
|
|
87
91
|
${instructions.map((x) => `\u2022 ${x}`).join("\n")}
|
|
88
92
|
`.trim(),
|
|
89
93
|
messages: [...examples, { type: "text", content: format(original, prompt), role: "user" }],
|
|
90
94
|
maxTokens: options.length,
|
|
91
|
-
stopSequences: [END]
|
|
95
|
+
stopSequences: [END],
|
|
96
|
+
transform: (text) => {
|
|
97
|
+
if (!text.trim().length) {
|
|
98
|
+
throw new Error("The model did not return a valid rewrite. The response was empty.");
|
|
99
|
+
}
|
|
100
|
+
return text;
|
|
101
|
+
}
|
|
92
102
|
});
|
|
93
|
-
let result =
|
|
103
|
+
let result = extracted;
|
|
94
104
|
if (result.includes(START)) {
|
|
95
105
|
result = result.slice(result.indexOf(START) + START.length);
|
|
96
106
|
}
|
|
97
107
|
if (result.includes(END)) {
|
|
98
108
|
result = result.slice(0, result.indexOf(END));
|
|
99
109
|
}
|
|
100
|
-
if (taskId) {
|
|
101
|
-
await
|
|
110
|
+
if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
|
|
111
|
+
await ctx.adapter.saveExample({
|
|
102
112
|
key: Key,
|
|
103
113
|
metadata: {
|
|
104
114
|
cost: {
|
|
@@ -106,7 +116,7 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
|
|
|
106
116
|
output: meta.cost.output
|
|
107
117
|
},
|
|
108
118
|
latency: meta.latency,
|
|
109
|
-
model:
|
|
119
|
+
model: ctx.modelId,
|
|
110
120
|
tokens: {
|
|
111
121
|
input: meta.tokens.input,
|
|
112
122
|
output: meta.tokens.output
|
|
@@ -121,3 +131,13 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
|
|
|
121
131
|
}
|
|
122
132
|
return result;
|
|
123
133
|
};
|
|
134
|
+
Zai.prototype.rewrite = function(original, prompt, _options) {
|
|
135
|
+
const context = new ZaiContext({
|
|
136
|
+
client: this.client,
|
|
137
|
+
modelId: this.Model,
|
|
138
|
+
taskId: this.taskId,
|
|
139
|
+
taskType: "zai.rewrite",
|
|
140
|
+
adapter: this.adapter
|
|
141
|
+
});
|
|
142
|
+
return new Response(context, rewrite(original, prompt, _options, context), (result) => result);
|
|
143
|
+
};
|
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import { z } from "@bpinternal/zui";
|
|
2
2
|
import { chunk } from "lodash-es";
|
|
3
|
+
import { ZaiContext } from "../context";
|
|
4
|
+
import { Response } from "../response";
|
|
5
|
+
import { getTokenizer } from "../tokenizer";
|
|
3
6
|
import { Zai } from "../zai";
|
|
4
7
|
import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from "./constants";
|
|
5
8
|
const Options = z.object({
|
|
@@ -17,20 +20,20 @@ const Options = z.object({
|
|
|
17
20
|
});
|
|
18
21
|
const START = "\u25A0START\u25A0";
|
|
19
22
|
const END = "\u25A0END\u25A0";
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
const tokenizer = await
|
|
23
|
-
await
|
|
24
|
-
const INPUT_COMPONENT_SIZE = Math.max(100, (
|
|
23
|
+
const summarize = async (original, options, ctx) => {
|
|
24
|
+
ctx.controller.signal.throwIfAborted();
|
|
25
|
+
const tokenizer = await getTokenizer();
|
|
26
|
+
const model = await ctx.getModel();
|
|
27
|
+
const INPUT_COMPONENT_SIZE = Math.max(100, (model.input.maxTokens - PROMPT_INPUT_BUFFER) / 4);
|
|
25
28
|
options.prompt = tokenizer.truncate(options.prompt, INPUT_COMPONENT_SIZE);
|
|
26
29
|
options.format = tokenizer.truncate(options.format, INPUT_COMPONENT_SIZE);
|
|
27
|
-
const maxOutputSize =
|
|
30
|
+
const maxOutputSize = model.output.maxTokens - PROMPT_OUTPUT_BUFFER;
|
|
28
31
|
if (options.length > maxOutputSize) {
|
|
29
32
|
throw new Error(
|
|
30
|
-
`The desired output length is ${maxOutputSize} tokens long, which is more than the maximum of ${
|
|
33
|
+
`The desired output length is ${maxOutputSize} tokens long, which is more than the maximum of ${model.output.maxTokens} tokens for this model (${model.name})`
|
|
31
34
|
);
|
|
32
35
|
}
|
|
33
|
-
options.sliding.window = Math.min(options.sliding.window,
|
|
36
|
+
options.sliding.window = Math.min(options.sliding.window, model.input.maxTokens - PROMPT_INPUT_BUFFER);
|
|
34
37
|
options.sliding.overlap = Math.min(options.sliding.overlap, options.sliding.window - 3 * options.sliding.overlap);
|
|
35
38
|
const format = (summary, newText) => {
|
|
36
39
|
return `
|
|
@@ -52,8 +55,8 @@ ${newText}
|
|
|
52
55
|
const chunkSize = Math.ceil(tokens.length / (parts * N));
|
|
53
56
|
if (useMergeSort) {
|
|
54
57
|
const chunks = chunk(tokens, chunkSize).map((x) => x.join(""));
|
|
55
|
-
const allSummaries = await Promise.
|
|
56
|
-
return
|
|
58
|
+
const allSummaries = (await Promise.allSettled(chunks.map((chunk2) => summarize(chunk2, options, ctx)))).filter((x) => x.status === "fulfilled").map((x) => x.value);
|
|
59
|
+
return summarize(allSummaries.join("\n\n============\n\n"), options, ctx);
|
|
57
60
|
}
|
|
58
61
|
const summaries = [];
|
|
59
62
|
let currentSummary = "";
|
|
@@ -103,7 +106,7 @@ ${newText}
|
|
|
103
106
|
);
|
|
104
107
|
}
|
|
105
108
|
}
|
|
106
|
-
|
|
109
|
+
let { extracted: result } = await ctx.generateContent({
|
|
107
110
|
systemPrompt: `
|
|
108
111
|
You are summarizing a text. The text is split into ${parts} parts, and you are currently working on part ${iteration}.
|
|
109
112
|
At every step, you will receive the current summary and a new part of the text. You need to amend the summary to include the new information (if needed).
|
|
@@ -117,9 +120,14 @@ ${options.format}
|
|
|
117
120
|
`.trim(),
|
|
118
121
|
messages: [{ type: "text", content: format(currentSummary, slice), role: "user" }],
|
|
119
122
|
maxTokens: generationLength,
|
|
120
|
-
stopSequences: [END]
|
|
123
|
+
stopSequences: [END],
|
|
124
|
+
transform: (text) => {
|
|
125
|
+
if (!text.trim().length) {
|
|
126
|
+
throw new Error("The model did not return a valid summary. The response was empty.");
|
|
127
|
+
}
|
|
128
|
+
return text;
|
|
129
|
+
}
|
|
121
130
|
});
|
|
122
|
-
let result = output?.choices[0]?.content;
|
|
123
131
|
if (result.includes(START)) {
|
|
124
132
|
result = result.slice(result.indexOf(START) + START.length);
|
|
125
133
|
}
|
|
@@ -131,3 +139,14 @@ ${options.format}
|
|
|
131
139
|
}
|
|
132
140
|
return currentSummary.trim();
|
|
133
141
|
};
|
|
142
|
+
Zai.prototype.summarize = function(original, _options) {
|
|
143
|
+
const options = Options.parse(_options ?? {});
|
|
144
|
+
const context = new ZaiContext({
|
|
145
|
+
client: this.client,
|
|
146
|
+
modelId: this.Model,
|
|
147
|
+
taskId: this.taskId,
|
|
148
|
+
taskType: "summarize",
|
|
149
|
+
adapter: this.adapter
|
|
150
|
+
});
|
|
151
|
+
return new Response(context, summarize(original, options, context), (value) => value);
|
|
152
|
+
};
|
package/dist/operations/text.js
CHANGED
|
@@ -1,17 +1,21 @@
|
|
|
1
1
|
import { z } from "@bpinternal/zui";
|
|
2
2
|
import { clamp } from "lodash-es";
|
|
3
|
+
import { ZaiContext } from "../context";
|
|
4
|
+
import { Response } from "../response";
|
|
5
|
+
import { getTokenizer } from "../tokenizer";
|
|
3
6
|
import { Zai } from "../zai";
|
|
4
7
|
import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from "./constants";
|
|
5
8
|
const Options = z.object({
|
|
6
9
|
length: z.number().min(1).max(1e5).optional().describe("The maximum number of tokens to generate")
|
|
7
10
|
});
|
|
8
|
-
|
|
11
|
+
const text = async (prompt, _options, ctx) => {
|
|
12
|
+
ctx.controller.signal.throwIfAborted();
|
|
9
13
|
const options = Options.parse(_options ?? {});
|
|
10
|
-
const tokenizer = await
|
|
11
|
-
await
|
|
12
|
-
prompt = tokenizer.truncate(prompt, Math.max(
|
|
14
|
+
const tokenizer = await getTokenizer();
|
|
15
|
+
const model = await ctx.getModel();
|
|
16
|
+
prompt = tokenizer.truncate(prompt, Math.max(model.input.maxTokens - PROMPT_INPUT_BUFFER, 100));
|
|
13
17
|
if (options.length) {
|
|
14
|
-
options.length = Math.min(
|
|
18
|
+
options.length = Math.min(model.output.maxTokens - PROMPT_OUTPUT_BUFFER, options.length);
|
|
15
19
|
}
|
|
16
20
|
const instructions = [];
|
|
17
21
|
let chart = "";
|
|
@@ -33,7 +37,7 @@ Zai.prototype.text = async function(prompt, _options) {
|
|
|
33
37
|
| 200-300 tokens| A medium paragraph (150-200 words) |
|
|
34
38
|
| 300-500 tokens| A long paragraph (200-300 words) |`.trim();
|
|
35
39
|
}
|
|
36
|
-
const {
|
|
40
|
+
const { extracted } = await ctx.generateContent({
|
|
37
41
|
systemPrompt: `
|
|
38
42
|
Generate a text that fulfills the user prompt below. Answer directly to the prompt, without any acknowledgements or fluff. Also, make sure the text is standalone and complete.
|
|
39
43
|
${instructions.map((x) => `- ${x}`).join("\n")}
|
|
@@ -41,7 +45,23 @@ ${chart}
|
|
|
41
45
|
`.trim(),
|
|
42
46
|
temperature: 0.7,
|
|
43
47
|
messages: [{ type: "text", content: prompt, role: "user" }],
|
|
44
|
-
maxTokens: options.length
|
|
48
|
+
maxTokens: options.length,
|
|
49
|
+
transform: (text2) => {
|
|
50
|
+
if (!text2.trim().length) {
|
|
51
|
+
throw new Error("The model did not return a valid summary. The response was empty.");
|
|
52
|
+
}
|
|
53
|
+
return text2;
|
|
54
|
+
}
|
|
45
55
|
});
|
|
46
|
-
return
|
|
56
|
+
return extracted;
|
|
57
|
+
};
|
|
58
|
+
Zai.prototype.text = function(prompt, _options) {
|
|
59
|
+
const context = new ZaiContext({
|
|
60
|
+
client: this.client,
|
|
61
|
+
modelId: this.Model,
|
|
62
|
+
taskId: this.taskId,
|
|
63
|
+
taskType: "zai.text",
|
|
64
|
+
adapter: this.adapter
|
|
65
|
+
});
|
|
66
|
+
return new Response(context, text(prompt, _options, context), (result) => result);
|
|
47
67
|
};
|
package/dist/response.js
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import { EventEmitter } from "./emitter";
|
|
2
|
+
export class Response {
|
|
3
|
+
_promise;
|
|
4
|
+
_eventEmitter;
|
|
5
|
+
_context;
|
|
6
|
+
_elasped = null;
|
|
7
|
+
_simplify;
|
|
8
|
+
constructor(context, promise, simplify) {
|
|
9
|
+
this._context = context;
|
|
10
|
+
this._eventEmitter = new EventEmitter();
|
|
11
|
+
this._simplify = simplify;
|
|
12
|
+
this._promise = promise.then(
|
|
13
|
+
(value) => {
|
|
14
|
+
this._elasped ||= this._context.elapsedTime;
|
|
15
|
+
this._eventEmitter.emit("complete", value);
|
|
16
|
+
this._eventEmitter.clear();
|
|
17
|
+
this._context.clear();
|
|
18
|
+
return value;
|
|
19
|
+
},
|
|
20
|
+
(reason) => {
|
|
21
|
+
this._elasped ||= this._context.elapsedTime;
|
|
22
|
+
this._eventEmitter.emit("error", reason);
|
|
23
|
+
this._eventEmitter.clear();
|
|
24
|
+
this._context.clear();
|
|
25
|
+
throw reason;
|
|
26
|
+
}
|
|
27
|
+
);
|
|
28
|
+
this._context.on("update", (usage) => {
|
|
29
|
+
this._eventEmitter.emit("progress", usage);
|
|
30
|
+
});
|
|
31
|
+
}
|
|
32
|
+
// Event emitter methods
|
|
33
|
+
on(type, listener) {
|
|
34
|
+
this._eventEmitter.on(type, listener);
|
|
35
|
+
return this;
|
|
36
|
+
}
|
|
37
|
+
off(type, listener) {
|
|
38
|
+
this._eventEmitter.off(type, listener);
|
|
39
|
+
return this;
|
|
40
|
+
}
|
|
41
|
+
once(type, listener) {
|
|
42
|
+
this._eventEmitter.once(type, listener);
|
|
43
|
+
return this;
|
|
44
|
+
}
|
|
45
|
+
bindSignal(signal) {
|
|
46
|
+
if (signal.aborted) {
|
|
47
|
+
this.abort(signal.reason);
|
|
48
|
+
}
|
|
49
|
+
const signalAbort = () => {
|
|
50
|
+
this.abort(signal.reason);
|
|
51
|
+
};
|
|
52
|
+
signal.addEventListener("abort", () => signalAbort());
|
|
53
|
+
this.once("complete", () => signal.removeEventListener("abort", signalAbort));
|
|
54
|
+
this.once("error", () => signal.removeEventListener("abort", signalAbort));
|
|
55
|
+
return this;
|
|
56
|
+
}
|
|
57
|
+
abort(reason) {
|
|
58
|
+
this._context.controller.abort(reason);
|
|
59
|
+
}
|
|
60
|
+
then(onfulfilled, onrejected) {
|
|
61
|
+
return this._promise.then(
|
|
62
|
+
(value) => {
|
|
63
|
+
const simplified = this._simplify(value);
|
|
64
|
+
return onfulfilled ? onfulfilled(simplified) : simplified;
|
|
65
|
+
},
|
|
66
|
+
(reason) => {
|
|
67
|
+
if (onrejected) {
|
|
68
|
+
return onrejected(reason);
|
|
69
|
+
}
|
|
70
|
+
throw reason;
|
|
71
|
+
}
|
|
72
|
+
);
|
|
73
|
+
}
|
|
74
|
+
catch(onrejected) {
|
|
75
|
+
return this._promise.catch(onrejected);
|
|
76
|
+
}
|
|
77
|
+
async result() {
|
|
78
|
+
const output = await this._promise;
|
|
79
|
+
const usage = this._context.usage;
|
|
80
|
+
return { output, usage, elapsed: this._elasped };
|
|
81
|
+
}
|
|
82
|
+
}
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import { getWasmTokenizer } from "@bpinternal/thicktoken";
|
|
2
|
+
let tokenizer = null;
|
|
3
|
+
export async function getTokenizer() {
|
|
4
|
+
if (!tokenizer) {
|
|
5
|
+
while (!getWasmTokenizer) {
|
|
6
|
+
await new Promise((resolve) => setTimeout(resolve, 25));
|
|
7
|
+
}
|
|
8
|
+
tokenizer = getWasmTokenizer();
|
|
9
|
+
}
|
|
10
|
+
return tokenizer;
|
|
11
|
+
}
|