@ax-llm/ax 12.0.3 → 12.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +270 -225
- package/index.cjs.map +1 -1
- package/index.d.cts +146 -66
- package/index.d.ts +146 -66
- package/index.js +270 -225
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.cjs
CHANGED
|
@@ -4164,7 +4164,8 @@ var axModelInfoMistral = [
|
|
|
4164
4164
|
// ai/mistral/api.ts
|
|
4165
4165
|
var axAIMistralDefaultConfig = () => structuredClone({
|
|
4166
4166
|
model: "mistral-small-latest" /* MistralSmall */,
|
|
4167
|
-
...axBaseAIDefaultConfig()
|
|
4167
|
+
...axBaseAIDefaultConfig(),
|
|
4168
|
+
topP: 1
|
|
4168
4169
|
});
|
|
4169
4170
|
var axAIMistralBestConfig = () => structuredClone({
|
|
4170
4171
|
...axAIMistralDefaultConfig(),
|
|
@@ -4192,6 +4193,15 @@ var AxAIMistral = class extends AxAIOpenAIBase {
|
|
|
4192
4193
|
hasThinkingBudget: false,
|
|
4193
4194
|
hasShowThoughts: false
|
|
4194
4195
|
};
|
|
4196
|
+
const chatReqUpdater = (req) => {
|
|
4197
|
+
const { max_completion_tokens, stream_options, messages, ...result } = req;
|
|
4198
|
+
return {
|
|
4199
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
4200
|
+
...result,
|
|
4201
|
+
messages: this.updateMessages(messages),
|
|
4202
|
+
max_tokens: max_completion_tokens
|
|
4203
|
+
};
|
|
4204
|
+
};
|
|
4195
4205
|
super({
|
|
4196
4206
|
apiKey,
|
|
4197
4207
|
config: _config,
|
|
@@ -4199,10 +4209,32 @@ var AxAIMistral = class extends AxAIOpenAIBase {
|
|
|
4199
4209
|
apiURL: "https://api.mistral.ai/v1",
|
|
4200
4210
|
modelInfo,
|
|
4201
4211
|
models,
|
|
4202
|
-
supportFor
|
|
4212
|
+
supportFor,
|
|
4213
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
4214
|
+
chatReqUpdater
|
|
4203
4215
|
});
|
|
4204
4216
|
super.setName("Mistral");
|
|
4205
4217
|
}
|
|
4218
|
+
updateMessages(messages) {
|
|
4219
|
+
const messagesUpdated = [];
|
|
4220
|
+
if (!Array.isArray(messages)) {
|
|
4221
|
+
return messages;
|
|
4222
|
+
}
|
|
4223
|
+
for (const message of messages) {
|
|
4224
|
+
if (message.role === "user" && Array.isArray(message.content)) {
|
|
4225
|
+
const contentUpdated = message.content.map((item) => {
|
|
4226
|
+
if (typeof item === "object" && item !== null && item.type === "image_url") {
|
|
4227
|
+
return { type: "image_url", image_url: item.image_url?.url };
|
|
4228
|
+
}
|
|
4229
|
+
return item;
|
|
4230
|
+
});
|
|
4231
|
+
messagesUpdated.push({ ...message, content: contentUpdated });
|
|
4232
|
+
} else {
|
|
4233
|
+
messagesUpdated.push(message);
|
|
4234
|
+
}
|
|
4235
|
+
}
|
|
4236
|
+
return messagesUpdated;
|
|
4237
|
+
}
|
|
4206
4238
|
};
|
|
4207
4239
|
|
|
4208
4240
|
// ai/ollama/api.ts
|
|
@@ -8332,7 +8364,7 @@ var AxSignature = class _AxSignature {
|
|
|
8332
8364
|
});
|
|
8333
8365
|
this.inputFields = parsedFields;
|
|
8334
8366
|
this.invalidateValidationCache();
|
|
8335
|
-
this.
|
|
8367
|
+
this.updateHashLight();
|
|
8336
8368
|
} catch (error) {
|
|
8337
8369
|
if (error instanceof AxSignatureValidationError) {
|
|
8338
8370
|
throw error;
|
|
@@ -8358,7 +8390,7 @@ var AxSignature = class _AxSignature {
|
|
|
8358
8390
|
});
|
|
8359
8391
|
this.outputFields = parsedFields;
|
|
8360
8392
|
this.invalidateValidationCache();
|
|
8361
|
-
this.
|
|
8393
|
+
this.updateHashLight();
|
|
8362
8394
|
} catch (error) {
|
|
8363
8395
|
if (error instanceof AxSignatureValidationError) {
|
|
8364
8396
|
throw error;
|
|
@@ -10126,219 +10158,6 @@ function validateModels2(services) {
|
|
|
10126
10158
|
}
|
|
10127
10159
|
}
|
|
10128
10160
|
|
|
10129
|
-
// dsp/optimize.ts
|
|
10130
|
-
var AxBootstrapFewShot = class {
|
|
10131
|
-
ai;
|
|
10132
|
-
teacherAI;
|
|
10133
|
-
program;
|
|
10134
|
-
examples;
|
|
10135
|
-
maxRounds;
|
|
10136
|
-
maxDemos;
|
|
10137
|
-
maxExamples;
|
|
10138
|
-
batchSize;
|
|
10139
|
-
earlyStoppingPatience;
|
|
10140
|
-
costMonitoring;
|
|
10141
|
-
maxTokensPerGeneration;
|
|
10142
|
-
verboseMode;
|
|
10143
|
-
debugMode;
|
|
10144
|
-
traces = [];
|
|
10145
|
-
stats = {
|
|
10146
|
-
totalCalls: 0,
|
|
10147
|
-
successfulDemos: 0,
|
|
10148
|
-
estimatedTokenUsage: 0,
|
|
10149
|
-
earlyStopped: false
|
|
10150
|
-
};
|
|
10151
|
-
constructor({
|
|
10152
|
-
ai,
|
|
10153
|
-
program,
|
|
10154
|
-
examples = [],
|
|
10155
|
-
options
|
|
10156
|
-
}) {
|
|
10157
|
-
if (examples.length === 0) {
|
|
10158
|
-
throw new Error("No examples found");
|
|
10159
|
-
}
|
|
10160
|
-
this.maxRounds = options?.maxRounds ?? 3;
|
|
10161
|
-
this.maxDemos = options?.maxDemos ?? 4;
|
|
10162
|
-
this.maxExamples = options?.maxExamples ?? 16;
|
|
10163
|
-
this.batchSize = options?.batchSize ?? 1;
|
|
10164
|
-
this.earlyStoppingPatience = options?.earlyStoppingPatience ?? 0;
|
|
10165
|
-
this.costMonitoring = options?.costMonitoring ?? false;
|
|
10166
|
-
this.maxTokensPerGeneration = options?.maxTokensPerGeneration ?? 0;
|
|
10167
|
-
this.verboseMode = options?.verboseMode ?? true;
|
|
10168
|
-
this.debugMode = options?.debugMode ?? false;
|
|
10169
|
-
this.ai = ai;
|
|
10170
|
-
this.teacherAI = options?.teacherAI;
|
|
10171
|
-
this.program = program;
|
|
10172
|
-
this.examples = examples;
|
|
10173
|
-
}
|
|
10174
|
-
async compileRound(roundIndex, metricFn, options) {
|
|
10175
|
-
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
10176
|
-
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
10177
|
-
const aiOpt = {
|
|
10178
|
-
modelConfig: {
|
|
10179
|
-
temperature: 0.7
|
|
10180
|
-
}
|
|
10181
|
-
};
|
|
10182
|
-
if (this.maxTokensPerGeneration > 0) {
|
|
10183
|
-
aiOpt.modelConfig.max_tokens = this.maxTokensPerGeneration;
|
|
10184
|
-
}
|
|
10185
|
-
const examples = randomSample(this.examples, this.maxExamples);
|
|
10186
|
-
const previousSuccessCount = this.traces.length;
|
|
10187
|
-
for (let i = 0; i < examples.length; i += this.batchSize) {
|
|
10188
|
-
if (i > 0) {
|
|
10189
|
-
aiOpt.modelConfig.temperature = 0.7 + 1e-3 * i;
|
|
10190
|
-
}
|
|
10191
|
-
const batch = examples.slice(i, i + this.batchSize);
|
|
10192
|
-
for (const ex of batch) {
|
|
10193
|
-
if (!ex) {
|
|
10194
|
-
continue;
|
|
10195
|
-
}
|
|
10196
|
-
const exList = examples.filter((e) => e !== ex);
|
|
10197
|
-
this.program.setExamples(exList);
|
|
10198
|
-
const aiService = this.teacherAI || this.ai;
|
|
10199
|
-
this.stats.totalCalls++;
|
|
10200
|
-
let res;
|
|
10201
|
-
let error;
|
|
10202
|
-
try {
|
|
10203
|
-
res = await this.program.forward(aiService, ex, aiOpt);
|
|
10204
|
-
if (this.costMonitoring) {
|
|
10205
|
-
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
10206
|
-
}
|
|
10207
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
10208
|
-
const success = score >= 0.5;
|
|
10209
|
-
if (success) {
|
|
10210
|
-
this.traces = [...this.traces, ...this.program.getTraces()];
|
|
10211
|
-
this.stats.successfulDemos++;
|
|
10212
|
-
}
|
|
10213
|
-
} catch (err) {
|
|
10214
|
-
error = err;
|
|
10215
|
-
res = {};
|
|
10216
|
-
}
|
|
10217
|
-
const current = i + examples.length * roundIndex + (batch.indexOf(ex) + 1);
|
|
10218
|
-
const total = examples.length * this.maxRounds;
|
|
10219
|
-
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
10220
|
-
if (this.verboseMode || this.debugMode) {
|
|
10221
|
-
const configInfo = {
|
|
10222
|
-
maxRounds: this.maxRounds,
|
|
10223
|
-
batchSize: this.batchSize,
|
|
10224
|
-
earlyStoppingPatience: this.earlyStoppingPatience,
|
|
10225
|
-
costMonitoring: this.costMonitoring,
|
|
10226
|
-
verboseMode: this.verboseMode,
|
|
10227
|
-
debugMode: this.debugMode
|
|
10228
|
-
};
|
|
10229
|
-
updateDetailedProgress(
|
|
10230
|
-
roundIndex,
|
|
10231
|
-
current,
|
|
10232
|
-
total,
|
|
10233
|
-
et,
|
|
10234
|
-
ex,
|
|
10235
|
-
this.stats,
|
|
10236
|
-
configInfo,
|
|
10237
|
-
res,
|
|
10238
|
-
error
|
|
10239
|
-
);
|
|
10240
|
-
} else {
|
|
10241
|
-
updateProgressBar(
|
|
10242
|
-
current,
|
|
10243
|
-
total,
|
|
10244
|
-
this.traces.length,
|
|
10245
|
-
et,
|
|
10246
|
-
"Tuning Prompt",
|
|
10247
|
-
30
|
|
10248
|
-
);
|
|
10249
|
-
}
|
|
10250
|
-
if (this.traces.length >= maxDemos) {
|
|
10251
|
-
return;
|
|
10252
|
-
}
|
|
10253
|
-
}
|
|
10254
|
-
}
|
|
10255
|
-
if (this.earlyStoppingPatience > 0) {
|
|
10256
|
-
const newSuccessCount = this.traces.length;
|
|
10257
|
-
const improvement = newSuccessCount - previousSuccessCount;
|
|
10258
|
-
if (!this.stats.earlyStopping) {
|
|
10259
|
-
this.stats.earlyStopping = {
|
|
10260
|
-
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
10261
|
-
patienceExhausted: false
|
|
10262
|
-
};
|
|
10263
|
-
} else if (improvement > 0) {
|
|
10264
|
-
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
10265
|
-
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
10266
|
-
this.stats.earlyStopping.patienceExhausted = true;
|
|
10267
|
-
this.stats.earlyStopped = true;
|
|
10268
|
-
if (this.verboseMode || this.debugMode) {
|
|
10269
|
-
console.log(
|
|
10270
|
-
`
|
|
10271
|
-
Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${this.earlyStoppingPatience} rounds.`
|
|
10272
|
-
);
|
|
10273
|
-
}
|
|
10274
|
-
return;
|
|
10275
|
-
}
|
|
10276
|
-
}
|
|
10277
|
-
}
|
|
10278
|
-
async compile(metricFn, options) {
|
|
10279
|
-
const maxRounds = options?.maxRounds ?? this.maxRounds;
|
|
10280
|
-
this.traces = [];
|
|
10281
|
-
this.stats = {
|
|
10282
|
-
totalCalls: 0,
|
|
10283
|
-
successfulDemos: 0,
|
|
10284
|
-
estimatedTokenUsage: 0,
|
|
10285
|
-
earlyStopped: false
|
|
10286
|
-
};
|
|
10287
|
-
for (let i = 0; i < maxRounds; i++) {
|
|
10288
|
-
await this.compileRound(i, metricFn, options);
|
|
10289
|
-
if (this.stats.earlyStopped) {
|
|
10290
|
-
break;
|
|
10291
|
-
}
|
|
10292
|
-
}
|
|
10293
|
-
if (this.traces.length === 0) {
|
|
10294
|
-
throw new Error(
|
|
10295
|
-
"No demonstrations found. Either provider more examples or improve the existing ones."
|
|
10296
|
-
);
|
|
10297
|
-
}
|
|
10298
|
-
const demos = groupTracesByKeys(this.traces);
|
|
10299
|
-
return {
|
|
10300
|
-
demos,
|
|
10301
|
-
stats: this.stats
|
|
10302
|
-
};
|
|
10303
|
-
}
|
|
10304
|
-
// Get optimization statistics
|
|
10305
|
-
getStats() {
|
|
10306
|
-
return this.stats;
|
|
10307
|
-
}
|
|
10308
|
-
};
|
|
10309
|
-
function groupTracesByKeys(programTraces) {
|
|
10310
|
-
const groupedTraces = /* @__PURE__ */ new Map();
|
|
10311
|
-
for (const programTrace of programTraces) {
|
|
10312
|
-
if (groupedTraces.has(programTrace.programId)) {
|
|
10313
|
-
const traces = groupedTraces.get(programTrace.programId);
|
|
10314
|
-
if (traces) {
|
|
10315
|
-
traces.push(programTrace.trace);
|
|
10316
|
-
}
|
|
10317
|
-
} else {
|
|
10318
|
-
groupedTraces.set(programTrace.programId, [programTrace.trace]);
|
|
10319
|
-
}
|
|
10320
|
-
}
|
|
10321
|
-
const programDemosArray = [];
|
|
10322
|
-
for (const [programId, traces] of groupedTraces.entries()) {
|
|
10323
|
-
programDemosArray.push({ traces, programId });
|
|
10324
|
-
}
|
|
10325
|
-
return programDemosArray;
|
|
10326
|
-
}
|
|
10327
|
-
var randomSample = (array, n) => {
|
|
10328
|
-
const clonedArray = [...array];
|
|
10329
|
-
for (let i = clonedArray.length - 1; i > 0; i--) {
|
|
10330
|
-
const j = Math.floor(Math.random() * (i + 1));
|
|
10331
|
-
const caI = clonedArray[i];
|
|
10332
|
-
const caJ = clonedArray[j];
|
|
10333
|
-
if (!caI || !caJ) {
|
|
10334
|
-
throw new Error("Invalid array elements");
|
|
10335
|
-
}
|
|
10336
|
-
;
|
|
10337
|
-
[clonedArray[i], clonedArray[j]] = [caJ, caI];
|
|
10338
|
-
}
|
|
10339
|
-
return clonedArray.slice(0, n);
|
|
10340
|
-
};
|
|
10341
|
-
|
|
10342
10161
|
// db/base.ts
|
|
10343
10162
|
var import_api23 = require("@opentelemetry/api");
|
|
10344
10163
|
var AxDBBase = class {
|
|
@@ -11797,7 +11616,222 @@ var AxMCPStreambleHTTPTransport = class {
|
|
|
11797
11616
|
}
|
|
11798
11617
|
};
|
|
11799
11618
|
|
|
11800
|
-
// dsp/
|
|
11619
|
+
// dsp/optimizers/bootstrapFewshot.ts
|
|
11620
|
+
var AxBootstrapFewShot = class {
|
|
11621
|
+
ai;
|
|
11622
|
+
teacherAI;
|
|
11623
|
+
program;
|
|
11624
|
+
examples;
|
|
11625
|
+
maxRounds;
|
|
11626
|
+
maxDemos;
|
|
11627
|
+
maxExamples;
|
|
11628
|
+
batchSize;
|
|
11629
|
+
earlyStoppingPatience;
|
|
11630
|
+
costMonitoring;
|
|
11631
|
+
maxTokensPerGeneration;
|
|
11632
|
+
verboseMode;
|
|
11633
|
+
debugMode;
|
|
11634
|
+
traces = [];
|
|
11635
|
+
stats = {
|
|
11636
|
+
totalCalls: 0,
|
|
11637
|
+
successfulDemos: 0,
|
|
11638
|
+
estimatedTokenUsage: 0,
|
|
11639
|
+
earlyStopped: false
|
|
11640
|
+
};
|
|
11641
|
+
constructor({
|
|
11642
|
+
ai,
|
|
11643
|
+
program,
|
|
11644
|
+
examples = [],
|
|
11645
|
+
options
|
|
11646
|
+
}) {
|
|
11647
|
+
if (examples.length === 0) {
|
|
11648
|
+
throw new Error("No examples found");
|
|
11649
|
+
}
|
|
11650
|
+
const bootstrapOptions = options;
|
|
11651
|
+
this.maxRounds = bootstrapOptions?.maxRounds ?? 3;
|
|
11652
|
+
this.maxDemos = bootstrapOptions?.maxDemos ?? 4;
|
|
11653
|
+
this.maxExamples = bootstrapOptions?.maxExamples ?? 16;
|
|
11654
|
+
this.batchSize = bootstrapOptions?.batchSize ?? 1;
|
|
11655
|
+
this.earlyStoppingPatience = bootstrapOptions?.earlyStoppingPatience ?? 0;
|
|
11656
|
+
this.costMonitoring = bootstrapOptions?.costMonitoring ?? false;
|
|
11657
|
+
this.maxTokensPerGeneration = bootstrapOptions?.maxTokensPerGeneration ?? 0;
|
|
11658
|
+
this.verboseMode = bootstrapOptions?.verboseMode ?? true;
|
|
11659
|
+
this.debugMode = bootstrapOptions?.debugMode ?? false;
|
|
11660
|
+
this.ai = ai;
|
|
11661
|
+
this.teacherAI = bootstrapOptions?.teacherAI;
|
|
11662
|
+
this.program = program;
|
|
11663
|
+
this.examples = examples;
|
|
11664
|
+
}
|
|
11665
|
+
async compileRound(roundIndex, metricFn, options) {
|
|
11666
|
+
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
11667
|
+
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
11668
|
+
const aiOpt = {
|
|
11669
|
+
modelConfig: {
|
|
11670
|
+
temperature: 0.7
|
|
11671
|
+
}
|
|
11672
|
+
};
|
|
11673
|
+
if (this.maxTokensPerGeneration > 0) {
|
|
11674
|
+
aiOpt.modelConfig.max_tokens = this.maxTokensPerGeneration;
|
|
11675
|
+
}
|
|
11676
|
+
const examples = randomSample(this.examples, this.maxExamples);
|
|
11677
|
+
const previousSuccessCount = this.traces.length;
|
|
11678
|
+
for (let i = 0; i < examples.length; i += this.batchSize) {
|
|
11679
|
+
if (i > 0) {
|
|
11680
|
+
aiOpt.modelConfig.temperature = 0.7 + 1e-3 * i;
|
|
11681
|
+
}
|
|
11682
|
+
const batch = examples.slice(i, i + this.batchSize);
|
|
11683
|
+
for (const ex of batch) {
|
|
11684
|
+
if (!ex) {
|
|
11685
|
+
continue;
|
|
11686
|
+
}
|
|
11687
|
+
const exList = examples.filter((e) => e !== ex);
|
|
11688
|
+
this.program.setExamples(exList);
|
|
11689
|
+
const aiService = this.teacherAI || this.ai;
|
|
11690
|
+
this.stats.totalCalls++;
|
|
11691
|
+
let res;
|
|
11692
|
+
let error;
|
|
11693
|
+
try {
|
|
11694
|
+
res = await this.program.forward(aiService, ex, aiOpt);
|
|
11695
|
+
if (this.costMonitoring) {
|
|
11696
|
+
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
11697
|
+
}
|
|
11698
|
+
const score = metricFn({ prediction: res, example: ex });
|
|
11699
|
+
const success = score >= 0.5;
|
|
11700
|
+
if (success) {
|
|
11701
|
+
this.traces = [...this.traces, ...this.program.getTraces()];
|
|
11702
|
+
this.stats.successfulDemos++;
|
|
11703
|
+
}
|
|
11704
|
+
} catch (err) {
|
|
11705
|
+
error = err;
|
|
11706
|
+
res = {};
|
|
11707
|
+
}
|
|
11708
|
+
const current = i + examples.length * roundIndex + (batch.indexOf(ex) + 1);
|
|
11709
|
+
const total = examples.length * this.maxRounds;
|
|
11710
|
+
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
11711
|
+
if (this.verboseMode || this.debugMode) {
|
|
11712
|
+
const configInfo = {
|
|
11713
|
+
maxRounds: this.maxRounds,
|
|
11714
|
+
batchSize: this.batchSize,
|
|
11715
|
+
earlyStoppingPatience: this.earlyStoppingPatience,
|
|
11716
|
+
costMonitoring: this.costMonitoring,
|
|
11717
|
+
verboseMode: this.verboseMode,
|
|
11718
|
+
debugMode: this.debugMode
|
|
11719
|
+
};
|
|
11720
|
+
updateDetailedProgress(
|
|
11721
|
+
roundIndex,
|
|
11722
|
+
current,
|
|
11723
|
+
total,
|
|
11724
|
+
et,
|
|
11725
|
+
ex,
|
|
11726
|
+
this.stats,
|
|
11727
|
+
configInfo,
|
|
11728
|
+
res,
|
|
11729
|
+
error
|
|
11730
|
+
);
|
|
11731
|
+
} else {
|
|
11732
|
+
updateProgressBar(
|
|
11733
|
+
current,
|
|
11734
|
+
total,
|
|
11735
|
+
this.traces.length,
|
|
11736
|
+
et,
|
|
11737
|
+
"Tuning Prompt",
|
|
11738
|
+
30
|
|
11739
|
+
);
|
|
11740
|
+
}
|
|
11741
|
+
if (this.traces.length >= maxDemos) {
|
|
11742
|
+
return;
|
|
11743
|
+
}
|
|
11744
|
+
}
|
|
11745
|
+
}
|
|
11746
|
+
if (this.earlyStoppingPatience > 0) {
|
|
11747
|
+
const newSuccessCount = this.traces.length;
|
|
11748
|
+
const improvement = newSuccessCount - previousSuccessCount;
|
|
11749
|
+
if (!this.stats.earlyStopping) {
|
|
11750
|
+
this.stats.earlyStopping = {
|
|
11751
|
+
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
11752
|
+
patienceExhausted: false
|
|
11753
|
+
};
|
|
11754
|
+
} else if (improvement > 0) {
|
|
11755
|
+
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
11756
|
+
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
11757
|
+
this.stats.earlyStopping.patienceExhausted = true;
|
|
11758
|
+
this.stats.earlyStopped = true;
|
|
11759
|
+
if (this.verboseMode || this.debugMode) {
|
|
11760
|
+
console.log(
|
|
11761
|
+
`
|
|
11762
|
+
Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${this.earlyStoppingPatience} rounds.`
|
|
11763
|
+
);
|
|
11764
|
+
}
|
|
11765
|
+
return;
|
|
11766
|
+
}
|
|
11767
|
+
}
|
|
11768
|
+
}
|
|
11769
|
+
async compile(metricFn, options) {
|
|
11770
|
+
const compileOptions = options;
|
|
11771
|
+
const maxRounds = compileOptions?.maxRounds ?? this.maxRounds;
|
|
11772
|
+
this.traces = [];
|
|
11773
|
+
this.stats = {
|
|
11774
|
+
totalCalls: 0,
|
|
11775
|
+
successfulDemos: 0,
|
|
11776
|
+
estimatedTokenUsage: 0,
|
|
11777
|
+
earlyStopped: false
|
|
11778
|
+
};
|
|
11779
|
+
for (let i = 0; i < maxRounds; i++) {
|
|
11780
|
+
await this.compileRound(i, metricFn, compileOptions);
|
|
11781
|
+
if (this.stats.earlyStopped) {
|
|
11782
|
+
break;
|
|
11783
|
+
}
|
|
11784
|
+
}
|
|
11785
|
+
if (this.traces.length === 0) {
|
|
11786
|
+
throw new Error(
|
|
11787
|
+
"No demonstrations found. Either provider more examples or improve the existing ones."
|
|
11788
|
+
);
|
|
11789
|
+
}
|
|
11790
|
+
const demos = groupTracesByKeys(this.traces);
|
|
11791
|
+
return {
|
|
11792
|
+
demos,
|
|
11793
|
+
stats: this.stats
|
|
11794
|
+
};
|
|
11795
|
+
}
|
|
11796
|
+
// Get optimization statistics
|
|
11797
|
+
getStats() {
|
|
11798
|
+
return this.stats;
|
|
11799
|
+
}
|
|
11800
|
+
};
|
|
11801
|
+
function groupTracesByKeys(programTraces) {
|
|
11802
|
+
const groupedTraces = /* @__PURE__ */ new Map();
|
|
11803
|
+
for (const programTrace of programTraces) {
|
|
11804
|
+
if (groupedTraces.has(programTrace.programId)) {
|
|
11805
|
+
const traces = groupedTraces.get(programTrace.programId);
|
|
11806
|
+
if (traces) {
|
|
11807
|
+
traces.push(programTrace.trace);
|
|
11808
|
+
}
|
|
11809
|
+
} else {
|
|
11810
|
+
groupedTraces.set(programTrace.programId, [programTrace.trace]);
|
|
11811
|
+
}
|
|
11812
|
+
}
|
|
11813
|
+
const programDemosArray = [];
|
|
11814
|
+
for (const [programId, traces] of groupedTraces.entries()) {
|
|
11815
|
+
programDemosArray.push({ traces, programId });
|
|
11816
|
+
}
|
|
11817
|
+
return programDemosArray;
|
|
11818
|
+
}
|
|
11819
|
+
var randomSample = (array, n) => {
|
|
11820
|
+
const clonedArray = [...array];
|
|
11821
|
+
for (let i = clonedArray.length - 1; i > 0; i--) {
|
|
11822
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
11823
|
+
const caI = clonedArray[i];
|
|
11824
|
+
const caJ = clonedArray[j];
|
|
11825
|
+
if (!caI || !caJ) {
|
|
11826
|
+
throw new Error("Invalid array elements");
|
|
11827
|
+
}
|
|
11828
|
+
;
|
|
11829
|
+
[clonedArray[i], clonedArray[j]] = [caJ, caI];
|
|
11830
|
+
}
|
|
11831
|
+
return clonedArray.slice(0, n);
|
|
11832
|
+
};
|
|
11833
|
+
|
|
11834
|
+
// dsp/optimizers/miproV2.ts
|
|
11801
11835
|
var AxMiPRO = class {
|
|
11802
11836
|
ai;
|
|
11803
11837
|
program;
|
|
@@ -12023,7 +12057,7 @@ ${dataContext}
|
|
|
12023
12057
|
const result = await this.bootstrapper.compile(metricFn, {
|
|
12024
12058
|
maxDemos: this.maxBootstrappedDemos
|
|
12025
12059
|
});
|
|
12026
|
-
return result.demos;
|
|
12060
|
+
return result.demos || [];
|
|
12027
12061
|
}
|
|
12028
12062
|
/**
|
|
12029
12063
|
* Selects labeled examples directly from the training set
|
|
@@ -12357,21 +12391,22 @@ ${dataContext}
|
|
|
12357
12391
|
* The main compile method to run MIPROv2 optimization
|
|
12358
12392
|
* @param metricFn Evaluation metric function
|
|
12359
12393
|
* @param options Optional configuration options
|
|
12360
|
-
* @returns The
|
|
12394
|
+
* @returns The optimization result
|
|
12361
12395
|
*/
|
|
12362
12396
|
async compile(metricFn, options) {
|
|
12363
|
-
|
|
12364
|
-
|
|
12397
|
+
const miproOptions = options;
|
|
12398
|
+
if (miproOptions?.auto) {
|
|
12399
|
+
this.configureAuto(miproOptions.auto);
|
|
12365
12400
|
}
|
|
12366
12401
|
const trainset = this.examples;
|
|
12367
|
-
const valset =
|
|
12402
|
+
const valset = miproOptions?.valset || this.examples.slice(0, Math.floor(this.examples.length * 0.8));
|
|
12368
12403
|
if (this.verbose) {
|
|
12369
12404
|
console.log(`Starting MIPROv2 optimization with ${this.numTrials} trials`);
|
|
12370
12405
|
console.log(
|
|
12371
12406
|
`Using ${trainset.length} examples for training and ${valset.length} for validation`
|
|
12372
12407
|
);
|
|
12373
12408
|
}
|
|
12374
|
-
if (
|
|
12409
|
+
if (miproOptions?.teacher) {
|
|
12375
12410
|
if (this.verbose) {
|
|
12376
12411
|
console.log("Using provided teacher to assist with bootstrapping");
|
|
12377
12412
|
}
|
|
@@ -12428,7 +12463,17 @@ ${dataContext}
|
|
|
12428
12463
|
bootstrappedDemos,
|
|
12429
12464
|
labeledExamples
|
|
12430
12465
|
);
|
|
12431
|
-
return
|
|
12466
|
+
return {
|
|
12467
|
+
program: this.program,
|
|
12468
|
+
demos: bootstrappedDemos
|
|
12469
|
+
};
|
|
12470
|
+
}
|
|
12471
|
+
/**
|
|
12472
|
+
* Get optimization statistics from the internal bootstrapper
|
|
12473
|
+
* @returns Optimization statistics or undefined if not available
|
|
12474
|
+
*/
|
|
12475
|
+
getStats() {
|
|
12476
|
+
return this.bootstrapper.getStats();
|
|
12432
12477
|
}
|
|
12433
12478
|
};
|
|
12434
12479
|
|