@botpress/zai 2.5.17 → 2.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@botpress/zai",
3
3
  "description": "Zui AI (zai) – An LLM utility library written on top of Zui and the Botpress API",
4
- "version": "2.5.17",
4
+ "version": "2.6.0",
5
5
  "main": "./dist/index.js",
6
6
  "types": "./dist/index.d.ts",
7
7
  "exports": {
package/src/context.ts CHANGED
@@ -1,6 +1,8 @@
1
1
  import { Cognitive, Model, GenerateContentInput, GenerateContentOutput } from '@botpress/cognitive'
2
2
  import { Adapter } from './adapters/adapter'
3
3
  import { EventEmitter } from './emitter'
4
+ import { fastHash } from './utils'
5
+ import type { Memoizer } from './zai'
4
6
 
5
7
  type Meta = Awaited<ReturnType<Cognitive['generateContent']>>['meta']
6
8
 
@@ -16,6 +18,7 @@ export type ZaiContextProps = {
16
18
  modelId: string
17
19
  adapter?: Adapter
18
20
  source?: GenerateContentInput['meta']
21
+ memoizer?: Memoizer
19
22
  }
20
23
 
21
24
  /**
@@ -94,10 +97,13 @@ export class ZaiContext {
94
97
  public source?: GenerateContentInput['meta']
95
98
 
96
99
  private _eventEmitter: EventEmitter<ContextEvents>
100
+ private _memoizer: Memoizer
97
101
 
98
102
  public controller: AbortController = new AbortController()
99
103
  private _client: Cognitive
100
104
 
105
+ private static _noopMemoizer: Memoizer = { run: (_id, fn) => fn() }
106
+
101
107
  public constructor(props: ZaiContextProps) {
102
108
  this._client = props.client.clone()
103
109
  this.taskId = props.taskId
@@ -105,6 +111,7 @@ export class ZaiContext {
105
111
  this.adapter = props.adapter
106
112
  this.source = props.source
107
113
  this.taskType = props.taskType
114
+ this._memoizer = props.memoizer ?? ZaiContext._noopMemoizer
108
115
  this._eventEmitter = new EventEmitter<ContextEvents>()
109
116
 
110
117
  this._client.on('request', () => {
@@ -148,6 +155,20 @@ export class ZaiContext {
148
155
 
149
156
  public async generateContent<Out = string>(
150
157
  props: GenerateContentProps<Out>
158
+ ): Promise<{ meta: Meta; output: GenerateContentOutput; text: string | undefined; extracted: Out }> {
159
+ const memoKey = `zai:memo:${this.taskType}:${this.taskId || 'default'}:${fastHash(
160
+ JSON.stringify({
161
+ s: props.systemPrompt,
162
+ m: props.messages?.map((m) => ('content' in m ? m.content : '')),
163
+ st: props.stopSequences,
164
+ })
165
+ )}`
166
+
167
+ return this._memoizer.run(memoKey, () => this._generateContentInner(props))
168
+ }
169
+
170
+ private async _generateContentInner<Out = string>(
171
+ props: GenerateContentProps<Out>
151
172
  ): Promise<{ meta: Meta; output: GenerateContentOutput; text: string | undefined; extracted: Out }> {
152
173
  const maxRetries = Math.max(props.maxRetries ?? 3, 0)
153
174
  const transform = props.transform
package/src/index.ts CHANGED
@@ -1,4 +1,4 @@
1
- import { Zai } from './zai'
1
+ import { Zai, type Memoizer } from './zai'
2
2
 
3
3
  import './operations/text'
4
4
  import './operations/rewrite'
@@ -14,3 +14,4 @@ import './operations/answer'
14
14
  import './operations/patch'
15
15
 
16
16
  export { Zai }
17
+ export type { Memoizer }
@@ -816,6 +816,7 @@ Zai.prototype.answer = function <T>(
816
816
  taskId: this.taskId,
817
817
  taskType: 'zai.answer',
818
818
  adapter: this.adapter,
819
+ memoizer: this._resolveMemoizer(),
819
820
  })
820
821
 
821
822
  return new Response<AnswerResult<T>, AnswerResult<T>>(
@@ -354,6 +354,7 @@ Zai.prototype.check = function (
354
354
  taskId: this.taskId,
355
355
  taskType: 'zai.check',
356
356
  adapter: this.adapter,
357
+ memoizer: this._resolveMemoizer(),
357
358
  })
358
359
 
359
360
  return new Response<
@@ -484,6 +484,7 @@ Zai.prototype.extract = function <S extends OfType<AnyObjectOrArray>>(
484
484
  taskId: this.taskId,
485
485
  taskType: 'zai.extract',
486
486
  adapter: this.adapter,
487
+ memoizer: this._resolveMemoizer(),
487
488
  })
488
489
 
489
490
  return new Response<S['_output']>(context, extract(input, schema, _options, context), (result) => result)
@@ -363,6 +363,7 @@ Zai.prototype.filter = function <T>(
363
363
  taskId: this.taskId,
364
364
  taskType: 'zai.filter',
365
365
  adapter: this.adapter,
366
+ memoizer: this._resolveMemoizer(),
366
367
  })
367
368
 
368
369
  return new Response<Array<T>>(context, filter(input, condition, _options, context), (result) => result)
@@ -32,6 +32,8 @@ export type Options = {
32
32
  tokensPerElement?: number
33
33
  chunkLength?: number
34
34
  initialGroups?: Array<InitialGroup>
35
+ maxGroups?: number
36
+ minElements?: number
35
37
  }
36
38
 
37
39
  const _Options = z.object({
@@ -39,6 +41,8 @@ const _Options = z.object({
39
41
  tokensPerElement: z.number().min(1).max(100_000).optional().default(250),
40
42
  chunkLength: z.number().min(100).max(100_000).optional().default(16_000),
41
43
  initialGroups: z.array(_InitialGroup).optional().default([]),
44
+ maxGroups: z.number().min(2).optional(),
45
+ minElements: z.number().min(1).optional(),
42
46
  })
43
47
 
44
48
  declare module '@botpress/zai' {
@@ -52,6 +56,8 @@ declare module '@botpress/zai' {
52
56
  *
53
57
  * @param input - Array of items to group
54
58
  * @param options - Configuration for grouping behavior, instructions, and initial categories
59
+ * @param options.maxGroups - Maximum number of groups allowed (minimum 2). When set, groups are merged at the end until within limit.
60
+ * @param options.minElements - Minimum elements per group (minimum 1). Groups below this threshold have their elements redistributed via AI.
55
61
  * @returns Response with groups array (simplified to Record<groupLabel, items[]>)
56
62
  *
57
63
  * @example Automatic grouping
@@ -193,6 +199,18 @@ declare module '@botpress/zai' {
193
199
  * })
194
200
  * ```
195
201
  */
202
+ /**
203
+ * @example Limiting number of groups
204
+ * ```typescript
205
+ * const items = ['apple', 'banana', 'carrot', 'chicken', 'rice', 'bread', 'salmon', 'milk']
206
+ *
207
+ * const groups = await zai.group(items, {
208
+ * instructions: 'Group by food type',
209
+ * maxGroups: 3 // At most 3 groups — smallest groups get merged if exceeded
210
+ * })
211
+ * // Guarantees no more than 3 groups in the result
212
+ * ```
213
+ */
196
214
  group<T>(input: Array<T>, options?: Options): Response<Array<Group<T>>, Record<string, T[]>>
197
215
  }
198
216
  }
@@ -609,6 +627,265 @@ ${END}`.trim()
609
627
  }
610
628
  }
611
629
 
630
+ // Phase 4: Merge groups if maxGroups is set (AI-driven)
631
+ if (options.maxGroups !== undefined) {
632
+ const nonEmptyGroupIds = () =>
633
+ Array.from(groupElements.entries())
634
+ .filter(([, s]) => s.size > 0)
635
+ .map(([id]) => id)
636
+
637
+ let currentIds = nonEmptyGroupIds()
638
+
639
+ if (currentIds.length > options.maxGroups) {
640
+ // Build a summary of each group: label + element count + sample elements
641
+ const groupSummaries = currentIds.map((gid, idx) => {
642
+ const info = groups.get(gid)!
643
+ const elemIndices = Array.from(groupElements.get(gid)!)
644
+ const sampleElements = elemIndices
645
+ .slice(0, 3)
646
+ .map((i) => tokenizer.truncate(elements[i].stringified, 60))
647
+ .join(', ')
648
+ return `■${idx}:${info.label} (${elemIndices.length} elements, e.g. ${sampleElements})■`
649
+ })
650
+
651
+ const mergeSystemPrompt = `You are consolidating groups into fewer, broader categories.
652
+
653
+ ${options.instructions ? `**Original instructions:** ${options.instructions}\n` : ''}
654
+ **Task:** Merge ${currentIds.length} groups down to at most ${options.maxGroups} groups.
655
+ Combine the most semantically related groups together. Give each merged group a new descriptive label.
656
+
657
+ **Output Format:**
658
+ For each input group (■0 to ■${currentIds.length - 1}), output which target label it maps to:
659
+ ■0:Merged Label■
660
+ ■1:Merged Label■
661
+ ${END}
662
+
663
+ Use the EXACT SAME label for groups that should be merged together.`.trim()
664
+
665
+ const mergeUserPrompt = `**Current groups:**
666
+ ${groupSummaries.join('\n')}
667
+
668
+ Merge into at most ${options.maxGroups} groups.
669
+ ${END}`.trim()
670
+
671
+ const { extracted: mergeAssignments } = await ctx.generateContent({
672
+ systemPrompt: mergeSystemPrompt,
673
+ stopSequences: [END],
674
+ messages: [{ type: 'text', role: 'user', content: mergeUserPrompt }],
675
+ transform: (text) => {
676
+ const assignments: Array<{ sourceIdx: number; label: string }> = []
677
+ const regex = /■(\d+):([^■]+)■/g
678
+ let match: RegExpExecArray | null
679
+
680
+ while ((match = regex.exec(text)) !== null) {
681
+ const idx = parseInt(match[1] ?? '', 10)
682
+ if (isNaN(idx) || idx < 0 || idx >= currentIds.length) continue
683
+
684
+ const label = (match[2] ?? '').trim()
685
+ if (!label) continue
686
+
687
+ assignments.push({ sourceIdx: idx, label: label.slice(0, 250) })
688
+ }
689
+
690
+ return assignments
691
+ },
692
+ })
693
+
694
+ // Build merge map: normalized merge label → list of source group IDs
695
+ const mergeMap = new Map<string, { label: string; sourceGroupIds: string[] }>()
696
+
697
+ for (const { sourceIdx, label } of mergeAssignments) {
698
+ const sourceGid = currentIds[sourceIdx]
699
+ if (!sourceGid) continue
700
+
701
+ const normalized = normalizeLabel(label)
702
+ if (!mergeMap.has(normalized)) {
703
+ mergeMap.set(normalized, { label, sourceGroupIds: [] })
704
+ }
705
+ mergeMap.get(normalized)!.sourceGroupIds.push(sourceGid)
706
+ }
707
+
708
+ // Apply merges: for each merge target, pick the first source group as the target
709
+ // and move all elements from other source groups into it
710
+ for (const [, { label, sourceGroupIds }] of mergeMap) {
711
+ if (sourceGroupIds.length <= 1) continue
712
+
713
+ const targetGid = sourceGroupIds[0]
714
+ const targetSet = groupElements.get(targetGid)!
715
+
716
+ // Update label on the target group
717
+ const targetInfo = groups.get(targetGid)!
718
+ targetInfo.label = label
719
+ targetInfo.normalizedLabel = normalizeLabel(label)
720
+
721
+ for (let i = 1; i < sourceGroupIds.length; i++) {
722
+ const sourceGid = sourceGroupIds[i]
723
+ const sourceSet = groupElements.get(sourceGid)!
724
+ sourceSet.forEach((elemIdx) => targetSet.add(elemIdx))
725
+ sourceSet.clear()
726
+ }
727
+ }
728
+
729
+ // Safety: if LLM still produced too many groups, fall back to merging smallest pairs
730
+ currentIds = nonEmptyGroupIds()
731
+ while (currentIds.length > options.maxGroups) {
732
+ currentIds.sort((a, b) => groupElements.get(a)!.size - groupElements.get(b)!.size)
733
+
734
+ const sourceSet = groupElements.get(currentIds[0])!
735
+ const targetSet = groupElements.get(currentIds[1])!
736
+ for (const elemIdx of sourceSet) {
737
+ targetSet.add(elemIdx)
738
+ }
739
+ sourceSet.clear()
740
+
741
+ currentIds = nonEmptyGroupIds()
742
+ }
743
+ }
744
+ }
745
+
746
+ // Phase 5: Redistribute undersized groups if minElements is set
747
+ // Reuses processChunk so orphans see the valid groups as available buckets
748
+ if (options.minElements !== undefined && options.minElements > 1) {
749
+ const getNonEmptyGroupIds = () =>
750
+ Array.from(groupElements.entries())
751
+ .filter(([, s]) => s.size > 0)
752
+ .map(([id]) => id)
753
+
754
+ // Collect orphan elements from all undersized groups
755
+ const orphanIndices: number[] = []
756
+
757
+ for (const gid of getNonEmptyGroupIds()) {
758
+ const elemSet = groupElements.get(gid)!
759
+ if (elemSet.size > 0 && elemSet.size < options.minElements) {
760
+ for (const idx of elemSet) {
761
+ orphanIndices.push(idx)
762
+ }
763
+ elemSet.clear()
764
+ }
765
+ }
766
+
767
+ if (orphanIndices.length > 0) {
768
+ // Valid groups = everything that's still non-empty (i.e. above minElements)
769
+ const validGroupIds = getNonEmptyGroupIds()
770
+
771
+ // Chunk orphans and run them through processChunk with only valid groups visible
772
+ const orphanChunks: number[][] = []
773
+ let currentOrphanChunk: number[] = []
774
+ let currentOrphanTokens = 0
775
+
776
+ for (const elemIdx of orphanIndices) {
777
+ const elem = elements[elemIdx]
778
+ const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement)
779
+ const elemTokens = tokenizer.count(truncated)
780
+
781
+ if (
782
+ (currentOrphanTokens + elemTokens > TOKENS_FOR_ELEMENTS_MAX ||
783
+ currentOrphanChunk.length >= MAX_ELEMENTS_PER_CHUNK) &&
784
+ currentOrphanChunk.length > 0
785
+ ) {
786
+ orphanChunks.push(currentOrphanChunk)
787
+ currentOrphanChunk = []
788
+ currentOrphanTokens = 0
789
+ }
790
+
791
+ currentOrphanChunk.push(elemIdx)
792
+ currentOrphanTokens += elemTokens
793
+ }
794
+
795
+ if (currentOrphanChunk.length > 0) {
796
+ orphanChunks.push(currentOrphanChunk)
797
+ }
798
+
799
+ // Process orphan chunks against valid groups (reuses the same processChunk as Phase 1)
800
+ const orphanResults = await Promise.all(
801
+ orphanChunks.map((chunk) =>
802
+ elementLimit(async () => {
803
+ // If there are valid groups, chunk them; otherwise pass empty so LLM creates new groups
804
+ const groupChunksForOrphans = validGroupIds.length > 0 ? getGroupChunks() : [[]]
805
+
806
+ const allAssignments = await Promise.all(
807
+ groupChunksForOrphans
808
+ .filter((gc) => gc.length === 0 || gc.some((gid) => validGroupIds.includes(gid)))
809
+ .map((groupChunk) => {
810
+ // Only show valid groups (exclude the orphaned/undersized ones)
811
+ const filteredGroupChunk = groupChunk.filter((gid) => validGroupIds.includes(gid))
812
+ return groupLimit(() => processChunk(chunk, filteredGroupChunk))
813
+ })
814
+ )
815
+
816
+ return allAssignments.flat()
817
+ })
818
+ )
819
+ )
820
+
821
+ // Apply assignments
822
+ const flatAssignments = orphanResults.flat()
823
+ for (const { elementIndex, label } of flatAssignments) {
824
+ const normalized = normalizeLabel(label)
825
+ let groupId = labelToGroupId.get(normalized)
826
+
827
+ if (!groupId) {
828
+ groupId = `group_${groupIdCounter++}`
829
+ groups.set(groupId, { id: groupId, label, normalizedLabel: normalized })
830
+ groupElements.set(groupId, new Set())
831
+ labelToGroupId.set(normalized, groupId)
832
+ }
833
+ groupElements.get(groupId)!.add(elementIndex)
834
+ }
835
+
836
+ // Safety: any orphans the LLM missed get placed into the largest group
837
+ const isAssigned = (idx: number) => {
838
+ for (const [, elemSet] of groupElements) {
839
+ if (elemSet.has(idx)) return true
840
+ }
841
+ return false
842
+ }
843
+ const unassigned = orphanIndices.filter((idx) => !isAssigned(idx))
844
+ const placeIntoLargest = (indices: number[]) => {
845
+ const allNonEmpty = getNonEmptyGroupIds()
846
+ if (allNonEmpty.length === 0) return
847
+ const largestGid = allNonEmpty.reduce((a, b) =>
848
+ groupElements.get(a)!.size >= groupElements.get(b)!.size ? a : b
849
+ )
850
+ for (const idx of indices) {
851
+ groupElements.get(largestGid)!.add(idx)
852
+ }
853
+ }
854
+
855
+ if (unassigned.length > 0) {
856
+ placeIntoLargest(unassigned)
857
+ }
858
+
859
+ // Second pass: if any groups are still undersized after redistribution,
860
+ // merge their elements into the largest group
861
+ const mergeUndersizedGroups = () => {
862
+ const allNonEmpty = getNonEmptyGroupIds()
863
+ if (allNonEmpty.length <= 1) return false
864
+
865
+ const largestGid = allNonEmpty.reduce((a, b) =>
866
+ groupElements.get(a)!.size >= groupElements.get(b)!.size ? a : b
867
+ )
868
+ const targetSet = groupElements.get(largestGid)!
869
+ let merged = false
870
+
871
+ for (const gid of allNonEmpty) {
872
+ if (gid === largestGid) continue
873
+ const elemSet = groupElements.get(gid)!
874
+ if (elemSet.size > 0 && elemSet.size < options.minElements) {
875
+ elemSet.forEach((idx) => targetSet.add(idx))
876
+ elemSet.clear()
877
+ merged = true
878
+ }
879
+ }
880
+ return merged
881
+ }
882
+
883
+ while (mergeUndersizedGroups()) {
884
+ // keep merging until no undersized groups remain
885
+ }
886
+ }
887
+ }
888
+
612
889
  // Build final result
613
890
  const result: Array<Group<T>> = []
614
891
 
@@ -678,6 +955,7 @@ Zai.prototype.group = function <T>(
678
955
  taskId: this.taskId,
679
956
  taskType: 'zai.group',
680
957
  adapter: this.adapter,
958
+ memoizer: this._resolveMemoizer(),
681
959
  })
682
960
 
683
961
  return new Response<Array<Group<T>>, Record<string, T[]>>(context, group(input, _options, context), (result) => {
@@ -542,6 +542,7 @@ Zai.prototype.label = function <T extends string>(
542
542
  taskId: this.taskId,
543
543
  taskType: 'zai.label',
544
544
  adapter: this.adapter,
545
+ memoizer: this._resolveMemoizer(),
545
546
  })
546
547
 
547
548
  return new Response<
@@ -650,6 +650,7 @@ Zai.prototype.patch = function (
650
650
  taskId: this.taskId,
651
651
  taskType: 'zai.patch',
652
652
  adapter: this.adapter,
653
+ memoizer: this._resolveMemoizer(),
653
654
  })
654
655
 
655
656
  return new Response<Array<File>>(context, patch(files, instructions, _options, context), (result) => result)
@@ -611,6 +611,7 @@ Zai.prototype.rate = function <T, I extends RatingInstructions>(
611
611
  taskId: this.taskId,
612
612
  taskType: 'zai.rate',
613
613
  adapter: this.adapter,
614
+ memoizer: this._resolveMemoizer(),
614
615
  })
615
616
 
616
617
  return new Response<Array<RatingResult<I>>, Array<SimplifiedRatingResult<I>>>(
@@ -277,6 +277,7 @@ Zai.prototype.rewrite = function (this: Zai, original: string, prompt: string, _
277
277
  taskId: this.taskId,
278
278
  taskType: 'zai.rewrite',
279
279
  adapter: this.adapter,
280
+ memoizer: this._resolveMemoizer(),
280
281
  })
281
282
 
282
283
  return new Response<string>(context, rewrite(original, prompt, _options, context), (result) => result)
@@ -800,6 +800,7 @@ Zai.prototype.sort = function <T>(
800
800
  taskId: this.taskId,
801
801
  taskType: 'zai.sort',
802
802
  adapter: this.adapter,
803
+ memoizer: this._resolveMemoizer(),
803
804
  })
804
805
 
805
806
  return new Response<Array<T>, Array<T>>(
@@ -306,6 +306,7 @@ Zai.prototype.summarize = function (this: Zai, original, _options): Response<str
306
306
  taskId: this.taskId,
307
307
  taskType: 'summarize',
308
308
  adapter: this.adapter,
309
+ memoizer: this._resolveMemoizer(),
309
310
  })
310
311
 
311
312
  return new Response<string, string>(context, summarize(original, options, context), (value) => value)
@@ -135,6 +135,7 @@ Zai.prototype.text = function (this: Zai, prompt: string, _options?: Options): R
135
135
  taskId: this.taskId,
136
136
  taskType: 'zai.text',
137
137
  adapter: this.adapter,
138
+ memoizer: this._resolveMemoizer(),
138
139
  })
139
140
 
140
141
  return new Response<string>(context, text(prompt, _options, context), (result) => result)
package/src/zai.ts CHANGED
@@ -8,6 +8,17 @@ import { Adapter } from './adapters/adapter'
8
8
  import { TableAdapter } from './adapters/botpress-table'
9
9
  import { MemoryAdapter } from './adapters/memory'
10
10
 
11
+ /**
12
+ * A memoizer that caches the result of async operations by a unique key.
13
+ *
14
+ * When used with the Botpress ADK workflow `step` function, this enables
15
+ * Zai operations to resume where they left off if a workflow is interrupted.
16
+ *
17
+ */
18
+ export type Memoizer = {
19
+ run: <T>(id: string, fn: () => Promise<T>) => Promise<T>
20
+ }
21
+
11
22
  /**
12
23
  * Active learning configuration for improving AI operations over time.
13
24
  *
@@ -86,6 +97,16 @@ type ZaiConfig = {
86
97
  activeLearning?: ActiveLearning
87
98
  /** Namespace for organizing tasks (default: 'zai') */
88
99
  namespace?: string
100
+ /**
101
+ * Memoizer (or factory returning one) for caching cognitive call results.
102
+ *
103
+ * When provided, all LLM calls are wrapped in the memoizer, allowing results
104
+ * to be cached and replayed. This is useful for resuming workflow runs where
105
+ * Zai operations have already completed their cognitive calls.
106
+ *
107
+ * If a factory function is provided, it is called once per Zai operation invocation.
108
+ */
109
+ memoize?: Memoizer | (() => Memoizer)
89
110
  }
90
111
 
91
112
  const _ZaiConfig = z.object({
@@ -195,6 +216,7 @@ export class Zai {
195
216
  protected namespace: string
196
217
  protected adapter: Adapter
197
218
  protected activeLearning: ActiveLearning
219
+ protected _memoize?: Memoizer | (() => Memoizer)
198
220
 
199
221
  /**
200
222
  * Creates a new Zai instance with the specified configuration.
@@ -236,6 +258,8 @@ export class Zai {
236
258
  tableName: parsed.activeLearning.tableName,
237
259
  })
238
260
  : new MemoryAdapter([])
261
+
262
+ this._memoize = config.memoize
239
263
  }
240
264
 
241
265
  /** @internal */
@@ -250,6 +274,14 @@ export class Zai {
250
274
  })
251
275
  }
252
276
 
277
+ /** @internal */
278
+ protected _resolveMemoizer(): Memoizer | undefined {
279
+ if (!this._memoize) {
280
+ return undefined
281
+ }
282
+ return typeof this._memoize === 'function' ? this._memoize() : this._memoize
283
+ }
284
+
253
285
  protected async getTokenizer() {
254
286
  Zai.tokenizer ??= await (async () => {
255
287
  while (!getWasmTokenizer) {