@ax-llm/ax 12.0.12 → 12.0.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +586 -115
- package/index.cjs.map +1 -1
- package/index.d.cts +79 -17
- package/index.d.ts +79 -17
- package/index.js +586 -115
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.cjs
CHANGED
|
@@ -6257,15 +6257,24 @@ var MemoryImpl = class {
|
|
|
6257
6257
|
debugRequest(items, this.options?.debugHideSystemPrompt);
|
|
6258
6258
|
}
|
|
6259
6259
|
}
|
|
6260
|
+
addFunctionResults(results) {
|
|
6261
|
+
const chat = results.map(({ index, ...value }) => ({
|
|
6262
|
+
index,
|
|
6263
|
+
value: structuredClone(value)
|
|
6264
|
+
}));
|
|
6265
|
+
const lastItem = this.getLast();
|
|
6266
|
+
if (lastItem?.role === "function") {
|
|
6267
|
+
lastItem.chat.push(...chat);
|
|
6268
|
+
} else {
|
|
6269
|
+
this.data.push({ role: "function", chat });
|
|
6270
|
+
}
|
|
6271
|
+
}
|
|
6260
6272
|
addResponse(results) {
|
|
6261
|
-
const chat = results.map((
|
|
6262
|
-
index
|
|
6263
|
-
value: structuredClone(
|
|
6273
|
+
const chat = results.map(({ index, ...value }) => ({
|
|
6274
|
+
index,
|
|
6275
|
+
value: structuredClone(value)
|
|
6264
6276
|
}));
|
|
6265
|
-
this.data.push({
|
|
6266
|
-
role: "assistant",
|
|
6267
|
-
chat
|
|
6268
|
-
});
|
|
6277
|
+
this.data.push({ role: "assistant", chat });
|
|
6269
6278
|
if (this.options?.debug) {
|
|
6270
6279
|
for (const result of results) {
|
|
6271
6280
|
debugResponse(result);
|
|
@@ -6353,9 +6362,20 @@ var MemoryImpl = class {
|
|
|
6353
6362
|
history(index) {
|
|
6354
6363
|
const result = [];
|
|
6355
6364
|
for (const { role, chat } of this.data) {
|
|
6356
|
-
|
|
6357
|
-
if (
|
|
6358
|
-
|
|
6365
|
+
let values;
|
|
6366
|
+
if (role === "function") {
|
|
6367
|
+
values = chat.filter((v) => v.index === index).map((v) => v.value);
|
|
6368
|
+
} else {
|
|
6369
|
+
values = chat.find((v) => v.index === index)?.value;
|
|
6370
|
+
}
|
|
6371
|
+
if (Array.isArray(values)) {
|
|
6372
|
+
result.push(
|
|
6373
|
+
...values.map(
|
|
6374
|
+
(v) => ({ ...v, role })
|
|
6375
|
+
)
|
|
6376
|
+
);
|
|
6377
|
+
} else if (values) {
|
|
6378
|
+
result.push({ ...values, role });
|
|
6359
6379
|
}
|
|
6360
6380
|
}
|
|
6361
6381
|
return result;
|
|
@@ -6393,20 +6413,8 @@ var AxMemory = class {
|
|
|
6393
6413
|
axValidateChatResponseResult(results);
|
|
6394
6414
|
this.getMemory(sessionId).addResponse(results);
|
|
6395
6415
|
}
|
|
6396
|
-
|
|
6397
|
-
|
|
6398
|
-
isError,
|
|
6399
|
-
index,
|
|
6400
|
-
result
|
|
6401
|
-
}, sessionId) {
|
|
6402
|
-
const functionMessage = {
|
|
6403
|
-
role: "function",
|
|
6404
|
-
functionId,
|
|
6405
|
-
isError,
|
|
6406
|
-
result
|
|
6407
|
-
};
|
|
6408
|
-
axValidateChatRequestMessage(functionMessage);
|
|
6409
|
-
this.getMemory(sessionId).addRequest([functionMessage], index);
|
|
6416
|
+
addFunctionResults(results, sessionId) {
|
|
6417
|
+
this.getMemory(sessionId).addFunctionResults(results);
|
|
6410
6418
|
}
|
|
6411
6419
|
updateResult(result, sessionId) {
|
|
6412
6420
|
this.getMemory(sessionId).updateResult(result);
|
|
@@ -7643,51 +7651,48 @@ var processFunctions = async ({
|
|
|
7643
7651
|
}
|
|
7644
7652
|
return {
|
|
7645
7653
|
result: functionResult ?? "",
|
|
7654
|
+
role: "function",
|
|
7646
7655
|
functionId: func.id,
|
|
7647
7656
|
index
|
|
7648
7657
|
};
|
|
7649
7658
|
}).catch((e) => {
|
|
7650
|
-
if (e instanceof FunctionError) {
|
|
7651
|
-
|
|
7652
|
-
|
|
7653
|
-
|
|
7654
|
-
|
|
7655
|
-
|
|
7656
|
-
|
|
7657
|
-
|
|
7658
|
-
|
|
7659
|
-
|
|
7660
|
-
|
|
7661
|
-
|
|
7659
|
+
if (!(e instanceof FunctionError)) {
|
|
7660
|
+
throw e;
|
|
7661
|
+
}
|
|
7662
|
+
const result = e.getFixingInstructions();
|
|
7663
|
+
if (span) {
|
|
7664
|
+
const errorEventData = {
|
|
7665
|
+
name: func.name,
|
|
7666
|
+
message: e.toString()
|
|
7667
|
+
};
|
|
7668
|
+
if (!excludeContentFromTrace) {
|
|
7669
|
+
errorEventData.args = func.args;
|
|
7670
|
+
errorEventData.fixing_instructions = result;
|
|
7662
7671
|
}
|
|
7663
|
-
|
|
7664
|
-
|
|
7665
|
-
|
|
7666
|
-
|
|
7667
|
-
|
|
7668
|
-
result
|
|
7669
|
-
},
|
|
7670
|
-
sessionId
|
|
7671
|
-
);
|
|
7672
|
-
mem.addTag("error", sessionId);
|
|
7673
|
-
if (ai.getOptions().debug) {
|
|
7674
|
-
const logger = ai.getLogger();
|
|
7675
|
-
logger(`\u274C Function Error Correction:
|
|
7672
|
+
span.addEvent("function.error", errorEventData);
|
|
7673
|
+
}
|
|
7674
|
+
if (ai.getOptions().debug) {
|
|
7675
|
+
const logger = ai.getLogger();
|
|
7676
|
+
logger(`\u274C Function Error Correction:
|
|
7676
7677
|
${result}`, {
|
|
7677
|
-
|
|
7678
|
-
|
|
7679
|
-
}
|
|
7680
|
-
} else {
|
|
7681
|
-
throw e;
|
|
7678
|
+
tags: ["error"]
|
|
7679
|
+
});
|
|
7682
7680
|
}
|
|
7681
|
+
return {
|
|
7682
|
+
functionId: func.id,
|
|
7683
|
+
isError: true,
|
|
7684
|
+
index,
|
|
7685
|
+
result,
|
|
7686
|
+
role: "function"
|
|
7687
|
+
};
|
|
7683
7688
|
});
|
|
7684
7689
|
return promise;
|
|
7685
7690
|
});
|
|
7686
7691
|
const results = await Promise.all(promises);
|
|
7687
|
-
|
|
7688
|
-
|
|
7689
|
-
|
|
7690
|
-
|
|
7692
|
+
const functionResults = results.filter((result) => result !== void 0);
|
|
7693
|
+
mem.addFunctionResults(functionResults, sessionId);
|
|
7694
|
+
if (functionResults.some((result) => result.isError)) {
|
|
7695
|
+
mem.addTag("error", sessionId);
|
|
7691
7696
|
}
|
|
7692
7697
|
return functionsExecuted;
|
|
7693
7698
|
};
|
|
@@ -9961,6 +9966,96 @@ var AxProgram = class {
|
|
|
9961
9966
|
}
|
|
9962
9967
|
};
|
|
9963
9968
|
|
|
9969
|
+
// dsp/samples.ts
|
|
9970
|
+
function checkForFunctionCalls(mem, sessionId) {
|
|
9971
|
+
const history = mem.history(0, sessionId);
|
|
9972
|
+
const hasFunctionResults = history.some((msg) => msg.role === "function");
|
|
9973
|
+
const hasFunctionCalls = history.some(
|
|
9974
|
+
(msg) => msg.role === "assistant" && "functionCalls" in msg && Array.isArray(msg.functionCalls) && msg.functionCalls.length > 0
|
|
9975
|
+
);
|
|
9976
|
+
return hasFunctionCalls && hasFunctionResults;
|
|
9977
|
+
}
|
|
9978
|
+
function extractFunctionResults(mem, sessionId) {
|
|
9979
|
+
const history = mem.history(0, sessionId);
|
|
9980
|
+
const results = [];
|
|
9981
|
+
const assistantMessages = history.filter(
|
|
9982
|
+
(msg) => msg.role === "assistant" && "functionCalls" in msg && Array.isArray(msg.functionCalls) && msg.functionCalls.length > 0
|
|
9983
|
+
);
|
|
9984
|
+
const functionMessages = history.filter((msg) => msg.role === "function");
|
|
9985
|
+
for (const assistantMsg of assistantMessages) {
|
|
9986
|
+
if ("functionCalls" in assistantMsg && assistantMsg.functionCalls) {
|
|
9987
|
+
for (const funcCall of assistantMsg.functionCalls) {
|
|
9988
|
+
const funcResult = functionMessages.find(
|
|
9989
|
+
(msg) => "functionId" in msg && msg.functionId === funcCall.id
|
|
9990
|
+
);
|
|
9991
|
+
if (funcResult && "result" in funcResult && "functionId" in funcResult) {
|
|
9992
|
+
results.push({
|
|
9993
|
+
index: results.length,
|
|
9994
|
+
// Use sequential index for function results
|
|
9995
|
+
functionName: funcCall.function.name,
|
|
9996
|
+
functionId: funcCall.id,
|
|
9997
|
+
args: funcCall.function.params || "",
|
|
9998
|
+
result: String(funcResult.result),
|
|
9999
|
+
isError: "isError" in funcResult ? Boolean(funcResult.isError) : false
|
|
10000
|
+
});
|
|
10001
|
+
}
|
|
10002
|
+
}
|
|
10003
|
+
}
|
|
10004
|
+
}
|
|
10005
|
+
return results;
|
|
10006
|
+
}
|
|
10007
|
+
async function selectFromSamples(buffer, options, mem, sessionId) {
|
|
10008
|
+
if (!options?.resultPicker || buffer.length <= 1) {
|
|
10009
|
+
return 0;
|
|
10010
|
+
}
|
|
10011
|
+
const resultPicker = options.resultPicker;
|
|
10012
|
+
const hasFunctionCalls = mem ? checkForFunctionCalls(mem, sessionId) : false;
|
|
10013
|
+
if (hasFunctionCalls && mem) {
|
|
10014
|
+
const functionResults = extractFunctionResults(mem, sessionId);
|
|
10015
|
+
const selectedIndex = await resultPicker({
|
|
10016
|
+
type: "function",
|
|
10017
|
+
results: functionResults
|
|
10018
|
+
});
|
|
10019
|
+
if (selectedIndex < 0 || selectedIndex >= functionResults.length) {
|
|
10020
|
+
throw new Error(
|
|
10021
|
+
`Result picker returned invalid index: ${selectedIndex}. Must be between 0 and ${functionResults.length - 1}`
|
|
10022
|
+
);
|
|
10023
|
+
}
|
|
10024
|
+
return selectedIndex;
|
|
10025
|
+
} else {
|
|
10026
|
+
const fieldResults = buffer.map((b, index) => ({
|
|
10027
|
+
index,
|
|
10028
|
+
sample: b.delta
|
|
10029
|
+
}));
|
|
10030
|
+
const selectedIndex = await resultPicker({
|
|
10031
|
+
type: "fields",
|
|
10032
|
+
results: fieldResults
|
|
10033
|
+
});
|
|
10034
|
+
if (selectedIndex < 0 || selectedIndex >= buffer.length) {
|
|
10035
|
+
throw new Error(
|
|
10036
|
+
`Result picker returned invalid index: ${selectedIndex}. Must be between 0 and ${buffer.length - 1}`
|
|
10037
|
+
);
|
|
10038
|
+
}
|
|
10039
|
+
return selectedIndex;
|
|
10040
|
+
}
|
|
10041
|
+
}
|
|
10042
|
+
async function selectFromSamplesInMemory(mem, sessionId, options) {
|
|
10043
|
+
const lastMemory = mem?.getLast(sessionId);
|
|
10044
|
+
if (!lastMemory || lastMemory.role !== "assistant") {
|
|
10045
|
+
return 0;
|
|
10046
|
+
}
|
|
10047
|
+
if (lastMemory.chat.length <= 1) {
|
|
10048
|
+
return 0;
|
|
10049
|
+
}
|
|
10050
|
+
const buffer = lastMemory.chat.map((chat) => ({
|
|
10051
|
+
version: 0,
|
|
10052
|
+
index: chat.index,
|
|
10053
|
+
delta: chat.value
|
|
10054
|
+
}));
|
|
10055
|
+
const selectedIndex = await selectFromSamples(buffer, options, mem, sessionId);
|
|
10056
|
+
return selectedIndex;
|
|
10057
|
+
}
|
|
10058
|
+
|
|
9964
10059
|
// dsp/validate.ts
|
|
9965
10060
|
function handleValidationError(mem, errorFields, ai, promptTemplate, sessionId) {
|
|
9966
10061
|
mem.addRequest(
|
|
@@ -10076,7 +10171,10 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
10076
10171
|
thinkingTokenBudget,
|
|
10077
10172
|
showThoughts
|
|
10078
10173
|
} = options ?? {};
|
|
10079
|
-
const
|
|
10174
|
+
const selectedIndex = await selectFromSamplesInMemory(mem, sessionId, {
|
|
10175
|
+
resultPicker: options?.resultPicker
|
|
10176
|
+
});
|
|
10177
|
+
const chatPrompt = mem?.history(selectedIndex, sessionId) ?? [];
|
|
10080
10178
|
if (chatPrompt.length === 0) {
|
|
10081
10179
|
throw new Error("No chat prompt found");
|
|
10082
10180
|
}
|
|
@@ -10087,7 +10185,8 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
10087
10185
|
}
|
|
10088
10186
|
const modelConfig = {
|
|
10089
10187
|
...options?.modelConfig,
|
|
10090
|
-
...options?.sampleCount ? { n: options.sampleCount } : {}
|
|
10188
|
+
...options?.sampleCount ? { n: options.sampleCount } : {},
|
|
10189
|
+
...options?.sampleCount && options?.modelConfig?.temperature == 1 ? { temperature: 0.8 } : {}
|
|
10091
10190
|
};
|
|
10092
10191
|
const res = await ai.chat(
|
|
10093
10192
|
{
|
|
@@ -10364,15 +10463,58 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
10364
10463
|
currentVersion = delta.version;
|
|
10365
10464
|
buffer = mergeDeltas(buffer, delta);
|
|
10366
10465
|
}
|
|
10367
|
-
const
|
|
10466
|
+
const selectedIndex = await selectFromSamples(
|
|
10467
|
+
buffer,
|
|
10468
|
+
{
|
|
10469
|
+
resultPicker: options?.resultPicker
|
|
10470
|
+
},
|
|
10471
|
+
// Pass memory to enable function result selection
|
|
10472
|
+
options?.mem,
|
|
10473
|
+
options?.sessionId
|
|
10474
|
+
);
|
|
10475
|
+
const selectedResult = buffer[selectedIndex];
|
|
10476
|
+
const result = selectedResult?.delta ?? {};
|
|
10368
10477
|
this.trace = { ...values, ...result };
|
|
10369
10478
|
return result;
|
|
10370
10479
|
}
|
|
10371
10480
|
async *streamingForward(ai, values, options) {
|
|
10372
|
-
|
|
10481
|
+
if (!options?.resultPicker) {
|
|
10482
|
+
yield* this._forward1(ai, values, {
|
|
10483
|
+
...options,
|
|
10484
|
+
stream: true
|
|
10485
|
+
});
|
|
10486
|
+
return;
|
|
10487
|
+
}
|
|
10488
|
+
const generator = this._forward1(ai, values, {
|
|
10373
10489
|
...options,
|
|
10374
10490
|
stream: true
|
|
10375
10491
|
});
|
|
10492
|
+
let buffer = [];
|
|
10493
|
+
let currentVersion = 0;
|
|
10494
|
+
for await (const delta of generator) {
|
|
10495
|
+
if (delta.version !== currentVersion) {
|
|
10496
|
+
buffer = [];
|
|
10497
|
+
}
|
|
10498
|
+
currentVersion = delta.version;
|
|
10499
|
+
buffer = mergeDeltas(buffer, delta);
|
|
10500
|
+
}
|
|
10501
|
+
const selectedIndex = await selectFromSamples(
|
|
10502
|
+
buffer,
|
|
10503
|
+
{
|
|
10504
|
+
resultPicker: options?.resultPicker
|
|
10505
|
+
},
|
|
10506
|
+
// Pass memory to enable function result selection
|
|
10507
|
+
options?.mem,
|
|
10508
|
+
options?.sessionId
|
|
10509
|
+
);
|
|
10510
|
+
const selectedResult = buffer[selectedIndex];
|
|
10511
|
+
if (selectedResult) {
|
|
10512
|
+
yield {
|
|
10513
|
+
version: currentVersion,
|
|
10514
|
+
index: selectedIndex,
|
|
10515
|
+
delta: selectedResult.delta
|
|
10516
|
+
};
|
|
10517
|
+
}
|
|
10376
10518
|
}
|
|
10377
10519
|
setExamples(examples, options) {
|
|
10378
10520
|
super.setExamples(examples, options);
|
|
@@ -11732,13 +11874,6 @@ var AxBaseOptimizer = class {
|
|
|
11732
11874
|
if (this.logger) {
|
|
11733
11875
|
return this.logger;
|
|
11734
11876
|
}
|
|
11735
|
-
try {
|
|
11736
|
-
const aiLogger = this.studentAI.getLogger();
|
|
11737
|
-
if (aiLogger) {
|
|
11738
|
-
return aiLogger;
|
|
11739
|
-
}
|
|
11740
|
-
} catch {
|
|
11741
|
-
}
|
|
11742
11877
|
return axDefaultOptimizerLogger;
|
|
11743
11878
|
}
|
|
11744
11879
|
/**
|
|
@@ -13430,6 +13565,11 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13430
13565
|
bayesianOptimization;
|
|
13431
13566
|
acquisitionFunction;
|
|
13432
13567
|
explorationWeight;
|
|
13568
|
+
// Self-consistency / multiple sampling
|
|
13569
|
+
sampleCount;
|
|
13570
|
+
// Surrogate model state for Bayesian optimization
|
|
13571
|
+
miproConfigHistory = [];
|
|
13572
|
+
surrogateModel = /* @__PURE__ */ new Map();
|
|
13433
13573
|
constructor(args) {
|
|
13434
13574
|
super(args);
|
|
13435
13575
|
const options = args.options || {};
|
|
@@ -13451,6 +13591,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13451
13591
|
this.bayesianOptimization = options.bayesianOptimization ?? false;
|
|
13452
13592
|
this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
|
|
13453
13593
|
this.explorationWeight = options.explorationWeight ?? 0.1;
|
|
13594
|
+
this.sampleCount = options.sampleCount ?? 1;
|
|
13454
13595
|
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13455
13596
|
}
|
|
13456
13597
|
/**
|
|
@@ -13495,43 +13636,186 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13495
13636
|
];
|
|
13496
13637
|
}
|
|
13497
13638
|
/**
|
|
13498
|
-
* Generates
|
|
13639
|
+
* Generates program summary for context-aware instruction generation
|
|
13640
|
+
*/
|
|
13641
|
+
async generateProgramSummary(program, ai) {
|
|
13642
|
+
let signature = "input -> output";
|
|
13643
|
+
if ("getSignature" in program && typeof program.getSignature === "function") {
|
|
13644
|
+
signature = program.getSignature();
|
|
13645
|
+
}
|
|
13646
|
+
const summaryPrompt = `
|
|
13647
|
+
Analyze this language model program and provide a concise summary of its purpose and structure.
|
|
13648
|
+
|
|
13649
|
+
Program Signature: ${signature}
|
|
13650
|
+
|
|
13651
|
+
Provide a 2-3 sentence summary focusing on:
|
|
13652
|
+
1. The main task or purpose of this program
|
|
13653
|
+
2. The input-output relationship
|
|
13654
|
+
3. Any special constraints or requirements
|
|
13655
|
+
|
|
13656
|
+
Summary:`;
|
|
13657
|
+
try {
|
|
13658
|
+
const response = await ai.chat({
|
|
13659
|
+
chatPrompt: [{ role: "user", content: summaryPrompt }]
|
|
13660
|
+
});
|
|
13661
|
+
if ("results" in response) {
|
|
13662
|
+
return response.results[0]?.content?.trim() || "General language model program";
|
|
13663
|
+
}
|
|
13664
|
+
return "General language model program";
|
|
13665
|
+
} catch {
|
|
13666
|
+
return "General language model program";
|
|
13667
|
+
}
|
|
13668
|
+
}
|
|
13669
|
+
/**
|
|
13670
|
+
* Generates dataset summary for context-aware instruction generation
|
|
13671
|
+
*/
|
|
13672
|
+
async generateDatasetSummary(examples, ai) {
|
|
13673
|
+
if (examples.length === 0) return "No examples available";
|
|
13674
|
+
const sampleSize = Math.min(this.viewDataBatchSize, examples.length);
|
|
13675
|
+
const sampledExamples = examples.slice(0, sampleSize);
|
|
13676
|
+
const exampleTexts = sampledExamples.map((ex, i) => `Example ${i + 1}: ${JSON.stringify(ex)}`).join("\n");
|
|
13677
|
+
const summaryPrompt = `
|
|
13678
|
+
Analyze this dataset and provide a concise summary of its characteristics.
|
|
13679
|
+
|
|
13680
|
+
Sample Examples:
|
|
13681
|
+
${exampleTexts}
|
|
13682
|
+
|
|
13683
|
+
Provide a 2-3 sentence summary focusing on:
|
|
13684
|
+
1. The type of data and domain
|
|
13685
|
+
2. Common patterns or structures in the examples
|
|
13686
|
+
3. Key challenges or requirements for processing this data
|
|
13687
|
+
|
|
13688
|
+
Dataset Summary:`;
|
|
13689
|
+
try {
|
|
13690
|
+
const response = await ai.chat({
|
|
13691
|
+
chatPrompt: [{ role: "user", content: summaryPrompt }]
|
|
13692
|
+
});
|
|
13693
|
+
if ("results" in response) {
|
|
13694
|
+
return response.results[0]?.content?.trim() || "General dataset";
|
|
13695
|
+
}
|
|
13696
|
+
return "General dataset";
|
|
13697
|
+
} catch {
|
|
13698
|
+
return "General dataset";
|
|
13699
|
+
}
|
|
13700
|
+
}
|
|
13701
|
+
/**
|
|
13702
|
+
* Enhanced instruction generation using AI with program and data awareness
|
|
13703
|
+
*/
|
|
13704
|
+
async generateInstruction({
|
|
13705
|
+
tip,
|
|
13706
|
+
candidateIndex,
|
|
13707
|
+
ai,
|
|
13708
|
+
programSummary,
|
|
13709
|
+
datasetSummary,
|
|
13710
|
+
previousInstructions = []
|
|
13711
|
+
}) {
|
|
13712
|
+
let contextInfo = "";
|
|
13713
|
+
if (this.programAwareProposer && programSummary) {
|
|
13714
|
+
contextInfo += `
|
|
13715
|
+
Program Context: ${programSummary}`;
|
|
13716
|
+
}
|
|
13717
|
+
if (this.dataAwareProposer && datasetSummary) {
|
|
13718
|
+
contextInfo += `
|
|
13719
|
+
Dataset Context: ${datasetSummary}`;
|
|
13720
|
+
}
|
|
13721
|
+
if (this.fewshotAwareProposer && previousInstructions.length > 0) {
|
|
13722
|
+
contextInfo += `
|
|
13723
|
+
Previous Instructions (avoid repeating): ${previousInstructions.slice(-3).join("; ")}`;
|
|
13724
|
+
}
|
|
13725
|
+
const instructionPrompt = `
|
|
13726
|
+
Generate a high-quality instruction for a language model program.
|
|
13727
|
+
|
|
13728
|
+
${contextInfo}
|
|
13729
|
+
|
|
13730
|
+
${tip ? `Tip: ${tip}` : ""}
|
|
13731
|
+
|
|
13732
|
+
Requirements:
|
|
13733
|
+
1. Be specific and actionable
|
|
13734
|
+
2. Focus on accuracy and clarity
|
|
13735
|
+
3. Consider the program's purpose and data characteristics
|
|
13736
|
+
4. Make the instruction distinct from previous ones
|
|
13737
|
+
5. Keep it concise but comprehensive
|
|
13738
|
+
|
|
13739
|
+
Generate a single, well-crafted instruction:
|
|
13740
|
+
Instruction:`;
|
|
13741
|
+
try {
|
|
13742
|
+
const response = await ai.chat({
|
|
13743
|
+
chatPrompt: [
|
|
13744
|
+
{
|
|
13745
|
+
role: "user",
|
|
13746
|
+
content: instructionPrompt
|
|
13747
|
+
}
|
|
13748
|
+
]
|
|
13749
|
+
});
|
|
13750
|
+
if ("results" in response) {
|
|
13751
|
+
const instruction2 = response.results[0]?.content?.trim();
|
|
13752
|
+
if (instruction2 && instruction2.length > 10) {
|
|
13753
|
+
return instruction2;
|
|
13754
|
+
}
|
|
13755
|
+
}
|
|
13756
|
+
} catch (error) {
|
|
13757
|
+
if (this.isLoggingEnabled()) {
|
|
13758
|
+
this.getLogger()?.(`Failed to generate AI instruction: ${error}`, {
|
|
13759
|
+
tags: ["optimizer", "warning"]
|
|
13760
|
+
});
|
|
13761
|
+
}
|
|
13762
|
+
}
|
|
13763
|
+
const enhancedTemplates = [
|
|
13764
|
+
"Analyze the input systematically and provide a precise, well-reasoned response.",
|
|
13765
|
+
"Think through this step-by-step, considering all relevant factors before responding.",
|
|
13766
|
+
"Examine the input carefully and generate an accurate, detailed answer.",
|
|
13767
|
+
"Process the information methodically and deliver a clear, comprehensive response.",
|
|
13768
|
+
"Consider the context thoroughly and provide a thoughtful, accurate answer."
|
|
13769
|
+
];
|
|
13770
|
+
let instruction = enhancedTemplates[candidateIndex % enhancedTemplates.length] || enhancedTemplates[0];
|
|
13771
|
+
if (tip) {
|
|
13772
|
+
instruction = `${instruction} ${tip}`;
|
|
13773
|
+
}
|
|
13774
|
+
return instruction;
|
|
13775
|
+
}
|
|
13776
|
+
/**
|
|
13777
|
+
* Generates instruction candidates using enhanced AI-powered generation
|
|
13499
13778
|
* @param options Optional compile options that may override teacher AI
|
|
13500
13779
|
* @returns Array of generated instruction candidates
|
|
13501
13780
|
*/
|
|
13502
|
-
async proposeInstructionCandidates(options) {
|
|
13781
|
+
async proposeInstructionCandidates(program, options) {
|
|
13503
13782
|
const instructions = [];
|
|
13504
13783
|
const aiToUse = this.getTeacherOrStudentAI(options);
|
|
13784
|
+
let programSummary;
|
|
13785
|
+
let datasetSummary;
|
|
13786
|
+
if (this.programAwareProposer) {
|
|
13787
|
+
programSummary = await this.generateProgramSummary(program, aiToUse);
|
|
13788
|
+
if (this.isLoggingEnabled(options)) {
|
|
13789
|
+
this.getLogger(options)?.(`Program summary: ${programSummary}`, {
|
|
13790
|
+
tags: ["optimizer", "config"]
|
|
13791
|
+
});
|
|
13792
|
+
}
|
|
13793
|
+
}
|
|
13794
|
+
if (this.dataAwareProposer) {
|
|
13795
|
+
datasetSummary = await this.generateDatasetSummary(this.examples, aiToUse);
|
|
13796
|
+
if (this.isLoggingEnabled(options)) {
|
|
13797
|
+
this.getLogger(options)?.(`Dataset summary: ${datasetSummary}`, {
|
|
13798
|
+
tags: ["optimizer", "config"]
|
|
13799
|
+
});
|
|
13800
|
+
}
|
|
13801
|
+
}
|
|
13505
13802
|
const tips = this.tipAwareProposer ? this.generateTips() : [];
|
|
13506
13803
|
for (let i = 0; i < this.numCandidates; i++) {
|
|
13507
13804
|
const tipIndex = tips.length > 0 ? i % tips.length : -1;
|
|
13508
|
-
const tipToUse = tipIndex >= 0 ? tips[tipIndex] :
|
|
13805
|
+
const tipToUse = tipIndex >= 0 ? tips[tipIndex] : void 0;
|
|
13509
13806
|
const instruction = await this.generateInstruction({
|
|
13510
13807
|
tip: tipToUse,
|
|
13511
13808
|
candidateIndex: i,
|
|
13512
|
-
ai: aiToUse
|
|
13809
|
+
ai: aiToUse,
|
|
13810
|
+
programSummary,
|
|
13811
|
+
datasetSummary,
|
|
13812
|
+
previousInstructions: instructions
|
|
13813
|
+
// Pass previous instructions for diversity
|
|
13513
13814
|
});
|
|
13514
13815
|
instructions.push(instruction);
|
|
13515
13816
|
}
|
|
13516
13817
|
return instructions;
|
|
13517
13818
|
}
|
|
13518
|
-
async generateInstruction({
|
|
13519
|
-
tip,
|
|
13520
|
-
candidateIndex
|
|
13521
|
-
}) {
|
|
13522
|
-
const baseInstructions = [
|
|
13523
|
-
"Analyze the input carefully and provide a detailed response.",
|
|
13524
|
-
"Think step by step and provide a clear answer.",
|
|
13525
|
-
"Consider all aspects of the input before responding.",
|
|
13526
|
-
"Provide a concise but comprehensive response.",
|
|
13527
|
-
"Focus on accuracy and clarity in your response."
|
|
13528
|
-
];
|
|
13529
|
-
let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
|
|
13530
|
-
if (tip) {
|
|
13531
|
-
instruction = `${instruction} ${tip}`;
|
|
13532
|
-
}
|
|
13533
|
-
return instruction;
|
|
13534
|
-
}
|
|
13535
13819
|
/**
|
|
13536
13820
|
* Bootstraps few-shot examples for the program
|
|
13537
13821
|
*/
|
|
@@ -13576,7 +13860,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13576
13860
|
/**
|
|
13577
13861
|
* Runs optimization to find the best combination of few-shot examples and instructions
|
|
13578
13862
|
*/
|
|
13579
|
-
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions,
|
|
13863
|
+
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, validationExamples, metricFn, options) {
|
|
13580
13864
|
let bestConfig = {
|
|
13581
13865
|
instruction: instructions[0] || "",
|
|
13582
13866
|
bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
|
|
@@ -13612,25 +13896,37 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13612
13896
|
);
|
|
13613
13897
|
}
|
|
13614
13898
|
for (let i = startRound; i < this.numTrials; i++) {
|
|
13615
|
-
|
|
13616
|
-
|
|
13617
|
-
|
|
13618
|
-
|
|
13619
|
-
|
|
13620
|
-
|
|
13621
|
-
|
|
13622
|
-
|
|
13623
|
-
|
|
13624
|
-
|
|
13625
|
-
|
|
13899
|
+
let config;
|
|
13900
|
+
if (this.bayesianOptimization && this.miproConfigHistory.length > 2) {
|
|
13901
|
+
config = await this.selectConfigurationViaBayesianOptimization(
|
|
13902
|
+
instructions,
|
|
13903
|
+
bootstrappedDemos,
|
|
13904
|
+
labeledExamples
|
|
13905
|
+
);
|
|
13906
|
+
} else {
|
|
13907
|
+
config = {
|
|
13908
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
13909
|
+
bootstrappedDemos: Math.min(
|
|
13910
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
13911
|
+
this.maxBootstrappedDemos
|
|
13912
|
+
),
|
|
13913
|
+
labeledExamples: Math.min(
|
|
13914
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
13915
|
+
this.maxLabeledDemos
|
|
13916
|
+
)
|
|
13917
|
+
};
|
|
13918
|
+
}
|
|
13626
13919
|
const score = await this.evaluateConfig(
|
|
13627
13920
|
program,
|
|
13628
13921
|
config,
|
|
13629
13922
|
bootstrappedDemos,
|
|
13630
13923
|
labeledExamples,
|
|
13631
|
-
|
|
13632
|
-
metricFn
|
|
13924
|
+
validationExamples,
|
|
13925
|
+
metricFn,
|
|
13926
|
+
i + 1
|
|
13927
|
+
// Pass current trial number for adaptive evaluation
|
|
13633
13928
|
);
|
|
13929
|
+
this.updateSurrogateModel(config, score);
|
|
13634
13930
|
scoreHistory.push(score);
|
|
13635
13931
|
const improvement = score - bestScore;
|
|
13636
13932
|
if (improvement > this.minImprovementThreshold) {
|
|
@@ -13712,7 +14008,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13712
14008
|
this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
|
|
13713
14009
|
return { bestConfig, bestScore };
|
|
13714
14010
|
}
|
|
13715
|
-
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples,
|
|
14011
|
+
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, validationExamples, metricFn, currentTrial = 0) {
|
|
13716
14012
|
const testProgram = { ...program };
|
|
13717
14013
|
this.applyConfigToProgram(
|
|
13718
14014
|
testProgram,
|
|
@@ -13722,12 +14018,31 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13722
14018
|
);
|
|
13723
14019
|
let totalScore = 0;
|
|
13724
14020
|
let count = 0;
|
|
13725
|
-
|
|
14021
|
+
let evalSize;
|
|
14022
|
+
if (this.minibatch) {
|
|
14023
|
+
const baseSize = Math.min(this.minibatchSize, validationExamples.length);
|
|
14024
|
+
const isFullEvalTrial = currentTrial % this.minibatchFullEvalSteps === 0;
|
|
14025
|
+
if (isFullEvalTrial || currentTrial > this.numTrials * 0.8) {
|
|
14026
|
+
evalSize = Math.min(validationExamples.length, baseSize * 2);
|
|
14027
|
+
} else {
|
|
14028
|
+
evalSize = Math.max(3, Math.min(baseSize, validationExamples.length));
|
|
14029
|
+
}
|
|
14030
|
+
} else {
|
|
14031
|
+
evalSize = validationExamples.length;
|
|
14032
|
+
}
|
|
14033
|
+
const evalIndices = this.shuffleArray([
|
|
14034
|
+
...Array(validationExamples.length).keys()
|
|
14035
|
+
]).slice(0, evalSize);
|
|
14036
|
+
const evalSet = evalIndices.map((i) => validationExamples[i]);
|
|
13726
14037
|
for (const example of evalSet) {
|
|
13727
14038
|
try {
|
|
13728
14039
|
const prediction = await testProgram.forward(
|
|
13729
14040
|
this.studentAI,
|
|
13730
|
-
example
|
|
14041
|
+
example,
|
|
14042
|
+
this.sampleCount > 1 ? {
|
|
14043
|
+
sampleCount: this.sampleCount,
|
|
14044
|
+
resultPicker: axMajorityVotePicker()
|
|
14045
|
+
} : void 0
|
|
13731
14046
|
);
|
|
13732
14047
|
const score = await metricFn({ prediction, example });
|
|
13733
14048
|
totalScore += score;
|
|
@@ -13739,6 +14054,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13739
14054
|
}
|
|
13740
14055
|
return count > 0 ? totalScore / count : 0;
|
|
13741
14056
|
}
|
|
14057
|
+
/**
|
|
14058
|
+
* Fisher-Yates shuffle for stochastic evaluation
|
|
14059
|
+
*/
|
|
14060
|
+
shuffleArray(array) {
|
|
14061
|
+
const shuffled = [...array];
|
|
14062
|
+
for (let i = shuffled.length - 1; i > 0; i--) {
|
|
14063
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
14064
|
+
[shuffled[i], shuffled[j]] = [shuffled[j], shuffled[i]];
|
|
14065
|
+
}
|
|
14066
|
+
return shuffled;
|
|
14067
|
+
}
|
|
13742
14068
|
applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
|
|
13743
14069
|
if (program.setInstruction) {
|
|
13744
14070
|
program.setInstruction(config.instruction);
|
|
@@ -13760,14 +14086,14 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13760
14086
|
if (miproOptions?.auto) {
|
|
13761
14087
|
this.configureAuto(miproOptions.auto);
|
|
13762
14088
|
}
|
|
13763
|
-
const
|
|
14089
|
+
const validationExamples = this.getValidationSet(options) || (miproOptions?.validationExamples ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
|
|
13764
14090
|
if (this.isLoggingEnabled(options)) {
|
|
13765
14091
|
this.getLogger(options)?.(
|
|
13766
14092
|
`Starting MIPROv2 optimization with ${this.numTrials} trials`,
|
|
13767
14093
|
{ tags: ["optimizer", "start"] }
|
|
13768
14094
|
);
|
|
13769
14095
|
this.getLogger(options)?.(
|
|
13770
|
-
`Using ${this.examples.length} examples for training and ${
|
|
14096
|
+
`Using ${this.examples.length} examples for training and ${validationExamples.length} for validation`,
|
|
13771
14097
|
{ tags: ["optimizer", "config"] }
|
|
13772
14098
|
);
|
|
13773
14099
|
if (this.teacherAI) {
|
|
@@ -13797,7 +14123,10 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13797
14123
|
);
|
|
13798
14124
|
}
|
|
13799
14125
|
}
|
|
13800
|
-
const instructions = await this.proposeInstructionCandidates(
|
|
14126
|
+
const instructions = await this.proposeInstructionCandidates(
|
|
14127
|
+
program,
|
|
14128
|
+
options
|
|
14129
|
+
);
|
|
13801
14130
|
if (this.isLoggingEnabled(options)) {
|
|
13802
14131
|
this.getLogger(options)?.(
|
|
13803
14132
|
`Generated ${instructions.length} instruction candidates`,
|
|
@@ -13815,7 +14144,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13815
14144
|
bootstrappedDemos,
|
|
13816
14145
|
labeledExamples,
|
|
13817
14146
|
instructions,
|
|
13818
|
-
|
|
14147
|
+
validationExamples,
|
|
13819
14148
|
metricFn,
|
|
13820
14149
|
options
|
|
13821
14150
|
);
|
|
@@ -13874,7 +14203,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13874
14203
|
bootstrappedDemos: bestConfig.bootstrappedDemos,
|
|
13875
14204
|
labeledExamples: bestConfig.labeledExamples,
|
|
13876
14205
|
numCandidates: this.numCandidates,
|
|
13877
|
-
numTrials: this.numTrials
|
|
14206
|
+
numTrials: this.numTrials,
|
|
14207
|
+
sampleCount: this.sampleCount
|
|
13878
14208
|
}
|
|
13879
14209
|
};
|
|
13880
14210
|
}
|
|
@@ -13919,7 +14249,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13919
14249
|
minImprovementThreshold: this.minImprovementThreshold,
|
|
13920
14250
|
bayesianOptimization: this.bayesianOptimization,
|
|
13921
14251
|
acquisitionFunction: this.acquisitionFunction,
|
|
13922
|
-
explorationWeight: this.explorationWeight
|
|
14252
|
+
explorationWeight: this.explorationWeight,
|
|
14253
|
+
sampleCount: this.sampleCount
|
|
13923
14254
|
};
|
|
13924
14255
|
}
|
|
13925
14256
|
/**
|
|
@@ -13954,12 +14285,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13954
14285
|
if (config.minImprovementThreshold !== void 0) {
|
|
13955
14286
|
this.minImprovementThreshold = config.minImprovementThreshold;
|
|
13956
14287
|
}
|
|
14288
|
+
if (config.sampleCount !== void 0) {
|
|
14289
|
+
this.sampleCount = config.sampleCount;
|
|
14290
|
+
}
|
|
13957
14291
|
}
|
|
13958
14292
|
/**
|
|
13959
14293
|
* Reset optimizer state for reuse with different programs
|
|
13960
14294
|
*/
|
|
13961
14295
|
reset() {
|
|
13962
14296
|
super.reset();
|
|
14297
|
+
this.miproConfigHistory = [];
|
|
14298
|
+
this.surrogateModel.clear();
|
|
13963
14299
|
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13964
14300
|
}
|
|
13965
14301
|
/**
|
|
@@ -13977,8 +14313,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13977
14313
|
"Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
|
|
13978
14314
|
);
|
|
13979
14315
|
}
|
|
13980
|
-
const
|
|
13981
|
-
if (
|
|
14316
|
+
const validationSetSize = this.getValidationSet().length;
|
|
14317
|
+
if (validationSetSize < 5) {
|
|
13982
14318
|
result.issues.push(
|
|
13983
14319
|
"Validation set too small for reliable MiPRO optimization"
|
|
13984
14320
|
);
|
|
@@ -13992,6 +14328,141 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13992
14328
|
suggestions: result.suggestions
|
|
13993
14329
|
};
|
|
13994
14330
|
}
|
|
14331
|
+
/**
|
|
14332
|
+
* Encodes a configuration into a string key for surrogate model lookup
|
|
14333
|
+
*/
|
|
14334
|
+
encodeConfiguration(config) {
|
|
14335
|
+
return `${config.instruction.length}_${config.bootstrappedDemos}_${config.labeledExamples}`;
|
|
14336
|
+
}
|
|
14337
|
+
/**
|
|
14338
|
+
* Updates the surrogate model with a new configuration-score pair
|
|
14339
|
+
*/
|
|
14340
|
+
updateSurrogateModel(config, score) {
|
|
14341
|
+
this.miproConfigHistory.push({ config: { ...config }, score });
|
|
14342
|
+
const key = this.encodeConfiguration(config);
|
|
14343
|
+
const similarConfigs = this.miproConfigHistory.filter(
|
|
14344
|
+
(entry) => this.encodeConfiguration(entry.config) === key
|
|
14345
|
+
);
|
|
14346
|
+
if (similarConfigs.length > 0) {
|
|
14347
|
+
const scores = similarConfigs.map((entry) => entry.score);
|
|
14348
|
+
const mean = scores.reduce((sum, s2) => sum + s2, 0) / scores.length;
|
|
14349
|
+
const variance = scores.length > 1 ? scores.reduce((sum, s2) => sum + Math.pow(s2 - mean, 2), 0) / (scores.length - 1) : 0.1;
|
|
14350
|
+
this.surrogateModel.set(key, { mean, variance });
|
|
14351
|
+
}
|
|
14352
|
+
}
|
|
14353
|
+
/**
|
|
14354
|
+
* Predicts performance using the surrogate model
|
|
14355
|
+
*/
|
|
14356
|
+
predictPerformance(config) {
|
|
14357
|
+
const key = this.encodeConfiguration(config);
|
|
14358
|
+
if (this.surrogateModel.has(key)) {
|
|
14359
|
+
return this.surrogateModel.get(key);
|
|
14360
|
+
}
|
|
14361
|
+
if (this.miproConfigHistory.length > 0) {
|
|
14362
|
+
const similarities = this.miproConfigHistory.map((entry) => {
|
|
14363
|
+
const diff = Math.abs(entry.config.bootstrappedDemos - config.bootstrappedDemos) + Math.abs(entry.config.labeledExamples - config.labeledExamples);
|
|
14364
|
+
return { score: entry.score, similarity: 1 / (1 + diff) };
|
|
14365
|
+
});
|
|
14366
|
+
const totalWeight = similarities.reduce((sum, s2) => sum + s2.similarity, 0);
|
|
14367
|
+
const weightedMean = similarities.reduce((sum, s2) => sum + s2.score * s2.similarity, 0) / totalWeight;
|
|
14368
|
+
return { mean: weightedMean, variance: 0.2 };
|
|
14369
|
+
}
|
|
14370
|
+
return { mean: 0.5, variance: 0.3 };
|
|
14371
|
+
}
|
|
14372
|
+
/**
|
|
14373
|
+
* Calculates acquisition function value for Bayesian optimization
|
|
14374
|
+
*/
|
|
14375
|
+
calculateAcquisitionValue(config) {
|
|
14376
|
+
const prediction = this.predictPerformance(config);
|
|
14377
|
+
const { mean, variance } = prediction;
|
|
14378
|
+
const std = Math.sqrt(variance);
|
|
14379
|
+
const bestScore = this.miproConfigHistory.length > 0 ? Math.max(...this.miproConfigHistory.map((entry) => entry.score)) : 0;
|
|
14380
|
+
switch (this.acquisitionFunction) {
|
|
14381
|
+
case "expected_improvement": {
|
|
14382
|
+
const improvement = mean - bestScore;
|
|
14383
|
+
if (std === 0) return Math.max(0, improvement);
|
|
14384
|
+
const z = improvement / std;
|
|
14385
|
+
const phi = 0.5 * (1 + this.erf(z / Math.sqrt(2)));
|
|
14386
|
+
const pdfValue = Math.exp(-0.5 * z * z) / Math.sqrt(2 * Math.PI);
|
|
14387
|
+
return improvement * phi + std * pdfValue;
|
|
14388
|
+
}
|
|
14389
|
+
case "upper_confidence_bound": {
|
|
14390
|
+
return mean + this.explorationWeight * std;
|
|
14391
|
+
}
|
|
14392
|
+
case "probability_improvement": {
|
|
14393
|
+
const improvement = mean - bestScore;
|
|
14394
|
+
if (std === 0) return improvement > 0 ? 1 : 0;
|
|
14395
|
+
const z = improvement / std;
|
|
14396
|
+
return 0.5 * (1 + this.erf(z / Math.sqrt(2)));
|
|
14397
|
+
}
|
|
14398
|
+
default:
|
|
14399
|
+
return mean;
|
|
14400
|
+
}
|
|
14401
|
+
}
|
|
14402
|
+
/**
|
|
14403
|
+
* Error function approximation for acquisition function calculations
|
|
14404
|
+
*/
|
|
14405
|
+
erf(x) {
|
|
14406
|
+
const a1 = 0.254829592;
|
|
14407
|
+
const a2 = -0.284496736;
|
|
14408
|
+
const a3 = 1.421413741;
|
|
14409
|
+
const a4 = -1.453152027;
|
|
14410
|
+
const a5 = 1.061405429;
|
|
14411
|
+
const p = 0.3275911;
|
|
14412
|
+
const sign = x >= 0 ? 1 : -1;
|
|
14413
|
+
x = Math.abs(x);
|
|
14414
|
+
const t = 1 / (1 + p * x);
|
|
14415
|
+
const y = 1 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.exp(-x * x);
|
|
14416
|
+
return sign * y;
|
|
14417
|
+
}
|
|
14418
|
+
/**
|
|
14419
|
+
* Selects the next configuration to evaluate using Bayesian optimization
|
|
14420
|
+
*/
|
|
14421
|
+
async selectConfigurationViaBayesianOptimization(instructions, bootstrappedDemos, labeledExamples) {
|
|
14422
|
+
const candidates = [];
|
|
14423
|
+
const numCandidates = Math.min(20, instructions.length * 3);
|
|
14424
|
+
for (let i = 0; i < numCandidates; i++) {
|
|
14425
|
+
const config = {
|
|
14426
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
14427
|
+
bootstrappedDemos: Math.min(
|
|
14428
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
14429
|
+
this.maxBootstrappedDemos
|
|
14430
|
+
),
|
|
14431
|
+
labeledExamples: Math.min(
|
|
14432
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
14433
|
+
this.maxLabeledDemos
|
|
14434
|
+
)
|
|
14435
|
+
};
|
|
14436
|
+
const acquisitionValue = this.calculateAcquisitionValue(config);
|
|
14437
|
+
candidates.push({ config, acquisitionValue });
|
|
14438
|
+
}
|
|
14439
|
+
candidates.sort((a, b) => b.acquisitionValue - a.acquisitionValue);
|
|
14440
|
+
return candidates[0].config;
|
|
14441
|
+
}
|
|
14442
|
+
};
|
|
14443
|
+
var axMajorityVotePicker = () => {
|
|
14444
|
+
return async (data) => {
|
|
14445
|
+
if (data.type === "fields") {
|
|
14446
|
+
const counts = {};
|
|
14447
|
+
for (const { index, sample } of data.results) {
|
|
14448
|
+
const key = JSON.stringify(sample);
|
|
14449
|
+
if (!counts[key]) {
|
|
14450
|
+
counts[key] = { count: 0, index };
|
|
14451
|
+
}
|
|
14452
|
+
counts[key].count += 1;
|
|
14453
|
+
}
|
|
14454
|
+
let bestKey;
|
|
14455
|
+
let bestCount = -1;
|
|
14456
|
+
for (const [k, v] of Object.entries(counts)) {
|
|
14457
|
+
if (v.count > bestCount) {
|
|
14458
|
+
bestCount = v.count;
|
|
14459
|
+
bestKey = k;
|
|
14460
|
+
}
|
|
14461
|
+
}
|
|
14462
|
+
return counts[bestKey]?.index ?? 0;
|
|
14463
|
+
}
|
|
14464
|
+
return data.results[0]?.index ?? 0;
|
|
14465
|
+
};
|
|
13995
14466
|
};
|
|
13996
14467
|
|
|
13997
14468
|
// ai/mock/api.ts
|