@ax-llm/ax 12.0.13 → 12.0.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +382 -53
- package/index.cjs.map +1 -1
- package/index.d.cts +47 -4
- package/index.d.ts +47 -4
- package/index.js +382 -53
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.js
CHANGED
|
@@ -11700,13 +11700,6 @@ var AxBaseOptimizer = class {
|
|
|
11700
11700
|
if (this.logger) {
|
|
11701
11701
|
return this.logger;
|
|
11702
11702
|
}
|
|
11703
|
-
try {
|
|
11704
|
-
const aiLogger = this.studentAI.getLogger();
|
|
11705
|
-
if (aiLogger) {
|
|
11706
|
-
return aiLogger;
|
|
11707
|
-
}
|
|
11708
|
-
} catch {
|
|
11709
|
-
}
|
|
11710
11703
|
return axDefaultOptimizerLogger;
|
|
11711
11704
|
}
|
|
11712
11705
|
/**
|
|
@@ -13398,6 +13391,11 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13398
13391
|
bayesianOptimization;
|
|
13399
13392
|
acquisitionFunction;
|
|
13400
13393
|
explorationWeight;
|
|
13394
|
+
// Self-consistency / multiple sampling
|
|
13395
|
+
sampleCount;
|
|
13396
|
+
// Surrogate model state for Bayesian optimization
|
|
13397
|
+
miproConfigHistory = [];
|
|
13398
|
+
surrogateModel = /* @__PURE__ */ new Map();
|
|
13401
13399
|
constructor(args) {
|
|
13402
13400
|
super(args);
|
|
13403
13401
|
const options = args.options || {};
|
|
@@ -13419,6 +13417,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13419
13417
|
this.bayesianOptimization = options.bayesianOptimization ?? false;
|
|
13420
13418
|
this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
|
|
13421
13419
|
this.explorationWeight = options.explorationWeight ?? 0.1;
|
|
13420
|
+
this.sampleCount = options.sampleCount ?? 1;
|
|
13422
13421
|
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13423
13422
|
}
|
|
13424
13423
|
/**
|
|
@@ -13463,43 +13462,186 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13463
13462
|
];
|
|
13464
13463
|
}
|
|
13465
13464
|
/**
|
|
13466
|
-
* Generates
|
|
13465
|
+
* Generates program summary for context-aware instruction generation
|
|
13466
|
+
*/
|
|
13467
|
+
async generateProgramSummary(program, ai) {
|
|
13468
|
+
let signature = "input -> output";
|
|
13469
|
+
if ("getSignature" in program && typeof program.getSignature === "function") {
|
|
13470
|
+
signature = program.getSignature();
|
|
13471
|
+
}
|
|
13472
|
+
const summaryPrompt = `
|
|
13473
|
+
Analyze this language model program and provide a concise summary of its purpose and structure.
|
|
13474
|
+
|
|
13475
|
+
Program Signature: ${signature}
|
|
13476
|
+
|
|
13477
|
+
Provide a 2-3 sentence summary focusing on:
|
|
13478
|
+
1. The main task or purpose of this program
|
|
13479
|
+
2. The input-output relationship
|
|
13480
|
+
3. Any special constraints or requirements
|
|
13481
|
+
|
|
13482
|
+
Summary:`;
|
|
13483
|
+
try {
|
|
13484
|
+
const response = await ai.chat({
|
|
13485
|
+
chatPrompt: [{ role: "user", content: summaryPrompt }]
|
|
13486
|
+
});
|
|
13487
|
+
if ("results" in response) {
|
|
13488
|
+
return response.results[0]?.content?.trim() || "General language model program";
|
|
13489
|
+
}
|
|
13490
|
+
return "General language model program";
|
|
13491
|
+
} catch {
|
|
13492
|
+
return "General language model program";
|
|
13493
|
+
}
|
|
13494
|
+
}
|
|
13495
|
+
/**
|
|
13496
|
+
* Generates dataset summary for context-aware instruction generation
|
|
13497
|
+
*/
|
|
13498
|
+
async generateDatasetSummary(examples, ai) {
|
|
13499
|
+
if (examples.length === 0) return "No examples available";
|
|
13500
|
+
const sampleSize = Math.min(this.viewDataBatchSize, examples.length);
|
|
13501
|
+
const sampledExamples = examples.slice(0, sampleSize);
|
|
13502
|
+
const exampleTexts = sampledExamples.map((ex, i) => `Example ${i + 1}: ${JSON.stringify(ex)}`).join("\n");
|
|
13503
|
+
const summaryPrompt = `
|
|
13504
|
+
Analyze this dataset and provide a concise summary of its characteristics.
|
|
13505
|
+
|
|
13506
|
+
Sample Examples:
|
|
13507
|
+
${exampleTexts}
|
|
13508
|
+
|
|
13509
|
+
Provide a 2-3 sentence summary focusing on:
|
|
13510
|
+
1. The type of data and domain
|
|
13511
|
+
2. Common patterns or structures in the examples
|
|
13512
|
+
3. Key challenges or requirements for processing this data
|
|
13513
|
+
|
|
13514
|
+
Dataset Summary:`;
|
|
13515
|
+
try {
|
|
13516
|
+
const response = await ai.chat({
|
|
13517
|
+
chatPrompt: [{ role: "user", content: summaryPrompt }]
|
|
13518
|
+
});
|
|
13519
|
+
if ("results" in response) {
|
|
13520
|
+
return response.results[0]?.content?.trim() || "General dataset";
|
|
13521
|
+
}
|
|
13522
|
+
return "General dataset";
|
|
13523
|
+
} catch {
|
|
13524
|
+
return "General dataset";
|
|
13525
|
+
}
|
|
13526
|
+
}
|
|
13527
|
+
/**
|
|
13528
|
+
* Enhanced instruction generation using AI with program and data awareness
|
|
13529
|
+
*/
|
|
13530
|
+
async generateInstruction({
|
|
13531
|
+
tip,
|
|
13532
|
+
candidateIndex,
|
|
13533
|
+
ai,
|
|
13534
|
+
programSummary,
|
|
13535
|
+
datasetSummary,
|
|
13536
|
+
previousInstructions = []
|
|
13537
|
+
}) {
|
|
13538
|
+
let contextInfo = "";
|
|
13539
|
+
if (this.programAwareProposer && programSummary) {
|
|
13540
|
+
contextInfo += `
|
|
13541
|
+
Program Context: ${programSummary}`;
|
|
13542
|
+
}
|
|
13543
|
+
if (this.dataAwareProposer && datasetSummary) {
|
|
13544
|
+
contextInfo += `
|
|
13545
|
+
Dataset Context: ${datasetSummary}`;
|
|
13546
|
+
}
|
|
13547
|
+
if (this.fewshotAwareProposer && previousInstructions.length > 0) {
|
|
13548
|
+
contextInfo += `
|
|
13549
|
+
Previous Instructions (avoid repeating): ${previousInstructions.slice(-3).join("; ")}`;
|
|
13550
|
+
}
|
|
13551
|
+
const instructionPrompt = `
|
|
13552
|
+
Generate a high-quality instruction for a language model program.
|
|
13553
|
+
|
|
13554
|
+
${contextInfo}
|
|
13555
|
+
|
|
13556
|
+
${tip ? `Tip: ${tip}` : ""}
|
|
13557
|
+
|
|
13558
|
+
Requirements:
|
|
13559
|
+
1. Be specific and actionable
|
|
13560
|
+
2. Focus on accuracy and clarity
|
|
13561
|
+
3. Consider the program's purpose and data characteristics
|
|
13562
|
+
4. Make the instruction distinct from previous ones
|
|
13563
|
+
5. Keep it concise but comprehensive
|
|
13564
|
+
|
|
13565
|
+
Generate a single, well-crafted instruction:
|
|
13566
|
+
Instruction:`;
|
|
13567
|
+
try {
|
|
13568
|
+
const response = await ai.chat({
|
|
13569
|
+
chatPrompt: [
|
|
13570
|
+
{
|
|
13571
|
+
role: "user",
|
|
13572
|
+
content: instructionPrompt
|
|
13573
|
+
}
|
|
13574
|
+
]
|
|
13575
|
+
});
|
|
13576
|
+
if ("results" in response) {
|
|
13577
|
+
const instruction2 = response.results[0]?.content?.trim();
|
|
13578
|
+
if (instruction2 && instruction2.length > 10) {
|
|
13579
|
+
return instruction2;
|
|
13580
|
+
}
|
|
13581
|
+
}
|
|
13582
|
+
} catch (error) {
|
|
13583
|
+
if (this.isLoggingEnabled()) {
|
|
13584
|
+
this.getLogger()?.(`Failed to generate AI instruction: ${error}`, {
|
|
13585
|
+
tags: ["optimizer", "warning"]
|
|
13586
|
+
});
|
|
13587
|
+
}
|
|
13588
|
+
}
|
|
13589
|
+
const enhancedTemplates = [
|
|
13590
|
+
"Analyze the input systematically and provide a precise, well-reasoned response.",
|
|
13591
|
+
"Think through this step-by-step, considering all relevant factors before responding.",
|
|
13592
|
+
"Examine the input carefully and generate an accurate, detailed answer.",
|
|
13593
|
+
"Process the information methodically and deliver a clear, comprehensive response.",
|
|
13594
|
+
"Consider the context thoroughly and provide a thoughtful, accurate answer."
|
|
13595
|
+
];
|
|
13596
|
+
let instruction = enhancedTemplates[candidateIndex % enhancedTemplates.length] || enhancedTemplates[0];
|
|
13597
|
+
if (tip) {
|
|
13598
|
+
instruction = `${instruction} ${tip}`;
|
|
13599
|
+
}
|
|
13600
|
+
return instruction;
|
|
13601
|
+
}
|
|
13602
|
+
/**
|
|
13603
|
+
* Generates instruction candidates using enhanced AI-powered generation
|
|
13467
13604
|
* @param options Optional compile options that may override teacher AI
|
|
13468
13605
|
* @returns Array of generated instruction candidates
|
|
13469
13606
|
*/
|
|
13470
|
-
async proposeInstructionCandidates(options) {
|
|
13607
|
+
async proposeInstructionCandidates(program, options) {
|
|
13471
13608
|
const instructions = [];
|
|
13472
13609
|
const aiToUse = this.getTeacherOrStudentAI(options);
|
|
13610
|
+
let programSummary;
|
|
13611
|
+
let datasetSummary;
|
|
13612
|
+
if (this.programAwareProposer) {
|
|
13613
|
+
programSummary = await this.generateProgramSummary(program, aiToUse);
|
|
13614
|
+
if (this.isLoggingEnabled(options)) {
|
|
13615
|
+
this.getLogger(options)?.(`Program summary: ${programSummary}`, {
|
|
13616
|
+
tags: ["optimizer", "config"]
|
|
13617
|
+
});
|
|
13618
|
+
}
|
|
13619
|
+
}
|
|
13620
|
+
if (this.dataAwareProposer) {
|
|
13621
|
+
datasetSummary = await this.generateDatasetSummary(this.examples, aiToUse);
|
|
13622
|
+
if (this.isLoggingEnabled(options)) {
|
|
13623
|
+
this.getLogger(options)?.(`Dataset summary: ${datasetSummary}`, {
|
|
13624
|
+
tags: ["optimizer", "config"]
|
|
13625
|
+
});
|
|
13626
|
+
}
|
|
13627
|
+
}
|
|
13473
13628
|
const tips = this.tipAwareProposer ? this.generateTips() : [];
|
|
13474
13629
|
for (let i = 0; i < this.numCandidates; i++) {
|
|
13475
13630
|
const tipIndex = tips.length > 0 ? i % tips.length : -1;
|
|
13476
|
-
const tipToUse = tipIndex >= 0 ? tips[tipIndex] :
|
|
13631
|
+
const tipToUse = tipIndex >= 0 ? tips[tipIndex] : void 0;
|
|
13477
13632
|
const instruction = await this.generateInstruction({
|
|
13478
13633
|
tip: tipToUse,
|
|
13479
13634
|
candidateIndex: i,
|
|
13480
|
-
ai: aiToUse
|
|
13635
|
+
ai: aiToUse,
|
|
13636
|
+
programSummary,
|
|
13637
|
+
datasetSummary,
|
|
13638
|
+
previousInstructions: instructions
|
|
13639
|
+
// Pass previous instructions for diversity
|
|
13481
13640
|
});
|
|
13482
13641
|
instructions.push(instruction);
|
|
13483
13642
|
}
|
|
13484
13643
|
return instructions;
|
|
13485
13644
|
}
|
|
13486
|
-
async generateInstruction({
|
|
13487
|
-
tip,
|
|
13488
|
-
candidateIndex
|
|
13489
|
-
}) {
|
|
13490
|
-
const baseInstructions = [
|
|
13491
|
-
"Analyze the input carefully and provide a detailed response.",
|
|
13492
|
-
"Think step by step and provide a clear answer.",
|
|
13493
|
-
"Consider all aspects of the input before responding.",
|
|
13494
|
-
"Provide a concise but comprehensive response.",
|
|
13495
|
-
"Focus on accuracy and clarity in your response."
|
|
13496
|
-
];
|
|
13497
|
-
let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
|
|
13498
|
-
if (tip) {
|
|
13499
|
-
instruction = `${instruction} ${tip}`;
|
|
13500
|
-
}
|
|
13501
|
-
return instruction;
|
|
13502
|
-
}
|
|
13503
13645
|
/**
|
|
13504
13646
|
* Bootstraps few-shot examples for the program
|
|
13505
13647
|
*/
|
|
@@ -13544,7 +13686,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13544
13686
|
/**
|
|
13545
13687
|
* Runs optimization to find the best combination of few-shot examples and instructions
|
|
13546
13688
|
*/
|
|
13547
|
-
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions,
|
|
13689
|
+
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, validationExamples, metricFn, options) {
|
|
13548
13690
|
let bestConfig = {
|
|
13549
13691
|
instruction: instructions[0] || "",
|
|
13550
13692
|
bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
|
|
@@ -13580,25 +13722,37 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13580
13722
|
);
|
|
13581
13723
|
}
|
|
13582
13724
|
for (let i = startRound; i < this.numTrials; i++) {
|
|
13583
|
-
|
|
13584
|
-
|
|
13585
|
-
|
|
13586
|
-
|
|
13587
|
-
|
|
13588
|
-
|
|
13589
|
-
|
|
13590
|
-
|
|
13591
|
-
|
|
13592
|
-
|
|
13593
|
-
|
|
13725
|
+
let config;
|
|
13726
|
+
if (this.bayesianOptimization && this.miproConfigHistory.length > 2) {
|
|
13727
|
+
config = await this.selectConfigurationViaBayesianOptimization(
|
|
13728
|
+
instructions,
|
|
13729
|
+
bootstrappedDemos,
|
|
13730
|
+
labeledExamples
|
|
13731
|
+
);
|
|
13732
|
+
} else {
|
|
13733
|
+
config = {
|
|
13734
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
13735
|
+
bootstrappedDemos: Math.min(
|
|
13736
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
13737
|
+
this.maxBootstrappedDemos
|
|
13738
|
+
),
|
|
13739
|
+
labeledExamples: Math.min(
|
|
13740
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
13741
|
+
this.maxLabeledDemos
|
|
13742
|
+
)
|
|
13743
|
+
};
|
|
13744
|
+
}
|
|
13594
13745
|
const score = await this.evaluateConfig(
|
|
13595
13746
|
program,
|
|
13596
13747
|
config,
|
|
13597
13748
|
bootstrappedDemos,
|
|
13598
13749
|
labeledExamples,
|
|
13599
|
-
|
|
13600
|
-
metricFn
|
|
13750
|
+
validationExamples,
|
|
13751
|
+
metricFn,
|
|
13752
|
+
i + 1
|
|
13753
|
+
// Pass current trial number for adaptive evaluation
|
|
13601
13754
|
);
|
|
13755
|
+
this.updateSurrogateModel(config, score);
|
|
13602
13756
|
scoreHistory.push(score);
|
|
13603
13757
|
const improvement = score - bestScore;
|
|
13604
13758
|
if (improvement > this.minImprovementThreshold) {
|
|
@@ -13680,7 +13834,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13680
13834
|
this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
|
|
13681
13835
|
return { bestConfig, bestScore };
|
|
13682
13836
|
}
|
|
13683
|
-
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples,
|
|
13837
|
+
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, validationExamples, metricFn, currentTrial = 0) {
|
|
13684
13838
|
const testProgram = { ...program };
|
|
13685
13839
|
this.applyConfigToProgram(
|
|
13686
13840
|
testProgram,
|
|
@@ -13690,12 +13844,31 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13690
13844
|
);
|
|
13691
13845
|
let totalScore = 0;
|
|
13692
13846
|
let count = 0;
|
|
13693
|
-
|
|
13847
|
+
let evalSize;
|
|
13848
|
+
if (this.minibatch) {
|
|
13849
|
+
const baseSize = Math.min(this.minibatchSize, validationExamples.length);
|
|
13850
|
+
const isFullEvalTrial = currentTrial % this.minibatchFullEvalSteps === 0;
|
|
13851
|
+
if (isFullEvalTrial || currentTrial > this.numTrials * 0.8) {
|
|
13852
|
+
evalSize = Math.min(validationExamples.length, baseSize * 2);
|
|
13853
|
+
} else {
|
|
13854
|
+
evalSize = Math.max(3, Math.min(baseSize, validationExamples.length));
|
|
13855
|
+
}
|
|
13856
|
+
} else {
|
|
13857
|
+
evalSize = validationExamples.length;
|
|
13858
|
+
}
|
|
13859
|
+
const evalIndices = this.shuffleArray([
|
|
13860
|
+
...Array(validationExamples.length).keys()
|
|
13861
|
+
]).slice(0, evalSize);
|
|
13862
|
+
const evalSet = evalIndices.map((i) => validationExamples[i]);
|
|
13694
13863
|
for (const example of evalSet) {
|
|
13695
13864
|
try {
|
|
13696
13865
|
const prediction = await testProgram.forward(
|
|
13697
13866
|
this.studentAI,
|
|
13698
|
-
example
|
|
13867
|
+
example,
|
|
13868
|
+
this.sampleCount > 1 ? {
|
|
13869
|
+
sampleCount: this.sampleCount,
|
|
13870
|
+
resultPicker: axMajorityVotePicker()
|
|
13871
|
+
} : void 0
|
|
13699
13872
|
);
|
|
13700
13873
|
const score = await metricFn({ prediction, example });
|
|
13701
13874
|
totalScore += score;
|
|
@@ -13707,6 +13880,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13707
13880
|
}
|
|
13708
13881
|
return count > 0 ? totalScore / count : 0;
|
|
13709
13882
|
}
|
|
13883
|
+
/**
|
|
13884
|
+
* Fisher-Yates shuffle for stochastic evaluation
|
|
13885
|
+
*/
|
|
13886
|
+
shuffleArray(array) {
|
|
13887
|
+
const shuffled = [...array];
|
|
13888
|
+
for (let i = shuffled.length - 1; i > 0; i--) {
|
|
13889
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
13890
|
+
[shuffled[i], shuffled[j]] = [shuffled[j], shuffled[i]];
|
|
13891
|
+
}
|
|
13892
|
+
return shuffled;
|
|
13893
|
+
}
|
|
13710
13894
|
applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
|
|
13711
13895
|
if (program.setInstruction) {
|
|
13712
13896
|
program.setInstruction(config.instruction);
|
|
@@ -13728,14 +13912,14 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13728
13912
|
if (miproOptions?.auto) {
|
|
13729
13913
|
this.configureAuto(miproOptions.auto);
|
|
13730
13914
|
}
|
|
13731
|
-
const
|
|
13915
|
+
const validationExamples = this.getValidationSet(options) || (miproOptions?.validationExamples ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
|
|
13732
13916
|
if (this.isLoggingEnabled(options)) {
|
|
13733
13917
|
this.getLogger(options)?.(
|
|
13734
13918
|
`Starting MIPROv2 optimization with ${this.numTrials} trials`,
|
|
13735
13919
|
{ tags: ["optimizer", "start"] }
|
|
13736
13920
|
);
|
|
13737
13921
|
this.getLogger(options)?.(
|
|
13738
|
-
`Using ${this.examples.length} examples for training and ${
|
|
13922
|
+
`Using ${this.examples.length} examples for training and ${validationExamples.length} for validation`,
|
|
13739
13923
|
{ tags: ["optimizer", "config"] }
|
|
13740
13924
|
);
|
|
13741
13925
|
if (this.teacherAI) {
|
|
@@ -13765,7 +13949,10 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13765
13949
|
);
|
|
13766
13950
|
}
|
|
13767
13951
|
}
|
|
13768
|
-
const instructions = await this.proposeInstructionCandidates(
|
|
13952
|
+
const instructions = await this.proposeInstructionCandidates(
|
|
13953
|
+
program,
|
|
13954
|
+
options
|
|
13955
|
+
);
|
|
13769
13956
|
if (this.isLoggingEnabled(options)) {
|
|
13770
13957
|
this.getLogger(options)?.(
|
|
13771
13958
|
`Generated ${instructions.length} instruction candidates`,
|
|
@@ -13783,7 +13970,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13783
13970
|
bootstrappedDemos,
|
|
13784
13971
|
labeledExamples,
|
|
13785
13972
|
instructions,
|
|
13786
|
-
|
|
13973
|
+
validationExamples,
|
|
13787
13974
|
metricFn,
|
|
13788
13975
|
options
|
|
13789
13976
|
);
|
|
@@ -13842,7 +14029,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13842
14029
|
bootstrappedDemos: bestConfig.bootstrappedDemos,
|
|
13843
14030
|
labeledExamples: bestConfig.labeledExamples,
|
|
13844
14031
|
numCandidates: this.numCandidates,
|
|
13845
|
-
numTrials: this.numTrials
|
|
14032
|
+
numTrials: this.numTrials,
|
|
14033
|
+
sampleCount: this.sampleCount
|
|
13846
14034
|
}
|
|
13847
14035
|
};
|
|
13848
14036
|
}
|
|
@@ -13887,7 +14075,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13887
14075
|
minImprovementThreshold: this.minImprovementThreshold,
|
|
13888
14076
|
bayesianOptimization: this.bayesianOptimization,
|
|
13889
14077
|
acquisitionFunction: this.acquisitionFunction,
|
|
13890
|
-
explorationWeight: this.explorationWeight
|
|
14078
|
+
explorationWeight: this.explorationWeight,
|
|
14079
|
+
sampleCount: this.sampleCount
|
|
13891
14080
|
};
|
|
13892
14081
|
}
|
|
13893
14082
|
/**
|
|
@@ -13922,12 +14111,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13922
14111
|
if (config.minImprovementThreshold !== void 0) {
|
|
13923
14112
|
this.minImprovementThreshold = config.minImprovementThreshold;
|
|
13924
14113
|
}
|
|
14114
|
+
if (config.sampleCount !== void 0) {
|
|
14115
|
+
this.sampleCount = config.sampleCount;
|
|
14116
|
+
}
|
|
13925
14117
|
}
|
|
13926
14118
|
/**
|
|
13927
14119
|
* Reset optimizer state for reuse with different programs
|
|
13928
14120
|
*/
|
|
13929
14121
|
reset() {
|
|
13930
14122
|
super.reset();
|
|
14123
|
+
this.miproConfigHistory = [];
|
|
14124
|
+
this.surrogateModel.clear();
|
|
13931
14125
|
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13932
14126
|
}
|
|
13933
14127
|
/**
|
|
@@ -13945,8 +14139,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13945
14139
|
"Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
|
|
13946
14140
|
);
|
|
13947
14141
|
}
|
|
13948
|
-
const
|
|
13949
|
-
if (
|
|
14142
|
+
const validationSetSize = this.getValidationSet().length;
|
|
14143
|
+
if (validationSetSize < 5) {
|
|
13950
14144
|
result.issues.push(
|
|
13951
14145
|
"Validation set too small for reliable MiPRO optimization"
|
|
13952
14146
|
);
|
|
@@ -13960,6 +14154,141 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13960
14154
|
suggestions: result.suggestions
|
|
13961
14155
|
};
|
|
13962
14156
|
}
|
|
14157
|
+
/**
|
|
14158
|
+
* Encodes a configuration into a string key for surrogate model lookup
|
|
14159
|
+
*/
|
|
14160
|
+
encodeConfiguration(config) {
|
|
14161
|
+
return `${config.instruction.length}_${config.bootstrappedDemos}_${config.labeledExamples}`;
|
|
14162
|
+
}
|
|
14163
|
+
/**
|
|
14164
|
+
* Updates the surrogate model with a new configuration-score pair
|
|
14165
|
+
*/
|
|
14166
|
+
updateSurrogateModel(config, score) {
|
|
14167
|
+
this.miproConfigHistory.push({ config: { ...config }, score });
|
|
14168
|
+
const key = this.encodeConfiguration(config);
|
|
14169
|
+
const similarConfigs = this.miproConfigHistory.filter(
|
|
14170
|
+
(entry) => this.encodeConfiguration(entry.config) === key
|
|
14171
|
+
);
|
|
14172
|
+
if (similarConfigs.length > 0) {
|
|
14173
|
+
const scores = similarConfigs.map((entry) => entry.score);
|
|
14174
|
+
const mean = scores.reduce((sum, s2) => sum + s2, 0) / scores.length;
|
|
14175
|
+
const variance = scores.length > 1 ? scores.reduce((sum, s2) => sum + Math.pow(s2 - mean, 2), 0) / (scores.length - 1) : 0.1;
|
|
14176
|
+
this.surrogateModel.set(key, { mean, variance });
|
|
14177
|
+
}
|
|
14178
|
+
}
|
|
14179
|
+
/**
|
|
14180
|
+
* Predicts performance using the surrogate model
|
|
14181
|
+
*/
|
|
14182
|
+
predictPerformance(config) {
|
|
14183
|
+
const key = this.encodeConfiguration(config);
|
|
14184
|
+
if (this.surrogateModel.has(key)) {
|
|
14185
|
+
return this.surrogateModel.get(key);
|
|
14186
|
+
}
|
|
14187
|
+
if (this.miproConfigHistory.length > 0) {
|
|
14188
|
+
const similarities = this.miproConfigHistory.map((entry) => {
|
|
14189
|
+
const diff = Math.abs(entry.config.bootstrappedDemos - config.bootstrappedDemos) + Math.abs(entry.config.labeledExamples - config.labeledExamples);
|
|
14190
|
+
return { score: entry.score, similarity: 1 / (1 + diff) };
|
|
14191
|
+
});
|
|
14192
|
+
const totalWeight = similarities.reduce((sum, s2) => sum + s2.similarity, 0);
|
|
14193
|
+
const weightedMean = similarities.reduce((sum, s2) => sum + s2.score * s2.similarity, 0) / totalWeight;
|
|
14194
|
+
return { mean: weightedMean, variance: 0.2 };
|
|
14195
|
+
}
|
|
14196
|
+
return { mean: 0.5, variance: 0.3 };
|
|
14197
|
+
}
|
|
14198
|
+
/**
|
|
14199
|
+
* Calculates acquisition function value for Bayesian optimization
|
|
14200
|
+
*/
|
|
14201
|
+
calculateAcquisitionValue(config) {
|
|
14202
|
+
const prediction = this.predictPerformance(config);
|
|
14203
|
+
const { mean, variance } = prediction;
|
|
14204
|
+
const std = Math.sqrt(variance);
|
|
14205
|
+
const bestScore = this.miproConfigHistory.length > 0 ? Math.max(...this.miproConfigHistory.map((entry) => entry.score)) : 0;
|
|
14206
|
+
switch (this.acquisitionFunction) {
|
|
14207
|
+
case "expected_improvement": {
|
|
14208
|
+
const improvement = mean - bestScore;
|
|
14209
|
+
if (std === 0) return Math.max(0, improvement);
|
|
14210
|
+
const z = improvement / std;
|
|
14211
|
+
const phi = 0.5 * (1 + this.erf(z / Math.sqrt(2)));
|
|
14212
|
+
const pdfValue = Math.exp(-0.5 * z * z) / Math.sqrt(2 * Math.PI);
|
|
14213
|
+
return improvement * phi + std * pdfValue;
|
|
14214
|
+
}
|
|
14215
|
+
case "upper_confidence_bound": {
|
|
14216
|
+
return mean + this.explorationWeight * std;
|
|
14217
|
+
}
|
|
14218
|
+
case "probability_improvement": {
|
|
14219
|
+
const improvement = mean - bestScore;
|
|
14220
|
+
if (std === 0) return improvement > 0 ? 1 : 0;
|
|
14221
|
+
const z = improvement / std;
|
|
14222
|
+
return 0.5 * (1 + this.erf(z / Math.sqrt(2)));
|
|
14223
|
+
}
|
|
14224
|
+
default:
|
|
14225
|
+
return mean;
|
|
14226
|
+
}
|
|
14227
|
+
}
|
|
14228
|
+
/**
|
|
14229
|
+
* Error function approximation for acquisition function calculations
|
|
14230
|
+
*/
|
|
14231
|
+
erf(x) {
|
|
14232
|
+
const a1 = 0.254829592;
|
|
14233
|
+
const a2 = -0.284496736;
|
|
14234
|
+
const a3 = 1.421413741;
|
|
14235
|
+
const a4 = -1.453152027;
|
|
14236
|
+
const a5 = 1.061405429;
|
|
14237
|
+
const p = 0.3275911;
|
|
14238
|
+
const sign = x >= 0 ? 1 : -1;
|
|
14239
|
+
x = Math.abs(x);
|
|
14240
|
+
const t = 1 / (1 + p * x);
|
|
14241
|
+
const y = 1 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.exp(-x * x);
|
|
14242
|
+
return sign * y;
|
|
14243
|
+
}
|
|
14244
|
+
/**
|
|
14245
|
+
* Selects the next configuration to evaluate using Bayesian optimization
|
|
14246
|
+
*/
|
|
14247
|
+
async selectConfigurationViaBayesianOptimization(instructions, bootstrappedDemos, labeledExamples) {
|
|
14248
|
+
const candidates = [];
|
|
14249
|
+
const numCandidates = Math.min(20, instructions.length * 3);
|
|
14250
|
+
for (let i = 0; i < numCandidates; i++) {
|
|
14251
|
+
const config = {
|
|
14252
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
14253
|
+
bootstrappedDemos: Math.min(
|
|
14254
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
14255
|
+
this.maxBootstrappedDemos
|
|
14256
|
+
),
|
|
14257
|
+
labeledExamples: Math.min(
|
|
14258
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
14259
|
+
this.maxLabeledDemos
|
|
14260
|
+
)
|
|
14261
|
+
};
|
|
14262
|
+
const acquisitionValue = this.calculateAcquisitionValue(config);
|
|
14263
|
+
candidates.push({ config, acquisitionValue });
|
|
14264
|
+
}
|
|
14265
|
+
candidates.sort((a, b) => b.acquisitionValue - a.acquisitionValue);
|
|
14266
|
+
return candidates[0].config;
|
|
14267
|
+
}
|
|
14268
|
+
};
|
|
14269
|
+
var axMajorityVotePicker = () => {
|
|
14270
|
+
return async (data) => {
|
|
14271
|
+
if (data.type === "fields") {
|
|
14272
|
+
const counts = {};
|
|
14273
|
+
for (const { index, sample } of data.results) {
|
|
14274
|
+
const key = JSON.stringify(sample);
|
|
14275
|
+
if (!counts[key]) {
|
|
14276
|
+
counts[key] = { count: 0, index };
|
|
14277
|
+
}
|
|
14278
|
+
counts[key].count += 1;
|
|
14279
|
+
}
|
|
14280
|
+
let bestKey;
|
|
14281
|
+
let bestCount = -1;
|
|
14282
|
+
for (const [k, v] of Object.entries(counts)) {
|
|
14283
|
+
if (v.count > bestCount) {
|
|
14284
|
+
bestCount = v.count;
|
|
14285
|
+
bestKey = k;
|
|
14286
|
+
}
|
|
14287
|
+
}
|
|
14288
|
+
return counts[bestKey]?.index ?? 0;
|
|
14289
|
+
}
|
|
14290
|
+
return data.results[0]?.index ?? 0;
|
|
14291
|
+
};
|
|
13963
14292
|
};
|
|
13964
14293
|
|
|
13965
14294
|
// ai/mock/api.ts
|