@ax-llm/ax 12.0.4 → 12.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +301 -294
- package/index.cjs.map +1 -1
- package/index.d.cts +125 -68
- package/index.d.ts +125 -68
- package/index.js +301 -294
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.cjs
CHANGED
|
@@ -3500,7 +3500,7 @@ var AxAIGoogleGeminiImpl = class {
|
|
|
3500
3500
|
}
|
|
3501
3501
|
];
|
|
3502
3502
|
return {
|
|
3503
|
-
role: "
|
|
3503
|
+
role: "user",
|
|
3504
3504
|
parts
|
|
3505
3505
|
};
|
|
3506
3506
|
}
|
|
@@ -4223,8 +4223,11 @@ var AxAIMistral = class extends AxAIOpenAIBase {
|
|
|
4223
4223
|
for (const message of messages) {
|
|
4224
4224
|
if (message.role === "user" && Array.isArray(message.content)) {
|
|
4225
4225
|
const contentUpdated = message.content.map((item) => {
|
|
4226
|
-
if (typeof item === "object" && item !== null &&
|
|
4227
|
-
return {
|
|
4226
|
+
if (typeof item === "object" && item !== null && "image_url" in item) {
|
|
4227
|
+
return {
|
|
4228
|
+
type: "image_url",
|
|
4229
|
+
image_url: { url: item.image_url?.url }
|
|
4230
|
+
};
|
|
4228
4231
|
}
|
|
4229
4232
|
return item;
|
|
4230
4233
|
});
|
|
@@ -4897,8 +4900,6 @@ var AxAIOpenAIResponsesImpl = class {
|
|
|
4897
4900
|
baseResult.content = event.delta;
|
|
4898
4901
|
break;
|
|
4899
4902
|
case "response.output_text.done":
|
|
4900
|
-
baseResult.id = event.item_id;
|
|
4901
|
-
baseResult.content = event.text;
|
|
4902
4903
|
break;
|
|
4903
4904
|
case "response.function_call_arguments.delta":
|
|
4904
4905
|
baseResult.id = event.item_id;
|
|
@@ -5717,9 +5718,6 @@ var MemoryImpl = class {
|
|
|
5717
5718
|
functionCalls
|
|
5718
5719
|
}) {
|
|
5719
5720
|
const isContentEmpty = typeof content === "string" && content.trim() === "";
|
|
5720
|
-
if (isContentEmpty && (!functionCalls || functionCalls.length === 0)) {
|
|
5721
|
-
return;
|
|
5722
|
-
}
|
|
5723
5721
|
if (isContentEmpty) {
|
|
5724
5722
|
this.addMemory({ name, role: "assistant", functionCalls });
|
|
5725
5723
|
} else {
|
|
@@ -5744,17 +5742,16 @@ var MemoryImpl = class {
|
|
|
5744
5742
|
}) {
|
|
5745
5743
|
const lastItem = this.data.at(-1);
|
|
5746
5744
|
if (!lastItem || lastItem.chat.role !== "assistant") {
|
|
5747
|
-
|
|
5748
|
-
}
|
|
5749
|
-
|
|
5750
|
-
|
|
5751
|
-
|
|
5752
|
-
|
|
5753
|
-
|
|
5754
|
-
|
|
5755
|
-
|
|
5756
|
-
|
|
5757
|
-
}
|
|
5745
|
+
throw new Error("No assistant message to update");
|
|
5746
|
+
}
|
|
5747
|
+
if (typeof content === "string" && content.trim() !== "") {
|
|
5748
|
+
lastItem.chat.content = content;
|
|
5749
|
+
}
|
|
5750
|
+
if (name && name.trim() !== "") {
|
|
5751
|
+
lastItem.chat.name = name;
|
|
5752
|
+
}
|
|
5753
|
+
if (functionCalls && functionCalls.length > 0) {
|
|
5754
|
+
lastItem.chat.functionCalls = functionCalls;
|
|
5758
5755
|
}
|
|
5759
5756
|
if (this.options?.debug) {
|
|
5760
5757
|
if (delta && typeof delta === "string") {
|
|
@@ -6284,6 +6281,12 @@ ${outputFields}`);
|
|
|
6284
6281
|
text: task.join("\n\n")
|
|
6285
6282
|
};
|
|
6286
6283
|
}
|
|
6284
|
+
renderSingleValueUserContent = (values, renderedExamples, renderedDemos, examplesInSystemPrompt) => {
|
|
6285
|
+
const completion = this.renderInputFields(values);
|
|
6286
|
+
const promptList = examplesInSystemPrompt ? completion : [...renderedExamples, ...renderedDemos, ...completion];
|
|
6287
|
+
const prompt = promptList.filter((v) => v !== void 0);
|
|
6288
|
+
return prompt.every((v) => v.type === "text") ? prompt.map((v) => v.text).join("\n") : prompt.reduce(combineConsecutiveStrings("\n"), []);
|
|
6289
|
+
};
|
|
6287
6290
|
render = (values, {
|
|
6288
6291
|
examples,
|
|
6289
6292
|
demos
|
|
@@ -6312,60 +6315,49 @@ ${outputFields}`);
|
|
|
6312
6315
|
role: "system",
|
|
6313
6316
|
content: systemContent
|
|
6314
6317
|
};
|
|
6315
|
-
let userMessages = [];
|
|
6316
6318
|
if (Array.isArray(values)) {
|
|
6319
|
+
let userMessages = [];
|
|
6317
6320
|
const history = values;
|
|
6318
|
-
|
|
6319
|
-
|
|
6320
|
-
|
|
6321
|
-
|
|
6322
|
-
|
|
6323
|
-
|
|
6324
|
-
|
|
6321
|
+
for (const [index, message] of history.entries()) {
|
|
6322
|
+
let content;
|
|
6323
|
+
if (index === 0) {
|
|
6324
|
+
content = this.renderSingleValueUserContent(
|
|
6325
|
+
message.values,
|
|
6326
|
+
renderedExamples,
|
|
6327
|
+
renderedDemos,
|
|
6328
|
+
examplesInSystemPrompt
|
|
6325
6329
|
);
|
|
6326
|
-
|
|
6327
|
-
|
|
6328
|
-
|
|
6329
|
-
|
|
6330
|
-
|
|
6330
|
+
} else {
|
|
6331
|
+
content = this.renderSingleValueUserContent(
|
|
6332
|
+
message.values,
|
|
6333
|
+
[],
|
|
6334
|
+
[],
|
|
6335
|
+
false
|
|
6331
6336
|
);
|
|
6332
|
-
messageContent = assistantMsgParts.map((part) => part.type === "text" ? part.text : "").join("").trim();
|
|
6333
6337
|
}
|
|
6334
|
-
if (
|
|
6335
|
-
|
|
6336
|
-
|
|
6337
|
-
if (lastMessage) {
|
|
6338
|
-
lastMessage.content += "\n" + messageContent;
|
|
6339
|
-
}
|
|
6340
|
-
} else {
|
|
6341
|
-
if (message.role === "user") {
|
|
6342
|
-
userMessages.push({ role: "user", content: messageContent });
|
|
6343
|
-
} else if (message.role === "assistant") {
|
|
6344
|
-
userMessages.push({ role: "assistant", content: messageContent });
|
|
6345
|
-
}
|
|
6346
|
-
}
|
|
6347
|
-
lastRole = message.role;
|
|
6338
|
+
if (message.role === "user") {
|
|
6339
|
+
userMessages.push({ role: "user", content });
|
|
6340
|
+
continue;
|
|
6348
6341
|
}
|
|
6342
|
+
if (message.role !== "assistant") {
|
|
6343
|
+
throw new Error("Invalid message role");
|
|
6344
|
+
}
|
|
6345
|
+
if (typeof content !== "string") {
|
|
6346
|
+
throw new Error(
|
|
6347
|
+
"Assistant message cannot contain non-text content like images, files,etc"
|
|
6348
|
+
);
|
|
6349
|
+
}
|
|
6350
|
+
userMessages.push({ role: "assistant", content });
|
|
6349
6351
|
}
|
|
6350
|
-
|
|
6351
|
-
const currentValues = values;
|
|
6352
|
-
const completion = this.renderInputFields(currentValues);
|
|
6353
|
-
const promptList = examplesInSystemPrompt ? completion : [...renderedExamples, ...renderedDemos, ...completion];
|
|
6354
|
-
const promptFilter = promptList.filter((v) => v !== void 0);
|
|
6355
|
-
let userContent;
|
|
6356
|
-
if (promptFilter.every((v) => v.type === "text")) {
|
|
6357
|
-
userContent = promptFilter.map((v) => v.text).join("\n");
|
|
6358
|
-
} else {
|
|
6359
|
-
userContent = promptFilter.map((part) => {
|
|
6360
|
-
if (part.type === "text") return part.text;
|
|
6361
|
-
if (part.type === "image") return "[IMAGE]";
|
|
6362
|
-
if (part.type === "audio") return "[AUDIO]";
|
|
6363
|
-
return "";
|
|
6364
|
-
}).join("\n").trim();
|
|
6365
|
-
}
|
|
6366
|
-
userMessages.push({ role: "user", content: userContent });
|
|
6352
|
+
return [systemPrompt, ...userMessages];
|
|
6367
6353
|
}
|
|
6368
|
-
|
|
6354
|
+
const userContent = this.renderSingleValueUserContent(
|
|
6355
|
+
values,
|
|
6356
|
+
renderedExamples,
|
|
6357
|
+
renderedDemos,
|
|
6358
|
+
examplesInSystemPrompt
|
|
6359
|
+
);
|
|
6360
|
+
return [systemPrompt, { role: "user", content: userContent }];
|
|
6369
6361
|
};
|
|
6370
6362
|
renderExtraFields = (extraFields) => {
|
|
6371
6363
|
const prompt = [];
|
|
@@ -6432,9 +6424,6 @@ ${outputFields}`);
|
|
|
6432
6424
|
if ("text" in v) {
|
|
6433
6425
|
v.text = v.text + "\n";
|
|
6434
6426
|
}
|
|
6435
|
-
if ("image" in v) {
|
|
6436
|
-
v.image = v.image;
|
|
6437
|
-
}
|
|
6438
6427
|
list.push(v);
|
|
6439
6428
|
});
|
|
6440
6429
|
}
|
|
@@ -6465,9 +6454,6 @@ ${outputFields}`);
|
|
|
6465
6454
|
if ("text" in v) {
|
|
6466
6455
|
v.text = v.text + "\n";
|
|
6467
6456
|
}
|
|
6468
|
-
if ("image" in v) {
|
|
6469
|
-
v.image = v.image;
|
|
6470
|
-
}
|
|
6471
6457
|
list.push(v);
|
|
6472
6458
|
});
|
|
6473
6459
|
}
|
|
@@ -8364,7 +8350,7 @@ var AxSignature = class _AxSignature {
|
|
|
8364
8350
|
});
|
|
8365
8351
|
this.inputFields = parsedFields;
|
|
8366
8352
|
this.invalidateValidationCache();
|
|
8367
|
-
this.
|
|
8353
|
+
this.updateHashLight();
|
|
8368
8354
|
} catch (error) {
|
|
8369
8355
|
if (error instanceof AxSignatureValidationError) {
|
|
8370
8356
|
throw error;
|
|
@@ -8390,7 +8376,7 @@ var AxSignature = class _AxSignature {
|
|
|
8390
8376
|
});
|
|
8391
8377
|
this.outputFields = parsedFields;
|
|
8392
8378
|
this.invalidateValidationCache();
|
|
8393
|
-
this.
|
|
8379
|
+
this.updateHashLight();
|
|
8394
8380
|
} catch (error) {
|
|
8395
8381
|
if (error instanceof AxSignatureValidationError) {
|
|
8396
8382
|
throw error;
|
|
@@ -9118,6 +9104,14 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9118
9104
|
s: -1
|
|
9119
9105
|
};
|
|
9120
9106
|
let content = "";
|
|
9107
|
+
mem.addResult(
|
|
9108
|
+
{
|
|
9109
|
+
content: "",
|
|
9110
|
+
name: "initial",
|
|
9111
|
+
functionCalls: []
|
|
9112
|
+
},
|
|
9113
|
+
sessionId
|
|
9114
|
+
);
|
|
9121
9115
|
for await (const v of res) {
|
|
9122
9116
|
const result = v.results[0];
|
|
9123
9117
|
if (!result) {
|
|
@@ -10158,219 +10152,6 @@ function validateModels2(services) {
|
|
|
10158
10152
|
}
|
|
10159
10153
|
}
|
|
10160
10154
|
|
|
10161
|
-
// dsp/optimize.ts
|
|
10162
|
-
var AxBootstrapFewShot = class {
|
|
10163
|
-
ai;
|
|
10164
|
-
teacherAI;
|
|
10165
|
-
program;
|
|
10166
|
-
examples;
|
|
10167
|
-
maxRounds;
|
|
10168
|
-
maxDemos;
|
|
10169
|
-
maxExamples;
|
|
10170
|
-
batchSize;
|
|
10171
|
-
earlyStoppingPatience;
|
|
10172
|
-
costMonitoring;
|
|
10173
|
-
maxTokensPerGeneration;
|
|
10174
|
-
verboseMode;
|
|
10175
|
-
debugMode;
|
|
10176
|
-
traces = [];
|
|
10177
|
-
stats = {
|
|
10178
|
-
totalCalls: 0,
|
|
10179
|
-
successfulDemos: 0,
|
|
10180
|
-
estimatedTokenUsage: 0,
|
|
10181
|
-
earlyStopped: false
|
|
10182
|
-
};
|
|
10183
|
-
constructor({
|
|
10184
|
-
ai,
|
|
10185
|
-
program,
|
|
10186
|
-
examples = [],
|
|
10187
|
-
options
|
|
10188
|
-
}) {
|
|
10189
|
-
if (examples.length === 0) {
|
|
10190
|
-
throw new Error("No examples found");
|
|
10191
|
-
}
|
|
10192
|
-
this.maxRounds = options?.maxRounds ?? 3;
|
|
10193
|
-
this.maxDemos = options?.maxDemos ?? 4;
|
|
10194
|
-
this.maxExamples = options?.maxExamples ?? 16;
|
|
10195
|
-
this.batchSize = options?.batchSize ?? 1;
|
|
10196
|
-
this.earlyStoppingPatience = options?.earlyStoppingPatience ?? 0;
|
|
10197
|
-
this.costMonitoring = options?.costMonitoring ?? false;
|
|
10198
|
-
this.maxTokensPerGeneration = options?.maxTokensPerGeneration ?? 0;
|
|
10199
|
-
this.verboseMode = options?.verboseMode ?? true;
|
|
10200
|
-
this.debugMode = options?.debugMode ?? false;
|
|
10201
|
-
this.ai = ai;
|
|
10202
|
-
this.teacherAI = options?.teacherAI;
|
|
10203
|
-
this.program = program;
|
|
10204
|
-
this.examples = examples;
|
|
10205
|
-
}
|
|
10206
|
-
async compileRound(roundIndex, metricFn, options) {
|
|
10207
|
-
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
10208
|
-
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
10209
|
-
const aiOpt = {
|
|
10210
|
-
modelConfig: {
|
|
10211
|
-
temperature: 0.7
|
|
10212
|
-
}
|
|
10213
|
-
};
|
|
10214
|
-
if (this.maxTokensPerGeneration > 0) {
|
|
10215
|
-
aiOpt.modelConfig.max_tokens = this.maxTokensPerGeneration;
|
|
10216
|
-
}
|
|
10217
|
-
const examples = randomSample(this.examples, this.maxExamples);
|
|
10218
|
-
const previousSuccessCount = this.traces.length;
|
|
10219
|
-
for (let i = 0; i < examples.length; i += this.batchSize) {
|
|
10220
|
-
if (i > 0) {
|
|
10221
|
-
aiOpt.modelConfig.temperature = 0.7 + 1e-3 * i;
|
|
10222
|
-
}
|
|
10223
|
-
const batch = examples.slice(i, i + this.batchSize);
|
|
10224
|
-
for (const ex of batch) {
|
|
10225
|
-
if (!ex) {
|
|
10226
|
-
continue;
|
|
10227
|
-
}
|
|
10228
|
-
const exList = examples.filter((e) => e !== ex);
|
|
10229
|
-
this.program.setExamples(exList);
|
|
10230
|
-
const aiService = this.teacherAI || this.ai;
|
|
10231
|
-
this.stats.totalCalls++;
|
|
10232
|
-
let res;
|
|
10233
|
-
let error;
|
|
10234
|
-
try {
|
|
10235
|
-
res = await this.program.forward(aiService, ex, aiOpt);
|
|
10236
|
-
if (this.costMonitoring) {
|
|
10237
|
-
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
10238
|
-
}
|
|
10239
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
10240
|
-
const success = score >= 0.5;
|
|
10241
|
-
if (success) {
|
|
10242
|
-
this.traces = [...this.traces, ...this.program.getTraces()];
|
|
10243
|
-
this.stats.successfulDemos++;
|
|
10244
|
-
}
|
|
10245
|
-
} catch (err) {
|
|
10246
|
-
error = err;
|
|
10247
|
-
res = {};
|
|
10248
|
-
}
|
|
10249
|
-
const current = i + examples.length * roundIndex + (batch.indexOf(ex) + 1);
|
|
10250
|
-
const total = examples.length * this.maxRounds;
|
|
10251
|
-
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
10252
|
-
if (this.verboseMode || this.debugMode) {
|
|
10253
|
-
const configInfo = {
|
|
10254
|
-
maxRounds: this.maxRounds,
|
|
10255
|
-
batchSize: this.batchSize,
|
|
10256
|
-
earlyStoppingPatience: this.earlyStoppingPatience,
|
|
10257
|
-
costMonitoring: this.costMonitoring,
|
|
10258
|
-
verboseMode: this.verboseMode,
|
|
10259
|
-
debugMode: this.debugMode
|
|
10260
|
-
};
|
|
10261
|
-
updateDetailedProgress(
|
|
10262
|
-
roundIndex,
|
|
10263
|
-
current,
|
|
10264
|
-
total,
|
|
10265
|
-
et,
|
|
10266
|
-
ex,
|
|
10267
|
-
this.stats,
|
|
10268
|
-
configInfo,
|
|
10269
|
-
res,
|
|
10270
|
-
error
|
|
10271
|
-
);
|
|
10272
|
-
} else {
|
|
10273
|
-
updateProgressBar(
|
|
10274
|
-
current,
|
|
10275
|
-
total,
|
|
10276
|
-
this.traces.length,
|
|
10277
|
-
et,
|
|
10278
|
-
"Tuning Prompt",
|
|
10279
|
-
30
|
|
10280
|
-
);
|
|
10281
|
-
}
|
|
10282
|
-
if (this.traces.length >= maxDemos) {
|
|
10283
|
-
return;
|
|
10284
|
-
}
|
|
10285
|
-
}
|
|
10286
|
-
}
|
|
10287
|
-
if (this.earlyStoppingPatience > 0) {
|
|
10288
|
-
const newSuccessCount = this.traces.length;
|
|
10289
|
-
const improvement = newSuccessCount - previousSuccessCount;
|
|
10290
|
-
if (!this.stats.earlyStopping) {
|
|
10291
|
-
this.stats.earlyStopping = {
|
|
10292
|
-
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
10293
|
-
patienceExhausted: false
|
|
10294
|
-
};
|
|
10295
|
-
} else if (improvement > 0) {
|
|
10296
|
-
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
10297
|
-
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
10298
|
-
this.stats.earlyStopping.patienceExhausted = true;
|
|
10299
|
-
this.stats.earlyStopped = true;
|
|
10300
|
-
if (this.verboseMode || this.debugMode) {
|
|
10301
|
-
console.log(
|
|
10302
|
-
`
|
|
10303
|
-
Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${this.earlyStoppingPatience} rounds.`
|
|
10304
|
-
);
|
|
10305
|
-
}
|
|
10306
|
-
return;
|
|
10307
|
-
}
|
|
10308
|
-
}
|
|
10309
|
-
}
|
|
10310
|
-
async compile(metricFn, options) {
|
|
10311
|
-
const maxRounds = options?.maxRounds ?? this.maxRounds;
|
|
10312
|
-
this.traces = [];
|
|
10313
|
-
this.stats = {
|
|
10314
|
-
totalCalls: 0,
|
|
10315
|
-
successfulDemos: 0,
|
|
10316
|
-
estimatedTokenUsage: 0,
|
|
10317
|
-
earlyStopped: false
|
|
10318
|
-
};
|
|
10319
|
-
for (let i = 0; i < maxRounds; i++) {
|
|
10320
|
-
await this.compileRound(i, metricFn, options);
|
|
10321
|
-
if (this.stats.earlyStopped) {
|
|
10322
|
-
break;
|
|
10323
|
-
}
|
|
10324
|
-
}
|
|
10325
|
-
if (this.traces.length === 0) {
|
|
10326
|
-
throw new Error(
|
|
10327
|
-
"No demonstrations found. Either provider more examples or improve the existing ones."
|
|
10328
|
-
);
|
|
10329
|
-
}
|
|
10330
|
-
const demos = groupTracesByKeys(this.traces);
|
|
10331
|
-
return {
|
|
10332
|
-
demos,
|
|
10333
|
-
stats: this.stats
|
|
10334
|
-
};
|
|
10335
|
-
}
|
|
10336
|
-
// Get optimization statistics
|
|
10337
|
-
getStats() {
|
|
10338
|
-
return this.stats;
|
|
10339
|
-
}
|
|
10340
|
-
};
|
|
10341
|
-
function groupTracesByKeys(programTraces) {
|
|
10342
|
-
const groupedTraces = /* @__PURE__ */ new Map();
|
|
10343
|
-
for (const programTrace of programTraces) {
|
|
10344
|
-
if (groupedTraces.has(programTrace.programId)) {
|
|
10345
|
-
const traces = groupedTraces.get(programTrace.programId);
|
|
10346
|
-
if (traces) {
|
|
10347
|
-
traces.push(programTrace.trace);
|
|
10348
|
-
}
|
|
10349
|
-
} else {
|
|
10350
|
-
groupedTraces.set(programTrace.programId, [programTrace.trace]);
|
|
10351
|
-
}
|
|
10352
|
-
}
|
|
10353
|
-
const programDemosArray = [];
|
|
10354
|
-
for (const [programId, traces] of groupedTraces.entries()) {
|
|
10355
|
-
programDemosArray.push({ traces, programId });
|
|
10356
|
-
}
|
|
10357
|
-
return programDemosArray;
|
|
10358
|
-
}
|
|
10359
|
-
var randomSample = (array, n) => {
|
|
10360
|
-
const clonedArray = [...array];
|
|
10361
|
-
for (let i = clonedArray.length - 1; i > 0; i--) {
|
|
10362
|
-
const j = Math.floor(Math.random() * (i + 1));
|
|
10363
|
-
const caI = clonedArray[i];
|
|
10364
|
-
const caJ = clonedArray[j];
|
|
10365
|
-
if (!caI || !caJ) {
|
|
10366
|
-
throw new Error("Invalid array elements");
|
|
10367
|
-
}
|
|
10368
|
-
;
|
|
10369
|
-
[clonedArray[i], clonedArray[j]] = [caJ, caI];
|
|
10370
|
-
}
|
|
10371
|
-
return clonedArray.slice(0, n);
|
|
10372
|
-
};
|
|
10373
|
-
|
|
10374
10155
|
// db/base.ts
|
|
10375
10156
|
var import_api23 = require("@opentelemetry/api");
|
|
10376
10157
|
var AxDBBase = class {
|
|
@@ -11829,7 +11610,222 @@ var AxMCPStreambleHTTPTransport = class {
|
|
|
11829
11610
|
}
|
|
11830
11611
|
};
|
|
11831
11612
|
|
|
11832
|
-
// dsp/
|
|
11613
|
+
// dsp/optimizers/bootstrapFewshot.ts
|
|
11614
|
+
var AxBootstrapFewShot = class {
|
|
11615
|
+
ai;
|
|
11616
|
+
teacherAI;
|
|
11617
|
+
program;
|
|
11618
|
+
examples;
|
|
11619
|
+
maxRounds;
|
|
11620
|
+
maxDemos;
|
|
11621
|
+
maxExamples;
|
|
11622
|
+
batchSize;
|
|
11623
|
+
earlyStoppingPatience;
|
|
11624
|
+
costMonitoring;
|
|
11625
|
+
maxTokensPerGeneration;
|
|
11626
|
+
verboseMode;
|
|
11627
|
+
debugMode;
|
|
11628
|
+
traces = [];
|
|
11629
|
+
stats = {
|
|
11630
|
+
totalCalls: 0,
|
|
11631
|
+
successfulDemos: 0,
|
|
11632
|
+
estimatedTokenUsage: 0,
|
|
11633
|
+
earlyStopped: false
|
|
11634
|
+
};
|
|
11635
|
+
constructor({
|
|
11636
|
+
ai,
|
|
11637
|
+
program,
|
|
11638
|
+
examples = [],
|
|
11639
|
+
options
|
|
11640
|
+
}) {
|
|
11641
|
+
if (examples.length === 0) {
|
|
11642
|
+
throw new Error("No examples found");
|
|
11643
|
+
}
|
|
11644
|
+
const bootstrapOptions = options;
|
|
11645
|
+
this.maxRounds = bootstrapOptions?.maxRounds ?? 3;
|
|
11646
|
+
this.maxDemos = bootstrapOptions?.maxDemos ?? 4;
|
|
11647
|
+
this.maxExamples = bootstrapOptions?.maxExamples ?? 16;
|
|
11648
|
+
this.batchSize = bootstrapOptions?.batchSize ?? 1;
|
|
11649
|
+
this.earlyStoppingPatience = bootstrapOptions?.earlyStoppingPatience ?? 0;
|
|
11650
|
+
this.costMonitoring = bootstrapOptions?.costMonitoring ?? false;
|
|
11651
|
+
this.maxTokensPerGeneration = bootstrapOptions?.maxTokensPerGeneration ?? 0;
|
|
11652
|
+
this.verboseMode = bootstrapOptions?.verboseMode ?? true;
|
|
11653
|
+
this.debugMode = bootstrapOptions?.debugMode ?? false;
|
|
11654
|
+
this.ai = ai;
|
|
11655
|
+
this.teacherAI = bootstrapOptions?.teacherAI;
|
|
11656
|
+
this.program = program;
|
|
11657
|
+
this.examples = examples;
|
|
11658
|
+
}
|
|
11659
|
+
async compileRound(roundIndex, metricFn, options) {
|
|
11660
|
+
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
11661
|
+
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
11662
|
+
const aiOpt = {
|
|
11663
|
+
modelConfig: {
|
|
11664
|
+
temperature: 0.7
|
|
11665
|
+
}
|
|
11666
|
+
};
|
|
11667
|
+
if (this.maxTokensPerGeneration > 0) {
|
|
11668
|
+
aiOpt.modelConfig.max_tokens = this.maxTokensPerGeneration;
|
|
11669
|
+
}
|
|
11670
|
+
const examples = randomSample(this.examples, this.maxExamples);
|
|
11671
|
+
const previousSuccessCount = this.traces.length;
|
|
11672
|
+
for (let i = 0; i < examples.length; i += this.batchSize) {
|
|
11673
|
+
if (i > 0) {
|
|
11674
|
+
aiOpt.modelConfig.temperature = 0.7 + 1e-3 * i;
|
|
11675
|
+
}
|
|
11676
|
+
const batch = examples.slice(i, i + this.batchSize);
|
|
11677
|
+
for (const ex of batch) {
|
|
11678
|
+
if (!ex) {
|
|
11679
|
+
continue;
|
|
11680
|
+
}
|
|
11681
|
+
const exList = examples.filter((e) => e !== ex);
|
|
11682
|
+
this.program.setExamples(exList);
|
|
11683
|
+
const aiService = this.teacherAI || this.ai;
|
|
11684
|
+
this.stats.totalCalls++;
|
|
11685
|
+
let res;
|
|
11686
|
+
let error;
|
|
11687
|
+
try {
|
|
11688
|
+
res = await this.program.forward(aiService, ex, aiOpt);
|
|
11689
|
+
if (this.costMonitoring) {
|
|
11690
|
+
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
11691
|
+
}
|
|
11692
|
+
const score = metricFn({ prediction: res, example: ex });
|
|
11693
|
+
const success = score >= 0.5;
|
|
11694
|
+
if (success) {
|
|
11695
|
+
this.traces = [...this.traces, ...this.program.getTraces()];
|
|
11696
|
+
this.stats.successfulDemos++;
|
|
11697
|
+
}
|
|
11698
|
+
} catch (err) {
|
|
11699
|
+
error = err;
|
|
11700
|
+
res = {};
|
|
11701
|
+
}
|
|
11702
|
+
const current = i + examples.length * roundIndex + (batch.indexOf(ex) + 1);
|
|
11703
|
+
const total = examples.length * this.maxRounds;
|
|
11704
|
+
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
11705
|
+
if (this.verboseMode || this.debugMode) {
|
|
11706
|
+
const configInfo = {
|
|
11707
|
+
maxRounds: this.maxRounds,
|
|
11708
|
+
batchSize: this.batchSize,
|
|
11709
|
+
earlyStoppingPatience: this.earlyStoppingPatience,
|
|
11710
|
+
costMonitoring: this.costMonitoring,
|
|
11711
|
+
verboseMode: this.verboseMode,
|
|
11712
|
+
debugMode: this.debugMode
|
|
11713
|
+
};
|
|
11714
|
+
updateDetailedProgress(
|
|
11715
|
+
roundIndex,
|
|
11716
|
+
current,
|
|
11717
|
+
total,
|
|
11718
|
+
et,
|
|
11719
|
+
ex,
|
|
11720
|
+
this.stats,
|
|
11721
|
+
configInfo,
|
|
11722
|
+
res,
|
|
11723
|
+
error
|
|
11724
|
+
);
|
|
11725
|
+
} else {
|
|
11726
|
+
updateProgressBar(
|
|
11727
|
+
current,
|
|
11728
|
+
total,
|
|
11729
|
+
this.traces.length,
|
|
11730
|
+
et,
|
|
11731
|
+
"Tuning Prompt",
|
|
11732
|
+
30
|
|
11733
|
+
);
|
|
11734
|
+
}
|
|
11735
|
+
if (this.traces.length >= maxDemos) {
|
|
11736
|
+
return;
|
|
11737
|
+
}
|
|
11738
|
+
}
|
|
11739
|
+
}
|
|
11740
|
+
if (this.earlyStoppingPatience > 0) {
|
|
11741
|
+
const newSuccessCount = this.traces.length;
|
|
11742
|
+
const improvement = newSuccessCount - previousSuccessCount;
|
|
11743
|
+
if (!this.stats.earlyStopping) {
|
|
11744
|
+
this.stats.earlyStopping = {
|
|
11745
|
+
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
11746
|
+
patienceExhausted: false
|
|
11747
|
+
};
|
|
11748
|
+
} else if (improvement > 0) {
|
|
11749
|
+
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
11750
|
+
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
11751
|
+
this.stats.earlyStopping.patienceExhausted = true;
|
|
11752
|
+
this.stats.earlyStopped = true;
|
|
11753
|
+
if (this.verboseMode || this.debugMode) {
|
|
11754
|
+
console.log(
|
|
11755
|
+
`
|
|
11756
|
+
Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${this.earlyStoppingPatience} rounds.`
|
|
11757
|
+
);
|
|
11758
|
+
}
|
|
11759
|
+
return;
|
|
11760
|
+
}
|
|
11761
|
+
}
|
|
11762
|
+
}
|
|
11763
|
+
async compile(metricFn, options) {
|
|
11764
|
+
const compileOptions = options;
|
|
11765
|
+
const maxRounds = compileOptions?.maxRounds ?? this.maxRounds;
|
|
11766
|
+
this.traces = [];
|
|
11767
|
+
this.stats = {
|
|
11768
|
+
totalCalls: 0,
|
|
11769
|
+
successfulDemos: 0,
|
|
11770
|
+
estimatedTokenUsage: 0,
|
|
11771
|
+
earlyStopped: false
|
|
11772
|
+
};
|
|
11773
|
+
for (let i = 0; i < maxRounds; i++) {
|
|
11774
|
+
await this.compileRound(i, metricFn, compileOptions);
|
|
11775
|
+
if (this.stats.earlyStopped) {
|
|
11776
|
+
break;
|
|
11777
|
+
}
|
|
11778
|
+
}
|
|
11779
|
+
if (this.traces.length === 0) {
|
|
11780
|
+
throw new Error(
|
|
11781
|
+
"No demonstrations found. Either provider more examples or improve the existing ones."
|
|
11782
|
+
);
|
|
11783
|
+
}
|
|
11784
|
+
const demos = groupTracesByKeys(this.traces);
|
|
11785
|
+
return {
|
|
11786
|
+
demos,
|
|
11787
|
+
stats: this.stats
|
|
11788
|
+
};
|
|
11789
|
+
}
|
|
11790
|
+
// Get optimization statistics
|
|
11791
|
+
getStats() {
|
|
11792
|
+
return this.stats;
|
|
11793
|
+
}
|
|
11794
|
+
};
|
|
11795
|
+
function groupTracesByKeys(programTraces) {
|
|
11796
|
+
const groupedTraces = /* @__PURE__ */ new Map();
|
|
11797
|
+
for (const programTrace of programTraces) {
|
|
11798
|
+
if (groupedTraces.has(programTrace.programId)) {
|
|
11799
|
+
const traces = groupedTraces.get(programTrace.programId);
|
|
11800
|
+
if (traces) {
|
|
11801
|
+
traces.push(programTrace.trace);
|
|
11802
|
+
}
|
|
11803
|
+
} else {
|
|
11804
|
+
groupedTraces.set(programTrace.programId, [programTrace.trace]);
|
|
11805
|
+
}
|
|
11806
|
+
}
|
|
11807
|
+
const programDemosArray = [];
|
|
11808
|
+
for (const [programId, traces] of groupedTraces.entries()) {
|
|
11809
|
+
programDemosArray.push({ traces, programId });
|
|
11810
|
+
}
|
|
11811
|
+
return programDemosArray;
|
|
11812
|
+
}
|
|
11813
|
+
var randomSample = (array, n) => {
|
|
11814
|
+
const clonedArray = [...array];
|
|
11815
|
+
for (let i = clonedArray.length - 1; i > 0; i--) {
|
|
11816
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
11817
|
+
const caI = clonedArray[i];
|
|
11818
|
+
const caJ = clonedArray[j];
|
|
11819
|
+
if (!caI || !caJ) {
|
|
11820
|
+
throw new Error("Invalid array elements");
|
|
11821
|
+
}
|
|
11822
|
+
;
|
|
11823
|
+
[clonedArray[i], clonedArray[j]] = [caJ, caI];
|
|
11824
|
+
}
|
|
11825
|
+
return clonedArray.slice(0, n);
|
|
11826
|
+
};
|
|
11827
|
+
|
|
11828
|
+
// dsp/optimizers/miproV2.ts
|
|
11833
11829
|
var AxMiPRO = class {
|
|
11834
11830
|
ai;
|
|
11835
11831
|
program;
|
|
@@ -12055,7 +12051,7 @@ ${dataContext}
|
|
|
12055
12051
|
const result = await this.bootstrapper.compile(metricFn, {
|
|
12056
12052
|
maxDemos: this.maxBootstrappedDemos
|
|
12057
12053
|
});
|
|
12058
|
-
return result.demos;
|
|
12054
|
+
return result.demos || [];
|
|
12059
12055
|
}
|
|
12060
12056
|
/**
|
|
12061
12057
|
* Selects labeled examples directly from the training set
|
|
@@ -12389,21 +12385,22 @@ ${dataContext}
|
|
|
12389
12385
|
* The main compile method to run MIPROv2 optimization
|
|
12390
12386
|
* @param metricFn Evaluation metric function
|
|
12391
12387
|
* @param options Optional configuration options
|
|
12392
|
-
* @returns The
|
|
12388
|
+
* @returns The optimization result
|
|
12393
12389
|
*/
|
|
12394
12390
|
async compile(metricFn, options) {
|
|
12395
|
-
|
|
12396
|
-
|
|
12391
|
+
const miproOptions = options;
|
|
12392
|
+
if (miproOptions?.auto) {
|
|
12393
|
+
this.configureAuto(miproOptions.auto);
|
|
12397
12394
|
}
|
|
12398
12395
|
const trainset = this.examples;
|
|
12399
|
-
const valset =
|
|
12396
|
+
const valset = miproOptions?.valset || this.examples.slice(0, Math.floor(this.examples.length * 0.8));
|
|
12400
12397
|
if (this.verbose) {
|
|
12401
12398
|
console.log(`Starting MIPROv2 optimization with ${this.numTrials} trials`);
|
|
12402
12399
|
console.log(
|
|
12403
12400
|
`Using ${trainset.length} examples for training and ${valset.length} for validation`
|
|
12404
12401
|
);
|
|
12405
12402
|
}
|
|
12406
|
-
if (
|
|
12403
|
+
if (miproOptions?.teacher) {
|
|
12407
12404
|
if (this.verbose) {
|
|
12408
12405
|
console.log("Using provided teacher to assist with bootstrapping");
|
|
12409
12406
|
}
|
|
@@ -12460,7 +12457,17 @@ ${dataContext}
|
|
|
12460
12457
|
bootstrappedDemos,
|
|
12461
12458
|
labeledExamples
|
|
12462
12459
|
);
|
|
12463
|
-
return
|
|
12460
|
+
return {
|
|
12461
|
+
program: this.program,
|
|
12462
|
+
demos: bootstrappedDemos
|
|
12463
|
+
};
|
|
12464
|
+
}
|
|
12465
|
+
/**
|
|
12466
|
+
* Get optimization statistics from the internal bootstrapper
|
|
12467
|
+
* @returns Optimization statistics or undefined if not available
|
|
12468
|
+
*/
|
|
12469
|
+
getStats() {
|
|
12470
|
+
return this.bootstrapper.getStats();
|
|
12464
12471
|
}
|
|
12465
12472
|
};
|
|
12466
12473
|
|