@ax-llm/ax 12.0.7 → 12.0.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +1107 -527
- package/index.cjs.map +1 -1
- package/index.d.cts +505 -188
- package/index.d.ts +505 -188
- package/index.js +1109 -531
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.js
CHANGED
|
@@ -69,7 +69,7 @@ var AxSpanKindValues = /* @__PURE__ */ ((AxSpanKindValues2) => {
|
|
|
69
69
|
// util/apicall.ts
|
|
70
70
|
import crypto from "crypto";
|
|
71
71
|
import {
|
|
72
|
-
ReadableStream
|
|
72
|
+
ReadableStream,
|
|
73
73
|
TextDecoderStream as TextDecoderStreamNative,
|
|
74
74
|
TransformStream as TransformStream3
|
|
75
75
|
} from "stream/web";
|
|
@@ -486,7 +486,7 @@ var apiCall = async (api, json) => {
|
|
|
486
486
|
}
|
|
487
487
|
});
|
|
488
488
|
let closed = false;
|
|
489
|
-
return new
|
|
489
|
+
return new ReadableStream({
|
|
490
490
|
start(controller) {
|
|
491
491
|
const reader = res.body.pipeThrough(new textDecoderStream()).pipeThrough(new SSEParser()).pipeThrough(trackingStream).getReader();
|
|
492
492
|
async function read() {
|
|
@@ -5498,7 +5498,7 @@ var AxAIGrok = class extends AxAIOpenAIBase {
|
|
|
5498
5498
|
};
|
|
5499
5499
|
|
|
5500
5500
|
// dsp/generate.ts
|
|
5501
|
-
import { ReadableStream as
|
|
5501
|
+
import { ReadableStream as ReadableStream2 } from "stream/web";
|
|
5502
5502
|
import {
|
|
5503
5503
|
context as context2,
|
|
5504
5504
|
SpanKind as SpanKind2,
|
|
@@ -7370,8 +7370,9 @@ var AxInstanceRegistry = class {
|
|
|
7370
7370
|
this.reg.add(instance);
|
|
7371
7371
|
}
|
|
7372
7372
|
*[Symbol.iterator]() {
|
|
7373
|
-
|
|
7374
|
-
|
|
7373
|
+
const items = Array.from(this.reg);
|
|
7374
|
+
for (let i = 0; i < items.length; i++) {
|
|
7375
|
+
yield items[i];
|
|
7375
7376
|
}
|
|
7376
7377
|
}
|
|
7377
7378
|
};
|
|
@@ -8309,7 +8310,7 @@ var AxSignature = class _AxSignature {
|
|
|
8309
8310
|
this.getOutputFields().forEach((field) => {
|
|
8310
8311
|
validateField(field, "output");
|
|
8311
8312
|
});
|
|
8312
|
-
this.sigHash = createHash("sha256").update(
|
|
8313
|
+
this.sigHash = createHash("sha256").update(JSON.stringify(this.inputFields)).update(JSON.stringify(this.outputFields)).digest("hex");
|
|
8313
8314
|
this.sigString = renderSignature(
|
|
8314
8315
|
this.description,
|
|
8315
8316
|
this.inputFields,
|
|
@@ -8630,7 +8631,7 @@ var AxProgramWithSignature = class {
|
|
|
8630
8631
|
this.signature.validate();
|
|
8631
8632
|
this.sigHash = this.signature?.hash();
|
|
8632
8633
|
this.children = new AxInstanceRegistry();
|
|
8633
|
-
this.key = { id: this.
|
|
8634
|
+
this.key = { id: this.signature.hash() };
|
|
8634
8635
|
}
|
|
8635
8636
|
getSignature() {
|
|
8636
8637
|
return this.signature;
|
|
@@ -8650,8 +8651,8 @@ var AxProgramWithSignature = class {
|
|
|
8650
8651
|
}
|
|
8651
8652
|
setId(id) {
|
|
8652
8653
|
this.key = { id, custom: true };
|
|
8653
|
-
for (const child of this.children) {
|
|
8654
|
-
child
|
|
8654
|
+
for (const child of Array.from(this.children)) {
|
|
8655
|
+
child?.setParentId(id);
|
|
8655
8656
|
}
|
|
8656
8657
|
}
|
|
8657
8658
|
setParentId(parentId) {
|
|
@@ -8664,8 +8665,8 @@ var AxProgramWithSignature = class {
|
|
|
8664
8665
|
if (!("programId" in examples)) {
|
|
8665
8666
|
return;
|
|
8666
8667
|
}
|
|
8667
|
-
for (const child of this.children) {
|
|
8668
|
-
child
|
|
8668
|
+
for (const child of Array.from(this.children)) {
|
|
8669
|
+
child?.setExamples(examples, options);
|
|
8669
8670
|
}
|
|
8670
8671
|
}
|
|
8671
8672
|
_setExamples(examples, options) {
|
|
@@ -8698,30 +8699,37 @@ var AxProgramWithSignature = class {
|
|
|
8698
8699
|
if (this.trace) {
|
|
8699
8700
|
traces.push({ trace: this.trace, programId: this.key.id });
|
|
8700
8701
|
}
|
|
8701
|
-
for (const child of this.children) {
|
|
8702
|
-
const _traces = child
|
|
8703
|
-
traces = [...traces, ..._traces];
|
|
8702
|
+
for (const child of Array.from(this.children)) {
|
|
8703
|
+
const _traces = child?.getTraces();
|
|
8704
|
+
traces = [...traces, ..._traces ?? []];
|
|
8704
8705
|
}
|
|
8705
8706
|
return traces;
|
|
8706
8707
|
}
|
|
8707
8708
|
getUsage() {
|
|
8708
8709
|
let usage = [...this.usage ?? []];
|
|
8709
|
-
for (const child of this.children) {
|
|
8710
|
-
const cu = child
|
|
8711
|
-
usage = [...usage, ...cu];
|
|
8710
|
+
for (const child of Array.from(this.children)) {
|
|
8711
|
+
const cu = child?.getUsage();
|
|
8712
|
+
usage = [...usage, ...cu ?? []];
|
|
8712
8713
|
}
|
|
8713
8714
|
return mergeProgramUsage(usage);
|
|
8714
8715
|
}
|
|
8715
8716
|
resetUsage() {
|
|
8716
8717
|
this.usage = [];
|
|
8717
|
-
for (const child of this.children) {
|
|
8718
|
-
child
|
|
8718
|
+
for (const child of Array.from(this.children)) {
|
|
8719
|
+
child?.resetUsage();
|
|
8719
8720
|
}
|
|
8720
8721
|
}
|
|
8721
8722
|
setDemos(demos) {
|
|
8723
|
+
const hasChildren = Array.from(this.children).length > 0;
|
|
8724
|
+
const hasMatchingDemo = demos.some((demo) => demo.programId === this.key.id);
|
|
8725
|
+
if (hasChildren && !hasMatchingDemo) {
|
|
8726
|
+
throw new Error(
|
|
8727
|
+
`Program with id '${this.key.id}' has children but no matching programId found in demos`
|
|
8728
|
+
);
|
|
8729
|
+
}
|
|
8722
8730
|
this.demos = demos.filter((v) => v.programId === this.key.id).map((v) => v.traces).flat();
|
|
8723
|
-
for (const child of this.children) {
|
|
8724
|
-
child
|
|
8731
|
+
for (const child of Array.from(this.children)) {
|
|
8732
|
+
child?.setDemos(demos);
|
|
8725
8733
|
}
|
|
8726
8734
|
}
|
|
8727
8735
|
};
|
|
@@ -8749,8 +8757,8 @@ var AxProgram = class {
|
|
|
8749
8757
|
}
|
|
8750
8758
|
setId(id) {
|
|
8751
8759
|
this.key = { id, custom: true };
|
|
8752
|
-
for (const child of this.children) {
|
|
8753
|
-
child
|
|
8760
|
+
for (const child of Array.from(this.children)) {
|
|
8761
|
+
child?.setParentId(id);
|
|
8754
8762
|
}
|
|
8755
8763
|
}
|
|
8756
8764
|
setParentId(parentId) {
|
|
@@ -8762,8 +8770,8 @@ var AxProgram = class {
|
|
|
8762
8770
|
if (!("programId" in examples)) {
|
|
8763
8771
|
return;
|
|
8764
8772
|
}
|
|
8765
|
-
for (const child of this.children) {
|
|
8766
|
-
child
|
|
8773
|
+
for (const child of Array.from(this.children)) {
|
|
8774
|
+
child?.setExamples(examples, options);
|
|
8767
8775
|
}
|
|
8768
8776
|
}
|
|
8769
8777
|
getTraces() {
|
|
@@ -8771,29 +8779,36 @@ var AxProgram = class {
|
|
|
8771
8779
|
if (this.trace) {
|
|
8772
8780
|
traces.push({ trace: this.trace, programId: this.key.id });
|
|
8773
8781
|
}
|
|
8774
|
-
for (const child of this.children) {
|
|
8775
|
-
const _traces = child
|
|
8776
|
-
traces = [...traces, ..._traces];
|
|
8782
|
+
for (const child of Array.from(this.children)) {
|
|
8783
|
+
const _traces = child?.getTraces();
|
|
8784
|
+
traces = [...traces, ..._traces ?? []];
|
|
8777
8785
|
}
|
|
8778
8786
|
return traces;
|
|
8779
8787
|
}
|
|
8780
8788
|
getUsage() {
|
|
8781
8789
|
let usage = [...this.usage ?? []];
|
|
8782
|
-
for (const child of this.children) {
|
|
8783
|
-
const cu = child
|
|
8784
|
-
usage = [...usage, ...cu];
|
|
8790
|
+
for (const child of Array.from(this.children)) {
|
|
8791
|
+
const cu = child?.getUsage();
|
|
8792
|
+
usage = [...usage, ...cu ?? []];
|
|
8785
8793
|
}
|
|
8786
8794
|
return mergeProgramUsage(usage);
|
|
8787
8795
|
}
|
|
8788
8796
|
resetUsage() {
|
|
8789
8797
|
this.usage = [];
|
|
8790
|
-
for (const child of this.children) {
|
|
8791
|
-
child
|
|
8798
|
+
for (const child of Array.from(this.children)) {
|
|
8799
|
+
child?.resetUsage();
|
|
8792
8800
|
}
|
|
8793
8801
|
}
|
|
8794
8802
|
setDemos(demos) {
|
|
8795
|
-
|
|
8796
|
-
|
|
8803
|
+
const hasChildren = Array.from(this.children).length > 0;
|
|
8804
|
+
const hasMatchingDemo = demos.some((demo) => demo.programId === this.key.id);
|
|
8805
|
+
if (hasChildren && !hasMatchingDemo) {
|
|
8806
|
+
throw new Error(
|
|
8807
|
+
`Program with id '${this.key.id}' has children but no matching programId found in demos`
|
|
8808
|
+
);
|
|
8809
|
+
}
|
|
8810
|
+
for (const child of Array.from(this.children)) {
|
|
8811
|
+
child?.setDemos(demos);
|
|
8797
8812
|
}
|
|
8798
8813
|
}
|
|
8799
8814
|
};
|
|
@@ -8931,7 +8946,7 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
8931
8946
|
traceContext,
|
|
8932
8947
|
firstStep
|
|
8933
8948
|
});
|
|
8934
|
-
if (res instanceof
|
|
8949
|
+
if (res instanceof ReadableStream2) {
|
|
8935
8950
|
yield* this.processStreamingResponse({
|
|
8936
8951
|
ai,
|
|
8937
8952
|
model,
|
|
@@ -9521,7 +9536,9 @@ var AxAgent = class {
|
|
|
9521
9536
|
description: definition ?? description
|
|
9522
9537
|
});
|
|
9523
9538
|
for (const agent of agents ?? []) {
|
|
9524
|
-
this.program.register(
|
|
9539
|
+
this.program.register(
|
|
9540
|
+
agent
|
|
9541
|
+
);
|
|
9525
9542
|
}
|
|
9526
9543
|
this.name = name;
|
|
9527
9544
|
this.func = {
|
|
@@ -10025,6 +10042,673 @@ function validateModels2(services) {
|
|
|
10025
10042
|
}
|
|
10026
10043
|
}
|
|
10027
10044
|
|
|
10045
|
+
// dsp/optimizer.ts
|
|
10046
|
+
var AxDefaultCostTracker = class {
|
|
10047
|
+
tokenUsage = {};
|
|
10048
|
+
totalTokens = 0;
|
|
10049
|
+
// Configuration options
|
|
10050
|
+
costPerModel;
|
|
10051
|
+
maxCost;
|
|
10052
|
+
maxTokens;
|
|
10053
|
+
constructor(options) {
|
|
10054
|
+
this.costPerModel = options?.costPerModel ?? {};
|
|
10055
|
+
this.maxCost = options?.maxCost;
|
|
10056
|
+
this.maxTokens = options?.maxTokens;
|
|
10057
|
+
}
|
|
10058
|
+
trackTokens(count, model) {
|
|
10059
|
+
this.tokenUsage[model] = (this.tokenUsage[model] || 0) + count;
|
|
10060
|
+
this.totalTokens += count;
|
|
10061
|
+
}
|
|
10062
|
+
getCurrentCost() {
|
|
10063
|
+
let totalCost = 0;
|
|
10064
|
+
for (const [model, tokens] of Object.entries(this.tokenUsage)) {
|
|
10065
|
+
const costPer1K = this.costPerModel[model] || 1e-3;
|
|
10066
|
+
totalCost += tokens / 1e3 * costPer1K;
|
|
10067
|
+
}
|
|
10068
|
+
return totalCost;
|
|
10069
|
+
}
|
|
10070
|
+
getTokenUsage() {
|
|
10071
|
+
return { ...this.tokenUsage };
|
|
10072
|
+
}
|
|
10073
|
+
getTotalTokens() {
|
|
10074
|
+
return this.totalTokens;
|
|
10075
|
+
}
|
|
10076
|
+
isLimitReached() {
|
|
10077
|
+
if (this.maxTokens !== void 0 && this.totalTokens >= this.maxTokens) {
|
|
10078
|
+
return true;
|
|
10079
|
+
}
|
|
10080
|
+
if (this.maxCost !== void 0) {
|
|
10081
|
+
const currentCost = this.getCurrentCost();
|
|
10082
|
+
if (currentCost >= this.maxCost) {
|
|
10083
|
+
return true;
|
|
10084
|
+
}
|
|
10085
|
+
}
|
|
10086
|
+
return false;
|
|
10087
|
+
}
|
|
10088
|
+
reset() {
|
|
10089
|
+
this.tokenUsage = {};
|
|
10090
|
+
this.totalTokens = 0;
|
|
10091
|
+
}
|
|
10092
|
+
};
|
|
10093
|
+
var AxBaseOptimizer = class {
|
|
10094
|
+
// Common AxOptimizerArgs fields
|
|
10095
|
+
studentAI;
|
|
10096
|
+
teacherAI;
|
|
10097
|
+
examples;
|
|
10098
|
+
validationSet;
|
|
10099
|
+
targetScore;
|
|
10100
|
+
minSuccessRate;
|
|
10101
|
+
onProgress;
|
|
10102
|
+
onEarlyStop;
|
|
10103
|
+
costTracker;
|
|
10104
|
+
seed;
|
|
10105
|
+
// Checkpointing fields
|
|
10106
|
+
checkpointSave;
|
|
10107
|
+
checkpointLoad;
|
|
10108
|
+
checkpointInterval;
|
|
10109
|
+
resumeFromCheckpoint;
|
|
10110
|
+
// Checkpoint state
|
|
10111
|
+
currentRound = 0;
|
|
10112
|
+
scoreHistory = [];
|
|
10113
|
+
configurationHistory = [];
|
|
10114
|
+
// Common optimization statistics
|
|
10115
|
+
stats;
|
|
10116
|
+
constructor(args) {
|
|
10117
|
+
if (args.examples.length === 0) {
|
|
10118
|
+
throw new Error("No examples found");
|
|
10119
|
+
}
|
|
10120
|
+
this.studentAI = args.studentAI;
|
|
10121
|
+
this.teacherAI = args.teacherAI;
|
|
10122
|
+
this.examples = args.examples;
|
|
10123
|
+
this.validationSet = args.validationSet;
|
|
10124
|
+
this.targetScore = args.targetScore;
|
|
10125
|
+
this.minSuccessRate = args.minSuccessRate;
|
|
10126
|
+
this.onProgress = args.onProgress;
|
|
10127
|
+
this.onEarlyStop = args.onEarlyStop;
|
|
10128
|
+
this.seed = args.seed;
|
|
10129
|
+
this.checkpointSave = args.checkpointSave;
|
|
10130
|
+
this.checkpointLoad = args.checkpointLoad;
|
|
10131
|
+
this.checkpointInterval = args.checkpointInterval ?? 10;
|
|
10132
|
+
this.resumeFromCheckpoint = args.resumeFromCheckpoint;
|
|
10133
|
+
const costTracker = new AxDefaultCostTracker({
|
|
10134
|
+
maxTokens: 1e6
|
|
10135
|
+
});
|
|
10136
|
+
this.costTracker = args.costTracker ?? costTracker;
|
|
10137
|
+
this.stats = this.initializeStats();
|
|
10138
|
+
}
|
|
10139
|
+
/**
|
|
10140
|
+
* Initialize the optimization statistics structure
|
|
10141
|
+
*/
|
|
10142
|
+
initializeStats() {
|
|
10143
|
+
return {
|
|
10144
|
+
totalCalls: 0,
|
|
10145
|
+
successfulDemos: 0,
|
|
10146
|
+
estimatedTokenUsage: 0,
|
|
10147
|
+
earlyStopped: false,
|
|
10148
|
+
resourceUsage: {
|
|
10149
|
+
totalTokens: 0,
|
|
10150
|
+
totalTime: 0,
|
|
10151
|
+
avgLatencyPerEval: 0,
|
|
10152
|
+
costByModel: {}
|
|
10153
|
+
},
|
|
10154
|
+
convergenceInfo: {
|
|
10155
|
+
converged: false,
|
|
10156
|
+
finalImprovement: 0,
|
|
10157
|
+
stagnationRounds: 0,
|
|
10158
|
+
convergenceThreshold: 0.01
|
|
10159
|
+
}
|
|
10160
|
+
};
|
|
10161
|
+
}
|
|
10162
|
+
/**
|
|
10163
|
+
* Set up reproducible random seed if provided
|
|
10164
|
+
*/
|
|
10165
|
+
setupRandomSeed() {
|
|
10166
|
+
if (this.seed !== void 0) {
|
|
10167
|
+
Math.random = (() => {
|
|
10168
|
+
let seed = this.seed;
|
|
10169
|
+
return () => {
|
|
10170
|
+
seed = (seed * 9301 + 49297) % 233280;
|
|
10171
|
+
return seed / 233280;
|
|
10172
|
+
};
|
|
10173
|
+
})();
|
|
10174
|
+
}
|
|
10175
|
+
}
|
|
10176
|
+
/**
|
|
10177
|
+
* Check if optimization should stop early due to cost limits
|
|
10178
|
+
*/
|
|
10179
|
+
checkCostLimits() {
|
|
10180
|
+
return this.costTracker?.isLimitReached() ?? false;
|
|
10181
|
+
}
|
|
10182
|
+
/**
|
|
10183
|
+
* Check if target score has been reached
|
|
10184
|
+
*/
|
|
10185
|
+
checkTargetScore(currentScore) {
|
|
10186
|
+
return this.targetScore !== void 0 && currentScore >= this.targetScore;
|
|
10187
|
+
}
|
|
10188
|
+
/**
|
|
10189
|
+
* Update resource usage statistics
|
|
10190
|
+
*/
|
|
10191
|
+
updateResourceUsage(startTime, tokensUsed = 0) {
|
|
10192
|
+
this.stats.resourceUsage.totalTime = Date.now() - startTime;
|
|
10193
|
+
this.stats.resourceUsage.totalTokens += tokensUsed;
|
|
10194
|
+
if (this.stats.totalCalls > 0) {
|
|
10195
|
+
this.stats.resourceUsage.avgLatencyPerEval = this.stats.resourceUsage.totalTime / this.stats.totalCalls;
|
|
10196
|
+
}
|
|
10197
|
+
}
|
|
10198
|
+
/**
|
|
10199
|
+
* Trigger early stopping with appropriate callbacks
|
|
10200
|
+
*/
|
|
10201
|
+
triggerEarlyStopping(reason, bestScoreRound) {
|
|
10202
|
+
this.stats.earlyStopped = true;
|
|
10203
|
+
this.stats.earlyStopping = {
|
|
10204
|
+
bestScoreRound,
|
|
10205
|
+
patienceExhausted: reason.includes("improvement"),
|
|
10206
|
+
reason
|
|
10207
|
+
};
|
|
10208
|
+
if (this.onEarlyStop) {
|
|
10209
|
+
this.onEarlyStop(reason, this.stats);
|
|
10210
|
+
}
|
|
10211
|
+
}
|
|
10212
|
+
/**
|
|
10213
|
+
* Get the validation set, with fallback to a split of examples
|
|
10214
|
+
*/
|
|
10215
|
+
getValidationSet(options) {
|
|
10216
|
+
return options?.overrideValidationSet || this.validationSet || this.examples.slice(0, Math.floor(this.examples.length * 0.2));
|
|
10217
|
+
}
|
|
10218
|
+
/**
|
|
10219
|
+
* Get the AI service to use for a specific task, preferring teacher when available
|
|
10220
|
+
* @param preferTeacher Whether to prefer teacher AI over student AI
|
|
10221
|
+
* @param options Optional compile options that may override teacher AI
|
|
10222
|
+
* @returns The appropriate AI service to use
|
|
10223
|
+
*/
|
|
10224
|
+
getAIService(preferTeacher = false, options) {
|
|
10225
|
+
if (preferTeacher && options?.overrideTeacherAI) {
|
|
10226
|
+
return options.overrideTeacherAI;
|
|
10227
|
+
}
|
|
10228
|
+
if (preferTeacher && this.teacherAI) {
|
|
10229
|
+
return this.teacherAI;
|
|
10230
|
+
}
|
|
10231
|
+
return this.studentAI;
|
|
10232
|
+
}
|
|
10233
|
+
/**
|
|
10234
|
+
* Check if teacher AI is available (including overrides)
|
|
10235
|
+
* @param options Optional compile options that may override teacher AI
|
|
10236
|
+
* @returns True if teacher AI is configured or overridden
|
|
10237
|
+
*/
|
|
10238
|
+
hasTeacherAI(options) {
|
|
10239
|
+
return options?.overrideTeacherAI !== void 0 || this.teacherAI !== void 0;
|
|
10240
|
+
}
|
|
10241
|
+
/**
|
|
10242
|
+
* Get teacher AI if available, otherwise return student AI
|
|
10243
|
+
* @param options Optional compile options that may override teacher AI
|
|
10244
|
+
* @returns Teacher AI if available, otherwise student AI
|
|
10245
|
+
*/
|
|
10246
|
+
getTeacherOrStudentAI(options) {
|
|
10247
|
+
return options?.overrideTeacherAI || this.teacherAI || this.studentAI;
|
|
10248
|
+
}
|
|
10249
|
+
/**
|
|
10250
|
+
* Execute a task with teacher AI if available, otherwise use student AI
|
|
10251
|
+
* @param task Function that takes an AI service and returns a promise
|
|
10252
|
+
* @param preferTeacher Whether to prefer teacher AI (default: true)
|
|
10253
|
+
* @param options Optional compile options that may override teacher AI
|
|
10254
|
+
* @returns Result of the task execution
|
|
10255
|
+
*/
|
|
10256
|
+
async executeWithTeacher(task, preferTeacher = true, options) {
|
|
10257
|
+
const ai = this.getAIService(preferTeacher, options);
|
|
10258
|
+
return await task(ai);
|
|
10259
|
+
}
|
|
10260
|
+
/**
|
|
10261
|
+
* Get current optimization statistics
|
|
10262
|
+
*/
|
|
10263
|
+
getStats() {
|
|
10264
|
+
return { ...this.stats };
|
|
10265
|
+
}
|
|
10266
|
+
/**
|
|
10267
|
+
* Reset optimizer state for reuse with different programs
|
|
10268
|
+
*/
|
|
10269
|
+
reset() {
|
|
10270
|
+
this.stats = this.initializeStats();
|
|
10271
|
+
this.costTracker?.reset();
|
|
10272
|
+
this.currentRound = 0;
|
|
10273
|
+
this.scoreHistory = [];
|
|
10274
|
+
this.configurationHistory = [];
|
|
10275
|
+
}
|
|
10276
|
+
/**
|
|
10277
|
+
* Basic program validation that can be extended by concrete optimizers
|
|
10278
|
+
*/
|
|
10279
|
+
validateProgram(program) {
|
|
10280
|
+
const issues = [];
|
|
10281
|
+
const suggestions = [];
|
|
10282
|
+
if (!("forward" in program) || typeof program.forward !== "function") {
|
|
10283
|
+
issues.push("Program must have a forward method");
|
|
10284
|
+
}
|
|
10285
|
+
if (this.examples.length < 2) {
|
|
10286
|
+
issues.push("Need at least 2 examples for optimization");
|
|
10287
|
+
suggestions.push("Provide more training examples");
|
|
10288
|
+
}
|
|
10289
|
+
const valSetSize = this.getValidationSet().length;
|
|
10290
|
+
if (valSetSize < 1) {
|
|
10291
|
+
issues.push("Validation set is empty");
|
|
10292
|
+
suggestions.push("Provide examples or a validation set");
|
|
10293
|
+
}
|
|
10294
|
+
return {
|
|
10295
|
+
isValid: issues.length === 0,
|
|
10296
|
+
issues,
|
|
10297
|
+
suggestions
|
|
10298
|
+
};
|
|
10299
|
+
}
|
|
10300
|
+
/**
|
|
10301
|
+
* Multi-objective optimization using Pareto frontier
|
|
10302
|
+
* Default implementation that leverages the single-objective compile method
|
|
10303
|
+
* @param program The program to optimize
|
|
10304
|
+
* @param metricFn Multi-objective metric function that returns multiple scores
|
|
10305
|
+
* @param options Optional configuration options
|
|
10306
|
+
* @returns Pareto optimization result with frontier of non-dominated solutions
|
|
10307
|
+
*/
|
|
10308
|
+
async compilePareto(program, metricFn, options) {
|
|
10309
|
+
const startTime = Date.now();
|
|
10310
|
+
if (options?.verbose) {
|
|
10311
|
+
console.log("Starting Pareto optimization using base implementation");
|
|
10312
|
+
console.log("This will run multiple single-objective optimizations");
|
|
10313
|
+
}
|
|
10314
|
+
const solutions = await this.generateWeightedSolutions(
|
|
10315
|
+
program,
|
|
10316
|
+
metricFn,
|
|
10317
|
+
options
|
|
10318
|
+
);
|
|
10319
|
+
const constraintSolutions = await this.generateConstraintSolutions(
|
|
10320
|
+
program,
|
|
10321
|
+
metricFn,
|
|
10322
|
+
options
|
|
10323
|
+
);
|
|
10324
|
+
const allSolutions = [...solutions, ...constraintSolutions];
|
|
10325
|
+
if (options?.verbose) {
|
|
10326
|
+
console.log(`Generated ${allSolutions.length} candidate solutions`);
|
|
10327
|
+
}
|
|
10328
|
+
const paretoFront = this.findParetoFrontier(allSolutions);
|
|
10329
|
+
const hypervolume = this.calculateHypervolume(paretoFront);
|
|
10330
|
+
if (options?.verbose) {
|
|
10331
|
+
console.log(`Found ${paretoFront.length} non-dominated solutions`);
|
|
10332
|
+
console.log(`Hypervolume: ${hypervolume?.toFixed(4) || "N/A"}`);
|
|
10333
|
+
}
|
|
10334
|
+
this.updateResourceUsage(startTime);
|
|
10335
|
+
this.stats.convergenceInfo.converged = true;
|
|
10336
|
+
const bestScore = paretoFront.length > 0 ? Math.max(
|
|
10337
|
+
...paretoFront.map((sol) => Math.max(...Object.values(sol.scores)))
|
|
10338
|
+
) : 0;
|
|
10339
|
+
return {
|
|
10340
|
+
demos: paretoFront.length > 0 ? [...paretoFront[0].demos] : void 0,
|
|
10341
|
+
stats: this.stats,
|
|
10342
|
+
bestScore,
|
|
10343
|
+
paretoFront,
|
|
10344
|
+
hypervolume,
|
|
10345
|
+
paretoFrontSize: paretoFront.length,
|
|
10346
|
+
finalConfiguration: {
|
|
10347
|
+
paretoFrontSize: paretoFront.length,
|
|
10348
|
+
hypervolume,
|
|
10349
|
+
strategy: "weighted_combinations_and_constraints",
|
|
10350
|
+
numSolutions: allSolutions.length
|
|
10351
|
+
}
|
|
10352
|
+
};
|
|
10353
|
+
}
|
|
10354
|
+
/**
|
|
10355
|
+
* Generate solutions using different weighted combinations of objectives
|
|
10356
|
+
*/
|
|
10357
|
+
async generateWeightedSolutions(program, metricFn, options) {
|
|
10358
|
+
const solutions = [];
|
|
10359
|
+
const sampleExample = this.examples[0];
|
|
10360
|
+
const samplePrediction = await program.forward(
|
|
10361
|
+
this.studentAI,
|
|
10362
|
+
sampleExample
|
|
10363
|
+
);
|
|
10364
|
+
const sampleScores = await metricFn({
|
|
10365
|
+
prediction: samplePrediction,
|
|
10366
|
+
example: sampleExample
|
|
10367
|
+
});
|
|
10368
|
+
const objectives = Object.keys(sampleScores);
|
|
10369
|
+
if (options?.verbose) {
|
|
10370
|
+
console.log(`Detected objectives: ${objectives.join(", ")}`);
|
|
10371
|
+
}
|
|
10372
|
+
const weightCombinations = this.generateWeightCombinations(objectives);
|
|
10373
|
+
for (let i = 0; i < weightCombinations.length; i++) {
|
|
10374
|
+
const weights = weightCombinations[i];
|
|
10375
|
+
if (options?.verbose) {
|
|
10376
|
+
console.log(`Optimizing with weights: ${JSON.stringify(weights)}`);
|
|
10377
|
+
}
|
|
10378
|
+
const weightedMetric = async ({ prediction, example }) => {
|
|
10379
|
+
const scores = await metricFn({ prediction, example });
|
|
10380
|
+
let weightedScore = 0;
|
|
10381
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10382
|
+
weightedScore += score * (weights[objective] || 0);
|
|
10383
|
+
}
|
|
10384
|
+
return weightedScore;
|
|
10385
|
+
};
|
|
10386
|
+
try {
|
|
10387
|
+
const result = await this.compile(program, weightedMetric, {
|
|
10388
|
+
...options,
|
|
10389
|
+
verbose: false
|
|
10390
|
+
// Suppress inner optimization logs
|
|
10391
|
+
});
|
|
10392
|
+
const scores = await this.evaluateWithMultiObjective(
|
|
10393
|
+
program,
|
|
10394
|
+
result,
|
|
10395
|
+
metricFn
|
|
10396
|
+
);
|
|
10397
|
+
solutions.push({
|
|
10398
|
+
scores,
|
|
10399
|
+
demos: result.demos,
|
|
10400
|
+
configuration: {
|
|
10401
|
+
...result.finalConfiguration,
|
|
10402
|
+
weights,
|
|
10403
|
+
strategy: "weighted_combination"
|
|
10404
|
+
}
|
|
10405
|
+
});
|
|
10406
|
+
} catch (error) {
|
|
10407
|
+
if (options?.verbose) {
|
|
10408
|
+
console.warn(
|
|
10409
|
+
`Failed optimization with weights ${JSON.stringify(weights)}:`,
|
|
10410
|
+
error
|
|
10411
|
+
);
|
|
10412
|
+
}
|
|
10413
|
+
continue;
|
|
10414
|
+
}
|
|
10415
|
+
}
|
|
10416
|
+
return solutions;
|
|
10417
|
+
}
|
|
10418
|
+
/**
|
|
10419
|
+
* Generate solutions using constraint-based optimization
|
|
10420
|
+
*/
|
|
10421
|
+
async generateConstraintSolutions(program, metricFn, options) {
|
|
10422
|
+
const solutions = [];
|
|
10423
|
+
const sampleExample = this.examples[0];
|
|
10424
|
+
const samplePrediction = await program.forward(
|
|
10425
|
+
this.studentAI,
|
|
10426
|
+
sampleExample
|
|
10427
|
+
);
|
|
10428
|
+
const sampleScores = await metricFn({
|
|
10429
|
+
prediction: samplePrediction,
|
|
10430
|
+
example: sampleExample
|
|
10431
|
+
});
|
|
10432
|
+
const objectives = Object.keys(sampleScores);
|
|
10433
|
+
for (const primaryObjective of objectives) {
|
|
10434
|
+
if (options?.verbose) {
|
|
10435
|
+
console.log(
|
|
10436
|
+
`Optimizing ${primaryObjective} with constraints on other objectives`
|
|
10437
|
+
);
|
|
10438
|
+
}
|
|
10439
|
+
const constraintMetric = async ({ prediction, example }) => {
|
|
10440
|
+
const scores = await metricFn({ prediction, example });
|
|
10441
|
+
const primaryScore = scores[primaryObjective] || 0;
|
|
10442
|
+
let penalty = 0;
|
|
10443
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10444
|
+
if (objective !== primaryObjective) {
|
|
10445
|
+
if (score < 0.3) {
|
|
10446
|
+
penalty += (0.3 - score) * 2;
|
|
10447
|
+
}
|
|
10448
|
+
}
|
|
10449
|
+
}
|
|
10450
|
+
return primaryScore - penalty;
|
|
10451
|
+
};
|
|
10452
|
+
try {
|
|
10453
|
+
const result = await this.compile(program, constraintMetric, {
|
|
10454
|
+
...options,
|
|
10455
|
+
verbose: false
|
|
10456
|
+
});
|
|
10457
|
+
const scores = await this.evaluateWithMultiObjective(
|
|
10458
|
+
program,
|
|
10459
|
+
result,
|
|
10460
|
+
metricFn
|
|
10461
|
+
);
|
|
10462
|
+
solutions.push({
|
|
10463
|
+
scores,
|
|
10464
|
+
demos: result.demos,
|
|
10465
|
+
configuration: {
|
|
10466
|
+
...result.finalConfiguration,
|
|
10467
|
+
primaryObjective,
|
|
10468
|
+
strategy: "constraint_based"
|
|
10469
|
+
}
|
|
10470
|
+
});
|
|
10471
|
+
} catch (error) {
|
|
10472
|
+
if (options?.verbose) {
|
|
10473
|
+
console.warn(
|
|
10474
|
+
`Failed constraint optimization for ${primaryObjective}:`,
|
|
10475
|
+
error
|
|
10476
|
+
);
|
|
10477
|
+
}
|
|
10478
|
+
continue;
|
|
10479
|
+
}
|
|
10480
|
+
}
|
|
10481
|
+
return solutions;
|
|
10482
|
+
}
|
|
10483
|
+
/**
|
|
10484
|
+
* Generate different weight combinations for objectives
|
|
10485
|
+
*/
|
|
10486
|
+
generateWeightCombinations(objectives) {
|
|
10487
|
+
const combinations = [];
|
|
10488
|
+
for (const objective of objectives) {
|
|
10489
|
+
const weights = {};
|
|
10490
|
+
for (const obj of objectives) {
|
|
10491
|
+
weights[obj] = obj === objective ? 1 : 0;
|
|
10492
|
+
}
|
|
10493
|
+
combinations.push(weights);
|
|
10494
|
+
}
|
|
10495
|
+
const equalWeights = {};
|
|
10496
|
+
for (const objective of objectives) {
|
|
10497
|
+
equalWeights[objective] = 1 / objectives.length;
|
|
10498
|
+
}
|
|
10499
|
+
combinations.push(equalWeights);
|
|
10500
|
+
if (objectives.length === 2) {
|
|
10501
|
+
const [obj1, obj2] = objectives;
|
|
10502
|
+
for (let w1 = 0.1; w1 <= 0.9; w1 += 0.2) {
|
|
10503
|
+
const w2 = 1 - w1;
|
|
10504
|
+
combinations.push({ [obj1]: w1, [obj2]: w2 });
|
|
10505
|
+
}
|
|
10506
|
+
}
|
|
10507
|
+
if (objectives.length === 3) {
|
|
10508
|
+
const [obj1, obj2, obj3] = objectives;
|
|
10509
|
+
combinations.push(
|
|
10510
|
+
{ [obj1]: 0.5, [obj2]: 0.3, [obj3]: 0.2 },
|
|
10511
|
+
{ [obj1]: 0.3, [obj2]: 0.5, [obj3]: 0.2 },
|
|
10512
|
+
{ [obj1]: 0.2, [obj2]: 0.3, [obj3]: 0.5 }
|
|
10513
|
+
);
|
|
10514
|
+
}
|
|
10515
|
+
return combinations;
|
|
10516
|
+
}
|
|
10517
|
+
/**
|
|
10518
|
+
* Evaluate a single-objective result with multi-objective metrics
|
|
10519
|
+
*/
|
|
10520
|
+
async evaluateWithMultiObjective(program, result, metricFn) {
|
|
10521
|
+
const valSet = this.getValidationSet();
|
|
10522
|
+
const allScores = {};
|
|
10523
|
+
const testProgram = { ...program };
|
|
10524
|
+
if (result.demos && "setDemos" in testProgram) {
|
|
10525
|
+
;
|
|
10526
|
+
testProgram.setDemos(result.demos);
|
|
10527
|
+
}
|
|
10528
|
+
const evalSet = valSet.slice(0, Math.min(5, valSet.length));
|
|
10529
|
+
for (const example of evalSet) {
|
|
10530
|
+
try {
|
|
10531
|
+
const prediction = await testProgram.forward(
|
|
10532
|
+
this.studentAI,
|
|
10533
|
+
example
|
|
10534
|
+
);
|
|
10535
|
+
const scores = await metricFn({ prediction, example });
|
|
10536
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10537
|
+
if (!allScores[objective]) {
|
|
10538
|
+
allScores[objective] = [];
|
|
10539
|
+
}
|
|
10540
|
+
allScores[objective].push(score);
|
|
10541
|
+
}
|
|
10542
|
+
} catch {
|
|
10543
|
+
continue;
|
|
10544
|
+
}
|
|
10545
|
+
}
|
|
10546
|
+
const avgScores = {};
|
|
10547
|
+
for (const [objective, scores] of Object.entries(allScores)) {
|
|
10548
|
+
avgScores[objective] = scores.length > 0 ? scores.reduce((sum, score) => sum + score, 0) / scores.length : 0;
|
|
10549
|
+
}
|
|
10550
|
+
return avgScores;
|
|
10551
|
+
}
|
|
10552
|
+
/**
|
|
10553
|
+
* Find the Pareto frontier from a set of solutions
|
|
10554
|
+
*/
|
|
10555
|
+
findParetoFrontier(solutions) {
|
|
10556
|
+
const paretoFront = [];
|
|
10557
|
+
for (let i = 0; i < solutions.length; i++) {
|
|
10558
|
+
const solutionA = solutions[i];
|
|
10559
|
+
let isDominated = false;
|
|
10560
|
+
let dominatedCount = 0;
|
|
10561
|
+
for (let j = 0; j < solutions.length; j++) {
|
|
10562
|
+
if (i === j) continue;
|
|
10563
|
+
const solutionB = solutions[j];
|
|
10564
|
+
if (this.dominates(solutionB.scores, solutionA.scores)) {
|
|
10565
|
+
isDominated = true;
|
|
10566
|
+
break;
|
|
10567
|
+
}
|
|
10568
|
+
if (this.dominates(solutionA.scores, solutionB.scores)) {
|
|
10569
|
+
dominatedCount++;
|
|
10570
|
+
}
|
|
10571
|
+
}
|
|
10572
|
+
if (!isDominated) {
|
|
10573
|
+
paretoFront.push({
|
|
10574
|
+
demos: solutionA.demos || [],
|
|
10575
|
+
scores: solutionA.scores,
|
|
10576
|
+
configuration: solutionA.configuration,
|
|
10577
|
+
dominatedSolutions: dominatedCount
|
|
10578
|
+
});
|
|
10579
|
+
}
|
|
10580
|
+
}
|
|
10581
|
+
return paretoFront;
|
|
10582
|
+
}
|
|
10583
|
+
/**
|
|
10584
|
+
* Check if solution A dominates solution B
|
|
10585
|
+
* A dominates B if A is better or equal in all objectives and strictly better in at least one
|
|
10586
|
+
*/
|
|
10587
|
+
dominates(scoresA, scoresB) {
|
|
10588
|
+
const objectives = Object.keys(scoresA);
|
|
10589
|
+
let atLeastAsGood = true;
|
|
10590
|
+
let strictlyBetter = false;
|
|
10591
|
+
for (const objective of objectives) {
|
|
10592
|
+
const scoreA = scoresA[objective] || 0;
|
|
10593
|
+
const scoreB = scoresB[objective] || 0;
|
|
10594
|
+
if (scoreA < scoreB) {
|
|
10595
|
+
atLeastAsGood = false;
|
|
10596
|
+
break;
|
|
10597
|
+
}
|
|
10598
|
+
if (scoreA > scoreB) {
|
|
10599
|
+
strictlyBetter = true;
|
|
10600
|
+
}
|
|
10601
|
+
}
|
|
10602
|
+
return atLeastAsGood && strictlyBetter;
|
|
10603
|
+
}
|
|
10604
|
+
/**
|
|
10605
|
+
* Calculate hypervolume of the Pareto frontier
|
|
10606
|
+
* Simplified implementation using reference point at origin
|
|
10607
|
+
*/
|
|
10608
|
+
calculateHypervolume(paretoFront) {
|
|
10609
|
+
if (paretoFront.length === 0) return void 0;
|
|
10610
|
+
const firstSolution = paretoFront[0];
|
|
10611
|
+
const objectives = Object.keys(firstSolution.scores);
|
|
10612
|
+
if (objectives.length === 2) {
|
|
10613
|
+
const [obj1, obj2] = objectives;
|
|
10614
|
+
let hypervolume = 0;
|
|
10615
|
+
const sortedSolutions = [...paretoFront].sort(
|
|
10616
|
+
(a, b) => (b.scores[obj1] || 0) - (a.scores[obj1] || 0)
|
|
10617
|
+
);
|
|
10618
|
+
let prevScore2 = 0;
|
|
10619
|
+
for (const solution of sortedSolutions) {
|
|
10620
|
+
const score1 = solution.scores[obj1] || 0;
|
|
10621
|
+
const score2 = solution.scores[obj2] || 0;
|
|
10622
|
+
hypervolume += score1 * (score2 - prevScore2);
|
|
10623
|
+
prevScore2 = Math.max(prevScore2, score2);
|
|
10624
|
+
}
|
|
10625
|
+
return hypervolume;
|
|
10626
|
+
}
|
|
10627
|
+
return void 0;
|
|
10628
|
+
}
|
|
10629
|
+
/**
|
|
10630
|
+
* Save current optimization state to checkpoint
|
|
10631
|
+
*/
|
|
10632
|
+
async saveCheckpoint(optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10633
|
+
const saveFn = options?.overrideCheckpointSave || this.checkpointSave;
|
|
10634
|
+
if (!saveFn) return void 0;
|
|
10635
|
+
const checkpoint = {
|
|
10636
|
+
version: "1.0.0",
|
|
10637
|
+
timestamp: Date.now(),
|
|
10638
|
+
optimizerType,
|
|
10639
|
+
optimizerConfig,
|
|
10640
|
+
currentRound: this.currentRound,
|
|
10641
|
+
totalRounds: this.stats.resourceUsage.totalTime > 0 ? this.currentRound : 0,
|
|
10642
|
+
bestScore,
|
|
10643
|
+
bestConfiguration,
|
|
10644
|
+
scoreHistory: [...this.scoreHistory],
|
|
10645
|
+
configurationHistory: [...this.configurationHistory],
|
|
10646
|
+
stats: { ...this.stats },
|
|
10647
|
+
optimizerState,
|
|
10648
|
+
examples: this.examples,
|
|
10649
|
+
validationSet: this.validationSet
|
|
10650
|
+
};
|
|
10651
|
+
return await saveFn(checkpoint);
|
|
10652
|
+
}
|
|
10653
|
+
/**
|
|
10654
|
+
* Load optimization state from checkpoint
|
|
10655
|
+
*/
|
|
10656
|
+
async loadCheckpoint(checkpointId, options) {
|
|
10657
|
+
const loadFn = options?.overrideCheckpointLoad || this.checkpointLoad;
|
|
10658
|
+
if (!loadFn) return null;
|
|
10659
|
+
return await loadFn(checkpointId);
|
|
10660
|
+
}
|
|
10661
|
+
/**
|
|
10662
|
+
* Restore optimizer state from checkpoint
|
|
10663
|
+
*/
|
|
10664
|
+
restoreFromCheckpoint(checkpoint) {
|
|
10665
|
+
this.currentRound = checkpoint.currentRound;
|
|
10666
|
+
this.scoreHistory = [...checkpoint.scoreHistory];
|
|
10667
|
+
this.configurationHistory = [...checkpoint.configurationHistory];
|
|
10668
|
+
this.stats = { ...checkpoint.stats };
|
|
10669
|
+
}
|
|
10670
|
+
/**
|
|
10671
|
+
* Check if checkpoint should be saved
|
|
10672
|
+
*/
|
|
10673
|
+
shouldSaveCheckpoint(round, options) {
|
|
10674
|
+
const interval = options?.overrideCheckpointInterval || this.checkpointInterval;
|
|
10675
|
+
return interval !== void 0 && round % interval === 0;
|
|
10676
|
+
}
|
|
10677
|
+
/**
|
|
10678
|
+
* Update optimization progress and handle checkpointing
|
|
10679
|
+
*/
|
|
10680
|
+
async updateOptimizationProgress(round, score, configuration, optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10681
|
+
this.currentRound = round;
|
|
10682
|
+
this.scoreHistory.push(score);
|
|
10683
|
+
this.configurationHistory.push(configuration);
|
|
10684
|
+
if (this.shouldSaveCheckpoint(round, options)) {
|
|
10685
|
+
await this.saveCheckpoint(
|
|
10686
|
+
optimizerType,
|
|
10687
|
+
optimizerConfig,
|
|
10688
|
+
bestScore,
|
|
10689
|
+
bestConfiguration,
|
|
10690
|
+
optimizerState,
|
|
10691
|
+
options
|
|
10692
|
+
);
|
|
10693
|
+
}
|
|
10694
|
+
}
|
|
10695
|
+
/**
|
|
10696
|
+
* Save final checkpoint on completion
|
|
10697
|
+
*/
|
|
10698
|
+
async saveFinalCheckpoint(optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10699
|
+
if (options?.saveCheckpointOnComplete !== false) {
|
|
10700
|
+
await this.saveCheckpoint(
|
|
10701
|
+
optimizerType,
|
|
10702
|
+
optimizerConfig,
|
|
10703
|
+
bestScore,
|
|
10704
|
+
bestConfiguration,
|
|
10705
|
+
{ ...optimizerState, final: true },
|
|
10706
|
+
options
|
|
10707
|
+
);
|
|
10708
|
+
}
|
|
10709
|
+
}
|
|
10710
|
+
};
|
|
10711
|
+
|
|
10028
10712
|
// db/base.ts
|
|
10029
10713
|
import { SpanKind as SpanKind3 } from "@opentelemetry/api";
|
|
10030
10714
|
var AxDBBase = class {
|
|
@@ -11484,11 +12168,7 @@ var AxMCPStreambleHTTPTransport = class {
|
|
|
11484
12168
|
};
|
|
11485
12169
|
|
|
11486
12170
|
// dsp/optimizers/bootstrapFewshot.ts
|
|
11487
|
-
var AxBootstrapFewShot = class {
|
|
11488
|
-
ai;
|
|
11489
|
-
teacherAI;
|
|
11490
|
-
program;
|
|
11491
|
-
examples;
|
|
12171
|
+
var AxBootstrapFewShot = class extends AxBaseOptimizer {
|
|
11492
12172
|
maxRounds;
|
|
11493
12173
|
maxDemos;
|
|
11494
12174
|
maxExamples;
|
|
@@ -11499,37 +12179,20 @@ var AxBootstrapFewShot = class {
|
|
|
11499
12179
|
verboseMode;
|
|
11500
12180
|
debugMode;
|
|
11501
12181
|
traces = [];
|
|
11502
|
-
|
|
11503
|
-
|
|
11504
|
-
|
|
11505
|
-
|
|
11506
|
-
|
|
11507
|
-
|
|
11508
|
-
|
|
11509
|
-
|
|
11510
|
-
|
|
11511
|
-
|
|
11512
|
-
options
|
|
11513
|
-
|
|
11514
|
-
|
|
11515
|
-
|
|
11516
|
-
}
|
|
11517
|
-
const bootstrapOptions = options;
|
|
11518
|
-
this.maxRounds = bootstrapOptions?.maxRounds ?? 3;
|
|
11519
|
-
this.maxDemos = bootstrapOptions?.maxDemos ?? 4;
|
|
11520
|
-
this.maxExamples = bootstrapOptions?.maxExamples ?? 16;
|
|
11521
|
-
this.batchSize = bootstrapOptions?.batchSize ?? 1;
|
|
11522
|
-
this.earlyStoppingPatience = bootstrapOptions?.earlyStoppingPatience ?? 0;
|
|
11523
|
-
this.costMonitoring = bootstrapOptions?.costMonitoring ?? false;
|
|
11524
|
-
this.maxTokensPerGeneration = bootstrapOptions?.maxTokensPerGeneration ?? 0;
|
|
11525
|
-
this.verboseMode = bootstrapOptions?.verboseMode ?? true;
|
|
11526
|
-
this.debugMode = bootstrapOptions?.debugMode ?? false;
|
|
11527
|
-
this.ai = ai;
|
|
11528
|
-
this.teacherAI = bootstrapOptions?.teacherAI;
|
|
11529
|
-
this.program = program;
|
|
11530
|
-
this.examples = examples;
|
|
11531
|
-
}
|
|
11532
|
-
async compileRound(roundIndex, metricFn, options) {
|
|
12182
|
+
constructor(args) {
|
|
12183
|
+
super(args);
|
|
12184
|
+
const options = args.options || {};
|
|
12185
|
+
this.maxRounds = options.maxRounds ?? 3;
|
|
12186
|
+
this.maxDemos = options.maxDemos ?? 4;
|
|
12187
|
+
this.maxExamples = options.maxExamples ?? 16;
|
|
12188
|
+
this.batchSize = options.batchSize ?? 1;
|
|
12189
|
+
this.earlyStoppingPatience = options.earlyStoppingPatience ?? 0;
|
|
12190
|
+
this.costMonitoring = options.costMonitoring ?? false;
|
|
12191
|
+
this.maxTokensPerGeneration = options.maxTokensPerGeneration ?? 0;
|
|
12192
|
+
this.verboseMode = options.verboseMode ?? true;
|
|
12193
|
+
this.debugMode = options.debugMode ?? false;
|
|
12194
|
+
}
|
|
12195
|
+
async compileRound(program, roundIndex, metricFn, options) {
|
|
11533
12196
|
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
11534
12197
|
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
11535
12198
|
const aiOpt = {
|
|
@@ -11552,20 +12215,20 @@ var AxBootstrapFewShot = class {
|
|
|
11552
12215
|
continue;
|
|
11553
12216
|
}
|
|
11554
12217
|
const exList = examples.filter((e) => e !== ex);
|
|
11555
|
-
|
|
11556
|
-
const aiService = this.
|
|
12218
|
+
program.setExamples(exList);
|
|
12219
|
+
const aiService = this.getTeacherOrStudentAI();
|
|
11557
12220
|
this.stats.totalCalls++;
|
|
11558
12221
|
let res;
|
|
11559
12222
|
let error;
|
|
11560
12223
|
try {
|
|
11561
|
-
res = await
|
|
12224
|
+
res = await program.forward(aiService, ex, aiOpt);
|
|
11562
12225
|
if (this.costMonitoring) {
|
|
11563
12226
|
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
11564
12227
|
}
|
|
11565
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
12228
|
+
const score = await metricFn({ prediction: res, example: ex });
|
|
11566
12229
|
const success = score >= 0.5;
|
|
11567
12230
|
if (success) {
|
|
11568
|
-
this.traces = [...this.traces, ...
|
|
12231
|
+
this.traces = [...this.traces, ...program.getTraces()];
|
|
11569
12232
|
this.stats.successfulDemos++;
|
|
11570
12233
|
}
|
|
11571
12234
|
} catch (err) {
|
|
@@ -11616,13 +12279,15 @@ var AxBootstrapFewShot = class {
|
|
|
11616
12279
|
if (!this.stats.earlyStopping) {
|
|
11617
12280
|
this.stats.earlyStopping = {
|
|
11618
12281
|
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
11619
|
-
patienceExhausted: false
|
|
12282
|
+
patienceExhausted: false,
|
|
12283
|
+
reason: "No improvement detected"
|
|
11620
12284
|
};
|
|
11621
12285
|
} else if (improvement > 0) {
|
|
11622
12286
|
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
11623
12287
|
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
11624
12288
|
this.stats.earlyStopping.patienceExhausted = true;
|
|
11625
12289
|
this.stats.earlyStopped = true;
|
|
12290
|
+
this.stats.earlyStopping.reason = `No improvement for ${this.earlyStoppingPatience} rounds`;
|
|
11626
12291
|
if (this.verboseMode || this.debugMode) {
|
|
11627
12292
|
console.log(
|
|
11628
12293
|
`
|
|
@@ -11633,37 +12298,38 @@ Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${th
|
|
|
11633
12298
|
}
|
|
11634
12299
|
}
|
|
11635
12300
|
}
|
|
11636
|
-
async compile(metricFn, options) {
|
|
11637
|
-
const
|
|
11638
|
-
const maxRounds = compileOptions?.maxRounds ?? this.maxRounds;
|
|
12301
|
+
async compile(program, metricFn, options) {
|
|
12302
|
+
const maxRounds = options?.maxIterations ?? this.maxRounds;
|
|
11639
12303
|
this.traces = [];
|
|
11640
|
-
this.
|
|
11641
|
-
totalCalls: 0,
|
|
11642
|
-
successfulDemos: 0,
|
|
11643
|
-
estimatedTokenUsage: 0,
|
|
11644
|
-
earlyStopped: false
|
|
11645
|
-
};
|
|
12304
|
+
this.reset();
|
|
11646
12305
|
for (let i = 0; i < maxRounds; i++) {
|
|
11647
|
-
await this.compileRound(i, metricFn,
|
|
12306
|
+
await this.compileRound(program, i, metricFn, options);
|
|
11648
12307
|
if (this.stats.earlyStopped) {
|
|
11649
12308
|
break;
|
|
11650
12309
|
}
|
|
11651
12310
|
}
|
|
11652
12311
|
if (this.traces.length === 0) {
|
|
11653
12312
|
throw new Error(
|
|
11654
|
-
"No demonstrations found. Either
|
|
12313
|
+
"No demonstrations found. Either provide more examples or improve the existing ones."
|
|
11655
12314
|
);
|
|
11656
12315
|
}
|
|
11657
12316
|
const demos = groupTracesByKeys(this.traces);
|
|
12317
|
+
let bestScore = 0;
|
|
12318
|
+
if (this.traces.length > 0) {
|
|
12319
|
+
bestScore = this.stats.successfulDemos / Math.max(1, this.stats.totalCalls);
|
|
12320
|
+
}
|
|
11658
12321
|
return {
|
|
11659
12322
|
demos,
|
|
11660
|
-
stats: this.stats
|
|
12323
|
+
stats: this.stats,
|
|
12324
|
+
bestScore,
|
|
12325
|
+
finalConfiguration: {
|
|
12326
|
+
maxRounds: this.maxRounds,
|
|
12327
|
+
maxDemos: this.maxDemos,
|
|
12328
|
+
batchSize: this.batchSize,
|
|
12329
|
+
successRate: bestScore
|
|
12330
|
+
}
|
|
11661
12331
|
};
|
|
11662
12332
|
}
|
|
11663
|
-
// Get optimization statistics
|
|
11664
|
-
getStats() {
|
|
11665
|
-
return this.stats;
|
|
11666
|
-
}
|
|
11667
12333
|
};
|
|
11668
12334
|
function groupTracesByKeys(programTraces) {
|
|
11669
12335
|
const groupedTraces = /* @__PURE__ */ new Map();
|
|
@@ -11678,9 +12344,12 @@ function groupTracesByKeys(programTraces) {
|
|
|
11678
12344
|
}
|
|
11679
12345
|
}
|
|
11680
12346
|
const programDemosArray = [];
|
|
11681
|
-
|
|
11682
|
-
programDemosArray.push({
|
|
11683
|
-
|
|
12347
|
+
groupedTraces.forEach((traces, programId) => {
|
|
12348
|
+
programDemosArray.push({
|
|
12349
|
+
traces,
|
|
12350
|
+
programId
|
|
12351
|
+
});
|
|
12352
|
+
});
|
|
11684
12353
|
return programDemosArray;
|
|
11685
12354
|
}
|
|
11686
12355
|
var randomSample = (array, n) => {
|
|
@@ -11699,10 +12368,8 @@ var randomSample = (array, n) => {
|
|
|
11699
12368
|
};
|
|
11700
12369
|
|
|
11701
12370
|
// dsp/optimizers/miproV2.ts
|
|
11702
|
-
var AxMiPRO = class {
|
|
11703
|
-
|
|
11704
|
-
program;
|
|
11705
|
-
examples;
|
|
12371
|
+
var AxMiPRO = class extends AxBaseOptimizer {
|
|
12372
|
+
// MiPRO-specific options
|
|
11706
12373
|
maxBootstrappedDemos;
|
|
11707
12374
|
maxLabeledDemos;
|
|
11708
12375
|
numCandidates;
|
|
@@ -11716,52 +12383,35 @@ var AxMiPRO = class {
|
|
|
11716
12383
|
viewDataBatchSize;
|
|
11717
12384
|
tipAwareProposer;
|
|
11718
12385
|
fewshotAwareProposer;
|
|
11719
|
-
seed;
|
|
11720
12386
|
verbose;
|
|
11721
|
-
bootstrapper;
|
|
11722
12387
|
earlyStoppingTrials;
|
|
11723
12388
|
minImprovementThreshold;
|
|
11724
|
-
|
|
11725
|
-
|
|
11726
|
-
|
|
11727
|
-
|
|
11728
|
-
|
|
11729
|
-
|
|
11730
|
-
|
|
11731
|
-
|
|
11732
|
-
|
|
11733
|
-
|
|
11734
|
-
this.
|
|
11735
|
-
this.
|
|
11736
|
-
this.
|
|
11737
|
-
this.
|
|
11738
|
-
this.
|
|
11739
|
-
this.
|
|
11740
|
-
this.
|
|
11741
|
-
this.
|
|
11742
|
-
this.
|
|
11743
|
-
this.
|
|
11744
|
-
this.
|
|
11745
|
-
this.
|
|
11746
|
-
this.
|
|
11747
|
-
this.
|
|
11748
|
-
this.
|
|
11749
|
-
this.
|
|
11750
|
-
this.minImprovementThreshold = miproOptions.minImprovementThreshold ?? 0.01;
|
|
11751
|
-
this.ai = ai;
|
|
11752
|
-
this.program = program;
|
|
11753
|
-
this.examples = examples;
|
|
11754
|
-
this.bootstrapper = new AxBootstrapFewShot({
|
|
11755
|
-
ai,
|
|
11756
|
-
program,
|
|
11757
|
-
examples,
|
|
11758
|
-
options: {
|
|
11759
|
-
maxDemos: this.maxBootstrappedDemos,
|
|
11760
|
-
maxRounds: 3,
|
|
11761
|
-
// Default, or adjust based on your needs
|
|
11762
|
-
verboseMode: this.verbose
|
|
11763
|
-
}
|
|
11764
|
-
});
|
|
12389
|
+
bayesianOptimization;
|
|
12390
|
+
acquisitionFunction;
|
|
12391
|
+
explorationWeight;
|
|
12392
|
+
constructor(args) {
|
|
12393
|
+
super(args);
|
|
12394
|
+
const options = args.options || {};
|
|
12395
|
+
this.numCandidates = options.numCandidates ?? 5;
|
|
12396
|
+
this.initTemperature = options.initTemperature ?? 0.7;
|
|
12397
|
+
this.maxBootstrappedDemos = options.maxBootstrappedDemos ?? 3;
|
|
12398
|
+
this.maxLabeledDemos = options.maxLabeledDemos ?? 4;
|
|
12399
|
+
this.numTrials = options.numTrials ?? 30;
|
|
12400
|
+
this.minibatch = options.minibatch ?? true;
|
|
12401
|
+
this.minibatchSize = options.minibatchSize ?? 25;
|
|
12402
|
+
this.minibatchFullEvalSteps = options.minibatchFullEvalSteps ?? 10;
|
|
12403
|
+
this.programAwareProposer = options.programAwareProposer ?? true;
|
|
12404
|
+
this.dataAwareProposer = options.dataAwareProposer ?? true;
|
|
12405
|
+
this.viewDataBatchSize = options.viewDataBatchSize ?? 10;
|
|
12406
|
+
this.tipAwareProposer = options.tipAwareProposer ?? true;
|
|
12407
|
+
this.fewshotAwareProposer = options.fewshotAwareProposer ?? true;
|
|
12408
|
+
this.verbose = options.verbose ?? false;
|
|
12409
|
+
this.earlyStoppingTrials = options.earlyStoppingTrials ?? 5;
|
|
12410
|
+
this.minImprovementThreshold = options.minImprovementThreshold ?? 0.01;
|
|
12411
|
+
this.bayesianOptimization = options.bayesianOptimization ?? false;
|
|
12412
|
+
this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
|
|
12413
|
+
this.explorationWeight = options.explorationWeight ?? 0.1;
|
|
12414
|
+
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
11765
12415
|
}
|
|
11766
12416
|
/**
|
|
11767
12417
|
* Configures the optimizer for light, medium, or heavy optimization
|
|
@@ -11805,123 +12455,60 @@ var AxMiPRO = class {
|
|
|
11805
12455
|
];
|
|
11806
12456
|
}
|
|
11807
12457
|
/**
|
|
11808
|
-
* Generates instruction candidates
|
|
12458
|
+
* Generates instruction candidates using the teacher model if available
|
|
12459
|
+
* @param options Optional compile options that may override teacher AI
|
|
11809
12460
|
* @returns Array of generated instruction candidates
|
|
11810
12461
|
*/
|
|
11811
|
-
async proposeInstructionCandidates() {
|
|
12462
|
+
async proposeInstructionCandidates(options) {
|
|
11812
12463
|
const instructions = [];
|
|
11813
|
-
|
|
11814
|
-
if (this.programAwareProposer) {
|
|
11815
|
-
programContext = await this.generateProgramSummary();
|
|
11816
|
-
}
|
|
11817
|
-
let dataContext = "";
|
|
11818
|
-
if (this.dataAwareProposer) {
|
|
11819
|
-
dataContext = await this.generateDataSummary();
|
|
11820
|
-
}
|
|
12464
|
+
const aiToUse = this.getTeacherOrStudentAI(options);
|
|
11821
12465
|
const tips = this.tipAwareProposer ? this.generateTips() : [];
|
|
11822
12466
|
for (let i = 0; i < this.numCandidates; i++) {
|
|
11823
12467
|
const tipIndex = tips.length > 0 ? i % tips.length : -1;
|
|
11824
12468
|
const tipToUse = tipIndex >= 0 ? tips[tipIndex] : "";
|
|
11825
12469
|
const instruction = await this.generateInstruction({
|
|
11826
|
-
programContext,
|
|
11827
|
-
dataContext,
|
|
11828
12470
|
tip: tipToUse,
|
|
11829
|
-
candidateIndex: i
|
|
12471
|
+
candidateIndex: i,
|
|
12472
|
+
ai: aiToUse
|
|
11830
12473
|
});
|
|
11831
12474
|
instructions.push(instruction);
|
|
11832
12475
|
}
|
|
11833
12476
|
return instructions;
|
|
11834
12477
|
}
|
|
11835
|
-
/**
|
|
11836
|
-
* Generates a summary of the program structure for instruction proposal
|
|
11837
|
-
*/
|
|
11838
|
-
async generateProgramSummary() {
|
|
11839
|
-
const prompt = `Summarize the following program structure. Focus on the signatures,
|
|
11840
|
-
input/output fields, and the purpose of each component. Identify key components
|
|
11841
|
-
that might benefit from better instructions.`;
|
|
11842
|
-
const programStr = JSON.stringify(this.program);
|
|
11843
|
-
const response = await this.ai.chat({
|
|
11844
|
-
chatPrompt: [
|
|
11845
|
-
{ role: "system", content: prompt },
|
|
11846
|
-
{ role: "user", content: programStr }
|
|
11847
|
-
],
|
|
11848
|
-
modelConfig: { temperature: 0.2 }
|
|
11849
|
-
});
|
|
11850
|
-
if (response instanceof ReadableStream) {
|
|
11851
|
-
return "";
|
|
11852
|
-
}
|
|
11853
|
-
return response.results[0]?.content || "";
|
|
11854
|
-
}
|
|
11855
|
-
/**
|
|
11856
|
-
* Generates a summary of the dataset for instruction proposal
|
|
11857
|
-
*/
|
|
11858
|
-
async generateDataSummary() {
|
|
11859
|
-
const sampleSize = Math.min(this.viewDataBatchSize, this.examples.length);
|
|
11860
|
-
const sample = this.examples.slice(0, sampleSize);
|
|
11861
|
-
const prompt = `Analyze the following dataset examples and provide a summary
|
|
11862
|
-
of key patterns, input-output relationships, and any specific challenges
|
|
11863
|
-
the data presents. Focus on what makes a good answer and what patterns should
|
|
11864
|
-
be followed.`;
|
|
11865
|
-
const dataStr = JSON.stringify(sample);
|
|
11866
|
-
const response = await this.ai.chat({
|
|
11867
|
-
chatPrompt: [
|
|
11868
|
-
{ role: "system", content: prompt },
|
|
11869
|
-
{ role: "user", content: dataStr }
|
|
11870
|
-
],
|
|
11871
|
-
modelConfig: { temperature: 0.2 }
|
|
11872
|
-
});
|
|
11873
|
-
if (response instanceof ReadableStream) {
|
|
11874
|
-
return "";
|
|
11875
|
-
}
|
|
11876
|
-
return response.results[0]?.content || "";
|
|
11877
|
-
}
|
|
11878
|
-
/**
|
|
11879
|
-
* Generates a specific instruction candidate
|
|
11880
|
-
*/
|
|
11881
12478
|
async generateInstruction({
|
|
11882
|
-
programContext,
|
|
11883
|
-
dataContext,
|
|
11884
12479
|
tip,
|
|
11885
12480
|
candidateIndex
|
|
11886
12481
|
}) {
|
|
11887
|
-
const
|
|
11888
|
-
|
|
11889
|
-
|
|
11890
|
-
|
|
11891
|
-
|
|
11892
|
-
|
|
11893
|
-
|
|
11894
|
-
|
|
11895
|
-
|
|
11896
|
-
|
|
11897
|
-
${tip ? `STYLE TIP: ${tip}
|
|
11898
|
-
|
|
11899
|
-
` : ""}
|
|
11900
|
-
|
|
11901
|
-
Your task is to craft a clear, effective instruction that will help the AI model generate
|
|
11902
|
-
accurate outputs for this task. Instruction #${candidateIndex + 1}/${this.numCandidates}.
|
|
11903
|
-
|
|
11904
|
-
The instruction should be detailed enough to guide the model but not overly prescriptive
|
|
11905
|
-
or restrictive. Focus on what makes a good response rather than listing exact steps.
|
|
11906
|
-
|
|
11907
|
-
INSTRUCTION:`;
|
|
11908
|
-
const response = await this.ai.chat({
|
|
11909
|
-
chatPrompt: [{ role: "user", content: prompt }],
|
|
11910
|
-
modelConfig: { temperature: 0.7 + 0.1 * candidateIndex }
|
|
11911
|
-
});
|
|
11912
|
-
if (response instanceof ReadableStream) {
|
|
11913
|
-
return "";
|
|
12482
|
+
const baseInstructions = [
|
|
12483
|
+
"Analyze the input carefully and provide a detailed response.",
|
|
12484
|
+
"Think step by step and provide a clear answer.",
|
|
12485
|
+
"Consider all aspects of the input before responding.",
|
|
12486
|
+
"Provide a concise but comprehensive response.",
|
|
12487
|
+
"Focus on accuracy and clarity in your response."
|
|
12488
|
+
];
|
|
12489
|
+
let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
|
|
12490
|
+
if (tip) {
|
|
12491
|
+
instruction = `${instruction} ${tip}`;
|
|
11914
12492
|
}
|
|
11915
|
-
return
|
|
12493
|
+
return instruction;
|
|
11916
12494
|
}
|
|
11917
12495
|
/**
|
|
11918
12496
|
* Bootstraps few-shot examples for the program
|
|
11919
12497
|
*/
|
|
11920
|
-
async bootstrapFewShotExamples(metricFn) {
|
|
12498
|
+
async bootstrapFewShotExamples(program, metricFn) {
|
|
11921
12499
|
if (this.verbose) {
|
|
11922
12500
|
console.log("Bootstrapping few-shot examples...");
|
|
11923
12501
|
}
|
|
11924
|
-
const
|
|
12502
|
+
const bootstrapper = new AxBootstrapFewShot({
|
|
12503
|
+
studentAI: this.studentAI,
|
|
12504
|
+
examples: this.examples,
|
|
12505
|
+
options: {
|
|
12506
|
+
maxDemos: this.maxBootstrappedDemos,
|
|
12507
|
+
maxRounds: 3,
|
|
12508
|
+
verboseMode: this.verbose
|
|
12509
|
+
}
|
|
12510
|
+
});
|
|
12511
|
+
const result = await bootstrapper.compile(program, metricFn, {
|
|
11925
12512
|
maxDemos: this.maxBootstrappedDemos
|
|
11926
12513
|
});
|
|
11927
12514
|
return result.demos || [];
|
|
@@ -11945,109 +12532,98 @@ ${dataContext}
|
|
|
11945
12532
|
return selectedExamples;
|
|
11946
12533
|
}
|
|
11947
12534
|
/**
|
|
11948
|
-
* Runs
|
|
12535
|
+
* Runs optimization to find the best combination of few-shot examples and instructions
|
|
11949
12536
|
*/
|
|
11950
|
-
async
|
|
11951
|
-
let bestConfig =
|
|
11952
|
-
let bestScore = Number.NEGATIVE_INFINITY;
|
|
11953
|
-
const evaluatedConfigs = [];
|
|
11954
|
-
const defaultConfig = {
|
|
12537
|
+
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, valset, metricFn, options) {
|
|
12538
|
+
let bestConfig = {
|
|
11955
12539
|
instruction: instructions[0] || "",
|
|
11956
12540
|
bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
|
|
11957
12541
|
labeledExamples: Math.min(1, labeledExamples.length)
|
|
11958
12542
|
};
|
|
11959
|
-
let
|
|
11960
|
-
let
|
|
11961
|
-
const
|
|
11962
|
-
|
|
11963
|
-
|
|
11964
|
-
|
|
11965
|
-
|
|
11966
|
-
|
|
11967
|
-
|
|
11968
|
-
|
|
12543
|
+
let bestScore = 0;
|
|
12544
|
+
let stagnationRounds = 0;
|
|
12545
|
+
const scoreHistory = [];
|
|
12546
|
+
let startRound = 0;
|
|
12547
|
+
if (this.resumeFromCheckpoint) {
|
|
12548
|
+
const checkpoint = await this.loadCheckpoint(
|
|
12549
|
+
this.resumeFromCheckpoint,
|
|
12550
|
+
options
|
|
12551
|
+
);
|
|
12552
|
+
if (checkpoint && checkpoint.optimizerType === "MiPRO") {
|
|
12553
|
+
if (this.verbose || options?.verbose) {
|
|
12554
|
+
console.log(
|
|
12555
|
+
`Resuming from checkpoint at round ${checkpoint.currentRound}`
|
|
12556
|
+
);
|
|
12557
|
+
}
|
|
12558
|
+
this.restoreFromCheckpoint(checkpoint);
|
|
12559
|
+
startRound = checkpoint.currentRound;
|
|
12560
|
+
bestScore = checkpoint.bestScore;
|
|
12561
|
+
bestConfig = checkpoint.bestConfiguration || bestConfig;
|
|
12562
|
+
stagnationRounds = checkpoint.stats.convergenceInfo?.stagnationRounds || 0;
|
|
12563
|
+
}
|
|
12564
|
+
}
|
|
12565
|
+
for (let i = startRound; i < this.numTrials; i++) {
|
|
11969
12566
|
const config = {
|
|
11970
|
-
instruction:
|
|
11971
|
-
bootstrappedDemos: Math.
|
|
11972
|
-
Math.random() * (bootstrappedDemos.length + 1)
|
|
12567
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
12568
|
+
bootstrappedDemos: Math.min(
|
|
12569
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
12570
|
+
this.maxBootstrappedDemos
|
|
11973
12571
|
),
|
|
11974
|
-
labeledExamples: Math.
|
|
11975
|
-
Math.random() * (labeledExamples.length + 1)
|
|
12572
|
+
labeledExamples: Math.min(
|
|
12573
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
12574
|
+
this.maxLabeledDemos
|
|
11976
12575
|
)
|
|
11977
12576
|
};
|
|
11978
|
-
configs.push(config);
|
|
11979
|
-
}
|
|
11980
|
-
for (let i = 0; i < configs.length; i++) {
|
|
11981
|
-
const config = configs[i];
|
|
11982
|
-
if (!config) continue;
|
|
11983
12577
|
const score = await this.evaluateConfig(
|
|
12578
|
+
program,
|
|
11984
12579
|
config,
|
|
11985
12580
|
bootstrappedDemos,
|
|
11986
12581
|
labeledExamples,
|
|
11987
12582
|
valset,
|
|
11988
|
-
metricFn
|
|
11989
|
-
i
|
|
12583
|
+
metricFn
|
|
11990
12584
|
);
|
|
11991
|
-
|
|
11992
|
-
|
|
12585
|
+
scoreHistory.push(score);
|
|
12586
|
+
const improvement = score - bestScore;
|
|
12587
|
+
if (improvement > this.minImprovementThreshold) {
|
|
11993
12588
|
bestScore = score;
|
|
11994
12589
|
bestConfig = config;
|
|
11995
|
-
|
|
11996
|
-
|
|
11997
|
-
|
|
11998
|
-
);
|
|
11999
|
-
}
|
|
12590
|
+
stagnationRounds = 0;
|
|
12591
|
+
} else {
|
|
12592
|
+
stagnationRounds++;
|
|
12000
12593
|
}
|
|
12001
|
-
|
|
12594
|
+
await this.updateOptimizationProgress(
|
|
12002
12595
|
i + 1,
|
|
12003
|
-
|
|
12004
|
-
|
|
12005
|
-
|
|
12006
|
-
|
|
12007
|
-
|
|
12008
|
-
|
|
12009
|
-
|
|
12010
|
-
|
|
12011
|
-
|
|
12012
|
-
|
|
12013
|
-
|
|
12014
|
-
|
|
12015
|
-
|
|
12016
|
-
);
|
|
12017
|
-
const score = await this.evaluateConfig(
|
|
12018
|
-
nextConfig,
|
|
12019
|
-
bootstrappedDemos,
|
|
12020
|
-
labeledExamples,
|
|
12021
|
-
valset,
|
|
12022
|
-
metricFn,
|
|
12023
|
-
i
|
|
12596
|
+
score,
|
|
12597
|
+
config,
|
|
12598
|
+
"MiPRO",
|
|
12599
|
+
this.getConfiguration(),
|
|
12600
|
+
bestScore,
|
|
12601
|
+
bestConfig,
|
|
12602
|
+
{
|
|
12603
|
+
stagnationRounds,
|
|
12604
|
+
bootstrappedDemos: bootstrappedDemos.length,
|
|
12605
|
+
labeledExamples: labeledExamples.length,
|
|
12606
|
+
instructions: instructions.length
|
|
12607
|
+
},
|
|
12608
|
+
options
|
|
12024
12609
|
);
|
|
12025
|
-
|
|
12026
|
-
|
|
12027
|
-
|
|
12028
|
-
|
|
12029
|
-
|
|
12030
|
-
|
|
12031
|
-
|
|
12032
|
-
)
|
|
12033
|
-
|
|
12034
|
-
|
|
12035
|
-
|
|
12036
|
-
|
|
12037
|
-
|
|
12038
|
-
|
|
12039
|
-
|
|
12040
|
-
if (this.verbose) {
|
|
12041
|
-
console.log(
|
|
12042
|
-
`Early stopping triggered after ${i + 1} trials. No improvement for ${trialsWithoutImprovement} trials.`
|
|
12043
|
-
);
|
|
12044
|
-
}
|
|
12045
|
-
break;
|
|
12610
|
+
if (this.onProgress) {
|
|
12611
|
+
this.onProgress({
|
|
12612
|
+
round: i + 1,
|
|
12613
|
+
totalRounds: this.numTrials,
|
|
12614
|
+
currentScore: score,
|
|
12615
|
+
bestScore,
|
|
12616
|
+
tokensUsed: this.stats.resourceUsage.totalTokens,
|
|
12617
|
+
timeElapsed: Date.now(),
|
|
12618
|
+
successfulExamples: this.stats.successfulDemos,
|
|
12619
|
+
totalExamples: this.examples.length,
|
|
12620
|
+
currentConfiguration: config,
|
|
12621
|
+
convergenceInfo: {
|
|
12622
|
+
improvement,
|
|
12623
|
+
stagnationRounds,
|
|
12624
|
+
isConverging: stagnationRounds < this.earlyStoppingTrials
|
|
12046
12625
|
}
|
|
12047
|
-
}
|
|
12048
|
-
lastBestScore = bestScore;
|
|
12049
|
-
trialsWithoutImprovement = 0;
|
|
12050
|
-
}
|
|
12626
|
+
});
|
|
12051
12627
|
}
|
|
12052
12628
|
updateProgressBar(
|
|
12053
12629
|
i + 1,
|
|
@@ -12057,243 +12633,91 @@ ${dataContext}
|
|
|
12057
12633
|
"Running MIPROv2 optimization",
|
|
12058
12634
|
30
|
|
12059
12635
|
);
|
|
12060
|
-
if (this.
|
|
12061
|
-
|
|
12062
|
-
|
|
12063
|
-
`Running full evaluation on best configuration at trial ${i + 1}`
|
|
12064
|
-
);
|
|
12065
|
-
}
|
|
12066
|
-
const fullScore = await this.fullEvaluation(
|
|
12067
|
-
bestConfig,
|
|
12068
|
-
bootstrappedDemos,
|
|
12069
|
-
labeledExamples,
|
|
12070
|
-
valset,
|
|
12071
|
-
metricFn
|
|
12072
|
-
);
|
|
12073
|
-
if (this.verbose) {
|
|
12074
|
-
console.log(`Full evaluation score: ${fullScore}`);
|
|
12075
|
-
}
|
|
12076
|
-
bestScore = fullScore;
|
|
12636
|
+
if (this.checkCostLimits()) {
|
|
12637
|
+
this.triggerEarlyStopping("Cost limit reached", i + 1);
|
|
12638
|
+
break;
|
|
12077
12639
|
}
|
|
12078
|
-
|
|
12079
|
-
|
|
12080
|
-
|
|
12081
|
-
|
|
12082
|
-
"Optimization failed to find any valid configurations, using default fallback configuration"
|
|
12640
|
+
if (stagnationRounds >= this.earlyStoppingTrials) {
|
|
12641
|
+
this.triggerEarlyStopping(
|
|
12642
|
+
`No improvement for ${this.earlyStoppingTrials} trials`,
|
|
12643
|
+
i - stagnationRounds + 1
|
|
12083
12644
|
);
|
|
12645
|
+
break;
|
|
12084
12646
|
}
|
|
12085
|
-
|
|
12086
|
-
|
|
12087
|
-
|
|
12088
|
-
|
|
12089
|
-
bootstrappedDemos,
|
|
12090
|
-
labeledExamples,
|
|
12091
|
-
valset,
|
|
12092
|
-
metricFn,
|
|
12093
|
-
this.numTrials - 1
|
|
12647
|
+
if (this.checkTargetScore(bestScore)) {
|
|
12648
|
+
this.triggerEarlyStopping(
|
|
12649
|
+
`Target score ${this.targetScore} reached`,
|
|
12650
|
+
i + 1
|
|
12094
12651
|
);
|
|
12095
|
-
|
|
12096
|
-
if (this.verbose) {
|
|
12097
|
-
console.error("Error evaluating default configuration:", err);
|
|
12098
|
-
}
|
|
12099
|
-
bestScore = 0;
|
|
12652
|
+
break;
|
|
12100
12653
|
}
|
|
12101
12654
|
}
|
|
12655
|
+
this.stats.convergenceInfo.stagnationRounds = stagnationRounds;
|
|
12656
|
+
this.stats.convergenceInfo.finalImprovement = scoreHistory.length > 1 ? bestScore - scoreHistory[0] : 0;
|
|
12657
|
+
this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
|
|
12102
12658
|
return { bestConfig, bestScore };
|
|
12103
12659
|
}
|
|
12104
|
-
|
|
12105
|
-
|
|
12106
|
-
*/
|
|
12107
|
-
async evaluateConfig(config, bootstrappedDemos, labeledExamples, valset, metricFn, trialIndex) {
|
|
12660
|
+
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, valset, metricFn) {
|
|
12661
|
+
const testProgram = { ...program };
|
|
12108
12662
|
this.applyConfigToProgram(
|
|
12109
|
-
|
|
12663
|
+
testProgram,
|
|
12110
12664
|
config,
|
|
12111
12665
|
bootstrappedDemos,
|
|
12112
12666
|
labeledExamples
|
|
12113
12667
|
);
|
|
12114
|
-
let
|
|
12115
|
-
|
|
12116
|
-
|
|
12117
|
-
const minibatchEvalSet = [];
|
|
12118
|
-
for (let j = 0; j < this.minibatchSize; j++) {
|
|
12119
|
-
const idx = (startIdx + j) % valset.length;
|
|
12120
|
-
const example = valset[idx];
|
|
12121
|
-
if (example) {
|
|
12122
|
-
minibatchEvalSet.push(example);
|
|
12123
|
-
}
|
|
12124
|
-
}
|
|
12125
|
-
evalSet = minibatchEvalSet;
|
|
12126
|
-
}
|
|
12127
|
-
let sumOfScores = 0;
|
|
12668
|
+
let totalScore = 0;
|
|
12669
|
+
let count = 0;
|
|
12670
|
+
const evalSet = valset.slice(0, Math.min(5, valset.length));
|
|
12128
12671
|
for (const example of evalSet) {
|
|
12129
12672
|
try {
|
|
12130
|
-
const prediction = await
|
|
12131
|
-
|
|
12132
|
-
|
|
12133
|
-
|
|
12134
|
-
|
|
12135
|
-
|
|
12136
|
-
|
|
12137
|
-
|
|
12138
|
-
|
|
12139
|
-
|
|
12140
|
-
return sumOfScores / evalSet.length;
|
|
12141
|
-
}
|
|
12142
|
-
/**
|
|
12143
|
-
* Run full evaluation on the entire validation set
|
|
12144
|
-
*/
|
|
12145
|
-
async fullEvaluation(config, bootstrappedDemos, labeledExamples, valset, metricFn) {
|
|
12146
|
-
this.applyConfigToProgram(
|
|
12147
|
-
this.program,
|
|
12148
|
-
config,
|
|
12149
|
-
bootstrappedDemos,
|
|
12150
|
-
labeledExamples
|
|
12151
|
-
);
|
|
12152
|
-
let sumOfScores = 0;
|
|
12153
|
-
for (const example of valset) {
|
|
12154
|
-
try {
|
|
12155
|
-
const prediction = await this.program.forward(this.ai, example);
|
|
12156
|
-
const score = metricFn({ prediction, example });
|
|
12157
|
-
sumOfScores += score;
|
|
12158
|
-
} catch (err) {
|
|
12159
|
-
if (this.verbose) {
|
|
12160
|
-
console.error("Error evaluating example:", err);
|
|
12161
|
-
}
|
|
12673
|
+
const prediction = await testProgram.forward(
|
|
12674
|
+
this.studentAI,
|
|
12675
|
+
example
|
|
12676
|
+
);
|
|
12677
|
+
const score = await metricFn({ prediction, example });
|
|
12678
|
+
totalScore += score;
|
|
12679
|
+
count++;
|
|
12680
|
+
this.stats.totalCalls++;
|
|
12681
|
+
} catch {
|
|
12682
|
+
continue;
|
|
12162
12683
|
}
|
|
12163
12684
|
}
|
|
12164
|
-
|
|
12165
|
-
return sumOfScores / valset.length;
|
|
12685
|
+
return count > 0 ? totalScore / count : 0;
|
|
12166
12686
|
}
|
|
12167
|
-
/**
|
|
12168
|
-
* Implements a Bayesian-inspired selection of the next configuration to try
|
|
12169
|
-
* This is a simplified version using Upper Confidence Bound (UCB) strategy
|
|
12170
|
-
*/
|
|
12171
|
-
selectNextConfiguration(evaluatedConfigs, maxBootstrappedDemos, maxLabeledExamples, instructions) {
|
|
12172
|
-
if (evaluatedConfigs.length < 5) {
|
|
12173
|
-
const instructionIndex = Math.floor(Math.random() * instructions.length);
|
|
12174
|
-
return {
|
|
12175
|
-
instruction: instructions[instructionIndex] || "",
|
|
12176
|
-
bootstrappedDemos: Math.floor(
|
|
12177
|
-
Math.random() * (maxBootstrappedDemos + 1)
|
|
12178
|
-
),
|
|
12179
|
-
labeledExamples: Math.floor(Math.random() * (maxLabeledExamples + 1))
|
|
12180
|
-
};
|
|
12181
|
-
}
|
|
12182
|
-
const sortedConfigs = [...evaluatedConfigs].sort(
|
|
12183
|
-
(a, b) => b.score - a.score
|
|
12184
|
-
);
|
|
12185
|
-
const topConfigs = sortedConfigs.slice(0, Math.min(3, sortedConfigs.length));
|
|
12186
|
-
const meanBootstrappedDemos = topConfigs.reduce((sum, c) => sum + c.config.bootstrappedDemos, 0) / topConfigs.length;
|
|
12187
|
-
const meanLabeledExamples = topConfigs.reduce((sum, c) => sum + c.config.labeledExamples, 0) / topConfigs.length;
|
|
12188
|
-
const popularInstructions = topConfigs.map((c) => c.config.instruction);
|
|
12189
|
-
const explorationFactor = Math.max(
|
|
12190
|
-
0.2,
|
|
12191
|
-
1 - evaluatedConfigs.length / this.numTrials
|
|
12192
|
-
);
|
|
12193
|
-
let newBootstrappedDemos;
|
|
12194
|
-
let newLabeledExamples;
|
|
12195
|
-
let newInstruction;
|
|
12196
|
-
if (Math.random() < 0.7) {
|
|
12197
|
-
newBootstrappedDemos = Math.min(
|
|
12198
|
-
maxBootstrappedDemos,
|
|
12199
|
-
Math.max(
|
|
12200
|
-
0,
|
|
12201
|
-
Math.round(
|
|
12202
|
-
meanBootstrappedDemos + (Math.random() * 2 - 1) * explorationFactor * 2
|
|
12203
|
-
)
|
|
12204
|
-
)
|
|
12205
|
-
);
|
|
12206
|
-
} else {
|
|
12207
|
-
newBootstrappedDemos = Math.floor(
|
|
12208
|
-
Math.random() * (maxBootstrappedDemos + 1)
|
|
12209
|
-
);
|
|
12210
|
-
}
|
|
12211
|
-
if (Math.random() < 0.7) {
|
|
12212
|
-
newLabeledExamples = Math.min(
|
|
12213
|
-
maxLabeledExamples,
|
|
12214
|
-
Math.max(
|
|
12215
|
-
0,
|
|
12216
|
-
Math.round(
|
|
12217
|
-
meanLabeledExamples + (Math.random() * 2 - 1) * explorationFactor * 2
|
|
12218
|
-
)
|
|
12219
|
-
)
|
|
12220
|
-
);
|
|
12221
|
-
} else {
|
|
12222
|
-
newLabeledExamples = Math.floor(Math.random() * (maxLabeledExamples + 1));
|
|
12223
|
-
}
|
|
12224
|
-
if (Math.random() < 0.7 && popularInstructions.length > 0) {
|
|
12225
|
-
const idx = Math.floor(Math.random() * popularInstructions.length);
|
|
12226
|
-
newInstruction = popularInstructions[idx] || "";
|
|
12227
|
-
} else {
|
|
12228
|
-
const idx = Math.floor(Math.random() * instructions.length);
|
|
12229
|
-
newInstruction = instructions[idx] || "";
|
|
12230
|
-
}
|
|
12231
|
-
return {
|
|
12232
|
-
instruction: newInstruction,
|
|
12233
|
-
bootstrappedDemos: newBootstrappedDemos,
|
|
12234
|
-
labeledExamples: newLabeledExamples
|
|
12235
|
-
};
|
|
12236
|
-
}
|
|
12237
|
-
/**
|
|
12238
|
-
* Applies a configuration to a program instance
|
|
12239
|
-
*/
|
|
12240
12687
|
applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
|
|
12241
|
-
|
|
12242
|
-
|
|
12688
|
+
if (program.setInstruction) {
|
|
12689
|
+
program.setInstruction(config.instruction);
|
|
12690
|
+
}
|
|
12691
|
+
if (config.bootstrappedDemos > 0 && program.setDemos) {
|
|
12243
12692
|
program.setDemos(bootstrappedDemos.slice(0, config.bootstrappedDemos));
|
|
12244
12693
|
}
|
|
12245
|
-
if (config.labeledExamples > 0) {
|
|
12694
|
+
if (config.labeledExamples > 0 && program.setExamples) {
|
|
12246
12695
|
program.setExamples(labeledExamples.slice(0, config.labeledExamples));
|
|
12247
12696
|
}
|
|
12248
12697
|
}
|
|
12249
|
-
/**
|
|
12250
|
-
* Sets instruction to a program
|
|
12251
|
-
* Note: Workaround since setInstruction may not be available directly
|
|
12252
|
-
*/
|
|
12253
|
-
setInstructionToProgram(program, instruction) {
|
|
12254
|
-
const programWithInstruction = program;
|
|
12255
|
-
programWithInstruction.setInstruction?.(instruction);
|
|
12256
|
-
}
|
|
12257
12698
|
/**
|
|
12258
12699
|
* The main compile method to run MIPROv2 optimization
|
|
12259
|
-
* @param metricFn Evaluation metric function
|
|
12260
|
-
* @param options Optional configuration options
|
|
12261
|
-
* @returns The optimization result
|
|
12262
12700
|
*/
|
|
12263
|
-
async compile(metricFn, options) {
|
|
12701
|
+
async compile(program, metricFn, options) {
|
|
12702
|
+
const startTime = Date.now();
|
|
12703
|
+
this.setupRandomSeed();
|
|
12264
12704
|
const miproOptions = options;
|
|
12265
12705
|
if (miproOptions?.auto) {
|
|
12266
12706
|
this.configureAuto(miproOptions.auto);
|
|
12267
12707
|
}
|
|
12268
|
-
const
|
|
12269
|
-
|
|
12270
|
-
if (this.verbose) {
|
|
12708
|
+
const valset = this.getValidationSet(options) || (miproOptions?.valset ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
|
|
12709
|
+
if (this.verbose || options?.verbose) {
|
|
12271
12710
|
console.log(`Starting MIPROv2 optimization with ${this.numTrials} trials`);
|
|
12272
12711
|
console.log(
|
|
12273
|
-
`Using ${
|
|
12712
|
+
`Using ${this.examples.length} examples for training and ${valset.length} for validation`
|
|
12274
12713
|
);
|
|
12275
|
-
|
|
12276
|
-
|
|
12277
|
-
if (this.verbose) {
|
|
12278
|
-
console.log("Using provided teacher to assist with bootstrapping");
|
|
12714
|
+
if (this.teacherAI) {
|
|
12715
|
+
console.log("Using separate teacher model for instruction generation");
|
|
12279
12716
|
}
|
|
12280
|
-
const bootstrapperWithTeacher = new AxBootstrapFewShot({
|
|
12281
|
-
ai: this.ai,
|
|
12282
|
-
program: this.program,
|
|
12283
|
-
examples: this.examples,
|
|
12284
|
-
options: {
|
|
12285
|
-
maxDemos: this.maxBootstrappedDemos,
|
|
12286
|
-
maxRounds: 3,
|
|
12287
|
-
verboseMode: this.verbose,
|
|
12288
|
-
teacherAI: this.ai
|
|
12289
|
-
// Use the same AI but with the teacher program
|
|
12290
|
-
}
|
|
12291
|
-
});
|
|
12292
|
-
this.bootstrapper = bootstrapperWithTeacher;
|
|
12293
12717
|
}
|
|
12294
12718
|
let bootstrappedDemos = [];
|
|
12295
12719
|
if (this.maxBootstrappedDemos > 0) {
|
|
12296
|
-
bootstrappedDemos = await this.bootstrapFewShotExamples(metricFn);
|
|
12720
|
+
bootstrappedDemos = await this.bootstrapFewShotExamples(program, metricFn);
|
|
12297
12721
|
if (this.verbose) {
|
|
12298
12722
|
console.log(
|
|
12299
12723
|
`Generated ${bootstrappedDemos.length} bootstrapped demonstrations`
|
|
@@ -12309,38 +12733,191 @@ ${dataContext}
|
|
|
12309
12733
|
);
|
|
12310
12734
|
}
|
|
12311
12735
|
}
|
|
12312
|
-
const instructions = await this.proposeInstructionCandidates();
|
|
12736
|
+
const instructions = await this.proposeInstructionCandidates(options);
|
|
12313
12737
|
if (this.verbose) {
|
|
12314
12738
|
console.log(`Generated ${instructions.length} instruction candidates`);
|
|
12739
|
+
if (this.hasTeacherAI(options)) {
|
|
12740
|
+
console.log("Using teacher AI for instruction generation");
|
|
12741
|
+
}
|
|
12315
12742
|
}
|
|
12316
|
-
const { bestConfig, bestScore } = await this.
|
|
12743
|
+
const { bestConfig, bestScore } = await this.runOptimization(
|
|
12744
|
+
program,
|
|
12317
12745
|
bootstrappedDemos,
|
|
12318
12746
|
labeledExamples,
|
|
12319
12747
|
instructions,
|
|
12320
12748
|
valset,
|
|
12321
|
-
metricFn
|
|
12749
|
+
metricFn,
|
|
12750
|
+
options
|
|
12322
12751
|
);
|
|
12323
|
-
if (this.verbose) {
|
|
12752
|
+
if (this.verbose || options?.verbose) {
|
|
12324
12753
|
console.log(`Optimization complete. Best score: ${bestScore}`);
|
|
12325
12754
|
console.log(`Best configuration: ${JSON.stringify(bestConfig)}`);
|
|
12326
12755
|
}
|
|
12327
|
-
this.
|
|
12328
|
-
this.
|
|
12756
|
+
if (this.checkTargetScore(bestScore)) {
|
|
12757
|
+
this.triggerEarlyStopping(
|
|
12758
|
+
`Target score ${this.targetScore} reached with score ${bestScore}`,
|
|
12759
|
+
this.numTrials
|
|
12760
|
+
);
|
|
12761
|
+
}
|
|
12762
|
+
let signature;
|
|
12763
|
+
if ("getSignature" in program && typeof program.getSignature === "function") {
|
|
12764
|
+
signature = program.getSignature();
|
|
12765
|
+
} else {
|
|
12766
|
+
signature = "input -> output";
|
|
12767
|
+
}
|
|
12768
|
+
const optimizedGen = new AxGen(signature);
|
|
12769
|
+
this.applyConfigToAxGen(
|
|
12770
|
+
optimizedGen,
|
|
12329
12771
|
bestConfig,
|
|
12330
12772
|
bootstrappedDemos,
|
|
12331
12773
|
labeledExamples
|
|
12332
12774
|
);
|
|
12775
|
+
this.updateResourceUsage(startTime);
|
|
12776
|
+
this.stats.convergenceInfo.converged = true;
|
|
12777
|
+
this.stats.convergenceInfo.finalImprovement = bestScore;
|
|
12778
|
+
await this.saveFinalCheckpoint(
|
|
12779
|
+
"MiPRO",
|
|
12780
|
+
this.getConfiguration(),
|
|
12781
|
+
bestScore,
|
|
12782
|
+
bestConfig,
|
|
12783
|
+
{
|
|
12784
|
+
bootstrappedDemos: bootstrappedDemos.length,
|
|
12785
|
+
labeledExamples: labeledExamples.length,
|
|
12786
|
+
instructions: instructions.length,
|
|
12787
|
+
optimizedGen: !!optimizedGen
|
|
12788
|
+
},
|
|
12789
|
+
options
|
|
12790
|
+
);
|
|
12333
12791
|
return {
|
|
12334
|
-
|
|
12335
|
-
|
|
12792
|
+
demos: bootstrappedDemos,
|
|
12793
|
+
stats: this.stats,
|
|
12794
|
+
bestScore,
|
|
12795
|
+
optimizedGen,
|
|
12796
|
+
finalConfiguration: {
|
|
12797
|
+
instruction: bestConfig.instruction,
|
|
12798
|
+
bootstrappedDemos: bestConfig.bootstrappedDemos,
|
|
12799
|
+
labeledExamples: bestConfig.labeledExamples,
|
|
12800
|
+
numCandidates: this.numCandidates,
|
|
12801
|
+
numTrials: this.numTrials
|
|
12802
|
+
}
|
|
12336
12803
|
};
|
|
12337
12804
|
}
|
|
12338
12805
|
/**
|
|
12339
|
-
*
|
|
12340
|
-
* @returns Optimization statistics or undefined if not available
|
|
12806
|
+
* Applies a configuration to an AxGen instance
|
|
12341
12807
|
*/
|
|
12342
|
-
|
|
12343
|
-
|
|
12808
|
+
applyConfigToAxGen(axgen, config, bootstrappedDemos, labeledExamples) {
|
|
12809
|
+
if ("setInstruction" in axgen && typeof axgen.setInstruction === "function") {
|
|
12810
|
+
axgen.setInstruction(config.instruction);
|
|
12811
|
+
}
|
|
12812
|
+
if (config.bootstrappedDemos > 0) {
|
|
12813
|
+
axgen.setDemos(bootstrappedDemos.slice(0, config.bootstrappedDemos));
|
|
12814
|
+
}
|
|
12815
|
+
if (config.labeledExamples > 0) {
|
|
12816
|
+
axgen.setExamples(
|
|
12817
|
+
labeledExamples.slice(
|
|
12818
|
+
0,
|
|
12819
|
+
config.labeledExamples
|
|
12820
|
+
)
|
|
12821
|
+
);
|
|
12822
|
+
}
|
|
12823
|
+
}
|
|
12824
|
+
/**
|
|
12825
|
+
* Get optimizer-specific configuration
|
|
12826
|
+
* @returns Current optimizer configuration
|
|
12827
|
+
*/
|
|
12828
|
+
getConfiguration() {
|
|
12829
|
+
return {
|
|
12830
|
+
numCandidates: this.numCandidates,
|
|
12831
|
+
initTemperature: this.initTemperature,
|
|
12832
|
+
maxBootstrappedDemos: this.maxBootstrappedDemos,
|
|
12833
|
+
maxLabeledDemos: this.maxLabeledDemos,
|
|
12834
|
+
numTrials: this.numTrials,
|
|
12835
|
+
minibatch: this.minibatch,
|
|
12836
|
+
minibatchSize: this.minibatchSize,
|
|
12837
|
+
minibatchFullEvalSteps: this.minibatchFullEvalSteps,
|
|
12838
|
+
programAwareProposer: this.programAwareProposer,
|
|
12839
|
+
dataAwareProposer: this.dataAwareProposer,
|
|
12840
|
+
tipAwareProposer: this.tipAwareProposer,
|
|
12841
|
+
fewshotAwareProposer: this.fewshotAwareProposer,
|
|
12842
|
+
earlyStoppingTrials: this.earlyStoppingTrials,
|
|
12843
|
+
minImprovementThreshold: this.minImprovementThreshold,
|
|
12844
|
+
bayesianOptimization: this.bayesianOptimization,
|
|
12845
|
+
acquisitionFunction: this.acquisitionFunction,
|
|
12846
|
+
explorationWeight: this.explorationWeight
|
|
12847
|
+
};
|
|
12848
|
+
}
|
|
12849
|
+
/**
|
|
12850
|
+
* Update optimizer configuration
|
|
12851
|
+
* @param config New configuration to merge with existing
|
|
12852
|
+
*/
|
|
12853
|
+
updateConfiguration(config) {
|
|
12854
|
+
if (config.numCandidates !== void 0) {
|
|
12855
|
+
this.numCandidates = config.numCandidates;
|
|
12856
|
+
}
|
|
12857
|
+
if (config.initTemperature !== void 0) {
|
|
12858
|
+
this.initTemperature = config.initTemperature;
|
|
12859
|
+
}
|
|
12860
|
+
if (config.maxBootstrappedDemos !== void 0) {
|
|
12861
|
+
this.maxBootstrappedDemos = config.maxBootstrappedDemos;
|
|
12862
|
+
}
|
|
12863
|
+
if (config.maxLabeledDemos !== void 0) {
|
|
12864
|
+
this.maxLabeledDemos = config.maxLabeledDemos;
|
|
12865
|
+
}
|
|
12866
|
+
if (config.numTrials !== void 0) {
|
|
12867
|
+
this.numTrials = config.numTrials;
|
|
12868
|
+
}
|
|
12869
|
+
if (config.minibatch !== void 0) {
|
|
12870
|
+
this.minibatch = config.minibatch;
|
|
12871
|
+
}
|
|
12872
|
+
if (config.minibatchSize !== void 0) {
|
|
12873
|
+
this.minibatchSize = config.minibatchSize;
|
|
12874
|
+
}
|
|
12875
|
+
if (config.earlyStoppingTrials !== void 0) {
|
|
12876
|
+
this.earlyStoppingTrials = config.earlyStoppingTrials;
|
|
12877
|
+
}
|
|
12878
|
+
if (config.minImprovementThreshold !== void 0) {
|
|
12879
|
+
this.minImprovementThreshold = config.minImprovementThreshold;
|
|
12880
|
+
}
|
|
12881
|
+
if (config.verbose !== void 0) {
|
|
12882
|
+
this.verbose = config.verbose;
|
|
12883
|
+
}
|
|
12884
|
+
}
|
|
12885
|
+
/**
|
|
12886
|
+
* Reset optimizer state for reuse with different programs
|
|
12887
|
+
*/
|
|
12888
|
+
reset() {
|
|
12889
|
+
super.reset();
|
|
12890
|
+
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
12891
|
+
}
|
|
12892
|
+
/**
|
|
12893
|
+
* Validate that the optimizer can handle the given program
|
|
12894
|
+
* @param program Program to validate
|
|
12895
|
+
* @returns Validation result with any issues found
|
|
12896
|
+
*/
|
|
12897
|
+
validateProgram(program) {
|
|
12898
|
+
const result = super.validateProgram(program);
|
|
12899
|
+
if (this.examples.length < this.maxBootstrappedDemos + this.maxLabeledDemos) {
|
|
12900
|
+
result.issues.push(
|
|
12901
|
+
`Not enough examples: need at least ${this.maxBootstrappedDemos + this.maxLabeledDemos}, got ${this.examples.length}`
|
|
12902
|
+
);
|
|
12903
|
+
result.suggestions.push(
|
|
12904
|
+
"Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
|
|
12905
|
+
);
|
|
12906
|
+
}
|
|
12907
|
+
const valSetSize = this.getValidationSet().length;
|
|
12908
|
+
if (valSetSize < 5) {
|
|
12909
|
+
result.issues.push(
|
|
12910
|
+
"Validation set too small for reliable MiPRO optimization"
|
|
12911
|
+
);
|
|
12912
|
+
result.suggestions.push(
|
|
12913
|
+
"Provide more examples or a larger validation set"
|
|
12914
|
+
);
|
|
12915
|
+
}
|
|
12916
|
+
return {
|
|
12917
|
+
isValid: result.issues.length === 0,
|
|
12918
|
+
issues: result.issues,
|
|
12919
|
+
suggestions: result.suggestions
|
|
12920
|
+
};
|
|
12344
12921
|
}
|
|
12345
12922
|
};
|
|
12346
12923
|
|
|
@@ -12587,7 +13164,7 @@ var AxTestPrompt = class {
|
|
|
12587
13164
|
throw new Error("Invalid example");
|
|
12588
13165
|
}
|
|
12589
13166
|
const res = await this.program.forward(this.ai, ex);
|
|
12590
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
13167
|
+
const score = await metricFn({ prediction: res, example: ex });
|
|
12591
13168
|
sumOfScores += score;
|
|
12592
13169
|
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
12593
13170
|
updateProgressBar(i, total, sumOfScores, et, "Testing Prompt", 30);
|
|
@@ -14621,7 +15198,6 @@ var AxRAG = class extends AxChainOfThought {
|
|
|
14621
15198
|
);
|
|
14622
15199
|
this.genQuery = new AxGen(qsig);
|
|
14623
15200
|
this.queryFn = queryFn;
|
|
14624
|
-
this.register(this.genQuery);
|
|
14625
15201
|
}
|
|
14626
15202
|
async forward(ai, values, options) {
|
|
14627
15203
|
let question;
|
|
@@ -14698,6 +15274,7 @@ export {
|
|
|
14698
15274
|
AxAssertionError,
|
|
14699
15275
|
AxBalancer,
|
|
14700
15276
|
AxBaseAI,
|
|
15277
|
+
AxBaseOptimizer,
|
|
14701
15278
|
AxBootstrapFewShot,
|
|
14702
15279
|
AxChainOfThought,
|
|
14703
15280
|
AxDB,
|
|
@@ -14707,6 +15284,7 @@ export {
|
|
|
14707
15284
|
AxDBMemory,
|
|
14708
15285
|
AxDBPinecone,
|
|
14709
15286
|
AxDBWeaviate,
|
|
15287
|
+
AxDefaultCostTracker,
|
|
14710
15288
|
AxDefaultQueryRewriter,
|
|
14711
15289
|
AxDefaultResultReranker,
|
|
14712
15290
|
AxDockerSession,
|