@ax-llm/ax 12.0.12 → 12.0.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/index.cjs CHANGED
@@ -6257,15 +6257,24 @@ var MemoryImpl = class {
6257
6257
  debugRequest(items, this.options?.debugHideSystemPrompt);
6258
6258
  }
6259
6259
  }
6260
+ addFunctionResults(results) {
6261
+ const chat = results.map(({ index, ...value }) => ({
6262
+ index,
6263
+ value: structuredClone(value)
6264
+ }));
6265
+ const lastItem = this.getLast();
6266
+ if (lastItem?.role === "function") {
6267
+ lastItem.chat.push(...chat);
6268
+ } else {
6269
+ this.data.push({ role: "function", chat });
6270
+ }
6271
+ }
6260
6272
  addResponse(results) {
6261
- const chat = results.map((result) => ({
6262
- index: result.index,
6263
- value: structuredClone(result)
6273
+ const chat = results.map(({ index, ...value }) => ({
6274
+ index,
6275
+ value: structuredClone(value)
6264
6276
  }));
6265
- this.data.push({
6266
- role: "assistant",
6267
- chat
6268
- });
6277
+ this.data.push({ role: "assistant", chat });
6269
6278
  if (this.options?.debug) {
6270
6279
  for (const result of results) {
6271
6280
  debugResponse(result);
@@ -6353,9 +6362,20 @@ var MemoryImpl = class {
6353
6362
  history(index) {
6354
6363
  const result = [];
6355
6364
  for (const { role, chat } of this.data) {
6356
- const value = chat.find((v) => v.index === index)?.value;
6357
- if (value) {
6358
- result.push({ role, ...value });
6365
+ let values;
6366
+ if (role === "function") {
6367
+ values = chat.filter((v) => v.index === index).map((v) => v.value);
6368
+ } else {
6369
+ values = chat.find((v) => v.index === index)?.value;
6370
+ }
6371
+ if (Array.isArray(values)) {
6372
+ result.push(
6373
+ ...values.map(
6374
+ (v) => ({ ...v, role })
6375
+ )
6376
+ );
6377
+ } else if (values) {
6378
+ result.push({ ...values, role });
6359
6379
  }
6360
6380
  }
6361
6381
  return result;
@@ -6393,20 +6413,8 @@ var AxMemory = class {
6393
6413
  axValidateChatResponseResult(results);
6394
6414
  this.getMemory(sessionId).addResponse(results);
6395
6415
  }
6396
- addFunctionResult({
6397
- functionId,
6398
- isError,
6399
- index,
6400
- result
6401
- }, sessionId) {
6402
- const functionMessage = {
6403
- role: "function",
6404
- functionId,
6405
- isError,
6406
- result
6407
- };
6408
- axValidateChatRequestMessage(functionMessage);
6409
- this.getMemory(sessionId).addRequest([functionMessage], index);
6416
+ addFunctionResults(results, sessionId) {
6417
+ this.getMemory(sessionId).addFunctionResults(results);
6410
6418
  }
6411
6419
  updateResult(result, sessionId) {
6412
6420
  this.getMemory(sessionId).updateResult(result);
@@ -7643,51 +7651,48 @@ var processFunctions = async ({
7643
7651
  }
7644
7652
  return {
7645
7653
  result: functionResult ?? "",
7654
+ role: "function",
7646
7655
  functionId: func.id,
7647
7656
  index
7648
7657
  };
7649
7658
  }).catch((e) => {
7650
- if (e instanceof FunctionError) {
7651
- const result = e.getFixingInstructions();
7652
- if (span) {
7653
- const errorEventData = {
7654
- name: func.name,
7655
- message: e.toString()
7656
- };
7657
- if (!excludeContentFromTrace) {
7658
- errorEventData.args = func.args;
7659
- errorEventData.fixing_instructions = result;
7660
- }
7661
- span.addEvent("function.error", errorEventData);
7659
+ if (!(e instanceof FunctionError)) {
7660
+ throw e;
7661
+ }
7662
+ const result = e.getFixingInstructions();
7663
+ if (span) {
7664
+ const errorEventData = {
7665
+ name: func.name,
7666
+ message: e.toString()
7667
+ };
7668
+ if (!excludeContentFromTrace) {
7669
+ errorEventData.args = func.args;
7670
+ errorEventData.fixing_instructions = result;
7662
7671
  }
7663
- mem.addFunctionResult(
7664
- {
7665
- functionId: func.id,
7666
- isError: true,
7667
- index,
7668
- result
7669
- },
7670
- sessionId
7671
- );
7672
- mem.addTag("error", sessionId);
7673
- if (ai.getOptions().debug) {
7674
- const logger = ai.getLogger();
7675
- logger(`\u274C Function Error Correction:
7672
+ span.addEvent("function.error", errorEventData);
7673
+ }
7674
+ if (ai.getOptions().debug) {
7675
+ const logger = ai.getLogger();
7676
+ logger(`\u274C Function Error Correction:
7676
7677
  ${result}`, {
7677
- tags: ["error"]
7678
- });
7679
- }
7680
- } else {
7681
- throw e;
7678
+ tags: ["error"]
7679
+ });
7682
7680
  }
7681
+ return {
7682
+ functionId: func.id,
7683
+ isError: true,
7684
+ index,
7685
+ result,
7686
+ role: "function"
7687
+ };
7683
7688
  });
7684
7689
  return promise;
7685
7690
  });
7686
7691
  const results = await Promise.all(promises);
7687
- for (const result of results) {
7688
- if (result) {
7689
- mem.addFunctionResult(result, sessionId);
7690
- }
7692
+ const functionResults = results.filter((result) => result !== void 0);
7693
+ mem.addFunctionResults(functionResults, sessionId);
7694
+ if (functionResults.some((result) => result.isError)) {
7695
+ mem.addTag("error", sessionId);
7691
7696
  }
7692
7697
  return functionsExecuted;
7693
7698
  };
@@ -9961,6 +9966,96 @@ var AxProgram = class {
9961
9966
  }
9962
9967
  };
9963
9968
 
9969
+ // dsp/samples.ts
9970
+ function checkForFunctionCalls(mem, sessionId) {
9971
+ const history = mem.history(0, sessionId);
9972
+ const hasFunctionResults = history.some((msg) => msg.role === "function");
9973
+ const hasFunctionCalls = history.some(
9974
+ (msg) => msg.role === "assistant" && "functionCalls" in msg && Array.isArray(msg.functionCalls) && msg.functionCalls.length > 0
9975
+ );
9976
+ return hasFunctionCalls && hasFunctionResults;
9977
+ }
9978
+ function extractFunctionResults(mem, sessionId) {
9979
+ const history = mem.history(0, sessionId);
9980
+ const results = [];
9981
+ const assistantMessages = history.filter(
9982
+ (msg) => msg.role === "assistant" && "functionCalls" in msg && Array.isArray(msg.functionCalls) && msg.functionCalls.length > 0
9983
+ );
9984
+ const functionMessages = history.filter((msg) => msg.role === "function");
9985
+ for (const assistantMsg of assistantMessages) {
9986
+ if ("functionCalls" in assistantMsg && assistantMsg.functionCalls) {
9987
+ for (const funcCall of assistantMsg.functionCalls) {
9988
+ const funcResult = functionMessages.find(
9989
+ (msg) => "functionId" in msg && msg.functionId === funcCall.id
9990
+ );
9991
+ if (funcResult && "result" in funcResult && "functionId" in funcResult) {
9992
+ results.push({
9993
+ index: results.length,
9994
+ // Use sequential index for function results
9995
+ functionName: funcCall.function.name,
9996
+ functionId: funcCall.id,
9997
+ args: funcCall.function.params || "",
9998
+ result: String(funcResult.result),
9999
+ isError: "isError" in funcResult ? Boolean(funcResult.isError) : false
10000
+ });
10001
+ }
10002
+ }
10003
+ }
10004
+ }
10005
+ return results;
10006
+ }
10007
+ async function selectFromSamples(buffer, options, mem, sessionId) {
10008
+ if (!options?.resultPicker || buffer.length <= 1) {
10009
+ return 0;
10010
+ }
10011
+ const resultPicker = options.resultPicker;
10012
+ const hasFunctionCalls = mem ? checkForFunctionCalls(mem, sessionId) : false;
10013
+ if (hasFunctionCalls && mem) {
10014
+ const functionResults = extractFunctionResults(mem, sessionId);
10015
+ const selectedIndex = await resultPicker({
10016
+ type: "function",
10017
+ results: functionResults
10018
+ });
10019
+ if (selectedIndex < 0 || selectedIndex >= functionResults.length) {
10020
+ throw new Error(
10021
+ `Result picker returned invalid index: ${selectedIndex}. Must be between 0 and ${functionResults.length - 1}`
10022
+ );
10023
+ }
10024
+ return selectedIndex;
10025
+ } else {
10026
+ const fieldResults = buffer.map((b, index) => ({
10027
+ index,
10028
+ sample: b.delta
10029
+ }));
10030
+ const selectedIndex = await resultPicker({
10031
+ type: "fields",
10032
+ results: fieldResults
10033
+ });
10034
+ if (selectedIndex < 0 || selectedIndex >= buffer.length) {
10035
+ throw new Error(
10036
+ `Result picker returned invalid index: ${selectedIndex}. Must be between 0 and ${buffer.length - 1}`
10037
+ );
10038
+ }
10039
+ return selectedIndex;
10040
+ }
10041
+ }
10042
+ async function selectFromSamplesInMemory(mem, sessionId, options) {
10043
+ const lastMemory = mem?.getLast(sessionId);
10044
+ if (!lastMemory || lastMemory.role !== "assistant") {
10045
+ return 0;
10046
+ }
10047
+ if (lastMemory.chat.length <= 1) {
10048
+ return 0;
10049
+ }
10050
+ const buffer = lastMemory.chat.map((chat) => ({
10051
+ version: 0,
10052
+ index: chat.index,
10053
+ delta: chat.value
10054
+ }));
10055
+ const selectedIndex = await selectFromSamples(buffer, options, mem, sessionId);
10056
+ return selectedIndex;
10057
+ }
10058
+
9964
10059
  // dsp/validate.ts
9965
10060
  function handleValidationError(mem, errorFields, ai, promptTemplate, sessionId) {
9966
10061
  mem.addRequest(
@@ -10076,7 +10171,10 @@ var AxGen = class extends AxProgramWithSignature {
10076
10171
  thinkingTokenBudget,
10077
10172
  showThoughts
10078
10173
  } = options ?? {};
10079
- const chatPrompt = mem?.history(0, sessionId) ?? [];
10174
+ const selectedIndex = await selectFromSamplesInMemory(mem, sessionId, {
10175
+ resultPicker: options?.resultPicker
10176
+ });
10177
+ const chatPrompt = mem?.history(selectedIndex, sessionId) ?? [];
10080
10178
  if (chatPrompt.length === 0) {
10081
10179
  throw new Error("No chat prompt found");
10082
10180
  }
@@ -10087,7 +10185,8 @@ var AxGen = class extends AxProgramWithSignature {
10087
10185
  }
10088
10186
  const modelConfig = {
10089
10187
  ...options?.modelConfig,
10090
- ...options?.sampleCount ? { n: options.sampleCount } : {}
10188
+ ...options?.sampleCount ? { n: options.sampleCount } : {},
10189
+ ...options?.sampleCount && options?.modelConfig?.temperature == 1 ? { temperature: 0.8 } : {}
10091
10190
  };
10092
10191
  const res = await ai.chat(
10093
10192
  {
@@ -10364,15 +10463,58 @@ var AxGen = class extends AxProgramWithSignature {
10364
10463
  currentVersion = delta.version;
10365
10464
  buffer = mergeDeltas(buffer, delta);
10366
10465
  }
10367
- const result = buffer[0]?.delta ?? {};
10466
+ const selectedIndex = await selectFromSamples(
10467
+ buffer,
10468
+ {
10469
+ resultPicker: options?.resultPicker
10470
+ },
10471
+ // Pass memory to enable function result selection
10472
+ options?.mem,
10473
+ options?.sessionId
10474
+ );
10475
+ const selectedResult = buffer[selectedIndex];
10476
+ const result = selectedResult?.delta ?? {};
10368
10477
  this.trace = { ...values, ...result };
10369
10478
  return result;
10370
10479
  }
10371
10480
  async *streamingForward(ai, values, options) {
10372
- yield* this._forward1(ai, values, {
10481
+ if (!options?.resultPicker) {
10482
+ yield* this._forward1(ai, values, {
10483
+ ...options,
10484
+ stream: true
10485
+ });
10486
+ return;
10487
+ }
10488
+ const generator = this._forward1(ai, values, {
10373
10489
  ...options,
10374
10490
  stream: true
10375
10491
  });
10492
+ let buffer = [];
10493
+ let currentVersion = 0;
10494
+ for await (const delta of generator) {
10495
+ if (delta.version !== currentVersion) {
10496
+ buffer = [];
10497
+ }
10498
+ currentVersion = delta.version;
10499
+ buffer = mergeDeltas(buffer, delta);
10500
+ }
10501
+ const selectedIndex = await selectFromSamples(
10502
+ buffer,
10503
+ {
10504
+ resultPicker: options?.resultPicker
10505
+ },
10506
+ // Pass memory to enable function result selection
10507
+ options?.mem,
10508
+ options?.sessionId
10509
+ );
10510
+ const selectedResult = buffer[selectedIndex];
10511
+ if (selectedResult) {
10512
+ yield {
10513
+ version: currentVersion,
10514
+ index: selectedIndex,
10515
+ delta: selectedResult.delta
10516
+ };
10517
+ }
10376
10518
  }
10377
10519
  setExamples(examples, options) {
10378
10520
  super.setExamples(examples, options);
@@ -11732,13 +11874,6 @@ var AxBaseOptimizer = class {
11732
11874
  if (this.logger) {
11733
11875
  return this.logger;
11734
11876
  }
11735
- try {
11736
- const aiLogger = this.studentAI.getLogger();
11737
- if (aiLogger) {
11738
- return aiLogger;
11739
- }
11740
- } catch {
11741
- }
11742
11877
  return axDefaultOptimizerLogger;
11743
11878
  }
11744
11879
  /**
@@ -13430,6 +13565,11 @@ var AxMiPRO = class extends AxBaseOptimizer {
13430
13565
  bayesianOptimization;
13431
13566
  acquisitionFunction;
13432
13567
  explorationWeight;
13568
+ // Self-consistency / multiple sampling
13569
+ sampleCount;
13570
+ // Surrogate model state for Bayesian optimization
13571
+ miproConfigHistory = [];
13572
+ surrogateModel = /* @__PURE__ */ new Map();
13433
13573
  constructor(args) {
13434
13574
  super(args);
13435
13575
  const options = args.options || {};
@@ -13451,6 +13591,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
13451
13591
  this.bayesianOptimization = options.bayesianOptimization ?? false;
13452
13592
  this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
13453
13593
  this.explorationWeight = options.explorationWeight ?? 0.1;
13594
+ this.sampleCount = options.sampleCount ?? 1;
13454
13595
  this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
13455
13596
  }
13456
13597
  /**
@@ -13495,43 +13636,186 @@ var AxMiPRO = class extends AxBaseOptimizer {
13495
13636
  ];
13496
13637
  }
13497
13638
  /**
13498
- * Generates instruction candidates using the teacher model if available
13639
+ * Generates program summary for context-aware instruction generation
13640
+ */
13641
+ async generateProgramSummary(program, ai) {
13642
+ let signature = "input -> output";
13643
+ if ("getSignature" in program && typeof program.getSignature === "function") {
13644
+ signature = program.getSignature();
13645
+ }
13646
+ const summaryPrompt = `
13647
+ Analyze this language model program and provide a concise summary of its purpose and structure.
13648
+
13649
+ Program Signature: ${signature}
13650
+
13651
+ Provide a 2-3 sentence summary focusing on:
13652
+ 1. The main task or purpose of this program
13653
+ 2. The input-output relationship
13654
+ 3. Any special constraints or requirements
13655
+
13656
+ Summary:`;
13657
+ try {
13658
+ const response = await ai.chat({
13659
+ chatPrompt: [{ role: "user", content: summaryPrompt }]
13660
+ });
13661
+ if ("results" in response) {
13662
+ return response.results[0]?.content?.trim() || "General language model program";
13663
+ }
13664
+ return "General language model program";
13665
+ } catch {
13666
+ return "General language model program";
13667
+ }
13668
+ }
13669
+ /**
13670
+ * Generates dataset summary for context-aware instruction generation
13671
+ */
13672
+ async generateDatasetSummary(examples, ai) {
13673
+ if (examples.length === 0) return "No examples available";
13674
+ const sampleSize = Math.min(this.viewDataBatchSize, examples.length);
13675
+ const sampledExamples = examples.slice(0, sampleSize);
13676
+ const exampleTexts = sampledExamples.map((ex, i) => `Example ${i + 1}: ${JSON.stringify(ex)}`).join("\n");
13677
+ const summaryPrompt = `
13678
+ Analyze this dataset and provide a concise summary of its characteristics.
13679
+
13680
+ Sample Examples:
13681
+ ${exampleTexts}
13682
+
13683
+ Provide a 2-3 sentence summary focusing on:
13684
+ 1. The type of data and domain
13685
+ 2. Common patterns or structures in the examples
13686
+ 3. Key challenges or requirements for processing this data
13687
+
13688
+ Dataset Summary:`;
13689
+ try {
13690
+ const response = await ai.chat({
13691
+ chatPrompt: [{ role: "user", content: summaryPrompt }]
13692
+ });
13693
+ if ("results" in response) {
13694
+ return response.results[0]?.content?.trim() || "General dataset";
13695
+ }
13696
+ return "General dataset";
13697
+ } catch {
13698
+ return "General dataset";
13699
+ }
13700
+ }
13701
+ /**
13702
+ * Enhanced instruction generation using AI with program and data awareness
13703
+ */
13704
+ async generateInstruction({
13705
+ tip,
13706
+ candidateIndex,
13707
+ ai,
13708
+ programSummary,
13709
+ datasetSummary,
13710
+ previousInstructions = []
13711
+ }) {
13712
+ let contextInfo = "";
13713
+ if (this.programAwareProposer && programSummary) {
13714
+ contextInfo += `
13715
+ Program Context: ${programSummary}`;
13716
+ }
13717
+ if (this.dataAwareProposer && datasetSummary) {
13718
+ contextInfo += `
13719
+ Dataset Context: ${datasetSummary}`;
13720
+ }
13721
+ if (this.fewshotAwareProposer && previousInstructions.length > 0) {
13722
+ contextInfo += `
13723
+ Previous Instructions (avoid repeating): ${previousInstructions.slice(-3).join("; ")}`;
13724
+ }
13725
+ const instructionPrompt = `
13726
+ Generate a high-quality instruction for a language model program.
13727
+
13728
+ ${contextInfo}
13729
+
13730
+ ${tip ? `Tip: ${tip}` : ""}
13731
+
13732
+ Requirements:
13733
+ 1. Be specific and actionable
13734
+ 2. Focus on accuracy and clarity
13735
+ 3. Consider the program's purpose and data characteristics
13736
+ 4. Make the instruction distinct from previous ones
13737
+ 5. Keep it concise but comprehensive
13738
+
13739
+ Generate a single, well-crafted instruction:
13740
+ Instruction:`;
13741
+ try {
13742
+ const response = await ai.chat({
13743
+ chatPrompt: [
13744
+ {
13745
+ role: "user",
13746
+ content: instructionPrompt
13747
+ }
13748
+ ]
13749
+ });
13750
+ if ("results" in response) {
13751
+ const instruction2 = response.results[0]?.content?.trim();
13752
+ if (instruction2 && instruction2.length > 10) {
13753
+ return instruction2;
13754
+ }
13755
+ }
13756
+ } catch (error) {
13757
+ if (this.isLoggingEnabled()) {
13758
+ this.getLogger()?.(`Failed to generate AI instruction: ${error}`, {
13759
+ tags: ["optimizer", "warning"]
13760
+ });
13761
+ }
13762
+ }
13763
+ const enhancedTemplates = [
13764
+ "Analyze the input systematically and provide a precise, well-reasoned response.",
13765
+ "Think through this step-by-step, considering all relevant factors before responding.",
13766
+ "Examine the input carefully and generate an accurate, detailed answer.",
13767
+ "Process the information methodically and deliver a clear, comprehensive response.",
13768
+ "Consider the context thoroughly and provide a thoughtful, accurate answer."
13769
+ ];
13770
+ let instruction = enhancedTemplates[candidateIndex % enhancedTemplates.length] || enhancedTemplates[0];
13771
+ if (tip) {
13772
+ instruction = `${instruction} ${tip}`;
13773
+ }
13774
+ return instruction;
13775
+ }
13776
+ /**
13777
+ * Generates instruction candidates using enhanced AI-powered generation
13499
13778
  * @param options Optional compile options that may override teacher AI
13500
13779
  * @returns Array of generated instruction candidates
13501
13780
  */
13502
- async proposeInstructionCandidates(options) {
13781
+ async proposeInstructionCandidates(program, options) {
13503
13782
  const instructions = [];
13504
13783
  const aiToUse = this.getTeacherOrStudentAI(options);
13784
+ let programSummary;
13785
+ let datasetSummary;
13786
+ if (this.programAwareProposer) {
13787
+ programSummary = await this.generateProgramSummary(program, aiToUse);
13788
+ if (this.isLoggingEnabled(options)) {
13789
+ this.getLogger(options)?.(`Program summary: ${programSummary}`, {
13790
+ tags: ["optimizer", "config"]
13791
+ });
13792
+ }
13793
+ }
13794
+ if (this.dataAwareProposer) {
13795
+ datasetSummary = await this.generateDatasetSummary(this.examples, aiToUse);
13796
+ if (this.isLoggingEnabled(options)) {
13797
+ this.getLogger(options)?.(`Dataset summary: ${datasetSummary}`, {
13798
+ tags: ["optimizer", "config"]
13799
+ });
13800
+ }
13801
+ }
13505
13802
  const tips = this.tipAwareProposer ? this.generateTips() : [];
13506
13803
  for (let i = 0; i < this.numCandidates; i++) {
13507
13804
  const tipIndex = tips.length > 0 ? i % tips.length : -1;
13508
- const tipToUse = tipIndex >= 0 ? tips[tipIndex] : "";
13805
+ const tipToUse = tipIndex >= 0 ? tips[tipIndex] : void 0;
13509
13806
  const instruction = await this.generateInstruction({
13510
13807
  tip: tipToUse,
13511
13808
  candidateIndex: i,
13512
- ai: aiToUse
13809
+ ai: aiToUse,
13810
+ programSummary,
13811
+ datasetSummary,
13812
+ previousInstructions: instructions
13813
+ // Pass previous instructions for diversity
13513
13814
  });
13514
13815
  instructions.push(instruction);
13515
13816
  }
13516
13817
  return instructions;
13517
13818
  }
13518
- async generateInstruction({
13519
- tip,
13520
- candidateIndex
13521
- }) {
13522
- const baseInstructions = [
13523
- "Analyze the input carefully and provide a detailed response.",
13524
- "Think step by step and provide a clear answer.",
13525
- "Consider all aspects of the input before responding.",
13526
- "Provide a concise but comprehensive response.",
13527
- "Focus on accuracy and clarity in your response."
13528
- ];
13529
- let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
13530
- if (tip) {
13531
- instruction = `${instruction} ${tip}`;
13532
- }
13533
- return instruction;
13534
- }
13535
13819
  /**
13536
13820
  * Bootstraps few-shot examples for the program
13537
13821
  */
@@ -13576,7 +13860,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
13576
13860
  /**
13577
13861
  * Runs optimization to find the best combination of few-shot examples and instructions
13578
13862
  */
13579
- async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, valset, metricFn, options) {
13863
+ async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, validationExamples, metricFn, options) {
13580
13864
  let bestConfig = {
13581
13865
  instruction: instructions[0] || "",
13582
13866
  bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
@@ -13612,25 +13896,37 @@ var AxMiPRO = class extends AxBaseOptimizer {
13612
13896
  );
13613
13897
  }
13614
13898
  for (let i = startRound; i < this.numTrials; i++) {
13615
- const config = {
13616
- instruction: instructions[i % instructions.length] || instructions[0] || "",
13617
- bootstrappedDemos: Math.min(
13618
- Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
13619
- this.maxBootstrappedDemos
13620
- ),
13621
- labeledExamples: Math.min(
13622
- Math.floor(Math.random() * (labeledExamples.length + 1)),
13623
- this.maxLabeledDemos
13624
- )
13625
- };
13899
+ let config;
13900
+ if (this.bayesianOptimization && this.miproConfigHistory.length > 2) {
13901
+ config = await this.selectConfigurationViaBayesianOptimization(
13902
+ instructions,
13903
+ bootstrappedDemos,
13904
+ labeledExamples
13905
+ );
13906
+ } else {
13907
+ config = {
13908
+ instruction: instructions[i % instructions.length] || instructions[0] || "",
13909
+ bootstrappedDemos: Math.min(
13910
+ Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
13911
+ this.maxBootstrappedDemos
13912
+ ),
13913
+ labeledExamples: Math.min(
13914
+ Math.floor(Math.random() * (labeledExamples.length + 1)),
13915
+ this.maxLabeledDemos
13916
+ )
13917
+ };
13918
+ }
13626
13919
  const score = await this.evaluateConfig(
13627
13920
  program,
13628
13921
  config,
13629
13922
  bootstrappedDemos,
13630
13923
  labeledExamples,
13631
- valset,
13632
- metricFn
13924
+ validationExamples,
13925
+ metricFn,
13926
+ i + 1
13927
+ // Pass current trial number for adaptive evaluation
13633
13928
  );
13929
+ this.updateSurrogateModel(config, score);
13634
13930
  scoreHistory.push(score);
13635
13931
  const improvement = score - bestScore;
13636
13932
  if (improvement > this.minImprovementThreshold) {
@@ -13712,7 +14008,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
13712
14008
  this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
13713
14009
  return { bestConfig, bestScore };
13714
14010
  }
13715
- async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, valset, metricFn) {
14011
+ async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, validationExamples, metricFn, currentTrial = 0) {
13716
14012
  const testProgram = { ...program };
13717
14013
  this.applyConfigToProgram(
13718
14014
  testProgram,
@@ -13722,12 +14018,31 @@ var AxMiPRO = class extends AxBaseOptimizer {
13722
14018
  );
13723
14019
  let totalScore = 0;
13724
14020
  let count = 0;
13725
- const evalSet = valset.slice(0, Math.min(5, valset.length));
14021
+ let evalSize;
14022
+ if (this.minibatch) {
14023
+ const baseSize = Math.min(this.minibatchSize, validationExamples.length);
14024
+ const isFullEvalTrial = currentTrial % this.minibatchFullEvalSteps === 0;
14025
+ if (isFullEvalTrial || currentTrial > this.numTrials * 0.8) {
14026
+ evalSize = Math.min(validationExamples.length, baseSize * 2);
14027
+ } else {
14028
+ evalSize = Math.max(3, Math.min(baseSize, validationExamples.length));
14029
+ }
14030
+ } else {
14031
+ evalSize = validationExamples.length;
14032
+ }
14033
+ const evalIndices = this.shuffleArray([
14034
+ ...Array(validationExamples.length).keys()
14035
+ ]).slice(0, evalSize);
14036
+ const evalSet = evalIndices.map((i) => validationExamples[i]);
13726
14037
  for (const example of evalSet) {
13727
14038
  try {
13728
14039
  const prediction = await testProgram.forward(
13729
14040
  this.studentAI,
13730
- example
14041
+ example,
14042
+ this.sampleCount > 1 ? {
14043
+ sampleCount: this.sampleCount,
14044
+ resultPicker: axMajorityVotePicker()
14045
+ } : void 0
13731
14046
  );
13732
14047
  const score = await metricFn({ prediction, example });
13733
14048
  totalScore += score;
@@ -13739,6 +14054,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
13739
14054
  }
13740
14055
  return count > 0 ? totalScore / count : 0;
13741
14056
  }
14057
+ /**
14058
+ * Fisher-Yates shuffle for stochastic evaluation
14059
+ */
14060
+ shuffleArray(array) {
14061
+ const shuffled = [...array];
14062
+ for (let i = shuffled.length - 1; i > 0; i--) {
14063
+ const j = Math.floor(Math.random() * (i + 1));
14064
+ [shuffled[i], shuffled[j]] = [shuffled[j], shuffled[i]];
14065
+ }
14066
+ return shuffled;
14067
+ }
13742
14068
  applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
13743
14069
  if (program.setInstruction) {
13744
14070
  program.setInstruction(config.instruction);
@@ -13760,14 +14086,14 @@ var AxMiPRO = class extends AxBaseOptimizer {
13760
14086
  if (miproOptions?.auto) {
13761
14087
  this.configureAuto(miproOptions.auto);
13762
14088
  }
13763
- const valset = this.getValidationSet(options) || (miproOptions?.valset ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
14089
+ const validationExamples = this.getValidationSet(options) || (miproOptions?.validationExamples ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
13764
14090
  if (this.isLoggingEnabled(options)) {
13765
14091
  this.getLogger(options)?.(
13766
14092
  `Starting MIPROv2 optimization with ${this.numTrials} trials`,
13767
14093
  { tags: ["optimizer", "start"] }
13768
14094
  );
13769
14095
  this.getLogger(options)?.(
13770
- `Using ${this.examples.length} examples for training and ${valset.length} for validation`,
14096
+ `Using ${this.examples.length} examples for training and ${validationExamples.length} for validation`,
13771
14097
  { tags: ["optimizer", "config"] }
13772
14098
  );
13773
14099
  if (this.teacherAI) {
@@ -13797,7 +14123,10 @@ var AxMiPRO = class extends AxBaseOptimizer {
13797
14123
  );
13798
14124
  }
13799
14125
  }
13800
- const instructions = await this.proposeInstructionCandidates(options);
14126
+ const instructions = await this.proposeInstructionCandidates(
14127
+ program,
14128
+ options
14129
+ );
13801
14130
  if (this.isLoggingEnabled(options)) {
13802
14131
  this.getLogger(options)?.(
13803
14132
  `Generated ${instructions.length} instruction candidates`,
@@ -13815,7 +14144,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
13815
14144
  bootstrappedDemos,
13816
14145
  labeledExamples,
13817
14146
  instructions,
13818
- valset,
14147
+ validationExamples,
13819
14148
  metricFn,
13820
14149
  options
13821
14150
  );
@@ -13874,7 +14203,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
13874
14203
  bootstrappedDemos: bestConfig.bootstrappedDemos,
13875
14204
  labeledExamples: bestConfig.labeledExamples,
13876
14205
  numCandidates: this.numCandidates,
13877
- numTrials: this.numTrials
14206
+ numTrials: this.numTrials,
14207
+ sampleCount: this.sampleCount
13878
14208
  }
13879
14209
  };
13880
14210
  }
@@ -13919,7 +14249,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
13919
14249
  minImprovementThreshold: this.minImprovementThreshold,
13920
14250
  bayesianOptimization: this.bayesianOptimization,
13921
14251
  acquisitionFunction: this.acquisitionFunction,
13922
- explorationWeight: this.explorationWeight
14252
+ explorationWeight: this.explorationWeight,
14253
+ sampleCount: this.sampleCount
13923
14254
  };
13924
14255
  }
13925
14256
  /**
@@ -13954,12 +14285,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
13954
14285
  if (config.minImprovementThreshold !== void 0) {
13955
14286
  this.minImprovementThreshold = config.minImprovementThreshold;
13956
14287
  }
14288
+ if (config.sampleCount !== void 0) {
14289
+ this.sampleCount = config.sampleCount;
14290
+ }
13957
14291
  }
13958
14292
  /**
13959
14293
  * Reset optimizer state for reuse with different programs
13960
14294
  */
13961
14295
  reset() {
13962
14296
  super.reset();
14297
+ this.miproConfigHistory = [];
14298
+ this.surrogateModel.clear();
13963
14299
  this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
13964
14300
  }
13965
14301
  /**
@@ -13977,8 +14313,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
13977
14313
  "Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
13978
14314
  );
13979
14315
  }
13980
- const valSetSize = this.getValidationSet().length;
13981
- if (valSetSize < 5) {
14316
+ const validationSetSize = this.getValidationSet().length;
14317
+ if (validationSetSize < 5) {
13982
14318
  result.issues.push(
13983
14319
  "Validation set too small for reliable MiPRO optimization"
13984
14320
  );
@@ -13992,6 +14328,141 @@ var AxMiPRO = class extends AxBaseOptimizer {
13992
14328
  suggestions: result.suggestions
13993
14329
  };
13994
14330
  }
14331
+ /**
14332
+ * Encodes a configuration into a string key for surrogate model lookup
14333
+ */
14334
+ encodeConfiguration(config) {
14335
+ return `${config.instruction.length}_${config.bootstrappedDemos}_${config.labeledExamples}`;
14336
+ }
14337
+ /**
14338
+ * Updates the surrogate model with a new configuration-score pair
14339
+ */
14340
+ updateSurrogateModel(config, score) {
14341
+ this.miproConfigHistory.push({ config: { ...config }, score });
14342
+ const key = this.encodeConfiguration(config);
14343
+ const similarConfigs = this.miproConfigHistory.filter(
14344
+ (entry) => this.encodeConfiguration(entry.config) === key
14345
+ );
14346
+ if (similarConfigs.length > 0) {
14347
+ const scores = similarConfigs.map((entry) => entry.score);
14348
+ const mean = scores.reduce((sum, s2) => sum + s2, 0) / scores.length;
14349
+ const variance = scores.length > 1 ? scores.reduce((sum, s2) => sum + Math.pow(s2 - mean, 2), 0) / (scores.length - 1) : 0.1;
14350
+ this.surrogateModel.set(key, { mean, variance });
14351
+ }
14352
+ }
14353
+ /**
14354
+ * Predicts performance using the surrogate model
14355
+ */
14356
+ predictPerformance(config) {
14357
+ const key = this.encodeConfiguration(config);
14358
+ if (this.surrogateModel.has(key)) {
14359
+ return this.surrogateModel.get(key);
14360
+ }
14361
+ if (this.miproConfigHistory.length > 0) {
14362
+ const similarities = this.miproConfigHistory.map((entry) => {
14363
+ const diff = Math.abs(entry.config.bootstrappedDemos - config.bootstrappedDemos) + Math.abs(entry.config.labeledExamples - config.labeledExamples);
14364
+ return { score: entry.score, similarity: 1 / (1 + diff) };
14365
+ });
14366
+ const totalWeight = similarities.reduce((sum, s2) => sum + s2.similarity, 0);
14367
+ const weightedMean = similarities.reduce((sum, s2) => sum + s2.score * s2.similarity, 0) / totalWeight;
14368
+ return { mean: weightedMean, variance: 0.2 };
14369
+ }
14370
+ return { mean: 0.5, variance: 0.3 };
14371
+ }
14372
+ /**
14373
+ * Calculates acquisition function value for Bayesian optimization
14374
+ */
14375
+ calculateAcquisitionValue(config) {
14376
+ const prediction = this.predictPerformance(config);
14377
+ const { mean, variance } = prediction;
14378
+ const std = Math.sqrt(variance);
14379
+ const bestScore = this.miproConfigHistory.length > 0 ? Math.max(...this.miproConfigHistory.map((entry) => entry.score)) : 0;
14380
+ switch (this.acquisitionFunction) {
14381
+ case "expected_improvement": {
14382
+ const improvement = mean - bestScore;
14383
+ if (std === 0) return Math.max(0, improvement);
14384
+ const z = improvement / std;
14385
+ const phi = 0.5 * (1 + this.erf(z / Math.sqrt(2)));
14386
+ const pdfValue = Math.exp(-0.5 * z * z) / Math.sqrt(2 * Math.PI);
14387
+ return improvement * phi + std * pdfValue;
14388
+ }
14389
+ case "upper_confidence_bound": {
14390
+ return mean + this.explorationWeight * std;
14391
+ }
14392
+ case "probability_improvement": {
14393
+ const improvement = mean - bestScore;
14394
+ if (std === 0) return improvement > 0 ? 1 : 0;
14395
+ const z = improvement / std;
14396
+ return 0.5 * (1 + this.erf(z / Math.sqrt(2)));
14397
+ }
14398
+ default:
14399
+ return mean;
14400
+ }
14401
+ }
14402
+ /**
14403
+ * Error function approximation for acquisition function calculations
14404
+ */
14405
+ erf(x) {
14406
+ const a1 = 0.254829592;
14407
+ const a2 = -0.284496736;
14408
+ const a3 = 1.421413741;
14409
+ const a4 = -1.453152027;
14410
+ const a5 = 1.061405429;
14411
+ const p = 0.3275911;
14412
+ const sign = x >= 0 ? 1 : -1;
14413
+ x = Math.abs(x);
14414
+ const t = 1 / (1 + p * x);
14415
+ const y = 1 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.exp(-x * x);
14416
+ return sign * y;
14417
+ }
14418
+ /**
14419
+ * Selects the next configuration to evaluate using Bayesian optimization
14420
+ */
14421
+ async selectConfigurationViaBayesianOptimization(instructions, bootstrappedDemos, labeledExamples) {
14422
+ const candidates = [];
14423
+ const numCandidates = Math.min(20, instructions.length * 3);
14424
+ for (let i = 0; i < numCandidates; i++) {
14425
+ const config = {
14426
+ instruction: instructions[i % instructions.length] || instructions[0] || "",
14427
+ bootstrappedDemos: Math.min(
14428
+ Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
14429
+ this.maxBootstrappedDemos
14430
+ ),
14431
+ labeledExamples: Math.min(
14432
+ Math.floor(Math.random() * (labeledExamples.length + 1)),
14433
+ this.maxLabeledDemos
14434
+ )
14435
+ };
14436
+ const acquisitionValue = this.calculateAcquisitionValue(config);
14437
+ candidates.push({ config, acquisitionValue });
14438
+ }
14439
+ candidates.sort((a, b) => b.acquisitionValue - a.acquisitionValue);
14440
+ return candidates[0].config;
14441
+ }
14442
+ };
14443
+ var axMajorityVotePicker = () => {
14444
+ return async (data) => {
14445
+ if (data.type === "fields") {
14446
+ const counts = {};
14447
+ for (const { index, sample } of data.results) {
14448
+ const key = JSON.stringify(sample);
14449
+ if (!counts[key]) {
14450
+ counts[key] = { count: 0, index };
14451
+ }
14452
+ counts[key].count += 1;
14453
+ }
14454
+ let bestKey;
14455
+ let bestCount = -1;
14456
+ for (const [k, v] of Object.entries(counts)) {
14457
+ if (v.count > bestCount) {
14458
+ bestCount = v.count;
14459
+ bestKey = k;
14460
+ }
14461
+ }
14462
+ return counts[bestKey]?.index ?? 0;
14463
+ }
14464
+ return data.results[0]?.index ?? 0;
14465
+ };
13995
14466
  };
13996
14467
 
13997
14468
  // ai/mock/api.ts