@ax-llm/ax 12.0.6 → 12.0.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +1348 -727
- package/index.cjs.map +1 -1
- package/index.d.cts +525 -195
- package/index.d.ts +525 -195
- package/index.js +1350 -731
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.cjs
CHANGED
|
@@ -81,6 +81,7 @@ __export(index_exports, {
|
|
|
81
81
|
AxAssertionError: () => AxAssertionError,
|
|
82
82
|
AxBalancer: () => AxBalancer,
|
|
83
83
|
AxBaseAI: () => AxBaseAI,
|
|
84
|
+
AxBaseOptimizer: () => AxBaseOptimizer,
|
|
84
85
|
AxBootstrapFewShot: () => AxBootstrapFewShot,
|
|
85
86
|
AxChainOfThought: () => AxChainOfThought,
|
|
86
87
|
AxDB: () => AxDB,
|
|
@@ -90,6 +91,7 @@ __export(index_exports, {
|
|
|
90
91
|
AxDBMemory: () => AxDBMemory,
|
|
91
92
|
AxDBPinecone: () => AxDBPinecone,
|
|
92
93
|
AxDBWeaviate: () => AxDBWeaviate,
|
|
94
|
+
AxDefaultCostTracker: () => AxDefaultCostTracker,
|
|
93
95
|
AxDefaultQueryRewriter: () => AxDefaultQueryRewriter,
|
|
94
96
|
AxDefaultResultReranker: () => AxDefaultResultReranker,
|
|
95
97
|
AxDockerSession: () => AxDockerSession,
|
|
@@ -2623,48 +2625,59 @@ function createMessages2(req) {
|
|
|
2623
2625
|
case "system":
|
|
2624
2626
|
return { role: "system", content: msg.content };
|
|
2625
2627
|
case "user":
|
|
2626
|
-
|
|
2628
|
+
const content = Array.isArray(msg.content) ? msg.content.map((c) => {
|
|
2629
|
+
switch (c.type) {
|
|
2630
|
+
case "text":
|
|
2631
|
+
return { type: "text", text: c.text };
|
|
2632
|
+
case "image": {
|
|
2633
|
+
const url = `data:${c.mimeType};base64,` + c.image;
|
|
2634
|
+
return {
|
|
2635
|
+
type: "image_url",
|
|
2636
|
+
image_url: { url, details: c.details ?? "auto" }
|
|
2637
|
+
};
|
|
2638
|
+
}
|
|
2639
|
+
case "audio": {
|
|
2640
|
+
const data = c.data;
|
|
2641
|
+
return {
|
|
2642
|
+
type: "input_audio",
|
|
2643
|
+
input_audio: { data, format: c.format ?? "wav" }
|
|
2644
|
+
};
|
|
2645
|
+
}
|
|
2646
|
+
default:
|
|
2647
|
+
throw new Error("Invalid content type");
|
|
2648
|
+
}
|
|
2649
|
+
}) : msg.content;
|
|
2650
|
+
return {
|
|
2651
|
+
role: "user",
|
|
2652
|
+
...msg.name ? { name: msg.name } : {},
|
|
2653
|
+
content
|
|
2654
|
+
};
|
|
2655
|
+
case "assistant":
|
|
2656
|
+
const toolCalls = msg.functionCalls?.map((v) => ({
|
|
2657
|
+
id: v.id,
|
|
2658
|
+
type: "function",
|
|
2659
|
+
function: {
|
|
2660
|
+
name: v.function.name,
|
|
2661
|
+
arguments: typeof v.function.params === "object" ? JSON.stringify(v.function.params) : v.function.params
|
|
2662
|
+
}
|
|
2663
|
+
}));
|
|
2664
|
+
if (toolCalls && toolCalls.length > 0) {
|
|
2627
2665
|
return {
|
|
2628
|
-
role: "
|
|
2666
|
+
role: "assistant",
|
|
2667
|
+
...msg.content ? { content: msg.content } : {},
|
|
2629
2668
|
name: msg.name,
|
|
2630
|
-
|
|
2631
|
-
switch (c.type) {
|
|
2632
|
-
case "text":
|
|
2633
|
-
return { type: "text", text: c.text };
|
|
2634
|
-
case "image": {
|
|
2635
|
-
const url = `data:${c.mimeType};base64,` + c.image;
|
|
2636
|
-
return {
|
|
2637
|
-
type: "image_url",
|
|
2638
|
-
image_url: { url, details: c.details ?? "auto" }
|
|
2639
|
-
};
|
|
2640
|
-
}
|
|
2641
|
-
case "audio": {
|
|
2642
|
-
const data = c.data;
|
|
2643
|
-
return {
|
|
2644
|
-
type: "input_audio",
|
|
2645
|
-
input_audio: { data, format: c.format ?? "wav" }
|
|
2646
|
-
};
|
|
2647
|
-
}
|
|
2648
|
-
default:
|
|
2649
|
-
throw new Error("Invalid content type");
|
|
2650
|
-
}
|
|
2651
|
-
})
|
|
2669
|
+
tool_calls: toolCalls
|
|
2652
2670
|
};
|
|
2653
2671
|
}
|
|
2654
|
-
|
|
2655
|
-
|
|
2672
|
+
if (!msg.content) {
|
|
2673
|
+
throw new Error(
|
|
2674
|
+
"Assistant content is required when no tool calls are provided"
|
|
2675
|
+
);
|
|
2676
|
+
}
|
|
2656
2677
|
return {
|
|
2657
2678
|
role: "assistant",
|
|
2658
2679
|
content: msg.content,
|
|
2659
|
-
name: msg.name
|
|
2660
|
-
tool_calls: msg.functionCalls?.map((v) => ({
|
|
2661
|
-
id: v.id,
|
|
2662
|
-
type: "function",
|
|
2663
|
-
function: {
|
|
2664
|
-
name: v.function.name,
|
|
2665
|
-
arguments: typeof v.function.params === "object" ? JSON.stringify(v.function.params) : v.function.params
|
|
2666
|
-
}
|
|
2667
|
-
}))
|
|
2680
|
+
...msg.name ? { name: msg.name } : {}
|
|
2668
2681
|
};
|
|
2669
2682
|
case "function":
|
|
2670
2683
|
return {
|
|
@@ -7527,8 +7540,9 @@ var AxInstanceRegistry = class {
|
|
|
7527
7540
|
this.reg.add(instance);
|
|
7528
7541
|
}
|
|
7529
7542
|
*[Symbol.iterator]() {
|
|
7530
|
-
|
|
7531
|
-
|
|
7543
|
+
const items = Array.from(this.reg);
|
|
7544
|
+
for (let i = 0; i < items.length; i++) {
|
|
7545
|
+
yield items[i];
|
|
7532
7546
|
}
|
|
7533
7547
|
}
|
|
7534
7548
|
};
|
|
@@ -8215,11 +8229,41 @@ var AxSignature = class _AxSignature {
|
|
|
8215
8229
|
if (signature.validatedAtHash === this.sigHash) {
|
|
8216
8230
|
this.validatedAtHash = this.sigHash;
|
|
8217
8231
|
}
|
|
8232
|
+
} else if (typeof signature === "object" && signature !== null) {
|
|
8233
|
+
if (!("inputs" in signature) || !("outputs" in signature)) {
|
|
8234
|
+
throw new AxSignatureValidationError(
|
|
8235
|
+
"Invalid signature object: missing inputs or outputs",
|
|
8236
|
+
void 0,
|
|
8237
|
+
'Signature object must have "inputs" and "outputs" arrays. Example: { inputs: [...], outputs: [...] }'
|
|
8238
|
+
);
|
|
8239
|
+
}
|
|
8240
|
+
if (!Array.isArray(signature.inputs) || !Array.isArray(signature.outputs)) {
|
|
8241
|
+
throw new AxSignatureValidationError(
|
|
8242
|
+
"Invalid signature object: inputs and outputs must be arrays",
|
|
8243
|
+
void 0,
|
|
8244
|
+
'Both "inputs" and "outputs" must be arrays of AxField objects'
|
|
8245
|
+
);
|
|
8246
|
+
}
|
|
8247
|
+
try {
|
|
8248
|
+
this.description = signature.description;
|
|
8249
|
+
this.inputFields = signature.inputs.map((v) => this.parseField(v));
|
|
8250
|
+
this.outputFields = signature.outputs.map((v) => this.parseField(v));
|
|
8251
|
+
[this.sigHash, this.sigString] = this.updateHash();
|
|
8252
|
+
} catch (error) {
|
|
8253
|
+
if (error instanceof AxSignatureValidationError) {
|
|
8254
|
+
throw error;
|
|
8255
|
+
}
|
|
8256
|
+
throw new AxSignatureValidationError(
|
|
8257
|
+
`Failed to create signature from object: ${error instanceof Error ? error.message : "Unknown error"}`,
|
|
8258
|
+
void 0,
|
|
8259
|
+
"Check that all fields in inputs and outputs arrays are valid AxField objects"
|
|
8260
|
+
);
|
|
8261
|
+
}
|
|
8218
8262
|
} else {
|
|
8219
8263
|
throw new AxSignatureValidationError(
|
|
8220
8264
|
"Invalid signature argument type",
|
|
8221
8265
|
void 0,
|
|
8222
|
-
"Signature must be a string
|
|
8266
|
+
"Signature must be a string, another AxSignature instance, or an object with inputs and outputs arrays"
|
|
8223
8267
|
);
|
|
8224
8268
|
}
|
|
8225
8269
|
}
|
|
@@ -8262,7 +8306,7 @@ var AxSignature = class _AxSignature {
|
|
|
8262
8306
|
}
|
|
8263
8307
|
this.description = desc;
|
|
8264
8308
|
this.invalidateValidationCache();
|
|
8265
|
-
this.
|
|
8309
|
+
this.updateHashLight();
|
|
8266
8310
|
};
|
|
8267
8311
|
addInputField = (field) => {
|
|
8268
8312
|
try {
|
|
@@ -8436,7 +8480,7 @@ var AxSignature = class _AxSignature {
|
|
|
8436
8480
|
this.getOutputFields().forEach((field) => {
|
|
8437
8481
|
validateField(field, "output");
|
|
8438
8482
|
});
|
|
8439
|
-
this.sigHash = (0, import_crypto3.createHash)("sha256").update(
|
|
8483
|
+
this.sigHash = (0, import_crypto3.createHash)("sha256").update(JSON.stringify(this.inputFields)).update(JSON.stringify(this.outputFields)).digest("hex");
|
|
8440
8484
|
this.sigString = renderSignature(
|
|
8441
8485
|
this.description,
|
|
8442
8486
|
this.inputFields,
|
|
@@ -8757,7 +8801,7 @@ var AxProgramWithSignature = class {
|
|
|
8757
8801
|
this.signature.validate();
|
|
8758
8802
|
this.sigHash = this.signature?.hash();
|
|
8759
8803
|
this.children = new AxInstanceRegistry();
|
|
8760
|
-
this.key = { id: this.
|
|
8804
|
+
this.key = { id: this.signature.hash() };
|
|
8761
8805
|
}
|
|
8762
8806
|
getSignature() {
|
|
8763
8807
|
return this.signature;
|
|
@@ -8777,8 +8821,8 @@ var AxProgramWithSignature = class {
|
|
|
8777
8821
|
}
|
|
8778
8822
|
setId(id) {
|
|
8779
8823
|
this.key = { id, custom: true };
|
|
8780
|
-
for (const child of this.children) {
|
|
8781
|
-
child
|
|
8824
|
+
for (const child of Array.from(this.children)) {
|
|
8825
|
+
child?.setParentId(id);
|
|
8782
8826
|
}
|
|
8783
8827
|
}
|
|
8784
8828
|
setParentId(parentId) {
|
|
@@ -8791,8 +8835,8 @@ var AxProgramWithSignature = class {
|
|
|
8791
8835
|
if (!("programId" in examples)) {
|
|
8792
8836
|
return;
|
|
8793
8837
|
}
|
|
8794
|
-
for (const child of this.children) {
|
|
8795
|
-
child
|
|
8838
|
+
for (const child of Array.from(this.children)) {
|
|
8839
|
+
child?.setExamples(examples, options);
|
|
8796
8840
|
}
|
|
8797
8841
|
}
|
|
8798
8842
|
_setExamples(examples, options) {
|
|
@@ -8825,30 +8869,37 @@ var AxProgramWithSignature = class {
|
|
|
8825
8869
|
if (this.trace) {
|
|
8826
8870
|
traces.push({ trace: this.trace, programId: this.key.id });
|
|
8827
8871
|
}
|
|
8828
|
-
for (const child of this.children) {
|
|
8829
|
-
const _traces = child
|
|
8830
|
-
traces = [...traces, ..._traces];
|
|
8872
|
+
for (const child of Array.from(this.children)) {
|
|
8873
|
+
const _traces = child?.getTraces();
|
|
8874
|
+
traces = [...traces, ..._traces ?? []];
|
|
8831
8875
|
}
|
|
8832
8876
|
return traces;
|
|
8833
8877
|
}
|
|
8834
8878
|
getUsage() {
|
|
8835
8879
|
let usage = [...this.usage ?? []];
|
|
8836
|
-
for (const child of this.children) {
|
|
8837
|
-
const cu = child
|
|
8838
|
-
usage = [...usage, ...cu];
|
|
8880
|
+
for (const child of Array.from(this.children)) {
|
|
8881
|
+
const cu = child?.getUsage();
|
|
8882
|
+
usage = [...usage, ...cu ?? []];
|
|
8839
8883
|
}
|
|
8840
8884
|
return mergeProgramUsage(usage);
|
|
8841
8885
|
}
|
|
8842
8886
|
resetUsage() {
|
|
8843
8887
|
this.usage = [];
|
|
8844
|
-
for (const child of this.children) {
|
|
8845
|
-
child
|
|
8888
|
+
for (const child of Array.from(this.children)) {
|
|
8889
|
+
child?.resetUsage();
|
|
8846
8890
|
}
|
|
8847
8891
|
}
|
|
8848
8892
|
setDemos(demos) {
|
|
8893
|
+
const hasChildren = Array.from(this.children).length > 0;
|
|
8894
|
+
const hasMatchingDemo = demos.some((demo) => demo.programId === this.key.id);
|
|
8895
|
+
if (hasChildren && !hasMatchingDemo) {
|
|
8896
|
+
throw new Error(
|
|
8897
|
+
`Program with id '${this.key.id}' has children but no matching programId found in demos`
|
|
8898
|
+
);
|
|
8899
|
+
}
|
|
8849
8900
|
this.demos = demos.filter((v) => v.programId === this.key.id).map((v) => v.traces).flat();
|
|
8850
|
-
for (const child of this.children) {
|
|
8851
|
-
child
|
|
8901
|
+
for (const child of Array.from(this.children)) {
|
|
8902
|
+
child?.setDemos(demos);
|
|
8852
8903
|
}
|
|
8853
8904
|
}
|
|
8854
8905
|
};
|
|
@@ -8876,8 +8927,8 @@ var AxProgram = class {
|
|
|
8876
8927
|
}
|
|
8877
8928
|
setId(id) {
|
|
8878
8929
|
this.key = { id, custom: true };
|
|
8879
|
-
for (const child of this.children) {
|
|
8880
|
-
child
|
|
8930
|
+
for (const child of Array.from(this.children)) {
|
|
8931
|
+
child?.setParentId(id);
|
|
8881
8932
|
}
|
|
8882
8933
|
}
|
|
8883
8934
|
setParentId(parentId) {
|
|
@@ -8889,8 +8940,8 @@ var AxProgram = class {
|
|
|
8889
8940
|
if (!("programId" in examples)) {
|
|
8890
8941
|
return;
|
|
8891
8942
|
}
|
|
8892
|
-
for (const child of this.children) {
|
|
8893
|
-
child
|
|
8943
|
+
for (const child of Array.from(this.children)) {
|
|
8944
|
+
child?.setExamples(examples, options);
|
|
8894
8945
|
}
|
|
8895
8946
|
}
|
|
8896
8947
|
getTraces() {
|
|
@@ -8898,29 +8949,36 @@ var AxProgram = class {
|
|
|
8898
8949
|
if (this.trace) {
|
|
8899
8950
|
traces.push({ trace: this.trace, programId: this.key.id });
|
|
8900
8951
|
}
|
|
8901
|
-
for (const child of this.children) {
|
|
8902
|
-
const _traces = child
|
|
8903
|
-
traces = [...traces, ..._traces];
|
|
8952
|
+
for (const child of Array.from(this.children)) {
|
|
8953
|
+
const _traces = child?.getTraces();
|
|
8954
|
+
traces = [...traces, ..._traces ?? []];
|
|
8904
8955
|
}
|
|
8905
8956
|
return traces;
|
|
8906
8957
|
}
|
|
8907
8958
|
getUsage() {
|
|
8908
8959
|
let usage = [...this.usage ?? []];
|
|
8909
|
-
for (const child of this.children) {
|
|
8910
|
-
const cu = child
|
|
8911
|
-
usage = [...usage, ...cu];
|
|
8960
|
+
for (const child of Array.from(this.children)) {
|
|
8961
|
+
const cu = child?.getUsage();
|
|
8962
|
+
usage = [...usage, ...cu ?? []];
|
|
8912
8963
|
}
|
|
8913
8964
|
return mergeProgramUsage(usage);
|
|
8914
8965
|
}
|
|
8915
8966
|
resetUsage() {
|
|
8916
8967
|
this.usage = [];
|
|
8917
|
-
for (const child of this.children) {
|
|
8918
|
-
child
|
|
8968
|
+
for (const child of Array.from(this.children)) {
|
|
8969
|
+
child?.resetUsage();
|
|
8919
8970
|
}
|
|
8920
8971
|
}
|
|
8921
8972
|
setDemos(demos) {
|
|
8922
|
-
|
|
8923
|
-
|
|
8973
|
+
const hasChildren = Array.from(this.children).length > 0;
|
|
8974
|
+
const hasMatchingDemo = demos.some((demo) => demo.programId === this.key.id);
|
|
8975
|
+
if (hasChildren && !hasMatchingDemo) {
|
|
8976
|
+
throw new Error(
|
|
8977
|
+
`Program with id '${this.key.id}' has children but no matching programId found in demos`
|
|
8978
|
+
);
|
|
8979
|
+
}
|
|
8980
|
+
for (const child of Array.from(this.children)) {
|
|
8981
|
+
child?.setDemos(demos);
|
|
8924
8982
|
}
|
|
8925
8983
|
}
|
|
8926
8984
|
};
|
|
@@ -8938,11 +8996,9 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
8938
8996
|
values = {};
|
|
8939
8997
|
excludeContentFromTrace = false;
|
|
8940
8998
|
thoughtFieldName;
|
|
8941
|
-
logger;
|
|
8942
8999
|
constructor(signature, options) {
|
|
8943
9000
|
super(signature, { description: options?.description });
|
|
8944
9001
|
this.options = options;
|
|
8945
|
-
this.logger = options?.logger;
|
|
8946
9002
|
this.thoughtFieldName = options?.thoughtFieldName ?? "thought";
|
|
8947
9003
|
const promptTemplateOptions = {
|
|
8948
9004
|
functions: options?.functions,
|
|
@@ -9032,6 +9088,7 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9032
9088
|
rateLimiter,
|
|
9033
9089
|
stream,
|
|
9034
9090
|
debug: false,
|
|
9091
|
+
// we do our own debug logging
|
|
9035
9092
|
thinkingTokenBudget,
|
|
9036
9093
|
showThoughts,
|
|
9037
9094
|
traceContext,
|
|
@@ -9071,6 +9128,7 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9071
9128
|
fastFail,
|
|
9072
9129
|
span
|
|
9073
9130
|
});
|
|
9131
|
+
this.getLogger(ai, options)?.("", { tags: ["responseEnd"] });
|
|
9074
9132
|
} else {
|
|
9075
9133
|
yield await this.processResponse({
|
|
9076
9134
|
ai,
|
|
@@ -9107,7 +9165,6 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
9107
9165
|
mem.addResult(
|
|
9108
9166
|
{
|
|
9109
9167
|
content: "",
|
|
9110
|
-
name: "initial",
|
|
9111
9168
|
functionCalls: []
|
|
9112
9169
|
},
|
|
9113
9170
|
sessionId
|
|
@@ -9240,10 +9297,6 @@ Content: ${content}`
|
|
|
9240
9297
|
xstate
|
|
9241
9298
|
);
|
|
9242
9299
|
}
|
|
9243
|
-
if (ai.getOptions().debug) {
|
|
9244
|
-
const logger = ai.getLogger();
|
|
9245
|
-
logger("", { tags: ["responseEnd"] });
|
|
9246
|
-
}
|
|
9247
9300
|
}
|
|
9248
9301
|
async processResponse({
|
|
9249
9302
|
ai,
|
|
@@ -9315,9 +9368,11 @@ Content: ${result.content}`
|
|
|
9315
9368
|
const stopFunction = (options?.stopFunction ?? this.options?.stopFunction)?.toLowerCase();
|
|
9316
9369
|
const maxRetries = options.maxRetries ?? this.options?.maxRetries ?? 10;
|
|
9317
9370
|
const maxSteps = options.maxSteps ?? this.options?.maxSteps ?? 10;
|
|
9318
|
-
const debug = options.debug ?? ai.getOptions().debug;
|
|
9319
9371
|
const debugHideSystemPrompt = options.debugHideSystemPrompt;
|
|
9320
|
-
const memOptions = {
|
|
9372
|
+
const memOptions = {
|
|
9373
|
+
debug: this.isDebug(ai, options),
|
|
9374
|
+
debugHideSystemPrompt
|
|
9375
|
+
};
|
|
9321
9376
|
const mem = options.mem ?? this.options?.mem ?? new AxMemory(1e4, memOptions);
|
|
9322
9377
|
let err;
|
|
9323
9378
|
if (options?.functions && options.functions.length > 0) {
|
|
@@ -9371,10 +9426,7 @@ Content: ${result.content}`
|
|
|
9371
9426
|
if (shouldContinue) {
|
|
9372
9427
|
continue multiStepLoop;
|
|
9373
9428
|
}
|
|
9374
|
-
|
|
9375
|
-
const logger = options.logger ?? this.logger ?? ai.getLogger();
|
|
9376
|
-
logger("", { tags: ["responseEnd"] });
|
|
9377
|
-
}
|
|
9429
|
+
this.getLogger(ai, options)?.("", { tags: ["responseEnd"] });
|
|
9378
9430
|
return;
|
|
9379
9431
|
} catch (e) {
|
|
9380
9432
|
let errorFields;
|
|
@@ -9516,6 +9568,12 @@ Content: ${result.content}`
|
|
|
9516
9568
|
setExamples(examples, options) {
|
|
9517
9569
|
super.setExamples(examples, options);
|
|
9518
9570
|
}
|
|
9571
|
+
isDebug(ai, options) {
|
|
9572
|
+
return options?.debug ?? this.options?.debug ?? ai.getOptions().debug ?? false;
|
|
9573
|
+
}
|
|
9574
|
+
getLogger(ai, options) {
|
|
9575
|
+
return options?.logger ?? this.options?.logger ?? ai.getLogger();
|
|
9576
|
+
}
|
|
9519
9577
|
};
|
|
9520
9578
|
var AxGenerateError = class extends Error {
|
|
9521
9579
|
details;
|
|
@@ -9648,7 +9706,9 @@ var AxAgent = class {
|
|
|
9648
9706
|
description: definition ?? description
|
|
9649
9707
|
});
|
|
9650
9708
|
for (const agent of agents ?? []) {
|
|
9651
|
-
this.program.register(
|
|
9709
|
+
this.program.register(
|
|
9710
|
+
agent
|
|
9711
|
+
);
|
|
9652
9712
|
}
|
|
9653
9713
|
this.name = name;
|
|
9654
9714
|
this.func = {
|
|
@@ -10152,171 +10212,838 @@ function validateModels2(services) {
|
|
|
10152
10212
|
}
|
|
10153
10213
|
}
|
|
10154
10214
|
|
|
10155
|
-
//
|
|
10156
|
-
var
|
|
10157
|
-
|
|
10158
|
-
|
|
10159
|
-
|
|
10160
|
-
|
|
10161
|
-
|
|
10162
|
-
|
|
10163
|
-
|
|
10164
|
-
|
|
10165
|
-
|
|
10166
|
-
|
|
10167
|
-
tracer
|
|
10168
|
-
}) {
|
|
10169
|
-
this.name = name;
|
|
10170
|
-
this.fetch = fetch2;
|
|
10171
|
-
this.tracer = tracer;
|
|
10215
|
+
// dsp/optimizer.ts
|
|
10216
|
+
var AxDefaultCostTracker = class {
|
|
10217
|
+
tokenUsage = {};
|
|
10218
|
+
totalTokens = 0;
|
|
10219
|
+
// Configuration options
|
|
10220
|
+
costPerModel;
|
|
10221
|
+
maxCost;
|
|
10222
|
+
maxTokens;
|
|
10223
|
+
constructor(options) {
|
|
10224
|
+
this.costPerModel = options?.costPerModel ?? {};
|
|
10225
|
+
this.maxCost = options?.maxCost;
|
|
10226
|
+
this.maxTokens = options?.maxTokens;
|
|
10172
10227
|
}
|
|
10173
|
-
|
|
10174
|
-
|
|
10175
|
-
|
|
10176
|
-
}
|
|
10177
|
-
if (!this.tracer) {
|
|
10178
|
-
return await this._upsert(req, update);
|
|
10179
|
-
}
|
|
10180
|
-
return await this.tracer.startActiveSpan(
|
|
10181
|
-
"DB Upsert Request",
|
|
10182
|
-
{
|
|
10183
|
-
kind: import_api23.SpanKind.SERVER,
|
|
10184
|
-
attributes: {
|
|
10185
|
-
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10186
|
-
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10187
|
-
[axSpanAttributes.DB_TABLE]: req.table,
|
|
10188
|
-
[axSpanAttributes.DB_NAMESPACE]: req.namespace,
|
|
10189
|
-
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
10190
|
-
}
|
|
10191
|
-
},
|
|
10192
|
-
async (span) => {
|
|
10193
|
-
try {
|
|
10194
|
-
return await this._upsert(req, update, { span });
|
|
10195
|
-
} finally {
|
|
10196
|
-
span.end();
|
|
10197
|
-
}
|
|
10198
|
-
}
|
|
10199
|
-
);
|
|
10228
|
+
trackTokens(count, model) {
|
|
10229
|
+
this.tokenUsage[model] = (this.tokenUsage[model] || 0) + count;
|
|
10230
|
+
this.totalTokens += count;
|
|
10200
10231
|
}
|
|
10201
|
-
|
|
10202
|
-
|
|
10203
|
-
|
|
10204
|
-
|
|
10205
|
-
|
|
10206
|
-
throw new Error("Batch request is empty");
|
|
10207
|
-
}
|
|
10208
|
-
if (!req[0]) {
|
|
10209
|
-
throw new Error("Batch request is invalid first element is undefined");
|
|
10210
|
-
}
|
|
10211
|
-
if (!this.tracer) {
|
|
10212
|
-
return await this._batchUpsert(req, update);
|
|
10232
|
+
getCurrentCost() {
|
|
10233
|
+
let totalCost = 0;
|
|
10234
|
+
for (const [model, tokens] of Object.entries(this.tokenUsage)) {
|
|
10235
|
+
const costPer1K = this.costPerModel[model] || 1e-3;
|
|
10236
|
+
totalCost += tokens / 1e3 * costPer1K;
|
|
10213
10237
|
}
|
|
10214
|
-
return
|
|
10215
|
-
"DB Batch Upsert Request",
|
|
10216
|
-
{
|
|
10217
|
-
kind: import_api23.SpanKind.SERVER,
|
|
10218
|
-
attributes: {
|
|
10219
|
-
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10220
|
-
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10221
|
-
[axSpanAttributes.DB_TABLE]: req[0].table,
|
|
10222
|
-
[axSpanAttributes.DB_NAMESPACE]: req[0].namespace,
|
|
10223
|
-
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
10224
|
-
}
|
|
10225
|
-
},
|
|
10226
|
-
async (span) => {
|
|
10227
|
-
try {
|
|
10228
|
-
return await this._batchUpsert(req, update, { span });
|
|
10229
|
-
} finally {
|
|
10230
|
-
span.end();
|
|
10231
|
-
}
|
|
10232
|
-
}
|
|
10233
|
-
);
|
|
10238
|
+
return totalCost;
|
|
10234
10239
|
}
|
|
10235
|
-
|
|
10236
|
-
|
|
10237
|
-
|
|
10238
|
-
|
|
10239
|
-
|
|
10240
|
-
|
|
10240
|
+
getTokenUsage() {
|
|
10241
|
+
return { ...this.tokenUsage };
|
|
10242
|
+
}
|
|
10243
|
+
getTotalTokens() {
|
|
10244
|
+
return this.totalTokens;
|
|
10245
|
+
}
|
|
10246
|
+
isLimitReached() {
|
|
10247
|
+
if (this.maxTokens !== void 0 && this.totalTokens >= this.maxTokens) {
|
|
10248
|
+
return true;
|
|
10241
10249
|
}
|
|
10242
|
-
|
|
10243
|
-
|
|
10244
|
-
{
|
|
10245
|
-
|
|
10246
|
-
attributes: {
|
|
10247
|
-
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10248
|
-
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10249
|
-
[axSpanAttributes.DB_TABLE]: req.table,
|
|
10250
|
-
[axSpanAttributes.DB_NAMESPACE]: req.namespace,
|
|
10251
|
-
[axSpanAttributes.DB_OPERATION_NAME]: "query"
|
|
10252
|
-
}
|
|
10253
|
-
},
|
|
10254
|
-
async (span) => {
|
|
10255
|
-
try {
|
|
10256
|
-
return await this._query(req, { span });
|
|
10257
|
-
} finally {
|
|
10258
|
-
span.end();
|
|
10259
|
-
}
|
|
10250
|
+
if (this.maxCost !== void 0) {
|
|
10251
|
+
const currentCost = this.getCurrentCost();
|
|
10252
|
+
if (currentCost >= this.maxCost) {
|
|
10253
|
+
return true;
|
|
10260
10254
|
}
|
|
10261
|
-
|
|
10255
|
+
}
|
|
10256
|
+
return false;
|
|
10257
|
+
}
|
|
10258
|
+
reset() {
|
|
10259
|
+
this.tokenUsage = {};
|
|
10260
|
+
this.totalTokens = 0;
|
|
10262
10261
|
}
|
|
10263
10262
|
};
|
|
10264
|
-
|
|
10265
|
-
//
|
|
10266
|
-
|
|
10267
|
-
|
|
10268
|
-
|
|
10269
|
-
|
|
10270
|
-
|
|
10271
|
-
|
|
10272
|
-
|
|
10273
|
-
|
|
10274
|
-
|
|
10275
|
-
|
|
10276
|
-
|
|
10277
|
-
|
|
10263
|
+
var AxBaseOptimizer = class {
|
|
10264
|
+
// Common AxOptimizerArgs fields
|
|
10265
|
+
studentAI;
|
|
10266
|
+
teacherAI;
|
|
10267
|
+
examples;
|
|
10268
|
+
validationSet;
|
|
10269
|
+
targetScore;
|
|
10270
|
+
minSuccessRate;
|
|
10271
|
+
onProgress;
|
|
10272
|
+
onEarlyStop;
|
|
10273
|
+
costTracker;
|
|
10274
|
+
seed;
|
|
10275
|
+
// Checkpointing fields
|
|
10276
|
+
checkpointSave;
|
|
10277
|
+
checkpointLoad;
|
|
10278
|
+
checkpointInterval;
|
|
10279
|
+
resumeFromCheckpoint;
|
|
10280
|
+
// Checkpoint state
|
|
10281
|
+
currentRound = 0;
|
|
10282
|
+
scoreHistory = [];
|
|
10283
|
+
configurationHistory = [];
|
|
10284
|
+
// Common optimization statistics
|
|
10285
|
+
stats;
|
|
10286
|
+
constructor(args) {
|
|
10287
|
+
if (args.examples.length === 0) {
|
|
10288
|
+
throw new Error("No examples found");
|
|
10278
10289
|
}
|
|
10279
|
-
|
|
10280
|
-
this.
|
|
10281
|
-
this.
|
|
10290
|
+
this.studentAI = args.studentAI;
|
|
10291
|
+
this.teacherAI = args.teacherAI;
|
|
10292
|
+
this.examples = args.examples;
|
|
10293
|
+
this.validationSet = args.validationSet;
|
|
10294
|
+
this.targetScore = args.targetScore;
|
|
10295
|
+
this.minSuccessRate = args.minSuccessRate;
|
|
10296
|
+
this.onProgress = args.onProgress;
|
|
10297
|
+
this.onEarlyStop = args.onEarlyStop;
|
|
10298
|
+
this.seed = args.seed;
|
|
10299
|
+
this.checkpointSave = args.checkpointSave;
|
|
10300
|
+
this.checkpointLoad = args.checkpointLoad;
|
|
10301
|
+
this.checkpointInterval = args.checkpointInterval ?? 10;
|
|
10302
|
+
this.resumeFromCheckpoint = args.resumeFromCheckpoint;
|
|
10303
|
+
const costTracker = new AxDefaultCostTracker({
|
|
10304
|
+
maxTokens: 1e6
|
|
10305
|
+
});
|
|
10306
|
+
this.costTracker = args.costTracker ?? costTracker;
|
|
10307
|
+
this.stats = this.initializeStats();
|
|
10282
10308
|
}
|
|
10283
|
-
|
|
10284
|
-
|
|
10285
|
-
|
|
10286
|
-
|
|
10287
|
-
|
|
10288
|
-
|
|
10289
|
-
|
|
10290
|
-
|
|
10291
|
-
|
|
10292
|
-
|
|
10293
|
-
|
|
10294
|
-
|
|
10309
|
+
/**
|
|
10310
|
+
* Initialize the optimization statistics structure
|
|
10311
|
+
*/
|
|
10312
|
+
initializeStats() {
|
|
10313
|
+
return {
|
|
10314
|
+
totalCalls: 0,
|
|
10315
|
+
successfulDemos: 0,
|
|
10316
|
+
estimatedTokenUsage: 0,
|
|
10317
|
+
earlyStopped: false,
|
|
10318
|
+
resourceUsage: {
|
|
10319
|
+
totalTokens: 0,
|
|
10320
|
+
totalTime: 0,
|
|
10321
|
+
avgLatencyPerEval: 0,
|
|
10322
|
+
costByModel: {}
|
|
10295
10323
|
},
|
|
10296
|
-
{
|
|
10297
|
-
|
|
10298
|
-
|
|
10299
|
-
|
|
10300
|
-
|
|
10324
|
+
convergenceInfo: {
|
|
10325
|
+
converged: false,
|
|
10326
|
+
finalImprovement: 0,
|
|
10327
|
+
stagnationRounds: 0,
|
|
10328
|
+
convergenceThreshold: 0.01
|
|
10301
10329
|
}
|
|
10302
|
-
);
|
|
10303
|
-
if (res.errors) {
|
|
10304
|
-
throw new Error(
|
|
10305
|
-
`Cloudflare upsert failed: ${res.errors.map(({ message }) => message).join(", ")}`
|
|
10306
|
-
);
|
|
10307
|
-
}
|
|
10308
|
-
return {
|
|
10309
|
-
ids: res.result.ids
|
|
10310
10330
|
};
|
|
10311
|
-
}
|
|
10312
|
-
|
|
10313
|
-
|
|
10314
|
-
|
|
10331
|
+
}
|
|
10332
|
+
/**
|
|
10333
|
+
* Set up reproducible random seed if provided
|
|
10334
|
+
*/
|
|
10335
|
+
setupRandomSeed() {
|
|
10336
|
+
if (this.seed !== void 0) {
|
|
10337
|
+
Math.random = (() => {
|
|
10338
|
+
let seed = this.seed;
|
|
10339
|
+
return () => {
|
|
10340
|
+
seed = (seed * 9301 + 49297) % 233280;
|
|
10341
|
+
return seed / 233280;
|
|
10342
|
+
};
|
|
10343
|
+
})();
|
|
10315
10344
|
}
|
|
10316
|
-
|
|
10317
|
-
|
|
10345
|
+
}
|
|
10346
|
+
/**
|
|
10347
|
+
* Check if optimization should stop early due to cost limits
|
|
10348
|
+
*/
|
|
10349
|
+
checkCostLimits() {
|
|
10350
|
+
return this.costTracker?.isLimitReached() ?? false;
|
|
10351
|
+
}
|
|
10352
|
+
/**
|
|
10353
|
+
* Check if target score has been reached
|
|
10354
|
+
*/
|
|
10355
|
+
checkTargetScore(currentScore) {
|
|
10356
|
+
return this.targetScore !== void 0 && currentScore >= this.targetScore;
|
|
10357
|
+
}
|
|
10358
|
+
/**
|
|
10359
|
+
* Update resource usage statistics
|
|
10360
|
+
*/
|
|
10361
|
+
updateResourceUsage(startTime, tokensUsed = 0) {
|
|
10362
|
+
this.stats.resourceUsage.totalTime = Date.now() - startTime;
|
|
10363
|
+
this.stats.resourceUsage.totalTokens += tokensUsed;
|
|
10364
|
+
if (this.stats.totalCalls > 0) {
|
|
10365
|
+
this.stats.resourceUsage.avgLatencyPerEval = this.stats.resourceUsage.totalTime / this.stats.totalCalls;
|
|
10318
10366
|
}
|
|
10319
|
-
|
|
10367
|
+
}
|
|
10368
|
+
/**
|
|
10369
|
+
* Trigger early stopping with appropriate callbacks
|
|
10370
|
+
*/
|
|
10371
|
+
triggerEarlyStopping(reason, bestScoreRound) {
|
|
10372
|
+
this.stats.earlyStopped = true;
|
|
10373
|
+
this.stats.earlyStopping = {
|
|
10374
|
+
bestScoreRound,
|
|
10375
|
+
patienceExhausted: reason.includes("improvement"),
|
|
10376
|
+
reason
|
|
10377
|
+
};
|
|
10378
|
+
if (this.onEarlyStop) {
|
|
10379
|
+
this.onEarlyStop(reason, this.stats);
|
|
10380
|
+
}
|
|
10381
|
+
}
|
|
10382
|
+
/**
|
|
10383
|
+
* Get the validation set, with fallback to a split of examples
|
|
10384
|
+
*/
|
|
10385
|
+
getValidationSet(options) {
|
|
10386
|
+
return options?.overrideValidationSet || this.validationSet || this.examples.slice(0, Math.floor(this.examples.length * 0.2));
|
|
10387
|
+
}
|
|
10388
|
+
/**
|
|
10389
|
+
* Get the AI service to use for a specific task, preferring teacher when available
|
|
10390
|
+
* @param preferTeacher Whether to prefer teacher AI over student AI
|
|
10391
|
+
* @param options Optional compile options that may override teacher AI
|
|
10392
|
+
* @returns The appropriate AI service to use
|
|
10393
|
+
*/
|
|
10394
|
+
getAIService(preferTeacher = false, options) {
|
|
10395
|
+
if (preferTeacher && options?.overrideTeacherAI) {
|
|
10396
|
+
return options.overrideTeacherAI;
|
|
10397
|
+
}
|
|
10398
|
+
if (preferTeacher && this.teacherAI) {
|
|
10399
|
+
return this.teacherAI;
|
|
10400
|
+
}
|
|
10401
|
+
return this.studentAI;
|
|
10402
|
+
}
|
|
10403
|
+
/**
|
|
10404
|
+
* Check if teacher AI is available (including overrides)
|
|
10405
|
+
* @param options Optional compile options that may override teacher AI
|
|
10406
|
+
* @returns True if teacher AI is configured or overridden
|
|
10407
|
+
*/
|
|
10408
|
+
hasTeacherAI(options) {
|
|
10409
|
+
return options?.overrideTeacherAI !== void 0 || this.teacherAI !== void 0;
|
|
10410
|
+
}
|
|
10411
|
+
/**
|
|
10412
|
+
* Get teacher AI if available, otherwise return student AI
|
|
10413
|
+
* @param options Optional compile options that may override teacher AI
|
|
10414
|
+
* @returns Teacher AI if available, otherwise student AI
|
|
10415
|
+
*/
|
|
10416
|
+
getTeacherOrStudentAI(options) {
|
|
10417
|
+
return options?.overrideTeacherAI || this.teacherAI || this.studentAI;
|
|
10418
|
+
}
|
|
10419
|
+
/**
|
|
10420
|
+
* Execute a task with teacher AI if available, otherwise use student AI
|
|
10421
|
+
* @param task Function that takes an AI service and returns a promise
|
|
10422
|
+
* @param preferTeacher Whether to prefer teacher AI (default: true)
|
|
10423
|
+
* @param options Optional compile options that may override teacher AI
|
|
10424
|
+
* @returns Result of the task execution
|
|
10425
|
+
*/
|
|
10426
|
+
async executeWithTeacher(task, preferTeacher = true, options) {
|
|
10427
|
+
const ai = this.getAIService(preferTeacher, options);
|
|
10428
|
+
return await task(ai);
|
|
10429
|
+
}
|
|
10430
|
+
/**
|
|
10431
|
+
* Get current optimization statistics
|
|
10432
|
+
*/
|
|
10433
|
+
getStats() {
|
|
10434
|
+
return { ...this.stats };
|
|
10435
|
+
}
|
|
10436
|
+
/**
|
|
10437
|
+
* Reset optimizer state for reuse with different programs
|
|
10438
|
+
*/
|
|
10439
|
+
reset() {
|
|
10440
|
+
this.stats = this.initializeStats();
|
|
10441
|
+
this.costTracker?.reset();
|
|
10442
|
+
this.currentRound = 0;
|
|
10443
|
+
this.scoreHistory = [];
|
|
10444
|
+
this.configurationHistory = [];
|
|
10445
|
+
}
|
|
10446
|
+
/**
|
|
10447
|
+
* Basic program validation that can be extended by concrete optimizers
|
|
10448
|
+
*/
|
|
10449
|
+
validateProgram(program) {
|
|
10450
|
+
const issues = [];
|
|
10451
|
+
const suggestions = [];
|
|
10452
|
+
if (!("forward" in program) || typeof program.forward !== "function") {
|
|
10453
|
+
issues.push("Program must have a forward method");
|
|
10454
|
+
}
|
|
10455
|
+
if (this.examples.length < 2) {
|
|
10456
|
+
issues.push("Need at least 2 examples for optimization");
|
|
10457
|
+
suggestions.push("Provide more training examples");
|
|
10458
|
+
}
|
|
10459
|
+
const valSetSize = this.getValidationSet().length;
|
|
10460
|
+
if (valSetSize < 1) {
|
|
10461
|
+
issues.push("Validation set is empty");
|
|
10462
|
+
suggestions.push("Provide examples or a validation set");
|
|
10463
|
+
}
|
|
10464
|
+
return {
|
|
10465
|
+
isValid: issues.length === 0,
|
|
10466
|
+
issues,
|
|
10467
|
+
suggestions
|
|
10468
|
+
};
|
|
10469
|
+
}
|
|
10470
|
+
/**
|
|
10471
|
+
* Multi-objective optimization using Pareto frontier
|
|
10472
|
+
* Default implementation that leverages the single-objective compile method
|
|
10473
|
+
* @param program The program to optimize
|
|
10474
|
+
* @param metricFn Multi-objective metric function that returns multiple scores
|
|
10475
|
+
* @param options Optional configuration options
|
|
10476
|
+
* @returns Pareto optimization result with frontier of non-dominated solutions
|
|
10477
|
+
*/
|
|
10478
|
+
async compilePareto(program, metricFn, options) {
|
|
10479
|
+
const startTime = Date.now();
|
|
10480
|
+
if (options?.verbose) {
|
|
10481
|
+
console.log("Starting Pareto optimization using base implementation");
|
|
10482
|
+
console.log("This will run multiple single-objective optimizations");
|
|
10483
|
+
}
|
|
10484
|
+
const solutions = await this.generateWeightedSolutions(
|
|
10485
|
+
program,
|
|
10486
|
+
metricFn,
|
|
10487
|
+
options
|
|
10488
|
+
);
|
|
10489
|
+
const constraintSolutions = await this.generateConstraintSolutions(
|
|
10490
|
+
program,
|
|
10491
|
+
metricFn,
|
|
10492
|
+
options
|
|
10493
|
+
);
|
|
10494
|
+
const allSolutions = [...solutions, ...constraintSolutions];
|
|
10495
|
+
if (options?.verbose) {
|
|
10496
|
+
console.log(`Generated ${allSolutions.length} candidate solutions`);
|
|
10497
|
+
}
|
|
10498
|
+
const paretoFront = this.findParetoFrontier(allSolutions);
|
|
10499
|
+
const hypervolume = this.calculateHypervolume(paretoFront);
|
|
10500
|
+
if (options?.verbose) {
|
|
10501
|
+
console.log(`Found ${paretoFront.length} non-dominated solutions`);
|
|
10502
|
+
console.log(`Hypervolume: ${hypervolume?.toFixed(4) || "N/A"}`);
|
|
10503
|
+
}
|
|
10504
|
+
this.updateResourceUsage(startTime);
|
|
10505
|
+
this.stats.convergenceInfo.converged = true;
|
|
10506
|
+
const bestScore = paretoFront.length > 0 ? Math.max(
|
|
10507
|
+
...paretoFront.map((sol) => Math.max(...Object.values(sol.scores)))
|
|
10508
|
+
) : 0;
|
|
10509
|
+
return {
|
|
10510
|
+
demos: paretoFront.length > 0 ? [...paretoFront[0].demos] : void 0,
|
|
10511
|
+
stats: this.stats,
|
|
10512
|
+
bestScore,
|
|
10513
|
+
paretoFront,
|
|
10514
|
+
hypervolume,
|
|
10515
|
+
paretoFrontSize: paretoFront.length,
|
|
10516
|
+
finalConfiguration: {
|
|
10517
|
+
paretoFrontSize: paretoFront.length,
|
|
10518
|
+
hypervolume,
|
|
10519
|
+
strategy: "weighted_combinations_and_constraints",
|
|
10520
|
+
numSolutions: allSolutions.length
|
|
10521
|
+
}
|
|
10522
|
+
};
|
|
10523
|
+
}
|
|
10524
|
+
/**
|
|
10525
|
+
* Generate solutions using different weighted combinations of objectives
|
|
10526
|
+
*/
|
|
10527
|
+
async generateWeightedSolutions(program, metricFn, options) {
|
|
10528
|
+
const solutions = [];
|
|
10529
|
+
const sampleExample = this.examples[0];
|
|
10530
|
+
const samplePrediction = await program.forward(
|
|
10531
|
+
this.studentAI,
|
|
10532
|
+
sampleExample
|
|
10533
|
+
);
|
|
10534
|
+
const sampleScores = await metricFn({
|
|
10535
|
+
prediction: samplePrediction,
|
|
10536
|
+
example: sampleExample
|
|
10537
|
+
});
|
|
10538
|
+
const objectives = Object.keys(sampleScores);
|
|
10539
|
+
if (options?.verbose) {
|
|
10540
|
+
console.log(`Detected objectives: ${objectives.join(", ")}`);
|
|
10541
|
+
}
|
|
10542
|
+
const weightCombinations = this.generateWeightCombinations(objectives);
|
|
10543
|
+
for (let i = 0; i < weightCombinations.length; i++) {
|
|
10544
|
+
const weights = weightCombinations[i];
|
|
10545
|
+
if (options?.verbose) {
|
|
10546
|
+
console.log(`Optimizing with weights: ${JSON.stringify(weights)}`);
|
|
10547
|
+
}
|
|
10548
|
+
const weightedMetric = async ({ prediction, example }) => {
|
|
10549
|
+
const scores = await metricFn({ prediction, example });
|
|
10550
|
+
let weightedScore = 0;
|
|
10551
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10552
|
+
weightedScore += score * (weights[objective] || 0);
|
|
10553
|
+
}
|
|
10554
|
+
return weightedScore;
|
|
10555
|
+
};
|
|
10556
|
+
try {
|
|
10557
|
+
const result = await this.compile(program, weightedMetric, {
|
|
10558
|
+
...options,
|
|
10559
|
+
verbose: false
|
|
10560
|
+
// Suppress inner optimization logs
|
|
10561
|
+
});
|
|
10562
|
+
const scores = await this.evaluateWithMultiObjective(
|
|
10563
|
+
program,
|
|
10564
|
+
result,
|
|
10565
|
+
metricFn
|
|
10566
|
+
);
|
|
10567
|
+
solutions.push({
|
|
10568
|
+
scores,
|
|
10569
|
+
demos: result.demos,
|
|
10570
|
+
configuration: {
|
|
10571
|
+
...result.finalConfiguration,
|
|
10572
|
+
weights,
|
|
10573
|
+
strategy: "weighted_combination"
|
|
10574
|
+
}
|
|
10575
|
+
});
|
|
10576
|
+
} catch (error) {
|
|
10577
|
+
if (options?.verbose) {
|
|
10578
|
+
console.warn(
|
|
10579
|
+
`Failed optimization with weights ${JSON.stringify(weights)}:`,
|
|
10580
|
+
error
|
|
10581
|
+
);
|
|
10582
|
+
}
|
|
10583
|
+
continue;
|
|
10584
|
+
}
|
|
10585
|
+
}
|
|
10586
|
+
return solutions;
|
|
10587
|
+
}
|
|
10588
|
+
/**
|
|
10589
|
+
* Generate solutions using constraint-based optimization
|
|
10590
|
+
*/
|
|
10591
|
+
async generateConstraintSolutions(program, metricFn, options) {
|
|
10592
|
+
const solutions = [];
|
|
10593
|
+
const sampleExample = this.examples[0];
|
|
10594
|
+
const samplePrediction = await program.forward(
|
|
10595
|
+
this.studentAI,
|
|
10596
|
+
sampleExample
|
|
10597
|
+
);
|
|
10598
|
+
const sampleScores = await metricFn({
|
|
10599
|
+
prediction: samplePrediction,
|
|
10600
|
+
example: sampleExample
|
|
10601
|
+
});
|
|
10602
|
+
const objectives = Object.keys(sampleScores);
|
|
10603
|
+
for (const primaryObjective of objectives) {
|
|
10604
|
+
if (options?.verbose) {
|
|
10605
|
+
console.log(
|
|
10606
|
+
`Optimizing ${primaryObjective} with constraints on other objectives`
|
|
10607
|
+
);
|
|
10608
|
+
}
|
|
10609
|
+
const constraintMetric = async ({ prediction, example }) => {
|
|
10610
|
+
const scores = await metricFn({ prediction, example });
|
|
10611
|
+
const primaryScore = scores[primaryObjective] || 0;
|
|
10612
|
+
let penalty = 0;
|
|
10613
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10614
|
+
if (objective !== primaryObjective) {
|
|
10615
|
+
if (score < 0.3) {
|
|
10616
|
+
penalty += (0.3 - score) * 2;
|
|
10617
|
+
}
|
|
10618
|
+
}
|
|
10619
|
+
}
|
|
10620
|
+
return primaryScore - penalty;
|
|
10621
|
+
};
|
|
10622
|
+
try {
|
|
10623
|
+
const result = await this.compile(program, constraintMetric, {
|
|
10624
|
+
...options,
|
|
10625
|
+
verbose: false
|
|
10626
|
+
});
|
|
10627
|
+
const scores = await this.evaluateWithMultiObjective(
|
|
10628
|
+
program,
|
|
10629
|
+
result,
|
|
10630
|
+
metricFn
|
|
10631
|
+
);
|
|
10632
|
+
solutions.push({
|
|
10633
|
+
scores,
|
|
10634
|
+
demos: result.demos,
|
|
10635
|
+
configuration: {
|
|
10636
|
+
...result.finalConfiguration,
|
|
10637
|
+
primaryObjective,
|
|
10638
|
+
strategy: "constraint_based"
|
|
10639
|
+
}
|
|
10640
|
+
});
|
|
10641
|
+
} catch (error) {
|
|
10642
|
+
if (options?.verbose) {
|
|
10643
|
+
console.warn(
|
|
10644
|
+
`Failed constraint optimization for ${primaryObjective}:`,
|
|
10645
|
+
error
|
|
10646
|
+
);
|
|
10647
|
+
}
|
|
10648
|
+
continue;
|
|
10649
|
+
}
|
|
10650
|
+
}
|
|
10651
|
+
return solutions;
|
|
10652
|
+
}
|
|
10653
|
+
/**
|
|
10654
|
+
* Generate different weight combinations for objectives
|
|
10655
|
+
*/
|
|
10656
|
+
generateWeightCombinations(objectives) {
|
|
10657
|
+
const combinations = [];
|
|
10658
|
+
for (const objective of objectives) {
|
|
10659
|
+
const weights = {};
|
|
10660
|
+
for (const obj of objectives) {
|
|
10661
|
+
weights[obj] = obj === objective ? 1 : 0;
|
|
10662
|
+
}
|
|
10663
|
+
combinations.push(weights);
|
|
10664
|
+
}
|
|
10665
|
+
const equalWeights = {};
|
|
10666
|
+
for (const objective of objectives) {
|
|
10667
|
+
equalWeights[objective] = 1 / objectives.length;
|
|
10668
|
+
}
|
|
10669
|
+
combinations.push(equalWeights);
|
|
10670
|
+
if (objectives.length === 2) {
|
|
10671
|
+
const [obj1, obj2] = objectives;
|
|
10672
|
+
for (let w1 = 0.1; w1 <= 0.9; w1 += 0.2) {
|
|
10673
|
+
const w2 = 1 - w1;
|
|
10674
|
+
combinations.push({ [obj1]: w1, [obj2]: w2 });
|
|
10675
|
+
}
|
|
10676
|
+
}
|
|
10677
|
+
if (objectives.length === 3) {
|
|
10678
|
+
const [obj1, obj2, obj3] = objectives;
|
|
10679
|
+
combinations.push(
|
|
10680
|
+
{ [obj1]: 0.5, [obj2]: 0.3, [obj3]: 0.2 },
|
|
10681
|
+
{ [obj1]: 0.3, [obj2]: 0.5, [obj3]: 0.2 },
|
|
10682
|
+
{ [obj1]: 0.2, [obj2]: 0.3, [obj3]: 0.5 }
|
|
10683
|
+
);
|
|
10684
|
+
}
|
|
10685
|
+
return combinations;
|
|
10686
|
+
}
|
|
10687
|
+
/**
|
|
10688
|
+
* Evaluate a single-objective result with multi-objective metrics
|
|
10689
|
+
*/
|
|
10690
|
+
async evaluateWithMultiObjective(program, result, metricFn) {
|
|
10691
|
+
const valSet = this.getValidationSet();
|
|
10692
|
+
const allScores = {};
|
|
10693
|
+
const testProgram = { ...program };
|
|
10694
|
+
if (result.demos && "setDemos" in testProgram) {
|
|
10695
|
+
;
|
|
10696
|
+
testProgram.setDemos(result.demos);
|
|
10697
|
+
}
|
|
10698
|
+
const evalSet = valSet.slice(0, Math.min(5, valSet.length));
|
|
10699
|
+
for (const example of evalSet) {
|
|
10700
|
+
try {
|
|
10701
|
+
const prediction = await testProgram.forward(
|
|
10702
|
+
this.studentAI,
|
|
10703
|
+
example
|
|
10704
|
+
);
|
|
10705
|
+
const scores = await metricFn({ prediction, example });
|
|
10706
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10707
|
+
if (!allScores[objective]) {
|
|
10708
|
+
allScores[objective] = [];
|
|
10709
|
+
}
|
|
10710
|
+
allScores[objective].push(score);
|
|
10711
|
+
}
|
|
10712
|
+
} catch {
|
|
10713
|
+
continue;
|
|
10714
|
+
}
|
|
10715
|
+
}
|
|
10716
|
+
const avgScores = {};
|
|
10717
|
+
for (const [objective, scores] of Object.entries(allScores)) {
|
|
10718
|
+
avgScores[objective] = scores.length > 0 ? scores.reduce((sum, score) => sum + score, 0) / scores.length : 0;
|
|
10719
|
+
}
|
|
10720
|
+
return avgScores;
|
|
10721
|
+
}
|
|
10722
|
+
/**
|
|
10723
|
+
* Find the Pareto frontier from a set of solutions
|
|
10724
|
+
*/
|
|
10725
|
+
findParetoFrontier(solutions) {
|
|
10726
|
+
const paretoFront = [];
|
|
10727
|
+
for (let i = 0; i < solutions.length; i++) {
|
|
10728
|
+
const solutionA = solutions[i];
|
|
10729
|
+
let isDominated = false;
|
|
10730
|
+
let dominatedCount = 0;
|
|
10731
|
+
for (let j = 0; j < solutions.length; j++) {
|
|
10732
|
+
if (i === j) continue;
|
|
10733
|
+
const solutionB = solutions[j];
|
|
10734
|
+
if (this.dominates(solutionB.scores, solutionA.scores)) {
|
|
10735
|
+
isDominated = true;
|
|
10736
|
+
break;
|
|
10737
|
+
}
|
|
10738
|
+
if (this.dominates(solutionA.scores, solutionB.scores)) {
|
|
10739
|
+
dominatedCount++;
|
|
10740
|
+
}
|
|
10741
|
+
}
|
|
10742
|
+
if (!isDominated) {
|
|
10743
|
+
paretoFront.push({
|
|
10744
|
+
demos: solutionA.demos || [],
|
|
10745
|
+
scores: solutionA.scores,
|
|
10746
|
+
configuration: solutionA.configuration,
|
|
10747
|
+
dominatedSolutions: dominatedCount
|
|
10748
|
+
});
|
|
10749
|
+
}
|
|
10750
|
+
}
|
|
10751
|
+
return paretoFront;
|
|
10752
|
+
}
|
|
10753
|
+
/**
|
|
10754
|
+
* Check if solution A dominates solution B
|
|
10755
|
+
* A dominates B if A is better or equal in all objectives and strictly better in at least one
|
|
10756
|
+
*/
|
|
10757
|
+
dominates(scoresA, scoresB) {
|
|
10758
|
+
const objectives = Object.keys(scoresA);
|
|
10759
|
+
let atLeastAsGood = true;
|
|
10760
|
+
let strictlyBetter = false;
|
|
10761
|
+
for (const objective of objectives) {
|
|
10762
|
+
const scoreA = scoresA[objective] || 0;
|
|
10763
|
+
const scoreB = scoresB[objective] || 0;
|
|
10764
|
+
if (scoreA < scoreB) {
|
|
10765
|
+
atLeastAsGood = false;
|
|
10766
|
+
break;
|
|
10767
|
+
}
|
|
10768
|
+
if (scoreA > scoreB) {
|
|
10769
|
+
strictlyBetter = true;
|
|
10770
|
+
}
|
|
10771
|
+
}
|
|
10772
|
+
return atLeastAsGood && strictlyBetter;
|
|
10773
|
+
}
|
|
10774
|
+
/**
|
|
10775
|
+
* Calculate hypervolume of the Pareto frontier
|
|
10776
|
+
* Simplified implementation using reference point at origin
|
|
10777
|
+
*/
|
|
10778
|
+
calculateHypervolume(paretoFront) {
|
|
10779
|
+
if (paretoFront.length === 0) return void 0;
|
|
10780
|
+
const firstSolution = paretoFront[0];
|
|
10781
|
+
const objectives = Object.keys(firstSolution.scores);
|
|
10782
|
+
if (objectives.length === 2) {
|
|
10783
|
+
const [obj1, obj2] = objectives;
|
|
10784
|
+
let hypervolume = 0;
|
|
10785
|
+
const sortedSolutions = [...paretoFront].sort(
|
|
10786
|
+
(a, b) => (b.scores[obj1] || 0) - (a.scores[obj1] || 0)
|
|
10787
|
+
);
|
|
10788
|
+
let prevScore2 = 0;
|
|
10789
|
+
for (const solution of sortedSolutions) {
|
|
10790
|
+
const score1 = solution.scores[obj1] || 0;
|
|
10791
|
+
const score2 = solution.scores[obj2] || 0;
|
|
10792
|
+
hypervolume += score1 * (score2 - prevScore2);
|
|
10793
|
+
prevScore2 = Math.max(prevScore2, score2);
|
|
10794
|
+
}
|
|
10795
|
+
return hypervolume;
|
|
10796
|
+
}
|
|
10797
|
+
return void 0;
|
|
10798
|
+
}
|
|
10799
|
+
/**
|
|
10800
|
+
* Save current optimization state to checkpoint
|
|
10801
|
+
*/
|
|
10802
|
+
async saveCheckpoint(optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10803
|
+
const saveFn = options?.overrideCheckpointSave || this.checkpointSave;
|
|
10804
|
+
if (!saveFn) return void 0;
|
|
10805
|
+
const checkpoint = {
|
|
10806
|
+
version: "1.0.0",
|
|
10807
|
+
timestamp: Date.now(),
|
|
10808
|
+
optimizerType,
|
|
10809
|
+
optimizerConfig,
|
|
10810
|
+
currentRound: this.currentRound,
|
|
10811
|
+
totalRounds: this.stats.resourceUsage.totalTime > 0 ? this.currentRound : 0,
|
|
10812
|
+
bestScore,
|
|
10813
|
+
bestConfiguration,
|
|
10814
|
+
scoreHistory: [...this.scoreHistory],
|
|
10815
|
+
configurationHistory: [...this.configurationHistory],
|
|
10816
|
+
stats: { ...this.stats },
|
|
10817
|
+
optimizerState,
|
|
10818
|
+
examples: this.examples,
|
|
10819
|
+
validationSet: this.validationSet
|
|
10820
|
+
};
|
|
10821
|
+
return await saveFn(checkpoint);
|
|
10822
|
+
}
|
|
10823
|
+
/**
|
|
10824
|
+
* Load optimization state from checkpoint
|
|
10825
|
+
*/
|
|
10826
|
+
async loadCheckpoint(checkpointId, options) {
|
|
10827
|
+
const loadFn = options?.overrideCheckpointLoad || this.checkpointLoad;
|
|
10828
|
+
if (!loadFn) return null;
|
|
10829
|
+
return await loadFn(checkpointId);
|
|
10830
|
+
}
|
|
10831
|
+
/**
|
|
10832
|
+
* Restore optimizer state from checkpoint
|
|
10833
|
+
*/
|
|
10834
|
+
restoreFromCheckpoint(checkpoint) {
|
|
10835
|
+
this.currentRound = checkpoint.currentRound;
|
|
10836
|
+
this.scoreHistory = [...checkpoint.scoreHistory];
|
|
10837
|
+
this.configurationHistory = [...checkpoint.configurationHistory];
|
|
10838
|
+
this.stats = { ...checkpoint.stats };
|
|
10839
|
+
}
|
|
10840
|
+
/**
|
|
10841
|
+
* Check if checkpoint should be saved
|
|
10842
|
+
*/
|
|
10843
|
+
shouldSaveCheckpoint(round, options) {
|
|
10844
|
+
const interval = options?.overrideCheckpointInterval || this.checkpointInterval;
|
|
10845
|
+
return interval !== void 0 && round % interval === 0;
|
|
10846
|
+
}
|
|
10847
|
+
/**
|
|
10848
|
+
* Update optimization progress and handle checkpointing
|
|
10849
|
+
*/
|
|
10850
|
+
async updateOptimizationProgress(round, score, configuration, optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10851
|
+
this.currentRound = round;
|
|
10852
|
+
this.scoreHistory.push(score);
|
|
10853
|
+
this.configurationHistory.push(configuration);
|
|
10854
|
+
if (this.shouldSaveCheckpoint(round, options)) {
|
|
10855
|
+
await this.saveCheckpoint(
|
|
10856
|
+
optimizerType,
|
|
10857
|
+
optimizerConfig,
|
|
10858
|
+
bestScore,
|
|
10859
|
+
bestConfiguration,
|
|
10860
|
+
optimizerState,
|
|
10861
|
+
options
|
|
10862
|
+
);
|
|
10863
|
+
}
|
|
10864
|
+
}
|
|
10865
|
+
/**
|
|
10866
|
+
* Save final checkpoint on completion
|
|
10867
|
+
*/
|
|
10868
|
+
async saveFinalCheckpoint(optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10869
|
+
if (options?.saveCheckpointOnComplete !== false) {
|
|
10870
|
+
await this.saveCheckpoint(
|
|
10871
|
+
optimizerType,
|
|
10872
|
+
optimizerConfig,
|
|
10873
|
+
bestScore,
|
|
10874
|
+
bestConfiguration,
|
|
10875
|
+
{ ...optimizerState, final: true },
|
|
10876
|
+
options
|
|
10877
|
+
);
|
|
10878
|
+
}
|
|
10879
|
+
}
|
|
10880
|
+
};
|
|
10881
|
+
|
|
10882
|
+
// db/base.ts
|
|
10883
|
+
var import_api23 = require("@opentelemetry/api");
|
|
10884
|
+
var AxDBBase = class {
|
|
10885
|
+
name;
|
|
10886
|
+
fetch;
|
|
10887
|
+
tracer;
|
|
10888
|
+
_upsert;
|
|
10889
|
+
_batchUpsert;
|
|
10890
|
+
_query;
|
|
10891
|
+
constructor({
|
|
10892
|
+
name,
|
|
10893
|
+
fetch: fetch2,
|
|
10894
|
+
tracer
|
|
10895
|
+
}) {
|
|
10896
|
+
this.name = name;
|
|
10897
|
+
this.fetch = fetch2;
|
|
10898
|
+
this.tracer = tracer;
|
|
10899
|
+
}
|
|
10900
|
+
async upsert(req, update) {
|
|
10901
|
+
if (!this._upsert) {
|
|
10902
|
+
throw new Error("upsert() not implemented");
|
|
10903
|
+
}
|
|
10904
|
+
if (!this.tracer) {
|
|
10905
|
+
return await this._upsert(req, update);
|
|
10906
|
+
}
|
|
10907
|
+
return await this.tracer.startActiveSpan(
|
|
10908
|
+
"DB Upsert Request",
|
|
10909
|
+
{
|
|
10910
|
+
kind: import_api23.SpanKind.SERVER,
|
|
10911
|
+
attributes: {
|
|
10912
|
+
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10913
|
+
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10914
|
+
[axSpanAttributes.DB_TABLE]: req.table,
|
|
10915
|
+
[axSpanAttributes.DB_NAMESPACE]: req.namespace,
|
|
10916
|
+
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
10917
|
+
}
|
|
10918
|
+
},
|
|
10919
|
+
async (span) => {
|
|
10920
|
+
try {
|
|
10921
|
+
return await this._upsert(req, update, { span });
|
|
10922
|
+
} finally {
|
|
10923
|
+
span.end();
|
|
10924
|
+
}
|
|
10925
|
+
}
|
|
10926
|
+
);
|
|
10927
|
+
}
|
|
10928
|
+
async batchUpsert(req, update) {
|
|
10929
|
+
if (!this._batchUpsert) {
|
|
10930
|
+
throw new Error("batchUpsert() not implemented");
|
|
10931
|
+
}
|
|
10932
|
+
if (req.length == 0) {
|
|
10933
|
+
throw new Error("Batch request is empty");
|
|
10934
|
+
}
|
|
10935
|
+
if (!req[0]) {
|
|
10936
|
+
throw new Error("Batch request is invalid first element is undefined");
|
|
10937
|
+
}
|
|
10938
|
+
if (!this.tracer) {
|
|
10939
|
+
return await this._batchUpsert(req, update);
|
|
10940
|
+
}
|
|
10941
|
+
return await this.tracer.startActiveSpan(
|
|
10942
|
+
"DB Batch Upsert Request",
|
|
10943
|
+
{
|
|
10944
|
+
kind: import_api23.SpanKind.SERVER,
|
|
10945
|
+
attributes: {
|
|
10946
|
+
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10947
|
+
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10948
|
+
[axSpanAttributes.DB_TABLE]: req[0].table,
|
|
10949
|
+
[axSpanAttributes.DB_NAMESPACE]: req[0].namespace,
|
|
10950
|
+
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
10951
|
+
}
|
|
10952
|
+
},
|
|
10953
|
+
async (span) => {
|
|
10954
|
+
try {
|
|
10955
|
+
return await this._batchUpsert(req, update, { span });
|
|
10956
|
+
} finally {
|
|
10957
|
+
span.end();
|
|
10958
|
+
}
|
|
10959
|
+
}
|
|
10960
|
+
);
|
|
10961
|
+
}
|
|
10962
|
+
async query(req) {
|
|
10963
|
+
if (!this._query) {
|
|
10964
|
+
throw new Error("query() not implemented");
|
|
10965
|
+
}
|
|
10966
|
+
if (!this.tracer) {
|
|
10967
|
+
return await this._query(req);
|
|
10968
|
+
}
|
|
10969
|
+
return await this.tracer.startActiveSpan(
|
|
10970
|
+
"DB Query Request",
|
|
10971
|
+
{
|
|
10972
|
+
kind: import_api23.SpanKind.SERVER,
|
|
10973
|
+
attributes: {
|
|
10974
|
+
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10975
|
+
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10976
|
+
[axSpanAttributes.DB_TABLE]: req.table,
|
|
10977
|
+
[axSpanAttributes.DB_NAMESPACE]: req.namespace,
|
|
10978
|
+
[axSpanAttributes.DB_OPERATION_NAME]: "query"
|
|
10979
|
+
}
|
|
10980
|
+
},
|
|
10981
|
+
async (span) => {
|
|
10982
|
+
try {
|
|
10983
|
+
return await this._query(req, { span });
|
|
10984
|
+
} finally {
|
|
10985
|
+
span.end();
|
|
10986
|
+
}
|
|
10987
|
+
}
|
|
10988
|
+
);
|
|
10989
|
+
}
|
|
10990
|
+
};
|
|
10991
|
+
|
|
10992
|
+
// db/cloudflare.ts
|
|
10993
|
+
var baseURL = "https://api.cloudflare.com/client/v4/accounts/";
|
|
10994
|
+
var AxDBCloudflare = class extends AxDBBase {
|
|
10995
|
+
apiKey;
|
|
10996
|
+
accountId;
|
|
10997
|
+
constructor({
|
|
10998
|
+
apiKey,
|
|
10999
|
+
accountId,
|
|
11000
|
+
fetch: fetch2,
|
|
11001
|
+
tracer
|
|
11002
|
+
}) {
|
|
11003
|
+
if (!apiKey || !accountId) {
|
|
11004
|
+
throw new Error("Cloudflare credentials not set");
|
|
11005
|
+
}
|
|
11006
|
+
super({ name: "Cloudflare", fetch: fetch2, tracer });
|
|
11007
|
+
this.apiKey = apiKey;
|
|
11008
|
+
this.accountId = accountId;
|
|
11009
|
+
}
|
|
11010
|
+
_upsert = async (req, _update, options) => {
|
|
11011
|
+
const res = await apiCall(
|
|
11012
|
+
{
|
|
11013
|
+
url: new URL(
|
|
11014
|
+
`${this.accountId}/vectorize/indexes/${req.table}/upsert`,
|
|
11015
|
+
baseURL
|
|
11016
|
+
),
|
|
11017
|
+
headers: {
|
|
11018
|
+
"X-Auth-Key": this.apiKey
|
|
11019
|
+
},
|
|
11020
|
+
fetch: this.fetch,
|
|
11021
|
+
span: options?.span
|
|
11022
|
+
},
|
|
11023
|
+
{
|
|
11024
|
+
id: req.id,
|
|
11025
|
+
values: req.values,
|
|
11026
|
+
namespace: req.namespace,
|
|
11027
|
+
metadata: req.metadata
|
|
11028
|
+
}
|
|
11029
|
+
);
|
|
11030
|
+
if (res.errors) {
|
|
11031
|
+
throw new Error(
|
|
11032
|
+
`Cloudflare upsert failed: ${res.errors.map(({ message }) => message).join(", ")}`
|
|
11033
|
+
);
|
|
11034
|
+
}
|
|
11035
|
+
return {
|
|
11036
|
+
ids: res.result.ids
|
|
11037
|
+
};
|
|
11038
|
+
};
|
|
11039
|
+
batchUpsert = async (batchReq, update, options) => {
|
|
11040
|
+
if (update) {
|
|
11041
|
+
throw new Error("Weaviate does not support batch update");
|
|
11042
|
+
}
|
|
11043
|
+
if (batchReq.length < 1) {
|
|
11044
|
+
throw new Error("Batch request is empty");
|
|
11045
|
+
}
|
|
11046
|
+
if (!batchReq[0] || !batchReq[0].table) {
|
|
10320
11047
|
throw new Error("Table name is empty");
|
|
10321
11048
|
}
|
|
10322
11049
|
const table2 = batchReq[0].table;
|
|
@@ -11611,11 +12338,7 @@ var AxMCPStreambleHTTPTransport = class {
|
|
|
11611
12338
|
};
|
|
11612
12339
|
|
|
11613
12340
|
// dsp/optimizers/bootstrapFewshot.ts
|
|
11614
|
-
var AxBootstrapFewShot = class {
|
|
11615
|
-
ai;
|
|
11616
|
-
teacherAI;
|
|
11617
|
-
program;
|
|
11618
|
-
examples;
|
|
12341
|
+
var AxBootstrapFewShot = class extends AxBaseOptimizer {
|
|
11619
12342
|
maxRounds;
|
|
11620
12343
|
maxDemos;
|
|
11621
12344
|
maxExamples;
|
|
@@ -11626,37 +12349,20 @@ var AxBootstrapFewShot = class {
|
|
|
11626
12349
|
verboseMode;
|
|
11627
12350
|
debugMode;
|
|
11628
12351
|
traces = [];
|
|
11629
|
-
|
|
11630
|
-
|
|
11631
|
-
|
|
11632
|
-
|
|
11633
|
-
|
|
11634
|
-
|
|
11635
|
-
|
|
11636
|
-
|
|
11637
|
-
|
|
11638
|
-
|
|
11639
|
-
options
|
|
11640
|
-
|
|
11641
|
-
|
|
11642
|
-
|
|
11643
|
-
}
|
|
11644
|
-
const bootstrapOptions = options;
|
|
11645
|
-
this.maxRounds = bootstrapOptions?.maxRounds ?? 3;
|
|
11646
|
-
this.maxDemos = bootstrapOptions?.maxDemos ?? 4;
|
|
11647
|
-
this.maxExamples = bootstrapOptions?.maxExamples ?? 16;
|
|
11648
|
-
this.batchSize = bootstrapOptions?.batchSize ?? 1;
|
|
11649
|
-
this.earlyStoppingPatience = bootstrapOptions?.earlyStoppingPatience ?? 0;
|
|
11650
|
-
this.costMonitoring = bootstrapOptions?.costMonitoring ?? false;
|
|
11651
|
-
this.maxTokensPerGeneration = bootstrapOptions?.maxTokensPerGeneration ?? 0;
|
|
11652
|
-
this.verboseMode = bootstrapOptions?.verboseMode ?? true;
|
|
11653
|
-
this.debugMode = bootstrapOptions?.debugMode ?? false;
|
|
11654
|
-
this.ai = ai;
|
|
11655
|
-
this.teacherAI = bootstrapOptions?.teacherAI;
|
|
11656
|
-
this.program = program;
|
|
11657
|
-
this.examples = examples;
|
|
11658
|
-
}
|
|
11659
|
-
async compileRound(roundIndex, metricFn, options) {
|
|
12352
|
+
constructor(args) {
|
|
12353
|
+
super(args);
|
|
12354
|
+
const options = args.options || {};
|
|
12355
|
+
this.maxRounds = options.maxRounds ?? 3;
|
|
12356
|
+
this.maxDemos = options.maxDemos ?? 4;
|
|
12357
|
+
this.maxExamples = options.maxExamples ?? 16;
|
|
12358
|
+
this.batchSize = options.batchSize ?? 1;
|
|
12359
|
+
this.earlyStoppingPatience = options.earlyStoppingPatience ?? 0;
|
|
12360
|
+
this.costMonitoring = options.costMonitoring ?? false;
|
|
12361
|
+
this.maxTokensPerGeneration = options.maxTokensPerGeneration ?? 0;
|
|
12362
|
+
this.verboseMode = options.verboseMode ?? true;
|
|
12363
|
+
this.debugMode = options.debugMode ?? false;
|
|
12364
|
+
}
|
|
12365
|
+
async compileRound(program, roundIndex, metricFn, options) {
|
|
11660
12366
|
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
11661
12367
|
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
11662
12368
|
const aiOpt = {
|
|
@@ -11679,20 +12385,20 @@ var AxBootstrapFewShot = class {
|
|
|
11679
12385
|
continue;
|
|
11680
12386
|
}
|
|
11681
12387
|
const exList = examples.filter((e) => e !== ex);
|
|
11682
|
-
|
|
11683
|
-
const aiService = this.
|
|
12388
|
+
program.setExamples(exList);
|
|
12389
|
+
const aiService = this.getTeacherOrStudentAI();
|
|
11684
12390
|
this.stats.totalCalls++;
|
|
11685
12391
|
let res;
|
|
11686
12392
|
let error;
|
|
11687
12393
|
try {
|
|
11688
|
-
res = await
|
|
12394
|
+
res = await program.forward(aiService, ex, aiOpt);
|
|
11689
12395
|
if (this.costMonitoring) {
|
|
11690
12396
|
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
11691
12397
|
}
|
|
11692
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
12398
|
+
const score = await metricFn({ prediction: res, example: ex });
|
|
11693
12399
|
const success = score >= 0.5;
|
|
11694
12400
|
if (success) {
|
|
11695
|
-
this.traces = [...this.traces, ...
|
|
12401
|
+
this.traces = [...this.traces, ...program.getTraces()];
|
|
11696
12402
|
this.stats.successfulDemos++;
|
|
11697
12403
|
}
|
|
11698
12404
|
} catch (err) {
|
|
@@ -11743,13 +12449,15 @@ var AxBootstrapFewShot = class {
|
|
|
11743
12449
|
if (!this.stats.earlyStopping) {
|
|
11744
12450
|
this.stats.earlyStopping = {
|
|
11745
12451
|
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
11746
|
-
patienceExhausted: false
|
|
12452
|
+
patienceExhausted: false,
|
|
12453
|
+
reason: "No improvement detected"
|
|
11747
12454
|
};
|
|
11748
12455
|
} else if (improvement > 0) {
|
|
11749
12456
|
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
11750
12457
|
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
11751
12458
|
this.stats.earlyStopping.patienceExhausted = true;
|
|
11752
12459
|
this.stats.earlyStopped = true;
|
|
12460
|
+
this.stats.earlyStopping.reason = `No improvement for ${this.earlyStoppingPatience} rounds`;
|
|
11753
12461
|
if (this.verboseMode || this.debugMode) {
|
|
11754
12462
|
console.log(
|
|
11755
12463
|
`
|
|
@@ -11760,37 +12468,38 @@ Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${th
|
|
|
11760
12468
|
}
|
|
11761
12469
|
}
|
|
11762
12470
|
}
|
|
11763
|
-
async compile(metricFn, options) {
|
|
11764
|
-
const
|
|
11765
|
-
const maxRounds = compileOptions?.maxRounds ?? this.maxRounds;
|
|
12471
|
+
async compile(program, metricFn, options) {
|
|
12472
|
+
const maxRounds = options?.maxIterations ?? this.maxRounds;
|
|
11766
12473
|
this.traces = [];
|
|
11767
|
-
this.
|
|
11768
|
-
totalCalls: 0,
|
|
11769
|
-
successfulDemos: 0,
|
|
11770
|
-
estimatedTokenUsage: 0,
|
|
11771
|
-
earlyStopped: false
|
|
11772
|
-
};
|
|
12474
|
+
this.reset();
|
|
11773
12475
|
for (let i = 0; i < maxRounds; i++) {
|
|
11774
|
-
await this.compileRound(i, metricFn,
|
|
12476
|
+
await this.compileRound(program, i, metricFn, options);
|
|
11775
12477
|
if (this.stats.earlyStopped) {
|
|
11776
12478
|
break;
|
|
11777
12479
|
}
|
|
11778
12480
|
}
|
|
11779
12481
|
if (this.traces.length === 0) {
|
|
11780
12482
|
throw new Error(
|
|
11781
|
-
"No demonstrations found. Either
|
|
12483
|
+
"No demonstrations found. Either provide more examples or improve the existing ones."
|
|
11782
12484
|
);
|
|
11783
12485
|
}
|
|
11784
12486
|
const demos = groupTracesByKeys(this.traces);
|
|
12487
|
+
let bestScore = 0;
|
|
12488
|
+
if (this.traces.length > 0) {
|
|
12489
|
+
bestScore = this.stats.successfulDemos / Math.max(1, this.stats.totalCalls);
|
|
12490
|
+
}
|
|
11785
12491
|
return {
|
|
11786
12492
|
demos,
|
|
11787
|
-
stats: this.stats
|
|
12493
|
+
stats: this.stats,
|
|
12494
|
+
bestScore,
|
|
12495
|
+
finalConfiguration: {
|
|
12496
|
+
maxRounds: this.maxRounds,
|
|
12497
|
+
maxDemos: this.maxDemos,
|
|
12498
|
+
batchSize: this.batchSize,
|
|
12499
|
+
successRate: bestScore
|
|
12500
|
+
}
|
|
11788
12501
|
};
|
|
11789
12502
|
}
|
|
11790
|
-
// Get optimization statistics
|
|
11791
|
-
getStats() {
|
|
11792
|
-
return this.stats;
|
|
11793
|
-
}
|
|
11794
12503
|
};
|
|
11795
12504
|
function groupTracesByKeys(programTraces) {
|
|
11796
12505
|
const groupedTraces = /* @__PURE__ */ new Map();
|
|
@@ -11805,9 +12514,12 @@ function groupTracesByKeys(programTraces) {
|
|
|
11805
12514
|
}
|
|
11806
12515
|
}
|
|
11807
12516
|
const programDemosArray = [];
|
|
11808
|
-
|
|
11809
|
-
programDemosArray.push({
|
|
11810
|
-
|
|
12517
|
+
groupedTraces.forEach((traces, programId) => {
|
|
12518
|
+
programDemosArray.push({
|
|
12519
|
+
traces,
|
|
12520
|
+
programId
|
|
12521
|
+
});
|
|
12522
|
+
});
|
|
11811
12523
|
return programDemosArray;
|
|
11812
12524
|
}
|
|
11813
12525
|
var randomSample = (array, n) => {
|
|
@@ -11826,10 +12538,8 @@ var randomSample = (array, n) => {
|
|
|
11826
12538
|
};
|
|
11827
12539
|
|
|
11828
12540
|
// dsp/optimizers/miproV2.ts
|
|
11829
|
-
var AxMiPRO = class {
|
|
11830
|
-
|
|
11831
|
-
program;
|
|
11832
|
-
examples;
|
|
12541
|
+
var AxMiPRO = class extends AxBaseOptimizer {
|
|
12542
|
+
// MiPRO-specific options
|
|
11833
12543
|
maxBootstrappedDemos;
|
|
11834
12544
|
maxLabeledDemos;
|
|
11835
12545
|
numCandidates;
|
|
@@ -11843,52 +12553,35 @@ var AxMiPRO = class {
|
|
|
11843
12553
|
viewDataBatchSize;
|
|
11844
12554
|
tipAwareProposer;
|
|
11845
12555
|
fewshotAwareProposer;
|
|
11846
|
-
seed;
|
|
11847
12556
|
verbose;
|
|
11848
|
-
bootstrapper;
|
|
11849
12557
|
earlyStoppingTrials;
|
|
11850
12558
|
minImprovementThreshold;
|
|
11851
|
-
|
|
11852
|
-
|
|
11853
|
-
|
|
11854
|
-
|
|
11855
|
-
|
|
11856
|
-
|
|
11857
|
-
|
|
11858
|
-
|
|
11859
|
-
|
|
11860
|
-
|
|
11861
|
-
this.
|
|
11862
|
-
this.
|
|
11863
|
-
this.
|
|
11864
|
-
this.
|
|
11865
|
-
this.
|
|
11866
|
-
this.
|
|
11867
|
-
this.
|
|
11868
|
-
this.
|
|
11869
|
-
this.
|
|
11870
|
-
this.
|
|
11871
|
-
this.
|
|
11872
|
-
this.
|
|
11873
|
-
this.
|
|
11874
|
-
this.
|
|
11875
|
-
this.
|
|
11876
|
-
this.
|
|
11877
|
-
this.minImprovementThreshold = miproOptions.minImprovementThreshold ?? 0.01;
|
|
11878
|
-
this.ai = ai;
|
|
11879
|
-
this.program = program;
|
|
11880
|
-
this.examples = examples;
|
|
11881
|
-
this.bootstrapper = new AxBootstrapFewShot({
|
|
11882
|
-
ai,
|
|
11883
|
-
program,
|
|
11884
|
-
examples,
|
|
11885
|
-
options: {
|
|
11886
|
-
maxDemos: this.maxBootstrappedDemos,
|
|
11887
|
-
maxRounds: 3,
|
|
11888
|
-
// Default, or adjust based on your needs
|
|
11889
|
-
verboseMode: this.verbose
|
|
11890
|
-
}
|
|
11891
|
-
});
|
|
12559
|
+
bayesianOptimization;
|
|
12560
|
+
acquisitionFunction;
|
|
12561
|
+
explorationWeight;
|
|
12562
|
+
constructor(args) {
|
|
12563
|
+
super(args);
|
|
12564
|
+
const options = args.options || {};
|
|
12565
|
+
this.numCandidates = options.numCandidates ?? 5;
|
|
12566
|
+
this.initTemperature = options.initTemperature ?? 0.7;
|
|
12567
|
+
this.maxBootstrappedDemos = options.maxBootstrappedDemos ?? 3;
|
|
12568
|
+
this.maxLabeledDemos = options.maxLabeledDemos ?? 4;
|
|
12569
|
+
this.numTrials = options.numTrials ?? 30;
|
|
12570
|
+
this.minibatch = options.minibatch ?? true;
|
|
12571
|
+
this.minibatchSize = options.minibatchSize ?? 25;
|
|
12572
|
+
this.minibatchFullEvalSteps = options.minibatchFullEvalSteps ?? 10;
|
|
12573
|
+
this.programAwareProposer = options.programAwareProposer ?? true;
|
|
12574
|
+
this.dataAwareProposer = options.dataAwareProposer ?? true;
|
|
12575
|
+
this.viewDataBatchSize = options.viewDataBatchSize ?? 10;
|
|
12576
|
+
this.tipAwareProposer = options.tipAwareProposer ?? true;
|
|
12577
|
+
this.fewshotAwareProposer = options.fewshotAwareProposer ?? true;
|
|
12578
|
+
this.verbose = options.verbose ?? false;
|
|
12579
|
+
this.earlyStoppingTrials = options.earlyStoppingTrials ?? 5;
|
|
12580
|
+
this.minImprovementThreshold = options.minImprovementThreshold ?? 0.01;
|
|
12581
|
+
this.bayesianOptimization = options.bayesianOptimization ?? false;
|
|
12582
|
+
this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
|
|
12583
|
+
this.explorationWeight = options.explorationWeight ?? 0.1;
|
|
12584
|
+
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
11892
12585
|
}
|
|
11893
12586
|
/**
|
|
11894
12587
|
* Configures the optimizer for light, medium, or heavy optimization
|
|
@@ -11932,123 +12625,60 @@ var AxMiPRO = class {
|
|
|
11932
12625
|
];
|
|
11933
12626
|
}
|
|
11934
12627
|
/**
|
|
11935
|
-
* Generates instruction candidates
|
|
12628
|
+
* Generates instruction candidates using the teacher model if available
|
|
12629
|
+
* @param options Optional compile options that may override teacher AI
|
|
11936
12630
|
* @returns Array of generated instruction candidates
|
|
11937
12631
|
*/
|
|
11938
|
-
async proposeInstructionCandidates() {
|
|
12632
|
+
async proposeInstructionCandidates(options) {
|
|
11939
12633
|
const instructions = [];
|
|
11940
|
-
|
|
11941
|
-
if (this.programAwareProposer) {
|
|
11942
|
-
programContext = await this.generateProgramSummary();
|
|
11943
|
-
}
|
|
11944
|
-
let dataContext = "";
|
|
11945
|
-
if (this.dataAwareProposer) {
|
|
11946
|
-
dataContext = await this.generateDataSummary();
|
|
11947
|
-
}
|
|
12634
|
+
const aiToUse = this.getTeacherOrStudentAI(options);
|
|
11948
12635
|
const tips = this.tipAwareProposer ? this.generateTips() : [];
|
|
11949
12636
|
for (let i = 0; i < this.numCandidates; i++) {
|
|
11950
12637
|
const tipIndex = tips.length > 0 ? i % tips.length : -1;
|
|
11951
12638
|
const tipToUse = tipIndex >= 0 ? tips[tipIndex] : "";
|
|
11952
12639
|
const instruction = await this.generateInstruction({
|
|
11953
|
-
programContext,
|
|
11954
|
-
dataContext,
|
|
11955
12640
|
tip: tipToUse,
|
|
11956
|
-
candidateIndex: i
|
|
12641
|
+
candidateIndex: i,
|
|
12642
|
+
ai: aiToUse
|
|
11957
12643
|
});
|
|
11958
12644
|
instructions.push(instruction);
|
|
11959
12645
|
}
|
|
11960
12646
|
return instructions;
|
|
11961
12647
|
}
|
|
11962
|
-
/**
|
|
11963
|
-
* Generates a summary of the program structure for instruction proposal
|
|
11964
|
-
*/
|
|
11965
|
-
async generateProgramSummary() {
|
|
11966
|
-
const prompt = `Summarize the following program structure. Focus on the signatures,
|
|
11967
|
-
input/output fields, and the purpose of each component. Identify key components
|
|
11968
|
-
that might benefit from better instructions.`;
|
|
11969
|
-
const programStr = JSON.stringify(this.program);
|
|
11970
|
-
const response = await this.ai.chat({
|
|
11971
|
-
chatPrompt: [
|
|
11972
|
-
{ role: "system", content: prompt },
|
|
11973
|
-
{ role: "user", content: programStr }
|
|
11974
|
-
],
|
|
11975
|
-
modelConfig: { temperature: 0.2 }
|
|
11976
|
-
});
|
|
11977
|
-
if (response instanceof ReadableStream) {
|
|
11978
|
-
return "";
|
|
11979
|
-
}
|
|
11980
|
-
return response.results[0]?.content || "";
|
|
11981
|
-
}
|
|
11982
|
-
/**
|
|
11983
|
-
* Generates a summary of the dataset for instruction proposal
|
|
11984
|
-
*/
|
|
11985
|
-
async generateDataSummary() {
|
|
11986
|
-
const sampleSize = Math.min(this.viewDataBatchSize, this.examples.length);
|
|
11987
|
-
const sample = this.examples.slice(0, sampleSize);
|
|
11988
|
-
const prompt = `Analyze the following dataset examples and provide a summary
|
|
11989
|
-
of key patterns, input-output relationships, and any specific challenges
|
|
11990
|
-
the data presents. Focus on what makes a good answer and what patterns should
|
|
11991
|
-
be followed.`;
|
|
11992
|
-
const dataStr = JSON.stringify(sample);
|
|
11993
|
-
const response = await this.ai.chat({
|
|
11994
|
-
chatPrompt: [
|
|
11995
|
-
{ role: "system", content: prompt },
|
|
11996
|
-
{ role: "user", content: dataStr }
|
|
11997
|
-
],
|
|
11998
|
-
modelConfig: { temperature: 0.2 }
|
|
11999
|
-
});
|
|
12000
|
-
if (response instanceof ReadableStream) {
|
|
12001
|
-
return "";
|
|
12002
|
-
}
|
|
12003
|
-
return response.results[0]?.content || "";
|
|
12004
|
-
}
|
|
12005
|
-
/**
|
|
12006
|
-
* Generates a specific instruction candidate
|
|
12007
|
-
*/
|
|
12008
12648
|
async generateInstruction({
|
|
12009
|
-
programContext,
|
|
12010
|
-
dataContext,
|
|
12011
12649
|
tip,
|
|
12012
12650
|
candidateIndex
|
|
12013
12651
|
}) {
|
|
12014
|
-
const
|
|
12015
|
-
|
|
12016
|
-
|
|
12017
|
-
|
|
12018
|
-
|
|
12019
|
-
|
|
12020
|
-
|
|
12021
|
-
|
|
12022
|
-
|
|
12023
|
-
|
|
12024
|
-
${tip ? `STYLE TIP: ${tip}
|
|
12025
|
-
|
|
12026
|
-
` : ""}
|
|
12027
|
-
|
|
12028
|
-
Your task is to craft a clear, effective instruction that will help the AI model generate
|
|
12029
|
-
accurate outputs for this task. Instruction #${candidateIndex + 1}/${this.numCandidates}.
|
|
12030
|
-
|
|
12031
|
-
The instruction should be detailed enough to guide the model but not overly prescriptive
|
|
12032
|
-
or restrictive. Focus on what makes a good response rather than listing exact steps.
|
|
12033
|
-
|
|
12034
|
-
INSTRUCTION:`;
|
|
12035
|
-
const response = await this.ai.chat({
|
|
12036
|
-
chatPrompt: [{ role: "user", content: prompt }],
|
|
12037
|
-
modelConfig: { temperature: 0.7 + 0.1 * candidateIndex }
|
|
12038
|
-
});
|
|
12039
|
-
if (response instanceof ReadableStream) {
|
|
12040
|
-
return "";
|
|
12652
|
+
const baseInstructions = [
|
|
12653
|
+
"Analyze the input carefully and provide a detailed response.",
|
|
12654
|
+
"Think step by step and provide a clear answer.",
|
|
12655
|
+
"Consider all aspects of the input before responding.",
|
|
12656
|
+
"Provide a concise but comprehensive response.",
|
|
12657
|
+
"Focus on accuracy and clarity in your response."
|
|
12658
|
+
];
|
|
12659
|
+
let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
|
|
12660
|
+
if (tip) {
|
|
12661
|
+
instruction = `${instruction} ${tip}`;
|
|
12041
12662
|
}
|
|
12042
|
-
return
|
|
12663
|
+
return instruction;
|
|
12043
12664
|
}
|
|
12044
12665
|
/**
|
|
12045
12666
|
* Bootstraps few-shot examples for the program
|
|
12046
12667
|
*/
|
|
12047
|
-
async bootstrapFewShotExamples(metricFn) {
|
|
12668
|
+
async bootstrapFewShotExamples(program, metricFn) {
|
|
12048
12669
|
if (this.verbose) {
|
|
12049
12670
|
console.log("Bootstrapping few-shot examples...");
|
|
12050
12671
|
}
|
|
12051
|
-
const
|
|
12672
|
+
const bootstrapper = new AxBootstrapFewShot({
|
|
12673
|
+
studentAI: this.studentAI,
|
|
12674
|
+
examples: this.examples,
|
|
12675
|
+
options: {
|
|
12676
|
+
maxDemos: this.maxBootstrappedDemos,
|
|
12677
|
+
maxRounds: 3,
|
|
12678
|
+
verboseMode: this.verbose
|
|
12679
|
+
}
|
|
12680
|
+
});
|
|
12681
|
+
const result = await bootstrapper.compile(program, metricFn, {
|
|
12052
12682
|
maxDemos: this.maxBootstrappedDemos
|
|
12053
12683
|
});
|
|
12054
12684
|
return result.demos || [];
|
|
@@ -12072,109 +12702,98 @@ ${dataContext}
|
|
|
12072
12702
|
return selectedExamples;
|
|
12073
12703
|
}
|
|
12074
12704
|
/**
|
|
12075
|
-
* Runs
|
|
12705
|
+
* Runs optimization to find the best combination of few-shot examples and instructions
|
|
12076
12706
|
*/
|
|
12077
|
-
async
|
|
12078
|
-
let bestConfig =
|
|
12079
|
-
let bestScore = Number.NEGATIVE_INFINITY;
|
|
12080
|
-
const evaluatedConfigs = [];
|
|
12081
|
-
const defaultConfig = {
|
|
12707
|
+
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, valset, metricFn, options) {
|
|
12708
|
+
let bestConfig = {
|
|
12082
12709
|
instruction: instructions[0] || "",
|
|
12083
12710
|
bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
|
|
12084
12711
|
labeledExamples: Math.min(1, labeledExamples.length)
|
|
12085
12712
|
};
|
|
12086
|
-
let
|
|
12087
|
-
let
|
|
12088
|
-
const
|
|
12089
|
-
|
|
12090
|
-
|
|
12091
|
-
|
|
12092
|
-
|
|
12093
|
-
|
|
12094
|
-
|
|
12095
|
-
|
|
12713
|
+
let bestScore = 0;
|
|
12714
|
+
let stagnationRounds = 0;
|
|
12715
|
+
const scoreHistory = [];
|
|
12716
|
+
let startRound = 0;
|
|
12717
|
+
if (this.resumeFromCheckpoint) {
|
|
12718
|
+
const checkpoint = await this.loadCheckpoint(
|
|
12719
|
+
this.resumeFromCheckpoint,
|
|
12720
|
+
options
|
|
12721
|
+
);
|
|
12722
|
+
if (checkpoint && checkpoint.optimizerType === "MiPRO") {
|
|
12723
|
+
if (this.verbose || options?.verbose) {
|
|
12724
|
+
console.log(
|
|
12725
|
+
`Resuming from checkpoint at round ${checkpoint.currentRound}`
|
|
12726
|
+
);
|
|
12727
|
+
}
|
|
12728
|
+
this.restoreFromCheckpoint(checkpoint);
|
|
12729
|
+
startRound = checkpoint.currentRound;
|
|
12730
|
+
bestScore = checkpoint.bestScore;
|
|
12731
|
+
bestConfig = checkpoint.bestConfiguration || bestConfig;
|
|
12732
|
+
stagnationRounds = checkpoint.stats.convergenceInfo?.stagnationRounds || 0;
|
|
12733
|
+
}
|
|
12734
|
+
}
|
|
12735
|
+
for (let i = startRound; i < this.numTrials; i++) {
|
|
12096
12736
|
const config = {
|
|
12097
|
-
instruction:
|
|
12098
|
-
bootstrappedDemos: Math.
|
|
12099
|
-
Math.random() * (bootstrappedDemos.length + 1)
|
|
12737
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
12738
|
+
bootstrappedDemos: Math.min(
|
|
12739
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
12740
|
+
this.maxBootstrappedDemos
|
|
12100
12741
|
),
|
|
12101
|
-
labeledExamples: Math.
|
|
12102
|
-
Math.random() * (labeledExamples.length + 1)
|
|
12742
|
+
labeledExamples: Math.min(
|
|
12743
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
12744
|
+
this.maxLabeledDemos
|
|
12103
12745
|
)
|
|
12104
12746
|
};
|
|
12105
|
-
configs.push(config);
|
|
12106
|
-
}
|
|
12107
|
-
for (let i = 0; i < configs.length; i++) {
|
|
12108
|
-
const config = configs[i];
|
|
12109
|
-
if (!config) continue;
|
|
12110
12747
|
const score = await this.evaluateConfig(
|
|
12748
|
+
program,
|
|
12111
12749
|
config,
|
|
12112
12750
|
bootstrappedDemos,
|
|
12113
12751
|
labeledExamples,
|
|
12114
12752
|
valset,
|
|
12115
|
-
metricFn
|
|
12116
|
-
i
|
|
12753
|
+
metricFn
|
|
12117
12754
|
);
|
|
12118
|
-
|
|
12119
|
-
|
|
12755
|
+
scoreHistory.push(score);
|
|
12756
|
+
const improvement = score - bestScore;
|
|
12757
|
+
if (improvement > this.minImprovementThreshold) {
|
|
12120
12758
|
bestScore = score;
|
|
12121
12759
|
bestConfig = config;
|
|
12122
|
-
|
|
12123
|
-
|
|
12124
|
-
|
|
12125
|
-
);
|
|
12126
|
-
}
|
|
12760
|
+
stagnationRounds = 0;
|
|
12761
|
+
} else {
|
|
12762
|
+
stagnationRounds++;
|
|
12127
12763
|
}
|
|
12128
|
-
|
|
12764
|
+
await this.updateOptimizationProgress(
|
|
12129
12765
|
i + 1,
|
|
12130
|
-
|
|
12131
|
-
|
|
12132
|
-
|
|
12133
|
-
|
|
12134
|
-
|
|
12135
|
-
|
|
12136
|
-
|
|
12137
|
-
|
|
12138
|
-
|
|
12139
|
-
|
|
12140
|
-
|
|
12141
|
-
|
|
12142
|
-
|
|
12143
|
-
);
|
|
12144
|
-
const score = await this.evaluateConfig(
|
|
12145
|
-
nextConfig,
|
|
12146
|
-
bootstrappedDemos,
|
|
12147
|
-
labeledExamples,
|
|
12148
|
-
valset,
|
|
12149
|
-
metricFn,
|
|
12150
|
-
i
|
|
12766
|
+
score,
|
|
12767
|
+
config,
|
|
12768
|
+
"MiPRO",
|
|
12769
|
+
this.getConfiguration(),
|
|
12770
|
+
bestScore,
|
|
12771
|
+
bestConfig,
|
|
12772
|
+
{
|
|
12773
|
+
stagnationRounds,
|
|
12774
|
+
bootstrappedDemos: bootstrappedDemos.length,
|
|
12775
|
+
labeledExamples: labeledExamples.length,
|
|
12776
|
+
instructions: instructions.length
|
|
12777
|
+
},
|
|
12778
|
+
options
|
|
12151
12779
|
);
|
|
12152
|
-
|
|
12153
|
-
|
|
12154
|
-
|
|
12155
|
-
|
|
12156
|
-
|
|
12157
|
-
|
|
12158
|
-
|
|
12159
|
-
)
|
|
12160
|
-
|
|
12161
|
-
|
|
12162
|
-
|
|
12163
|
-
|
|
12164
|
-
|
|
12165
|
-
|
|
12166
|
-
|
|
12167
|
-
if (this.verbose) {
|
|
12168
|
-
console.log(
|
|
12169
|
-
`Early stopping triggered after ${i + 1} trials. No improvement for ${trialsWithoutImprovement} trials.`
|
|
12170
|
-
);
|
|
12171
|
-
}
|
|
12172
|
-
break;
|
|
12780
|
+
if (this.onProgress) {
|
|
12781
|
+
this.onProgress({
|
|
12782
|
+
round: i + 1,
|
|
12783
|
+
totalRounds: this.numTrials,
|
|
12784
|
+
currentScore: score,
|
|
12785
|
+
bestScore,
|
|
12786
|
+
tokensUsed: this.stats.resourceUsage.totalTokens,
|
|
12787
|
+
timeElapsed: Date.now(),
|
|
12788
|
+
successfulExamples: this.stats.successfulDemos,
|
|
12789
|
+
totalExamples: this.examples.length,
|
|
12790
|
+
currentConfiguration: config,
|
|
12791
|
+
convergenceInfo: {
|
|
12792
|
+
improvement,
|
|
12793
|
+
stagnationRounds,
|
|
12794
|
+
isConverging: stagnationRounds < this.earlyStoppingTrials
|
|
12173
12795
|
}
|
|
12174
|
-
}
|
|
12175
|
-
lastBestScore = bestScore;
|
|
12176
|
-
trialsWithoutImprovement = 0;
|
|
12177
|
-
}
|
|
12796
|
+
});
|
|
12178
12797
|
}
|
|
12179
12798
|
updateProgressBar(
|
|
12180
12799
|
i + 1,
|
|
@@ -12184,243 +12803,91 @@ ${dataContext}
|
|
|
12184
12803
|
"Running MIPROv2 optimization",
|
|
12185
12804
|
30
|
|
12186
12805
|
);
|
|
12187
|
-
if (this.
|
|
12188
|
-
|
|
12189
|
-
|
|
12190
|
-
`Running full evaluation on best configuration at trial ${i + 1}`
|
|
12191
|
-
);
|
|
12192
|
-
}
|
|
12193
|
-
const fullScore = await this.fullEvaluation(
|
|
12194
|
-
bestConfig,
|
|
12195
|
-
bootstrappedDemos,
|
|
12196
|
-
labeledExamples,
|
|
12197
|
-
valset,
|
|
12198
|
-
metricFn
|
|
12199
|
-
);
|
|
12200
|
-
if (this.verbose) {
|
|
12201
|
-
console.log(`Full evaluation score: ${fullScore}`);
|
|
12202
|
-
}
|
|
12203
|
-
bestScore = fullScore;
|
|
12806
|
+
if (this.checkCostLimits()) {
|
|
12807
|
+
this.triggerEarlyStopping("Cost limit reached", i + 1);
|
|
12808
|
+
break;
|
|
12204
12809
|
}
|
|
12205
|
-
|
|
12206
|
-
|
|
12207
|
-
|
|
12208
|
-
|
|
12209
|
-
"Optimization failed to find any valid configurations, using default fallback configuration"
|
|
12810
|
+
if (stagnationRounds >= this.earlyStoppingTrials) {
|
|
12811
|
+
this.triggerEarlyStopping(
|
|
12812
|
+
`No improvement for ${this.earlyStoppingTrials} trials`,
|
|
12813
|
+
i - stagnationRounds + 1
|
|
12210
12814
|
);
|
|
12815
|
+
break;
|
|
12211
12816
|
}
|
|
12212
|
-
|
|
12213
|
-
|
|
12214
|
-
|
|
12215
|
-
|
|
12216
|
-
bootstrappedDemos,
|
|
12217
|
-
labeledExamples,
|
|
12218
|
-
valset,
|
|
12219
|
-
metricFn,
|
|
12220
|
-
this.numTrials - 1
|
|
12817
|
+
if (this.checkTargetScore(bestScore)) {
|
|
12818
|
+
this.triggerEarlyStopping(
|
|
12819
|
+
`Target score ${this.targetScore} reached`,
|
|
12820
|
+
i + 1
|
|
12221
12821
|
);
|
|
12222
|
-
|
|
12223
|
-
if (this.verbose) {
|
|
12224
|
-
console.error("Error evaluating default configuration:", err);
|
|
12225
|
-
}
|
|
12226
|
-
bestScore = 0;
|
|
12822
|
+
break;
|
|
12227
12823
|
}
|
|
12228
12824
|
}
|
|
12825
|
+
this.stats.convergenceInfo.stagnationRounds = stagnationRounds;
|
|
12826
|
+
this.stats.convergenceInfo.finalImprovement = scoreHistory.length > 1 ? bestScore - scoreHistory[0] : 0;
|
|
12827
|
+
this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
|
|
12229
12828
|
return { bestConfig, bestScore };
|
|
12230
12829
|
}
|
|
12231
|
-
|
|
12232
|
-
|
|
12233
|
-
*/
|
|
12234
|
-
async evaluateConfig(config, bootstrappedDemos, labeledExamples, valset, metricFn, trialIndex) {
|
|
12830
|
+
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, valset, metricFn) {
|
|
12831
|
+
const testProgram = { ...program };
|
|
12235
12832
|
this.applyConfigToProgram(
|
|
12236
|
-
|
|
12833
|
+
testProgram,
|
|
12237
12834
|
config,
|
|
12238
12835
|
bootstrappedDemos,
|
|
12239
12836
|
labeledExamples
|
|
12240
12837
|
);
|
|
12241
|
-
let
|
|
12242
|
-
|
|
12243
|
-
|
|
12244
|
-
const minibatchEvalSet = [];
|
|
12245
|
-
for (let j = 0; j < this.minibatchSize; j++) {
|
|
12246
|
-
const idx = (startIdx + j) % valset.length;
|
|
12247
|
-
const example = valset[idx];
|
|
12248
|
-
if (example) {
|
|
12249
|
-
minibatchEvalSet.push(example);
|
|
12250
|
-
}
|
|
12251
|
-
}
|
|
12252
|
-
evalSet = minibatchEvalSet;
|
|
12253
|
-
}
|
|
12254
|
-
let sumOfScores = 0;
|
|
12838
|
+
let totalScore = 0;
|
|
12839
|
+
let count = 0;
|
|
12840
|
+
const evalSet = valset.slice(0, Math.min(5, valset.length));
|
|
12255
12841
|
for (const example of evalSet) {
|
|
12256
12842
|
try {
|
|
12257
|
-
const prediction = await
|
|
12258
|
-
|
|
12259
|
-
|
|
12260
|
-
|
|
12261
|
-
|
|
12262
|
-
|
|
12263
|
-
|
|
12264
|
-
|
|
12265
|
-
|
|
12266
|
-
|
|
12267
|
-
return sumOfScores / evalSet.length;
|
|
12268
|
-
}
|
|
12269
|
-
/**
|
|
12270
|
-
* Run full evaluation on the entire validation set
|
|
12271
|
-
*/
|
|
12272
|
-
async fullEvaluation(config, bootstrappedDemos, labeledExamples, valset, metricFn) {
|
|
12273
|
-
this.applyConfigToProgram(
|
|
12274
|
-
this.program,
|
|
12275
|
-
config,
|
|
12276
|
-
bootstrappedDemos,
|
|
12277
|
-
labeledExamples
|
|
12278
|
-
);
|
|
12279
|
-
let sumOfScores = 0;
|
|
12280
|
-
for (const example of valset) {
|
|
12281
|
-
try {
|
|
12282
|
-
const prediction = await this.program.forward(this.ai, example);
|
|
12283
|
-
const score = metricFn({ prediction, example });
|
|
12284
|
-
sumOfScores += score;
|
|
12285
|
-
} catch (err) {
|
|
12286
|
-
if (this.verbose) {
|
|
12287
|
-
console.error("Error evaluating example:", err);
|
|
12288
|
-
}
|
|
12843
|
+
const prediction = await testProgram.forward(
|
|
12844
|
+
this.studentAI,
|
|
12845
|
+
example
|
|
12846
|
+
);
|
|
12847
|
+
const score = await metricFn({ prediction, example });
|
|
12848
|
+
totalScore += score;
|
|
12849
|
+
count++;
|
|
12850
|
+
this.stats.totalCalls++;
|
|
12851
|
+
} catch {
|
|
12852
|
+
continue;
|
|
12289
12853
|
}
|
|
12290
12854
|
}
|
|
12291
|
-
|
|
12292
|
-
return sumOfScores / valset.length;
|
|
12293
|
-
}
|
|
12294
|
-
/**
|
|
12295
|
-
* Implements a Bayesian-inspired selection of the next configuration to try
|
|
12296
|
-
* This is a simplified version using Upper Confidence Bound (UCB) strategy
|
|
12297
|
-
*/
|
|
12298
|
-
selectNextConfiguration(evaluatedConfigs, maxBootstrappedDemos, maxLabeledExamples, instructions) {
|
|
12299
|
-
if (evaluatedConfigs.length < 5) {
|
|
12300
|
-
const instructionIndex = Math.floor(Math.random() * instructions.length);
|
|
12301
|
-
return {
|
|
12302
|
-
instruction: instructions[instructionIndex] || "",
|
|
12303
|
-
bootstrappedDemos: Math.floor(
|
|
12304
|
-
Math.random() * (maxBootstrappedDemos + 1)
|
|
12305
|
-
),
|
|
12306
|
-
labeledExamples: Math.floor(Math.random() * (maxLabeledExamples + 1))
|
|
12307
|
-
};
|
|
12308
|
-
}
|
|
12309
|
-
const sortedConfigs = [...evaluatedConfigs].sort(
|
|
12310
|
-
(a, b) => b.score - a.score
|
|
12311
|
-
);
|
|
12312
|
-
const topConfigs = sortedConfigs.slice(0, Math.min(3, sortedConfigs.length));
|
|
12313
|
-
const meanBootstrappedDemos = topConfigs.reduce((sum, c) => sum + c.config.bootstrappedDemos, 0) / topConfigs.length;
|
|
12314
|
-
const meanLabeledExamples = topConfigs.reduce((sum, c) => sum + c.config.labeledExamples, 0) / topConfigs.length;
|
|
12315
|
-
const popularInstructions = topConfigs.map((c) => c.config.instruction);
|
|
12316
|
-
const explorationFactor = Math.max(
|
|
12317
|
-
0.2,
|
|
12318
|
-
1 - evaluatedConfigs.length / this.numTrials
|
|
12319
|
-
);
|
|
12320
|
-
let newBootstrappedDemos;
|
|
12321
|
-
let newLabeledExamples;
|
|
12322
|
-
let newInstruction;
|
|
12323
|
-
if (Math.random() < 0.7) {
|
|
12324
|
-
newBootstrappedDemos = Math.min(
|
|
12325
|
-
maxBootstrappedDemos,
|
|
12326
|
-
Math.max(
|
|
12327
|
-
0,
|
|
12328
|
-
Math.round(
|
|
12329
|
-
meanBootstrappedDemos + (Math.random() * 2 - 1) * explorationFactor * 2
|
|
12330
|
-
)
|
|
12331
|
-
)
|
|
12332
|
-
);
|
|
12333
|
-
} else {
|
|
12334
|
-
newBootstrappedDemos = Math.floor(
|
|
12335
|
-
Math.random() * (maxBootstrappedDemos + 1)
|
|
12336
|
-
);
|
|
12337
|
-
}
|
|
12338
|
-
if (Math.random() < 0.7) {
|
|
12339
|
-
newLabeledExamples = Math.min(
|
|
12340
|
-
maxLabeledExamples,
|
|
12341
|
-
Math.max(
|
|
12342
|
-
0,
|
|
12343
|
-
Math.round(
|
|
12344
|
-
meanLabeledExamples + (Math.random() * 2 - 1) * explorationFactor * 2
|
|
12345
|
-
)
|
|
12346
|
-
)
|
|
12347
|
-
);
|
|
12348
|
-
} else {
|
|
12349
|
-
newLabeledExamples = Math.floor(Math.random() * (maxLabeledExamples + 1));
|
|
12350
|
-
}
|
|
12351
|
-
if (Math.random() < 0.7 && popularInstructions.length > 0) {
|
|
12352
|
-
const idx = Math.floor(Math.random() * popularInstructions.length);
|
|
12353
|
-
newInstruction = popularInstructions[idx] || "";
|
|
12354
|
-
} else {
|
|
12355
|
-
const idx = Math.floor(Math.random() * instructions.length);
|
|
12356
|
-
newInstruction = instructions[idx] || "";
|
|
12357
|
-
}
|
|
12358
|
-
return {
|
|
12359
|
-
instruction: newInstruction,
|
|
12360
|
-
bootstrappedDemos: newBootstrappedDemos,
|
|
12361
|
-
labeledExamples: newLabeledExamples
|
|
12362
|
-
};
|
|
12855
|
+
return count > 0 ? totalScore / count : 0;
|
|
12363
12856
|
}
|
|
12364
|
-
/**
|
|
12365
|
-
* Applies a configuration to a program instance
|
|
12366
|
-
*/
|
|
12367
12857
|
applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
|
|
12368
|
-
|
|
12369
|
-
|
|
12858
|
+
if (program.setInstruction) {
|
|
12859
|
+
program.setInstruction(config.instruction);
|
|
12860
|
+
}
|
|
12861
|
+
if (config.bootstrappedDemos > 0 && program.setDemos) {
|
|
12370
12862
|
program.setDemos(bootstrappedDemos.slice(0, config.bootstrappedDemos));
|
|
12371
12863
|
}
|
|
12372
|
-
if (config.labeledExamples > 0) {
|
|
12864
|
+
if (config.labeledExamples > 0 && program.setExamples) {
|
|
12373
12865
|
program.setExamples(labeledExamples.slice(0, config.labeledExamples));
|
|
12374
12866
|
}
|
|
12375
12867
|
}
|
|
12376
|
-
/**
|
|
12377
|
-
* Sets instruction to a program
|
|
12378
|
-
* Note: Workaround since setInstruction may not be available directly
|
|
12379
|
-
*/
|
|
12380
|
-
setInstructionToProgram(program, instruction) {
|
|
12381
|
-
const programWithInstruction = program;
|
|
12382
|
-
programWithInstruction.setInstruction?.(instruction);
|
|
12383
|
-
}
|
|
12384
12868
|
/**
|
|
12385
12869
|
* The main compile method to run MIPROv2 optimization
|
|
12386
|
-
* @param metricFn Evaluation metric function
|
|
12387
|
-
* @param options Optional configuration options
|
|
12388
|
-
* @returns The optimization result
|
|
12389
12870
|
*/
|
|
12390
|
-
async compile(metricFn, options) {
|
|
12871
|
+
async compile(program, metricFn, options) {
|
|
12872
|
+
const startTime = Date.now();
|
|
12873
|
+
this.setupRandomSeed();
|
|
12391
12874
|
const miproOptions = options;
|
|
12392
12875
|
if (miproOptions?.auto) {
|
|
12393
12876
|
this.configureAuto(miproOptions.auto);
|
|
12394
12877
|
}
|
|
12395
|
-
const
|
|
12396
|
-
|
|
12397
|
-
if (this.verbose) {
|
|
12878
|
+
const valset = this.getValidationSet(options) || (miproOptions?.valset ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
|
|
12879
|
+
if (this.verbose || options?.verbose) {
|
|
12398
12880
|
console.log(`Starting MIPROv2 optimization with ${this.numTrials} trials`);
|
|
12399
12881
|
console.log(
|
|
12400
|
-
`Using ${
|
|
12882
|
+
`Using ${this.examples.length} examples for training and ${valset.length} for validation`
|
|
12401
12883
|
);
|
|
12402
|
-
|
|
12403
|
-
|
|
12404
|
-
if (this.verbose) {
|
|
12405
|
-
console.log("Using provided teacher to assist with bootstrapping");
|
|
12884
|
+
if (this.teacherAI) {
|
|
12885
|
+
console.log("Using separate teacher model for instruction generation");
|
|
12406
12886
|
}
|
|
12407
|
-
const bootstrapperWithTeacher = new AxBootstrapFewShot({
|
|
12408
|
-
ai: this.ai,
|
|
12409
|
-
program: this.program,
|
|
12410
|
-
examples: this.examples,
|
|
12411
|
-
options: {
|
|
12412
|
-
maxDemos: this.maxBootstrappedDemos,
|
|
12413
|
-
maxRounds: 3,
|
|
12414
|
-
verboseMode: this.verbose,
|
|
12415
|
-
teacherAI: this.ai
|
|
12416
|
-
// Use the same AI but with the teacher program
|
|
12417
|
-
}
|
|
12418
|
-
});
|
|
12419
|
-
this.bootstrapper = bootstrapperWithTeacher;
|
|
12420
12887
|
}
|
|
12421
12888
|
let bootstrappedDemos = [];
|
|
12422
12889
|
if (this.maxBootstrappedDemos > 0) {
|
|
12423
|
-
bootstrappedDemos = await this.bootstrapFewShotExamples(metricFn);
|
|
12890
|
+
bootstrappedDemos = await this.bootstrapFewShotExamples(program, metricFn);
|
|
12424
12891
|
if (this.verbose) {
|
|
12425
12892
|
console.log(
|
|
12426
12893
|
`Generated ${bootstrappedDemos.length} bootstrapped demonstrations`
|
|
@@ -12436,38 +12903,191 @@ ${dataContext}
|
|
|
12436
12903
|
);
|
|
12437
12904
|
}
|
|
12438
12905
|
}
|
|
12439
|
-
const instructions = await this.proposeInstructionCandidates();
|
|
12906
|
+
const instructions = await this.proposeInstructionCandidates(options);
|
|
12440
12907
|
if (this.verbose) {
|
|
12441
12908
|
console.log(`Generated ${instructions.length} instruction candidates`);
|
|
12909
|
+
if (this.hasTeacherAI(options)) {
|
|
12910
|
+
console.log("Using teacher AI for instruction generation");
|
|
12911
|
+
}
|
|
12442
12912
|
}
|
|
12443
|
-
const { bestConfig, bestScore } = await this.
|
|
12913
|
+
const { bestConfig, bestScore } = await this.runOptimization(
|
|
12914
|
+
program,
|
|
12444
12915
|
bootstrappedDemos,
|
|
12445
12916
|
labeledExamples,
|
|
12446
12917
|
instructions,
|
|
12447
12918
|
valset,
|
|
12448
|
-
metricFn
|
|
12919
|
+
metricFn,
|
|
12920
|
+
options
|
|
12449
12921
|
);
|
|
12450
|
-
if (this.verbose) {
|
|
12922
|
+
if (this.verbose || options?.verbose) {
|
|
12451
12923
|
console.log(`Optimization complete. Best score: ${bestScore}`);
|
|
12452
12924
|
console.log(`Best configuration: ${JSON.stringify(bestConfig)}`);
|
|
12453
12925
|
}
|
|
12454
|
-
this.
|
|
12455
|
-
this.
|
|
12926
|
+
if (this.checkTargetScore(bestScore)) {
|
|
12927
|
+
this.triggerEarlyStopping(
|
|
12928
|
+
`Target score ${this.targetScore} reached with score ${bestScore}`,
|
|
12929
|
+
this.numTrials
|
|
12930
|
+
);
|
|
12931
|
+
}
|
|
12932
|
+
let signature;
|
|
12933
|
+
if ("getSignature" in program && typeof program.getSignature === "function") {
|
|
12934
|
+
signature = program.getSignature();
|
|
12935
|
+
} else {
|
|
12936
|
+
signature = "input -> output";
|
|
12937
|
+
}
|
|
12938
|
+
const optimizedGen = new AxGen(signature);
|
|
12939
|
+
this.applyConfigToAxGen(
|
|
12940
|
+
optimizedGen,
|
|
12456
12941
|
bestConfig,
|
|
12457
12942
|
bootstrappedDemos,
|
|
12458
12943
|
labeledExamples
|
|
12459
12944
|
);
|
|
12945
|
+
this.updateResourceUsage(startTime);
|
|
12946
|
+
this.stats.convergenceInfo.converged = true;
|
|
12947
|
+
this.stats.convergenceInfo.finalImprovement = bestScore;
|
|
12948
|
+
await this.saveFinalCheckpoint(
|
|
12949
|
+
"MiPRO",
|
|
12950
|
+
this.getConfiguration(),
|
|
12951
|
+
bestScore,
|
|
12952
|
+
bestConfig,
|
|
12953
|
+
{
|
|
12954
|
+
bootstrappedDemos: bootstrappedDemos.length,
|
|
12955
|
+
labeledExamples: labeledExamples.length,
|
|
12956
|
+
instructions: instructions.length,
|
|
12957
|
+
optimizedGen: !!optimizedGen
|
|
12958
|
+
},
|
|
12959
|
+
options
|
|
12960
|
+
);
|
|
12460
12961
|
return {
|
|
12461
|
-
|
|
12462
|
-
|
|
12962
|
+
demos: bootstrappedDemos,
|
|
12963
|
+
stats: this.stats,
|
|
12964
|
+
bestScore,
|
|
12965
|
+
optimizedGen,
|
|
12966
|
+
finalConfiguration: {
|
|
12967
|
+
instruction: bestConfig.instruction,
|
|
12968
|
+
bootstrappedDemos: bestConfig.bootstrappedDemos,
|
|
12969
|
+
labeledExamples: bestConfig.labeledExamples,
|
|
12970
|
+
numCandidates: this.numCandidates,
|
|
12971
|
+
numTrials: this.numTrials
|
|
12972
|
+
}
|
|
12463
12973
|
};
|
|
12464
12974
|
}
|
|
12465
12975
|
/**
|
|
12466
|
-
*
|
|
12467
|
-
* @returns Optimization statistics or undefined if not available
|
|
12976
|
+
* Applies a configuration to an AxGen instance
|
|
12468
12977
|
*/
|
|
12469
|
-
|
|
12470
|
-
|
|
12978
|
+
applyConfigToAxGen(axgen, config, bootstrappedDemos, labeledExamples) {
|
|
12979
|
+
if ("setInstruction" in axgen && typeof axgen.setInstruction === "function") {
|
|
12980
|
+
axgen.setInstruction(config.instruction);
|
|
12981
|
+
}
|
|
12982
|
+
if (config.bootstrappedDemos > 0) {
|
|
12983
|
+
axgen.setDemos(bootstrappedDemos.slice(0, config.bootstrappedDemos));
|
|
12984
|
+
}
|
|
12985
|
+
if (config.labeledExamples > 0) {
|
|
12986
|
+
axgen.setExamples(
|
|
12987
|
+
labeledExamples.slice(
|
|
12988
|
+
0,
|
|
12989
|
+
config.labeledExamples
|
|
12990
|
+
)
|
|
12991
|
+
);
|
|
12992
|
+
}
|
|
12993
|
+
}
|
|
12994
|
+
/**
|
|
12995
|
+
* Get optimizer-specific configuration
|
|
12996
|
+
* @returns Current optimizer configuration
|
|
12997
|
+
*/
|
|
12998
|
+
getConfiguration() {
|
|
12999
|
+
return {
|
|
13000
|
+
numCandidates: this.numCandidates,
|
|
13001
|
+
initTemperature: this.initTemperature,
|
|
13002
|
+
maxBootstrappedDemos: this.maxBootstrappedDemos,
|
|
13003
|
+
maxLabeledDemos: this.maxLabeledDemos,
|
|
13004
|
+
numTrials: this.numTrials,
|
|
13005
|
+
minibatch: this.minibatch,
|
|
13006
|
+
minibatchSize: this.minibatchSize,
|
|
13007
|
+
minibatchFullEvalSteps: this.minibatchFullEvalSteps,
|
|
13008
|
+
programAwareProposer: this.programAwareProposer,
|
|
13009
|
+
dataAwareProposer: this.dataAwareProposer,
|
|
13010
|
+
tipAwareProposer: this.tipAwareProposer,
|
|
13011
|
+
fewshotAwareProposer: this.fewshotAwareProposer,
|
|
13012
|
+
earlyStoppingTrials: this.earlyStoppingTrials,
|
|
13013
|
+
minImprovementThreshold: this.minImprovementThreshold,
|
|
13014
|
+
bayesianOptimization: this.bayesianOptimization,
|
|
13015
|
+
acquisitionFunction: this.acquisitionFunction,
|
|
13016
|
+
explorationWeight: this.explorationWeight
|
|
13017
|
+
};
|
|
13018
|
+
}
|
|
13019
|
+
/**
|
|
13020
|
+
* Update optimizer configuration
|
|
13021
|
+
* @param config New configuration to merge with existing
|
|
13022
|
+
*/
|
|
13023
|
+
updateConfiguration(config) {
|
|
13024
|
+
if (config.numCandidates !== void 0) {
|
|
13025
|
+
this.numCandidates = config.numCandidates;
|
|
13026
|
+
}
|
|
13027
|
+
if (config.initTemperature !== void 0) {
|
|
13028
|
+
this.initTemperature = config.initTemperature;
|
|
13029
|
+
}
|
|
13030
|
+
if (config.maxBootstrappedDemos !== void 0) {
|
|
13031
|
+
this.maxBootstrappedDemos = config.maxBootstrappedDemos;
|
|
13032
|
+
}
|
|
13033
|
+
if (config.maxLabeledDemos !== void 0) {
|
|
13034
|
+
this.maxLabeledDemos = config.maxLabeledDemos;
|
|
13035
|
+
}
|
|
13036
|
+
if (config.numTrials !== void 0) {
|
|
13037
|
+
this.numTrials = config.numTrials;
|
|
13038
|
+
}
|
|
13039
|
+
if (config.minibatch !== void 0) {
|
|
13040
|
+
this.minibatch = config.minibatch;
|
|
13041
|
+
}
|
|
13042
|
+
if (config.minibatchSize !== void 0) {
|
|
13043
|
+
this.minibatchSize = config.minibatchSize;
|
|
13044
|
+
}
|
|
13045
|
+
if (config.earlyStoppingTrials !== void 0) {
|
|
13046
|
+
this.earlyStoppingTrials = config.earlyStoppingTrials;
|
|
13047
|
+
}
|
|
13048
|
+
if (config.minImprovementThreshold !== void 0) {
|
|
13049
|
+
this.minImprovementThreshold = config.minImprovementThreshold;
|
|
13050
|
+
}
|
|
13051
|
+
if (config.verbose !== void 0) {
|
|
13052
|
+
this.verbose = config.verbose;
|
|
13053
|
+
}
|
|
13054
|
+
}
|
|
13055
|
+
/**
|
|
13056
|
+
* Reset optimizer state for reuse with different programs
|
|
13057
|
+
*/
|
|
13058
|
+
reset() {
|
|
13059
|
+
super.reset();
|
|
13060
|
+
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
13061
|
+
}
|
|
13062
|
+
/**
|
|
13063
|
+
* Validate that the optimizer can handle the given program
|
|
13064
|
+
* @param program Program to validate
|
|
13065
|
+
* @returns Validation result with any issues found
|
|
13066
|
+
*/
|
|
13067
|
+
validateProgram(program) {
|
|
13068
|
+
const result = super.validateProgram(program);
|
|
13069
|
+
if (this.examples.length < this.maxBootstrappedDemos + this.maxLabeledDemos) {
|
|
13070
|
+
result.issues.push(
|
|
13071
|
+
`Not enough examples: need at least ${this.maxBootstrappedDemos + this.maxLabeledDemos}, got ${this.examples.length}`
|
|
13072
|
+
);
|
|
13073
|
+
result.suggestions.push(
|
|
13074
|
+
"Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
|
|
13075
|
+
);
|
|
13076
|
+
}
|
|
13077
|
+
const valSetSize = this.getValidationSet().length;
|
|
13078
|
+
if (valSetSize < 5) {
|
|
13079
|
+
result.issues.push(
|
|
13080
|
+
"Validation set too small for reliable MiPRO optimization"
|
|
13081
|
+
);
|
|
13082
|
+
result.suggestions.push(
|
|
13083
|
+
"Provide more examples or a larger validation set"
|
|
13084
|
+
);
|
|
13085
|
+
}
|
|
13086
|
+
return {
|
|
13087
|
+
isValid: result.issues.length === 0,
|
|
13088
|
+
issues: result.issues,
|
|
13089
|
+
suggestions: result.suggestions
|
|
13090
|
+
};
|
|
12471
13091
|
}
|
|
12472
13092
|
};
|
|
12473
13093
|
|
|
@@ -12714,7 +13334,7 @@ var AxTestPrompt = class {
|
|
|
12714
13334
|
throw new Error("Invalid example");
|
|
12715
13335
|
}
|
|
12716
13336
|
const res = await this.program.forward(this.ai, ex);
|
|
12717
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
13337
|
+
const score = await metricFn({ prediction: res, example: ex });
|
|
12718
13338
|
sumOfScores += score;
|
|
12719
13339
|
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
12720
13340
|
updateProgressBar(i, total, sumOfScores, et, "Testing Prompt", 30);
|
|
@@ -14748,7 +15368,6 @@ var AxRAG = class extends AxChainOfThought {
|
|
|
14748
15368
|
);
|
|
14749
15369
|
this.genQuery = new AxGen(qsig);
|
|
14750
15370
|
this.queryFn = queryFn;
|
|
14751
|
-
this.register(this.genQuery);
|
|
14752
15371
|
}
|
|
14753
15372
|
async forward(ai, values, options) {
|
|
14754
15373
|
let question;
|
|
@@ -14826,6 +15445,7 @@ var AxRAG = class extends AxChainOfThought {
|
|
|
14826
15445
|
AxAssertionError,
|
|
14827
15446
|
AxBalancer,
|
|
14828
15447
|
AxBaseAI,
|
|
15448
|
+
AxBaseOptimizer,
|
|
14829
15449
|
AxBootstrapFewShot,
|
|
14830
15450
|
AxChainOfThought,
|
|
14831
15451
|
AxDB,
|
|
@@ -14835,6 +15455,7 @@ var AxRAG = class extends AxChainOfThought {
|
|
|
14835
15455
|
AxDBMemory,
|
|
14836
15456
|
AxDBPinecone,
|
|
14837
15457
|
AxDBWeaviate,
|
|
15458
|
+
AxDefaultCostTracker,
|
|
14838
15459
|
AxDefaultQueryRewriter,
|
|
14839
15460
|
AxDefaultResultReranker,
|
|
14840
15461
|
AxDockerSession,
|