@botpress/zai 2.0.15 → 2.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,131 @@
1
+ import { EventEmitter } from "./emitter";
2
+ export class ZaiContext {
3
+ _startedAt = Date.now();
4
+ _inputCost = 0;
5
+ _outputCost = 0;
6
+ _inputTokens = 0;
7
+ _outputTokens = 0;
8
+ _totalCachedResponses = 0;
9
+ _totalRequests = 0;
10
+ _totalErrors = 0;
11
+ _totalResponses = 0;
12
+ taskId;
13
+ taskType;
14
+ modelId;
15
+ adapter;
16
+ source;
17
+ _eventEmitter;
18
+ controller = new AbortController();
19
+ _client;
20
+ constructor(props) {
21
+ this._client = props.client.clone();
22
+ this.taskId = props.taskId;
23
+ this.modelId = props.modelId;
24
+ this.adapter = props.adapter;
25
+ this.source = props.source;
26
+ this.taskType = props.taskType;
27
+ this._eventEmitter = new EventEmitter();
28
+ this._client.on("request", () => {
29
+ this._totalRequests++;
30
+ this._eventEmitter.emit("update", this.usage);
31
+ });
32
+ this._client.on("response", (_req, res) => {
33
+ this._totalResponses++;
34
+ if (res.meta.cached) {
35
+ this._totalCachedResponses++;
36
+ } else {
37
+ this._inputTokens += res.meta.tokens.input || 0;
38
+ this._outputTokens += res.meta.tokens.output || 0;
39
+ this._inputCost += res.meta.cost.input || 0;
40
+ this._outputCost += res.meta.cost.output || 0;
41
+ }
42
+ this._eventEmitter.emit("update", this.usage);
43
+ });
44
+ this._client.on("error", () => {
45
+ this._totalErrors++;
46
+ this._eventEmitter.emit("update", this.usage);
47
+ });
48
+ }
49
+ async getModel() {
50
+ return this._client.getModelDetails(this.modelId);
51
+ }
52
+ on(type, listener) {
53
+ this._eventEmitter.on(type, listener);
54
+ return this;
55
+ }
56
+ clear() {
57
+ this._eventEmitter.clear();
58
+ }
59
+ async generateContent(props) {
60
+ const maxRetries = Math.max(props.maxRetries ?? 3, 0);
61
+ const transform = props.transform;
62
+ let lastError = null;
63
+ const messages = [...props.messages || []];
64
+ for (let attempt = 0; attempt <= maxRetries; attempt++) {
65
+ try {
66
+ const response = await this._client.generateContent({
67
+ ...props,
68
+ messages,
69
+ signal: this.controller.signal,
70
+ model: this.modelId,
71
+ meta: {
72
+ integrationName: props.meta?.integrationName || "zai",
73
+ promptCategory: props.meta?.promptCategory || `zai:${this.taskType}`,
74
+ promptSource: props.meta?.promptSource || `zai:${this.taskType}:${this.taskId ?? "default"}`
75
+ }
76
+ });
77
+ const content = response.output.choices[0]?.content;
78
+ const str = typeof content === "string" ? content : content?.[0]?.text || "";
79
+ let output;
80
+ messages.push({
81
+ role: "assistant",
82
+ content: str || "<Invalid output, no content provided>"
83
+ });
84
+ if (!transform) {
85
+ output = str;
86
+ } else {
87
+ output = transform(str, response.output);
88
+ }
89
+ return { meta: response.meta, output: response.output, text: str, extracted: output };
90
+ } catch (error) {
91
+ lastError = error;
92
+ if (attempt === maxRetries) {
93
+ throw lastError;
94
+ }
95
+ messages.push({
96
+ role: "user",
97
+ content: `ERROR PARSING OUTPUT
98
+
99
+ ${lastError.message}.
100
+
101
+ Please return a valid response addressing the error above.`
102
+ });
103
+ }
104
+ }
105
+ throw lastError;
106
+ }
107
+ get elapsedTime() {
108
+ return Date.now() - this._startedAt;
109
+ }
110
+ get usage() {
111
+ return {
112
+ requests: {
113
+ errors: this._totalErrors,
114
+ requests: this._totalRequests,
115
+ responses: this._totalResponses,
116
+ cached: this._totalCachedResponses,
117
+ percentage: this._totalRequests > 0 ? (this._totalResponses + this._totalErrors) / this._totalRequests : 0
118
+ },
119
+ tokens: {
120
+ input: this._inputTokens,
121
+ output: this._outputTokens,
122
+ total: this._inputTokens + this._outputTokens
123
+ },
124
+ cost: {
125
+ input: this._inputCost,
126
+ output: this._outputCost,
127
+ total: this._inputCost + this._outputCost
128
+ }
129
+ };
130
+ }
131
+ }
@@ -0,0 +1,42 @@
1
+ export class EventEmitter {
2
+ _listeners = {};
3
+ emit(type, event) {
4
+ const listeners = this._listeners[type];
5
+ if (!listeners) {
6
+ return;
7
+ }
8
+ for (const listener of listeners) {
9
+ listener(event);
10
+ }
11
+ }
12
+ once(type, listener) {
13
+ const wrapped = (event) => {
14
+ this.off(type, wrapped);
15
+ listener(event);
16
+ };
17
+ this.on(type, wrapped);
18
+ }
19
+ on(type, listener) {
20
+ if (!this._listeners[type]) {
21
+ this._listeners[type] = [];
22
+ }
23
+ this._listeners[type].push(listener);
24
+ }
25
+ off(type, listener) {
26
+ const listeners = this._listeners[type];
27
+ if (!listeners) {
28
+ return;
29
+ }
30
+ const index = listeners.indexOf(listener);
31
+ if (index !== -1) {
32
+ listeners.splice(index, 1);
33
+ }
34
+ }
35
+ clear(type) {
36
+ if (type) {
37
+ delete this._listeners[type];
38
+ } else {
39
+ this._listeners = {};
40
+ }
41
+ }
42
+ }
package/dist/index.d.ts CHANGED
@@ -1,4 +1,4 @@
1
- import { Cognitive, Model, BotpressClientLike } from '@botpress/cognitive';
1
+ import { Cognitive, Model, BotpressClientLike, GenerateContentInput, GenerateContentOutput } from '@botpress/cognitive';
2
2
  import { TextTokenizer } from '@bpinternal/thicktoken';
3
3
 
4
4
  type GenerationMetadata = {
@@ -74,6 +74,99 @@ declare class Zai {
74
74
  learn(taskId: string): Zai;
75
75
  }
76
76
 
77
+ type Meta = Awaited<ReturnType<Cognitive['generateContent']>>['meta'];
78
+ type GenerateContentProps<T> = Omit<GenerateContentInput, 'model' | 'signal'> & {
79
+ maxRetries?: number;
80
+ transform?: (text: string | undefined, output: GenerateContentOutput) => T;
81
+ };
82
+ type ZaiContextProps = {
83
+ client: Cognitive;
84
+ taskType: string;
85
+ taskId: string;
86
+ modelId: string;
87
+ adapter?: Adapter;
88
+ source?: GenerateContentInput['meta'];
89
+ };
90
+ type Usage = {
91
+ requests: {
92
+ requests: number;
93
+ errors: number;
94
+ responses: number;
95
+ cached: number;
96
+ percentage: number;
97
+ };
98
+ cost: {
99
+ input: number;
100
+ output: number;
101
+ total: number;
102
+ };
103
+ tokens: {
104
+ input: number;
105
+ output: number;
106
+ total: number;
107
+ };
108
+ };
109
+ type ContextEvents = {
110
+ update: Usage;
111
+ };
112
+ declare class ZaiContext {
113
+ private _startedAt;
114
+ private _inputCost;
115
+ private _outputCost;
116
+ private _inputTokens;
117
+ private _outputTokens;
118
+ private _totalCachedResponses;
119
+ private _totalRequests;
120
+ private _totalErrors;
121
+ private _totalResponses;
122
+ taskId: string;
123
+ taskType: string;
124
+ modelId: GenerateContentInput['model'];
125
+ adapter?: Adapter;
126
+ source?: GenerateContentInput['meta'];
127
+ private _eventEmitter;
128
+ controller: AbortController;
129
+ private _client;
130
+ constructor(props: ZaiContextProps);
131
+ getModel(): Promise<Model>;
132
+ on<K extends keyof ContextEvents>(type: K, listener: (event: ContextEvents[K]) => void): this;
133
+ clear(): void;
134
+ generateContent<Out = string>(props: GenerateContentProps<Out>): Promise<{
135
+ meta: Meta;
136
+ output: GenerateContentOutput;
137
+ text: string | undefined;
138
+ extracted: Out;
139
+ }>;
140
+ get elapsedTime(): number;
141
+ get usage(): Usage;
142
+ }
143
+
144
+ type ResponseEvents<TComplete = any> = {
145
+ progress: Usage;
146
+ complete: TComplete;
147
+ error: unknown;
148
+ };
149
+ declare class Response<T = any, S = T> implements PromiseLike<S> {
150
+ private _promise;
151
+ private _eventEmitter;
152
+ private _context;
153
+ private _elasped;
154
+ private _simplify;
155
+ constructor(context: ZaiContext, promise: Promise<T>, simplify: (value: T) => S);
156
+ on<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void): this;
157
+ off<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void): this;
158
+ once<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void): this;
159
+ bindSignal(signal: AbortSignal): this;
160
+ abort(reason?: string | Error): void;
161
+ then<TResult1 = S, TResult2 = never>(onfulfilled?: ((value: S) => TResult1 | PromiseLike<TResult1>) | null, onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null): PromiseLike<TResult1 | TResult2>;
162
+ catch<TResult = never>(onrejected?: ((reason: any) => TResult | PromiseLike<TResult>) | null): PromiseLike<S | TResult>;
163
+ result(): Promise<{
164
+ output: T;
165
+ usage: Usage;
166
+ elapsed: number;
167
+ }>;
168
+ }
169
+
77
170
  type Options$6 = {
78
171
  /** The maximum number of tokens to generate */
79
172
  length?: number;
@@ -81,7 +174,7 @@ type Options$6 = {
81
174
  declare module '@botpress/zai' {
82
175
  interface Zai {
83
176
  /** Generates a text of the desired length according to the prompt */
84
- text(prompt: string, options?: Options$6): Promise<string>;
177
+ text(prompt: string, options?: Options$6): Response<string>;
85
178
  }
86
179
  }
87
180
 
@@ -99,7 +192,7 @@ type Options$5 = {
99
192
  declare module '@botpress/zai' {
100
193
  interface Zai {
101
194
  /** Rewrites a string according to match the prompt */
102
- rewrite(original: string, prompt: string, options?: Options$5): Promise<string>;
195
+ rewrite(original: string, prompt: string, options?: Options$5): Response<string>;
103
196
  }
104
197
  }
105
198
 
@@ -123,7 +216,7 @@ type Options$4 = {
123
216
  declare module '@botpress/zai' {
124
217
  interface Zai {
125
218
  /** Summarizes a text of any length to a summary of the desired length */
126
- summarize(original: string, options?: Options$4): Promise<string>;
219
+ summarize(original: string, options?: Options$4): Response<string>;
127
220
  }
128
221
  }
129
222
 
@@ -140,12 +233,12 @@ type Options$3 = {
140
233
  declare module '@botpress/zai' {
141
234
  interface Zai {
142
235
  /** Checks wether a condition is true or not */
143
- check(input: unknown, condition: string, options?: Options$3): Promise<{
236
+ check(input: unknown, condition: string, options?: Options$3): Response<{
144
237
  /** Whether the condition is true or not */
145
238
  value: boolean;
146
239
  /** The explanation of the decision */
147
240
  explanation: string;
148
- }>;
241
+ }, boolean>;
149
242
  }
150
243
  }
151
244
 
@@ -163,7 +256,7 @@ type Options$2 = {
163
256
  declare module '@botpress/zai' {
164
257
  interface Zai {
165
258
  /** Filters elements of an array against a condition */
166
- filter<T>(input: Array<T>, condition: string, options?: Options$2): Promise<Array<T>>;
259
+ filter<T>(input: Array<T>, condition: string, options?: Options$2): Response<Array<T>>;
167
260
  }
168
261
  }
169
262
 
@@ -172,16 +265,17 @@ type Options$1 = {
172
265
  instructions?: string;
173
266
  /** The maximum number of tokens per chunk */
174
267
  chunkLength?: number;
268
+ /** Whether to strictly follow the schema or not */
269
+ strict?: boolean;
175
270
  };
176
271
  type __Z<T extends any = any> = {
177
272
  _output: T;
178
273
  };
179
274
  type OfType<O, T extends __Z = __Z<O>> = T extends __Z<O> ? T : never;
180
- type AnyObjectOrArray = Record<string, unknown> | Array<unknown>;
181
275
  declare module '@botpress/zai' {
182
276
  interface Zai {
183
277
  /** Extracts one or many elements from an arbitrary input */
184
- extract<S extends OfType<AnyObjectOrArray>>(input: unknown, schema: S, options?: Options$1): Promise<S['_output']>;
278
+ extract<S extends OfType<any>>(input: unknown, schema: S, options?: Options$1): Response<S['_output']>;
185
279
  }
186
280
  }
187
281
 
@@ -212,12 +306,14 @@ type Labels<T extends string> = Record<T, string>;
212
306
  declare module '@botpress/zai' {
213
307
  interface Zai {
214
308
  /** Tags the provided input with a list of predefined labels */
215
- label<T extends string>(input: unknown, labels: Labels<T>, options?: Options<T>): Promise<{
309
+ label<T extends string>(input: unknown, labels: Labels<T>, options?: Options<T>): Response<{
216
310
  [K in T]: {
217
311
  explanation: string;
218
312
  value: boolean;
219
313
  confidence: number;
220
314
  };
315
+ }, {
316
+ [K in T]: boolean;
221
317
  }>;
222
318
  }
223
319
  }
@@ -1,4 +1,7 @@
1
1
  import { z } from "@bpinternal/zui";
2
+ import { ZaiContext } from "../context";
3
+ import { Response } from "../response";
4
+ import { getTokenizer } from "../tokenizer";
2
5
  import { fastHash, stringify, takeUntilTokens } from "../utils";
3
6
  import { Zai } from "../zai";
4
7
  import { PROMPT_INPUT_BUFFER } from "./constants";
@@ -14,12 +17,12 @@ const _Options = z.object({
14
17
  const TRUE = "\u25A0TRUE\u25A0";
15
18
  const FALSE = "\u25A0FALSE\u25A0";
16
19
  const END = "\u25A0END\u25A0";
17
- Zai.prototype.check = async function(input, condition, _options) {
18
- const options = _Options.parse(_options ?? {});
19
- const tokenizer = await this.getTokenizer();
20
- await this.fetchModelDetails();
21
- const PROMPT_COMPONENT = Math.max(this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER, 100);
22
- const taskId = this.taskId;
20
+ const check = async (input, condition, options, ctx) => {
21
+ ctx.controller.signal.throwIfAborted();
22
+ const tokenizer = await getTokenizer();
23
+ const model = await ctx.getModel();
24
+ const PROMPT_COMPONENT = Math.max(model.input.maxTokens - PROMPT_INPUT_BUFFER, 100);
25
+ const taskId = ctx.taskId;
23
26
  const taskType = "zai.check";
24
27
  const PROMPT_TOKENS = {
25
28
  INPUT: Math.floor(0.5 * PROMPT_COMPONENT),
@@ -36,7 +39,7 @@ Zai.prototype.check = async function(input, condition, _options) {
36
39
  condition
37
40
  })
38
41
  );
39
- const examples = taskId ? await this.adapter.getExamples({
42
+ const examples = taskId && ctx.adapter ? await ctx.adapter.getExamples({
40
43
  input: inputAsString,
41
44
  taskType,
42
45
  taskId
@@ -73,10 +76,10 @@ ${input2.trim()}
73
76
  <|end_input|>
74
77
  `.trim();
75
78
  };
76
- const formatOutput = (answer2, justification) => {
79
+ const formatOutput = (answer, justification) => {
77
80
  return `
78
81
  Analysis: ${justification}
79
- Final Answer: ${answer2 ? TRUE : FALSE}
82
+ Final Answer: ${answer ? TRUE : FALSE}
80
83
  ${END}
81
84
  `.trim();
82
85
  };
@@ -103,7 +106,10 @@ ${END}
103
106
  - When in doubt, ground your decision on the examples provided by the experts instead of your own intuition.
104
107
  - When no example is similar to the input, make sure to provide a clear justification for your decision while inferring the decision-making process from the examples provided by the experts.
105
108
  `.trim() : "";
106
- const { output, meta } = await this.callModel({
109
+ const {
110
+ extracted: { finalAnswer, explanation },
111
+ meta
112
+ } = await ctx.generateContent({
107
113
  systemPrompt: `
108
114
  Check if the following condition is true or false for the given input. Before answering, make sure to read the input and the condition carefully.
109
115
  Justify your answer, then answer with either ${TRUE} or ${FALSE} at the very end, then add ${END} to finish the response.
@@ -123,23 +129,25 @@ ${formatInput(inputAsString, condition)}
123
129
  In your "Analysis", please refer to the Expert Examples # to justify your decision.`.trim(),
124
130
  role: "user"
125
131
  }
126
- ]
132
+ ],
133
+ transform: (text) => {
134
+ const hasTrue = text.includes(TRUE);
135
+ const hasFalse = text.includes(FALSE);
136
+ if (!hasTrue && !hasFalse) {
137
+ throw new Error(`The model did not return a valid answer. The response was: ${text}`);
138
+ }
139
+ let finalAnswer2;
140
+ const explanation2 = text.replace(TRUE, "").replace(FALSE, "").replace(END, "").replace("Final Answer:", "").replace("Analysis:", "").trim();
141
+ if (hasTrue && hasFalse) {
142
+ finalAnswer2 = text.lastIndexOf(TRUE) > text.lastIndexOf(FALSE);
143
+ } else {
144
+ finalAnswer2 = hasTrue;
145
+ }
146
+ return { finalAnswer: finalAnswer2, explanation: explanation2.trim() };
147
+ }
127
148
  });
128
- const answer = output.choices[0]?.content;
129
- const hasTrue = answer.includes(TRUE);
130
- const hasFalse = answer.includes(FALSE);
131
- if (!hasTrue && !hasFalse) {
132
- throw new Error(`The model did not return a valid answer. The response was: ${answer}`);
133
- }
134
- let finalAnswer;
135
- const explanation = answer.replace(TRUE, "").replace(FALSE, "").replace(END, "").replace("Final Answer:", "").replace("Analysis:", "").trim();
136
- if (hasTrue && hasFalse) {
137
- finalAnswer = answer.lastIndexOf(TRUE) > answer.lastIndexOf(FALSE);
138
- } else {
139
- finalAnswer = hasTrue;
140
- }
141
- if (taskId) {
142
- await this.adapter.saveExample({
149
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
150
+ await ctx.adapter.saveExample({
143
151
  key: Key,
144
152
  taskType,
145
153
  taskId,
@@ -151,7 +159,7 @@ In your "Analysis", please refer to the Expert Examples # to justify your decisi
151
159
  output: meta.cost.output
152
160
  },
153
161
  latency: meta.latency,
154
- model: this.Model,
162
+ model: ctx.modelId,
155
163
  tokens: {
156
164
  input: meta.tokens.input,
157
165
  output: meta.tokens.output
@@ -166,3 +174,14 @@ In your "Analysis", please refer to the Expert Examples # to justify your decisi
166
174
  explanation: explanation.trim()
167
175
  };
168
176
  };
177
+ Zai.prototype.check = function(input, condition, _options) {
178
+ const options = _Options.parse(_options ?? {});
179
+ const context = new ZaiContext({
180
+ client: this.client,
181
+ modelId: this.Model,
182
+ taskId: this.taskId,
183
+ taskType: "zai.check",
184
+ adapter: this.adapter
185
+ });
186
+ return new Response(context, check(input, condition, options, context), (result) => result.value);
187
+ };