@botpress/zai 2.0.16 → 2.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,9 @@
1
1
  // eslint-disable consistent-type-definitions
2
2
  import { z } from '@bpinternal/zui'
3
3
 
4
+ import { ZaiContext } from '../context'
5
+ import { Response } from '../response'
6
+ import { getTokenizer } from '../tokenizer'
4
7
  import { fastHash, stringify, takeUntilTokens } from '../utils'
5
8
  import { Zai } from '../zai'
6
9
  import { PROMPT_INPUT_BUFFER } from './constants'
@@ -31,29 +34,35 @@ const Options = z.object({
31
34
  declare module '@botpress/zai' {
32
35
  interface Zai {
33
36
  /** Rewrites a string according to match the prompt */
34
- rewrite(original: string, prompt: string, options?: Options): Promise<string>
37
+ rewrite(original: string, prompt: string, options?: Options): Response<string>
35
38
  }
36
39
  }
37
40
 
38
41
  const START = '■START■'
39
42
  const END = '■END■'
40
43
 
41
- Zai.prototype.rewrite = async function (this: Zai, original, prompt, _options) {
44
+ const rewrite = async (
45
+ original: string,
46
+ prompt: string,
47
+ _options: Options | undefined,
48
+ ctx: ZaiContext
49
+ ): Promise<string> => {
50
+ ctx.controller.signal.throwIfAborted()
42
51
  const options = Options.parse(_options ?? {}) as Options
43
- const tokenizer = await this.getTokenizer()
44
- await this.fetchModelDetails()
52
+ const tokenizer = await getTokenizer()
53
+ const model = await ctx.getModel()
45
54
 
46
- const taskId = this.taskId
55
+ const taskId = ctx.taskId
47
56
  const taskType = 'zai.rewrite'
48
57
 
49
- const INPUT_COMPONENT_SIZE = Math.max(100, (this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER) / 2)
58
+ const INPUT_COMPONENT_SIZE = Math.max(100, (model.input.maxTokens - PROMPT_INPUT_BUFFER) / 2)
50
59
  prompt = tokenizer.truncate(prompt, INPUT_COMPONENT_SIZE)
51
60
 
52
61
  const inputSize = tokenizer.count(original) + tokenizer.count(prompt)
53
- const maxInputSize = this.ModelDetails.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER
62
+ const maxInputSize = model.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER
54
63
  if (inputSize > maxInputSize) {
55
64
  throw new Error(
56
- `The input size is ${inputSize} tokens long, which is more than the maximum of ${maxInputSize} tokens for this model (${this.ModelDetails.name} = ${this.ModelDetails.input.maxTokens} tokens)`
65
+ `The input size is ${inputSize} tokens long, which is more than the maximum of ${maxInputSize} tokens for this model (${model.name} = ${model.input.maxTokens} tokens)`
57
66
  )
58
67
  }
59
68
 
@@ -98,13 +107,14 @@ ${END}
98
107
  { input: '1\n2\n3', output: '3\n2\n1', instructions: 'reverse the order' },
99
108
  ]
100
109
 
101
- const tableExamples = taskId
102
- ? await this.adapter.getExamples<string, string>({
103
- input: original,
104
- taskId,
105
- taskType,
106
- })
107
- : []
110
+ const tableExamples =
111
+ taskId && ctx.adapter
112
+ ? await ctx.adapter.getExamples<string, string>({
113
+ input: original,
114
+ taskId,
115
+ taskType,
116
+ })
117
+ : []
108
118
 
109
119
  const exactMatch = tableExamples.find((x) => x.key === Key)
110
120
  if (exactMatch) {
@@ -116,7 +126,7 @@ ${END}
116
126
  ...options.examples,
117
127
  ]
118
128
 
119
- const REMAINING_TOKENS = this.ModelDetails.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER
129
+ const REMAINING_TOKENS = model.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER
120
130
  const examples = takeUntilTokens(
121
131
  savedExamples.length ? savedExamples : defaultExamples,
122
132
  REMAINING_TOKENS,
@@ -125,7 +135,7 @@ ${END}
125
135
  .map(formatExample)
126
136
  .flat()
127
137
 
128
- const { output, meta } = await this.callModel({
138
+ const { extracted, meta } = await ctx.generateContent({
129
139
  systemPrompt: `
130
140
  Rewrite the text between the ${START} and ${END} tags to match the user prompt.
131
141
  ${instructions.map((x) => `• ${x}`).join('\n')}
@@ -133,9 +143,16 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
133
143
  messages: [...examples, { type: 'text', content: format(original, prompt), role: 'user' }],
134
144
  maxTokens: options.length,
135
145
  stopSequences: [END],
146
+ transform: (text) => {
147
+ if (!text.trim().length) {
148
+ throw new Error('The model did not return a valid rewrite. The response was empty.')
149
+ }
150
+
151
+ return text
152
+ },
136
153
  })
137
154
 
138
- let result = output.choices[0]?.content as string
155
+ let result = extracted
139
156
 
140
157
  if (result.includes(START)) {
141
158
  result = result.slice(result.indexOf(START) + START.length)
@@ -145,8 +162,8 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
145
162
  result = result.slice(0, result.indexOf(END))
146
163
  }
147
164
 
148
- if (taskId) {
149
- await this.adapter.saveExample({
165
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
166
+ await ctx.adapter.saveExample({
150
167
  key: Key,
151
168
  metadata: {
152
169
  cost: {
@@ -154,7 +171,7 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
154
171
  output: meta.cost.output,
155
172
  },
156
173
  latency: meta.latency,
157
- model: this.Model,
174
+ model: ctx.modelId,
158
175
  tokens: {
159
176
  input: meta.tokens.input,
160
177
  output: meta.tokens.output,
@@ -170,3 +187,15 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
170
187
 
171
188
  return result
172
189
  }
190
+
191
+ Zai.prototype.rewrite = function (this: Zai, original: string, prompt: string, _options?: Options): Response<string> {
192
+ const context = new ZaiContext({
193
+ client: this.client,
194
+ modelId: this.Model,
195
+ taskId: this.taskId,
196
+ taskType: 'zai.rewrite',
197
+ adapter: this.adapter,
198
+ })
199
+
200
+ return new Response<string>(context, rewrite(original, prompt, _options, context), (result) => result)
201
+ }
@@ -2,6 +2,10 @@
2
2
  import { z } from '@bpinternal/zui'
3
3
 
4
4
  import { chunk } from 'lodash-es'
5
+ import { ZaiContext } from '../context'
6
+ import { Response } from '../response'
7
+
8
+ import { getTokenizer } from '../tokenizer'
5
9
  import { Zai } from '../zai'
6
10
  import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
7
11
 
@@ -54,31 +58,31 @@ const Options = z.object({
54
58
  declare module '@botpress/zai' {
55
59
  interface Zai {
56
60
  /** Summarizes a text of any length to a summary of the desired length */
57
- summarize(original: string, options?: Options): Promise<string>
61
+ summarize(original: string, options?: Options): Response<string>
58
62
  }
59
63
  }
60
64
 
61
65
  const START = '■START■'
62
66
  const END = '■END■'
63
67
 
64
- Zai.prototype.summarize = async function (this: Zai, original, _options) {
65
- const options = Options.parse(_options ?? {}) as Options
66
- const tokenizer = await this.getTokenizer()
67
- await this.fetchModelDetails()
68
+ const summarize = async (original: string, options: Options, ctx: ZaiContext): Promise<string> => {
69
+ ctx.controller.signal.throwIfAborted()
70
+ const tokenizer = await getTokenizer()
71
+ const model = await ctx.getModel()
68
72
 
69
- const INPUT_COMPONENT_SIZE = Math.max(100, (this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER) / 4)
73
+ const INPUT_COMPONENT_SIZE = Math.max(100, (model.input.maxTokens - PROMPT_INPUT_BUFFER) / 4)
70
74
  options.prompt = tokenizer.truncate(options.prompt, INPUT_COMPONENT_SIZE)
71
75
  options.format = tokenizer.truncate(options.format, INPUT_COMPONENT_SIZE)
72
76
 
73
- const maxOutputSize = this.ModelDetails.output.maxTokens - PROMPT_OUTPUT_BUFFER
77
+ const maxOutputSize = model.output.maxTokens - PROMPT_OUTPUT_BUFFER
74
78
  if (options.length > maxOutputSize) {
75
79
  throw new Error(
76
- `The desired output length is ${maxOutputSize} tokens long, which is more than the maximum of ${this.ModelDetails.output.maxTokens} tokens for this model (${this.ModelDetails.name})`
80
+ `The desired output length is ${maxOutputSize} tokens long, which is more than the maximum of ${model.output.maxTokens} tokens for this model (${model.name})`
77
81
  )
78
82
  }
79
83
 
80
84
  // Ensure the sliding window is not bigger than the model input size
81
- options.sliding.window = Math.min(options.sliding.window, this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER)
85
+ options.sliding.window = Math.min(options.sliding.window, model.input.maxTokens - PROMPT_INPUT_BUFFER)
82
86
 
83
87
  // Ensure the overlap is not bigger than the window
84
88
  // Most extreme case possible (all 3 same size)
@@ -111,9 +115,12 @@ ${newText}
111
115
  const chunkSize = Math.ceil(tokens.length / (parts * N))
112
116
 
113
117
  if (useMergeSort) {
118
+ // TODO: use pLimit here to not have too many chunks
114
119
  const chunks = chunk(tokens, chunkSize).map((x) => x.join(''))
115
- const allSummaries = await Promise.all(chunks.map((chunk) => this.summarize(chunk, options)))
116
- return this.summarize(allSummaries.join('\n\n============\n\n'), options)
120
+ const allSummaries = (await Promise.allSettled(chunks.map((chunk) => summarize(chunk, options, ctx))))
121
+ .filter((x) => x.status === 'fulfilled')
122
+ .map((x) => x.value)
123
+ return summarize(allSummaries.join('\n\n============\n\n'), options, ctx)
117
124
  }
118
125
 
119
126
  const summaries: string[] = []
@@ -176,7 +183,7 @@ ${newText}
176
183
  }
177
184
  }
178
185
 
179
- const { output } = await this.callModel({
186
+ let { extracted: result } = await ctx.generateContent({
180
187
  systemPrompt: `
181
188
  You are summarizing a text. The text is split into ${parts} parts, and you are currently working on part ${iteration}.
182
189
  At every step, you will receive the current summary and a new part of the text. You need to amend the summary to include the new information (if needed).
@@ -191,9 +198,14 @@ ${options.format}
191
198
  messages: [{ type: 'text', content: format(currentSummary, slice), role: 'user' }],
192
199
  maxTokens: generationLength,
193
200
  stopSequences: [END],
194
- })
201
+ transform: (text) => {
202
+ if (!text.trim().length) {
203
+ throw new Error('The model did not return a valid summary. The response was empty.')
204
+ }
195
205
 
196
- let result = output?.choices[0]?.content as string
206
+ return text
207
+ },
208
+ })
197
209
 
198
210
  if (result.includes(START)) {
199
211
  result = result.slice(result.indexOf(START) + START.length)
@@ -210,3 +222,17 @@ ${options.format}
210
222
 
211
223
  return currentSummary.trim()
212
224
  }
225
+
226
+ Zai.prototype.summarize = function (this: Zai, original, _options): Response<string> {
227
+ const options = Options.parse(_options ?? {}) as Options
228
+
229
+ const context = new ZaiContext({
230
+ client: this.client,
231
+ modelId: this.Model,
232
+ taskId: this.taskId,
233
+ taskType: 'summarize',
234
+ adapter: this.adapter,
235
+ })
236
+
237
+ return new Response<string, string>(context, summarize(original, options, context), (value) => value)
238
+ }
@@ -2,6 +2,9 @@
2
2
  import { z } from '@bpinternal/zui'
3
3
 
4
4
  import { clamp } from 'lodash-es'
5
+ import { ZaiContext } from '../context'
6
+ import { Response } from '../response'
7
+ import { getTokenizer } from '../tokenizer'
5
8
  import { Zai } from '../zai'
6
9
  import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
7
10
 
@@ -17,19 +20,20 @@ const Options = z.object({
17
20
  declare module '@botpress/zai' {
18
21
  interface Zai {
19
22
  /** Generates a text of the desired length according to the prompt */
20
- text(prompt: string, options?: Options): Promise<string>
23
+ text(prompt: string, options?: Options): Response<string>
21
24
  }
22
25
  }
23
26
 
24
- Zai.prototype.text = async function (this: Zai, prompt, _options) {
27
+ const text = async (prompt: string, _options: Options | undefined, ctx: ZaiContext): Promise<string> => {
28
+ ctx.controller.signal.throwIfAborted()
25
29
  const options = Options.parse(_options ?? {})
26
- const tokenizer = await this.getTokenizer()
27
- await this.fetchModelDetails()
30
+ const tokenizer = await getTokenizer()
31
+ const model = await ctx.getModel()
28
32
 
29
- prompt = tokenizer.truncate(prompt, Math.max(this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER, 100))
33
+ prompt = tokenizer.truncate(prompt, Math.max(model.input.maxTokens - PROMPT_INPUT_BUFFER, 100))
30
34
 
31
35
  if (options.length) {
32
- options.length = Math.min(this.ModelDetails.output.maxTokens - PROMPT_OUTPUT_BUFFER, options.length)
36
+ options.length = Math.min(model.output.maxTokens - PROMPT_OUTPUT_BUFFER, options.length)
33
37
  }
34
38
 
35
39
  const instructions: string[] = []
@@ -55,7 +59,7 @@ Zai.prototype.text = async function (this: Zai, prompt, _options) {
55
59
  | 300-500 tokens| A long paragraph (200-300 words) |`.trim()
56
60
  }
57
61
 
58
- const { output } = await this.callModel({
62
+ const { extracted } = await ctx.generateContent({
59
63
  systemPrompt: `
60
64
  Generate a text that fulfills the user prompt below. Answer directly to the prompt, without any acknowledgements or fluff. Also, make sure the text is standalone and complete.
61
65
  ${instructions.map((x) => `- ${x}`).join('\n')}
@@ -64,6 +68,26 @@ ${chart}
64
68
  temperature: 0.7,
65
69
  messages: [{ type: 'text', content: prompt, role: 'user' }],
66
70
  maxTokens: options.length,
71
+ transform: (text) => {
72
+ if (!text.trim().length) {
73
+ throw new Error('The model did not return a valid summary. The response was empty.')
74
+ }
75
+
76
+ return text
77
+ },
78
+ })
79
+
80
+ return extracted
81
+ }
82
+
83
+ Zai.prototype.text = function (this: Zai, prompt: string, _options?: Options): Response<string> {
84
+ const context = new ZaiContext({
85
+ client: this.client,
86
+ modelId: this.Model,
87
+ taskId: this.taskId,
88
+ taskType: 'zai.text',
89
+ adapter: this.adapter,
67
90
  })
68
- return output?.choices?.[0]?.content! as string
91
+
92
+ return new Response<string>(context, text(prompt, _options, context), (result) => result)
69
93
  }
@@ -0,0 +1,114 @@
1
+ import { Usage, ZaiContext } from './context'
2
+ import { EventEmitter } from './emitter'
3
+
4
+ // Event types for the Response class
5
+ export type ResponseEvents<TComplete = any> = {
6
+ progress: Usage
7
+ complete: TComplete
8
+ error: unknown
9
+ }
10
+
11
+ export class Response<T = any, S = T> implements PromiseLike<S> {
12
+ private _promise: Promise<T>
13
+ private _eventEmitter: EventEmitter<ResponseEvents<T>>
14
+ private _context: ZaiContext
15
+ private _elasped: number | null = null
16
+ private _simplify: (value: T) => S
17
+
18
+ public constructor(context: ZaiContext, promise: Promise<T>, simplify: (value: T) => S) {
19
+ this._context = context
20
+ this._eventEmitter = new EventEmitter<ResponseEvents<T>>()
21
+ this._simplify = simplify
22
+ this._promise = promise.then(
23
+ (value) => {
24
+ this._elasped ||= this._context.elapsedTime
25
+ this._eventEmitter.emit('complete', value)
26
+ this._eventEmitter.clear()
27
+ this._context.clear()
28
+ return value
29
+ },
30
+ (reason) => {
31
+ this._elasped ||= this._context.elapsedTime
32
+ this._eventEmitter.emit('error', reason)
33
+ this._eventEmitter.clear()
34
+ this._context.clear()
35
+ throw reason
36
+ }
37
+ )
38
+
39
+ this._context.on('update', (usage) => {
40
+ this._eventEmitter.emit('progress', usage)
41
+ })
42
+ }
43
+
44
+ // Event emitter methods
45
+ public on<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void) {
46
+ this._eventEmitter.on(type, listener)
47
+ return this
48
+ }
49
+
50
+ public off<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void) {
51
+ this._eventEmitter.off(type, listener)
52
+ return this
53
+ }
54
+
55
+ public once<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void) {
56
+ this._eventEmitter.once(type, listener)
57
+ return this
58
+ }
59
+
60
+ public bindSignal(signal: AbortSignal): this {
61
+ if (signal.aborted) {
62
+ this.abort(signal.reason)
63
+ }
64
+
65
+ const signalAbort = () => {
66
+ this.abort(signal.reason)
67
+ }
68
+
69
+ signal.addEventListener('abort', () => signalAbort())
70
+
71
+ this.once('complete', () => signal.removeEventListener('abort', signalAbort))
72
+ this.once('error', () => signal.removeEventListener('abort', signalAbort))
73
+
74
+ return this
75
+ }
76
+
77
+ public abort(reason?: string | Error) {
78
+ this._context.controller.abort(reason)
79
+ }
80
+
81
+ public then<TResult1 = S, TResult2 = never>(
82
+ onfulfilled?: ((value: S) => TResult1 | PromiseLike<TResult1>) | null,
83
+ onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null
84
+ ): PromiseLike<TResult1 | TResult2> {
85
+ return this._promise.then(
86
+ (value: T) => {
87
+ const simplified = this._simplify(value)
88
+ return onfulfilled ? onfulfilled(simplified) : simplified
89
+ },
90
+ (reason) => {
91
+ if (onrejected) {
92
+ return onrejected(reason)
93
+ }
94
+ throw reason
95
+ }
96
+ ) as PromiseLike<TResult1 | TResult2>
97
+ }
98
+
99
+ public catch<TResult = never>(
100
+ onrejected?: ((reason: any) => TResult | PromiseLike<TResult>) | null
101
+ ): PromiseLike<S | TResult> {
102
+ return this._promise.catch(onrejected) as PromiseLike<S | TResult>
103
+ }
104
+
105
+ public async result(): Promise<{
106
+ output: T
107
+ usage: Usage
108
+ elapsed: number
109
+ }> {
110
+ const output = await this._promise
111
+ const usage = this._context.usage
112
+ return { output, usage, elapsed: this._elasped }
113
+ }
114
+ }
@@ -0,0 +1,14 @@
1
+ import { getWasmTokenizer, TextTokenizer } from '@bpinternal/thicktoken'
2
+
3
+ let tokenizer: TextTokenizer | null = null
4
+
5
+ export async function getTokenizer(): Promise<TextTokenizer> {
6
+ if (!tokenizer) {
7
+ while (!getWasmTokenizer) {
8
+ // there's an issue with wasm, it doesn't load immediately
9
+ await new Promise((resolve) => setTimeout(resolve, 25))
10
+ }
11
+ tokenizer = getWasmTokenizer() as TextTokenizer
12
+ }
13
+ return tokenizer
14
+ }