@ax-llm/ax 12.0.6 → 12.0.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +1348 -727
- package/index.cjs.map +1 -1
- package/index.d.cts +525 -195
- package/index.d.ts +525 -195
- package/index.js +1350 -731
- package/index.js.map +1 -1
- package/package.json +1 -1
package/index.js
CHANGED
|
@@ -69,7 +69,7 @@ var AxSpanKindValues = /* @__PURE__ */ ((AxSpanKindValues2) => {
|
|
|
69
69
|
// util/apicall.ts
|
|
70
70
|
import crypto from "crypto";
|
|
71
71
|
import {
|
|
72
|
-
ReadableStream
|
|
72
|
+
ReadableStream,
|
|
73
73
|
TextDecoderStream as TextDecoderStreamNative,
|
|
74
74
|
TransformStream as TransformStream3
|
|
75
75
|
} from "stream/web";
|
|
@@ -486,7 +486,7 @@ var apiCall = async (api, json) => {
|
|
|
486
486
|
}
|
|
487
487
|
});
|
|
488
488
|
let closed = false;
|
|
489
|
-
return new
|
|
489
|
+
return new ReadableStream({
|
|
490
490
|
start(controller) {
|
|
491
491
|
const reader = res.body.pipeThrough(new textDecoderStream()).pipeThrough(new SSEParser()).pipeThrough(trackingStream).getReader();
|
|
492
492
|
async function read() {
|
|
@@ -2451,48 +2451,59 @@ function createMessages2(req) {
|
|
|
2451
2451
|
case "system":
|
|
2452
2452
|
return { role: "system", content: msg.content };
|
|
2453
2453
|
case "user":
|
|
2454
|
-
|
|
2454
|
+
const content = Array.isArray(msg.content) ? msg.content.map((c) => {
|
|
2455
|
+
switch (c.type) {
|
|
2456
|
+
case "text":
|
|
2457
|
+
return { type: "text", text: c.text };
|
|
2458
|
+
case "image": {
|
|
2459
|
+
const url = `data:${c.mimeType};base64,` + c.image;
|
|
2460
|
+
return {
|
|
2461
|
+
type: "image_url",
|
|
2462
|
+
image_url: { url, details: c.details ?? "auto" }
|
|
2463
|
+
};
|
|
2464
|
+
}
|
|
2465
|
+
case "audio": {
|
|
2466
|
+
const data = c.data;
|
|
2467
|
+
return {
|
|
2468
|
+
type: "input_audio",
|
|
2469
|
+
input_audio: { data, format: c.format ?? "wav" }
|
|
2470
|
+
};
|
|
2471
|
+
}
|
|
2472
|
+
default:
|
|
2473
|
+
throw new Error("Invalid content type");
|
|
2474
|
+
}
|
|
2475
|
+
}) : msg.content;
|
|
2476
|
+
return {
|
|
2477
|
+
role: "user",
|
|
2478
|
+
...msg.name ? { name: msg.name } : {},
|
|
2479
|
+
content
|
|
2480
|
+
};
|
|
2481
|
+
case "assistant":
|
|
2482
|
+
const toolCalls = msg.functionCalls?.map((v) => ({
|
|
2483
|
+
id: v.id,
|
|
2484
|
+
type: "function",
|
|
2485
|
+
function: {
|
|
2486
|
+
name: v.function.name,
|
|
2487
|
+
arguments: typeof v.function.params === "object" ? JSON.stringify(v.function.params) : v.function.params
|
|
2488
|
+
}
|
|
2489
|
+
}));
|
|
2490
|
+
if (toolCalls && toolCalls.length > 0) {
|
|
2455
2491
|
return {
|
|
2456
|
-
role: "
|
|
2492
|
+
role: "assistant",
|
|
2493
|
+
...msg.content ? { content: msg.content } : {},
|
|
2457
2494
|
name: msg.name,
|
|
2458
|
-
|
|
2459
|
-
switch (c.type) {
|
|
2460
|
-
case "text":
|
|
2461
|
-
return { type: "text", text: c.text };
|
|
2462
|
-
case "image": {
|
|
2463
|
-
const url = `data:${c.mimeType};base64,` + c.image;
|
|
2464
|
-
return {
|
|
2465
|
-
type: "image_url",
|
|
2466
|
-
image_url: { url, details: c.details ?? "auto" }
|
|
2467
|
-
};
|
|
2468
|
-
}
|
|
2469
|
-
case "audio": {
|
|
2470
|
-
const data = c.data;
|
|
2471
|
-
return {
|
|
2472
|
-
type: "input_audio",
|
|
2473
|
-
input_audio: { data, format: c.format ?? "wav" }
|
|
2474
|
-
};
|
|
2475
|
-
}
|
|
2476
|
-
default:
|
|
2477
|
-
throw new Error("Invalid content type");
|
|
2478
|
-
}
|
|
2479
|
-
})
|
|
2495
|
+
tool_calls: toolCalls
|
|
2480
2496
|
};
|
|
2481
2497
|
}
|
|
2482
|
-
|
|
2483
|
-
|
|
2498
|
+
if (!msg.content) {
|
|
2499
|
+
throw new Error(
|
|
2500
|
+
"Assistant content is required when no tool calls are provided"
|
|
2501
|
+
);
|
|
2502
|
+
}
|
|
2484
2503
|
return {
|
|
2485
2504
|
role: "assistant",
|
|
2486
2505
|
content: msg.content,
|
|
2487
|
-
name: msg.name
|
|
2488
|
-
tool_calls: msg.functionCalls?.map((v) => ({
|
|
2489
|
-
id: v.id,
|
|
2490
|
-
type: "function",
|
|
2491
|
-
function: {
|
|
2492
|
-
name: v.function.name,
|
|
2493
|
-
arguments: typeof v.function.params === "object" ? JSON.stringify(v.function.params) : v.function.params
|
|
2494
|
-
}
|
|
2495
|
-
}))
|
|
2506
|
+
...msg.name ? { name: msg.name } : {}
|
|
2496
2507
|
};
|
|
2497
2508
|
case "function":
|
|
2498
2509
|
return {
|
|
@@ -5487,7 +5498,7 @@ var AxAIGrok = class extends AxAIOpenAIBase {
|
|
|
5487
5498
|
};
|
|
5488
5499
|
|
|
5489
5500
|
// dsp/generate.ts
|
|
5490
|
-
import { ReadableStream as
|
|
5501
|
+
import { ReadableStream as ReadableStream2 } from "stream/web";
|
|
5491
5502
|
import {
|
|
5492
5503
|
context as context2,
|
|
5493
5504
|
SpanKind as SpanKind2,
|
|
@@ -7359,8 +7370,9 @@ var AxInstanceRegistry = class {
|
|
|
7359
7370
|
this.reg.add(instance);
|
|
7360
7371
|
}
|
|
7361
7372
|
*[Symbol.iterator]() {
|
|
7362
|
-
|
|
7363
|
-
|
|
7373
|
+
const items = Array.from(this.reg);
|
|
7374
|
+
for (let i = 0; i < items.length; i++) {
|
|
7375
|
+
yield items[i];
|
|
7364
7376
|
}
|
|
7365
7377
|
}
|
|
7366
7378
|
};
|
|
@@ -8047,11 +8059,41 @@ var AxSignature = class _AxSignature {
|
|
|
8047
8059
|
if (signature.validatedAtHash === this.sigHash) {
|
|
8048
8060
|
this.validatedAtHash = this.sigHash;
|
|
8049
8061
|
}
|
|
8062
|
+
} else if (typeof signature === "object" && signature !== null) {
|
|
8063
|
+
if (!("inputs" in signature) || !("outputs" in signature)) {
|
|
8064
|
+
throw new AxSignatureValidationError(
|
|
8065
|
+
"Invalid signature object: missing inputs or outputs",
|
|
8066
|
+
void 0,
|
|
8067
|
+
'Signature object must have "inputs" and "outputs" arrays. Example: { inputs: [...], outputs: [...] }'
|
|
8068
|
+
);
|
|
8069
|
+
}
|
|
8070
|
+
if (!Array.isArray(signature.inputs) || !Array.isArray(signature.outputs)) {
|
|
8071
|
+
throw new AxSignatureValidationError(
|
|
8072
|
+
"Invalid signature object: inputs and outputs must be arrays",
|
|
8073
|
+
void 0,
|
|
8074
|
+
'Both "inputs" and "outputs" must be arrays of AxField objects'
|
|
8075
|
+
);
|
|
8076
|
+
}
|
|
8077
|
+
try {
|
|
8078
|
+
this.description = signature.description;
|
|
8079
|
+
this.inputFields = signature.inputs.map((v) => this.parseField(v));
|
|
8080
|
+
this.outputFields = signature.outputs.map((v) => this.parseField(v));
|
|
8081
|
+
[this.sigHash, this.sigString] = this.updateHash();
|
|
8082
|
+
} catch (error) {
|
|
8083
|
+
if (error instanceof AxSignatureValidationError) {
|
|
8084
|
+
throw error;
|
|
8085
|
+
}
|
|
8086
|
+
throw new AxSignatureValidationError(
|
|
8087
|
+
`Failed to create signature from object: ${error instanceof Error ? error.message : "Unknown error"}`,
|
|
8088
|
+
void 0,
|
|
8089
|
+
"Check that all fields in inputs and outputs arrays are valid AxField objects"
|
|
8090
|
+
);
|
|
8091
|
+
}
|
|
8050
8092
|
} else {
|
|
8051
8093
|
throw new AxSignatureValidationError(
|
|
8052
8094
|
"Invalid signature argument type",
|
|
8053
8095
|
void 0,
|
|
8054
|
-
"Signature must be a string
|
|
8096
|
+
"Signature must be a string, another AxSignature instance, or an object with inputs and outputs arrays"
|
|
8055
8097
|
);
|
|
8056
8098
|
}
|
|
8057
8099
|
}
|
|
@@ -8094,7 +8136,7 @@ var AxSignature = class _AxSignature {
|
|
|
8094
8136
|
}
|
|
8095
8137
|
this.description = desc;
|
|
8096
8138
|
this.invalidateValidationCache();
|
|
8097
|
-
this.
|
|
8139
|
+
this.updateHashLight();
|
|
8098
8140
|
};
|
|
8099
8141
|
addInputField = (field) => {
|
|
8100
8142
|
try {
|
|
@@ -8268,7 +8310,7 @@ var AxSignature = class _AxSignature {
|
|
|
8268
8310
|
this.getOutputFields().forEach((field) => {
|
|
8269
8311
|
validateField(field, "output");
|
|
8270
8312
|
});
|
|
8271
|
-
this.sigHash = createHash("sha256").update(
|
|
8313
|
+
this.sigHash = createHash("sha256").update(JSON.stringify(this.inputFields)).update(JSON.stringify(this.outputFields)).digest("hex");
|
|
8272
8314
|
this.sigString = renderSignature(
|
|
8273
8315
|
this.description,
|
|
8274
8316
|
this.inputFields,
|
|
@@ -8589,7 +8631,7 @@ var AxProgramWithSignature = class {
|
|
|
8589
8631
|
this.signature.validate();
|
|
8590
8632
|
this.sigHash = this.signature?.hash();
|
|
8591
8633
|
this.children = new AxInstanceRegistry();
|
|
8592
|
-
this.key = { id: this.
|
|
8634
|
+
this.key = { id: this.signature.hash() };
|
|
8593
8635
|
}
|
|
8594
8636
|
getSignature() {
|
|
8595
8637
|
return this.signature;
|
|
@@ -8609,8 +8651,8 @@ var AxProgramWithSignature = class {
|
|
|
8609
8651
|
}
|
|
8610
8652
|
setId(id) {
|
|
8611
8653
|
this.key = { id, custom: true };
|
|
8612
|
-
for (const child of this.children) {
|
|
8613
|
-
child
|
|
8654
|
+
for (const child of Array.from(this.children)) {
|
|
8655
|
+
child?.setParentId(id);
|
|
8614
8656
|
}
|
|
8615
8657
|
}
|
|
8616
8658
|
setParentId(parentId) {
|
|
@@ -8623,8 +8665,8 @@ var AxProgramWithSignature = class {
|
|
|
8623
8665
|
if (!("programId" in examples)) {
|
|
8624
8666
|
return;
|
|
8625
8667
|
}
|
|
8626
|
-
for (const child of this.children) {
|
|
8627
|
-
child
|
|
8668
|
+
for (const child of Array.from(this.children)) {
|
|
8669
|
+
child?.setExamples(examples, options);
|
|
8628
8670
|
}
|
|
8629
8671
|
}
|
|
8630
8672
|
_setExamples(examples, options) {
|
|
@@ -8657,30 +8699,37 @@ var AxProgramWithSignature = class {
|
|
|
8657
8699
|
if (this.trace) {
|
|
8658
8700
|
traces.push({ trace: this.trace, programId: this.key.id });
|
|
8659
8701
|
}
|
|
8660
|
-
for (const child of this.children) {
|
|
8661
|
-
const _traces = child
|
|
8662
|
-
traces = [...traces, ..._traces];
|
|
8702
|
+
for (const child of Array.from(this.children)) {
|
|
8703
|
+
const _traces = child?.getTraces();
|
|
8704
|
+
traces = [...traces, ..._traces ?? []];
|
|
8663
8705
|
}
|
|
8664
8706
|
return traces;
|
|
8665
8707
|
}
|
|
8666
8708
|
getUsage() {
|
|
8667
8709
|
let usage = [...this.usage ?? []];
|
|
8668
|
-
for (const child of this.children) {
|
|
8669
|
-
const cu = child
|
|
8670
|
-
usage = [...usage, ...cu];
|
|
8710
|
+
for (const child of Array.from(this.children)) {
|
|
8711
|
+
const cu = child?.getUsage();
|
|
8712
|
+
usage = [...usage, ...cu ?? []];
|
|
8671
8713
|
}
|
|
8672
8714
|
return mergeProgramUsage(usage);
|
|
8673
8715
|
}
|
|
8674
8716
|
resetUsage() {
|
|
8675
8717
|
this.usage = [];
|
|
8676
|
-
for (const child of this.children) {
|
|
8677
|
-
child
|
|
8718
|
+
for (const child of Array.from(this.children)) {
|
|
8719
|
+
child?.resetUsage();
|
|
8678
8720
|
}
|
|
8679
8721
|
}
|
|
8680
8722
|
setDemos(demos) {
|
|
8723
|
+
const hasChildren = Array.from(this.children).length > 0;
|
|
8724
|
+
const hasMatchingDemo = demos.some((demo) => demo.programId === this.key.id);
|
|
8725
|
+
if (hasChildren && !hasMatchingDemo) {
|
|
8726
|
+
throw new Error(
|
|
8727
|
+
`Program with id '${this.key.id}' has children but no matching programId found in demos`
|
|
8728
|
+
);
|
|
8729
|
+
}
|
|
8681
8730
|
this.demos = demos.filter((v) => v.programId === this.key.id).map((v) => v.traces).flat();
|
|
8682
|
-
for (const child of this.children) {
|
|
8683
|
-
child
|
|
8731
|
+
for (const child of Array.from(this.children)) {
|
|
8732
|
+
child?.setDemos(demos);
|
|
8684
8733
|
}
|
|
8685
8734
|
}
|
|
8686
8735
|
};
|
|
@@ -8708,8 +8757,8 @@ var AxProgram = class {
|
|
|
8708
8757
|
}
|
|
8709
8758
|
setId(id) {
|
|
8710
8759
|
this.key = { id, custom: true };
|
|
8711
|
-
for (const child of this.children) {
|
|
8712
|
-
child
|
|
8760
|
+
for (const child of Array.from(this.children)) {
|
|
8761
|
+
child?.setParentId(id);
|
|
8713
8762
|
}
|
|
8714
8763
|
}
|
|
8715
8764
|
setParentId(parentId) {
|
|
@@ -8721,8 +8770,8 @@ var AxProgram = class {
|
|
|
8721
8770
|
if (!("programId" in examples)) {
|
|
8722
8771
|
return;
|
|
8723
8772
|
}
|
|
8724
|
-
for (const child of this.children) {
|
|
8725
|
-
child
|
|
8773
|
+
for (const child of Array.from(this.children)) {
|
|
8774
|
+
child?.setExamples(examples, options);
|
|
8726
8775
|
}
|
|
8727
8776
|
}
|
|
8728
8777
|
getTraces() {
|
|
@@ -8730,29 +8779,36 @@ var AxProgram = class {
|
|
|
8730
8779
|
if (this.trace) {
|
|
8731
8780
|
traces.push({ trace: this.trace, programId: this.key.id });
|
|
8732
8781
|
}
|
|
8733
|
-
for (const child of this.children) {
|
|
8734
|
-
const _traces = child
|
|
8735
|
-
traces = [...traces, ..._traces];
|
|
8782
|
+
for (const child of Array.from(this.children)) {
|
|
8783
|
+
const _traces = child?.getTraces();
|
|
8784
|
+
traces = [...traces, ..._traces ?? []];
|
|
8736
8785
|
}
|
|
8737
8786
|
return traces;
|
|
8738
8787
|
}
|
|
8739
8788
|
getUsage() {
|
|
8740
8789
|
let usage = [...this.usage ?? []];
|
|
8741
|
-
for (const child of this.children) {
|
|
8742
|
-
const cu = child
|
|
8743
|
-
usage = [...usage, ...cu];
|
|
8790
|
+
for (const child of Array.from(this.children)) {
|
|
8791
|
+
const cu = child?.getUsage();
|
|
8792
|
+
usage = [...usage, ...cu ?? []];
|
|
8744
8793
|
}
|
|
8745
8794
|
return mergeProgramUsage(usage);
|
|
8746
8795
|
}
|
|
8747
8796
|
resetUsage() {
|
|
8748
8797
|
this.usage = [];
|
|
8749
|
-
for (const child of this.children) {
|
|
8750
|
-
child
|
|
8798
|
+
for (const child of Array.from(this.children)) {
|
|
8799
|
+
child?.resetUsage();
|
|
8751
8800
|
}
|
|
8752
8801
|
}
|
|
8753
8802
|
setDemos(demos) {
|
|
8754
|
-
|
|
8755
|
-
|
|
8803
|
+
const hasChildren = Array.from(this.children).length > 0;
|
|
8804
|
+
const hasMatchingDemo = demos.some((demo) => demo.programId === this.key.id);
|
|
8805
|
+
if (hasChildren && !hasMatchingDemo) {
|
|
8806
|
+
throw new Error(
|
|
8807
|
+
`Program with id '${this.key.id}' has children but no matching programId found in demos`
|
|
8808
|
+
);
|
|
8809
|
+
}
|
|
8810
|
+
for (const child of Array.from(this.children)) {
|
|
8811
|
+
child?.setDemos(demos);
|
|
8756
8812
|
}
|
|
8757
8813
|
}
|
|
8758
8814
|
};
|
|
@@ -8770,11 +8826,9 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
8770
8826
|
values = {};
|
|
8771
8827
|
excludeContentFromTrace = false;
|
|
8772
8828
|
thoughtFieldName;
|
|
8773
|
-
logger;
|
|
8774
8829
|
constructor(signature, options) {
|
|
8775
8830
|
super(signature, { description: options?.description });
|
|
8776
8831
|
this.options = options;
|
|
8777
|
-
this.logger = options?.logger;
|
|
8778
8832
|
this.thoughtFieldName = options?.thoughtFieldName ?? "thought";
|
|
8779
8833
|
const promptTemplateOptions = {
|
|
8780
8834
|
functions: options?.functions,
|
|
@@ -8864,6 +8918,7 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
8864
8918
|
rateLimiter,
|
|
8865
8919
|
stream,
|
|
8866
8920
|
debug: false,
|
|
8921
|
+
// we do our own debug logging
|
|
8867
8922
|
thinkingTokenBudget,
|
|
8868
8923
|
showThoughts,
|
|
8869
8924
|
traceContext,
|
|
@@ -8891,7 +8946,7 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
8891
8946
|
traceContext,
|
|
8892
8947
|
firstStep
|
|
8893
8948
|
});
|
|
8894
|
-
if (res instanceof
|
|
8949
|
+
if (res instanceof ReadableStream2) {
|
|
8895
8950
|
yield* this.processStreamingResponse({
|
|
8896
8951
|
ai,
|
|
8897
8952
|
model,
|
|
@@ -8903,6 +8958,7 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
8903
8958
|
fastFail,
|
|
8904
8959
|
span
|
|
8905
8960
|
});
|
|
8961
|
+
this.getLogger(ai, options)?.("", { tags: ["responseEnd"] });
|
|
8906
8962
|
} else {
|
|
8907
8963
|
yield await this.processResponse({
|
|
8908
8964
|
ai,
|
|
@@ -8939,7 +8995,6 @@ var AxGen = class extends AxProgramWithSignature {
|
|
|
8939
8995
|
mem.addResult(
|
|
8940
8996
|
{
|
|
8941
8997
|
content: "",
|
|
8942
|
-
name: "initial",
|
|
8943
8998
|
functionCalls: []
|
|
8944
8999
|
},
|
|
8945
9000
|
sessionId
|
|
@@ -9072,10 +9127,6 @@ Content: ${content}`
|
|
|
9072
9127
|
xstate
|
|
9073
9128
|
);
|
|
9074
9129
|
}
|
|
9075
|
-
if (ai.getOptions().debug) {
|
|
9076
|
-
const logger = ai.getLogger();
|
|
9077
|
-
logger("", { tags: ["responseEnd"] });
|
|
9078
|
-
}
|
|
9079
9130
|
}
|
|
9080
9131
|
async processResponse({
|
|
9081
9132
|
ai,
|
|
@@ -9147,9 +9198,11 @@ Content: ${result.content}`
|
|
|
9147
9198
|
const stopFunction = (options?.stopFunction ?? this.options?.stopFunction)?.toLowerCase();
|
|
9148
9199
|
const maxRetries = options.maxRetries ?? this.options?.maxRetries ?? 10;
|
|
9149
9200
|
const maxSteps = options.maxSteps ?? this.options?.maxSteps ?? 10;
|
|
9150
|
-
const debug = options.debug ?? ai.getOptions().debug;
|
|
9151
9201
|
const debugHideSystemPrompt = options.debugHideSystemPrompt;
|
|
9152
|
-
const memOptions = {
|
|
9202
|
+
const memOptions = {
|
|
9203
|
+
debug: this.isDebug(ai, options),
|
|
9204
|
+
debugHideSystemPrompt
|
|
9205
|
+
};
|
|
9153
9206
|
const mem = options.mem ?? this.options?.mem ?? new AxMemory(1e4, memOptions);
|
|
9154
9207
|
let err;
|
|
9155
9208
|
if (options?.functions && options.functions.length > 0) {
|
|
@@ -9203,10 +9256,7 @@ Content: ${result.content}`
|
|
|
9203
9256
|
if (shouldContinue) {
|
|
9204
9257
|
continue multiStepLoop;
|
|
9205
9258
|
}
|
|
9206
|
-
|
|
9207
|
-
const logger = options.logger ?? this.logger ?? ai.getLogger();
|
|
9208
|
-
logger("", { tags: ["responseEnd"] });
|
|
9209
|
-
}
|
|
9259
|
+
this.getLogger(ai, options)?.("", { tags: ["responseEnd"] });
|
|
9210
9260
|
return;
|
|
9211
9261
|
} catch (e) {
|
|
9212
9262
|
let errorFields;
|
|
@@ -9348,6 +9398,12 @@ Content: ${result.content}`
|
|
|
9348
9398
|
setExamples(examples, options) {
|
|
9349
9399
|
super.setExamples(examples, options);
|
|
9350
9400
|
}
|
|
9401
|
+
isDebug(ai, options) {
|
|
9402
|
+
return options?.debug ?? this.options?.debug ?? ai.getOptions().debug ?? false;
|
|
9403
|
+
}
|
|
9404
|
+
getLogger(ai, options) {
|
|
9405
|
+
return options?.logger ?? this.options?.logger ?? ai.getLogger();
|
|
9406
|
+
}
|
|
9351
9407
|
};
|
|
9352
9408
|
var AxGenerateError = class extends Error {
|
|
9353
9409
|
details;
|
|
@@ -9480,7 +9536,9 @@ var AxAgent = class {
|
|
|
9480
9536
|
description: definition ?? description
|
|
9481
9537
|
});
|
|
9482
9538
|
for (const agent of agents ?? []) {
|
|
9483
|
-
this.program.register(
|
|
9539
|
+
this.program.register(
|
|
9540
|
+
agent
|
|
9541
|
+
);
|
|
9484
9542
|
}
|
|
9485
9543
|
this.name = name;
|
|
9486
9544
|
this.func = {
|
|
@@ -9984,171 +10042,838 @@ function validateModels2(services) {
|
|
|
9984
10042
|
}
|
|
9985
10043
|
}
|
|
9986
10044
|
|
|
9987
|
-
//
|
|
9988
|
-
|
|
9989
|
-
|
|
9990
|
-
|
|
9991
|
-
|
|
9992
|
-
|
|
9993
|
-
|
|
9994
|
-
|
|
9995
|
-
|
|
9996
|
-
|
|
9997
|
-
|
|
9998
|
-
|
|
9999
|
-
tracer
|
|
10000
|
-
}) {
|
|
10001
|
-
this.name = name;
|
|
10002
|
-
this.fetch = fetch2;
|
|
10003
|
-
this.tracer = tracer;
|
|
10045
|
+
// dsp/optimizer.ts
|
|
10046
|
+
var AxDefaultCostTracker = class {
|
|
10047
|
+
tokenUsage = {};
|
|
10048
|
+
totalTokens = 0;
|
|
10049
|
+
// Configuration options
|
|
10050
|
+
costPerModel;
|
|
10051
|
+
maxCost;
|
|
10052
|
+
maxTokens;
|
|
10053
|
+
constructor(options) {
|
|
10054
|
+
this.costPerModel = options?.costPerModel ?? {};
|
|
10055
|
+
this.maxCost = options?.maxCost;
|
|
10056
|
+
this.maxTokens = options?.maxTokens;
|
|
10004
10057
|
}
|
|
10005
|
-
|
|
10006
|
-
|
|
10007
|
-
|
|
10008
|
-
}
|
|
10009
|
-
if (!this.tracer) {
|
|
10010
|
-
return await this._upsert(req, update);
|
|
10011
|
-
}
|
|
10012
|
-
return await this.tracer.startActiveSpan(
|
|
10013
|
-
"DB Upsert Request",
|
|
10014
|
-
{
|
|
10015
|
-
kind: SpanKind3.SERVER,
|
|
10016
|
-
attributes: {
|
|
10017
|
-
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10018
|
-
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10019
|
-
[axSpanAttributes.DB_TABLE]: req.table,
|
|
10020
|
-
[axSpanAttributes.DB_NAMESPACE]: req.namespace,
|
|
10021
|
-
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
10022
|
-
}
|
|
10023
|
-
},
|
|
10024
|
-
async (span) => {
|
|
10025
|
-
try {
|
|
10026
|
-
return await this._upsert(req, update, { span });
|
|
10027
|
-
} finally {
|
|
10028
|
-
span.end();
|
|
10029
|
-
}
|
|
10030
|
-
}
|
|
10031
|
-
);
|
|
10058
|
+
trackTokens(count, model) {
|
|
10059
|
+
this.tokenUsage[model] = (this.tokenUsage[model] || 0) + count;
|
|
10060
|
+
this.totalTokens += count;
|
|
10032
10061
|
}
|
|
10033
|
-
|
|
10034
|
-
|
|
10035
|
-
|
|
10036
|
-
|
|
10037
|
-
|
|
10038
|
-
throw new Error("Batch request is empty");
|
|
10039
|
-
}
|
|
10040
|
-
if (!req[0]) {
|
|
10041
|
-
throw new Error("Batch request is invalid first element is undefined");
|
|
10042
|
-
}
|
|
10043
|
-
if (!this.tracer) {
|
|
10044
|
-
return await this._batchUpsert(req, update);
|
|
10062
|
+
getCurrentCost() {
|
|
10063
|
+
let totalCost = 0;
|
|
10064
|
+
for (const [model, tokens] of Object.entries(this.tokenUsage)) {
|
|
10065
|
+
const costPer1K = this.costPerModel[model] || 1e-3;
|
|
10066
|
+
totalCost += tokens / 1e3 * costPer1K;
|
|
10045
10067
|
}
|
|
10046
|
-
return
|
|
10047
|
-
"DB Batch Upsert Request",
|
|
10048
|
-
{
|
|
10049
|
-
kind: SpanKind3.SERVER,
|
|
10050
|
-
attributes: {
|
|
10051
|
-
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10052
|
-
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10053
|
-
[axSpanAttributes.DB_TABLE]: req[0].table,
|
|
10054
|
-
[axSpanAttributes.DB_NAMESPACE]: req[0].namespace,
|
|
10055
|
-
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
10056
|
-
}
|
|
10057
|
-
},
|
|
10058
|
-
async (span) => {
|
|
10059
|
-
try {
|
|
10060
|
-
return await this._batchUpsert(req, update, { span });
|
|
10061
|
-
} finally {
|
|
10062
|
-
span.end();
|
|
10063
|
-
}
|
|
10064
|
-
}
|
|
10065
|
-
);
|
|
10068
|
+
return totalCost;
|
|
10066
10069
|
}
|
|
10067
|
-
|
|
10068
|
-
|
|
10069
|
-
|
|
10070
|
-
|
|
10071
|
-
|
|
10072
|
-
|
|
10070
|
+
getTokenUsage() {
|
|
10071
|
+
return { ...this.tokenUsage };
|
|
10072
|
+
}
|
|
10073
|
+
getTotalTokens() {
|
|
10074
|
+
return this.totalTokens;
|
|
10075
|
+
}
|
|
10076
|
+
isLimitReached() {
|
|
10077
|
+
if (this.maxTokens !== void 0 && this.totalTokens >= this.maxTokens) {
|
|
10078
|
+
return true;
|
|
10073
10079
|
}
|
|
10074
|
-
|
|
10075
|
-
|
|
10076
|
-
{
|
|
10077
|
-
|
|
10078
|
-
attributes: {
|
|
10079
|
-
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10080
|
-
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10081
|
-
[axSpanAttributes.DB_TABLE]: req.table,
|
|
10082
|
-
[axSpanAttributes.DB_NAMESPACE]: req.namespace,
|
|
10083
|
-
[axSpanAttributes.DB_OPERATION_NAME]: "query"
|
|
10084
|
-
}
|
|
10085
|
-
},
|
|
10086
|
-
async (span) => {
|
|
10087
|
-
try {
|
|
10088
|
-
return await this._query(req, { span });
|
|
10089
|
-
} finally {
|
|
10090
|
-
span.end();
|
|
10091
|
-
}
|
|
10080
|
+
if (this.maxCost !== void 0) {
|
|
10081
|
+
const currentCost = this.getCurrentCost();
|
|
10082
|
+
if (currentCost >= this.maxCost) {
|
|
10083
|
+
return true;
|
|
10092
10084
|
}
|
|
10093
|
-
|
|
10085
|
+
}
|
|
10086
|
+
return false;
|
|
10087
|
+
}
|
|
10088
|
+
reset() {
|
|
10089
|
+
this.tokenUsage = {};
|
|
10090
|
+
this.totalTokens = 0;
|
|
10094
10091
|
}
|
|
10095
10092
|
};
|
|
10096
|
-
|
|
10097
|
-
//
|
|
10098
|
-
|
|
10099
|
-
|
|
10100
|
-
|
|
10101
|
-
|
|
10102
|
-
|
|
10103
|
-
|
|
10104
|
-
|
|
10105
|
-
|
|
10106
|
-
|
|
10107
|
-
|
|
10108
|
-
|
|
10109
|
-
|
|
10093
|
+
var AxBaseOptimizer = class {
|
|
10094
|
+
// Common AxOptimizerArgs fields
|
|
10095
|
+
studentAI;
|
|
10096
|
+
teacherAI;
|
|
10097
|
+
examples;
|
|
10098
|
+
validationSet;
|
|
10099
|
+
targetScore;
|
|
10100
|
+
minSuccessRate;
|
|
10101
|
+
onProgress;
|
|
10102
|
+
onEarlyStop;
|
|
10103
|
+
costTracker;
|
|
10104
|
+
seed;
|
|
10105
|
+
// Checkpointing fields
|
|
10106
|
+
checkpointSave;
|
|
10107
|
+
checkpointLoad;
|
|
10108
|
+
checkpointInterval;
|
|
10109
|
+
resumeFromCheckpoint;
|
|
10110
|
+
// Checkpoint state
|
|
10111
|
+
currentRound = 0;
|
|
10112
|
+
scoreHistory = [];
|
|
10113
|
+
configurationHistory = [];
|
|
10114
|
+
// Common optimization statistics
|
|
10115
|
+
stats;
|
|
10116
|
+
constructor(args) {
|
|
10117
|
+
if (args.examples.length === 0) {
|
|
10118
|
+
throw new Error("No examples found");
|
|
10110
10119
|
}
|
|
10111
|
-
|
|
10112
|
-
this.
|
|
10113
|
-
this.
|
|
10120
|
+
this.studentAI = args.studentAI;
|
|
10121
|
+
this.teacherAI = args.teacherAI;
|
|
10122
|
+
this.examples = args.examples;
|
|
10123
|
+
this.validationSet = args.validationSet;
|
|
10124
|
+
this.targetScore = args.targetScore;
|
|
10125
|
+
this.minSuccessRate = args.minSuccessRate;
|
|
10126
|
+
this.onProgress = args.onProgress;
|
|
10127
|
+
this.onEarlyStop = args.onEarlyStop;
|
|
10128
|
+
this.seed = args.seed;
|
|
10129
|
+
this.checkpointSave = args.checkpointSave;
|
|
10130
|
+
this.checkpointLoad = args.checkpointLoad;
|
|
10131
|
+
this.checkpointInterval = args.checkpointInterval ?? 10;
|
|
10132
|
+
this.resumeFromCheckpoint = args.resumeFromCheckpoint;
|
|
10133
|
+
const costTracker = new AxDefaultCostTracker({
|
|
10134
|
+
maxTokens: 1e6
|
|
10135
|
+
});
|
|
10136
|
+
this.costTracker = args.costTracker ?? costTracker;
|
|
10137
|
+
this.stats = this.initializeStats();
|
|
10114
10138
|
}
|
|
10115
|
-
|
|
10116
|
-
|
|
10117
|
-
|
|
10118
|
-
|
|
10119
|
-
|
|
10120
|
-
|
|
10121
|
-
|
|
10122
|
-
|
|
10123
|
-
|
|
10124
|
-
|
|
10125
|
-
|
|
10126
|
-
|
|
10139
|
+
/**
|
|
10140
|
+
* Initialize the optimization statistics structure
|
|
10141
|
+
*/
|
|
10142
|
+
initializeStats() {
|
|
10143
|
+
return {
|
|
10144
|
+
totalCalls: 0,
|
|
10145
|
+
successfulDemos: 0,
|
|
10146
|
+
estimatedTokenUsage: 0,
|
|
10147
|
+
earlyStopped: false,
|
|
10148
|
+
resourceUsage: {
|
|
10149
|
+
totalTokens: 0,
|
|
10150
|
+
totalTime: 0,
|
|
10151
|
+
avgLatencyPerEval: 0,
|
|
10152
|
+
costByModel: {}
|
|
10127
10153
|
},
|
|
10128
|
-
{
|
|
10129
|
-
|
|
10130
|
-
|
|
10131
|
-
|
|
10132
|
-
|
|
10154
|
+
convergenceInfo: {
|
|
10155
|
+
converged: false,
|
|
10156
|
+
finalImprovement: 0,
|
|
10157
|
+
stagnationRounds: 0,
|
|
10158
|
+
convergenceThreshold: 0.01
|
|
10133
10159
|
}
|
|
10134
|
-
);
|
|
10135
|
-
if (res.errors) {
|
|
10136
|
-
throw new Error(
|
|
10137
|
-
`Cloudflare upsert failed: ${res.errors.map(({ message }) => message).join(", ")}`
|
|
10138
|
-
);
|
|
10139
|
-
}
|
|
10140
|
-
return {
|
|
10141
|
-
ids: res.result.ids
|
|
10142
10160
|
};
|
|
10143
|
-
}
|
|
10144
|
-
|
|
10145
|
-
|
|
10146
|
-
|
|
10161
|
+
}
|
|
10162
|
+
/**
|
|
10163
|
+
* Set up reproducible random seed if provided
|
|
10164
|
+
*/
|
|
10165
|
+
setupRandomSeed() {
|
|
10166
|
+
if (this.seed !== void 0) {
|
|
10167
|
+
Math.random = (() => {
|
|
10168
|
+
let seed = this.seed;
|
|
10169
|
+
return () => {
|
|
10170
|
+
seed = (seed * 9301 + 49297) % 233280;
|
|
10171
|
+
return seed / 233280;
|
|
10172
|
+
};
|
|
10173
|
+
})();
|
|
10147
10174
|
}
|
|
10148
|
-
|
|
10149
|
-
|
|
10175
|
+
}
|
|
10176
|
+
/**
|
|
10177
|
+
* Check if optimization should stop early due to cost limits
|
|
10178
|
+
*/
|
|
10179
|
+
checkCostLimits() {
|
|
10180
|
+
return this.costTracker?.isLimitReached() ?? false;
|
|
10181
|
+
}
|
|
10182
|
+
/**
|
|
10183
|
+
* Check if target score has been reached
|
|
10184
|
+
*/
|
|
10185
|
+
checkTargetScore(currentScore) {
|
|
10186
|
+
return this.targetScore !== void 0 && currentScore >= this.targetScore;
|
|
10187
|
+
}
|
|
10188
|
+
/**
|
|
10189
|
+
* Update resource usage statistics
|
|
10190
|
+
*/
|
|
10191
|
+
updateResourceUsage(startTime, tokensUsed = 0) {
|
|
10192
|
+
this.stats.resourceUsage.totalTime = Date.now() - startTime;
|
|
10193
|
+
this.stats.resourceUsage.totalTokens += tokensUsed;
|
|
10194
|
+
if (this.stats.totalCalls > 0) {
|
|
10195
|
+
this.stats.resourceUsage.avgLatencyPerEval = this.stats.resourceUsage.totalTime / this.stats.totalCalls;
|
|
10150
10196
|
}
|
|
10151
|
-
|
|
10197
|
+
}
|
|
10198
|
+
/**
|
|
10199
|
+
* Trigger early stopping with appropriate callbacks
|
|
10200
|
+
*/
|
|
10201
|
+
triggerEarlyStopping(reason, bestScoreRound) {
|
|
10202
|
+
this.stats.earlyStopped = true;
|
|
10203
|
+
this.stats.earlyStopping = {
|
|
10204
|
+
bestScoreRound,
|
|
10205
|
+
patienceExhausted: reason.includes("improvement"),
|
|
10206
|
+
reason
|
|
10207
|
+
};
|
|
10208
|
+
if (this.onEarlyStop) {
|
|
10209
|
+
this.onEarlyStop(reason, this.stats);
|
|
10210
|
+
}
|
|
10211
|
+
}
|
|
10212
|
+
/**
|
|
10213
|
+
* Get the validation set, with fallback to a split of examples
|
|
10214
|
+
*/
|
|
10215
|
+
getValidationSet(options) {
|
|
10216
|
+
return options?.overrideValidationSet || this.validationSet || this.examples.slice(0, Math.floor(this.examples.length * 0.2));
|
|
10217
|
+
}
|
|
10218
|
+
/**
|
|
10219
|
+
* Get the AI service to use for a specific task, preferring teacher when available
|
|
10220
|
+
* @param preferTeacher Whether to prefer teacher AI over student AI
|
|
10221
|
+
* @param options Optional compile options that may override teacher AI
|
|
10222
|
+
* @returns The appropriate AI service to use
|
|
10223
|
+
*/
|
|
10224
|
+
getAIService(preferTeacher = false, options) {
|
|
10225
|
+
if (preferTeacher && options?.overrideTeacherAI) {
|
|
10226
|
+
return options.overrideTeacherAI;
|
|
10227
|
+
}
|
|
10228
|
+
if (preferTeacher && this.teacherAI) {
|
|
10229
|
+
return this.teacherAI;
|
|
10230
|
+
}
|
|
10231
|
+
return this.studentAI;
|
|
10232
|
+
}
|
|
10233
|
+
/**
|
|
10234
|
+
* Check if teacher AI is available (including overrides)
|
|
10235
|
+
* @param options Optional compile options that may override teacher AI
|
|
10236
|
+
* @returns True if teacher AI is configured or overridden
|
|
10237
|
+
*/
|
|
10238
|
+
hasTeacherAI(options) {
|
|
10239
|
+
return options?.overrideTeacherAI !== void 0 || this.teacherAI !== void 0;
|
|
10240
|
+
}
|
|
10241
|
+
/**
|
|
10242
|
+
* Get teacher AI if available, otherwise return student AI
|
|
10243
|
+
* @param options Optional compile options that may override teacher AI
|
|
10244
|
+
* @returns Teacher AI if available, otherwise student AI
|
|
10245
|
+
*/
|
|
10246
|
+
getTeacherOrStudentAI(options) {
|
|
10247
|
+
return options?.overrideTeacherAI || this.teacherAI || this.studentAI;
|
|
10248
|
+
}
|
|
10249
|
+
/**
|
|
10250
|
+
* Execute a task with teacher AI if available, otherwise use student AI
|
|
10251
|
+
* @param task Function that takes an AI service and returns a promise
|
|
10252
|
+
* @param preferTeacher Whether to prefer teacher AI (default: true)
|
|
10253
|
+
* @param options Optional compile options that may override teacher AI
|
|
10254
|
+
* @returns Result of the task execution
|
|
10255
|
+
*/
|
|
10256
|
+
async executeWithTeacher(task, preferTeacher = true, options) {
|
|
10257
|
+
const ai = this.getAIService(preferTeacher, options);
|
|
10258
|
+
return await task(ai);
|
|
10259
|
+
}
|
|
10260
|
+
/**
|
|
10261
|
+
* Get current optimization statistics
|
|
10262
|
+
*/
|
|
10263
|
+
getStats() {
|
|
10264
|
+
return { ...this.stats };
|
|
10265
|
+
}
|
|
10266
|
+
/**
|
|
10267
|
+
* Reset optimizer state for reuse with different programs
|
|
10268
|
+
*/
|
|
10269
|
+
reset() {
|
|
10270
|
+
this.stats = this.initializeStats();
|
|
10271
|
+
this.costTracker?.reset();
|
|
10272
|
+
this.currentRound = 0;
|
|
10273
|
+
this.scoreHistory = [];
|
|
10274
|
+
this.configurationHistory = [];
|
|
10275
|
+
}
|
|
10276
|
+
/**
|
|
10277
|
+
* Basic program validation that can be extended by concrete optimizers
|
|
10278
|
+
*/
|
|
10279
|
+
validateProgram(program) {
|
|
10280
|
+
const issues = [];
|
|
10281
|
+
const suggestions = [];
|
|
10282
|
+
if (!("forward" in program) || typeof program.forward !== "function") {
|
|
10283
|
+
issues.push("Program must have a forward method");
|
|
10284
|
+
}
|
|
10285
|
+
if (this.examples.length < 2) {
|
|
10286
|
+
issues.push("Need at least 2 examples for optimization");
|
|
10287
|
+
suggestions.push("Provide more training examples");
|
|
10288
|
+
}
|
|
10289
|
+
const valSetSize = this.getValidationSet().length;
|
|
10290
|
+
if (valSetSize < 1) {
|
|
10291
|
+
issues.push("Validation set is empty");
|
|
10292
|
+
suggestions.push("Provide examples or a validation set");
|
|
10293
|
+
}
|
|
10294
|
+
return {
|
|
10295
|
+
isValid: issues.length === 0,
|
|
10296
|
+
issues,
|
|
10297
|
+
suggestions
|
|
10298
|
+
};
|
|
10299
|
+
}
|
|
10300
|
+
/**
|
|
10301
|
+
* Multi-objective optimization using Pareto frontier
|
|
10302
|
+
* Default implementation that leverages the single-objective compile method
|
|
10303
|
+
* @param program The program to optimize
|
|
10304
|
+
* @param metricFn Multi-objective metric function that returns multiple scores
|
|
10305
|
+
* @param options Optional configuration options
|
|
10306
|
+
* @returns Pareto optimization result with frontier of non-dominated solutions
|
|
10307
|
+
*/
|
|
10308
|
+
async compilePareto(program, metricFn, options) {
|
|
10309
|
+
const startTime = Date.now();
|
|
10310
|
+
if (options?.verbose) {
|
|
10311
|
+
console.log("Starting Pareto optimization using base implementation");
|
|
10312
|
+
console.log("This will run multiple single-objective optimizations");
|
|
10313
|
+
}
|
|
10314
|
+
const solutions = await this.generateWeightedSolutions(
|
|
10315
|
+
program,
|
|
10316
|
+
metricFn,
|
|
10317
|
+
options
|
|
10318
|
+
);
|
|
10319
|
+
const constraintSolutions = await this.generateConstraintSolutions(
|
|
10320
|
+
program,
|
|
10321
|
+
metricFn,
|
|
10322
|
+
options
|
|
10323
|
+
);
|
|
10324
|
+
const allSolutions = [...solutions, ...constraintSolutions];
|
|
10325
|
+
if (options?.verbose) {
|
|
10326
|
+
console.log(`Generated ${allSolutions.length} candidate solutions`);
|
|
10327
|
+
}
|
|
10328
|
+
const paretoFront = this.findParetoFrontier(allSolutions);
|
|
10329
|
+
const hypervolume = this.calculateHypervolume(paretoFront);
|
|
10330
|
+
if (options?.verbose) {
|
|
10331
|
+
console.log(`Found ${paretoFront.length} non-dominated solutions`);
|
|
10332
|
+
console.log(`Hypervolume: ${hypervolume?.toFixed(4) || "N/A"}`);
|
|
10333
|
+
}
|
|
10334
|
+
this.updateResourceUsage(startTime);
|
|
10335
|
+
this.stats.convergenceInfo.converged = true;
|
|
10336
|
+
const bestScore = paretoFront.length > 0 ? Math.max(
|
|
10337
|
+
...paretoFront.map((sol) => Math.max(...Object.values(sol.scores)))
|
|
10338
|
+
) : 0;
|
|
10339
|
+
return {
|
|
10340
|
+
demos: paretoFront.length > 0 ? [...paretoFront[0].demos] : void 0,
|
|
10341
|
+
stats: this.stats,
|
|
10342
|
+
bestScore,
|
|
10343
|
+
paretoFront,
|
|
10344
|
+
hypervolume,
|
|
10345
|
+
paretoFrontSize: paretoFront.length,
|
|
10346
|
+
finalConfiguration: {
|
|
10347
|
+
paretoFrontSize: paretoFront.length,
|
|
10348
|
+
hypervolume,
|
|
10349
|
+
strategy: "weighted_combinations_and_constraints",
|
|
10350
|
+
numSolutions: allSolutions.length
|
|
10351
|
+
}
|
|
10352
|
+
};
|
|
10353
|
+
}
|
|
10354
|
+
/**
|
|
10355
|
+
* Generate solutions using different weighted combinations of objectives
|
|
10356
|
+
*/
|
|
10357
|
+
async generateWeightedSolutions(program, metricFn, options) {
|
|
10358
|
+
const solutions = [];
|
|
10359
|
+
const sampleExample = this.examples[0];
|
|
10360
|
+
const samplePrediction = await program.forward(
|
|
10361
|
+
this.studentAI,
|
|
10362
|
+
sampleExample
|
|
10363
|
+
);
|
|
10364
|
+
const sampleScores = await metricFn({
|
|
10365
|
+
prediction: samplePrediction,
|
|
10366
|
+
example: sampleExample
|
|
10367
|
+
});
|
|
10368
|
+
const objectives = Object.keys(sampleScores);
|
|
10369
|
+
if (options?.verbose) {
|
|
10370
|
+
console.log(`Detected objectives: ${objectives.join(", ")}`);
|
|
10371
|
+
}
|
|
10372
|
+
const weightCombinations = this.generateWeightCombinations(objectives);
|
|
10373
|
+
for (let i = 0; i < weightCombinations.length; i++) {
|
|
10374
|
+
const weights = weightCombinations[i];
|
|
10375
|
+
if (options?.verbose) {
|
|
10376
|
+
console.log(`Optimizing with weights: ${JSON.stringify(weights)}`);
|
|
10377
|
+
}
|
|
10378
|
+
const weightedMetric = async ({ prediction, example }) => {
|
|
10379
|
+
const scores = await metricFn({ prediction, example });
|
|
10380
|
+
let weightedScore = 0;
|
|
10381
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10382
|
+
weightedScore += score * (weights[objective] || 0);
|
|
10383
|
+
}
|
|
10384
|
+
return weightedScore;
|
|
10385
|
+
};
|
|
10386
|
+
try {
|
|
10387
|
+
const result = await this.compile(program, weightedMetric, {
|
|
10388
|
+
...options,
|
|
10389
|
+
verbose: false
|
|
10390
|
+
// Suppress inner optimization logs
|
|
10391
|
+
});
|
|
10392
|
+
const scores = await this.evaluateWithMultiObjective(
|
|
10393
|
+
program,
|
|
10394
|
+
result,
|
|
10395
|
+
metricFn
|
|
10396
|
+
);
|
|
10397
|
+
solutions.push({
|
|
10398
|
+
scores,
|
|
10399
|
+
demos: result.demos,
|
|
10400
|
+
configuration: {
|
|
10401
|
+
...result.finalConfiguration,
|
|
10402
|
+
weights,
|
|
10403
|
+
strategy: "weighted_combination"
|
|
10404
|
+
}
|
|
10405
|
+
});
|
|
10406
|
+
} catch (error) {
|
|
10407
|
+
if (options?.verbose) {
|
|
10408
|
+
console.warn(
|
|
10409
|
+
`Failed optimization with weights ${JSON.stringify(weights)}:`,
|
|
10410
|
+
error
|
|
10411
|
+
);
|
|
10412
|
+
}
|
|
10413
|
+
continue;
|
|
10414
|
+
}
|
|
10415
|
+
}
|
|
10416
|
+
return solutions;
|
|
10417
|
+
}
|
|
10418
|
+
/**
|
|
10419
|
+
* Generate solutions using constraint-based optimization
|
|
10420
|
+
*/
|
|
10421
|
+
async generateConstraintSolutions(program, metricFn, options) {
|
|
10422
|
+
const solutions = [];
|
|
10423
|
+
const sampleExample = this.examples[0];
|
|
10424
|
+
const samplePrediction = await program.forward(
|
|
10425
|
+
this.studentAI,
|
|
10426
|
+
sampleExample
|
|
10427
|
+
);
|
|
10428
|
+
const sampleScores = await metricFn({
|
|
10429
|
+
prediction: samplePrediction,
|
|
10430
|
+
example: sampleExample
|
|
10431
|
+
});
|
|
10432
|
+
const objectives = Object.keys(sampleScores);
|
|
10433
|
+
for (const primaryObjective of objectives) {
|
|
10434
|
+
if (options?.verbose) {
|
|
10435
|
+
console.log(
|
|
10436
|
+
`Optimizing ${primaryObjective} with constraints on other objectives`
|
|
10437
|
+
);
|
|
10438
|
+
}
|
|
10439
|
+
const constraintMetric = async ({ prediction, example }) => {
|
|
10440
|
+
const scores = await metricFn({ prediction, example });
|
|
10441
|
+
const primaryScore = scores[primaryObjective] || 0;
|
|
10442
|
+
let penalty = 0;
|
|
10443
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10444
|
+
if (objective !== primaryObjective) {
|
|
10445
|
+
if (score < 0.3) {
|
|
10446
|
+
penalty += (0.3 - score) * 2;
|
|
10447
|
+
}
|
|
10448
|
+
}
|
|
10449
|
+
}
|
|
10450
|
+
return primaryScore - penalty;
|
|
10451
|
+
};
|
|
10452
|
+
try {
|
|
10453
|
+
const result = await this.compile(program, constraintMetric, {
|
|
10454
|
+
...options,
|
|
10455
|
+
verbose: false
|
|
10456
|
+
});
|
|
10457
|
+
const scores = await this.evaluateWithMultiObjective(
|
|
10458
|
+
program,
|
|
10459
|
+
result,
|
|
10460
|
+
metricFn
|
|
10461
|
+
);
|
|
10462
|
+
solutions.push({
|
|
10463
|
+
scores,
|
|
10464
|
+
demos: result.demos,
|
|
10465
|
+
configuration: {
|
|
10466
|
+
...result.finalConfiguration,
|
|
10467
|
+
primaryObjective,
|
|
10468
|
+
strategy: "constraint_based"
|
|
10469
|
+
}
|
|
10470
|
+
});
|
|
10471
|
+
} catch (error) {
|
|
10472
|
+
if (options?.verbose) {
|
|
10473
|
+
console.warn(
|
|
10474
|
+
`Failed constraint optimization for ${primaryObjective}:`,
|
|
10475
|
+
error
|
|
10476
|
+
);
|
|
10477
|
+
}
|
|
10478
|
+
continue;
|
|
10479
|
+
}
|
|
10480
|
+
}
|
|
10481
|
+
return solutions;
|
|
10482
|
+
}
|
|
10483
|
+
/**
|
|
10484
|
+
* Generate different weight combinations for objectives
|
|
10485
|
+
*/
|
|
10486
|
+
generateWeightCombinations(objectives) {
|
|
10487
|
+
const combinations = [];
|
|
10488
|
+
for (const objective of objectives) {
|
|
10489
|
+
const weights = {};
|
|
10490
|
+
for (const obj of objectives) {
|
|
10491
|
+
weights[obj] = obj === objective ? 1 : 0;
|
|
10492
|
+
}
|
|
10493
|
+
combinations.push(weights);
|
|
10494
|
+
}
|
|
10495
|
+
const equalWeights = {};
|
|
10496
|
+
for (const objective of objectives) {
|
|
10497
|
+
equalWeights[objective] = 1 / objectives.length;
|
|
10498
|
+
}
|
|
10499
|
+
combinations.push(equalWeights);
|
|
10500
|
+
if (objectives.length === 2) {
|
|
10501
|
+
const [obj1, obj2] = objectives;
|
|
10502
|
+
for (let w1 = 0.1; w1 <= 0.9; w1 += 0.2) {
|
|
10503
|
+
const w2 = 1 - w1;
|
|
10504
|
+
combinations.push({ [obj1]: w1, [obj2]: w2 });
|
|
10505
|
+
}
|
|
10506
|
+
}
|
|
10507
|
+
if (objectives.length === 3) {
|
|
10508
|
+
const [obj1, obj2, obj3] = objectives;
|
|
10509
|
+
combinations.push(
|
|
10510
|
+
{ [obj1]: 0.5, [obj2]: 0.3, [obj3]: 0.2 },
|
|
10511
|
+
{ [obj1]: 0.3, [obj2]: 0.5, [obj3]: 0.2 },
|
|
10512
|
+
{ [obj1]: 0.2, [obj2]: 0.3, [obj3]: 0.5 }
|
|
10513
|
+
);
|
|
10514
|
+
}
|
|
10515
|
+
return combinations;
|
|
10516
|
+
}
|
|
10517
|
+
/**
|
|
10518
|
+
* Evaluate a single-objective result with multi-objective metrics
|
|
10519
|
+
*/
|
|
10520
|
+
async evaluateWithMultiObjective(program, result, metricFn) {
|
|
10521
|
+
const valSet = this.getValidationSet();
|
|
10522
|
+
const allScores = {};
|
|
10523
|
+
const testProgram = { ...program };
|
|
10524
|
+
if (result.demos && "setDemos" in testProgram) {
|
|
10525
|
+
;
|
|
10526
|
+
testProgram.setDemos(result.demos);
|
|
10527
|
+
}
|
|
10528
|
+
const evalSet = valSet.slice(0, Math.min(5, valSet.length));
|
|
10529
|
+
for (const example of evalSet) {
|
|
10530
|
+
try {
|
|
10531
|
+
const prediction = await testProgram.forward(
|
|
10532
|
+
this.studentAI,
|
|
10533
|
+
example
|
|
10534
|
+
);
|
|
10535
|
+
const scores = await metricFn({ prediction, example });
|
|
10536
|
+
for (const [objective, score] of Object.entries(scores)) {
|
|
10537
|
+
if (!allScores[objective]) {
|
|
10538
|
+
allScores[objective] = [];
|
|
10539
|
+
}
|
|
10540
|
+
allScores[objective].push(score);
|
|
10541
|
+
}
|
|
10542
|
+
} catch {
|
|
10543
|
+
continue;
|
|
10544
|
+
}
|
|
10545
|
+
}
|
|
10546
|
+
const avgScores = {};
|
|
10547
|
+
for (const [objective, scores] of Object.entries(allScores)) {
|
|
10548
|
+
avgScores[objective] = scores.length > 0 ? scores.reduce((sum, score) => sum + score, 0) / scores.length : 0;
|
|
10549
|
+
}
|
|
10550
|
+
return avgScores;
|
|
10551
|
+
}
|
|
10552
|
+
/**
|
|
10553
|
+
* Find the Pareto frontier from a set of solutions
|
|
10554
|
+
*/
|
|
10555
|
+
findParetoFrontier(solutions) {
|
|
10556
|
+
const paretoFront = [];
|
|
10557
|
+
for (let i = 0; i < solutions.length; i++) {
|
|
10558
|
+
const solutionA = solutions[i];
|
|
10559
|
+
let isDominated = false;
|
|
10560
|
+
let dominatedCount = 0;
|
|
10561
|
+
for (let j = 0; j < solutions.length; j++) {
|
|
10562
|
+
if (i === j) continue;
|
|
10563
|
+
const solutionB = solutions[j];
|
|
10564
|
+
if (this.dominates(solutionB.scores, solutionA.scores)) {
|
|
10565
|
+
isDominated = true;
|
|
10566
|
+
break;
|
|
10567
|
+
}
|
|
10568
|
+
if (this.dominates(solutionA.scores, solutionB.scores)) {
|
|
10569
|
+
dominatedCount++;
|
|
10570
|
+
}
|
|
10571
|
+
}
|
|
10572
|
+
if (!isDominated) {
|
|
10573
|
+
paretoFront.push({
|
|
10574
|
+
demos: solutionA.demos || [],
|
|
10575
|
+
scores: solutionA.scores,
|
|
10576
|
+
configuration: solutionA.configuration,
|
|
10577
|
+
dominatedSolutions: dominatedCount
|
|
10578
|
+
});
|
|
10579
|
+
}
|
|
10580
|
+
}
|
|
10581
|
+
return paretoFront;
|
|
10582
|
+
}
|
|
10583
|
+
/**
|
|
10584
|
+
* Check if solution A dominates solution B
|
|
10585
|
+
* A dominates B if A is better or equal in all objectives and strictly better in at least one
|
|
10586
|
+
*/
|
|
10587
|
+
dominates(scoresA, scoresB) {
|
|
10588
|
+
const objectives = Object.keys(scoresA);
|
|
10589
|
+
let atLeastAsGood = true;
|
|
10590
|
+
let strictlyBetter = false;
|
|
10591
|
+
for (const objective of objectives) {
|
|
10592
|
+
const scoreA = scoresA[objective] || 0;
|
|
10593
|
+
const scoreB = scoresB[objective] || 0;
|
|
10594
|
+
if (scoreA < scoreB) {
|
|
10595
|
+
atLeastAsGood = false;
|
|
10596
|
+
break;
|
|
10597
|
+
}
|
|
10598
|
+
if (scoreA > scoreB) {
|
|
10599
|
+
strictlyBetter = true;
|
|
10600
|
+
}
|
|
10601
|
+
}
|
|
10602
|
+
return atLeastAsGood && strictlyBetter;
|
|
10603
|
+
}
|
|
10604
|
+
/**
|
|
10605
|
+
* Calculate hypervolume of the Pareto frontier
|
|
10606
|
+
* Simplified implementation using reference point at origin
|
|
10607
|
+
*/
|
|
10608
|
+
calculateHypervolume(paretoFront) {
|
|
10609
|
+
if (paretoFront.length === 0) return void 0;
|
|
10610
|
+
const firstSolution = paretoFront[0];
|
|
10611
|
+
const objectives = Object.keys(firstSolution.scores);
|
|
10612
|
+
if (objectives.length === 2) {
|
|
10613
|
+
const [obj1, obj2] = objectives;
|
|
10614
|
+
let hypervolume = 0;
|
|
10615
|
+
const sortedSolutions = [...paretoFront].sort(
|
|
10616
|
+
(a, b) => (b.scores[obj1] || 0) - (a.scores[obj1] || 0)
|
|
10617
|
+
);
|
|
10618
|
+
let prevScore2 = 0;
|
|
10619
|
+
for (const solution of sortedSolutions) {
|
|
10620
|
+
const score1 = solution.scores[obj1] || 0;
|
|
10621
|
+
const score2 = solution.scores[obj2] || 0;
|
|
10622
|
+
hypervolume += score1 * (score2 - prevScore2);
|
|
10623
|
+
prevScore2 = Math.max(prevScore2, score2);
|
|
10624
|
+
}
|
|
10625
|
+
return hypervolume;
|
|
10626
|
+
}
|
|
10627
|
+
return void 0;
|
|
10628
|
+
}
|
|
10629
|
+
/**
|
|
10630
|
+
* Save current optimization state to checkpoint
|
|
10631
|
+
*/
|
|
10632
|
+
async saveCheckpoint(optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10633
|
+
const saveFn = options?.overrideCheckpointSave || this.checkpointSave;
|
|
10634
|
+
if (!saveFn) return void 0;
|
|
10635
|
+
const checkpoint = {
|
|
10636
|
+
version: "1.0.0",
|
|
10637
|
+
timestamp: Date.now(),
|
|
10638
|
+
optimizerType,
|
|
10639
|
+
optimizerConfig,
|
|
10640
|
+
currentRound: this.currentRound,
|
|
10641
|
+
totalRounds: this.stats.resourceUsage.totalTime > 0 ? this.currentRound : 0,
|
|
10642
|
+
bestScore,
|
|
10643
|
+
bestConfiguration,
|
|
10644
|
+
scoreHistory: [...this.scoreHistory],
|
|
10645
|
+
configurationHistory: [...this.configurationHistory],
|
|
10646
|
+
stats: { ...this.stats },
|
|
10647
|
+
optimizerState,
|
|
10648
|
+
examples: this.examples,
|
|
10649
|
+
validationSet: this.validationSet
|
|
10650
|
+
};
|
|
10651
|
+
return await saveFn(checkpoint);
|
|
10652
|
+
}
|
|
10653
|
+
/**
|
|
10654
|
+
* Load optimization state from checkpoint
|
|
10655
|
+
*/
|
|
10656
|
+
async loadCheckpoint(checkpointId, options) {
|
|
10657
|
+
const loadFn = options?.overrideCheckpointLoad || this.checkpointLoad;
|
|
10658
|
+
if (!loadFn) return null;
|
|
10659
|
+
return await loadFn(checkpointId);
|
|
10660
|
+
}
|
|
10661
|
+
/**
|
|
10662
|
+
* Restore optimizer state from checkpoint
|
|
10663
|
+
*/
|
|
10664
|
+
restoreFromCheckpoint(checkpoint) {
|
|
10665
|
+
this.currentRound = checkpoint.currentRound;
|
|
10666
|
+
this.scoreHistory = [...checkpoint.scoreHistory];
|
|
10667
|
+
this.configurationHistory = [...checkpoint.configurationHistory];
|
|
10668
|
+
this.stats = { ...checkpoint.stats };
|
|
10669
|
+
}
|
|
10670
|
+
/**
|
|
10671
|
+
* Check if checkpoint should be saved
|
|
10672
|
+
*/
|
|
10673
|
+
shouldSaveCheckpoint(round, options) {
|
|
10674
|
+
const interval = options?.overrideCheckpointInterval || this.checkpointInterval;
|
|
10675
|
+
return interval !== void 0 && round % interval === 0;
|
|
10676
|
+
}
|
|
10677
|
+
/**
|
|
10678
|
+
* Update optimization progress and handle checkpointing
|
|
10679
|
+
*/
|
|
10680
|
+
async updateOptimizationProgress(round, score, configuration, optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10681
|
+
this.currentRound = round;
|
|
10682
|
+
this.scoreHistory.push(score);
|
|
10683
|
+
this.configurationHistory.push(configuration);
|
|
10684
|
+
if (this.shouldSaveCheckpoint(round, options)) {
|
|
10685
|
+
await this.saveCheckpoint(
|
|
10686
|
+
optimizerType,
|
|
10687
|
+
optimizerConfig,
|
|
10688
|
+
bestScore,
|
|
10689
|
+
bestConfiguration,
|
|
10690
|
+
optimizerState,
|
|
10691
|
+
options
|
|
10692
|
+
);
|
|
10693
|
+
}
|
|
10694
|
+
}
|
|
10695
|
+
/**
|
|
10696
|
+
* Save final checkpoint on completion
|
|
10697
|
+
*/
|
|
10698
|
+
async saveFinalCheckpoint(optimizerType, optimizerConfig, bestScore, bestConfiguration, optimizerState = {}, options) {
|
|
10699
|
+
if (options?.saveCheckpointOnComplete !== false) {
|
|
10700
|
+
await this.saveCheckpoint(
|
|
10701
|
+
optimizerType,
|
|
10702
|
+
optimizerConfig,
|
|
10703
|
+
bestScore,
|
|
10704
|
+
bestConfiguration,
|
|
10705
|
+
{ ...optimizerState, final: true },
|
|
10706
|
+
options
|
|
10707
|
+
);
|
|
10708
|
+
}
|
|
10709
|
+
}
|
|
10710
|
+
};
|
|
10711
|
+
|
|
10712
|
+
// db/base.ts
|
|
10713
|
+
import { SpanKind as SpanKind3 } from "@opentelemetry/api";
|
|
10714
|
+
var AxDBBase = class {
|
|
10715
|
+
name;
|
|
10716
|
+
fetch;
|
|
10717
|
+
tracer;
|
|
10718
|
+
_upsert;
|
|
10719
|
+
_batchUpsert;
|
|
10720
|
+
_query;
|
|
10721
|
+
constructor({
|
|
10722
|
+
name,
|
|
10723
|
+
fetch: fetch2,
|
|
10724
|
+
tracer
|
|
10725
|
+
}) {
|
|
10726
|
+
this.name = name;
|
|
10727
|
+
this.fetch = fetch2;
|
|
10728
|
+
this.tracer = tracer;
|
|
10729
|
+
}
|
|
10730
|
+
async upsert(req, update) {
|
|
10731
|
+
if (!this._upsert) {
|
|
10732
|
+
throw new Error("upsert() not implemented");
|
|
10733
|
+
}
|
|
10734
|
+
if (!this.tracer) {
|
|
10735
|
+
return await this._upsert(req, update);
|
|
10736
|
+
}
|
|
10737
|
+
return await this.tracer.startActiveSpan(
|
|
10738
|
+
"DB Upsert Request",
|
|
10739
|
+
{
|
|
10740
|
+
kind: SpanKind3.SERVER,
|
|
10741
|
+
attributes: {
|
|
10742
|
+
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10743
|
+
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10744
|
+
[axSpanAttributes.DB_TABLE]: req.table,
|
|
10745
|
+
[axSpanAttributes.DB_NAMESPACE]: req.namespace,
|
|
10746
|
+
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
10747
|
+
}
|
|
10748
|
+
},
|
|
10749
|
+
async (span) => {
|
|
10750
|
+
try {
|
|
10751
|
+
return await this._upsert(req, update, { span });
|
|
10752
|
+
} finally {
|
|
10753
|
+
span.end();
|
|
10754
|
+
}
|
|
10755
|
+
}
|
|
10756
|
+
);
|
|
10757
|
+
}
|
|
10758
|
+
async batchUpsert(req, update) {
|
|
10759
|
+
if (!this._batchUpsert) {
|
|
10760
|
+
throw new Error("batchUpsert() not implemented");
|
|
10761
|
+
}
|
|
10762
|
+
if (req.length == 0) {
|
|
10763
|
+
throw new Error("Batch request is empty");
|
|
10764
|
+
}
|
|
10765
|
+
if (!req[0]) {
|
|
10766
|
+
throw new Error("Batch request is invalid first element is undefined");
|
|
10767
|
+
}
|
|
10768
|
+
if (!this.tracer) {
|
|
10769
|
+
return await this._batchUpsert(req, update);
|
|
10770
|
+
}
|
|
10771
|
+
return await this.tracer.startActiveSpan(
|
|
10772
|
+
"DB Batch Upsert Request",
|
|
10773
|
+
{
|
|
10774
|
+
kind: SpanKind3.SERVER,
|
|
10775
|
+
attributes: {
|
|
10776
|
+
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10777
|
+
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10778
|
+
[axSpanAttributes.DB_TABLE]: req[0].table,
|
|
10779
|
+
[axSpanAttributes.DB_NAMESPACE]: req[0].namespace,
|
|
10780
|
+
[axSpanAttributes.DB_OPERATION_NAME]: update ? "update" : "insert"
|
|
10781
|
+
}
|
|
10782
|
+
},
|
|
10783
|
+
async (span) => {
|
|
10784
|
+
try {
|
|
10785
|
+
return await this._batchUpsert(req, update, { span });
|
|
10786
|
+
} finally {
|
|
10787
|
+
span.end();
|
|
10788
|
+
}
|
|
10789
|
+
}
|
|
10790
|
+
);
|
|
10791
|
+
}
|
|
10792
|
+
async query(req) {
|
|
10793
|
+
if (!this._query) {
|
|
10794
|
+
throw new Error("query() not implemented");
|
|
10795
|
+
}
|
|
10796
|
+
if (!this.tracer) {
|
|
10797
|
+
return await this._query(req);
|
|
10798
|
+
}
|
|
10799
|
+
return await this.tracer.startActiveSpan(
|
|
10800
|
+
"DB Query Request",
|
|
10801
|
+
{
|
|
10802
|
+
kind: SpanKind3.SERVER,
|
|
10803
|
+
attributes: {
|
|
10804
|
+
[axSpanAttributes.DB_SYSTEM]: this.name,
|
|
10805
|
+
[axSpanAttributes.DB_OPERATION_NAME]: "upsert",
|
|
10806
|
+
[axSpanAttributes.DB_TABLE]: req.table,
|
|
10807
|
+
[axSpanAttributes.DB_NAMESPACE]: req.namespace,
|
|
10808
|
+
[axSpanAttributes.DB_OPERATION_NAME]: "query"
|
|
10809
|
+
}
|
|
10810
|
+
},
|
|
10811
|
+
async (span) => {
|
|
10812
|
+
try {
|
|
10813
|
+
return await this._query(req, { span });
|
|
10814
|
+
} finally {
|
|
10815
|
+
span.end();
|
|
10816
|
+
}
|
|
10817
|
+
}
|
|
10818
|
+
);
|
|
10819
|
+
}
|
|
10820
|
+
};
|
|
10821
|
+
|
|
10822
|
+
// db/cloudflare.ts
|
|
10823
|
+
var baseURL = "https://api.cloudflare.com/client/v4/accounts/";
|
|
10824
|
+
var AxDBCloudflare = class extends AxDBBase {
|
|
10825
|
+
apiKey;
|
|
10826
|
+
accountId;
|
|
10827
|
+
constructor({
|
|
10828
|
+
apiKey,
|
|
10829
|
+
accountId,
|
|
10830
|
+
fetch: fetch2,
|
|
10831
|
+
tracer
|
|
10832
|
+
}) {
|
|
10833
|
+
if (!apiKey || !accountId) {
|
|
10834
|
+
throw new Error("Cloudflare credentials not set");
|
|
10835
|
+
}
|
|
10836
|
+
super({ name: "Cloudflare", fetch: fetch2, tracer });
|
|
10837
|
+
this.apiKey = apiKey;
|
|
10838
|
+
this.accountId = accountId;
|
|
10839
|
+
}
|
|
10840
|
+
_upsert = async (req, _update, options) => {
|
|
10841
|
+
const res = await apiCall(
|
|
10842
|
+
{
|
|
10843
|
+
url: new URL(
|
|
10844
|
+
`${this.accountId}/vectorize/indexes/${req.table}/upsert`,
|
|
10845
|
+
baseURL
|
|
10846
|
+
),
|
|
10847
|
+
headers: {
|
|
10848
|
+
"X-Auth-Key": this.apiKey
|
|
10849
|
+
},
|
|
10850
|
+
fetch: this.fetch,
|
|
10851
|
+
span: options?.span
|
|
10852
|
+
},
|
|
10853
|
+
{
|
|
10854
|
+
id: req.id,
|
|
10855
|
+
values: req.values,
|
|
10856
|
+
namespace: req.namespace,
|
|
10857
|
+
metadata: req.metadata
|
|
10858
|
+
}
|
|
10859
|
+
);
|
|
10860
|
+
if (res.errors) {
|
|
10861
|
+
throw new Error(
|
|
10862
|
+
`Cloudflare upsert failed: ${res.errors.map(({ message }) => message).join(", ")}`
|
|
10863
|
+
);
|
|
10864
|
+
}
|
|
10865
|
+
return {
|
|
10866
|
+
ids: res.result.ids
|
|
10867
|
+
};
|
|
10868
|
+
};
|
|
10869
|
+
batchUpsert = async (batchReq, update, options) => {
|
|
10870
|
+
if (update) {
|
|
10871
|
+
throw new Error("Weaviate does not support batch update");
|
|
10872
|
+
}
|
|
10873
|
+
if (batchReq.length < 1) {
|
|
10874
|
+
throw new Error("Batch request is empty");
|
|
10875
|
+
}
|
|
10876
|
+
if (!batchReq[0] || !batchReq[0].table) {
|
|
10152
10877
|
throw new Error("Table name is empty");
|
|
10153
10878
|
}
|
|
10154
10879
|
const table2 = batchReq[0].table;
|
|
@@ -11443,11 +12168,7 @@ var AxMCPStreambleHTTPTransport = class {
|
|
|
11443
12168
|
};
|
|
11444
12169
|
|
|
11445
12170
|
// dsp/optimizers/bootstrapFewshot.ts
|
|
11446
|
-
var AxBootstrapFewShot = class {
|
|
11447
|
-
ai;
|
|
11448
|
-
teacherAI;
|
|
11449
|
-
program;
|
|
11450
|
-
examples;
|
|
12171
|
+
var AxBootstrapFewShot = class extends AxBaseOptimizer {
|
|
11451
12172
|
maxRounds;
|
|
11452
12173
|
maxDemos;
|
|
11453
12174
|
maxExamples;
|
|
@@ -11458,37 +12179,20 @@ var AxBootstrapFewShot = class {
|
|
|
11458
12179
|
verboseMode;
|
|
11459
12180
|
debugMode;
|
|
11460
12181
|
traces = [];
|
|
11461
|
-
|
|
11462
|
-
|
|
11463
|
-
|
|
11464
|
-
|
|
11465
|
-
|
|
11466
|
-
|
|
11467
|
-
|
|
11468
|
-
|
|
11469
|
-
|
|
11470
|
-
|
|
11471
|
-
options
|
|
11472
|
-
|
|
11473
|
-
|
|
11474
|
-
|
|
11475
|
-
}
|
|
11476
|
-
const bootstrapOptions = options;
|
|
11477
|
-
this.maxRounds = bootstrapOptions?.maxRounds ?? 3;
|
|
11478
|
-
this.maxDemos = bootstrapOptions?.maxDemos ?? 4;
|
|
11479
|
-
this.maxExamples = bootstrapOptions?.maxExamples ?? 16;
|
|
11480
|
-
this.batchSize = bootstrapOptions?.batchSize ?? 1;
|
|
11481
|
-
this.earlyStoppingPatience = bootstrapOptions?.earlyStoppingPatience ?? 0;
|
|
11482
|
-
this.costMonitoring = bootstrapOptions?.costMonitoring ?? false;
|
|
11483
|
-
this.maxTokensPerGeneration = bootstrapOptions?.maxTokensPerGeneration ?? 0;
|
|
11484
|
-
this.verboseMode = bootstrapOptions?.verboseMode ?? true;
|
|
11485
|
-
this.debugMode = bootstrapOptions?.debugMode ?? false;
|
|
11486
|
-
this.ai = ai;
|
|
11487
|
-
this.teacherAI = bootstrapOptions?.teacherAI;
|
|
11488
|
-
this.program = program;
|
|
11489
|
-
this.examples = examples;
|
|
11490
|
-
}
|
|
11491
|
-
async compileRound(roundIndex, metricFn, options) {
|
|
12182
|
+
constructor(args) {
|
|
12183
|
+
super(args);
|
|
12184
|
+
const options = args.options || {};
|
|
12185
|
+
this.maxRounds = options.maxRounds ?? 3;
|
|
12186
|
+
this.maxDemos = options.maxDemos ?? 4;
|
|
12187
|
+
this.maxExamples = options.maxExamples ?? 16;
|
|
12188
|
+
this.batchSize = options.batchSize ?? 1;
|
|
12189
|
+
this.earlyStoppingPatience = options.earlyStoppingPatience ?? 0;
|
|
12190
|
+
this.costMonitoring = options.costMonitoring ?? false;
|
|
12191
|
+
this.maxTokensPerGeneration = options.maxTokensPerGeneration ?? 0;
|
|
12192
|
+
this.verboseMode = options.verboseMode ?? true;
|
|
12193
|
+
this.debugMode = options.debugMode ?? false;
|
|
12194
|
+
}
|
|
12195
|
+
async compileRound(program, roundIndex, metricFn, options) {
|
|
11492
12196
|
const st = (/* @__PURE__ */ new Date()).getTime();
|
|
11493
12197
|
const maxDemos = options?.maxDemos ?? this.maxDemos;
|
|
11494
12198
|
const aiOpt = {
|
|
@@ -11511,20 +12215,20 @@ var AxBootstrapFewShot = class {
|
|
|
11511
12215
|
continue;
|
|
11512
12216
|
}
|
|
11513
12217
|
const exList = examples.filter((e) => e !== ex);
|
|
11514
|
-
|
|
11515
|
-
const aiService = this.
|
|
12218
|
+
program.setExamples(exList);
|
|
12219
|
+
const aiService = this.getTeacherOrStudentAI();
|
|
11516
12220
|
this.stats.totalCalls++;
|
|
11517
12221
|
let res;
|
|
11518
12222
|
let error;
|
|
11519
12223
|
try {
|
|
11520
|
-
res = await
|
|
12224
|
+
res = await program.forward(aiService, ex, aiOpt);
|
|
11521
12225
|
if (this.costMonitoring) {
|
|
11522
12226
|
this.stats.estimatedTokenUsage += JSON.stringify(ex).length / 4 + JSON.stringify(res).length / 4;
|
|
11523
12227
|
}
|
|
11524
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
12228
|
+
const score = await metricFn({ prediction: res, example: ex });
|
|
11525
12229
|
const success = score >= 0.5;
|
|
11526
12230
|
if (success) {
|
|
11527
|
-
this.traces = [...this.traces, ...
|
|
12231
|
+
this.traces = [...this.traces, ...program.getTraces()];
|
|
11528
12232
|
this.stats.successfulDemos++;
|
|
11529
12233
|
}
|
|
11530
12234
|
} catch (err) {
|
|
@@ -11575,13 +12279,15 @@ var AxBootstrapFewShot = class {
|
|
|
11575
12279
|
if (!this.stats.earlyStopping) {
|
|
11576
12280
|
this.stats.earlyStopping = {
|
|
11577
12281
|
bestScoreRound: improvement > 0 ? roundIndex : 0,
|
|
11578
|
-
patienceExhausted: false
|
|
12282
|
+
patienceExhausted: false,
|
|
12283
|
+
reason: "No improvement detected"
|
|
11579
12284
|
};
|
|
11580
12285
|
} else if (improvement > 0) {
|
|
11581
12286
|
this.stats.earlyStopping.bestScoreRound = roundIndex;
|
|
11582
12287
|
} else if (roundIndex - this.stats.earlyStopping.bestScoreRound >= this.earlyStoppingPatience) {
|
|
11583
12288
|
this.stats.earlyStopping.patienceExhausted = true;
|
|
11584
12289
|
this.stats.earlyStopped = true;
|
|
12290
|
+
this.stats.earlyStopping.reason = `No improvement for ${this.earlyStoppingPatience} rounds`;
|
|
11585
12291
|
if (this.verboseMode || this.debugMode) {
|
|
11586
12292
|
console.log(
|
|
11587
12293
|
`
|
|
@@ -11592,37 +12298,38 @@ Early stopping triggered after ${roundIndex + 1} rounds. No improvement for ${th
|
|
|
11592
12298
|
}
|
|
11593
12299
|
}
|
|
11594
12300
|
}
|
|
11595
|
-
async compile(metricFn, options) {
|
|
11596
|
-
const
|
|
11597
|
-
const maxRounds = compileOptions?.maxRounds ?? this.maxRounds;
|
|
12301
|
+
async compile(program, metricFn, options) {
|
|
12302
|
+
const maxRounds = options?.maxIterations ?? this.maxRounds;
|
|
11598
12303
|
this.traces = [];
|
|
11599
|
-
this.
|
|
11600
|
-
totalCalls: 0,
|
|
11601
|
-
successfulDemos: 0,
|
|
11602
|
-
estimatedTokenUsage: 0,
|
|
11603
|
-
earlyStopped: false
|
|
11604
|
-
};
|
|
12304
|
+
this.reset();
|
|
11605
12305
|
for (let i = 0; i < maxRounds; i++) {
|
|
11606
|
-
await this.compileRound(i, metricFn,
|
|
12306
|
+
await this.compileRound(program, i, metricFn, options);
|
|
11607
12307
|
if (this.stats.earlyStopped) {
|
|
11608
12308
|
break;
|
|
11609
12309
|
}
|
|
11610
12310
|
}
|
|
11611
12311
|
if (this.traces.length === 0) {
|
|
11612
12312
|
throw new Error(
|
|
11613
|
-
"No demonstrations found. Either
|
|
12313
|
+
"No demonstrations found. Either provide more examples or improve the existing ones."
|
|
11614
12314
|
);
|
|
11615
12315
|
}
|
|
11616
12316
|
const demos = groupTracesByKeys(this.traces);
|
|
12317
|
+
let bestScore = 0;
|
|
12318
|
+
if (this.traces.length > 0) {
|
|
12319
|
+
bestScore = this.stats.successfulDemos / Math.max(1, this.stats.totalCalls);
|
|
12320
|
+
}
|
|
11617
12321
|
return {
|
|
11618
12322
|
demos,
|
|
11619
|
-
stats: this.stats
|
|
12323
|
+
stats: this.stats,
|
|
12324
|
+
bestScore,
|
|
12325
|
+
finalConfiguration: {
|
|
12326
|
+
maxRounds: this.maxRounds,
|
|
12327
|
+
maxDemos: this.maxDemos,
|
|
12328
|
+
batchSize: this.batchSize,
|
|
12329
|
+
successRate: bestScore
|
|
12330
|
+
}
|
|
11620
12331
|
};
|
|
11621
12332
|
}
|
|
11622
|
-
// Get optimization statistics
|
|
11623
|
-
getStats() {
|
|
11624
|
-
return this.stats;
|
|
11625
|
-
}
|
|
11626
12333
|
};
|
|
11627
12334
|
function groupTracesByKeys(programTraces) {
|
|
11628
12335
|
const groupedTraces = /* @__PURE__ */ new Map();
|
|
@@ -11637,9 +12344,12 @@ function groupTracesByKeys(programTraces) {
|
|
|
11637
12344
|
}
|
|
11638
12345
|
}
|
|
11639
12346
|
const programDemosArray = [];
|
|
11640
|
-
|
|
11641
|
-
programDemosArray.push({
|
|
11642
|
-
|
|
12347
|
+
groupedTraces.forEach((traces, programId) => {
|
|
12348
|
+
programDemosArray.push({
|
|
12349
|
+
traces,
|
|
12350
|
+
programId
|
|
12351
|
+
});
|
|
12352
|
+
});
|
|
11643
12353
|
return programDemosArray;
|
|
11644
12354
|
}
|
|
11645
12355
|
var randomSample = (array, n) => {
|
|
@@ -11658,10 +12368,8 @@ var randomSample = (array, n) => {
|
|
|
11658
12368
|
};
|
|
11659
12369
|
|
|
11660
12370
|
// dsp/optimizers/miproV2.ts
|
|
11661
|
-
var AxMiPRO = class {
|
|
11662
|
-
|
|
11663
|
-
program;
|
|
11664
|
-
examples;
|
|
12371
|
+
var AxMiPRO = class extends AxBaseOptimizer {
|
|
12372
|
+
// MiPRO-specific options
|
|
11665
12373
|
maxBootstrappedDemos;
|
|
11666
12374
|
maxLabeledDemos;
|
|
11667
12375
|
numCandidates;
|
|
@@ -11675,52 +12383,35 @@ var AxMiPRO = class {
|
|
|
11675
12383
|
viewDataBatchSize;
|
|
11676
12384
|
tipAwareProposer;
|
|
11677
12385
|
fewshotAwareProposer;
|
|
11678
|
-
seed;
|
|
11679
12386
|
verbose;
|
|
11680
|
-
bootstrapper;
|
|
11681
12387
|
earlyStoppingTrials;
|
|
11682
12388
|
minImprovementThreshold;
|
|
11683
|
-
|
|
11684
|
-
|
|
11685
|
-
|
|
11686
|
-
|
|
11687
|
-
|
|
11688
|
-
|
|
11689
|
-
|
|
11690
|
-
|
|
11691
|
-
|
|
11692
|
-
|
|
11693
|
-
this.
|
|
11694
|
-
this.
|
|
11695
|
-
this.
|
|
11696
|
-
this.
|
|
11697
|
-
this.
|
|
11698
|
-
this.
|
|
11699
|
-
this.
|
|
11700
|
-
this.
|
|
11701
|
-
this.
|
|
11702
|
-
this.
|
|
11703
|
-
this.
|
|
11704
|
-
this.
|
|
11705
|
-
this.
|
|
11706
|
-
this.
|
|
11707
|
-
this.
|
|
11708
|
-
this.
|
|
11709
|
-
this.minImprovementThreshold = miproOptions.minImprovementThreshold ?? 0.01;
|
|
11710
|
-
this.ai = ai;
|
|
11711
|
-
this.program = program;
|
|
11712
|
-
this.examples = examples;
|
|
11713
|
-
this.bootstrapper = new AxBootstrapFewShot({
|
|
11714
|
-
ai,
|
|
11715
|
-
program,
|
|
11716
|
-
examples,
|
|
11717
|
-
options: {
|
|
11718
|
-
maxDemos: this.maxBootstrappedDemos,
|
|
11719
|
-
maxRounds: 3,
|
|
11720
|
-
// Default, or adjust based on your needs
|
|
11721
|
-
verboseMode: this.verbose
|
|
11722
|
-
}
|
|
11723
|
-
});
|
|
12389
|
+
bayesianOptimization;
|
|
12390
|
+
acquisitionFunction;
|
|
12391
|
+
explorationWeight;
|
|
12392
|
+
constructor(args) {
|
|
12393
|
+
super(args);
|
|
12394
|
+
const options = args.options || {};
|
|
12395
|
+
this.numCandidates = options.numCandidates ?? 5;
|
|
12396
|
+
this.initTemperature = options.initTemperature ?? 0.7;
|
|
12397
|
+
this.maxBootstrappedDemos = options.maxBootstrappedDemos ?? 3;
|
|
12398
|
+
this.maxLabeledDemos = options.maxLabeledDemos ?? 4;
|
|
12399
|
+
this.numTrials = options.numTrials ?? 30;
|
|
12400
|
+
this.minibatch = options.minibatch ?? true;
|
|
12401
|
+
this.minibatchSize = options.minibatchSize ?? 25;
|
|
12402
|
+
this.minibatchFullEvalSteps = options.minibatchFullEvalSteps ?? 10;
|
|
12403
|
+
this.programAwareProposer = options.programAwareProposer ?? true;
|
|
12404
|
+
this.dataAwareProposer = options.dataAwareProposer ?? true;
|
|
12405
|
+
this.viewDataBatchSize = options.viewDataBatchSize ?? 10;
|
|
12406
|
+
this.tipAwareProposer = options.tipAwareProposer ?? true;
|
|
12407
|
+
this.fewshotAwareProposer = options.fewshotAwareProposer ?? true;
|
|
12408
|
+
this.verbose = options.verbose ?? false;
|
|
12409
|
+
this.earlyStoppingTrials = options.earlyStoppingTrials ?? 5;
|
|
12410
|
+
this.minImprovementThreshold = options.minImprovementThreshold ?? 0.01;
|
|
12411
|
+
this.bayesianOptimization = options.bayesianOptimization ?? false;
|
|
12412
|
+
this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
|
|
12413
|
+
this.explorationWeight = options.explorationWeight ?? 0.1;
|
|
12414
|
+
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
11724
12415
|
}
|
|
11725
12416
|
/**
|
|
11726
12417
|
* Configures the optimizer for light, medium, or heavy optimization
|
|
@@ -11764,123 +12455,60 @@ var AxMiPRO = class {
|
|
|
11764
12455
|
];
|
|
11765
12456
|
}
|
|
11766
12457
|
/**
|
|
11767
|
-
* Generates instruction candidates
|
|
12458
|
+
* Generates instruction candidates using the teacher model if available
|
|
12459
|
+
* @param options Optional compile options that may override teacher AI
|
|
11768
12460
|
* @returns Array of generated instruction candidates
|
|
11769
12461
|
*/
|
|
11770
|
-
async proposeInstructionCandidates() {
|
|
12462
|
+
async proposeInstructionCandidates(options) {
|
|
11771
12463
|
const instructions = [];
|
|
11772
|
-
|
|
11773
|
-
if (this.programAwareProposer) {
|
|
11774
|
-
programContext = await this.generateProgramSummary();
|
|
11775
|
-
}
|
|
11776
|
-
let dataContext = "";
|
|
11777
|
-
if (this.dataAwareProposer) {
|
|
11778
|
-
dataContext = await this.generateDataSummary();
|
|
11779
|
-
}
|
|
12464
|
+
const aiToUse = this.getTeacherOrStudentAI(options);
|
|
11780
12465
|
const tips = this.tipAwareProposer ? this.generateTips() : [];
|
|
11781
12466
|
for (let i = 0; i < this.numCandidates; i++) {
|
|
11782
12467
|
const tipIndex = tips.length > 0 ? i % tips.length : -1;
|
|
11783
12468
|
const tipToUse = tipIndex >= 0 ? tips[tipIndex] : "";
|
|
11784
12469
|
const instruction = await this.generateInstruction({
|
|
11785
|
-
programContext,
|
|
11786
|
-
dataContext,
|
|
11787
12470
|
tip: tipToUse,
|
|
11788
|
-
candidateIndex: i
|
|
12471
|
+
candidateIndex: i,
|
|
12472
|
+
ai: aiToUse
|
|
11789
12473
|
});
|
|
11790
12474
|
instructions.push(instruction);
|
|
11791
12475
|
}
|
|
11792
12476
|
return instructions;
|
|
11793
12477
|
}
|
|
11794
|
-
/**
|
|
11795
|
-
* Generates a summary of the program structure for instruction proposal
|
|
11796
|
-
*/
|
|
11797
|
-
async generateProgramSummary() {
|
|
11798
|
-
const prompt = `Summarize the following program structure. Focus on the signatures,
|
|
11799
|
-
input/output fields, and the purpose of each component. Identify key components
|
|
11800
|
-
that might benefit from better instructions.`;
|
|
11801
|
-
const programStr = JSON.stringify(this.program);
|
|
11802
|
-
const response = await this.ai.chat({
|
|
11803
|
-
chatPrompt: [
|
|
11804
|
-
{ role: "system", content: prompt },
|
|
11805
|
-
{ role: "user", content: programStr }
|
|
11806
|
-
],
|
|
11807
|
-
modelConfig: { temperature: 0.2 }
|
|
11808
|
-
});
|
|
11809
|
-
if (response instanceof ReadableStream) {
|
|
11810
|
-
return "";
|
|
11811
|
-
}
|
|
11812
|
-
return response.results[0]?.content || "";
|
|
11813
|
-
}
|
|
11814
|
-
/**
|
|
11815
|
-
* Generates a summary of the dataset for instruction proposal
|
|
11816
|
-
*/
|
|
11817
|
-
async generateDataSummary() {
|
|
11818
|
-
const sampleSize = Math.min(this.viewDataBatchSize, this.examples.length);
|
|
11819
|
-
const sample = this.examples.slice(0, sampleSize);
|
|
11820
|
-
const prompt = `Analyze the following dataset examples and provide a summary
|
|
11821
|
-
of key patterns, input-output relationships, and any specific challenges
|
|
11822
|
-
the data presents. Focus on what makes a good answer and what patterns should
|
|
11823
|
-
be followed.`;
|
|
11824
|
-
const dataStr = JSON.stringify(sample);
|
|
11825
|
-
const response = await this.ai.chat({
|
|
11826
|
-
chatPrompt: [
|
|
11827
|
-
{ role: "system", content: prompt },
|
|
11828
|
-
{ role: "user", content: dataStr }
|
|
11829
|
-
],
|
|
11830
|
-
modelConfig: { temperature: 0.2 }
|
|
11831
|
-
});
|
|
11832
|
-
if (response instanceof ReadableStream) {
|
|
11833
|
-
return "";
|
|
11834
|
-
}
|
|
11835
|
-
return response.results[0]?.content || "";
|
|
11836
|
-
}
|
|
11837
|
-
/**
|
|
11838
|
-
* Generates a specific instruction candidate
|
|
11839
|
-
*/
|
|
11840
12478
|
async generateInstruction({
|
|
11841
|
-
programContext,
|
|
11842
|
-
dataContext,
|
|
11843
12479
|
tip,
|
|
11844
12480
|
candidateIndex
|
|
11845
12481
|
}) {
|
|
11846
|
-
const
|
|
11847
|
-
|
|
11848
|
-
|
|
11849
|
-
|
|
11850
|
-
|
|
11851
|
-
|
|
11852
|
-
|
|
11853
|
-
|
|
11854
|
-
|
|
11855
|
-
|
|
11856
|
-
${tip ? `STYLE TIP: ${tip}
|
|
11857
|
-
|
|
11858
|
-
` : ""}
|
|
11859
|
-
|
|
11860
|
-
Your task is to craft a clear, effective instruction that will help the AI model generate
|
|
11861
|
-
accurate outputs for this task. Instruction #${candidateIndex + 1}/${this.numCandidates}.
|
|
11862
|
-
|
|
11863
|
-
The instruction should be detailed enough to guide the model but not overly prescriptive
|
|
11864
|
-
or restrictive. Focus on what makes a good response rather than listing exact steps.
|
|
11865
|
-
|
|
11866
|
-
INSTRUCTION:`;
|
|
11867
|
-
const response = await this.ai.chat({
|
|
11868
|
-
chatPrompt: [{ role: "user", content: prompt }],
|
|
11869
|
-
modelConfig: { temperature: 0.7 + 0.1 * candidateIndex }
|
|
11870
|
-
});
|
|
11871
|
-
if (response instanceof ReadableStream) {
|
|
11872
|
-
return "";
|
|
12482
|
+
const baseInstructions = [
|
|
12483
|
+
"Analyze the input carefully and provide a detailed response.",
|
|
12484
|
+
"Think step by step and provide a clear answer.",
|
|
12485
|
+
"Consider all aspects of the input before responding.",
|
|
12486
|
+
"Provide a concise but comprehensive response.",
|
|
12487
|
+
"Focus on accuracy and clarity in your response."
|
|
12488
|
+
];
|
|
12489
|
+
let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
|
|
12490
|
+
if (tip) {
|
|
12491
|
+
instruction = `${instruction} ${tip}`;
|
|
11873
12492
|
}
|
|
11874
|
-
return
|
|
12493
|
+
return instruction;
|
|
11875
12494
|
}
|
|
11876
12495
|
/**
|
|
11877
12496
|
* Bootstraps few-shot examples for the program
|
|
11878
12497
|
*/
|
|
11879
|
-
async bootstrapFewShotExamples(metricFn) {
|
|
12498
|
+
async bootstrapFewShotExamples(program, metricFn) {
|
|
11880
12499
|
if (this.verbose) {
|
|
11881
12500
|
console.log("Bootstrapping few-shot examples...");
|
|
11882
12501
|
}
|
|
11883
|
-
const
|
|
12502
|
+
const bootstrapper = new AxBootstrapFewShot({
|
|
12503
|
+
studentAI: this.studentAI,
|
|
12504
|
+
examples: this.examples,
|
|
12505
|
+
options: {
|
|
12506
|
+
maxDemos: this.maxBootstrappedDemos,
|
|
12507
|
+
maxRounds: 3,
|
|
12508
|
+
verboseMode: this.verbose
|
|
12509
|
+
}
|
|
12510
|
+
});
|
|
12511
|
+
const result = await bootstrapper.compile(program, metricFn, {
|
|
11884
12512
|
maxDemos: this.maxBootstrappedDemos
|
|
11885
12513
|
});
|
|
11886
12514
|
return result.demos || [];
|
|
@@ -11904,109 +12532,98 @@ ${dataContext}
|
|
|
11904
12532
|
return selectedExamples;
|
|
11905
12533
|
}
|
|
11906
12534
|
/**
|
|
11907
|
-
* Runs
|
|
12535
|
+
* Runs optimization to find the best combination of few-shot examples and instructions
|
|
11908
12536
|
*/
|
|
11909
|
-
async
|
|
11910
|
-
let bestConfig =
|
|
11911
|
-
let bestScore = Number.NEGATIVE_INFINITY;
|
|
11912
|
-
const evaluatedConfigs = [];
|
|
11913
|
-
const defaultConfig = {
|
|
12537
|
+
async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, valset, metricFn, options) {
|
|
12538
|
+
let bestConfig = {
|
|
11914
12539
|
instruction: instructions[0] || "",
|
|
11915
12540
|
bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
|
|
11916
12541
|
labeledExamples: Math.min(1, labeledExamples.length)
|
|
11917
12542
|
};
|
|
11918
|
-
let
|
|
11919
|
-
let
|
|
11920
|
-
const
|
|
11921
|
-
|
|
11922
|
-
|
|
11923
|
-
|
|
11924
|
-
|
|
11925
|
-
|
|
11926
|
-
|
|
11927
|
-
|
|
12543
|
+
let bestScore = 0;
|
|
12544
|
+
let stagnationRounds = 0;
|
|
12545
|
+
const scoreHistory = [];
|
|
12546
|
+
let startRound = 0;
|
|
12547
|
+
if (this.resumeFromCheckpoint) {
|
|
12548
|
+
const checkpoint = await this.loadCheckpoint(
|
|
12549
|
+
this.resumeFromCheckpoint,
|
|
12550
|
+
options
|
|
12551
|
+
);
|
|
12552
|
+
if (checkpoint && checkpoint.optimizerType === "MiPRO") {
|
|
12553
|
+
if (this.verbose || options?.verbose) {
|
|
12554
|
+
console.log(
|
|
12555
|
+
`Resuming from checkpoint at round ${checkpoint.currentRound}`
|
|
12556
|
+
);
|
|
12557
|
+
}
|
|
12558
|
+
this.restoreFromCheckpoint(checkpoint);
|
|
12559
|
+
startRound = checkpoint.currentRound;
|
|
12560
|
+
bestScore = checkpoint.bestScore;
|
|
12561
|
+
bestConfig = checkpoint.bestConfiguration || bestConfig;
|
|
12562
|
+
stagnationRounds = checkpoint.stats.convergenceInfo?.stagnationRounds || 0;
|
|
12563
|
+
}
|
|
12564
|
+
}
|
|
12565
|
+
for (let i = startRound; i < this.numTrials; i++) {
|
|
11928
12566
|
const config = {
|
|
11929
|
-
instruction:
|
|
11930
|
-
bootstrappedDemos: Math.
|
|
11931
|
-
Math.random() * (bootstrappedDemos.length + 1)
|
|
12567
|
+
instruction: instructions[i % instructions.length] || instructions[0] || "",
|
|
12568
|
+
bootstrappedDemos: Math.min(
|
|
12569
|
+
Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
|
|
12570
|
+
this.maxBootstrappedDemos
|
|
11932
12571
|
),
|
|
11933
|
-
labeledExamples: Math.
|
|
11934
|
-
Math.random() * (labeledExamples.length + 1)
|
|
12572
|
+
labeledExamples: Math.min(
|
|
12573
|
+
Math.floor(Math.random() * (labeledExamples.length + 1)),
|
|
12574
|
+
this.maxLabeledDemos
|
|
11935
12575
|
)
|
|
11936
12576
|
};
|
|
11937
|
-
configs.push(config);
|
|
11938
|
-
}
|
|
11939
|
-
for (let i = 0; i < configs.length; i++) {
|
|
11940
|
-
const config = configs[i];
|
|
11941
|
-
if (!config) continue;
|
|
11942
12577
|
const score = await this.evaluateConfig(
|
|
12578
|
+
program,
|
|
11943
12579
|
config,
|
|
11944
12580
|
bootstrappedDemos,
|
|
11945
12581
|
labeledExamples,
|
|
11946
12582
|
valset,
|
|
11947
|
-
metricFn
|
|
11948
|
-
i
|
|
12583
|
+
metricFn
|
|
11949
12584
|
);
|
|
11950
|
-
|
|
11951
|
-
|
|
12585
|
+
scoreHistory.push(score);
|
|
12586
|
+
const improvement = score - bestScore;
|
|
12587
|
+
if (improvement > this.minImprovementThreshold) {
|
|
11952
12588
|
bestScore = score;
|
|
11953
12589
|
bestConfig = config;
|
|
11954
|
-
|
|
11955
|
-
|
|
11956
|
-
|
|
11957
|
-
);
|
|
11958
|
-
}
|
|
12590
|
+
stagnationRounds = 0;
|
|
12591
|
+
} else {
|
|
12592
|
+
stagnationRounds++;
|
|
11959
12593
|
}
|
|
11960
|
-
|
|
12594
|
+
await this.updateOptimizationProgress(
|
|
11961
12595
|
i + 1,
|
|
11962
|
-
|
|
11963
|
-
|
|
11964
|
-
|
|
11965
|
-
|
|
11966
|
-
|
|
11967
|
-
|
|
11968
|
-
|
|
11969
|
-
|
|
11970
|
-
|
|
11971
|
-
|
|
11972
|
-
|
|
11973
|
-
|
|
11974
|
-
|
|
11975
|
-
);
|
|
11976
|
-
const score = await this.evaluateConfig(
|
|
11977
|
-
nextConfig,
|
|
11978
|
-
bootstrappedDemos,
|
|
11979
|
-
labeledExamples,
|
|
11980
|
-
valset,
|
|
11981
|
-
metricFn,
|
|
11982
|
-
i
|
|
12596
|
+
score,
|
|
12597
|
+
config,
|
|
12598
|
+
"MiPRO",
|
|
12599
|
+
this.getConfiguration(),
|
|
12600
|
+
bestScore,
|
|
12601
|
+
bestConfig,
|
|
12602
|
+
{
|
|
12603
|
+
stagnationRounds,
|
|
12604
|
+
bootstrappedDemos: bootstrappedDemos.length,
|
|
12605
|
+
labeledExamples: labeledExamples.length,
|
|
12606
|
+
instructions: instructions.length
|
|
12607
|
+
},
|
|
12608
|
+
options
|
|
11983
12609
|
);
|
|
11984
|
-
|
|
11985
|
-
|
|
11986
|
-
|
|
11987
|
-
|
|
11988
|
-
|
|
11989
|
-
|
|
11990
|
-
|
|
11991
|
-
)
|
|
11992
|
-
|
|
11993
|
-
|
|
11994
|
-
|
|
11995
|
-
|
|
11996
|
-
|
|
11997
|
-
|
|
11998
|
-
|
|
11999
|
-
if (this.verbose) {
|
|
12000
|
-
console.log(
|
|
12001
|
-
`Early stopping triggered after ${i + 1} trials. No improvement for ${trialsWithoutImprovement} trials.`
|
|
12002
|
-
);
|
|
12003
|
-
}
|
|
12004
|
-
break;
|
|
12610
|
+
if (this.onProgress) {
|
|
12611
|
+
this.onProgress({
|
|
12612
|
+
round: i + 1,
|
|
12613
|
+
totalRounds: this.numTrials,
|
|
12614
|
+
currentScore: score,
|
|
12615
|
+
bestScore,
|
|
12616
|
+
tokensUsed: this.stats.resourceUsage.totalTokens,
|
|
12617
|
+
timeElapsed: Date.now(),
|
|
12618
|
+
successfulExamples: this.stats.successfulDemos,
|
|
12619
|
+
totalExamples: this.examples.length,
|
|
12620
|
+
currentConfiguration: config,
|
|
12621
|
+
convergenceInfo: {
|
|
12622
|
+
improvement,
|
|
12623
|
+
stagnationRounds,
|
|
12624
|
+
isConverging: stagnationRounds < this.earlyStoppingTrials
|
|
12005
12625
|
}
|
|
12006
|
-
}
|
|
12007
|
-
lastBestScore = bestScore;
|
|
12008
|
-
trialsWithoutImprovement = 0;
|
|
12009
|
-
}
|
|
12626
|
+
});
|
|
12010
12627
|
}
|
|
12011
12628
|
updateProgressBar(
|
|
12012
12629
|
i + 1,
|
|
@@ -12016,243 +12633,91 @@ ${dataContext}
|
|
|
12016
12633
|
"Running MIPROv2 optimization",
|
|
12017
12634
|
30
|
|
12018
12635
|
);
|
|
12019
|
-
if (this.
|
|
12020
|
-
|
|
12021
|
-
|
|
12022
|
-
`Running full evaluation on best configuration at trial ${i + 1}`
|
|
12023
|
-
);
|
|
12024
|
-
}
|
|
12025
|
-
const fullScore = await this.fullEvaluation(
|
|
12026
|
-
bestConfig,
|
|
12027
|
-
bootstrappedDemos,
|
|
12028
|
-
labeledExamples,
|
|
12029
|
-
valset,
|
|
12030
|
-
metricFn
|
|
12031
|
-
);
|
|
12032
|
-
if (this.verbose) {
|
|
12033
|
-
console.log(`Full evaluation score: ${fullScore}`);
|
|
12034
|
-
}
|
|
12035
|
-
bestScore = fullScore;
|
|
12636
|
+
if (this.checkCostLimits()) {
|
|
12637
|
+
this.triggerEarlyStopping("Cost limit reached", i + 1);
|
|
12638
|
+
break;
|
|
12036
12639
|
}
|
|
12037
|
-
|
|
12038
|
-
|
|
12039
|
-
|
|
12040
|
-
|
|
12041
|
-
"Optimization failed to find any valid configurations, using default fallback configuration"
|
|
12640
|
+
if (stagnationRounds >= this.earlyStoppingTrials) {
|
|
12641
|
+
this.triggerEarlyStopping(
|
|
12642
|
+
`No improvement for ${this.earlyStoppingTrials} trials`,
|
|
12643
|
+
i - stagnationRounds + 1
|
|
12042
12644
|
);
|
|
12645
|
+
break;
|
|
12043
12646
|
}
|
|
12044
|
-
|
|
12045
|
-
|
|
12046
|
-
|
|
12047
|
-
|
|
12048
|
-
bootstrappedDemos,
|
|
12049
|
-
labeledExamples,
|
|
12050
|
-
valset,
|
|
12051
|
-
metricFn,
|
|
12052
|
-
this.numTrials - 1
|
|
12647
|
+
if (this.checkTargetScore(bestScore)) {
|
|
12648
|
+
this.triggerEarlyStopping(
|
|
12649
|
+
`Target score ${this.targetScore} reached`,
|
|
12650
|
+
i + 1
|
|
12053
12651
|
);
|
|
12054
|
-
|
|
12055
|
-
if (this.verbose) {
|
|
12056
|
-
console.error("Error evaluating default configuration:", err);
|
|
12057
|
-
}
|
|
12058
|
-
bestScore = 0;
|
|
12652
|
+
break;
|
|
12059
12653
|
}
|
|
12060
12654
|
}
|
|
12655
|
+
this.stats.convergenceInfo.stagnationRounds = stagnationRounds;
|
|
12656
|
+
this.stats.convergenceInfo.finalImprovement = scoreHistory.length > 1 ? bestScore - scoreHistory[0] : 0;
|
|
12657
|
+
this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
|
|
12061
12658
|
return { bestConfig, bestScore };
|
|
12062
12659
|
}
|
|
12063
|
-
|
|
12064
|
-
|
|
12065
|
-
*/
|
|
12066
|
-
async evaluateConfig(config, bootstrappedDemos, labeledExamples, valset, metricFn, trialIndex) {
|
|
12660
|
+
async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, valset, metricFn) {
|
|
12661
|
+
const testProgram = { ...program };
|
|
12067
12662
|
this.applyConfigToProgram(
|
|
12068
|
-
|
|
12663
|
+
testProgram,
|
|
12069
12664
|
config,
|
|
12070
12665
|
bootstrappedDemos,
|
|
12071
12666
|
labeledExamples
|
|
12072
12667
|
);
|
|
12073
|
-
let
|
|
12074
|
-
|
|
12075
|
-
|
|
12076
|
-
const minibatchEvalSet = [];
|
|
12077
|
-
for (let j = 0; j < this.minibatchSize; j++) {
|
|
12078
|
-
const idx = (startIdx + j) % valset.length;
|
|
12079
|
-
const example = valset[idx];
|
|
12080
|
-
if (example) {
|
|
12081
|
-
minibatchEvalSet.push(example);
|
|
12082
|
-
}
|
|
12083
|
-
}
|
|
12084
|
-
evalSet = minibatchEvalSet;
|
|
12085
|
-
}
|
|
12086
|
-
let sumOfScores = 0;
|
|
12668
|
+
let totalScore = 0;
|
|
12669
|
+
let count = 0;
|
|
12670
|
+
const evalSet = valset.slice(0, Math.min(5, valset.length));
|
|
12087
12671
|
for (const example of evalSet) {
|
|
12088
12672
|
try {
|
|
12089
|
-
const prediction = await
|
|
12090
|
-
|
|
12091
|
-
|
|
12092
|
-
|
|
12093
|
-
|
|
12094
|
-
|
|
12095
|
-
|
|
12096
|
-
|
|
12097
|
-
|
|
12098
|
-
|
|
12099
|
-
return sumOfScores / evalSet.length;
|
|
12100
|
-
}
|
|
12101
|
-
/**
|
|
12102
|
-
* Run full evaluation on the entire validation set
|
|
12103
|
-
*/
|
|
12104
|
-
async fullEvaluation(config, bootstrappedDemos, labeledExamples, valset, metricFn) {
|
|
12105
|
-
this.applyConfigToProgram(
|
|
12106
|
-
this.program,
|
|
12107
|
-
config,
|
|
12108
|
-
bootstrappedDemos,
|
|
12109
|
-
labeledExamples
|
|
12110
|
-
);
|
|
12111
|
-
let sumOfScores = 0;
|
|
12112
|
-
for (const example of valset) {
|
|
12113
|
-
try {
|
|
12114
|
-
const prediction = await this.program.forward(this.ai, example);
|
|
12115
|
-
const score = metricFn({ prediction, example });
|
|
12116
|
-
sumOfScores += score;
|
|
12117
|
-
} catch (err) {
|
|
12118
|
-
if (this.verbose) {
|
|
12119
|
-
console.error("Error evaluating example:", err);
|
|
12120
|
-
}
|
|
12673
|
+
const prediction = await testProgram.forward(
|
|
12674
|
+
this.studentAI,
|
|
12675
|
+
example
|
|
12676
|
+
);
|
|
12677
|
+
const score = await metricFn({ prediction, example });
|
|
12678
|
+
totalScore += score;
|
|
12679
|
+
count++;
|
|
12680
|
+
this.stats.totalCalls++;
|
|
12681
|
+
} catch {
|
|
12682
|
+
continue;
|
|
12121
12683
|
}
|
|
12122
12684
|
}
|
|
12123
|
-
|
|
12124
|
-
return sumOfScores / valset.length;
|
|
12125
|
-
}
|
|
12126
|
-
/**
|
|
12127
|
-
* Implements a Bayesian-inspired selection of the next configuration to try
|
|
12128
|
-
* This is a simplified version using Upper Confidence Bound (UCB) strategy
|
|
12129
|
-
*/
|
|
12130
|
-
selectNextConfiguration(evaluatedConfigs, maxBootstrappedDemos, maxLabeledExamples, instructions) {
|
|
12131
|
-
if (evaluatedConfigs.length < 5) {
|
|
12132
|
-
const instructionIndex = Math.floor(Math.random() * instructions.length);
|
|
12133
|
-
return {
|
|
12134
|
-
instruction: instructions[instructionIndex] || "",
|
|
12135
|
-
bootstrappedDemos: Math.floor(
|
|
12136
|
-
Math.random() * (maxBootstrappedDemos + 1)
|
|
12137
|
-
),
|
|
12138
|
-
labeledExamples: Math.floor(Math.random() * (maxLabeledExamples + 1))
|
|
12139
|
-
};
|
|
12140
|
-
}
|
|
12141
|
-
const sortedConfigs = [...evaluatedConfigs].sort(
|
|
12142
|
-
(a, b) => b.score - a.score
|
|
12143
|
-
);
|
|
12144
|
-
const topConfigs = sortedConfigs.slice(0, Math.min(3, sortedConfigs.length));
|
|
12145
|
-
const meanBootstrappedDemos = topConfigs.reduce((sum, c) => sum + c.config.bootstrappedDemos, 0) / topConfigs.length;
|
|
12146
|
-
const meanLabeledExamples = topConfigs.reduce((sum, c) => sum + c.config.labeledExamples, 0) / topConfigs.length;
|
|
12147
|
-
const popularInstructions = topConfigs.map((c) => c.config.instruction);
|
|
12148
|
-
const explorationFactor = Math.max(
|
|
12149
|
-
0.2,
|
|
12150
|
-
1 - evaluatedConfigs.length / this.numTrials
|
|
12151
|
-
);
|
|
12152
|
-
let newBootstrappedDemos;
|
|
12153
|
-
let newLabeledExamples;
|
|
12154
|
-
let newInstruction;
|
|
12155
|
-
if (Math.random() < 0.7) {
|
|
12156
|
-
newBootstrappedDemos = Math.min(
|
|
12157
|
-
maxBootstrappedDemos,
|
|
12158
|
-
Math.max(
|
|
12159
|
-
0,
|
|
12160
|
-
Math.round(
|
|
12161
|
-
meanBootstrappedDemos + (Math.random() * 2 - 1) * explorationFactor * 2
|
|
12162
|
-
)
|
|
12163
|
-
)
|
|
12164
|
-
);
|
|
12165
|
-
} else {
|
|
12166
|
-
newBootstrappedDemos = Math.floor(
|
|
12167
|
-
Math.random() * (maxBootstrappedDemos + 1)
|
|
12168
|
-
);
|
|
12169
|
-
}
|
|
12170
|
-
if (Math.random() < 0.7) {
|
|
12171
|
-
newLabeledExamples = Math.min(
|
|
12172
|
-
maxLabeledExamples,
|
|
12173
|
-
Math.max(
|
|
12174
|
-
0,
|
|
12175
|
-
Math.round(
|
|
12176
|
-
meanLabeledExamples + (Math.random() * 2 - 1) * explorationFactor * 2
|
|
12177
|
-
)
|
|
12178
|
-
)
|
|
12179
|
-
);
|
|
12180
|
-
} else {
|
|
12181
|
-
newLabeledExamples = Math.floor(Math.random() * (maxLabeledExamples + 1));
|
|
12182
|
-
}
|
|
12183
|
-
if (Math.random() < 0.7 && popularInstructions.length > 0) {
|
|
12184
|
-
const idx = Math.floor(Math.random() * popularInstructions.length);
|
|
12185
|
-
newInstruction = popularInstructions[idx] || "";
|
|
12186
|
-
} else {
|
|
12187
|
-
const idx = Math.floor(Math.random() * instructions.length);
|
|
12188
|
-
newInstruction = instructions[idx] || "";
|
|
12189
|
-
}
|
|
12190
|
-
return {
|
|
12191
|
-
instruction: newInstruction,
|
|
12192
|
-
bootstrappedDemos: newBootstrappedDemos,
|
|
12193
|
-
labeledExamples: newLabeledExamples
|
|
12194
|
-
};
|
|
12685
|
+
return count > 0 ? totalScore / count : 0;
|
|
12195
12686
|
}
|
|
12196
|
-
/**
|
|
12197
|
-
* Applies a configuration to a program instance
|
|
12198
|
-
*/
|
|
12199
12687
|
applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
|
|
12200
|
-
|
|
12201
|
-
|
|
12688
|
+
if (program.setInstruction) {
|
|
12689
|
+
program.setInstruction(config.instruction);
|
|
12690
|
+
}
|
|
12691
|
+
if (config.bootstrappedDemos > 0 && program.setDemos) {
|
|
12202
12692
|
program.setDemos(bootstrappedDemos.slice(0, config.bootstrappedDemos));
|
|
12203
12693
|
}
|
|
12204
|
-
if (config.labeledExamples > 0) {
|
|
12694
|
+
if (config.labeledExamples > 0 && program.setExamples) {
|
|
12205
12695
|
program.setExamples(labeledExamples.slice(0, config.labeledExamples));
|
|
12206
12696
|
}
|
|
12207
12697
|
}
|
|
12208
|
-
/**
|
|
12209
|
-
* Sets instruction to a program
|
|
12210
|
-
* Note: Workaround since setInstruction may not be available directly
|
|
12211
|
-
*/
|
|
12212
|
-
setInstructionToProgram(program, instruction) {
|
|
12213
|
-
const programWithInstruction = program;
|
|
12214
|
-
programWithInstruction.setInstruction?.(instruction);
|
|
12215
|
-
}
|
|
12216
12698
|
/**
|
|
12217
12699
|
* The main compile method to run MIPROv2 optimization
|
|
12218
|
-
* @param metricFn Evaluation metric function
|
|
12219
|
-
* @param options Optional configuration options
|
|
12220
|
-
* @returns The optimization result
|
|
12221
12700
|
*/
|
|
12222
|
-
async compile(metricFn, options) {
|
|
12701
|
+
async compile(program, metricFn, options) {
|
|
12702
|
+
const startTime = Date.now();
|
|
12703
|
+
this.setupRandomSeed();
|
|
12223
12704
|
const miproOptions = options;
|
|
12224
12705
|
if (miproOptions?.auto) {
|
|
12225
12706
|
this.configureAuto(miproOptions.auto);
|
|
12226
12707
|
}
|
|
12227
|
-
const
|
|
12228
|
-
|
|
12229
|
-
if (this.verbose) {
|
|
12708
|
+
const valset = this.getValidationSet(options) || (miproOptions?.valset ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
|
|
12709
|
+
if (this.verbose || options?.verbose) {
|
|
12230
12710
|
console.log(`Starting MIPROv2 optimization with ${this.numTrials} trials`);
|
|
12231
12711
|
console.log(
|
|
12232
|
-
`Using ${
|
|
12712
|
+
`Using ${this.examples.length} examples for training and ${valset.length} for validation`
|
|
12233
12713
|
);
|
|
12234
|
-
|
|
12235
|
-
|
|
12236
|
-
if (this.verbose) {
|
|
12237
|
-
console.log("Using provided teacher to assist with bootstrapping");
|
|
12714
|
+
if (this.teacherAI) {
|
|
12715
|
+
console.log("Using separate teacher model for instruction generation");
|
|
12238
12716
|
}
|
|
12239
|
-
const bootstrapperWithTeacher = new AxBootstrapFewShot({
|
|
12240
|
-
ai: this.ai,
|
|
12241
|
-
program: this.program,
|
|
12242
|
-
examples: this.examples,
|
|
12243
|
-
options: {
|
|
12244
|
-
maxDemos: this.maxBootstrappedDemos,
|
|
12245
|
-
maxRounds: 3,
|
|
12246
|
-
verboseMode: this.verbose,
|
|
12247
|
-
teacherAI: this.ai
|
|
12248
|
-
// Use the same AI but with the teacher program
|
|
12249
|
-
}
|
|
12250
|
-
});
|
|
12251
|
-
this.bootstrapper = bootstrapperWithTeacher;
|
|
12252
12717
|
}
|
|
12253
12718
|
let bootstrappedDemos = [];
|
|
12254
12719
|
if (this.maxBootstrappedDemos > 0) {
|
|
12255
|
-
bootstrappedDemos = await this.bootstrapFewShotExamples(metricFn);
|
|
12720
|
+
bootstrappedDemos = await this.bootstrapFewShotExamples(program, metricFn);
|
|
12256
12721
|
if (this.verbose) {
|
|
12257
12722
|
console.log(
|
|
12258
12723
|
`Generated ${bootstrappedDemos.length} bootstrapped demonstrations`
|
|
@@ -12268,38 +12733,191 @@ ${dataContext}
|
|
|
12268
12733
|
);
|
|
12269
12734
|
}
|
|
12270
12735
|
}
|
|
12271
|
-
const instructions = await this.proposeInstructionCandidates();
|
|
12736
|
+
const instructions = await this.proposeInstructionCandidates(options);
|
|
12272
12737
|
if (this.verbose) {
|
|
12273
12738
|
console.log(`Generated ${instructions.length} instruction candidates`);
|
|
12739
|
+
if (this.hasTeacherAI(options)) {
|
|
12740
|
+
console.log("Using teacher AI for instruction generation");
|
|
12741
|
+
}
|
|
12274
12742
|
}
|
|
12275
|
-
const { bestConfig, bestScore } = await this.
|
|
12743
|
+
const { bestConfig, bestScore } = await this.runOptimization(
|
|
12744
|
+
program,
|
|
12276
12745
|
bootstrappedDemos,
|
|
12277
12746
|
labeledExamples,
|
|
12278
12747
|
instructions,
|
|
12279
12748
|
valset,
|
|
12280
|
-
metricFn
|
|
12749
|
+
metricFn,
|
|
12750
|
+
options
|
|
12281
12751
|
);
|
|
12282
|
-
if (this.verbose) {
|
|
12752
|
+
if (this.verbose || options?.verbose) {
|
|
12283
12753
|
console.log(`Optimization complete. Best score: ${bestScore}`);
|
|
12284
12754
|
console.log(`Best configuration: ${JSON.stringify(bestConfig)}`);
|
|
12285
12755
|
}
|
|
12286
|
-
this.
|
|
12287
|
-
this.
|
|
12756
|
+
if (this.checkTargetScore(bestScore)) {
|
|
12757
|
+
this.triggerEarlyStopping(
|
|
12758
|
+
`Target score ${this.targetScore} reached with score ${bestScore}`,
|
|
12759
|
+
this.numTrials
|
|
12760
|
+
);
|
|
12761
|
+
}
|
|
12762
|
+
let signature;
|
|
12763
|
+
if ("getSignature" in program && typeof program.getSignature === "function") {
|
|
12764
|
+
signature = program.getSignature();
|
|
12765
|
+
} else {
|
|
12766
|
+
signature = "input -> output";
|
|
12767
|
+
}
|
|
12768
|
+
const optimizedGen = new AxGen(signature);
|
|
12769
|
+
this.applyConfigToAxGen(
|
|
12770
|
+
optimizedGen,
|
|
12288
12771
|
bestConfig,
|
|
12289
12772
|
bootstrappedDemos,
|
|
12290
12773
|
labeledExamples
|
|
12291
12774
|
);
|
|
12775
|
+
this.updateResourceUsage(startTime);
|
|
12776
|
+
this.stats.convergenceInfo.converged = true;
|
|
12777
|
+
this.stats.convergenceInfo.finalImprovement = bestScore;
|
|
12778
|
+
await this.saveFinalCheckpoint(
|
|
12779
|
+
"MiPRO",
|
|
12780
|
+
this.getConfiguration(),
|
|
12781
|
+
bestScore,
|
|
12782
|
+
bestConfig,
|
|
12783
|
+
{
|
|
12784
|
+
bootstrappedDemos: bootstrappedDemos.length,
|
|
12785
|
+
labeledExamples: labeledExamples.length,
|
|
12786
|
+
instructions: instructions.length,
|
|
12787
|
+
optimizedGen: !!optimizedGen
|
|
12788
|
+
},
|
|
12789
|
+
options
|
|
12790
|
+
);
|
|
12292
12791
|
return {
|
|
12293
|
-
|
|
12294
|
-
|
|
12792
|
+
demos: bootstrappedDemos,
|
|
12793
|
+
stats: this.stats,
|
|
12794
|
+
bestScore,
|
|
12795
|
+
optimizedGen,
|
|
12796
|
+
finalConfiguration: {
|
|
12797
|
+
instruction: bestConfig.instruction,
|
|
12798
|
+
bootstrappedDemos: bestConfig.bootstrappedDemos,
|
|
12799
|
+
labeledExamples: bestConfig.labeledExamples,
|
|
12800
|
+
numCandidates: this.numCandidates,
|
|
12801
|
+
numTrials: this.numTrials
|
|
12802
|
+
}
|
|
12295
12803
|
};
|
|
12296
12804
|
}
|
|
12297
12805
|
/**
|
|
12298
|
-
*
|
|
12299
|
-
* @returns Optimization statistics or undefined if not available
|
|
12806
|
+
* Applies a configuration to an AxGen instance
|
|
12300
12807
|
*/
|
|
12301
|
-
|
|
12302
|
-
|
|
12808
|
+
applyConfigToAxGen(axgen, config, bootstrappedDemos, labeledExamples) {
|
|
12809
|
+
if ("setInstruction" in axgen && typeof axgen.setInstruction === "function") {
|
|
12810
|
+
axgen.setInstruction(config.instruction);
|
|
12811
|
+
}
|
|
12812
|
+
if (config.bootstrappedDemos > 0) {
|
|
12813
|
+
axgen.setDemos(bootstrappedDemos.slice(0, config.bootstrappedDemos));
|
|
12814
|
+
}
|
|
12815
|
+
if (config.labeledExamples > 0) {
|
|
12816
|
+
axgen.setExamples(
|
|
12817
|
+
labeledExamples.slice(
|
|
12818
|
+
0,
|
|
12819
|
+
config.labeledExamples
|
|
12820
|
+
)
|
|
12821
|
+
);
|
|
12822
|
+
}
|
|
12823
|
+
}
|
|
12824
|
+
/**
|
|
12825
|
+
* Get optimizer-specific configuration
|
|
12826
|
+
* @returns Current optimizer configuration
|
|
12827
|
+
*/
|
|
12828
|
+
getConfiguration() {
|
|
12829
|
+
return {
|
|
12830
|
+
numCandidates: this.numCandidates,
|
|
12831
|
+
initTemperature: this.initTemperature,
|
|
12832
|
+
maxBootstrappedDemos: this.maxBootstrappedDemos,
|
|
12833
|
+
maxLabeledDemos: this.maxLabeledDemos,
|
|
12834
|
+
numTrials: this.numTrials,
|
|
12835
|
+
minibatch: this.minibatch,
|
|
12836
|
+
minibatchSize: this.minibatchSize,
|
|
12837
|
+
minibatchFullEvalSteps: this.minibatchFullEvalSteps,
|
|
12838
|
+
programAwareProposer: this.programAwareProposer,
|
|
12839
|
+
dataAwareProposer: this.dataAwareProposer,
|
|
12840
|
+
tipAwareProposer: this.tipAwareProposer,
|
|
12841
|
+
fewshotAwareProposer: this.fewshotAwareProposer,
|
|
12842
|
+
earlyStoppingTrials: this.earlyStoppingTrials,
|
|
12843
|
+
minImprovementThreshold: this.minImprovementThreshold,
|
|
12844
|
+
bayesianOptimization: this.bayesianOptimization,
|
|
12845
|
+
acquisitionFunction: this.acquisitionFunction,
|
|
12846
|
+
explorationWeight: this.explorationWeight
|
|
12847
|
+
};
|
|
12848
|
+
}
|
|
12849
|
+
/**
|
|
12850
|
+
* Update optimizer configuration
|
|
12851
|
+
* @param config New configuration to merge with existing
|
|
12852
|
+
*/
|
|
12853
|
+
updateConfiguration(config) {
|
|
12854
|
+
if (config.numCandidates !== void 0) {
|
|
12855
|
+
this.numCandidates = config.numCandidates;
|
|
12856
|
+
}
|
|
12857
|
+
if (config.initTemperature !== void 0) {
|
|
12858
|
+
this.initTemperature = config.initTemperature;
|
|
12859
|
+
}
|
|
12860
|
+
if (config.maxBootstrappedDemos !== void 0) {
|
|
12861
|
+
this.maxBootstrappedDemos = config.maxBootstrappedDemos;
|
|
12862
|
+
}
|
|
12863
|
+
if (config.maxLabeledDemos !== void 0) {
|
|
12864
|
+
this.maxLabeledDemos = config.maxLabeledDemos;
|
|
12865
|
+
}
|
|
12866
|
+
if (config.numTrials !== void 0) {
|
|
12867
|
+
this.numTrials = config.numTrials;
|
|
12868
|
+
}
|
|
12869
|
+
if (config.minibatch !== void 0) {
|
|
12870
|
+
this.minibatch = config.minibatch;
|
|
12871
|
+
}
|
|
12872
|
+
if (config.minibatchSize !== void 0) {
|
|
12873
|
+
this.minibatchSize = config.minibatchSize;
|
|
12874
|
+
}
|
|
12875
|
+
if (config.earlyStoppingTrials !== void 0) {
|
|
12876
|
+
this.earlyStoppingTrials = config.earlyStoppingTrials;
|
|
12877
|
+
}
|
|
12878
|
+
if (config.minImprovementThreshold !== void 0) {
|
|
12879
|
+
this.minImprovementThreshold = config.minImprovementThreshold;
|
|
12880
|
+
}
|
|
12881
|
+
if (config.verbose !== void 0) {
|
|
12882
|
+
this.verbose = config.verbose;
|
|
12883
|
+
}
|
|
12884
|
+
}
|
|
12885
|
+
/**
|
|
12886
|
+
* Reset optimizer state for reuse with different programs
|
|
12887
|
+
*/
|
|
12888
|
+
reset() {
|
|
12889
|
+
super.reset();
|
|
12890
|
+
this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
|
|
12891
|
+
}
|
|
12892
|
+
/**
|
|
12893
|
+
* Validate that the optimizer can handle the given program
|
|
12894
|
+
* @param program Program to validate
|
|
12895
|
+
* @returns Validation result with any issues found
|
|
12896
|
+
*/
|
|
12897
|
+
validateProgram(program) {
|
|
12898
|
+
const result = super.validateProgram(program);
|
|
12899
|
+
if (this.examples.length < this.maxBootstrappedDemos + this.maxLabeledDemos) {
|
|
12900
|
+
result.issues.push(
|
|
12901
|
+
`Not enough examples: need at least ${this.maxBootstrappedDemos + this.maxLabeledDemos}, got ${this.examples.length}`
|
|
12902
|
+
);
|
|
12903
|
+
result.suggestions.push(
|
|
12904
|
+
"Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
|
|
12905
|
+
);
|
|
12906
|
+
}
|
|
12907
|
+
const valSetSize = this.getValidationSet().length;
|
|
12908
|
+
if (valSetSize < 5) {
|
|
12909
|
+
result.issues.push(
|
|
12910
|
+
"Validation set too small for reliable MiPRO optimization"
|
|
12911
|
+
);
|
|
12912
|
+
result.suggestions.push(
|
|
12913
|
+
"Provide more examples or a larger validation set"
|
|
12914
|
+
);
|
|
12915
|
+
}
|
|
12916
|
+
return {
|
|
12917
|
+
isValid: result.issues.length === 0,
|
|
12918
|
+
issues: result.issues,
|
|
12919
|
+
suggestions: result.suggestions
|
|
12920
|
+
};
|
|
12303
12921
|
}
|
|
12304
12922
|
};
|
|
12305
12923
|
|
|
@@ -12546,7 +13164,7 @@ var AxTestPrompt = class {
|
|
|
12546
13164
|
throw new Error("Invalid example");
|
|
12547
13165
|
}
|
|
12548
13166
|
const res = await this.program.forward(this.ai, ex);
|
|
12549
|
-
const score = metricFn({ prediction: res, example: ex });
|
|
13167
|
+
const score = await metricFn({ prediction: res, example: ex });
|
|
12550
13168
|
sumOfScores += score;
|
|
12551
13169
|
const et = (/* @__PURE__ */ new Date()).getTime() - st;
|
|
12552
13170
|
updateProgressBar(i, total, sumOfScores, et, "Testing Prompt", 30);
|
|
@@ -14580,7 +15198,6 @@ var AxRAG = class extends AxChainOfThought {
|
|
|
14580
15198
|
);
|
|
14581
15199
|
this.genQuery = new AxGen(qsig);
|
|
14582
15200
|
this.queryFn = queryFn;
|
|
14583
|
-
this.register(this.genQuery);
|
|
14584
15201
|
}
|
|
14585
15202
|
async forward(ai, values, options) {
|
|
14586
15203
|
let question;
|
|
@@ -14657,6 +15274,7 @@ export {
|
|
|
14657
15274
|
AxAssertionError,
|
|
14658
15275
|
AxBalancer,
|
|
14659
15276
|
AxBaseAI,
|
|
15277
|
+
AxBaseOptimizer,
|
|
14660
15278
|
AxBootstrapFewShot,
|
|
14661
15279
|
AxChainOfThought,
|
|
14662
15280
|
AxDB,
|
|
@@ -14666,6 +15284,7 @@ export {
|
|
|
14666
15284
|
AxDBMemory,
|
|
14667
15285
|
AxDBPinecone,
|
|
14668
15286
|
AxDBWeaviate,
|
|
15287
|
+
AxDefaultCostTracker,
|
|
14669
15288
|
AxDefaultQueryRewriter,
|
|
14670
15289
|
AxDefaultResultReranker,
|
|
14671
15290
|
AxDockerSession,
|