@ax-llm/ax 12.0.4 → 12.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +236 -223
- package/index.cjs.map +1 -1
- package/index.d.cts +113 -66
- package/index.d.ts +113 -66
- package/index.js +236 -223
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.cjs
CHANGED
|
@@ -8364,7 +8364,7 @@ var AxSignature = class _AxSignature {
|
|
|
8364
8364
|
});
|
|
8365
8365
|
this.inputFields = parsedFields;
|
|
8366
8366
|
this.invalidateValidationCache();
|
|
8367
|
-
this.
|
|
8367
|
+
this.updateHashLight();
|
|
8368
8368
|
} catch (error) {
|
|
8369
8369
|
if (error instanceof AxSignatureValidationError) {
|
|
8370
8370
|
throw error;
|
|
@@ -8390,7 +8390,7 @@ var AxSignature = class _AxSignature {
|
|
|
8390
8390
|
});
|
|
8391
8391
|
this.outputFields = parsedFields;
|
|
8392
8392
|
this.invalidateValidationCache();
|
|
8393
|
-
this.
|
|
8393
|
+
this.updateHashLight();
|
|
8394
8394
|
} catch (error) {
|
|
8395
8395
|
if (error instanceof AxSignatureValidationError) {
|
|
8396
8396
|
throw error;
|
|
@@ -10158,219 +10158,6 @@ function validateModels2(services) {
|
|
|
10158
10158
|
}
|
|
10159
10159
|
}
|
|
10160
10160
|
|
|
10161
|
-
// dsp/optimize.ts
|
|
10162
|
-
var AxBootstrapFewShot = class {
|
|
10163
|
-
ai;
|
|
10164
|
-
teacherAI;
|
|
10165
|
-
program;
|
|
10166
|
-
examples;
|
|
10167
|
-
maxRounds;
|
|
10168
|
-
maxDemos;
|
|
10169
|
-
maxExamples;
|
|
10170
|
-
batchSize;
|
|
10171
|
-
earlyStoppingPatience;
|
|
10172
|
-
costMonitoring;
|
|
10173
|
-
maxTokensPerGeneration;
|
|
10174
|
-
verboseMode;
|
|
10175
|
-
debugMode;
|
|
10176
|
-
traces = [];
|
|
10177
|
-
stats = {
|
|
10178
|
-
totalCalls: 0,
|
|
10179
|
-
successfulDemos: 0,
|
|
10180
|
-
estimatedTokenUsage: 0,
|
|
10181
|
-
earlyStopped: false
|
|
10182
|
-
};
|
|
10183
|
-
constructor({
|
|
10184
|
-
ai,
|
|
10185
|
-
program,
|
|
10186
|
-
examples = [],
|
|
10187
|
-
options
|
|
10188
|
-
}) {
|
|
10189
|
-
if (examples.length === 0) {
|
|
10190
|
-
throw new Error("No examples found");
|
|
10191
|
-
}
|
|
10192
|
-
this.maxRounds = options?.maxRounds ?? 3;
|
|
10193
|
-
this.maxDemos = options?.maxDemos ?? 4;
|
|
10194
|
-
this.maxExamples = options?.maxExamples ?? 16;
|
|
10195
|
-
this.batchSize = options?.batchSize ?? 1;
|
|
10196
|
-
this.earlyStoppingPatience = options?.earlyStoppingPatience ?? 0;
|
|
10197
|
-
this.costMonitoring = options?.costMonitoring ?? false;
|
|
10198
|
-
this.maxTokensPerGeneration = options?.maxTokensPerGeneration ?? 0;
|
|
10199
|
-
this.verboseMode = options?.verboseMode ?? true;
|
|
10200
|
-
this.debugMode = options?.debugMode ?? false;
|
|
10201
|
-
this.ai = ai;
|
|
10202
|
-
this.teacherAI = options?.teacherAI;
|
|
10203
|
-
this.program = program;
|
|
10204
|
-
this.examples = examples;
|
|
10205
|
-
}
|
|
10206
|
-
async compileRound(roundIndex, metricFn, options) {
|
|
10207
|
-
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
10208
|
-
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
10209
|
-
const aiOpt = {
|
|
10210
|
-
modelConfig: {
|
|
10211
|
-
temperature: 0.7
|
|
10212
|
-
}
|
|
10213
|
-
};
|
|
10214
|
-
if (this.maxTokensPerGeneration > 0) {
|
|
10215
|
-
aiOpt.modelConfig.max_tokens = this.maxTokensPerGeneration;
|
|
10216
|
-
}
|
|
10217
|
-
const examples = randomSample(this.examples, this.maxExamples);
|
|
10218
|
-
const previousSuccessCount = this.traces.length;
|
|
10219
|
-
for (let i = 0; i < examples.length; i += this.batchSize) {
|
|
10220
|
-
if (i > 0) {
|
|
10221
|
-
aiOpt.modelConfig.temperature = 0.7 + 1e-3 * i;
|
|
10222
|
-
}
|
|
10223
|
-
const batch = examples.slice(i, i + this.batchSize);
|
|
10224
|
-
for (const ex of batch) {
|
|
10225
|
-
if (!ex) {
|
|
10226
|
-
continue;
|
|
10227
|
-
}
|
|
10228
|
-
const exList = examples.filter((e) => e !== ex);
|
|
10229
|
-
this.program.setExamples(exList);
|
|
10230
|
-
const aiService = this.teacherAI || this.ai;
|
|
10231
|
-
this.stats.totalCalls++;
|
|
10232
|
-
let res;
|
|
10233
|
-
let error;
|
|
10234
|
-
try {
|
|
10235
|
-
res = await this.program.forward(aiService, ex, aiOpt);
|
|
10236
|
-
if (this.costMonitoring) {
|
|
10237
|
-
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
10238
|
-
}
|
|
10239
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
10240
|
-
const success = score >= 0.5;
|
|
10241
|
-
if (success) {
|
|
10242
|
-
this.traces = [...this.traces, ...this.program.getTraces()];
|
|
10243
|
-
this.stats.successfulDemos++;
|
|
10244
|
-
}
|
|
10245
|
-
} catch (err) {
|
|
10246
|
-
error = err;
|
|
10247
|
-
res = {};
|
|
10248
|
-
}
|
|
10249
|
-
const current = i + examples.length * roundIndex + (batch.indexOf(ex) + 1);
|
|
10250
|
-
const total = examples.length * this.maxRounds;
|
|
10251
|
-
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
10252
|
-
if (this.verboseMode || this.debugMode) {
|
|
10253
|
-
const configInfo = {
|
|
10254
|
-
maxRounds: this.maxRounds,
|
|
10255
|
-
batchSize: this.batchSize,
|
|
10256
|
-
earlyStoppingPatience: this.earlyStoppingPatience,
|
|
10257
|
-
costMonitoring: this.costMonitoring,
|
|
10258
|
-
verboseMode: this.verboseMode,
|
|
10259
|
-
debugMode: this.debugMode
|
|
10260
|
-
};
|
|
10261
|
-
updateDetailedProgress(
|
|
10262
|
-
roundIndex,
|
|
10263
|
-
current,
|
|
10264
|
-
total,
|
|
10265
|
-
et,
|
|
10266
|
-
ex,
|
|
10267
|
-
this.stats,
|
|
10268
|
-
configInfo,
|
|
10269
|
-
res,
|
|
10270
|
-
error
|
|
10271
|
-
);
|
|
10272
|
-
} else {
|
|
10273
|
-
updateProgressBar(
|
|
10274
|
-
current,
|
|
10275
|
-
total,
|
|
10276
|
-
this.traces.length,
|
|
10277
|
-
et,
|
|
10278
|
-
"Tuning Prompt",
|
|
10279
|
-
30
|
|
10280
|
-
);
|
|
10281
|
-
}
|
|
10282
|
-
if (this.traces.length >= maxDemos) {
|
|
10283
|
-
return;
|
|
10284
|
-
}
|
|
10285
|
-
}
|
|
10286
|
-
}
|
|
10287
|
-
if (this.earlyStoppingPatience > 0) {
|
|
10288
|
-
const newSuccessCount = this.traces.length;
|
|
10289
|
-
const improvement = newSuccessCount - previousSuccessCount;
|
|
10290
|
-
if (!this.stats.earlyStopping) {
|
|
10291
|
-
this.stats.earlyStopping = {
|
|
10292
|
-
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
10293
|
-
patienceExhausted: false
|
|
10294
|
-
};
|
|
10295
|
-
} else if (improvement > 0) {
|
|
10296
|
-
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
10297
|
-
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
10298
|
-
this.stats.earlyStopping.patienceExhausted = true;
|
|
10299
|
-
this.stats.earlyStopped = true;
|
|
10300
|
-
if (this.verboseMode || this.debugMode) {
|
|
10301
|
-
console.log(
|
|
10302
|
-
`
|
|
10303
|
-
Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${this.earlyStoppingPatience} rounds.`
|
|
10304
|
-
);
|
|
10305
|
-
}
|
|
10306
|
-
return;
|
|
10307
|
-
}
|
|
10308
|
-
}
|
|
10309
|
-
}
|
|
10310
|
-
async compile(metricFn, options) {
|
|
10311
|
-
const maxRounds = options?.maxRounds ?? this.maxRounds;
|
|
10312
|
-
this.traces = [];
|
|
10313
|
-
this.stats = {
|
|
10314
|
-
totalCalls: 0,
|
|
10315
|
-
successfulDemos: 0,
|
|
10316
|
-
estimatedTokenUsage: 0,
|
|
10317
|
-
earlyStopped: false
|
|
10318
|
-
};
|
|
10319
|
-
for (let i = 0; i < maxRounds; i++) {
|
|
10320
|
-
await this.compileRound(i, metricFn, options);
|
|
10321
|
-
if (this.stats.earlyStopped) {
|
|
10322
|
-
break;
|
|
10323
|
-
}
|
|
10324
|
-
}
|
|
10325
|
-
if (this.traces.length === 0) {
|
|
10326
|
-
throw new Error(
|
|
10327
|
-
"No demonstrations found. Either provider more examples or improve the existing ones."
|
|
10328
|
-
);
|
|
10329
|
-
}
|
|
10330
|
-
const demos = groupTracesByKeys(this.traces);
|
|
10331
|
-
return {
|
|
10332
|
-
demos,
|
|
10333
|
-
stats: this.stats
|
|
10334
|
-
};
|
|
10335
|
-
}
|
|
10336
|
-
// Get optimization statistics
|
|
10337
|
-
getStats() {
|
|
10338
|
-
return this.stats;
|
|
10339
|
-
}
|
|
10340
|
-
};
|
|
10341
|
-
function groupTracesByKeys(programTraces) {
|
|
10342
|
-
const groupedTraces = /* @__PURE__ */ new Map();
|
|
10343
|
-
for (const programTrace of programTraces) {
|
|
10344
|
-
if (groupedTraces.has(programTrace.programId)) {
|
|
10345
|
-
const traces = groupedTraces.get(programTrace.programId);
|
|
10346
|
-
if (traces) {
|
|
10347
|
-
traces.push(programTrace.trace);
|
|
10348
|
-
}
|
|
10349
|
-
} else {
|
|
10350
|
-
groupedTraces.set(programTrace.programId, [programTrace.trace]);
|
|
10351
|
-
}
|
|
10352
|
-
}
|
|
10353
|
-
const programDemosArray = [];
|
|
10354
|
-
for (const [programId, traces] of groupedTraces.entries()) {
|
|
10355
|
-
programDemosArray.push({ traces, programId });
|
|
10356
|
-
}
|
|
10357
|
-
return programDemosArray;
|
|
10358
|
-
}
|
|
10359
|
-
var randomSample = (array, n) => {
|
|
10360
|
-
const clonedArray = [...array];
|
|
10361
|
-
for (let i = clonedArray.length - 1; i > 0; i--) {
|
|
10362
|
-
const j = Math.floor(Math.random() * (i + 1));
|
|
10363
|
-
const caI = clonedArray[i];
|
|
10364
|
-
const caJ = clonedArray[j];
|
|
10365
|
-
if (!caI || !caJ) {
|
|
10366
|
-
throw new Error("Invalid array elements");
|
|
10367
|
-
}
|
|
10368
|
-
;
|
|
10369
|
-
[clonedArray[i], clonedArray[j]] = [caJ, caI];
|
|
10370
|
-
}
|
|
10371
|
-
return clonedArray.slice(0, n);
|
|
10372
|
-
};
|
|
10373
|
-
|
|
10374
10161
|
// db/base.ts
|
|
10375
10162
|
var import_api23 = require("@opentelemetry/api");
|
|
10376
10163
|
var AxDBBase = class {
|
|
@@ -11829,7 +11616,222 @@ var AxMCPStreambleHTTPTransport = class {
|
|
|
11829
11616
|
}
|
|
11830
11617
|
};
|
|
11831
11618
|
|
|
11832
|
-
// dsp/
|
|
11619
|
+
// dsp/optimizers/bootstrapFewshot.ts
|
|
11620
|
+
var AxBootstrapFewShot = class {
|
|
11621
|
+
ai;
|
|
11622
|
+
teacherAI;
|
|
11623
|
+
program;
|
|
11624
|
+
examples;
|
|
11625
|
+
maxRounds;
|
|
11626
|
+
maxDemos;
|
|
11627
|
+
maxExamples;
|
|
11628
|
+
batchSize;
|
|
11629
|
+
earlyStoppingPatience;
|
|
11630
|
+
costMonitoring;
|
|
11631
|
+
maxTokensPerGeneration;
|
|
11632
|
+
verboseMode;
|
|
11633
|
+
debugMode;
|
|
11634
|
+
traces = [];
|
|
11635
|
+
stats = {
|
|
11636
|
+
totalCalls: 0,
|
|
11637
|
+
successfulDemos: 0,
|
|
11638
|
+
estimatedTokenUsage: 0,
|
|
11639
|
+
earlyStopped: false
|
|
11640
|
+
};
|
|
11641
|
+
constructor({
|
|
11642
|
+
ai,
|
|
11643
|
+
program,
|
|
11644
|
+
examples = [],
|
|
11645
|
+
options
|
|
11646
|
+
}) {
|
|
11647
|
+
if (examples.length === 0) {
|
|
11648
|
+
throw new Error("No examples found");
|
|
11649
|
+
}
|
|
11650
|
+
const bootstrapOptions = options;
|
|
11651
|
+
this.maxRounds = bootstrapOptions?.maxRounds ?? 3;
|
|
11652
|
+
this.maxDemos = bootstrapOptions?.maxDemos ?? 4;
|
|
11653
|
+
this.maxExamples = bootstrapOptions?.maxExamples ?? 16;
|
|
11654
|
+
this.batchSize = bootstrapOptions?.batchSize ?? 1;
|
|
11655
|
+
this.earlyStoppingPatience = bootstrapOptions?.earlyStoppingPatience ?? 0;
|
|
11656
|
+
this.costMonitoring = bootstrapOptions?.costMonitoring ?? false;
|
|
11657
|
+
this.maxTokensPerGeneration = bootstrapOptions?.maxTokensPerGeneration ?? 0;
|
|
11658
|
+
this.verboseMode = bootstrapOptions?.verboseMode ?? true;
|
|
11659
|
+
this.debugMode = bootstrapOptions?.debugMode ?? false;
|
|
11660
|
+
this.ai = ai;
|
|
11661
|
+
this.teacherAI = bootstrapOptions?.teacherAI;
|
|
11662
|
+
this.program = program;
|
|
11663
|
+
this.examples = examples;
|
|
11664
|
+
}
|
|
11665
|
+
async compileRound(roundIndex, metricFn, options) {
|
|
11666
|
+
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
11667
|
+
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
11668
|
+
const aiOpt = {
|
|
11669
|
+
modelConfig: {
|
|
11670
|
+
temperature: 0.7
|
|
11671
|
+
}
|
|
11672
|
+
};
|
|
11673
|
+
if (this.maxTokensPerGeneration > 0) {
|
|
11674
|
+
aiOpt.modelConfig.max_tokens = this.maxTokensPerGeneration;
|
|
11675
|
+
}
|
|
11676
|
+
const examples = randomSample(this.examples, this.maxExamples);
|
|
11677
|
+
const previousSuccessCount = this.traces.length;
|
|
11678
|
+
for (let i = 0; i < examples.length; i += this.batchSize) {
|
|
11679
|
+
if (i > 0) {
|
|
11680
|
+
aiOpt.modelConfig.temperature = 0.7 + 1e-3 * i;
|
|
11681
|
+
}
|
|
11682
|
+
const batch = examples.slice(i, i + this.batchSize);
|
|
11683
|
+
for (const ex of batch) {
|
|
11684
|
+
if (!ex) {
|
|
11685
|
+
continue;
|
|
11686
|
+
}
|
|
11687
|
+
const exList = examples.filter((e) => e !== ex);
|
|
11688
|
+
this.program.setExamples(exList);
|
|
11689
|
+
const aiService = this.teacherAI || this.ai;
|
|
11690
|
+
this.stats.totalCalls++;
|
|
11691
|
+
let res;
|
|
11692
|
+
let error;
|
|
11693
|
+
try {
|
|
11694
|
+
res = await this.program.forward(aiService, ex, aiOpt);
|
|
11695
|
+
if (this.costMonitoring) {
|
|
11696
|
+
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
11697
|
+
}
|
|
11698
|
+
const score = metricFn({ prediction: res, example: ex });
|
|
11699
|
+
const success = score >= 0.5;
|
|
11700
|
+
if (success) {
|
|
11701
|
+
this.traces = [...this.traces, ...this.program.getTraces()];
|
|
11702
|
+
this.stats.successfulDemos++;
|
|
11703
|
+
}
|
|
11704
|
+
} catch (err) {
|
|
11705
|
+
error = err;
|
|
11706
|
+
res = {};
|
|
11707
|
+
}
|
|
11708
|
+
const current = i + examples.length * roundIndex + (batch.indexOf(ex) + 1);
|
|
11709
|
+
const total = examples.length * this.maxRounds;
|
|
11710
|
+
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
11711
|
+
if (this.verboseMode || this.debugMode) {
|
|
11712
|
+
const configInfo = {
|
|
11713
|
+
maxRounds: this.maxRounds,
|
|
11714
|
+
batchSize: this.batchSize,
|
|
11715
|
+
earlyStoppingPatience: this.earlyStoppingPatience,
|
|
11716
|
+
costMonitoring: this.costMonitoring,
|
|
11717
|
+
verboseMode: this.verboseMode,
|
|
11718
|
+
debugMode: this.debugMode
|
|
11719
|
+
};
|
|
11720
|
+
updateDetailedProgress(
|
|
11721
|
+
roundIndex,
|
|
11722
|
+
current,
|
|
11723
|
+
total,
|
|
11724
|
+
et,
|
|
11725
|
+
ex,
|
|
11726
|
+
this.stats,
|
|
11727
|
+
configInfo,
|
|
11728
|
+
res,
|
|
11729
|
+
error
|
|
11730
|
+
);
|
|
11731
|
+
} else {
|
|
11732
|
+
updateProgressBar(
|
|
11733
|
+
current,
|
|
11734
|
+
total,
|
|
11735
|
+
this.traces.length,
|
|
11736
|
+
et,
|
|
11737
|
+
"Tuning Prompt",
|
|
11738
|
+
30
|
|
11739
|
+
);
|
|
11740
|
+
}
|
|
11741
|
+
if (this.traces.length >= maxDemos) {
|
|
11742
|
+
return;
|
|
11743
|
+
}
|
|
11744
|
+
}
|
|
11745
|
+
}
|
|
11746
|
+
if (this.earlyStoppingPatience > 0) {
|
|
11747
|
+
const newSuccessCount = this.traces.length;
|
|
11748
|
+
const improvement = newSuccessCount - previousSuccessCount;
|
|
11749
|
+
if (!this.stats.earlyStopping) {
|
|
11750
|
+
this.stats.earlyStopping = {
|
|
11751
|
+
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
11752
|
+
patienceExhausted: false
|
|
11753
|
+
};
|
|
11754
|
+
} else if (improvement > 0) {
|
|
11755
|
+
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
11756
|
+
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
11757
|
+
this.stats.earlyStopping.patienceExhausted = true;
|
|
11758
|
+
this.stats.earlyStopped = true;
|
|
11759
|
+
if (this.verboseMode || this.debugMode) {
|
|
11760
|
+
console.log(
|
|
11761
|
+
`
|
|
11762
|
+
Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${this.earlyStoppingPatience} rounds.`
|
|
11763
|
+
);
|
|
11764
|
+
}
|
|
11765
|
+
return;
|
|
11766
|
+
}
|
|
11767
|
+
}
|
|
11768
|
+
}
|
|
11769
|
+
async compile(metricFn, options) {
|
|
11770
|
+
const compileOptions = options;
|
|
11771
|
+
const maxRounds = compileOptions?.maxRounds ?? this.maxRounds;
|
|
11772
|
+
this.traces = [];
|
|
11773
|
+
this.stats = {
|
|
11774
|
+
totalCalls: 0,
|
|
11775
|
+
successfulDemos: 0,
|
|
11776
|
+
estimatedTokenUsage: 0,
|
|
11777
|
+
earlyStopped: false
|
|
11778
|
+
};
|
|
11779
|
+
for (let i = 0; i < maxRounds; i++) {
|
|
11780
|
+
await this.compileRound(i, metricFn, compileOptions);
|
|
11781
|
+
if (this.stats.earlyStopped) {
|
|
11782
|
+
break;
|
|
11783
|
+
}
|
|
11784
|
+
}
|
|
11785
|
+
if (this.traces.length === 0) {
|
|
11786
|
+
throw new Error(
|
|
11787
|
+
"No demonstrations found. Either provider more examples or improve the existing ones."
|
|
11788
|
+
);
|
|
11789
|
+
}
|
|
11790
|
+
const demos = groupTracesByKeys(this.traces);
|
|
11791
|
+
return {
|
|
11792
|
+
demos,
|
|
11793
|
+
stats: this.stats
|
|
11794
|
+
};
|
|
11795
|
+
}
|
|
11796
|
+
// Get optimization statistics
|
|
11797
|
+
getStats() {
|
|
11798
|
+
return this.stats;
|
|
11799
|
+
}
|
|
11800
|
+
};
|
|
11801
|
+
function groupTracesByKeys(programTraces) {
|
|
11802
|
+
const groupedTraces = /* @__PURE__ */ new Map();
|
|
11803
|
+
for (const programTrace of programTraces) {
|
|
11804
|
+
if (groupedTraces.has(programTrace.programId)) {
|
|
11805
|
+
const traces = groupedTraces.get(programTrace.programId);
|
|
11806
|
+
if (traces) {
|
|
11807
|
+
traces.push(programTrace.trace);
|
|
11808
|
+
}
|
|
11809
|
+
} else {
|
|
11810
|
+
groupedTraces.set(programTrace.programId, [programTrace.trace]);
|
|
11811
|
+
}
|
|
11812
|
+
}
|
|
11813
|
+
const programDemosArray = [];
|
|
11814
|
+
for (const [programId, traces] of groupedTraces.entries()) {
|
|
11815
|
+
programDemosArray.push({ traces, programId });
|
|
11816
|
+
}
|
|
11817
|
+
return programDemosArray;
|
|
11818
|
+
}
|
|
11819
|
+
var randomSample = (array, n) => {
|
|
11820
|
+
const clonedArray = [...array];
|
|
11821
|
+
for (let i = clonedArray.length - 1; i > 0; i--) {
|
|
11822
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
11823
|
+
const caI = clonedArray[i];
|
|
11824
|
+
const caJ = clonedArray[j];
|
|
11825
|
+
if (!caI || !caJ) {
|
|
11826
|
+
throw new Error("Invalid array elements");
|
|
11827
|
+
}
|
|
11828
|
+
;
|
|
11829
|
+
[clonedArray[i], clonedArray[j]] = [caJ, caI];
|
|
11830
|
+
}
|
|
11831
|
+
return clonedArray.slice(0, n);
|
|
11832
|
+
};
|
|
11833
|
+
|
|
11834
|
+
// dsp/optimizers/miproV2.ts
|
|
11833
11835
|
var AxMiPRO = class {
|
|
11834
11836
|
ai;
|
|
11835
11837
|
program;
|
|
@@ -12055,7 +12057,7 @@ ${dataContext}
|
|
|
12055
12057
|
const result = await this.bootstrapper.compile(metricFn, {
|
|
12056
12058
|
maxDemos: this.maxBootstrappedDemos
|
|
12057
12059
|
});
|
|
12058
|
-
return result.demos;
|
|
12060
|
+
return result.demos || [];
|
|
12059
12061
|
}
|
|
12060
12062
|
/**
|
|
12061
12063
|
* Selects labeled examples directly from the training set
|
|
@@ -12389,21 +12391,22 @@ ${dataContext}
|
|
|
12389
12391
|
* The main compile method to run MIPROv2 optimization
|
|
12390
12392
|
* @param metricFn Evaluation metric function
|
|
12391
12393
|
* @param options Optional configuration options
|
|
12392
|
-
* @returns The
|
|
12394
|
+
* @returns The optimization result
|
|
12393
12395
|
*/
|
|
12394
12396
|
async compile(metricFn, options) {
|
|
12395
|
-
|
|
12396
|
-
|
|
12397
|
+
const miproOptions = options;
|
|
12398
|
+
if (miproOptions?.auto) {
|
|
12399
|
+
this.configureAuto(miproOptions.auto);
|
|
12397
12400
|
}
|
|
12398
12401
|
const trainset = this.examples;
|
|
12399
|
-
const valset =
|
|
12402
|
+
const valset = miproOptions?.valset || this.examples.slice(0, Math.floor(this.examples.length * 0.8));
|
|
12400
12403
|
if (this.verbose) {
|
|
12401
12404
|
console.log(`Starting MIPROv2 optimization with ${this.numTrials} trials`);
|
|
12402
12405
|
console.log(
|
|
12403
12406
|
`Using ${trainset.length} examples for training and ${valset.length} for validation`
|
|
12404
12407
|
);
|
|
12405
12408
|
}
|
|
12406
|
-
if (
|
|
12409
|
+
if (miproOptions?.teacher) {
|
|
12407
12410
|
if (this.verbose) {
|
|
12408
12411
|
console.log("Using provided teacher to assist with bootstrapping");
|
|
12409
12412
|
}
|
|
@@ -12460,7 +12463,17 @@ ${dataContext}
|
|
|
12460
12463
|
bootstrappedDemos,
|
|
12461
12464
|
labeledExamples
|
|
12462
12465
|
);
|
|
12463
|
-
return
|
|
12466
|
+
return {
|
|
12467
|
+
program: this.program,
|
|
12468
|
+
demos: bootstrappedDemos
|
|
12469
|
+
};
|
|
12470
|
+
}
|
|
12471
|
+
/**
|
|
12472
|
+
* Get optimization statistics from the internal bootstrapper
|
|
12473
|
+
* @returns Optimization statistics or undefined if not available
|
|
12474
|
+
*/
|
|
12475
|
+
getStats() {
|
|
12476
|
+
return this.bootstrapper.getStats();
|
|
12464
12477
|
}
|
|
12465
12478
|
};
|
|
12466
12479
|
|