@botpress/zai 2.0.15 → 2.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,10 @@
1
1
  // eslint-disable consistent-type-definitions
2
2
  import { z } from '@bpinternal/zui'
3
3
 
4
- import { clamp, chunk } from 'lodash-es'
4
+ import { chunk, clamp } from 'lodash-es'
5
+ import { ZaiContext } from '../context'
6
+ import { Response } from '../response'
7
+ import { getTokenizer } from '../tokenizer'
5
8
  import { fastHash, stringify, takeUntilTokens } from '../utils'
6
9
  import { Zai } from '../zai'
7
10
  import { PROMPT_INPUT_BUFFER } from './constants'
@@ -83,13 +86,16 @@ declare module '@botpress/zai' {
83
86
  input: unknown,
84
87
  labels: Labels<T>,
85
88
  options?: Options<T>
86
- ): Promise<{
87
- [K in T]: {
88
- explanation: string
89
- value: boolean
90
- confidence: number
91
- }
92
- }>
89
+ ): Response<
90
+ {
91
+ [K in T]: {
92
+ explanation: string
93
+ value: boolean
94
+ confidence: number
95
+ }
96
+ },
97
+ { [K in T]: boolean }
98
+ >
93
99
  }
94
100
  }
95
101
 
@@ -124,21 +130,28 @@ const getConfidence = (label: Label) => {
124
130
  }
125
131
  }
126
132
 
127
- Zai.prototype.label = async function <T extends string>(
128
- this: Zai,
133
+ const label = async <T extends string>(
129
134
  input: unknown,
130
135
  _labels: Labels<T>,
131
- _options: Options<T> | undefined
132
- ) {
136
+ _options: Options<T> | undefined,
137
+ ctx: ZaiContext
138
+ ): Promise<{
139
+ [K in T]: {
140
+ explanation: string
141
+ value: boolean
142
+ confidence: number
143
+ }
144
+ }> => {
145
+ ctx.controller.signal.throwIfAborted()
133
146
  const options = _Options.parse(_options ?? {}) as unknown as Options<T>
134
147
  const labels = _Labels.parse(_labels) as Labels<T>
135
- const tokenizer = await this.getTokenizer()
136
- await this.fetchModelDetails()
148
+ const tokenizer = await getTokenizer()
149
+ const model = await ctx.getModel()
137
150
 
138
- const taskId = this.taskId
151
+ const taskId = ctx.taskId
139
152
  const taskType = 'zai.label'
140
153
 
141
- const TOTAL_MAX_TOKENS = clamp(options.chunkLength, 1000, this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER)
154
+ const TOTAL_MAX_TOKENS = clamp(options.chunkLength, 1000, model.input.maxTokens - PROMPT_INPUT_BUFFER)
142
155
  const CHUNK_EXAMPLES_MAX_TOKENS = clamp(Math.floor(TOTAL_MAX_TOKENS * 0.5), 250, 10_000)
143
156
  const CHUNK_INPUT_MAX_TOKENS = clamp(
144
157
  TOTAL_MAX_TOKENS - CHUNK_EXAMPLES_MAX_TOKENS,
@@ -151,7 +164,7 @@ Zai.prototype.label = async function <T extends string>(
151
164
  if (tokenizer.count(inputAsString) > CHUNK_INPUT_MAX_TOKENS) {
152
165
  const tokens = tokenizer.split(inputAsString)
153
166
  const chunks = chunk(tokens, CHUNK_INPUT_MAX_TOKENS).map((x) => x.join(''))
154
- const allLabels = await Promise.all(chunks.map((chunk) => this.label(chunk, _labels)))
167
+ const allLabels = await Promise.all(chunks.map((chunk) => label(chunk, _labels, _options, ctx)))
155
168
 
156
169
  // Merge all the labels together (those who are true will remain true)
157
170
  return allLabels.reduce((acc, x) => {
@@ -202,21 +215,22 @@ Zai.prototype.label = async function <T extends string>(
202
215
  }
203
216
  }
204
217
 
205
- const examples = taskId
206
- ? await this.adapter.getExamples<
207
- string,
208
- {
209
- [K in T]: {
210
- explanation: string
211
- label: Label
218
+ const examples =
219
+ taskId && ctx.adapter
220
+ ? await ctx.adapter.getExamples<
221
+ string,
222
+ {
223
+ [K in T]: {
224
+ explanation: string
225
+ label: Label
226
+ }
212
227
  }
213
- }
214
- >({
215
- input: inputAsString,
216
- taskType,
217
- taskId,
218
- })
219
- : []
228
+ >({
229
+ input: inputAsString,
230
+ taskType,
231
+ taskId,
232
+ })
233
+ : []
220
234
 
221
235
  options.examples.forEach((example) => {
222
236
  examples.push({
@@ -285,7 +299,7 @@ ${END}
285
299
  })
286
300
  .join('\n\n')
287
301
 
288
- const { output, meta } = await this.callModel({
302
+ const { extracted, meta } = await ctx.generateContent({
289
303
  stopSequences: [END],
290
304
  systemPrompt: `
291
305
  You need to tag the input with the following labels based on the question asked:
@@ -336,35 +350,33 @@ The Expert Examples are there to help you make your decision. They have been pro
336
350
  For example, you can say: "According to Expert Example #1, ..."`.trim(),
337
351
  },
338
352
  ],
353
+ transform: (text) =>
354
+ Object.keys(labels).reduce((acc, key) => {
355
+ const match = text.match(new RegExp(`■${key}:【(.+)】:(\\w{2,})■`, 'i'))
356
+ if (match) {
357
+ const explanation = match[1].trim()
358
+ const label = parseLabel(match[2])
359
+ acc[key] = {
360
+ explanation,
361
+ label,
362
+ }
363
+ } else {
364
+ acc[key] = {
365
+ explanation: '',
366
+ label: LABELS.AMBIGUOUS,
367
+ }
368
+ }
369
+ return acc
370
+ }, {}) as {
371
+ [K in T]: {
372
+ explanation: string
373
+ label: Label
374
+ }
375
+ },
339
376
  })
340
377
 
341
- const answer = output.choices[0].content as string
342
-
343
- const final = Object.keys(labels).reduce((acc, key) => {
344
- const match = answer.match(new RegExp(`■${key}:【(.+)】:(\\w{2,})■`, 'i'))
345
- if (match) {
346
- const explanation = match[1].trim()
347
- const label = parseLabel(match[2])
348
- acc[key] = {
349
- explanation,
350
- label,
351
- }
352
- } else {
353
- acc[key] = {
354
- explanation: '',
355
- label: LABELS.AMBIGUOUS,
356
- }
357
- }
358
- return acc
359
- }, {}) as {
360
- [K in T]: {
361
- explanation: string
362
- label: Label
363
- }
364
- }
365
-
366
- if (taskId) {
367
- await this.adapter.saveExample({
378
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
379
+ await ctx.adapter.saveExample({
368
380
  key: Key,
369
381
  taskType,
370
382
  taskId,
@@ -375,16 +387,59 @@ For example, you can say: "According to Expert Example #1, ..."`.trim(),
375
387
  output: meta.cost.output,
376
388
  },
377
389
  latency: meta.latency,
378
- model: this.Model,
390
+ model: ctx.modelId,
379
391
  tokens: {
380
392
  input: meta.tokens.input,
381
393
  output: meta.tokens.output,
382
394
  },
383
395
  },
384
396
  input: inputAsString,
385
- output: final,
397
+ output: extracted,
386
398
  })
387
399
  }
388
400
 
389
- return convertToAnswer(final)
401
+ return convertToAnswer(extracted)
402
+ }
403
+
404
+ Zai.prototype.label = function <T extends string>(
405
+ this: Zai,
406
+ input: unknown,
407
+ labels: Labels<T>,
408
+ _options?: Options<T>
409
+ ): Response<
410
+ {
411
+ [K in T]: {
412
+ explanation: string
413
+ value: boolean
414
+ confidence: number
415
+ }
416
+ },
417
+ { [K in T]: boolean }
418
+ > {
419
+ const context = new ZaiContext({
420
+ client: this.client,
421
+ modelId: this.Model,
422
+ taskId: this.taskId,
423
+ taskType: 'zai.label',
424
+ adapter: this.adapter,
425
+ })
426
+
427
+ return new Response<
428
+ {
429
+ [K in T]: {
430
+ explanation: string
431
+ value: boolean
432
+ confidence: number
433
+ }
434
+ },
435
+ { [K in T]: boolean }
436
+ >(context, label(input, labels, _options, context), (result) =>
437
+ Object.keys(result).reduce(
438
+ (acc, key) => {
439
+ acc[key] = result[key].value
440
+ return acc
441
+ },
442
+ {} as { [K in T]: boolean }
443
+ )
444
+ )
390
445
  }
@@ -1,6 +1,9 @@
1
1
  // eslint-disable consistent-type-definitions
2
2
  import { z } from '@bpinternal/zui'
3
3
 
4
+ import { ZaiContext } from '../context'
5
+ import { Response } from '../response'
6
+ import { getTokenizer } from '../tokenizer'
4
7
  import { fastHash, stringify, takeUntilTokens } from '../utils'
5
8
  import { Zai } from '../zai'
6
9
  import { PROMPT_INPUT_BUFFER } from './constants'
@@ -31,29 +34,35 @@ const Options = z.object({
31
34
  declare module '@botpress/zai' {
32
35
  interface Zai {
33
36
  /** Rewrites a string according to match the prompt */
34
- rewrite(original: string, prompt: string, options?: Options): Promise<string>
37
+ rewrite(original: string, prompt: string, options?: Options): Response<string>
35
38
  }
36
39
  }
37
40
 
38
41
  const START = '■START■'
39
42
  const END = '■END■'
40
43
 
41
- Zai.prototype.rewrite = async function (this: Zai, original, prompt, _options) {
44
+ const rewrite = async (
45
+ original: string,
46
+ prompt: string,
47
+ _options: Options | undefined,
48
+ ctx: ZaiContext
49
+ ): Promise<string> => {
50
+ ctx.controller.signal.throwIfAborted()
42
51
  const options = Options.parse(_options ?? {}) as Options
43
- const tokenizer = await this.getTokenizer()
44
- await this.fetchModelDetails()
52
+ const tokenizer = await getTokenizer()
53
+ const model = await ctx.getModel()
45
54
 
46
- const taskId = this.taskId
55
+ const taskId = ctx.taskId
47
56
  const taskType = 'zai.rewrite'
48
57
 
49
- const INPUT_COMPONENT_SIZE = Math.max(100, (this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER) / 2)
58
+ const INPUT_COMPONENT_SIZE = Math.max(100, (model.input.maxTokens - PROMPT_INPUT_BUFFER) / 2)
50
59
  prompt = tokenizer.truncate(prompt, INPUT_COMPONENT_SIZE)
51
60
 
52
61
  const inputSize = tokenizer.count(original) + tokenizer.count(prompt)
53
- const maxInputSize = this.ModelDetails.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER
62
+ const maxInputSize = model.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER
54
63
  if (inputSize > maxInputSize) {
55
64
  throw new Error(
56
- `The input size is ${inputSize} tokens long, which is more than the maximum of ${maxInputSize} tokens for this model (${this.ModelDetails.name} = ${this.ModelDetails.input.maxTokens} tokens)`
65
+ `The input size is ${inputSize} tokens long, which is more than the maximum of ${maxInputSize} tokens for this model (${model.name} = ${model.input.maxTokens} tokens)`
57
66
  )
58
67
  }
59
68
 
@@ -98,13 +107,14 @@ ${END}
98
107
  { input: '1\n2\n3', output: '3\n2\n1', instructions: 'reverse the order' },
99
108
  ]
100
109
 
101
- const tableExamples = taskId
102
- ? await this.adapter.getExamples<string, string>({
103
- input: original,
104
- taskId,
105
- taskType,
106
- })
107
- : []
110
+ const tableExamples =
111
+ taskId && ctx.adapter
112
+ ? await ctx.adapter.getExamples<string, string>({
113
+ input: original,
114
+ taskId,
115
+ taskType,
116
+ })
117
+ : []
108
118
 
109
119
  const exactMatch = tableExamples.find((x) => x.key === Key)
110
120
  if (exactMatch) {
@@ -116,7 +126,7 @@ ${END}
116
126
  ...options.examples,
117
127
  ]
118
128
 
119
- const REMAINING_TOKENS = this.ModelDetails.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER
129
+ const REMAINING_TOKENS = model.input.maxTokens - tokenizer.count(prompt) - PROMPT_INPUT_BUFFER
120
130
  const examples = takeUntilTokens(
121
131
  savedExamples.length ? savedExamples : defaultExamples,
122
132
  REMAINING_TOKENS,
@@ -125,7 +135,7 @@ ${END}
125
135
  .map(formatExample)
126
136
  .flat()
127
137
 
128
- const { output, meta } = await this.callModel({
138
+ const { extracted, meta } = await ctx.generateContent({
129
139
  systemPrompt: `
130
140
  Rewrite the text between the ${START} and ${END} tags to match the user prompt.
131
141
  ${instructions.map((x) => `• ${x}`).join('\n')}
@@ -133,9 +143,16 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
133
143
  messages: [...examples, { type: 'text', content: format(original, prompt), role: 'user' }],
134
144
  maxTokens: options.length,
135
145
  stopSequences: [END],
146
+ transform: (text) => {
147
+ if (!text.trim().length) {
148
+ throw new Error('The model did not return a valid rewrite. The response was empty.')
149
+ }
150
+
151
+ return text
152
+ },
136
153
  })
137
154
 
138
- let result = output.choices[0]?.content as string
155
+ let result = extracted
139
156
 
140
157
  if (result.includes(START)) {
141
158
  result = result.slice(result.indexOf(START) + START.length)
@@ -145,8 +162,8 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
145
162
  result = result.slice(0, result.indexOf(END))
146
163
  }
147
164
 
148
- if (taskId) {
149
- await this.adapter.saveExample({
165
+ if (taskId && ctx.adapter && !ctx.controller.signal.aborted) {
166
+ await ctx.adapter.saveExample({
150
167
  key: Key,
151
168
  metadata: {
152
169
  cost: {
@@ -154,7 +171,7 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
154
171
  output: meta.cost.output,
155
172
  },
156
173
  latency: meta.latency,
157
- model: this.Model,
174
+ model: ctx.modelId,
158
175
  tokens: {
159
176
  input: meta.tokens.input,
160
177
  output: meta.tokens.output,
@@ -170,3 +187,15 @@ ${instructions.map((x) => `• ${x}`).join('\n')}
170
187
 
171
188
  return result
172
189
  }
190
+
191
+ Zai.prototype.rewrite = function (this: Zai, original: string, prompt: string, _options?: Options): Response<string> {
192
+ const context = new ZaiContext({
193
+ client: this.client,
194
+ modelId: this.Model,
195
+ taskId: this.taskId,
196
+ taskType: 'zai.rewrite',
197
+ adapter: this.adapter,
198
+ })
199
+
200
+ return new Response<string>(context, rewrite(original, prompt, _options, context), (result) => result)
201
+ }
@@ -2,6 +2,10 @@
2
2
  import { z } from '@bpinternal/zui'
3
3
 
4
4
  import { chunk } from 'lodash-es'
5
+ import { ZaiContext } from '../context'
6
+ import { Response } from '../response'
7
+
8
+ import { getTokenizer } from '../tokenizer'
5
9
  import { Zai } from '../zai'
6
10
  import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
7
11
 
@@ -54,31 +58,31 @@ const Options = z.object({
54
58
  declare module '@botpress/zai' {
55
59
  interface Zai {
56
60
  /** Summarizes a text of any length to a summary of the desired length */
57
- summarize(original: string, options?: Options): Promise<string>
61
+ summarize(original: string, options?: Options): Response<string>
58
62
  }
59
63
  }
60
64
 
61
65
  const START = '■START■'
62
66
  const END = '■END■'
63
67
 
64
- Zai.prototype.summarize = async function (this: Zai, original, _options) {
65
- const options = Options.parse(_options ?? {}) as Options
66
- const tokenizer = await this.getTokenizer()
67
- await this.fetchModelDetails()
68
+ const summarize = async (original: string, options: Options, ctx: ZaiContext): Promise<string> => {
69
+ ctx.controller.signal.throwIfAborted()
70
+ const tokenizer = await getTokenizer()
71
+ const model = await ctx.getModel()
68
72
 
69
- const INPUT_COMPONENT_SIZE = Math.max(100, (this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER) / 4)
73
+ const INPUT_COMPONENT_SIZE = Math.max(100, (model.input.maxTokens - PROMPT_INPUT_BUFFER) / 4)
70
74
  options.prompt = tokenizer.truncate(options.prompt, INPUT_COMPONENT_SIZE)
71
75
  options.format = tokenizer.truncate(options.format, INPUT_COMPONENT_SIZE)
72
76
 
73
- const maxOutputSize = this.ModelDetails.output.maxTokens - PROMPT_OUTPUT_BUFFER
77
+ const maxOutputSize = model.output.maxTokens - PROMPT_OUTPUT_BUFFER
74
78
  if (options.length > maxOutputSize) {
75
79
  throw new Error(
76
- `The desired output length is ${maxOutputSize} tokens long, which is more than the maximum of ${this.ModelDetails.output.maxTokens} tokens for this model (${this.ModelDetails.name})`
80
+ `The desired output length is ${maxOutputSize} tokens long, which is more than the maximum of ${model.output.maxTokens} tokens for this model (${model.name})`
77
81
  )
78
82
  }
79
83
 
80
84
  // Ensure the sliding window is not bigger than the model input size
81
- options.sliding.window = Math.min(options.sliding.window, this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER)
85
+ options.sliding.window = Math.min(options.sliding.window, model.input.maxTokens - PROMPT_INPUT_BUFFER)
82
86
 
83
87
  // Ensure the overlap is not bigger than the window
84
88
  // Most extreme case possible (all 3 same size)
@@ -111,9 +115,12 @@ ${newText}
111
115
  const chunkSize = Math.ceil(tokens.length / (parts * N))
112
116
 
113
117
  if (useMergeSort) {
118
+ // TODO: use pLimit here to not have too many chunks
114
119
  const chunks = chunk(tokens, chunkSize).map((x) => x.join(''))
115
- const allSummaries = await Promise.all(chunks.map((chunk) => this.summarize(chunk, options)))
116
- return this.summarize(allSummaries.join('\n\n============\n\n'), options)
120
+ const allSummaries = (await Promise.allSettled(chunks.map((chunk) => summarize(chunk, options, ctx))))
121
+ .filter((x) => x.status === 'fulfilled')
122
+ .map((x) => x.value)
123
+ return summarize(allSummaries.join('\n\n============\n\n'), options, ctx)
117
124
  }
118
125
 
119
126
  const summaries: string[] = []
@@ -176,7 +183,7 @@ ${newText}
176
183
  }
177
184
  }
178
185
 
179
- const { output } = await this.callModel({
186
+ let { extracted: result } = await ctx.generateContent({
180
187
  systemPrompt: `
181
188
  You are summarizing a text. The text is split into ${parts} parts, and you are currently working on part ${iteration}.
182
189
  At every step, you will receive the current summary and a new part of the text. You need to amend the summary to include the new information (if needed).
@@ -191,9 +198,14 @@ ${options.format}
191
198
  messages: [{ type: 'text', content: format(currentSummary, slice), role: 'user' }],
192
199
  maxTokens: generationLength,
193
200
  stopSequences: [END],
194
- })
201
+ transform: (text) => {
202
+ if (!text.trim().length) {
203
+ throw new Error('The model did not return a valid summary. The response was empty.')
204
+ }
195
205
 
196
- let result = output?.choices[0]?.content as string
206
+ return text
207
+ },
208
+ })
197
209
 
198
210
  if (result.includes(START)) {
199
211
  result = result.slice(result.indexOf(START) + START.length)
@@ -210,3 +222,17 @@ ${options.format}
210
222
 
211
223
  return currentSummary.trim()
212
224
  }
225
+
226
+ Zai.prototype.summarize = function (this: Zai, original, _options): Response<string> {
227
+ const options = Options.parse(_options ?? {}) as Options
228
+
229
+ const context = new ZaiContext({
230
+ client: this.client,
231
+ modelId: this.Model,
232
+ taskId: this.taskId,
233
+ taskType: 'summarize',
234
+ adapter: this.adapter,
235
+ })
236
+
237
+ return new Response<string, string>(context, summarize(original, options, context), (value) => value)
238
+ }
@@ -2,6 +2,9 @@
2
2
  import { z } from '@bpinternal/zui'
3
3
 
4
4
  import { clamp } from 'lodash-es'
5
+ import { ZaiContext } from '../context'
6
+ import { Response } from '../response'
7
+ import { getTokenizer } from '../tokenizer'
5
8
  import { Zai } from '../zai'
6
9
  import { PROMPT_INPUT_BUFFER, PROMPT_OUTPUT_BUFFER } from './constants'
7
10
 
@@ -17,19 +20,20 @@ const Options = z.object({
17
20
  declare module '@botpress/zai' {
18
21
  interface Zai {
19
22
  /** Generates a text of the desired length according to the prompt */
20
- text(prompt: string, options?: Options): Promise<string>
23
+ text(prompt: string, options?: Options): Response<string>
21
24
  }
22
25
  }
23
26
 
24
- Zai.prototype.text = async function (this: Zai, prompt, _options) {
27
+ const text = async (prompt: string, _options: Options | undefined, ctx: ZaiContext): Promise<string> => {
28
+ ctx.controller.signal.throwIfAborted()
25
29
  const options = Options.parse(_options ?? {})
26
- const tokenizer = await this.getTokenizer()
27
- await this.fetchModelDetails()
30
+ const tokenizer = await getTokenizer()
31
+ const model = await ctx.getModel()
28
32
 
29
- prompt = tokenizer.truncate(prompt, Math.max(this.ModelDetails.input.maxTokens - PROMPT_INPUT_BUFFER, 100))
33
+ prompt = tokenizer.truncate(prompt, Math.max(model.input.maxTokens - PROMPT_INPUT_BUFFER, 100))
30
34
 
31
35
  if (options.length) {
32
- options.length = Math.min(this.ModelDetails.output.maxTokens - PROMPT_OUTPUT_BUFFER, options.length)
36
+ options.length = Math.min(model.output.maxTokens - PROMPT_OUTPUT_BUFFER, options.length)
33
37
  }
34
38
 
35
39
  const instructions: string[] = []
@@ -55,7 +59,7 @@ Zai.prototype.text = async function (this: Zai, prompt, _options) {
55
59
  | 300-500 tokens| A long paragraph (200-300 words) |`.trim()
56
60
  }
57
61
 
58
- const { output } = await this.callModel({
62
+ const { extracted } = await ctx.generateContent({
59
63
  systemPrompt: `
60
64
  Generate a text that fulfills the user prompt below. Answer directly to the prompt, without any acknowledgements or fluff. Also, make sure the text is standalone and complete.
61
65
  ${instructions.map((x) => `- ${x}`).join('\n')}
@@ -64,6 +68,26 @@ ${chart}
64
68
  temperature: 0.7,
65
69
  messages: [{ type: 'text', content: prompt, role: 'user' }],
66
70
  maxTokens: options.length,
71
+ transform: (text) => {
72
+ if (!text.trim().length) {
73
+ throw new Error('The model did not return a valid summary. The response was empty.')
74
+ }
75
+
76
+ return text
77
+ },
78
+ })
79
+
80
+ return extracted
81
+ }
82
+
83
+ Zai.prototype.text = function (this: Zai, prompt: string, _options?: Options): Response<string> {
84
+ const context = new ZaiContext({
85
+ client: this.client,
86
+ modelId: this.Model,
87
+ taskId: this.taskId,
88
+ taskType: 'zai.text',
89
+ adapter: this.adapter,
67
90
  })
68
- return output?.choices?.[0]?.content! as string
91
+
92
+ return new Response<string>(context, text(prompt, _options, context), (result) => result)
69
93
  }
@@ -0,0 +1,114 @@
1
+ import { Usage, ZaiContext } from './context'
2
+ import { EventEmitter } from './emitter'
3
+
4
+ // Event types for the Response class
5
+ export type ResponseEvents<TComplete = any> = {
6
+ progress: Usage
7
+ complete: TComplete
8
+ error: unknown
9
+ }
10
+
11
+ export class Response<T = any, S = T> implements PromiseLike<S> {
12
+ private _promise: Promise<T>
13
+ private _eventEmitter: EventEmitter<ResponseEvents<T>>
14
+ private _context: ZaiContext
15
+ private _elasped: number | null = null
16
+ private _simplify: (value: T) => S
17
+
18
+ public constructor(context: ZaiContext, promise: Promise<T>, simplify: (value: T) => S) {
19
+ this._context = context
20
+ this._eventEmitter = new EventEmitter<ResponseEvents<T>>()
21
+ this._simplify = simplify
22
+ this._promise = promise.then(
23
+ (value) => {
24
+ this._elasped ||= this._context.elapsedTime
25
+ this._eventEmitter.emit('complete', value)
26
+ this._eventEmitter.clear()
27
+ this._context.clear()
28
+ return value
29
+ },
30
+ (reason) => {
31
+ this._elasped ||= this._context.elapsedTime
32
+ this._eventEmitter.emit('error', reason)
33
+ this._eventEmitter.clear()
34
+ this._context.clear()
35
+ throw reason
36
+ }
37
+ )
38
+
39
+ this._context.on('update', (usage) => {
40
+ this._eventEmitter.emit('progress', usage)
41
+ })
42
+ }
43
+
44
+ // Event emitter methods
45
+ public on<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void) {
46
+ this._eventEmitter.on(type, listener)
47
+ return this
48
+ }
49
+
50
+ public off<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void) {
51
+ this._eventEmitter.off(type, listener)
52
+ return this
53
+ }
54
+
55
+ public once<K extends keyof ResponseEvents<T>>(type: K, listener: (event: ResponseEvents<T>[K]) => void) {
56
+ this._eventEmitter.once(type, listener)
57
+ return this
58
+ }
59
+
60
+ public bindSignal(signal: AbortSignal): this {
61
+ if (signal.aborted) {
62
+ this.abort(signal.reason)
63
+ }
64
+
65
+ const signalAbort = () => {
66
+ this.abort(signal.reason)
67
+ }
68
+
69
+ signal.addEventListener('abort', () => signalAbort())
70
+
71
+ this.once('complete', () => signal.removeEventListener('abort', signalAbort))
72
+ this.once('error', () => signal.removeEventListener('abort', signalAbort))
73
+
74
+ return this
75
+ }
76
+
77
+ public abort(reason?: string | Error) {
78
+ this._context.controller.abort(reason)
79
+ }
80
+
81
+ public then<TResult1 = S, TResult2 = never>(
82
+ onfulfilled?: ((value: S) => TResult1 | PromiseLike<TResult1>) | null,
83
+ onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null
84
+ ): PromiseLike<TResult1 | TResult2> {
85
+ return this._promise.then(
86
+ (value: T) => {
87
+ const simplified = this._simplify(value)
88
+ return onfulfilled ? onfulfilled(simplified) : simplified
89
+ },
90
+ (reason) => {
91
+ if (onrejected) {
92
+ return onrejected(reason)
93
+ }
94
+ throw reason
95
+ }
96
+ ) as PromiseLike<TResult1 | TResult2>
97
+ }
98
+
99
+ public catch<TResult = never>(
100
+ onrejected?: ((reason: any) => TResult | PromiseLike<TResult>) | null
101
+ ): PromiseLike<S | TResult> {
102
+ return this._promise.catch(onrejected) as PromiseLike<S | TResult>
103
+ }
104
+
105
+ public async result(): Promise<{
106
+ output: T
107
+ usage: Usage
108
+ elapsed: number
109
+ }> {
110
+ const output = await this._promise
111
+ const usage = this._context.usage
112
+ return { output, usage, elapsed: this._elasped }
113
+ }
114
+ }