@botpress/zai 2.5.17 → 2.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/context.js CHANGED
@@ -1,4 +1,5 @@
1
1
  import { EventEmitter } from "./emitter";
2
+ import { fastHash } from "./utils";
2
3
  export class ZaiContext {
3
4
  _startedAt = Date.now();
4
5
  _inputCost = 0;
@@ -15,8 +16,10 @@ export class ZaiContext {
15
16
  adapter;
16
17
  source;
17
18
  _eventEmitter;
19
+ _memoizer;
18
20
  controller = new AbortController();
19
21
  _client;
22
+ static _noopMemoizer = { run: (_id, fn) => fn() };
20
23
  constructor(props) {
21
24
  this._client = props.client.clone();
22
25
  this.taskId = props.taskId;
@@ -24,6 +27,7 @@ export class ZaiContext {
24
27
  this.adapter = props.adapter;
25
28
  this.source = props.source;
26
29
  this.taskType = props.taskType;
30
+ this._memoizer = props.memoizer ?? ZaiContext._noopMemoizer;
27
31
  this._eventEmitter = new EventEmitter();
28
32
  this._client.on("request", () => {
29
33
  this._totalRequests++;
@@ -57,6 +61,16 @@ export class ZaiContext {
57
61
  this._eventEmitter.clear();
58
62
  }
59
63
  async generateContent(props) {
64
+ const memoKey = `zai:memo:${this.taskType}:${this.taskId || "default"}:${fastHash(
65
+ JSON.stringify({
66
+ s: props.systemPrompt,
67
+ m: props.messages?.map((m) => "content" in m ? m.content : ""),
68
+ st: props.stopSequences
69
+ })
70
+ )}`;
71
+ return this._memoizer.run(memoKey, () => this._generateContentInner(props));
72
+ }
73
+ async _generateContentInner(props) {
60
74
  const maxRetries = Math.max(props.maxRetries ?? 3, 0);
61
75
  const transform = props.transform;
62
76
  let lastError = null;
package/dist/index.d.ts CHANGED
@@ -41,6 +41,16 @@ declare abstract class Adapter {
41
41
  abstract saveExample<TInput, TOutput>(props: SaveExampleProps<TInput, TOutput>): Promise<void>;
42
42
  }
43
43
 
44
+ /**
45
+ * A memoizer that caches the result of async operations by a unique key.
46
+ *
47
+ * When used with the Botpress ADK workflow `step` function, this enables
48
+ * Zai operations to resume where they left off if a workflow is interrupted.
49
+ *
50
+ */
51
+ type Memoizer = {
52
+ run: <T>(id: string, fn: () => Promise<T>) => Promise<T>;
53
+ };
44
54
  /**
45
55
  * Active learning configuration for improving AI operations over time.
46
56
  *
@@ -98,6 +108,16 @@ type ZaiConfig = {
98
108
  activeLearning?: ActiveLearning;
99
109
  /** Namespace for organizing tasks (default: 'zai') */
100
110
  namespace?: string;
111
+ /**
112
+ * Memoizer (or factory returning one) for caching cognitive call results.
113
+ *
114
+ * When provided, all LLM calls are wrapped in the memoizer, allowing results
115
+ * to be cached and replayed. This is useful for resuming workflow runs where
116
+ * Zai operations have already completed their cognitive calls.
117
+ *
118
+ * If a factory function is provided, it is called once per Zai operation invocation.
119
+ */
120
+ memoize?: Memoizer | (() => Memoizer);
101
121
  };
102
122
  /**
103
123
  * Zai - A type-safe LLM utility library for production-ready AI operations.
@@ -171,6 +191,7 @@ declare class Zai {
171
191
  protected namespace: string;
172
192
  protected adapter: Adapter;
173
193
  protected activeLearning: ActiveLearning;
194
+ protected _memoize?: Memoizer | (() => Memoizer);
174
195
  /**
175
196
  * Creates a new Zai instance with the specified configuration.
176
197
  *
@@ -195,6 +216,8 @@ declare class Zai {
195
216
  constructor(config: ZaiConfig);
196
217
  /** @internal */
197
218
  protected callModel(props: Parameters<Cognitive['generateContent']>[0]): ReturnType<Cognitive['generateContent']>;
219
+ /** @internal */
220
+ protected _resolveMemoizer(): Memoizer | undefined;
198
221
  protected getTokenizer(): Promise<TextTokenizer>;
199
222
  protected fetchModelDetails(): Promise<void>;
200
223
  protected get taskId(): string;
@@ -299,6 +322,7 @@ type ZaiContextProps = {
299
322
  modelId: string;
300
323
  adapter?: Adapter;
301
324
  source?: GenerateContentInput['meta'];
325
+ memoizer?: Memoizer;
302
326
  };
303
327
  /**
304
328
  * Usage statistics tracking tokens, cost, and request metrics for an operation.
@@ -370,8 +394,10 @@ declare class ZaiContext {
370
394
  adapter?: Adapter;
371
395
  source?: GenerateContentInput['meta'];
372
396
  private _eventEmitter;
397
+ private _memoizer;
373
398
  controller: AbortController;
374
399
  private _client;
400
+ private static _noopMemoizer;
375
401
  constructor(props: ZaiContextProps);
376
402
  getModel(): Promise<Model>;
377
403
  on<K extends keyof ContextEvents>(type: K, listener: (event: ContextEvents[K]) => void): this;
@@ -382,6 +408,7 @@ declare class ZaiContext {
382
408
  text: string | undefined;
383
409
  extracted: Out;
384
410
  }>;
411
+ private _generateContentInner;
385
412
  get elapsedTime(): number;
386
413
  get usage(): Usage;
387
414
  }
@@ -1390,6 +1417,8 @@ type Options$4 = {
1390
1417
  tokensPerElement?: number;
1391
1418
  chunkLength?: number;
1392
1419
  initialGroups?: Array<InitialGroup>;
1420
+ maxGroups?: number;
1421
+ minElements?: number;
1393
1422
  };
1394
1423
  declare module '@botpress/zai' {
1395
1424
  interface Zai {
@@ -1402,6 +1431,8 @@ declare module '@botpress/zai' {
1402
1431
  *
1403
1432
  * @param input - Array of items to group
1404
1433
  * @param options - Configuration for grouping behavior, instructions, and initial categories
1434
+ * @param options.maxGroups - Maximum number of groups allowed (minimum 2). When set, groups are merged at the end until within limit.
1435
+ * @param options.minElements - Minimum elements per group (minimum 1). Groups below this threshold have their elements redistributed via AI.
1405
1436
  * @returns Response with groups array (simplified to Record<groupLabel, items[]>)
1406
1437
  *
1407
1438
  * @example Automatic grouping
@@ -1543,6 +1574,18 @@ declare module '@botpress/zai' {
1543
1574
  * })
1544
1575
  * ```
1545
1576
  */
1577
+ /**
1578
+ * @example Limiting number of groups
1579
+ * ```typescript
1580
+ * const items = ['apple', 'banana', 'carrot', 'chicken', 'rice', 'bread', 'salmon', 'milk']
1581
+ *
1582
+ * const groups = await zai.group(items, {
1583
+ * instructions: 'Group by food type',
1584
+ * maxGroups: 3 // At most 3 groups — smallest groups get merged if exceeded
1585
+ * })
1586
+ * // Guarantees no more than 3 groups in the result
1587
+ * ```
1588
+ */
1546
1589
  group<T>(input: Array<T>, options?: Options$4): Response<Array<Group<T>>, Record<string, T[]>>;
1547
1590
  }
1548
1591
  }
@@ -2127,4 +2170,4 @@ declare module '@botpress/zai' {
2127
2170
  }
2128
2171
  }
2129
2172
 
2130
- export { Zai };
2173
+ export { type Memoizer, Zai };
@@ -373,7 +373,8 @@ Zai.prototype.answer = function(documents, question, _options) {
373
373
  modelId: this.Model,
374
374
  taskId: this.taskId,
375
375
  taskType: "zai.answer",
376
- adapter: this.adapter
376
+ adapter: this.adapter,
377
+ memoizer: this._resolveMemoizer()
377
378
  });
378
379
  return new Response(
379
380
  context,
@@ -181,7 +181,8 @@ Zai.prototype.check = function(input, condition, _options) {
181
181
  modelId: this.Model,
182
182
  taskId: this.taskId,
183
183
  taskType: "zai.check",
184
- adapter: this.adapter
184
+ adapter: this.adapter,
185
+ memoizer: this._resolveMemoizer()
185
186
  });
186
187
  return new Response(context, check(input, condition, options, context), (result) => result.value);
187
188
  };
@@ -313,7 +313,8 @@ Zai.prototype.extract = function(input, schema, _options) {
313
313
  modelId: this.Model,
314
314
  taskId: this.taskId,
315
315
  taskType: "zai.extract",
316
- adapter: this.adapter
316
+ adapter: this.adapter,
317
+ memoizer: this._resolveMemoizer()
317
318
  });
318
319
  return new Response(context, extract(input, schema, _options, context), (result) => result);
319
320
  };
@@ -202,7 +202,8 @@ Zai.prototype.filter = function(input, condition, _options) {
202
202
  modelId: this.Model,
203
203
  taskId: this.taskId,
204
204
  taskType: "zai.filter",
205
- adapter: this.adapter
205
+ adapter: this.adapter,
206
+ memoizer: this._resolveMemoizer()
206
207
  });
207
208
  return new Response(context, filter(input, condition, _options, context), (result) => result);
208
209
  };
@@ -16,7 +16,9 @@ const _Options = z.object({
16
16
  instructions: z.string().optional(),
17
17
  tokensPerElement: z.number().min(1).max(1e5).optional().default(250),
18
18
  chunkLength: z.number().min(100).max(1e5).optional().default(16e3),
19
- initialGroups: z.array(_InitialGroup).optional().default([])
19
+ initialGroups: z.array(_InitialGroup).optional().default([]),
20
+ maxGroups: z.number().min(2).optional(),
21
+ minElements: z.number().min(1).optional()
20
22
  });
21
23
  const END = "\u25A0END\u25A0";
22
24
  const normalizeLabel = (label) => {
@@ -301,6 +303,191 @@ ${END}`.trim();
301
303
  groupElements.get(finalGroupId).add(elementIndex);
302
304
  }
303
305
  }
306
+ if (options.maxGroups !== void 0) {
307
+ const nonEmptyGroupIds = () => Array.from(groupElements.entries()).filter(([, s]) => s.size > 0).map(([id]) => id);
308
+ let currentIds = nonEmptyGroupIds();
309
+ if (currentIds.length > options.maxGroups) {
310
+ const groupSummaries = currentIds.map((gid, idx) => {
311
+ const info = groups.get(gid);
312
+ const elemIndices = Array.from(groupElements.get(gid));
313
+ const sampleElements = elemIndices.slice(0, 3).map((i) => tokenizer.truncate(elements[i].stringified, 60)).join(", ");
314
+ return `\u25A0${idx}:${info.label} (${elemIndices.length} elements, e.g. ${sampleElements})\u25A0`;
315
+ });
316
+ const mergeSystemPrompt = `You are consolidating groups into fewer, broader categories.
317
+
318
+ ${options.instructions ? `**Original instructions:** ${options.instructions}
319
+ ` : ""}
320
+ **Task:** Merge ${currentIds.length} groups down to at most ${options.maxGroups} groups.
321
+ Combine the most semantically related groups together. Give each merged group a new descriptive label.
322
+
323
+ **Output Format:**
324
+ For each input group (\u25A00 to \u25A0${currentIds.length - 1}), output which target label it maps to:
325
+ \u25A00:Merged Label\u25A0
326
+ \u25A01:Merged Label\u25A0
327
+ ${END}
328
+
329
+ Use the EXACT SAME label for groups that should be merged together.`.trim();
330
+ const mergeUserPrompt = `**Current groups:**
331
+ ${groupSummaries.join("\n")}
332
+
333
+ Merge into at most ${options.maxGroups} groups.
334
+ ${END}`.trim();
335
+ const { extracted: mergeAssignments } = await ctx.generateContent({
336
+ systemPrompt: mergeSystemPrompt,
337
+ stopSequences: [END],
338
+ messages: [{ type: "text", role: "user", content: mergeUserPrompt }],
339
+ transform: (text) => {
340
+ const assignments = [];
341
+ const regex = /■(\d+):([^■]+)■/g;
342
+ let match;
343
+ while ((match = regex.exec(text)) !== null) {
344
+ const idx = parseInt(match[1] ?? "", 10);
345
+ if (isNaN(idx) || idx < 0 || idx >= currentIds.length) continue;
346
+ const label = (match[2] ?? "").trim();
347
+ if (!label) continue;
348
+ assignments.push({ sourceIdx: idx, label: label.slice(0, 250) });
349
+ }
350
+ return assignments;
351
+ }
352
+ });
353
+ const mergeMap = /* @__PURE__ */ new Map();
354
+ for (const { sourceIdx, label } of mergeAssignments) {
355
+ const sourceGid = currentIds[sourceIdx];
356
+ if (!sourceGid) continue;
357
+ const normalized = normalizeLabel(label);
358
+ if (!mergeMap.has(normalized)) {
359
+ mergeMap.set(normalized, { label, sourceGroupIds: [] });
360
+ }
361
+ mergeMap.get(normalized).sourceGroupIds.push(sourceGid);
362
+ }
363
+ for (const [, { label, sourceGroupIds }] of mergeMap) {
364
+ if (sourceGroupIds.length <= 1) continue;
365
+ const targetGid = sourceGroupIds[0];
366
+ const targetSet = groupElements.get(targetGid);
367
+ const targetInfo = groups.get(targetGid);
368
+ targetInfo.label = label;
369
+ targetInfo.normalizedLabel = normalizeLabel(label);
370
+ for (let i = 1; i < sourceGroupIds.length; i++) {
371
+ const sourceGid = sourceGroupIds[i];
372
+ const sourceSet = groupElements.get(sourceGid);
373
+ sourceSet.forEach((elemIdx) => targetSet.add(elemIdx));
374
+ sourceSet.clear();
375
+ }
376
+ }
377
+ currentIds = nonEmptyGroupIds();
378
+ while (currentIds.length > options.maxGroups) {
379
+ currentIds.sort((a, b) => groupElements.get(a).size - groupElements.get(b).size);
380
+ const sourceSet = groupElements.get(currentIds[0]);
381
+ const targetSet = groupElements.get(currentIds[1]);
382
+ for (const elemIdx of sourceSet) {
383
+ targetSet.add(elemIdx);
384
+ }
385
+ sourceSet.clear();
386
+ currentIds = nonEmptyGroupIds();
387
+ }
388
+ }
389
+ }
390
+ if (options.minElements !== void 0 && options.minElements > 1) {
391
+ const getNonEmptyGroupIds = () => Array.from(groupElements.entries()).filter(([, s]) => s.size > 0).map(([id]) => id);
392
+ const orphanIndices = [];
393
+ for (const gid of getNonEmptyGroupIds()) {
394
+ const elemSet = groupElements.get(gid);
395
+ if (elemSet.size > 0 && elemSet.size < options.minElements) {
396
+ for (const idx of elemSet) {
397
+ orphanIndices.push(idx);
398
+ }
399
+ elemSet.clear();
400
+ }
401
+ }
402
+ if (orphanIndices.length > 0) {
403
+ const validGroupIds = getNonEmptyGroupIds();
404
+ const orphanChunks = [];
405
+ let currentOrphanChunk = [];
406
+ let currentOrphanTokens = 0;
407
+ for (const elemIdx of orphanIndices) {
408
+ const elem = elements[elemIdx];
409
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement);
410
+ const elemTokens = tokenizer.count(truncated);
411
+ if ((currentOrphanTokens + elemTokens > TOKENS_FOR_ELEMENTS_MAX || currentOrphanChunk.length >= MAX_ELEMENTS_PER_CHUNK) && currentOrphanChunk.length > 0) {
412
+ orphanChunks.push(currentOrphanChunk);
413
+ currentOrphanChunk = [];
414
+ currentOrphanTokens = 0;
415
+ }
416
+ currentOrphanChunk.push(elemIdx);
417
+ currentOrphanTokens += elemTokens;
418
+ }
419
+ if (currentOrphanChunk.length > 0) {
420
+ orphanChunks.push(currentOrphanChunk);
421
+ }
422
+ const orphanResults = await Promise.all(
423
+ orphanChunks.map(
424
+ (chunk) => elementLimit(async () => {
425
+ const groupChunksForOrphans = validGroupIds.length > 0 ? getGroupChunks() : [[]];
426
+ const allAssignments = await Promise.all(
427
+ groupChunksForOrphans.filter((gc) => gc.length === 0 || gc.some((gid) => validGroupIds.includes(gid))).map((groupChunk) => {
428
+ const filteredGroupChunk = groupChunk.filter((gid) => validGroupIds.includes(gid));
429
+ return groupLimit(() => processChunk(chunk, filteredGroupChunk));
430
+ })
431
+ );
432
+ return allAssignments.flat();
433
+ })
434
+ )
435
+ );
436
+ const flatAssignments = orphanResults.flat();
437
+ for (const { elementIndex, label } of flatAssignments) {
438
+ const normalized = normalizeLabel(label);
439
+ let groupId = labelToGroupId.get(normalized);
440
+ if (!groupId) {
441
+ groupId = `group_${groupIdCounter++}`;
442
+ groups.set(groupId, { id: groupId, label, normalizedLabel: normalized });
443
+ groupElements.set(groupId, /* @__PURE__ */ new Set());
444
+ labelToGroupId.set(normalized, groupId);
445
+ }
446
+ groupElements.get(groupId).add(elementIndex);
447
+ }
448
+ const isAssigned = (idx) => {
449
+ for (const [, elemSet] of groupElements) {
450
+ if (elemSet.has(idx)) return true;
451
+ }
452
+ return false;
453
+ };
454
+ const unassigned = orphanIndices.filter((idx) => !isAssigned(idx));
455
+ const placeIntoLargest = (indices) => {
456
+ const allNonEmpty = getNonEmptyGroupIds();
457
+ if (allNonEmpty.length === 0) return;
458
+ const largestGid = allNonEmpty.reduce(
459
+ (a, b) => groupElements.get(a).size >= groupElements.get(b).size ? a : b
460
+ );
461
+ for (const idx of indices) {
462
+ groupElements.get(largestGid).add(idx);
463
+ }
464
+ };
465
+ if (unassigned.length > 0) {
466
+ placeIntoLargest(unassigned);
467
+ }
468
+ const mergeUndersizedGroups = () => {
469
+ const allNonEmpty = getNonEmptyGroupIds();
470
+ if (allNonEmpty.length <= 1) return false;
471
+ const largestGid = allNonEmpty.reduce(
472
+ (a, b) => groupElements.get(a).size >= groupElements.get(b).size ? a : b
473
+ );
474
+ const targetSet = groupElements.get(largestGid);
475
+ let merged = false;
476
+ for (const gid of allNonEmpty) {
477
+ if (gid === largestGid) continue;
478
+ const elemSet = groupElements.get(gid);
479
+ if (elemSet.size > 0 && elemSet.size < options.minElements) {
480
+ elemSet.forEach((idx) => targetSet.add(idx));
481
+ elemSet.clear();
482
+ merged = true;
483
+ }
484
+ }
485
+ return merged;
486
+ };
487
+ while (mergeUndersizedGroups()) {
488
+ }
489
+ }
490
+ }
304
491
  const result = [];
305
492
  for (const [groupId, elementIndices] of groupElements.entries()) {
306
493
  if (elementIndices.size > 0) {
@@ -354,7 +541,8 @@ Zai.prototype.group = function(input, _options) {
354
541
  modelId: this.Model,
355
542
  taskId: this.taskId,
356
543
  taskType: "zai.group",
357
- adapter: this.adapter
544
+ adapter: this.adapter,
545
+ memoizer: this._resolveMemoizer()
358
546
  });
359
547
  return new Response(context, group(input, _options, context), (result) => {
360
548
  const merged = {};
@@ -276,7 +276,8 @@ Zai.prototype.label = function(input, labels, _options) {
276
276
  modelId: this.Model,
277
277
  taskId: this.taskId,
278
278
  taskType: "zai.label",
279
- adapter: this.adapter
279
+ adapter: this.adapter,
280
+ memoizer: this._resolveMemoizer()
280
281
  });
281
282
  return new Response(
282
283
  context,
@@ -392,7 +392,8 @@ Zai.prototype.patch = function(files, instructions, _options) {
392
392
  modelId: this.Model,
393
393
  taskId: this.taskId,
394
394
  taskType: "zai.patch",
395
- adapter: this.adapter
395
+ adapter: this.adapter,
396
+ memoizer: this._resolveMemoizer()
396
397
  });
397
398
  return new Response(context, patch(files, instructions, _options, context), (result) => result);
398
399
  };
@@ -335,7 +335,8 @@ Zai.prototype.rate = function(input, instructions, _options) {
335
335
  modelId: this.Model,
336
336
  taskId: this.taskId,
337
337
  taskType: "zai.rate",
338
- adapter: this.adapter
338
+ adapter: this.adapter,
339
+ memoizer: this._resolveMemoizer()
339
340
  });
340
341
  return new Response(
341
342
  context,
@@ -136,7 +136,8 @@ Zai.prototype.rewrite = function(original, prompt, _options) {
136
136
  modelId: this.Model,
137
137
  taskId: this.taskId,
138
138
  taskType: "zai.rewrite",
139
- adapter: this.adapter
139
+ adapter: this.adapter,
140
+ memoizer: this._resolveMemoizer()
140
141
  });
141
142
  return new Response(context, rewrite(original, prompt, _options, context), (result) => result);
142
143
  };
@@ -511,7 +511,8 @@ Zai.prototype.sort = function(input, instructions, _options) {
511
511
  modelId: this.Model,
512
512
  taskId: this.taskId,
513
513
  taskType: "zai.sort",
514
- adapter: this.adapter
514
+ adapter: this.adapter,
515
+ memoizer: this._resolveMemoizer()
515
516
  });
516
517
  return new Response(
517
518
  context,
@@ -148,7 +148,8 @@ Zai.prototype.summarize = function(original, _options) {
148
148
  modelId: this.Model,
149
149
  taskId: this.taskId,
150
150
  taskType: "summarize",
151
- adapter: this.adapter
151
+ adapter: this.adapter,
152
+ memoizer: this._resolveMemoizer()
152
153
  });
153
154
  return new Response(context, summarize(original, options, context), (value) => value);
154
155
  };
@@ -60,7 +60,8 @@ Zai.prototype.text = function(prompt, _options) {
60
60
  modelId: this.Model,
61
61
  taskId: this.taskId,
62
62
  taskType: "zai.text",
63
- adapter: this.adapter
63
+ adapter: this.adapter,
64
+ memoizer: this._resolveMemoizer()
64
65
  });
65
66
  return new Response(context, text(prompt, _options, context), (result) => result);
66
67
  };
package/dist/zai.js CHANGED
@@ -47,6 +47,7 @@ export class Zai {
47
47
  namespace;
48
48
  adapter;
49
49
  activeLearning;
50
+ _memoize;
50
51
  /**
51
52
  * Creates a new Zai instance with the specified configuration.
52
53
  *
@@ -80,6 +81,7 @@ export class Zai {
80
81
  client: this.client.client,
81
82
  tableName: parsed.activeLearning.tableName
82
83
  }) : new MemoryAdapter([]);
84
+ this._memoize = config.memoize;
83
85
  }
84
86
  /** @internal */
85
87
  async callModel(props) {
@@ -90,6 +92,13 @@ export class Zai {
90
92
  userId: this._userId
91
93
  });
92
94
  }
95
+ /** @internal */
96
+ _resolveMemoizer() {
97
+ if (!this._memoize) {
98
+ return void 0;
99
+ }
100
+ return typeof this._memoize === "function" ? this._memoize() : this._memoize;
101
+ }
93
102
  async getTokenizer() {
94
103
  Zai.tokenizer ??= await (async () => {
95
104
  while (!getWasmTokenizer) {