@ax-llm/ax 12.0.4 → 12.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +301 -294
- package/index.cjs.map +1 -1
- package/index.d.cts +125 -68
- package/index.d.ts +125 -68
- package/index.js +301 -294
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.js
CHANGED
|
@@ -3328,7 +3328,7 @@ var AxAIGoogleGeminiImpl = class {
|
|
|
3328
3328
|
}
|
|
3329
3329
|
];
|
|
3330
3330
|
return {
|
|
3331
|
-
role: "
|
|
3331
|
+
role: "user",
|
|
3332
3332
|
parts
|
|
3333
3333
|
};
|
|
3334
3334
|
}
|
|
@@ -4051,8 +4051,11 @@ var AxAIMistral = class extends AxAIOpenAIBase {
|
|
|
4051
4051
|
for (const message of messages) {
|
|
4052
4052
|
if (message.role === "user" && Array.isArray(message.content)) {
|
|
4053
4053
|
const contentUpdated = message.content.map((item) => {
|
|
4054
|
-
if (typeof item === "object" && item !== null &&
|
|
4055
|
-
return {
|
|
4054
|
+
if (typeof item === "object" && item !== null && "image_url" in item) {
|
|
4055
|
+
return {
|
|
4056
|
+
type: "image_url",
|
|
4057
|
+
image_url: { url: item.image_url?.url }
|
|
4058
|
+
};
|
|
4056
4059
|
}
|
|
4057
4060
|
return item;
|
|
4058
4061
|
});
|
|
@@ -4725,8 +4728,6 @@ var AxAIOpenAIResponsesImpl = class {
|
|
|
4725
4728
|
baseResult.content = event.delta;
|
|
4726
4729
|
break;
|
|
4727
4730
|
case "response.output_text.done":
|
|
4728
|
-
baseResult.id = event.item_id;
|
|
4729
|
-
baseResult.content = event.text;
|
|
4730
4731
|
break;
|
|
4731
4732
|
case "response.function_call_arguments.delta":
|
|
4732
4733
|
baseResult.id = event.item_id;
|
|
@@ -5549,9 +5550,6 @@ var MemoryImpl = class {
|
|
|
5549
5550
|
functionCalls
|
|
5550
5551
|
}) {
|
|
5551
5552
|
const isContentEmpty = typeof content === "string" && content.trim() === "";
|
|
5552
|
-
if (isContentEmpty && (!functionCalls || functionCalls.length === 0)) {
|
|
5553
|
-
return;
|
|
5554
|
-
}
|
|
5555
5553
|
if (isContentEmpty) {
|
|
5556
5554
|
this.addMemory({ name, role: "assistant", functionCalls });
|
|
5557
5555
|
} else {
|
|
@@ -5576,17 +5574,16 @@ var MemoryImpl = class {
|
|
|
5576
5574
|
}) {
|
|
5577
5575
|
const lastItem = this.data.at(-1);
|
|
5578
5576
|
if (!lastItem || lastItem.chat.role !== "assistant") {
|
|
5579
|
-
|
|
5580
|
-
}
|
|
5581
|
-
|
|
5582
|
-
|
|
5583
|
-
|
|
5584
|
-
|
|
5585
|
-
|
|
5586
|
-
|
|
5587
|
-
|
|
5588
|
-
|
|
5589
|
-
}
|
|
5577
|
+
throw new Error("No assistant message to update");
|
|
5578
|
+
}
|
|
5579
|
+
if (typeof content === "string" && content.trim() !== "") {
|
|
5580
|
+
lastItem.chat.content = content;
|
|
5581
|
+
}
|
|
5582
|
+
if (name && name.trim() !== "") {
|
|
5583
|
+
lastItem.chat.name = name;
|
|
5584
|
+
}
|
|
5585
|
+
if (functionCalls && functionCalls.length > 0) {
|
|
5586
|
+
lastItem.chat.functionCalls = functionCalls;
|
|
5590
5587
|
}
|
|
5591
5588
|
if (this.options?.debug) {
|
|
5592
5589
|
if (delta && typeof delta === "string") {
|
|
@@ -6116,6 +6113,12 @@ ${outputFields}`);
|
|
|
6116
6113
|
text: task.join("\n\n")
|
|
6117
6114
|
};
|
|
6118
6115
|
}
|
|
6116
|
+
renderSingleValueUserContent = (values, renderedExamples, renderedDemos, examplesInSystemPrompt) => {
|
|
6117
|
+
const completion = this.renderInputFields(values);
|
|
6118
|
+
const promptList = examplesInSystemPrompt ? completion : [...renderedExamples, ...renderedDemos, ...completion];
|
|
6119
|
+
const prompt = promptList.filter((v) => v !== void 0);
|
|
6120
|
+
return prompt.every((v) => v.type === "text") ? prompt.map((v) => v.text).join("\n") : prompt.reduce(combineConsecutiveStrings("\n"), []);
|
|
6121
|
+
};
|
|
6119
6122
|
render = (values, {
|
|
6120
6123
|
examples,
|
|
6121
6124
|
demos
|
|
@@ -6144,60 +6147,49 @@ ${outputFields}`);
|
|
|
6144
6147
|
role: "system",
|
|
6145
6148
|
content: systemContent
|
|
6146
6149
|
};
|
|
6147
|
-
let userMessages = [];
|
|
6148
6150
|
if (Array.isArray(values)) {
|
|
6151
|
+
let userMessages = [];
|
|
6149
6152
|
const history = values;
|
|
6150
|
-
|
|
6151
|
-
|
|
6152
|
-
|
|
6153
|
-
|
|
6154
|
-
|
|
6155
|
-
|
|
6156
|
-
|
|
6153
|
+
for (const [index, message] of history.entries()) {
|
|
6154
|
+
let content;
|
|
6155
|
+
if (index === 0) {
|
|
6156
|
+
content = this.renderSingleValueUserContent(
|
|
6157
|
+
message.values,
|
|
6158
|
+
renderedExamples,
|
|
6159
|
+
renderedDemos,
|
|
6160
|
+
examplesInSystemPrompt
|
|
6157
6161
|
);
|
|
6158
|
-
|
|
6159
|
-
|
|
6160
|
-
|
|
6161
|
-
|
|
6162
|
-
|
|
6162
|
+
} else {
|
|
6163
|
+
content = this.renderSingleValueUserContent(
|
|
6164
|
+
message.values,
|
|
6165
|
+
[],
|
|
6166
|
+
[],
|
|
6167
|
+
false
|
|
6163
6168
|
);
|
|
6164
|
-
messageContent = assistantMsgParts.map((part) => part.type === "text" ? part.text : "").join("").trim();
|
|
6165
6169
|
}
|
|
6166
|
-
if (
|
|
6167
|
-
|
|
6168
|
-
|
|
6169
|
-
if (lastMessage) {
|
|
6170
|
-
lastMessage.content += "\n" + messageContent;
|
|
6171
|
-
}
|
|
6172
|
-
} else {
|
|
6173
|
-
if (message.role === "user") {
|
|
6174
|
-
userMessages.push({ role: "user", content: messageContent });
|
|
6175
|
-
} else if (message.role === "assistant") {
|
|
6176
|
-
userMessages.push({ role: "assistant", content: messageContent });
|
|
6177
|
-
}
|
|
6178
|
-
}
|
|
6179
|
-
lastRole = message.role;
|
|
6170
|
+
if (message.role === "user") {
|
|
6171
|
+
userMessages.push({ role: "user", content });
|
|
6172
|
+
continue;
|
|
6180
6173
|
}
|
|
6174
|
+
if (message.role !== "assistant") {
|
|
6175
|
+
throw new Error("Invalid message role");
|
|
6176
|
+
}
|
|
6177
|
+
if (typeof content !== "string") {
|
|
6178
|
+
throw new Error(
|
|
6179
|
+
"Assistant message cannot contain non-text content like images, files,etc"
|
|
6180
|
+
);
|
|
6181
|
+
}
|
|
6182
|
+
userMessages.push({ role: "assistant", content });
|
|
6181
6183
|
}
|
|
6182
|
-
|
|
6183
|
-
const currentValues = values;
|
|
6184
|
-
const completion = this.renderInputFields(currentValues);
|
|
6185
|
-
const promptList = examplesInSystemPrompt ? completion : [...renderedExamples, ...renderedDemos, ...completion];
|
|
6186
|
-
const promptFilter = promptList.filter((v) => v !== void 0);
|
|
6187
|
-
let userContent;
|
|
6188
|
-
if (promptFilter.every((v) => v.type === "text")) {
|
|
6189
|
-
userContent = promptFilter.map((v) => v.text).join("\n");
|
|
6190
|
-
} else {
|
|
6191
|
-
userContent = promptFilter.map((part) => {
|
|
6192
|
-
if (part.type === "text") return part.text;
|
|
6193
|
-
if (part.type === "image") return "[IMAGE]";
|
|
6194
|
-
if (part.type === "audio") return "[AUDIO]";
|
|
6195
|
-
return "";
|
|
6196
|
-
}).join("\n").trim();
|
|
6197
|
-
}
|
|
6198
|
-
userMessages.push({ role: "user", content: userContent });
|
|
6184
|
+
return [systemPrompt, ...userMessages];
|
|
6199
6185
|
}
|
|
6200
|
-
|
|
6186
|
+
const userContent = this.renderSingleValueUserContent(
|
|
6187
|
+
values,
|
|
6188
|
+
renderedExamples,
|
|
6189
|
+
renderedDemos,
|
|
6190
|
+
examplesInSystemPrompt
|
|
6191
|
+
);
|
|
6192
|
+
return [systemPrompt, { role: "user", content: userContent }];
|
|
6201
6193
|
};
|
|
6202
6194
|
renderExtraFields = (extraFields) => {
|
|
6203
6195
|
const prompt = [];
|
|
@@ -6264,9 +6256,6 @@ ${outputFields}`);
|
|
|
6264
6256
|
if ("text" in v) {
|
|
6265
6257
|
v.text = v.text + "\n";
|
|
6266
6258
|
}
|
|
6267
|
-
if ("image" in v) {
|
|
6268
|
-
v.image = v.image;
|
|
6269
|
-
}
|
|
6270
6259
|
list.push(v);
|
|
6271
6260
|
});
|
|
6272
6261
|
}
|
|
@@ -6297,9 +6286,6 @@ ${outputFields}`);
|
|
|
6297
6286
|
if ("text" in v) {
|
|
6298
6287
|
v.text = v.text + "\n";
|
|
6299
6288
|
}
|
|
6300
|
-
if ("image" in v) {
|
|
6301
|
-
v.image = v.image;
|
|
6302
|
-
}
|
|
6303
6289
|
list.push(v);
|
|
6304
6290
|
});
|
|
6305
6291
|
}
|
|
@@ -8196,7 +8182,7 @@ var AxSignature = class _AxSignature {
|
|
|
8196
8182
|
});
|
|
8197
8183
|
this.inputFields = parsedFields;
|
|
8198
8184
|
this.invalidateValidationCache();
|
|
8199
|
-
this.
|
|
8185
|
+
this.updateHashLight();
|
|
8200
8186
|
} catch (error) {
|
|
8201
8187
|
if (error instanceof AxSignatureValidationError) {
|
|
8202
8188
|
throw error;
|
|
@@ -8222,7 +8208,7 @@ var AxSignature = class _AxSignature {
|
|
|
8222
8208
|
});
|
|
8223
8209
|
this.outputFields = parsedFields;
|
|
8224
8210
|
this.invalidateValidationCache();
|
|
8225
|
-
this.
|
|
8211
|
+
this.updateHashLight();
|
|
8226
8212
|
} catch (error) {
|
|
8227
8213
|
if (error instanceof AxSignatureValidationError) {
|
|
8228
8214
|
throw error;
|
|
@@ -8950,6 +8936,14 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
8950
8936
|
s: -1
|
|
8951
8937
|
};
|
|
8952
8938
|
let content = "";
|
|
8939
|
+
mem.addResult(
|
|
8940
|
+
{
|
|
8941
|
+
content: "",
|
|
8942
|
+
name: "initial",
|
|
8943
|
+
functionCalls: []
|
|
8944
|
+
},
|
|
8945
|
+
sessionId
|
|
8946
|
+
);
|
|
8953
8947
|
for await (const v of res) {
|
|
8954
8948
|
const result = v.results[0];
|
|
8955
8949
|
if (!result) {
|
|
@@ -9990,219 +9984,6 @@ function validateModels2(services) {
|
|
|
9990
9984
|
}
|
|
9991
9985
|
}
|
|
9992
9986
|
|
|
9993
|
-
// dsp/optimize.ts
|
|
9994
|
-
var AxBootstrapFewShot = class {
|
|
9995
|
-
ai;
|
|
9996
|
-
teacherAI;
|
|
9997
|
-
program;
|
|
9998
|
-
examples;
|
|
9999
|
-
maxRounds;
|
|
10000
|
-
maxDemos;
|
|
10001
|
-
maxExamples;
|
|
10002
|
-
batchSize;
|
|
10003
|
-
earlyStoppingPatience;
|
|
10004
|
-
costMonitoring;
|
|
10005
|
-
maxTokensPerGeneration;
|
|
10006
|
-
verboseMode;
|
|
10007
|
-
debugMode;
|
|
10008
|
-
traces = [];
|
|
10009
|
-
stats = {
|
|
10010
|
-
totalCalls: 0,
|
|
10011
|
-
successfulDemos: 0,
|
|
10012
|
-
estimatedTokenUsage: 0,
|
|
10013
|
-
earlyStopped: false
|
|
10014
|
-
};
|
|
10015
|
-
constructor({
|
|
10016
|
-
ai,
|
|
10017
|
-
program,
|
|
10018
|
-
examples = [],
|
|
10019
|
-
options
|
|
10020
|
-
}) {
|
|
10021
|
-
if (examples.length === 0) {
|
|
10022
|
-
throw new Error("No examples found");
|
|
10023
|
-
}
|
|
10024
|
-
this.maxRounds = options?.maxRounds ?? 3;
|
|
10025
|
-
this.maxDemos = options?.maxDemos ?? 4;
|
|
10026
|
-
this.maxExamples = options?.maxExamples ?? 16;
|
|
10027
|
-
this.batchSize = options?.batchSize ?? 1;
|
|
10028
|
-
this.earlyStoppingPatience = options?.earlyStoppingPatience ?? 0;
|
|
10029
|
-
this.costMonitoring = options?.costMonitoring ?? false;
|
|
10030
|
-
this.maxTokensPerGeneration = options?.maxTokensPerGeneration ?? 0;
|
|
10031
|
-
this.verboseMode = options?.verboseMode ?? true;
|
|
10032
|
-
this.debugMode = options?.debugMode ?? false;
|
|
10033
|
-
this.ai = ai;
|
|
10034
|
-
this.teacherAI = options?.teacherAI;
|
|
10035
|
-
this.program = program;
|
|
10036
|
-
this.examples = examples;
|
|
10037
|
-
}
|
|
10038
|
-
async compileRound(roundIndex, metricFn, options) {
|
|
10039
|
-
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
10040
|
-
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
10041
|
-
const aiOpt = {
|
|
10042
|
-
modelConfig: {
|
|
10043
|
-
temperature: 0.7
|
|
10044
|
-
}
|
|
10045
|
-
};
|
|
10046
|
-
if (this.maxTokensPerGeneration > 0) {
|
|
10047
|
-
aiOpt.modelConfig.max_tokens = this.maxTokensPerGeneration;
|
|
10048
|
-
}
|
|
10049
|
-
const examples = randomSample(this.examples, this.maxExamples);
|
|
10050
|
-
const previousSuccessCount = this.traces.length;
|
|
10051
|
-
for (let i = 0; i < examples.length; i += this.batchSize) {
|
|
10052
|
-
if (i > 0) {
|
|
10053
|
-
aiOpt.modelConfig.temperature = 0.7 + 1e-3 * i;
|
|
10054
|
-
}
|
|
10055
|
-
const batch = examples.slice(i, i + this.batchSize);
|
|
10056
|
-
for (const ex of batch) {
|
|
10057
|
-
if (!ex) {
|
|
10058
|
-
continue;
|
|
10059
|
-
}
|
|
10060
|
-
const exList = examples.filter((e) => e !== ex);
|
|
10061
|
-
this.program.setExamples(exList);
|
|
10062
|
-
const aiService = this.teacherAI || this.ai;
|
|
10063
|
-
this.stats.totalCalls++;
|
|
10064
|
-
let res;
|
|
10065
|
-
let error;
|
|
10066
|
-
try {
|
|
10067
|
-
res = await this.program.forward(aiService, ex, aiOpt);
|
|
10068
|
-
if (this.costMonitoring) {
|
|
10069
|
-
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
10070
|
-
}
|
|
10071
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
10072
|
-
const success = score >= 0.5;
|
|
10073
|
-
if (success) {
|
|
10074
|
-
this.traces = [...this.traces, ...this.program.getTraces()];
|
|
10075
|
-
this.stats.successfulDemos++;
|
|
10076
|
-
}
|
|
10077
|
-
} catch (err) {
|
|
10078
|
-
error = err;
|
|
10079
|
-
res = {};
|
|
10080
|
-
}
|
|
10081
|
-
const current = i + examples.length * roundIndex + (batch.indexOf(ex) + 1);
|
|
10082
|
-
const total = examples.length * this.maxRounds;
|
|
10083
|
-
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
10084
|
-
if (this.verboseMode || this.debugMode) {
|
|
10085
|
-
const configInfo = {
|
|
10086
|
-
maxRounds: this.maxRounds,
|
|
10087
|
-
batchSize: this.batchSize,
|
|
10088
|
-
earlyStoppingPatience: this.earlyStoppingPatience,
|
|
10089
|
-
costMonitoring: this.costMonitoring,
|
|
10090
|
-
verboseMode: this.verboseMode,
|
|
10091
|
-
debugMode: this.debugMode
|
|
10092
|
-
};
|
|
10093
|
-
updateDetailedProgress(
|
|
10094
|
-
roundIndex,
|
|
10095
|
-
current,
|
|
10096
|
-
total,
|
|
10097
|
-
et,
|
|
10098
|
-
ex,
|
|
10099
|
-
this.stats,
|
|
10100
|
-
configInfo,
|
|
10101
|
-
res,
|
|
10102
|
-
error
|
|
10103
|
-
);
|
|
10104
|
-
} else {
|
|
10105
|
-
updateProgressBar(
|
|
10106
|
-
current,
|
|
10107
|
-
total,
|
|
10108
|
-
this.traces.length,
|
|
10109
|
-
et,
|
|
10110
|
-
"Tuning Prompt",
|
|
10111
|
-
30
|
|
10112
|
-
);
|
|
10113
|
-
}
|
|
10114
|
-
if (this.traces.length >= maxDemos) {
|
|
10115
|
-
return;
|
|
10116
|
-
}
|
|
10117
|
-
}
|
|
10118
|
-
}
|
|
10119
|
-
if (this.earlyStoppingPatience > 0) {
|
|
10120
|
-
const newSuccessCount = this.traces.length;
|
|
10121
|
-
const improvement = newSuccessCount - previousSuccessCount;
|
|
10122
|
-
if (!this.stats.earlyStopping) {
|
|
10123
|
-
this.stats.earlyStopping = {
|
|
10124
|
-
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
10125
|
-
patienceExhausted: false
|
|
10126
|
-
};
|
|
10127
|
-
} else if (improvement > 0) {
|
|
10128
|
-
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
10129
|
-
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
10130
|
-
this.stats.earlyStopping.patienceExhausted = true;
|
|
10131
|
-
this.stats.earlyStopped = true;
|
|
10132
|
-
if (this.verboseMode || this.debugMode) {
|
|
10133
|
-
console.log(
|
|
10134
|
-
`
|
|
10135
|
-
Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${this.earlyStoppingPatience} rounds.`
|
|
10136
|
-
);
|
|
10137
|
-
}
|
|
10138
|
-
return;
|
|
10139
|
-
}
|
|
10140
|
-
}
|
|
10141
|
-
}
|
|
10142
|
-
async compile(metricFn, options) {
|
|
10143
|
-
const maxRounds = options?.maxRounds ?? this.maxRounds;
|
|
10144
|
-
this.traces = [];
|
|
10145
|
-
this.stats = {
|
|
10146
|
-
totalCalls: 0,
|
|
10147
|
-
successfulDemos: 0,
|
|
10148
|
-
estimatedTokenUsage: 0,
|
|
10149
|
-
earlyStopped: false
|
|
10150
|
-
};
|
|
10151
|
-
for (let i = 0; i < maxRounds; i++) {
|
|
10152
|
-
await this.compileRound(i, metricFn, options);
|
|
10153
|
-
if (this.stats.earlyStopped) {
|
|
10154
|
-
break;
|
|
10155
|
-
}
|
|
10156
|
-
}
|
|
10157
|
-
if (this.traces.length === 0) {
|
|
10158
|
-
throw new Error(
|
|
10159
|
-
"No demonstrations found. Either provider more examples or improve the existing ones."
|
|
10160
|
-
);
|
|
10161
|
-
}
|
|
10162
|
-
const demos = groupTracesByKeys(this.traces);
|
|
10163
|
-
return {
|
|
10164
|
-
demos,
|
|
10165
|
-
stats: this.stats
|
|
10166
|
-
};
|
|
10167
|
-
}
|
|
10168
|
-
// Get optimization statistics
|
|
10169
|
-
getStats() {
|
|
10170
|
-
return this.stats;
|
|
10171
|
-
}
|
|
10172
|
-
};
|
|
10173
|
-
function groupTracesByKeys(programTraces) {
|
|
10174
|
-
const groupedTraces = /* @__PURE__ */ new Map();
|
|
10175
|
-
for (const programTrace of programTraces) {
|
|
10176
|
-
if (groupedTraces.has(programTrace.programId)) {
|
|
10177
|
-
const traces = groupedTraces.get(programTrace.programId);
|
|
10178
|
-
if (traces) {
|
|
10179
|
-
traces.push(programTrace.trace);
|
|
10180
|
-
}
|
|
10181
|
-
} else {
|
|
10182
|
-
groupedTraces.set(programTrace.programId, [programTrace.trace]);
|
|
10183
|
-
}
|
|
10184
|
-
}
|
|
10185
|
-
const programDemosArray = [];
|
|
10186
|
-
for (const [programId, traces] of groupedTraces.entries()) {
|
|
10187
|
-
programDemosArray.push({ traces, programId });
|
|
10188
|
-
}
|
|
10189
|
-
return programDemosArray;
|
|
10190
|
-
}
|
|
10191
|
-
var randomSample = (array, n) => {
|
|
10192
|
-
const clonedArray = [...array];
|
|
10193
|
-
for (let i = clonedArray.length - 1; i > 0; i--) {
|
|
10194
|
-
const j = Math.floor(Math.random() * (i + 1));
|
|
10195
|
-
const caI = clonedArray[i];
|
|
10196
|
-
const caJ = clonedArray[j];
|
|
10197
|
-
if (!caI || !caJ) {
|
|
10198
|
-
throw new Error("Invalid array elements");
|
|
10199
|
-
}
|
|
10200
|
-
;
|
|
10201
|
-
[clonedArray[i], clonedArray[j]] = [caJ, caI];
|
|
10202
|
-
}
|
|
10203
|
-
return clonedArray.slice(0, n);
|
|
10204
|
-
};
|
|
10205
|
-
|
|
10206
9987
|
// db/base.ts
|
|
10207
9988
|
import { SpanKind as SpanKind3 } from "@opentelemetry/api";
|
|
10208
9989
|
var AxDBBase = class {
|
|
@@ -11661,7 +11442,222 @@ var AxMCPStreambleHTTPTransport = class {
|
|
|
11661
11442
|
}
|
|
11662
11443
|
};
|
|
11663
11444
|
|
|
11664
|
-
// dsp/
|
|
11445
|
+
// dsp/optimizers/bootstrapFewshot.ts
|
|
11446
|
+
var AxBootstrapFewShot = class {
|
|
11447
|
+
ai;
|
|
11448
|
+
teacherAI;
|
|
11449
|
+
program;
|
|
11450
|
+
examples;
|
|
11451
|
+
maxRounds;
|
|
11452
|
+
maxDemos;
|
|
11453
|
+
maxExamples;
|
|
11454
|
+
batchSize;
|
|
11455
|
+
earlyStoppingPatience;
|
|
11456
|
+
costMonitoring;
|
|
11457
|
+
maxTokensPerGeneration;
|
|
11458
|
+
verboseMode;
|
|
11459
|
+
debugMode;
|
|
11460
|
+
traces = [];
|
|
11461
|
+
stats = {
|
|
11462
|
+
totalCalls: 0,
|
|
11463
|
+
successfulDemos: 0,
|
|
11464
|
+
estimatedTokenUsage: 0,
|
|
11465
|
+
earlyStopped: false
|
|
11466
|
+
};
|
|
11467
|
+
constructor({
|
|
11468
|
+
ai,
|
|
11469
|
+
program,
|
|
11470
|
+
examples = [],
|
|
11471
|
+
options
|
|
11472
|
+
}) {
|
|
11473
|
+
if (examples.length === 0) {
|
|
11474
|
+
throw new Error("No examples found");
|
|
11475
|
+
}
|
|
11476
|
+
const bootstrapOptions = options;
|
|
11477
|
+
this.maxRounds = bootstrapOptions?.maxRounds ?? 3;
|
|
11478
|
+
this.maxDemos = bootstrapOptions?.maxDemos ?? 4;
|
|
11479
|
+
this.maxExamples = bootstrapOptions?.maxExamples ?? 16;
|
|
11480
|
+
this.batchSize = bootstrapOptions?.batchSize ?? 1;
|
|
11481
|
+
this.earlyStoppingPatience = bootstrapOptions?.earlyStoppingPatience ?? 0;
|
|
11482
|
+
this.costMonitoring = bootstrapOptions?.costMonitoring ?? false;
|
|
11483
|
+
this.maxTokensPerGeneration = bootstrapOptions?.maxTokensPerGeneration ?? 0;
|
|
11484
|
+
this.verboseMode = bootstrapOptions?.verboseMode ?? true;
|
|
11485
|
+
this.debugMode = bootstrapOptions?.debugMode ?? false;
|
|
11486
|
+
this.ai = ai;
|
|
11487
|
+
this.teacherAI = bootstrapOptions?.teacherAI;
|
|
11488
|
+
this.program = program;
|
|
11489
|
+
this.examples = examples;
|
|
11490
|
+
}
|
|
11491
|
+
async compileRound(roundIndex, metricFn, options) {
|
|
11492
|
+
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
11493
|
+
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
11494
|
+
const aiOpt = {
|
|
11495
|
+
modelConfig: {
|
|
11496
|
+
temperature: 0.7
|
|
11497
|
+
}
|
|
11498
|
+
};
|
|
11499
|
+
if (this.maxTokensPerGeneration > 0) {
|
|
11500
|
+
aiOpt.modelConfig.max_tokens = this.maxTokensPerGeneration;
|
|
11501
|
+
}
|
|
11502
|
+
const examples = randomSample(this.examples, this.maxExamples);
|
|
11503
|
+
const previousSuccessCount = this.traces.length;
|
|
11504
|
+
for (let i = 0; i < examples.length; i += this.batchSize) {
|
|
11505
|
+
if (i > 0) {
|
|
11506
|
+
aiOpt.modelConfig.temperature = 0.7 + 1e-3 * i;
|
|
11507
|
+
}
|
|
11508
|
+
const batch = examples.slice(i, i + this.batchSize);
|
|
11509
|
+
for (const ex of batch) {
|
|
11510
|
+
if (!ex) {
|
|
11511
|
+
continue;
|
|
11512
|
+
}
|
|
11513
|
+
const exList = examples.filter((e) => e !== ex);
|
|
11514
|
+
this.program.setExamples(exList);
|
|
11515
|
+
const aiService = this.teacherAI || this.ai;
|
|
11516
|
+
this.stats.totalCalls++;
|
|
11517
|
+
let res;
|
|
11518
|
+
let error;
|
|
11519
|
+
try {
|
|
11520
|
+
res = await this.program.forward(aiService, ex, aiOpt);
|
|
11521
|
+
if (this.costMonitoring) {
|
|
11522
|
+
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
11523
|
+
}
|
|
11524
|
+
const score = metricFn({ prediction: res, example: ex });
|
|
11525
|
+
const success = score >= 0.5;
|
|
11526
|
+
if (success) {
|
|
11527
|
+
this.traces = [...this.traces, ...this.program.getTraces()];
|
|
11528
|
+
this.stats.successfulDemos++;
|
|
11529
|
+
}
|
|
11530
|
+
} catch (err) {
|
|
11531
|
+
error = err;
|
|
11532
|
+
res = {};
|
|
11533
|
+
}
|
|
11534
|
+
const current = i + examples.length * roundIndex + (batch.indexOf(ex) + 1);
|
|
11535
|
+
const total = examples.length * this.maxRounds;
|
|
11536
|
+
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
11537
|
+
if (this.verboseMode || this.debugMode) {
|
|
11538
|
+
const configInfo = {
|
|
11539
|
+
maxRounds: this.maxRounds,
|
|
11540
|
+
batchSize: this.batchSize,
|
|
11541
|
+
earlyStoppingPatience: this.earlyStoppingPatience,
|
|
11542
|
+
costMonitoring: this.costMonitoring,
|
|
11543
|
+
verboseMode: this.verboseMode,
|
|
11544
|
+
debugMode: this.debugMode
|
|
11545
|
+
};
|
|
11546
|
+
updateDetailedProgress(
|
|
11547
|
+
roundIndex,
|
|
11548
|
+
current,
|
|
11549
|
+
total,
|
|
11550
|
+
et,
|
|
11551
|
+
ex,
|
|
11552
|
+
this.stats,
|
|
11553
|
+
configInfo,
|
|
11554
|
+
res,
|
|
11555
|
+
error
|
|
11556
|
+
);
|
|
11557
|
+
} else {
|
|
11558
|
+
updateProgressBar(
|
|
11559
|
+
current,
|
|
11560
|
+
total,
|
|
11561
|
+
this.traces.length,
|
|
11562
|
+
et,
|
|
11563
|
+
"Tuning Prompt",
|
|
11564
|
+
30
|
|
11565
|
+
);
|
|
11566
|
+
}
|
|
11567
|
+
if (this.traces.length >= maxDemos) {
|
|
11568
|
+
return;
|
|
11569
|
+
}
|
|
11570
|
+
}
|
|
11571
|
+
}
|
|
11572
|
+
if (this.earlyStoppingPatience > 0) {
|
|
11573
|
+
const newSuccessCount = this.traces.length;
|
|
11574
|
+
const improvement = newSuccessCount - previousSuccessCount;
|
|
11575
|
+
if (!this.stats.earlyStopping) {
|
|
11576
|
+
this.stats.earlyStopping = {
|
|
11577
|
+
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
11578
|
+
patienceExhausted: false
|
|
11579
|
+
};
|
|
11580
|
+
} else if (improvement > 0) {
|
|
11581
|
+
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
11582
|
+
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
11583
|
+
this.stats.earlyStopping.patienceExhausted = true;
|
|
11584
|
+
this.stats.earlyStopped = true;
|
|
11585
|
+
if (this.verboseMode || this.debugMode) {
|
|
11586
|
+
console.log(
|
|
11587
|
+
`
|
|
11588
|
+
Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${this.earlyStoppingPatience} rounds.`
|
|
11589
|
+
);
|
|
11590
|
+
}
|
|
11591
|
+
return;
|
|
11592
|
+
}
|
|
11593
|
+
}
|
|
11594
|
+
}
|
|
11595
|
+
async compile(metricFn, options) {
|
|
11596
|
+
const compileOptions = options;
|
|
11597
|
+
const maxRounds = compileOptions?.maxRounds ?? this.maxRounds;
|
|
11598
|
+
this.traces = [];
|
|
11599
|
+
this.stats = {
|
|
11600
|
+
totalCalls: 0,
|
|
11601
|
+
successfulDemos: 0,
|
|
11602
|
+
estimatedTokenUsage: 0,
|
|
11603
|
+
earlyStopped: false
|
|
11604
|
+
};
|
|
11605
|
+
for (let i = 0; i < maxRounds; i++) {
|
|
11606
|
+
await this.compileRound(i, metricFn, compileOptions);
|
|
11607
|
+
if (this.stats.earlyStopped) {
|
|
11608
|
+
break;
|
|
11609
|
+
}
|
|
11610
|
+
}
|
|
11611
|
+
if (this.traces.length === 0) {
|
|
11612
|
+
throw new Error(
|
|
11613
|
+
"No demonstrations found. Either provider more examples or improve the existing ones."
|
|
11614
|
+
);
|
|
11615
|
+
}
|
|
11616
|
+
const demos = groupTracesByKeys(this.traces);
|
|
11617
|
+
return {
|
|
11618
|
+
demos,
|
|
11619
|
+
stats: this.stats
|
|
11620
|
+
};
|
|
11621
|
+
}
|
|
11622
|
+
// Get optimization statistics
|
|
11623
|
+
getStats() {
|
|
11624
|
+
return this.stats;
|
|
11625
|
+
}
|
|
11626
|
+
};
|
|
11627
|
+
function groupTracesByKeys(programTraces) {
|
|
11628
|
+
const groupedTraces = /* @__PURE__ */ new Map();
|
|
11629
|
+
for (const programTrace of programTraces) {
|
|
11630
|
+
if (groupedTraces.has(programTrace.programId)) {
|
|
11631
|
+
const traces = groupedTraces.get(programTrace.programId);
|
|
11632
|
+
if (traces) {
|
|
11633
|
+
traces.push(programTrace.trace);
|
|
11634
|
+
}
|
|
11635
|
+
} else {
|
|
11636
|
+
groupedTraces.set(programTrace.programId, [programTrace.trace]);
|
|
11637
|
+
}
|
|
11638
|
+
}
|
|
11639
|
+
const programDemosArray = [];
|
|
11640
|
+
for (const [programId, traces] of groupedTraces.entries()) {
|
|
11641
|
+
programDemosArray.push({ traces, programId });
|
|
11642
|
+
}
|
|
11643
|
+
return programDemosArray;
|
|
11644
|
+
}
|
|
11645
|
+
var randomSample = (array, n) => {
|
|
11646
|
+
const clonedArray = [...array];
|
|
11647
|
+
for (let i = clonedArray.length - 1; i > 0; i--) {
|
|
11648
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
11649
|
+
const caI = clonedArray[i];
|
|
11650
|
+
const caJ = clonedArray[j];
|
|
11651
|
+
if (!caI || !caJ) {
|
|
11652
|
+
throw new Error("Invalid array elements");
|
|
11653
|
+
}
|
|
11654
|
+
;
|
|
11655
|
+
[clonedArray[i], clonedArray[j]] = [caJ, caI];
|
|
11656
|
+
}
|
|
11657
|
+
return clonedArray.slice(0, n);
|
|
11658
|
+
};
|
|
11659
|
+
|
|
11660
|
+
// dsp/optimizers/miproV2.ts
|
|
11665
11661
|
var AxMiPRO = class {
|
|
11666
11662
|
ai;
|
|
11667
11663
|
program;
|
|
@@ -11887,7 +11883,7 @@ ${dataContext}
|
|
|
11887
11883
|
const result = await this.bootstrapper.compile(metricFn, {
|
|
11888
11884
|
maxDemos: this.maxBootstrappedDemos
|
|
11889
11885
|
});
|
|
11890
|
-
return result.demos;
|
|
11886
|
+
return result.demos || [];
|
|
11891
11887
|
}
|
|
11892
11888
|
/**
|
|
11893
11889
|
* Selects labeled examples directly from the training set
|
|
@@ -12221,21 +12217,22 @@ ${dataContext}
|
|
|
12221
12217
|
* The main compile method to run MIPROv2 optimization
|
|
12222
12218
|
* @param metricFn Evaluation metric function
|
|
12223
12219
|
* @param options Optional configuration options
|
|
12224
|
-
* @returns The
|
|
12220
|
+
* @returns The optimization result
|
|
12225
12221
|
*/
|
|
12226
12222
|
async compile(metricFn, options) {
|
|
12227
|
-
|
|
12228
|
-
|
|
12223
|
+
const miproOptions = options;
|
|
12224
|
+
if (miproOptions?.auto) {
|
|
12225
|
+
this.configureAuto(miproOptions.auto);
|
|
12229
12226
|
}
|
|
12230
12227
|
const trainset = this.examples;
|
|
12231
|
-
const valset =
|
|
12228
|
+
const valset = miproOptions?.valset || this.examples.slice(0, Math.floor(this.examples.length * 0.8));
|
|
12232
12229
|
if (this.verbose) {
|
|
12233
12230
|
console.log(`Starting MIPROv2 optimization with ${this.numTrials} trials`);
|
|
12234
12231
|
console.log(
|
|
12235
12232
|
`Using ${trainset.length} examples for training and ${valset.length} for validation`
|
|
12236
12233
|
);
|
|
12237
12234
|
}
|
|
12238
|
-
if (
|
|
12235
|
+
if (miproOptions?.teacher) {
|
|
12239
12236
|
if (this.verbose) {
|
|
12240
12237
|
console.log("Using provided teacher to assist with bootstrapping");
|
|
12241
12238
|
}
|
|
@@ -12292,7 +12289,17 @@ ${dataContext}
|
|
|
12292
12289
|
bootstrappedDemos,
|
|
12293
12290
|
labeledExamples
|
|
12294
12291
|
);
|
|
12295
|
-
return
|
|
12292
|
+
return {
|
|
12293
|
+
program: this.program,
|
|
12294
|
+
demos: bootstrappedDemos
|
|
12295
|
+
};
|
|
12296
|
+
}
|
|
12297
|
+
/**
|
|
12298
|
+
* Get optimization statistics from the internal bootstrapper
|
|
12299
|
+
* @returns Optimization statistics or undefined if not available
|
|
12300
|
+
*/
|
|
12301
|
+
getStats() {
|
|
12302
|
+
return this.bootstrapper.getStats();
|
|
12296
12303
|
}
|
|
12297
12304
|
};
|
|
12298
12305
|
|