@botpress/zai 2.0.15 → 2.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,56 +2,103 @@ import { z } from "@bpinternal/zui";
2
2
  import JSON5 from "json5";
3
3
  import { jsonrepair } from "jsonrepair";
4
4
  import { chunk, isArray } from "lodash-es";
5
+ import { ZaiContext } from "../context";
6
+ import { Response } from "../response";
7
+ import { getTokenizer } from "../tokenizer";
5
8
  import { fastHash, stringify, takeUntilTokens } from "../utils";
6
9
  import { Zai } from "../zai";
7
10
  import { PROMPT_INPUT_BUFFER } from "./constants";
8
11
  import { JsonParsingError } from "./errors";
9
12
  const Options = z.object({
10
13
  instructions: z.string().optional().describe("Instructions to guide the user on how to extract the data"),
11
- chunkLength: z.number().min(100).max(1e5).optional().describe("The maximum number of tokens per chunk").default(16e3)
14
+ chunkLength: z.number().min(100).max(1e5).optional().describe("The maximum number of tokens per chunk").default(16e3),
15
+ strict: z.boolean().optional().default(true).describe("Whether to strictly follow the schema or not")
12
16
  });
13
17
  const START = "\u25A0json_start\u25A0";
14
18
  const END = "\u25A0json_end\u25A0";
15
19
  const NO_MORE = "\u25A0NO_MORE_ELEMENT\u25A0";
16
- Zai.prototype.extract = async function(input, _schema, _options) {
20
+ const extract = async (input, _schema, _options, ctx) => {
21
+ ctx.controller.signal.throwIfAborted();
17
22
  let schema = _schema;
18
23
  const options = Options.parse(_options ?? {});
19
- const tokenizer = await this.getTokenizer();
20
- await this.fetchModelDetails();
21
- const taskId = this.taskId;
24
+ const tokenizer = await getTokenizer();
25
+ const model = await ctx.getModel();
26
+ const taskId = ctx.taskId;
22
27
  const taskType = "zai.extract";
23
- const PROMPT_COMPONENT = Math.max(this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER, 100);
28
+ const PROMPT_COMPONENT = Math.max(model.input.maxTokens - PROMPT_INPUT_BUFFER, 100);
24
29
  let isArrayOfObjects = false;
30
+ let wrappedValue = false;
25
31
  const originalSchema = schema;
26
32
  const baseType = (schema.naked ? schema.naked() : schema)?.constructor?.name ?? "unknown";
27
- if (baseType === "ZodObject") {
28
- } else if (baseType === "ZodArray") {
33
+ if (baseType === "ZodArray") {
34
+ isArrayOfObjects = true;
29
35
  let elementType = schema.element;
30
36
  if (elementType.naked) {
31
37
  elementType = elementType.naked();
32
38
  }
33
39
  if (elementType?.constructor?.name === "ZodObject") {
34
- isArrayOfObjects = true;
35
40
  schema = elementType;
36
41
  } else {
37
- throw new Error("Schema must be a ZodObject or a ZodArray<ZodObject>");
42
+ wrappedValue = true;
43
+ schema = z.object({
44
+ value: elementType
45
+ });
46
+ }
47
+ } else if (baseType !== "ZodObject") {
48
+ wrappedValue = true;
49
+ schema = z.object({
50
+ value: originalSchema
51
+ });
52
+ }
53
+ if (!options.strict) {
54
+ try {
55
+ schema = schema.partial();
56
+ } catch {
38
57
  }
39
- } else {
40
- throw new Error("Schema must be either a ZuiObject or a ZuiArray<ZuiObject>");
41
58
  }
42
59
  const schemaTypescript = schema.toTypescriptType({ declaration: false });
43
60
  const schemaLength = tokenizer.count(schemaTypescript);
44
- options.chunkLength = Math.min(
45
- options.chunkLength,
46
- this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength
47
- );
61
+ options.chunkLength = Math.min(options.chunkLength, model.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength);
48
62
  const keys = Object.keys(schema.shape);
49
63
  const inputAsString = stringify(input);
50
64
  if (tokenizer.count(inputAsString) > options.chunkLength) {
51
65
  const tokens = tokenizer.split(inputAsString);
52
66
  const chunks = chunk(tokens, options.chunkLength).map((x) => x.join(""));
53
- const all = await Promise.all(chunks.map((chunk2) => this.extract(chunk2, originalSchema)));
54
- return this.extract(all, originalSchema, options);
67
+ const all = await Promise.allSettled(
68
+ chunks.map(
69
+ (chunk2) => extract(
70
+ chunk2,
71
+ originalSchema,
72
+ {
73
+ ...options,
74
+ strict: false
75
+ // We don't want to fail on strict mode for sub-chunks
76
+ },
77
+ ctx
78
+ )
79
+ )
80
+ ).then(
81
+ (results) => results.filter((x) => x.status === "fulfilled").map((x) => x.value)
82
+ );
83
+ ctx.controller.signal.throwIfAborted();
84
+ const rows = all.map((x, idx) => `<part-${idx + 1}>
85
+ ${stringify(x, true)}
86
+ </part-${idx + 1}>`).join("\n");
87
+ return extract(
88
+ `
89
+ The result has been split into ${all.length} parts. Recursively merge the result into the final result.
90
+ When merging arrays, take unique values.
91
+ When merging conflictual (but defined) information, take the most reasonable and frequent value.
92
+ Non-defined values are OK and normal. Don't delete fields because of null values. Focus on defined values.
93
+
94
+ Here's the data:
95
+ ${rows}
96
+
97
+ Merge it back into a final result.`.trim(),
98
+ originalSchema,
99
+ options,
100
+ ctx
101
+ );
55
102
  }
56
103
  const instructions = [];
57
104
  if (options.instructions) {
@@ -72,6 +119,9 @@ Zai.prototype.extract = async function(input, _schema, _options) {
72
119
  instructions.push("You may have exactly one element in the input.");
73
120
  instructions.push(`The element must be a JSON object with exactly the format: ${START}${shape}${END}`);
74
121
  }
122
+ if (!options.strict) {
123
+ instructions.push("You may ignore any fields that are not present in the input. All keys are optional.");
124
+ }
75
125
  const EXAMPLES_TOKENS = PROMPT_COMPONENT - tokenizer.count(inputAsString) - tokenizer.count(instructions.join("\n"));
76
126
  const Key = fastHash(
77
127
  JSON.stringify({
@@ -81,7 +131,7 @@ Zai.prototype.extract = async function(input, _schema, _options) {
81
131
  instructions: options.instructions
82
132
  })
83
133
  );
84
- const examples = taskId ? await this.adapter.getExamples({
134
+ const examples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
85
135
  input: inputAsString,
86
136
  taskType,
87
137
  taskId
@@ -140,9 +190,9 @@ ${input2.trim()}
140
190
  <|end_input|>
141
191
  `.trim();
142
192
  };
143
- const formatOutput = (extracted) => {
144
- extracted = isArray(extracted) ? extracted : [extracted];
145
- return extracted.map(
193
+ const formatOutput = (extracted2) => {
194
+ extracted2 = isArray(extracted2) ? extracted2 : [extracted2];
195
+ return extracted2.map(
146
196
  (x) => `
147
197
  ${START}
148
198
  ${JSON.stringify(x, null, 2)}
@@ -166,7 +216,7 @@ ${END}`.trim()
166
216
  EXAMPLES_TOKENS,
167
217
  (el) => tokenizer.count(stringify(el.input)) + tokenizer.count(stringify(el.extracted))
168
218
  ).map(formatExample).flat();
169
- const { output, meta } = await this.callModel({
219
+ const { meta, extracted } = await ctx.generateContent({
170
220
  systemPrompt: `
171
221
  Extract the following information from the input:
172
222
  ${schemaTypescript}
@@ -182,29 +232,42 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
182
232
  type: "text",
183
233
  content: formatInput(inputAsString, schemaTypescript, options.instructions ?? "")
184
234
  }
185
- ]
235
+ ],
236
+ transform: (text) => (text || "{}")?.split(START).filter((x) => x.trim().length > 0 && x.includes("}")).map((x) => {
237
+ try {
238
+ const json = x.slice(0, x.indexOf(END)).trim();
239
+ const repairedJson = jsonrepair(json);
240
+ const parsedJson = JSON5.parse(repairedJson);
241
+ const safe = schema.safeParse(parsedJson);
242
+ if (safe.success) {
243
+ return safe.data;
244
+ }
245
+ if (options.strict) {
246
+ throw new JsonParsingError(x, safe.error);
247
+ }
248
+ return parsedJson;
249
+ } catch (error) {
250
+ throw new JsonParsingError(x, error instanceof Error ? error : new Error("Unknown error"));
251
+ }
252
+ }).filter((x) => x !== null)
186
253
  });
187
- const answer = output.choices[0]?.content;
188
- const elements = answer.split(START).filter((x) => x.trim().length > 0).map((x) => {
189
- try {
190
- const json = x.slice(0, x.indexOf(END)).trim();
191
- const repairedJson = jsonrepair(json);
192
- const parsedJson = JSON5.parse(repairedJson);
193
- return schema.parse(parsedJson);
194
- } catch (error) {
195
- throw new JsonParsingError(x, error instanceof Error ? error : new Error("Unknown error"));
196
- }
197
- }).filter((x) => x !== null);
198
254
  let final;
199
255
  if (isArrayOfObjects) {
200
- final = elements;
201
- } else if (elements.length === 0) {
202
- final = schema.parse({});
256
+ final = extracted;
257
+ } else if (extracted.length === 0) {
258
+ final = options.strict ? schema.parse({}) : {};
203
259
  } else {
204
- final = elements[0];
260
+ final = extracted[0];
205
261
  }
206
- if (taskId) {
207
- await this.adapter.saveExample({
262
+ if (wrappedValue) {
263
+ if (Array.isArray(final)) {
264
+ final = final.map((x) => "value" in x ? x.value : x);
265
+ } else {
266
+ final = "value" in final ? final.value : final;
267
+ }
268
+ }
269
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
270
+ await ctx.adapter.saveExample({
208
271
  key: Key,
209
272
  taskId: `zai/${taskId}`,
210
273
  taskType,
@@ -217,7 +280,7 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
217
280
  output: meta.cost.output
218
281
  },
219
282
  latency: meta.latency,
220
- model: this.Model,
283
+ model: ctx.modelId,
221
284
  tokens: {
222
285
  input: meta.tokens.input,
223
286
  output: meta.tokens.output
@@ -227,3 +290,13 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
227
290
  }
228
291
  return final;
229
292
  };
293
+ Zai.prototype.extract = function(input, schema, _options) {
294
+ const context = new ZaiContext({
295
+ client: this.client,
296
+ modelId: this.Model,
297
+ taskId: this.taskId,
298
+ taskType: "zai.extract",
299
+ adapter: this.adapter
300
+ });
301
+ return new Response(context, extract(input, schema, _options, context), (result) => result);
302
+ };
@@ -1,5 +1,8 @@
1
1
  import { z } from "@bpinternal/zui";
2
2
  import { clamp } from "lodash-es";
3
+ import { ZaiContext } from "../context";
4
+ import { Response } from "../response";
5
+ import { getTokenizer } from "../tokenizer";
3
6
  import { fastHash, stringify, takeUntilTokens } from "../utils";
4
7
  import { Zai } from "../zai";
5
8
  import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from "./constants";
@@ -13,14 +16,15 @@ const _Options = z.object({
13
16
  examples: z.array(_Example).describe("Examples to filter the condition against").default([])
14
17
  });
15
18
  const END = "\u25A0END\u25A0";
16
- Zai.prototype.filter = async function(input, condition, _options) {
19
+ const filter = async (input, condition, _options, ctx) => {
20
+ ctx.controller.signal.throwIfAborted();
17
21
  const options = _Options.parse(_options ?? {});
18
- const tokenizer = await this.getTokenizer();
19
- await this.fetchModelDetails();
20
- const taskId = this.taskId;
22
+ const tokenizer = await getTokenizer();
23
+ const model = await ctx.getModel();
24
+ const taskId = ctx.taskId;
21
25
  const taskType = "zai.filter";
22
26
  const MAX_ITEMS_PER_CHUNK = 50;
23
- const TOKENS_TOTAL_MAX = this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER;
27
+ const TOKENS_TOTAL_MAX = model.input.maxTokens - PROMPT_INPUT_BUFFER - PROMPT_OUTPUT_BUFFER;
24
28
  const TOKENS_EXAMPLES_MAX = Math.floor(Math.max(250, TOKENS_TOTAL_MAX * 0.5));
25
29
  const TOKENS_CONDITION_MAX = clamp(TOKENS_TOTAL_MAX * 0.25, 250, tokenizer.count(condition));
26
30
  const TOKENS_INPUT_ARRAY_MAX = TOKENS_TOTAL_MAX - TOKENS_EXAMPLES_MAX - TOKENS_CONDITION_MAX;
@@ -97,7 +101,7 @@ ${examples.map((x, idx) => `\u25A0${idx}:${!!x.filter ? "true" : "false"}:${x.re
97
101
  }
98
102
  ];
99
103
  const filterChunk = async (chunk) => {
100
- const examples = taskId ? await this.adapter.getExamples({
104
+ const examples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
101
105
  // The Table API can't search for a huge input string
102
106
  input: JSON.stringify(chunk).slice(0, 1e3),
103
107
  taskType,
@@ -122,7 +126,7 @@ ${examples.map((x, idx) => `\u25A0${idx}:${!!x.filter ? "true" : "false"}:${x.re
122
126
  role: "assistant"
123
127
  }
124
128
  ];
125
- const { output, meta } = await this.callModel({
129
+ const { extracted: partial, meta } = await ctx.generateContent({
126
130
  systemPrompt: `
127
131
  You are given a list of items. Your task is to filter out the items that meet the condition below.
128
132
  You need to return the full list of items with the format:
@@ -144,17 +148,18 @@ The condition is: "${condition}"
144
148
  ),
145
149
  role: "user"
146
150
  }
147
- ]
148
- });
149
- const answer = output.choices[0]?.content;
150
- const indices = answer.trim().split("\u25A0").filter((x) => x.length > 0).map((x) => {
151
- const [idx, filter] = x.split(":");
152
- return { idx: parseInt(idx?.trim() ?? ""), filter: filter?.toLowerCase().trim() === "true" };
153
- });
154
- const partial = chunk.filter((_, idx) => {
155
- return indices.find((x) => x.idx === idx)?.filter ?? false;
151
+ ],
152
+ transform: (text) => {
153
+ const indices = text.trim().split("\u25A0").filter((x) => x.length > 0).map((x) => {
154
+ const [idx, filter2] = x.split(":");
155
+ return { idx: parseInt(idx?.trim() ?? ""), filter: filter2?.toLowerCase().trim() === "true" };
156
+ });
157
+ return chunk.filter((_, idx) => {
158
+ return indices.find((x) => x.idx === idx && x.filter) ?? false;
159
+ });
160
+ }
156
161
  });
157
- if (taskId) {
162
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
158
163
  const key = fastHash(
159
164
  stringify({
160
165
  taskId,
@@ -163,7 +168,7 @@ The condition is: "${condition}"
163
168
  condition
164
169
  })
165
170
  );
166
- await this.adapter.saveExample({
171
+ await ctx.adapter.saveExample({
167
172
  key,
168
173
  taskType,
169
174
  taskId,
@@ -176,7 +181,7 @@ The condition is: "${condition}"
176
181
  output: meta.cost.output
177
182
  },
178
183
  latency: meta.latency,
179
- model: this.Model,
184
+ model: ctx.modelId,
180
185
  tokens: {
181
186
  input: meta.tokens.input,
182
187
  output: meta.tokens.output
@@ -189,3 +194,13 @@ The condition is: "${condition}"
189
194
  const filteredChunks = await Promise.all(chunks.map(filterChunk));
190
195
  return filteredChunks.flat();
191
196
  };
197
+ Zai.prototype.filter = function(input, condition, _options) {
198
+ const context = new ZaiContext({
199
+ client: this.client,
200
+ modelId: this.Model,
201
+ taskId: this.taskId,
202
+ taskType: "zai.filter",
203
+ adapter: this.adapter
204
+ });
205
+ return new Response(context, filter(input, condition, _options, context), (result) => result);
206
+ };
@@ -1,5 +1,8 @@
1
1
  import { z } from "@bpinternal/zui";
2
- import { clamp, chunk } from "lodash-es";
2
+ import { chunk, clamp } from "lodash-es";
3
+ import { ZaiContext } from "../context";
4
+ import { Response } from "../response";
5
+ import { getTokenizer } from "../tokenizer";
3
6
  import { fastHash, stringify, takeUntilTokens } from "../utils";
4
7
  import { Zai } from "../zai";
5
8
  import { PROMPT_INPUT_BUFFER } from "./constants";
@@ -39,24 +42,24 @@ const _Labels = z.record(z.string().min(1).max(250), z.string()).superRefine((la
39
42
  }
40
43
  return true;
41
44
  });
42
- const parseLabel = (label) => {
43
- label = label.toUpperCase().replace(/\s+/g, "_").replace(/_{2,}/g, "_").trim();
44
- if (label.includes("ABSOLUTELY") && label.includes("NOT")) {
45
+ const parseLabel = (label2) => {
46
+ label2 = label2.toUpperCase().replace(/\s+/g, "_").replace(/_{2,}/g, "_").trim();
47
+ if (label2.includes("ABSOLUTELY") && label2.includes("NOT")) {
45
48
  return LABELS.ABSOLUTELY_NOT;
46
- } else if (label.includes("NOT")) {
49
+ } else if (label2.includes("NOT")) {
47
50
  return LABELS.PROBABLY_NOT;
48
- } else if (label.includes("AMBIGUOUS")) {
51
+ } else if (label2.includes("AMBIGUOUS")) {
49
52
  return LABELS.AMBIGUOUS;
50
53
  }
51
- if (label.includes("YES")) {
54
+ if (label2.includes("YES")) {
52
55
  return LABELS.PROBABLY_YES;
53
- } else if (label.includes("ABSOLUTELY") && label.includes("YES")) {
56
+ } else if (label2.includes("ABSOLUTELY") && label2.includes("YES")) {
54
57
  return LABELS.ABSOLUTELY_YES;
55
58
  }
56
59
  return LABELS.AMBIGUOUS;
57
60
  };
58
- const getConfidence = (label) => {
59
- switch (label) {
61
+ const getConfidence = (label2) => {
62
+ switch (label2) {
60
63
  case LABELS.ABSOLUTELY_NOT:
61
64
  case LABELS.ABSOLUTELY_YES:
62
65
  return 1;
@@ -67,14 +70,15 @@ const getConfidence = (label) => {
67
70
  return 0;
68
71
  }
69
72
  };
70
- Zai.prototype.label = async function(input, _labels, _options) {
73
+ const label = async (input, _labels, _options, ctx) => {
74
+ ctx.controller.signal.throwIfAborted();
71
75
  const options = _Options.parse(_options ?? {});
72
76
  const labels = _Labels.parse(_labels);
73
- const tokenizer = await this.getTokenizer();
74
- await this.fetchModelDetails();
75
- const taskId = this.taskId;
77
+ const tokenizer = await getTokenizer();
78
+ const model = await ctx.getModel();
79
+ const taskId = ctx.taskId;
76
80
  const taskType = "zai.label";
77
- const TOTAL_MAX_TOKENS = clamp(options.chunkLength, 1e3, this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER);
81
+ const TOTAL_MAX_TOKENS = clamp(options.chunkLength, 1e3, model.input.maxTokens - PROMPT_INPUT_BUFFER);
78
82
  const CHUNK_EXAMPLES_MAX_TOKENS = clamp(Math.floor(TOTAL_MAX_TOKENS * 0.5), 250, 1e4);
79
83
  const CHUNK_INPUT_MAX_TOKENS = clamp(
80
84
  TOTAL_MAX_TOKENS - CHUNK_EXAMPLES_MAX_TOKENS,
@@ -85,7 +89,7 @@ Zai.prototype.label = async function(input, _labels, _options) {
85
89
  if (tokenizer.count(inputAsString) > CHUNK_INPUT_MAX_TOKENS) {
86
90
  const tokens = tokenizer.split(inputAsString);
87
91
  const chunks = chunk(tokens, CHUNK_INPUT_MAX_TOKENS).map((x) => x.join(""));
88
- const allLabels = await Promise.all(chunks.map((chunk2) => this.label(chunk2, _labels)));
92
+ const allLabels = await Promise.all(chunks.map((chunk2) => label(chunk2, _labels, _options, ctx)));
89
93
  return allLabels.reduce((acc, x) => {
90
94
  Object.keys(x).forEach((key) => {
91
95
  if (acc[key]?.value === true) {
@@ -118,7 +122,7 @@ Zai.prototype.label = async function(input, _labels, _options) {
118
122
  return acc;
119
123
  }, {});
120
124
  };
121
- const examples = taskId ? await this.adapter.getExamples({
125
+ const examples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
122
126
  input: inputAsString,
123
127
  taskType,
124
128
  taskId
@@ -171,7 +175,7 @@ ${END}
171
175
  \u25A0${key}:\u3010explanation (where "explanation" is answering the question "${labels[key]}")\u3011:x\u25A0 (where x is ${ALL_LABELS})
172
176
  `.trim();
173
177
  }).join("\n\n");
174
- const { output, meta } = await this.callModel({
178
+ const { extracted, meta } = await ctx.generateContent({
175
179
  stopSequences: [END],
176
180
  systemPrompt: `
177
181
  You need to tag the input with the following labels based on the question asked:
@@ -221,28 +225,27 @@ Remember: In your \`explanation\`, please refer to the Expert Examples # (and qu
221
225
  The Expert Examples are there to help you make your decision. They have been provided by experts in the field and their answers (and reasoning) are considered the ground truth and should be used as a reference to make your decision when applicable.
222
226
  For example, you can say: "According to Expert Example #1, ..."`.trim()
223
227
  }
224
- ]
228
+ ],
229
+ transform: (text) => Object.keys(labels).reduce((acc, key) => {
230
+ const match = text.match(new RegExp(`\u25A0${key}:\u3010(.+)\u3011:(\\w{2,})\u25A0`, "i"));
231
+ if (match) {
232
+ const explanation = match[1].trim();
233
+ const label2 = parseLabel(match[2]);
234
+ acc[key] = {
235
+ explanation,
236
+ label: label2
237
+ };
238
+ } else {
239
+ acc[key] = {
240
+ explanation: "",
241
+ label: LABELS.AMBIGUOUS
242
+ };
243
+ }
244
+ return acc;
245
+ }, {})
225
246
  });
226
- const answer = output.choices[0].content;
227
- const final = Object.keys(labels).reduce((acc, key) => {
228
- const match = answer.match(new RegExp(`\u25A0${key}:\u3010(.+)\u3011:(\\w{2,})\u25A0`, "i"));
229
- if (match) {
230
- const explanation = match[1].trim();
231
- const label = parseLabel(match[2]);
232
- acc[key] = {
233
- explanation,
234
- label
235
- };
236
- } else {
237
- acc[key] = {
238
- explanation: "",
239
- label: LABELS.AMBIGUOUS
240
- };
241
- }
242
- return acc;
243
- }, {});
244
- if (taskId) {
245
- await this.adapter.saveExample({
247
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
248
+ await ctx.adapter.saveExample({
246
249
  key: Key,
247
250
  taskType,
248
251
  taskId,
@@ -253,15 +256,35 @@ For example, you can say: "According to Expert Example #1, ..."`.trim()
253
256
  output: meta.cost.output
254
257
  },
255
258
  latency: meta.latency,
256
- model: this.Model,
259
+ model: ctx.modelId,
257
260
  tokens: {
258
261
  input: meta.tokens.input,
259
262
  output: meta.tokens.output
260
263
  }
261
264
  },
262
265
  input: inputAsString,
263
- output: final
266
+ output: extracted
264
267
  });
265
268
  }
266
- return convertToAnswer(final);
269
+ return convertToAnswer(extracted);
270
+ };
271
+ Zai.prototype.label = function(input, labels, _options) {
272
+ const context = new ZaiContext({
273
+ client: this.client,
274
+ modelId: this.Model,
275
+ taskId: this.taskId,
276
+ taskType: "zai.label",
277
+ adapter: this.adapter
278
+ });
279
+ return new Response(
280
+ context,
281
+ label(input, labels, _options, context),
282
+ (result) => Object.keys(result).reduce(
283
+ (acc, key) => {
284
+ acc[key] = result[key].value;
285
+ return acc;
286
+ },
287
+ {}
288
+ )
289
+ );
267
290
  };
@@ -1,4 +1,7 @@
1
1
  import { z } from "@bpinternal/zui";
2
+ import { ZaiContext } from "../context";
3
+ import { Response } from "../response";
4
+ import { getTokenizer } from "../tokenizer";
2
5
  import { fastHash, stringify, takeUntilTokens } from "../utils";
3
6
  import { Zai } from "../zai";
4
7
  import { PROMPT_INPUT_BUFFER } from "./constants";
@@ -12,19 +15,20 @@ const Options = z.object({
12
15
  });
13
16
  const START = "\u25A0START\u25A0";
14
17
  const END = "\u25A0END\u25A0";
15
- Zai.prototype.rewrite = async function(original, prompt, _options) {
18
+ const rewrite = async (original, prompt, _options, ctx) => {
19
+ ctx.controller.signal.throwIfAborted();
16
20
  const options = Options.parse(_options ?? {});
17
- const tokenizer = await this.getTokenizer();
18
- await this.fetchModelDetails();
19
- const taskId = this.taskId;
21
+ const tokenizer = await getTokenizer();
22
+ const model = await ctx.getModel();
23
+ const taskId = ctx.taskId;
20
24
  const taskType = "zai.rewrite";
21
- const INPUT_COMPONENT_SIZE = Math.max(100, (this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER) / 2);
25
+ const INPUT_COMPONENT_SIZE = Math.max(100, (model.input.maxTokens - PROMPT_INPUT_BUFFER) / 2);
22
26
  prompt = tokenizer.truncate(prompt, INPUT_COMPONENT_SIZE);
23
27
  const inputSize = tokenizer.count(original) + tokenizer.count(prompt);
24
- const maxInputSize = this.ModelDetails.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER;
28
+ const maxInputSize = model.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER;
25
29
  if (inputSize > maxInputSize) {
26
30
  throw new Error(
27
- `The input size is ${inputSize} tokens long, which is more than the maximum of ${maxInputSize} tokens for this model (${this.ModelDetails.name} = ${this.ModelDetails.input.maxTokens} tokens)`
31
+ `The input size is ${inputSize} tokens long, which is more than the maximum of ${maxInputSize} tokens for this model (${model.name} = ${model.input.maxTokens} tokens)`
28
32
  );
29
33
  }
30
34
  const instructions = [];
@@ -52,17 +56,17 @@ ${END}
52
56
  prompt
53
57
  })
54
58
  );
55
- const formatExample = ({ input, output: output2, instructions: instructions2 }) => {
59
+ const formatExample = ({ input, output, instructions: instructions2 }) => {
56
60
  return [
57
61
  { type: "text", role: "user", content: format(input, instructions2 || prompt) },
58
- { type: "text", role: "assistant", content: `${START}${output2}${END}` }
62
+ { type: "text", role: "assistant", content: `${START}${output}${END}` }
59
63
  ];
60
64
  };
61
65
  const defaultExamples = [
62
66
  { input: "Hello, how are you?", output: "Bonjour, comment \xE7a va?", instructions: "translate to French" },
63
67
  { input: "1\n2\n3", output: "3\n2\n1", instructions: "reverse the order" }
64
68
  ];
65
- const tableExamples = taskId ? await this.adapter.getExamples({
69
+ const tableExamples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
66
70
  input: original,
67
71
  taskId,
68
72
  taskType
@@ -75,30 +79,36 @@ ${END}
75
79
  ...tableExamples.map((x) => ({ input: x.input, output: x.output })),
76
80
  ...options.examples
77
81
  ];
78
- const REMAINING_TOKENS = this.ModelDetails.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER;
82
+ const REMAINING_TOKENS = model.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER;
79
83
  const examples = takeUntilTokens(
80
84
  savedExamples.length ? savedExamples : defaultExamples,
81
85
  REMAINING_TOKENS,
82
86
  (el) => tokenizer.count(stringify(el.input)) + tokenizer.count(stringify(el.output))
83
87
  ).map(formatExample).flat();
84
- const { output, meta } = await this.callModel({
88
+ const { extracted, meta } = await ctx.generateContent({
85
89
  systemPrompt: `
86
90
  Rewrite the text between the ${START} and ${END} tags to match the user prompt.
87
91
  ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
88
92
  `.trim(),
89
93
  messages: [...examples, { type: "text", content: format(original, prompt), role: "user" }],
90
94
  maxTokens: options.length,
91
- stopSequences: [END]
95
+ stopSequences: [END],
96
+ transform: (text) => {
97
+ if (!text.trim().length) {
98
+ throw new Error("The model did not return a valid rewrite. The response was empty.");
99
+ }
100
+ return text;
101
+ }
92
102
  });
93
- let result = output.choices[0]?.content;
103
+ let result = extracted;
94
104
  if (result.includes(START)) {
95
105
  result = result.slice(result.indexOf(START) + START.length);
96
106
  }
97
107
  if (result.includes(END)) {
98
108
  result = result.slice(0, result.indexOf(END));
99
109
  }
100
- if (taskId) {
101
- await this.adapter.saveExample({
110
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
111
+ await ctx.adapter.saveExample({
102
112
  key: Key,
103
113
  metadata: {
104
114
  cost: {
@@ -106,7 +116,7 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
106
116
  output: meta.cost.output
107
117
  },
108
118
  latency: meta.latency,
109
- model: this.Model,
119
+ model: ctx.modelId,
110
120
  tokens: {
111
121
  input: meta.tokens.input,
112
122
  output: meta.tokens.output
@@ -121,3 +131,13 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
121
131
  }
122
132
  return result;
123
133
  };
134
+ Zai.prototype.rewrite = function(original, prompt, _options) {
135
+ const context = new ZaiContext({
136
+ client: this.client,
137
+ modelId: this.Model,
138
+ taskId: this.taskId,
139
+ taskType: "zai.rewrite",
140
+ adapter: this.adapter
141
+ });
142
+ return new Response(context, rewrite(original, prompt, _options, context), (result) => result);
143
+ };