@botpress/zai 2.0.16 → 2.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,131 @@
1
+ import { EventEmitter } from "./emitter";
2
+ export class ZaiContext {
3
+ _startedAt = Date.now();
4
+ _inputCost = 0;
5
+ _outputCost = 0;
6
+ _inputTokens = 0;
7
+ _outputTokens = 0;
8
+ _totalCachedResponses = 0;
9
+ _totalRequests = 0;
10
+ _totalErrors = 0;
11
+ _totalResponses = 0;
12
+ taskId;
13
+ taskType;
14
+ modelId;
15
+ adapter;
16
+ source;
17
+ _eventEmitter;
18
+ controller = new AbortController();
19
+ _client;
20
+ constructor(props) {
21
+ this._client = props.client.clone();
22
+ this.taskId = props.taskId;
23
+ this.modelId = props.modelId;
24
+ this.adapter = props.adapter;
25
+ this.source = props.source;
26
+ this.taskType = props.taskType;
27
+ this._eventEmitter = new EventEmitter();
28
+ this._client.on("request", () => {
29
+ this._totalRequests++;
30
+ this._eventEmitter.emit("update", this.usage);
31
+ });
32
+ this._client.on("response", (_req, res) => {
33
+ this._totalResponses++;
34
+ if (res.meta.cached) {
35
+ this._totalCachedResponses++;
36
+ } else {
37
+ this._inputTokens += res.meta.tokens.input || 0;
38
+ this._outputTokens += res.meta.tokens.output || 0;
39
+ this._inputCost += res.meta.cost.input || 0;
40
+ this._outputCost += res.meta.cost.output || 0;
41
+ }
42
+ this._eventEmitter.emit("update", this.usage);
43
+ });
44
+ this._client.on("error", () => {
45
+ this._totalErrors++;
46
+ this._eventEmitter.emit("update", this.usage);
47
+ });
48
+ }
49
+ async getModel() {
50
+ return this._client.getModelDetails(this.modelId);
51
+ }
52
+ on(type, listener) {
53
+ this._eventEmitter.on(type, listener);
54
+ return this;
55
+ }
56
+ clear() {
57
+ this._eventEmitter.clear();
58
+ }
59
+ async generateContent(props) {
60
+ const maxRetries = Math.max(props.maxRetries ?? 3, 0);
61
+ const transform = props.transform;
62
+ let lastError = null;
63
+ const messages = [...props.messages || []];
64
+ for (let attempt = 0; attempt <= maxRetries; attempt++) {
65
+ try {
66
+ const response = await this._client.generateContent({
67
+ ...props,
68
+ messages,
69
+ signal: this.controller.signal,
70
+ model: this.modelId,
71
+ meta: {
72
+ integrationName: props.meta?.integrationName || "zai",
73
+ promptCategory: props.meta?.promptCategory || `zai:${this.taskType}`,
74
+ promptSource: props.meta?.promptSource || `zai:${this.taskType}:${this.taskId ?? "default"}`
75
+ }
76
+ });
77
+ const content = response.output.choices[0]?.content;
78
+ const str = typeof content === "string" ? content : content?.[0]?.text || "";
79
+ let output;
80
+ messages.push({
81
+ role: "assistant",
82
+ content: str || "<Invalid output, no content provided>"
83
+ });
84
+ if (!transform) {
85
+ output = str;
86
+ } else {
87
+ output = transform(str, response.output);
88
+ }
89
+ return { meta: response.meta, output: response.output, text: str, extracted: output };
90
+ } catch (error) {
91
+ lastError = error;
92
+ if (attempt === maxRetries) {
93
+ throw lastError;
94
+ }
95
+ messages.push({
96
+ role: "user",
97
+ content: `ERROR PARSING OUTPUT
98
+
99
+ ${lastError.message}.
100
+
101
+ Please return a valid response addressing the error above.`
102
+ });
103
+ }
104
+ }
105
+ throw lastError;
106
+ }
107
+ get elapsedTime() {
108
+ return Date.now() - this._startedAt;
109
+ }
110
+ get usage() {
111
+ return {
112
+ requests: {
113
+ errors: this._totalErrors,
114
+ requests: this._totalRequests,
115
+ responses: this._totalResponses,
116
+ cached: this._totalCachedResponses,
117
+ percentage: this._totalRequests > 0 ? (this._totalResponses + this._totalErrors) / this._totalRequests : 0
118
+ },
119
+ tokens: {
120
+ input: this._inputTokens,
121
+ output: this._outputTokens,
122
+ total: this._inputTokens + this._outputTokens
123
+ },
124
+ cost: {
125
+ input: this._inputCost,
126
+ output: this._outputCost,
127
+ total: this._inputCost + this._outputCost
128
+ }
129
+ };
130
+ }
131
+ }
@@ -0,0 +1,42 @@
1
+ export class EventEmitter {
2
+ _listeners = {};
3
+ emit(type, event) {
4
+ const listeners = this._listeners[type];
5
+ if (!listeners) {
6
+ return;
7
+ }
8
+ for (const listener of listeners) {
9
+ listener(event);
10
+ }
11
+ }
12
+ once(type, listener) {
13
+ const wrapped = (event) => {
14
+ this.off(type, wrapped);
15
+ listener(event);
16
+ };
17
+ this.on(type, wrapped);
18
+ }
19
+ on(type, listener) {
20
+ if (!this._listeners[type]) {
21
+ this._listeners[type] = [];
22
+ }
23
+ this._listeners[type].push(listener);
24
+ }
25
+ off(type, listener) {
26
+ const listeners = this._listeners[type];
27
+ if (!listeners) {
28
+ return;
29
+ }
30
+ const index = listeners.indexOf(listener);
31
+ if (index !== -1) {
32
+ listeners.splice(index, 1);
33
+ }
34
+ }
35
+ clear(type) {
36
+ if (type) {
37
+ delete this._listeners[type];
38
+ } else {
39
+ this._listeners = {};
40
+ }
41
+ }
42
+ }
package/dist/index.d.ts CHANGED
@@ -1,4 +1,4 @@
1
- import { Cognitive, Model, BotpressClientLike } from '@botpress/cognitive';
1
+ import { Cognitive, Model, BotpressClientLike, GenerateContentInput, GenerateContentOutput } from '@botpress/cognitive';
2
2
  import { TextTokenizer } from '@bpinternal/thicktoken';
3
3
 
4
4
  type GenerationMetadata = {
@@ -74,6 +74,99 @@ declare class Zai {
74
74
  learn(taskId: string): Zai;
75
75
  }
76
76
 
77
+ type Meta = Awaited<ReturnType<Cognitive['generateContent']>>['meta'];
78
+ type GenerateContentProps<T> = Omit<GenerateContentInput, 'model' | 'signal'> & {
79
+ maxRetries?: number;
80
+ transform?: (text: string | undefined, output: GenerateContentOutput) => T;
81
+ };
82
+ type ZaiContextProps = {
83
+ client: Cognitive;
84
+ taskType: string;
85
+ taskId: string;
86
+ modelId: string;
87
+ adapter?: Adapter;
88
+ source?: GenerateContentInput['meta'];
89
+ };
90
+ type Usage = {
91
+ requests: {
92
+ requests: number;
93
+ errors: number;
94
+ responses: number;
95
+ cached: number;
96
+ percentage: number;
97
+ };
98
+ cost: {
99
+ input: number;
100
+ output: number;
101
+ total: number;
102
+ };
103
+ tokens: {
104
+ input: number;
105
+ output: number;
106
+ total: number;
107
+ };
108
+ };
109
+ type ContextEvents = {
110
+ update: Usage;
111
+ };
112
+ declare class ZaiContext {
113
+ private _startedAt;
114
+ private _inputCost;
115
+ private _outputCost;
116
+ private _inputTokens;
117
+ private _outputTokens;
118
+ private _totalCachedResponses;
119
+ private _totalRequests;
120
+ private _totalErrors;
121
+ private _totalResponses;
122
+ taskId: string;
123
+ taskType: string;
124
+ modelId: GenerateContentInput['model'];
125
+ adapter?: Adapter;
126
+ source?: GenerateContentInput['meta'];
127
+ private _eventEmitter;
128
+ controller: AbortController;
129
+ private _client;
130
+ constructor(props: ZaiContextProps);
131
+ getModel(): Promise<Model>;
132
+ on<K extends keyof ContextEvents>(type: K, listener: (event: ContextEvents[K]) => void): this;
133
+ clear(): void;
134
+ generateContent<Out = string>(props: GenerateContentProps<Out>): Promise<{
135
+ meta: Meta;
136
+ output: GenerateContentOutput;
137
+ text: string | undefined;
138
+ extracted: Out;
139
+ }>;
140
+ get elapsedTime(): number;
141
+ get usage(): Usage;
142
+ }
143
+
144
+ type ResponseEvents<TComplete = any> = {
145
+ progress: Usage;
146
+ complete: TComplete;
147
+ error: unknown;
148
+ };
149
+ declare class Response<T = any, S = T> implements PromiseLike<S> {
150
+ private _promise;
151
+ private _eventEmitter;
152
+ private _context;
153
+ private _elasped;
154
+ private _simplify;
155
+ constructor(context: ZaiContext, promise: Promise<T>, simplify: (value: T) => S);
156
+ on<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void): this;
157
+ off<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void): this;
158
+ once<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void): this;
159
+ bindSignal(signal: AbortSignal): this;
160
+ abort(reason?: string | Error): void;
161
+ then<TResult1 = S, TResult2 = never>(onfulfilled?: ((value: S) => TResult1 | PromiseLike<TResult1>) | null, onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null): PromiseLike<TResult1 | TResult2>;
162
+ catch<TResult = never>(onrejected?: ((reason: any) => TResult | PromiseLike<TResult>) | null): PromiseLike<S | TResult>;
163
+ result(): Promise<{
164
+ output: T;
165
+ usage: Usage;
166
+ elapsed: number;
167
+ }>;
168
+ }
169
+
77
170
  type Options$6 = {
78
171
  /** The maximum number of tokens to generate */
79
172
  length?: number;
@@ -81,7 +174,7 @@ type Options$6 = {
81
174
  declare module '@botpress/zai' {
82
175
  interface Zai {
83
176
  /** Generates a text of the desired length according to the prompt */
84
- text(prompt: string, options?: Options$6): Promise<string>;
177
+ text(prompt: string, options?: Options$6): Response<string>;
85
178
  }
86
179
  }
87
180
 
@@ -99,7 +192,7 @@ type Options$5 = {
99
192
  declare module '@botpress/zai' {
100
193
  interface Zai {
101
194
  /** Rewrites a string according to match the prompt */
102
- rewrite(original: string, prompt: string, options?: Options$5): Promise<string>;
195
+ rewrite(original: string, prompt: string, options?: Options$5): Response<string>;
103
196
  }
104
197
  }
105
198
 
@@ -123,7 +216,7 @@ type Options$4 = {
123
216
  declare module '@botpress/zai' {
124
217
  interface Zai {
125
218
  /** Summarizes a text of any length to a summary of the desired length */
126
- summarize(original: string, options?: Options$4): Promise<string>;
219
+ summarize(original: string, options?: Options$4): Response<string>;
127
220
  }
128
221
  }
129
222
 
@@ -140,12 +233,12 @@ type Options$3 = {
140
233
  declare module '@botpress/zai' {
141
234
  interface Zai {
142
235
  /** Checks wether a condition is true or not */
143
- check(input: unknown, condition: string, options?: Options$3): Promise<{
236
+ check(input: unknown, condition: string, options?: Options$3): Response<{
144
237
  /** Whether the condition is true or not */
145
238
  value: boolean;
146
239
  /** The explanation of the decision */
147
240
  explanation: string;
148
- }>;
241
+ }, boolean>;
149
242
  }
150
243
  }
151
244
 
@@ -163,7 +256,7 @@ type Options$2 = {
163
256
  declare module '@botpress/zai' {
164
257
  interface Zai {
165
258
  /** Filters elements of an array against a condition */
166
- filter<T>(input: Array<T>, condition: string, options?: Options$2): Promise<Array<T>>;
259
+ filter<T>(input: Array<T>, condition: string, options?: Options$2): Response<Array<T>>;
167
260
  }
168
261
  }
169
262
 
@@ -182,7 +275,7 @@ type OfType<O, T extends __Z = __Z<O>> = T extends __Z<O> ? T : never;
182
275
  declare module '@botpress/zai' {
183
276
  interface Zai {
184
277
  /** Extracts one or many elements from an arbitrary input */
185
- extract<S extends OfType<any>>(input: unknown, schema: S, options?: Options$1): Promise<S['_output']>;
278
+ extract<S extends OfType<any>>(input: unknown, schema: S, options?: Options$1): Response<S['_output']>;
186
279
  }
187
280
  }
188
281
 
@@ -213,12 +306,14 @@ type Labels<T extends string> = Record<T, string>;
213
306
  declare module '@botpress/zai' {
214
307
  interface Zai {
215
308
  /** Tags the provided input with a list of predefined labels */
216
- label<T extends string>(input: unknown, labels: Labels<T>, options?: Options<T>): Promise<{
309
+ label<T extends string>(input: unknown, labels: Labels<T>, options?: Options<T>): Response<{
217
310
  [K in T]: {
218
311
  explanation: string;
219
312
  value: boolean;
220
313
  confidence: number;
221
314
  };
315
+ }, {
316
+ [K in T]: boolean;
222
317
  }>;
223
318
  }
224
319
  }
@@ -1,4 +1,7 @@
1
1
  import { z } from "@bpinternal/zui";
2
+ import { ZaiContext } from "../context";
3
+ import { Response } from "../response";
4
+ import { getTokenizer } from "../tokenizer";
2
5
  import { fastHash, stringify, takeUntilTokens } from "../utils";
3
6
  import { Zai } from "../zai";
4
7
  import { PROMPT_INPUT_BUFFER } from "./constants";
@@ -14,12 +17,12 @@ const _Options = z.object({
14
17
  const TRUE = "\u25A0TRUE\u25A0";
15
18
  const FALSE = "\u25A0FALSE\u25A0";
16
19
  const END = "\u25A0END\u25A0";
17
- Zai.prototype.check = async function(input, condition, _options) {
18
- const options = _Options.parse(_options ?? {});
19
- const tokenizer = await this.getTokenizer();
20
- await this.fetchModelDetails();
21
- const PROMPT_COMPONENT = Math.max(this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER, 100);
22
- const taskId = this.taskId;
20
+ const check = async (input, condition, options, ctx) => {
21
+ ctx.controller.signal.throwIfAborted();
22
+ const tokenizer = await getTokenizer();
23
+ const model = await ctx.getModel();
24
+ const PROMPT_COMPONENT = Math.max(model.input.maxTokens - PROMPT_INPUT_BUFFER, 100);
25
+ const taskId = ctx.taskId;
23
26
  const taskType = "zai.check";
24
27
  const PROMPT_TOKENS = {
25
28
  INPUT: Math.floor(0.5 * PROMPT_COMPONENT),
@@ -36,7 +39,7 @@ Zai.prototype.check = async function(input, condition, _options) {
36
39
  condition
37
40
  })
38
41
  );
39
- const examples = taskId ? await this.adapter.getExamples({
42
+ const examples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
40
43
  input: inputAsString,
41
44
  taskType,
42
45
  taskId
@@ -73,10 +76,10 @@ ${input2.trim()}
73
76
  <|end_input|>
74
77
  `.trim();
75
78
  };
76
- const formatOutput = (answer2, justification) => {
79
+ const formatOutput = (answer, justification) => {
77
80
  return `
78
81
  Analysis: ${justification}
79
- Final Answer: ${answer2 ? TRUE : FALSE}
82
+ Final Answer: ${answer ? TRUE : FALSE}
80
83
  ${END}
81
84
  `.trim();
82
85
  };
@@ -103,7 +106,10 @@ ${END}
103
106
  - When in doubt, ground your decision on the examples provided by the experts instead of your own intuition.
104
107
  - When no example is similar to the input, make sure to provide a clear justification for your decision while inferring the decision-making process from the examples provided by the experts.
105
108
  `.trim() : "";
106
- const { output, meta } = await this.callModel({
109
+ const {
110
+ extracted: { finalAnswer, explanation },
111
+ meta
112
+ } = await ctx.generateContent({
107
113
  systemPrompt: `
108
114
  Check if the following condition is true or false for the given input. Before answering, make sure to read the input and the condition carefully.
109
115
  Justify your answer, then answer with either ${TRUE} or ${FALSE} at the very end, then add ${END} to finish the response.
@@ -123,23 +129,25 @@ ${formatInput(inputAsString, condition)}
123
129
  In your "Analysis", please refer to the Expert Examples # to justify your decision.`.trim(),
124
130
  role: "user"
125
131
  }
126
- ]
132
+ ],
133
+ transform: (text) => {
134
+ const hasTrue = text.includes(TRUE);
135
+ const hasFalse = text.includes(FALSE);
136
+ if (!hasTrue && !hasFalse) {
137
+ throw new Error(`The model did not return a valid answer. The response was: ${text}`);
138
+ }
139
+ let finalAnswer2;
140
+ const explanation2 = text.replace(TRUE, "").replace(FALSE, "").replace(END, "").replace("Final Answer:", "").replace("Analysis:", "").trim();
141
+ if (hasTrue && hasFalse) {
142
+ finalAnswer2 = text.lastIndexOf(TRUE) > text.lastIndexOf(FALSE);
143
+ } else {
144
+ finalAnswer2 = hasTrue;
145
+ }
146
+ return { finalAnswer: finalAnswer2, explanation: explanation2.trim() };
147
+ }
127
148
  });
128
- const answer = output.choices[0]?.content;
129
- const hasTrue = answer.includes(TRUE);
130
- const hasFalse = answer.includes(FALSE);
131
- if (!hasTrue && !hasFalse) {
132
- throw new Error(`The model did not return a valid answer. The response was: ${answer}`);
133
- }
134
- let finalAnswer;
135
- const explanation = answer.replace(TRUE, "").replace(FALSE, "").replace(END, "").replace("Final Answer:", "").replace("Analysis:", "").trim();
136
- if (hasTrue && hasFalse) {
137
- finalAnswer = answer.lastIndexOf(TRUE) > answer.lastIndexOf(FALSE);
138
- } else {
139
- finalAnswer = hasTrue;
140
- }
141
- if (taskId) {
142
- await this.adapter.saveExample({
149
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
150
+ await ctx.adapter.saveExample({
143
151
  key: Key,
144
152
  taskType,
145
153
  taskId,
@@ -151,7 +159,7 @@ In your "Analysis", please refer to the Expert Examples # to justify your decisi
151
159
  output: meta.cost.output
152
160
  },
153
161
  latency: meta.latency,
154
- model: this.Model,
162
+ model: ctx.modelId,
155
163
  tokens: {
156
164
  input: meta.tokens.input,
157
165
  output: meta.tokens.output
@@ -166,3 +174,14 @@ In your "Analysis", please refer to the Expert Examples # to justify your decisi
166
174
  explanation: explanation.trim()
167
175
  };
168
176
  };
177
+ Zai.prototype.check = function(input, condition, _options) {
178
+ const options = _Options.parse(_options ?? {});
179
+ const context = new ZaiContext({
180
+ client: this.client,
181
+ modelId: this.Model,
182
+ taskId: this.taskId,
183
+ taskType: "zai.check",
184
+ adapter: this.adapter
185
+ });
186
+ return new Response(context, check(input, condition, options, context), (result) => result.value);
187
+ };
@@ -2,6 +2,9 @@ import { z } from "@bpinternal/zui";
2
2
  import JSON5 from "json5";
3
3
  import { jsonrepair } from "jsonrepair";
4
4
  import { chunk, isArray } from "lodash-es";
5
+ import { ZaiContext } from "../context";
6
+ import { Response } from "../response";
7
+ import { getTokenizer } from "../tokenizer";
5
8
  import { fastHash, stringify, takeUntilTokens } from "../utils";
6
9
  import { Zai } from "../zai";
7
10
  import { PROMPT_INPUT_BUFFER } from "./constants";
@@ -14,14 +17,15 @@ const Options = z.object({
14
17
  const START = "\u25A0json_start\u25A0";
15
18
  const END = "\u25A0json_end\u25A0";
16
19
  const NO_MORE = "\u25A0NO_MORE_ELEMENT\u25A0";
17
- Zai.prototype.extract = async function(input, _schema, _options) {
20
+ const extract = async (input, _schema, _options, ctx) => {
21
+ ctx.controller.signal.throwIfAborted();
18
22
  let schema = _schema;
19
23
  const options = Options.parse(_options ?? {});
20
- const tokenizer = await this.getTokenizer();
21
- await this.fetchModelDetails();
22
- const taskId = this.taskId;
24
+ const tokenizer = await getTokenizer();
25
+ const model = await ctx.getModel();
26
+ const taskId = ctx.taskId;
23
27
  const taskType = "zai.extract";
24
- const PROMPT_COMPONENT = Math.max(this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER, 100);
28
+ const PROMPT_COMPONENT = Math.max(model.input.maxTokens - PROMPT_INPUT_BUFFER, 100);
25
29
  let isArrayOfObjects = false;
26
30
  let wrappedValue = false;
27
31
  const originalSchema = schema;
@@ -54,10 +58,7 @@ Zai.prototype.extract = async function(input, _schema, _options) {
54
58
  }
55
59
  const schemaTypescript = schema.toTypescriptType({ declaration: false });
56
60
  const schemaLength = tokenizer.count(schemaTypescript);
57
- options.chunkLength = Math.min(
58
- options.chunkLength,
59
- this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength
60
- );
61
+ options.chunkLength = Math.min(options.chunkLength, model.input.maxTokens - PROMPT_INPUT_BUFFER - schemaLength);
61
62
  const keys = Object.keys(schema.shape);
62
63
  const inputAsString = stringify(input);
63
64
  if (tokenizer.count(inputAsString) > options.chunkLength) {
@@ -65,19 +66,25 @@ Zai.prototype.extract = async function(input, _schema, _options) {
65
66
  const chunks = chunk(tokens, options.chunkLength).map((x) => x.join(""));
66
67
  const all = await Promise.allSettled(
67
68
  chunks.map(
68
- (chunk2) => this.extract(chunk2, originalSchema, {
69
- ...options,
70
- strict: false
71
- // We don't want to fail on strict mode for sub-chunks
72
- })
69
+ (chunk2) => extract(
70
+ chunk2,
71
+ originalSchema,
72
+ {
73
+ ...options,
74
+ strict: false
75
+ // We don't want to fail on strict mode for sub-chunks
76
+ },
77
+ ctx
78
+ )
73
79
  )
74
80
  ).then(
75
81
  (results) => results.filter((x) => x.status === "fulfilled").map((x) => x.value)
76
82
  );
83
+ ctx.controller.signal.throwIfAborted();
77
84
  const rows = all.map((x, idx) => `<part-${idx + 1}>
78
85
  ${stringify(x, true)}
79
86
  </part-${idx + 1}>`).join("\n");
80
- return this.extract(
87
+ return extract(
81
88
  `
82
89
  The result has been split into ${all.length} parts. Recursively merge the result into the final result.
83
90
  When merging arrays, take unique values.
@@ -89,7 +96,8 @@ ${rows}
89
96
 
90
97
  Merge it back into a final result.`.trim(),
91
98
  originalSchema,
92
- options
99
+ options,
100
+ ctx
93
101
  );
94
102
  }
95
103
  const instructions = [];
@@ -123,7 +131,7 @@ Merge it back into a final result.`.trim(),
123
131
  instructions: options.instructions
124
132
  })
125
133
  );
126
- const examples = taskId ? await this.adapter.getExamples({
134
+ const examples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
127
135
  input: inputAsString,
128
136
  taskType,
129
137
  taskId
@@ -182,9 +190,9 @@ ${input2.trim()}
182
190
  <|end_input|>
183
191
  `.trim();
184
192
  };
185
- const formatOutput = (extracted) => {
186
- extracted = isArray(extracted) ? extracted : [extracted];
187
- return extracted.map(
193
+ const formatOutput = (extracted2) => {
194
+ extracted2 = isArray(extracted2) ? extracted2 : [extracted2];
195
+ return extracted2.map(
188
196
  (x) => `
189
197
  ${START}
190
198
  ${JSON.stringify(x, null, 2)}
@@ -208,7 +216,7 @@ ${END}`.trim()
208
216
  EXAMPLES_TOKENS,
209
217
  (el) => tokenizer.count(stringify(el.input)) + tokenizer.count(stringify(el.extracted))
210
218
  ).map(formatExample).flat();
211
- const { output, meta } = await this.callModel({
219
+ const { meta, extracted } = await ctx.generateContent({
212
220
  systemPrompt: `
213
221
  Extract the following information from the input:
214
222
  ${schemaTypescript}
@@ -224,33 +232,32 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
224
232
  type: "text",
225
233
  content: formatInput(inputAsString, schemaTypescript, options.instructions ?? "")
226
234
  }
227
- ]
228
- });
229
- const answer = output.choices[0]?.content ?? "{}";
230
- const elements = answer?.split(START).filter((x) => x.trim().length > 0 && x.includes("}")).map((x) => {
231
- try {
232
- const json = x.slice(0, x.indexOf(END)).trim();
233
- const repairedJson = jsonrepair(json);
234
- const parsedJson = JSON5.parse(repairedJson);
235
- const safe = schema.safeParse(parsedJson);
236
- if (safe.success) {
237
- return safe.data;
238
- }
239
- if (options.strict) {
240
- throw new JsonParsingError(x, safe.error);
235
+ ],
236
+ transform: (text) => (text || "{}")?.split(START).filter((x) => x.trim().length > 0 && x.includes("}")).map((x) => {
237
+ try {
238
+ const json = x.slice(0, x.indexOf(END)).trim();
239
+ const repairedJson = jsonrepair(json);
240
+ const parsedJson = JSON5.parse(repairedJson);
241
+ const safe = schema.safeParse(parsedJson);
242
+ if (safe.success) {
243
+ return safe.data;
244
+ }
245
+ if (options.strict) {
246
+ throw new JsonParsingError(x, safe.error);
247
+ }
248
+ return parsedJson;
249
+ } catch (error) {
250
+ throw new JsonParsingError(x, error instanceof Error ? error : new Error("Unknown error"));
241
251
  }
242
- return parsedJson;
243
- } catch (error) {
244
- throw new JsonParsingError(x, error instanceof Error ? error : new Error("Unknown error"));
245
- }
246
- }).filter((x) => x !== null);
252
+ }).filter((x) => x !== null)
253
+ });
247
254
  let final;
248
255
  if (isArrayOfObjects) {
249
- final = elements;
250
- } else if (elements.length === 0) {
256
+ final = extracted;
257
+ } else if (extracted.length === 0) {
251
258
  final = options.strict ? schema.parse({}) : {};
252
259
  } else {
253
- final = elements[0];
260
+ final = extracted[0];
254
261
  }
255
262
  if (wrappedValue) {
256
263
  if (Array.isArray(final)) {
@@ -259,8 +266,8 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
259
266
  final = "value" in final ? final.value : final;
260
267
  }
261
268
  }
262
- if (taskId) {
263
- await this.adapter.saveExample({
269
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
270
+ await ctx.adapter.saveExample({
264
271
  key: Key,
265
272
  taskId: `zai/${taskId}`,
266
273
  taskType,
@@ -273,7 +280,7 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
273
280
  output: meta.cost.output
274
281
  },
275
282
  latency: meta.latency,
276
- model: this.Model,
283
+ model: ctx.modelId,
277
284
  tokens: {
278
285
  input: meta.tokens.input,
279
286
  output: meta.tokens.output
@@ -283,3 +290,13 @@ ${instructions.map((x) => `\u2022 ${x}`).join("\n")}
283
290
  }
284
291
  return final;
285
292
  };
293
+ Zai.prototype.extract = function(input, schema, _options) {
294
+ const context = new ZaiContext({
295
+ client: this.client,
296
+ modelId: this.Model,
297
+ taskId: this.taskId,
298
+ taskType: "zai.extract",
299
+ adapter: this.adapter
300
+ });
301
+ return new Response(context, extract(input, schema, _options, context), (result) => result);
302
+ };