@botpress/zai 2.5.17 → 2.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/context.js +14 -0
- package/dist/index.d.ts +44 -1
- package/dist/operations/answer.js +2 -1
- package/dist/operations/check.js +2 -1
- package/dist/operations/extract.js +2 -1
- package/dist/operations/filter.js +2 -1
- package/dist/operations/group.js +190 -2
- package/dist/operations/label.js +2 -1
- package/dist/operations/patch.js +2 -1
- package/dist/operations/rate.js +2 -1
- package/dist/operations/rewrite.js +2 -1
- package/dist/operations/sort.js +2 -1
- package/dist/operations/summarize.js +2 -1
- package/dist/operations/text.js +2 -1
- package/dist/zai.js +9 -0
- package/e2e/data/cache.jsonl +70 -0
- package/package.json +1 -1
- package/src/context.ts +21 -0
- package/src/index.ts +2 -1
- package/src/operations/answer.ts +1 -0
- package/src/operations/check.ts +1 -0
- package/src/operations/extract.ts +1 -0
- package/src/operations/filter.ts +1 -0
- package/src/operations/group.ts +278 -0
- package/src/operations/label.ts +1 -0
- package/src/operations/patch.ts +1 -0
- package/src/operations/rate.ts +1 -0
- package/src/operations/rewrite.ts +1 -0
- package/src/operations/sort.ts +1 -0
- package/src/operations/summarize.ts +1 -0
- package/src/operations/text.ts +1 -0
- package/src/zai.ts +32 -0
package/package.json
CHANGED
package/src/context.ts
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import { Cognitive, Model, GenerateContentInput, GenerateContentOutput } from '@botpress/cognitive'
|
|
2
2
|
import { Adapter } from './adapters/adapter'
|
|
3
3
|
import { EventEmitter } from './emitter'
|
|
4
|
+
import { fastHash } from './utils'
|
|
5
|
+
import type { Memoizer } from './zai'
|
|
4
6
|
|
|
5
7
|
type Meta = Awaited<ReturnType<Cognitive['generateContent']>>['meta']
|
|
6
8
|
|
|
@@ -16,6 +18,7 @@ export type ZaiContextProps = {
|
|
|
16
18
|
modelId: string
|
|
17
19
|
adapter?: Adapter
|
|
18
20
|
source?: GenerateContentInput['meta']
|
|
21
|
+
memoizer?: Memoizer
|
|
19
22
|
}
|
|
20
23
|
|
|
21
24
|
/**
|
|
@@ -94,10 +97,13 @@ export class ZaiContext {
|
|
|
94
97
|
public source?: GenerateContentInput['meta']
|
|
95
98
|
|
|
96
99
|
private _eventEmitter: EventEmitter<ContextEvents>
|
|
100
|
+
private _memoizer: Memoizer
|
|
97
101
|
|
|
98
102
|
public controller: AbortController = new AbortController()
|
|
99
103
|
private _client: Cognitive
|
|
100
104
|
|
|
105
|
+
private static _noopMemoizer: Memoizer = { run: (_id, fn) => fn() }
|
|
106
|
+
|
|
101
107
|
public constructor(props: ZaiContextProps) {
|
|
102
108
|
this._client = props.client.clone()
|
|
103
109
|
this.taskId = props.taskId
|
|
@@ -105,6 +111,7 @@ export class ZaiContext {
|
|
|
105
111
|
this.adapter = props.adapter
|
|
106
112
|
this.source = props.source
|
|
107
113
|
this.taskType = props.taskType
|
|
114
|
+
this._memoizer = props.memoizer ?? ZaiContext._noopMemoizer
|
|
108
115
|
this._eventEmitter = new EventEmitter<ContextEvents>()
|
|
109
116
|
|
|
110
117
|
this._client.on('request', () => {
|
|
@@ -148,6 +155,20 @@ export class ZaiContext {
|
|
|
148
155
|
|
|
149
156
|
public async generateContent<Out = string>(
|
|
150
157
|
props: GenerateContentProps<Out>
|
|
158
|
+
): Promise<{ meta: Meta; output: GenerateContentOutput; text: string | undefined; extracted: Out }> {
|
|
159
|
+
const memoKey = `zai:memo:${this.taskType}:${this.taskId || 'default'}:${fastHash(
|
|
160
|
+
JSON.stringify({
|
|
161
|
+
s: props.systemPrompt,
|
|
162
|
+
m: props.messages?.map((m) => ('content' in m ? m.content : '')),
|
|
163
|
+
st: props.stopSequences,
|
|
164
|
+
})
|
|
165
|
+
)}`
|
|
166
|
+
|
|
167
|
+
return this._memoizer.run(memoKey, () => this._generateContentInner(props))
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
private async _generateContentInner<Out = string>(
|
|
171
|
+
props: GenerateContentProps<Out>
|
|
151
172
|
): Promise<{ meta: Meta; output: GenerateContentOutput; text: string | undefined; extracted: Out }> {
|
|
152
173
|
const maxRetries = Math.max(props.maxRetries ?? 3, 0)
|
|
153
174
|
const transform = props.transform
|
package/src/index.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { Zai } from './zai'
|
|
1
|
+
import { Zai, type Memoizer } from './zai'
|
|
2
2
|
|
|
3
3
|
import './operations/text'
|
|
4
4
|
import './operations/rewrite'
|
|
@@ -14,3 +14,4 @@ import './operations/answer'
|
|
|
14
14
|
import './operations/patch'
|
|
15
15
|
|
|
16
16
|
export { Zai }
|
|
17
|
+
export type { Memoizer }
|
package/src/operations/answer.ts
CHANGED
package/src/operations/check.ts
CHANGED
|
@@ -484,6 +484,7 @@ Zai.prototype.extract = function <S extends OfType<AnyObjectOrArray>>(
|
|
|
484
484
|
taskId: this.taskId,
|
|
485
485
|
taskType: 'zai.extract',
|
|
486
486
|
adapter: this.adapter,
|
|
487
|
+
memoizer: this._resolveMemoizer(),
|
|
487
488
|
})
|
|
488
489
|
|
|
489
490
|
return new Response<S['_output']>(context, extract(input, schema, _options, context), (result) => result)
|
package/src/operations/filter.ts
CHANGED
|
@@ -363,6 +363,7 @@ Zai.prototype.filter = function <T>(
|
|
|
363
363
|
taskId: this.taskId,
|
|
364
364
|
taskType: 'zai.filter',
|
|
365
365
|
adapter: this.adapter,
|
|
366
|
+
memoizer: this._resolveMemoizer(),
|
|
366
367
|
})
|
|
367
368
|
|
|
368
369
|
return new Response<Array<T>>(context, filter(input, condition, _options, context), (result) => result)
|
package/src/operations/group.ts
CHANGED
|
@@ -32,6 +32,8 @@ export type Options = {
|
|
|
32
32
|
tokensPerElement?: number
|
|
33
33
|
chunkLength?: number
|
|
34
34
|
initialGroups?: Array<InitialGroup>
|
|
35
|
+
maxGroups?: number
|
|
36
|
+
minElements?: number
|
|
35
37
|
}
|
|
36
38
|
|
|
37
39
|
const _Options = z.object({
|
|
@@ -39,6 +41,8 @@ const _Options = z.object({
|
|
|
39
41
|
tokensPerElement: z.number().min(1).max(100_000).optional().default(250),
|
|
40
42
|
chunkLength: z.number().min(100).max(100_000).optional().default(16_000),
|
|
41
43
|
initialGroups: z.array(_InitialGroup).optional().default([]),
|
|
44
|
+
maxGroups: z.number().min(2).optional(),
|
|
45
|
+
minElements: z.number().min(1).optional(),
|
|
42
46
|
})
|
|
43
47
|
|
|
44
48
|
declare module '@botpress/zai' {
|
|
@@ -52,6 +56,8 @@ declare module '@botpress/zai' {
|
|
|
52
56
|
*
|
|
53
57
|
* @param input - Array of items to group
|
|
54
58
|
* @param options - Configuration for grouping behavior, instructions, and initial categories
|
|
59
|
+
* @param options.maxGroups - Maximum number of groups allowed (minimum 2). When set, groups are merged at the end until within limit.
|
|
60
|
+
* @param options.minElements - Minimum elements per group (minimum 1). Groups below this threshold have their elements redistributed via AI.
|
|
55
61
|
* @returns Response with groups array (simplified to Record<groupLabel, items[]>)
|
|
56
62
|
*
|
|
57
63
|
* @example Automatic grouping
|
|
@@ -193,6 +199,18 @@ declare module '@botpress/zai' {
|
|
|
193
199
|
* })
|
|
194
200
|
* ```
|
|
195
201
|
*/
|
|
202
|
+
/**
|
|
203
|
+
* @example Limiting number of groups
|
|
204
|
+
* ```typescript
|
|
205
|
+
* const items = ['apple', 'banana', 'carrot', 'chicken', 'rice', 'bread', 'salmon', 'milk']
|
|
206
|
+
*
|
|
207
|
+
* const groups = await zai.group(items, {
|
|
208
|
+
* instructions: 'Group by food type',
|
|
209
|
+
* maxGroups: 3 // At most 3 groups — smallest groups get merged if exceeded
|
|
210
|
+
* })
|
|
211
|
+
* // Guarantees no more than 3 groups in the result
|
|
212
|
+
* ```
|
|
213
|
+
*/
|
|
196
214
|
group<T>(input: Array<T>, options?: Options): Response<Array<Group<T>>, Record<string, T[]>>
|
|
197
215
|
}
|
|
198
216
|
}
|
|
@@ -609,6 +627,265 @@ ${END}`.trim()
|
|
|
609
627
|
}
|
|
610
628
|
}
|
|
611
629
|
|
|
630
|
+
// Phase 4: Merge groups if maxGroups is set (AI-driven)
|
|
631
|
+
if (options.maxGroups !== undefined) {
|
|
632
|
+
const nonEmptyGroupIds = () =>
|
|
633
|
+
Array.from(groupElements.entries())
|
|
634
|
+
.filter(([, s]) => s.size > 0)
|
|
635
|
+
.map(([id]) => id)
|
|
636
|
+
|
|
637
|
+
let currentIds = nonEmptyGroupIds()
|
|
638
|
+
|
|
639
|
+
if (currentIds.length > options.maxGroups) {
|
|
640
|
+
// Build a summary of each group: label + element count + sample elements
|
|
641
|
+
const groupSummaries = currentIds.map((gid, idx) => {
|
|
642
|
+
const info = groups.get(gid)!
|
|
643
|
+
const elemIndices = Array.from(groupElements.get(gid)!)
|
|
644
|
+
const sampleElements = elemIndices
|
|
645
|
+
.slice(0, 3)
|
|
646
|
+
.map((i) => tokenizer.truncate(elements[i].stringified, 60))
|
|
647
|
+
.join(', ')
|
|
648
|
+
return `■${idx}:${info.label} (${elemIndices.length} elements, e.g. ${sampleElements})■`
|
|
649
|
+
})
|
|
650
|
+
|
|
651
|
+
const mergeSystemPrompt = `You are consolidating groups into fewer, broader categories.
|
|
652
|
+
|
|
653
|
+
${options.instructions ? `**Original instructions:** ${options.instructions}\n` : ''}
|
|
654
|
+
**Task:** Merge ${currentIds.length} groups down to at most ${options.maxGroups} groups.
|
|
655
|
+
Combine the most semantically related groups together. Give each merged group a new descriptive label.
|
|
656
|
+
|
|
657
|
+
**Output Format:**
|
|
658
|
+
For each input group (■0 to ■${currentIds.length - 1}), output which target label it maps to:
|
|
659
|
+
■0:Merged Label■
|
|
660
|
+
■1:Merged Label■
|
|
661
|
+
${END}
|
|
662
|
+
|
|
663
|
+
Use the EXACT SAME label for groups that should be merged together.`.trim()
|
|
664
|
+
|
|
665
|
+
const mergeUserPrompt = `**Current groups:**
|
|
666
|
+
${groupSummaries.join('\n')}
|
|
667
|
+
|
|
668
|
+
Merge into at most ${options.maxGroups} groups.
|
|
669
|
+
${END}`.trim()
|
|
670
|
+
|
|
671
|
+
const { extracted: mergeAssignments } = await ctx.generateContent({
|
|
672
|
+
systemPrompt: mergeSystemPrompt,
|
|
673
|
+
stopSequences: [END],
|
|
674
|
+
messages: [{ type: 'text', role: 'user', content: mergeUserPrompt }],
|
|
675
|
+
transform: (text) => {
|
|
676
|
+
const assignments: Array<{ sourceIdx: number; label: string }> = []
|
|
677
|
+
const regex = /■(\d+):([^■]+)■/g
|
|
678
|
+
let match: RegExpExecArray | null
|
|
679
|
+
|
|
680
|
+
while ((match = regex.exec(text)) !== null) {
|
|
681
|
+
const idx = parseInt(match[1] ?? '', 10)
|
|
682
|
+
if (isNaN(idx) || idx < 0 || idx >= currentIds.length) continue
|
|
683
|
+
|
|
684
|
+
const label = (match[2] ?? '').trim()
|
|
685
|
+
if (!label) continue
|
|
686
|
+
|
|
687
|
+
assignments.push({ sourceIdx: idx, label: label.slice(0, 250) })
|
|
688
|
+
}
|
|
689
|
+
|
|
690
|
+
return assignments
|
|
691
|
+
},
|
|
692
|
+
})
|
|
693
|
+
|
|
694
|
+
// Build merge map: normalized merge label → list of source group IDs
|
|
695
|
+
const mergeMap = new Map<string, { label: string; sourceGroupIds: string[] }>()
|
|
696
|
+
|
|
697
|
+
for (const { sourceIdx, label } of mergeAssignments) {
|
|
698
|
+
const sourceGid = currentIds[sourceIdx]
|
|
699
|
+
if (!sourceGid) continue
|
|
700
|
+
|
|
701
|
+
const normalized = normalizeLabel(label)
|
|
702
|
+
if (!mergeMap.has(normalized)) {
|
|
703
|
+
mergeMap.set(normalized, { label, sourceGroupIds: [] })
|
|
704
|
+
}
|
|
705
|
+
mergeMap.get(normalized)!.sourceGroupIds.push(sourceGid)
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
// Apply merges: for each merge target, pick the first source group as the target
|
|
709
|
+
// and move all elements from other source groups into it
|
|
710
|
+
for (const [, { label, sourceGroupIds }] of mergeMap) {
|
|
711
|
+
if (sourceGroupIds.length <= 1) continue
|
|
712
|
+
|
|
713
|
+
const targetGid = sourceGroupIds[0]
|
|
714
|
+
const targetSet = groupElements.get(targetGid)!
|
|
715
|
+
|
|
716
|
+
// Update label on the target group
|
|
717
|
+
const targetInfo = groups.get(targetGid)!
|
|
718
|
+
targetInfo.label = label
|
|
719
|
+
targetInfo.normalizedLabel = normalizeLabel(label)
|
|
720
|
+
|
|
721
|
+
for (let i = 1; i < sourceGroupIds.length; i++) {
|
|
722
|
+
const sourceGid = sourceGroupIds[i]
|
|
723
|
+
const sourceSet = groupElements.get(sourceGid)!
|
|
724
|
+
sourceSet.forEach((elemIdx) => targetSet.add(elemIdx))
|
|
725
|
+
sourceSet.clear()
|
|
726
|
+
}
|
|
727
|
+
}
|
|
728
|
+
|
|
729
|
+
// Safety: if LLM still produced too many groups, fall back to merging smallest pairs
|
|
730
|
+
currentIds = nonEmptyGroupIds()
|
|
731
|
+
while (currentIds.length > options.maxGroups) {
|
|
732
|
+
currentIds.sort((a, b) => groupElements.get(a)!.size - groupElements.get(b)!.size)
|
|
733
|
+
|
|
734
|
+
const sourceSet = groupElements.get(currentIds[0])!
|
|
735
|
+
const targetSet = groupElements.get(currentIds[1])!
|
|
736
|
+
for (const elemIdx of sourceSet) {
|
|
737
|
+
targetSet.add(elemIdx)
|
|
738
|
+
}
|
|
739
|
+
sourceSet.clear()
|
|
740
|
+
|
|
741
|
+
currentIds = nonEmptyGroupIds()
|
|
742
|
+
}
|
|
743
|
+
}
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
// Phase 5: Redistribute undersized groups if minElements is set
|
|
747
|
+
// Reuses processChunk so orphans see the valid groups as available buckets
|
|
748
|
+
if (options.minElements !== undefined && options.minElements > 1) {
|
|
749
|
+
const getNonEmptyGroupIds = () =>
|
|
750
|
+
Array.from(groupElements.entries())
|
|
751
|
+
.filter(([, s]) => s.size > 0)
|
|
752
|
+
.map(([id]) => id)
|
|
753
|
+
|
|
754
|
+
// Collect orphan elements from all undersized groups
|
|
755
|
+
const orphanIndices: number[] = []
|
|
756
|
+
|
|
757
|
+
for (const gid of getNonEmptyGroupIds()) {
|
|
758
|
+
const elemSet = groupElements.get(gid)!
|
|
759
|
+
if (elemSet.size > 0 && elemSet.size < options.minElements) {
|
|
760
|
+
for (const idx of elemSet) {
|
|
761
|
+
orphanIndices.push(idx)
|
|
762
|
+
}
|
|
763
|
+
elemSet.clear()
|
|
764
|
+
}
|
|
765
|
+
}
|
|
766
|
+
|
|
767
|
+
if (orphanIndices.length > 0) {
|
|
768
|
+
// Valid groups = everything that's still non-empty (i.e. above minElements)
|
|
769
|
+
const validGroupIds = getNonEmptyGroupIds()
|
|
770
|
+
|
|
771
|
+
// Chunk orphans and run them through processChunk with only valid groups visible
|
|
772
|
+
const orphanChunks: number[][] = []
|
|
773
|
+
let currentOrphanChunk: number[] = []
|
|
774
|
+
let currentOrphanTokens = 0
|
|
775
|
+
|
|
776
|
+
for (const elemIdx of orphanIndices) {
|
|
777
|
+
const elem = elements[elemIdx]
|
|
778
|
+
const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement)
|
|
779
|
+
const elemTokens = tokenizer.count(truncated)
|
|
780
|
+
|
|
781
|
+
if (
|
|
782
|
+
(currentOrphanTokens + elemTokens > TOKENS_FOR_ELEMENTS_MAX ||
|
|
783
|
+
currentOrphanChunk.length >= MAX_ELEMENTS_PER_CHUNK) &&
|
|
784
|
+
currentOrphanChunk.length > 0
|
|
785
|
+
) {
|
|
786
|
+
orphanChunks.push(currentOrphanChunk)
|
|
787
|
+
currentOrphanChunk = []
|
|
788
|
+
currentOrphanTokens = 0
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
currentOrphanChunk.push(elemIdx)
|
|
792
|
+
currentOrphanTokens += elemTokens
|
|
793
|
+
}
|
|
794
|
+
|
|
795
|
+
if (currentOrphanChunk.length > 0) {
|
|
796
|
+
orphanChunks.push(currentOrphanChunk)
|
|
797
|
+
}
|
|
798
|
+
|
|
799
|
+
// Process orphan chunks against valid groups (reuses the same processChunk as Phase 1)
|
|
800
|
+
const orphanResults = await Promise.all(
|
|
801
|
+
orphanChunks.map((chunk) =>
|
|
802
|
+
elementLimit(async () => {
|
|
803
|
+
// If there are valid groups, chunk them; otherwise pass empty so LLM creates new groups
|
|
804
|
+
const groupChunksForOrphans = validGroupIds.length > 0 ? getGroupChunks() : [[]]
|
|
805
|
+
|
|
806
|
+
const allAssignments = await Promise.all(
|
|
807
|
+
groupChunksForOrphans
|
|
808
|
+
.filter((gc) => gc.length === 0 || gc.some((gid) => validGroupIds.includes(gid)))
|
|
809
|
+
.map((groupChunk) => {
|
|
810
|
+
// Only show valid groups (exclude the orphaned/undersized ones)
|
|
811
|
+
const filteredGroupChunk = groupChunk.filter((gid) => validGroupIds.includes(gid))
|
|
812
|
+
return groupLimit(() => processChunk(chunk, filteredGroupChunk))
|
|
813
|
+
})
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
return allAssignments.flat()
|
|
817
|
+
})
|
|
818
|
+
)
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
// Apply assignments
|
|
822
|
+
const flatAssignments = orphanResults.flat()
|
|
823
|
+
for (const { elementIndex, label } of flatAssignments) {
|
|
824
|
+
const normalized = normalizeLabel(label)
|
|
825
|
+
let groupId = labelToGroupId.get(normalized)
|
|
826
|
+
|
|
827
|
+
if (!groupId) {
|
|
828
|
+
groupId = `group_${groupIdCounter++}`
|
|
829
|
+
groups.set(groupId, { id: groupId, label, normalizedLabel: normalized })
|
|
830
|
+
groupElements.set(groupId, new Set())
|
|
831
|
+
labelToGroupId.set(normalized, groupId)
|
|
832
|
+
}
|
|
833
|
+
groupElements.get(groupId)!.add(elementIndex)
|
|
834
|
+
}
|
|
835
|
+
|
|
836
|
+
// Safety: any orphans the LLM missed get placed into the largest group
|
|
837
|
+
const isAssigned = (idx: number) => {
|
|
838
|
+
for (const [, elemSet] of groupElements) {
|
|
839
|
+
if (elemSet.has(idx)) return true
|
|
840
|
+
}
|
|
841
|
+
return false
|
|
842
|
+
}
|
|
843
|
+
const unassigned = orphanIndices.filter((idx) => !isAssigned(idx))
|
|
844
|
+
const placeIntoLargest = (indices: number[]) => {
|
|
845
|
+
const allNonEmpty = getNonEmptyGroupIds()
|
|
846
|
+
if (allNonEmpty.length === 0) return
|
|
847
|
+
const largestGid = allNonEmpty.reduce((a, b) =>
|
|
848
|
+
groupElements.get(a)!.size >= groupElements.get(b)!.size ? a : b
|
|
849
|
+
)
|
|
850
|
+
for (const idx of indices) {
|
|
851
|
+
groupElements.get(largestGid)!.add(idx)
|
|
852
|
+
}
|
|
853
|
+
}
|
|
854
|
+
|
|
855
|
+
if (unassigned.length > 0) {
|
|
856
|
+
placeIntoLargest(unassigned)
|
|
857
|
+
}
|
|
858
|
+
|
|
859
|
+
// Second pass: if any groups are still undersized after redistribution,
|
|
860
|
+
// merge their elements into the largest group
|
|
861
|
+
const mergeUndersizedGroups = () => {
|
|
862
|
+
const allNonEmpty = getNonEmptyGroupIds()
|
|
863
|
+
if (allNonEmpty.length <= 1) return false
|
|
864
|
+
|
|
865
|
+
const largestGid = allNonEmpty.reduce((a, b) =>
|
|
866
|
+
groupElements.get(a)!.size >= groupElements.get(b)!.size ? a : b
|
|
867
|
+
)
|
|
868
|
+
const targetSet = groupElements.get(largestGid)!
|
|
869
|
+
let merged = false
|
|
870
|
+
|
|
871
|
+
for (const gid of allNonEmpty) {
|
|
872
|
+
if (gid === largestGid) continue
|
|
873
|
+
const elemSet = groupElements.get(gid)!
|
|
874
|
+
if (elemSet.size > 0 && elemSet.size < options.minElements) {
|
|
875
|
+
elemSet.forEach((idx) => targetSet.add(idx))
|
|
876
|
+
elemSet.clear()
|
|
877
|
+
merged = true
|
|
878
|
+
}
|
|
879
|
+
}
|
|
880
|
+
return merged
|
|
881
|
+
}
|
|
882
|
+
|
|
883
|
+
while (mergeUndersizedGroups()) {
|
|
884
|
+
// keep merging until no undersized groups remain
|
|
885
|
+
}
|
|
886
|
+
}
|
|
887
|
+
}
|
|
888
|
+
|
|
612
889
|
// Build final result
|
|
613
890
|
const result: Array<Group<T>> = []
|
|
614
891
|
|
|
@@ -678,6 +955,7 @@ Zai.prototype.group = function <T>(
|
|
|
678
955
|
taskId: this.taskId,
|
|
679
956
|
taskType: 'zai.group',
|
|
680
957
|
adapter: this.adapter,
|
|
958
|
+
memoizer: this._resolveMemoizer(),
|
|
681
959
|
})
|
|
682
960
|
|
|
683
961
|
return new Response<Array<Group<T>>, Record<string, T[]>>(context, group(input, _options, context), (result) => {
|
package/src/operations/label.ts
CHANGED
package/src/operations/patch.ts
CHANGED
|
@@ -650,6 +650,7 @@ Zai.prototype.patch = function (
|
|
|
650
650
|
taskId: this.taskId,
|
|
651
651
|
taskType: 'zai.patch',
|
|
652
652
|
adapter: this.adapter,
|
|
653
|
+
memoizer: this._resolveMemoizer(),
|
|
653
654
|
})
|
|
654
655
|
|
|
655
656
|
return new Response<Array<File>>(context, patch(files, instructions, _options, context), (result) => result)
|
package/src/operations/rate.ts
CHANGED
|
@@ -611,6 +611,7 @@ Zai.prototype.rate = function <T, I extends RatingInstructions>(
|
|
|
611
611
|
taskId: this.taskId,
|
|
612
612
|
taskType: 'zai.rate',
|
|
613
613
|
adapter: this.adapter,
|
|
614
|
+
memoizer: this._resolveMemoizer(),
|
|
614
615
|
})
|
|
615
616
|
|
|
616
617
|
return new Response<Array<RatingResult<I>>, Array<SimplifiedRatingResult<I>>>(
|
|
@@ -277,6 +277,7 @@ Zai.prototype.rewrite = function (this: Zai, original: string, prompt: string, _
|
|
|
277
277
|
taskId: this.taskId,
|
|
278
278
|
taskType: 'zai.rewrite',
|
|
279
279
|
adapter: this.adapter,
|
|
280
|
+
memoizer: this._resolveMemoizer(),
|
|
280
281
|
})
|
|
281
282
|
|
|
282
283
|
return new Response<string>(context, rewrite(original, prompt, _options, context), (result) => result)
|
package/src/operations/sort.ts
CHANGED
|
@@ -306,6 +306,7 @@ Zai.prototype.summarize = function (this: Zai, original, _options): Response<str
|
|
|
306
306
|
taskId: this.taskId,
|
|
307
307
|
taskType: 'summarize',
|
|
308
308
|
adapter: this.adapter,
|
|
309
|
+
memoizer: this._resolveMemoizer(),
|
|
309
310
|
})
|
|
310
311
|
|
|
311
312
|
return new Response<string, string>(context, summarize(original, options, context), (value) => value)
|
package/src/operations/text.ts
CHANGED
|
@@ -135,6 +135,7 @@ Zai.prototype.text = function (this: Zai, prompt: string, _options?: Options): R
|
|
|
135
135
|
taskId: this.taskId,
|
|
136
136
|
taskType: 'zai.text',
|
|
137
137
|
adapter: this.adapter,
|
|
138
|
+
memoizer: this._resolveMemoizer(),
|
|
138
139
|
})
|
|
139
140
|
|
|
140
141
|
return new Response<string>(context, text(prompt, _options, context), (result) => result)
|
package/src/zai.ts
CHANGED
|
@@ -8,6 +8,17 @@ import { Adapter } from './adapters/adapter'
|
|
|
8
8
|
import { TableAdapter } from './adapters/botpress-table'
|
|
9
9
|
import { MemoryAdapter } from './adapters/memory'
|
|
10
10
|
|
|
11
|
+
/**
|
|
12
|
+
* A memoizer that caches the result of async operations by a unique key.
|
|
13
|
+
*
|
|
14
|
+
* When used with the Botpress ADK workflow `step` function, this enables
|
|
15
|
+
* Zai operations to resume where they left off if a workflow is interrupted.
|
|
16
|
+
*
|
|
17
|
+
*/
|
|
18
|
+
export type Memoizer = {
|
|
19
|
+
run: <T>(id: string, fn: () => Promise<T>) => Promise<T>
|
|
20
|
+
}
|
|
21
|
+
|
|
11
22
|
/**
|
|
12
23
|
* Active learning configuration for improving AI operations over time.
|
|
13
24
|
*
|
|
@@ -86,6 +97,16 @@ type ZaiConfig = {
|
|
|
86
97
|
activeLearning?: ActiveLearning
|
|
87
98
|
/** Namespace for organizing tasks (default: 'zai') */
|
|
88
99
|
namespace?: string
|
|
100
|
+
/**
|
|
101
|
+
* Memoizer (or factory returning one) for caching cognitive call results.
|
|
102
|
+
*
|
|
103
|
+
* When provided, all LLM calls are wrapped in the memoizer, allowing results
|
|
104
|
+
* to be cached and replayed. This is useful for resuming workflow runs where
|
|
105
|
+
* Zai operations have already completed their cognitive calls.
|
|
106
|
+
*
|
|
107
|
+
* If a factory function is provided, it is called once per Zai operation invocation.
|
|
108
|
+
*/
|
|
109
|
+
memoize?: Memoizer | (() => Memoizer)
|
|
89
110
|
}
|
|
90
111
|
|
|
91
112
|
const _ZaiConfig = z.object({
|
|
@@ -195,6 +216,7 @@ export class Zai {
|
|
|
195
216
|
protected namespace: string
|
|
196
217
|
protected adapter: Adapter
|
|
197
218
|
protected activeLearning: ActiveLearning
|
|
219
|
+
protected _memoize?: Memoizer | (() => Memoizer)
|
|
198
220
|
|
|
199
221
|
/**
|
|
200
222
|
* Creates a new Zai instance with the specified configuration.
|
|
@@ -236,6 +258,8 @@ export class Zai {
|
|
|
236
258
|
tableName: parsed.activeLearning.tableName,
|
|
237
259
|
})
|
|
238
260
|
: new MemoryAdapter([])
|
|
261
|
+
|
|
262
|
+
this._memoize = config.memoize
|
|
239
263
|
}
|
|
240
264
|
|
|
241
265
|
/** @internal */
|
|
@@ -250,6 +274,14 @@ export class Zai {
|
|
|
250
274
|
})
|
|
251
275
|
}
|
|
252
276
|
|
|
277
|
+
/** @internal */
|
|
278
|
+
protected _resolveMemoizer(): Memoizer | undefined {
|
|
279
|
+
if (!this._memoize) {
|
|
280
|
+
return undefined
|
|
281
|
+
}
|
|
282
|
+
return typeof this._memoize === 'function' ? this._memoize() : this._memoize
|
|
283
|
+
}
|
|
284
|
+
|
|
253
285
|
protected async getTokenizer() {
|
|
254
286
|
Zai.tokenizer ??= await (async () => {
|
|
255
287
|
while (!getWasmTokenizer) {
|