@ax-llm/ax 12.0.7 → 12.0.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +1588 -688
- package/index.cjs.map +1 -1
- package/index.d.cts +536 -191
- package/index.d.ts +536 -191
- package/index.js +1586 -692
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.js
CHANGED
|
@@ -69,7 +69,7 @@ var AxSpanKindValues = /* @__PURE__ */ ((AxSpanKindValues2) => {
|
|
|
69
69
|
// util/apicall.ts
|
|
70
70
|
import crypto from "crypto";
|
|
71
71
|
import {
|
|
72
|
-
ReadableStream
|
|
72
|
+
ReadableStream,
|
|
73
73
|
TextDecoderStream as TextDecoderStreamNative,
|
|
74
74
|
TransformStream as TransformStream3
|
|
75
75
|
} from "stream/web";
|
|
@@ -320,6 +320,17 @@ var AxAIServiceAuthenticationError = class extends AxAIServiceError {
|
|
|
320
320
|
this.name = this.constructor.name;
|
|
321
321
|
}
|
|
322
322
|
};
|
|
323
|
+
async function safeReadResponseBody(response) {
|
|
324
|
+
try {
|
|
325
|
+
if (response.headers.get("content-type")?.includes("application/json")) {
|
|
326
|
+
return await response.json();
|
|
327
|
+
}
|
|
328
|
+
const clonedResponse = response.clone();
|
|
329
|
+
return await clonedResponse.text();
|
|
330
|
+
} catch (e) {
|
|
331
|
+
return `[ReadableStream - read failed: ${e.message}]`;
|
|
332
|
+
}
|
|
333
|
+
}
|
|
323
334
|
function calculateRetryDelay(attempt, config) {
|
|
324
335
|
const delay = Math.min(
|
|
325
336
|
config.maxDelayMs,
|
|
@@ -413,9 +424,15 @@ var apiCall = async (api, json) => {
|
|
|
413
424
|
});
|
|
414
425
|
clearTimeout(timeoutId);
|
|
415
426
|
if (res.status === 401 || res.status === 403) {
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
427
|
+
const responseBody = await safeReadResponseBody(res);
|
|
428
|
+
throw new AxAIServiceAuthenticationError(
|
|
429
|
+
apiUrl.href,
|
|
430
|
+
json,
|
|
431
|
+
responseBody,
|
|
432
|
+
{
|
|
433
|
+
metrics
|
|
434
|
+
}
|
|
435
|
+
);
|
|
419
436
|
}
|
|
420
437
|
if (res.status >= 400 && shouldRetry(new Error(), res.status, attempt, retryConfig)) {
|
|
421
438
|
const delay = calculateRetryDelay(attempt, retryConfig);
|
|
@@ -433,12 +450,13 @@ var apiCall = async (api, json) => {
|
|
|
433
450
|
continue;
|
|
434
451
|
}
|
|
435
452
|
if (res.status >= 400) {
|
|
453
|
+
const responseBody = await safeReadResponseBody(res);
|
|
436
454
|
throw new AxAIServiceStatusError(
|
|
437
455
|
res.status,
|
|
438
456
|
res.statusText,
|
|
439
457
|
apiUrl.href,
|
|
440
458
|
json,
|
|
441
|
-
|
|
459
|
+
responseBody,
|
|
442
460
|
{ metrics }
|
|
443
461
|
);
|
|
444
462
|
}
|
|
@@ -486,7 +504,7 @@ var apiCall = async (api, json) => {
|
|
|
486
504
|
}
|
|
487
505
|
});
|
|
488
506
|
let closed = false;
|
|
489
|
-
return new
|
|
507
|
+
return new ReadableStream({
|
|
490
508
|
start(controller) {
|
|
491
509
|
const reader = res.body.pipeThrough(new textDecoderStream()).pipeThrough(new SSEParser()).pipeThrough(trackingStream).getReader();
|
|
492
510
|
async function read() {
|
|
@@ -536,7 +554,7 @@ var apiCall = async (api, json) => {
|
|
|
536
554
|
error,
|
|
537
555
|
apiUrl.href,
|
|
538
556
|
json,
|
|
539
|
-
|
|
557
|
+
"[ReadableStream - consumed during streaming]",
|
|
540
558
|
{
|
|
541
559
|
streamMetrics
|
|
542
560
|
}
|
|
@@ -667,12 +685,12 @@ var ColorLog = class {
|
|
|
667
685
|
}
|
|
668
686
|
};
|
|
669
687
|
|
|
670
|
-
//
|
|
688
|
+
// dsp/loggers.ts
|
|
671
689
|
var colorLog = new ColorLog();
|
|
672
690
|
var defaultOutput = (message) => {
|
|
673
691
|
process.stdout.write(message);
|
|
674
692
|
};
|
|
675
|
-
var
|
|
693
|
+
var axCreateDefaultLogger = (output = defaultOutput) => {
|
|
676
694
|
return (message, options) => {
|
|
677
695
|
const tags = options?.tags ?? [];
|
|
678
696
|
let formattedMessage = message;
|
|
@@ -681,12 +699,44 @@ var createDefaultLogger = (output = defaultOutput) => {
|
|
|
681
699
|
} else if (tags.includes("success") || tags.includes("responseContent")) {
|
|
682
700
|
formattedMessage = colorLog.greenBright(formattedMessage);
|
|
683
701
|
} else if (tags.includes("functionName")) {
|
|
684
|
-
|
|
685
|
-
|
|
702
|
+
if (tags.includes("firstFunction")) {
|
|
703
|
+
formattedMessage = `
|
|
704
|
+
${colorLog.whiteBright(formattedMessage)}`;
|
|
705
|
+
} else {
|
|
706
|
+
formattedMessage = `${colorLog.whiteBright(formattedMessage)}`;
|
|
707
|
+
}
|
|
708
|
+
} else if (tags.includes("systemContent") || tags.includes("assistantContent")) {
|
|
686
709
|
formattedMessage = colorLog.blueBright(formattedMessage);
|
|
687
710
|
} else if (tags.includes("warning") || tags.includes("discovery")) {
|
|
688
711
|
formattedMessage = colorLog.yellow(formattedMessage);
|
|
712
|
+
} else if (tags.includes("functionArg")) {
|
|
713
|
+
formattedMessage = "";
|
|
714
|
+
}
|
|
715
|
+
if (tags.includes("responseStart") || tags.includes("systemStart") || tags.includes("userStart")) {
|
|
716
|
+
formattedMessage = `
|
|
717
|
+
${formattedMessage}`;
|
|
718
|
+
} else if (tags.includes("responseEnd") || tags.includes("systemEnd") || tags.includes("userEnd")) {
|
|
719
|
+
formattedMessage = `${formattedMessage}
|
|
720
|
+
`;
|
|
721
|
+
} else if (tags.includes("assistantStart")) {
|
|
722
|
+
formattedMessage = `
|
|
723
|
+
${formattedMessage}
|
|
724
|
+
`;
|
|
725
|
+
} else if (tags.includes("error")) {
|
|
726
|
+
formattedMessage = `
|
|
727
|
+
${formattedMessage}
|
|
728
|
+
`;
|
|
729
|
+
} else if (tags.includes("functionEnd")) {
|
|
730
|
+
formattedMessage = `
|
|
731
|
+
`;
|
|
689
732
|
}
|
|
733
|
+
output(formattedMessage);
|
|
734
|
+
};
|
|
735
|
+
};
|
|
736
|
+
var axCreateDefaultTextLogger = (output = defaultOutput) => {
|
|
737
|
+
return (message, options) => {
|
|
738
|
+
const tags = options?.tags ?? [];
|
|
739
|
+
let formattedMessage = message;
|
|
690
740
|
if (tags.includes("responseStart") || tags.includes("systemStart") || tags.includes("userStart")) {
|
|
691
741
|
formattedMessage = `
|
|
692
742
|
${formattedMessage}`;
|
|
@@ -708,7 +758,137 @@ ${formattedMessage}
|
|
|
708
758
|
output(formattedMessage);
|
|
709
759
|
};
|
|
710
760
|
};
|
|
711
|
-
var
|
|
761
|
+
var axCreateOptimizerLogger = (output = (msg) => process.stdout.write(msg)) => {
|
|
762
|
+
const baseLogger = axCreateDefaultLogger(output);
|
|
763
|
+
let isFirstPhase = true;
|
|
764
|
+
return (message, options) => {
|
|
765
|
+
const tags = options?.tags ?? [];
|
|
766
|
+
let formattedMessage = message;
|
|
767
|
+
if (tags.includes("optimizer")) {
|
|
768
|
+
if (tags.includes("start")) {
|
|
769
|
+
const trialsMatch = message.match(/with (\d+) trials?/) || message.match(/(\d+) trials?/);
|
|
770
|
+
const optimizerMatch = message.match(
|
|
771
|
+
/(MIPROv2|BootstrapFewshot|[A-Z][a-zA-Z]+)/
|
|
772
|
+
);
|
|
773
|
+
const optimizerName = optimizerMatch ? optimizerMatch[1] : "Optimizer";
|
|
774
|
+
if (trialsMatch && trialsMatch[1]) {
|
|
775
|
+
formattedMessage = `
|
|
776
|
+
\u250C\u2500 ${optimizerName} optimization (${trialsMatch[1]} trials)
|
|
777
|
+
`;
|
|
778
|
+
} else {
|
|
779
|
+
formattedMessage = `
|
|
780
|
+
\u250C\u2500 ${optimizerName} optimization
|
|
781
|
+
`;
|
|
782
|
+
}
|
|
783
|
+
isFirstPhase = true;
|
|
784
|
+
} else if (tags.includes("config")) {
|
|
785
|
+
if (message.includes("examples") && message.includes("training")) {
|
|
786
|
+
const match = message.match(
|
|
787
|
+
/(\d+) examples for training and (\d+) for validation/
|
|
788
|
+
) || message.match(/(\d+) training.*?(\d+) validation/);
|
|
789
|
+
if (match && match[1] && match[2]) {
|
|
790
|
+
formattedMessage = `\u2502 Dataset: ${match[1]} training, ${match[2]} validation
|
|
791
|
+
`;
|
|
792
|
+
} else {
|
|
793
|
+
const simpleMatch = message.match(/(\d+) examples/);
|
|
794
|
+
if (simpleMatch && simpleMatch[1]) {
|
|
795
|
+
formattedMessage = `\u2502 Dataset: ${simpleMatch[1]} examples
|
|
796
|
+
`;
|
|
797
|
+
}
|
|
798
|
+
}
|
|
799
|
+
} else if (message.includes("teacher")) {
|
|
800
|
+
formattedMessage = `\u2502 Using teacher model
|
|
801
|
+
`;
|
|
802
|
+
} else {
|
|
803
|
+
formattedMessage = `\u2502 ${message}
|
|
804
|
+
`;
|
|
805
|
+
}
|
|
806
|
+
} else if (tags.includes("phase")) {
|
|
807
|
+
if (isFirstPhase) {
|
|
808
|
+
formattedMessage = `\u251C\u2500 ${message}
|
|
809
|
+
`;
|
|
810
|
+
isFirstPhase = false;
|
|
811
|
+
} else {
|
|
812
|
+
formattedMessage = `\u251C\u2500 ${message}
|
|
813
|
+
`;
|
|
814
|
+
}
|
|
815
|
+
} else if (tags.includes("result")) {
|
|
816
|
+
if (message.includes("Generated") || message.includes("Selected")) {
|
|
817
|
+
const match = message.match(/(\d+)/);
|
|
818
|
+
if (match && match[1]) {
|
|
819
|
+
formattedMessage = `\u2502 \u2713 ${message}
|
|
820
|
+
`;
|
|
821
|
+
} else {
|
|
822
|
+
formattedMessage = `\u2502 \u2713 ${message}
|
|
823
|
+
`;
|
|
824
|
+
}
|
|
825
|
+
} else if (message.includes("configuration")) {
|
|
826
|
+
formattedMessage = `\u2502 Applied best configuration
|
|
827
|
+
`;
|
|
828
|
+
} else {
|
|
829
|
+
formattedMessage = `\u2502 ${message}
|
|
830
|
+
`;
|
|
831
|
+
}
|
|
832
|
+
} else if (tags.includes("progress")) {
|
|
833
|
+
formattedMessage = `\u2502 ${message}
|
|
834
|
+
`;
|
|
835
|
+
} else if (tags.includes("complete")) {
|
|
836
|
+
const scoreMatch = message.match(/(score|performance):\s*([\d.]+)/);
|
|
837
|
+
if (scoreMatch && scoreMatch[2]) {
|
|
838
|
+
const score = parseFloat(scoreMatch[2]);
|
|
839
|
+
const percentage = score <= 1 ? (score * 100).toFixed(1) + "%" : score.toFixed(3);
|
|
840
|
+
formattedMessage = `\u251C\u2500 Complete! Best: ${percentage}
|
|
841
|
+
`;
|
|
842
|
+
} else if (message.includes("Bootstrap")) {
|
|
843
|
+
formattedMessage = `\u251C\u2500 ${message}
|
|
844
|
+
`;
|
|
845
|
+
} else {
|
|
846
|
+
formattedMessage = `\u251C\u2500 Optimization complete
|
|
847
|
+
`;
|
|
848
|
+
}
|
|
849
|
+
} else if (tags.includes("checkpoint")) {
|
|
850
|
+
if (message.includes("Resuming")) {
|
|
851
|
+
formattedMessage = `\u2502 ${message}
|
|
852
|
+
`;
|
|
853
|
+
} else {
|
|
854
|
+
const match = message.match(/checkpoint:\s*(.+)/) || message.match(/Saved\s+(.+)/);
|
|
855
|
+
if (match && match[1]) {
|
|
856
|
+
formattedMessage = `\u2514\u2500 Saved: ${match[1]}
|
|
857
|
+
`;
|
|
858
|
+
} else {
|
|
859
|
+
formattedMessage = `\u2514\u2500 Checkpoint saved
|
|
860
|
+
`;
|
|
861
|
+
}
|
|
862
|
+
}
|
|
863
|
+
}
|
|
864
|
+
} else if (tags.includes("discovery")) {
|
|
865
|
+
if (message.includes("Found") && message.includes("examples")) {
|
|
866
|
+
const match = message.match(/Found (\d+)/);
|
|
867
|
+
if (match && match[1]) {
|
|
868
|
+
formattedMessage = `\u2502 Found ${match[1]} examples
|
|
869
|
+
`;
|
|
870
|
+
}
|
|
871
|
+
}
|
|
872
|
+
}
|
|
873
|
+
if (tags.includes("error")) {
|
|
874
|
+
formattedMessage = `
|
|
875
|
+
\u2717 ${message}
|
|
876
|
+
`;
|
|
877
|
+
} else if (tags.includes("warning")) {
|
|
878
|
+
formattedMessage = `
|
|
879
|
+
\u26A0 ${message}
|
|
880
|
+
`;
|
|
881
|
+
} else if (tags.includes("success") && !tags.includes("optimizer")) {
|
|
882
|
+
formattedMessage = `\u2713 ${message}
|
|
883
|
+
`;
|
|
884
|
+
}
|
|
885
|
+
baseLogger(formattedMessage, options);
|
|
886
|
+
};
|
|
887
|
+
};
|
|
888
|
+
var axDefaultOptimizerLogger = axCreateOptimizerLogger();
|
|
889
|
+
|
|
890
|
+
// ai/debug.ts
|
|
891
|
+
var defaultLogger = axCreateDefaultLogger();
|
|
712
892
|
var formatChatMessage = (msg, hideContent, hideSystemPrompt) => {
|
|
713
893
|
switch (msg.role) {
|
|
714
894
|
case "system":
|
|
@@ -785,9 +965,14 @@ var logResponseResult = (r, logger = defaultLogger) => {
|
|
|
785
965
|
if (r.functionCalls && r.functionCalls.length > 0) {
|
|
786
966
|
for (const [i, f2] of r.functionCalls.entries()) {
|
|
787
967
|
if (f2.function.name) {
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
968
|
+
const tags = ["functionName"];
|
|
969
|
+
if (i === 0) {
|
|
970
|
+
tags.push("firstFunction");
|
|
971
|
+
}
|
|
972
|
+
if (r.functionCalls.length > 1) {
|
|
973
|
+
tags.push("multipleFunctions");
|
|
974
|
+
}
|
|
975
|
+
logger(`[${i + 1}] ${f2.function.name}`, { tags });
|
|
791
976
|
}
|
|
792
977
|
if (f2.function.params) {
|
|
793
978
|
const params = typeof f2.function.params === "string" ? f2.function.params : JSON.stringify(f2.function.params, null, 2);
|
|
@@ -1458,6 +1643,16 @@ function validateAxMessageArray(values) {
|
|
|
1458
1643
|
function validateChatPrompt(chatPrompt) {
|
|
1459
1644
|
for (let i = 0; i < chatPrompt.length; i++) {
|
|
1460
1645
|
const message = chatPrompt[i];
|
|
1646
|
+
if (message && "functionCalls" in message && Array.isArray(message.functionCalls) && message.functionCalls.length === 0) {
|
|
1647
|
+
throw new Error(
|
|
1648
|
+
`Chat prompt validation failed: Message at index ${i} has empty functionCalls`
|
|
1649
|
+
);
|
|
1650
|
+
}
|
|
1651
|
+
if (message && "content" in message && Array.isArray(message.content) && message.content.length === 0) {
|
|
1652
|
+
throw new Error(
|
|
1653
|
+
`Chat prompt validation failed: Message at index ${i} has empty content`
|
|
1654
|
+
);
|
|
1655
|
+
}
|
|
1461
1656
|
if (message && "content" in message && typeof message.content === "string" && message.content.trim() === "") {
|
|
1462
1657
|
throw new Error(
|
|
1463
1658
|
`Chat prompt validation failed: Message at index ${i} has empty content`
|
|
@@ -5498,7 +5693,7 @@ var AxAIGrok = class extends AxAIOpenAIBase {
|
|
|
5498
5693
|
};
|
|
5499
5694
|
|
|
5500
5695
|
// dsp/generate.ts
|
|
5501
|
-
import { ReadableStream as
|
|
5696
|
+
import { ReadableStream as ReadableStream2 } from "stream/web";
|
|
5502
5697
|
import {
|
|
5503
5698
|
context as context2,
|
|
5504
5699
|
SpanKind as SpanKind2,
|
|
@@ -5785,9 +5980,11 @@ var updateProgressBar = (current, total, success, elapsedTime, msg, progressBarW
|
|
|
5785
5980
|
const emptyBarLength = progressBarWidth - filledBarLength;
|
|
5786
5981
|
const filledBar = colorLog3.blueBright("\u2588".repeat(filledBarLength));
|
|
5787
5982
|
const emptyBar = " ".repeat(emptyBarLength);
|
|
5788
|
-
const
|
|
5983
|
+
const successRate = total > 0 ? (success / total * 100).toFixed(1) : "0.0";
|
|
5984
|
+
const friendlyMsg = msg.includes("Running MIPROv2 optimization") ? "Testing prompt variations" : msg.includes("Tuning Prompt") ? "Generating training examples" : msg;
|
|
5789
5985
|
process.stdout.write(
|
|
5790
|
-
`\
|
|
5986
|
+
`\u2502 ${friendlyMsg}: ${current}/${total} (${colorLog3.yellow(percentage)}%) |${filledBar}${emptyBar}| Success rate: ${colorLog3.greenBright(successRate)}%
|
|
5987
|
+
`
|
|
5791
5988
|
);
|
|
5792
5989
|
};
|
|
5793
5990
|
var validateValue = (field, value) => {
|
|
@@ -5994,19 +6191,15 @@ function matchesContent(content, prefix, startIndex = 0, prefixCache = globalPre
|
|
|
5994
6191
|
if (!prefixCache.get(prefix)) {
|
|
5995
6192
|
prefixCache.set(prefix, prefixes);
|
|
5996
6193
|
}
|
|
5997
|
-
|
|
5998
|
-
|
|
5999
|
-
);
|
|
6000
|
-
for (let i = 0; i < prefixes.length - 1; i++) {
|
|
6194
|
+
let longestPartialMatch = -1;
|
|
6195
|
+
for (let i = prefixes.length - 1; i >= 0; i--) {
|
|
6001
6196
|
const partialPrefix = prefixes[i];
|
|
6002
|
-
if (partialPrefix
|
|
6003
|
-
|
|
6004
|
-
|
|
6005
|
-
if (partialPrefix && contentEnd.endsWith(partialPrefix)) {
|
|
6006
|
-
return -2;
|
|
6197
|
+
if (content.endsWith(partialPrefix)) {
|
|
6198
|
+
longestPartialMatch = i;
|
|
6199
|
+
break;
|
|
6007
6200
|
}
|
|
6008
6201
|
}
|
|
6009
|
-
return -1;
|
|
6202
|
+
return longestPartialMatch >= 0 ? -2 : -1;
|
|
6010
6203
|
}
|
|
6011
6204
|
var formatTime = (ms) => {
|
|
6012
6205
|
const seconds = Math.floor(ms / 1e3);
|
|
@@ -6029,11 +6222,10 @@ var updateDetailedProgress = (roundIndex, current, total, elapsedTime, example,
|
|
|
6029
6222
|
process.stdout.write("\r\x1B[K");
|
|
6030
6223
|
const percentage = (current / total * 100).toFixed(1);
|
|
6031
6224
|
const formattedTime = formatTime(elapsedTime);
|
|
6032
|
-
const itemsPerSecond = elapsedTime > 0 ? (current / elapsedTime * 1e3).toFixed(2) : "0.00";
|
|
6033
6225
|
const eta = calculateETA(current, total, elapsedTime);
|
|
6034
|
-
let output = `
|
|
6226
|
+
let output = `Training round ${roundIndex + 1}/${configInfo.maxRounds}: ${current}/${total} (${percentage}%) [${formattedTime}, ETA: ${eta}]`;
|
|
6035
6227
|
const successRate = stats.totalCalls > 0 ? stats.successfulDemos / stats.totalCalls * 100 : 0;
|
|
6036
|
-
output += ` | Success: ${stats.successfulDemos}/${stats.totalCalls}
|
|
6228
|
+
output += ` | Success rate: ${successRate.toFixed(1)}% (${stats.successfulDemos}/${stats.totalCalls})`;
|
|
6037
6229
|
if (configInfo.verboseMode || configInfo.debugMode) {
|
|
6038
6230
|
if (configInfo.costMonitoring) {
|
|
6039
6231
|
output += `
|
|
@@ -6159,7 +6351,7 @@ ${outputFields}`);
|
|
|
6159
6351
|
content: systemContent
|
|
6160
6352
|
};
|
|
6161
6353
|
if (Array.isArray(values)) {
|
|
6162
|
-
let
|
|
6354
|
+
let messages = [];
|
|
6163
6355
|
const history = values;
|
|
6164
6356
|
for (const [index, message] of history.entries()) {
|
|
6165
6357
|
let content;
|
|
@@ -6179,7 +6371,7 @@ ${outputFields}`);
|
|
|
6179
6371
|
);
|
|
6180
6372
|
}
|
|
6181
6373
|
if (message.role === "user") {
|
|
6182
|
-
|
|
6374
|
+
messages.push({ role: "user", content });
|
|
6183
6375
|
continue;
|
|
6184
6376
|
}
|
|
6185
6377
|
if (message.role !== "assistant") {
|
|
@@ -6190,9 +6382,9 @@ ${outputFields}`);
|
|
|
6190
6382
|
"Assistant message cannot contain non-text content like images, files,etc"
|
|
6191
6383
|
);
|
|
6192
6384
|
}
|
|
6193
|
-
|
|
6385
|
+
messages.push({ role: "assistant", content });
|
|
6194
6386
|
}
|
|
6195
|
-
return [systemPrompt, ...
|
|
6387
|
+
return [systemPrompt, ...messages];
|
|
6196
6388
|
}
|
|
6197
6389
|
const userContent = this.renderSingleValueUserContent(
|
|
6198
6390
|
values,
|
|
@@ -6645,9 +6837,9 @@ var formatDateWithTimezone = (date) => {
|
|
|
6645
6837
|
};
|
|
6646
6838
|
|
|
6647
6839
|
// dsp/extract.ts
|
|
6648
|
-
var extractValues = (sig, values, content) => {
|
|
6840
|
+
var extractValues = (sig, values, content, strictMode = false) => {
|
|
6649
6841
|
const xstate = { extractedFields: [], streamedIndex: {}, s: -1 };
|
|
6650
|
-
streamingExtractValues(sig, values, xstate, content);
|
|
6842
|
+
streamingExtractValues(sig, values, xstate, content, strictMode);
|
|
6651
6843
|
streamingExtractFinalValue(sig, values, xstate, content);
|
|
6652
6844
|
for (const field of sig.getOutputFields()) {
|
|
6653
6845
|
if (field.isInternal) {
|
|
@@ -6655,10 +6847,9 @@ var extractValues = (sig, values, content) => {
|
|
|
6655
6847
|
}
|
|
6656
6848
|
}
|
|
6657
6849
|
};
|
|
6658
|
-
var checkMissingRequiredFields = (xstate, values,
|
|
6850
|
+
var checkMissingRequiredFields = (xstate, values, outputFields) => {
|
|
6659
6851
|
const missingFields = [];
|
|
6660
|
-
for (
|
|
6661
|
-
const field = xstate.extractedFields[i];
|
|
6852
|
+
for (const field of outputFields) {
|
|
6662
6853
|
if (field && !field.isOptional && values[field.name] === void 0) {
|
|
6663
6854
|
missingFields.push(field);
|
|
6664
6855
|
}
|
|
@@ -6670,23 +6861,34 @@ var checkMissingRequiredFields = (xstate, values, currentIndex) => {
|
|
|
6670
6861
|
});
|
|
6671
6862
|
}
|
|
6672
6863
|
};
|
|
6673
|
-
var streamingExtractValues = (sig, values, xstate, content,
|
|
6864
|
+
var streamingExtractValues = (sig, values, xstate, content, strictMode = false) => {
|
|
6674
6865
|
const fields = sig.getOutputFields();
|
|
6866
|
+
let expectedField;
|
|
6675
6867
|
for (const [index, field] of fields.entries()) {
|
|
6868
|
+
if (index === xstate.currFieldIndex) {
|
|
6869
|
+
continue;
|
|
6870
|
+
}
|
|
6676
6871
|
if (field.name in values) {
|
|
6677
6872
|
continue;
|
|
6678
6873
|
}
|
|
6679
6874
|
const isFirst = xstate.extractedFields.length === 0;
|
|
6680
6875
|
const prefix = (isFirst ? "" : "\n") + field.title + ":";
|
|
6681
6876
|
let e = matchesContent(content, prefix, xstate.s);
|
|
6877
|
+
let prefixLen = prefix.length;
|
|
6682
6878
|
switch (e) {
|
|
6683
6879
|
case -1:
|
|
6684
|
-
if (
|
|
6880
|
+
if (!strictMode && fields.length === 1 && xstate.currField === void 0) {
|
|
6881
|
+
prefixLen = 0;
|
|
6882
|
+
e = 0;
|
|
6883
|
+
break;
|
|
6884
|
+
}
|
|
6885
|
+
if (xstate.currField === void 0 && !field.isOptional) {
|
|
6685
6886
|
throw new ValidationError({
|
|
6686
|
-
message: "Required field not found",
|
|
6887
|
+
message: "Expected (Required) field not found",
|
|
6687
6888
|
fields: [field]
|
|
6688
6889
|
});
|
|
6689
6890
|
}
|
|
6891
|
+
expectedField = field.isOptional ? void 0 : field;
|
|
6690
6892
|
continue;
|
|
6691
6893
|
// Field is not found, continue to the next field
|
|
6692
6894
|
case -2:
|
|
@@ -6699,7 +6901,12 @@ var streamingExtractValues = (sig, values, xstate, content, streamingValidation
|
|
|
6699
6901
|
xstate.inBlock = true;
|
|
6700
6902
|
return true;
|
|
6701
6903
|
}
|
|
6702
|
-
|
|
6904
|
+
if (expectedField && expectedField.name !== field.name) {
|
|
6905
|
+
throw new ValidationError({
|
|
6906
|
+
message: "Expected (Required) field not found",
|
|
6907
|
+
fields: [expectedField]
|
|
6908
|
+
});
|
|
6909
|
+
}
|
|
6703
6910
|
if (xstate.currField) {
|
|
6704
6911
|
const val = content.substring(xstate.s, e).trim();
|
|
6705
6912
|
const parsedValue = validateAndParseFieldValue(xstate.currField, val);
|
|
@@ -6712,7 +6919,6 @@ var streamingExtractValues = (sig, values, xstate, content, streamingValidation
|
|
|
6712
6919
|
xstate.prevFields = [{ field: xstate.currField, s: xstate.s, e }];
|
|
6713
6920
|
}
|
|
6714
6921
|
}
|
|
6715
|
-
checkMissingRequiredFields(xstate, values, index);
|
|
6716
6922
|
xstate.s = e + prefixLen;
|
|
6717
6923
|
xstate.currField = field;
|
|
6718
6924
|
xstate.currFieldIndex = index;
|
|
@@ -6732,8 +6938,7 @@ var streamingExtractFinalValue = (sig, values, xstate, content) => {
|
|
|
6732
6938
|
values[xstate.currField.name] = parsedValue;
|
|
6733
6939
|
}
|
|
6734
6940
|
}
|
|
6735
|
-
|
|
6736
|
-
checkMissingRequiredFields(xstate, values, sigFields.length);
|
|
6941
|
+
checkMissingRequiredFields(xstate, values, sig.getOutputFields());
|
|
6737
6942
|
};
|
|
6738
6943
|
var convertValueToType = (field, val, required = false) => {
|
|
6739
6944
|
switch (field.type?.name) {
|
|
@@ -7370,8 +7575,9 @@ var AxInstanceRegistry = class {
|
|
|
7370
7575
|
this.reg.add(instance);
|
|
7371
7576
|
}
|
|
7372
7577
|
*[Symbol.iterator]() {
|
|
7373
|
-
|
|
7374
|
-
|
|
7578
|
+
const items = Array.from(this.reg);
|
|
7579
|
+
for (let i = 0; i < items.length; i++) {
|
|
7580
|
+
yield items[i];
|
|
7375
7581
|
}
|
|
7376
7582
|
}
|
|
7377
7583
|
};
|
|
@@ -8309,7 +8515,7 @@ var AxSignature = class _AxSignature {
|
|
|
8309
8515
|
this.getOutputFields().forEach((field) => {
|
|
8310
8516
|
validateField(field, "output");
|
|
8311
8517
|
});
|
|
8312
|
-
this.sigHash = createHash("sha256").update(
|
|
8518
|
+
this.sigHash = createHash("sha256").update(JSON.stringify(this.inputFields)).update(JSON.stringify(this.outputFields)).digest("hex");
|
|
8313
8519
|
this.sigString = renderSignature(
|
|
8314
8520
|
this.description,
|
|
8315
8521
|
this.inputFields,
|
|
@@ -8630,7 +8836,7 @@ var AxProgramWithSignature = class {
|
|
|
8630
8836
|
this.signature.validate();
|
|
8631
8837
|
this.sigHash = this.signature?.hash();
|
|
8632
8838
|
this.children = new AxInstanceRegistry();
|
|
8633
|
-
this.key = { id: this.
|
|
8839
|
+
this.key = { id: this.signature.hash() };
|
|
8634
8840
|
}
|
|
8635
8841
|
getSignature() {
|
|
8636
8842
|
return this.signature;
|
|
@@ -8650,8 +8856,8 @@ var AxProgramWithSignature = class {
|
|
|
8650
8856
|
}
|
|
8651
8857
|
setId(id) {
|
|
8652
8858
|
this.key = { id, custom: true };
|
|
8653
|
-
for (const child of this.children) {
|
|
8654
|
-
child
|
|
8859
|
+
for (const child of Array.from(this.children)) {
|
|
8860
|
+
child?.setParentId(id);
|
|
8655
8861
|
}
|
|
8656
8862
|
}
|
|
8657
8863
|
setParentId(parentId) {
|
|
@@ -8664,8 +8870,8 @@ var AxProgramWithSignature = class {
|
|
|
8664
8870
|
if (!("programId" in examples)) {
|
|
8665
8871
|
return;
|
|
8666
8872
|
}
|
|
8667
|
-
for (const child of this.children) {
|
|
8668
|
-
child
|
|
8873
|
+
for (const child of Array.from(this.children)) {
|
|
8874
|
+
child?.setExamples(examples, options);
|
|
8669
8875
|
}
|
|
8670
8876
|
}
|
|
8671
8877
|
_setExamples(examples, options) {
|
|
@@ -8698,30 +8904,37 @@ var AxProgramWithSignature = class {
|
|
|
8698
8904
|
if (this.trace) {
|
|
8699
8905
|
traces.push({ trace: this.trace, programId: this.key.id });
|
|
8700
8906
|
}
|
|
8701
|
-
for (const child of this.children) {
|
|
8702
|
-
const _traces = child
|
|
8703
|
-
traces = [...traces, ..._traces];
|
|
8907
|
+
for (const child of Array.from(this.children)) {
|
|
8908
|
+
const _traces = child?.getTraces();
|
|
8909
|
+
traces = [...traces, ..._traces ?? []];
|
|
8704
8910
|
}
|
|
8705
8911
|
return traces;
|
|
8706
8912
|
}
|
|
8707
8913
|
getUsage() {
|
|
8708
8914
|
let usage = [...this.usage ?? []];
|
|
8709
|
-
for (const child of this.children) {
|
|
8710
|
-
const cu = child
|
|
8711
|
-
usage = [...usage, ...cu];
|
|
8915
|
+
for (const child of Array.from(this.children)) {
|
|
8916
|
+
const cu = child?.getUsage();
|
|
8917
|
+
usage = [...usage, ...cu ?? []];
|
|
8712
8918
|
}
|
|
8713
8919
|
return mergeProgramUsage(usage);
|
|
8714
8920
|
}
|
|
8715
8921
|
resetUsage() {
|
|
8716
8922
|
this.usage = [];
|
|
8717
|
-
for (const child of this.children) {
|
|
8718
|
-
child
|
|
8923
|
+
for (const child of Array.from(this.children)) {
|
|
8924
|
+
child?.resetUsage();
|
|
8719
8925
|
}
|
|
8720
8926
|
}
|
|
8721
8927
|
setDemos(demos) {
|
|
8928
|
+
const hasChildren = Array.from(this.children).length > 0;
|
|
8929
|
+
const hasMatchingDemo = demos.some((demo) => demo.programId === this.key.id);
|
|
8930
|
+
if (hasChildren && !hasMatchingDemo) {
|
|
8931
|
+
throw new Error(
|
|
8932
|
+
`Program with id '${this.key.id}' has children but no matching programId found in demos`
|
|
8933
|
+
);
|
|
8934
|
+
}
|
|
8722
8935
|
this.demos = demos.filter((v) => v.programId === this.key.id).map((v) => v.traces).flat();
|
|
8723
|
-
for (const child of this.children) {
|
|
8724
|
-
child
|
|
8936
|
+
for (const child of Array.from(this.children)) {
|
|
8937
|
+
child?.setDemos(demos);
|
|
8725
8938
|
}
|
|
8726
8939
|
}
|
|
8727
8940
|
};
|
|
@@ -8749,8 +8962,8 @@ var AxProgram = class {
|
|
|
8749
8962
|
}
|
|
8750
8963
|
setId(id) {
|
|
8751
8964
|
this.key = { id, custom: true };
|
|
8752
|
-
for (const child of this.children) {
|
|
8753
|
-
child
|
|
8965
|
+
for (const child of Array.from(this.children)) {
|
|
8966
|
+
child?.setParentId(id);
|
|
8754
8967
|
}
|
|
8755
8968
|
}
|
|
8756
8969
|
setParentId(parentId) {
|
|
@@ -8762,8 +8975,8 @@ var AxProgram = class {
|
|
|
8762
8975
|
if (!("programId" in examples)) {
|
|
8763
8976
|
return;
|
|
8764
8977
|
}
|
|
8765
|
-
for (const child of this.children) {
|
|
8766
|
-
child
|
|
8978
|
+
for (const child of Array.from(this.children)) {
|
|
8979
|
+
child?.setExamples(examples, options);
|
|
8767
8980
|
}
|
|
8768
8981
|
}
|
|
8769
8982
|
getTraces() {
|
|
@@ -8771,29 +8984,36 @@ var AxProgram = class {
|
|
|
8771
8984
|
if (this.trace) {
|
|
8772
8985
|
traces.push({ trace: this.trace, programId: this.key.id });
|
|
8773
8986
|
}
|
|
8774
|
-
for (const child of this.children) {
|
|
8775
|
-
const _traces = child
|
|
8776
|
-
traces = [...traces, ..._traces];
|
|
8987
|
+
for (const child of Array.from(this.children)) {
|
|
8988
|
+
const _traces = child?.getTraces();
|
|
8989
|
+
traces = [...traces, ..._traces ?? []];
|
|
8777
8990
|
}
|
|
8778
8991
|
return traces;
|
|
8779
8992
|
}
|
|
8780
8993
|
getUsage() {
|
|
8781
8994
|
let usage = [...this.usage ?? []];
|
|
8782
|
-
for (const child of this.children) {
|
|
8783
|
-
const cu = child
|
|
8784
|
-
usage = [...usage, ...cu];
|
|
8995
|
+
for (const child of Array.from(this.children)) {
|
|
8996
|
+
const cu = child?.getUsage();
|
|
8997
|
+
usage = [...usage, ...cu ?? []];
|
|
8785
8998
|
}
|
|
8786
8999
|
return mergeProgramUsage(usage);
|
|
8787
9000
|
}
|
|
8788
9001
|
resetUsage() {
|
|
8789
9002
|
this.usage = [];
|
|
8790
|
-
for (const child of this.children) {
|
|
8791
|
-
child
|
|
9003
|
+
for (const child of Array.from(this.children)) {
|
|
9004
|
+
child?.resetUsage();
|
|
8792
9005
|
}
|
|
8793
9006
|
}
|
|
8794
9007
|
setDemos(demos) {
|
|
8795
|
-
|
|
8796
|
-
|
|
9008
|
+
const hasChildren = Array.from(this.children).length > 0;
|
|
9009
|
+
const hasMatchingDemo = demos.some((demo) => demo.programId === this.key.id);
|
|
9010
|
+
if (hasChildren && !hasMatchingDemo) {
|
|
9011
|
+
throw new Error(
|
|
9012
|
+
`Program with id '${this.key.id}' has children but no matching programId found in demos`
|
|
9013
|
+
);
|
|
9014
|
+
}
|
|
9015
|
+
for (const child of Array.from(this.children)) {
|
|
9016
|
+
child?.setDemos(demos);
|
|
8797
9017
|
}
|
|
8798
9018
|
}
|
|
8799
9019
|
};
|
|
@@ -8921,7 +9141,7 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
8921
9141
|
traceContext
|
|
8922
9142
|
}) {
|
|
8923
9143
|
const { sessionId, traceId, functions: _functions } = options ?? {};
|
|
8924
|
-
const
|
|
9144
|
+
const strictMode = options?.strictMode ?? false;
|
|
8925
9145
|
const model = options.model;
|
|
8926
9146
|
const functions = _functions?.map((f2) => "toFunction" in f2 ? f2.toFunction() : f2)?.flat();
|
|
8927
9147
|
const res = await this.forwardSendRequest({
|
|
@@ -8931,7 +9151,7 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
8931
9151
|
traceContext,
|
|
8932
9152
|
firstStep
|
|
8933
9153
|
});
|
|
8934
|
-
if (res instanceof
|
|
9154
|
+
if (res instanceof ReadableStream2) {
|
|
8935
9155
|
yield* this.processStreamingResponse({
|
|
8936
9156
|
ai,
|
|
8937
9157
|
model,
|
|
@@ -8940,7 +9160,7 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
8940
9160
|
traceId,
|
|
8941
9161
|
sessionId,
|
|
8942
9162
|
functions,
|
|
8943
|
-
|
|
9163
|
+
strictMode,
|
|
8944
9164
|
span
|
|
8945
9165
|
});
|
|
8946
9166
|
this.getLogger(ai, options)?.("", { tags: ["responseEnd"] });
|
|
@@ -8953,7 +9173,8 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
8953
9173
|
traceId,
|
|
8954
9174
|
sessionId,
|
|
8955
9175
|
functions,
|
|
8956
|
-
span
|
|
9176
|
+
span,
|
|
9177
|
+
strictMode
|
|
8957
9178
|
});
|
|
8958
9179
|
}
|
|
8959
9180
|
}
|
|
@@ -8965,10 +9186,9 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
8965
9186
|
sessionId,
|
|
8966
9187
|
traceId,
|
|
8967
9188
|
functions,
|
|
8968
|
-
|
|
9189
|
+
strictMode,
|
|
8969
9190
|
span
|
|
8970
9191
|
}) {
|
|
8971
|
-
const streamingValidation = fastFail ?? ai.getFeatures(model).functionCot !== true;
|
|
8972
9192
|
const functionCalls = [];
|
|
8973
9193
|
this.values = {};
|
|
8974
9194
|
const xstate = {
|
|
@@ -9019,7 +9239,7 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9019
9239
|
this.values,
|
|
9020
9240
|
xstate,
|
|
9021
9241
|
content,
|
|
9022
|
-
|
|
9242
|
+
strictMode
|
|
9023
9243
|
);
|
|
9024
9244
|
if (skip) {
|
|
9025
9245
|
continue;
|
|
@@ -9120,7 +9340,8 @@ Content: ${content}`
|
|
|
9120
9340
|
sessionId,
|
|
9121
9341
|
traceId,
|
|
9122
9342
|
functions,
|
|
9123
|
-
span
|
|
9343
|
+
span,
|
|
9344
|
+
strictMode
|
|
9124
9345
|
}) {
|
|
9125
9346
|
this.values = {};
|
|
9126
9347
|
let results = res.results ?? [];
|
|
@@ -9154,7 +9375,7 @@ Content: ${content}`
|
|
|
9154
9375
|
if (result.thought && result.thought.length > 0) {
|
|
9155
9376
|
this.values[this.thoughtFieldName] = result.thought;
|
|
9156
9377
|
}
|
|
9157
|
-
extractValues(this.signature, this.values, result.content);
|
|
9378
|
+
extractValues(this.signature, this.values, result.content, strictMode);
|
|
9158
9379
|
await assertAssertions(this.asserts, this.values);
|
|
9159
9380
|
if (this.fieldProcessors.length) {
|
|
9160
9381
|
await processFieldProcessors(
|
|
@@ -9326,8 +9547,7 @@ Content: ${result.content}`
|
|
|
9326
9547
|
...options?.thinkingTokenBudget ? { thinking_token_budget: options.thinkingTokenBudget } : {},
|
|
9327
9548
|
...options?.showThoughts ? { show_thoughts: options.showThoughts } : {},
|
|
9328
9549
|
...options?.maxSteps ? { max_steps: options.maxSteps } : {},
|
|
9329
|
-
...options?.maxRetries ? { max_retries: options.maxRetries } : {}
|
|
9330
|
-
...options?.fastFail ? { fast_fail: options.fastFail } : {}
|
|
9550
|
+
...options?.maxRetries ? { max_retries: options.maxRetries } : {}
|
|
9331
9551
|
};
|
|
9332
9552
|
const traceLabel = options.traceLabel ?? this.options?.traceLabel;
|
|
9333
9553
|
const spanName = traceLabel ? `${traceLabel} (AxGen)` : "AxGen";
|
|
@@ -9521,7 +9741,9 @@ var AxAgent = class {
|
|
|
9521
9741
|
description: definition ?? description
|
|
9522
9742
|
});
|
|
9523
9743
|
for (const agent of agents ?? []) {
|
|
9524
|
-
this.program.register(
|
|
9744
|
+
this.program.register(
|
|
9745
|
+
agent
|
|
9746
|
+
);
|
|
9525
9747
|
}
|
|
9526
9748
|
this.name = name;
|
|
9527
9749
|
this.func = {
|
|
@@ -10025,98 +10247,825 @@ function validateModels2(services) {
|
|
|
10025
10247
|
}
|
|
10026
10248
|
}
|
|
10027
10249
|
|
|
10028
|
-
//
|
|
10029
|
-
|
|
10030
|
-
|
|
10031
|
-
|
|
10032
|
-
|
|
10033
|
-
|
|
10034
|
-
|
|
10035
|
-
|
|
10036
|
-
|
|
10037
|
-
|
|
10038
|
-
|
|
10039
|
-
|
|
10040
|
-
tracer
|
|
10041
|
-
}) {
|
|
10042
|
-
this.name = name;
|
|
10043
|
-
this.fetch = fetch2;
|
|
10044
|
-
this.tracer = tracer;
|
|
10250
|
+
// dsp/optimizer.ts
|
|
10251
|
+
var AxDefaultCostTracker = class {
|
|
10252
|
+
tokenUsage = {};
|
|
10253
|
+
totalTokens = 0;
|
|
10254
|
+
// Configuration options
|
|
10255
|
+
costPerModel;
|
|
10256
|
+
maxCost;
|
|
10257
|
+
maxTokens;
|
|
10258
|
+
constructor(options) {
|
|
10259
|
+
this.costPerModel = options?.costPerModel ?? {};
|
|
10260
|
+
this.maxCost = options?.maxCost;
|
|
10261
|
+
this.maxTokens = options?.maxTokens;
|
|
10045
10262
|
}
|
|
10046
|
-
|
|
10047
|
-
|
|
10048
|
-
|
|
10049
|
-
}
|
|
10050
|
-
if (!this.tracer) {
|
|
10051
|
-
return await this._upsert(req, update);
|
|
10052
|
-
}
|
|
10053
|
-
return await this.tracer.startActiveSpan(
|
|
10054
|
-
"DB Upsert Request",
|
|
10055
|
-
{
|
|
10056
|
-
kind: SpanKind3.SERVER,
|
|
10057
|
-
attributes: {
|
|
10058
|
-
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10059
|
-
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10060
|
-
[axSpanAttributes.DB_TABLE]: req.table,
|
|
10061
|
-
[axSpanAttributes.DB_NAMESPACE]: req.namespace,
|
|
10062
|
-
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
10063
|
-
}
|
|
10064
|
-
},
|
|
10065
|
-
async (span) => {
|
|
10066
|
-
try {
|
|
10067
|
-
return await this._upsert(req, update, { span });
|
|
10068
|
-
} finally {
|
|
10069
|
-
span.end();
|
|
10070
|
-
}
|
|
10071
|
-
}
|
|
10072
|
-
);
|
|
10263
|
+
trackTokens(count, model) {
|
|
10264
|
+
this.tokenUsage[model] = (this.tokenUsage[model] || 0) + count;
|
|
10265
|
+
this.totalTokens += count;
|
|
10073
10266
|
}
|
|
10074
|
-
|
|
10075
|
-
|
|
10076
|
-
|
|
10267
|
+
getCurrentCost() {
|
|
10268
|
+
let totalCost = 0;
|
|
10269
|
+
for (const [model, tokens] of Object.entries(this.tokenUsage)) {
|
|
10270
|
+
const costPer1K = this.costPerModel[model] || 1e-3;
|
|
10271
|
+
totalCost += tokens / 1e3 * costPer1K;
|
|
10077
10272
|
}
|
|
10078
|
-
|
|
10079
|
-
|
|
10273
|
+
return totalCost;
|
|
10274
|
+
}
|
|
10275
|
+
getTokenUsage() {
|
|
10276
|
+
return { ...this.tokenUsage };
|
|
10277
|
+
}
|
|
10278
|
+
getTotalTokens() {
|
|
10279
|
+
return this.totalTokens;
|
|
10280
|
+
}
|
|
10281
|
+
isLimitReached() {
|
|
10282
|
+
if (this.maxTokens !== void 0 && this.totalTokens >= this.maxTokens) {
|
|
10283
|
+
return true;
|
|
10080
10284
|
}
|
|
10081
|
-
if (
|
|
10082
|
-
|
|
10285
|
+
if (this.maxCost !== void 0) {
|
|
10286
|
+
const currentCost = this.getCurrentCost();
|
|
10287
|
+
if (currentCost >= this.maxCost) {
|
|
10288
|
+
return true;
|
|
10289
|
+
}
|
|
10083
10290
|
}
|
|
10084
|
-
|
|
10085
|
-
|
|
10291
|
+
return false;
|
|
10292
|
+
}
|
|
10293
|
+
reset() {
|
|
10294
|
+
this.tokenUsage = {};
|
|
10295
|
+
this.totalTokens = 0;
|
|
10296
|
+
}
|
|
10297
|
+
};
|
|
10298
|
+
var AxBaseOptimizer = class {
|
|
10299
|
+
// Common AxOptimizerArgs fields
|
|
10300
|
+
studentAI;
|
|
10301
|
+
teacherAI;
|
|
10302
|
+
examples;
|
|
10303
|
+
validationSet;
|
|
10304
|
+
targetScore;
|
|
10305
|
+
minSuccessRate;
|
|
10306
|
+
onProgress;
|
|
10307
|
+
onEarlyStop;
|
|
10308
|
+
costTracker;
|
|
10309
|
+
seed;
|
|
10310
|
+
// Checkpointing fields
|
|
10311
|
+
checkpointSave;
|
|
10312
|
+
checkpointLoad;
|
|
10313
|
+
checkpointInterval;
|
|
10314
|
+
resumeFromCheckpoint;
|
|
10315
|
+
// Logging fields
|
|
10316
|
+
logger;
|
|
10317
|
+
verbose;
|
|
10318
|
+
// Checkpoint state
|
|
10319
|
+
currentRound = 0;
|
|
10320
|
+
scoreHistory = [];
|
|
10321
|
+
configurationHistory = [];
|
|
10322
|
+
// Common optimization statistics
|
|
10323
|
+
stats;
|
|
10324
|
+
constructor(args) {
|
|
10325
|
+
if (args.examples.length === 0) {
|
|
10326
|
+
throw new Error("No examples found");
|
|
10086
10327
|
}
|
|
10087
|
-
|
|
10088
|
-
|
|
10089
|
-
|
|
10090
|
-
|
|
10091
|
-
|
|
10092
|
-
|
|
10093
|
-
|
|
10094
|
-
|
|
10095
|
-
|
|
10096
|
-
|
|
10097
|
-
|
|
10328
|
+
this.studentAI = args.studentAI;
|
|
10329
|
+
this.teacherAI = args.teacherAI;
|
|
10330
|
+
this.examples = args.examples;
|
|
10331
|
+
this.validationSet = args.validationSet;
|
|
10332
|
+
this.targetScore = args.targetScore;
|
|
10333
|
+
this.minSuccessRate = args.minSuccessRate;
|
|
10334
|
+
this.onProgress = args.onProgress;
|
|
10335
|
+
this.onEarlyStop = args.onEarlyStop;
|
|
10336
|
+
this.seed = args.seed;
|
|
10337
|
+
this.checkpointSave = args.checkpointSave;
|
|
10338
|
+
this.checkpointLoad = args.checkpointLoad;
|
|
10339
|
+
this.checkpointInterval = args.checkpointInterval ?? 10;
|
|
10340
|
+
this.resumeFromCheckpoint = args.resumeFromCheckpoint;
|
|
10341
|
+
this.logger = args.logger;
|
|
10342
|
+
this.verbose = args.verbose;
|
|
10343
|
+
const costTracker = new AxDefaultCostTracker({
|
|
10344
|
+
maxTokens: 1e6
|
|
10345
|
+
});
|
|
10346
|
+
this.costTracker = args.costTracker ?? costTracker;
|
|
10347
|
+
this.stats = this.initializeStats();
|
|
10348
|
+
}
|
|
10349
|
+
/**
|
|
10350
|
+
* Initialize the optimization statistics structure
|
|
10351
|
+
*/
|
|
10352
|
+
initializeStats() {
|
|
10353
|
+
return {
|
|
10354
|
+
totalCalls: 0,
|
|
10355
|
+
successfulDemos: 0,
|
|
10356
|
+
estimatedTokenUsage: 0,
|
|
10357
|
+
earlyStopped: false,
|
|
10358
|
+
resourceUsage: {
|
|
10359
|
+
totalTokens: 0,
|
|
10360
|
+
totalTime: 0,
|
|
10361
|
+
avgLatencyPerEval: 0,
|
|
10362
|
+
costByModel: {}
|
|
10098
10363
|
},
|
|
10099
|
-
|
|
10100
|
-
|
|
10101
|
-
|
|
10102
|
-
|
|
10103
|
-
|
|
10104
|
-
}
|
|
10364
|
+
convergenceInfo: {
|
|
10365
|
+
converged: false,
|
|
10366
|
+
finalImprovement: 0,
|
|
10367
|
+
stagnationRounds: 0,
|
|
10368
|
+
convergenceThreshold: 0.01
|
|
10105
10369
|
}
|
|
10106
|
-
|
|
10370
|
+
};
|
|
10107
10371
|
}
|
|
10108
|
-
|
|
10109
|
-
|
|
10110
|
-
|
|
10372
|
+
/**
|
|
10373
|
+
* Set up reproducible random seed if provided
|
|
10374
|
+
*/
|
|
10375
|
+
setupRandomSeed() {
|
|
10376
|
+
if (this.seed !== void 0) {
|
|
10377
|
+
Math.random = (() => {
|
|
10378
|
+
let seed = this.seed;
|
|
10379
|
+
return () => {
|
|
10380
|
+
seed = (seed * 9301 + 49297) % 233280;
|
|
10381
|
+
return seed / 233280;
|
|
10382
|
+
};
|
|
10383
|
+
})();
|
|
10111
10384
|
}
|
|
10112
|
-
|
|
10113
|
-
|
|
10385
|
+
}
|
|
10386
|
+
/**
|
|
10387
|
+
* Check if optimization should stop early due to cost limits
|
|
10388
|
+
*/
|
|
10389
|
+
checkCostLimits() {
|
|
10390
|
+
return this.costTracker?.isLimitReached() ?? false;
|
|
10391
|
+
}
|
|
10392
|
+
/**
|
|
10393
|
+
* Check if target score has been reached
|
|
10394
|
+
*/
|
|
10395
|
+
checkTargetScore(currentScore) {
|
|
10396
|
+
return this.targetScore !== void 0 && currentScore >= this.targetScore;
|
|
10397
|
+
}
|
|
10398
|
+
/**
|
|
10399
|
+
* Update resource usage statistics
|
|
10400
|
+
*/
|
|
10401
|
+
updateResourceUsage(startTime, tokensUsed = 0) {
|
|
10402
|
+
this.stats.resourceUsage.totalTime = Date.now() - startTime;
|
|
10403
|
+
this.stats.resourceUsage.totalTokens += tokensUsed;
|
|
10404
|
+
if (this.stats.totalCalls > 0) {
|
|
10405
|
+
this.stats.resourceUsage.avgLatencyPerEval = this.stats.resourceUsage.totalTime / this.stats.totalCalls;
|
|
10114
10406
|
}
|
|
10115
|
-
|
|
10116
|
-
|
|
10117
|
-
|
|
10118
|
-
|
|
10119
|
-
|
|
10407
|
+
}
|
|
10408
|
+
/**
|
|
10409
|
+
* Trigger early stopping with appropriate callbacks
|
|
10410
|
+
*/
|
|
10411
|
+
triggerEarlyStopping(reason, bestScoreRound) {
|
|
10412
|
+
this.stats.earlyStopped = true;
|
|
10413
|
+
this.stats.earlyStopping = {
|
|
10414
|
+
bestScoreRound,
|
|
10415
|
+
patienceExhausted: reason.includes("improvement"),
|
|
10416
|
+
reason
|
|
10417
|
+
};
|
|
10418
|
+
if (this.onEarlyStop) {
|
|
10419
|
+
this.onEarlyStop(reason, this.stats);
|
|
10420
|
+
}
|
|
10421
|
+
}
|
|
10422
|
+
/**
|
|
10423
|
+
* Get the validation set, with fallback to a split of examples
|
|
10424
|
+
*/
|
|
10425
|
+
getValidationSet(options) {
|
|
10426
|
+
return options?.overrideValidationSet || this.validationSet || this.examples.slice(0, Math.floor(this.examples.length * 0.2));
|
|
10427
|
+
}
|
|
10428
|
+
/**
|
|
10429
|
+
* Get the AI service to use for a specific task, preferring teacher when available
|
|
10430
|
+
* @param preferTeacher Whether to prefer teacher AI over student AI
|
|
10431
|
+
* @param options Optional compile options that may override teacher AI
|
|
10432
|
+
* @returns The appropriate AI service to use
|
|
10433
|
+
*/
|
|
10434
|
+
getAIService(preferTeacher = false, options) {
|
|
10435
|
+
if (preferTeacher && options?.overrideTeacherAI) {
|
|
10436
|
+
return options.overrideTeacherAI;
|
|
10437
|
+
}
|
|
10438
|
+
if (preferTeacher && this.teacherAI) {
|
|
10439
|
+
return this.teacherAI;
|
|
10440
|
+
}
|
|
10441
|
+
return this.studentAI;
|
|
10442
|
+
}
|
|
10443
|
+
/**
|
|
10444
|
+
* Check if teacher AI is available (including overrides)
|
|
10445
|
+
* @param options Optional compile options that may override teacher AI
|
|
10446
|
+
* @returns True if teacher AI is configured or overridden
|
|
10447
|
+
*/
|
|
10448
|
+
hasTeacherAI(options) {
|
|
10449
|
+
return options?.overrideTeacherAI !== void 0 || this.teacherAI !== void 0;
|
|
10450
|
+
}
|
|
10451
|
+
/**
|
|
10452
|
+
* Get teacher AI if available, otherwise return student AI
|
|
10453
|
+
* @param options Optional compile options that may override teacher AI
|
|
10454
|
+
* @returns Teacher AI if available, otherwise student AI
|
|
10455
|
+
*/
|
|
10456
|
+
getTeacherOrStudentAI(options) {
|
|
10457
|
+
return options?.overrideTeacherAI || this.teacherAI || this.studentAI;
|
|
10458
|
+
}
|
|
10459
|
+
/**
|
|
10460
|
+
* Execute a task with teacher AI if available, otherwise use student AI
|
|
10461
|
+
* @param task Function that takes an AI service and returns a promise
|
|
10462
|
+
* @param preferTeacher Whether to prefer teacher AI (default: true)
|
|
10463
|
+
* @param options Optional compile options that may override teacher AI
|
|
10464
|
+
* @returns Result of the task execution
|
|
10465
|
+
*/
|
|
10466
|
+
async executeWithTeacher(task, preferTeacher = true, options) {
|
|
10467
|
+
const ai = this.getAIService(preferTeacher, options);
|
|
10468
|
+
return await task(ai);
|
|
10469
|
+
}
|
|
10470
|
+
/**
|
|
10471
|
+
* Get current optimization statistics
|
|
10472
|
+
*/
|
|
10473
|
+
getStats() {
|
|
10474
|
+
return { ...this.stats };
|
|
10475
|
+
}
|
|
10476
|
+
/**
|
|
10477
|
+
* Reset optimizer state for reuse with different programs
|
|
10478
|
+
*/
|
|
10479
|
+
reset() {
|
|
10480
|
+
this.stats = this.initializeStats();
|
|
10481
|
+
this.costTracker?.reset();
|
|
10482
|
+
this.currentRound = 0;
|
|
10483
|
+
this.scoreHistory = [];
|
|
10484
|
+
this.configurationHistory = [];
|
|
10485
|
+
}
|
|
10486
|
+
/**
|
|
10487
|
+
* Basic program validation that can be extended by concrete optimizers
|
|
10488
|
+
*/
|
|
10489
|
+
validateProgram(program) {
|
|
10490
|
+
const issues = [];
|
|
10491
|
+
const suggestions = [];
|
|
10492
|
+
if (!("forward" in program) || typeof program.forward !== "function") {
|
|
10493
|
+
issues.push("Program must have a forward method");
|
|
10494
|
+
}
|
|
10495
|
+
if (this.examples.length < 2) {
|
|
10496
|
+
issues.push("Need at least 2 examples for optimization");
|
|
10497
|
+
suggestions.push("Provide more training examples");
|
|
10498
|
+
}
|
|
10499
|
+
const valSetSize = this.getValidationSet().length;
|
|
10500
|
+
if (valSetSize < 1) {
|
|
10501
|
+
issues.push("Validation set is empty");
|
|
10502
|
+
suggestions.push("Provide examples or a validation set");
|
|
10503
|
+
}
|
|
10504
|
+
return {
|
|
10505
|
+
isValid: issues.length === 0,
|
|
10506
|
+
issues,
|
|
10507
|
+
suggestions
|
|
10508
|
+
};
|
|
10509
|
+
}
|
|
10510
|
+
/**
|
|
10511
|
+
* Multi-objective optimization using Pareto frontier
|
|
10512
|
+
* Default implementation that leverages the single-objective compile method
|
|
10513
|
+
* @param program The program to optimize
|
|
10514
|
+
* @param metricFn Multi-objective metric function that returns multiple scores
|
|
10515
|
+
* @param options Optional configuration options
|
|
10516
|
+
* @returns Pareto optimization result with frontier of non-dominated solutions
|
|
10517
|
+
*/
|
|
10518
|
+
async compilePareto(program, metricFn, options) {
|
|
10519
|
+
const startTime = Date.now();
|
|
10520
|
+
if (options?.verbose) {
|
|
10521
|
+
this.getLogger(options)?.(
|
|
10522
|
+
"Starting Pareto optimization using base implementation",
|
|
10523
|
+
{ tags: ["discovery"] }
|
|
10524
|
+
);
|
|
10525
|
+
this.getLogger(options)?.(
|
|
10526
|
+
"This will run multiple single-objective optimizations",
|
|
10527
|
+
{ tags: ["discovery"] }
|
|
10528
|
+
);
|
|
10529
|
+
}
|
|
10530
|
+
const solutions = await this.generateWeightedSolutions(
|
|
10531
|
+
program,
|
|
10532
|
+
metricFn,
|
|
10533
|
+
options
|
|
10534
|
+
);
|
|
10535
|
+
const constraintSolutions = await this.generateConstraintSolutions(
|
|
10536
|
+
program,
|
|
10537
|
+
metricFn,
|
|
10538
|
+
options
|
|
10539
|
+
);
|
|
10540
|
+
const allSolutions = [...solutions, ...constraintSolutions];
|
|
10541
|
+
if (options?.verbose) {
|
|
10542
|
+
this.getLogger(options)?.(
|
|
10543
|
+
`Generated ${allSolutions.length} candidate solutions`,
|
|
10544
|
+
{ tags: ["discovery"] }
|
|
10545
|
+
);
|
|
10546
|
+
}
|
|
10547
|
+
const paretoFront = this.findParetoFrontier(allSolutions);
|
|
10548
|
+
const hypervolume = this.calculateHypervolume(paretoFront);
|
|
10549
|
+
if (options?.verbose) {
|
|
10550
|
+
this.getLogger(options)?.(
|
|
10551
|
+
`Found ${paretoFront.length} non-dominated solutions`,
|
|
10552
|
+
{ tags: ["discovery"] }
|
|
10553
|
+
);
|
|
10554
|
+
this.getLogger(options)?.(
|
|
10555
|
+
`Hypervolume: ${hypervolume?.toFixed(4) || "N/A"}`,
|
|
10556
|
+
{ tags: ["discovery"] }
|
|
10557
|
+
);
|
|
10558
|
+
}
|
|
10559
|
+
this.updateResourceUsage(startTime);
|
|
10560
|
+
this.stats.convergenceInfo.converged = true;
|
|
10561
|
+
const bestScore = paretoFront.length > 0 ? Math.max(
|
|
10562
|
+
...paretoFront.map((sol) => Math.max(...Object.values(sol.scores)))
|
|
10563
|
+
) : 0;
|
|
10564
|
+
return {
|
|
10565
|
+
demos: paretoFront.length > 0 ? [...paretoFront[0].demos] : void 0,
|
|
10566
|
+
stats: this.stats,
|
|
10567
|
+
bestScore,
|
|
10568
|
+
paretoFront,
|
|
10569
|
+
hypervolume,
|
|
10570
|
+
paretoFrontSize: paretoFront.length,
|
|
10571
|
+
finalConfiguration: {
|
|
10572
|
+
paretoFrontSize: paretoFront.length,
|
|
10573
|
+
hypervolume,
|
|
10574
|
+
strategy: "weighted_combinations_and_constraints",
|
|
10575
|
+
numSolutions: allSolutions.length
|
|
10576
|
+
}
|
|
10577
|
+
};
|
|
10578
|
+
}
|
|
10579
|
+
/**
|
|
10580
|
+
* Generate solutions using different weighted combinations of objectives
|
|
10581
|
+
*/
|
|
10582
|
+
async generateWeightedSolutions(program, metricFn, options) {
|
|
10583
|
+
const solutions = [];
|
|
10584
|
+
const sampleExample = this.examples[0];
|
|
10585
|
+
const samplePrediction = await program.forward(
|
|
10586
|
+
this.studentAI,
|
|
10587
|
+
sampleExample
|
|
10588
|
+
);
|
|
10589
|
+
const sampleScores = await metricFn({
|
|
10590
|
+
prediction: samplePrediction,
|
|
10591
|
+
example: sampleExample
|
|
10592
|
+
});
|
|
10593
|
+
const objectives = Object.keys(sampleScores);
|
|
10594
|
+
if (options?.verbose) {
|
|
10595
|
+
this.getLogger(options)?.(
|
|
10596
|
+
`Detected objectives: ${objectives.join(", ")}`,
|
|
10597
|
+
{ tags: ["discovery"] }
|
|
10598
|
+
);
|
|
10599
|
+
}
|
|
10600
|
+
const weightCombinations = this.generateWeightCombinations(objectives);
|
|
10601
|
+
for (let i = 0; i < weightCombinations.length; i++) {
|
|
10602
|
+
const weights = weightCombinations[i];
|
|
10603
|
+
if (options?.verbose) {
|
|
10604
|
+
this.getLogger(options)?.(
|
|
10605
|
+
`Optimizing with weights: ${JSON.stringify(weights)}`,
|
|
10606
|
+
{ tags: ["discovery"] }
|
|
10607
|
+
);
|
|
10608
|
+
}
|
|
10609
|
+
const weightedMetric = async ({ prediction, example }) => {
|
|
10610
|
+
const scores = await metricFn({ prediction, example });
|
|
10611
|
+
let weightedScore = 0;
|
|
10612
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10613
|
+
weightedScore += score * (weights[objective] || 0);
|
|
10614
|
+
}
|
|
10615
|
+
return weightedScore;
|
|
10616
|
+
};
|
|
10617
|
+
try {
|
|
10618
|
+
const result = await this.compile(program, weightedMetric, {
|
|
10619
|
+
...options,
|
|
10620
|
+
verbose: false
|
|
10621
|
+
// Suppress inner optimization logs
|
|
10622
|
+
});
|
|
10623
|
+
const scores = await this.evaluateWithMultiObjective(
|
|
10624
|
+
program,
|
|
10625
|
+
result,
|
|
10626
|
+
metricFn
|
|
10627
|
+
);
|
|
10628
|
+
solutions.push({
|
|
10629
|
+
scores,
|
|
10630
|
+
demos: result.demos,
|
|
10631
|
+
configuration: {
|
|
10632
|
+
...result.finalConfiguration,
|
|
10633
|
+
weights,
|
|
10634
|
+
strategy: "weighted_combination"
|
|
10635
|
+
}
|
|
10636
|
+
});
|
|
10637
|
+
} catch (error) {
|
|
10638
|
+
if (options?.verbose) {
|
|
10639
|
+
this.getLogger(options)?.(
|
|
10640
|
+
`Failed optimization with weights ${JSON.stringify(weights)}: ${error}`,
|
|
10641
|
+
{ tags: ["warning"] }
|
|
10642
|
+
);
|
|
10643
|
+
}
|
|
10644
|
+
continue;
|
|
10645
|
+
}
|
|
10646
|
+
}
|
|
10647
|
+
return solutions;
|
|
10648
|
+
}
|
|
10649
|
+
/**
|
|
10650
|
+
* Generate solutions using constraint-based optimization
|
|
10651
|
+
*/
|
|
10652
|
+
async generateConstraintSolutions(program, metricFn, options) {
|
|
10653
|
+
const solutions = [];
|
|
10654
|
+
const sampleExample = this.examples[0];
|
|
10655
|
+
const samplePrediction = await program.forward(
|
|
10656
|
+
this.studentAI,
|
|
10657
|
+
sampleExample
|
|
10658
|
+
);
|
|
10659
|
+
const sampleScores = await metricFn({
|
|
10660
|
+
prediction: samplePrediction,
|
|
10661
|
+
example: sampleExample
|
|
10662
|
+
});
|
|
10663
|
+
const objectives = Object.keys(sampleScores);
|
|
10664
|
+
for (const primaryObjective of objectives) {
|
|
10665
|
+
if (options?.verbose) {
|
|
10666
|
+
this.getLogger(options)?.(
|
|
10667
|
+
`Optimizing ${primaryObjective} with constraints on other objectives`,
|
|
10668
|
+
{ tags: ["discovery"] }
|
|
10669
|
+
);
|
|
10670
|
+
}
|
|
10671
|
+
const constraintMetric = async ({ prediction, example }) => {
|
|
10672
|
+
const scores = await metricFn({ prediction, example });
|
|
10673
|
+
const primaryScore = scores[primaryObjective] || 0;
|
|
10674
|
+
let penalty = 0;
|
|
10675
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10676
|
+
if (objective !== primaryObjective) {
|
|
10677
|
+
if (score < 0.3) {
|
|
10678
|
+
penalty += (0.3 - score) * 2;
|
|
10679
|
+
}
|
|
10680
|
+
}
|
|
10681
|
+
}
|
|
10682
|
+
return primaryScore - penalty;
|
|
10683
|
+
};
|
|
10684
|
+
try {
|
|
10685
|
+
const result = await this.compile(program, constraintMetric, {
|
|
10686
|
+
...options,
|
|
10687
|
+
verbose: false
|
|
10688
|
+
});
|
|
10689
|
+
const scores = await this.evaluateWithMultiObjective(
|
|
10690
|
+
program,
|
|
10691
|
+
result,
|
|
10692
|
+
metricFn
|
|
10693
|
+
);
|
|
10694
|
+
solutions.push({
|
|
10695
|
+
scores,
|
|
10696
|
+
demos: result.demos,
|
|
10697
|
+
configuration: {
|
|
10698
|
+
...result.finalConfiguration,
|
|
10699
|
+
primaryObjective,
|
|
10700
|
+
strategy: "constraint_based"
|
|
10701
|
+
}
|
|
10702
|
+
});
|
|
10703
|
+
} catch (error) {
|
|
10704
|
+
if (options?.verbose) {
|
|
10705
|
+
this.getLogger(options)?.(
|
|
10706
|
+
`Failed constraint optimization for ${primaryObjective}: ${error}`,
|
|
10707
|
+
{ tags: ["warning"] }
|
|
10708
|
+
);
|
|
10709
|
+
}
|
|
10710
|
+
continue;
|
|
10711
|
+
}
|
|
10712
|
+
}
|
|
10713
|
+
return solutions;
|
|
10714
|
+
}
|
|
10715
|
+
/**
|
|
10716
|
+
* Generate different weight combinations for objectives
|
|
10717
|
+
*/
|
|
10718
|
+
generateWeightCombinations(objectives) {
|
|
10719
|
+
const combinations = [];
|
|
10720
|
+
for (const objective of objectives) {
|
|
10721
|
+
const weights = {};
|
|
10722
|
+
for (const obj of objectives) {
|
|
10723
|
+
weights[obj] = obj === objective ? 1 : 0;
|
|
10724
|
+
}
|
|
10725
|
+
combinations.push(weights);
|
|
10726
|
+
}
|
|
10727
|
+
const equalWeights = {};
|
|
10728
|
+
for (const objective of objectives) {
|
|
10729
|
+
equalWeights[objective] = 1 / objectives.length;
|
|
10730
|
+
}
|
|
10731
|
+
combinations.push(equalWeights);
|
|
10732
|
+
if (objectives.length === 2) {
|
|
10733
|
+
const [obj1, obj2] = objectives;
|
|
10734
|
+
for (let w1 = 0.1; w1 <= 0.9; w1 += 0.2) {
|
|
10735
|
+
const w2 = 1 - w1;
|
|
10736
|
+
combinations.push({ [obj1]: w1, [obj2]: w2 });
|
|
10737
|
+
}
|
|
10738
|
+
}
|
|
10739
|
+
if (objectives.length === 3) {
|
|
10740
|
+
const [obj1, obj2, obj3] = objectives;
|
|
10741
|
+
combinations.push(
|
|
10742
|
+
{ [obj1]: 0.5, [obj2]: 0.3, [obj3]: 0.2 },
|
|
10743
|
+
{ [obj1]: 0.3, [obj2]: 0.5, [obj3]: 0.2 },
|
|
10744
|
+
{ [obj1]: 0.2, [obj2]: 0.3, [obj3]: 0.5 }
|
|
10745
|
+
);
|
|
10746
|
+
}
|
|
10747
|
+
return combinations;
|
|
10748
|
+
}
|
|
10749
|
+
/**
|
|
10750
|
+
* Evaluate a single-objective result with multi-objective metrics
|
|
10751
|
+
*/
|
|
10752
|
+
async evaluateWithMultiObjective(program, result, metricFn) {
|
|
10753
|
+
const valSet = this.getValidationSet();
|
|
10754
|
+
const allScores = {};
|
|
10755
|
+
const testProgram = { ...program };
|
|
10756
|
+
if (result.demos && "setDemos" in testProgram) {
|
|
10757
|
+
;
|
|
10758
|
+
testProgram.setDemos(result.demos);
|
|
10759
|
+
}
|
|
10760
|
+
const evalSet = valSet.slice(0, Math.min(5, valSet.length));
|
|
10761
|
+
for (const example of evalSet) {
|
|
10762
|
+
try {
|
|
10763
|
+
const prediction = await testProgram.forward(
|
|
10764
|
+
this.studentAI,
|
|
10765
|
+
example
|
|
10766
|
+
);
|
|
10767
|
+
const scores = await metricFn({ prediction, example });
|
|
10768
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10769
|
+
if (!allScores[objective]) {
|
|
10770
|
+
allScores[objective] = [];
|
|
10771
|
+
}
|
|
10772
|
+
allScores[objective].push(score);
|
|
10773
|
+
}
|
|
10774
|
+
} catch {
|
|
10775
|
+
continue;
|
|
10776
|
+
}
|
|
10777
|
+
}
|
|
10778
|
+
const avgScores = {};
|
|
10779
|
+
for (const [objective, scores] of Object.entries(allScores)) {
|
|
10780
|
+
avgScores[objective] = scores.length > 0 ? scores.reduce((sum, score) => sum + score, 0) / scores.length : 0;
|
|
10781
|
+
}
|
|
10782
|
+
return avgScores;
|
|
10783
|
+
}
|
|
10784
|
+
/**
|
|
10785
|
+
* Find the Pareto frontier from a set of solutions
|
|
10786
|
+
*/
|
|
10787
|
+
findParetoFrontier(solutions) {
|
|
10788
|
+
const paretoFront = [];
|
|
10789
|
+
for (let i = 0; i < solutions.length; i++) {
|
|
10790
|
+
const solutionA = solutions[i];
|
|
10791
|
+
let isDominated = false;
|
|
10792
|
+
let dominatedCount = 0;
|
|
10793
|
+
for (let j = 0; j < solutions.length; j++) {
|
|
10794
|
+
if (i === j) continue;
|
|
10795
|
+
const solutionB = solutions[j];
|
|
10796
|
+
if (this.dominates(solutionB.scores, solutionA.scores)) {
|
|
10797
|
+
isDominated = true;
|
|
10798
|
+
break;
|
|
10799
|
+
}
|
|
10800
|
+
if (this.dominates(solutionA.scores, solutionB.scores)) {
|
|
10801
|
+
dominatedCount++;
|
|
10802
|
+
}
|
|
10803
|
+
}
|
|
10804
|
+
if (!isDominated) {
|
|
10805
|
+
paretoFront.push({
|
|
10806
|
+
demos: solutionA.demos || [],
|
|
10807
|
+
scores: solutionA.scores,
|
|
10808
|
+
configuration: solutionA.configuration,
|
|
10809
|
+
dominatedSolutions: dominatedCount
|
|
10810
|
+
});
|
|
10811
|
+
}
|
|
10812
|
+
}
|
|
10813
|
+
return paretoFront;
|
|
10814
|
+
}
|
|
10815
|
+
/**
|
|
10816
|
+
* Check if solution A dominates solution B
|
|
10817
|
+
* A dominates B if A is better or equal in all objectives and strictly better in at least one
|
|
10818
|
+
*/
|
|
10819
|
+
dominates(scoresA, scoresB) {
|
|
10820
|
+
const objectives = Object.keys(scoresA);
|
|
10821
|
+
let atLeastAsGood = true;
|
|
10822
|
+
let strictlyBetter = false;
|
|
10823
|
+
for (const objective of objectives) {
|
|
10824
|
+
const scoreA = scoresA[objective] || 0;
|
|
10825
|
+
const scoreB = scoresB[objective] || 0;
|
|
10826
|
+
if (scoreA < scoreB) {
|
|
10827
|
+
atLeastAsGood = false;
|
|
10828
|
+
break;
|
|
10829
|
+
}
|
|
10830
|
+
if (scoreA > scoreB) {
|
|
10831
|
+
strictlyBetter = true;
|
|
10832
|
+
}
|
|
10833
|
+
}
|
|
10834
|
+
return atLeastAsGood && strictlyBetter;
|
|
10835
|
+
}
|
|
10836
|
+
/**
|
|
10837
|
+
* Calculate hypervolume of the Pareto frontier
|
|
10838
|
+
* Simplified implementation using reference point at origin
|
|
10839
|
+
*/
|
|
10840
|
+
calculateHypervolume(paretoFront) {
|
|
10841
|
+
if (paretoFront.length === 0) return void 0;
|
|
10842
|
+
const firstSolution = paretoFront[0];
|
|
10843
|
+
const objectives = Object.keys(firstSolution.scores);
|
|
10844
|
+
if (objectives.length === 2) {
|
|
10845
|
+
const [obj1, obj2] = objectives;
|
|
10846
|
+
let hypervolume = 0;
|
|
10847
|
+
const sortedSolutions = [...paretoFront].sort(
|
|
10848
|
+
(a, b) => (b.scores[obj1] || 0) - (a.scores[obj1] || 0)
|
|
10849
|
+
);
|
|
10850
|
+
let prevScore2 = 0;
|
|
10851
|
+
for (const solution of sortedSolutions) {
|
|
10852
|
+
const score1 = solution.scores[obj1] || 0;
|
|
10853
|
+
const score2 = solution.scores[obj2] || 0;
|
|
10854
|
+
hypervolume += score1 * (score2 - prevScore2);
|
|
10855
|
+
prevScore2 = Math.max(prevScore2, score2);
|
|
10856
|
+
}
|
|
10857
|
+
return hypervolume;
|
|
10858
|
+
}
|
|
10859
|
+
return void 0;
|
|
10860
|
+
}
|
|
10861
|
+
/**
|
|
10862
|
+
* Save current optimization state to checkpoint
|
|
10863
|
+
*/
|
|
10864
|
+
async saveCheckpoint(optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10865
|
+
const saveFn = options?.overrideCheckpointSave || this.checkpointSave;
|
|
10866
|
+
if (!saveFn) return void 0;
|
|
10867
|
+
const checkpoint = {
|
|
10868
|
+
version: "1.0.0",
|
|
10869
|
+
timestamp: Date.now(),
|
|
10870
|
+
optimizerType,
|
|
10871
|
+
optimizerConfig,
|
|
10872
|
+
currentRound: this.currentRound,
|
|
10873
|
+
totalRounds: this.stats.resourceUsage.totalTime > 0 ? this.currentRound : 0,
|
|
10874
|
+
bestScore,
|
|
10875
|
+
bestConfiguration,
|
|
10876
|
+
scoreHistory: [...this.scoreHistory],
|
|
10877
|
+
configurationHistory: [...this.configurationHistory],
|
|
10878
|
+
stats: { ...this.stats },
|
|
10879
|
+
optimizerState,
|
|
10880
|
+
examples: this.examples,
|
|
10881
|
+
validationSet: this.validationSet
|
|
10882
|
+
};
|
|
10883
|
+
return await saveFn(checkpoint);
|
|
10884
|
+
}
|
|
10885
|
+
/**
|
|
10886
|
+
* Load optimization state from checkpoint
|
|
10887
|
+
*/
|
|
10888
|
+
async loadCheckpoint(checkpointId, options) {
|
|
10889
|
+
const loadFn = options?.overrideCheckpointLoad || this.checkpointLoad;
|
|
10890
|
+
if (!loadFn) return null;
|
|
10891
|
+
return await loadFn(checkpointId);
|
|
10892
|
+
}
|
|
10893
|
+
/**
|
|
10894
|
+
* Restore optimizer state from checkpoint
|
|
10895
|
+
*/
|
|
10896
|
+
restoreFromCheckpoint(checkpoint) {
|
|
10897
|
+
this.currentRound = checkpoint.currentRound;
|
|
10898
|
+
this.scoreHistory = [...checkpoint.scoreHistory];
|
|
10899
|
+
this.configurationHistory = [...checkpoint.configurationHistory];
|
|
10900
|
+
this.stats = { ...checkpoint.stats };
|
|
10901
|
+
}
|
|
10902
|
+
/**
|
|
10903
|
+
* Check if checkpoint should be saved
|
|
10904
|
+
*/
|
|
10905
|
+
shouldSaveCheckpoint(round, options) {
|
|
10906
|
+
const interval = options?.overrideCheckpointInterval || this.checkpointInterval;
|
|
10907
|
+
return interval !== void 0 && round % interval === 0;
|
|
10908
|
+
}
|
|
10909
|
+
/**
|
|
10910
|
+
* Update optimization progress and handle checkpointing
|
|
10911
|
+
*/
|
|
10912
|
+
async updateOptimizationProgress(round, score, configuration, optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10913
|
+
this.currentRound = round;
|
|
10914
|
+
this.scoreHistory.push(score);
|
|
10915
|
+
this.configurationHistory.push(configuration);
|
|
10916
|
+
if (this.shouldSaveCheckpoint(round, options)) {
|
|
10917
|
+
await this.saveCheckpoint(
|
|
10918
|
+
optimizerType,
|
|
10919
|
+
optimizerConfig,
|
|
10920
|
+
bestScore,
|
|
10921
|
+
bestConfiguration,
|
|
10922
|
+
optimizerState,
|
|
10923
|
+
options
|
|
10924
|
+
);
|
|
10925
|
+
}
|
|
10926
|
+
}
|
|
10927
|
+
/**
|
|
10928
|
+
* Save final checkpoint on completion
|
|
10929
|
+
*/
|
|
10930
|
+
async saveFinalCheckpoint(optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10931
|
+
if (options?.saveCheckpointOnComplete !== false) {
|
|
10932
|
+
await this.saveCheckpoint(
|
|
10933
|
+
optimizerType,
|
|
10934
|
+
optimizerConfig,
|
|
10935
|
+
bestScore,
|
|
10936
|
+
bestConfiguration,
|
|
10937
|
+
{ ...optimizerState, final: true },
|
|
10938
|
+
options
|
|
10939
|
+
);
|
|
10940
|
+
}
|
|
10941
|
+
}
|
|
10942
|
+
/**
|
|
10943
|
+
* Get the logger function with fallback hierarchy:
|
|
10944
|
+
* 1. Explicit logger passed to optimizer
|
|
10945
|
+
* 2. Logger from student AI service
|
|
10946
|
+
* 3. Default optimizer logger
|
|
10947
|
+
* 4. undefined if verbose is false
|
|
10948
|
+
*/
|
|
10949
|
+
getLogger(options) {
|
|
10950
|
+
const isVerbose = this.isLoggingEnabled(options);
|
|
10951
|
+
if (!isVerbose) {
|
|
10952
|
+
return void 0;
|
|
10953
|
+
}
|
|
10954
|
+
if (this.logger) {
|
|
10955
|
+
return this.logger;
|
|
10956
|
+
}
|
|
10957
|
+
try {
|
|
10958
|
+
const aiLogger = this.studentAI.getLogger();
|
|
10959
|
+
if (aiLogger) {
|
|
10960
|
+
return aiLogger;
|
|
10961
|
+
}
|
|
10962
|
+
} catch {
|
|
10963
|
+
}
|
|
10964
|
+
return axDefaultOptimizerLogger;
|
|
10965
|
+
}
|
|
10966
|
+
/**
|
|
10967
|
+
* Check if logging is enabled based on verbose settings
|
|
10968
|
+
*/
|
|
10969
|
+
isLoggingEnabled(options) {
|
|
10970
|
+
if (options?.verbose !== void 0) {
|
|
10971
|
+
return options.verbose;
|
|
10972
|
+
}
|
|
10973
|
+
return this.verbose ?? true;
|
|
10974
|
+
}
|
|
10975
|
+
};
|
|
10976
|
+
|
|
10977
|
+
// db/base.ts
|
|
10978
|
+
import { SpanKind as SpanKind3 } from "@opentelemetry/api";
|
|
10979
|
+
var AxDBBase = class {
|
|
10980
|
+
name;
|
|
10981
|
+
fetch;
|
|
10982
|
+
tracer;
|
|
10983
|
+
_upsert;
|
|
10984
|
+
_batchUpsert;
|
|
10985
|
+
_query;
|
|
10986
|
+
constructor({
|
|
10987
|
+
name,
|
|
10988
|
+
fetch: fetch2,
|
|
10989
|
+
tracer
|
|
10990
|
+
}) {
|
|
10991
|
+
this.name = name;
|
|
10992
|
+
this.fetch = fetch2;
|
|
10993
|
+
this.tracer = tracer;
|
|
10994
|
+
}
|
|
10995
|
+
async upsert(req, update) {
|
|
10996
|
+
if (!this._upsert) {
|
|
10997
|
+
throw new Error("upsert() not implemented");
|
|
10998
|
+
}
|
|
10999
|
+
if (!this.tracer) {
|
|
11000
|
+
return await this._upsert(req, update);
|
|
11001
|
+
}
|
|
11002
|
+
return await this.tracer.startActiveSpan(
|
|
11003
|
+
"DB Upsert Request",
|
|
11004
|
+
{
|
|
11005
|
+
kind: SpanKind3.SERVER,
|
|
11006
|
+
attributes: {
|
|
11007
|
+
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
11008
|
+
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
11009
|
+
[axSpanAttributes.DB_TABLE]: req.table,
|
|
11010
|
+
[axSpanAttributes.DB_NAMESPACE]: req.namespace,
|
|
11011
|
+
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
11012
|
+
}
|
|
11013
|
+
},
|
|
11014
|
+
async (span) => {
|
|
11015
|
+
try {
|
|
11016
|
+
return await this._upsert(req, update, { span });
|
|
11017
|
+
} finally {
|
|
11018
|
+
span.end();
|
|
11019
|
+
}
|
|
11020
|
+
}
|
|
11021
|
+
);
|
|
11022
|
+
}
|
|
11023
|
+
async batchUpsert(req, update) {
|
|
11024
|
+
if (!this._batchUpsert) {
|
|
11025
|
+
throw new Error("batchUpsert() not implemented");
|
|
11026
|
+
}
|
|
11027
|
+
if (req.length == 0) {
|
|
11028
|
+
throw new Error("Batch request is empty");
|
|
11029
|
+
}
|
|
11030
|
+
if (!req[0]) {
|
|
11031
|
+
throw new Error("Batch request is invalid first element is undefined");
|
|
11032
|
+
}
|
|
11033
|
+
if (!this.tracer) {
|
|
11034
|
+
return await this._batchUpsert(req, update);
|
|
11035
|
+
}
|
|
11036
|
+
return await this.tracer.startActiveSpan(
|
|
11037
|
+
"DB Batch Upsert Request",
|
|
11038
|
+
{
|
|
11039
|
+
kind: SpanKind3.SERVER,
|
|
11040
|
+
attributes: {
|
|
11041
|
+
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
11042
|
+
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
11043
|
+
[axSpanAttributes.DB_TABLE]: req[0].table,
|
|
11044
|
+
[axSpanAttributes.DB_NAMESPACE]: req[0].namespace,
|
|
11045
|
+
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
11046
|
+
}
|
|
11047
|
+
},
|
|
11048
|
+
async (span) => {
|
|
11049
|
+
try {
|
|
11050
|
+
return await this._batchUpsert(req, update, { span });
|
|
11051
|
+
} finally {
|
|
11052
|
+
span.end();
|
|
11053
|
+
}
|
|
11054
|
+
}
|
|
11055
|
+
);
|
|
11056
|
+
}
|
|
11057
|
+
async query(req) {
|
|
11058
|
+
if (!this._query) {
|
|
11059
|
+
throw new Error("query() not implemented");
|
|
11060
|
+
}
|
|
11061
|
+
if (!this.tracer) {
|
|
11062
|
+
return await this._query(req);
|
|
11063
|
+
}
|
|
11064
|
+
return await this.tracer.startActiveSpan(
|
|
11065
|
+
"DB Query Request",
|
|
11066
|
+
{
|
|
11067
|
+
kind: SpanKind3.SERVER,
|
|
11068
|
+
attributes: {
|
|
10120
11069
|
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10121
11070
|
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10122
11071
|
[axSpanAttributes.DB_TABLE]: req.table,
|
|
@@ -11484,52 +12433,31 @@ var AxMCPStreambleHTTPTransport = class {
|
|
|
11484
12433
|
};
|
|
11485
12434
|
|
|
11486
12435
|
// dsp/optimizers/bootstrapFewshot.ts
|
|
11487
|
-
var AxBootstrapFewShot = class {
|
|
11488
|
-
|
|
11489
|
-
|
|
11490
|
-
|
|
11491
|
-
|
|
11492
|
-
|
|
11493
|
-
|
|
11494
|
-
|
|
11495
|
-
|
|
11496
|
-
|
|
11497
|
-
|
|
11498
|
-
|
|
11499
|
-
|
|
11500
|
-
|
|
11501
|
-
|
|
11502
|
-
|
|
11503
|
-
|
|
11504
|
-
|
|
11505
|
-
|
|
11506
|
-
|
|
11507
|
-
|
|
11508
|
-
|
|
11509
|
-
|
|
11510
|
-
|
|
11511
|
-
|
|
11512
|
-
options
|
|
11513
|
-
}) {
|
|
11514
|
-
if (examples.length === 0) {
|
|
11515
|
-
throw new Error("No examples found");
|
|
11516
|
-
}
|
|
11517
|
-
const bootstrapOptions = options;
|
|
11518
|
-
this.maxRounds = bootstrapOptions?.maxRounds ?? 3;
|
|
11519
|
-
this.maxDemos = bootstrapOptions?.maxDemos ?? 4;
|
|
11520
|
-
this.maxExamples = bootstrapOptions?.maxExamples ?? 16;
|
|
11521
|
-
this.batchSize = bootstrapOptions?.batchSize ?? 1;
|
|
11522
|
-
this.earlyStoppingPatience = bootstrapOptions?.earlyStoppingPatience ?? 0;
|
|
11523
|
-
this.costMonitoring = bootstrapOptions?.costMonitoring ?? false;
|
|
11524
|
-
this.maxTokensPerGeneration = bootstrapOptions?.maxTokensPerGeneration ?? 0;
|
|
11525
|
-
this.verboseMode = bootstrapOptions?.verboseMode ?? true;
|
|
11526
|
-
this.debugMode = bootstrapOptions?.debugMode ?? false;
|
|
11527
|
-
this.ai = ai;
|
|
11528
|
-
this.teacherAI = bootstrapOptions?.teacherAI;
|
|
11529
|
-
this.program = program;
|
|
11530
|
-
this.examples = examples;
|
|
11531
|
-
}
|
|
11532
|
-
async compileRound(roundIndex, metricFn, options) {
|
|
12436
|
+
var AxBootstrapFewShot = class extends AxBaseOptimizer {
|
|
12437
|
+
maxRounds;
|
|
12438
|
+
maxDemos;
|
|
12439
|
+
maxExamples;
|
|
12440
|
+
batchSize;
|
|
12441
|
+
earlyStoppingPatience;
|
|
12442
|
+
costMonitoring;
|
|
12443
|
+
maxTokensPerGeneration;
|
|
12444
|
+
verboseMode;
|
|
12445
|
+
debugMode;
|
|
12446
|
+
traces = [];
|
|
12447
|
+
constructor(args) {
|
|
12448
|
+
super(args);
|
|
12449
|
+
const options = args.options || {};
|
|
12450
|
+
this.maxRounds = options.maxRounds ?? 3;
|
|
12451
|
+
this.maxDemos = options.maxDemos ?? 4;
|
|
12452
|
+
this.maxExamples = options.maxExamples ?? 16;
|
|
12453
|
+
this.batchSize = options.batchSize ?? 1;
|
|
12454
|
+
this.earlyStoppingPatience = options.earlyStoppingPatience ?? 0;
|
|
12455
|
+
this.costMonitoring = options.costMonitoring ?? false;
|
|
12456
|
+
this.maxTokensPerGeneration = options.maxTokensPerGeneration ?? 0;
|
|
12457
|
+
this.verboseMode = options.verboseMode ?? true;
|
|
12458
|
+
this.debugMode = options.debugMode ?? false;
|
|
12459
|
+
}
|
|
12460
|
+
async compileRound(program, roundIndex, metricFn, options) {
|
|
11533
12461
|
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
11534
12462
|
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
11535
12463
|
const aiOpt = {
|
|
@@ -11552,20 +12480,20 @@ var AxBootstrapFewShot = class {
|
|
|
11552
12480
|
continue;
|
|
11553
12481
|
}
|
|
11554
12482
|
const exList = examples.filter((e) => e !== ex);
|
|
11555
|
-
|
|
11556
|
-
const aiService = this.
|
|
12483
|
+
program.setExamples(exList);
|
|
12484
|
+
const aiService = this.getTeacherOrStudentAI();
|
|
11557
12485
|
this.stats.totalCalls++;
|
|
11558
12486
|
let res;
|
|
11559
12487
|
let error;
|
|
11560
12488
|
try {
|
|
11561
|
-
res = await
|
|
12489
|
+
res = await program.forward(aiService, ex, aiOpt);
|
|
11562
12490
|
if (this.costMonitoring) {
|
|
11563
12491
|
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
11564
12492
|
}
|
|
11565
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
12493
|
+
const score = await metricFn({ prediction: res, example: ex });
|
|
11566
12494
|
const success = score >= 0.5;
|
|
11567
12495
|
if (success) {
|
|
11568
|
-
this.traces = [...this.traces, ...
|
|
12496
|
+
this.traces = [...this.traces, ...program.getTraces()];
|
|
11569
12497
|
this.stats.successfulDemos++;
|
|
11570
12498
|
}
|
|
11571
12499
|
} catch (err) {
|
|
@@ -11616,54 +12544,73 @@ var AxBootstrapFewShot = class {
|
|
|
11616
12544
|
if (!this.stats.earlyStopping) {
|
|
11617
12545
|
this.stats.earlyStopping = {
|
|
11618
12546
|
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
11619
|
-
patienceExhausted: false
|
|
12547
|
+
patienceExhausted: false,
|
|
12548
|
+
reason: "No improvement detected"
|
|
11620
12549
|
};
|
|
11621
12550
|
} else if (improvement > 0) {
|
|
11622
12551
|
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
11623
12552
|
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
11624
12553
|
this.stats.earlyStopping.patienceExhausted = true;
|
|
11625
12554
|
this.stats.earlyStopped = true;
|
|
12555
|
+
this.stats.earlyStopping.reason = `No improvement for ${this.earlyStoppingPatience} rounds`;
|
|
11626
12556
|
if (this.verboseMode || this.debugMode) {
|
|
11627
|
-
|
|
11628
|
-
`
|
|
11629
|
-
|
|
12557
|
+
this.getLogger()?.(
|
|
12558
|
+
`Early stopping after ${roundIndex + 1} rounds (no improvement for ${this.earlyStoppingPatience} rounds)`,
|
|
12559
|
+
{ tags: ["optimizer", "warning"] }
|
|
11630
12560
|
);
|
|
11631
12561
|
}
|
|
11632
12562
|
return;
|
|
11633
12563
|
}
|
|
11634
12564
|
}
|
|
11635
12565
|
}
|
|
11636
|
-
async compile(metricFn, options) {
|
|
11637
|
-
const
|
|
11638
|
-
const maxRounds = compileOptions?.maxRounds ?? this.maxRounds;
|
|
12566
|
+
async compile(program, metricFn, options) {
|
|
12567
|
+
const maxRounds = options?.maxIterations ?? this.maxRounds;
|
|
11639
12568
|
this.traces = [];
|
|
11640
|
-
this.
|
|
11641
|
-
|
|
11642
|
-
|
|
11643
|
-
|
|
11644
|
-
|
|
11645
|
-
|
|
12569
|
+
this.reset();
|
|
12570
|
+
if (this.verboseMode || this.debugMode) {
|
|
12571
|
+
this.getLogger()?.(
|
|
12572
|
+
`Starting BootstrapFewshot optimization with ${maxRounds} rounds`,
|
|
12573
|
+
{ tags: ["optimizer", "start"] }
|
|
12574
|
+
);
|
|
12575
|
+
this.getLogger()?.(
|
|
12576
|
+
`Using ${this.examples.length} examples, max ${this.maxDemos} demos`,
|
|
12577
|
+
{ tags: ["optimizer", "config"] }
|
|
12578
|
+
);
|
|
12579
|
+
}
|
|
11646
12580
|
for (let i = 0; i < maxRounds; i++) {
|
|
11647
|
-
await this.compileRound(i, metricFn,
|
|
12581
|
+
await this.compileRound(program, i, metricFn, options);
|
|
11648
12582
|
if (this.stats.earlyStopped) {
|
|
11649
12583
|
break;
|
|
11650
12584
|
}
|
|
11651
12585
|
}
|
|
11652
12586
|
if (this.traces.length === 0) {
|
|
11653
12587
|
throw new Error(
|
|
11654
|
-
"No demonstrations found. Either
|
|
12588
|
+
"No demonstrations found. Either provide more examples or improve the existing ones."
|
|
11655
12589
|
);
|
|
11656
12590
|
}
|
|
11657
12591
|
const demos = groupTracesByKeys(this.traces);
|
|
12592
|
+
let bestScore = 0;
|
|
12593
|
+
if (this.traces.length > 0) {
|
|
12594
|
+
bestScore = this.stats.successfulDemos / Math.max(1, this.stats.totalCalls);
|
|
12595
|
+
}
|
|
12596
|
+
if (this.verboseMode || this.debugMode) {
|
|
12597
|
+
this.getLogger()?.(
|
|
12598
|
+
`Bootstrap complete. Generated ${demos.length} demos with ${bestScore.toFixed(3)} success rate`,
|
|
12599
|
+
{ tags: ["optimizer", "complete"] }
|
|
12600
|
+
);
|
|
12601
|
+
}
|
|
11658
12602
|
return {
|
|
11659
12603
|
demos,
|
|
11660
|
-
stats: this.stats
|
|
12604
|
+
stats: this.stats,
|
|
12605
|
+
bestScore,
|
|
12606
|
+
finalConfiguration: {
|
|
12607
|
+
maxRounds: this.maxRounds,
|
|
12608
|
+
maxDemos: this.maxDemos,
|
|
12609
|
+
batchSize: this.batchSize,
|
|
12610
|
+
successRate: bestScore
|
|
12611
|
+
}
|
|
11661
12612
|
};
|
|
11662
12613
|
}
|
|
11663
|
-
// Get optimization statistics
|
|
11664
|
-
getStats() {
|
|
11665
|
-
return this.stats;
|
|
11666
|
-
}
|
|
11667
12614
|
};
|
|
11668
12615
|
function groupTracesByKeys(programTraces) {
|
|
11669
12616
|
const groupedTraces = /* @__PURE__ */ new Map();
|
|
@@ -11678,9 +12625,12 @@ function groupTracesByKeys(programTraces) {
|
|
|
11678
12625
|
}
|
|
11679
12626
|
}
|
|
11680
12627
|
const programDemosArray = [];
|
|
11681
|
-
|
|
11682
|
-
programDemosArray.push({
|
|
11683
|
-
|
|
12628
|
+
groupedTraces.forEach((traces, programId) => {
|
|
12629
|
+
programDemosArray.push({
|
|
12630
|
+
traces,
|
|
12631
|
+
programId
|
|
12632
|
+
});
|
|
12633
|
+
});
|
|
11684
12634
|
return programDemosArray;
|
|
11685
12635
|
}
|
|
11686
12636
|
var randomSample = (array, n) => {
|
|
@@ -11699,10 +12649,8 @@ var randomSample = (array, n) => {
|
|
|
11699
12649
|
};
|
|
11700
12650
|
|
|
11701
12651
|
// dsp/optimizers/miproV2.ts
|
|
11702
|
-
var AxMiPRO = class {
|
|
11703
|
-
|
|
11704
|
-
program;
|
|
11705
|
-
examples;
|
|
12652
|
+
var AxMiPRO = class extends AxBaseOptimizer {
|
|
12653
|
+
// MiPRO-specific options
|
|
11706
12654
|
maxBootstrappedDemos;
|
|
11707
12655
|
maxLabeledDemos;
|
|
11708
12656
|
numCandidates;
|
|
@@ -11716,52 +12664,33 @@ var AxMiPRO = class {
|
|
|
11716
12664
|
viewDataBatchSize;
|
|
11717
12665
|
tipAwareProposer;
|
|
11718
12666
|
fewshotAwareProposer;
|
|
11719
|
-
seed;
|
|
11720
|
-
verbose;
|
|
11721
|
-
bootstrapper;
|
|
11722
12667
|
earlyStoppingTrials;
|
|
11723
12668
|
minImprovementThreshold;
|
|
11724
|
-
|
|
11725
|
-
|
|
11726
|
-
|
|
11727
|
-
|
|
11728
|
-
|
|
11729
|
-
|
|
11730
|
-
|
|
11731
|
-
|
|
11732
|
-
|
|
11733
|
-
|
|
11734
|
-
this.
|
|
11735
|
-
this.
|
|
11736
|
-
this.
|
|
11737
|
-
this.
|
|
11738
|
-
this.
|
|
11739
|
-
this.
|
|
11740
|
-
this.
|
|
11741
|
-
this.
|
|
11742
|
-
this.
|
|
11743
|
-
this.
|
|
11744
|
-
this.
|
|
11745
|
-
this.
|
|
11746
|
-
this.
|
|
11747
|
-
this.
|
|
11748
|
-
this.
|
|
11749
|
-
this.earlyStoppingTrials = miproOptions.earlyStoppingTrials ?? 5;
|
|
11750
|
-
this.minImprovementThreshold = miproOptions.minImprovementThreshold ?? 0.01;
|
|
11751
|
-
this.ai = ai;
|
|
11752
|
-
this.program = program;
|
|
11753
|
-
this.examples = examples;
|
|
11754
|
-
this.bootstrapper = new AxBootstrapFewShot({
|
|
11755
|
-
ai,
|
|
11756
|
-
program,
|
|
11757
|
-
examples,
|
|
11758
|
-
options: {
|
|
11759
|
-
maxDemos: this.maxBootstrappedDemos,
|
|
11760
|
-
maxRounds: 3,
|
|
11761
|
-
// Default, or adjust based on your needs
|
|
11762
|
-
verboseMode: this.verbose
|
|
11763
|
-
}
|
|
11764
|
-
});
|
|
12669
|
+
bayesianOptimization;
|
|
12670
|
+
acquisitionFunction;
|
|
12671
|
+
explorationWeight;
|
|
12672
|
+
constructor(args) {
|
|
12673
|
+
super(args);
|
|
12674
|
+
const options = args.options || {};
|
|
12675
|
+
this.numCandidates = options.numCandidates ?? 5;
|
|
12676
|
+
this.initTemperature = options.initTemperature ?? 0.7;
|
|
12677
|
+
this.maxBootstrappedDemos = options.maxBootstrappedDemos ?? 3;
|
|
12678
|
+
this.maxLabeledDemos = options.maxLabeledDemos ?? 4;
|
|
12679
|
+
this.numTrials = options.numTrials ?? 30;
|
|
12680
|
+
this.minibatch = options.minibatch ?? true;
|
|
12681
|
+
this.minibatchSize = options.minibatchSize ?? 25;
|
|
12682
|
+
this.minibatchFullEvalSteps = options.minibatchFullEvalSteps ?? 10;
|
|
12683
|
+
this.programAwareProposer = options.programAwareProposer ?? true;
|
|
12684
|
+
this.dataAwareProposer = options.dataAwareProposer ?? true;
|
|
12685
|
+
this.viewDataBatchSize = options.viewDataBatchSize ?? 10;
|
|
12686
|
+
this.tipAwareProposer = options.tipAwareProposer ?? true;
|
|
12687
|
+
this.fewshotAwareProposer = options.fewshotAwareProposer ?? true;
|
|
12688
|
+
this.earlyStoppingTrials = options.earlyStoppingTrials ?? 5;
|
|
12689
|
+
this.minImprovementThreshold = options.minImprovementThreshold ?? 0.01;
|
|
12690
|
+
this.bayesianOptimization = options.bayesianOptimization ?? false;
|
|
12691
|
+
this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
|
|
12692
|
+
this.explorationWeight = options.explorationWeight ?? 0.1;
|
|
12693
|
+
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
11765
12694
|
}
|
|
11766
12695
|
/**
|
|
11767
12696
|
* Configures the optimizer for light, medium, or heavy optimization
|
|
@@ -11805,123 +12734,62 @@ var AxMiPRO = class {
|
|
|
11805
12734
|
];
|
|
11806
12735
|
}
|
|
11807
12736
|
/**
|
|
11808
|
-
* Generates instruction candidates
|
|
12737
|
+
* Generates instruction candidates using the teacher model if available
|
|
12738
|
+
* @param options Optional compile options that may override teacher AI
|
|
11809
12739
|
* @returns Array of generated instruction candidates
|
|
11810
12740
|
*/
|
|
11811
|
-
async proposeInstructionCandidates() {
|
|
12741
|
+
async proposeInstructionCandidates(options) {
|
|
11812
12742
|
const instructions = [];
|
|
11813
|
-
|
|
11814
|
-
if (this.programAwareProposer) {
|
|
11815
|
-
programContext = await this.generateProgramSummary();
|
|
11816
|
-
}
|
|
11817
|
-
let dataContext = "";
|
|
11818
|
-
if (this.dataAwareProposer) {
|
|
11819
|
-
dataContext = await this.generateDataSummary();
|
|
11820
|
-
}
|
|
12743
|
+
const aiToUse = this.getTeacherOrStudentAI(options);
|
|
11821
12744
|
const tips = this.tipAwareProposer ? this.generateTips() : [];
|
|
11822
12745
|
for (let i = 0; i < this.numCandidates; i++) {
|
|
11823
12746
|
const tipIndex = tips.length > 0 ? i % tips.length : -1;
|
|
11824
12747
|
const tipToUse = tipIndex >= 0 ? tips[tipIndex] : "";
|
|
11825
12748
|
const instruction = await this.generateInstruction({
|
|
11826
|
-
programContext,
|
|
11827
|
-
dataContext,
|
|
11828
12749
|
tip: tipToUse,
|
|
11829
|
-
candidateIndex: i
|
|
12750
|
+
candidateIndex: i,
|
|
12751
|
+
ai: aiToUse
|
|
11830
12752
|
});
|
|
11831
12753
|
instructions.push(instruction);
|
|
11832
12754
|
}
|
|
11833
12755
|
return instructions;
|
|
11834
12756
|
}
|
|
11835
|
-
/**
|
|
11836
|
-
* Generates a summary of the program structure for instruction proposal
|
|
11837
|
-
*/
|
|
11838
|
-
async generateProgramSummary() {
|
|
11839
|
-
const prompt = `Summarize the following program structure. Focus on the signatures,
|
|
11840
|
-
input/output fields, and the purpose of each component. Identify key components
|
|
11841
|
-
that might benefit from better instructions.`;
|
|
11842
|
-
const programStr = JSON.stringify(this.program);
|
|
11843
|
-
const response = await this.ai.chat({
|
|
11844
|
-
chatPrompt: [
|
|
11845
|
-
{ role: "system", content: prompt },
|
|
11846
|
-
{ role: "user", content: programStr }
|
|
11847
|
-
],
|
|
11848
|
-
modelConfig: { temperature: 0.2 }
|
|
11849
|
-
});
|
|
11850
|
-
if (response instanceof ReadableStream) {
|
|
11851
|
-
return "";
|
|
11852
|
-
}
|
|
11853
|
-
return response.results[0]?.content || "";
|
|
11854
|
-
}
|
|
11855
|
-
/**
|
|
11856
|
-
* Generates a summary of the dataset for instruction proposal
|
|
11857
|
-
*/
|
|
11858
|
-
async generateDataSummary() {
|
|
11859
|
-
const sampleSize = Math.min(this.viewDataBatchSize, this.examples.length);
|
|
11860
|
-
const sample = this.examples.slice(0, sampleSize);
|
|
11861
|
-
const prompt = `Analyze the following dataset examples and provide a summary
|
|
11862
|
-
of key patterns, input-output relationships, and any specific challenges
|
|
11863
|
-
the data presents. Focus on what makes a good answer and what patterns should
|
|
11864
|
-
be followed.`;
|
|
11865
|
-
const dataStr = JSON.stringify(sample);
|
|
11866
|
-
const response = await this.ai.chat({
|
|
11867
|
-
chatPrompt: [
|
|
11868
|
-
{ role: "system", content: prompt },
|
|
11869
|
-
{ role: "user", content: dataStr }
|
|
11870
|
-
],
|
|
11871
|
-
modelConfig: { temperature: 0.2 }
|
|
11872
|
-
});
|
|
11873
|
-
if (response instanceof ReadableStream) {
|
|
11874
|
-
return "";
|
|
11875
|
-
}
|
|
11876
|
-
return response.results[0]?.content || "";
|
|
11877
|
-
}
|
|
11878
|
-
/**
|
|
11879
|
-
* Generates a specific instruction candidate
|
|
11880
|
-
*/
|
|
11881
12757
|
async generateInstruction({
|
|
11882
|
-
programContext,
|
|
11883
|
-
dataContext,
|
|
11884
12758
|
tip,
|
|
11885
12759
|
candidateIndex
|
|
11886
12760
|
}) {
|
|
11887
|
-
const
|
|
11888
|
-
|
|
11889
|
-
|
|
11890
|
-
|
|
11891
|
-
|
|
11892
|
-
|
|
11893
|
-
|
|
11894
|
-
|
|
11895
|
-
|
|
11896
|
-
|
|
11897
|
-
${tip ? `STYLE TIP: ${tip}
|
|
11898
|
-
|
|
11899
|
-
` : ""}
|
|
11900
|
-
|
|
11901
|
-
Your task is to craft a clear, effective instruction that will help the AI model generate
|
|
11902
|
-
accurate outputs for this task. Instruction #${candidateIndex + 1}/${this.numCandidates}.
|
|
11903
|
-
|
|
11904
|
-
The instruction should be detailed enough to guide the model but not overly prescriptive
|
|
11905
|
-
or restrictive. Focus on what makes a good response rather than listing exact steps.
|
|
11906
|
-
|
|
11907
|
-
INSTRUCTION:`;
|
|
11908
|
-
const response = await this.ai.chat({
|
|
11909
|
-
chatPrompt: [{ role: "user", content: prompt }],
|
|
11910
|
-
modelConfig: { temperature: 0.7 + 0.1 * candidateIndex }
|
|
11911
|
-
});
|
|
11912
|
-
if (response instanceof ReadableStream) {
|
|
11913
|
-
return "";
|
|
12761
|
+
const baseInstructions = [
|
|
12762
|
+
"Analyze the input carefully and provide a detailed response.",
|
|
12763
|
+
"Think step by step and provide a clear answer.",
|
|
12764
|
+
"Consider all aspects of the input before responding.",
|
|
12765
|
+
"Provide a concise but comprehensive response.",
|
|
12766
|
+
"Focus on accuracy and clarity in your response."
|
|
12767
|
+
];
|
|
12768
|
+
let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
|
|
12769
|
+
if (tip) {
|
|
12770
|
+
instruction = `${instruction} ${tip}`;
|
|
11914
12771
|
}
|
|
11915
|
-
return
|
|
12772
|
+
return instruction;
|
|
11916
12773
|
}
|
|
11917
12774
|
/**
|
|
11918
12775
|
* Bootstraps few-shot examples for the program
|
|
11919
12776
|
*/
|
|
11920
|
-
async bootstrapFewShotExamples(metricFn) {
|
|
11921
|
-
if (this.
|
|
11922
|
-
|
|
12777
|
+
async bootstrapFewShotExamples(program, metricFn) {
|
|
12778
|
+
if (this.isLoggingEnabled()) {
|
|
12779
|
+
this.getLogger()?.("Bootstrapping few-shot examples...", {
|
|
12780
|
+
tags: ["optimizer", "phase"]
|
|
12781
|
+
});
|
|
11923
12782
|
}
|
|
11924
|
-
const
|
|
12783
|
+
const bootstrapper = new AxBootstrapFewShot({
|
|
12784
|
+
studentAI: this.studentAI,
|
|
12785
|
+
examples: this.examples,
|
|
12786
|
+
options: {
|
|
12787
|
+
maxDemos: this.maxBootstrappedDemos,
|
|
12788
|
+
maxRounds: 3,
|
|
12789
|
+
verboseMode: this.isLoggingEnabled()
|
|
12790
|
+
}
|
|
12791
|
+
});
|
|
12792
|
+
const result = await bootstrapper.compile(program, metricFn, {
|
|
11925
12793
|
maxDemos: this.maxBootstrappedDemos
|
|
11926
12794
|
});
|
|
11927
12795
|
return result.demos || [];
|
|
@@ -11945,109 +12813,111 @@ ${dataContext}
|
|
|
11945
12813
|
return selectedExamples;
|
|
11946
12814
|
}
|
|
11947
12815
|
/**
|
|
11948
|
-
* Runs
|
|
12816
|
+
* Runs optimization to find the best combination of few-shot examples and instructions
|
|
11949
12817
|
*/
|
|
11950
|
-
async
|
|
11951
|
-
let bestConfig =
|
|
11952
|
-
let bestScore = Number.NEGATIVE_INFINITY;
|
|
11953
|
-
const evaluatedConfigs = [];
|
|
11954
|
-
const defaultConfig = {
|
|
12818
|
+
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, valset, metricFn, options) {
|
|
12819
|
+
let bestConfig = {
|
|
11955
12820
|
instruction: instructions[0] || "",
|
|
11956
12821
|
bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
|
|
11957
12822
|
labeledExamples: Math.min(1, labeledExamples.length)
|
|
11958
12823
|
};
|
|
11959
|
-
let
|
|
11960
|
-
let
|
|
11961
|
-
const
|
|
11962
|
-
|
|
11963
|
-
|
|
11964
|
-
|
|
11965
|
-
|
|
11966
|
-
|
|
11967
|
-
|
|
11968
|
-
|
|
12824
|
+
let bestScore = 0;
|
|
12825
|
+
let stagnationRounds = 0;
|
|
12826
|
+
const scoreHistory = [];
|
|
12827
|
+
let startRound = 0;
|
|
12828
|
+
if (this.resumeFromCheckpoint) {
|
|
12829
|
+
const checkpoint = await this.loadCheckpoint(
|
|
12830
|
+
this.resumeFromCheckpoint,
|
|
12831
|
+
options
|
|
12832
|
+
);
|
|
12833
|
+
if (checkpoint && checkpoint.optimizerType === "MiPRO") {
|
|
12834
|
+
if (this.isLoggingEnabled(options)) {
|
|
12835
|
+
this.getLogger(options)?.(
|
|
12836
|
+
`Resuming from checkpoint at round ${checkpoint.currentRound}`,
|
|
12837
|
+
{ tags: ["optimizer", "checkpoint"] }
|
|
12838
|
+
);
|
|
12839
|
+
}
|
|
12840
|
+
this.restoreFromCheckpoint(checkpoint);
|
|
12841
|
+
startRound = checkpoint.currentRound;
|
|
12842
|
+
bestScore = checkpoint.bestScore;
|
|
12843
|
+
bestConfig = checkpoint.bestConfiguration || bestConfig;
|
|
12844
|
+
stagnationRounds = checkpoint.stats.convergenceInfo?.stagnationRounds || 0;
|
|
12845
|
+
}
|
|
12846
|
+
}
|
|
12847
|
+
if (this.isLoggingEnabled(options)) {
|
|
12848
|
+
this.getLogger(options)?.(
|
|
12849
|
+
`Running optimization trials (${this.numTrials} total)`,
|
|
12850
|
+
{ tags: ["optimizer", "phase"] }
|
|
12851
|
+
);
|
|
12852
|
+
}
|
|
12853
|
+
for (let i = startRound; i < this.numTrials; i++) {
|
|
11969
12854
|
const config = {
|
|
11970
|
-
instruction:
|
|
11971
|
-
bootstrappedDemos: Math.
|
|
11972
|
-
Math.random() * (bootstrappedDemos.length + 1)
|
|
12855
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
12856
|
+
bootstrappedDemos: Math.min(
|
|
12857
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
12858
|
+
this.maxBootstrappedDemos
|
|
11973
12859
|
),
|
|
11974
|
-
labeledExamples: Math.
|
|
11975
|
-
Math.random() * (labeledExamples.length + 1)
|
|
12860
|
+
labeledExamples: Math.min(
|
|
12861
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
12862
|
+
this.maxLabeledDemos
|
|
11976
12863
|
)
|
|
11977
12864
|
};
|
|
11978
|
-
configs.push(config);
|
|
11979
|
-
}
|
|
11980
|
-
for (let i = 0; i < configs.length; i++) {
|
|
11981
|
-
const config = configs[i];
|
|
11982
|
-
if (!config) continue;
|
|
11983
12865
|
const score = await this.evaluateConfig(
|
|
12866
|
+
program,
|
|
11984
12867
|
config,
|
|
11985
12868
|
bootstrappedDemos,
|
|
11986
12869
|
labeledExamples,
|
|
11987
12870
|
valset,
|
|
11988
|
-
metricFn
|
|
11989
|
-
i
|
|
12871
|
+
metricFn
|
|
11990
12872
|
);
|
|
11991
|
-
|
|
11992
|
-
|
|
12873
|
+
scoreHistory.push(score);
|
|
12874
|
+
const improvement = score - bestScore;
|
|
12875
|
+
if (improvement > this.minImprovementThreshold) {
|
|
11993
12876
|
bestScore = score;
|
|
11994
12877
|
bestConfig = config;
|
|
11995
|
-
|
|
11996
|
-
|
|
11997
|
-
|
|
12878
|
+
stagnationRounds = 0;
|
|
12879
|
+
if (this.isLoggingEnabled(options)) {
|
|
12880
|
+
this.getLogger(options)?.(
|
|
12881
|
+
`Trial ${i + 1}/${this.numTrials}: New best score ${bestScore.toFixed(3)}`,
|
|
12882
|
+
{ tags: ["optimizer", "progress"] }
|
|
11998
12883
|
);
|
|
11999
12884
|
}
|
|
12885
|
+
} else {
|
|
12886
|
+
stagnationRounds++;
|
|
12000
12887
|
}
|
|
12001
|
-
|
|
12888
|
+
await this.updateOptimizationProgress(
|
|
12002
12889
|
i + 1,
|
|
12003
|
-
|
|
12004
|
-
|
|
12005
|
-
|
|
12006
|
-
|
|
12007
|
-
|
|
12008
|
-
|
|
12009
|
-
|
|
12010
|
-
|
|
12011
|
-
|
|
12012
|
-
|
|
12013
|
-
|
|
12014
|
-
|
|
12015
|
-
|
|
12016
|
-
);
|
|
12017
|
-
const score = await this.evaluateConfig(
|
|
12018
|
-
nextConfig,
|
|
12019
|
-
bootstrappedDemos,
|
|
12020
|
-
labeledExamples,
|
|
12021
|
-
valset,
|
|
12022
|
-
metricFn,
|
|
12023
|
-
i
|
|
12890
|
+
score,
|
|
12891
|
+
config,
|
|
12892
|
+
"MiPRO",
|
|
12893
|
+
this.getConfiguration(),
|
|
12894
|
+
bestScore,
|
|
12895
|
+
bestConfig,
|
|
12896
|
+
{
|
|
12897
|
+
stagnationRounds,
|
|
12898
|
+
bootstrappedDemos: bootstrappedDemos.length,
|
|
12899
|
+
labeledExamples: labeledExamples.length,
|
|
12900
|
+
instructions: instructions.length
|
|
12901
|
+
},
|
|
12902
|
+
options
|
|
12024
12903
|
);
|
|
12025
|
-
|
|
12026
|
-
|
|
12027
|
-
|
|
12028
|
-
|
|
12029
|
-
|
|
12030
|
-
|
|
12031
|
-
|
|
12032
|
-
)
|
|
12033
|
-
|
|
12034
|
-
|
|
12035
|
-
|
|
12036
|
-
|
|
12037
|
-
|
|
12038
|
-
|
|
12039
|
-
|
|
12040
|
-
if (this.verbose) {
|
|
12041
|
-
console.log(
|
|
12042
|
-
`Early stopping triggered after ${i + 1} trials. No improvement for ${trialsWithoutImprovement} trials.`
|
|
12043
|
-
);
|
|
12044
|
-
}
|
|
12045
|
-
break;
|
|
12904
|
+
if (this.onProgress) {
|
|
12905
|
+
this.onProgress({
|
|
12906
|
+
round: i + 1,
|
|
12907
|
+
totalRounds: this.numTrials,
|
|
12908
|
+
currentScore: score,
|
|
12909
|
+
bestScore,
|
|
12910
|
+
tokensUsed: this.stats.resourceUsage.totalTokens,
|
|
12911
|
+
timeElapsed: Date.now(),
|
|
12912
|
+
successfulExamples: this.stats.successfulDemos,
|
|
12913
|
+
totalExamples: this.examples.length,
|
|
12914
|
+
currentConfiguration: config,
|
|
12915
|
+
convergenceInfo: {
|
|
12916
|
+
improvement,
|
|
12917
|
+
stagnationRounds,
|
|
12918
|
+
isConverging: stagnationRounds < this.earlyStoppingTrials
|
|
12046
12919
|
}
|
|
12047
|
-
}
|
|
12048
|
-
lastBestScore = bestScore;
|
|
12049
|
-
trialsWithoutImprovement = 0;
|
|
12050
|
-
}
|
|
12920
|
+
});
|
|
12051
12921
|
}
|
|
12052
12922
|
updateProgressBar(
|
|
12053
12923
|
i + 1,
|
|
@@ -12057,290 +12927,309 @@ ${dataContext}
|
|
|
12057
12927
|
"Running MIPROv2 optimization",
|
|
12058
12928
|
30
|
|
12059
12929
|
);
|
|
12060
|
-
if (this.
|
|
12061
|
-
|
|
12062
|
-
|
|
12063
|
-
`Running full evaluation on best configuration at trial ${i + 1}`
|
|
12064
|
-
);
|
|
12065
|
-
}
|
|
12066
|
-
const fullScore = await this.fullEvaluation(
|
|
12067
|
-
bestConfig,
|
|
12068
|
-
bootstrappedDemos,
|
|
12069
|
-
labeledExamples,
|
|
12070
|
-
valset,
|
|
12071
|
-
metricFn
|
|
12072
|
-
);
|
|
12073
|
-
if (this.verbose) {
|
|
12074
|
-
console.log(`Full evaluation score: ${fullScore}`);
|
|
12075
|
-
}
|
|
12076
|
-
bestScore = fullScore;
|
|
12930
|
+
if (this.checkCostLimits()) {
|
|
12931
|
+
this.triggerEarlyStopping("Cost limit reached", i + 1);
|
|
12932
|
+
break;
|
|
12077
12933
|
}
|
|
12078
|
-
|
|
12079
|
-
|
|
12080
|
-
|
|
12081
|
-
|
|
12082
|
-
"Optimization failed to find any valid configurations, using default fallback configuration"
|
|
12934
|
+
if (stagnationRounds >= this.earlyStoppingTrials) {
|
|
12935
|
+
this.triggerEarlyStopping(
|
|
12936
|
+
`No improvement for ${this.earlyStoppingTrials} trials`,
|
|
12937
|
+
i - stagnationRounds + 1
|
|
12083
12938
|
);
|
|
12939
|
+
break;
|
|
12084
12940
|
}
|
|
12085
|
-
|
|
12086
|
-
|
|
12087
|
-
|
|
12088
|
-
|
|
12089
|
-
bootstrappedDemos,
|
|
12090
|
-
labeledExamples,
|
|
12091
|
-
valset,
|
|
12092
|
-
metricFn,
|
|
12093
|
-
this.numTrials - 1
|
|
12941
|
+
if (this.checkTargetScore(bestScore)) {
|
|
12942
|
+
this.triggerEarlyStopping(
|
|
12943
|
+
`Target score ${this.targetScore} reached`,
|
|
12944
|
+
i + 1
|
|
12094
12945
|
);
|
|
12095
|
-
|
|
12096
|
-
if (this.verbose) {
|
|
12097
|
-
console.error("Error evaluating default configuration:", err);
|
|
12098
|
-
}
|
|
12099
|
-
bestScore = 0;
|
|
12946
|
+
break;
|
|
12100
12947
|
}
|
|
12101
12948
|
}
|
|
12949
|
+
this.stats.convergenceInfo.stagnationRounds = stagnationRounds;
|
|
12950
|
+
this.stats.convergenceInfo.finalImprovement = scoreHistory.length > 1 ? bestScore - scoreHistory[0] : 0;
|
|
12951
|
+
this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
|
|
12102
12952
|
return { bestConfig, bestScore };
|
|
12103
12953
|
}
|
|
12104
|
-
|
|
12105
|
-
|
|
12106
|
-
*/
|
|
12107
|
-
async evaluateConfig(config, bootstrappedDemos, labeledExamples, valset, metricFn, trialIndex) {
|
|
12954
|
+
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, valset, metricFn) {
|
|
12955
|
+
const testProgram = { ...program };
|
|
12108
12956
|
this.applyConfigToProgram(
|
|
12109
|
-
|
|
12957
|
+
testProgram,
|
|
12110
12958
|
config,
|
|
12111
12959
|
bootstrappedDemos,
|
|
12112
12960
|
labeledExamples
|
|
12113
12961
|
);
|
|
12114
|
-
let
|
|
12115
|
-
|
|
12116
|
-
|
|
12117
|
-
const minibatchEvalSet = [];
|
|
12118
|
-
for (let j = 0; j < this.minibatchSize; j++) {
|
|
12119
|
-
const idx = (startIdx + j) % valset.length;
|
|
12120
|
-
const example = valset[idx];
|
|
12121
|
-
if (example) {
|
|
12122
|
-
minibatchEvalSet.push(example);
|
|
12123
|
-
}
|
|
12124
|
-
}
|
|
12125
|
-
evalSet = minibatchEvalSet;
|
|
12126
|
-
}
|
|
12127
|
-
let sumOfScores = 0;
|
|
12962
|
+
let totalScore = 0;
|
|
12963
|
+
let count = 0;
|
|
12964
|
+
const evalSet = valset.slice(0, Math.min(5, valset.length));
|
|
12128
12965
|
for (const example of evalSet) {
|
|
12129
12966
|
try {
|
|
12130
|
-
const prediction = await
|
|
12131
|
-
|
|
12132
|
-
|
|
12133
|
-
|
|
12134
|
-
|
|
12135
|
-
|
|
12136
|
-
|
|
12137
|
-
|
|
12138
|
-
|
|
12139
|
-
|
|
12140
|
-
return sumOfScores / evalSet.length;
|
|
12141
|
-
}
|
|
12142
|
-
/**
|
|
12143
|
-
* Run full evaluation on the entire validation set
|
|
12144
|
-
*/
|
|
12145
|
-
async fullEvaluation(config, bootstrappedDemos, labeledExamples, valset, metricFn) {
|
|
12146
|
-
this.applyConfigToProgram(
|
|
12147
|
-
this.program,
|
|
12148
|
-
config,
|
|
12149
|
-
bootstrappedDemos,
|
|
12150
|
-
labeledExamples
|
|
12151
|
-
);
|
|
12152
|
-
let sumOfScores = 0;
|
|
12153
|
-
for (const example of valset) {
|
|
12154
|
-
try {
|
|
12155
|
-
const prediction = await this.program.forward(this.ai, example);
|
|
12156
|
-
const score = metricFn({ prediction, example });
|
|
12157
|
-
sumOfScores += score;
|
|
12158
|
-
} catch (err) {
|
|
12159
|
-
if (this.verbose) {
|
|
12160
|
-
console.error("Error evaluating example:", err);
|
|
12161
|
-
}
|
|
12967
|
+
const prediction = await testProgram.forward(
|
|
12968
|
+
this.studentAI,
|
|
12969
|
+
example
|
|
12970
|
+
);
|
|
12971
|
+
const score = await metricFn({ prediction, example });
|
|
12972
|
+
totalScore += score;
|
|
12973
|
+
count++;
|
|
12974
|
+
this.stats.totalCalls++;
|
|
12975
|
+
} catch {
|
|
12976
|
+
continue;
|
|
12162
12977
|
}
|
|
12163
12978
|
}
|
|
12164
|
-
|
|
12165
|
-
return sumOfScores / valset.length;
|
|
12166
|
-
}
|
|
12167
|
-
/**
|
|
12168
|
-
* Implements a Bayesian-inspired selection of the next configuration to try
|
|
12169
|
-
* This is a simplified version using Upper Confidence Bound (UCB) strategy
|
|
12170
|
-
*/
|
|
12171
|
-
selectNextConfiguration(evaluatedConfigs, maxBootstrappedDemos, maxLabeledExamples, instructions) {
|
|
12172
|
-
if (evaluatedConfigs.length < 5) {
|
|
12173
|
-
const instructionIndex = Math.floor(Math.random() * instructions.length);
|
|
12174
|
-
return {
|
|
12175
|
-
instruction: instructions[instructionIndex] || "",
|
|
12176
|
-
bootstrappedDemos: Math.floor(
|
|
12177
|
-
Math.random() * (maxBootstrappedDemos + 1)
|
|
12178
|
-
),
|
|
12179
|
-
labeledExamples: Math.floor(Math.random() * (maxLabeledExamples + 1))
|
|
12180
|
-
};
|
|
12181
|
-
}
|
|
12182
|
-
const sortedConfigs = [...evaluatedConfigs].sort(
|
|
12183
|
-
(a, b) => b.score - a.score
|
|
12184
|
-
);
|
|
12185
|
-
const topConfigs = sortedConfigs.slice(0, Math.min(3, sortedConfigs.length));
|
|
12186
|
-
const meanBootstrappedDemos = topConfigs.reduce((sum, c) => sum + c.config.bootstrappedDemos, 0) / topConfigs.length;
|
|
12187
|
-
const meanLabeledExamples = topConfigs.reduce((sum, c) => sum + c.config.labeledExamples, 0) / topConfigs.length;
|
|
12188
|
-
const popularInstructions = topConfigs.map((c) => c.config.instruction);
|
|
12189
|
-
const explorationFactor = Math.max(
|
|
12190
|
-
0.2,
|
|
12191
|
-
1 - evaluatedConfigs.length / this.numTrials
|
|
12192
|
-
);
|
|
12193
|
-
let newBootstrappedDemos;
|
|
12194
|
-
let newLabeledExamples;
|
|
12195
|
-
let newInstruction;
|
|
12196
|
-
if (Math.random() < 0.7) {
|
|
12197
|
-
newBootstrappedDemos = Math.min(
|
|
12198
|
-
maxBootstrappedDemos,
|
|
12199
|
-
Math.max(
|
|
12200
|
-
0,
|
|
12201
|
-
Math.round(
|
|
12202
|
-
meanBootstrappedDemos + (Math.random() * 2 - 1) * explorationFactor * 2
|
|
12203
|
-
)
|
|
12204
|
-
)
|
|
12205
|
-
);
|
|
12206
|
-
} else {
|
|
12207
|
-
newBootstrappedDemos = Math.floor(
|
|
12208
|
-
Math.random() * (maxBootstrappedDemos + 1)
|
|
12209
|
-
);
|
|
12210
|
-
}
|
|
12211
|
-
if (Math.random() < 0.7) {
|
|
12212
|
-
newLabeledExamples = Math.min(
|
|
12213
|
-
maxLabeledExamples,
|
|
12214
|
-
Math.max(
|
|
12215
|
-
0,
|
|
12216
|
-
Math.round(
|
|
12217
|
-
meanLabeledExamples + (Math.random() * 2 - 1) * explorationFactor * 2
|
|
12218
|
-
)
|
|
12219
|
-
)
|
|
12220
|
-
);
|
|
12221
|
-
} else {
|
|
12222
|
-
newLabeledExamples = Math.floor(Math.random() * (maxLabeledExamples + 1));
|
|
12223
|
-
}
|
|
12224
|
-
if (Math.random() < 0.7 && popularInstructions.length > 0) {
|
|
12225
|
-
const idx = Math.floor(Math.random() * popularInstructions.length);
|
|
12226
|
-
newInstruction = popularInstructions[idx] || "";
|
|
12227
|
-
} else {
|
|
12228
|
-
const idx = Math.floor(Math.random() * instructions.length);
|
|
12229
|
-
newInstruction = instructions[idx] || "";
|
|
12230
|
-
}
|
|
12231
|
-
return {
|
|
12232
|
-
instruction: newInstruction,
|
|
12233
|
-
bootstrappedDemos: newBootstrappedDemos,
|
|
12234
|
-
labeledExamples: newLabeledExamples
|
|
12235
|
-
};
|
|
12979
|
+
return count > 0 ? totalScore / count : 0;
|
|
12236
12980
|
}
|
|
12237
|
-
/**
|
|
12238
|
-
* Applies a configuration to a program instance
|
|
12239
|
-
*/
|
|
12240
12981
|
applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
|
|
12241
|
-
|
|
12242
|
-
|
|
12982
|
+
if (program.setInstruction) {
|
|
12983
|
+
program.setInstruction(config.instruction);
|
|
12984
|
+
}
|
|
12985
|
+
if (config.bootstrappedDemos > 0 && program.setDemos) {
|
|
12243
12986
|
program.setDemos(bootstrappedDemos.slice(0, config.bootstrappedDemos));
|
|
12244
12987
|
}
|
|
12245
|
-
if (config.labeledExamples > 0) {
|
|
12988
|
+
if (config.labeledExamples > 0 && program.setExamples) {
|
|
12246
12989
|
program.setExamples(labeledExamples.slice(0, config.labeledExamples));
|
|
12247
12990
|
}
|
|
12248
12991
|
}
|
|
12249
|
-
/**
|
|
12250
|
-
* Sets instruction to a program
|
|
12251
|
-
* Note: Workaround since setInstruction may not be available directly
|
|
12252
|
-
*/
|
|
12253
|
-
setInstructionToProgram(program, instruction) {
|
|
12254
|
-
const programWithInstruction = program;
|
|
12255
|
-
programWithInstruction.setInstruction?.(instruction);
|
|
12256
|
-
}
|
|
12257
12992
|
/**
|
|
12258
12993
|
* The main compile method to run MIPROv2 optimization
|
|
12259
|
-
* @param metricFn Evaluation metric function
|
|
12260
|
-
* @param options Optional configuration options
|
|
12261
|
-
* @returns The optimization result
|
|
12262
12994
|
*/
|
|
12263
|
-
async compile(metricFn, options) {
|
|
12995
|
+
async compile(program, metricFn, options) {
|
|
12996
|
+
const startTime = Date.now();
|
|
12997
|
+
this.setupRandomSeed();
|
|
12264
12998
|
const miproOptions = options;
|
|
12265
12999
|
if (miproOptions?.auto) {
|
|
12266
13000
|
this.configureAuto(miproOptions.auto);
|
|
12267
13001
|
}
|
|
12268
|
-
const
|
|
12269
|
-
|
|
12270
|
-
|
|
12271
|
-
|
|
12272
|
-
|
|
12273
|
-
`Using ${trainset.length} examples for training and ${valset.length} for validation`
|
|
13002
|
+
const valset = this.getValidationSet(options) || (miproOptions?.valset ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
|
|
13003
|
+
if (this.isLoggingEnabled(options)) {
|
|
13004
|
+
this.getLogger(options)?.(
|
|
13005
|
+
`Starting MIPROv2 optimization with ${this.numTrials} trials`,
|
|
13006
|
+
{ tags: ["optimizer", "start"] }
|
|
12274
13007
|
);
|
|
12275
|
-
|
|
12276
|
-
|
|
12277
|
-
|
|
12278
|
-
|
|
13008
|
+
this.getLogger(options)?.(
|
|
13009
|
+
`Using ${this.examples.length} examples for training and ${valset.length} for validation`,
|
|
13010
|
+
{ tags: ["optimizer", "config"] }
|
|
13011
|
+
);
|
|
13012
|
+
if (this.teacherAI) {
|
|
13013
|
+
this.getLogger(options)?.(
|
|
13014
|
+
"Using separate teacher model for instruction generation",
|
|
13015
|
+
{ tags: ["optimizer", "config"] }
|
|
13016
|
+
);
|
|
12279
13017
|
}
|
|
12280
|
-
const bootstrapperWithTeacher = new AxBootstrapFewShot({
|
|
12281
|
-
ai: this.ai,
|
|
12282
|
-
program: this.program,
|
|
12283
|
-
examples: this.examples,
|
|
12284
|
-
options: {
|
|
12285
|
-
maxDemos: this.maxBootstrappedDemos,
|
|
12286
|
-
maxRounds: 3,
|
|
12287
|
-
verboseMode: this.verbose,
|
|
12288
|
-
teacherAI: this.ai
|
|
12289
|
-
// Use the same AI but with the teacher program
|
|
12290
|
-
}
|
|
12291
|
-
});
|
|
12292
|
-
this.bootstrapper = bootstrapperWithTeacher;
|
|
12293
13018
|
}
|
|
12294
13019
|
let bootstrappedDemos = [];
|
|
12295
13020
|
if (this.maxBootstrappedDemos > 0) {
|
|
12296
|
-
bootstrappedDemos = await this.bootstrapFewShotExamples(metricFn);
|
|
12297
|
-
if (this.
|
|
12298
|
-
|
|
12299
|
-
`Generated ${bootstrappedDemos.length} bootstrapped demonstrations
|
|
13021
|
+
bootstrappedDemos = await this.bootstrapFewShotExamples(program, metricFn);
|
|
13022
|
+
if (this.isLoggingEnabled(options)) {
|
|
13023
|
+
this.getLogger(options)?.(
|
|
13024
|
+
`Generated ${bootstrappedDemos.length} bootstrapped demonstrations`,
|
|
13025
|
+
{ tags: ["optimizer", "result"] }
|
|
12300
13026
|
);
|
|
12301
13027
|
}
|
|
12302
13028
|
}
|
|
12303
13029
|
let labeledExamples = [];
|
|
12304
13030
|
if (this.maxLabeledDemos > 0) {
|
|
12305
13031
|
labeledExamples = this.selectLabeledExamples();
|
|
12306
|
-
if (this.
|
|
12307
|
-
|
|
12308
|
-
`Selected ${labeledExamples.length} labeled examples from training set
|
|
13032
|
+
if (this.isLoggingEnabled(options)) {
|
|
13033
|
+
this.getLogger(options)?.(
|
|
13034
|
+
`Selected ${labeledExamples.length} labeled examples from training set`,
|
|
13035
|
+
{ tags: ["optimizer", "result"] }
|
|
12309
13036
|
);
|
|
12310
13037
|
}
|
|
12311
13038
|
}
|
|
12312
|
-
const instructions = await this.proposeInstructionCandidates();
|
|
12313
|
-
if (this.
|
|
12314
|
-
|
|
13039
|
+
const instructions = await this.proposeInstructionCandidates(options);
|
|
13040
|
+
if (this.isLoggingEnabled(options)) {
|
|
13041
|
+
this.getLogger(options)?.(
|
|
13042
|
+
`Generated ${instructions.length} instruction candidates`,
|
|
13043
|
+
{ tags: ["optimizer", "result"] }
|
|
13044
|
+
);
|
|
13045
|
+
if (this.hasTeacherAI(options)) {
|
|
13046
|
+
this.getLogger(options)?.(
|
|
13047
|
+
"Using teacher AI for instruction generation",
|
|
13048
|
+
{ tags: ["optimizer", "config"] }
|
|
13049
|
+
);
|
|
13050
|
+
}
|
|
12315
13051
|
}
|
|
12316
|
-
const { bestConfig, bestScore } = await this.
|
|
13052
|
+
const { bestConfig, bestScore } = await this.runOptimization(
|
|
13053
|
+
program,
|
|
12317
13054
|
bootstrappedDemos,
|
|
12318
13055
|
labeledExamples,
|
|
12319
13056
|
instructions,
|
|
12320
13057
|
valset,
|
|
12321
|
-
metricFn
|
|
13058
|
+
metricFn,
|
|
13059
|
+
options
|
|
12322
13060
|
);
|
|
12323
|
-
if (this.
|
|
12324
|
-
|
|
12325
|
-
|
|
13061
|
+
if (this.isLoggingEnabled(options)) {
|
|
13062
|
+
this.getLogger(options)?.(
|
|
13063
|
+
`Optimization complete. Best score: ${bestScore}`,
|
|
13064
|
+
{ tags: ["optimizer", "complete"] }
|
|
13065
|
+
);
|
|
13066
|
+
this.getLogger(options)?.(
|
|
13067
|
+
`Best configuration: ${JSON.stringify(bestConfig)}`,
|
|
13068
|
+
{ tags: ["optimizer", "result"] }
|
|
13069
|
+
);
|
|
12326
13070
|
}
|
|
12327
|
-
this.
|
|
12328
|
-
this.
|
|
13071
|
+
if (this.checkTargetScore(bestScore)) {
|
|
13072
|
+
this.triggerEarlyStopping(
|
|
13073
|
+
`Target score ${this.targetScore} reached with score ${bestScore}`,
|
|
13074
|
+
this.numTrials
|
|
13075
|
+
);
|
|
13076
|
+
}
|
|
13077
|
+
let signature;
|
|
13078
|
+
if ("getSignature" in program && typeof program.getSignature === "function") {
|
|
13079
|
+
signature = program.getSignature();
|
|
13080
|
+
} else {
|
|
13081
|
+
signature = "input -> output";
|
|
13082
|
+
}
|
|
13083
|
+
const optimizedGen = new AxGen(signature);
|
|
13084
|
+
this.applyConfigToAxGen(
|
|
13085
|
+
optimizedGen,
|
|
12329
13086
|
bestConfig,
|
|
12330
13087
|
bootstrappedDemos,
|
|
12331
13088
|
labeledExamples
|
|
12332
13089
|
);
|
|
13090
|
+
this.updateResourceUsage(startTime);
|
|
13091
|
+
this.stats.convergenceInfo.converged = true;
|
|
13092
|
+
this.stats.convergenceInfo.finalImprovement = bestScore;
|
|
13093
|
+
await this.saveFinalCheckpoint(
|
|
13094
|
+
"MiPRO",
|
|
13095
|
+
this.getConfiguration(),
|
|
13096
|
+
bestScore,
|
|
13097
|
+
bestConfig,
|
|
13098
|
+
{
|
|
13099
|
+
bootstrappedDemos: bootstrappedDemos.length,
|
|
13100
|
+
labeledExamples: labeledExamples.length,
|
|
13101
|
+
instructions: instructions.length,
|
|
13102
|
+
optimizedGen: !!optimizedGen
|
|
13103
|
+
},
|
|
13104
|
+
options
|
|
13105
|
+
);
|
|
12333
13106
|
return {
|
|
12334
|
-
|
|
12335
|
-
|
|
13107
|
+
demos: bootstrappedDemos,
|
|
13108
|
+
stats: this.stats,
|
|
13109
|
+
bestScore,
|
|
13110
|
+
optimizedGen,
|
|
13111
|
+
finalConfiguration: {
|
|
13112
|
+
instruction: bestConfig.instruction,
|
|
13113
|
+
bootstrappedDemos: bestConfig.bootstrappedDemos,
|
|
13114
|
+
labeledExamples: bestConfig.labeledExamples,
|
|
13115
|
+
numCandidates: this.numCandidates,
|
|
13116
|
+
numTrials: this.numTrials
|
|
13117
|
+
}
|
|
12336
13118
|
};
|
|
12337
13119
|
}
|
|
12338
13120
|
/**
|
|
12339
|
-
*
|
|
12340
|
-
* @returns Optimization statistics or undefined if not available
|
|
13121
|
+
* Applies a configuration to an AxGen instance
|
|
12341
13122
|
*/
|
|
12342
|
-
|
|
12343
|
-
|
|
13123
|
+
applyConfigToAxGen(axgen, config, bootstrappedDemos, labeledExamples) {
|
|
13124
|
+
if ("setInstruction" in axgen && typeof axgen.setInstruction === "function") {
|
|
13125
|
+
axgen.setInstruction(config.instruction);
|
|
13126
|
+
}
|
|
13127
|
+
if (config.bootstrappedDemos > 0) {
|
|
13128
|
+
axgen.setDemos(bootstrappedDemos.slice(0, config.bootstrappedDemos));
|
|
13129
|
+
}
|
|
13130
|
+
if (config.labeledExamples > 0) {
|
|
13131
|
+
axgen.setExamples(
|
|
13132
|
+
labeledExamples.slice(
|
|
13133
|
+
0,
|
|
13134
|
+
config.labeledExamples
|
|
13135
|
+
)
|
|
13136
|
+
);
|
|
13137
|
+
}
|
|
13138
|
+
}
|
|
13139
|
+
/**
|
|
13140
|
+
* Get optimizer-specific configuration
|
|
13141
|
+
* @returns Current optimizer configuration
|
|
13142
|
+
*/
|
|
13143
|
+
getConfiguration() {
|
|
13144
|
+
return {
|
|
13145
|
+
numCandidates: this.numCandidates,
|
|
13146
|
+
initTemperature: this.initTemperature,
|
|
13147
|
+
maxBootstrappedDemos: this.maxBootstrappedDemos,
|
|
13148
|
+
maxLabeledDemos: this.maxLabeledDemos,
|
|
13149
|
+
numTrials: this.numTrials,
|
|
13150
|
+
minibatch: this.minibatch,
|
|
13151
|
+
minibatchSize: this.minibatchSize,
|
|
13152
|
+
minibatchFullEvalSteps: this.minibatchFullEvalSteps,
|
|
13153
|
+
programAwareProposer: this.programAwareProposer,
|
|
13154
|
+
dataAwareProposer: this.dataAwareProposer,
|
|
13155
|
+
tipAwareProposer: this.tipAwareProposer,
|
|
13156
|
+
fewshotAwareProposer: this.fewshotAwareProposer,
|
|
13157
|
+
earlyStoppingTrials: this.earlyStoppingTrials,
|
|
13158
|
+
minImprovementThreshold: this.minImprovementThreshold,
|
|
13159
|
+
bayesianOptimization: this.bayesianOptimization,
|
|
13160
|
+
acquisitionFunction: this.acquisitionFunction,
|
|
13161
|
+
explorationWeight: this.explorationWeight
|
|
13162
|
+
};
|
|
13163
|
+
}
|
|
13164
|
+
/**
|
|
13165
|
+
* Update optimizer configuration
|
|
13166
|
+
* @param config New configuration to merge with existing
|
|
13167
|
+
*/
|
|
13168
|
+
updateConfiguration(config) {
|
|
13169
|
+
if (config.numCandidates !== void 0) {
|
|
13170
|
+
this.numCandidates = config.numCandidates;
|
|
13171
|
+
}
|
|
13172
|
+
if (config.initTemperature !== void 0) {
|
|
13173
|
+
this.initTemperature = config.initTemperature;
|
|
13174
|
+
}
|
|
13175
|
+
if (config.maxBootstrappedDemos !== void 0) {
|
|
13176
|
+
this.maxBootstrappedDemos = config.maxBootstrappedDemos;
|
|
13177
|
+
}
|
|
13178
|
+
if (config.maxLabeledDemos !== void 0) {
|
|
13179
|
+
this.maxLabeledDemos = config.maxLabeledDemos;
|
|
13180
|
+
}
|
|
13181
|
+
if (config.numTrials !== void 0) {
|
|
13182
|
+
this.numTrials = config.numTrials;
|
|
13183
|
+
}
|
|
13184
|
+
if (config.minibatch !== void 0) {
|
|
13185
|
+
this.minibatch = config.minibatch;
|
|
13186
|
+
}
|
|
13187
|
+
if (config.minibatchSize !== void 0) {
|
|
13188
|
+
this.minibatchSize = config.minibatchSize;
|
|
13189
|
+
}
|
|
13190
|
+
if (config.earlyStoppingTrials !== void 0) {
|
|
13191
|
+
this.earlyStoppingTrials = config.earlyStoppingTrials;
|
|
13192
|
+
}
|
|
13193
|
+
if (config.minImprovementThreshold !== void 0) {
|
|
13194
|
+
this.minImprovementThreshold = config.minImprovementThreshold;
|
|
13195
|
+
}
|
|
13196
|
+
}
|
|
13197
|
+
/**
|
|
13198
|
+
* Reset optimizer state for reuse with different programs
|
|
13199
|
+
*/
|
|
13200
|
+
reset() {
|
|
13201
|
+
super.reset();
|
|
13202
|
+
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13203
|
+
}
|
|
13204
|
+
/**
|
|
13205
|
+
* Validate that the optimizer can handle the given program
|
|
13206
|
+
* @param program Program to validate
|
|
13207
|
+
* @returns Validation result with any issues found
|
|
13208
|
+
*/
|
|
13209
|
+
validateProgram(program) {
|
|
13210
|
+
const result = super.validateProgram(program);
|
|
13211
|
+
if (this.examples.length < this.maxBootstrappedDemos + this.maxLabeledDemos) {
|
|
13212
|
+
result.issues.push(
|
|
13213
|
+
`Not enough examples: need at least ${this.maxBootstrappedDemos + this.maxLabeledDemos}, got ${this.examples.length}`
|
|
13214
|
+
);
|
|
13215
|
+
result.suggestions.push(
|
|
13216
|
+
"Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
|
|
13217
|
+
);
|
|
13218
|
+
}
|
|
13219
|
+
const valSetSize = this.getValidationSet().length;
|
|
13220
|
+
if (valSetSize < 5) {
|
|
13221
|
+
result.issues.push(
|
|
13222
|
+
"Validation set too small for reliable MiPRO optimization"
|
|
13223
|
+
);
|
|
13224
|
+
result.suggestions.push(
|
|
13225
|
+
"Provide more examples or a larger validation set"
|
|
13226
|
+
);
|
|
13227
|
+
}
|
|
13228
|
+
return {
|
|
13229
|
+
isValid: result.issues.length === 0,
|
|
13230
|
+
issues: result.issues,
|
|
13231
|
+
suggestions: result.suggestions
|
|
13232
|
+
};
|
|
12344
13233
|
}
|
|
12345
13234
|
};
|
|
12346
13235
|
|
|
@@ -12587,7 +13476,7 @@ var AxTestPrompt = class {
|
|
|
12587
13476
|
throw new Error("Invalid example");
|
|
12588
13477
|
}
|
|
12589
13478
|
const res = await this.program.forward(this.ai, ex);
|
|
12590
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
13479
|
+
const score = await metricFn({ prediction: res, example: ex });
|
|
12591
13480
|
sumOfScores += score;
|
|
12592
13481
|
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
12593
13482
|
updateProgressBar(i, total, sumOfScores, et, "Testing Prompt", 30);
|
|
@@ -14621,7 +15510,6 @@ var AxRAG = class extends AxChainOfThought {
|
|
|
14621
15510
|
);
|
|
14622
15511
|
this.genQuery = new AxGen(qsig);
|
|
14623
15512
|
this.queryFn = queryFn;
|
|
14624
|
-
this.register(this.genQuery);
|
|
14625
15513
|
}
|
|
14626
15514
|
async forward(ai, values, options) {
|
|
14627
15515
|
let question;
|
|
@@ -14698,6 +15586,7 @@ export {
|
|
|
14698
15586
|
AxAssertionError,
|
|
14699
15587
|
AxBalancer,
|
|
14700
15588
|
AxBaseAI,
|
|
15589
|
+
AxBaseOptimizer,
|
|
14701
15590
|
AxBootstrapFewShot,
|
|
14702
15591
|
AxChainOfThought,
|
|
14703
15592
|
AxDB,
|
|
@@ -14707,6 +15596,7 @@ export {
|
|
|
14707
15596
|
AxDBMemory,
|
|
14708
15597
|
AxDBPinecone,
|
|
14709
15598
|
AxDBWeaviate,
|
|
15599
|
+
AxDefaultCostTracker,
|
|
14710
15600
|
AxDefaultQueryRewriter,
|
|
14711
15601
|
AxDefaultResultReranker,
|
|
14712
15602
|
AxDockerSession,
|
|
@@ -14776,6 +15666,10 @@ export {
|
|
|
14776
15666
|
axAITogetherDefaultConfig,
|
|
14777
15667
|
axBaseAIDefaultConfig,
|
|
14778
15668
|
axBaseAIDefaultCreativeConfig,
|
|
15669
|
+
axCreateDefaultLogger,
|
|
15670
|
+
axCreateDefaultTextLogger,
|
|
15671
|
+
axCreateOptimizerLogger,
|
|
15672
|
+
axDefaultOptimizerLogger,
|
|
14779
15673
|
axGlobals,
|
|
14780
15674
|
axModelInfoAnthropic,
|
|
14781
15675
|
axModelInfoCohere,
|