@ax-llm/ax 12.0.12 → 12.0.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +586 -115
- package/index.cjs.map +1 -1
- package/index.d.cts +79 -17
- package/index.d.ts +79 -17
- package/index.js +586 -115
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.js
CHANGED
|
@@ -6083,15 +6083,24 @@ var MemoryImpl = class {
|
|
|
6083
6083
|
debugRequest(items, this.options?.debugHideSystemPrompt);
|
|
6084
6084
|
}
|
|
6085
6085
|
}
|
|
6086
|
+
addFunctionResults(results) {
|
|
6087
|
+
const chat = results.map(({ index, ...value }) => ({
|
|
6088
|
+
index,
|
|
6089
|
+
value: structuredClone(value)
|
|
6090
|
+
}));
|
|
6091
|
+
const lastItem = this.getLast();
|
|
6092
|
+
if (lastItem?.role === "function") {
|
|
6093
|
+
lastItem.chat.push(...chat);
|
|
6094
|
+
} else {
|
|
6095
|
+
this.data.push({ role: "function", chat });
|
|
6096
|
+
}
|
|
6097
|
+
}
|
|
6086
6098
|
addResponse(results) {
|
|
6087
|
-
const chat = results.map((
|
|
6088
|
-
index
|
|
6089
|
-
value: structuredClone(
|
|
6099
|
+
const chat = results.map(({ index, ...value }) => ({
|
|
6100
|
+
index,
|
|
6101
|
+
value: structuredClone(value)
|
|
6090
6102
|
}));
|
|
6091
|
-
this.data.push({
|
|
6092
|
-
role: "assistant",
|
|
6093
|
-
chat
|
|
6094
|
-
});
|
|
6103
|
+
this.data.push({ role: "assistant", chat });
|
|
6095
6104
|
if (this.options?.debug) {
|
|
6096
6105
|
for (const result of results) {
|
|
6097
6106
|
debugResponse(result);
|
|
@@ -6179,9 +6188,20 @@ var MemoryImpl = class {
|
|
|
6179
6188
|
history(index) {
|
|
6180
6189
|
const result = [];
|
|
6181
6190
|
for (const { role, chat } of this.data) {
|
|
6182
|
-
|
|
6183
|
-
if (
|
|
6184
|
-
|
|
6191
|
+
let values;
|
|
6192
|
+
if (role === "function") {
|
|
6193
|
+
values = chat.filter((v) => v.index === index).map((v) => v.value);
|
|
6194
|
+
} else {
|
|
6195
|
+
values = chat.find((v) => v.index === index)?.value;
|
|
6196
|
+
}
|
|
6197
|
+
if (Array.isArray(values)) {
|
|
6198
|
+
result.push(
|
|
6199
|
+
...values.map(
|
|
6200
|
+
(v) => ({ ...v, role })
|
|
6201
|
+
)
|
|
6202
|
+
);
|
|
6203
|
+
} else if (values) {
|
|
6204
|
+
result.push({ ...values, role });
|
|
6185
6205
|
}
|
|
6186
6206
|
}
|
|
6187
6207
|
return result;
|
|
@@ -6219,20 +6239,8 @@ var AxMemory = class {
|
|
|
6219
6239
|
axValidateChatResponseResult(results);
|
|
6220
6240
|
this.getMemory(sessionId).addResponse(results);
|
|
6221
6241
|
}
|
|
6222
|
-
|
|
6223
|
-
|
|
6224
|
-
isError,
|
|
6225
|
-
index,
|
|
6226
|
-
result
|
|
6227
|
-
}, sessionId) {
|
|
6228
|
-
const functionMessage = {
|
|
6229
|
-
role: "function",
|
|
6230
|
-
functionId,
|
|
6231
|
-
isError,
|
|
6232
|
-
result
|
|
6233
|
-
};
|
|
6234
|
-
axValidateChatRequestMessage(functionMessage);
|
|
6235
|
-
this.getMemory(sessionId).addRequest([functionMessage], index);
|
|
6242
|
+
addFunctionResults(results, sessionId) {
|
|
6243
|
+
this.getMemory(sessionId).addFunctionResults(results);
|
|
6236
6244
|
}
|
|
6237
6245
|
updateResult(result, sessionId) {
|
|
6238
6246
|
this.getMemory(sessionId).updateResult(result);
|
|
@@ -7469,51 +7477,48 @@ var processFunctions = async ({
|
|
|
7469
7477
|
}
|
|
7470
7478
|
return {
|
|
7471
7479
|
result: functionResult ?? "",
|
|
7480
|
+
role: "function",
|
|
7472
7481
|
functionId: func.id,
|
|
7473
7482
|
index
|
|
7474
7483
|
};
|
|
7475
7484
|
}).catch((e) => {
|
|
7476
|
-
if (e instanceof FunctionError) {
|
|
7477
|
-
|
|
7478
|
-
|
|
7479
|
-
|
|
7480
|
-
|
|
7481
|
-
|
|
7482
|
-
|
|
7483
|
-
|
|
7484
|
-
|
|
7485
|
-
|
|
7486
|
-
|
|
7487
|
-
|
|
7485
|
+
if (!(e instanceof FunctionError)) {
|
|
7486
|
+
throw e;
|
|
7487
|
+
}
|
|
7488
|
+
const result = e.getFixingInstructions();
|
|
7489
|
+
if (span) {
|
|
7490
|
+
const errorEventData = {
|
|
7491
|
+
name: func.name,
|
|
7492
|
+
message: e.toString()
|
|
7493
|
+
};
|
|
7494
|
+
if (!excludeContentFromTrace) {
|
|
7495
|
+
errorEventData.args = func.args;
|
|
7496
|
+
errorEventData.fixing_instructions = result;
|
|
7488
7497
|
}
|
|
7489
|
-
|
|
7490
|
-
|
|
7491
|
-
|
|
7492
|
-
|
|
7493
|
-
|
|
7494
|
-
result
|
|
7495
|
-
},
|
|
7496
|
-
sessionId
|
|
7497
|
-
);
|
|
7498
|
-
mem.addTag("error", sessionId);
|
|
7499
|
-
if (ai.getOptions().debug) {
|
|
7500
|
-
const logger = ai.getLogger();
|
|
7501
|
-
logger(`\u274C Function Error Correction:
|
|
7498
|
+
span.addEvent("function.error", errorEventData);
|
|
7499
|
+
}
|
|
7500
|
+
if (ai.getOptions().debug) {
|
|
7501
|
+
const logger = ai.getLogger();
|
|
7502
|
+
logger(`\u274C Function Error Correction:
|
|
7502
7503
|
${result}`, {
|
|
7503
|
-
|
|
7504
|
-
|
|
7505
|
-
}
|
|
7506
|
-
} else {
|
|
7507
|
-
throw e;
|
|
7504
|
+
tags: ["error"]
|
|
7505
|
+
});
|
|
7508
7506
|
}
|
|
7507
|
+
return {
|
|
7508
|
+
functionId: func.id,
|
|
7509
|
+
isError: true,
|
|
7510
|
+
index,
|
|
7511
|
+
result,
|
|
7512
|
+
role: "function"
|
|
7513
|
+
};
|
|
7509
7514
|
});
|
|
7510
7515
|
return promise;
|
|
7511
7516
|
});
|
|
7512
7517
|
const results = await Promise.all(promises);
|
|
7513
|
-
|
|
7514
|
-
|
|
7515
|
-
|
|
7516
|
-
|
|
7518
|
+
const functionResults = results.filter((result) => result !== void 0);
|
|
7519
|
+
mem.addFunctionResults(functionResults, sessionId);
|
|
7520
|
+
if (functionResults.some((result) => result.isError)) {
|
|
7521
|
+
mem.addTag("error", sessionId);
|
|
7517
7522
|
}
|
|
7518
7523
|
return functionsExecuted;
|
|
7519
7524
|
};
|
|
@@ -9787,6 +9792,96 @@ var AxProgram = class {
|
|
|
9787
9792
|
}
|
|
9788
9793
|
};
|
|
9789
9794
|
|
|
9795
|
+
// dsp/samples.ts
|
|
9796
|
+
function checkForFunctionCalls(mem, sessionId) {
|
|
9797
|
+
const history = mem.history(0, sessionId);
|
|
9798
|
+
const hasFunctionResults = history.some((msg) => msg.role === "function");
|
|
9799
|
+
const hasFunctionCalls = history.some(
|
|
9800
|
+
(msg) => msg.role === "assistant" && "functionCalls" in msg && Array.isArray(msg.functionCalls) && msg.functionCalls.length > 0
|
|
9801
|
+
);
|
|
9802
|
+
return hasFunctionCalls && hasFunctionResults;
|
|
9803
|
+
}
|
|
9804
|
+
function extractFunctionResults(mem, sessionId) {
|
|
9805
|
+
const history = mem.history(0, sessionId);
|
|
9806
|
+
const results = [];
|
|
9807
|
+
const assistantMessages = history.filter(
|
|
9808
|
+
(msg) => msg.role === "assistant" && "functionCalls" in msg && Array.isArray(msg.functionCalls) && msg.functionCalls.length > 0
|
|
9809
|
+
);
|
|
9810
|
+
const functionMessages = history.filter((msg) => msg.role === "function");
|
|
9811
|
+
for (const assistantMsg of assistantMessages) {
|
|
9812
|
+
if ("functionCalls" in assistantMsg && assistantMsg.functionCalls) {
|
|
9813
|
+
for (const funcCall of assistantMsg.functionCalls) {
|
|
9814
|
+
const funcResult = functionMessages.find(
|
|
9815
|
+
(msg) => "functionId" in msg && msg.functionId === funcCall.id
|
|
9816
|
+
);
|
|
9817
|
+
if (funcResult && "result" in funcResult && "functionId" in funcResult) {
|
|
9818
|
+
results.push({
|
|
9819
|
+
index: results.length,
|
|
9820
|
+
// Use sequential index for function results
|
|
9821
|
+
functionName: funcCall.function.name,
|
|
9822
|
+
functionId: funcCall.id,
|
|
9823
|
+
args: funcCall.function.params || "",
|
|
9824
|
+
result: String(funcResult.result),
|
|
9825
|
+
isError: "isError" in funcResult ? Boolean(funcResult.isError) : false
|
|
9826
|
+
});
|
|
9827
|
+
}
|
|
9828
|
+
}
|
|
9829
|
+
}
|
|
9830
|
+
}
|
|
9831
|
+
return results;
|
|
9832
|
+
}
|
|
9833
|
+
async function selectFromSamples(buffer, options, mem, sessionId) {
|
|
9834
|
+
if (!options?.resultPicker || buffer.length <= 1) {
|
|
9835
|
+
return 0;
|
|
9836
|
+
}
|
|
9837
|
+
const resultPicker = options.resultPicker;
|
|
9838
|
+
const hasFunctionCalls = mem ? checkForFunctionCalls(mem, sessionId) : false;
|
|
9839
|
+
if (hasFunctionCalls && mem) {
|
|
9840
|
+
const functionResults = extractFunctionResults(mem, sessionId);
|
|
9841
|
+
const selectedIndex = await resultPicker({
|
|
9842
|
+
type: "function",
|
|
9843
|
+
results: functionResults
|
|
9844
|
+
});
|
|
9845
|
+
if (selectedIndex < 0 || selectedIndex >= functionResults.length) {
|
|
9846
|
+
throw new Error(
|
|
9847
|
+
`Result picker returned invalid index: ${selectedIndex}. Must be between 0 and ${functionResults.length - 1}`
|
|
9848
|
+
);
|
|
9849
|
+
}
|
|
9850
|
+
return selectedIndex;
|
|
9851
|
+
} else {
|
|
9852
|
+
const fieldResults = buffer.map((b, index) => ({
|
|
9853
|
+
index,
|
|
9854
|
+
sample: b.delta
|
|
9855
|
+
}));
|
|
9856
|
+
const selectedIndex = await resultPicker({
|
|
9857
|
+
type: "fields",
|
|
9858
|
+
results: fieldResults
|
|
9859
|
+
});
|
|
9860
|
+
if (selectedIndex < 0 || selectedIndex >= buffer.length) {
|
|
9861
|
+
throw new Error(
|
|
9862
|
+
`Result picker returned invalid index: ${selectedIndex}. Must be between 0 and ${buffer.length - 1}`
|
|
9863
|
+
);
|
|
9864
|
+
}
|
|
9865
|
+
return selectedIndex;
|
|
9866
|
+
}
|
|
9867
|
+
}
|
|
9868
|
+
async function selectFromSamplesInMemory(mem, sessionId, options) {
|
|
9869
|
+
const lastMemory = mem?.getLast(sessionId);
|
|
9870
|
+
if (!lastMemory || lastMemory.role !== "assistant") {
|
|
9871
|
+
return 0;
|
|
9872
|
+
}
|
|
9873
|
+
if (lastMemory.chat.length <= 1) {
|
|
9874
|
+
return 0;
|
|
9875
|
+
}
|
|
9876
|
+
const buffer = lastMemory.chat.map((chat) => ({
|
|
9877
|
+
version: 0,
|
|
9878
|
+
index: chat.index,
|
|
9879
|
+
delta: chat.value
|
|
9880
|
+
}));
|
|
9881
|
+
const selectedIndex = await selectFromSamples(buffer, options, mem, sessionId);
|
|
9882
|
+
return selectedIndex;
|
|
9883
|
+
}
|
|
9884
|
+
|
|
9790
9885
|
// dsp/validate.ts
|
|
9791
9886
|
function handleValidationError(mem, errorFields, ai, promptTemplate, sessionId) {
|
|
9792
9887
|
mem.addRequest(
|
|
@@ -9902,7 +9997,10 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9902
9997
|
thinkingTokenBudget,
|
|
9903
9998
|
showThoughts
|
|
9904
9999
|
} = options ?? {};
|
|
9905
|
-
const
|
|
10000
|
+
const selectedIndex = await selectFromSamplesInMemory(mem, sessionId, {
|
|
10001
|
+
resultPicker: options?.resultPicker
|
|
10002
|
+
});
|
|
10003
|
+
const chatPrompt = mem?.history(selectedIndex, sessionId) ?? [];
|
|
9906
10004
|
if (chatPrompt.length === 0) {
|
|
9907
10005
|
throw new Error("No chat prompt found");
|
|
9908
10006
|
}
|
|
@@ -9913,7 +10011,8 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9913
10011
|
}
|
|
9914
10012
|
const modelConfig = {
|
|
9915
10013
|
...options?.modelConfig,
|
|
9916
|
-
...options?.sampleCount ? { n: options.sampleCount } : {}
|
|
10014
|
+
...options?.sampleCount ? { n: options.sampleCount } : {},
|
|
10015
|
+
...options?.sampleCount && options?.modelConfig?.temperature == 1 ? { temperature: 0.8 } : {}
|
|
9917
10016
|
};
|
|
9918
10017
|
const res = await ai.chat(
|
|
9919
10018
|
{
|
|
@@ -10190,15 +10289,58 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
10190
10289
|
currentVersion = delta.version;
|
|
10191
10290
|
buffer = mergeDeltas(buffer, delta);
|
|
10192
10291
|
}
|
|
10193
|
-
const
|
|
10292
|
+
const selectedIndex = await selectFromSamples(
|
|
10293
|
+
buffer,
|
|
10294
|
+
{
|
|
10295
|
+
resultPicker: options?.resultPicker
|
|
10296
|
+
},
|
|
10297
|
+
// Pass memory to enable function result selection
|
|
10298
|
+
options?.mem,
|
|
10299
|
+
options?.sessionId
|
|
10300
|
+
);
|
|
10301
|
+
const selectedResult = buffer[selectedIndex];
|
|
10302
|
+
const result = selectedResult?.delta ?? {};
|
|
10194
10303
|
this.trace = { ...values, ...result };
|
|
10195
10304
|
return result;
|
|
10196
10305
|
}
|
|
10197
10306
|
async *streamingForward(ai, values, options) {
|
|
10198
|
-
|
|
10307
|
+
if (!options?.resultPicker) {
|
|
10308
|
+
yield* this._forward1(ai, values, {
|
|
10309
|
+
...options,
|
|
10310
|
+
stream: true
|
|
10311
|
+
});
|
|
10312
|
+
return;
|
|
10313
|
+
}
|
|
10314
|
+
const generator = this._forward1(ai, values, {
|
|
10199
10315
|
...options,
|
|
10200
10316
|
stream: true
|
|
10201
10317
|
});
|
|
10318
|
+
let buffer = [];
|
|
10319
|
+
let currentVersion = 0;
|
|
10320
|
+
for await (const delta of generator) {
|
|
10321
|
+
if (delta.version !== currentVersion) {
|
|
10322
|
+
buffer = [];
|
|
10323
|
+
}
|
|
10324
|
+
currentVersion = delta.version;
|
|
10325
|
+
buffer = mergeDeltas(buffer, delta);
|
|
10326
|
+
}
|
|
10327
|
+
const selectedIndex = await selectFromSamples(
|
|
10328
|
+
buffer,
|
|
10329
|
+
{
|
|
10330
|
+
resultPicker: options?.resultPicker
|
|
10331
|
+
},
|
|
10332
|
+
// Pass memory to enable function result selection
|
|
10333
|
+
options?.mem,
|
|
10334
|
+
options?.sessionId
|
|
10335
|
+
);
|
|
10336
|
+
const selectedResult = buffer[selectedIndex];
|
|
10337
|
+
if (selectedResult) {
|
|
10338
|
+
yield {
|
|
10339
|
+
version: currentVersion,
|
|
10340
|
+
index: selectedIndex,
|
|
10341
|
+
delta: selectedResult.delta
|
|
10342
|
+
};
|
|
10343
|
+
}
|
|
10202
10344
|
}
|
|
10203
10345
|
setExamples(examples, options) {
|
|
10204
10346
|
super.setExamples(examples, options);
|
|
@@ -11558,13 +11700,6 @@ var AxBaseOptimizer = class {
|
|
|
11558
11700
|
if (this.logger) {
|
|
11559
11701
|
return this.logger;
|
|
11560
11702
|
}
|
|
11561
|
-
try {
|
|
11562
|
-
const aiLogger = this.studentAI.getLogger();
|
|
11563
|
-
if (aiLogger) {
|
|
11564
|
-
return aiLogger;
|
|
11565
|
-
}
|
|
11566
|
-
} catch {
|
|
11567
|
-
}
|
|
11568
11703
|
return axDefaultOptimizerLogger;
|
|
11569
11704
|
}
|
|
11570
11705
|
/**
|
|
@@ -13256,6 +13391,11 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13256
13391
|
bayesianOptimization;
|
|
13257
13392
|
acquisitionFunction;
|
|
13258
13393
|
explorationWeight;
|
|
13394
|
+
// Self-consistency / multiple sampling
|
|
13395
|
+
sampleCount;
|
|
13396
|
+
// Surrogate model state for Bayesian optimization
|
|
13397
|
+
miproConfigHistory = [];
|
|
13398
|
+
surrogateModel = /* @__PURE__ */ new Map();
|
|
13259
13399
|
constructor(args) {
|
|
13260
13400
|
super(args);
|
|
13261
13401
|
const options = args.options || {};
|
|
@@ -13277,6 +13417,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13277
13417
|
this.bayesianOptimization = options.bayesianOptimization ?? false;
|
|
13278
13418
|
this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
|
|
13279
13419
|
this.explorationWeight = options.explorationWeight ?? 0.1;
|
|
13420
|
+
this.sampleCount = options.sampleCount ?? 1;
|
|
13280
13421
|
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13281
13422
|
}
|
|
13282
13423
|
/**
|
|
@@ -13321,43 +13462,186 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13321
13462
|
];
|
|
13322
13463
|
}
|
|
13323
13464
|
/**
|
|
13324
|
-
* Generates
|
|
13465
|
+
* Generates program summary for context-aware instruction generation
|
|
13466
|
+
*/
|
|
13467
|
+
async generateProgramSummary(program, ai) {
|
|
13468
|
+
let signature = "input -> output";
|
|
13469
|
+
if ("getSignature" in program && typeof program.getSignature === "function") {
|
|
13470
|
+
signature = program.getSignature();
|
|
13471
|
+
}
|
|
13472
|
+
const summaryPrompt = `
|
|
13473
|
+
Analyze this language model program and provide a concise summary of its purpose and structure.
|
|
13474
|
+
|
|
13475
|
+
Program Signature: ${signature}
|
|
13476
|
+
|
|
13477
|
+
Provide a 2-3 sentence summary focusing on:
|
|
13478
|
+
1. The main task or purpose of this program
|
|
13479
|
+
2. The input-output relationship
|
|
13480
|
+
3. Any special constraints or requirements
|
|
13481
|
+
|
|
13482
|
+
Summary:`;
|
|
13483
|
+
try {
|
|
13484
|
+
const response = await ai.chat({
|
|
13485
|
+
chatPrompt: [{ role: "user", content: summaryPrompt }]
|
|
13486
|
+
});
|
|
13487
|
+
if ("results" in response) {
|
|
13488
|
+
return response.results[0]?.content?.trim() || "General language model program";
|
|
13489
|
+
}
|
|
13490
|
+
return "General language model program";
|
|
13491
|
+
} catch {
|
|
13492
|
+
return "General language model program";
|
|
13493
|
+
}
|
|
13494
|
+
}
|
|
13495
|
+
/**
|
|
13496
|
+
* Generates dataset summary for context-aware instruction generation
|
|
13497
|
+
*/
|
|
13498
|
+
async generateDatasetSummary(examples, ai) {
|
|
13499
|
+
if (examples.length === 0) return "No examples available";
|
|
13500
|
+
const sampleSize = Math.min(this.viewDataBatchSize, examples.length);
|
|
13501
|
+
const sampledExamples = examples.slice(0, sampleSize);
|
|
13502
|
+
const exampleTexts = sampledExamples.map((ex, i) => `Example ${i + 1}: ${JSON.stringify(ex)}`).join("\n");
|
|
13503
|
+
const summaryPrompt = `
|
|
13504
|
+
Analyze this dataset and provide a concise summary of its characteristics.
|
|
13505
|
+
|
|
13506
|
+
Sample Examples:
|
|
13507
|
+
${exampleTexts}
|
|
13508
|
+
|
|
13509
|
+
Provide a 2-3 sentence summary focusing on:
|
|
13510
|
+
1. The type of data and domain
|
|
13511
|
+
2. Common patterns or structures in the examples
|
|
13512
|
+
3. Key challenges or requirements for processing this data
|
|
13513
|
+
|
|
13514
|
+
Dataset Summary:`;
|
|
13515
|
+
try {
|
|
13516
|
+
const response = await ai.chat({
|
|
13517
|
+
chatPrompt: [{ role: "user", content: summaryPrompt }]
|
|
13518
|
+
});
|
|
13519
|
+
if ("results" in response) {
|
|
13520
|
+
return response.results[0]?.content?.trim() || "General dataset";
|
|
13521
|
+
}
|
|
13522
|
+
return "General dataset";
|
|
13523
|
+
} catch {
|
|
13524
|
+
return "General dataset";
|
|
13525
|
+
}
|
|
13526
|
+
}
|
|
13527
|
+
/**
|
|
13528
|
+
* Enhanced instruction generation using AI with program and data awareness
|
|
13529
|
+
*/
|
|
13530
|
+
async generateInstruction({
|
|
13531
|
+
tip,
|
|
13532
|
+
candidateIndex,
|
|
13533
|
+
ai,
|
|
13534
|
+
programSummary,
|
|
13535
|
+
datasetSummary,
|
|
13536
|
+
previousInstructions = []
|
|
13537
|
+
}) {
|
|
13538
|
+
let contextInfo = "";
|
|
13539
|
+
if (this.programAwareProposer && programSummary) {
|
|
13540
|
+
contextInfo += `
|
|
13541
|
+
Program Context: ${programSummary}`;
|
|
13542
|
+
}
|
|
13543
|
+
if (this.dataAwareProposer && datasetSummary) {
|
|
13544
|
+
contextInfo += `
|
|
13545
|
+
Dataset Context: ${datasetSummary}`;
|
|
13546
|
+
}
|
|
13547
|
+
if (this.fewshotAwareProposer && previousInstructions.length > 0) {
|
|
13548
|
+
contextInfo += `
|
|
13549
|
+
Previous Instructions (avoid repeating): ${previousInstructions.slice(-3).join("; ")}`;
|
|
13550
|
+
}
|
|
13551
|
+
const instructionPrompt = `
|
|
13552
|
+
Generate a high-quality instruction for a language model program.
|
|
13553
|
+
|
|
13554
|
+
${contextInfo}
|
|
13555
|
+
|
|
13556
|
+
${tip ? `Tip: ${tip}` : ""}
|
|
13557
|
+
|
|
13558
|
+
Requirements:
|
|
13559
|
+
1. Be specific and actionable
|
|
13560
|
+
2. Focus on accuracy and clarity
|
|
13561
|
+
3. Consider the program's purpose and data characteristics
|
|
13562
|
+
4. Make the instruction distinct from previous ones
|
|
13563
|
+
5. Keep it concise but comprehensive
|
|
13564
|
+
|
|
13565
|
+
Generate a single, well-crafted instruction:
|
|
13566
|
+
Instruction:`;
|
|
13567
|
+
try {
|
|
13568
|
+
const response = await ai.chat({
|
|
13569
|
+
chatPrompt: [
|
|
13570
|
+
{
|
|
13571
|
+
role: "user",
|
|
13572
|
+
content: instructionPrompt
|
|
13573
|
+
}
|
|
13574
|
+
]
|
|
13575
|
+
});
|
|
13576
|
+
if ("results" in response) {
|
|
13577
|
+
const instruction2 = response.results[0]?.content?.trim();
|
|
13578
|
+
if (instruction2 && instruction2.length > 10) {
|
|
13579
|
+
return instruction2;
|
|
13580
|
+
}
|
|
13581
|
+
}
|
|
13582
|
+
} catch (error) {
|
|
13583
|
+
if (this.isLoggingEnabled()) {
|
|
13584
|
+
this.getLogger()?.(`Failed to generate AI instruction: ${error}`, {
|
|
13585
|
+
tags: ["optimizer", "warning"]
|
|
13586
|
+
});
|
|
13587
|
+
}
|
|
13588
|
+
}
|
|
13589
|
+
const enhancedTemplates = [
|
|
13590
|
+
"Analyze the input systematically and provide a precise, well-reasoned response.",
|
|
13591
|
+
"Think through this step-by-step, considering all relevant factors before responding.",
|
|
13592
|
+
"Examine the input carefully and generate an accurate, detailed answer.",
|
|
13593
|
+
"Process the information methodically and deliver a clear, comprehensive response.",
|
|
13594
|
+
"Consider the context thoroughly and provide a thoughtful, accurate answer."
|
|
13595
|
+
];
|
|
13596
|
+
let instruction = enhancedTemplates[candidateIndex % enhancedTemplates.length] || enhancedTemplates[0];
|
|
13597
|
+
if (tip) {
|
|
13598
|
+
instruction = `${instruction} ${tip}`;
|
|
13599
|
+
}
|
|
13600
|
+
return instruction;
|
|
13601
|
+
}
|
|
13602
|
+
/**
|
|
13603
|
+
* Generates instruction candidates using enhanced AI-powered generation
|
|
13325
13604
|
* @param options Optional compile options that may override teacher AI
|
|
13326
13605
|
* @returns Array of generated instruction candidates
|
|
13327
13606
|
*/
|
|
13328
|
-
async proposeInstructionCandidates(options) {
|
|
13607
|
+
async proposeInstructionCandidates(program, options) {
|
|
13329
13608
|
const instructions = [];
|
|
13330
13609
|
const aiToUse = this.getTeacherOrStudentAI(options);
|
|
13610
|
+
let programSummary;
|
|
13611
|
+
let datasetSummary;
|
|
13612
|
+
if (this.programAwareProposer) {
|
|
13613
|
+
programSummary = await this.generateProgramSummary(program, aiToUse);
|
|
13614
|
+
if (this.isLoggingEnabled(options)) {
|
|
13615
|
+
this.getLogger(options)?.(`Program summary: ${programSummary}`, {
|
|
13616
|
+
tags: ["optimizer", "config"]
|
|
13617
|
+
});
|
|
13618
|
+
}
|
|
13619
|
+
}
|
|
13620
|
+
if (this.dataAwareProposer) {
|
|
13621
|
+
datasetSummary = await this.generateDatasetSummary(this.examples, aiToUse);
|
|
13622
|
+
if (this.isLoggingEnabled(options)) {
|
|
13623
|
+
this.getLogger(options)?.(`Dataset summary: ${datasetSummary}`, {
|
|
13624
|
+
tags: ["optimizer", "config"]
|
|
13625
|
+
});
|
|
13626
|
+
}
|
|
13627
|
+
}
|
|
13331
13628
|
const tips = this.tipAwareProposer ? this.generateTips() : [];
|
|
13332
13629
|
for (let i = 0; i < this.numCandidates; i++) {
|
|
13333
13630
|
const tipIndex = tips.length > 0 ? i % tips.length : -1;
|
|
13334
|
-
const tipToUse = tipIndex >= 0 ? tips[tipIndex] :
|
|
13631
|
+
const tipToUse = tipIndex >= 0 ? tips[tipIndex] : void 0;
|
|
13335
13632
|
const instruction = await this.generateInstruction({
|
|
13336
13633
|
tip: tipToUse,
|
|
13337
13634
|
candidateIndex: i,
|
|
13338
|
-
ai: aiToUse
|
|
13635
|
+
ai: aiToUse,
|
|
13636
|
+
programSummary,
|
|
13637
|
+
datasetSummary,
|
|
13638
|
+
previousInstructions: instructions
|
|
13639
|
+
// Pass previous instructions for diversity
|
|
13339
13640
|
});
|
|
13340
13641
|
instructions.push(instruction);
|
|
13341
13642
|
}
|
|
13342
13643
|
return instructions;
|
|
13343
13644
|
}
|
|
13344
|
-
async generateInstruction({
|
|
13345
|
-
tip,
|
|
13346
|
-
candidateIndex
|
|
13347
|
-
}) {
|
|
13348
|
-
const baseInstructions = [
|
|
13349
|
-
"Analyze the input carefully and provide a detailed response.",
|
|
13350
|
-
"Think step by step and provide a clear answer.",
|
|
13351
|
-
"Consider all aspects of the input before responding.",
|
|
13352
|
-
"Provide a concise but comprehensive response.",
|
|
13353
|
-
"Focus on accuracy and clarity in your response."
|
|
13354
|
-
];
|
|
13355
|
-
let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
|
|
13356
|
-
if (tip) {
|
|
13357
|
-
instruction = `${instruction} ${tip}`;
|
|
13358
|
-
}
|
|
13359
|
-
return instruction;
|
|
13360
|
-
}
|
|
13361
13645
|
/**
|
|
13362
13646
|
* Bootstraps few-shot examples for the program
|
|
13363
13647
|
*/
|
|
@@ -13402,7 +13686,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13402
13686
|
/**
|
|
13403
13687
|
* Runs optimization to find the best combination of few-shot examples and instructions
|
|
13404
13688
|
*/
|
|
13405
|
-
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions,
|
|
13689
|
+
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, validationExamples, metricFn, options) {
|
|
13406
13690
|
let bestConfig = {
|
|
13407
13691
|
instruction: instructions[0] || "",
|
|
13408
13692
|
bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
|
|
@@ -13438,25 +13722,37 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13438
13722
|
);
|
|
13439
13723
|
}
|
|
13440
13724
|
for (let i = startRound; i < this.numTrials; i++) {
|
|
13441
|
-
|
|
13442
|
-
|
|
13443
|
-
|
|
13444
|
-
|
|
13445
|
-
|
|
13446
|
-
|
|
13447
|
-
|
|
13448
|
-
|
|
13449
|
-
|
|
13450
|
-
|
|
13451
|
-
|
|
13725
|
+
let config;
|
|
13726
|
+
if (this.bayesianOptimization && this.miproConfigHistory.length > 2) {
|
|
13727
|
+
config = await this.selectConfigurationViaBayesianOptimization(
|
|
13728
|
+
instructions,
|
|
13729
|
+
bootstrappedDemos,
|
|
13730
|
+
labeledExamples
|
|
13731
|
+
);
|
|
13732
|
+
} else {
|
|
13733
|
+
config = {
|
|
13734
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
13735
|
+
bootstrappedDemos: Math.min(
|
|
13736
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
13737
|
+
this.maxBootstrappedDemos
|
|
13738
|
+
),
|
|
13739
|
+
labeledExamples: Math.min(
|
|
13740
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
13741
|
+
this.maxLabeledDemos
|
|
13742
|
+
)
|
|
13743
|
+
};
|
|
13744
|
+
}
|
|
13452
13745
|
const score = await this.evaluateConfig(
|
|
13453
13746
|
program,
|
|
13454
13747
|
config,
|
|
13455
13748
|
bootstrappedDemos,
|
|
13456
13749
|
labeledExamples,
|
|
13457
|
-
|
|
13458
|
-
metricFn
|
|
13750
|
+
validationExamples,
|
|
13751
|
+
metricFn,
|
|
13752
|
+
i + 1
|
|
13753
|
+
// Pass current trial number for adaptive evaluation
|
|
13459
13754
|
);
|
|
13755
|
+
this.updateSurrogateModel(config, score);
|
|
13460
13756
|
scoreHistory.push(score);
|
|
13461
13757
|
const improvement = score - bestScore;
|
|
13462
13758
|
if (improvement > this.minImprovementThreshold) {
|
|
@@ -13538,7 +13834,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13538
13834
|
this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
|
|
13539
13835
|
return { bestConfig, bestScore };
|
|
13540
13836
|
}
|
|
13541
|
-
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples,
|
|
13837
|
+
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, validationExamples, metricFn, currentTrial = 0) {
|
|
13542
13838
|
const testProgram = { ...program };
|
|
13543
13839
|
this.applyConfigToProgram(
|
|
13544
13840
|
testProgram,
|
|
@@ -13548,12 +13844,31 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13548
13844
|
);
|
|
13549
13845
|
let totalScore = 0;
|
|
13550
13846
|
let count = 0;
|
|
13551
|
-
|
|
13847
|
+
let evalSize;
|
|
13848
|
+
if (this.minibatch) {
|
|
13849
|
+
const baseSize = Math.min(this.minibatchSize, validationExamples.length);
|
|
13850
|
+
const isFullEvalTrial = currentTrial % this.minibatchFullEvalSteps === 0;
|
|
13851
|
+
if (isFullEvalTrial || currentTrial > this.numTrials * 0.8) {
|
|
13852
|
+
evalSize = Math.min(validationExamples.length, baseSize * 2);
|
|
13853
|
+
} else {
|
|
13854
|
+
evalSize = Math.max(3, Math.min(baseSize, validationExamples.length));
|
|
13855
|
+
}
|
|
13856
|
+
} else {
|
|
13857
|
+
evalSize = validationExamples.length;
|
|
13858
|
+
}
|
|
13859
|
+
const evalIndices = this.shuffleArray([
|
|
13860
|
+
...Array(validationExamples.length).keys()
|
|
13861
|
+
]).slice(0, evalSize);
|
|
13862
|
+
const evalSet = evalIndices.map((i) => validationExamples[i]);
|
|
13552
13863
|
for (const example of evalSet) {
|
|
13553
13864
|
try {
|
|
13554
13865
|
const prediction = await testProgram.forward(
|
|
13555
13866
|
this.studentAI,
|
|
13556
|
-
example
|
|
13867
|
+
example,
|
|
13868
|
+
this.sampleCount > 1 ? {
|
|
13869
|
+
sampleCount: this.sampleCount,
|
|
13870
|
+
resultPicker: axMajorityVotePicker()
|
|
13871
|
+
} : void 0
|
|
13557
13872
|
);
|
|
13558
13873
|
const score = await metricFn({ prediction, example });
|
|
13559
13874
|
totalScore += score;
|
|
@@ -13565,6 +13880,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13565
13880
|
}
|
|
13566
13881
|
return count > 0 ? totalScore / count : 0;
|
|
13567
13882
|
}
|
|
13883
|
+
/**
|
|
13884
|
+
* Fisher-Yates shuffle for stochastic evaluation
|
|
13885
|
+
*/
|
|
13886
|
+
shuffleArray(array) {
|
|
13887
|
+
const shuffled = [...array];
|
|
13888
|
+
for (let i = shuffled.length - 1; i > 0; i--) {
|
|
13889
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
13890
|
+
[shuffled[i], shuffled[j]] = [shuffled[j], shuffled[i]];
|
|
13891
|
+
}
|
|
13892
|
+
return shuffled;
|
|
13893
|
+
}
|
|
13568
13894
|
applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
|
|
13569
13895
|
if (program.setInstruction) {
|
|
13570
13896
|
program.setInstruction(config.instruction);
|
|
@@ -13586,14 +13912,14 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13586
13912
|
if (miproOptions?.auto) {
|
|
13587
13913
|
this.configureAuto(miproOptions.auto);
|
|
13588
13914
|
}
|
|
13589
|
-
const
|
|
13915
|
+
const validationExamples = this.getValidationSet(options) || (miproOptions?.validationExamples ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
|
|
13590
13916
|
if (this.isLoggingEnabled(options)) {
|
|
13591
13917
|
this.getLogger(options)?.(
|
|
13592
13918
|
`Starting MIPROv2 optimization with ${this.numTrials} trials`,
|
|
13593
13919
|
{ tags: ["optimizer", "start"] }
|
|
13594
13920
|
);
|
|
13595
13921
|
this.getLogger(options)?.(
|
|
13596
|
-
`Using ${this.examples.length} examples for training and ${
|
|
13922
|
+
`Using ${this.examples.length} examples for training and ${validationExamples.length} for validation`,
|
|
13597
13923
|
{ tags: ["optimizer", "config"] }
|
|
13598
13924
|
);
|
|
13599
13925
|
if (this.teacherAI) {
|
|
@@ -13623,7 +13949,10 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13623
13949
|
);
|
|
13624
13950
|
}
|
|
13625
13951
|
}
|
|
13626
|
-
const instructions = await this.proposeInstructionCandidates(
|
|
13952
|
+
const instructions = await this.proposeInstructionCandidates(
|
|
13953
|
+
program,
|
|
13954
|
+
options
|
|
13955
|
+
);
|
|
13627
13956
|
if (this.isLoggingEnabled(options)) {
|
|
13628
13957
|
this.getLogger(options)?.(
|
|
13629
13958
|
`Generated ${instructions.length} instruction candidates`,
|
|
@@ -13641,7 +13970,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13641
13970
|
bootstrappedDemos,
|
|
13642
13971
|
labeledExamples,
|
|
13643
13972
|
instructions,
|
|
13644
|
-
|
|
13973
|
+
validationExamples,
|
|
13645
13974
|
metricFn,
|
|
13646
13975
|
options
|
|
13647
13976
|
);
|
|
@@ -13700,7 +14029,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13700
14029
|
bootstrappedDemos: bestConfig.bootstrappedDemos,
|
|
13701
14030
|
labeledExamples: bestConfig.labeledExamples,
|
|
13702
14031
|
numCandidates: this.numCandidates,
|
|
13703
|
-
numTrials: this.numTrials
|
|
14032
|
+
numTrials: this.numTrials,
|
|
14033
|
+
sampleCount: this.sampleCount
|
|
13704
14034
|
}
|
|
13705
14035
|
};
|
|
13706
14036
|
}
|
|
@@ -13745,7 +14075,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13745
14075
|
minImprovementThreshold: this.minImprovementThreshold,
|
|
13746
14076
|
bayesianOptimization: this.bayesianOptimization,
|
|
13747
14077
|
acquisitionFunction: this.acquisitionFunction,
|
|
13748
|
-
explorationWeight: this.explorationWeight
|
|
14078
|
+
explorationWeight: this.explorationWeight,
|
|
14079
|
+
sampleCount: this.sampleCount
|
|
13749
14080
|
};
|
|
13750
14081
|
}
|
|
13751
14082
|
/**
|
|
@@ -13780,12 +14111,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13780
14111
|
if (config.minImprovementThreshold !== void 0) {
|
|
13781
14112
|
this.minImprovementThreshold = config.minImprovementThreshold;
|
|
13782
14113
|
}
|
|
14114
|
+
if (config.sampleCount !== void 0) {
|
|
14115
|
+
this.sampleCount = config.sampleCount;
|
|
14116
|
+
}
|
|
13783
14117
|
}
|
|
13784
14118
|
/**
|
|
13785
14119
|
* Reset optimizer state for reuse with different programs
|
|
13786
14120
|
*/
|
|
13787
14121
|
reset() {
|
|
13788
14122
|
super.reset();
|
|
14123
|
+
this.miproConfigHistory = [];
|
|
14124
|
+
this.surrogateModel.clear();
|
|
13789
14125
|
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13790
14126
|
}
|
|
13791
14127
|
/**
|
|
@@ -13803,8 +14139,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13803
14139
|
"Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
|
|
13804
14140
|
);
|
|
13805
14141
|
}
|
|
13806
|
-
const
|
|
13807
|
-
if (
|
|
14142
|
+
const validationSetSize = this.getValidationSet().length;
|
|
14143
|
+
if (validationSetSize < 5) {
|
|
13808
14144
|
result.issues.push(
|
|
13809
14145
|
"Validation set too small for reliable MiPRO optimization"
|
|
13810
14146
|
);
|
|
@@ -13818,6 +14154,141 @@ var AxMiPRO = class extends AxBaseOptimizer {
|
|
|
13818
14154
|
suggestions: result.suggestions
|
|
13819
14155
|
};
|
|
13820
14156
|
}
|
|
14157
|
+
/**
|
|
14158
|
+
* Encodes a configuration into a string key for surrogate model lookup
|
|
14159
|
+
*/
|
|
14160
|
+
encodeConfiguration(config) {
|
|
14161
|
+
return `${config.instruction.length}_${config.bootstrappedDemos}_${config.labeledExamples}`;
|
|
14162
|
+
}
|
|
14163
|
+
/**
|
|
14164
|
+
* Updates the surrogate model with a new configuration-score pair
|
|
14165
|
+
*/
|
|
14166
|
+
updateSurrogateModel(config, score) {
|
|
14167
|
+
this.miproConfigHistory.push({ config: { ...config }, score });
|
|
14168
|
+
const key = this.encodeConfiguration(config);
|
|
14169
|
+
const similarConfigs = this.miproConfigHistory.filter(
|
|
14170
|
+
(entry) => this.encodeConfiguration(entry.config) === key
|
|
14171
|
+
);
|
|
14172
|
+
if (similarConfigs.length > 0) {
|
|
14173
|
+
const scores = similarConfigs.map((entry) => entry.score);
|
|
14174
|
+
const mean = scores.reduce((sum, s2) => sum + s2, 0) / scores.length;
|
|
14175
|
+
const variance = scores.length > 1 ? scores.reduce((sum, s2) => sum + Math.pow(s2 - mean, 2), 0) / (scores.length - 1) : 0.1;
|
|
14176
|
+
this.surrogateModel.set(key, { mean, variance });
|
|
14177
|
+
}
|
|
14178
|
+
}
|
|
14179
|
+
/**
|
|
14180
|
+
* Predicts performance using the surrogate model
|
|
14181
|
+
*/
|
|
14182
|
+
predictPerformance(config) {
|
|
14183
|
+
const key = this.encodeConfiguration(config);
|
|
14184
|
+
if (this.surrogateModel.has(key)) {
|
|
14185
|
+
return this.surrogateModel.get(key);
|
|
14186
|
+
}
|
|
14187
|
+
if (this.miproConfigHistory.length > 0) {
|
|
14188
|
+
const similarities = this.miproConfigHistory.map((entry) => {
|
|
14189
|
+
const diff = Math.abs(entry.config.bootstrappedDemos - config.bootstrappedDemos) + Math.abs(entry.config.labeledExamples - config.labeledExamples);
|
|
14190
|
+
return { score: entry.score, similarity: 1 / (1 + diff) };
|
|
14191
|
+
});
|
|
14192
|
+
const totalWeight = similarities.reduce((sum, s2) => sum + s2.similarity, 0);
|
|
14193
|
+
const weightedMean = similarities.reduce((sum, s2) => sum + s2.score * s2.similarity, 0) / totalWeight;
|
|
14194
|
+
return { mean: weightedMean, variance: 0.2 };
|
|
14195
|
+
}
|
|
14196
|
+
return { mean: 0.5, variance: 0.3 };
|
|
14197
|
+
}
|
|
14198
|
+
/**
|
|
14199
|
+
* Calculates acquisition function value for Bayesian optimization
|
|
14200
|
+
*/
|
|
14201
|
+
calculateAcquisitionValue(config) {
|
|
14202
|
+
const prediction = this.predictPerformance(config);
|
|
14203
|
+
const { mean, variance } = prediction;
|
|
14204
|
+
const std = Math.sqrt(variance);
|
|
14205
|
+
const bestScore = this.miproConfigHistory.length > 0 ? Math.max(...this.miproConfigHistory.map((entry) => entry.score)) : 0;
|
|
14206
|
+
switch (this.acquisitionFunction) {
|
|
14207
|
+
case "expected_improvement": {
|
|
14208
|
+
const improvement = mean - bestScore;
|
|
14209
|
+
if (std === 0) return Math.max(0, improvement);
|
|
14210
|
+
const z = improvement / std;
|
|
14211
|
+
const phi = 0.5 * (1 + this.erf(z / Math.sqrt(2)));
|
|
14212
|
+
const pdfValue = Math.exp(-0.5 * z * z) / Math.sqrt(2 * Math.PI);
|
|
14213
|
+
return improvement * phi + std * pdfValue;
|
|
14214
|
+
}
|
|
14215
|
+
case "upper_confidence_bound": {
|
|
14216
|
+
return mean + this.explorationWeight * std;
|
|
14217
|
+
}
|
|
14218
|
+
case "probability_improvement": {
|
|
14219
|
+
const improvement = mean - bestScore;
|
|
14220
|
+
if (std === 0) return improvement > 0 ? 1 : 0;
|
|
14221
|
+
const z = improvement / std;
|
|
14222
|
+
return 0.5 * (1 + this.erf(z / Math.sqrt(2)));
|
|
14223
|
+
}
|
|
14224
|
+
default:
|
|
14225
|
+
return mean;
|
|
14226
|
+
}
|
|
14227
|
+
}
|
|
14228
|
+
/**
|
|
14229
|
+
* Error function approximation for acquisition function calculations
|
|
14230
|
+
*/
|
|
14231
|
+
erf(x) {
|
|
14232
|
+
const a1 = 0.254829592;
|
|
14233
|
+
const a2 = -0.284496736;
|
|
14234
|
+
const a3 = 1.421413741;
|
|
14235
|
+
const a4 = -1.453152027;
|
|
14236
|
+
const a5 = 1.061405429;
|
|
14237
|
+
const p = 0.3275911;
|
|
14238
|
+
const sign = x >= 0 ? 1 : -1;
|
|
14239
|
+
x = Math.abs(x);
|
|
14240
|
+
const t = 1 / (1 + p * x);
|
|
14241
|
+
const y = 1 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.exp(-x * x);
|
|
14242
|
+
return sign * y;
|
|
14243
|
+
}
|
|
14244
|
+
/**
|
|
14245
|
+
* Selects the next configuration to evaluate using Bayesian optimization
|
|
14246
|
+
*/
|
|
14247
|
+
async selectConfigurationViaBayesianOptimization(instructions, bootstrappedDemos, labeledExamples) {
|
|
14248
|
+
const candidates = [];
|
|
14249
|
+
const numCandidates = Math.min(20, instructions.length * 3);
|
|
14250
|
+
for (let i = 0; i < numCandidates; i++) {
|
|
14251
|
+
const config = {
|
|
14252
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
14253
|
+
bootstrappedDemos: Math.min(
|
|
14254
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
14255
|
+
this.maxBootstrappedDemos
|
|
14256
|
+
),
|
|
14257
|
+
labeledExamples: Math.min(
|
|
14258
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
14259
|
+
this.maxLabeledDemos
|
|
14260
|
+
)
|
|
14261
|
+
};
|
|
14262
|
+
const acquisitionValue = this.calculateAcquisitionValue(config);
|
|
14263
|
+
candidates.push({ config, acquisitionValue });
|
|
14264
|
+
}
|
|
14265
|
+
candidates.sort((a, b) => b.acquisitionValue - a.acquisitionValue);
|
|
14266
|
+
return candidates[0].config;
|
|
14267
|
+
}
|
|
14268
|
+
};
|
|
14269
|
+
var axMajorityVotePicker = () => {
|
|
14270
|
+
return async (data) => {
|
|
14271
|
+
if (data.type === "fields") {
|
|
14272
|
+
const counts = {};
|
|
14273
|
+
for (const { index, sample } of data.results) {
|
|
14274
|
+
const key = JSON.stringify(sample);
|
|
14275
|
+
if (!counts[key]) {
|
|
14276
|
+
counts[key] = { count: 0, index };
|
|
14277
|
+
}
|
|
14278
|
+
counts[key].count += 1;
|
|
14279
|
+
}
|
|
14280
|
+
let bestKey;
|
|
14281
|
+
let bestCount = -1;
|
|
14282
|
+
for (const [k, v] of Object.entries(counts)) {
|
|
14283
|
+
if (v.count > bestCount) {
|
|
14284
|
+
bestCount = v.count;
|
|
14285
|
+
bestKey = k;
|
|
14286
|
+
}
|
|
14287
|
+
}
|
|
14288
|
+
return counts[bestKey]?.index ?? 0;
|
|
14289
|
+
}
|
|
14290
|
+
return data.results[0]?.index ?? 0;
|
|
14291
|
+
};
|
|
13821
14292
|
};
|
|
13822
14293
|
|
|
13823
14294
|
// ai/mock/api.ts
|