@ax-llm/ax 12.0.7 → 12.0.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +1107 -527
- package/index.cjs.map +1 -1
- package/index.d.cts +505 -188
- package/index.d.ts +505 -188
- package/index.js +1109 -531
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.cjs
CHANGED
|
@@ -81,6 +81,7 @@ __export(index_exports, {
|
|
|
81
81
|
AxAssertionError: () => AxAssertionError,
|
|
82
82
|
AxBalancer: () => AxBalancer,
|
|
83
83
|
AxBaseAI: () => AxBaseAI,
|
|
84
|
+
AxBaseOptimizer: () => AxBaseOptimizer,
|
|
84
85
|
AxBootstrapFewShot: () => AxBootstrapFewShot,
|
|
85
86
|
AxChainOfThought: () => AxChainOfThought,
|
|
86
87
|
AxDB: () => AxDB,
|
|
@@ -90,6 +91,7 @@ __export(index_exports, {
|
|
|
90
91
|
AxDBMemory: () => AxDBMemory,
|
|
91
92
|
AxDBPinecone: () => AxDBPinecone,
|
|
92
93
|
AxDBWeaviate: () => AxDBWeaviate,
|
|
94
|
+
AxDefaultCostTracker: () => AxDefaultCostTracker,
|
|
93
95
|
AxDefaultQueryRewriter: () => AxDefaultQueryRewriter,
|
|
94
96
|
AxDefaultResultReranker: () => AxDefaultResultReranker,
|
|
95
97
|
AxDockerSession: () => AxDockerSession,
|
|
@@ -7538,8 +7540,9 @@ var AxInstanceRegistry = class {
|
|
|
7538
7540
|
this.reg.add(instance);
|
|
7539
7541
|
}
|
|
7540
7542
|
*[Symbol.iterator]() {
|
|
7541
|
-
|
|
7542
|
-
|
|
7543
|
+
const items = Array.from(this.reg);
|
|
7544
|
+
for (let i = 0; i < items.length; i++) {
|
|
7545
|
+
yield items[i];
|
|
7543
7546
|
}
|
|
7544
7547
|
}
|
|
7545
7548
|
};
|
|
@@ -8477,7 +8480,7 @@ var AxSignature = class _AxSignature {
|
|
|
8477
8480
|
this.getOutputFields().forEach((field) => {
|
|
8478
8481
|
validateField(field, "output");
|
|
8479
8482
|
});
|
|
8480
|
-
this.sigHash = (0, import_crypto3.createHash)("sha256").update(
|
|
8483
|
+
this.sigHash = (0, import_crypto3.createHash)("sha256").update(JSON.stringify(this.inputFields)).update(JSON.stringify(this.outputFields)).digest("hex");
|
|
8481
8484
|
this.sigString = renderSignature(
|
|
8482
8485
|
this.description,
|
|
8483
8486
|
this.inputFields,
|
|
@@ -8798,7 +8801,7 @@ var AxProgramWithSignature = class {
|
|
|
8798
8801
|
this.signature.validate();
|
|
8799
8802
|
this.sigHash = this.signature?.hash();
|
|
8800
8803
|
this.children = new AxInstanceRegistry();
|
|
8801
|
-
this.key = { id: this.
|
|
8804
|
+
this.key = { id: this.signature.hash() };
|
|
8802
8805
|
}
|
|
8803
8806
|
getSignature() {
|
|
8804
8807
|
return this.signature;
|
|
@@ -8818,8 +8821,8 @@ var AxProgramWithSignature = class {
|
|
|
8818
8821
|
}
|
|
8819
8822
|
setId(id) {
|
|
8820
8823
|
this.key = { id, custom: true };
|
|
8821
|
-
for (const child of this.children) {
|
|
8822
|
-
child
|
|
8824
|
+
for (const child of Array.from(this.children)) {
|
|
8825
|
+
child?.setParentId(id);
|
|
8823
8826
|
}
|
|
8824
8827
|
}
|
|
8825
8828
|
setParentId(parentId) {
|
|
@@ -8832,8 +8835,8 @@ var AxProgramWithSignature = class {
|
|
|
8832
8835
|
if (!("programId" in examples)) {
|
|
8833
8836
|
return;
|
|
8834
8837
|
}
|
|
8835
|
-
for (const child of this.children) {
|
|
8836
|
-
child
|
|
8838
|
+
for (const child of Array.from(this.children)) {
|
|
8839
|
+
child?.setExamples(examples, options);
|
|
8837
8840
|
}
|
|
8838
8841
|
}
|
|
8839
8842
|
_setExamples(examples, options) {
|
|
@@ -8866,30 +8869,37 @@ var AxProgramWithSignature = class {
|
|
|
8866
8869
|
if (this.trace) {
|
|
8867
8870
|
traces.push({ trace: this.trace, programId: this.key.id });
|
|
8868
8871
|
}
|
|
8869
|
-
for (const child of this.children) {
|
|
8870
|
-
const _traces = child
|
|
8871
|
-
traces = [...traces, ..._traces];
|
|
8872
|
+
for (const child of Array.from(this.children)) {
|
|
8873
|
+
const _traces = child?.getTraces();
|
|
8874
|
+
traces = [...traces, ..._traces ?? []];
|
|
8872
8875
|
}
|
|
8873
8876
|
return traces;
|
|
8874
8877
|
}
|
|
8875
8878
|
getUsage() {
|
|
8876
8879
|
let usage = [...this.usage ?? []];
|
|
8877
|
-
for (const child of this.children) {
|
|
8878
|
-
const cu = child
|
|
8879
|
-
usage = [...usage, ...cu];
|
|
8880
|
+
for (const child of Array.from(this.children)) {
|
|
8881
|
+
const cu = child?.getUsage();
|
|
8882
|
+
usage = [...usage, ...cu ?? []];
|
|
8880
8883
|
}
|
|
8881
8884
|
return mergeProgramUsage(usage);
|
|
8882
8885
|
}
|
|
8883
8886
|
resetUsage() {
|
|
8884
8887
|
this.usage = [];
|
|
8885
|
-
for (const child of this.children) {
|
|
8886
|
-
child
|
|
8888
|
+
for (const child of Array.from(this.children)) {
|
|
8889
|
+
child?.resetUsage();
|
|
8887
8890
|
}
|
|
8888
8891
|
}
|
|
8889
8892
|
setDemos(demos) {
|
|
8893
|
+
const hasChildren = Array.from(this.children).length > 0;
|
|
8894
|
+
const hasMatchingDemo = demos.some((demo) => demo.programId === this.key.id);
|
|
8895
|
+
if (hasChildren && !hasMatchingDemo) {
|
|
8896
|
+
throw new Error(
|
|
8897
|
+
`Program with id '${this.key.id}' has children but no matching programId found in demos`
|
|
8898
|
+
);
|
|
8899
|
+
}
|
|
8890
8900
|
this.demos = demos.filter((v) => v.programId === this.key.id).map((v) => v.traces).flat();
|
|
8891
|
-
for (const child of this.children) {
|
|
8892
|
-
child
|
|
8901
|
+
for (const child of Array.from(this.children)) {
|
|
8902
|
+
child?.setDemos(demos);
|
|
8893
8903
|
}
|
|
8894
8904
|
}
|
|
8895
8905
|
};
|
|
@@ -8917,8 +8927,8 @@ var AxProgram = class {
|
|
|
8917
8927
|
}
|
|
8918
8928
|
setId(id) {
|
|
8919
8929
|
this.key = { id, custom: true };
|
|
8920
|
-
for (const child of this.children) {
|
|
8921
|
-
child
|
|
8930
|
+
for (const child of Array.from(this.children)) {
|
|
8931
|
+
child?.setParentId(id);
|
|
8922
8932
|
}
|
|
8923
8933
|
}
|
|
8924
8934
|
setParentId(parentId) {
|
|
@@ -8930,8 +8940,8 @@ var AxProgram = class {
|
|
|
8930
8940
|
if (!("programId" in examples)) {
|
|
8931
8941
|
return;
|
|
8932
8942
|
}
|
|
8933
|
-
for (const child of this.children) {
|
|
8934
|
-
child
|
|
8943
|
+
for (const child of Array.from(this.children)) {
|
|
8944
|
+
child?.setExamples(examples, options);
|
|
8935
8945
|
}
|
|
8936
8946
|
}
|
|
8937
8947
|
getTraces() {
|
|
@@ -8939,29 +8949,36 @@ var AxProgram = class {
|
|
|
8939
8949
|
if (this.trace) {
|
|
8940
8950
|
traces.push({ trace: this.trace, programId: this.key.id });
|
|
8941
8951
|
}
|
|
8942
|
-
for (const child of this.children) {
|
|
8943
|
-
const _traces = child
|
|
8944
|
-
traces = [...traces, ..._traces];
|
|
8952
|
+
for (const child of Array.from(this.children)) {
|
|
8953
|
+
const _traces = child?.getTraces();
|
|
8954
|
+
traces = [...traces, ..._traces ?? []];
|
|
8945
8955
|
}
|
|
8946
8956
|
return traces;
|
|
8947
8957
|
}
|
|
8948
8958
|
getUsage() {
|
|
8949
8959
|
let usage = [...this.usage ?? []];
|
|
8950
|
-
for (const child of this.children) {
|
|
8951
|
-
const cu = child
|
|
8952
|
-
usage = [...usage, ...cu];
|
|
8960
|
+
for (const child of Array.from(this.children)) {
|
|
8961
|
+
const cu = child?.getUsage();
|
|
8962
|
+
usage = [...usage, ...cu ?? []];
|
|
8953
8963
|
}
|
|
8954
8964
|
return mergeProgramUsage(usage);
|
|
8955
8965
|
}
|
|
8956
8966
|
resetUsage() {
|
|
8957
8967
|
this.usage = [];
|
|
8958
|
-
for (const child of this.children) {
|
|
8959
|
-
child
|
|
8968
|
+
for (const child of Array.from(this.children)) {
|
|
8969
|
+
child?.resetUsage();
|
|
8960
8970
|
}
|
|
8961
8971
|
}
|
|
8962
8972
|
setDemos(demos) {
|
|
8963
|
-
|
|
8964
|
-
|
|
8973
|
+
const hasChildren = Array.from(this.children).length > 0;
|
|
8974
|
+
const hasMatchingDemo = demos.some((demo) => demo.programId === this.key.id);
|
|
8975
|
+
if (hasChildren && !hasMatchingDemo) {
|
|
8976
|
+
throw new Error(
|
|
8977
|
+
`Program with id '${this.key.id}' has children but no matching programId found in demos`
|
|
8978
|
+
);
|
|
8979
|
+
}
|
|
8980
|
+
for (const child of Array.from(this.children)) {
|
|
8981
|
+
child?.setDemos(demos);
|
|
8965
8982
|
}
|
|
8966
8983
|
}
|
|
8967
8984
|
};
|
|
@@ -9689,7 +9706,9 @@ var AxAgent = class {
|
|
|
9689
9706
|
description: definition ?? description
|
|
9690
9707
|
});
|
|
9691
9708
|
for (const agent of agents ?? []) {
|
|
9692
|
-
this.program.register(
|
|
9709
|
+
this.program.register(
|
|
9710
|
+
agent
|
|
9711
|
+
);
|
|
9693
9712
|
}
|
|
9694
9713
|
this.name = name;
|
|
9695
9714
|
this.func = {
|
|
@@ -10193,6 +10212,673 @@ function validateModels2(services) {
|
|
|
10193
10212
|
}
|
|
10194
10213
|
}
|
|
10195
10214
|
|
|
10215
|
+
// dsp/optimizer.ts
|
|
10216
|
+
var AxDefaultCostTracker = class {
|
|
10217
|
+
tokenUsage = {};
|
|
10218
|
+
totalTokens = 0;
|
|
10219
|
+
// Configuration options
|
|
10220
|
+
costPerModel;
|
|
10221
|
+
maxCost;
|
|
10222
|
+
maxTokens;
|
|
10223
|
+
constructor(options) {
|
|
10224
|
+
this.costPerModel = options?.costPerModel ?? {};
|
|
10225
|
+
this.maxCost = options?.maxCost;
|
|
10226
|
+
this.maxTokens = options?.maxTokens;
|
|
10227
|
+
}
|
|
10228
|
+
trackTokens(count, model) {
|
|
10229
|
+
this.tokenUsage[model] = (this.tokenUsage[model] || 0) + count;
|
|
10230
|
+
this.totalTokens += count;
|
|
10231
|
+
}
|
|
10232
|
+
getCurrentCost() {
|
|
10233
|
+
let totalCost = 0;
|
|
10234
|
+
for (const [model, tokens] of Object.entries(this.tokenUsage)) {
|
|
10235
|
+
const costPer1K = this.costPerModel[model] || 1e-3;
|
|
10236
|
+
totalCost += tokens / 1e3 * costPer1K;
|
|
10237
|
+
}
|
|
10238
|
+
return totalCost;
|
|
10239
|
+
}
|
|
10240
|
+
getTokenUsage() {
|
|
10241
|
+
return { ...this.tokenUsage };
|
|
10242
|
+
}
|
|
10243
|
+
getTotalTokens() {
|
|
10244
|
+
return this.totalTokens;
|
|
10245
|
+
}
|
|
10246
|
+
isLimitReached() {
|
|
10247
|
+
if (this.maxTokens !== void 0 && this.totalTokens >= this.maxTokens) {
|
|
10248
|
+
return true;
|
|
10249
|
+
}
|
|
10250
|
+
if (this.maxCost !== void 0) {
|
|
10251
|
+
const currentCost = this.getCurrentCost();
|
|
10252
|
+
if (currentCost >= this.maxCost) {
|
|
10253
|
+
return true;
|
|
10254
|
+
}
|
|
10255
|
+
}
|
|
10256
|
+
return false;
|
|
10257
|
+
}
|
|
10258
|
+
reset() {
|
|
10259
|
+
this.tokenUsage = {};
|
|
10260
|
+
this.totalTokens = 0;
|
|
10261
|
+
}
|
|
10262
|
+
};
|
|
10263
|
+
var AxBaseOptimizer = class {
|
|
10264
|
+
// Common AxOptimizerArgs fields
|
|
10265
|
+
studentAI;
|
|
10266
|
+
teacherAI;
|
|
10267
|
+
examples;
|
|
10268
|
+
validationSet;
|
|
10269
|
+
targetScore;
|
|
10270
|
+
minSuccessRate;
|
|
10271
|
+
onProgress;
|
|
10272
|
+
onEarlyStop;
|
|
10273
|
+
costTracker;
|
|
10274
|
+
seed;
|
|
10275
|
+
// Checkpointing fields
|
|
10276
|
+
checkpointSave;
|
|
10277
|
+
checkpointLoad;
|
|
10278
|
+
checkpointInterval;
|
|
10279
|
+
resumeFromCheckpoint;
|
|
10280
|
+
// Checkpoint state
|
|
10281
|
+
currentRound = 0;
|
|
10282
|
+
scoreHistory = [];
|
|
10283
|
+
configurationHistory = [];
|
|
10284
|
+
// Common optimization statistics
|
|
10285
|
+
stats;
|
|
10286
|
+
constructor(args) {
|
|
10287
|
+
if (args.examples.length === 0) {
|
|
10288
|
+
throw new Error("No examples found");
|
|
10289
|
+
}
|
|
10290
|
+
this.studentAI = args.studentAI;
|
|
10291
|
+
this.teacherAI = args.teacherAI;
|
|
10292
|
+
this.examples = args.examples;
|
|
10293
|
+
this.validationSet = args.validationSet;
|
|
10294
|
+
this.targetScore = args.targetScore;
|
|
10295
|
+
this.minSuccessRate = args.minSuccessRate;
|
|
10296
|
+
this.onProgress = args.onProgress;
|
|
10297
|
+
this.onEarlyStop = args.onEarlyStop;
|
|
10298
|
+
this.seed = args.seed;
|
|
10299
|
+
this.checkpointSave = args.checkpointSave;
|
|
10300
|
+
this.checkpointLoad = args.checkpointLoad;
|
|
10301
|
+
this.checkpointInterval = args.checkpointInterval ?? 10;
|
|
10302
|
+
this.resumeFromCheckpoint = args.resumeFromCheckpoint;
|
|
10303
|
+
const costTracker = new AxDefaultCostTracker({
|
|
10304
|
+
maxTokens: 1e6
|
|
10305
|
+
});
|
|
10306
|
+
this.costTracker = args.costTracker ?? costTracker;
|
|
10307
|
+
this.stats = this.initializeStats();
|
|
10308
|
+
}
|
|
10309
|
+
/**
|
|
10310
|
+
* Initialize the optimization statistics structure
|
|
10311
|
+
*/
|
|
10312
|
+
initializeStats() {
|
|
10313
|
+
return {
|
|
10314
|
+
totalCalls: 0,
|
|
10315
|
+
successfulDemos: 0,
|
|
10316
|
+
estimatedTokenUsage: 0,
|
|
10317
|
+
earlyStopped: false,
|
|
10318
|
+
resourceUsage: {
|
|
10319
|
+
totalTokens: 0,
|
|
10320
|
+
totalTime: 0,
|
|
10321
|
+
avgLatencyPerEval: 0,
|
|
10322
|
+
costByModel: {}
|
|
10323
|
+
},
|
|
10324
|
+
convergenceInfo: {
|
|
10325
|
+
converged: false,
|
|
10326
|
+
finalImprovement: 0,
|
|
10327
|
+
stagnationRounds: 0,
|
|
10328
|
+
convergenceThreshold: 0.01
|
|
10329
|
+
}
|
|
10330
|
+
};
|
|
10331
|
+
}
|
|
10332
|
+
/**
|
|
10333
|
+
* Set up reproducible random seed if provided
|
|
10334
|
+
*/
|
|
10335
|
+
setupRandomSeed() {
|
|
10336
|
+
if (this.seed !== void 0) {
|
|
10337
|
+
Math.random = (() => {
|
|
10338
|
+
let seed = this.seed;
|
|
10339
|
+
return () => {
|
|
10340
|
+
seed = (seed * 9301 + 49297) % 233280;
|
|
10341
|
+
return seed / 233280;
|
|
10342
|
+
};
|
|
10343
|
+
})();
|
|
10344
|
+
}
|
|
10345
|
+
}
|
|
10346
|
+
/**
|
|
10347
|
+
* Check if optimization should stop early due to cost limits
|
|
10348
|
+
*/
|
|
10349
|
+
checkCostLimits() {
|
|
10350
|
+
return this.costTracker?.isLimitReached() ?? false;
|
|
10351
|
+
}
|
|
10352
|
+
/**
|
|
10353
|
+
* Check if target score has been reached
|
|
10354
|
+
*/
|
|
10355
|
+
checkTargetScore(currentScore) {
|
|
10356
|
+
return this.targetScore !== void 0 && currentScore >= this.targetScore;
|
|
10357
|
+
}
|
|
10358
|
+
/**
|
|
10359
|
+
* Update resource usage statistics
|
|
10360
|
+
*/
|
|
10361
|
+
updateResourceUsage(startTime, tokensUsed = 0) {
|
|
10362
|
+
this.stats.resourceUsage.totalTime = Date.now() - startTime;
|
|
10363
|
+
this.stats.resourceUsage.totalTokens += tokensUsed;
|
|
10364
|
+
if (this.stats.totalCalls > 0) {
|
|
10365
|
+
this.stats.resourceUsage.avgLatencyPerEval = this.stats.resourceUsage.totalTime / this.stats.totalCalls;
|
|
10366
|
+
}
|
|
10367
|
+
}
|
|
10368
|
+
/**
|
|
10369
|
+
* Trigger early stopping with appropriate callbacks
|
|
10370
|
+
*/
|
|
10371
|
+
triggerEarlyStopping(reason, bestScoreRound) {
|
|
10372
|
+
this.stats.earlyStopped = true;
|
|
10373
|
+
this.stats.earlyStopping = {
|
|
10374
|
+
bestScoreRound,
|
|
10375
|
+
patienceExhausted: reason.includes("improvement"),
|
|
10376
|
+
reason
|
|
10377
|
+
};
|
|
10378
|
+
if (this.onEarlyStop) {
|
|
10379
|
+
this.onEarlyStop(reason, this.stats);
|
|
10380
|
+
}
|
|
10381
|
+
}
|
|
10382
|
+
/**
|
|
10383
|
+
* Get the validation set, with fallback to a split of examples
|
|
10384
|
+
*/
|
|
10385
|
+
getValidationSet(options) {
|
|
10386
|
+
return options?.overrideValidationSet || this.validationSet || this.examples.slice(0, Math.floor(this.examples.length * 0.2));
|
|
10387
|
+
}
|
|
10388
|
+
/**
|
|
10389
|
+
* Get the AI service to use for a specific task, preferring teacher when available
|
|
10390
|
+
* @param preferTeacher Whether to prefer teacher AI over student AI
|
|
10391
|
+
* @param options Optional compile options that may override teacher AI
|
|
10392
|
+
* @returns The appropriate AI service to use
|
|
10393
|
+
*/
|
|
10394
|
+
getAIService(preferTeacher = false, options) {
|
|
10395
|
+
if (preferTeacher && options?.overrideTeacherAI) {
|
|
10396
|
+
return options.overrideTeacherAI;
|
|
10397
|
+
}
|
|
10398
|
+
if (preferTeacher && this.teacherAI) {
|
|
10399
|
+
return this.teacherAI;
|
|
10400
|
+
}
|
|
10401
|
+
return this.studentAI;
|
|
10402
|
+
}
|
|
10403
|
+
/**
|
|
10404
|
+
* Check if teacher AI is available (including overrides)
|
|
10405
|
+
* @param options Optional compile options that may override teacher AI
|
|
10406
|
+
* @returns True if teacher AI is configured or overridden
|
|
10407
|
+
*/
|
|
10408
|
+
hasTeacherAI(options) {
|
|
10409
|
+
return options?.overrideTeacherAI !== void 0 || this.teacherAI !== void 0;
|
|
10410
|
+
}
|
|
10411
|
+
/**
|
|
10412
|
+
* Get teacher AI if available, otherwise return student AI
|
|
10413
|
+
* @param options Optional compile options that may override teacher AI
|
|
10414
|
+
* @returns Teacher AI if available, otherwise student AI
|
|
10415
|
+
*/
|
|
10416
|
+
getTeacherOrStudentAI(options) {
|
|
10417
|
+
return options?.overrideTeacherAI || this.teacherAI || this.studentAI;
|
|
10418
|
+
}
|
|
10419
|
+
/**
|
|
10420
|
+
* Execute a task with teacher AI if available, otherwise use student AI
|
|
10421
|
+
* @param task Function that takes an AI service and returns a promise
|
|
10422
|
+
* @param preferTeacher Whether to prefer teacher AI (default: true)
|
|
10423
|
+
* @param options Optional compile options that may override teacher AI
|
|
10424
|
+
* @returns Result of the task execution
|
|
10425
|
+
*/
|
|
10426
|
+
async executeWithTeacher(task, preferTeacher = true, options) {
|
|
10427
|
+
const ai = this.getAIService(preferTeacher, options);
|
|
10428
|
+
return await task(ai);
|
|
10429
|
+
}
|
|
10430
|
+
/**
|
|
10431
|
+
* Get current optimization statistics
|
|
10432
|
+
*/
|
|
10433
|
+
getStats() {
|
|
10434
|
+
return { ...this.stats };
|
|
10435
|
+
}
|
|
10436
|
+
/**
|
|
10437
|
+
* Reset optimizer state for reuse with different programs
|
|
10438
|
+
*/
|
|
10439
|
+
reset() {
|
|
10440
|
+
this.stats = this.initializeStats();
|
|
10441
|
+
this.costTracker?.reset();
|
|
10442
|
+
this.currentRound = 0;
|
|
10443
|
+
this.scoreHistory = [];
|
|
10444
|
+
this.configurationHistory = [];
|
|
10445
|
+
}
|
|
10446
|
+
/**
|
|
10447
|
+
* Basic program validation that can be extended by concrete optimizers
|
|
10448
|
+
*/
|
|
10449
|
+
validateProgram(program) {
|
|
10450
|
+
const issues = [];
|
|
10451
|
+
const suggestions = [];
|
|
10452
|
+
if (!("forward" in program) || typeof program.forward !== "function") {
|
|
10453
|
+
issues.push("Program must have a forward method");
|
|
10454
|
+
}
|
|
10455
|
+
if (this.examples.length < 2) {
|
|
10456
|
+
issues.push("Need at least 2 examples for optimization");
|
|
10457
|
+
suggestions.push("Provide more training examples");
|
|
10458
|
+
}
|
|
10459
|
+
const valSetSize = this.getValidationSet().length;
|
|
10460
|
+
if (valSetSize < 1) {
|
|
10461
|
+
issues.push("Validation set is empty");
|
|
10462
|
+
suggestions.push("Provide examples or a validation set");
|
|
10463
|
+
}
|
|
10464
|
+
return {
|
|
10465
|
+
isValid: issues.length === 0,
|
|
10466
|
+
issues,
|
|
10467
|
+
suggestions
|
|
10468
|
+
};
|
|
10469
|
+
}
|
|
10470
|
+
/**
|
|
10471
|
+
* Multi-objective optimization using Pareto frontier
|
|
10472
|
+
* Default implementation that leverages the single-objective compile method
|
|
10473
|
+
* @param program The program to optimize
|
|
10474
|
+
* @param metricFn Multi-objective metric function that returns multiple scores
|
|
10475
|
+
* @param options Optional configuration options
|
|
10476
|
+
* @returns Pareto optimization result with frontier of non-dominated solutions
|
|
10477
|
+
*/
|
|
10478
|
+
async compilePareto(program, metricFn, options) {
|
|
10479
|
+
const startTime = Date.now();
|
|
10480
|
+
if (options?.verbose) {
|
|
10481
|
+
console.log("Starting Pareto optimization using base implementation");
|
|
10482
|
+
console.log("This will run multiple single-objective optimizations");
|
|
10483
|
+
}
|
|
10484
|
+
const solutions = await this.generateWeightedSolutions(
|
|
10485
|
+
program,
|
|
10486
|
+
metricFn,
|
|
10487
|
+
options
|
|
10488
|
+
);
|
|
10489
|
+
const constraintSolutions = await this.generateConstraintSolutions(
|
|
10490
|
+
program,
|
|
10491
|
+
metricFn,
|
|
10492
|
+
options
|
|
10493
|
+
);
|
|
10494
|
+
const allSolutions = [...solutions, ...constraintSolutions];
|
|
10495
|
+
if (options?.verbose) {
|
|
10496
|
+
console.log(`Generated ${allSolutions.length} candidate solutions`);
|
|
10497
|
+
}
|
|
10498
|
+
const paretoFront = this.findParetoFrontier(allSolutions);
|
|
10499
|
+
const hypervolume = this.calculateHypervolume(paretoFront);
|
|
10500
|
+
if (options?.verbose) {
|
|
10501
|
+
console.log(`Found ${paretoFront.length} non-dominated solutions`);
|
|
10502
|
+
console.log(`Hypervolume: ${hypervolume?.toFixed(4) || "N/A"}`);
|
|
10503
|
+
}
|
|
10504
|
+
this.updateResourceUsage(startTime);
|
|
10505
|
+
this.stats.convergenceInfo.converged = true;
|
|
10506
|
+
const bestScore = paretoFront.length > 0 ? Math.max(
|
|
10507
|
+
...paretoFront.map((sol) => Math.max(...Object.values(sol.scores)))
|
|
10508
|
+
) : 0;
|
|
10509
|
+
return {
|
|
10510
|
+
demos: paretoFront.length > 0 ? [...paretoFront[0].demos] : void 0,
|
|
10511
|
+
stats: this.stats,
|
|
10512
|
+
bestScore,
|
|
10513
|
+
paretoFront,
|
|
10514
|
+
hypervolume,
|
|
10515
|
+
paretoFrontSize: paretoFront.length,
|
|
10516
|
+
finalConfiguration: {
|
|
10517
|
+
paretoFrontSize: paretoFront.length,
|
|
10518
|
+
hypervolume,
|
|
10519
|
+
strategy: "weighted_combinations_and_constraints",
|
|
10520
|
+
numSolutions: allSolutions.length
|
|
10521
|
+
}
|
|
10522
|
+
};
|
|
10523
|
+
}
|
|
10524
|
+
/**
|
|
10525
|
+
* Generate solutions using different weighted combinations of objectives
|
|
10526
|
+
*/
|
|
10527
|
+
async generateWeightedSolutions(program, metricFn, options) {
|
|
10528
|
+
const solutions = [];
|
|
10529
|
+
const sampleExample = this.examples[0];
|
|
10530
|
+
const samplePrediction = await program.forward(
|
|
10531
|
+
this.studentAI,
|
|
10532
|
+
sampleExample
|
|
10533
|
+
);
|
|
10534
|
+
const sampleScores = await metricFn({
|
|
10535
|
+
prediction: samplePrediction,
|
|
10536
|
+
example: sampleExample
|
|
10537
|
+
});
|
|
10538
|
+
const objectives = Object.keys(sampleScores);
|
|
10539
|
+
if (options?.verbose) {
|
|
10540
|
+
console.log(`Detected objectives: ${objectives.join(", ")}`);
|
|
10541
|
+
}
|
|
10542
|
+
const weightCombinations = this.generateWeightCombinations(objectives);
|
|
10543
|
+
for (let i = 0; i < weightCombinations.length; i++) {
|
|
10544
|
+
const weights = weightCombinations[i];
|
|
10545
|
+
if (options?.verbose) {
|
|
10546
|
+
console.log(`Optimizing with weights: ${JSON.stringify(weights)}`);
|
|
10547
|
+
}
|
|
10548
|
+
const weightedMetric = async ({ prediction, example }) => {
|
|
10549
|
+
const scores = await metricFn({ prediction, example });
|
|
10550
|
+
let weightedScore = 0;
|
|
10551
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10552
|
+
weightedScore += score * (weights[objective] || 0);
|
|
10553
|
+
}
|
|
10554
|
+
return weightedScore;
|
|
10555
|
+
};
|
|
10556
|
+
try {
|
|
10557
|
+
const result = await this.compile(program, weightedMetric, {
|
|
10558
|
+
...options,
|
|
10559
|
+
verbose: false
|
|
10560
|
+
// Suppress inner optimization logs
|
|
10561
|
+
});
|
|
10562
|
+
const scores = await this.evaluateWithMultiObjective(
|
|
10563
|
+
program,
|
|
10564
|
+
result,
|
|
10565
|
+
metricFn
|
|
10566
|
+
);
|
|
10567
|
+
solutions.push({
|
|
10568
|
+
scores,
|
|
10569
|
+
demos: result.demos,
|
|
10570
|
+
configuration: {
|
|
10571
|
+
...result.finalConfiguration,
|
|
10572
|
+
weights,
|
|
10573
|
+
strategy: "weighted_combination"
|
|
10574
|
+
}
|
|
10575
|
+
});
|
|
10576
|
+
} catch (error) {
|
|
10577
|
+
if (options?.verbose) {
|
|
10578
|
+
console.warn(
|
|
10579
|
+
`Failed optimization with weights ${JSON.stringify(weights)}:`,
|
|
10580
|
+
error
|
|
10581
|
+
);
|
|
10582
|
+
}
|
|
10583
|
+
continue;
|
|
10584
|
+
}
|
|
10585
|
+
}
|
|
10586
|
+
return solutions;
|
|
10587
|
+
}
|
|
10588
|
+
/**
|
|
10589
|
+
* Generate solutions using constraint-based optimization
|
|
10590
|
+
*/
|
|
10591
|
+
async generateConstraintSolutions(program, metricFn, options) {
|
|
10592
|
+
const solutions = [];
|
|
10593
|
+
const sampleExample = this.examples[0];
|
|
10594
|
+
const samplePrediction = await program.forward(
|
|
10595
|
+
this.studentAI,
|
|
10596
|
+
sampleExample
|
|
10597
|
+
);
|
|
10598
|
+
const sampleScores = await metricFn({
|
|
10599
|
+
prediction: samplePrediction,
|
|
10600
|
+
example: sampleExample
|
|
10601
|
+
});
|
|
10602
|
+
const objectives = Object.keys(sampleScores);
|
|
10603
|
+
for (const primaryObjective of objectives) {
|
|
10604
|
+
if (options?.verbose) {
|
|
10605
|
+
console.log(
|
|
10606
|
+
`Optimizing ${primaryObjective} with constraints on other objectives`
|
|
10607
|
+
);
|
|
10608
|
+
}
|
|
10609
|
+
const constraintMetric = async ({ prediction, example }) => {
|
|
10610
|
+
const scores = await metricFn({ prediction, example });
|
|
10611
|
+
const primaryScore = scores[primaryObjective] || 0;
|
|
10612
|
+
let penalty = 0;
|
|
10613
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10614
|
+
if (objective !== primaryObjective) {
|
|
10615
|
+
if (score < 0.3) {
|
|
10616
|
+
penalty += (0.3 - score) * 2;
|
|
10617
|
+
}
|
|
10618
|
+
}
|
|
10619
|
+
}
|
|
10620
|
+
return primaryScore - penalty;
|
|
10621
|
+
};
|
|
10622
|
+
try {
|
|
10623
|
+
const result = await this.compile(program, constraintMetric, {
|
|
10624
|
+
...options,
|
|
10625
|
+
verbose: false
|
|
10626
|
+
});
|
|
10627
|
+
const scores = await this.evaluateWithMultiObjective(
|
|
10628
|
+
program,
|
|
10629
|
+
result,
|
|
10630
|
+
metricFn
|
|
10631
|
+
);
|
|
10632
|
+
solutions.push({
|
|
10633
|
+
scores,
|
|
10634
|
+
demos: result.demos,
|
|
10635
|
+
configuration: {
|
|
10636
|
+
...result.finalConfiguration,
|
|
10637
|
+
primaryObjective,
|
|
10638
|
+
strategy: "constraint_based"
|
|
10639
|
+
}
|
|
10640
|
+
});
|
|
10641
|
+
} catch (error) {
|
|
10642
|
+
if (options?.verbose) {
|
|
10643
|
+
console.warn(
|
|
10644
|
+
`Failed constraint optimization for ${primaryObjective}:`,
|
|
10645
|
+
error
|
|
10646
|
+
);
|
|
10647
|
+
}
|
|
10648
|
+
continue;
|
|
10649
|
+
}
|
|
10650
|
+
}
|
|
10651
|
+
return solutions;
|
|
10652
|
+
}
|
|
10653
|
+
/**
|
|
10654
|
+
* Generate different weight combinations for objectives
|
|
10655
|
+
*/
|
|
10656
|
+
generateWeightCombinations(objectives) {
|
|
10657
|
+
const combinations = [];
|
|
10658
|
+
for (const objective of objectives) {
|
|
10659
|
+
const weights = {};
|
|
10660
|
+
for (const obj of objectives) {
|
|
10661
|
+
weights[obj] = obj === objective ? 1 : 0;
|
|
10662
|
+
}
|
|
10663
|
+
combinations.push(weights);
|
|
10664
|
+
}
|
|
10665
|
+
const equalWeights = {};
|
|
10666
|
+
for (const objective of objectives) {
|
|
10667
|
+
equalWeights[objective] = 1 / objectives.length;
|
|
10668
|
+
}
|
|
10669
|
+
combinations.push(equalWeights);
|
|
10670
|
+
if (objectives.length === 2) {
|
|
10671
|
+
const [obj1, obj2] = objectives;
|
|
10672
|
+
for (let w1 = 0.1; w1 <= 0.9; w1 += 0.2) {
|
|
10673
|
+
const w2 = 1 - w1;
|
|
10674
|
+
combinations.push({ [obj1]: w1, [obj2]: w2 });
|
|
10675
|
+
}
|
|
10676
|
+
}
|
|
10677
|
+
if (objectives.length === 3) {
|
|
10678
|
+
const [obj1, obj2, obj3] = objectives;
|
|
10679
|
+
combinations.push(
|
|
10680
|
+
{ [obj1]: 0.5, [obj2]: 0.3, [obj3]: 0.2 },
|
|
10681
|
+
{ [obj1]: 0.3, [obj2]: 0.5, [obj3]: 0.2 },
|
|
10682
|
+
{ [obj1]: 0.2, [obj2]: 0.3, [obj3]: 0.5 }
|
|
10683
|
+
);
|
|
10684
|
+
}
|
|
10685
|
+
return combinations;
|
|
10686
|
+
}
|
|
10687
|
+
/**
|
|
10688
|
+
* Evaluate a single-objective result with multi-objective metrics
|
|
10689
|
+
*/
|
|
10690
|
+
async evaluateWithMultiObjective(program, result, metricFn) {
|
|
10691
|
+
const valSet = this.getValidationSet();
|
|
10692
|
+
const allScores = {};
|
|
10693
|
+
const testProgram = { ...program };
|
|
10694
|
+
if (result.demos && "setDemos" in testProgram) {
|
|
10695
|
+
;
|
|
10696
|
+
testProgram.setDemos(result.demos);
|
|
10697
|
+
}
|
|
10698
|
+
const evalSet = valSet.slice(0, Math.min(5, valSet.length));
|
|
10699
|
+
for (const example of evalSet) {
|
|
10700
|
+
try {
|
|
10701
|
+
const prediction = await testProgram.forward(
|
|
10702
|
+
this.studentAI,
|
|
10703
|
+
example
|
|
10704
|
+
);
|
|
10705
|
+
const scores = await metricFn({ prediction, example });
|
|
10706
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10707
|
+
if (!allScores[objective]) {
|
|
10708
|
+
allScores[objective] = [];
|
|
10709
|
+
}
|
|
10710
|
+
allScores[objective].push(score);
|
|
10711
|
+
}
|
|
10712
|
+
} catch {
|
|
10713
|
+
continue;
|
|
10714
|
+
}
|
|
10715
|
+
}
|
|
10716
|
+
const avgScores = {};
|
|
10717
|
+
for (const [objective, scores] of Object.entries(allScores)) {
|
|
10718
|
+
avgScores[objective] = scores.length > 0 ? scores.reduce((sum, score) => sum + score, 0) / scores.length : 0;
|
|
10719
|
+
}
|
|
10720
|
+
return avgScores;
|
|
10721
|
+
}
|
|
10722
|
+
/**
|
|
10723
|
+
* Find the Pareto frontier from a set of solutions
|
|
10724
|
+
*/
|
|
10725
|
+
findParetoFrontier(solutions) {
|
|
10726
|
+
const paretoFront = [];
|
|
10727
|
+
for (let i = 0; i < solutions.length; i++) {
|
|
10728
|
+
const solutionA = solutions[i];
|
|
10729
|
+
let isDominated = false;
|
|
10730
|
+
let dominatedCount = 0;
|
|
10731
|
+
for (let j = 0; j < solutions.length; j++) {
|
|
10732
|
+
if (i === j) continue;
|
|
10733
|
+
const solutionB = solutions[j];
|
|
10734
|
+
if (this.dominates(solutionB.scores, solutionA.scores)) {
|
|
10735
|
+
isDominated = true;
|
|
10736
|
+
break;
|
|
10737
|
+
}
|
|
10738
|
+
if (this.dominates(solutionA.scores, solutionB.scores)) {
|
|
10739
|
+
dominatedCount++;
|
|
10740
|
+
}
|
|
10741
|
+
}
|
|
10742
|
+
if (!isDominated) {
|
|
10743
|
+
paretoFront.push({
|
|
10744
|
+
demos: solutionA.demos || [],
|
|
10745
|
+
scores: solutionA.scores,
|
|
10746
|
+
configuration: solutionA.configuration,
|
|
10747
|
+
dominatedSolutions: dominatedCount
|
|
10748
|
+
});
|
|
10749
|
+
}
|
|
10750
|
+
}
|
|
10751
|
+
return paretoFront;
|
|
10752
|
+
}
|
|
10753
|
+
/**
|
|
10754
|
+
* Check if solution A dominates solution B
|
|
10755
|
+
* A dominates B if A is better or equal in all objectives and strictly better in at least one
|
|
10756
|
+
*/
|
|
10757
|
+
dominates(scoresA, scoresB) {
|
|
10758
|
+
const objectives = Object.keys(scoresA);
|
|
10759
|
+
let atLeastAsGood = true;
|
|
10760
|
+
let strictlyBetter = false;
|
|
10761
|
+
for (const objective of objectives) {
|
|
10762
|
+
const scoreA = scoresA[objective] || 0;
|
|
10763
|
+
const scoreB = scoresB[objective] || 0;
|
|
10764
|
+
if (scoreA < scoreB) {
|
|
10765
|
+
atLeastAsGood = false;
|
|
10766
|
+
break;
|
|
10767
|
+
}
|
|
10768
|
+
if (scoreA > scoreB) {
|
|
10769
|
+
strictlyBetter = true;
|
|
10770
|
+
}
|
|
10771
|
+
}
|
|
10772
|
+
return atLeastAsGood && strictlyBetter;
|
|
10773
|
+
}
|
|
10774
|
+
/**
|
|
10775
|
+
* Calculate hypervolume of the Pareto frontier
|
|
10776
|
+
* Simplified implementation using reference point at origin
|
|
10777
|
+
*/
|
|
10778
|
+
calculateHypervolume(paretoFront) {
|
|
10779
|
+
if (paretoFront.length === 0) return void 0;
|
|
10780
|
+
const firstSolution = paretoFront[0];
|
|
10781
|
+
const objectives = Object.keys(firstSolution.scores);
|
|
10782
|
+
if (objectives.length === 2) {
|
|
10783
|
+
const [obj1, obj2] = objectives;
|
|
10784
|
+
let hypervolume = 0;
|
|
10785
|
+
const sortedSolutions = [...paretoFront].sort(
|
|
10786
|
+
(a, b) => (b.scores[obj1] || 0) - (a.scores[obj1] || 0)
|
|
10787
|
+
);
|
|
10788
|
+
let prevScore2 = 0;
|
|
10789
|
+
for (const solution of sortedSolutions) {
|
|
10790
|
+
const score1 = solution.scores[obj1] || 0;
|
|
10791
|
+
const score2 = solution.scores[obj2] || 0;
|
|
10792
|
+
hypervolume += score1 * (score2 - prevScore2);
|
|
10793
|
+
prevScore2 = Math.max(prevScore2, score2);
|
|
10794
|
+
}
|
|
10795
|
+
return hypervolume;
|
|
10796
|
+
}
|
|
10797
|
+
return void 0;
|
|
10798
|
+
}
|
|
10799
|
+
/**
|
|
10800
|
+
* Save current optimization state to checkpoint
|
|
10801
|
+
*/
|
|
10802
|
+
async saveCheckpoint(optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10803
|
+
const saveFn = options?.overrideCheckpointSave || this.checkpointSave;
|
|
10804
|
+
if (!saveFn) return void 0;
|
|
10805
|
+
const checkpoint = {
|
|
10806
|
+
version: "1.0.0",
|
|
10807
|
+
timestamp: Date.now(),
|
|
10808
|
+
optimizerType,
|
|
10809
|
+
optimizerConfig,
|
|
10810
|
+
currentRound: this.currentRound,
|
|
10811
|
+
totalRounds: this.stats.resourceUsage.totalTime > 0 ? this.currentRound : 0,
|
|
10812
|
+
bestScore,
|
|
10813
|
+
bestConfiguration,
|
|
10814
|
+
scoreHistory: [...this.scoreHistory],
|
|
10815
|
+
configurationHistory: [...this.configurationHistory],
|
|
10816
|
+
stats: { ...this.stats },
|
|
10817
|
+
optimizerState,
|
|
10818
|
+
examples: this.examples,
|
|
10819
|
+
validationSet: this.validationSet
|
|
10820
|
+
};
|
|
10821
|
+
return await saveFn(checkpoint);
|
|
10822
|
+
}
|
|
10823
|
+
/**
|
|
10824
|
+
* Load optimization state from checkpoint
|
|
10825
|
+
*/
|
|
10826
|
+
async loadCheckpoint(checkpointId, options) {
|
|
10827
|
+
const loadFn = options?.overrideCheckpointLoad || this.checkpointLoad;
|
|
10828
|
+
if (!loadFn) return null;
|
|
10829
|
+
return await loadFn(checkpointId);
|
|
10830
|
+
}
|
|
10831
|
+
/**
|
|
10832
|
+
* Restore optimizer state from checkpoint
|
|
10833
|
+
*/
|
|
10834
|
+
restoreFromCheckpoint(checkpoint) {
|
|
10835
|
+
this.currentRound = checkpoint.currentRound;
|
|
10836
|
+
this.scoreHistory = [...checkpoint.scoreHistory];
|
|
10837
|
+
this.configurationHistory = [...checkpoint.configurationHistory];
|
|
10838
|
+
this.stats = { ...checkpoint.stats };
|
|
10839
|
+
}
|
|
10840
|
+
/**
|
|
10841
|
+
* Check if checkpoint should be saved
|
|
10842
|
+
*/
|
|
10843
|
+
shouldSaveCheckpoint(round, options) {
|
|
10844
|
+
const interval = options?.overrideCheckpointInterval || this.checkpointInterval;
|
|
10845
|
+
return interval !== void 0 && round % interval === 0;
|
|
10846
|
+
}
|
|
10847
|
+
/**
|
|
10848
|
+
* Update optimization progress and handle checkpointing
|
|
10849
|
+
*/
|
|
10850
|
+
async updateOptimizationProgress(round, score, configuration, optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10851
|
+
this.currentRound = round;
|
|
10852
|
+
this.scoreHistory.push(score);
|
|
10853
|
+
this.configurationHistory.push(configuration);
|
|
10854
|
+
if (this.shouldSaveCheckpoint(round, options)) {
|
|
10855
|
+
await this.saveCheckpoint(
|
|
10856
|
+
optimizerType,
|
|
10857
|
+
optimizerConfig,
|
|
10858
|
+
bestScore,
|
|
10859
|
+
bestConfiguration,
|
|
10860
|
+
optimizerState,
|
|
10861
|
+
options
|
|
10862
|
+
);
|
|
10863
|
+
}
|
|
10864
|
+
}
|
|
10865
|
+
/**
|
|
10866
|
+
* Save final checkpoint on completion
|
|
10867
|
+
*/
|
|
10868
|
+
async saveFinalCheckpoint(optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10869
|
+
if (options?.saveCheckpointOnComplete !== false) {
|
|
10870
|
+
await this.saveCheckpoint(
|
|
10871
|
+
optimizerType,
|
|
10872
|
+
optimizerConfig,
|
|
10873
|
+
bestScore,
|
|
10874
|
+
bestConfiguration,
|
|
10875
|
+
{ ...optimizerState, final: true },
|
|
10876
|
+
options
|
|
10877
|
+
);
|
|
10878
|
+
}
|
|
10879
|
+
}
|
|
10880
|
+
};
|
|
10881
|
+
|
|
10196
10882
|
// db/base.ts
|
|
10197
10883
|
var import_api23 = require("@opentelemetry/api");
|
|
10198
10884
|
var AxDBBase = class {
|
|
@@ -11652,11 +12338,7 @@ var AxMCPStreambleHTTPTransport = class {
|
|
|
11652
12338
|
};
|
|
11653
12339
|
|
|
11654
12340
|
// dsp/optimizers/bootstrapFewshot.ts
|
|
11655
|
-
var AxBootstrapFewShot = class {
|
|
11656
|
-
ai;
|
|
11657
|
-
teacherAI;
|
|
11658
|
-
program;
|
|
11659
|
-
examples;
|
|
12341
|
+
var AxBootstrapFewShot = class extends AxBaseOptimizer {
|
|
11660
12342
|
maxRounds;
|
|
11661
12343
|
maxDemos;
|
|
11662
12344
|
maxExamples;
|
|
@@ -11667,37 +12349,20 @@ var AxBootstrapFewShot = class {
|
|
|
11667
12349
|
verboseMode;
|
|
11668
12350
|
debugMode;
|
|
11669
12351
|
traces = [];
|
|
11670
|
-
|
|
11671
|
-
|
|
11672
|
-
|
|
11673
|
-
|
|
11674
|
-
|
|
11675
|
-
|
|
11676
|
-
|
|
11677
|
-
|
|
11678
|
-
|
|
11679
|
-
|
|
11680
|
-
options
|
|
11681
|
-
|
|
11682
|
-
|
|
11683
|
-
|
|
11684
|
-
}
|
|
11685
|
-
const bootstrapOptions = options;
|
|
11686
|
-
this.maxRounds = bootstrapOptions?.maxRounds ?? 3;
|
|
11687
|
-
this.maxDemos = bootstrapOptions?.maxDemos ?? 4;
|
|
11688
|
-
this.maxExamples = bootstrapOptions?.maxExamples ?? 16;
|
|
11689
|
-
this.batchSize = bootstrapOptions?.batchSize ?? 1;
|
|
11690
|
-
this.earlyStoppingPatience = bootstrapOptions?.earlyStoppingPatience ?? 0;
|
|
11691
|
-
this.costMonitoring = bootstrapOptions?.costMonitoring ?? false;
|
|
11692
|
-
this.maxTokensPerGeneration = bootstrapOptions?.maxTokensPerGeneration ?? 0;
|
|
11693
|
-
this.verboseMode = bootstrapOptions?.verboseMode ?? true;
|
|
11694
|
-
this.debugMode = bootstrapOptions?.debugMode ?? false;
|
|
11695
|
-
this.ai = ai;
|
|
11696
|
-
this.teacherAI = bootstrapOptions?.teacherAI;
|
|
11697
|
-
this.program = program;
|
|
11698
|
-
this.examples = examples;
|
|
11699
|
-
}
|
|
11700
|
-
async compileRound(roundIndex, metricFn, options) {
|
|
12352
|
+
constructor(args) {
|
|
12353
|
+
super(args);
|
|
12354
|
+
const options = args.options || {};
|
|
12355
|
+
this.maxRounds = options.maxRounds ?? 3;
|
|
12356
|
+
this.maxDemos = options.maxDemos ?? 4;
|
|
12357
|
+
this.maxExamples = options.maxExamples ?? 16;
|
|
12358
|
+
this.batchSize = options.batchSize ?? 1;
|
|
12359
|
+
this.earlyStoppingPatience = options.earlyStoppingPatience ?? 0;
|
|
12360
|
+
this.costMonitoring = options.costMonitoring ?? false;
|
|
12361
|
+
this.maxTokensPerGeneration = options.maxTokensPerGeneration ?? 0;
|
|
12362
|
+
this.verboseMode = options.verboseMode ?? true;
|
|
12363
|
+
this.debugMode = options.debugMode ?? false;
|
|
12364
|
+
}
|
|
12365
|
+
async compileRound(program, roundIndex, metricFn, options) {
|
|
11701
12366
|
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
11702
12367
|
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
11703
12368
|
const aiOpt = {
|
|
@@ -11720,20 +12385,20 @@ var AxBootstrapFewShot = class {
|
|
|
11720
12385
|
continue;
|
|
11721
12386
|
}
|
|
11722
12387
|
const exList = examples.filter((e) => e !== ex);
|
|
11723
|
-
|
|
11724
|
-
const aiService = this.
|
|
12388
|
+
program.setExamples(exList);
|
|
12389
|
+
const aiService = this.getTeacherOrStudentAI();
|
|
11725
12390
|
this.stats.totalCalls++;
|
|
11726
12391
|
let res;
|
|
11727
12392
|
let error;
|
|
11728
12393
|
try {
|
|
11729
|
-
res = await
|
|
12394
|
+
res = await program.forward(aiService, ex, aiOpt);
|
|
11730
12395
|
if (this.costMonitoring) {
|
|
11731
12396
|
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
11732
12397
|
}
|
|
11733
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
12398
|
+
const score = await metricFn({ prediction: res, example: ex });
|
|
11734
12399
|
const success = score >= 0.5;
|
|
11735
12400
|
if (success) {
|
|
11736
|
-
this.traces = [...this.traces, ...
|
|
12401
|
+
this.traces = [...this.traces, ...program.getTraces()];
|
|
11737
12402
|
this.stats.successfulDemos++;
|
|
11738
12403
|
}
|
|
11739
12404
|
} catch (err) {
|
|
@@ -11784,13 +12449,15 @@ var AxBootstrapFewShot = class {
|
|
|
11784
12449
|
if (!this.stats.earlyStopping) {
|
|
11785
12450
|
this.stats.earlyStopping = {
|
|
11786
12451
|
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
11787
|
-
patienceExhausted: false
|
|
12452
|
+
patienceExhausted: false,
|
|
12453
|
+
reason: "No improvement detected"
|
|
11788
12454
|
};
|
|
11789
12455
|
} else if (improvement > 0) {
|
|
11790
12456
|
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
11791
12457
|
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
11792
12458
|
this.stats.earlyStopping.patienceExhausted = true;
|
|
11793
12459
|
this.stats.earlyStopped = true;
|
|
12460
|
+
this.stats.earlyStopping.reason = `No improvement for ${this.earlyStoppingPatience} rounds`;
|
|
11794
12461
|
if (this.verboseMode || this.debugMode) {
|
|
11795
12462
|
console.log(
|
|
11796
12463
|
`
|
|
@@ -11801,37 +12468,38 @@ Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${th
|
|
|
11801
12468
|
}
|
|
11802
12469
|
}
|
|
11803
12470
|
}
|
|
11804
|
-
async compile(metricFn, options) {
|
|
11805
|
-
const
|
|
11806
|
-
const maxRounds = compileOptions?.maxRounds ?? this.maxRounds;
|
|
12471
|
+
async compile(program, metricFn, options) {
|
|
12472
|
+
const maxRounds = options?.maxIterations ?? this.maxRounds;
|
|
11807
12473
|
this.traces = [];
|
|
11808
|
-
this.
|
|
11809
|
-
totalCalls: 0,
|
|
11810
|
-
successfulDemos: 0,
|
|
11811
|
-
estimatedTokenUsage: 0,
|
|
11812
|
-
earlyStopped: false
|
|
11813
|
-
};
|
|
12474
|
+
this.reset();
|
|
11814
12475
|
for (let i = 0; i < maxRounds; i++) {
|
|
11815
|
-
await this.compileRound(i, metricFn,
|
|
12476
|
+
await this.compileRound(program, i, metricFn, options);
|
|
11816
12477
|
if (this.stats.earlyStopped) {
|
|
11817
12478
|
break;
|
|
11818
12479
|
}
|
|
11819
12480
|
}
|
|
11820
12481
|
if (this.traces.length === 0) {
|
|
11821
12482
|
throw new Error(
|
|
11822
|
-
"No demonstrations found. Either
|
|
12483
|
+
"No demonstrations found. Either provide more examples or improve the existing ones."
|
|
11823
12484
|
);
|
|
11824
12485
|
}
|
|
11825
12486
|
const demos = groupTracesByKeys(this.traces);
|
|
12487
|
+
let bestScore = 0;
|
|
12488
|
+
if (this.traces.length > 0) {
|
|
12489
|
+
bestScore = this.stats.successfulDemos / Math.max(1, this.stats.totalCalls);
|
|
12490
|
+
}
|
|
11826
12491
|
return {
|
|
11827
12492
|
demos,
|
|
11828
|
-
stats: this.stats
|
|
12493
|
+
stats: this.stats,
|
|
12494
|
+
bestScore,
|
|
12495
|
+
finalConfiguration: {
|
|
12496
|
+
maxRounds: this.maxRounds,
|
|
12497
|
+
maxDemos: this.maxDemos,
|
|
12498
|
+
batchSize: this.batchSize,
|
|
12499
|
+
successRate: bestScore
|
|
12500
|
+
}
|
|
11829
12501
|
};
|
|
11830
12502
|
}
|
|
11831
|
-
// Get optimization statistics
|
|
11832
|
-
getStats() {
|
|
11833
|
-
return this.stats;
|
|
11834
|
-
}
|
|
11835
12503
|
};
|
|
11836
12504
|
function groupTracesByKeys(programTraces) {
|
|
11837
12505
|
const groupedTraces = /* @__PURE__ */ new Map();
|
|
@@ -11846,9 +12514,12 @@ function groupTracesByKeys(programTraces) {
|
|
|
11846
12514
|
}
|
|
11847
12515
|
}
|
|
11848
12516
|
const programDemosArray = [];
|
|
11849
|
-
|
|
11850
|
-
programDemosArray.push({
|
|
11851
|
-
|
|
12517
|
+
groupedTraces.forEach((traces, programId) => {
|
|
12518
|
+
programDemosArray.push({
|
|
12519
|
+
traces,
|
|
12520
|
+
programId
|
|
12521
|
+
});
|
|
12522
|
+
});
|
|
11852
12523
|
return programDemosArray;
|
|
11853
12524
|
}
|
|
11854
12525
|
var randomSample = (array, n) => {
|
|
@@ -11867,10 +12538,8 @@ var randomSample = (array, n) => {
|
|
|
11867
12538
|
};
|
|
11868
12539
|
|
|
11869
12540
|
// dsp/optimizers/miproV2.ts
|
|
11870
|
-
var AxMiPRO = class {
|
|
11871
|
-
|
|
11872
|
-
program;
|
|
11873
|
-
examples;
|
|
12541
|
+
var AxMiPRO = class extends AxBaseOptimizer {
|
|
12542
|
+
// MiPRO-specific options
|
|
11874
12543
|
maxBootstrappedDemos;
|
|
11875
12544
|
maxLabeledDemos;
|
|
11876
12545
|
numCandidates;
|
|
@@ -11884,52 +12553,35 @@ var AxMiPRO = class {
|
|
|
11884
12553
|
viewDataBatchSize;
|
|
11885
12554
|
tipAwareProposer;
|
|
11886
12555
|
fewshotAwareProposer;
|
|
11887
|
-
seed;
|
|
11888
12556
|
verbose;
|
|
11889
|
-
bootstrapper;
|
|
11890
12557
|
earlyStoppingTrials;
|
|
11891
12558
|
minImprovementThreshold;
|
|
11892
|
-
|
|
11893
|
-
|
|
11894
|
-
|
|
11895
|
-
|
|
11896
|
-
|
|
11897
|
-
|
|
11898
|
-
|
|
11899
|
-
|
|
11900
|
-
|
|
11901
|
-
|
|
11902
|
-
this.
|
|
11903
|
-
this.
|
|
11904
|
-
this.
|
|
11905
|
-
this.
|
|
11906
|
-
this.
|
|
11907
|
-
this.
|
|
11908
|
-
this.
|
|
11909
|
-
this.
|
|
11910
|
-
this.
|
|
11911
|
-
this.
|
|
11912
|
-
this.
|
|
11913
|
-
this.
|
|
11914
|
-
this.
|
|
11915
|
-
this.
|
|
11916
|
-
this.
|
|
11917
|
-
this.
|
|
11918
|
-
this.minImprovementThreshold = miproOptions.minImprovementThreshold ?? 0.01;
|
|
11919
|
-
this.ai = ai;
|
|
11920
|
-
this.program = program;
|
|
11921
|
-
this.examples = examples;
|
|
11922
|
-
this.bootstrapper = new AxBootstrapFewShot({
|
|
11923
|
-
ai,
|
|
11924
|
-
program,
|
|
11925
|
-
examples,
|
|
11926
|
-
options: {
|
|
11927
|
-
maxDemos: this.maxBootstrappedDemos,
|
|
11928
|
-
maxRounds: 3,
|
|
11929
|
-
// Default, or adjust based on your needs
|
|
11930
|
-
verboseMode: this.verbose
|
|
11931
|
-
}
|
|
11932
|
-
});
|
|
12559
|
+
bayesianOptimization;
|
|
12560
|
+
acquisitionFunction;
|
|
12561
|
+
explorationWeight;
|
|
12562
|
+
constructor(args) {
|
|
12563
|
+
super(args);
|
|
12564
|
+
const options = args.options || {};
|
|
12565
|
+
this.numCandidates = options.numCandidates ?? 5;
|
|
12566
|
+
this.initTemperature = options.initTemperature ?? 0.7;
|
|
12567
|
+
this.maxBootstrappedDemos = options.maxBootstrappedDemos ?? 3;
|
|
12568
|
+
this.maxLabeledDemos = options.maxLabeledDemos ?? 4;
|
|
12569
|
+
this.numTrials = options.numTrials ?? 30;
|
|
12570
|
+
this.minibatch = options.minibatch ?? true;
|
|
12571
|
+
this.minibatchSize = options.minibatchSize ?? 25;
|
|
12572
|
+
this.minibatchFullEvalSteps = options.minibatchFullEvalSteps ?? 10;
|
|
12573
|
+
this.programAwareProposer = options.programAwareProposer ?? true;
|
|
12574
|
+
this.dataAwareProposer = options.dataAwareProposer ?? true;
|
|
12575
|
+
this.viewDataBatchSize = options.viewDataBatchSize ?? 10;
|
|
12576
|
+
this.tipAwareProposer = options.tipAwareProposer ?? true;
|
|
12577
|
+
this.fewshotAwareProposer = options.fewshotAwareProposer ?? true;
|
|
12578
|
+
this.verbose = options.verbose ?? false;
|
|
12579
|
+
this.earlyStoppingTrials = options.earlyStoppingTrials ?? 5;
|
|
12580
|
+
this.minImprovementThreshold = options.minImprovementThreshold ?? 0.01;
|
|
12581
|
+
this.bayesianOptimization = options.bayesianOptimization ?? false;
|
|
12582
|
+
this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
|
|
12583
|
+
this.explorationWeight = options.explorationWeight ?? 0.1;
|
|
12584
|
+
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
11933
12585
|
}
|
|
11934
12586
|
/**
|
|
11935
12587
|
* Configures the optimizer for light, medium, or heavy optimization
|
|
@@ -11973,123 +12625,60 @@ var AxMiPRO = class {
|
|
|
11973
12625
|
];
|
|
11974
12626
|
}
|
|
11975
12627
|
/**
|
|
11976
|
-
* Generates instruction candidates
|
|
12628
|
+
* Generates instruction candidates using the teacher model if available
|
|
12629
|
+
* @param options Optional compile options that may override teacher AI
|
|
11977
12630
|
* @returns Array of generated instruction candidates
|
|
11978
12631
|
*/
|
|
11979
|
-
async proposeInstructionCandidates() {
|
|
12632
|
+
async proposeInstructionCandidates(options) {
|
|
11980
12633
|
const instructions = [];
|
|
11981
|
-
|
|
11982
|
-
if (this.programAwareProposer) {
|
|
11983
|
-
programContext = await this.generateProgramSummary();
|
|
11984
|
-
}
|
|
11985
|
-
let dataContext = "";
|
|
11986
|
-
if (this.dataAwareProposer) {
|
|
11987
|
-
dataContext = await this.generateDataSummary();
|
|
11988
|
-
}
|
|
12634
|
+
const aiToUse = this.getTeacherOrStudentAI(options);
|
|
11989
12635
|
const tips = this.tipAwareProposer ? this.generateTips() : [];
|
|
11990
12636
|
for (let i = 0; i < this.numCandidates; i++) {
|
|
11991
12637
|
const tipIndex = tips.length > 0 ? i % tips.length : -1;
|
|
11992
12638
|
const tipToUse = tipIndex >= 0 ? tips[tipIndex] : "";
|
|
11993
12639
|
const instruction = await this.generateInstruction({
|
|
11994
|
-
programContext,
|
|
11995
|
-
dataContext,
|
|
11996
12640
|
tip: tipToUse,
|
|
11997
|
-
candidateIndex: i
|
|
12641
|
+
candidateIndex: i,
|
|
12642
|
+
ai: aiToUse
|
|
11998
12643
|
});
|
|
11999
12644
|
instructions.push(instruction);
|
|
12000
12645
|
}
|
|
12001
12646
|
return instructions;
|
|
12002
12647
|
}
|
|
12003
|
-
/**
|
|
12004
|
-
* Generates a summary of the program structure for instruction proposal
|
|
12005
|
-
*/
|
|
12006
|
-
async generateProgramSummary() {
|
|
12007
|
-
const prompt = `Summarize the following program structure. Focus on the signatures,
|
|
12008
|
-
input/output fields, and the purpose of each component. Identify key components
|
|
12009
|
-
that might benefit from better instructions.`;
|
|
12010
|
-
const programStr = JSON.stringify(this.program);
|
|
12011
|
-
const response = await this.ai.chat({
|
|
12012
|
-
chatPrompt: [
|
|
12013
|
-
{ role: "system", content: prompt },
|
|
12014
|
-
{ role: "user", content: programStr }
|
|
12015
|
-
],
|
|
12016
|
-
modelConfig: { temperature: 0.2 }
|
|
12017
|
-
});
|
|
12018
|
-
if (response instanceof ReadableStream) {
|
|
12019
|
-
return "";
|
|
12020
|
-
}
|
|
12021
|
-
return response.results[0]?.content || "";
|
|
12022
|
-
}
|
|
12023
|
-
/**
|
|
12024
|
-
* Generates a summary of the dataset for instruction proposal
|
|
12025
|
-
*/
|
|
12026
|
-
async generateDataSummary() {
|
|
12027
|
-
const sampleSize = Math.min(this.viewDataBatchSize, this.examples.length);
|
|
12028
|
-
const sample = this.examples.slice(0, sampleSize);
|
|
12029
|
-
const prompt = `Analyze the following dataset examples and provide a summary
|
|
12030
|
-
of key patterns, input-output relationships, and any specific challenges
|
|
12031
|
-
the data presents. Focus on what makes a good answer and what patterns should
|
|
12032
|
-
be followed.`;
|
|
12033
|
-
const dataStr = JSON.stringify(sample);
|
|
12034
|
-
const response = await this.ai.chat({
|
|
12035
|
-
chatPrompt: [
|
|
12036
|
-
{ role: "system", content: prompt },
|
|
12037
|
-
{ role: "user", content: dataStr }
|
|
12038
|
-
],
|
|
12039
|
-
modelConfig: { temperature: 0.2 }
|
|
12040
|
-
});
|
|
12041
|
-
if (response instanceof ReadableStream) {
|
|
12042
|
-
return "";
|
|
12043
|
-
}
|
|
12044
|
-
return response.results[0]?.content || "";
|
|
12045
|
-
}
|
|
12046
|
-
/**
|
|
12047
|
-
* Generates a specific instruction candidate
|
|
12048
|
-
*/
|
|
12049
12648
|
async generateInstruction({
|
|
12050
|
-
programContext,
|
|
12051
|
-
dataContext,
|
|
12052
12649
|
tip,
|
|
12053
12650
|
candidateIndex
|
|
12054
12651
|
}) {
|
|
12055
|
-
const
|
|
12056
|
-
|
|
12057
|
-
|
|
12058
|
-
|
|
12059
|
-
|
|
12060
|
-
|
|
12061
|
-
|
|
12062
|
-
|
|
12063
|
-
|
|
12064
|
-
|
|
12065
|
-
${tip ? `STYLE TIP: ${tip}
|
|
12066
|
-
|
|
12067
|
-
` : ""}
|
|
12068
|
-
|
|
12069
|
-
Your task is to craft a clear, effective instruction that will help the AI model generate
|
|
12070
|
-
accurate outputs for this task. Instruction #${candidateIndex + 1}/${this.numCandidates}.
|
|
12071
|
-
|
|
12072
|
-
The instruction should be detailed enough to guide the model but not overly prescriptive
|
|
12073
|
-
or restrictive. Focus on what makes a good response rather than listing exact steps.
|
|
12074
|
-
|
|
12075
|
-
INSTRUCTION:`;
|
|
12076
|
-
const response = await this.ai.chat({
|
|
12077
|
-
chatPrompt: [{ role: "user", content: prompt }],
|
|
12078
|
-
modelConfig: { temperature: 0.7 + 0.1 * candidateIndex }
|
|
12079
|
-
});
|
|
12080
|
-
if (response instanceof ReadableStream) {
|
|
12081
|
-
return "";
|
|
12652
|
+
const baseInstructions = [
|
|
12653
|
+
"Analyze the input carefully and provide a detailed response.",
|
|
12654
|
+
"Think step by step and provide a clear answer.",
|
|
12655
|
+
"Consider all aspects of the input before responding.",
|
|
12656
|
+
"Provide a concise but comprehensive response.",
|
|
12657
|
+
"Focus on accuracy and clarity in your response."
|
|
12658
|
+
];
|
|
12659
|
+
let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
|
|
12660
|
+
if (tip) {
|
|
12661
|
+
instruction = `${instruction} ${tip}`;
|
|
12082
12662
|
}
|
|
12083
|
-
return
|
|
12663
|
+
return instruction;
|
|
12084
12664
|
}
|
|
12085
12665
|
/**
|
|
12086
12666
|
* Bootstraps few-shot examples for the program
|
|
12087
12667
|
*/
|
|
12088
|
-
async bootstrapFewShotExamples(metricFn) {
|
|
12668
|
+
async bootstrapFewShotExamples(program, metricFn) {
|
|
12089
12669
|
if (this.verbose) {
|
|
12090
12670
|
console.log("Bootstrapping few-shot examples...");
|
|
12091
12671
|
}
|
|
12092
|
-
const
|
|
12672
|
+
const bootstrapper = new AxBootstrapFewShot({
|
|
12673
|
+
studentAI: this.studentAI,
|
|
12674
|
+
examples: this.examples,
|
|
12675
|
+
options: {
|
|
12676
|
+
maxDemos: this.maxBootstrappedDemos,
|
|
12677
|
+
maxRounds: 3,
|
|
12678
|
+
verboseMode: this.verbose
|
|
12679
|
+
}
|
|
12680
|
+
});
|
|
12681
|
+
const result = await bootstrapper.compile(program, metricFn, {
|
|
12093
12682
|
maxDemos: this.maxBootstrappedDemos
|
|
12094
12683
|
});
|
|
12095
12684
|
return result.demos || [];
|
|
@@ -12113,109 +12702,98 @@ ${dataContext}
|
|
|
12113
12702
|
return selectedExamples;
|
|
12114
12703
|
}
|
|
12115
12704
|
/**
|
|
12116
|
-
* Runs
|
|
12705
|
+
* Runs optimization to find the best combination of few-shot examples and instructions
|
|
12117
12706
|
*/
|
|
12118
|
-
async
|
|
12119
|
-
let bestConfig =
|
|
12120
|
-
let bestScore = Number.NEGATIVE_INFINITY;
|
|
12121
|
-
const evaluatedConfigs = [];
|
|
12122
|
-
const defaultConfig = {
|
|
12707
|
+
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, valset, metricFn, options) {
|
|
12708
|
+
let bestConfig = {
|
|
12123
12709
|
instruction: instructions[0] || "",
|
|
12124
12710
|
bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
|
|
12125
12711
|
labeledExamples: Math.min(1, labeledExamples.length)
|
|
12126
12712
|
};
|
|
12127
|
-
let
|
|
12128
|
-
let
|
|
12129
|
-
const
|
|
12130
|
-
|
|
12131
|
-
|
|
12132
|
-
|
|
12133
|
-
|
|
12134
|
-
|
|
12135
|
-
|
|
12136
|
-
|
|
12713
|
+
let bestScore = 0;
|
|
12714
|
+
let stagnationRounds = 0;
|
|
12715
|
+
const scoreHistory = [];
|
|
12716
|
+
let startRound = 0;
|
|
12717
|
+
if (this.resumeFromCheckpoint) {
|
|
12718
|
+
const checkpoint = await this.loadCheckpoint(
|
|
12719
|
+
this.resumeFromCheckpoint,
|
|
12720
|
+
options
|
|
12721
|
+
);
|
|
12722
|
+
if (checkpoint && checkpoint.optimizerType === "MiPRO") {
|
|
12723
|
+
if (this.verbose || options?.verbose) {
|
|
12724
|
+
console.log(
|
|
12725
|
+
`Resuming from checkpoint at round ${checkpoint.currentRound}`
|
|
12726
|
+
);
|
|
12727
|
+
}
|
|
12728
|
+
this.restoreFromCheckpoint(checkpoint);
|
|
12729
|
+
startRound = checkpoint.currentRound;
|
|
12730
|
+
bestScore = checkpoint.bestScore;
|
|
12731
|
+
bestConfig = checkpoint.bestConfiguration || bestConfig;
|
|
12732
|
+
stagnationRounds = checkpoint.stats.convergenceInfo?.stagnationRounds || 0;
|
|
12733
|
+
}
|
|
12734
|
+
}
|
|
12735
|
+
for (let i = startRound; i < this.numTrials; i++) {
|
|
12137
12736
|
const config = {
|
|
12138
|
-
instruction:
|
|
12139
|
-
bootstrappedDemos: Math.
|
|
12140
|
-
Math.random() * (bootstrappedDemos.length + 1)
|
|
12737
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
12738
|
+
bootstrappedDemos: Math.min(
|
|
12739
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
12740
|
+
this.maxBootstrappedDemos
|
|
12141
12741
|
),
|
|
12142
|
-
labeledExamples: Math.
|
|
12143
|
-
Math.random() * (labeledExamples.length + 1)
|
|
12742
|
+
labeledExamples: Math.min(
|
|
12743
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
12744
|
+
this.maxLabeledDemos
|
|
12144
12745
|
)
|
|
12145
12746
|
};
|
|
12146
|
-
configs.push(config);
|
|
12147
|
-
}
|
|
12148
|
-
for (let i = 0; i < configs.length; i++) {
|
|
12149
|
-
const config = configs[i];
|
|
12150
|
-
if (!config) continue;
|
|
12151
12747
|
const score = await this.evaluateConfig(
|
|
12748
|
+
program,
|
|
12152
12749
|
config,
|
|
12153
12750
|
bootstrappedDemos,
|
|
12154
12751
|
labeledExamples,
|
|
12155
12752
|
valset,
|
|
12156
|
-
metricFn
|
|
12157
|
-
i
|
|
12753
|
+
metricFn
|
|
12158
12754
|
);
|
|
12159
|
-
|
|
12160
|
-
|
|
12755
|
+
scoreHistory.push(score);
|
|
12756
|
+
const improvement = score - bestScore;
|
|
12757
|
+
if (improvement > this.minImprovementThreshold) {
|
|
12161
12758
|
bestScore = score;
|
|
12162
12759
|
bestConfig = config;
|
|
12163
|
-
|
|
12164
|
-
|
|
12165
|
-
|
|
12166
|
-
);
|
|
12167
|
-
}
|
|
12760
|
+
stagnationRounds = 0;
|
|
12761
|
+
} else {
|
|
12762
|
+
stagnationRounds++;
|
|
12168
12763
|
}
|
|
12169
|
-
|
|
12764
|
+
await this.updateOptimizationProgress(
|
|
12170
12765
|
i + 1,
|
|
12171
|
-
|
|
12172
|
-
|
|
12173
|
-
|
|
12174
|
-
|
|
12175
|
-
|
|
12176
|
-
|
|
12177
|
-
|
|
12178
|
-
|
|
12179
|
-
|
|
12180
|
-
|
|
12181
|
-
|
|
12182
|
-
|
|
12183
|
-
|
|
12184
|
-
);
|
|
12185
|
-
const score = await this.evaluateConfig(
|
|
12186
|
-
nextConfig,
|
|
12187
|
-
bootstrappedDemos,
|
|
12188
|
-
labeledExamples,
|
|
12189
|
-
valset,
|
|
12190
|
-
metricFn,
|
|
12191
|
-
i
|
|
12766
|
+
score,
|
|
12767
|
+
config,
|
|
12768
|
+
"MiPRO",
|
|
12769
|
+
this.getConfiguration(),
|
|
12770
|
+
bestScore,
|
|
12771
|
+
bestConfig,
|
|
12772
|
+
{
|
|
12773
|
+
stagnationRounds,
|
|
12774
|
+
bootstrappedDemos: bootstrappedDemos.length,
|
|
12775
|
+
labeledExamples: labeledExamples.length,
|
|
12776
|
+
instructions: instructions.length
|
|
12777
|
+
},
|
|
12778
|
+
options
|
|
12192
12779
|
);
|
|
12193
|
-
|
|
12194
|
-
|
|
12195
|
-
|
|
12196
|
-
|
|
12197
|
-
|
|
12198
|
-
|
|
12199
|
-
|
|
12200
|
-
)
|
|
12201
|
-
|
|
12202
|
-
|
|
12203
|
-
|
|
12204
|
-
|
|
12205
|
-
|
|
12206
|
-
|
|
12207
|
-
|
|
12208
|
-
if (this.verbose) {
|
|
12209
|
-
console.log(
|
|
12210
|
-
`Early stopping triggered after ${i + 1} trials. No improvement for ${trialsWithoutImprovement} trials.`
|
|
12211
|
-
);
|
|
12212
|
-
}
|
|
12213
|
-
break;
|
|
12780
|
+
if (this.onProgress) {
|
|
12781
|
+
this.onProgress({
|
|
12782
|
+
round: i + 1,
|
|
12783
|
+
totalRounds: this.numTrials,
|
|
12784
|
+
currentScore: score,
|
|
12785
|
+
bestScore,
|
|
12786
|
+
tokensUsed: this.stats.resourceUsage.totalTokens,
|
|
12787
|
+
timeElapsed: Date.now(),
|
|
12788
|
+
successfulExamples: this.stats.successfulDemos,
|
|
12789
|
+
totalExamples: this.examples.length,
|
|
12790
|
+
currentConfiguration: config,
|
|
12791
|
+
convergenceInfo: {
|
|
12792
|
+
improvement,
|
|
12793
|
+
stagnationRounds,
|
|
12794
|
+
isConverging: stagnationRounds < this.earlyStoppingTrials
|
|
12214
12795
|
}
|
|
12215
|
-
}
|
|
12216
|
-
lastBestScore = bestScore;
|
|
12217
|
-
trialsWithoutImprovement = 0;
|
|
12218
|
-
}
|
|
12796
|
+
});
|
|
12219
12797
|
}
|
|
12220
12798
|
updateProgressBar(
|
|
12221
12799
|
i + 1,
|
|
@@ -12225,243 +12803,91 @@ ${dataContext}
|
|
|
12225
12803
|
"Running MIPROv2 optimization",
|
|
12226
12804
|
30
|
|
12227
12805
|
);
|
|
12228
|
-
if (this.
|
|
12229
|
-
|
|
12230
|
-
|
|
12231
|
-
`Running full evaluation on best configuration at trial ${i + 1}`
|
|
12232
|
-
);
|
|
12233
|
-
}
|
|
12234
|
-
const fullScore = await this.fullEvaluation(
|
|
12235
|
-
bestConfig,
|
|
12236
|
-
bootstrappedDemos,
|
|
12237
|
-
labeledExamples,
|
|
12238
|
-
valset,
|
|
12239
|
-
metricFn
|
|
12240
|
-
);
|
|
12241
|
-
if (this.verbose) {
|
|
12242
|
-
console.log(`Full evaluation score: ${fullScore}`);
|
|
12243
|
-
}
|
|
12244
|
-
bestScore = fullScore;
|
|
12806
|
+
if (this.checkCostLimits()) {
|
|
12807
|
+
this.triggerEarlyStopping("Cost limit reached", i + 1);
|
|
12808
|
+
break;
|
|
12245
12809
|
}
|
|
12246
|
-
|
|
12247
|
-
|
|
12248
|
-
|
|
12249
|
-
|
|
12250
|
-
"Optimization failed to find any valid configurations, using default fallback configuration"
|
|
12810
|
+
if (stagnationRounds >= this.earlyStoppingTrials) {
|
|
12811
|
+
this.triggerEarlyStopping(
|
|
12812
|
+
`No improvement for ${this.earlyStoppingTrials} trials`,
|
|
12813
|
+
i - stagnationRounds + 1
|
|
12251
12814
|
);
|
|
12815
|
+
break;
|
|
12252
12816
|
}
|
|
12253
|
-
|
|
12254
|
-
|
|
12255
|
-
|
|
12256
|
-
|
|
12257
|
-
bootstrappedDemos,
|
|
12258
|
-
labeledExamples,
|
|
12259
|
-
valset,
|
|
12260
|
-
metricFn,
|
|
12261
|
-
this.numTrials - 1
|
|
12817
|
+
if (this.checkTargetScore(bestScore)) {
|
|
12818
|
+
this.triggerEarlyStopping(
|
|
12819
|
+
`Target score ${this.targetScore} reached`,
|
|
12820
|
+
i + 1
|
|
12262
12821
|
);
|
|
12263
|
-
|
|
12264
|
-
if (this.verbose) {
|
|
12265
|
-
console.error("Error evaluating default configuration:", err);
|
|
12266
|
-
}
|
|
12267
|
-
bestScore = 0;
|
|
12822
|
+
break;
|
|
12268
12823
|
}
|
|
12269
12824
|
}
|
|
12825
|
+
this.stats.convergenceInfo.stagnationRounds = stagnationRounds;
|
|
12826
|
+
this.stats.convergenceInfo.finalImprovement = scoreHistory.length > 1 ? bestScore - scoreHistory[0] : 0;
|
|
12827
|
+
this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
|
|
12270
12828
|
return { bestConfig, bestScore };
|
|
12271
12829
|
}
|
|
12272
|
-
|
|
12273
|
-
|
|
12274
|
-
*/
|
|
12275
|
-
async evaluateConfig(config, bootstrappedDemos, labeledExamples, valset, metricFn, trialIndex) {
|
|
12830
|
+
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, valset, metricFn) {
|
|
12831
|
+
const testProgram = { ...program };
|
|
12276
12832
|
this.applyConfigToProgram(
|
|
12277
|
-
|
|
12833
|
+
testProgram,
|
|
12278
12834
|
config,
|
|
12279
12835
|
bootstrappedDemos,
|
|
12280
12836
|
labeledExamples
|
|
12281
12837
|
);
|
|
12282
|
-
let
|
|
12283
|
-
|
|
12284
|
-
|
|
12285
|
-
const minibatchEvalSet = [];
|
|
12286
|
-
for (let j = 0; j < this.minibatchSize; j++) {
|
|
12287
|
-
const idx = (startIdx + j) % valset.length;
|
|
12288
|
-
const example = valset[idx];
|
|
12289
|
-
if (example) {
|
|
12290
|
-
minibatchEvalSet.push(example);
|
|
12291
|
-
}
|
|
12292
|
-
}
|
|
12293
|
-
evalSet = minibatchEvalSet;
|
|
12294
|
-
}
|
|
12295
|
-
let sumOfScores = 0;
|
|
12838
|
+
let totalScore = 0;
|
|
12839
|
+
let count = 0;
|
|
12840
|
+
const evalSet = valset.slice(0, Math.min(5, valset.length));
|
|
12296
12841
|
for (const example of evalSet) {
|
|
12297
12842
|
try {
|
|
12298
|
-
const prediction = await
|
|
12299
|
-
|
|
12300
|
-
|
|
12301
|
-
|
|
12302
|
-
|
|
12303
|
-
|
|
12304
|
-
|
|
12305
|
-
|
|
12306
|
-
|
|
12307
|
-
|
|
12308
|
-
return sumOfScores / evalSet.length;
|
|
12309
|
-
}
|
|
12310
|
-
/**
|
|
12311
|
-
* Run full evaluation on the entire validation set
|
|
12312
|
-
*/
|
|
12313
|
-
async fullEvaluation(config, bootstrappedDemos, labeledExamples, valset, metricFn) {
|
|
12314
|
-
this.applyConfigToProgram(
|
|
12315
|
-
this.program,
|
|
12316
|
-
config,
|
|
12317
|
-
bootstrappedDemos,
|
|
12318
|
-
labeledExamples
|
|
12319
|
-
);
|
|
12320
|
-
let sumOfScores = 0;
|
|
12321
|
-
for (const example of valset) {
|
|
12322
|
-
try {
|
|
12323
|
-
const prediction = await this.program.forward(this.ai, example);
|
|
12324
|
-
const score = metricFn({ prediction, example });
|
|
12325
|
-
sumOfScores += score;
|
|
12326
|
-
} catch (err) {
|
|
12327
|
-
if (this.verbose) {
|
|
12328
|
-
console.error("Error evaluating example:", err);
|
|
12329
|
-
}
|
|
12843
|
+
const prediction = await testProgram.forward(
|
|
12844
|
+
this.studentAI,
|
|
12845
|
+
example
|
|
12846
|
+
);
|
|
12847
|
+
const score = await metricFn({ prediction, example });
|
|
12848
|
+
totalScore += score;
|
|
12849
|
+
count++;
|
|
12850
|
+
this.stats.totalCalls++;
|
|
12851
|
+
} catch {
|
|
12852
|
+
continue;
|
|
12330
12853
|
}
|
|
12331
12854
|
}
|
|
12332
|
-
|
|
12333
|
-
return sumOfScores / valset.length;
|
|
12855
|
+
return count > 0 ? totalScore / count : 0;
|
|
12334
12856
|
}
|
|
12335
|
-
/**
|
|
12336
|
-
* Implements a Bayesian-inspired selection of the next configuration to try
|
|
12337
|
-
* This is a simplified version using Upper Confidence Bound (UCB) strategy
|
|
12338
|
-
*/
|
|
12339
|
-
selectNextConfiguration(evaluatedConfigs, maxBootstrappedDemos, maxLabeledExamples, instructions) {
|
|
12340
|
-
if (evaluatedConfigs.length < 5) {
|
|
12341
|
-
const instructionIndex = Math.floor(Math.random() * instructions.length);
|
|
12342
|
-
return {
|
|
12343
|
-
instruction: instructions[instructionIndex] || "",
|
|
12344
|
-
bootstrappedDemos: Math.floor(
|
|
12345
|
-
Math.random() * (maxBootstrappedDemos + 1)
|
|
12346
|
-
),
|
|
12347
|
-
labeledExamples: Math.floor(Math.random() * (maxLabeledExamples + 1))
|
|
12348
|
-
};
|
|
12349
|
-
}
|
|
12350
|
-
const sortedConfigs = [...evaluatedConfigs].sort(
|
|
12351
|
-
(a, b) => b.score - a.score
|
|
12352
|
-
);
|
|
12353
|
-
const topConfigs = sortedConfigs.slice(0, Math.min(3, sortedConfigs.length));
|
|
12354
|
-
const meanBootstrappedDemos = topConfigs.reduce((sum, c) => sum + c.config.bootstrappedDemos, 0) / topConfigs.length;
|
|
12355
|
-
const meanLabeledExamples = topConfigs.reduce((sum, c) => sum + c.config.labeledExamples, 0) / topConfigs.length;
|
|
12356
|
-
const popularInstructions = topConfigs.map((c) => c.config.instruction);
|
|
12357
|
-
const explorationFactor = Math.max(
|
|
12358
|
-
0.2,
|
|
12359
|
-
1 - evaluatedConfigs.length / this.numTrials
|
|
12360
|
-
);
|
|
12361
|
-
let newBootstrappedDemos;
|
|
12362
|
-
let newLabeledExamples;
|
|
12363
|
-
let newInstruction;
|
|
12364
|
-
if (Math.random() < 0.7) {
|
|
12365
|
-
newBootstrappedDemos = Math.min(
|
|
12366
|
-
maxBootstrappedDemos,
|
|
12367
|
-
Math.max(
|
|
12368
|
-
0,
|
|
12369
|
-
Math.round(
|
|
12370
|
-
meanBootstrappedDemos + (Math.random() * 2 - 1) * explorationFactor * 2
|
|
12371
|
-
)
|
|
12372
|
-
)
|
|
12373
|
-
);
|
|
12374
|
-
} else {
|
|
12375
|
-
newBootstrappedDemos = Math.floor(
|
|
12376
|
-
Math.random() * (maxBootstrappedDemos + 1)
|
|
12377
|
-
);
|
|
12378
|
-
}
|
|
12379
|
-
if (Math.random() < 0.7) {
|
|
12380
|
-
newLabeledExamples = Math.min(
|
|
12381
|
-
maxLabeledExamples,
|
|
12382
|
-
Math.max(
|
|
12383
|
-
0,
|
|
12384
|
-
Math.round(
|
|
12385
|
-
meanLabeledExamples + (Math.random() * 2 - 1) * explorationFactor * 2
|
|
12386
|
-
)
|
|
12387
|
-
)
|
|
12388
|
-
);
|
|
12389
|
-
} else {
|
|
12390
|
-
newLabeledExamples = Math.floor(Math.random() * (maxLabeledExamples + 1));
|
|
12391
|
-
}
|
|
12392
|
-
if (Math.random() < 0.7 && popularInstructions.length > 0) {
|
|
12393
|
-
const idx = Math.floor(Math.random() * popularInstructions.length);
|
|
12394
|
-
newInstruction = popularInstructions[idx] || "";
|
|
12395
|
-
} else {
|
|
12396
|
-
const idx = Math.floor(Math.random() * instructions.length);
|
|
12397
|
-
newInstruction = instructions[idx] || "";
|
|
12398
|
-
}
|
|
12399
|
-
return {
|
|
12400
|
-
instruction: newInstruction,
|
|
12401
|
-
bootstrappedDemos: newBootstrappedDemos,
|
|
12402
|
-
labeledExamples: newLabeledExamples
|
|
12403
|
-
};
|
|
12404
|
-
}
|
|
12405
|
-
/**
|
|
12406
|
-
* Applies a configuration to a program instance
|
|
12407
|
-
*/
|
|
12408
12857
|
applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
|
|
12409
|
-
|
|
12410
|
-
|
|
12858
|
+
if (program.setInstruction) {
|
|
12859
|
+
program.setInstruction(config.instruction);
|
|
12860
|
+
}
|
|
12861
|
+
if (config.bootstrappedDemos > 0 && program.setDemos) {
|
|
12411
12862
|
program.setDemos(bootstrappedDemos.slice(0, config.bootstrappedDemos));
|
|
12412
12863
|
}
|
|
12413
|
-
if (config.labeledExamples > 0) {
|
|
12864
|
+
if (config.labeledExamples > 0 && program.setExamples) {
|
|
12414
12865
|
program.setExamples(labeledExamples.slice(0, config.labeledExamples));
|
|
12415
12866
|
}
|
|
12416
12867
|
}
|
|
12417
|
-
/**
|
|
12418
|
-
* Sets instruction to a program
|
|
12419
|
-
* Note: Workaround since setInstruction may not be available directly
|
|
12420
|
-
*/
|
|
12421
|
-
setInstructionToProgram(program, instruction) {
|
|
12422
|
-
const programWithInstruction = program;
|
|
12423
|
-
programWithInstruction.setInstruction?.(instruction);
|
|
12424
|
-
}
|
|
12425
12868
|
/**
|
|
12426
12869
|
* The main compile method to run MIPROv2 optimization
|
|
12427
|
-
* @param metricFn Evaluation metric function
|
|
12428
|
-
* @param options Optional configuration options
|
|
12429
|
-
* @returns The optimization result
|
|
12430
12870
|
*/
|
|
12431
|
-
async compile(metricFn, options) {
|
|
12871
|
+
async compile(program, metricFn, options) {
|
|
12872
|
+
const startTime = Date.now();
|
|
12873
|
+
this.setupRandomSeed();
|
|
12432
12874
|
const miproOptions = options;
|
|
12433
12875
|
if (miproOptions?.auto) {
|
|
12434
12876
|
this.configureAuto(miproOptions.auto);
|
|
12435
12877
|
}
|
|
12436
|
-
const
|
|
12437
|
-
|
|
12438
|
-
if (this.verbose) {
|
|
12878
|
+
const valset = this.getValidationSet(options) || (miproOptions?.valset ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
|
|
12879
|
+
if (this.verbose || options?.verbose) {
|
|
12439
12880
|
console.log(`Starting MIPROv2 optimization with ${this.numTrials} trials`);
|
|
12440
12881
|
console.log(
|
|
12441
|
-
`Using ${
|
|
12882
|
+
`Using ${this.examples.length} examples for training and ${valset.length} for validation`
|
|
12442
12883
|
);
|
|
12443
|
-
|
|
12444
|
-
|
|
12445
|
-
if (this.verbose) {
|
|
12446
|
-
console.log("Using provided teacher to assist with bootstrapping");
|
|
12884
|
+
if (this.teacherAI) {
|
|
12885
|
+
console.log("Using separate teacher model for instruction generation");
|
|
12447
12886
|
}
|
|
12448
|
-
const bootstrapperWithTeacher = new AxBootstrapFewShot({
|
|
12449
|
-
ai: this.ai,
|
|
12450
|
-
program: this.program,
|
|
12451
|
-
examples: this.examples,
|
|
12452
|
-
options: {
|
|
12453
|
-
maxDemos: this.maxBootstrappedDemos,
|
|
12454
|
-
maxRounds: 3,
|
|
12455
|
-
verboseMode: this.verbose,
|
|
12456
|
-
teacherAI: this.ai
|
|
12457
|
-
// Use the same AI but with the teacher program
|
|
12458
|
-
}
|
|
12459
|
-
});
|
|
12460
|
-
this.bootstrapper = bootstrapperWithTeacher;
|
|
12461
12887
|
}
|
|
12462
12888
|
let bootstrappedDemos = [];
|
|
12463
12889
|
if (this.maxBootstrappedDemos > 0) {
|
|
12464
|
-
bootstrappedDemos = await this.bootstrapFewShotExamples(metricFn);
|
|
12890
|
+
bootstrappedDemos = await this.bootstrapFewShotExamples(program, metricFn);
|
|
12465
12891
|
if (this.verbose) {
|
|
12466
12892
|
console.log(
|
|
12467
12893
|
`Generated ${bootstrappedDemos.length} bootstrapped demonstrations`
|
|
@@ -12477,38 +12903,191 @@ ${dataContext}
|
|
|
12477
12903
|
);
|
|
12478
12904
|
}
|
|
12479
12905
|
}
|
|
12480
|
-
const instructions = await this.proposeInstructionCandidates();
|
|
12906
|
+
const instructions = await this.proposeInstructionCandidates(options);
|
|
12481
12907
|
if (this.verbose) {
|
|
12482
12908
|
console.log(`Generated ${instructions.length} instruction candidates`);
|
|
12909
|
+
if (this.hasTeacherAI(options)) {
|
|
12910
|
+
console.log("Using teacher AI for instruction generation");
|
|
12911
|
+
}
|
|
12483
12912
|
}
|
|
12484
|
-
const { bestConfig, bestScore } = await this.
|
|
12913
|
+
const { bestConfig, bestScore } = await this.runOptimization(
|
|
12914
|
+
program,
|
|
12485
12915
|
bootstrappedDemos,
|
|
12486
12916
|
labeledExamples,
|
|
12487
12917
|
instructions,
|
|
12488
12918
|
valset,
|
|
12489
|
-
metricFn
|
|
12919
|
+
metricFn,
|
|
12920
|
+
options
|
|
12490
12921
|
);
|
|
12491
|
-
if (this.verbose) {
|
|
12922
|
+
if (this.verbose || options?.verbose) {
|
|
12492
12923
|
console.log(`Optimization complete. Best score: ${bestScore}`);
|
|
12493
12924
|
console.log(`Best configuration: ${JSON.stringify(bestConfig)}`);
|
|
12494
12925
|
}
|
|
12495
|
-
this.
|
|
12496
|
-
this.
|
|
12926
|
+
if (this.checkTargetScore(bestScore)) {
|
|
12927
|
+
this.triggerEarlyStopping(
|
|
12928
|
+
`Target score ${this.targetScore} reached with score ${bestScore}`,
|
|
12929
|
+
this.numTrials
|
|
12930
|
+
);
|
|
12931
|
+
}
|
|
12932
|
+
let signature;
|
|
12933
|
+
if ("getSignature" in program && typeof program.getSignature === "function") {
|
|
12934
|
+
signature = program.getSignature();
|
|
12935
|
+
} else {
|
|
12936
|
+
signature = "input -> output";
|
|
12937
|
+
}
|
|
12938
|
+
const optimizedGen = new AxGen(signature);
|
|
12939
|
+
this.applyConfigToAxGen(
|
|
12940
|
+
optimizedGen,
|
|
12497
12941
|
bestConfig,
|
|
12498
12942
|
bootstrappedDemos,
|
|
12499
12943
|
labeledExamples
|
|
12500
12944
|
);
|
|
12945
|
+
this.updateResourceUsage(startTime);
|
|
12946
|
+
this.stats.convergenceInfo.converged = true;
|
|
12947
|
+
this.stats.convergenceInfo.finalImprovement = bestScore;
|
|
12948
|
+
await this.saveFinalCheckpoint(
|
|
12949
|
+
"MiPRO",
|
|
12950
|
+
this.getConfiguration(),
|
|
12951
|
+
bestScore,
|
|
12952
|
+
bestConfig,
|
|
12953
|
+
{
|
|
12954
|
+
bootstrappedDemos: bootstrappedDemos.length,
|
|
12955
|
+
labeledExamples: labeledExamples.length,
|
|
12956
|
+
instructions: instructions.length,
|
|
12957
|
+
optimizedGen: !!optimizedGen
|
|
12958
|
+
},
|
|
12959
|
+
options
|
|
12960
|
+
);
|
|
12501
12961
|
return {
|
|
12502
|
-
|
|
12503
|
-
|
|
12962
|
+
demos: bootstrappedDemos,
|
|
12963
|
+
stats: this.stats,
|
|
12964
|
+
bestScore,
|
|
12965
|
+
optimizedGen,
|
|
12966
|
+
finalConfiguration: {
|
|
12967
|
+
instruction: bestConfig.instruction,
|
|
12968
|
+
bootstrappedDemos: bestConfig.bootstrappedDemos,
|
|
12969
|
+
labeledExamples: bestConfig.labeledExamples,
|
|
12970
|
+
numCandidates: this.numCandidates,
|
|
12971
|
+
numTrials: this.numTrials
|
|
12972
|
+
}
|
|
12504
12973
|
};
|
|
12505
12974
|
}
|
|
12506
12975
|
/**
|
|
12507
|
-
*
|
|
12508
|
-
* @returns Optimization statistics or undefined if not available
|
|
12976
|
+
* Applies a configuration to an AxGen instance
|
|
12509
12977
|
*/
|
|
12510
|
-
|
|
12511
|
-
|
|
12978
|
+
applyConfigToAxGen(axgen, config, bootstrappedDemos, labeledExamples) {
|
|
12979
|
+
if ("setInstruction" in axgen && typeof axgen.setInstruction === "function") {
|
|
12980
|
+
axgen.setInstruction(config.instruction);
|
|
12981
|
+
}
|
|
12982
|
+
if (config.bootstrappedDemos > 0) {
|
|
12983
|
+
axgen.setDemos(bootstrappedDemos.slice(0, config.bootstrappedDemos));
|
|
12984
|
+
}
|
|
12985
|
+
if (config.labeledExamples > 0) {
|
|
12986
|
+
axgen.setExamples(
|
|
12987
|
+
labeledExamples.slice(
|
|
12988
|
+
0,
|
|
12989
|
+
config.labeledExamples
|
|
12990
|
+
)
|
|
12991
|
+
);
|
|
12992
|
+
}
|
|
12993
|
+
}
|
|
12994
|
+
/**
|
|
12995
|
+
* Get optimizer-specific configuration
|
|
12996
|
+
* @returns Current optimizer configuration
|
|
12997
|
+
*/
|
|
12998
|
+
getConfiguration() {
|
|
12999
|
+
return {
|
|
13000
|
+
numCandidates: this.numCandidates,
|
|
13001
|
+
initTemperature: this.initTemperature,
|
|
13002
|
+
maxBootstrappedDemos: this.maxBootstrappedDemos,
|
|
13003
|
+
maxLabeledDemos: this.maxLabeledDemos,
|
|
13004
|
+
numTrials: this.numTrials,
|
|
13005
|
+
minibatch: this.minibatch,
|
|
13006
|
+
minibatchSize: this.minibatchSize,
|
|
13007
|
+
minibatchFullEvalSteps: this.minibatchFullEvalSteps,
|
|
13008
|
+
programAwareProposer: this.programAwareProposer,
|
|
13009
|
+
dataAwareProposer: this.dataAwareProposer,
|
|
13010
|
+
tipAwareProposer: this.tipAwareProposer,
|
|
13011
|
+
fewshotAwareProposer: this.fewshotAwareProposer,
|
|
13012
|
+
earlyStoppingTrials: this.earlyStoppingTrials,
|
|
13013
|
+
minImprovementThreshold: this.minImprovementThreshold,
|
|
13014
|
+
bayesianOptimization: this.bayesianOptimization,
|
|
13015
|
+
acquisitionFunction: this.acquisitionFunction,
|
|
13016
|
+
explorationWeight: this.explorationWeight
|
|
13017
|
+
};
|
|
13018
|
+
}
|
|
13019
|
+
/**
|
|
13020
|
+
* Update optimizer configuration
|
|
13021
|
+
* @param config New configuration to merge with existing
|
|
13022
|
+
*/
|
|
13023
|
+
updateConfiguration(config) {
|
|
13024
|
+
if (config.numCandidates !== void 0) {
|
|
13025
|
+
this.numCandidates = config.numCandidates;
|
|
13026
|
+
}
|
|
13027
|
+
if (config.initTemperature !== void 0) {
|
|
13028
|
+
this.initTemperature = config.initTemperature;
|
|
13029
|
+
}
|
|
13030
|
+
if (config.maxBootstrappedDemos !== void 0) {
|
|
13031
|
+
this.maxBootstrappedDemos = config.maxBootstrappedDemos;
|
|
13032
|
+
}
|
|
13033
|
+
if (config.maxLabeledDemos !== void 0) {
|
|
13034
|
+
this.maxLabeledDemos = config.maxLabeledDemos;
|
|
13035
|
+
}
|
|
13036
|
+
if (config.numTrials !== void 0) {
|
|
13037
|
+
this.numTrials = config.numTrials;
|
|
13038
|
+
}
|
|
13039
|
+
if (config.minibatch !== void 0) {
|
|
13040
|
+
this.minibatch = config.minibatch;
|
|
13041
|
+
}
|
|
13042
|
+
if (config.minibatchSize !== void 0) {
|
|
13043
|
+
this.minibatchSize = config.minibatchSize;
|
|
13044
|
+
}
|
|
13045
|
+
if (config.earlyStoppingTrials !== void 0) {
|
|
13046
|
+
this.earlyStoppingTrials = config.earlyStoppingTrials;
|
|
13047
|
+
}
|
|
13048
|
+
if (config.minImprovementThreshold !== void 0) {
|
|
13049
|
+
this.minImprovementThreshold = config.minImprovementThreshold;
|
|
13050
|
+
}
|
|
13051
|
+
if (config.verbose !== void 0) {
|
|
13052
|
+
this.verbose = config.verbose;
|
|
13053
|
+
}
|
|
13054
|
+
}
|
|
13055
|
+
/**
|
|
13056
|
+
* Reset optimizer state for reuse with different programs
|
|
13057
|
+
*/
|
|
13058
|
+
reset() {
|
|
13059
|
+
super.reset();
|
|
13060
|
+
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13061
|
+
}
|
|
13062
|
+
/**
|
|
13063
|
+
* Validate that the optimizer can handle the given program
|
|
13064
|
+
* @param program Program to validate
|
|
13065
|
+
* @returns Validation result with any issues found
|
|
13066
|
+
*/
|
|
13067
|
+
validateProgram(program) {
|
|
13068
|
+
const result = super.validateProgram(program);
|
|
13069
|
+
if (this.examples.length < this.maxBootstrappedDemos + this.maxLabeledDemos) {
|
|
13070
|
+
result.issues.push(
|
|
13071
|
+
`Not enough examples: need at least ${this.maxBootstrappedDemos + this.maxLabeledDemos}, got ${this.examples.length}`
|
|
13072
|
+
);
|
|
13073
|
+
result.suggestions.push(
|
|
13074
|
+
"Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
|
|
13075
|
+
);
|
|
13076
|
+
}
|
|
13077
|
+
const valSetSize = this.getValidationSet().length;
|
|
13078
|
+
if (valSetSize < 5) {
|
|
13079
|
+
result.issues.push(
|
|
13080
|
+
"Validation set too small for reliable MiPRO optimization"
|
|
13081
|
+
);
|
|
13082
|
+
result.suggestions.push(
|
|
13083
|
+
"Provide more examples or a larger validation set"
|
|
13084
|
+
);
|
|
13085
|
+
}
|
|
13086
|
+
return {
|
|
13087
|
+
isValid: result.issues.length === 0,
|
|
13088
|
+
issues: result.issues,
|
|
13089
|
+
suggestions: result.suggestions
|
|
13090
|
+
};
|
|
12512
13091
|
}
|
|
12513
13092
|
};
|
|
12514
13093
|
|
|
@@ -12755,7 +13334,7 @@ var AxTestPrompt = class {
|
|
|
12755
13334
|
throw new Error("Invalid example");
|
|
12756
13335
|
}
|
|
12757
13336
|
const res = await this.program.forward(this.ai, ex);
|
|
12758
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
13337
|
+
const score = await metricFn({ prediction: res, example: ex });
|
|
12759
13338
|
sumOfScores += score;
|
|
12760
13339
|
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
12761
13340
|
updateProgressBar(i, total, sumOfScores, et, "Testing Prompt", 30);
|
|
@@ -14789,7 +15368,6 @@ var AxRAG = class extends AxChainOfThought {
|
|
|
14789
15368
|
);
|
|
14790
15369
|
this.genQuery = new AxGen(qsig);
|
|
14791
15370
|
this.queryFn = queryFn;
|
|
14792
|
-
this.register(this.genQuery);
|
|
14793
15371
|
}
|
|
14794
15372
|
async forward(ai, values, options) {
|
|
14795
15373
|
let question;
|
|
@@ -14867,6 +15445,7 @@ var AxRAG = class extends AxChainOfThought {
|
|
|
14867
15445
|
AxAssertionError,
|
|
14868
15446
|
AxBalancer,
|
|
14869
15447
|
AxBaseAI,
|
|
15448
|
+
AxBaseOptimizer,
|
|
14870
15449
|
AxBootstrapFewShot,
|
|
14871
15450
|
AxChainOfThought,
|
|
14872
15451
|
AxDB,
|
|
@@ -14876,6 +15455,7 @@ var AxRAG = class extends AxChainOfThought {
|
|
|
14876
15455
|
AxDBMemory,
|
|
14877
15456
|
AxDBPinecone,
|
|
14878
15457
|
AxDBWeaviate,
|
|
15458
|
+
AxDefaultCostTracker,
|
|
14879
15459
|
AxDefaultQueryRewriter,
|
|
14880
15460
|
AxDefaultResultReranker,
|
|
14881
15461
|
AxDockerSession,
|