@ax-llm/ax 12.0.13 → 12.0.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +409 -64
- package/index.cjs.map +1 -1
- package/index.d.cts +47 -4
- package/index.d.ts +47 -4
- package/index.js +409 -64
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.js
CHANGED
|
@@ -7536,6 +7536,22 @@ function parseFunctionCalls(ai, functionCalls, values, model) {
|
|
|
7536
7536
|
}));
|
|
7537
7537
|
return funcs;
|
|
7538
7538
|
}
|
|
7539
|
+
function createFunctionConfig(functionList, definedFunctionCall, firstStep) {
|
|
7540
|
+
let functionCall = definedFunctionCall;
|
|
7541
|
+
if (!firstStep && (functionCall === "required" || typeof functionCall === "function")) {
|
|
7542
|
+
return { functions: [], functionCall: void 0 };
|
|
7543
|
+
}
|
|
7544
|
+
if (!functionList) {
|
|
7545
|
+
return { functions: [], functionCall };
|
|
7546
|
+
}
|
|
7547
|
+
const functions = functionList.map((f2) => {
|
|
7548
|
+
if ("toFunction" in f2) {
|
|
7549
|
+
return f2.toFunction();
|
|
7550
|
+
}
|
|
7551
|
+
return f2;
|
|
7552
|
+
}).flat();
|
|
7553
|
+
return { functions, functionCall };
|
|
7554
|
+
}
|
|
7539
7555
|
|
|
7540
7556
|
// dsp/processResponse.ts
|
|
7541
7557
|
import "stream/web";
|
|
@@ -9984,7 +10000,8 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9984
10000
|
mem,
|
|
9985
10001
|
options,
|
|
9986
10002
|
traceContext,
|
|
9987
|
-
|
|
10003
|
+
functions,
|
|
10004
|
+
functionCall
|
|
9988
10005
|
}) {
|
|
9989
10006
|
const {
|
|
9990
10007
|
sessionId,
|
|
@@ -9992,8 +10009,6 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9992
10009
|
model,
|
|
9993
10010
|
rateLimiter,
|
|
9994
10011
|
stream,
|
|
9995
|
-
functions: _functions,
|
|
9996
|
-
functionCall: _functionCall,
|
|
9997
10012
|
thinkingTokenBudget,
|
|
9998
10013
|
showThoughts
|
|
9999
10014
|
} = options ?? {};
|
|
@@ -10004,11 +10019,6 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
10004
10019
|
if (chatPrompt.length === 0) {
|
|
10005
10020
|
throw new Error("No chat prompt found");
|
|
10006
10021
|
}
|
|
10007
|
-
const functions = _functions?.map((f2) => "toFunction" in f2 ? f2.toFunction() : f2)?.flat();
|
|
10008
|
-
let functionCall = _functionCall ?? this.options?.functionCall;
|
|
10009
|
-
if (!firstStep && (functionCall === "required" || typeof functionCall === "function")) {
|
|
10010
|
-
functionCall = void 0;
|
|
10011
|
-
}
|
|
10012
10022
|
const modelConfig = {
|
|
10013
10023
|
...options?.modelConfig,
|
|
10014
10024
|
...options?.sampleCount ? { n: options.sampleCount } : {},
|
|
@@ -10045,18 +10055,24 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
10045
10055
|
span,
|
|
10046
10056
|
traceContext
|
|
10047
10057
|
}) {
|
|
10048
|
-
const { sessionId, traceId, functions:
|
|
10058
|
+
const { sessionId, traceId, functions: functionList } = options ?? {};
|
|
10059
|
+
const definedFunctionCall = options?.functionCall ?? this.options?.functionCall;
|
|
10049
10060
|
const strictMode = options?.strictMode ?? false;
|
|
10050
10061
|
const model = options.model;
|
|
10051
10062
|
const states = this.createStates(options.sampleCount ?? 1);
|
|
10052
10063
|
const usage = this.usage;
|
|
10053
|
-
const functions
|
|
10064
|
+
const { functions, functionCall } = createFunctionConfig(
|
|
10065
|
+
functionList,
|
|
10066
|
+
definedFunctionCall,
|
|
10067
|
+
firstStep
|
|
10068
|
+
);
|
|
10054
10069
|
const res = await this.forwardSendRequest({
|
|
10055
10070
|
ai,
|
|
10056
10071
|
mem,
|
|
10057
10072
|
options,
|
|
10058
10073
|
traceContext,
|
|
10059
|
-
|
|
10074
|
+
functions,
|
|
10075
|
+
functionCall
|
|
10060
10076
|
});
|
|
10061
10077
|
if (res instanceof ReadableStream3) {
|
|
10062
10078
|
yield* processStreamingResponse({
|
|
@@ -11700,13 +11716,6 @@ var AxBaseOptimizer = class {
|
|
|
11700
11716
|
if (this.logger) {
|
|
11701
11717
|
return this.logger;
|
|
11702
11718
|
}
|
|
11703
|
-
try {
|
|
11704
|
-
const aiLogger = this.studentAI.getLogger();
|
|
11705
|
-
if (aiLogger) {
|
|
11706
|
-
return aiLogger;
|
|
11707
|
-
}
|
|
11708
|
-
} catch {
|
|
11709
|
-
}
|
|
11710
11719
|
return axDefaultOptimizerLogger;
|
|
11711
11720
|
}
|
|
11712
11721
|
/**
|
|
@@ -13398,6 +13407,11 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13398
13407
|
bayesianOptimization;
|
|
13399
13408
|
acquisitionFunction;
|
|
13400
13409
|
explorationWeight;
|
|
13410
|
+
// Self-consistency / multiple sampling
|
|
13411
|
+
sampleCount;
|
|
13412
|
+
// Surrogate model state for Bayesian optimization
|
|
13413
|
+
miproConfigHistory = [];
|
|
13414
|
+
surrogateModel = /* @__PURE__ */ new Map();
|
|
13401
13415
|
constructor(args) {
|
|
13402
13416
|
super(args);
|
|
13403
13417
|
const options = args.options || {};
|
|
@@ -13419,6 +13433,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13419
13433
|
this.bayesianOptimization = options.bayesianOptimization ?? false;
|
|
13420
13434
|
this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
|
|
13421
13435
|
this.explorationWeight = options.explorationWeight ?? 0.1;
|
|
13436
|
+
this.sampleCount = options.sampleCount ?? 1;
|
|
13422
13437
|
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13423
13438
|
}
|
|
13424
13439
|
/**
|
|
@@ -13463,43 +13478,186 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13463
13478
|
];
|
|
13464
13479
|
}
|
|
13465
13480
|
/**
|
|
13466
|
-
* Generates
|
|
13481
|
+
* Generates program summary for context-aware instruction generation
|
|
13482
|
+
*/
|
|
13483
|
+
async generateProgramSummary(program, ai) {
|
|
13484
|
+
let signature = "input -> output";
|
|
13485
|
+
if ("getSignature" in program && typeof program.getSignature === "function") {
|
|
13486
|
+
signature = program.getSignature();
|
|
13487
|
+
}
|
|
13488
|
+
const summaryPrompt = `
|
|
13489
|
+
Analyze this language model program and provide a concise summary of its purpose and structure.
|
|
13490
|
+
|
|
13491
|
+
Program Signature: ${signature}
|
|
13492
|
+
|
|
13493
|
+
Provide a 2-3 sentence summary focusing on:
|
|
13494
|
+
1. The main task or purpose of this program
|
|
13495
|
+
2. The input-output relationship
|
|
13496
|
+
3. Any special constraints or requirements
|
|
13497
|
+
|
|
13498
|
+
Summary:`;
|
|
13499
|
+
try {
|
|
13500
|
+
const response = await ai.chat({
|
|
13501
|
+
chatPrompt: [{ role: "user", content: summaryPrompt }]
|
|
13502
|
+
});
|
|
13503
|
+
if ("results" in response) {
|
|
13504
|
+
return response.results[0]?.content?.trim() || "General language model program";
|
|
13505
|
+
}
|
|
13506
|
+
return "General language model program";
|
|
13507
|
+
} catch {
|
|
13508
|
+
return "General language model program";
|
|
13509
|
+
}
|
|
13510
|
+
}
|
|
13511
|
+
/**
|
|
13512
|
+
* Generates dataset summary for context-aware instruction generation
|
|
13513
|
+
*/
|
|
13514
|
+
async generateDatasetSummary(examples, ai) {
|
|
13515
|
+
if (examples.length === 0) return "No examples available";
|
|
13516
|
+
const sampleSize = Math.min(this.viewDataBatchSize, examples.length);
|
|
13517
|
+
const sampledExamples = examples.slice(0, sampleSize);
|
|
13518
|
+
const exampleTexts = sampledExamples.map((ex, i) => `Example ${i + 1}: ${JSON.stringify(ex)}`).join("\n");
|
|
13519
|
+
const summaryPrompt = `
|
|
13520
|
+
Analyze this dataset and provide a concise summary of its characteristics.
|
|
13521
|
+
|
|
13522
|
+
Sample Examples:
|
|
13523
|
+
${exampleTexts}
|
|
13524
|
+
|
|
13525
|
+
Provide a 2-3 sentence summary focusing on:
|
|
13526
|
+
1. The type of data and domain
|
|
13527
|
+
2. Common patterns or structures in the examples
|
|
13528
|
+
3. Key challenges or requirements for processing this data
|
|
13529
|
+
|
|
13530
|
+
Dataset Summary:`;
|
|
13531
|
+
try {
|
|
13532
|
+
const response = await ai.chat({
|
|
13533
|
+
chatPrompt: [{ role: "user", content: summaryPrompt }]
|
|
13534
|
+
});
|
|
13535
|
+
if ("results" in response) {
|
|
13536
|
+
return response.results[0]?.content?.trim() || "General dataset";
|
|
13537
|
+
}
|
|
13538
|
+
return "General dataset";
|
|
13539
|
+
} catch {
|
|
13540
|
+
return "General dataset";
|
|
13541
|
+
}
|
|
13542
|
+
}
|
|
13543
|
+
/**
|
|
13544
|
+
* Enhanced instruction generation using AI with program and data awareness
|
|
13545
|
+
*/
|
|
13546
|
+
async generateInstruction({
|
|
13547
|
+
tip,
|
|
13548
|
+
candidateIndex,
|
|
13549
|
+
ai,
|
|
13550
|
+
programSummary,
|
|
13551
|
+
datasetSummary,
|
|
13552
|
+
previousInstructions = []
|
|
13553
|
+
}) {
|
|
13554
|
+
let contextInfo = "";
|
|
13555
|
+
if (this.programAwareProposer && programSummary) {
|
|
13556
|
+
contextInfo += `
|
|
13557
|
+
Program Context: ${programSummary}`;
|
|
13558
|
+
}
|
|
13559
|
+
if (this.dataAwareProposer && datasetSummary) {
|
|
13560
|
+
contextInfo += `
|
|
13561
|
+
Dataset Context: ${datasetSummary}`;
|
|
13562
|
+
}
|
|
13563
|
+
if (this.fewshotAwareProposer && previousInstructions.length > 0) {
|
|
13564
|
+
contextInfo += `
|
|
13565
|
+
Previous Instructions (avoid repeating): ${previousInstructions.slice(-3).join("; ")}`;
|
|
13566
|
+
}
|
|
13567
|
+
const instructionPrompt = `
|
|
13568
|
+
Generate a high-quality instruction for a language model program.
|
|
13569
|
+
|
|
13570
|
+
${contextInfo}
|
|
13571
|
+
|
|
13572
|
+
${tip ? `Tip: ${tip}` : ""}
|
|
13573
|
+
|
|
13574
|
+
Requirements:
|
|
13575
|
+
1. Be specific and actionable
|
|
13576
|
+
2. Focus on accuracy and clarity
|
|
13577
|
+
3. Consider the program's purpose and data characteristics
|
|
13578
|
+
4. Make the instruction distinct from previous ones
|
|
13579
|
+
5. Keep it concise but comprehensive
|
|
13580
|
+
|
|
13581
|
+
Generate a single, well-crafted instruction:
|
|
13582
|
+
Instruction:`;
|
|
13583
|
+
try {
|
|
13584
|
+
const response = await ai.chat({
|
|
13585
|
+
chatPrompt: [
|
|
13586
|
+
{
|
|
13587
|
+
role: "user",
|
|
13588
|
+
content: instructionPrompt
|
|
13589
|
+
}
|
|
13590
|
+
]
|
|
13591
|
+
});
|
|
13592
|
+
if ("results" in response) {
|
|
13593
|
+
const instruction2 = response.results[0]?.content?.trim();
|
|
13594
|
+
if (instruction2 && instruction2.length > 10) {
|
|
13595
|
+
return instruction2;
|
|
13596
|
+
}
|
|
13597
|
+
}
|
|
13598
|
+
} catch (error) {
|
|
13599
|
+
if (this.isLoggingEnabled()) {
|
|
13600
|
+
this.getLogger()?.(`Failed to generate AI instruction: ${error}`, {
|
|
13601
|
+
tags: ["optimizer", "warning"]
|
|
13602
|
+
});
|
|
13603
|
+
}
|
|
13604
|
+
}
|
|
13605
|
+
const enhancedTemplates = [
|
|
13606
|
+
"Analyze the input systematically and provide a precise, well-reasoned response.",
|
|
13607
|
+
"Think through this step-by-step, considering all relevant factors before responding.",
|
|
13608
|
+
"Examine the input carefully and generate an accurate, detailed answer.",
|
|
13609
|
+
"Process the information methodically and deliver a clear, comprehensive response.",
|
|
13610
|
+
"Consider the context thoroughly and provide a thoughtful, accurate answer."
|
|
13611
|
+
];
|
|
13612
|
+
let instruction = enhancedTemplates[candidateIndex % enhancedTemplates.length] || enhancedTemplates[0];
|
|
13613
|
+
if (tip) {
|
|
13614
|
+
instruction = `${instruction} ${tip}`;
|
|
13615
|
+
}
|
|
13616
|
+
return instruction;
|
|
13617
|
+
}
|
|
13618
|
+
/**
|
|
13619
|
+
* Generates instruction candidates using enhanced AI-powered generation
|
|
13467
13620
|
* @param options Optional compile options that may override teacher AI
|
|
13468
13621
|
* @returns Array of generated instruction candidates
|
|
13469
13622
|
*/
|
|
13470
|
-
async proposeInstructionCandidates(options) {
|
|
13623
|
+
async proposeInstructionCandidates(program, options) {
|
|
13471
13624
|
const instructions = [];
|
|
13472
13625
|
const aiToUse = this.getTeacherOrStudentAI(options);
|
|
13626
|
+
let programSummary;
|
|
13627
|
+
let datasetSummary;
|
|
13628
|
+
if (this.programAwareProposer) {
|
|
13629
|
+
programSummary = await this.generateProgramSummary(program, aiToUse);
|
|
13630
|
+
if (this.isLoggingEnabled(options)) {
|
|
13631
|
+
this.getLogger(options)?.(`Program summary: ${programSummary}`, {
|
|
13632
|
+
tags: ["optimizer", "config"]
|
|
13633
|
+
});
|
|
13634
|
+
}
|
|
13635
|
+
}
|
|
13636
|
+
if (this.dataAwareProposer) {
|
|
13637
|
+
datasetSummary = await this.generateDatasetSummary(this.examples, aiToUse);
|
|
13638
|
+
if (this.isLoggingEnabled(options)) {
|
|
13639
|
+
this.getLogger(options)?.(`Dataset summary: ${datasetSummary}`, {
|
|
13640
|
+
tags: ["optimizer", "config"]
|
|
13641
|
+
});
|
|
13642
|
+
}
|
|
13643
|
+
}
|
|
13473
13644
|
const tips = this.tipAwareProposer ? this.generateTips() : [];
|
|
13474
13645
|
for (let i = 0; i < this.numCandidates; i++) {
|
|
13475
13646
|
const tipIndex = tips.length > 0 ? i % tips.length : -1;
|
|
13476
|
-
const tipToUse = tipIndex >= 0 ? tips[tipIndex] :
|
|
13647
|
+
const tipToUse = tipIndex >= 0 ? tips[tipIndex] : void 0;
|
|
13477
13648
|
const instruction = await this.generateInstruction({
|
|
13478
13649
|
tip: tipToUse,
|
|
13479
13650
|
candidateIndex: i,
|
|
13480
|
-
ai: aiToUse
|
|
13651
|
+
ai: aiToUse,
|
|
13652
|
+
programSummary,
|
|
13653
|
+
datasetSummary,
|
|
13654
|
+
previousInstructions: instructions
|
|
13655
|
+
// Pass previous instructions for diversity
|
|
13481
13656
|
});
|
|
13482
13657
|
instructions.push(instruction);
|
|
13483
13658
|
}
|
|
13484
13659
|
return instructions;
|
|
13485
13660
|
}
|
|
13486
|
-
async generateInstruction({
|
|
13487
|
-
tip,
|
|
13488
|
-
candidateIndex
|
|
13489
|
-
}) {
|
|
13490
|
-
const baseInstructions = [
|
|
13491
|
-
"Analyze the input carefully and provide a detailed response.",
|
|
13492
|
-
"Think step by step and provide a clear answer.",
|
|
13493
|
-
"Consider all aspects of the input before responding.",
|
|
13494
|
-
"Provide a concise but comprehensive response.",
|
|
13495
|
-
"Focus on accuracy and clarity in your response."
|
|
13496
|
-
];
|
|
13497
|
-
let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
|
|
13498
|
-
if (tip) {
|
|
13499
|
-
instruction = `${instruction} ${tip}`;
|
|
13500
|
-
}
|
|
13501
|
-
return instruction;
|
|
13502
|
-
}
|
|
13503
13661
|
/**
|
|
13504
13662
|
* Bootstraps few-shot examples for the program
|
|
13505
13663
|
*/
|
|
@@ -13544,7 +13702,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13544
13702
|
/**
|
|
13545
13703
|
* Runs optimization to find the best combination of few-shot examples and instructions
|
|
13546
13704
|
*/
|
|
13547
|
-
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions,
|
|
13705
|
+
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, validationExamples, metricFn, options) {
|
|
13548
13706
|
let bestConfig = {
|
|
13549
13707
|
instruction: instructions[0] || "",
|
|
13550
13708
|
bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
|
|
@@ -13580,25 +13738,37 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13580
13738
|
);
|
|
13581
13739
|
}
|
|
13582
13740
|
for (let i = startRound; i < this.numTrials; i++) {
|
|
13583
|
-
|
|
13584
|
-
|
|
13585
|
-
|
|
13586
|
-
|
|
13587
|
-
|
|
13588
|
-
|
|
13589
|
-
|
|
13590
|
-
|
|
13591
|
-
|
|
13592
|
-
|
|
13593
|
-
|
|
13741
|
+
let config;
|
|
13742
|
+
if (this.bayesianOptimization && this.miproConfigHistory.length > 2) {
|
|
13743
|
+
config = await this.selectConfigurationViaBayesianOptimization(
|
|
13744
|
+
instructions,
|
|
13745
|
+
bootstrappedDemos,
|
|
13746
|
+
labeledExamples
|
|
13747
|
+
);
|
|
13748
|
+
} else {
|
|
13749
|
+
config = {
|
|
13750
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
13751
|
+
bootstrappedDemos: Math.min(
|
|
13752
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
13753
|
+
this.maxBootstrappedDemos
|
|
13754
|
+
),
|
|
13755
|
+
labeledExamples: Math.min(
|
|
13756
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
13757
|
+
this.maxLabeledDemos
|
|
13758
|
+
)
|
|
13759
|
+
};
|
|
13760
|
+
}
|
|
13594
13761
|
const score = await this.evaluateConfig(
|
|
13595
13762
|
program,
|
|
13596
13763
|
config,
|
|
13597
13764
|
bootstrappedDemos,
|
|
13598
13765
|
labeledExamples,
|
|
13599
|
-
|
|
13600
|
-
metricFn
|
|
13766
|
+
validationExamples,
|
|
13767
|
+
metricFn,
|
|
13768
|
+
i + 1
|
|
13769
|
+
// Pass current trial number for adaptive evaluation
|
|
13601
13770
|
);
|
|
13771
|
+
this.updateSurrogateModel(config, score);
|
|
13602
13772
|
scoreHistory.push(score);
|
|
13603
13773
|
const improvement = score - bestScore;
|
|
13604
13774
|
if (improvement > this.minImprovementThreshold) {
|
|
@@ -13680,7 +13850,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13680
13850
|
this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
|
|
13681
13851
|
return { bestConfig, bestScore };
|
|
13682
13852
|
}
|
|
13683
|
-
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples,
|
|
13853
|
+
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, validationExamples, metricFn, currentTrial = 0) {
|
|
13684
13854
|
const testProgram = { ...program };
|
|
13685
13855
|
this.applyConfigToProgram(
|
|
13686
13856
|
testProgram,
|
|
@@ -13690,12 +13860,31 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13690
13860
|
);
|
|
13691
13861
|
let totalScore = 0;
|
|
13692
13862
|
let count = 0;
|
|
13693
|
-
|
|
13863
|
+
let evalSize;
|
|
13864
|
+
if (this.minibatch) {
|
|
13865
|
+
const baseSize = Math.min(this.minibatchSize, validationExamples.length);
|
|
13866
|
+
const isFullEvalTrial = currentTrial % this.minibatchFullEvalSteps === 0;
|
|
13867
|
+
if (isFullEvalTrial || currentTrial > this.numTrials * 0.8) {
|
|
13868
|
+
evalSize = Math.min(validationExamples.length, baseSize * 2);
|
|
13869
|
+
} else {
|
|
13870
|
+
evalSize = Math.max(3, Math.min(baseSize, validationExamples.length));
|
|
13871
|
+
}
|
|
13872
|
+
} else {
|
|
13873
|
+
evalSize = validationExamples.length;
|
|
13874
|
+
}
|
|
13875
|
+
const evalIndices = this.shuffleArray([
|
|
13876
|
+
...Array(validationExamples.length).keys()
|
|
13877
|
+
]).slice(0, evalSize);
|
|
13878
|
+
const evalSet = evalIndices.map((i) => validationExamples[i]);
|
|
13694
13879
|
for (const example of evalSet) {
|
|
13695
13880
|
try {
|
|
13696
13881
|
const prediction = await testProgram.forward(
|
|
13697
13882
|
this.studentAI,
|
|
13698
|
-
example
|
|
13883
|
+
example,
|
|
13884
|
+
this.sampleCount > 1 ? {
|
|
13885
|
+
sampleCount: this.sampleCount,
|
|
13886
|
+
resultPicker: axMajorityVotePicker()
|
|
13887
|
+
} : void 0
|
|
13699
13888
|
);
|
|
13700
13889
|
const score = await metricFn({ prediction, example });
|
|
13701
13890
|
totalScore += score;
|
|
@@ -13707,6 +13896,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13707
13896
|
}
|
|
13708
13897
|
return count > 0 ? totalScore / count : 0;
|
|
13709
13898
|
}
|
|
13899
|
+
/**
|
|
13900
|
+
* Fisher-Yates shuffle for stochastic evaluation
|
|
13901
|
+
*/
|
|
13902
|
+
shuffleArray(array) {
|
|
13903
|
+
const shuffled = [...array];
|
|
13904
|
+
for (let i = shuffled.length - 1; i > 0; i--) {
|
|
13905
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
13906
|
+
[shuffled[i], shuffled[j]] = [shuffled[j], shuffled[i]];
|
|
13907
|
+
}
|
|
13908
|
+
return shuffled;
|
|
13909
|
+
}
|
|
13710
13910
|
applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
|
|
13711
13911
|
if (program.setInstruction) {
|
|
13712
13912
|
program.setInstruction(config.instruction);
|
|
@@ -13728,14 +13928,14 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13728
13928
|
if (miproOptions?.auto) {
|
|
13729
13929
|
this.configureAuto(miproOptions.auto);
|
|
13730
13930
|
}
|
|
13731
|
-
const
|
|
13931
|
+
const validationExamples = this.getValidationSet(options) || (miproOptions?.validationExamples ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
|
|
13732
13932
|
if (this.isLoggingEnabled(options)) {
|
|
13733
13933
|
this.getLogger(options)?.(
|
|
13734
13934
|
`Starting MIPROv2 optimization with ${this.numTrials} trials`,
|
|
13735
13935
|
{ tags: ["optimizer", "start"] }
|
|
13736
13936
|
);
|
|
13737
13937
|
this.getLogger(options)?.(
|
|
13738
|
-
`Using ${this.examples.length} examples for training and ${
|
|
13938
|
+
`Using ${this.examples.length} examples for training and ${validationExamples.length} for validation`,
|
|
13739
13939
|
{ tags: ["optimizer", "config"] }
|
|
13740
13940
|
);
|
|
13741
13941
|
if (this.teacherAI) {
|
|
@@ -13765,7 +13965,10 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13765
13965
|
);
|
|
13766
13966
|
}
|
|
13767
13967
|
}
|
|
13768
|
-
const instructions = await this.proposeInstructionCandidates(
|
|
13968
|
+
const instructions = await this.proposeInstructionCandidates(
|
|
13969
|
+
program,
|
|
13970
|
+
options
|
|
13971
|
+
);
|
|
13769
13972
|
if (this.isLoggingEnabled(options)) {
|
|
13770
13973
|
this.getLogger(options)?.(
|
|
13771
13974
|
`Generated ${instructions.length} instruction candidates`,
|
|
@@ -13783,7 +13986,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13783
13986
|
bootstrappedDemos,
|
|
13784
13987
|
labeledExamples,
|
|
13785
13988
|
instructions,
|
|
13786
|
-
|
|
13989
|
+
validationExamples,
|
|
13787
13990
|
metricFn,
|
|
13788
13991
|
options
|
|
13789
13992
|
);
|
|
@@ -13842,7 +14045,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13842
14045
|
bootstrappedDemos: bestConfig.bootstrappedDemos,
|
|
13843
14046
|
labeledExamples: bestConfig.labeledExamples,
|
|
13844
14047
|
numCandidates: this.numCandidates,
|
|
13845
|
-
numTrials: this.numTrials
|
|
14048
|
+
numTrials: this.numTrials,
|
|
14049
|
+
sampleCount: this.sampleCount
|
|
13846
14050
|
}
|
|
13847
14051
|
};
|
|
13848
14052
|
}
|
|
@@ -13887,7 +14091,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13887
14091
|
minImprovementThreshold: this.minImprovementThreshold,
|
|
13888
14092
|
bayesianOptimization: this.bayesianOptimization,
|
|
13889
14093
|
acquisitionFunction: this.acquisitionFunction,
|
|
13890
|
-
explorationWeight: this.explorationWeight
|
|
14094
|
+
explorationWeight: this.explorationWeight,
|
|
14095
|
+
sampleCount: this.sampleCount
|
|
13891
14096
|
};
|
|
13892
14097
|
}
|
|
13893
14098
|
/**
|
|
@@ -13922,12 +14127,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13922
14127
|
if (config.minImprovementThreshold !== void 0) {
|
|
13923
14128
|
this.minImprovementThreshold = config.minImprovementThreshold;
|
|
13924
14129
|
}
|
|
14130
|
+
if (config.sampleCount !== void 0) {
|
|
14131
|
+
this.sampleCount = config.sampleCount;
|
|
14132
|
+
}
|
|
13925
14133
|
}
|
|
13926
14134
|
/**
|
|
13927
14135
|
* Reset optimizer state for reuse with different programs
|
|
13928
14136
|
*/
|
|
13929
14137
|
reset() {
|
|
13930
14138
|
super.reset();
|
|
14139
|
+
this.miproConfigHistory = [];
|
|
14140
|
+
this.surrogateModel.clear();
|
|
13931
14141
|
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13932
14142
|
}
|
|
13933
14143
|
/**
|
|
@@ -13945,8 +14155,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13945
14155
|
"Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
|
|
13946
14156
|
);
|
|
13947
14157
|
}
|
|
13948
|
-
const
|
|
13949
|
-
if (
|
|
14158
|
+
const validationSetSize = this.getValidationSet().length;
|
|
14159
|
+
if (validationSetSize < 5) {
|
|
13950
14160
|
result.issues.push(
|
|
13951
14161
|
"Validation set too small for reliable MiPRO optimization"
|
|
13952
14162
|
);
|
|
@@ -13960,6 +14170,141 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13960
14170
|
suggestions: result.suggestions
|
|
13961
14171
|
};
|
|
13962
14172
|
}
|
|
14173
|
+
/**
|
|
14174
|
+
* Encodes a configuration into a string key for surrogate model lookup
|
|
14175
|
+
*/
|
|
14176
|
+
encodeConfiguration(config) {
|
|
14177
|
+
return `${config.instruction.length}_${config.bootstrappedDemos}_${config.labeledExamples}`;
|
|
14178
|
+
}
|
|
14179
|
+
/**
|
|
14180
|
+
* Updates the surrogate model with a new configuration-score pair
|
|
14181
|
+
*/
|
|
14182
|
+
updateSurrogateModel(config, score) {
|
|
14183
|
+
this.miproConfigHistory.push({ config: { ...config }, score });
|
|
14184
|
+
const key = this.encodeConfiguration(config);
|
|
14185
|
+
const similarConfigs = this.miproConfigHistory.filter(
|
|
14186
|
+
(entry) => this.encodeConfiguration(entry.config) === key
|
|
14187
|
+
);
|
|
14188
|
+
if (similarConfigs.length > 0) {
|
|
14189
|
+
const scores = similarConfigs.map((entry) => entry.score);
|
|
14190
|
+
const mean = scores.reduce((sum, s2) => sum + s2, 0) / scores.length;
|
|
14191
|
+
const variance = scores.length > 1 ? scores.reduce((sum, s2) => sum + Math.pow(s2 - mean, 2), 0) / (scores.length - 1) : 0.1;
|
|
14192
|
+
this.surrogateModel.set(key, { mean, variance });
|
|
14193
|
+
}
|
|
14194
|
+
}
|
|
14195
|
+
/**
|
|
14196
|
+
* Predicts performance using the surrogate model
|
|
14197
|
+
*/
|
|
14198
|
+
predictPerformance(config) {
|
|
14199
|
+
const key = this.encodeConfiguration(config);
|
|
14200
|
+
if (this.surrogateModel.has(key)) {
|
|
14201
|
+
return this.surrogateModel.get(key);
|
|
14202
|
+
}
|
|
14203
|
+
if (this.miproConfigHistory.length > 0) {
|
|
14204
|
+
const similarities = this.miproConfigHistory.map((entry) => {
|
|
14205
|
+
const diff = Math.abs(entry.config.bootstrappedDemos - config.bootstrappedDemos) + Math.abs(entry.config.labeledExamples - config.labeledExamples);
|
|
14206
|
+
return { score: entry.score, similarity: 1 / (1 + diff) };
|
|
14207
|
+
});
|
|
14208
|
+
const totalWeight = similarities.reduce((sum, s2) => sum + s2.similarity, 0);
|
|
14209
|
+
const weightedMean = similarities.reduce((sum, s2) => sum + s2.score * s2.similarity, 0) / totalWeight;
|
|
14210
|
+
return { mean: weightedMean, variance: 0.2 };
|
|
14211
|
+
}
|
|
14212
|
+
return { mean: 0.5, variance: 0.3 };
|
|
14213
|
+
}
|
|
14214
|
+
/**
|
|
14215
|
+
* Calculates acquisition function value for Bayesian optimization
|
|
14216
|
+
*/
|
|
14217
|
+
calculateAcquisitionValue(config) {
|
|
14218
|
+
const prediction = this.predictPerformance(config);
|
|
14219
|
+
const { mean, variance } = prediction;
|
|
14220
|
+
const std = Math.sqrt(variance);
|
|
14221
|
+
const bestScore = this.miproConfigHistory.length > 0 ? Math.max(...this.miproConfigHistory.map((entry) => entry.score)) : 0;
|
|
14222
|
+
switch (this.acquisitionFunction) {
|
|
14223
|
+
case "expected_improvement": {
|
|
14224
|
+
const improvement = mean - bestScore;
|
|
14225
|
+
if (std === 0) return Math.max(0, improvement);
|
|
14226
|
+
const z = improvement / std;
|
|
14227
|
+
const phi = 0.5 * (1 + this.erf(z / Math.sqrt(2)));
|
|
14228
|
+
const pdfValue = Math.exp(-0.5 * z * z) / Math.sqrt(2 * Math.PI);
|
|
14229
|
+
return improvement * phi + std * pdfValue;
|
|
14230
|
+
}
|
|
14231
|
+
case "upper_confidence_bound": {
|
|
14232
|
+
return mean + this.explorationWeight * std;
|
|
14233
|
+
}
|
|
14234
|
+
case "probability_improvement": {
|
|
14235
|
+
const improvement = mean - bestScore;
|
|
14236
|
+
if (std === 0) return improvement > 0 ? 1 : 0;
|
|
14237
|
+
const z = improvement / std;
|
|
14238
|
+
return 0.5 * (1 + this.erf(z / Math.sqrt(2)));
|
|
14239
|
+
}
|
|
14240
|
+
default:
|
|
14241
|
+
return mean;
|
|
14242
|
+
}
|
|
14243
|
+
}
|
|
14244
|
+
/**
|
|
14245
|
+
* Error function approximation for acquisition function calculations
|
|
14246
|
+
*/
|
|
14247
|
+
erf(x) {
|
|
14248
|
+
const a1 = 0.254829592;
|
|
14249
|
+
const a2 = -0.284496736;
|
|
14250
|
+
const a3 = 1.421413741;
|
|
14251
|
+
const a4 = -1.453152027;
|
|
14252
|
+
const a5 = 1.061405429;
|
|
14253
|
+
const p = 0.3275911;
|
|
14254
|
+
const sign = x >= 0 ? 1 : -1;
|
|
14255
|
+
x = Math.abs(x);
|
|
14256
|
+
const t = 1 / (1 + p * x);
|
|
14257
|
+
const y = 1 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.exp(-x * x);
|
|
14258
|
+
return sign * y;
|
|
14259
|
+
}
|
|
14260
|
+
/**
|
|
14261
|
+
* Selects the next configuration to evaluate using Bayesian optimization
|
|
14262
|
+
*/
|
|
14263
|
+
async selectConfigurationViaBayesianOptimization(instructions, bootstrappedDemos, labeledExamples) {
|
|
14264
|
+
const candidates = [];
|
|
14265
|
+
const numCandidates = Math.min(20, instructions.length * 3);
|
|
14266
|
+
for (let i = 0; i < numCandidates; i++) {
|
|
14267
|
+
const config = {
|
|
14268
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
14269
|
+
bootstrappedDemos: Math.min(
|
|
14270
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
14271
|
+
this.maxBootstrappedDemos
|
|
14272
|
+
),
|
|
14273
|
+
labeledExamples: Math.min(
|
|
14274
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
14275
|
+
this.maxLabeledDemos
|
|
14276
|
+
)
|
|
14277
|
+
};
|
|
14278
|
+
const acquisitionValue = this.calculateAcquisitionValue(config);
|
|
14279
|
+
candidates.push({ config, acquisitionValue });
|
|
14280
|
+
}
|
|
14281
|
+
candidates.sort((a, b) => b.acquisitionValue - a.acquisitionValue);
|
|
14282
|
+
return candidates[0].config;
|
|
14283
|
+
}
|
|
14284
|
+
};
|
|
14285
|
+
var axMajorityVotePicker = () => {
|
|
14286
|
+
return async (data) => {
|
|
14287
|
+
if (data.type === "fields") {
|
|
14288
|
+
const counts = {};
|
|
14289
|
+
for (const { index, sample } of data.results) {
|
|
14290
|
+
const key = JSON.stringify(sample);
|
|
14291
|
+
if (!counts[key]) {
|
|
14292
|
+
counts[key] = { count: 0, index };
|
|
14293
|
+
}
|
|
14294
|
+
counts[key].count += 1;
|
|
14295
|
+
}
|
|
14296
|
+
let bestKey;
|
|
14297
|
+
let bestCount = -1;
|
|
14298
|
+
for (const [k, v] of Object.entries(counts)) {
|
|
14299
|
+
if (v.count > bestCount) {
|
|
14300
|
+
bestCount = v.count;
|
|
14301
|
+
bestKey = k;
|
|
14302
|
+
}
|
|
14303
|
+
}
|
|
14304
|
+
return counts[bestKey]?.index ?? 0;
|
|
14305
|
+
}
|
|
14306
|
+
return data.results[0]?.index ?? 0;
|
|
14307
|
+
};
|
|
13963
14308
|
};
|
|
13964
14309
|
|
|
13965
14310
|
// ai/mock/api.ts
|