@ax-llm/ax 12.0.3 → 12.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +270 -225
- package/index.cjs.map +1 -1
- package/index.d.cts +146 -66
- package/index.d.ts +146 -66
- package/index.js +270 -225
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.js
CHANGED
|
@@ -3992,7 +3992,8 @@ var axModelInfoMistral = [
|
|
|
3992
3992
|
// ai/mistral/api.ts
|
|
3993
3993
|
var axAIMistralDefaultConfig = () => structuredClone({
|
|
3994
3994
|
model: "mistral-small-latest" /* MistralSmall */,
|
|
3995
|
-
...axBaseAIDefaultConfig()
|
|
3995
|
+
...axBaseAIDefaultConfig(),
|
|
3996
|
+
topP: 1
|
|
3996
3997
|
});
|
|
3997
3998
|
var axAIMistralBestConfig = () => structuredClone({
|
|
3998
3999
|
...axAIMistralDefaultConfig(),
|
|
@@ -4020,6 +4021,15 @@ var AxAIMistral = class extends AxAIOpenAIBase {
|
|
|
4020
4021
|
hasThinkingBudget: false,
|
|
4021
4022
|
hasShowThoughts: false
|
|
4022
4023
|
};
|
|
4024
|
+
const chatReqUpdater = (req) => {
|
|
4025
|
+
const { max_completion_tokens, stream_options, messages, ...result } = req;
|
|
4026
|
+
return {
|
|
4027
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
4028
|
+
...result,
|
|
4029
|
+
messages: this.updateMessages(messages),
|
|
4030
|
+
max_tokens: max_completion_tokens
|
|
4031
|
+
};
|
|
4032
|
+
};
|
|
4023
4033
|
super({
|
|
4024
4034
|
apiKey,
|
|
4025
4035
|
config: _config,
|
|
@@ -4027,10 +4037,32 @@ var AxAIMistral = class extends AxAIOpenAIBase {
|
|
|
4027
4037
|
apiURL: "https://api.mistral.ai/v1",
|
|
4028
4038
|
modelInfo,
|
|
4029
4039
|
models,
|
|
4030
|
-
supportFor
|
|
4040
|
+
supportFor,
|
|
4041
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
4042
|
+
chatReqUpdater
|
|
4031
4043
|
});
|
|
4032
4044
|
super.setName("Mistral");
|
|
4033
4045
|
}
|
|
4046
|
+
updateMessages(messages) {
|
|
4047
|
+
const messagesUpdated = [];
|
|
4048
|
+
if (!Array.isArray(messages)) {
|
|
4049
|
+
return messages;
|
|
4050
|
+
}
|
|
4051
|
+
for (const message of messages) {
|
|
4052
|
+
if (message.role === "user" && Array.isArray(message.content)) {
|
|
4053
|
+
const contentUpdated = message.content.map((item) => {
|
|
4054
|
+
if (typeof item === "object" && item !== null && item.type === "image_url") {
|
|
4055
|
+
return { type: "image_url", image_url: item.image_url?.url };
|
|
4056
|
+
}
|
|
4057
|
+
return item;
|
|
4058
|
+
});
|
|
4059
|
+
messagesUpdated.push({ ...message, content: contentUpdated });
|
|
4060
|
+
} else {
|
|
4061
|
+
messagesUpdated.push(message);
|
|
4062
|
+
}
|
|
4063
|
+
}
|
|
4064
|
+
return messagesUpdated;
|
|
4065
|
+
}
|
|
4034
4066
|
};
|
|
4035
4067
|
|
|
4036
4068
|
// ai/ollama/api.ts
|
|
@@ -8164,7 +8196,7 @@ var AxSignature = class _AxSignature {
|
|
|
8164
8196
|
});
|
|
8165
8197
|
this.inputFields = parsedFields;
|
|
8166
8198
|
this.invalidateValidationCache();
|
|
8167
|
-
this.
|
|
8199
|
+
this.updateHashLight();
|
|
8168
8200
|
} catch (error) {
|
|
8169
8201
|
if (error instanceof AxSignatureValidationError) {
|
|
8170
8202
|
throw error;
|
|
@@ -8190,7 +8222,7 @@ var AxSignature = class _AxSignature {
|
|
|
8190
8222
|
});
|
|
8191
8223
|
this.outputFields = parsedFields;
|
|
8192
8224
|
this.invalidateValidationCache();
|
|
8193
|
-
this.
|
|
8225
|
+
this.updateHashLight();
|
|
8194
8226
|
} catch (error) {
|
|
8195
8227
|
if (error instanceof AxSignatureValidationError) {
|
|
8196
8228
|
throw error;
|
|
@@ -9958,219 +9990,6 @@ function validateModels2(services) {
|
|
|
9958
9990
|
}
|
|
9959
9991
|
}
|
|
9960
9992
|
|
|
9961
|
-
// dsp/optimize.ts
|
|
9962
|
-
var AxBootstrapFewShot = class {
|
|
9963
|
-
ai;
|
|
9964
|
-
teacherAI;
|
|
9965
|
-
program;
|
|
9966
|
-
examples;
|
|
9967
|
-
maxRounds;
|
|
9968
|
-
maxDemos;
|
|
9969
|
-
maxExamples;
|
|
9970
|
-
batchSize;
|
|
9971
|
-
earlyStoppingPatience;
|
|
9972
|
-
costMonitoring;
|
|
9973
|
-
maxTokensPerGeneration;
|
|
9974
|
-
verboseMode;
|
|
9975
|
-
debugMode;
|
|
9976
|
-
traces = [];
|
|
9977
|
-
stats = {
|
|
9978
|
-
totalCalls: 0,
|
|
9979
|
-
successfulDemos: 0,
|
|
9980
|
-
estimatedTokenUsage: 0,
|
|
9981
|
-
earlyStopped: false
|
|
9982
|
-
};
|
|
9983
|
-
constructor({
|
|
9984
|
-
ai,
|
|
9985
|
-
program,
|
|
9986
|
-
examples = [],
|
|
9987
|
-
options
|
|
9988
|
-
}) {
|
|
9989
|
-
if (examples.length === 0) {
|
|
9990
|
-
throw new Error("No examples found");
|
|
9991
|
-
}
|
|
9992
|
-
this.maxRounds = options?.maxRounds ?? 3;
|
|
9993
|
-
this.maxDemos = options?.maxDemos ?? 4;
|
|
9994
|
-
this.maxExamples = options?.maxExamples ?? 16;
|
|
9995
|
-
this.batchSize = options?.batchSize ?? 1;
|
|
9996
|
-
this.earlyStoppingPatience = options?.earlyStoppingPatience ?? 0;
|
|
9997
|
-
this.costMonitoring = options?.costMonitoring ?? false;
|
|
9998
|
-
this.maxTokensPerGeneration = options?.maxTokensPerGeneration ?? 0;
|
|
9999
|
-
this.verboseMode = options?.verboseMode ?? true;
|
|
10000
|
-
this.debugMode = options?.debugMode ?? false;
|
|
10001
|
-
this.ai = ai;
|
|
10002
|
-
this.teacherAI = options?.teacherAI;
|
|
10003
|
-
this.program = program;
|
|
10004
|
-
this.examples = examples;
|
|
10005
|
-
}
|
|
10006
|
-
async compileRound(roundIndex, metricFn, options) {
|
|
10007
|
-
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
10008
|
-
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
10009
|
-
const aiOpt = {
|
|
10010
|
-
modelConfig: {
|
|
10011
|
-
temperature: 0.7
|
|
10012
|
-
}
|
|
10013
|
-
};
|
|
10014
|
-
if (this.maxTokensPerGeneration > 0) {
|
|
10015
|
-
aiOpt.modelConfig.max_tokens = this.maxTokensPerGeneration;
|
|
10016
|
-
}
|
|
10017
|
-
const examples = randomSample(this.examples, this.maxExamples);
|
|
10018
|
-
const previousSuccessCount = this.traces.length;
|
|
10019
|
-
for (let i = 0; i < examples.length; i += this.batchSize) {
|
|
10020
|
-
if (i > 0) {
|
|
10021
|
-
aiOpt.modelConfig.temperature = 0.7 + 1e-3 * i;
|
|
10022
|
-
}
|
|
10023
|
-
const batch = examples.slice(i, i + this.batchSize);
|
|
10024
|
-
for (const ex of batch) {
|
|
10025
|
-
if (!ex) {
|
|
10026
|
-
continue;
|
|
10027
|
-
}
|
|
10028
|
-
const exList = examples.filter((e) => e !== ex);
|
|
10029
|
-
this.program.setExamples(exList);
|
|
10030
|
-
const aiService = this.teacherAI || this.ai;
|
|
10031
|
-
this.stats.totalCalls++;
|
|
10032
|
-
let res;
|
|
10033
|
-
let error;
|
|
10034
|
-
try {
|
|
10035
|
-
res = await this.program.forward(aiService, ex, aiOpt);
|
|
10036
|
-
if (this.costMonitoring) {
|
|
10037
|
-
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
10038
|
-
}
|
|
10039
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
10040
|
-
const success = score >= 0.5;
|
|
10041
|
-
if (success) {
|
|
10042
|
-
this.traces = [...this.traces, ...this.program.getTraces()];
|
|
10043
|
-
this.stats.successfulDemos++;
|
|
10044
|
-
}
|
|
10045
|
-
} catch (err) {
|
|
10046
|
-
error = err;
|
|
10047
|
-
res = {};
|
|
10048
|
-
}
|
|
10049
|
-
const current = i + examples.length * roundIndex + (batch.indexOf(ex) + 1);
|
|
10050
|
-
const total = examples.length * this.maxRounds;
|
|
10051
|
-
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
10052
|
-
if (this.verboseMode || this.debugMode) {
|
|
10053
|
-
const configInfo = {
|
|
10054
|
-
maxRounds: this.maxRounds,
|
|
10055
|
-
batchSize: this.batchSize,
|
|
10056
|
-
earlyStoppingPatience: this.earlyStoppingPatience,
|
|
10057
|
-
costMonitoring: this.costMonitoring,
|
|
10058
|
-
verboseMode: this.verboseMode,
|
|
10059
|
-
debugMode: this.debugMode
|
|
10060
|
-
};
|
|
10061
|
-
updateDetailedProgress(
|
|
10062
|
-
roundIndex,
|
|
10063
|
-
current,
|
|
10064
|
-
total,
|
|
10065
|
-
et,
|
|
10066
|
-
ex,
|
|
10067
|
-
this.stats,
|
|
10068
|
-
configInfo,
|
|
10069
|
-
res,
|
|
10070
|
-
error
|
|
10071
|
-
);
|
|
10072
|
-
} else {
|
|
10073
|
-
updateProgressBar(
|
|
10074
|
-
current,
|
|
10075
|
-
total,
|
|
10076
|
-
this.traces.length,
|
|
10077
|
-
et,
|
|
10078
|
-
"Tuning Prompt",
|
|
10079
|
-
30
|
|
10080
|
-
);
|
|
10081
|
-
}
|
|
10082
|
-
if (this.traces.length >= maxDemos) {
|
|
10083
|
-
return;
|
|
10084
|
-
}
|
|
10085
|
-
}
|
|
10086
|
-
}
|
|
10087
|
-
if (this.earlyStoppingPatience > 0) {
|
|
10088
|
-
const newSuccessCount = this.traces.length;
|
|
10089
|
-
const improvement = newSuccessCount - previousSuccessCount;
|
|
10090
|
-
if (!this.stats.earlyStopping) {
|
|
10091
|
-
this.stats.earlyStopping = {
|
|
10092
|
-
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
10093
|
-
patienceExhausted: false
|
|
10094
|
-
};
|
|
10095
|
-
} else if (improvement > 0) {
|
|
10096
|
-
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
10097
|
-
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
10098
|
-
this.stats.earlyStopping.patienceExhausted = true;
|
|
10099
|
-
this.stats.earlyStopped = true;
|
|
10100
|
-
if (this.verboseMode || this.debugMode) {
|
|
10101
|
-
console.log(
|
|
10102
|
-
`
|
|
10103
|
-
Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${this.earlyStoppingPatience} rounds.`
|
|
10104
|
-
);
|
|
10105
|
-
}
|
|
10106
|
-
return;
|
|
10107
|
-
}
|
|
10108
|
-
}
|
|
10109
|
-
}
|
|
10110
|
-
async compile(metricFn, options) {
|
|
10111
|
-
const maxRounds = options?.maxRounds ?? this.maxRounds;
|
|
10112
|
-
this.traces = [];
|
|
10113
|
-
this.stats = {
|
|
10114
|
-
totalCalls: 0,
|
|
10115
|
-
successfulDemos: 0,
|
|
10116
|
-
estimatedTokenUsage: 0,
|
|
10117
|
-
earlyStopped: false
|
|
10118
|
-
};
|
|
10119
|
-
for (let i = 0; i < maxRounds; i++) {
|
|
10120
|
-
await this.compileRound(i, metricFn, options);
|
|
10121
|
-
if (this.stats.earlyStopped) {
|
|
10122
|
-
break;
|
|
10123
|
-
}
|
|
10124
|
-
}
|
|
10125
|
-
if (this.traces.length === 0) {
|
|
10126
|
-
throw new Error(
|
|
10127
|
-
"No demonstrations found. Either provider more examples or improve the existing ones."
|
|
10128
|
-
);
|
|
10129
|
-
}
|
|
10130
|
-
const demos = groupTracesByKeys(this.traces);
|
|
10131
|
-
return {
|
|
10132
|
-
demos,
|
|
10133
|
-
stats: this.stats
|
|
10134
|
-
};
|
|
10135
|
-
}
|
|
10136
|
-
// Get optimization statistics
|
|
10137
|
-
getStats() {
|
|
10138
|
-
return this.stats;
|
|
10139
|
-
}
|
|
10140
|
-
};
|
|
10141
|
-
function groupTracesByKeys(programTraces) {
|
|
10142
|
-
const groupedTraces = /* @__PURE__ */ new Map();
|
|
10143
|
-
for (const programTrace of programTraces) {
|
|
10144
|
-
if (groupedTraces.has(programTrace.programId)) {
|
|
10145
|
-
const traces = groupedTraces.get(programTrace.programId);
|
|
10146
|
-
if (traces) {
|
|
10147
|
-
traces.push(programTrace.trace);
|
|
10148
|
-
}
|
|
10149
|
-
} else {
|
|
10150
|
-
groupedTraces.set(programTrace.programId, [programTrace.trace]);
|
|
10151
|
-
}
|
|
10152
|
-
}
|
|
10153
|
-
const programDemosArray = [];
|
|
10154
|
-
for (const [programId, traces] of groupedTraces.entries()) {
|
|
10155
|
-
programDemosArray.push({ traces, programId });
|
|
10156
|
-
}
|
|
10157
|
-
return programDemosArray;
|
|
10158
|
-
}
|
|
10159
|
-
var randomSample = (array, n) => {
|
|
10160
|
-
const clonedArray = [...array];
|
|
10161
|
-
for (let i = clonedArray.length - 1; i > 0; i--) {
|
|
10162
|
-
const j = Math.floor(Math.random() * (i + 1));
|
|
10163
|
-
const caI = clonedArray[i];
|
|
10164
|
-
const caJ = clonedArray[j];
|
|
10165
|
-
if (!caI || !caJ) {
|
|
10166
|
-
throw new Error("Invalid array elements");
|
|
10167
|
-
}
|
|
10168
|
-
;
|
|
10169
|
-
[clonedArray[i], clonedArray[j]] = [caJ, caI];
|
|
10170
|
-
}
|
|
10171
|
-
return clonedArray.slice(0, n);
|
|
10172
|
-
};
|
|
10173
|
-
|
|
10174
9993
|
// db/base.ts
|
|
10175
9994
|
import { SpanKind as SpanKind3 } from "@opentelemetry/api";
|
|
10176
9995
|
var AxDBBase = class {
|
|
@@ -11629,7 +11448,222 @@ var AxMCPStreambleHTTPTransport = class {
|
|
|
11629
11448
|
}
|
|
11630
11449
|
};
|
|
11631
11450
|
|
|
11632
|
-
// dsp/
|
|
11451
|
+
// dsp/optimizers/bootstrapFewshot.ts
|
|
11452
|
+
var AxBootstrapFewShot = class {
|
|
11453
|
+
ai;
|
|
11454
|
+
teacherAI;
|
|
11455
|
+
program;
|
|
11456
|
+
examples;
|
|
11457
|
+
maxRounds;
|
|
11458
|
+
maxDemos;
|
|
11459
|
+
maxExamples;
|
|
11460
|
+
batchSize;
|
|
11461
|
+
earlyStoppingPatience;
|
|
11462
|
+
costMonitoring;
|
|
11463
|
+
maxTokensPerGeneration;
|
|
11464
|
+
verboseMode;
|
|
11465
|
+
debugMode;
|
|
11466
|
+
traces = [];
|
|
11467
|
+
stats = {
|
|
11468
|
+
totalCalls: 0,
|
|
11469
|
+
successfulDemos: 0,
|
|
11470
|
+
estimatedTokenUsage: 0,
|
|
11471
|
+
earlyStopped: false
|
|
11472
|
+
};
|
|
11473
|
+
constructor({
|
|
11474
|
+
ai,
|
|
11475
|
+
program,
|
|
11476
|
+
examples = [],
|
|
11477
|
+
options
|
|
11478
|
+
}) {
|
|
11479
|
+
if (examples.length === 0) {
|
|
11480
|
+
throw new Error("No examples found");
|
|
11481
|
+
}
|
|
11482
|
+
const bootstrapOptions = options;
|
|
11483
|
+
this.maxRounds = bootstrapOptions?.maxRounds ?? 3;
|
|
11484
|
+
this.maxDemos = bootstrapOptions?.maxDemos ?? 4;
|
|
11485
|
+
this.maxExamples = bootstrapOptions?.maxExamples ?? 16;
|
|
11486
|
+
this.batchSize = bootstrapOptions?.batchSize ?? 1;
|
|
11487
|
+
this.earlyStoppingPatience = bootstrapOptions?.earlyStoppingPatience ?? 0;
|
|
11488
|
+
this.costMonitoring = bootstrapOptions?.costMonitoring ?? false;
|
|
11489
|
+
this.maxTokensPerGeneration = bootstrapOptions?.maxTokensPerGeneration ?? 0;
|
|
11490
|
+
this.verboseMode = bootstrapOptions?.verboseMode ?? true;
|
|
11491
|
+
this.debugMode = bootstrapOptions?.debugMode ?? false;
|
|
11492
|
+
this.ai = ai;
|
|
11493
|
+
this.teacherAI = bootstrapOptions?.teacherAI;
|
|
11494
|
+
this.program = program;
|
|
11495
|
+
this.examples = examples;
|
|
11496
|
+
}
|
|
11497
|
+
async compileRound(roundIndex, metricFn, options) {
|
|
11498
|
+
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
11499
|
+
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
11500
|
+
const aiOpt = {
|
|
11501
|
+
modelConfig: {
|
|
11502
|
+
temperature: 0.7
|
|
11503
|
+
}
|
|
11504
|
+
};
|
|
11505
|
+
if (this.maxTokensPerGeneration > 0) {
|
|
11506
|
+
aiOpt.modelConfig.max_tokens = this.maxTokensPerGeneration;
|
|
11507
|
+
}
|
|
11508
|
+
const examples = randomSample(this.examples, this.maxExamples);
|
|
11509
|
+
const previousSuccessCount = this.traces.length;
|
|
11510
|
+
for (let i = 0; i < examples.length; i += this.batchSize) {
|
|
11511
|
+
if (i > 0) {
|
|
11512
|
+
aiOpt.modelConfig.temperature = 0.7 + 1e-3 * i;
|
|
11513
|
+
}
|
|
11514
|
+
const batch = examples.slice(i, i + this.batchSize);
|
|
11515
|
+
for (const ex of batch) {
|
|
11516
|
+
if (!ex) {
|
|
11517
|
+
continue;
|
|
11518
|
+
}
|
|
11519
|
+
const exList = examples.filter((e) => e !== ex);
|
|
11520
|
+
this.program.setExamples(exList);
|
|
11521
|
+
const aiService = this.teacherAI || this.ai;
|
|
11522
|
+
this.stats.totalCalls++;
|
|
11523
|
+
let res;
|
|
11524
|
+
let error;
|
|
11525
|
+
try {
|
|
11526
|
+
res = await this.program.forward(aiService, ex, aiOpt);
|
|
11527
|
+
if (this.costMonitoring) {
|
|
11528
|
+
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
11529
|
+
}
|
|
11530
|
+
const score = metricFn({ prediction: res, example: ex });
|
|
11531
|
+
const success = score >= 0.5;
|
|
11532
|
+
if (success) {
|
|
11533
|
+
this.traces = [...this.traces, ...this.program.getTraces()];
|
|
11534
|
+
this.stats.successfulDemos++;
|
|
11535
|
+
}
|
|
11536
|
+
} catch (err) {
|
|
11537
|
+
error = err;
|
|
11538
|
+
res = {};
|
|
11539
|
+
}
|
|
11540
|
+
const current = i + examples.length * roundIndex + (batch.indexOf(ex) + 1);
|
|
11541
|
+
const total = examples.length * this.maxRounds;
|
|
11542
|
+
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
11543
|
+
if (this.verboseMode || this.debugMode) {
|
|
11544
|
+
const configInfo = {
|
|
11545
|
+
maxRounds: this.maxRounds,
|
|
11546
|
+
batchSize: this.batchSize,
|
|
11547
|
+
earlyStoppingPatience: this.earlyStoppingPatience,
|
|
11548
|
+
costMonitoring: this.costMonitoring,
|
|
11549
|
+
verboseMode: this.verboseMode,
|
|
11550
|
+
debugMode: this.debugMode
|
|
11551
|
+
};
|
|
11552
|
+
updateDetailedProgress(
|
|
11553
|
+
roundIndex,
|
|
11554
|
+
current,
|
|
11555
|
+
total,
|
|
11556
|
+
et,
|
|
11557
|
+
ex,
|
|
11558
|
+
this.stats,
|
|
11559
|
+
configInfo,
|
|
11560
|
+
res,
|
|
11561
|
+
error
|
|
11562
|
+
);
|
|
11563
|
+
} else {
|
|
11564
|
+
updateProgressBar(
|
|
11565
|
+
current,
|
|
11566
|
+
total,
|
|
11567
|
+
this.traces.length,
|
|
11568
|
+
et,
|
|
11569
|
+
"Tuning Prompt",
|
|
11570
|
+
30
|
|
11571
|
+
);
|
|
11572
|
+
}
|
|
11573
|
+
if (this.traces.length >= maxDemos) {
|
|
11574
|
+
return;
|
|
11575
|
+
}
|
|
11576
|
+
}
|
|
11577
|
+
}
|
|
11578
|
+
if (this.earlyStoppingPatience > 0) {
|
|
11579
|
+
const newSuccessCount = this.traces.length;
|
|
11580
|
+
const improvement = newSuccessCount - previousSuccessCount;
|
|
11581
|
+
if (!this.stats.earlyStopping) {
|
|
11582
|
+
this.stats.earlyStopping = {
|
|
11583
|
+
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
11584
|
+
patienceExhausted: false
|
|
11585
|
+
};
|
|
11586
|
+
} else if (improvement > 0) {
|
|
11587
|
+
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
11588
|
+
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
11589
|
+
this.stats.earlyStopping.patienceExhausted = true;
|
|
11590
|
+
this.stats.earlyStopped = true;
|
|
11591
|
+
if (this.verboseMode || this.debugMode) {
|
|
11592
|
+
console.log(
|
|
11593
|
+
`
|
|
11594
|
+
Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${this.earlyStoppingPatience} rounds.`
|
|
11595
|
+
);
|
|
11596
|
+
}
|
|
11597
|
+
return;
|
|
11598
|
+
}
|
|
11599
|
+
}
|
|
11600
|
+
}
|
|
11601
|
+
async compile(metricFn, options) {
|
|
11602
|
+
const compileOptions = options;
|
|
11603
|
+
const maxRounds = compileOptions?.maxRounds ?? this.maxRounds;
|
|
11604
|
+
this.traces = [];
|
|
11605
|
+
this.stats = {
|
|
11606
|
+
totalCalls: 0,
|
|
11607
|
+
successfulDemos: 0,
|
|
11608
|
+
estimatedTokenUsage: 0,
|
|
11609
|
+
earlyStopped: false
|
|
11610
|
+
};
|
|
11611
|
+
for (let i = 0; i < maxRounds; i++) {
|
|
11612
|
+
await this.compileRound(i, metricFn, compileOptions);
|
|
11613
|
+
if (this.stats.earlyStopped) {
|
|
11614
|
+
break;
|
|
11615
|
+
}
|
|
11616
|
+
}
|
|
11617
|
+
if (this.traces.length === 0) {
|
|
11618
|
+
throw new Error(
|
|
11619
|
+
"No demonstrations found. Either provider more examples or improve the existing ones."
|
|
11620
|
+
);
|
|
11621
|
+
}
|
|
11622
|
+
const demos = groupTracesByKeys(this.traces);
|
|
11623
|
+
return {
|
|
11624
|
+
demos,
|
|
11625
|
+
stats: this.stats
|
|
11626
|
+
};
|
|
11627
|
+
}
|
|
11628
|
+
// Get optimization statistics
|
|
11629
|
+
getStats() {
|
|
11630
|
+
return this.stats;
|
|
11631
|
+
}
|
|
11632
|
+
};
|
|
11633
|
+
function groupTracesByKeys(programTraces) {
|
|
11634
|
+
const groupedTraces = /* @__PURE__ */ new Map();
|
|
11635
|
+
for (const programTrace of programTraces) {
|
|
11636
|
+
if (groupedTraces.has(programTrace.programId)) {
|
|
11637
|
+
const traces = groupedTraces.get(programTrace.programId);
|
|
11638
|
+
if (traces) {
|
|
11639
|
+
traces.push(programTrace.trace);
|
|
11640
|
+
}
|
|
11641
|
+
} else {
|
|
11642
|
+
groupedTraces.set(programTrace.programId, [programTrace.trace]);
|
|
11643
|
+
}
|
|
11644
|
+
}
|
|
11645
|
+
const programDemosArray = [];
|
|
11646
|
+
for (const [programId, traces] of groupedTraces.entries()) {
|
|
11647
|
+
programDemosArray.push({ traces, programId });
|
|
11648
|
+
}
|
|
11649
|
+
return programDemosArray;
|
|
11650
|
+
}
|
|
11651
|
+
var randomSample = (array, n) => {
|
|
11652
|
+
const clonedArray = [...array];
|
|
11653
|
+
for (let i = clonedArray.length - 1; i > 0; i--) {
|
|
11654
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
11655
|
+
const caI = clonedArray[i];
|
|
11656
|
+
const caJ = clonedArray[j];
|
|
11657
|
+
if (!caI || !caJ) {
|
|
11658
|
+
throw new Error("Invalid array elements");
|
|
11659
|
+
}
|
|
11660
|
+
;
|
|
11661
|
+
[clonedArray[i], clonedArray[j]] = [caJ, caI];
|
|
11662
|
+
}
|
|
11663
|
+
return clonedArray.slice(0, n);
|
|
11664
|
+
};
|
|
11665
|
+
|
|
11666
|
+
// dsp/optimizers/miproV2.ts
|
|
11633
11667
|
var AxMiPRO = class {
|
|
11634
11668
|
ai;
|
|
11635
11669
|
program;
|
|
@@ -11855,7 +11889,7 @@ ${dataContext}
|
|
|
11855
11889
|
const result = await this.bootstrapper.compile(metricFn, {
|
|
11856
11890
|
maxDemos: this.maxBootstrappedDemos
|
|
11857
11891
|
});
|
|
11858
|
-
return result.demos;
|
|
11892
|
+
return result.demos || [];
|
|
11859
11893
|
}
|
|
11860
11894
|
/**
|
|
11861
11895
|
* Selects labeled examples directly from the training set
|
|
@@ -12189,21 +12223,22 @@ ${dataContext}
|
|
|
12189
12223
|
* The main compile method to run MIPROv2 optimization
|
|
12190
12224
|
* @param metricFn Evaluation metric function
|
|
12191
12225
|
* @param options Optional configuration options
|
|
12192
|
-
* @returns The
|
|
12226
|
+
* @returns The optimization result
|
|
12193
12227
|
*/
|
|
12194
12228
|
async compile(metricFn, options) {
|
|
12195
|
-
|
|
12196
|
-
|
|
12229
|
+
const miproOptions = options;
|
|
12230
|
+
if (miproOptions?.auto) {
|
|
12231
|
+
this.configureAuto(miproOptions.auto);
|
|
12197
12232
|
}
|
|
12198
12233
|
const trainset = this.examples;
|
|
12199
|
-
const valset =
|
|
12234
|
+
const valset = miproOptions?.valset || this.examples.slice(0, Math.floor(this.examples.length * 0.8));
|
|
12200
12235
|
if (this.verbose) {
|
|
12201
12236
|
console.log(`Starting MIPROv2 optimization with ${this.numTrials} trials`);
|
|
12202
12237
|
console.log(
|
|
12203
12238
|
`Using ${trainset.length} examples for training and ${valset.length} for validation`
|
|
12204
12239
|
);
|
|
12205
12240
|
}
|
|
12206
|
-
if (
|
|
12241
|
+
if (miproOptions?.teacher) {
|
|
12207
12242
|
if (this.verbose) {
|
|
12208
12243
|
console.log("Using provided teacher to assist with bootstrapping");
|
|
12209
12244
|
}
|
|
@@ -12260,7 +12295,17 @@ ${dataContext}
|
|
|
12260
12295
|
bootstrappedDemos,
|
|
12261
12296
|
labeledExamples
|
|
12262
12297
|
);
|
|
12263
|
-
return
|
|
12298
|
+
return {
|
|
12299
|
+
program: this.program,
|
|
12300
|
+
demos: bootstrappedDemos
|
|
12301
|
+
};
|
|
12302
|
+
}
|
|
12303
|
+
/**
|
|
12304
|
+
* Get optimization statistics from the internal bootstrapper
|
|
12305
|
+
* @returns Optimization statistics or undefined if not available
|
|
12306
|
+
*/
|
|
12307
|
+
getStats() {
|
|
12308
|
+
return this.bootstrapper.getStats();
|
|
12264
12309
|
}
|
|
12265
12310
|
};
|
|
12266
12311
|
|