@ax-llm/ax 12.0.13 → 12.0.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +409 -64
- package/index.cjs.map +1 -1
- package/index.d.cts +47 -4
- package/index.d.ts +47 -4
- package/index.js +409 -64
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.cjs
CHANGED
|
@@ -7710,6 +7710,22 @@ function parseFunctionCalls(ai, functionCalls, values, model) {
|
|
|
7710
7710
|
}));
|
|
7711
7711
|
return funcs;
|
|
7712
7712
|
}
|
|
7713
|
+
function createFunctionConfig(functionList, definedFunctionCall, firstStep) {
|
|
7714
|
+
let functionCall = definedFunctionCall;
|
|
7715
|
+
if (!firstStep && (functionCall === "required" || typeof functionCall === "function")) {
|
|
7716
|
+
return { functions: [], functionCall: void 0 };
|
|
7717
|
+
}
|
|
7718
|
+
if (!functionList) {
|
|
7719
|
+
return { functions: [], functionCall };
|
|
7720
|
+
}
|
|
7721
|
+
const functions = functionList.map((f2) => {
|
|
7722
|
+
if ("toFunction" in f2) {
|
|
7723
|
+
return f2.toFunction();
|
|
7724
|
+
}
|
|
7725
|
+
return f2;
|
|
7726
|
+
}).flat();
|
|
7727
|
+
return { functions, functionCall };
|
|
7728
|
+
}
|
|
7713
7729
|
|
|
7714
7730
|
// dsp/processResponse.ts
|
|
7715
7731
|
var import_web5 = require("stream/web");
|
|
@@ -10158,7 +10174,8 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
10158
10174
|
mem,
|
|
10159
10175
|
options,
|
|
10160
10176
|
traceContext,
|
|
10161
|
-
|
|
10177
|
+
functions,
|
|
10178
|
+
functionCall
|
|
10162
10179
|
}) {
|
|
10163
10180
|
const {
|
|
10164
10181
|
sessionId,
|
|
@@ -10166,8 +10183,6 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
10166
10183
|
model,
|
|
10167
10184
|
rateLimiter,
|
|
10168
10185
|
stream,
|
|
10169
|
-
functions: _functions,
|
|
10170
|
-
functionCall: _functionCall,
|
|
10171
10186
|
thinkingTokenBudget,
|
|
10172
10187
|
showThoughts
|
|
10173
10188
|
} = options ?? {};
|
|
@@ -10178,11 +10193,6 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
10178
10193
|
if (chatPrompt.length === 0) {
|
|
10179
10194
|
throw new Error("No chat prompt found");
|
|
10180
10195
|
}
|
|
10181
|
-
const functions = _functions?.map((f2) => "toFunction" in f2 ? f2.toFunction() : f2)?.flat();
|
|
10182
|
-
let functionCall = _functionCall ?? this.options?.functionCall;
|
|
10183
|
-
if (!firstStep && (functionCall === "required" || typeof functionCall === "function")) {
|
|
10184
|
-
functionCall = void 0;
|
|
10185
|
-
}
|
|
10186
10196
|
const modelConfig = {
|
|
10187
10197
|
...options?.modelConfig,
|
|
10188
10198
|
...options?.sampleCount ? { n: options.sampleCount } : {},
|
|
@@ -10219,18 +10229,24 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
10219
10229
|
span,
|
|
10220
10230
|
traceContext
|
|
10221
10231
|
}) {
|
|
10222
|
-
const { sessionId, traceId, functions:
|
|
10232
|
+
const { sessionId, traceId, functions: functionList } = options ?? {};
|
|
10233
|
+
const definedFunctionCall = options?.functionCall ?? this.options?.functionCall;
|
|
10223
10234
|
const strictMode = options?.strictMode ?? false;
|
|
10224
10235
|
const model = options.model;
|
|
10225
10236
|
const states = this.createStates(options.sampleCount ?? 1);
|
|
10226
10237
|
const usage = this.usage;
|
|
10227
|
-
const functions
|
|
10238
|
+
const { functions, functionCall } = createFunctionConfig(
|
|
10239
|
+
functionList,
|
|
10240
|
+
definedFunctionCall,
|
|
10241
|
+
firstStep
|
|
10242
|
+
);
|
|
10228
10243
|
const res = await this.forwardSendRequest({
|
|
10229
10244
|
ai,
|
|
10230
10245
|
mem,
|
|
10231
10246
|
options,
|
|
10232
10247
|
traceContext,
|
|
10233
|
-
|
|
10248
|
+
functions,
|
|
10249
|
+
functionCall
|
|
10234
10250
|
});
|
|
10235
10251
|
if (res instanceof import_web6.ReadableStream) {
|
|
10236
10252
|
yield* processStreamingResponse({
|
|
@@ -11874,13 +11890,6 @@ var AxBaseOptimizer = class {
|
|
|
11874
11890
|
if (this.logger) {
|
|
11875
11891
|
return this.logger;
|
|
11876
11892
|
}
|
|
11877
|
-
try {
|
|
11878
|
-
const aiLogger = this.studentAI.getLogger();
|
|
11879
|
-
if (aiLogger) {
|
|
11880
|
-
return aiLogger;
|
|
11881
|
-
}
|
|
11882
|
-
} catch {
|
|
11883
|
-
}
|
|
11884
11893
|
return axDefaultOptimizerLogger;
|
|
11885
11894
|
}
|
|
11886
11895
|
/**
|
|
@@ -13572,6 +13581,11 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13572
13581
|
bayesianOptimization;
|
|
13573
13582
|
acquisitionFunction;
|
|
13574
13583
|
explorationWeight;
|
|
13584
|
+
// Self-consistency / multiple sampling
|
|
13585
|
+
sampleCount;
|
|
13586
|
+
// Surrogate model state for Bayesian optimization
|
|
13587
|
+
miproConfigHistory = [];
|
|
13588
|
+
surrogateModel = /* @__PURE__ */ new Map();
|
|
13575
13589
|
constructor(args) {
|
|
13576
13590
|
super(args);
|
|
13577
13591
|
const options = args.options || {};
|
|
@@ -13593,6 +13607,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13593
13607
|
this.bayesianOptimization = options.bayesianOptimization ?? false;
|
|
13594
13608
|
this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
|
|
13595
13609
|
this.explorationWeight = options.explorationWeight ?? 0.1;
|
|
13610
|
+
this.sampleCount = options.sampleCount ?? 1;
|
|
13596
13611
|
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13597
13612
|
}
|
|
13598
13613
|
/**
|
|
@@ -13637,43 +13652,186 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13637
13652
|
];
|
|
13638
13653
|
}
|
|
13639
13654
|
/**
|
|
13640
|
-
* Generates
|
|
13655
|
+
* Generates program summary for context-aware instruction generation
|
|
13656
|
+
*/
|
|
13657
|
+
async generateProgramSummary(program, ai) {
|
|
13658
|
+
let signature = "input -> output";
|
|
13659
|
+
if ("getSignature" in program && typeof program.getSignature === "function") {
|
|
13660
|
+
signature = program.getSignature();
|
|
13661
|
+
}
|
|
13662
|
+
const summaryPrompt = `
|
|
13663
|
+
Analyze this language model program and provide a concise summary of its purpose and structure.
|
|
13664
|
+
|
|
13665
|
+
Program Signature: ${signature}
|
|
13666
|
+
|
|
13667
|
+
Provide a 2-3 sentence summary focusing on:
|
|
13668
|
+
1. The main task or purpose of this program
|
|
13669
|
+
2. The input-output relationship
|
|
13670
|
+
3. Any special constraints or requirements
|
|
13671
|
+
|
|
13672
|
+
Summary:`;
|
|
13673
|
+
try {
|
|
13674
|
+
const response = await ai.chat({
|
|
13675
|
+
chatPrompt: [{ role: "user", content: summaryPrompt }]
|
|
13676
|
+
});
|
|
13677
|
+
if ("results" in response) {
|
|
13678
|
+
return response.results[0]?.content?.trim() || "General language model program";
|
|
13679
|
+
}
|
|
13680
|
+
return "General language model program";
|
|
13681
|
+
} catch {
|
|
13682
|
+
return "General language model program";
|
|
13683
|
+
}
|
|
13684
|
+
}
|
|
13685
|
+
/**
|
|
13686
|
+
* Generates dataset summary for context-aware instruction generation
|
|
13687
|
+
*/
|
|
13688
|
+
async generateDatasetSummary(examples, ai) {
|
|
13689
|
+
if (examples.length === 0) return "No examples available";
|
|
13690
|
+
const sampleSize = Math.min(this.viewDataBatchSize, examples.length);
|
|
13691
|
+
const sampledExamples = examples.slice(0, sampleSize);
|
|
13692
|
+
const exampleTexts = sampledExamples.map((ex, i) => `Example ${i + 1}: ${JSON.stringify(ex)}`).join("\n");
|
|
13693
|
+
const summaryPrompt = `
|
|
13694
|
+
Analyze this dataset and provide a concise summary of its characteristics.
|
|
13695
|
+
|
|
13696
|
+
Sample Examples:
|
|
13697
|
+
${exampleTexts}
|
|
13698
|
+
|
|
13699
|
+
Provide a 2-3 sentence summary focusing on:
|
|
13700
|
+
1. The type of data and domain
|
|
13701
|
+
2. Common patterns or structures in the examples
|
|
13702
|
+
3. Key challenges or requirements for processing this data
|
|
13703
|
+
|
|
13704
|
+
Dataset Summary:`;
|
|
13705
|
+
try {
|
|
13706
|
+
const response = await ai.chat({
|
|
13707
|
+
chatPrompt: [{ role: "user", content: summaryPrompt }]
|
|
13708
|
+
});
|
|
13709
|
+
if ("results" in response) {
|
|
13710
|
+
return response.results[0]?.content?.trim() || "General dataset";
|
|
13711
|
+
}
|
|
13712
|
+
return "General dataset";
|
|
13713
|
+
} catch {
|
|
13714
|
+
return "General dataset";
|
|
13715
|
+
}
|
|
13716
|
+
}
|
|
13717
|
+
/**
|
|
13718
|
+
* Enhanced instruction generation using AI with program and data awareness
|
|
13719
|
+
*/
|
|
13720
|
+
async generateInstruction({
|
|
13721
|
+
tip,
|
|
13722
|
+
candidateIndex,
|
|
13723
|
+
ai,
|
|
13724
|
+
programSummary,
|
|
13725
|
+
datasetSummary,
|
|
13726
|
+
previousInstructions = []
|
|
13727
|
+
}) {
|
|
13728
|
+
let contextInfo = "";
|
|
13729
|
+
if (this.programAwareProposer && programSummary) {
|
|
13730
|
+
contextInfo += `
|
|
13731
|
+
Program Context: ${programSummary}`;
|
|
13732
|
+
}
|
|
13733
|
+
if (this.dataAwareProposer && datasetSummary) {
|
|
13734
|
+
contextInfo += `
|
|
13735
|
+
Dataset Context: ${datasetSummary}`;
|
|
13736
|
+
}
|
|
13737
|
+
if (this.fewshotAwareProposer && previousInstructions.length > 0) {
|
|
13738
|
+
contextInfo += `
|
|
13739
|
+
Previous Instructions (avoid repeating): ${previousInstructions.slice(-3).join("; ")}`;
|
|
13740
|
+
}
|
|
13741
|
+
const instructionPrompt = `
|
|
13742
|
+
Generate a high-quality instruction for a language model program.
|
|
13743
|
+
|
|
13744
|
+
${contextInfo}
|
|
13745
|
+
|
|
13746
|
+
${tip ? `Tip: ${tip}` : ""}
|
|
13747
|
+
|
|
13748
|
+
Requirements:
|
|
13749
|
+
1. Be specific and actionable
|
|
13750
|
+
2. Focus on accuracy and clarity
|
|
13751
|
+
3. Consider the program's purpose and data characteristics
|
|
13752
|
+
4. Make the instruction distinct from previous ones
|
|
13753
|
+
5. Keep it concise but comprehensive
|
|
13754
|
+
|
|
13755
|
+
Generate a single, well-crafted instruction:
|
|
13756
|
+
Instruction:`;
|
|
13757
|
+
try {
|
|
13758
|
+
const response = await ai.chat({
|
|
13759
|
+
chatPrompt: [
|
|
13760
|
+
{
|
|
13761
|
+
role: "user",
|
|
13762
|
+
content: instructionPrompt
|
|
13763
|
+
}
|
|
13764
|
+
]
|
|
13765
|
+
});
|
|
13766
|
+
if ("results" in response) {
|
|
13767
|
+
const instruction2 = response.results[0]?.content?.trim();
|
|
13768
|
+
if (instruction2 && instruction2.length > 10) {
|
|
13769
|
+
return instruction2;
|
|
13770
|
+
}
|
|
13771
|
+
}
|
|
13772
|
+
} catch (error) {
|
|
13773
|
+
if (this.isLoggingEnabled()) {
|
|
13774
|
+
this.getLogger()?.(`Failed to generate AI instruction: ${error}`, {
|
|
13775
|
+
tags: ["optimizer", "warning"]
|
|
13776
|
+
});
|
|
13777
|
+
}
|
|
13778
|
+
}
|
|
13779
|
+
const enhancedTemplates = [
|
|
13780
|
+
"Analyze the input systematically and provide a precise, well-reasoned response.",
|
|
13781
|
+
"Think through this step-by-step, considering all relevant factors before responding.",
|
|
13782
|
+
"Examine the input carefully and generate an accurate, detailed answer.",
|
|
13783
|
+
"Process the information methodically and deliver a clear, comprehensive response.",
|
|
13784
|
+
"Consider the context thoroughly and provide a thoughtful, accurate answer."
|
|
13785
|
+
];
|
|
13786
|
+
let instruction = enhancedTemplates[candidateIndex % enhancedTemplates.length] || enhancedTemplates[0];
|
|
13787
|
+
if (tip) {
|
|
13788
|
+
instruction = `${instruction} ${tip}`;
|
|
13789
|
+
}
|
|
13790
|
+
return instruction;
|
|
13791
|
+
}
|
|
13792
|
+
/**
|
|
13793
|
+
* Generates instruction candidates using enhanced AI-powered generation
|
|
13641
13794
|
* @param options Optional compile options that may override teacher AI
|
|
13642
13795
|
* @returns Array of generated instruction candidates
|
|
13643
13796
|
*/
|
|
13644
|
-
async proposeInstructionCandidates(options) {
|
|
13797
|
+
async proposeInstructionCandidates(program, options) {
|
|
13645
13798
|
const instructions = [];
|
|
13646
13799
|
const aiToUse = this.getTeacherOrStudentAI(options);
|
|
13800
|
+
let programSummary;
|
|
13801
|
+
let datasetSummary;
|
|
13802
|
+
if (this.programAwareProposer) {
|
|
13803
|
+
programSummary = await this.generateProgramSummary(program, aiToUse);
|
|
13804
|
+
if (this.isLoggingEnabled(options)) {
|
|
13805
|
+
this.getLogger(options)?.(`Program summary: ${programSummary}`, {
|
|
13806
|
+
tags: ["optimizer", "config"]
|
|
13807
|
+
});
|
|
13808
|
+
}
|
|
13809
|
+
}
|
|
13810
|
+
if (this.dataAwareProposer) {
|
|
13811
|
+
datasetSummary = await this.generateDatasetSummary(this.examples, aiToUse);
|
|
13812
|
+
if (this.isLoggingEnabled(options)) {
|
|
13813
|
+
this.getLogger(options)?.(`Dataset summary: ${datasetSummary}`, {
|
|
13814
|
+
tags: ["optimizer", "config"]
|
|
13815
|
+
});
|
|
13816
|
+
}
|
|
13817
|
+
}
|
|
13647
13818
|
const tips = this.tipAwareProposer ? this.generateTips() : [];
|
|
13648
13819
|
for (let i = 0; i < this.numCandidates; i++) {
|
|
13649
13820
|
const tipIndex = tips.length > 0 ? i % tips.length : -1;
|
|
13650
|
-
const tipToUse = tipIndex >= 0 ? tips[tipIndex] :
|
|
13821
|
+
const tipToUse = tipIndex >= 0 ? tips[tipIndex] : void 0;
|
|
13651
13822
|
const instruction = await this.generateInstruction({
|
|
13652
13823
|
tip: tipToUse,
|
|
13653
13824
|
candidateIndex: i,
|
|
13654
|
-
ai: aiToUse
|
|
13825
|
+
ai: aiToUse,
|
|
13826
|
+
programSummary,
|
|
13827
|
+
datasetSummary,
|
|
13828
|
+
previousInstructions: instructions
|
|
13829
|
+
// Pass previous instructions for diversity
|
|
13655
13830
|
});
|
|
13656
13831
|
instructions.push(instruction);
|
|
13657
13832
|
}
|
|
13658
13833
|
return instructions;
|
|
13659
13834
|
}
|
|
13660
|
-
async generateInstruction({
|
|
13661
|
-
tip,
|
|
13662
|
-
candidateIndex
|
|
13663
|
-
}) {
|
|
13664
|
-
const baseInstructions = [
|
|
13665
|
-
"Analyze the input carefully and provide a detailed response.",
|
|
13666
|
-
"Think step by step and provide a clear answer.",
|
|
13667
|
-
"Consider all aspects of the input before responding.",
|
|
13668
|
-
"Provide a concise but comprehensive response.",
|
|
13669
|
-
"Focus on accuracy and clarity in your response."
|
|
13670
|
-
];
|
|
13671
|
-
let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
|
|
13672
|
-
if (tip) {
|
|
13673
|
-
instruction = `${instruction} ${tip}`;
|
|
13674
|
-
}
|
|
13675
|
-
return instruction;
|
|
13676
|
-
}
|
|
13677
13835
|
/**
|
|
13678
13836
|
* Bootstraps few-shot examples for the program
|
|
13679
13837
|
*/
|
|
@@ -13718,7 +13876,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13718
13876
|
/**
|
|
13719
13877
|
* Runs optimization to find the best combination of few-shot examples and instructions
|
|
13720
13878
|
*/
|
|
13721
|
-
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions,
|
|
13879
|
+
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, validationExamples, metricFn, options) {
|
|
13722
13880
|
let bestConfig = {
|
|
13723
13881
|
instruction: instructions[0] || "",
|
|
13724
13882
|
bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
|
|
@@ -13754,25 +13912,37 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13754
13912
|
);
|
|
13755
13913
|
}
|
|
13756
13914
|
for (let i = startRound; i < this.numTrials; i++) {
|
|
13757
|
-
|
|
13758
|
-
|
|
13759
|
-
|
|
13760
|
-
|
|
13761
|
-
|
|
13762
|
-
|
|
13763
|
-
|
|
13764
|
-
|
|
13765
|
-
|
|
13766
|
-
|
|
13767
|
-
|
|
13915
|
+
let config;
|
|
13916
|
+
if (this.bayesianOptimization && this.miproConfigHistory.length > 2) {
|
|
13917
|
+
config = await this.selectConfigurationViaBayesianOptimization(
|
|
13918
|
+
instructions,
|
|
13919
|
+
bootstrappedDemos,
|
|
13920
|
+
labeledExamples
|
|
13921
|
+
);
|
|
13922
|
+
} else {
|
|
13923
|
+
config = {
|
|
13924
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
13925
|
+
bootstrappedDemos: Math.min(
|
|
13926
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
13927
|
+
this.maxBootstrappedDemos
|
|
13928
|
+
),
|
|
13929
|
+
labeledExamples: Math.min(
|
|
13930
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
13931
|
+
this.maxLabeledDemos
|
|
13932
|
+
)
|
|
13933
|
+
};
|
|
13934
|
+
}
|
|
13768
13935
|
const score = await this.evaluateConfig(
|
|
13769
13936
|
program,
|
|
13770
13937
|
config,
|
|
13771
13938
|
bootstrappedDemos,
|
|
13772
13939
|
labeledExamples,
|
|
13773
|
-
|
|
13774
|
-
metricFn
|
|
13940
|
+
validationExamples,
|
|
13941
|
+
metricFn,
|
|
13942
|
+
i + 1
|
|
13943
|
+
// Pass current trial number for adaptive evaluation
|
|
13775
13944
|
);
|
|
13945
|
+
this.updateSurrogateModel(config, score);
|
|
13776
13946
|
scoreHistory.push(score);
|
|
13777
13947
|
const improvement = score - bestScore;
|
|
13778
13948
|
if (improvement > this.minImprovementThreshold) {
|
|
@@ -13854,7 +14024,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13854
14024
|
this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
|
|
13855
14025
|
return { bestConfig, bestScore };
|
|
13856
14026
|
}
|
|
13857
|
-
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples,
|
|
14027
|
+
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, validationExamples, metricFn, currentTrial = 0) {
|
|
13858
14028
|
const testProgram = { ...program };
|
|
13859
14029
|
this.applyConfigToProgram(
|
|
13860
14030
|
testProgram,
|
|
@@ -13864,12 +14034,31 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13864
14034
|
);
|
|
13865
14035
|
let totalScore = 0;
|
|
13866
14036
|
let count = 0;
|
|
13867
|
-
|
|
14037
|
+
let evalSize;
|
|
14038
|
+
if (this.minibatch) {
|
|
14039
|
+
const baseSize = Math.min(this.minibatchSize, validationExamples.length);
|
|
14040
|
+
const isFullEvalTrial = currentTrial % this.minibatchFullEvalSteps === 0;
|
|
14041
|
+
if (isFullEvalTrial || currentTrial > this.numTrials * 0.8) {
|
|
14042
|
+
evalSize = Math.min(validationExamples.length, baseSize * 2);
|
|
14043
|
+
} else {
|
|
14044
|
+
evalSize = Math.max(3, Math.min(baseSize, validationExamples.length));
|
|
14045
|
+
}
|
|
14046
|
+
} else {
|
|
14047
|
+
evalSize = validationExamples.length;
|
|
14048
|
+
}
|
|
14049
|
+
const evalIndices = this.shuffleArray([
|
|
14050
|
+
...Array(validationExamples.length).keys()
|
|
14051
|
+
]).slice(0, evalSize);
|
|
14052
|
+
const evalSet = evalIndices.map((i) => validationExamples[i]);
|
|
13868
14053
|
for (const example of evalSet) {
|
|
13869
14054
|
try {
|
|
13870
14055
|
const prediction = await testProgram.forward(
|
|
13871
14056
|
this.studentAI,
|
|
13872
|
-
example
|
|
14057
|
+
example,
|
|
14058
|
+
this.sampleCount > 1 ? {
|
|
14059
|
+
sampleCount: this.sampleCount,
|
|
14060
|
+
resultPicker: axMajorityVotePicker()
|
|
14061
|
+
} : void 0
|
|
13873
14062
|
);
|
|
13874
14063
|
const score = await metricFn({ prediction, example });
|
|
13875
14064
|
totalScore += score;
|
|
@@ -13881,6 +14070,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13881
14070
|
}
|
|
13882
14071
|
return count > 0 ? totalScore / count : 0;
|
|
13883
14072
|
}
|
|
14073
|
+
/**
|
|
14074
|
+
* Fisher-Yates shuffle for stochastic evaluation
|
|
14075
|
+
*/
|
|
14076
|
+
shuffleArray(array) {
|
|
14077
|
+
const shuffled = [...array];
|
|
14078
|
+
for (let i = shuffled.length - 1; i > 0; i--) {
|
|
14079
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
14080
|
+
[shuffled[i], shuffled[j]] = [shuffled[j], shuffled[i]];
|
|
14081
|
+
}
|
|
14082
|
+
return shuffled;
|
|
14083
|
+
}
|
|
13884
14084
|
applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
|
|
13885
14085
|
if (program.setInstruction) {
|
|
13886
14086
|
program.setInstruction(config.instruction);
|
|
@@ -13902,14 +14102,14 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13902
14102
|
if (miproOptions?.auto) {
|
|
13903
14103
|
this.configureAuto(miproOptions.auto);
|
|
13904
14104
|
}
|
|
13905
|
-
const
|
|
14105
|
+
const validationExamples = this.getValidationSet(options) || (miproOptions?.validationExamples ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
|
|
13906
14106
|
if (this.isLoggingEnabled(options)) {
|
|
13907
14107
|
this.getLogger(options)?.(
|
|
13908
14108
|
`Starting MIPROv2 optimization with ${this.numTrials} trials`,
|
|
13909
14109
|
{ tags: ["optimizer", "start"] }
|
|
13910
14110
|
);
|
|
13911
14111
|
this.getLogger(options)?.(
|
|
13912
|
-
`Using ${this.examples.length} examples for training and ${
|
|
14112
|
+
`Using ${this.examples.length} examples for training and ${validationExamples.length} for validation`,
|
|
13913
14113
|
{ tags: ["optimizer", "config"] }
|
|
13914
14114
|
);
|
|
13915
14115
|
if (this.teacherAI) {
|
|
@@ -13939,7 +14139,10 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13939
14139
|
);
|
|
13940
14140
|
}
|
|
13941
14141
|
}
|
|
13942
|
-
const instructions = await this.proposeInstructionCandidates(
|
|
14142
|
+
const instructions = await this.proposeInstructionCandidates(
|
|
14143
|
+
program,
|
|
14144
|
+
options
|
|
14145
|
+
);
|
|
13943
14146
|
if (this.isLoggingEnabled(options)) {
|
|
13944
14147
|
this.getLogger(options)?.(
|
|
13945
14148
|
`Generated ${instructions.length} instruction candidates`,
|
|
@@ -13957,7 +14160,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13957
14160
|
bootstrappedDemos,
|
|
13958
14161
|
labeledExamples,
|
|
13959
14162
|
instructions,
|
|
13960
|
-
|
|
14163
|
+
validationExamples,
|
|
13961
14164
|
metricFn,
|
|
13962
14165
|
options
|
|
13963
14166
|
);
|
|
@@ -14016,7 +14219,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
14016
14219
|
bootstrappedDemos: bestConfig.bootstrappedDemos,
|
|
14017
14220
|
labeledExamples: bestConfig.labeledExamples,
|
|
14018
14221
|
numCandidates: this.numCandidates,
|
|
14019
|
-
numTrials: this.numTrials
|
|
14222
|
+
numTrials: this.numTrials,
|
|
14223
|
+
sampleCount: this.sampleCount
|
|
14020
14224
|
}
|
|
14021
14225
|
};
|
|
14022
14226
|
}
|
|
@@ -14061,7 +14265,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
14061
14265
|
minImprovementThreshold: this.minImprovementThreshold,
|
|
14062
14266
|
bayesianOptimization: this.bayesianOptimization,
|
|
14063
14267
|
acquisitionFunction: this.acquisitionFunction,
|
|
14064
|
-
explorationWeight: this.explorationWeight
|
|
14268
|
+
explorationWeight: this.explorationWeight,
|
|
14269
|
+
sampleCount: this.sampleCount
|
|
14065
14270
|
};
|
|
14066
14271
|
}
|
|
14067
14272
|
/**
|
|
@@ -14096,12 +14301,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
14096
14301
|
if (config.minImprovementThreshold !== void 0) {
|
|
14097
14302
|
this.minImprovementThreshold = config.minImprovementThreshold;
|
|
14098
14303
|
}
|
|
14304
|
+
if (config.sampleCount !== void 0) {
|
|
14305
|
+
this.sampleCount = config.sampleCount;
|
|
14306
|
+
}
|
|
14099
14307
|
}
|
|
14100
14308
|
/**
|
|
14101
14309
|
* Reset optimizer state for reuse with different programs
|
|
14102
14310
|
*/
|
|
14103
14311
|
reset() {
|
|
14104
14312
|
super.reset();
|
|
14313
|
+
this.miproConfigHistory = [];
|
|
14314
|
+
this.surrogateModel.clear();
|
|
14105
14315
|
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
14106
14316
|
}
|
|
14107
14317
|
/**
|
|
@@ -14119,8 +14329,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
14119
14329
|
"Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
|
|
14120
14330
|
);
|
|
14121
14331
|
}
|
|
14122
|
-
const
|
|
14123
|
-
if (
|
|
14332
|
+
const validationSetSize = this.getValidationSet().length;
|
|
14333
|
+
if (validationSetSize < 5) {
|
|
14124
14334
|
result.issues.push(
|
|
14125
14335
|
"Validation set too small for reliable MiPRO optimization"
|
|
14126
14336
|
);
|
|
@@ -14134,6 +14344,141 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
14134
14344
|
suggestions: result.suggestions
|
|
14135
14345
|
};
|
|
14136
14346
|
}
|
|
14347
|
+
/**
|
|
14348
|
+
* Encodes a configuration into a string key for surrogate model lookup
|
|
14349
|
+
*/
|
|
14350
|
+
encodeConfiguration(config) {
|
|
14351
|
+
return `${config.instruction.length}_${config.bootstrappedDemos}_${config.labeledExamples}`;
|
|
14352
|
+
}
|
|
14353
|
+
/**
|
|
14354
|
+
* Updates the surrogate model with a new configuration-score pair
|
|
14355
|
+
*/
|
|
14356
|
+
updateSurrogateModel(config, score) {
|
|
14357
|
+
this.miproConfigHistory.push({ config: { ...config }, score });
|
|
14358
|
+
const key = this.encodeConfiguration(config);
|
|
14359
|
+
const similarConfigs = this.miproConfigHistory.filter(
|
|
14360
|
+
(entry) => this.encodeConfiguration(entry.config) === key
|
|
14361
|
+
);
|
|
14362
|
+
if (similarConfigs.length > 0) {
|
|
14363
|
+
const scores = similarConfigs.map((entry) => entry.score);
|
|
14364
|
+
const mean = scores.reduce((sum, s2) => sum + s2, 0) / scores.length;
|
|
14365
|
+
const variance = scores.length > 1 ? scores.reduce((sum, s2) => sum + Math.pow(s2 - mean, 2), 0) / (scores.length - 1) : 0.1;
|
|
14366
|
+
this.surrogateModel.set(key, { mean, variance });
|
|
14367
|
+
}
|
|
14368
|
+
}
|
|
14369
|
+
/**
|
|
14370
|
+
* Predicts performance using the surrogate model
|
|
14371
|
+
*/
|
|
14372
|
+
predictPerformance(config) {
|
|
14373
|
+
const key = this.encodeConfiguration(config);
|
|
14374
|
+
if (this.surrogateModel.has(key)) {
|
|
14375
|
+
return this.surrogateModel.get(key);
|
|
14376
|
+
}
|
|
14377
|
+
if (this.miproConfigHistory.length > 0) {
|
|
14378
|
+
const similarities = this.miproConfigHistory.map((entry) => {
|
|
14379
|
+
const diff = Math.abs(entry.config.bootstrappedDemos - config.bootstrappedDemos) + Math.abs(entry.config.labeledExamples - config.labeledExamples);
|
|
14380
|
+
return { score: entry.score, similarity: 1 / (1 + diff) };
|
|
14381
|
+
});
|
|
14382
|
+
const totalWeight = similarities.reduce((sum, s2) => sum + s2.similarity, 0);
|
|
14383
|
+
const weightedMean = similarities.reduce((sum, s2) => sum + s2.score * s2.similarity, 0) / totalWeight;
|
|
14384
|
+
return { mean: weightedMean, variance: 0.2 };
|
|
14385
|
+
}
|
|
14386
|
+
return { mean: 0.5, variance: 0.3 };
|
|
14387
|
+
}
|
|
14388
|
+
/**
|
|
14389
|
+
* Calculates acquisition function value for Bayesian optimization
|
|
14390
|
+
*/
|
|
14391
|
+
calculateAcquisitionValue(config) {
|
|
14392
|
+
const prediction = this.predictPerformance(config);
|
|
14393
|
+
const { mean, variance } = prediction;
|
|
14394
|
+
const std = Math.sqrt(variance);
|
|
14395
|
+
const bestScore = this.miproConfigHistory.length > 0 ? Math.max(...this.miproConfigHistory.map((entry) => entry.score)) : 0;
|
|
14396
|
+
switch (this.acquisitionFunction) {
|
|
14397
|
+
case "expected_improvement": {
|
|
14398
|
+
const improvement = mean - bestScore;
|
|
14399
|
+
if (std === 0) return Math.max(0, improvement);
|
|
14400
|
+
const z = improvement / std;
|
|
14401
|
+
const phi = 0.5 * (1 + this.erf(z / Math.sqrt(2)));
|
|
14402
|
+
const pdfValue = Math.exp(-0.5 * z * z) / Math.sqrt(2 * Math.PI);
|
|
14403
|
+
return improvement * phi + std * pdfValue;
|
|
14404
|
+
}
|
|
14405
|
+
case "upper_confidence_bound": {
|
|
14406
|
+
return mean + this.explorationWeight * std;
|
|
14407
|
+
}
|
|
14408
|
+
case "probability_improvement": {
|
|
14409
|
+
const improvement = mean - bestScore;
|
|
14410
|
+
if (std === 0) return improvement > 0 ? 1 : 0;
|
|
14411
|
+
const z = improvement / std;
|
|
14412
|
+
return 0.5 * (1 + this.erf(z / Math.sqrt(2)));
|
|
14413
|
+
}
|
|
14414
|
+
default:
|
|
14415
|
+
return mean;
|
|
14416
|
+
}
|
|
14417
|
+
}
|
|
14418
|
+
/**
|
|
14419
|
+
* Error function approximation for acquisition function calculations
|
|
14420
|
+
*/
|
|
14421
|
+
erf(x) {
|
|
14422
|
+
const a1 = 0.254829592;
|
|
14423
|
+
const a2 = -0.284496736;
|
|
14424
|
+
const a3 = 1.421413741;
|
|
14425
|
+
const a4 = -1.453152027;
|
|
14426
|
+
const a5 = 1.061405429;
|
|
14427
|
+
const p = 0.3275911;
|
|
14428
|
+
const sign = x >= 0 ? 1 : -1;
|
|
14429
|
+
x = Math.abs(x);
|
|
14430
|
+
const t = 1 / (1 + p * x);
|
|
14431
|
+
const y = 1 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.exp(-x * x);
|
|
14432
|
+
return sign * y;
|
|
14433
|
+
}
|
|
14434
|
+
/**
|
|
14435
|
+
* Selects the next configuration to evaluate using Bayesian optimization
|
|
14436
|
+
*/
|
|
14437
|
+
async selectConfigurationViaBayesianOptimization(instructions, bootstrappedDemos, labeledExamples) {
|
|
14438
|
+
const candidates = [];
|
|
14439
|
+
const numCandidates = Math.min(20, instructions.length * 3);
|
|
14440
|
+
for (let i = 0; i < numCandidates; i++) {
|
|
14441
|
+
const config = {
|
|
14442
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
14443
|
+
bootstrappedDemos: Math.min(
|
|
14444
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
14445
|
+
this.maxBootstrappedDemos
|
|
14446
|
+
),
|
|
14447
|
+
labeledExamples: Math.min(
|
|
14448
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
14449
|
+
this.maxLabeledDemos
|
|
14450
|
+
)
|
|
14451
|
+
};
|
|
14452
|
+
const acquisitionValue = this.calculateAcquisitionValue(config);
|
|
14453
|
+
candidates.push({ config, acquisitionValue });
|
|
14454
|
+
}
|
|
14455
|
+
candidates.sort((a, b) => b.acquisitionValue - a.acquisitionValue);
|
|
14456
|
+
return candidates[0].config;
|
|
14457
|
+
}
|
|
14458
|
+
};
|
|
14459
|
+
var axMajorityVotePicker = () => {
|
|
14460
|
+
return async (data) => {
|
|
14461
|
+
if (data.type === "fields") {
|
|
14462
|
+
const counts = {};
|
|
14463
|
+
for (const { index, sample } of data.results) {
|
|
14464
|
+
const key = JSON.stringify(sample);
|
|
14465
|
+
if (!counts[key]) {
|
|
14466
|
+
counts[key] = { count: 0, index };
|
|
14467
|
+
}
|
|
14468
|
+
counts[key].count += 1;
|
|
14469
|
+
}
|
|
14470
|
+
let bestKey;
|
|
14471
|
+
let bestCount = -1;
|
|
14472
|
+
for (const [k, v] of Object.entries(counts)) {
|
|
14473
|
+
if (v.count > bestCount) {
|
|
14474
|
+
bestCount = v.count;
|
|
14475
|
+
bestKey = k;
|
|
14476
|
+
}
|
|
14477
|
+
}
|
|
14478
|
+
return counts[bestKey]?.index ?? 0;
|
|
14479
|
+
}
|
|
14480
|
+
return data.results[0]?.index ?? 0;
|
|
14481
|
+
};
|
|
14137
14482
|
};
|
|
14138
14483
|
|
|
14139
14484
|
// ai/mock/api.ts
|