@ax-llm/ax 12.0.13 → 12.0.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +382 -53
- package/index.cjs.map +1 -1
- package/index.d.cts +47 -4
- package/index.d.ts +47 -4
- package/index.js +382 -53
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.cjs
CHANGED
|
@@ -11874,13 +11874,6 @@ var AxBaseOptimizer = class {
|
|
|
11874
11874
|
if (this.logger) {
|
|
11875
11875
|
return this.logger;
|
|
11876
11876
|
}
|
|
11877
|
-
try {
|
|
11878
|
-
const aiLogger = this.studentAI.getLogger();
|
|
11879
|
-
if (aiLogger) {
|
|
11880
|
-
return aiLogger;
|
|
11881
|
-
}
|
|
11882
|
-
} catch {
|
|
11883
|
-
}
|
|
11884
11877
|
return axDefaultOptimizerLogger;
|
|
11885
11878
|
}
|
|
11886
11879
|
/**
|
|
@@ -13572,6 +13565,11 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13572
13565
|
bayesianOptimization;
|
|
13573
13566
|
acquisitionFunction;
|
|
13574
13567
|
explorationWeight;
|
|
13568
|
+
// Self-consistency / multiple sampling
|
|
13569
|
+
sampleCount;
|
|
13570
|
+
// Surrogate model state for Bayesian optimization
|
|
13571
|
+
miproConfigHistory = [];
|
|
13572
|
+
surrogateModel = /* @__PURE__ */ new Map();
|
|
13575
13573
|
constructor(args) {
|
|
13576
13574
|
super(args);
|
|
13577
13575
|
const options = args.options || {};
|
|
@@ -13593,6 +13591,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13593
13591
|
this.bayesianOptimization = options.bayesianOptimization ?? false;
|
|
13594
13592
|
this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
|
|
13595
13593
|
this.explorationWeight = options.explorationWeight ?? 0.1;
|
|
13594
|
+
this.sampleCount = options.sampleCount ?? 1;
|
|
13596
13595
|
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13597
13596
|
}
|
|
13598
13597
|
/**
|
|
@@ -13637,43 +13636,186 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13637
13636
|
];
|
|
13638
13637
|
}
|
|
13639
13638
|
/**
|
|
13640
|
-
* Generates
|
|
13639
|
+
* Generates program summary for context-aware instruction generation
|
|
13640
|
+
*/
|
|
13641
|
+
async generateProgramSummary(program, ai) {
|
|
13642
|
+
let signature = "input -> output";
|
|
13643
|
+
if ("getSignature" in program && typeof program.getSignature === "function") {
|
|
13644
|
+
signature = program.getSignature();
|
|
13645
|
+
}
|
|
13646
|
+
const summaryPrompt = `
|
|
13647
|
+
Analyze this language model program and provide a concise summary of its purpose and structure.
|
|
13648
|
+
|
|
13649
|
+
Program Signature: ${signature}
|
|
13650
|
+
|
|
13651
|
+
Provide a 2-3 sentence summary focusing on:
|
|
13652
|
+
1. The main task or purpose of this program
|
|
13653
|
+
2. The input-output relationship
|
|
13654
|
+
3. Any special constraints or requirements
|
|
13655
|
+
|
|
13656
|
+
Summary:`;
|
|
13657
|
+
try {
|
|
13658
|
+
const response = await ai.chat({
|
|
13659
|
+
chatPrompt: [{ role: "user", content: summaryPrompt }]
|
|
13660
|
+
});
|
|
13661
|
+
if ("results" in response) {
|
|
13662
|
+
return response.results[0]?.content?.trim() || "General language model program";
|
|
13663
|
+
}
|
|
13664
|
+
return "General language model program";
|
|
13665
|
+
} catch {
|
|
13666
|
+
return "General language model program";
|
|
13667
|
+
}
|
|
13668
|
+
}
|
|
13669
|
+
/**
|
|
13670
|
+
* Generates dataset summary for context-aware instruction generation
|
|
13671
|
+
*/
|
|
13672
|
+
async generateDatasetSummary(examples, ai) {
|
|
13673
|
+
if (examples.length === 0) return "No examples available";
|
|
13674
|
+
const sampleSize = Math.min(this.viewDataBatchSize, examples.length);
|
|
13675
|
+
const sampledExamples = examples.slice(0, sampleSize);
|
|
13676
|
+
const exampleTexts = sampledExamples.map((ex, i) => `Example ${i + 1}: ${JSON.stringify(ex)}`).join("\n");
|
|
13677
|
+
const summaryPrompt = `
|
|
13678
|
+
Analyze this dataset and provide a concise summary of its characteristics.
|
|
13679
|
+
|
|
13680
|
+
Sample Examples:
|
|
13681
|
+
${exampleTexts}
|
|
13682
|
+
|
|
13683
|
+
Provide a 2-3 sentence summary focusing on:
|
|
13684
|
+
1. The type of data and domain
|
|
13685
|
+
2. Common patterns or structures in the examples
|
|
13686
|
+
3. Key challenges or requirements for processing this data
|
|
13687
|
+
|
|
13688
|
+
Dataset Summary:`;
|
|
13689
|
+
try {
|
|
13690
|
+
const response = await ai.chat({
|
|
13691
|
+
chatPrompt: [{ role: "user", content: summaryPrompt }]
|
|
13692
|
+
});
|
|
13693
|
+
if ("results" in response) {
|
|
13694
|
+
return response.results[0]?.content?.trim() || "General dataset";
|
|
13695
|
+
}
|
|
13696
|
+
return "General dataset";
|
|
13697
|
+
} catch {
|
|
13698
|
+
return "General dataset";
|
|
13699
|
+
}
|
|
13700
|
+
}
|
|
13701
|
+
/**
|
|
13702
|
+
* Enhanced instruction generation using AI with program and data awareness
|
|
13703
|
+
*/
|
|
13704
|
+
async generateInstruction({
|
|
13705
|
+
tip,
|
|
13706
|
+
candidateIndex,
|
|
13707
|
+
ai,
|
|
13708
|
+
programSummary,
|
|
13709
|
+
datasetSummary,
|
|
13710
|
+
previousInstructions = []
|
|
13711
|
+
}) {
|
|
13712
|
+
let contextInfo = "";
|
|
13713
|
+
if (this.programAwareProposer && programSummary) {
|
|
13714
|
+
contextInfo += `
|
|
13715
|
+
Program Context: ${programSummary}`;
|
|
13716
|
+
}
|
|
13717
|
+
if (this.dataAwareProposer && datasetSummary) {
|
|
13718
|
+
contextInfo += `
|
|
13719
|
+
Dataset Context: ${datasetSummary}`;
|
|
13720
|
+
}
|
|
13721
|
+
if (this.fewshotAwareProposer && previousInstructions.length > 0) {
|
|
13722
|
+
contextInfo += `
|
|
13723
|
+
Previous Instructions (avoid repeating): ${previousInstructions.slice(-3).join("; ")}`;
|
|
13724
|
+
}
|
|
13725
|
+
const instructionPrompt = `
|
|
13726
|
+
Generate a high-quality instruction for a language model program.
|
|
13727
|
+
|
|
13728
|
+
${contextInfo}
|
|
13729
|
+
|
|
13730
|
+
${tip ? `Tip: ${tip}` : ""}
|
|
13731
|
+
|
|
13732
|
+
Requirements:
|
|
13733
|
+
1. Be specific and actionable
|
|
13734
|
+
2. Focus on accuracy and clarity
|
|
13735
|
+
3. Consider the program's purpose and data characteristics
|
|
13736
|
+
4. Make the instruction distinct from previous ones
|
|
13737
|
+
5. Keep it concise but comprehensive
|
|
13738
|
+
|
|
13739
|
+
Generate a single, well-crafted instruction:
|
|
13740
|
+
Instruction:`;
|
|
13741
|
+
try {
|
|
13742
|
+
const response = await ai.chat({
|
|
13743
|
+
chatPrompt: [
|
|
13744
|
+
{
|
|
13745
|
+
role: "user",
|
|
13746
|
+
content: instructionPrompt
|
|
13747
|
+
}
|
|
13748
|
+
]
|
|
13749
|
+
});
|
|
13750
|
+
if ("results" in response) {
|
|
13751
|
+
const instruction2 = response.results[0]?.content?.trim();
|
|
13752
|
+
if (instruction2 && instruction2.length > 10) {
|
|
13753
|
+
return instruction2;
|
|
13754
|
+
}
|
|
13755
|
+
}
|
|
13756
|
+
} catch (error) {
|
|
13757
|
+
if (this.isLoggingEnabled()) {
|
|
13758
|
+
this.getLogger()?.(`Failed to generate AI instruction: ${error}`, {
|
|
13759
|
+
tags: ["optimizer", "warning"]
|
|
13760
|
+
});
|
|
13761
|
+
}
|
|
13762
|
+
}
|
|
13763
|
+
const enhancedTemplates = [
|
|
13764
|
+
"Analyze the input systematically and provide a precise, well-reasoned response.",
|
|
13765
|
+
"Think through this step-by-step, considering all relevant factors before responding.",
|
|
13766
|
+
"Examine the input carefully and generate an accurate, detailed answer.",
|
|
13767
|
+
"Process the information methodically and deliver a clear, comprehensive response.",
|
|
13768
|
+
"Consider the context thoroughly and provide a thoughtful, accurate answer."
|
|
13769
|
+
];
|
|
13770
|
+
let instruction = enhancedTemplates[candidateIndex % enhancedTemplates.length] || enhancedTemplates[0];
|
|
13771
|
+
if (tip) {
|
|
13772
|
+
instruction = `${instruction} ${tip}`;
|
|
13773
|
+
}
|
|
13774
|
+
return instruction;
|
|
13775
|
+
}
|
|
13776
|
+
/**
|
|
13777
|
+
* Generates instruction candidates using enhanced AI-powered generation
|
|
13641
13778
|
* @param options Optional compile options that may override teacher AI
|
|
13642
13779
|
* @returns Array of generated instruction candidates
|
|
13643
13780
|
*/
|
|
13644
|
-
async proposeInstructionCandidates(options) {
|
|
13781
|
+
async proposeInstructionCandidates(program, options) {
|
|
13645
13782
|
const instructions = [];
|
|
13646
13783
|
const aiToUse = this.getTeacherOrStudentAI(options);
|
|
13784
|
+
let programSummary;
|
|
13785
|
+
let datasetSummary;
|
|
13786
|
+
if (this.programAwareProposer) {
|
|
13787
|
+
programSummary = await this.generateProgramSummary(program, aiToUse);
|
|
13788
|
+
if (this.isLoggingEnabled(options)) {
|
|
13789
|
+
this.getLogger(options)?.(`Program summary: ${programSummary}`, {
|
|
13790
|
+
tags: ["optimizer", "config"]
|
|
13791
|
+
});
|
|
13792
|
+
}
|
|
13793
|
+
}
|
|
13794
|
+
if (this.dataAwareProposer) {
|
|
13795
|
+
datasetSummary = await this.generateDatasetSummary(this.examples, aiToUse);
|
|
13796
|
+
if (this.isLoggingEnabled(options)) {
|
|
13797
|
+
this.getLogger(options)?.(`Dataset summary: ${datasetSummary}`, {
|
|
13798
|
+
tags: ["optimizer", "config"]
|
|
13799
|
+
});
|
|
13800
|
+
}
|
|
13801
|
+
}
|
|
13647
13802
|
const tips = this.tipAwareProposer ? this.generateTips() : [];
|
|
13648
13803
|
for (let i = 0; i < this.numCandidates; i++) {
|
|
13649
13804
|
const tipIndex = tips.length > 0 ? i % tips.length : -1;
|
|
13650
|
-
const tipToUse = tipIndex >= 0 ? tips[tipIndex] :
|
|
13805
|
+
const tipToUse = tipIndex >= 0 ? tips[tipIndex] : void 0;
|
|
13651
13806
|
const instruction = await this.generateInstruction({
|
|
13652
13807
|
tip: tipToUse,
|
|
13653
13808
|
candidateIndex: i,
|
|
13654
|
-
ai: aiToUse
|
|
13809
|
+
ai: aiToUse,
|
|
13810
|
+
programSummary,
|
|
13811
|
+
datasetSummary,
|
|
13812
|
+
previousInstructions: instructions
|
|
13813
|
+
// Pass previous instructions for diversity
|
|
13655
13814
|
});
|
|
13656
13815
|
instructions.push(instruction);
|
|
13657
13816
|
}
|
|
13658
13817
|
return instructions;
|
|
13659
13818
|
}
|
|
13660
|
-
async generateInstruction({
|
|
13661
|
-
tip,
|
|
13662
|
-
candidateIndex
|
|
13663
|
-
}) {
|
|
13664
|
-
const baseInstructions = [
|
|
13665
|
-
"Analyze the input carefully and provide a detailed response.",
|
|
13666
|
-
"Think step by step and provide a clear answer.",
|
|
13667
|
-
"Consider all aspects of the input before responding.",
|
|
13668
|
-
"Provide a concise but comprehensive response.",
|
|
13669
|
-
"Focus on accuracy and clarity in your response."
|
|
13670
|
-
];
|
|
13671
|
-
let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
|
|
13672
|
-
if (tip) {
|
|
13673
|
-
instruction = `${instruction} ${tip}`;
|
|
13674
|
-
}
|
|
13675
|
-
return instruction;
|
|
13676
|
-
}
|
|
13677
13819
|
/**
|
|
13678
13820
|
* Bootstraps few-shot examples for the program
|
|
13679
13821
|
*/
|
|
@@ -13718,7 +13860,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13718
13860
|
/**
|
|
13719
13861
|
* Runs optimization to find the best combination of few-shot examples and instructions
|
|
13720
13862
|
*/
|
|
13721
|
-
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions,
|
|
13863
|
+
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, validationExamples, metricFn, options) {
|
|
13722
13864
|
let bestConfig = {
|
|
13723
13865
|
instruction: instructions[0] || "",
|
|
13724
13866
|
bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
|
|
@@ -13754,25 +13896,37 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13754
13896
|
);
|
|
13755
13897
|
}
|
|
13756
13898
|
for (let i = startRound; i < this.numTrials; i++) {
|
|
13757
|
-
|
|
13758
|
-
|
|
13759
|
-
|
|
13760
|
-
|
|
13761
|
-
|
|
13762
|
-
|
|
13763
|
-
|
|
13764
|
-
|
|
13765
|
-
|
|
13766
|
-
|
|
13767
|
-
|
|
13899
|
+
let config;
|
|
13900
|
+
if (this.bayesianOptimization && this.miproConfigHistory.length > 2) {
|
|
13901
|
+
config = await this.selectConfigurationViaBayesianOptimization(
|
|
13902
|
+
instructions,
|
|
13903
|
+
bootstrappedDemos,
|
|
13904
|
+
labeledExamples
|
|
13905
|
+
);
|
|
13906
|
+
} else {
|
|
13907
|
+
config = {
|
|
13908
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
13909
|
+
bootstrappedDemos: Math.min(
|
|
13910
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
13911
|
+
this.maxBootstrappedDemos
|
|
13912
|
+
),
|
|
13913
|
+
labeledExamples: Math.min(
|
|
13914
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
13915
|
+
this.maxLabeledDemos
|
|
13916
|
+
)
|
|
13917
|
+
};
|
|
13918
|
+
}
|
|
13768
13919
|
const score = await this.evaluateConfig(
|
|
13769
13920
|
program,
|
|
13770
13921
|
config,
|
|
13771
13922
|
bootstrappedDemos,
|
|
13772
13923
|
labeledExamples,
|
|
13773
|
-
|
|
13774
|
-
metricFn
|
|
13924
|
+
validationExamples,
|
|
13925
|
+
metricFn,
|
|
13926
|
+
i + 1
|
|
13927
|
+
// Pass current trial number for adaptive evaluation
|
|
13775
13928
|
);
|
|
13929
|
+
this.updateSurrogateModel(config, score);
|
|
13776
13930
|
scoreHistory.push(score);
|
|
13777
13931
|
const improvement = score - bestScore;
|
|
13778
13932
|
if (improvement > this.minImprovementThreshold) {
|
|
@@ -13854,7 +14008,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13854
14008
|
this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
|
|
13855
14009
|
return { bestConfig, bestScore };
|
|
13856
14010
|
}
|
|
13857
|
-
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples,
|
|
14011
|
+
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, validationExamples, metricFn, currentTrial = 0) {
|
|
13858
14012
|
const testProgram = { ...program };
|
|
13859
14013
|
this.applyConfigToProgram(
|
|
13860
14014
|
testProgram,
|
|
@@ -13864,12 +14018,31 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13864
14018
|
);
|
|
13865
14019
|
let totalScore = 0;
|
|
13866
14020
|
let count = 0;
|
|
13867
|
-
|
|
14021
|
+
let evalSize;
|
|
14022
|
+
if (this.minibatch) {
|
|
14023
|
+
const baseSize = Math.min(this.minibatchSize, validationExamples.length);
|
|
14024
|
+
const isFullEvalTrial = currentTrial % this.minibatchFullEvalSteps === 0;
|
|
14025
|
+
if (isFullEvalTrial || currentTrial > this.numTrials * 0.8) {
|
|
14026
|
+
evalSize = Math.min(validationExamples.length, baseSize * 2);
|
|
14027
|
+
} else {
|
|
14028
|
+
evalSize = Math.max(3, Math.min(baseSize, validationExamples.length));
|
|
14029
|
+
}
|
|
14030
|
+
} else {
|
|
14031
|
+
evalSize = validationExamples.length;
|
|
14032
|
+
}
|
|
14033
|
+
const evalIndices = this.shuffleArray([
|
|
14034
|
+
...Array(validationExamples.length).keys()
|
|
14035
|
+
]).slice(0, evalSize);
|
|
14036
|
+
const evalSet = evalIndices.map((i) => validationExamples[i]);
|
|
13868
14037
|
for (const example of evalSet) {
|
|
13869
14038
|
try {
|
|
13870
14039
|
const prediction = await testProgram.forward(
|
|
13871
14040
|
this.studentAI,
|
|
13872
|
-
example
|
|
14041
|
+
example,
|
|
14042
|
+
this.sampleCount > 1 ? {
|
|
14043
|
+
sampleCount: this.sampleCount,
|
|
14044
|
+
resultPicker: axMajorityVotePicker()
|
|
14045
|
+
} : void 0
|
|
13873
14046
|
);
|
|
13874
14047
|
const score = await metricFn({ prediction, example });
|
|
13875
14048
|
totalScore += score;
|
|
@@ -13881,6 +14054,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13881
14054
|
}
|
|
13882
14055
|
return count > 0 ? totalScore / count : 0;
|
|
13883
14056
|
}
|
|
14057
|
+
/**
|
|
14058
|
+
* Fisher-Yates shuffle for stochastic evaluation
|
|
14059
|
+
*/
|
|
14060
|
+
shuffleArray(array) {
|
|
14061
|
+
const shuffled = [...array];
|
|
14062
|
+
for (let i = shuffled.length - 1; i > 0; i--) {
|
|
14063
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
14064
|
+
[shuffled[i], shuffled[j]] = [shuffled[j], shuffled[i]];
|
|
14065
|
+
}
|
|
14066
|
+
return shuffled;
|
|
14067
|
+
}
|
|
13884
14068
|
applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
|
|
13885
14069
|
if (program.setInstruction) {
|
|
13886
14070
|
program.setInstruction(config.instruction);
|
|
@@ -13902,14 +14086,14 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13902
14086
|
if (miproOptions?.auto) {
|
|
13903
14087
|
this.configureAuto(miproOptions.auto);
|
|
13904
14088
|
}
|
|
13905
|
-
const
|
|
14089
|
+
const validationExamples = this.getValidationSet(options) || (miproOptions?.validationExamples ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
|
|
13906
14090
|
if (this.isLoggingEnabled(options)) {
|
|
13907
14091
|
this.getLogger(options)?.(
|
|
13908
14092
|
`Starting MIPROv2 optimization with ${this.numTrials} trials`,
|
|
13909
14093
|
{ tags: ["optimizer", "start"] }
|
|
13910
14094
|
);
|
|
13911
14095
|
this.getLogger(options)?.(
|
|
13912
|
-
`Using ${this.examples.length} examples for training and ${
|
|
14096
|
+
`Using ${this.examples.length} examples for training and ${validationExamples.length} for validation`,
|
|
13913
14097
|
{ tags: ["optimizer", "config"] }
|
|
13914
14098
|
);
|
|
13915
14099
|
if (this.teacherAI) {
|
|
@@ -13939,7 +14123,10 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13939
14123
|
);
|
|
13940
14124
|
}
|
|
13941
14125
|
}
|
|
13942
|
-
const instructions = await this.proposeInstructionCandidates(
|
|
14126
|
+
const instructions = await this.proposeInstructionCandidates(
|
|
14127
|
+
program,
|
|
14128
|
+
options
|
|
14129
|
+
);
|
|
13943
14130
|
if (this.isLoggingEnabled(options)) {
|
|
13944
14131
|
this.getLogger(options)?.(
|
|
13945
14132
|
`Generated ${instructions.length} instruction candidates`,
|
|
@@ -13957,7 +14144,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13957
14144
|
bootstrappedDemos,
|
|
13958
14145
|
labeledExamples,
|
|
13959
14146
|
instructions,
|
|
13960
|
-
|
|
14147
|
+
validationExamples,
|
|
13961
14148
|
metricFn,
|
|
13962
14149
|
options
|
|
13963
14150
|
);
|
|
@@ -14016,7 +14203,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
14016
14203
|
bootstrappedDemos: bestConfig.bootstrappedDemos,
|
|
14017
14204
|
labeledExamples: bestConfig.labeledExamples,
|
|
14018
14205
|
numCandidates: this.numCandidates,
|
|
14019
|
-
numTrials: this.numTrials
|
|
14206
|
+
numTrials: this.numTrials,
|
|
14207
|
+
sampleCount: this.sampleCount
|
|
14020
14208
|
}
|
|
14021
14209
|
};
|
|
14022
14210
|
}
|
|
@@ -14061,7 +14249,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
14061
14249
|
minImprovementThreshold: this.minImprovementThreshold,
|
|
14062
14250
|
bayesianOptimization: this.bayesianOptimization,
|
|
14063
14251
|
acquisitionFunction: this.acquisitionFunction,
|
|
14064
|
-
explorationWeight: this.explorationWeight
|
|
14252
|
+
explorationWeight: this.explorationWeight,
|
|
14253
|
+
sampleCount: this.sampleCount
|
|
14065
14254
|
};
|
|
14066
14255
|
}
|
|
14067
14256
|
/**
|
|
@@ -14096,12 +14285,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
14096
14285
|
if (config.minImprovementThreshold !== void 0) {
|
|
14097
14286
|
this.minImprovementThreshold = config.minImprovementThreshold;
|
|
14098
14287
|
}
|
|
14288
|
+
if (config.sampleCount !== void 0) {
|
|
14289
|
+
this.sampleCount = config.sampleCount;
|
|
14290
|
+
}
|
|
14099
14291
|
}
|
|
14100
14292
|
/**
|
|
14101
14293
|
* Reset optimizer state for reuse with different programs
|
|
14102
14294
|
*/
|
|
14103
14295
|
reset() {
|
|
14104
14296
|
super.reset();
|
|
14297
|
+
this.miproConfigHistory = [];
|
|
14298
|
+
this.surrogateModel.clear();
|
|
14105
14299
|
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
14106
14300
|
}
|
|
14107
14301
|
/**
|
|
@@ -14119,8 +14313,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
14119
14313
|
"Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
|
|
14120
14314
|
);
|
|
14121
14315
|
}
|
|
14122
|
-
const
|
|
14123
|
-
if (
|
|
14316
|
+
const validationSetSize = this.getValidationSet().length;
|
|
14317
|
+
if (validationSetSize < 5) {
|
|
14124
14318
|
result.issues.push(
|
|
14125
14319
|
"Validation set too small for reliable MiPRO optimization"
|
|
14126
14320
|
);
|
|
@@ -14134,6 +14328,141 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
14134
14328
|
suggestions: result.suggestions
|
|
14135
14329
|
};
|
|
14136
14330
|
}
|
|
14331
|
+
/**
|
|
14332
|
+
* Encodes a configuration into a string key for surrogate model lookup
|
|
14333
|
+
*/
|
|
14334
|
+
encodeConfiguration(config) {
|
|
14335
|
+
return `${config.instruction.length}_${config.bootstrappedDemos}_${config.labeledExamples}`;
|
|
14336
|
+
}
|
|
14337
|
+
/**
|
|
14338
|
+
* Updates the surrogate model with a new configuration-score pair
|
|
14339
|
+
*/
|
|
14340
|
+
updateSurrogateModel(config, score) {
|
|
14341
|
+
this.miproConfigHistory.push({ config: { ...config }, score });
|
|
14342
|
+
const key = this.encodeConfiguration(config);
|
|
14343
|
+
const similarConfigs = this.miproConfigHistory.filter(
|
|
14344
|
+
(entry) => this.encodeConfiguration(entry.config) === key
|
|
14345
|
+
);
|
|
14346
|
+
if (similarConfigs.length > 0) {
|
|
14347
|
+
const scores = similarConfigs.map((entry) => entry.score);
|
|
14348
|
+
const mean = scores.reduce((sum, s2) => sum + s2, 0) / scores.length;
|
|
14349
|
+
const variance = scores.length > 1 ? scores.reduce((sum, s2) => sum + Math.pow(s2 - mean, 2), 0) / (scores.length - 1) : 0.1;
|
|
14350
|
+
this.surrogateModel.set(key, { mean, variance });
|
|
14351
|
+
}
|
|
14352
|
+
}
|
|
14353
|
+
/**
|
|
14354
|
+
* Predicts performance using the surrogate model
|
|
14355
|
+
*/
|
|
14356
|
+
predictPerformance(config) {
|
|
14357
|
+
const key = this.encodeConfiguration(config);
|
|
14358
|
+
if (this.surrogateModel.has(key)) {
|
|
14359
|
+
return this.surrogateModel.get(key);
|
|
14360
|
+
}
|
|
14361
|
+
if (this.miproConfigHistory.length > 0) {
|
|
14362
|
+
const similarities = this.miproConfigHistory.map((entry) => {
|
|
14363
|
+
const diff = Math.abs(entry.config.bootstrappedDemos - config.bootstrappedDemos) + Math.abs(entry.config.labeledExamples - config.labeledExamples);
|
|
14364
|
+
return { score: entry.score, similarity: 1 / (1 + diff) };
|
|
14365
|
+
});
|
|
14366
|
+
const totalWeight = similarities.reduce((sum, s2) => sum + s2.similarity, 0);
|
|
14367
|
+
const weightedMean = similarities.reduce((sum, s2) => sum + s2.score * s2.similarity, 0) / totalWeight;
|
|
14368
|
+
return { mean: weightedMean, variance: 0.2 };
|
|
14369
|
+
}
|
|
14370
|
+
return { mean: 0.5, variance: 0.3 };
|
|
14371
|
+
}
|
|
14372
|
+
/**
|
|
14373
|
+
* Calculates acquisition function value for Bayesian optimization
|
|
14374
|
+
*/
|
|
14375
|
+
calculateAcquisitionValue(config) {
|
|
14376
|
+
const prediction = this.predictPerformance(config);
|
|
14377
|
+
const { mean, variance } = prediction;
|
|
14378
|
+
const std = Math.sqrt(variance);
|
|
14379
|
+
const bestScore = this.miproConfigHistory.length > 0 ? Math.max(...this.miproConfigHistory.map((entry) => entry.score)) : 0;
|
|
14380
|
+
switch (this.acquisitionFunction) {
|
|
14381
|
+
case "expected_improvement": {
|
|
14382
|
+
const improvement = mean - bestScore;
|
|
14383
|
+
if (std === 0) return Math.max(0, improvement);
|
|
14384
|
+
const z = improvement / std;
|
|
14385
|
+
const phi = 0.5 * (1 + this.erf(z / Math.sqrt(2)));
|
|
14386
|
+
const pdfValue = Math.exp(-0.5 * z * z) / Math.sqrt(2 * Math.PI);
|
|
14387
|
+
return improvement * phi + std * pdfValue;
|
|
14388
|
+
}
|
|
14389
|
+
case "upper_confidence_bound": {
|
|
14390
|
+
return mean + this.explorationWeight * std;
|
|
14391
|
+
}
|
|
14392
|
+
case "probability_improvement": {
|
|
14393
|
+
const improvement = mean - bestScore;
|
|
14394
|
+
if (std === 0) return improvement > 0 ? 1 : 0;
|
|
14395
|
+
const z = improvement / std;
|
|
14396
|
+
return 0.5 * (1 + this.erf(z / Math.sqrt(2)));
|
|
14397
|
+
}
|
|
14398
|
+
default:
|
|
14399
|
+
return mean;
|
|
14400
|
+
}
|
|
14401
|
+
}
|
|
14402
|
+
/**
|
|
14403
|
+
* Error function approximation for acquisition function calculations
|
|
14404
|
+
*/
|
|
14405
|
+
erf(x) {
|
|
14406
|
+
const a1 = 0.254829592;
|
|
14407
|
+
const a2 = -0.284496736;
|
|
14408
|
+
const a3 = 1.421413741;
|
|
14409
|
+
const a4 = -1.453152027;
|
|
14410
|
+
const a5 = 1.061405429;
|
|
14411
|
+
const p = 0.3275911;
|
|
14412
|
+
const sign = x >= 0 ? 1 : -1;
|
|
14413
|
+
x = Math.abs(x);
|
|
14414
|
+
const t = 1 / (1 + p * x);
|
|
14415
|
+
const y = 1 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.exp(-x * x);
|
|
14416
|
+
return sign * y;
|
|
14417
|
+
}
|
|
14418
|
+
/**
|
|
14419
|
+
* Selects the next configuration to evaluate using Bayesian optimization
|
|
14420
|
+
*/
|
|
14421
|
+
async selectConfigurationViaBayesianOptimization(instructions, bootstrappedDemos, labeledExamples) {
|
|
14422
|
+
const candidates = [];
|
|
14423
|
+
const numCandidates = Math.min(20, instructions.length * 3);
|
|
14424
|
+
for (let i = 0; i < numCandidates; i++) {
|
|
14425
|
+
const config = {
|
|
14426
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
14427
|
+
bootstrappedDemos: Math.min(
|
|
14428
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
14429
|
+
this.maxBootstrappedDemos
|
|
14430
|
+
),
|
|
14431
|
+
labeledExamples: Math.min(
|
|
14432
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
14433
|
+
this.maxLabeledDemos
|
|
14434
|
+
)
|
|
14435
|
+
};
|
|
14436
|
+
const acquisitionValue = this.calculateAcquisitionValue(config);
|
|
14437
|
+
candidates.push({ config, acquisitionValue });
|
|
14438
|
+
}
|
|
14439
|
+
candidates.sort((a, b) => b.acquisitionValue - a.acquisitionValue);
|
|
14440
|
+
return candidates[0].config;
|
|
14441
|
+
}
|
|
14442
|
+
};
|
|
14443
|
+
var axMajorityVotePicker = () => {
|
|
14444
|
+
return async (data) => {
|
|
14445
|
+
if (data.type === "fields") {
|
|
14446
|
+
const counts = {};
|
|
14447
|
+
for (const { index, sample } of data.results) {
|
|
14448
|
+
const key = JSON.stringify(sample);
|
|
14449
|
+
if (!counts[key]) {
|
|
14450
|
+
counts[key] = { count: 0, index };
|
|
14451
|
+
}
|
|
14452
|
+
counts[key].count += 1;
|
|
14453
|
+
}
|
|
14454
|
+
let bestKey;
|
|
14455
|
+
let bestCount = -1;
|
|
14456
|
+
for (const [k, v] of Object.entries(counts)) {
|
|
14457
|
+
if (v.count > bestCount) {
|
|
14458
|
+
bestCount = v.count;
|
|
14459
|
+
bestKey = k;
|
|
14460
|
+
}
|
|
14461
|
+
}
|
|
14462
|
+
return counts[bestKey]?.index ?? 0;
|
|
14463
|
+
}
|
|
14464
|
+
return data.results[0]?.index ?? 0;
|
|
14465
|
+
};
|
|
14137
14466
|
};
|
|
14138
14467
|
|
|
14139
14468
|
// ai/mock/api.ts
|