@ax-llm/ax 12.0.7 → 12.0.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +1588 -688
- package/index.cjs.map +1 -1
- package/index.d.cts +536 -191
- package/index.d.ts +536 -191
- package/index.js +1586 -692
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.cjs
CHANGED
|
@@ -81,6 +81,7 @@ __export(index_exports, {
|
|
|
81
81
|
AxAssertionError: () => AxAssertionError,
|
|
82
82
|
AxBalancer: () => AxBalancer,
|
|
83
83
|
AxBaseAI: () => AxBaseAI,
|
|
84
|
+
AxBaseOptimizer: () => AxBaseOptimizer,
|
|
84
85
|
AxBootstrapFewShot: () => AxBootstrapFewShot,
|
|
85
86
|
AxChainOfThought: () => AxChainOfThought,
|
|
86
87
|
AxDB: () => AxDB,
|
|
@@ -90,6 +91,7 @@ __export(index_exports, {
|
|
|
90
91
|
AxDBMemory: () => AxDBMemory,
|
|
91
92
|
AxDBPinecone: () => AxDBPinecone,
|
|
92
93
|
AxDBWeaviate: () => AxDBWeaviate,
|
|
94
|
+
AxDefaultCostTracker: () => AxDefaultCostTracker,
|
|
93
95
|
AxDefaultQueryRewriter: () => AxDefaultQueryRewriter,
|
|
94
96
|
AxDefaultResultReranker: () => AxDefaultResultReranker,
|
|
95
97
|
AxDockerSession: () => AxDockerSession,
|
|
@@ -159,6 +161,10 @@ __export(index_exports, {
|
|
|
159
161
|
axAITogetherDefaultConfig: () => axAITogetherDefaultConfig,
|
|
160
162
|
axBaseAIDefaultConfig: () => axBaseAIDefaultConfig,
|
|
161
163
|
axBaseAIDefaultCreativeConfig: () => axBaseAIDefaultCreativeConfig,
|
|
164
|
+
axCreateDefaultLogger: () => axCreateDefaultLogger,
|
|
165
|
+
axCreateDefaultTextLogger: () => axCreateDefaultTextLogger,
|
|
166
|
+
axCreateOptimizerLogger: () => axCreateOptimizerLogger,
|
|
167
|
+
axDefaultOptimizerLogger: () => axDefaultOptimizerLogger,
|
|
162
168
|
axGlobals: () => axGlobals,
|
|
163
169
|
axModelInfoAnthropic: () => axModelInfoAnthropic,
|
|
164
170
|
axModelInfoCohere: () => axModelInfoCohere,
|
|
@@ -494,6 +500,17 @@ var AxAIServiceAuthenticationError = class extends AxAIServiceError {
|
|
|
494
500
|
this.name = this.constructor.name;
|
|
495
501
|
}
|
|
496
502
|
};
|
|
503
|
+
async function safeReadResponseBody(response) {
|
|
504
|
+
try {
|
|
505
|
+
if (response.headers.get("content-type")?.includes("application/json")) {
|
|
506
|
+
return await response.json();
|
|
507
|
+
}
|
|
508
|
+
const clonedResponse = response.clone();
|
|
509
|
+
return await clonedResponse.text();
|
|
510
|
+
} catch (e) {
|
|
511
|
+
return `[ReadableStream - read failed: ${e.message}]`;
|
|
512
|
+
}
|
|
513
|
+
}
|
|
497
514
|
function calculateRetryDelay(attempt, config) {
|
|
498
515
|
const delay = Math.min(
|
|
499
516
|
config.maxDelayMs,
|
|
@@ -587,9 +604,15 @@ var apiCall = async (api, json) => {
|
|
|
587
604
|
});
|
|
588
605
|
clearTimeout(timeoutId);
|
|
589
606
|
if (res.status === 401 || res.status === 403) {
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
607
|
+
const responseBody = await safeReadResponseBody(res);
|
|
608
|
+
throw new AxAIServiceAuthenticationError(
|
|
609
|
+
apiUrl.href,
|
|
610
|
+
json,
|
|
611
|
+
responseBody,
|
|
612
|
+
{
|
|
613
|
+
metrics
|
|
614
|
+
}
|
|
615
|
+
);
|
|
593
616
|
}
|
|
594
617
|
if (res.status >= 400 && shouldRetry(new Error(), res.status, attempt, retryConfig)) {
|
|
595
618
|
const delay = calculateRetryDelay(attempt, retryConfig);
|
|
@@ -607,12 +630,13 @@ var apiCall = async (api, json) => {
|
|
|
607
630
|
continue;
|
|
608
631
|
}
|
|
609
632
|
if (res.status >= 400) {
|
|
633
|
+
const responseBody = await safeReadResponseBody(res);
|
|
610
634
|
throw new AxAIServiceStatusError(
|
|
611
635
|
res.status,
|
|
612
636
|
res.statusText,
|
|
613
637
|
apiUrl.href,
|
|
614
638
|
json,
|
|
615
|
-
|
|
639
|
+
responseBody,
|
|
616
640
|
{ metrics }
|
|
617
641
|
);
|
|
618
642
|
}
|
|
@@ -710,7 +734,7 @@ var apiCall = async (api, json) => {
|
|
|
710
734
|
error,
|
|
711
735
|
apiUrl.href,
|
|
712
736
|
json,
|
|
713
|
-
|
|
737
|
+
"[ReadableStream - consumed during streaming]",
|
|
714
738
|
{
|
|
715
739
|
streamMetrics
|
|
716
740
|
}
|
|
@@ -839,12 +863,12 @@ var ColorLog = class {
|
|
|
839
863
|
}
|
|
840
864
|
};
|
|
841
865
|
|
|
842
|
-
//
|
|
866
|
+
// dsp/loggers.ts
|
|
843
867
|
var colorLog = new ColorLog();
|
|
844
868
|
var defaultOutput = (message) => {
|
|
845
869
|
process.stdout.write(message);
|
|
846
870
|
};
|
|
847
|
-
var
|
|
871
|
+
var axCreateDefaultLogger = (output = defaultOutput) => {
|
|
848
872
|
return (message, options) => {
|
|
849
873
|
const tags = options?.tags ?? [];
|
|
850
874
|
let formattedMessage = message;
|
|
@@ -853,12 +877,44 @@ var createDefaultLogger = (output = defaultOutput) => {
|
|
|
853
877
|
} else if (tags.includes("success") || tags.includes("responseContent")) {
|
|
854
878
|
formattedMessage = colorLog.greenBright(formattedMessage);
|
|
855
879
|
} else if (tags.includes("functionName")) {
|
|
856
|
-
|
|
857
|
-
|
|
880
|
+
if (tags.includes("firstFunction")) {
|
|
881
|
+
formattedMessage = `
|
|
882
|
+
${colorLog.whiteBright(formattedMessage)}`;
|
|
883
|
+
} else {
|
|
884
|
+
formattedMessage = `${colorLog.whiteBright(formattedMessage)}`;
|
|
885
|
+
}
|
|
886
|
+
} else if (tags.includes("systemContent") || tags.includes("assistantContent")) {
|
|
858
887
|
formattedMessage = colorLog.blueBright(formattedMessage);
|
|
859
888
|
} else if (tags.includes("warning") || tags.includes("discovery")) {
|
|
860
889
|
formattedMessage = colorLog.yellow(formattedMessage);
|
|
890
|
+
} else if (tags.includes("functionArg")) {
|
|
891
|
+
formattedMessage = "";
|
|
892
|
+
}
|
|
893
|
+
if (tags.includes("responseStart") || tags.includes("systemStart") || tags.includes("userStart")) {
|
|
894
|
+
formattedMessage = `
|
|
895
|
+
${formattedMessage}`;
|
|
896
|
+
} else if (tags.includes("responseEnd") || tags.includes("systemEnd") || tags.includes("userEnd")) {
|
|
897
|
+
formattedMessage = `${formattedMessage}
|
|
898
|
+
`;
|
|
899
|
+
} else if (tags.includes("assistantStart")) {
|
|
900
|
+
formattedMessage = `
|
|
901
|
+
${formattedMessage}
|
|
902
|
+
`;
|
|
903
|
+
} else if (tags.includes("error")) {
|
|
904
|
+
formattedMessage = `
|
|
905
|
+
${formattedMessage}
|
|
906
|
+
`;
|
|
907
|
+
} else if (tags.includes("functionEnd")) {
|
|
908
|
+
formattedMessage = `
|
|
909
|
+
`;
|
|
861
910
|
}
|
|
911
|
+
output(formattedMessage);
|
|
912
|
+
};
|
|
913
|
+
};
|
|
914
|
+
var axCreateDefaultTextLogger = (output = defaultOutput) => {
|
|
915
|
+
return (message, options) => {
|
|
916
|
+
const tags = options?.tags ?? [];
|
|
917
|
+
let formattedMessage = message;
|
|
862
918
|
if (tags.includes("responseStart") || tags.includes("systemStart") || tags.includes("userStart")) {
|
|
863
919
|
formattedMessage = `
|
|
864
920
|
${formattedMessage}`;
|
|
@@ -880,7 +936,137 @@ ${formattedMessage}
|
|
|
880
936
|
output(formattedMessage);
|
|
881
937
|
};
|
|
882
938
|
};
|
|
883
|
-
var
|
|
939
|
+
var axCreateOptimizerLogger = (output = (msg) => process.stdout.write(msg)) => {
|
|
940
|
+
const baseLogger = axCreateDefaultLogger(output);
|
|
941
|
+
let isFirstPhase = true;
|
|
942
|
+
return (message, options) => {
|
|
943
|
+
const tags = options?.tags ?? [];
|
|
944
|
+
let formattedMessage = message;
|
|
945
|
+
if (tags.includes("optimizer")) {
|
|
946
|
+
if (tags.includes("start")) {
|
|
947
|
+
const trialsMatch = message.match(/with (\d+) trials?/) || message.match(/(\d+) trials?/);
|
|
948
|
+
const optimizerMatch = message.match(
|
|
949
|
+
/(MIPROv2|BootstrapFewshot|[A-Z][a-zA-Z]+)/
|
|
950
|
+
);
|
|
951
|
+
const optimizerName = optimizerMatch ? optimizerMatch[1] : "Optimizer";
|
|
952
|
+
if (trialsMatch && trialsMatch[1]) {
|
|
953
|
+
formattedMessage = `
|
|
954
|
+
\u250C\u2500 ${optimizerName} optimization (${trialsMatch[1]} trials)
|
|
955
|
+
`;
|
|
956
|
+
} else {
|
|
957
|
+
formattedMessage = `
|
|
958
|
+
\u250C\u2500 ${optimizerName} optimization
|
|
959
|
+
`;
|
|
960
|
+
}
|
|
961
|
+
isFirstPhase = true;
|
|
962
|
+
} else if (tags.includes("config")) {
|
|
963
|
+
if (message.includes("examples") && message.includes("training")) {
|
|
964
|
+
const match = message.match(
|
|
965
|
+
/(\d+) examples for training and (\d+) for validation/
|
|
966
|
+
) || message.match(/(\d+) training.*?(\d+) validation/);
|
|
967
|
+
if (match && match[1] && match[2]) {
|
|
968
|
+
formattedMessage = `\u2502 Dataset: ${match[1]} training, ${match[2]} validation
|
|
969
|
+
`;
|
|
970
|
+
} else {
|
|
971
|
+
const simpleMatch = message.match(/(\d+) examples/);
|
|
972
|
+
if (simpleMatch && simpleMatch[1]) {
|
|
973
|
+
formattedMessage = `\u2502 Dataset: ${simpleMatch[1]} examples
|
|
974
|
+
`;
|
|
975
|
+
}
|
|
976
|
+
}
|
|
977
|
+
} else if (message.includes("teacher")) {
|
|
978
|
+
formattedMessage = `\u2502 Using teacher model
|
|
979
|
+
`;
|
|
980
|
+
} else {
|
|
981
|
+
formattedMessage = `\u2502 ${message}
|
|
982
|
+
`;
|
|
983
|
+
}
|
|
984
|
+
} else if (tags.includes("phase")) {
|
|
985
|
+
if (isFirstPhase) {
|
|
986
|
+
formattedMessage = `\u251C\u2500 ${message}
|
|
987
|
+
`;
|
|
988
|
+
isFirstPhase = false;
|
|
989
|
+
} else {
|
|
990
|
+
formattedMessage = `\u251C\u2500 ${message}
|
|
991
|
+
`;
|
|
992
|
+
}
|
|
993
|
+
} else if (tags.includes("result")) {
|
|
994
|
+
if (message.includes("Generated") || message.includes("Selected")) {
|
|
995
|
+
const match = message.match(/(\d+)/);
|
|
996
|
+
if (match && match[1]) {
|
|
997
|
+
formattedMessage = `\u2502 \u2713 ${message}
|
|
998
|
+
`;
|
|
999
|
+
} else {
|
|
1000
|
+
formattedMessage = `\u2502 \u2713 ${message}
|
|
1001
|
+
`;
|
|
1002
|
+
}
|
|
1003
|
+
} else if (message.includes("configuration")) {
|
|
1004
|
+
formattedMessage = `\u2502 Applied best configuration
|
|
1005
|
+
`;
|
|
1006
|
+
} else {
|
|
1007
|
+
formattedMessage = `\u2502 ${message}
|
|
1008
|
+
`;
|
|
1009
|
+
}
|
|
1010
|
+
} else if (tags.includes("progress")) {
|
|
1011
|
+
formattedMessage = `\u2502 ${message}
|
|
1012
|
+
`;
|
|
1013
|
+
} else if (tags.includes("complete")) {
|
|
1014
|
+
const scoreMatch = message.match(/(score|performance):\s*([\d.]+)/);
|
|
1015
|
+
if (scoreMatch && scoreMatch[2]) {
|
|
1016
|
+
const score = parseFloat(scoreMatch[2]);
|
|
1017
|
+
const percentage = score <= 1 ? (score * 100).toFixed(1) + "%" : score.toFixed(3);
|
|
1018
|
+
formattedMessage = `\u251C\u2500 Complete! Best: ${percentage}
|
|
1019
|
+
`;
|
|
1020
|
+
} else if (message.includes("Bootstrap")) {
|
|
1021
|
+
formattedMessage = `\u251C\u2500 ${message}
|
|
1022
|
+
`;
|
|
1023
|
+
} else {
|
|
1024
|
+
formattedMessage = `\u251C\u2500 Optimization complete
|
|
1025
|
+
`;
|
|
1026
|
+
}
|
|
1027
|
+
} else if (tags.includes("checkpoint")) {
|
|
1028
|
+
if (message.includes("Resuming")) {
|
|
1029
|
+
formattedMessage = `\u2502 ${message}
|
|
1030
|
+
`;
|
|
1031
|
+
} else {
|
|
1032
|
+
const match = message.match(/checkpoint:\s*(.+)/) || message.match(/Saved\s+(.+)/);
|
|
1033
|
+
if (match && match[1]) {
|
|
1034
|
+
formattedMessage = `\u2514\u2500 Saved: ${match[1]}
|
|
1035
|
+
`;
|
|
1036
|
+
} else {
|
|
1037
|
+
formattedMessage = `\u2514\u2500 Checkpoint saved
|
|
1038
|
+
`;
|
|
1039
|
+
}
|
|
1040
|
+
}
|
|
1041
|
+
}
|
|
1042
|
+
} else if (tags.includes("discovery")) {
|
|
1043
|
+
if (message.includes("Found") && message.includes("examples")) {
|
|
1044
|
+
const match = message.match(/Found (\d+)/);
|
|
1045
|
+
if (match && match[1]) {
|
|
1046
|
+
formattedMessage = `\u2502 Found ${match[1]} examples
|
|
1047
|
+
`;
|
|
1048
|
+
}
|
|
1049
|
+
}
|
|
1050
|
+
}
|
|
1051
|
+
if (tags.includes("error")) {
|
|
1052
|
+
formattedMessage = `
|
|
1053
|
+
\u2717 ${message}
|
|
1054
|
+
`;
|
|
1055
|
+
} else if (tags.includes("warning")) {
|
|
1056
|
+
formattedMessage = `
|
|
1057
|
+
\u26A0 ${message}
|
|
1058
|
+
`;
|
|
1059
|
+
} else if (tags.includes("success") && !tags.includes("optimizer")) {
|
|
1060
|
+
formattedMessage = `\u2713 ${message}
|
|
1061
|
+
`;
|
|
1062
|
+
}
|
|
1063
|
+
baseLogger(formattedMessage, options);
|
|
1064
|
+
};
|
|
1065
|
+
};
|
|
1066
|
+
var axDefaultOptimizerLogger = axCreateOptimizerLogger();
|
|
1067
|
+
|
|
1068
|
+
// ai/debug.ts
|
|
1069
|
+
var defaultLogger = axCreateDefaultLogger();
|
|
884
1070
|
var formatChatMessage = (msg, hideContent, hideSystemPrompt) => {
|
|
885
1071
|
switch (msg.role) {
|
|
886
1072
|
case "system":
|
|
@@ -957,9 +1143,14 @@ var logResponseResult = (r, logger = defaultLogger) => {
|
|
|
957
1143
|
if (r.functionCalls && r.functionCalls.length > 0) {
|
|
958
1144
|
for (const [i, f2] of r.functionCalls.entries()) {
|
|
959
1145
|
if (f2.function.name) {
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
1146
|
+
const tags = ["functionName"];
|
|
1147
|
+
if (i === 0) {
|
|
1148
|
+
tags.push("firstFunction");
|
|
1149
|
+
}
|
|
1150
|
+
if (r.functionCalls.length > 1) {
|
|
1151
|
+
tags.push("multipleFunctions");
|
|
1152
|
+
}
|
|
1153
|
+
logger(`[${i + 1}] ${f2.function.name}`, { tags });
|
|
963
1154
|
}
|
|
964
1155
|
if (f2.function.params) {
|
|
965
1156
|
const params = typeof f2.function.params === "string" ? f2.function.params : JSON.stringify(f2.function.params, null, 2);
|
|
@@ -1630,6 +1821,16 @@ function validateAxMessageArray(values) {
|
|
|
1630
1821
|
function validateChatPrompt(chatPrompt) {
|
|
1631
1822
|
for (let i = 0; i < chatPrompt.length; i++) {
|
|
1632
1823
|
const message = chatPrompt[i];
|
|
1824
|
+
if (message && "functionCalls" in message && Array.isArray(message.functionCalls) && message.functionCalls.length === 0) {
|
|
1825
|
+
throw new Error(
|
|
1826
|
+
`Chat prompt validation failed: Message at index ${i} has empty functionCalls`
|
|
1827
|
+
);
|
|
1828
|
+
}
|
|
1829
|
+
if (message && "content" in message && Array.isArray(message.content) && message.content.length === 0) {
|
|
1830
|
+
throw new Error(
|
|
1831
|
+
`Chat prompt validation failed: Message at index ${i} has empty content`
|
|
1832
|
+
);
|
|
1833
|
+
}
|
|
1633
1834
|
if (message && "content" in message && typeof message.content === "string" && message.content.trim() === "") {
|
|
1634
1835
|
throw new Error(
|
|
1635
1836
|
`Chat prompt validation failed: Message at index ${i} has empty content`
|
|
@@ -5953,9 +6154,11 @@ var updateProgressBar = (current, total, success, elapsedTime, msg, progressBarW
|
|
|
5953
6154
|
const emptyBarLength = progressBarWidth - filledBarLength;
|
|
5954
6155
|
const filledBar = colorLog3.blueBright("\u2588".repeat(filledBarLength));
|
|
5955
6156
|
const emptyBar = " ".repeat(emptyBarLength);
|
|
5956
|
-
const
|
|
6157
|
+
const successRate = total > 0 ? (success / total * 100).toFixed(1) : "0.0";
|
|
6158
|
+
const friendlyMsg = msg.includes("Running MIPROv2 optimization") ? "Testing prompt variations" : msg.includes("Tuning Prompt") ? "Generating training examples" : msg;
|
|
5957
6159
|
process.stdout.write(
|
|
5958
|
-
`\
|
|
6160
|
+
`\u2502 ${friendlyMsg}: ${current}/${total} (${colorLog3.yellow(percentage)}%) |${filledBar}${emptyBar}| Success rate: ${colorLog3.greenBright(successRate)}%
|
|
6161
|
+
`
|
|
5959
6162
|
);
|
|
5960
6163
|
};
|
|
5961
6164
|
var validateValue = (field, value) => {
|
|
@@ -6162,19 +6365,15 @@ function matchesContent(content, prefix, startIndex = 0, prefixCache = globalPre
|
|
|
6162
6365
|
if (!prefixCache.get(prefix)) {
|
|
6163
6366
|
prefixCache.set(prefix, prefixes);
|
|
6164
6367
|
}
|
|
6165
|
-
|
|
6166
|
-
|
|
6167
|
-
);
|
|
6168
|
-
for (let i = 0; i < prefixes.length - 1; i++) {
|
|
6368
|
+
let longestPartialMatch = -1;
|
|
6369
|
+
for (let i = prefixes.length - 1; i >= 0; i--) {
|
|
6169
6370
|
const partialPrefix = prefixes[i];
|
|
6170
|
-
if (partialPrefix
|
|
6171
|
-
|
|
6172
|
-
|
|
6173
|
-
if (partialPrefix && contentEnd.endsWith(partialPrefix)) {
|
|
6174
|
-
return -2;
|
|
6371
|
+
if (content.endsWith(partialPrefix)) {
|
|
6372
|
+
longestPartialMatch = i;
|
|
6373
|
+
break;
|
|
6175
6374
|
}
|
|
6176
6375
|
}
|
|
6177
|
-
return -1;
|
|
6376
|
+
return longestPartialMatch >= 0 ? -2 : -1;
|
|
6178
6377
|
}
|
|
6179
6378
|
var formatTime = (ms) => {
|
|
6180
6379
|
const seconds = Math.floor(ms / 1e3);
|
|
@@ -6197,11 +6396,10 @@ var updateDetailedProgress = (roundIndex, current, total, elapsedTime, example,
|
|
|
6197
6396
|
process.stdout.write("\r\x1B[K");
|
|
6198
6397
|
const percentage = (current / total * 100).toFixed(1);
|
|
6199
6398
|
const formattedTime = formatTime(elapsedTime);
|
|
6200
|
-
const itemsPerSecond = elapsedTime > 0 ? (current / elapsedTime * 1e3).toFixed(2) : "0.00";
|
|
6201
6399
|
const eta = calculateETA(current, total, elapsedTime);
|
|
6202
|
-
let output = `
|
|
6400
|
+
let output = `Training round ${roundIndex + 1}/${configInfo.maxRounds}: ${current}/${total} (${percentage}%) [${formattedTime}, ETA: ${eta}]`;
|
|
6203
6401
|
const successRate = stats.totalCalls > 0 ? stats.successfulDemos / stats.totalCalls * 100 : 0;
|
|
6204
|
-
output += ` | Success: ${stats.successfulDemos}/${stats.totalCalls}
|
|
6402
|
+
output += ` | Success rate: ${successRate.toFixed(1)}% (${stats.successfulDemos}/${stats.totalCalls})`;
|
|
6205
6403
|
if (configInfo.verboseMode || configInfo.debugMode) {
|
|
6206
6404
|
if (configInfo.costMonitoring) {
|
|
6207
6405
|
output += `
|
|
@@ -6327,7 +6525,7 @@ ${outputFields}`);
|
|
|
6327
6525
|
content: systemContent
|
|
6328
6526
|
};
|
|
6329
6527
|
if (Array.isArray(values)) {
|
|
6330
|
-
let
|
|
6528
|
+
let messages = [];
|
|
6331
6529
|
const history = values;
|
|
6332
6530
|
for (const [index, message] of history.entries()) {
|
|
6333
6531
|
let content;
|
|
@@ -6347,7 +6545,7 @@ ${outputFields}`);
|
|
|
6347
6545
|
);
|
|
6348
6546
|
}
|
|
6349
6547
|
if (message.role === "user") {
|
|
6350
|
-
|
|
6548
|
+
messages.push({ role: "user", content });
|
|
6351
6549
|
continue;
|
|
6352
6550
|
}
|
|
6353
6551
|
if (message.role !== "assistant") {
|
|
@@ -6358,9 +6556,9 @@ ${outputFields}`);
|
|
|
6358
6556
|
"Assistant message cannot contain non-text content like images, files,etc"
|
|
6359
6557
|
);
|
|
6360
6558
|
}
|
|
6361
|
-
|
|
6559
|
+
messages.push({ role: "assistant", content });
|
|
6362
6560
|
}
|
|
6363
|
-
return [systemPrompt, ...
|
|
6561
|
+
return [systemPrompt, ...messages];
|
|
6364
6562
|
}
|
|
6365
6563
|
const userContent = this.renderSingleValueUserContent(
|
|
6366
6564
|
values,
|
|
@@ -6813,9 +7011,9 @@ var formatDateWithTimezone = (date) => {
|
|
|
6813
7011
|
};
|
|
6814
7012
|
|
|
6815
7013
|
// dsp/extract.ts
|
|
6816
|
-
var extractValues = (sig, values, content) => {
|
|
7014
|
+
var extractValues = (sig, values, content, strictMode = false) => {
|
|
6817
7015
|
const xstate = { extractedFields: [], streamedIndex: {}, s: -1 };
|
|
6818
|
-
streamingExtractValues(sig, values, xstate, content);
|
|
7016
|
+
streamingExtractValues(sig, values, xstate, content, strictMode);
|
|
6819
7017
|
streamingExtractFinalValue(sig, values, xstate, content);
|
|
6820
7018
|
for (const field of sig.getOutputFields()) {
|
|
6821
7019
|
if (field.isInternal) {
|
|
@@ -6823,10 +7021,9 @@ var extractValues = (sig, values, content) => {
|
|
|
6823
7021
|
}
|
|
6824
7022
|
}
|
|
6825
7023
|
};
|
|
6826
|
-
var checkMissingRequiredFields = (xstate, values,
|
|
7024
|
+
var checkMissingRequiredFields = (xstate, values, outputFields) => {
|
|
6827
7025
|
const missingFields = [];
|
|
6828
|
-
for (
|
|
6829
|
-
const field = xstate.extractedFields[i];
|
|
7026
|
+
for (const field of outputFields) {
|
|
6830
7027
|
if (field && !field.isOptional && values[field.name] === void 0) {
|
|
6831
7028
|
missingFields.push(field);
|
|
6832
7029
|
}
|
|
@@ -6838,23 +7035,34 @@ var checkMissingRequiredFields = (xstate, values, currentIndex) => {
|
|
|
6838
7035
|
});
|
|
6839
7036
|
}
|
|
6840
7037
|
};
|
|
6841
|
-
var streamingExtractValues = (sig, values, xstate, content,
|
|
7038
|
+
var streamingExtractValues = (sig, values, xstate, content, strictMode = false) => {
|
|
6842
7039
|
const fields = sig.getOutputFields();
|
|
7040
|
+
let expectedField;
|
|
6843
7041
|
for (const [index, field] of fields.entries()) {
|
|
7042
|
+
if (index === xstate.currFieldIndex) {
|
|
7043
|
+
continue;
|
|
7044
|
+
}
|
|
6844
7045
|
if (field.name in values) {
|
|
6845
7046
|
continue;
|
|
6846
7047
|
}
|
|
6847
7048
|
const isFirst = xstate.extractedFields.length === 0;
|
|
6848
7049
|
const prefix = (isFirst ? "" : "\n") + field.title + ":";
|
|
6849
7050
|
let e = matchesContent(content, prefix, xstate.s);
|
|
7051
|
+
let prefixLen = prefix.length;
|
|
6850
7052
|
switch (e) {
|
|
6851
7053
|
case -1:
|
|
6852
|
-
if (
|
|
7054
|
+
if (!strictMode && fields.length === 1 && xstate.currField === void 0) {
|
|
7055
|
+
prefixLen = 0;
|
|
7056
|
+
e = 0;
|
|
7057
|
+
break;
|
|
7058
|
+
}
|
|
7059
|
+
if (xstate.currField === void 0 && !field.isOptional) {
|
|
6853
7060
|
throw new ValidationError({
|
|
6854
|
-
message: "Required field not found",
|
|
7061
|
+
message: "Expected (Required) field not found",
|
|
6855
7062
|
fields: [field]
|
|
6856
7063
|
});
|
|
6857
7064
|
}
|
|
7065
|
+
expectedField = field.isOptional ? void 0 : field;
|
|
6858
7066
|
continue;
|
|
6859
7067
|
// Field is not found, continue to the next field
|
|
6860
7068
|
case -2:
|
|
@@ -6867,7 +7075,12 @@ var streamingExtractValues = (sig, values, xstate, content, streamingValidation
|
|
|
6867
7075
|
xstate.inBlock = true;
|
|
6868
7076
|
return true;
|
|
6869
7077
|
}
|
|
6870
|
-
|
|
7078
|
+
if (expectedField && expectedField.name !== field.name) {
|
|
7079
|
+
throw new ValidationError({
|
|
7080
|
+
message: "Expected (Required) field not found",
|
|
7081
|
+
fields: [expectedField]
|
|
7082
|
+
});
|
|
7083
|
+
}
|
|
6871
7084
|
if (xstate.currField) {
|
|
6872
7085
|
const val = content.substring(xstate.s, e).trim();
|
|
6873
7086
|
const parsedValue = validateAndParseFieldValue(xstate.currField, val);
|
|
@@ -6880,7 +7093,6 @@ var streamingExtractValues = (sig, values, xstate, content, streamingValidation
|
|
|
6880
7093
|
xstate.prevFields = [{ field: xstate.currField, s: xstate.s, e }];
|
|
6881
7094
|
}
|
|
6882
7095
|
}
|
|
6883
|
-
checkMissingRequiredFields(xstate, values, index);
|
|
6884
7096
|
xstate.s = e + prefixLen;
|
|
6885
7097
|
xstate.currField = field;
|
|
6886
7098
|
xstate.currFieldIndex = index;
|
|
@@ -6900,8 +7112,7 @@ var streamingExtractFinalValue = (sig, values, xstate, content) => {
|
|
|
6900
7112
|
values[xstate.currField.name] = parsedValue;
|
|
6901
7113
|
}
|
|
6902
7114
|
}
|
|
6903
|
-
|
|
6904
|
-
checkMissingRequiredFields(xstate, values, sigFields.length);
|
|
7115
|
+
checkMissingRequiredFields(xstate, values, sig.getOutputFields());
|
|
6905
7116
|
};
|
|
6906
7117
|
var convertValueToType = (field, val, required = false) => {
|
|
6907
7118
|
switch (field.type?.name) {
|
|
@@ -7538,8 +7749,9 @@ var AxInstanceRegistry = class {
|
|
|
7538
7749
|
this.reg.add(instance);
|
|
7539
7750
|
}
|
|
7540
7751
|
*[Symbol.iterator]() {
|
|
7541
|
-
|
|
7542
|
-
|
|
7752
|
+
const items = Array.from(this.reg);
|
|
7753
|
+
for (let i = 0; i < items.length; i++) {
|
|
7754
|
+
yield items[i];
|
|
7543
7755
|
}
|
|
7544
7756
|
}
|
|
7545
7757
|
};
|
|
@@ -8477,7 +8689,7 @@ var AxSignature = class _AxSignature {
|
|
|
8477
8689
|
this.getOutputFields().forEach((field) => {
|
|
8478
8690
|
validateField(field, "output");
|
|
8479
8691
|
});
|
|
8480
|
-
this.sigHash = (0, import_crypto3.createHash)("sha256").update(
|
|
8692
|
+
this.sigHash = (0, import_crypto3.createHash)("sha256").update(JSON.stringify(this.inputFields)).update(JSON.stringify(this.outputFields)).digest("hex");
|
|
8481
8693
|
this.sigString = renderSignature(
|
|
8482
8694
|
this.description,
|
|
8483
8695
|
this.inputFields,
|
|
@@ -8798,7 +9010,7 @@ var AxProgramWithSignature = class {
|
|
|
8798
9010
|
this.signature.validate();
|
|
8799
9011
|
this.sigHash = this.signature?.hash();
|
|
8800
9012
|
this.children = new AxInstanceRegistry();
|
|
8801
|
-
this.key = { id: this.
|
|
9013
|
+
this.key = { id: this.signature.hash() };
|
|
8802
9014
|
}
|
|
8803
9015
|
getSignature() {
|
|
8804
9016
|
return this.signature;
|
|
@@ -8818,8 +9030,8 @@ var AxProgramWithSignature = class {
|
|
|
8818
9030
|
}
|
|
8819
9031
|
setId(id) {
|
|
8820
9032
|
this.key = { id, custom: true };
|
|
8821
|
-
for (const child of this.children) {
|
|
8822
|
-
child
|
|
9033
|
+
for (const child of Array.from(this.children)) {
|
|
9034
|
+
child?.setParentId(id);
|
|
8823
9035
|
}
|
|
8824
9036
|
}
|
|
8825
9037
|
setParentId(parentId) {
|
|
@@ -8832,8 +9044,8 @@ var AxProgramWithSignature = class {
|
|
|
8832
9044
|
if (!("programId" in examples)) {
|
|
8833
9045
|
return;
|
|
8834
9046
|
}
|
|
8835
|
-
for (const child of this.children) {
|
|
8836
|
-
child
|
|
9047
|
+
for (const child of Array.from(this.children)) {
|
|
9048
|
+
child?.setExamples(examples, options);
|
|
8837
9049
|
}
|
|
8838
9050
|
}
|
|
8839
9051
|
_setExamples(examples, options) {
|
|
@@ -8866,30 +9078,37 @@ var AxProgramWithSignature = class {
|
|
|
8866
9078
|
if (this.trace) {
|
|
8867
9079
|
traces.push({ trace: this.trace, programId: this.key.id });
|
|
8868
9080
|
}
|
|
8869
|
-
for (const child of this.children) {
|
|
8870
|
-
const _traces = child
|
|
8871
|
-
traces = [...traces, ..._traces];
|
|
9081
|
+
for (const child of Array.from(this.children)) {
|
|
9082
|
+
const _traces = child?.getTraces();
|
|
9083
|
+
traces = [...traces, ..._traces ?? []];
|
|
8872
9084
|
}
|
|
8873
9085
|
return traces;
|
|
8874
9086
|
}
|
|
8875
9087
|
getUsage() {
|
|
8876
9088
|
let usage = [...this.usage ?? []];
|
|
8877
|
-
for (const child of this.children) {
|
|
8878
|
-
const cu = child
|
|
8879
|
-
usage = [...usage, ...cu];
|
|
9089
|
+
for (const child of Array.from(this.children)) {
|
|
9090
|
+
const cu = child?.getUsage();
|
|
9091
|
+
usage = [...usage, ...cu ?? []];
|
|
8880
9092
|
}
|
|
8881
9093
|
return mergeProgramUsage(usage);
|
|
8882
9094
|
}
|
|
8883
9095
|
resetUsage() {
|
|
8884
9096
|
this.usage = [];
|
|
8885
|
-
for (const child of this.children) {
|
|
8886
|
-
child
|
|
9097
|
+
for (const child of Array.from(this.children)) {
|
|
9098
|
+
child?.resetUsage();
|
|
8887
9099
|
}
|
|
8888
9100
|
}
|
|
8889
9101
|
setDemos(demos) {
|
|
9102
|
+
const hasChildren = Array.from(this.children).length > 0;
|
|
9103
|
+
const hasMatchingDemo = demos.some((demo) => demo.programId === this.key.id);
|
|
9104
|
+
if (hasChildren && !hasMatchingDemo) {
|
|
9105
|
+
throw new Error(
|
|
9106
|
+
`Program with id '${this.key.id}' has children but no matching programId found in demos`
|
|
9107
|
+
);
|
|
9108
|
+
}
|
|
8890
9109
|
this.demos = demos.filter((v) => v.programId === this.key.id).map((v) => v.traces).flat();
|
|
8891
|
-
for (const child of this.children) {
|
|
8892
|
-
child
|
|
9110
|
+
for (const child of Array.from(this.children)) {
|
|
9111
|
+
child?.setDemos(demos);
|
|
8893
9112
|
}
|
|
8894
9113
|
}
|
|
8895
9114
|
};
|
|
@@ -8917,8 +9136,8 @@ var AxProgram = class {
|
|
|
8917
9136
|
}
|
|
8918
9137
|
setId(id) {
|
|
8919
9138
|
this.key = { id, custom: true };
|
|
8920
|
-
for (const child of this.children) {
|
|
8921
|
-
child
|
|
9139
|
+
for (const child of Array.from(this.children)) {
|
|
9140
|
+
child?.setParentId(id);
|
|
8922
9141
|
}
|
|
8923
9142
|
}
|
|
8924
9143
|
setParentId(parentId) {
|
|
@@ -8930,8 +9149,8 @@ var AxProgram = class {
|
|
|
8930
9149
|
if (!("programId" in examples)) {
|
|
8931
9150
|
return;
|
|
8932
9151
|
}
|
|
8933
|
-
for (const child of this.children) {
|
|
8934
|
-
child
|
|
9152
|
+
for (const child of Array.from(this.children)) {
|
|
9153
|
+
child?.setExamples(examples, options);
|
|
8935
9154
|
}
|
|
8936
9155
|
}
|
|
8937
9156
|
getTraces() {
|
|
@@ -8939,29 +9158,36 @@ var AxProgram = class {
|
|
|
8939
9158
|
if (this.trace) {
|
|
8940
9159
|
traces.push({ trace: this.trace, programId: this.key.id });
|
|
8941
9160
|
}
|
|
8942
|
-
for (const child of this.children) {
|
|
8943
|
-
const _traces = child
|
|
8944
|
-
traces = [...traces, ..._traces];
|
|
9161
|
+
for (const child of Array.from(this.children)) {
|
|
9162
|
+
const _traces = child?.getTraces();
|
|
9163
|
+
traces = [...traces, ..._traces ?? []];
|
|
8945
9164
|
}
|
|
8946
9165
|
return traces;
|
|
8947
9166
|
}
|
|
8948
9167
|
getUsage() {
|
|
8949
9168
|
let usage = [...this.usage ?? []];
|
|
8950
|
-
for (const child of this.children) {
|
|
8951
|
-
const cu = child
|
|
8952
|
-
usage = [...usage, ...cu];
|
|
9169
|
+
for (const child of Array.from(this.children)) {
|
|
9170
|
+
const cu = child?.getUsage();
|
|
9171
|
+
usage = [...usage, ...cu ?? []];
|
|
8953
9172
|
}
|
|
8954
9173
|
return mergeProgramUsage(usage);
|
|
8955
9174
|
}
|
|
8956
9175
|
resetUsage() {
|
|
8957
9176
|
this.usage = [];
|
|
8958
|
-
for (const child of this.children) {
|
|
8959
|
-
child
|
|
9177
|
+
for (const child of Array.from(this.children)) {
|
|
9178
|
+
child?.resetUsage();
|
|
8960
9179
|
}
|
|
8961
9180
|
}
|
|
8962
9181
|
setDemos(demos) {
|
|
8963
|
-
|
|
8964
|
-
|
|
9182
|
+
const hasChildren = Array.from(this.children).length > 0;
|
|
9183
|
+
const hasMatchingDemo = demos.some((demo) => demo.programId === this.key.id);
|
|
9184
|
+
if (hasChildren && !hasMatchingDemo) {
|
|
9185
|
+
throw new Error(
|
|
9186
|
+
`Program with id '${this.key.id}' has children but no matching programId found in demos`
|
|
9187
|
+
);
|
|
9188
|
+
}
|
|
9189
|
+
for (const child of Array.from(this.children)) {
|
|
9190
|
+
child?.setDemos(demos);
|
|
8965
9191
|
}
|
|
8966
9192
|
}
|
|
8967
9193
|
};
|
|
@@ -9089,7 +9315,7 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9089
9315
|
traceContext
|
|
9090
9316
|
}) {
|
|
9091
9317
|
const { sessionId, traceId, functions: _functions } = options ?? {};
|
|
9092
|
-
const
|
|
9318
|
+
const strictMode = options?.strictMode ?? false;
|
|
9093
9319
|
const model = options.model;
|
|
9094
9320
|
const functions = _functions?.map((f2) => "toFunction" in f2 ? f2.toFunction() : f2)?.flat();
|
|
9095
9321
|
const res = await this.forwardSendRequest({
|
|
@@ -9108,7 +9334,7 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9108
9334
|
traceId,
|
|
9109
9335
|
sessionId,
|
|
9110
9336
|
functions,
|
|
9111
|
-
|
|
9337
|
+
strictMode,
|
|
9112
9338
|
span
|
|
9113
9339
|
});
|
|
9114
9340
|
this.getLogger(ai, options)?.("", { tags: ["responseEnd"] });
|
|
@@ -9121,7 +9347,8 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9121
9347
|
traceId,
|
|
9122
9348
|
sessionId,
|
|
9123
9349
|
functions,
|
|
9124
|
-
span
|
|
9350
|
+
span,
|
|
9351
|
+
strictMode
|
|
9125
9352
|
});
|
|
9126
9353
|
}
|
|
9127
9354
|
}
|
|
@@ -9133,10 +9360,9 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9133
9360
|
sessionId,
|
|
9134
9361
|
traceId,
|
|
9135
9362
|
functions,
|
|
9136
|
-
|
|
9363
|
+
strictMode,
|
|
9137
9364
|
span
|
|
9138
9365
|
}) {
|
|
9139
|
-
const streamingValidation = fastFail ?? ai.getFeatures(model).functionCot !== true;
|
|
9140
9366
|
const functionCalls = [];
|
|
9141
9367
|
this.values = {};
|
|
9142
9368
|
const xstate = {
|
|
@@ -9187,7 +9413,7 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9187
9413
|
this.values,
|
|
9188
9414
|
xstate,
|
|
9189
9415
|
content,
|
|
9190
|
-
|
|
9416
|
+
strictMode
|
|
9191
9417
|
);
|
|
9192
9418
|
if (skip) {
|
|
9193
9419
|
continue;
|
|
@@ -9288,7 +9514,8 @@ Content: ${content}`
|
|
|
9288
9514
|
sessionId,
|
|
9289
9515
|
traceId,
|
|
9290
9516
|
functions,
|
|
9291
|
-
span
|
|
9517
|
+
span,
|
|
9518
|
+
strictMode
|
|
9292
9519
|
}) {
|
|
9293
9520
|
this.values = {};
|
|
9294
9521
|
let results = res.results ?? [];
|
|
@@ -9322,7 +9549,7 @@ Content: ${content}`
|
|
|
9322
9549
|
if (result.thought && result.thought.length > 0) {
|
|
9323
9550
|
this.values[this.thoughtFieldName] = result.thought;
|
|
9324
9551
|
}
|
|
9325
|
-
extractValues(this.signature, this.values, result.content);
|
|
9552
|
+
extractValues(this.signature, this.values, result.content, strictMode);
|
|
9326
9553
|
await assertAssertions(this.asserts, this.values);
|
|
9327
9554
|
if (this.fieldProcessors.length) {
|
|
9328
9555
|
await processFieldProcessors(
|
|
@@ -9494,8 +9721,7 @@ Content: ${result.content}`
|
|
|
9494
9721
|
...options?.thinkingTokenBudget ? { thinking_token_budget: options.thinkingTokenBudget } : {},
|
|
9495
9722
|
...options?.showThoughts ? { show_thoughts: options.showThoughts } : {},
|
|
9496
9723
|
...options?.maxSteps ? { max_steps: options.maxSteps } : {},
|
|
9497
|
-
...options?.maxRetries ? { max_retries: options.maxRetries } : {}
|
|
9498
|
-
...options?.fastFail ? { fast_fail: options.fastFail } : {}
|
|
9724
|
+
...options?.maxRetries ? { max_retries: options.maxRetries } : {}
|
|
9499
9725
|
};
|
|
9500
9726
|
const traceLabel = options.traceLabel ?? this.options?.traceLabel;
|
|
9501
9727
|
const spanName = traceLabel ? `${traceLabel} (AxGen)` : "AxGen";
|
|
@@ -9689,7 +9915,9 @@ var AxAgent = class {
|
|
|
9689
9915
|
description: definition ?? description
|
|
9690
9916
|
});
|
|
9691
9917
|
for (const agent of agents ?? []) {
|
|
9692
|
-
this.program.register(
|
|
9918
|
+
this.program.register(
|
|
9919
|
+
agent
|
|
9920
|
+
);
|
|
9693
9921
|
}
|
|
9694
9922
|
this.name = name;
|
|
9695
9923
|
this.func = {
|
|
@@ -10193,98 +10421,825 @@ function validateModels2(services) {
|
|
|
10193
10421
|
}
|
|
10194
10422
|
}
|
|
10195
10423
|
|
|
10196
|
-
//
|
|
10197
|
-
var
|
|
10198
|
-
|
|
10199
|
-
|
|
10200
|
-
|
|
10201
|
-
|
|
10202
|
-
|
|
10203
|
-
|
|
10204
|
-
|
|
10205
|
-
|
|
10206
|
-
|
|
10207
|
-
|
|
10208
|
-
tracer
|
|
10209
|
-
}) {
|
|
10210
|
-
this.name = name;
|
|
10211
|
-
this.fetch = fetch2;
|
|
10212
|
-
this.tracer = tracer;
|
|
10424
|
+
// dsp/optimizer.ts
|
|
10425
|
+
var AxDefaultCostTracker = class {
|
|
10426
|
+
tokenUsage = {};
|
|
10427
|
+
totalTokens = 0;
|
|
10428
|
+
// Configuration options
|
|
10429
|
+
costPerModel;
|
|
10430
|
+
maxCost;
|
|
10431
|
+
maxTokens;
|
|
10432
|
+
constructor(options) {
|
|
10433
|
+
this.costPerModel = options?.costPerModel ?? {};
|
|
10434
|
+
this.maxCost = options?.maxCost;
|
|
10435
|
+
this.maxTokens = options?.maxTokens;
|
|
10213
10436
|
}
|
|
10214
|
-
|
|
10215
|
-
|
|
10216
|
-
|
|
10217
|
-
}
|
|
10218
|
-
if (!this.tracer) {
|
|
10219
|
-
return await this._upsert(req, update);
|
|
10220
|
-
}
|
|
10221
|
-
return await this.tracer.startActiveSpan(
|
|
10222
|
-
"DB Upsert Request",
|
|
10223
|
-
{
|
|
10224
|
-
kind: import_api23.SpanKind.SERVER,
|
|
10225
|
-
attributes: {
|
|
10226
|
-
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10227
|
-
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10228
|
-
[axSpanAttributes.DB_TABLE]: req.table,
|
|
10229
|
-
[axSpanAttributes.DB_NAMESPACE]: req.namespace,
|
|
10230
|
-
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
10231
|
-
}
|
|
10232
|
-
},
|
|
10233
|
-
async (span) => {
|
|
10234
|
-
try {
|
|
10235
|
-
return await this._upsert(req, update, { span });
|
|
10236
|
-
} finally {
|
|
10237
|
-
span.end();
|
|
10238
|
-
}
|
|
10239
|
-
}
|
|
10240
|
-
);
|
|
10437
|
+
trackTokens(count, model) {
|
|
10438
|
+
this.tokenUsage[model] = (this.tokenUsage[model] || 0) + count;
|
|
10439
|
+
this.totalTokens += count;
|
|
10241
10440
|
}
|
|
10242
|
-
|
|
10243
|
-
|
|
10244
|
-
|
|
10441
|
+
getCurrentCost() {
|
|
10442
|
+
let totalCost = 0;
|
|
10443
|
+
for (const [model, tokens] of Object.entries(this.tokenUsage)) {
|
|
10444
|
+
const costPer1K = this.costPerModel[model] || 1e-3;
|
|
10445
|
+
totalCost += tokens / 1e3 * costPer1K;
|
|
10245
10446
|
}
|
|
10246
|
-
|
|
10247
|
-
|
|
10447
|
+
return totalCost;
|
|
10448
|
+
}
|
|
10449
|
+
getTokenUsage() {
|
|
10450
|
+
return { ...this.tokenUsage };
|
|
10451
|
+
}
|
|
10452
|
+
getTotalTokens() {
|
|
10453
|
+
return this.totalTokens;
|
|
10454
|
+
}
|
|
10455
|
+
isLimitReached() {
|
|
10456
|
+
if (this.maxTokens !== void 0 && this.totalTokens >= this.maxTokens) {
|
|
10457
|
+
return true;
|
|
10248
10458
|
}
|
|
10249
|
-
if (
|
|
10250
|
-
|
|
10459
|
+
if (this.maxCost !== void 0) {
|
|
10460
|
+
const currentCost = this.getCurrentCost();
|
|
10461
|
+
if (currentCost >= this.maxCost) {
|
|
10462
|
+
return true;
|
|
10463
|
+
}
|
|
10251
10464
|
}
|
|
10252
|
-
|
|
10253
|
-
|
|
10465
|
+
return false;
|
|
10466
|
+
}
|
|
10467
|
+
reset() {
|
|
10468
|
+
this.tokenUsage = {};
|
|
10469
|
+
this.totalTokens = 0;
|
|
10470
|
+
}
|
|
10471
|
+
};
|
|
10472
|
+
var AxBaseOptimizer = class {
|
|
10473
|
+
// Common AxOptimizerArgs fields
|
|
10474
|
+
studentAI;
|
|
10475
|
+
teacherAI;
|
|
10476
|
+
examples;
|
|
10477
|
+
validationSet;
|
|
10478
|
+
targetScore;
|
|
10479
|
+
minSuccessRate;
|
|
10480
|
+
onProgress;
|
|
10481
|
+
onEarlyStop;
|
|
10482
|
+
costTracker;
|
|
10483
|
+
seed;
|
|
10484
|
+
// Checkpointing fields
|
|
10485
|
+
checkpointSave;
|
|
10486
|
+
checkpointLoad;
|
|
10487
|
+
checkpointInterval;
|
|
10488
|
+
resumeFromCheckpoint;
|
|
10489
|
+
// Logging fields
|
|
10490
|
+
logger;
|
|
10491
|
+
verbose;
|
|
10492
|
+
// Checkpoint state
|
|
10493
|
+
currentRound = 0;
|
|
10494
|
+
scoreHistory = [];
|
|
10495
|
+
configurationHistory = [];
|
|
10496
|
+
// Common optimization statistics
|
|
10497
|
+
stats;
|
|
10498
|
+
constructor(args) {
|
|
10499
|
+
if (args.examples.length === 0) {
|
|
10500
|
+
throw new Error("No examples found");
|
|
10254
10501
|
}
|
|
10255
|
-
|
|
10256
|
-
|
|
10257
|
-
|
|
10258
|
-
|
|
10259
|
-
|
|
10260
|
-
|
|
10261
|
-
|
|
10262
|
-
|
|
10263
|
-
|
|
10264
|
-
|
|
10265
|
-
|
|
10502
|
+
this.studentAI = args.studentAI;
|
|
10503
|
+
this.teacherAI = args.teacherAI;
|
|
10504
|
+
this.examples = args.examples;
|
|
10505
|
+
this.validationSet = args.validationSet;
|
|
10506
|
+
this.targetScore = args.targetScore;
|
|
10507
|
+
this.minSuccessRate = args.minSuccessRate;
|
|
10508
|
+
this.onProgress = args.onProgress;
|
|
10509
|
+
this.onEarlyStop = args.onEarlyStop;
|
|
10510
|
+
this.seed = args.seed;
|
|
10511
|
+
this.checkpointSave = args.checkpointSave;
|
|
10512
|
+
this.checkpointLoad = args.checkpointLoad;
|
|
10513
|
+
this.checkpointInterval = args.checkpointInterval ?? 10;
|
|
10514
|
+
this.resumeFromCheckpoint = args.resumeFromCheckpoint;
|
|
10515
|
+
this.logger = args.logger;
|
|
10516
|
+
this.verbose = args.verbose;
|
|
10517
|
+
const costTracker = new AxDefaultCostTracker({
|
|
10518
|
+
maxTokens: 1e6
|
|
10519
|
+
});
|
|
10520
|
+
this.costTracker = args.costTracker ?? costTracker;
|
|
10521
|
+
this.stats = this.initializeStats();
|
|
10522
|
+
}
|
|
10523
|
+
/**
|
|
10524
|
+
* Initialize the optimization statistics structure
|
|
10525
|
+
*/
|
|
10526
|
+
initializeStats() {
|
|
10527
|
+
return {
|
|
10528
|
+
totalCalls: 0,
|
|
10529
|
+
successfulDemos: 0,
|
|
10530
|
+
estimatedTokenUsage: 0,
|
|
10531
|
+
earlyStopped: false,
|
|
10532
|
+
resourceUsage: {
|
|
10533
|
+
totalTokens: 0,
|
|
10534
|
+
totalTime: 0,
|
|
10535
|
+
avgLatencyPerEval: 0,
|
|
10536
|
+
costByModel: {}
|
|
10266
10537
|
},
|
|
10267
|
-
|
|
10268
|
-
|
|
10269
|
-
|
|
10270
|
-
|
|
10271
|
-
|
|
10272
|
-
}
|
|
10538
|
+
convergenceInfo: {
|
|
10539
|
+
converged: false,
|
|
10540
|
+
finalImprovement: 0,
|
|
10541
|
+
stagnationRounds: 0,
|
|
10542
|
+
convergenceThreshold: 0.01
|
|
10273
10543
|
}
|
|
10274
|
-
|
|
10544
|
+
};
|
|
10275
10545
|
}
|
|
10276
|
-
|
|
10277
|
-
|
|
10278
|
-
|
|
10546
|
+
/**
|
|
10547
|
+
* Set up reproducible random seed if provided
|
|
10548
|
+
*/
|
|
10549
|
+
setupRandomSeed() {
|
|
10550
|
+
if (this.seed !== void 0) {
|
|
10551
|
+
Math.random = (() => {
|
|
10552
|
+
let seed = this.seed;
|
|
10553
|
+
return () => {
|
|
10554
|
+
seed = (seed * 9301 + 49297) % 233280;
|
|
10555
|
+
return seed / 233280;
|
|
10556
|
+
};
|
|
10557
|
+
})();
|
|
10279
10558
|
}
|
|
10280
|
-
|
|
10281
|
-
|
|
10559
|
+
}
|
|
10560
|
+
/**
|
|
10561
|
+
* Check if optimization should stop early due to cost limits
|
|
10562
|
+
*/
|
|
10563
|
+
checkCostLimits() {
|
|
10564
|
+
return this.costTracker?.isLimitReached() ?? false;
|
|
10565
|
+
}
|
|
10566
|
+
/**
|
|
10567
|
+
* Check if target score has been reached
|
|
10568
|
+
*/
|
|
10569
|
+
checkTargetScore(currentScore) {
|
|
10570
|
+
return this.targetScore !== void 0 && currentScore >= this.targetScore;
|
|
10571
|
+
}
|
|
10572
|
+
/**
|
|
10573
|
+
* Update resource usage statistics
|
|
10574
|
+
*/
|
|
10575
|
+
updateResourceUsage(startTime, tokensUsed = 0) {
|
|
10576
|
+
this.stats.resourceUsage.totalTime = Date.now() - startTime;
|
|
10577
|
+
this.stats.resourceUsage.totalTokens += tokensUsed;
|
|
10578
|
+
if (this.stats.totalCalls > 0) {
|
|
10579
|
+
this.stats.resourceUsage.avgLatencyPerEval = this.stats.resourceUsage.totalTime / this.stats.totalCalls;
|
|
10282
10580
|
}
|
|
10283
|
-
|
|
10284
|
-
|
|
10285
|
-
|
|
10286
|
-
|
|
10287
|
-
|
|
10581
|
+
}
|
|
10582
|
+
/**
|
|
10583
|
+
* Trigger early stopping with appropriate callbacks
|
|
10584
|
+
*/
|
|
10585
|
+
triggerEarlyStopping(reason, bestScoreRound) {
|
|
10586
|
+
this.stats.earlyStopped = true;
|
|
10587
|
+
this.stats.earlyStopping = {
|
|
10588
|
+
bestScoreRound,
|
|
10589
|
+
patienceExhausted: reason.includes("improvement"),
|
|
10590
|
+
reason
|
|
10591
|
+
};
|
|
10592
|
+
if (this.onEarlyStop) {
|
|
10593
|
+
this.onEarlyStop(reason, this.stats);
|
|
10594
|
+
}
|
|
10595
|
+
}
|
|
10596
|
+
/**
|
|
10597
|
+
* Get the validation set, with fallback to a split of examples
|
|
10598
|
+
*/
|
|
10599
|
+
getValidationSet(options) {
|
|
10600
|
+
return options?.overrideValidationSet || this.validationSet || this.examples.slice(0, Math.floor(this.examples.length * 0.2));
|
|
10601
|
+
}
|
|
10602
|
+
/**
|
|
10603
|
+
* Get the AI service to use for a specific task, preferring teacher when available
|
|
10604
|
+
* @param preferTeacher Whether to prefer teacher AI over student AI
|
|
10605
|
+
* @param options Optional compile options that may override teacher AI
|
|
10606
|
+
* @returns The appropriate AI service to use
|
|
10607
|
+
*/
|
|
10608
|
+
getAIService(preferTeacher = false, options) {
|
|
10609
|
+
if (preferTeacher && options?.overrideTeacherAI) {
|
|
10610
|
+
return options.overrideTeacherAI;
|
|
10611
|
+
}
|
|
10612
|
+
if (preferTeacher && this.teacherAI) {
|
|
10613
|
+
return this.teacherAI;
|
|
10614
|
+
}
|
|
10615
|
+
return this.studentAI;
|
|
10616
|
+
}
|
|
10617
|
+
/**
|
|
10618
|
+
* Check if teacher AI is available (including overrides)
|
|
10619
|
+
* @param options Optional compile options that may override teacher AI
|
|
10620
|
+
* @returns True if teacher AI is configured or overridden
|
|
10621
|
+
*/
|
|
10622
|
+
hasTeacherAI(options) {
|
|
10623
|
+
return options?.overrideTeacherAI !== void 0 || this.teacherAI !== void 0;
|
|
10624
|
+
}
|
|
10625
|
+
/**
|
|
10626
|
+
* Get teacher AI if available, otherwise return student AI
|
|
10627
|
+
* @param options Optional compile options that may override teacher AI
|
|
10628
|
+
* @returns Teacher AI if available, otherwise student AI
|
|
10629
|
+
*/
|
|
10630
|
+
getTeacherOrStudentAI(options) {
|
|
10631
|
+
return options?.overrideTeacherAI || this.teacherAI || this.studentAI;
|
|
10632
|
+
}
|
|
10633
|
+
/**
|
|
10634
|
+
* Execute a task with teacher AI if available, otherwise use student AI
|
|
10635
|
+
* @param task Function that takes an AI service and returns a promise
|
|
10636
|
+
* @param preferTeacher Whether to prefer teacher AI (default: true)
|
|
10637
|
+
* @param options Optional compile options that may override teacher AI
|
|
10638
|
+
* @returns Result of the task execution
|
|
10639
|
+
*/
|
|
10640
|
+
async executeWithTeacher(task, preferTeacher = true, options) {
|
|
10641
|
+
const ai = this.getAIService(preferTeacher, options);
|
|
10642
|
+
return await task(ai);
|
|
10643
|
+
}
|
|
10644
|
+
/**
|
|
10645
|
+
* Get current optimization statistics
|
|
10646
|
+
*/
|
|
10647
|
+
getStats() {
|
|
10648
|
+
return { ...this.stats };
|
|
10649
|
+
}
|
|
10650
|
+
/**
|
|
10651
|
+
* Reset optimizer state for reuse with different programs
|
|
10652
|
+
*/
|
|
10653
|
+
reset() {
|
|
10654
|
+
this.stats = this.initializeStats();
|
|
10655
|
+
this.costTracker?.reset();
|
|
10656
|
+
this.currentRound = 0;
|
|
10657
|
+
this.scoreHistory = [];
|
|
10658
|
+
this.configurationHistory = [];
|
|
10659
|
+
}
|
|
10660
|
+
/**
|
|
10661
|
+
* Basic program validation that can be extended by concrete optimizers
|
|
10662
|
+
*/
|
|
10663
|
+
validateProgram(program) {
|
|
10664
|
+
const issues = [];
|
|
10665
|
+
const suggestions = [];
|
|
10666
|
+
if (!("forward" in program) || typeof program.forward !== "function") {
|
|
10667
|
+
issues.push("Program must have a forward method");
|
|
10668
|
+
}
|
|
10669
|
+
if (this.examples.length < 2) {
|
|
10670
|
+
issues.push("Need at least 2 examples for optimization");
|
|
10671
|
+
suggestions.push("Provide more training examples");
|
|
10672
|
+
}
|
|
10673
|
+
const valSetSize = this.getValidationSet().length;
|
|
10674
|
+
if (valSetSize < 1) {
|
|
10675
|
+
issues.push("Validation set is empty");
|
|
10676
|
+
suggestions.push("Provide examples or a validation set");
|
|
10677
|
+
}
|
|
10678
|
+
return {
|
|
10679
|
+
isValid: issues.length === 0,
|
|
10680
|
+
issues,
|
|
10681
|
+
suggestions
|
|
10682
|
+
};
|
|
10683
|
+
}
|
|
10684
|
+
/**
|
|
10685
|
+
* Multi-objective optimization using Pareto frontier
|
|
10686
|
+
* Default implementation that leverages the single-objective compile method
|
|
10687
|
+
* @param program The program to optimize
|
|
10688
|
+
* @param metricFn Multi-objective metric function that returns multiple scores
|
|
10689
|
+
* @param options Optional configuration options
|
|
10690
|
+
* @returns Pareto optimization result with frontier of non-dominated solutions
|
|
10691
|
+
*/
|
|
10692
|
+
async compilePareto(program, metricFn, options) {
|
|
10693
|
+
const startTime = Date.now();
|
|
10694
|
+
if (options?.verbose) {
|
|
10695
|
+
this.getLogger(options)?.(
|
|
10696
|
+
"Starting Pareto optimization using base implementation",
|
|
10697
|
+
{ tags: ["discovery"] }
|
|
10698
|
+
);
|
|
10699
|
+
this.getLogger(options)?.(
|
|
10700
|
+
"This will run multiple single-objective optimizations",
|
|
10701
|
+
{ tags: ["discovery"] }
|
|
10702
|
+
);
|
|
10703
|
+
}
|
|
10704
|
+
const solutions = await this.generateWeightedSolutions(
|
|
10705
|
+
program,
|
|
10706
|
+
metricFn,
|
|
10707
|
+
options
|
|
10708
|
+
);
|
|
10709
|
+
const constraintSolutions = await this.generateConstraintSolutions(
|
|
10710
|
+
program,
|
|
10711
|
+
metricFn,
|
|
10712
|
+
options
|
|
10713
|
+
);
|
|
10714
|
+
const allSolutions = [...solutions, ...constraintSolutions];
|
|
10715
|
+
if (options?.verbose) {
|
|
10716
|
+
this.getLogger(options)?.(
|
|
10717
|
+
`Generated ${allSolutions.length} candidate solutions`,
|
|
10718
|
+
{ tags: ["discovery"] }
|
|
10719
|
+
);
|
|
10720
|
+
}
|
|
10721
|
+
const paretoFront = this.findParetoFrontier(allSolutions);
|
|
10722
|
+
const hypervolume = this.calculateHypervolume(paretoFront);
|
|
10723
|
+
if (options?.verbose) {
|
|
10724
|
+
this.getLogger(options)?.(
|
|
10725
|
+
`Found ${paretoFront.length} non-dominated solutions`,
|
|
10726
|
+
{ tags: ["discovery"] }
|
|
10727
|
+
);
|
|
10728
|
+
this.getLogger(options)?.(
|
|
10729
|
+
`Hypervolume: ${hypervolume?.toFixed(4) || "N/A"}`,
|
|
10730
|
+
{ tags: ["discovery"] }
|
|
10731
|
+
);
|
|
10732
|
+
}
|
|
10733
|
+
this.updateResourceUsage(startTime);
|
|
10734
|
+
this.stats.convergenceInfo.converged = true;
|
|
10735
|
+
const bestScore = paretoFront.length > 0 ? Math.max(
|
|
10736
|
+
...paretoFront.map((sol) => Math.max(...Object.values(sol.scores)))
|
|
10737
|
+
) : 0;
|
|
10738
|
+
return {
|
|
10739
|
+
demos: paretoFront.length > 0 ? [...paretoFront[0].demos] : void 0,
|
|
10740
|
+
stats: this.stats,
|
|
10741
|
+
bestScore,
|
|
10742
|
+
paretoFront,
|
|
10743
|
+
hypervolume,
|
|
10744
|
+
paretoFrontSize: paretoFront.length,
|
|
10745
|
+
finalConfiguration: {
|
|
10746
|
+
paretoFrontSize: paretoFront.length,
|
|
10747
|
+
hypervolume,
|
|
10748
|
+
strategy: "weighted_combinations_and_constraints",
|
|
10749
|
+
numSolutions: allSolutions.length
|
|
10750
|
+
}
|
|
10751
|
+
};
|
|
10752
|
+
}
|
|
10753
|
+
/**
|
|
10754
|
+
* Generate solutions using different weighted combinations of objectives
|
|
10755
|
+
*/
|
|
10756
|
+
async generateWeightedSolutions(program, metricFn, options) {
|
|
10757
|
+
const solutions = [];
|
|
10758
|
+
const sampleExample = this.examples[0];
|
|
10759
|
+
const samplePrediction = await program.forward(
|
|
10760
|
+
this.studentAI,
|
|
10761
|
+
sampleExample
|
|
10762
|
+
);
|
|
10763
|
+
const sampleScores = await metricFn({
|
|
10764
|
+
prediction: samplePrediction,
|
|
10765
|
+
example: sampleExample
|
|
10766
|
+
});
|
|
10767
|
+
const objectives = Object.keys(sampleScores);
|
|
10768
|
+
if (options?.verbose) {
|
|
10769
|
+
this.getLogger(options)?.(
|
|
10770
|
+
`Detected objectives: ${objectives.join(", ")}`,
|
|
10771
|
+
{ tags: ["discovery"] }
|
|
10772
|
+
);
|
|
10773
|
+
}
|
|
10774
|
+
const weightCombinations = this.generateWeightCombinations(objectives);
|
|
10775
|
+
for (let i = 0; i < weightCombinations.length; i++) {
|
|
10776
|
+
const weights = weightCombinations[i];
|
|
10777
|
+
if (options?.verbose) {
|
|
10778
|
+
this.getLogger(options)?.(
|
|
10779
|
+
`Optimizing with weights: ${JSON.stringify(weights)}`,
|
|
10780
|
+
{ tags: ["discovery"] }
|
|
10781
|
+
);
|
|
10782
|
+
}
|
|
10783
|
+
const weightedMetric = async ({ prediction, example }) => {
|
|
10784
|
+
const scores = await metricFn({ prediction, example });
|
|
10785
|
+
let weightedScore = 0;
|
|
10786
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10787
|
+
weightedScore += score * (weights[objective] || 0);
|
|
10788
|
+
}
|
|
10789
|
+
return weightedScore;
|
|
10790
|
+
};
|
|
10791
|
+
try {
|
|
10792
|
+
const result = await this.compile(program, weightedMetric, {
|
|
10793
|
+
...options,
|
|
10794
|
+
verbose: false
|
|
10795
|
+
// Suppress inner optimization logs
|
|
10796
|
+
});
|
|
10797
|
+
const scores = await this.evaluateWithMultiObjective(
|
|
10798
|
+
program,
|
|
10799
|
+
result,
|
|
10800
|
+
metricFn
|
|
10801
|
+
);
|
|
10802
|
+
solutions.push({
|
|
10803
|
+
scores,
|
|
10804
|
+
demos: result.demos,
|
|
10805
|
+
configuration: {
|
|
10806
|
+
...result.finalConfiguration,
|
|
10807
|
+
weights,
|
|
10808
|
+
strategy: "weighted_combination"
|
|
10809
|
+
}
|
|
10810
|
+
});
|
|
10811
|
+
} catch (error) {
|
|
10812
|
+
if (options?.verbose) {
|
|
10813
|
+
this.getLogger(options)?.(
|
|
10814
|
+
`Failed optimization with weights ${JSON.stringify(weights)}: ${error}`,
|
|
10815
|
+
{ tags: ["warning"] }
|
|
10816
|
+
);
|
|
10817
|
+
}
|
|
10818
|
+
continue;
|
|
10819
|
+
}
|
|
10820
|
+
}
|
|
10821
|
+
return solutions;
|
|
10822
|
+
}
|
|
10823
|
+
/**
|
|
10824
|
+
* Generate solutions using constraint-based optimization
|
|
10825
|
+
*/
|
|
10826
|
+
async generateConstraintSolutions(program, metricFn, options) {
|
|
10827
|
+
const solutions = [];
|
|
10828
|
+
const sampleExample = this.examples[0];
|
|
10829
|
+
const samplePrediction = await program.forward(
|
|
10830
|
+
this.studentAI,
|
|
10831
|
+
sampleExample
|
|
10832
|
+
);
|
|
10833
|
+
const sampleScores = await metricFn({
|
|
10834
|
+
prediction: samplePrediction,
|
|
10835
|
+
example: sampleExample
|
|
10836
|
+
});
|
|
10837
|
+
const objectives = Object.keys(sampleScores);
|
|
10838
|
+
for (const primaryObjective of objectives) {
|
|
10839
|
+
if (options?.verbose) {
|
|
10840
|
+
this.getLogger(options)?.(
|
|
10841
|
+
`Optimizing ${primaryObjective} with constraints on other objectives`,
|
|
10842
|
+
{ tags: ["discovery"] }
|
|
10843
|
+
);
|
|
10844
|
+
}
|
|
10845
|
+
const constraintMetric = async ({ prediction, example }) => {
|
|
10846
|
+
const scores = await metricFn({ prediction, example });
|
|
10847
|
+
const primaryScore = scores[primaryObjective] || 0;
|
|
10848
|
+
let penalty = 0;
|
|
10849
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10850
|
+
if (objective !== primaryObjective) {
|
|
10851
|
+
if (score < 0.3) {
|
|
10852
|
+
penalty += (0.3 - score) * 2;
|
|
10853
|
+
}
|
|
10854
|
+
}
|
|
10855
|
+
}
|
|
10856
|
+
return primaryScore - penalty;
|
|
10857
|
+
};
|
|
10858
|
+
try {
|
|
10859
|
+
const result = await this.compile(program, constraintMetric, {
|
|
10860
|
+
...options,
|
|
10861
|
+
verbose: false
|
|
10862
|
+
});
|
|
10863
|
+
const scores = await this.evaluateWithMultiObjective(
|
|
10864
|
+
program,
|
|
10865
|
+
result,
|
|
10866
|
+
metricFn
|
|
10867
|
+
);
|
|
10868
|
+
solutions.push({
|
|
10869
|
+
scores,
|
|
10870
|
+
demos: result.demos,
|
|
10871
|
+
configuration: {
|
|
10872
|
+
...result.finalConfiguration,
|
|
10873
|
+
primaryObjective,
|
|
10874
|
+
strategy: "constraint_based"
|
|
10875
|
+
}
|
|
10876
|
+
});
|
|
10877
|
+
} catch (error) {
|
|
10878
|
+
if (options?.verbose) {
|
|
10879
|
+
this.getLogger(options)?.(
|
|
10880
|
+
`Failed constraint optimization for ${primaryObjective}: ${error}`,
|
|
10881
|
+
{ tags: ["warning"] }
|
|
10882
|
+
);
|
|
10883
|
+
}
|
|
10884
|
+
continue;
|
|
10885
|
+
}
|
|
10886
|
+
}
|
|
10887
|
+
return solutions;
|
|
10888
|
+
}
|
|
10889
|
+
/**
|
|
10890
|
+
* Generate different weight combinations for objectives
|
|
10891
|
+
*/
|
|
10892
|
+
generateWeightCombinations(objectives) {
|
|
10893
|
+
const combinations = [];
|
|
10894
|
+
for (const objective of objectives) {
|
|
10895
|
+
const weights = {};
|
|
10896
|
+
for (const obj of objectives) {
|
|
10897
|
+
weights[obj] = obj === objective ? 1 : 0;
|
|
10898
|
+
}
|
|
10899
|
+
combinations.push(weights);
|
|
10900
|
+
}
|
|
10901
|
+
const equalWeights = {};
|
|
10902
|
+
for (const objective of objectives) {
|
|
10903
|
+
equalWeights[objective] = 1 / objectives.length;
|
|
10904
|
+
}
|
|
10905
|
+
combinations.push(equalWeights);
|
|
10906
|
+
if (objectives.length === 2) {
|
|
10907
|
+
const [obj1, obj2] = objectives;
|
|
10908
|
+
for (let w1 = 0.1; w1 <= 0.9; w1 += 0.2) {
|
|
10909
|
+
const w2 = 1 - w1;
|
|
10910
|
+
combinations.push({ [obj1]: w1, [obj2]: w2 });
|
|
10911
|
+
}
|
|
10912
|
+
}
|
|
10913
|
+
if (objectives.length === 3) {
|
|
10914
|
+
const [obj1, obj2, obj3] = objectives;
|
|
10915
|
+
combinations.push(
|
|
10916
|
+
{ [obj1]: 0.5, [obj2]: 0.3, [obj3]: 0.2 },
|
|
10917
|
+
{ [obj1]: 0.3, [obj2]: 0.5, [obj3]: 0.2 },
|
|
10918
|
+
{ [obj1]: 0.2, [obj2]: 0.3, [obj3]: 0.5 }
|
|
10919
|
+
);
|
|
10920
|
+
}
|
|
10921
|
+
return combinations;
|
|
10922
|
+
}
|
|
10923
|
+
/**
|
|
10924
|
+
* Evaluate a single-objective result with multi-objective metrics
|
|
10925
|
+
*/
|
|
10926
|
+
async evaluateWithMultiObjective(program, result, metricFn) {
|
|
10927
|
+
const valSet = this.getValidationSet();
|
|
10928
|
+
const allScores = {};
|
|
10929
|
+
const testProgram = { ...program };
|
|
10930
|
+
if (result.demos && "setDemos" in testProgram) {
|
|
10931
|
+
;
|
|
10932
|
+
testProgram.setDemos(result.demos);
|
|
10933
|
+
}
|
|
10934
|
+
const evalSet = valSet.slice(0, Math.min(5, valSet.length));
|
|
10935
|
+
for (const example of evalSet) {
|
|
10936
|
+
try {
|
|
10937
|
+
const prediction = await testProgram.forward(
|
|
10938
|
+
this.studentAI,
|
|
10939
|
+
example
|
|
10940
|
+
);
|
|
10941
|
+
const scores = await metricFn({ prediction, example });
|
|
10942
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10943
|
+
if (!allScores[objective]) {
|
|
10944
|
+
allScores[objective] = [];
|
|
10945
|
+
}
|
|
10946
|
+
allScores[objective].push(score);
|
|
10947
|
+
}
|
|
10948
|
+
} catch {
|
|
10949
|
+
continue;
|
|
10950
|
+
}
|
|
10951
|
+
}
|
|
10952
|
+
const avgScores = {};
|
|
10953
|
+
for (const [objective, scores] of Object.entries(allScores)) {
|
|
10954
|
+
avgScores[objective] = scores.length > 0 ? scores.reduce((sum, score) => sum + score, 0) / scores.length : 0;
|
|
10955
|
+
}
|
|
10956
|
+
return avgScores;
|
|
10957
|
+
}
|
|
10958
|
+
/**
|
|
10959
|
+
* Find the Pareto frontier from a set of solutions
|
|
10960
|
+
*/
|
|
10961
|
+
findParetoFrontier(solutions) {
|
|
10962
|
+
const paretoFront = [];
|
|
10963
|
+
for (let i = 0; i < solutions.length; i++) {
|
|
10964
|
+
const solutionA = solutions[i];
|
|
10965
|
+
let isDominated = false;
|
|
10966
|
+
let dominatedCount = 0;
|
|
10967
|
+
for (let j = 0; j < solutions.length; j++) {
|
|
10968
|
+
if (i === j) continue;
|
|
10969
|
+
const solutionB = solutions[j];
|
|
10970
|
+
if (this.dominates(solutionB.scores, solutionA.scores)) {
|
|
10971
|
+
isDominated = true;
|
|
10972
|
+
break;
|
|
10973
|
+
}
|
|
10974
|
+
if (this.dominates(solutionA.scores, solutionB.scores)) {
|
|
10975
|
+
dominatedCount++;
|
|
10976
|
+
}
|
|
10977
|
+
}
|
|
10978
|
+
if (!isDominated) {
|
|
10979
|
+
paretoFront.push({
|
|
10980
|
+
demos: solutionA.demos || [],
|
|
10981
|
+
scores: solutionA.scores,
|
|
10982
|
+
configuration: solutionA.configuration,
|
|
10983
|
+
dominatedSolutions: dominatedCount
|
|
10984
|
+
});
|
|
10985
|
+
}
|
|
10986
|
+
}
|
|
10987
|
+
return paretoFront;
|
|
10988
|
+
}
|
|
10989
|
+
/**
|
|
10990
|
+
* Check if solution A dominates solution B
|
|
10991
|
+
* A dominates B if A is better or equal in all objectives and strictly better in at least one
|
|
10992
|
+
*/
|
|
10993
|
+
dominates(scoresA, scoresB) {
|
|
10994
|
+
const objectives = Object.keys(scoresA);
|
|
10995
|
+
let atLeastAsGood = true;
|
|
10996
|
+
let strictlyBetter = false;
|
|
10997
|
+
for (const objective of objectives) {
|
|
10998
|
+
const scoreA = scoresA[objective] || 0;
|
|
10999
|
+
const scoreB = scoresB[objective] || 0;
|
|
11000
|
+
if (scoreA < scoreB) {
|
|
11001
|
+
atLeastAsGood = false;
|
|
11002
|
+
break;
|
|
11003
|
+
}
|
|
11004
|
+
if (scoreA > scoreB) {
|
|
11005
|
+
strictlyBetter = true;
|
|
11006
|
+
}
|
|
11007
|
+
}
|
|
11008
|
+
return atLeastAsGood && strictlyBetter;
|
|
11009
|
+
}
|
|
11010
|
+
/**
|
|
11011
|
+
* Calculate hypervolume of the Pareto frontier
|
|
11012
|
+
* Simplified implementation using reference point at origin
|
|
11013
|
+
*/
|
|
11014
|
+
calculateHypervolume(paretoFront) {
|
|
11015
|
+
if (paretoFront.length === 0) return void 0;
|
|
11016
|
+
const firstSolution = paretoFront[0];
|
|
11017
|
+
const objectives = Object.keys(firstSolution.scores);
|
|
11018
|
+
if (objectives.length === 2) {
|
|
11019
|
+
const [obj1, obj2] = objectives;
|
|
11020
|
+
let hypervolume = 0;
|
|
11021
|
+
const sortedSolutions = [...paretoFront].sort(
|
|
11022
|
+
(a, b) => (b.scores[obj1] || 0) - (a.scores[obj1] || 0)
|
|
11023
|
+
);
|
|
11024
|
+
let prevScore2 = 0;
|
|
11025
|
+
for (const solution of sortedSolutions) {
|
|
11026
|
+
const score1 = solution.scores[obj1] || 0;
|
|
11027
|
+
const score2 = solution.scores[obj2] || 0;
|
|
11028
|
+
hypervolume += score1 * (score2 - prevScore2);
|
|
11029
|
+
prevScore2 = Math.max(prevScore2, score2);
|
|
11030
|
+
}
|
|
11031
|
+
return hypervolume;
|
|
11032
|
+
}
|
|
11033
|
+
return void 0;
|
|
11034
|
+
}
|
|
11035
|
+
/**
|
|
11036
|
+
* Save current optimization state to checkpoint
|
|
11037
|
+
*/
|
|
11038
|
+
async saveCheckpoint(optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
11039
|
+
const saveFn = options?.overrideCheckpointSave || this.checkpointSave;
|
|
11040
|
+
if (!saveFn) return void 0;
|
|
11041
|
+
const checkpoint = {
|
|
11042
|
+
version: "1.0.0",
|
|
11043
|
+
timestamp: Date.now(),
|
|
11044
|
+
optimizerType,
|
|
11045
|
+
optimizerConfig,
|
|
11046
|
+
currentRound: this.currentRound,
|
|
11047
|
+
totalRounds: this.stats.resourceUsage.totalTime > 0 ? this.currentRound : 0,
|
|
11048
|
+
bestScore,
|
|
11049
|
+
bestConfiguration,
|
|
11050
|
+
scoreHistory: [...this.scoreHistory],
|
|
11051
|
+
configurationHistory: [...this.configurationHistory],
|
|
11052
|
+
stats: { ...this.stats },
|
|
11053
|
+
optimizerState,
|
|
11054
|
+
examples: this.examples,
|
|
11055
|
+
validationSet: this.validationSet
|
|
11056
|
+
};
|
|
11057
|
+
return await saveFn(checkpoint);
|
|
11058
|
+
}
|
|
11059
|
+
/**
|
|
11060
|
+
* Load optimization state from checkpoint
|
|
11061
|
+
*/
|
|
11062
|
+
async loadCheckpoint(checkpointId, options) {
|
|
11063
|
+
const loadFn = options?.overrideCheckpointLoad || this.checkpointLoad;
|
|
11064
|
+
if (!loadFn) return null;
|
|
11065
|
+
return await loadFn(checkpointId);
|
|
11066
|
+
}
|
|
11067
|
+
/**
|
|
11068
|
+
* Restore optimizer state from checkpoint
|
|
11069
|
+
*/
|
|
11070
|
+
restoreFromCheckpoint(checkpoint) {
|
|
11071
|
+
this.currentRound = checkpoint.currentRound;
|
|
11072
|
+
this.scoreHistory = [...checkpoint.scoreHistory];
|
|
11073
|
+
this.configurationHistory = [...checkpoint.configurationHistory];
|
|
11074
|
+
this.stats = { ...checkpoint.stats };
|
|
11075
|
+
}
|
|
11076
|
+
/**
|
|
11077
|
+
* Check if checkpoint should be saved
|
|
11078
|
+
*/
|
|
11079
|
+
shouldSaveCheckpoint(round, options) {
|
|
11080
|
+
const interval = options?.overrideCheckpointInterval || this.checkpointInterval;
|
|
11081
|
+
return interval !== void 0 && round % interval === 0;
|
|
11082
|
+
}
|
|
11083
|
+
/**
|
|
11084
|
+
* Update optimization progress and handle checkpointing
|
|
11085
|
+
*/
|
|
11086
|
+
async updateOptimizationProgress(round, score, configuration, optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
11087
|
+
this.currentRound = round;
|
|
11088
|
+
this.scoreHistory.push(score);
|
|
11089
|
+
this.configurationHistory.push(configuration);
|
|
11090
|
+
if (this.shouldSaveCheckpoint(round, options)) {
|
|
11091
|
+
await this.saveCheckpoint(
|
|
11092
|
+
optimizerType,
|
|
11093
|
+
optimizerConfig,
|
|
11094
|
+
bestScore,
|
|
11095
|
+
bestConfiguration,
|
|
11096
|
+
optimizerState,
|
|
11097
|
+
options
|
|
11098
|
+
);
|
|
11099
|
+
}
|
|
11100
|
+
}
|
|
11101
|
+
/**
|
|
11102
|
+
* Save final checkpoint on completion
|
|
11103
|
+
*/
|
|
11104
|
+
async saveFinalCheckpoint(optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
11105
|
+
if (options?.saveCheckpointOnComplete !== false) {
|
|
11106
|
+
await this.saveCheckpoint(
|
|
11107
|
+
optimizerType,
|
|
11108
|
+
optimizerConfig,
|
|
11109
|
+
bestScore,
|
|
11110
|
+
bestConfiguration,
|
|
11111
|
+
{ ...optimizerState, final: true },
|
|
11112
|
+
options
|
|
11113
|
+
);
|
|
11114
|
+
}
|
|
11115
|
+
}
|
|
11116
|
+
/**
|
|
11117
|
+
* Get the logger function with fallback hierarchy:
|
|
11118
|
+
* 1. Explicit logger passed to optimizer
|
|
11119
|
+
* 2. Logger from student AI service
|
|
11120
|
+
* 3. Default optimizer logger
|
|
11121
|
+
* 4. undefined if verbose is false
|
|
11122
|
+
*/
|
|
11123
|
+
getLogger(options) {
|
|
11124
|
+
const isVerbose = this.isLoggingEnabled(options);
|
|
11125
|
+
if (!isVerbose) {
|
|
11126
|
+
return void 0;
|
|
11127
|
+
}
|
|
11128
|
+
if (this.logger) {
|
|
11129
|
+
return this.logger;
|
|
11130
|
+
}
|
|
11131
|
+
try {
|
|
11132
|
+
const aiLogger = this.studentAI.getLogger();
|
|
11133
|
+
if (aiLogger) {
|
|
11134
|
+
return aiLogger;
|
|
11135
|
+
}
|
|
11136
|
+
} catch {
|
|
11137
|
+
}
|
|
11138
|
+
return axDefaultOptimizerLogger;
|
|
11139
|
+
}
|
|
11140
|
+
/**
|
|
11141
|
+
* Check if logging is enabled based on verbose settings
|
|
11142
|
+
*/
|
|
11143
|
+
isLoggingEnabled(options) {
|
|
11144
|
+
if (options?.verbose !== void 0) {
|
|
11145
|
+
return options.verbose;
|
|
11146
|
+
}
|
|
11147
|
+
return this.verbose ?? true;
|
|
11148
|
+
}
|
|
11149
|
+
};
|
|
11150
|
+
|
|
11151
|
+
// db/base.ts
|
|
11152
|
+
var import_api23 = require("@opentelemetry/api");
|
|
11153
|
+
var AxDBBase = class {
|
|
11154
|
+
name;
|
|
11155
|
+
fetch;
|
|
11156
|
+
tracer;
|
|
11157
|
+
_upsert;
|
|
11158
|
+
_batchUpsert;
|
|
11159
|
+
_query;
|
|
11160
|
+
constructor({
|
|
11161
|
+
name,
|
|
11162
|
+
fetch: fetch2,
|
|
11163
|
+
tracer
|
|
11164
|
+
}) {
|
|
11165
|
+
this.name = name;
|
|
11166
|
+
this.fetch = fetch2;
|
|
11167
|
+
this.tracer = tracer;
|
|
11168
|
+
}
|
|
11169
|
+
async upsert(req, update) {
|
|
11170
|
+
if (!this._upsert) {
|
|
11171
|
+
throw new Error("upsert() not implemented");
|
|
11172
|
+
}
|
|
11173
|
+
if (!this.tracer) {
|
|
11174
|
+
return await this._upsert(req, update);
|
|
11175
|
+
}
|
|
11176
|
+
return await this.tracer.startActiveSpan(
|
|
11177
|
+
"DB Upsert Request",
|
|
11178
|
+
{
|
|
11179
|
+
kind: import_api23.SpanKind.SERVER,
|
|
11180
|
+
attributes: {
|
|
11181
|
+
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
11182
|
+
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
11183
|
+
[axSpanAttributes.DB_TABLE]: req.table,
|
|
11184
|
+
[axSpanAttributes.DB_NAMESPACE]: req.namespace,
|
|
11185
|
+
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
11186
|
+
}
|
|
11187
|
+
},
|
|
11188
|
+
async (span) => {
|
|
11189
|
+
try {
|
|
11190
|
+
return await this._upsert(req, update, { span });
|
|
11191
|
+
} finally {
|
|
11192
|
+
span.end();
|
|
11193
|
+
}
|
|
11194
|
+
}
|
|
11195
|
+
);
|
|
11196
|
+
}
|
|
11197
|
+
async batchUpsert(req, update) {
|
|
11198
|
+
if (!this._batchUpsert) {
|
|
11199
|
+
throw new Error("batchUpsert() not implemented");
|
|
11200
|
+
}
|
|
11201
|
+
if (req.length == 0) {
|
|
11202
|
+
throw new Error("Batch request is empty");
|
|
11203
|
+
}
|
|
11204
|
+
if (!req[0]) {
|
|
11205
|
+
throw new Error("Batch request is invalid first element is undefined");
|
|
11206
|
+
}
|
|
11207
|
+
if (!this.tracer) {
|
|
11208
|
+
return await this._batchUpsert(req, update);
|
|
11209
|
+
}
|
|
11210
|
+
return await this.tracer.startActiveSpan(
|
|
11211
|
+
"DB Batch Upsert Request",
|
|
11212
|
+
{
|
|
11213
|
+
kind: import_api23.SpanKind.SERVER,
|
|
11214
|
+
attributes: {
|
|
11215
|
+
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
11216
|
+
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
11217
|
+
[axSpanAttributes.DB_TABLE]: req[0].table,
|
|
11218
|
+
[axSpanAttributes.DB_NAMESPACE]: req[0].namespace,
|
|
11219
|
+
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
11220
|
+
}
|
|
11221
|
+
},
|
|
11222
|
+
async (span) => {
|
|
11223
|
+
try {
|
|
11224
|
+
return await this._batchUpsert(req, update, { span });
|
|
11225
|
+
} finally {
|
|
11226
|
+
span.end();
|
|
11227
|
+
}
|
|
11228
|
+
}
|
|
11229
|
+
);
|
|
11230
|
+
}
|
|
11231
|
+
async query(req) {
|
|
11232
|
+
if (!this._query) {
|
|
11233
|
+
throw new Error("query() not implemented");
|
|
11234
|
+
}
|
|
11235
|
+
if (!this.tracer) {
|
|
11236
|
+
return await this._query(req);
|
|
11237
|
+
}
|
|
11238
|
+
return await this.tracer.startActiveSpan(
|
|
11239
|
+
"DB Query Request",
|
|
11240
|
+
{
|
|
11241
|
+
kind: import_api23.SpanKind.SERVER,
|
|
11242
|
+
attributes: {
|
|
10288
11243
|
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10289
11244
|
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10290
11245
|
[axSpanAttributes.DB_TABLE]: req.table,
|
|
@@ -11652,52 +12607,31 @@ var AxMCPStreambleHTTPTransport = class {
|
|
|
11652
12607
|
};
|
|
11653
12608
|
|
|
11654
12609
|
// dsp/optimizers/bootstrapFewshot.ts
|
|
11655
|
-
var AxBootstrapFewShot = class {
|
|
11656
|
-
|
|
11657
|
-
|
|
11658
|
-
|
|
11659
|
-
|
|
11660
|
-
|
|
11661
|
-
|
|
11662
|
-
|
|
11663
|
-
|
|
11664
|
-
|
|
11665
|
-
|
|
11666
|
-
|
|
11667
|
-
|
|
11668
|
-
|
|
11669
|
-
|
|
11670
|
-
|
|
11671
|
-
|
|
11672
|
-
|
|
11673
|
-
|
|
11674
|
-
|
|
11675
|
-
|
|
11676
|
-
|
|
11677
|
-
|
|
11678
|
-
|
|
11679
|
-
|
|
11680
|
-
options
|
|
11681
|
-
}) {
|
|
11682
|
-
if (examples.length === 0) {
|
|
11683
|
-
throw new Error("No examples found");
|
|
11684
|
-
}
|
|
11685
|
-
const bootstrapOptions = options;
|
|
11686
|
-
this.maxRounds = bootstrapOptions?.maxRounds ?? 3;
|
|
11687
|
-
this.maxDemos = bootstrapOptions?.maxDemos ?? 4;
|
|
11688
|
-
this.maxExamples = bootstrapOptions?.maxExamples ?? 16;
|
|
11689
|
-
this.batchSize = bootstrapOptions?.batchSize ?? 1;
|
|
11690
|
-
this.earlyStoppingPatience = bootstrapOptions?.earlyStoppingPatience ?? 0;
|
|
11691
|
-
this.costMonitoring = bootstrapOptions?.costMonitoring ?? false;
|
|
11692
|
-
this.maxTokensPerGeneration = bootstrapOptions?.maxTokensPerGeneration ?? 0;
|
|
11693
|
-
this.verboseMode = bootstrapOptions?.verboseMode ?? true;
|
|
11694
|
-
this.debugMode = bootstrapOptions?.debugMode ?? false;
|
|
11695
|
-
this.ai = ai;
|
|
11696
|
-
this.teacherAI = bootstrapOptions?.teacherAI;
|
|
11697
|
-
this.program = program;
|
|
11698
|
-
this.examples = examples;
|
|
11699
|
-
}
|
|
11700
|
-
async compileRound(roundIndex, metricFn, options) {
|
|
12610
|
+
var AxBootstrapFewShot = class extends AxBaseOptimizer {
|
|
12611
|
+
maxRounds;
|
|
12612
|
+
maxDemos;
|
|
12613
|
+
maxExamples;
|
|
12614
|
+
batchSize;
|
|
12615
|
+
earlyStoppingPatience;
|
|
12616
|
+
costMonitoring;
|
|
12617
|
+
maxTokensPerGeneration;
|
|
12618
|
+
verboseMode;
|
|
12619
|
+
debugMode;
|
|
12620
|
+
traces = [];
|
|
12621
|
+
constructor(args) {
|
|
12622
|
+
super(args);
|
|
12623
|
+
const options = args.options || {};
|
|
12624
|
+
this.maxRounds = options.maxRounds ?? 3;
|
|
12625
|
+
this.maxDemos = options.maxDemos ?? 4;
|
|
12626
|
+
this.maxExamples = options.maxExamples ?? 16;
|
|
12627
|
+
this.batchSize = options.batchSize ?? 1;
|
|
12628
|
+
this.earlyStoppingPatience = options.earlyStoppingPatience ?? 0;
|
|
12629
|
+
this.costMonitoring = options.costMonitoring ?? false;
|
|
12630
|
+
this.maxTokensPerGeneration = options.maxTokensPerGeneration ?? 0;
|
|
12631
|
+
this.verboseMode = options.verboseMode ?? true;
|
|
12632
|
+
this.debugMode = options.debugMode ?? false;
|
|
12633
|
+
}
|
|
12634
|
+
async compileRound(program, roundIndex, metricFn, options) {
|
|
11701
12635
|
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
11702
12636
|
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
11703
12637
|
const aiOpt = {
|
|
@@ -11720,20 +12654,20 @@ var AxBootstrapFewShot = class {
|
|
|
11720
12654
|
continue;
|
|
11721
12655
|
}
|
|
11722
12656
|
const exList = examples.filter((e) => e !== ex);
|
|
11723
|
-
|
|
11724
|
-
const aiService = this.
|
|
12657
|
+
program.setExamples(exList);
|
|
12658
|
+
const aiService = this.getTeacherOrStudentAI();
|
|
11725
12659
|
this.stats.totalCalls++;
|
|
11726
12660
|
let res;
|
|
11727
12661
|
let error;
|
|
11728
12662
|
try {
|
|
11729
|
-
res = await
|
|
12663
|
+
res = await program.forward(aiService, ex, aiOpt);
|
|
11730
12664
|
if (this.costMonitoring) {
|
|
11731
12665
|
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
11732
12666
|
}
|
|
11733
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
12667
|
+
const score = await metricFn({ prediction: res, example: ex });
|
|
11734
12668
|
const success = score >= 0.5;
|
|
11735
12669
|
if (success) {
|
|
11736
|
-
this.traces = [...this.traces, ...
|
|
12670
|
+
this.traces = [...this.traces, ...program.getTraces()];
|
|
11737
12671
|
this.stats.successfulDemos++;
|
|
11738
12672
|
}
|
|
11739
12673
|
} catch (err) {
|
|
@@ -11784,54 +12718,73 @@ var AxBootstrapFewShot = class {
|
|
|
11784
12718
|
if (!this.stats.earlyStopping) {
|
|
11785
12719
|
this.stats.earlyStopping = {
|
|
11786
12720
|
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
11787
|
-
patienceExhausted: false
|
|
12721
|
+
patienceExhausted: false,
|
|
12722
|
+
reason: "No improvement detected"
|
|
11788
12723
|
};
|
|
11789
12724
|
} else if (improvement > 0) {
|
|
11790
12725
|
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
11791
12726
|
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
11792
12727
|
this.stats.earlyStopping.patienceExhausted = true;
|
|
11793
12728
|
this.stats.earlyStopped = true;
|
|
12729
|
+
this.stats.earlyStopping.reason = `No improvement for ${this.earlyStoppingPatience} rounds`;
|
|
11794
12730
|
if (this.verboseMode || this.debugMode) {
|
|
11795
|
-
|
|
11796
|
-
`
|
|
11797
|
-
|
|
12731
|
+
this.getLogger()?.(
|
|
12732
|
+
`Early stopping after ${roundIndex + 1} rounds (no improvement for ${this.earlyStoppingPatience} rounds)`,
|
|
12733
|
+
{ tags: ["optimizer", "warning"] }
|
|
11798
12734
|
);
|
|
11799
12735
|
}
|
|
11800
12736
|
return;
|
|
11801
12737
|
}
|
|
11802
12738
|
}
|
|
11803
12739
|
}
|
|
11804
|
-
async compile(metricFn, options) {
|
|
11805
|
-
const
|
|
11806
|
-
const maxRounds = compileOptions?.maxRounds ?? this.maxRounds;
|
|
12740
|
+
async compile(program, metricFn, options) {
|
|
12741
|
+
const maxRounds = options?.maxIterations ?? this.maxRounds;
|
|
11807
12742
|
this.traces = [];
|
|
11808
|
-
this.
|
|
11809
|
-
|
|
11810
|
-
|
|
11811
|
-
|
|
11812
|
-
|
|
11813
|
-
|
|
12743
|
+
this.reset();
|
|
12744
|
+
if (this.verboseMode || this.debugMode) {
|
|
12745
|
+
this.getLogger()?.(
|
|
12746
|
+
`Starting BootstrapFewshot optimization with ${maxRounds} rounds`,
|
|
12747
|
+
{ tags: ["optimizer", "start"] }
|
|
12748
|
+
);
|
|
12749
|
+
this.getLogger()?.(
|
|
12750
|
+
`Using ${this.examples.length} examples, max ${this.maxDemos} demos`,
|
|
12751
|
+
{ tags: ["optimizer", "config"] }
|
|
12752
|
+
);
|
|
12753
|
+
}
|
|
11814
12754
|
for (let i = 0; i < maxRounds; i++) {
|
|
11815
|
-
await this.compileRound(i, metricFn,
|
|
12755
|
+
await this.compileRound(program, i, metricFn, options);
|
|
11816
12756
|
if (this.stats.earlyStopped) {
|
|
11817
12757
|
break;
|
|
11818
12758
|
}
|
|
11819
12759
|
}
|
|
11820
12760
|
if (this.traces.length === 0) {
|
|
11821
12761
|
throw new Error(
|
|
11822
|
-
"No demonstrations found. Either
|
|
12762
|
+
"No demonstrations found. Either provide more examples or improve the existing ones."
|
|
11823
12763
|
);
|
|
11824
12764
|
}
|
|
11825
12765
|
const demos = groupTracesByKeys(this.traces);
|
|
12766
|
+
let bestScore = 0;
|
|
12767
|
+
if (this.traces.length > 0) {
|
|
12768
|
+
bestScore = this.stats.successfulDemos / Math.max(1, this.stats.totalCalls);
|
|
12769
|
+
}
|
|
12770
|
+
if (this.verboseMode || this.debugMode) {
|
|
12771
|
+
this.getLogger()?.(
|
|
12772
|
+
`Bootstrap complete. Generated ${demos.length} demos with ${bestScore.toFixed(3)} success rate`,
|
|
12773
|
+
{ tags: ["optimizer", "complete"] }
|
|
12774
|
+
);
|
|
12775
|
+
}
|
|
11826
12776
|
return {
|
|
11827
12777
|
demos,
|
|
11828
|
-
stats: this.stats
|
|
12778
|
+
stats: this.stats,
|
|
12779
|
+
bestScore,
|
|
12780
|
+
finalConfiguration: {
|
|
12781
|
+
maxRounds: this.maxRounds,
|
|
12782
|
+
maxDemos: this.maxDemos,
|
|
12783
|
+
batchSize: this.batchSize,
|
|
12784
|
+
successRate: bestScore
|
|
12785
|
+
}
|
|
11829
12786
|
};
|
|
11830
12787
|
}
|
|
11831
|
-
// Get optimization statistics
|
|
11832
|
-
getStats() {
|
|
11833
|
-
return this.stats;
|
|
11834
|
-
}
|
|
11835
12788
|
};
|
|
11836
12789
|
function groupTracesByKeys(programTraces) {
|
|
11837
12790
|
const groupedTraces = /* @__PURE__ */ new Map();
|
|
@@ -11846,9 +12799,12 @@ function groupTracesByKeys(programTraces) {
|
|
|
11846
12799
|
}
|
|
11847
12800
|
}
|
|
11848
12801
|
const programDemosArray = [];
|
|
11849
|
-
|
|
11850
|
-
programDemosArray.push({
|
|
11851
|
-
|
|
12802
|
+
groupedTraces.forEach((traces, programId) => {
|
|
12803
|
+
programDemosArray.push({
|
|
12804
|
+
traces,
|
|
12805
|
+
programId
|
|
12806
|
+
});
|
|
12807
|
+
});
|
|
11852
12808
|
return programDemosArray;
|
|
11853
12809
|
}
|
|
11854
12810
|
var randomSample = (array, n) => {
|
|
@@ -11867,10 +12823,8 @@ var randomSample = (array, n) => {
|
|
|
11867
12823
|
};
|
|
11868
12824
|
|
|
11869
12825
|
// dsp/optimizers/miproV2.ts
|
|
11870
|
-
var AxMiPRO = class {
|
|
11871
|
-
|
|
11872
|
-
program;
|
|
11873
|
-
examples;
|
|
12826
|
+
var AxMiPRO = class extends AxBaseOptimizer {
|
|
12827
|
+
// MiPRO-specific options
|
|
11874
12828
|
maxBootstrappedDemos;
|
|
11875
12829
|
maxLabeledDemos;
|
|
11876
12830
|
numCandidates;
|
|
@@ -11884,52 +12838,33 @@ var AxMiPRO = class {
|
|
|
11884
12838
|
viewDataBatchSize;
|
|
11885
12839
|
tipAwareProposer;
|
|
11886
12840
|
fewshotAwareProposer;
|
|
11887
|
-
seed;
|
|
11888
|
-
verbose;
|
|
11889
|
-
bootstrapper;
|
|
11890
12841
|
earlyStoppingTrials;
|
|
11891
12842
|
minImprovementThreshold;
|
|
11892
|
-
|
|
11893
|
-
|
|
11894
|
-
|
|
11895
|
-
|
|
11896
|
-
|
|
11897
|
-
|
|
11898
|
-
|
|
11899
|
-
|
|
11900
|
-
|
|
11901
|
-
|
|
11902
|
-
this.
|
|
11903
|
-
this.
|
|
11904
|
-
this.
|
|
11905
|
-
this.
|
|
11906
|
-
this.
|
|
11907
|
-
this.
|
|
11908
|
-
this.
|
|
11909
|
-
this.
|
|
11910
|
-
this.
|
|
11911
|
-
this.
|
|
11912
|
-
this.
|
|
11913
|
-
this.
|
|
11914
|
-
this.
|
|
11915
|
-
this.
|
|
11916
|
-
this.
|
|
11917
|
-
this.earlyStoppingTrials = miproOptions.earlyStoppingTrials ?? 5;
|
|
11918
|
-
this.minImprovementThreshold = miproOptions.minImprovementThreshold ?? 0.01;
|
|
11919
|
-
this.ai = ai;
|
|
11920
|
-
this.program = program;
|
|
11921
|
-
this.examples = examples;
|
|
11922
|
-
this.bootstrapper = new AxBootstrapFewShot({
|
|
11923
|
-
ai,
|
|
11924
|
-
program,
|
|
11925
|
-
examples,
|
|
11926
|
-
options: {
|
|
11927
|
-
maxDemos: this.maxBootstrappedDemos,
|
|
11928
|
-
maxRounds: 3,
|
|
11929
|
-
// Default, or adjust based on your needs
|
|
11930
|
-
verboseMode: this.verbose
|
|
11931
|
-
}
|
|
11932
|
-
});
|
|
12843
|
+
bayesianOptimization;
|
|
12844
|
+
acquisitionFunction;
|
|
12845
|
+
explorationWeight;
|
|
12846
|
+
constructor(args) {
|
|
12847
|
+
super(args);
|
|
12848
|
+
const options = args.options || {};
|
|
12849
|
+
this.numCandidates = options.numCandidates ?? 5;
|
|
12850
|
+
this.initTemperature = options.initTemperature ?? 0.7;
|
|
12851
|
+
this.maxBootstrappedDemos = options.maxBootstrappedDemos ?? 3;
|
|
12852
|
+
this.maxLabeledDemos = options.maxLabeledDemos ?? 4;
|
|
12853
|
+
this.numTrials = options.numTrials ?? 30;
|
|
12854
|
+
this.minibatch = options.minibatch ?? true;
|
|
12855
|
+
this.minibatchSize = options.minibatchSize ?? 25;
|
|
12856
|
+
this.minibatchFullEvalSteps = options.minibatchFullEvalSteps ?? 10;
|
|
12857
|
+
this.programAwareProposer = options.programAwareProposer ?? true;
|
|
12858
|
+
this.dataAwareProposer = options.dataAwareProposer ?? true;
|
|
12859
|
+
this.viewDataBatchSize = options.viewDataBatchSize ?? 10;
|
|
12860
|
+
this.tipAwareProposer = options.tipAwareProposer ?? true;
|
|
12861
|
+
this.fewshotAwareProposer = options.fewshotAwareProposer ?? true;
|
|
12862
|
+
this.earlyStoppingTrials = options.earlyStoppingTrials ?? 5;
|
|
12863
|
+
this.minImprovementThreshold = options.minImprovementThreshold ?? 0.01;
|
|
12864
|
+
this.bayesianOptimization = options.bayesianOptimization ?? false;
|
|
12865
|
+
this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
|
|
12866
|
+
this.explorationWeight = options.explorationWeight ?? 0.1;
|
|
12867
|
+
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
11933
12868
|
}
|
|
11934
12869
|
/**
|
|
11935
12870
|
* Configures the optimizer for light, medium, or heavy optimization
|
|
@@ -11973,123 +12908,62 @@ var AxMiPRO = class {
|
|
|
11973
12908
|
];
|
|
11974
12909
|
}
|
|
11975
12910
|
/**
|
|
11976
|
-
* Generates instruction candidates
|
|
12911
|
+
* Generates instruction candidates using the teacher model if available
|
|
12912
|
+
* @param options Optional compile options that may override teacher AI
|
|
11977
12913
|
* @returns Array of generated instruction candidates
|
|
11978
12914
|
*/
|
|
11979
|
-
async proposeInstructionCandidates() {
|
|
12915
|
+
async proposeInstructionCandidates(options) {
|
|
11980
12916
|
const instructions = [];
|
|
11981
|
-
|
|
11982
|
-
if (this.programAwareProposer) {
|
|
11983
|
-
programContext = await this.generateProgramSummary();
|
|
11984
|
-
}
|
|
11985
|
-
let dataContext = "";
|
|
11986
|
-
if (this.dataAwareProposer) {
|
|
11987
|
-
dataContext = await this.generateDataSummary();
|
|
11988
|
-
}
|
|
12917
|
+
const aiToUse = this.getTeacherOrStudentAI(options);
|
|
11989
12918
|
const tips = this.tipAwareProposer ? this.generateTips() : [];
|
|
11990
12919
|
for (let i = 0; i < this.numCandidates; i++) {
|
|
11991
12920
|
const tipIndex = tips.length > 0 ? i % tips.length : -1;
|
|
11992
12921
|
const tipToUse = tipIndex >= 0 ? tips[tipIndex] : "";
|
|
11993
12922
|
const instruction = await this.generateInstruction({
|
|
11994
|
-
programContext,
|
|
11995
|
-
dataContext,
|
|
11996
12923
|
tip: tipToUse,
|
|
11997
|
-
candidateIndex: i
|
|
12924
|
+
candidateIndex: i,
|
|
12925
|
+
ai: aiToUse
|
|
11998
12926
|
});
|
|
11999
12927
|
instructions.push(instruction);
|
|
12000
12928
|
}
|
|
12001
12929
|
return instructions;
|
|
12002
12930
|
}
|
|
12003
|
-
/**
|
|
12004
|
-
* Generates a summary of the program structure for instruction proposal
|
|
12005
|
-
*/
|
|
12006
|
-
async generateProgramSummary() {
|
|
12007
|
-
const prompt = `Summarize the following program structure. Focus on the signatures,
|
|
12008
|
-
input/output fields, and the purpose of each component. Identify key components
|
|
12009
|
-
that might benefit from better instructions.`;
|
|
12010
|
-
const programStr = JSON.stringify(this.program);
|
|
12011
|
-
const response = await this.ai.chat({
|
|
12012
|
-
chatPrompt: [
|
|
12013
|
-
{ role: "system", content: prompt },
|
|
12014
|
-
{ role: "user", content: programStr }
|
|
12015
|
-
],
|
|
12016
|
-
modelConfig: { temperature: 0.2 }
|
|
12017
|
-
});
|
|
12018
|
-
if (response instanceof ReadableStream) {
|
|
12019
|
-
return "";
|
|
12020
|
-
}
|
|
12021
|
-
return response.results[0]?.content || "";
|
|
12022
|
-
}
|
|
12023
|
-
/**
|
|
12024
|
-
* Generates a summary of the dataset for instruction proposal
|
|
12025
|
-
*/
|
|
12026
|
-
async generateDataSummary() {
|
|
12027
|
-
const sampleSize = Math.min(this.viewDataBatchSize, this.examples.length);
|
|
12028
|
-
const sample = this.examples.slice(0, sampleSize);
|
|
12029
|
-
const prompt = `Analyze the following dataset examples and provide a summary
|
|
12030
|
-
of key patterns, input-output relationships, and any specific challenges
|
|
12031
|
-
the data presents. Focus on what makes a good answer and what patterns should
|
|
12032
|
-
be followed.`;
|
|
12033
|
-
const dataStr = JSON.stringify(sample);
|
|
12034
|
-
const response = await this.ai.chat({
|
|
12035
|
-
chatPrompt: [
|
|
12036
|
-
{ role: "system", content: prompt },
|
|
12037
|
-
{ role: "user", content: dataStr }
|
|
12038
|
-
],
|
|
12039
|
-
modelConfig: { temperature: 0.2 }
|
|
12040
|
-
});
|
|
12041
|
-
if (response instanceof ReadableStream) {
|
|
12042
|
-
return "";
|
|
12043
|
-
}
|
|
12044
|
-
return response.results[0]?.content || "";
|
|
12045
|
-
}
|
|
12046
|
-
/**
|
|
12047
|
-
* Generates a specific instruction candidate
|
|
12048
|
-
*/
|
|
12049
12931
|
async generateInstruction({
|
|
12050
|
-
programContext,
|
|
12051
|
-
dataContext,
|
|
12052
12932
|
tip,
|
|
12053
12933
|
candidateIndex
|
|
12054
12934
|
}) {
|
|
12055
|
-
const
|
|
12056
|
-
|
|
12057
|
-
|
|
12058
|
-
|
|
12059
|
-
|
|
12060
|
-
|
|
12061
|
-
|
|
12062
|
-
|
|
12063
|
-
|
|
12064
|
-
|
|
12065
|
-
${tip ? `STYLE TIP: ${tip}
|
|
12066
|
-
|
|
12067
|
-
` : ""}
|
|
12068
|
-
|
|
12069
|
-
Your task is to craft a clear, effective instruction that will help the AI model generate
|
|
12070
|
-
accurate outputs for this task. Instruction #${candidateIndex + 1}/${this.numCandidates}.
|
|
12071
|
-
|
|
12072
|
-
The instruction should be detailed enough to guide the model but not overly prescriptive
|
|
12073
|
-
or restrictive. Focus on what makes a good response rather than listing exact steps.
|
|
12074
|
-
|
|
12075
|
-
INSTRUCTION:`;
|
|
12076
|
-
const response = await this.ai.chat({
|
|
12077
|
-
chatPrompt: [{ role: "user", content: prompt }],
|
|
12078
|
-
modelConfig: { temperature: 0.7 + 0.1 * candidateIndex }
|
|
12079
|
-
});
|
|
12080
|
-
if (response instanceof ReadableStream) {
|
|
12081
|
-
return "";
|
|
12935
|
+
const baseInstructions = [
|
|
12936
|
+
"Analyze the input carefully and provide a detailed response.",
|
|
12937
|
+
"Think step by step and provide a clear answer.",
|
|
12938
|
+
"Consider all aspects of the input before responding.",
|
|
12939
|
+
"Provide a concise but comprehensive response.",
|
|
12940
|
+
"Focus on accuracy and clarity in your response."
|
|
12941
|
+
];
|
|
12942
|
+
let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
|
|
12943
|
+
if (tip) {
|
|
12944
|
+
instruction = `${instruction} ${tip}`;
|
|
12082
12945
|
}
|
|
12083
|
-
return
|
|
12946
|
+
return instruction;
|
|
12084
12947
|
}
|
|
12085
12948
|
/**
|
|
12086
12949
|
* Bootstraps few-shot examples for the program
|
|
12087
12950
|
*/
|
|
12088
|
-
async bootstrapFewShotExamples(metricFn) {
|
|
12089
|
-
if (this.
|
|
12090
|
-
|
|
12951
|
+
async bootstrapFewShotExamples(program, metricFn) {
|
|
12952
|
+
if (this.isLoggingEnabled()) {
|
|
12953
|
+
this.getLogger()?.("Bootstrapping few-shot examples...", {
|
|
12954
|
+
tags: ["optimizer", "phase"]
|
|
12955
|
+
});
|
|
12091
12956
|
}
|
|
12092
|
-
const
|
|
12957
|
+
const bootstrapper = new AxBootstrapFewShot({
|
|
12958
|
+
studentAI: this.studentAI,
|
|
12959
|
+
examples: this.examples,
|
|
12960
|
+
options: {
|
|
12961
|
+
maxDemos: this.maxBootstrappedDemos,
|
|
12962
|
+
maxRounds: 3,
|
|
12963
|
+
verboseMode: this.isLoggingEnabled()
|
|
12964
|
+
}
|
|
12965
|
+
});
|
|
12966
|
+
const result = await bootstrapper.compile(program, metricFn, {
|
|
12093
12967
|
maxDemos: this.maxBootstrappedDemos
|
|
12094
12968
|
});
|
|
12095
12969
|
return result.demos || [];
|
|
@@ -12113,109 +12987,111 @@ ${dataContext}
|
|
|
12113
12987
|
return selectedExamples;
|
|
12114
12988
|
}
|
|
12115
12989
|
/**
|
|
12116
|
-
* Runs
|
|
12990
|
+
* Runs optimization to find the best combination of few-shot examples and instructions
|
|
12117
12991
|
*/
|
|
12118
|
-
async
|
|
12119
|
-
let bestConfig =
|
|
12120
|
-
let bestScore = Number.NEGATIVE_INFINITY;
|
|
12121
|
-
const evaluatedConfigs = [];
|
|
12122
|
-
const defaultConfig = {
|
|
12992
|
+
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, valset, metricFn, options) {
|
|
12993
|
+
let bestConfig = {
|
|
12123
12994
|
instruction: instructions[0] || "",
|
|
12124
12995
|
bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
|
|
12125
12996
|
labeledExamples: Math.min(1, labeledExamples.length)
|
|
12126
12997
|
};
|
|
12127
|
-
let
|
|
12128
|
-
let
|
|
12129
|
-
const
|
|
12130
|
-
|
|
12131
|
-
|
|
12132
|
-
|
|
12133
|
-
|
|
12134
|
-
|
|
12135
|
-
|
|
12136
|
-
|
|
12998
|
+
let bestScore = 0;
|
|
12999
|
+
let stagnationRounds = 0;
|
|
13000
|
+
const scoreHistory = [];
|
|
13001
|
+
let startRound = 0;
|
|
13002
|
+
if (this.resumeFromCheckpoint) {
|
|
13003
|
+
const checkpoint = await this.loadCheckpoint(
|
|
13004
|
+
this.resumeFromCheckpoint,
|
|
13005
|
+
options
|
|
13006
|
+
);
|
|
13007
|
+
if (checkpoint && checkpoint.optimizerType === "MiPRO") {
|
|
13008
|
+
if (this.isLoggingEnabled(options)) {
|
|
13009
|
+
this.getLogger(options)?.(
|
|
13010
|
+
`Resuming from checkpoint at round ${checkpoint.currentRound}`,
|
|
13011
|
+
{ tags: ["optimizer", "checkpoint"] }
|
|
13012
|
+
);
|
|
13013
|
+
}
|
|
13014
|
+
this.restoreFromCheckpoint(checkpoint);
|
|
13015
|
+
startRound = checkpoint.currentRound;
|
|
13016
|
+
bestScore = checkpoint.bestScore;
|
|
13017
|
+
bestConfig = checkpoint.bestConfiguration || bestConfig;
|
|
13018
|
+
stagnationRounds = checkpoint.stats.convergenceInfo?.stagnationRounds || 0;
|
|
13019
|
+
}
|
|
13020
|
+
}
|
|
13021
|
+
if (this.isLoggingEnabled(options)) {
|
|
13022
|
+
this.getLogger(options)?.(
|
|
13023
|
+
`Running optimization trials (${this.numTrials} total)`,
|
|
13024
|
+
{ tags: ["optimizer", "phase"] }
|
|
13025
|
+
);
|
|
13026
|
+
}
|
|
13027
|
+
for (let i = startRound; i < this.numTrials; i++) {
|
|
12137
13028
|
const config = {
|
|
12138
|
-
instruction:
|
|
12139
|
-
bootstrappedDemos: Math.
|
|
12140
|
-
Math.random() * (bootstrappedDemos.length + 1)
|
|
13029
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
13030
|
+
bootstrappedDemos: Math.min(
|
|
13031
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
13032
|
+
this.maxBootstrappedDemos
|
|
12141
13033
|
),
|
|
12142
|
-
labeledExamples: Math.
|
|
12143
|
-
Math.random() * (labeledExamples.length + 1)
|
|
13034
|
+
labeledExamples: Math.min(
|
|
13035
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
13036
|
+
this.maxLabeledDemos
|
|
12144
13037
|
)
|
|
12145
13038
|
};
|
|
12146
|
-
configs.push(config);
|
|
12147
|
-
}
|
|
12148
|
-
for (let i = 0; i < configs.length; i++) {
|
|
12149
|
-
const config = configs[i];
|
|
12150
|
-
if (!config) continue;
|
|
12151
13039
|
const score = await this.evaluateConfig(
|
|
13040
|
+
program,
|
|
12152
13041
|
config,
|
|
12153
13042
|
bootstrappedDemos,
|
|
12154
13043
|
labeledExamples,
|
|
12155
13044
|
valset,
|
|
12156
|
-
metricFn
|
|
12157
|
-
i
|
|
13045
|
+
metricFn
|
|
12158
13046
|
);
|
|
12159
|
-
|
|
12160
|
-
|
|
13047
|
+
scoreHistory.push(score);
|
|
13048
|
+
const improvement = score - bestScore;
|
|
13049
|
+
if (improvement > this.minImprovementThreshold) {
|
|
12161
13050
|
bestScore = score;
|
|
12162
13051
|
bestConfig = config;
|
|
12163
|
-
|
|
12164
|
-
|
|
12165
|
-
|
|
13052
|
+
stagnationRounds = 0;
|
|
13053
|
+
if (this.isLoggingEnabled(options)) {
|
|
13054
|
+
this.getLogger(options)?.(
|
|
13055
|
+
`Trial ${i + 1}/${this.numTrials}: New best score ${bestScore.toFixed(3)}`,
|
|
13056
|
+
{ tags: ["optimizer", "progress"] }
|
|
12166
13057
|
);
|
|
12167
13058
|
}
|
|
13059
|
+
} else {
|
|
13060
|
+
stagnationRounds++;
|
|
12168
13061
|
}
|
|
12169
|
-
|
|
13062
|
+
await this.updateOptimizationProgress(
|
|
12170
13063
|
i + 1,
|
|
12171
|
-
|
|
12172
|
-
|
|
12173
|
-
|
|
12174
|
-
|
|
12175
|
-
|
|
12176
|
-
|
|
12177
|
-
|
|
12178
|
-
|
|
12179
|
-
|
|
12180
|
-
|
|
12181
|
-
|
|
12182
|
-
|
|
12183
|
-
|
|
12184
|
-
);
|
|
12185
|
-
const score = await this.evaluateConfig(
|
|
12186
|
-
nextConfig,
|
|
12187
|
-
bootstrappedDemos,
|
|
12188
|
-
labeledExamples,
|
|
12189
|
-
valset,
|
|
12190
|
-
metricFn,
|
|
12191
|
-
i
|
|
13064
|
+
score,
|
|
13065
|
+
config,
|
|
13066
|
+
"MiPRO",
|
|
13067
|
+
this.getConfiguration(),
|
|
13068
|
+
bestScore,
|
|
13069
|
+
bestConfig,
|
|
13070
|
+
{
|
|
13071
|
+
stagnationRounds,
|
|
13072
|
+
bootstrappedDemos: bootstrappedDemos.length,
|
|
13073
|
+
labeledExamples: labeledExamples.length,
|
|
13074
|
+
instructions: instructions.length
|
|
13075
|
+
},
|
|
13076
|
+
options
|
|
12192
13077
|
);
|
|
12193
|
-
|
|
12194
|
-
|
|
12195
|
-
|
|
12196
|
-
|
|
12197
|
-
|
|
12198
|
-
|
|
12199
|
-
|
|
12200
|
-
)
|
|
12201
|
-
|
|
12202
|
-
|
|
12203
|
-
|
|
12204
|
-
|
|
12205
|
-
|
|
12206
|
-
|
|
12207
|
-
|
|
12208
|
-
if (this.verbose) {
|
|
12209
|
-
console.log(
|
|
12210
|
-
`Early stopping triggered after ${i + 1} trials. No improvement for ${trialsWithoutImprovement} trials.`
|
|
12211
|
-
);
|
|
12212
|
-
}
|
|
12213
|
-
break;
|
|
13078
|
+
if (this.onProgress) {
|
|
13079
|
+
this.onProgress({
|
|
13080
|
+
round: i + 1,
|
|
13081
|
+
totalRounds: this.numTrials,
|
|
13082
|
+
currentScore: score,
|
|
13083
|
+
bestScore,
|
|
13084
|
+
tokensUsed: this.stats.resourceUsage.totalTokens,
|
|
13085
|
+
timeElapsed: Date.now(),
|
|
13086
|
+
successfulExamples: this.stats.successfulDemos,
|
|
13087
|
+
totalExamples: this.examples.length,
|
|
13088
|
+
currentConfiguration: config,
|
|
13089
|
+
convergenceInfo: {
|
|
13090
|
+
improvement,
|
|
13091
|
+
stagnationRounds,
|
|
13092
|
+
isConverging: stagnationRounds < this.earlyStoppingTrials
|
|
12214
13093
|
}
|
|
12215
|
-
}
|
|
12216
|
-
lastBestScore = bestScore;
|
|
12217
|
-
trialsWithoutImprovement = 0;
|
|
12218
|
-
}
|
|
13094
|
+
});
|
|
12219
13095
|
}
|
|
12220
13096
|
updateProgressBar(
|
|
12221
13097
|
i + 1,
|
|
@@ -12225,290 +13101,309 @@ ${dataContext}
|
|
|
12225
13101
|
"Running MIPROv2 optimization",
|
|
12226
13102
|
30
|
|
12227
13103
|
);
|
|
12228
|
-
if (this.
|
|
12229
|
-
|
|
12230
|
-
|
|
12231
|
-
`Running full evaluation on best configuration at trial ${i + 1}`
|
|
12232
|
-
);
|
|
12233
|
-
}
|
|
12234
|
-
const fullScore = await this.fullEvaluation(
|
|
12235
|
-
bestConfig,
|
|
12236
|
-
bootstrappedDemos,
|
|
12237
|
-
labeledExamples,
|
|
12238
|
-
valset,
|
|
12239
|
-
metricFn
|
|
12240
|
-
);
|
|
12241
|
-
if (this.verbose) {
|
|
12242
|
-
console.log(`Full evaluation score: ${fullScore}`);
|
|
12243
|
-
}
|
|
12244
|
-
bestScore = fullScore;
|
|
13104
|
+
if (this.checkCostLimits()) {
|
|
13105
|
+
this.triggerEarlyStopping("Cost limit reached", i + 1);
|
|
13106
|
+
break;
|
|
12245
13107
|
}
|
|
12246
|
-
|
|
12247
|
-
|
|
12248
|
-
|
|
12249
|
-
|
|
12250
|
-
"Optimization failed to find any valid configurations, using default fallback configuration"
|
|
13108
|
+
if (stagnationRounds >= this.earlyStoppingTrials) {
|
|
13109
|
+
this.triggerEarlyStopping(
|
|
13110
|
+
`No improvement for ${this.earlyStoppingTrials} trials`,
|
|
13111
|
+
i - stagnationRounds + 1
|
|
12251
13112
|
);
|
|
13113
|
+
break;
|
|
12252
13114
|
}
|
|
12253
|
-
|
|
12254
|
-
|
|
12255
|
-
|
|
12256
|
-
|
|
12257
|
-
bootstrappedDemos,
|
|
12258
|
-
labeledExamples,
|
|
12259
|
-
valset,
|
|
12260
|
-
metricFn,
|
|
12261
|
-
this.numTrials - 1
|
|
13115
|
+
if (this.checkTargetScore(bestScore)) {
|
|
13116
|
+
this.triggerEarlyStopping(
|
|
13117
|
+
`Target score ${this.targetScore} reached`,
|
|
13118
|
+
i + 1
|
|
12262
13119
|
);
|
|
12263
|
-
|
|
12264
|
-
if (this.verbose) {
|
|
12265
|
-
console.error("Error evaluating default configuration:", err);
|
|
12266
|
-
}
|
|
12267
|
-
bestScore = 0;
|
|
13120
|
+
break;
|
|
12268
13121
|
}
|
|
12269
13122
|
}
|
|
13123
|
+
this.stats.convergenceInfo.stagnationRounds = stagnationRounds;
|
|
13124
|
+
this.stats.convergenceInfo.finalImprovement = scoreHistory.length > 1 ? bestScore - scoreHistory[0] : 0;
|
|
13125
|
+
this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
|
|
12270
13126
|
return { bestConfig, bestScore };
|
|
12271
13127
|
}
|
|
12272
|
-
|
|
12273
|
-
|
|
12274
|
-
*/
|
|
12275
|
-
async evaluateConfig(config, bootstrappedDemos, labeledExamples, valset, metricFn, trialIndex) {
|
|
13128
|
+
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, valset, metricFn) {
|
|
13129
|
+
const testProgram = { ...program };
|
|
12276
13130
|
this.applyConfigToProgram(
|
|
12277
|
-
|
|
13131
|
+
testProgram,
|
|
12278
13132
|
config,
|
|
12279
13133
|
bootstrappedDemos,
|
|
12280
13134
|
labeledExamples
|
|
12281
13135
|
);
|
|
12282
|
-
let
|
|
12283
|
-
|
|
12284
|
-
|
|
12285
|
-
const minibatchEvalSet = [];
|
|
12286
|
-
for (let j = 0; j < this.minibatchSize; j++) {
|
|
12287
|
-
const idx = (startIdx + j) % valset.length;
|
|
12288
|
-
const example = valset[idx];
|
|
12289
|
-
if (example) {
|
|
12290
|
-
minibatchEvalSet.push(example);
|
|
12291
|
-
}
|
|
12292
|
-
}
|
|
12293
|
-
evalSet = minibatchEvalSet;
|
|
12294
|
-
}
|
|
12295
|
-
let sumOfScores = 0;
|
|
13136
|
+
let totalScore = 0;
|
|
13137
|
+
let count = 0;
|
|
13138
|
+
const evalSet = valset.slice(0, Math.min(5, valset.length));
|
|
12296
13139
|
for (const example of evalSet) {
|
|
12297
13140
|
try {
|
|
12298
|
-
const prediction = await
|
|
12299
|
-
|
|
12300
|
-
|
|
12301
|
-
|
|
12302
|
-
|
|
12303
|
-
|
|
12304
|
-
|
|
12305
|
-
|
|
12306
|
-
|
|
12307
|
-
|
|
12308
|
-
return sumOfScores / evalSet.length;
|
|
12309
|
-
}
|
|
12310
|
-
/**
|
|
12311
|
-
* Run full evaluation on the entire validation set
|
|
12312
|
-
*/
|
|
12313
|
-
async fullEvaluation(config, bootstrappedDemos, labeledExamples, valset, metricFn) {
|
|
12314
|
-
this.applyConfigToProgram(
|
|
12315
|
-
this.program,
|
|
12316
|
-
config,
|
|
12317
|
-
bootstrappedDemos,
|
|
12318
|
-
labeledExamples
|
|
12319
|
-
);
|
|
12320
|
-
let sumOfScores = 0;
|
|
12321
|
-
for (const example of valset) {
|
|
12322
|
-
try {
|
|
12323
|
-
const prediction = await this.program.forward(this.ai, example);
|
|
12324
|
-
const score = metricFn({ prediction, example });
|
|
12325
|
-
sumOfScores += score;
|
|
12326
|
-
} catch (err) {
|
|
12327
|
-
if (this.verbose) {
|
|
12328
|
-
console.error("Error evaluating example:", err);
|
|
12329
|
-
}
|
|
13141
|
+
const prediction = await testProgram.forward(
|
|
13142
|
+
this.studentAI,
|
|
13143
|
+
example
|
|
13144
|
+
);
|
|
13145
|
+
const score = await metricFn({ prediction, example });
|
|
13146
|
+
totalScore += score;
|
|
13147
|
+
count++;
|
|
13148
|
+
this.stats.totalCalls++;
|
|
13149
|
+
} catch {
|
|
13150
|
+
continue;
|
|
12330
13151
|
}
|
|
12331
13152
|
}
|
|
12332
|
-
|
|
12333
|
-
return sumOfScores / valset.length;
|
|
12334
|
-
}
|
|
12335
|
-
/**
|
|
12336
|
-
* Implements a Bayesian-inspired selection of the next configuration to try
|
|
12337
|
-
* This is a simplified version using Upper Confidence Bound (UCB) strategy
|
|
12338
|
-
*/
|
|
12339
|
-
selectNextConfiguration(evaluatedConfigs, maxBootstrappedDemos, maxLabeledExamples, instructions) {
|
|
12340
|
-
if (evaluatedConfigs.length < 5) {
|
|
12341
|
-
const instructionIndex = Math.floor(Math.random() * instructions.length);
|
|
12342
|
-
return {
|
|
12343
|
-
instruction: instructions[instructionIndex] || "",
|
|
12344
|
-
bootstrappedDemos: Math.floor(
|
|
12345
|
-
Math.random() * (maxBootstrappedDemos + 1)
|
|
12346
|
-
),
|
|
12347
|
-
labeledExamples: Math.floor(Math.random() * (maxLabeledExamples + 1))
|
|
12348
|
-
};
|
|
12349
|
-
}
|
|
12350
|
-
const sortedConfigs = [...evaluatedConfigs].sort(
|
|
12351
|
-
(a, b) => b.score - a.score
|
|
12352
|
-
);
|
|
12353
|
-
const topConfigs = sortedConfigs.slice(0, Math.min(3, sortedConfigs.length));
|
|
12354
|
-
const meanBootstrappedDemos = topConfigs.reduce((sum, c) => sum + c.config.bootstrappedDemos, 0) / topConfigs.length;
|
|
12355
|
-
const meanLabeledExamples = topConfigs.reduce((sum, c) => sum + c.config.labeledExamples, 0) / topConfigs.length;
|
|
12356
|
-
const popularInstructions = topConfigs.map((c) => c.config.instruction);
|
|
12357
|
-
const explorationFactor = Math.max(
|
|
12358
|
-
0.2,
|
|
12359
|
-
1 - evaluatedConfigs.length / this.numTrials
|
|
12360
|
-
);
|
|
12361
|
-
let newBootstrappedDemos;
|
|
12362
|
-
let newLabeledExamples;
|
|
12363
|
-
let newInstruction;
|
|
12364
|
-
if (Math.random() < 0.7) {
|
|
12365
|
-
newBootstrappedDemos = Math.min(
|
|
12366
|
-
maxBootstrappedDemos,
|
|
12367
|
-
Math.max(
|
|
12368
|
-
0,
|
|
12369
|
-
Math.round(
|
|
12370
|
-
meanBootstrappedDemos + (Math.random() * 2 - 1) * explorationFactor * 2
|
|
12371
|
-
)
|
|
12372
|
-
)
|
|
12373
|
-
);
|
|
12374
|
-
} else {
|
|
12375
|
-
newBootstrappedDemos = Math.floor(
|
|
12376
|
-
Math.random() * (maxBootstrappedDemos + 1)
|
|
12377
|
-
);
|
|
12378
|
-
}
|
|
12379
|
-
if (Math.random() < 0.7) {
|
|
12380
|
-
newLabeledExamples = Math.min(
|
|
12381
|
-
maxLabeledExamples,
|
|
12382
|
-
Math.max(
|
|
12383
|
-
0,
|
|
12384
|
-
Math.round(
|
|
12385
|
-
meanLabeledExamples + (Math.random() * 2 - 1) * explorationFactor * 2
|
|
12386
|
-
)
|
|
12387
|
-
)
|
|
12388
|
-
);
|
|
12389
|
-
} else {
|
|
12390
|
-
newLabeledExamples = Math.floor(Math.random() * (maxLabeledExamples + 1));
|
|
12391
|
-
}
|
|
12392
|
-
if (Math.random() < 0.7 && popularInstructions.length > 0) {
|
|
12393
|
-
const idx = Math.floor(Math.random() * popularInstructions.length);
|
|
12394
|
-
newInstruction = popularInstructions[idx] || "";
|
|
12395
|
-
} else {
|
|
12396
|
-
const idx = Math.floor(Math.random() * instructions.length);
|
|
12397
|
-
newInstruction = instructions[idx] || "";
|
|
12398
|
-
}
|
|
12399
|
-
return {
|
|
12400
|
-
instruction: newInstruction,
|
|
12401
|
-
bootstrappedDemos: newBootstrappedDemos,
|
|
12402
|
-
labeledExamples: newLabeledExamples
|
|
12403
|
-
};
|
|
13153
|
+
return count > 0 ? totalScore / count : 0;
|
|
12404
13154
|
}
|
|
12405
|
-
/**
|
|
12406
|
-
* Applies a configuration to a program instance
|
|
12407
|
-
*/
|
|
12408
13155
|
applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
|
|
12409
|
-
|
|
12410
|
-
|
|
13156
|
+
if (program.setInstruction) {
|
|
13157
|
+
program.setInstruction(config.instruction);
|
|
13158
|
+
}
|
|
13159
|
+
if (config.bootstrappedDemos > 0 && program.setDemos) {
|
|
12411
13160
|
program.setDemos(bootstrappedDemos.slice(0, config.bootstrappedDemos));
|
|
12412
13161
|
}
|
|
12413
|
-
if (config.labeledExamples > 0) {
|
|
13162
|
+
if (config.labeledExamples > 0 && program.setExamples) {
|
|
12414
13163
|
program.setExamples(labeledExamples.slice(0, config.labeledExamples));
|
|
12415
13164
|
}
|
|
12416
13165
|
}
|
|
12417
|
-
/**
|
|
12418
|
-
* Sets instruction to a program
|
|
12419
|
-
* Note: Workaround since setInstruction may not be available directly
|
|
12420
|
-
*/
|
|
12421
|
-
setInstructionToProgram(program, instruction) {
|
|
12422
|
-
const programWithInstruction = program;
|
|
12423
|
-
programWithInstruction.setInstruction?.(instruction);
|
|
12424
|
-
}
|
|
12425
13166
|
/**
|
|
12426
13167
|
* The main compile method to run MIPROv2 optimization
|
|
12427
|
-
* @param metricFn Evaluation metric function
|
|
12428
|
-
* @param options Optional configuration options
|
|
12429
|
-
* @returns The optimization result
|
|
12430
13168
|
*/
|
|
12431
|
-
async compile(metricFn, options) {
|
|
13169
|
+
async compile(program, metricFn, options) {
|
|
13170
|
+
const startTime = Date.now();
|
|
13171
|
+
this.setupRandomSeed();
|
|
12432
13172
|
const miproOptions = options;
|
|
12433
13173
|
if (miproOptions?.auto) {
|
|
12434
13174
|
this.configureAuto(miproOptions.auto);
|
|
12435
13175
|
}
|
|
12436
|
-
const
|
|
12437
|
-
|
|
12438
|
-
|
|
12439
|
-
|
|
12440
|
-
|
|
12441
|
-
`Using ${trainset.length} examples for training and ${valset.length} for validation`
|
|
13176
|
+
const valset = this.getValidationSet(options) || (miproOptions?.valset ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
|
|
13177
|
+
if (this.isLoggingEnabled(options)) {
|
|
13178
|
+
this.getLogger(options)?.(
|
|
13179
|
+
`Starting MIPROv2 optimization with ${this.numTrials} trials`,
|
|
13180
|
+
{ tags: ["optimizer", "start"] }
|
|
12442
13181
|
);
|
|
12443
|
-
|
|
12444
|
-
|
|
12445
|
-
|
|
12446
|
-
|
|
13182
|
+
this.getLogger(options)?.(
|
|
13183
|
+
`Using ${this.examples.length} examples for training and ${valset.length} for validation`,
|
|
13184
|
+
{ tags: ["optimizer", "config"] }
|
|
13185
|
+
);
|
|
13186
|
+
if (this.teacherAI) {
|
|
13187
|
+
this.getLogger(options)?.(
|
|
13188
|
+
"Using separate teacher model for instruction generation",
|
|
13189
|
+
{ tags: ["optimizer", "config"] }
|
|
13190
|
+
);
|
|
12447
13191
|
}
|
|
12448
|
-
const bootstrapperWithTeacher = new AxBootstrapFewShot({
|
|
12449
|
-
ai: this.ai,
|
|
12450
|
-
program: this.program,
|
|
12451
|
-
examples: this.examples,
|
|
12452
|
-
options: {
|
|
12453
|
-
maxDemos: this.maxBootstrappedDemos,
|
|
12454
|
-
maxRounds: 3,
|
|
12455
|
-
verboseMode: this.verbose,
|
|
12456
|
-
teacherAI: this.ai
|
|
12457
|
-
// Use the same AI but with the teacher program
|
|
12458
|
-
}
|
|
12459
|
-
});
|
|
12460
|
-
this.bootstrapper = bootstrapperWithTeacher;
|
|
12461
13192
|
}
|
|
12462
13193
|
let bootstrappedDemos = [];
|
|
12463
13194
|
if (this.maxBootstrappedDemos > 0) {
|
|
12464
|
-
bootstrappedDemos = await this.bootstrapFewShotExamples(metricFn);
|
|
12465
|
-
if (this.
|
|
12466
|
-
|
|
12467
|
-
`Generated ${bootstrappedDemos.length} bootstrapped demonstrations
|
|
13195
|
+
bootstrappedDemos = await this.bootstrapFewShotExamples(program, metricFn);
|
|
13196
|
+
if (this.isLoggingEnabled(options)) {
|
|
13197
|
+
this.getLogger(options)?.(
|
|
13198
|
+
`Generated ${bootstrappedDemos.length} bootstrapped demonstrations`,
|
|
13199
|
+
{ tags: ["optimizer", "result"] }
|
|
12468
13200
|
);
|
|
12469
13201
|
}
|
|
12470
13202
|
}
|
|
12471
13203
|
let labeledExamples = [];
|
|
12472
13204
|
if (this.maxLabeledDemos > 0) {
|
|
12473
13205
|
labeledExamples = this.selectLabeledExamples();
|
|
12474
|
-
if (this.
|
|
12475
|
-
|
|
12476
|
-
`Selected ${labeledExamples.length} labeled examples from training set
|
|
13206
|
+
if (this.isLoggingEnabled(options)) {
|
|
13207
|
+
this.getLogger(options)?.(
|
|
13208
|
+
`Selected ${labeledExamples.length} labeled examples from training set`,
|
|
13209
|
+
{ tags: ["optimizer", "result"] }
|
|
12477
13210
|
);
|
|
12478
13211
|
}
|
|
12479
13212
|
}
|
|
12480
|
-
const instructions = await this.proposeInstructionCandidates();
|
|
12481
|
-
if (this.
|
|
12482
|
-
|
|
13213
|
+
const instructions = await this.proposeInstructionCandidates(options);
|
|
13214
|
+
if (this.isLoggingEnabled(options)) {
|
|
13215
|
+
this.getLogger(options)?.(
|
|
13216
|
+
`Generated ${instructions.length} instruction candidates`,
|
|
13217
|
+
{ tags: ["optimizer", "result"] }
|
|
13218
|
+
);
|
|
13219
|
+
if (this.hasTeacherAI(options)) {
|
|
13220
|
+
this.getLogger(options)?.(
|
|
13221
|
+
"Using teacher AI for instruction generation",
|
|
13222
|
+
{ tags: ["optimizer", "config"] }
|
|
13223
|
+
);
|
|
13224
|
+
}
|
|
12483
13225
|
}
|
|
12484
|
-
const { bestConfig, bestScore } = await this.
|
|
13226
|
+
const { bestConfig, bestScore } = await this.runOptimization(
|
|
13227
|
+
program,
|
|
12485
13228
|
bootstrappedDemos,
|
|
12486
13229
|
labeledExamples,
|
|
12487
13230
|
instructions,
|
|
12488
13231
|
valset,
|
|
12489
|
-
metricFn
|
|
13232
|
+
metricFn,
|
|
13233
|
+
options
|
|
12490
13234
|
);
|
|
12491
|
-
if (this.
|
|
12492
|
-
|
|
12493
|
-
|
|
13235
|
+
if (this.isLoggingEnabled(options)) {
|
|
13236
|
+
this.getLogger(options)?.(
|
|
13237
|
+
`Optimization complete. Best score: ${bestScore}`,
|
|
13238
|
+
{ tags: ["optimizer", "complete"] }
|
|
13239
|
+
);
|
|
13240
|
+
this.getLogger(options)?.(
|
|
13241
|
+
`Best configuration: ${JSON.stringify(bestConfig)}`,
|
|
13242
|
+
{ tags: ["optimizer", "result"] }
|
|
13243
|
+
);
|
|
12494
13244
|
}
|
|
12495
|
-
this.
|
|
12496
|
-
this.
|
|
13245
|
+
if (this.checkTargetScore(bestScore)) {
|
|
13246
|
+
this.triggerEarlyStopping(
|
|
13247
|
+
`Target score ${this.targetScore} reached with score ${bestScore}`,
|
|
13248
|
+
this.numTrials
|
|
13249
|
+
);
|
|
13250
|
+
}
|
|
13251
|
+
let signature;
|
|
13252
|
+
if ("getSignature" in program && typeof program.getSignature === "function") {
|
|
13253
|
+
signature = program.getSignature();
|
|
13254
|
+
} else {
|
|
13255
|
+
signature = "input -> output";
|
|
13256
|
+
}
|
|
13257
|
+
const optimizedGen = new AxGen(signature);
|
|
13258
|
+
this.applyConfigToAxGen(
|
|
13259
|
+
optimizedGen,
|
|
12497
13260
|
bestConfig,
|
|
12498
13261
|
bootstrappedDemos,
|
|
12499
13262
|
labeledExamples
|
|
12500
13263
|
);
|
|
13264
|
+
this.updateResourceUsage(startTime);
|
|
13265
|
+
this.stats.convergenceInfo.converged = true;
|
|
13266
|
+
this.stats.convergenceInfo.finalImprovement = bestScore;
|
|
13267
|
+
await this.saveFinalCheckpoint(
|
|
13268
|
+
"MiPRO",
|
|
13269
|
+
this.getConfiguration(),
|
|
13270
|
+
bestScore,
|
|
13271
|
+
bestConfig,
|
|
13272
|
+
{
|
|
13273
|
+
bootstrappedDemos: bootstrappedDemos.length,
|
|
13274
|
+
labeledExamples: labeledExamples.length,
|
|
13275
|
+
instructions: instructions.length,
|
|
13276
|
+
optimizedGen: !!optimizedGen
|
|
13277
|
+
},
|
|
13278
|
+
options
|
|
13279
|
+
);
|
|
12501
13280
|
return {
|
|
12502
|
-
|
|
12503
|
-
|
|
13281
|
+
demos: bootstrappedDemos,
|
|
13282
|
+
stats: this.stats,
|
|
13283
|
+
bestScore,
|
|
13284
|
+
optimizedGen,
|
|
13285
|
+
finalConfiguration: {
|
|
13286
|
+
instruction: bestConfig.instruction,
|
|
13287
|
+
bootstrappedDemos: bestConfig.bootstrappedDemos,
|
|
13288
|
+
labeledExamples: bestConfig.labeledExamples,
|
|
13289
|
+
numCandidates: this.numCandidates,
|
|
13290
|
+
numTrials: this.numTrials
|
|
13291
|
+
}
|
|
12504
13292
|
};
|
|
12505
13293
|
}
|
|
12506
13294
|
/**
|
|
12507
|
-
*
|
|
12508
|
-
* @returns Optimization statistics or undefined if not available
|
|
13295
|
+
* Applies a configuration to an AxGen instance
|
|
12509
13296
|
*/
|
|
12510
|
-
|
|
12511
|
-
|
|
13297
|
+
applyConfigToAxGen(axgen, config, bootstrappedDemos, labeledExamples) {
|
|
13298
|
+
if ("setInstruction" in axgen && typeof axgen.setInstruction === "function") {
|
|
13299
|
+
axgen.setInstruction(config.instruction);
|
|
13300
|
+
}
|
|
13301
|
+
if (config.bootstrappedDemos > 0) {
|
|
13302
|
+
axgen.setDemos(bootstrappedDemos.slice(0, config.bootstrappedDemos));
|
|
13303
|
+
}
|
|
13304
|
+
if (config.labeledExamples > 0) {
|
|
13305
|
+
axgen.setExamples(
|
|
13306
|
+
labeledExamples.slice(
|
|
13307
|
+
0,
|
|
13308
|
+
config.labeledExamples
|
|
13309
|
+
)
|
|
13310
|
+
);
|
|
13311
|
+
}
|
|
13312
|
+
}
|
|
13313
|
+
/**
|
|
13314
|
+
* Get optimizer-specific configuration
|
|
13315
|
+
* @returns Current optimizer configuration
|
|
13316
|
+
*/
|
|
13317
|
+
getConfiguration() {
|
|
13318
|
+
return {
|
|
13319
|
+
numCandidates: this.numCandidates,
|
|
13320
|
+
initTemperature: this.initTemperature,
|
|
13321
|
+
maxBootstrappedDemos: this.maxBootstrappedDemos,
|
|
13322
|
+
maxLabeledDemos: this.maxLabeledDemos,
|
|
13323
|
+
numTrials: this.numTrials,
|
|
13324
|
+
minibatch: this.minibatch,
|
|
13325
|
+
minibatchSize: this.minibatchSize,
|
|
13326
|
+
minibatchFullEvalSteps: this.minibatchFullEvalSteps,
|
|
13327
|
+
programAwareProposer: this.programAwareProposer,
|
|
13328
|
+
dataAwareProposer: this.dataAwareProposer,
|
|
13329
|
+
tipAwareProposer: this.tipAwareProposer,
|
|
13330
|
+
fewshotAwareProposer: this.fewshotAwareProposer,
|
|
13331
|
+
earlyStoppingTrials: this.earlyStoppingTrials,
|
|
13332
|
+
minImprovementThreshold: this.minImprovementThreshold,
|
|
13333
|
+
bayesianOptimization: this.bayesianOptimization,
|
|
13334
|
+
acquisitionFunction: this.acquisitionFunction,
|
|
13335
|
+
explorationWeight: this.explorationWeight
|
|
13336
|
+
};
|
|
13337
|
+
}
|
|
13338
|
+
/**
|
|
13339
|
+
* Update optimizer configuration
|
|
13340
|
+
* @param config New configuration to merge with existing
|
|
13341
|
+
*/
|
|
13342
|
+
updateConfiguration(config) {
|
|
13343
|
+
if (config.numCandidates !== void 0) {
|
|
13344
|
+
this.numCandidates = config.numCandidates;
|
|
13345
|
+
}
|
|
13346
|
+
if (config.initTemperature !== void 0) {
|
|
13347
|
+
this.initTemperature = config.initTemperature;
|
|
13348
|
+
}
|
|
13349
|
+
if (config.maxBootstrappedDemos !== void 0) {
|
|
13350
|
+
this.maxBootstrappedDemos = config.maxBootstrappedDemos;
|
|
13351
|
+
}
|
|
13352
|
+
if (config.maxLabeledDemos !== void 0) {
|
|
13353
|
+
this.maxLabeledDemos = config.maxLabeledDemos;
|
|
13354
|
+
}
|
|
13355
|
+
if (config.numTrials !== void 0) {
|
|
13356
|
+
this.numTrials = config.numTrials;
|
|
13357
|
+
}
|
|
13358
|
+
if (config.minibatch !== void 0) {
|
|
13359
|
+
this.minibatch = config.minibatch;
|
|
13360
|
+
}
|
|
13361
|
+
if (config.minibatchSize !== void 0) {
|
|
13362
|
+
this.minibatchSize = config.minibatchSize;
|
|
13363
|
+
}
|
|
13364
|
+
if (config.earlyStoppingTrials !== void 0) {
|
|
13365
|
+
this.earlyStoppingTrials = config.earlyStoppingTrials;
|
|
13366
|
+
}
|
|
13367
|
+
if (config.minImprovementThreshold !== void 0) {
|
|
13368
|
+
this.minImprovementThreshold = config.minImprovementThreshold;
|
|
13369
|
+
}
|
|
13370
|
+
}
|
|
13371
|
+
/**
|
|
13372
|
+
* Reset optimizer state for reuse with different programs
|
|
13373
|
+
*/
|
|
13374
|
+
reset() {
|
|
13375
|
+
super.reset();
|
|
13376
|
+
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13377
|
+
}
|
|
13378
|
+
/**
|
|
13379
|
+
* Validate that the optimizer can handle the given program
|
|
13380
|
+
* @param program Program to validate
|
|
13381
|
+
* @returns Validation result with any issues found
|
|
13382
|
+
*/
|
|
13383
|
+
validateProgram(program) {
|
|
13384
|
+
const result = super.validateProgram(program);
|
|
13385
|
+
if (this.examples.length < this.maxBootstrappedDemos + this.maxLabeledDemos) {
|
|
13386
|
+
result.issues.push(
|
|
13387
|
+
`Not enough examples: need at least ${this.maxBootstrappedDemos + this.maxLabeledDemos}, got ${this.examples.length}`
|
|
13388
|
+
);
|
|
13389
|
+
result.suggestions.push(
|
|
13390
|
+
"Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
|
|
13391
|
+
);
|
|
13392
|
+
}
|
|
13393
|
+
const valSetSize = this.getValidationSet().length;
|
|
13394
|
+
if (valSetSize < 5) {
|
|
13395
|
+
result.issues.push(
|
|
13396
|
+
"Validation set too small for reliable MiPRO optimization"
|
|
13397
|
+
);
|
|
13398
|
+
result.suggestions.push(
|
|
13399
|
+
"Provide more examples or a larger validation set"
|
|
13400
|
+
);
|
|
13401
|
+
}
|
|
13402
|
+
return {
|
|
13403
|
+
isValid: result.issues.length === 0,
|
|
13404
|
+
issues: result.issues,
|
|
13405
|
+
suggestions: result.suggestions
|
|
13406
|
+
};
|
|
12512
13407
|
}
|
|
12513
13408
|
};
|
|
12514
13409
|
|
|
@@ -12755,7 +13650,7 @@ var AxTestPrompt = class {
|
|
|
12755
13650
|
throw new Error("Invalid example");
|
|
12756
13651
|
}
|
|
12757
13652
|
const res = await this.program.forward(this.ai, ex);
|
|
12758
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
13653
|
+
const score = await metricFn({ prediction: res, example: ex });
|
|
12759
13654
|
sumOfScores += score;
|
|
12760
13655
|
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
12761
13656
|
updateProgressBar(i, total, sumOfScores, et, "Testing Prompt", 30);
|
|
@@ -14789,7 +15684,6 @@ var AxRAG = class extends AxChainOfThought {
|
|
|
14789
15684
|
);
|
|
14790
15685
|
this.genQuery = new AxGen(qsig);
|
|
14791
15686
|
this.queryFn = queryFn;
|
|
14792
|
-
this.register(this.genQuery);
|
|
14793
15687
|
}
|
|
14794
15688
|
async forward(ai, values, options) {
|
|
14795
15689
|
let question;
|
|
@@ -14867,6 +15761,7 @@ var AxRAG = class extends AxChainOfThought {
|
|
|
14867
15761
|
AxAssertionError,
|
|
14868
15762
|
AxBalancer,
|
|
14869
15763
|
AxBaseAI,
|
|
15764
|
+
AxBaseOptimizer,
|
|
14870
15765
|
AxBootstrapFewShot,
|
|
14871
15766
|
AxChainOfThought,
|
|
14872
15767
|
AxDB,
|
|
@@ -14876,6 +15771,7 @@ var AxRAG = class extends AxChainOfThought {
|
|
|
14876
15771
|
AxDBMemory,
|
|
14877
15772
|
AxDBPinecone,
|
|
14878
15773
|
AxDBWeaviate,
|
|
15774
|
+
AxDefaultCostTracker,
|
|
14879
15775
|
AxDefaultQueryRewriter,
|
|
14880
15776
|
AxDefaultResultReranker,
|
|
14881
15777
|
AxDockerSession,
|
|
@@ -14945,6 +15841,10 @@ var AxRAG = class extends AxChainOfThought {
|
|
|
14945
15841
|
axAITogetherDefaultConfig,
|
|
14946
15842
|
axBaseAIDefaultConfig,
|
|
14947
15843
|
axBaseAIDefaultCreativeConfig,
|
|
15844
|
+
axCreateDefaultLogger,
|
|
15845
|
+
axCreateDefaultTextLogger,
|
|
15846
|
+
axCreateOptimizerLogger,
|
|
15847
|
+
axDefaultOptimizerLogger,
|
|
14948
15848
|
axGlobals,
|
|
14949
15849
|
axModelInfoAnthropic,
|
|
14950
15850
|
axModelInfoCohere,
|