@botpress/zai 2.5.17 → 2.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/context.js +14 -0
- package/dist/index.d.ts +44 -1
- package/dist/operations/answer.js +2 -1
- package/dist/operations/check.js +2 -1
- package/dist/operations/extract.js +2 -1
- package/dist/operations/filter.js +2 -1
- package/dist/operations/group.js +190 -2
- package/dist/operations/label.js +2 -1
- package/dist/operations/patch.js +2 -1
- package/dist/operations/rate.js +2 -1
- package/dist/operations/rewrite.js +2 -1
- package/dist/operations/sort.js +2 -1
- package/dist/operations/summarize.js +2 -1
- package/dist/operations/text.js +2 -1
- package/dist/zai.js +9 -0
- package/e2e/data/cache.jsonl +70 -0
- package/package.json +1 -1
- package/src/context.ts +21 -0
- package/src/index.ts +2 -1
- package/src/operations/answer.ts +1 -0
- package/src/operations/check.ts +1 -0
- package/src/operations/extract.ts +1 -0
- package/src/operations/filter.ts +1 -0
- package/src/operations/group.ts +278 -0
- package/src/operations/label.ts +1 -0
- package/src/operations/patch.ts +1 -0
- package/src/operations/rate.ts +1 -0
- package/src/operations/rewrite.ts +1 -0
- package/src/operations/sort.ts +1 -0
- package/src/operations/summarize.ts +1 -0
- package/src/operations/text.ts +1 -0
- package/src/zai.ts +32 -0
package/dist/context.js
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import { EventEmitter } from "./emitter";
|
|
2
|
+
import { fastHash } from "./utils";
|
|
2
3
|
export class ZaiContext {
|
|
3
4
|
_startedAt = Date.now();
|
|
4
5
|
_inputCost = 0;
|
|
@@ -15,8 +16,10 @@ export class ZaiContext {
|
|
|
15
16
|
adapter;
|
|
16
17
|
source;
|
|
17
18
|
_eventEmitter;
|
|
19
|
+
_memoizer;
|
|
18
20
|
controller = new AbortController();
|
|
19
21
|
_client;
|
|
22
|
+
static _noopMemoizer = { run: (_id, fn) => fn() };
|
|
20
23
|
constructor(props) {
|
|
21
24
|
this._client = props.client.clone();
|
|
22
25
|
this.taskId = props.taskId;
|
|
@@ -24,6 +27,7 @@ export class ZaiContext {
|
|
|
24
27
|
this.adapter = props.adapter;
|
|
25
28
|
this.source = props.source;
|
|
26
29
|
this.taskType = props.taskType;
|
|
30
|
+
this._memoizer = props.memoizer ?? ZaiContext._noopMemoizer;
|
|
27
31
|
this._eventEmitter = new EventEmitter();
|
|
28
32
|
this._client.on("request", () => {
|
|
29
33
|
this._totalRequests++;
|
|
@@ -57,6 +61,16 @@ export class ZaiContext {
|
|
|
57
61
|
this._eventEmitter.clear();
|
|
58
62
|
}
|
|
59
63
|
async generateContent(props) {
|
|
64
|
+
const memoKey = `zai:memo:${this.taskType}:${this.taskId || "default"}:${fastHash(
|
|
65
|
+
JSON.stringify({
|
|
66
|
+
s: props.systemPrompt,
|
|
67
|
+
m: props.messages?.map((m) => "content" in m ? m.content : ""),
|
|
68
|
+
st: props.stopSequences
|
|
69
|
+
})
|
|
70
|
+
)}`;
|
|
71
|
+
return this._memoizer.run(memoKey, () => this._generateContentInner(props));
|
|
72
|
+
}
|
|
73
|
+
async _generateContentInner(props) {
|
|
60
74
|
const maxRetries = Math.max(props.maxRetries ?? 3, 0);
|
|
61
75
|
const transform = props.transform;
|
|
62
76
|
let lastError = null;
|
package/dist/index.d.ts
CHANGED
|
@@ -41,6 +41,16 @@ declare abstract class Adapter {
|
|
|
41
41
|
abstract saveExample<TInput, TOutput>(props: SaveExampleProps<TInput, TOutput>): Promise<void>;
|
|
42
42
|
}
|
|
43
43
|
|
|
44
|
+
/**
|
|
45
|
+
* A memoizer that caches the result of async operations by a unique key.
|
|
46
|
+
*
|
|
47
|
+
* When used with the Botpress ADK workflow `step` function, this enables
|
|
48
|
+
* Zai operations to resume where they left off if a workflow is interrupted.
|
|
49
|
+
*
|
|
50
|
+
*/
|
|
51
|
+
type Memoizer = {
|
|
52
|
+
run: <T>(id: string, fn: () => Promise<T>) => Promise<T>;
|
|
53
|
+
};
|
|
44
54
|
/**
|
|
45
55
|
* Active learning configuration for improving AI operations over time.
|
|
46
56
|
*
|
|
@@ -98,6 +108,16 @@ type ZaiConfig = {
|
|
|
98
108
|
activeLearning?: ActiveLearning;
|
|
99
109
|
/** Namespace for organizing tasks (default: 'zai') */
|
|
100
110
|
namespace?: string;
|
|
111
|
+
/**
|
|
112
|
+
* Memoizer (or factory returning one) for caching cognitive call results.
|
|
113
|
+
*
|
|
114
|
+
* When provided, all LLM calls are wrapped in the memoizer, allowing results
|
|
115
|
+
* to be cached and replayed. This is useful for resuming workflow runs where
|
|
116
|
+
* Zai operations have already completed their cognitive calls.
|
|
117
|
+
*
|
|
118
|
+
* If a factory function is provided, it is called once per Zai operation invocation.
|
|
119
|
+
*/
|
|
120
|
+
memoize?: Memoizer | (() => Memoizer);
|
|
101
121
|
};
|
|
102
122
|
/**
|
|
103
123
|
* Zai - A type-safe LLM utility library for production-ready AI operations.
|
|
@@ -171,6 +191,7 @@ declare class Zai {
|
|
|
171
191
|
protected namespace: string;
|
|
172
192
|
protected adapter: Adapter;
|
|
173
193
|
protected activeLearning: ActiveLearning;
|
|
194
|
+
protected _memoize?: Memoizer | (() => Memoizer);
|
|
174
195
|
/**
|
|
175
196
|
* Creates a new Zai instance with the specified configuration.
|
|
176
197
|
*
|
|
@@ -195,6 +216,8 @@ declare class Zai {
|
|
|
195
216
|
constructor(config: ZaiConfig);
|
|
196
217
|
/** @internal */
|
|
197
218
|
protected callModel(props: Parameters<Cognitive['generateContent']>[0]): ReturnType<Cognitive['generateContent']>;
|
|
219
|
+
/** @internal */
|
|
220
|
+
protected _resolveMemoizer(): Memoizer | undefined;
|
|
198
221
|
protected getTokenizer(): Promise<TextTokenizer>;
|
|
199
222
|
protected fetchModelDetails(): Promise<void>;
|
|
200
223
|
protected get taskId(): string;
|
|
@@ -299,6 +322,7 @@ type ZaiContextProps = {
|
|
|
299
322
|
modelId: string;
|
|
300
323
|
adapter?: Adapter;
|
|
301
324
|
source?: GenerateContentInput['meta'];
|
|
325
|
+
memoizer?: Memoizer;
|
|
302
326
|
};
|
|
303
327
|
/**
|
|
304
328
|
* Usage statistics tracking tokens, cost, and request metrics for an operation.
|
|
@@ -370,8 +394,10 @@ declare class ZaiContext {
|
|
|
370
394
|
adapter?: Adapter;
|
|
371
395
|
source?: GenerateContentInput['meta'];
|
|
372
396
|
private _eventEmitter;
|
|
397
|
+
private _memoizer;
|
|
373
398
|
controller: AbortController;
|
|
374
399
|
private _client;
|
|
400
|
+
private static _noopMemoizer;
|
|
375
401
|
constructor(props: ZaiContextProps);
|
|
376
402
|
getModel(): Promise<Model>;
|
|
377
403
|
on<K extends keyof ContextEvents>(type: K, listener: (event: ContextEvents[K]) => void): this;
|
|
@@ -382,6 +408,7 @@ declare class ZaiContext {
|
|
|
382
408
|
text: string | undefined;
|
|
383
409
|
extracted: Out;
|
|
384
410
|
}>;
|
|
411
|
+
private _generateContentInner;
|
|
385
412
|
get elapsedTime(): number;
|
|
386
413
|
get usage(): Usage;
|
|
387
414
|
}
|
|
@@ -1390,6 +1417,8 @@ type Options$4 = {
|
|
|
1390
1417
|
tokensPerElement?: number;
|
|
1391
1418
|
chunkLength?: number;
|
|
1392
1419
|
initialGroups?: Array<InitialGroup>;
|
|
1420
|
+
maxGroups?: number;
|
|
1421
|
+
minElements?: number;
|
|
1393
1422
|
};
|
|
1394
1423
|
declare module '@botpress/zai' {
|
|
1395
1424
|
interface Zai {
|
|
@@ -1402,6 +1431,8 @@ declare module '@botpress/zai' {
|
|
|
1402
1431
|
*
|
|
1403
1432
|
* @param input - Array of items to group
|
|
1404
1433
|
* @param options - Configuration for grouping behavior, instructions, and initial categories
|
|
1434
|
+
* @param options.maxGroups - Maximum number of groups allowed (minimum 2). When set, groups are merged at the end until within limit.
|
|
1435
|
+
* @param options.minElements - Minimum elements per group (minimum 1). Groups below this threshold have their elements redistributed via AI.
|
|
1405
1436
|
* @returns Response with groups array (simplified to Record<groupLabel, items[]>)
|
|
1406
1437
|
*
|
|
1407
1438
|
* @example Automatic grouping
|
|
@@ -1543,6 +1574,18 @@ declare module '@botpress/zai' {
|
|
|
1543
1574
|
* })
|
|
1544
1575
|
* ```
|
|
1545
1576
|
*/
|
|
1577
|
+
/**
|
|
1578
|
+
* @example Limiting number of groups
|
|
1579
|
+
* ```typescript
|
|
1580
|
+
* const items = ['apple', 'banana', 'carrot', 'chicken', 'rice', 'bread', 'salmon', 'milk']
|
|
1581
|
+
*
|
|
1582
|
+
* const groups = await zai.group(items, {
|
|
1583
|
+
* instructions: 'Group by food type',
|
|
1584
|
+
* maxGroups: 3 // At most 3 groups — smallest groups get merged if exceeded
|
|
1585
|
+
* })
|
|
1586
|
+
* // Guarantees no more than 3 groups in the result
|
|
1587
|
+
* ```
|
|
1588
|
+
*/
|
|
1546
1589
|
group<T>(input: Array<T>, options?: Options$4): Response<Array<Group<T>>, Record<string, T[]>>;
|
|
1547
1590
|
}
|
|
1548
1591
|
}
|
|
@@ -2127,4 +2170,4 @@ declare module '@botpress/zai' {
|
|
|
2127
2170
|
}
|
|
2128
2171
|
}
|
|
2129
2172
|
|
|
2130
|
-
export { Zai };
|
|
2173
|
+
export { type Memoizer, Zai };
|
|
@@ -373,7 +373,8 @@ Zai.prototype.answer = function(documents, question, _options) {
|
|
|
373
373
|
modelId: this.Model,
|
|
374
374
|
taskId: this.taskId,
|
|
375
375
|
taskType: "zai.answer",
|
|
376
|
-
adapter: this.adapter
|
|
376
|
+
adapter: this.adapter,
|
|
377
|
+
memoizer: this._resolveMemoizer()
|
|
377
378
|
});
|
|
378
379
|
return new Response(
|
|
379
380
|
context,
|
package/dist/operations/check.js
CHANGED
|
@@ -181,7 +181,8 @@ Zai.prototype.check = function(input, condition, _options) {
|
|
|
181
181
|
modelId: this.Model,
|
|
182
182
|
taskId: this.taskId,
|
|
183
183
|
taskType: "zai.check",
|
|
184
|
-
adapter: this.adapter
|
|
184
|
+
adapter: this.adapter,
|
|
185
|
+
memoizer: this._resolveMemoizer()
|
|
185
186
|
});
|
|
186
187
|
return new Response(context, check(input, condition, options, context), (result) => result.value);
|
|
187
188
|
};
|
|
@@ -313,7 +313,8 @@ Zai.prototype.extract = function(input, schema, _options) {
|
|
|
313
313
|
modelId: this.Model,
|
|
314
314
|
taskId: this.taskId,
|
|
315
315
|
taskType: "zai.extract",
|
|
316
|
-
adapter: this.adapter
|
|
316
|
+
adapter: this.adapter,
|
|
317
|
+
memoizer: this._resolveMemoizer()
|
|
317
318
|
});
|
|
318
319
|
return new Response(context, extract(input, schema, _options, context), (result) => result);
|
|
319
320
|
};
|
|
@@ -202,7 +202,8 @@ Zai.prototype.filter = function(input, condition, _options) {
|
|
|
202
202
|
modelId: this.Model,
|
|
203
203
|
taskId: this.taskId,
|
|
204
204
|
taskType: "zai.filter",
|
|
205
|
-
adapter: this.adapter
|
|
205
|
+
adapter: this.adapter,
|
|
206
|
+
memoizer: this._resolveMemoizer()
|
|
206
207
|
});
|
|
207
208
|
return new Response(context, filter(input, condition, _options, context), (result) => result);
|
|
208
209
|
};
|
package/dist/operations/group.js
CHANGED
|
@@ -16,7 +16,9 @@ const _Options = z.object({
|
|
|
16
16
|
instructions: z.string().optional(),
|
|
17
17
|
tokensPerElement: z.number().min(1).max(1e5).optional().default(250),
|
|
18
18
|
chunkLength: z.number().min(100).max(1e5).optional().default(16e3),
|
|
19
|
-
initialGroups: z.array(_InitialGroup).optional().default([])
|
|
19
|
+
initialGroups: z.array(_InitialGroup).optional().default([]),
|
|
20
|
+
maxGroups: z.number().min(2).optional(),
|
|
21
|
+
minElements: z.number().min(1).optional()
|
|
20
22
|
});
|
|
21
23
|
const END = "\u25A0END\u25A0";
|
|
22
24
|
const normalizeLabel = (label) => {
|
|
@@ -301,6 +303,191 @@ ${END}`.trim();
|
|
|
301
303
|
groupElements.get(finalGroupId).add(elementIndex);
|
|
302
304
|
}
|
|
303
305
|
}
|
|
306
|
+
if (options.maxGroups !== void 0) {
|
|
307
|
+
const nonEmptyGroupIds = () => Array.from(groupElements.entries()).filter(([, s]) => s.size > 0).map(([id]) => id);
|
|
308
|
+
let currentIds = nonEmptyGroupIds();
|
|
309
|
+
if (currentIds.length > options.maxGroups) {
|
|
310
|
+
const groupSummaries = currentIds.map((gid, idx) => {
|
|
311
|
+
const info = groups.get(gid);
|
|
312
|
+
const elemIndices = Array.from(groupElements.get(gid));
|
|
313
|
+
const sampleElements = elemIndices.slice(0, 3).map((i) => tokenizer.truncate(elements[i].stringified, 60)).join(", ");
|
|
314
|
+
return `\u25A0${idx}:${info.label} (${elemIndices.length} elements, e.g. ${sampleElements})\u25A0`;
|
|
315
|
+
});
|
|
316
|
+
const mergeSystemPrompt = `You are consolidating groups into fewer, broader categories.
|
|
317
|
+
|
|
318
|
+
${options.instructions ? `**Original instructions:** ${options.instructions}
|
|
319
|
+
` : ""}
|
|
320
|
+
**Task:** Merge ${currentIds.length} groups down to at most ${options.maxGroups} groups.
|
|
321
|
+
Combine the most semantically related groups together. Give each merged group a new descriptive label.
|
|
322
|
+
|
|
323
|
+
**Output Format:**
|
|
324
|
+
For each input group (\u25A00 to \u25A0${currentIds.length - 1}), output which target label it maps to:
|
|
325
|
+
\u25A00:Merged Label\u25A0
|
|
326
|
+
\u25A01:Merged Label\u25A0
|
|
327
|
+
${END}
|
|
328
|
+
|
|
329
|
+
Use the EXACT SAME label for groups that should be merged together.`.trim();
|
|
330
|
+
const mergeUserPrompt = `**Current groups:**
|
|
331
|
+
${groupSummaries.join("\n")}
|
|
332
|
+
|
|
333
|
+
Merge into at most ${options.maxGroups} groups.
|
|
334
|
+
${END}`.trim();
|
|
335
|
+
const { extracted: mergeAssignments } = await ctx.generateContent({
|
|
336
|
+
systemPrompt: mergeSystemPrompt,
|
|
337
|
+
stopSequences: [END],
|
|
338
|
+
messages: [{ type: "text", role: "user", content: mergeUserPrompt }],
|
|
339
|
+
transform: (text) => {
|
|
340
|
+
const assignments = [];
|
|
341
|
+
const regex = /■(\d+):([^■]+)■/g;
|
|
342
|
+
let match;
|
|
343
|
+
while ((match = regex.exec(text)) !== null) {
|
|
344
|
+
const idx = parseInt(match[1] ?? "", 10);
|
|
345
|
+
if (isNaN(idx) || idx < 0 || idx >= currentIds.length) continue;
|
|
346
|
+
const label = (match[2] ?? "").trim();
|
|
347
|
+
if (!label) continue;
|
|
348
|
+
assignments.push({ sourceIdx: idx, label: label.slice(0, 250) });
|
|
349
|
+
}
|
|
350
|
+
return assignments;
|
|
351
|
+
}
|
|
352
|
+
});
|
|
353
|
+
const mergeMap = /* @__PURE__ */ new Map();
|
|
354
|
+
for (const { sourceIdx, label } of mergeAssignments) {
|
|
355
|
+
const sourceGid = currentIds[sourceIdx];
|
|
356
|
+
if (!sourceGid) continue;
|
|
357
|
+
const normalized = normalizeLabel(label);
|
|
358
|
+
if (!mergeMap.has(normalized)) {
|
|
359
|
+
mergeMap.set(normalized, { label, sourceGroupIds: [] });
|
|
360
|
+
}
|
|
361
|
+
mergeMap.get(normalized).sourceGroupIds.push(sourceGid);
|
|
362
|
+
}
|
|
363
|
+
for (const [, { label, sourceGroupIds }] of mergeMap) {
|
|
364
|
+
if (sourceGroupIds.length <= 1) continue;
|
|
365
|
+
const targetGid = sourceGroupIds[0];
|
|
366
|
+
const targetSet = groupElements.get(targetGid);
|
|
367
|
+
const targetInfo = groups.get(targetGid);
|
|
368
|
+
targetInfo.label = label;
|
|
369
|
+
targetInfo.normalizedLabel = normalizeLabel(label);
|
|
370
|
+
for (let i = 1; i < sourceGroupIds.length; i++) {
|
|
371
|
+
const sourceGid = sourceGroupIds[i];
|
|
372
|
+
const sourceSet = groupElements.get(sourceGid);
|
|
373
|
+
sourceSet.forEach((elemIdx) => targetSet.add(elemIdx));
|
|
374
|
+
sourceSet.clear();
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
currentIds = nonEmptyGroupIds();
|
|
378
|
+
while (currentIds.length > options.maxGroups) {
|
|
379
|
+
currentIds.sort((a, b) => groupElements.get(a).size - groupElements.get(b).size);
|
|
380
|
+
const sourceSet = groupElements.get(currentIds[0]);
|
|
381
|
+
const targetSet = groupElements.get(currentIds[1]);
|
|
382
|
+
for (const elemIdx of sourceSet) {
|
|
383
|
+
targetSet.add(elemIdx);
|
|
384
|
+
}
|
|
385
|
+
sourceSet.clear();
|
|
386
|
+
currentIds = nonEmptyGroupIds();
|
|
387
|
+
}
|
|
388
|
+
}
|
|
389
|
+
}
|
|
390
|
+
if (options.minElements !== void 0 && options.minElements > 1) {
|
|
391
|
+
const getNonEmptyGroupIds = () => Array.from(groupElements.entries()).filter(([, s]) => s.size > 0).map(([id]) => id);
|
|
392
|
+
const orphanIndices = [];
|
|
393
|
+
for (const gid of getNonEmptyGroupIds()) {
|
|
394
|
+
const elemSet = groupElements.get(gid);
|
|
395
|
+
if (elemSet.size > 0 && elemSet.size < options.minElements) {
|
|
396
|
+
for (const idx of elemSet) {
|
|
397
|
+
orphanIndices.push(idx);
|
|
398
|
+
}
|
|
399
|
+
elemSet.clear();
|
|
400
|
+
}
|
|
401
|
+
}
|
|
402
|
+
if (orphanIndices.length > 0) {
|
|
403
|
+
const validGroupIds = getNonEmptyGroupIds();
|
|
404
|
+
const orphanChunks = [];
|
|
405
|
+
let currentOrphanChunk = [];
|
|
406
|
+
let currentOrphanTokens = 0;
|
|
407
|
+
for (const elemIdx of orphanIndices) {
|
|
408
|
+
const elem = elements[elemIdx];
|
|
409
|
+
const truncated = tokenizer.truncate(elem.stringified, options.tokensPerElement);
|
|
410
|
+
const elemTokens = tokenizer.count(truncated);
|
|
411
|
+
if ((currentOrphanTokens + elemTokens > TOKENS_FOR_ELEMENTS_MAX || currentOrphanChunk.length >= MAX_ELEMENTS_PER_CHUNK) && currentOrphanChunk.length > 0) {
|
|
412
|
+
orphanChunks.push(currentOrphanChunk);
|
|
413
|
+
currentOrphanChunk = [];
|
|
414
|
+
currentOrphanTokens = 0;
|
|
415
|
+
}
|
|
416
|
+
currentOrphanChunk.push(elemIdx);
|
|
417
|
+
currentOrphanTokens += elemTokens;
|
|
418
|
+
}
|
|
419
|
+
if (currentOrphanChunk.length > 0) {
|
|
420
|
+
orphanChunks.push(currentOrphanChunk);
|
|
421
|
+
}
|
|
422
|
+
const orphanResults = await Promise.all(
|
|
423
|
+
orphanChunks.map(
|
|
424
|
+
(chunk) => elementLimit(async () => {
|
|
425
|
+
const groupChunksForOrphans = validGroupIds.length > 0 ? getGroupChunks() : [[]];
|
|
426
|
+
const allAssignments = await Promise.all(
|
|
427
|
+
groupChunksForOrphans.filter((gc) => gc.length === 0 || gc.some((gid) => validGroupIds.includes(gid))).map((groupChunk) => {
|
|
428
|
+
const filteredGroupChunk = groupChunk.filter((gid) => validGroupIds.includes(gid));
|
|
429
|
+
return groupLimit(() => processChunk(chunk, filteredGroupChunk));
|
|
430
|
+
})
|
|
431
|
+
);
|
|
432
|
+
return allAssignments.flat();
|
|
433
|
+
})
|
|
434
|
+
)
|
|
435
|
+
);
|
|
436
|
+
const flatAssignments = orphanResults.flat();
|
|
437
|
+
for (const { elementIndex, label } of flatAssignments) {
|
|
438
|
+
const normalized = normalizeLabel(label);
|
|
439
|
+
let groupId = labelToGroupId.get(normalized);
|
|
440
|
+
if (!groupId) {
|
|
441
|
+
groupId = `group_${groupIdCounter++}`;
|
|
442
|
+
groups.set(groupId, { id: groupId, label, normalizedLabel: normalized });
|
|
443
|
+
groupElements.set(groupId, /* @__PURE__ */ new Set());
|
|
444
|
+
labelToGroupId.set(normalized, groupId);
|
|
445
|
+
}
|
|
446
|
+
groupElements.get(groupId).add(elementIndex);
|
|
447
|
+
}
|
|
448
|
+
const isAssigned = (idx) => {
|
|
449
|
+
for (const [, elemSet] of groupElements) {
|
|
450
|
+
if (elemSet.has(idx)) return true;
|
|
451
|
+
}
|
|
452
|
+
return false;
|
|
453
|
+
};
|
|
454
|
+
const unassigned = orphanIndices.filter((idx) => !isAssigned(idx));
|
|
455
|
+
const placeIntoLargest = (indices) => {
|
|
456
|
+
const allNonEmpty = getNonEmptyGroupIds();
|
|
457
|
+
if (allNonEmpty.length === 0) return;
|
|
458
|
+
const largestGid = allNonEmpty.reduce(
|
|
459
|
+
(a, b) => groupElements.get(a).size >= groupElements.get(b).size ? a : b
|
|
460
|
+
);
|
|
461
|
+
for (const idx of indices) {
|
|
462
|
+
groupElements.get(largestGid).add(idx);
|
|
463
|
+
}
|
|
464
|
+
};
|
|
465
|
+
if (unassigned.length > 0) {
|
|
466
|
+
placeIntoLargest(unassigned);
|
|
467
|
+
}
|
|
468
|
+
const mergeUndersizedGroups = () => {
|
|
469
|
+
const allNonEmpty = getNonEmptyGroupIds();
|
|
470
|
+
if (allNonEmpty.length <= 1) return false;
|
|
471
|
+
const largestGid = allNonEmpty.reduce(
|
|
472
|
+
(a, b) => groupElements.get(a).size >= groupElements.get(b).size ? a : b
|
|
473
|
+
);
|
|
474
|
+
const targetSet = groupElements.get(largestGid);
|
|
475
|
+
let merged = false;
|
|
476
|
+
for (const gid of allNonEmpty) {
|
|
477
|
+
if (gid === largestGid) continue;
|
|
478
|
+
const elemSet = groupElements.get(gid);
|
|
479
|
+
if (elemSet.size > 0 && elemSet.size < options.minElements) {
|
|
480
|
+
elemSet.forEach((idx) => targetSet.add(idx));
|
|
481
|
+
elemSet.clear();
|
|
482
|
+
merged = true;
|
|
483
|
+
}
|
|
484
|
+
}
|
|
485
|
+
return merged;
|
|
486
|
+
};
|
|
487
|
+
while (mergeUndersizedGroups()) {
|
|
488
|
+
}
|
|
489
|
+
}
|
|
490
|
+
}
|
|
304
491
|
const result = [];
|
|
305
492
|
for (const [groupId, elementIndices] of groupElements.entries()) {
|
|
306
493
|
if (elementIndices.size > 0) {
|
|
@@ -354,7 +541,8 @@ Zai.prototype.group = function(input, _options) {
|
|
|
354
541
|
modelId: this.Model,
|
|
355
542
|
taskId: this.taskId,
|
|
356
543
|
taskType: "zai.group",
|
|
357
|
-
adapter: this.adapter
|
|
544
|
+
adapter: this.adapter,
|
|
545
|
+
memoizer: this._resolveMemoizer()
|
|
358
546
|
});
|
|
359
547
|
return new Response(context, group(input, _options, context), (result) => {
|
|
360
548
|
const merged = {};
|
package/dist/operations/label.js
CHANGED
|
@@ -276,7 +276,8 @@ Zai.prototype.label = function(input, labels, _options) {
|
|
|
276
276
|
modelId: this.Model,
|
|
277
277
|
taskId: this.taskId,
|
|
278
278
|
taskType: "zai.label",
|
|
279
|
-
adapter: this.adapter
|
|
279
|
+
adapter: this.adapter,
|
|
280
|
+
memoizer: this._resolveMemoizer()
|
|
280
281
|
});
|
|
281
282
|
return new Response(
|
|
282
283
|
context,
|
package/dist/operations/patch.js
CHANGED
|
@@ -392,7 +392,8 @@ Zai.prototype.patch = function(files, instructions, _options) {
|
|
|
392
392
|
modelId: this.Model,
|
|
393
393
|
taskId: this.taskId,
|
|
394
394
|
taskType: "zai.patch",
|
|
395
|
-
adapter: this.adapter
|
|
395
|
+
adapter: this.adapter,
|
|
396
|
+
memoizer: this._resolveMemoizer()
|
|
396
397
|
});
|
|
397
398
|
return new Response(context, patch(files, instructions, _options, context), (result) => result);
|
|
398
399
|
};
|
package/dist/operations/rate.js
CHANGED
|
@@ -335,7 +335,8 @@ Zai.prototype.rate = function(input, instructions, _options) {
|
|
|
335
335
|
modelId: this.Model,
|
|
336
336
|
taskId: this.taskId,
|
|
337
337
|
taskType: "zai.rate",
|
|
338
|
-
adapter: this.adapter
|
|
338
|
+
adapter: this.adapter,
|
|
339
|
+
memoizer: this._resolveMemoizer()
|
|
339
340
|
});
|
|
340
341
|
return new Response(
|
|
341
342
|
context,
|
|
@@ -136,7 +136,8 @@ Zai.prototype.rewrite = function(original, prompt, _options) {
|
|
|
136
136
|
modelId: this.Model,
|
|
137
137
|
taskId: this.taskId,
|
|
138
138
|
taskType: "zai.rewrite",
|
|
139
|
-
adapter: this.adapter
|
|
139
|
+
adapter: this.adapter,
|
|
140
|
+
memoizer: this._resolveMemoizer()
|
|
140
141
|
});
|
|
141
142
|
return new Response(context, rewrite(original, prompt, _options, context), (result) => result);
|
|
142
143
|
};
|
package/dist/operations/sort.js
CHANGED
|
@@ -511,7 +511,8 @@ Zai.prototype.sort = function(input, instructions, _options) {
|
|
|
511
511
|
modelId: this.Model,
|
|
512
512
|
taskId: this.taskId,
|
|
513
513
|
taskType: "zai.sort",
|
|
514
|
-
adapter: this.adapter
|
|
514
|
+
adapter: this.adapter,
|
|
515
|
+
memoizer: this._resolveMemoizer()
|
|
515
516
|
});
|
|
516
517
|
return new Response(
|
|
517
518
|
context,
|
|
@@ -148,7 +148,8 @@ Zai.prototype.summarize = function(original, _options) {
|
|
|
148
148
|
modelId: this.Model,
|
|
149
149
|
taskId: this.taskId,
|
|
150
150
|
taskType: "summarize",
|
|
151
|
-
adapter: this.adapter
|
|
151
|
+
adapter: this.adapter,
|
|
152
|
+
memoizer: this._resolveMemoizer()
|
|
152
153
|
});
|
|
153
154
|
return new Response(context, summarize(original, options, context), (value) => value);
|
|
154
155
|
};
|
package/dist/operations/text.js
CHANGED
|
@@ -60,7 +60,8 @@ Zai.prototype.text = function(prompt, _options) {
|
|
|
60
60
|
modelId: this.Model,
|
|
61
61
|
taskId: this.taskId,
|
|
62
62
|
taskType: "zai.text",
|
|
63
|
-
adapter: this.adapter
|
|
63
|
+
adapter: this.adapter,
|
|
64
|
+
memoizer: this._resolveMemoizer()
|
|
64
65
|
});
|
|
65
66
|
return new Response(context, text(prompt, _options, context), (result) => result);
|
|
66
67
|
};
|
package/dist/zai.js
CHANGED
|
@@ -47,6 +47,7 @@ export class Zai {
|
|
|
47
47
|
namespace;
|
|
48
48
|
adapter;
|
|
49
49
|
activeLearning;
|
|
50
|
+
_memoize;
|
|
50
51
|
/**
|
|
51
52
|
* Creates a new Zai instance with the specified configuration.
|
|
52
53
|
*
|
|
@@ -80,6 +81,7 @@ export class Zai {
|
|
|
80
81
|
client: this.client.client,
|
|
81
82
|
tableName: parsed.activeLearning.tableName
|
|
82
83
|
}) : new MemoryAdapter([]);
|
|
84
|
+
this._memoize = config.memoize;
|
|
83
85
|
}
|
|
84
86
|
/** @internal */
|
|
85
87
|
async callModel(props) {
|
|
@@ -90,6 +92,13 @@ export class Zai {
|
|
|
90
92
|
userId: this._userId
|
|
91
93
|
});
|
|
92
94
|
}
|
|
95
|
+
/** @internal */
|
|
96
|
+
_resolveMemoizer() {
|
|
97
|
+
if (!this._memoize) {
|
|
98
|
+
return void 0;
|
|
99
|
+
}
|
|
100
|
+
return typeof this._memoize === "function" ? this._memoize() : this._memoize;
|
|
101
|
+
}
|
|
93
102
|
async getTokenizer() {
|
|
94
103
|
Zai.tokenizer ??= await (async () => {
|
|
95
104
|
while (!getWasmTokenizer) {
|