@ax-llm/ax 12.0.12 → 12.0.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/index.js CHANGED
@@ -6083,15 +6083,24 @@ var MemoryImpl = class {
6083
6083
  debugRequest(items, this.options?.debugHideSystemPrompt);
6084
6084
  }
6085
6085
  }
6086
+ addFunctionResults(results) {
6087
+ const chat = results.map(({ index, ...value }) => ({
6088
+ index,
6089
+ value: structuredClone(value)
6090
+ }));
6091
+ const lastItem = this.getLast();
6092
+ if (lastItem?.role === "function") {
6093
+ lastItem.chat.push(...chat);
6094
+ } else {
6095
+ this.data.push({ role: "function", chat });
6096
+ }
6097
+ }
6086
6098
  addResponse(results) {
6087
- const chat = results.map((result) => ({
6088
- index: result.index,
6089
- value: structuredClone(result)
6099
+ const chat = results.map(({ index, ...value }) => ({
6100
+ index,
6101
+ value: structuredClone(value)
6090
6102
  }));
6091
- this.data.push({
6092
- role: "assistant",
6093
- chat
6094
- });
6103
+ this.data.push({ role: "assistant", chat });
6095
6104
  if (this.options?.debug) {
6096
6105
  for (const result of results) {
6097
6106
  debugResponse(result);
@@ -6179,9 +6188,20 @@ var MemoryImpl = class {
6179
6188
  history(index) {
6180
6189
  const result = [];
6181
6190
  for (const { role, chat } of this.data) {
6182
- const value = chat.find((v) => v.index === index)?.value;
6183
- if (value) {
6184
- result.push({ role, ...value });
6191
+ let values;
6192
+ if (role === "function") {
6193
+ values = chat.filter((v) => v.index === index).map((v) => v.value);
6194
+ } else {
6195
+ values = chat.find((v) => v.index === index)?.value;
6196
+ }
6197
+ if (Array.isArray(values)) {
6198
+ result.push(
6199
+ ...values.map(
6200
+ (v) => ({ ...v, role })
6201
+ )
6202
+ );
6203
+ } else if (values) {
6204
+ result.push({ ...values, role });
6185
6205
  }
6186
6206
  }
6187
6207
  return result;
@@ -6219,20 +6239,8 @@ var AxMemory = class {
6219
6239
  axValidateChatResponseResult(results);
6220
6240
  this.getMemory(sessionId).addResponse(results);
6221
6241
  }
6222
- addFunctionResult({
6223
- functionId,
6224
- isError,
6225
- index,
6226
- result
6227
- }, sessionId) {
6228
- const functionMessage = {
6229
- role: "function",
6230
- functionId,
6231
- isError,
6232
- result
6233
- };
6234
- axValidateChatRequestMessage(functionMessage);
6235
- this.getMemory(sessionId).addRequest([functionMessage], index);
6242
+ addFunctionResults(results, sessionId) {
6243
+ this.getMemory(sessionId).addFunctionResults(results);
6236
6244
  }
6237
6245
  updateResult(result, sessionId) {
6238
6246
  this.getMemory(sessionId).updateResult(result);
@@ -7469,51 +7477,48 @@ var processFunctions = async ({
7469
7477
  }
7470
7478
  return {
7471
7479
  result: functionResult ?? "",
7480
+ role: "function",
7472
7481
  functionId: func.id,
7473
7482
  index
7474
7483
  };
7475
7484
  }).catch((e) => {
7476
- if (e instanceof FunctionError) {
7477
- const result = e.getFixingInstructions();
7478
- if (span) {
7479
- const errorEventData = {
7480
- name: func.name,
7481
- message: e.toString()
7482
- };
7483
- if (!excludeContentFromTrace) {
7484
- errorEventData.args = func.args;
7485
- errorEventData.fixing_instructions = result;
7486
- }
7487
- span.addEvent("function.error", errorEventData);
7485
+ if (!(e instanceof FunctionError)) {
7486
+ throw e;
7487
+ }
7488
+ const result = e.getFixingInstructions();
7489
+ if (span) {
7490
+ const errorEventData = {
7491
+ name: func.name,
7492
+ message: e.toString()
7493
+ };
7494
+ if (!excludeContentFromTrace) {
7495
+ errorEventData.args = func.args;
7496
+ errorEventData.fixing_instructions = result;
7488
7497
  }
7489
- mem.addFunctionResult(
7490
- {
7491
- functionId: func.id,
7492
- isError: true,
7493
- index,
7494
- result
7495
- },
7496
- sessionId
7497
- );
7498
- mem.addTag("error", sessionId);
7499
- if (ai.getOptions().debug) {
7500
- const logger = ai.getLogger();
7501
- logger(`\u274C Function Error Correction:
7498
+ span.addEvent("function.error", errorEventData);
7499
+ }
7500
+ if (ai.getOptions().debug) {
7501
+ const logger = ai.getLogger();
7502
+ logger(`\u274C Function Error Correction:
7502
7503
  ${result}`, {
7503
- tags: ["error"]
7504
- });
7505
- }
7506
- } else {
7507
- throw e;
7504
+ tags: ["error"]
7505
+ });
7508
7506
  }
7507
+ return {
7508
+ functionId: func.id,
7509
+ isError: true,
7510
+ index,
7511
+ result,
7512
+ role: "function"
7513
+ };
7509
7514
  });
7510
7515
  return promise;
7511
7516
  });
7512
7517
  const results = await Promise.all(promises);
7513
- for (const result of results) {
7514
- if (result) {
7515
- mem.addFunctionResult(result, sessionId);
7516
- }
7518
+ const functionResults = results.filter((result) => result !== void 0);
7519
+ mem.addFunctionResults(functionResults, sessionId);
7520
+ if (functionResults.some((result) => result.isError)) {
7521
+ mem.addTag("error", sessionId);
7517
7522
  }
7518
7523
  return functionsExecuted;
7519
7524
  };
@@ -9787,6 +9792,96 @@ var AxProgram = class {
9787
9792
  }
9788
9793
  };
9789
9794
 
9795
+ // dsp/samples.ts
9796
+ function checkForFunctionCalls(mem, sessionId) {
9797
+ const history = mem.history(0, sessionId);
9798
+ const hasFunctionResults = history.some((msg) => msg.role === "function");
9799
+ const hasFunctionCalls = history.some(
9800
+ (msg) => msg.role === "assistant" && "functionCalls" in msg && Array.isArray(msg.functionCalls) && msg.functionCalls.length > 0
9801
+ );
9802
+ return hasFunctionCalls && hasFunctionResults;
9803
+ }
9804
+ function extractFunctionResults(mem, sessionId) {
9805
+ const history = mem.history(0, sessionId);
9806
+ const results = [];
9807
+ const assistantMessages = history.filter(
9808
+ (msg) => msg.role === "assistant" && "functionCalls" in msg && Array.isArray(msg.functionCalls) && msg.functionCalls.length > 0
9809
+ );
9810
+ const functionMessages = history.filter((msg) => msg.role === "function");
9811
+ for (const assistantMsg of assistantMessages) {
9812
+ if ("functionCalls" in assistantMsg && assistantMsg.functionCalls) {
9813
+ for (const funcCall of assistantMsg.functionCalls) {
9814
+ const funcResult = functionMessages.find(
9815
+ (msg) => "functionId" in msg && msg.functionId === funcCall.id
9816
+ );
9817
+ if (funcResult && "result" in funcResult && "functionId" in funcResult) {
9818
+ results.push({
9819
+ index: results.length,
9820
+ // Use sequential index for function results
9821
+ functionName: funcCall.function.name,
9822
+ functionId: funcCall.id,
9823
+ args: funcCall.function.params || "",
9824
+ result: String(funcResult.result),
9825
+ isError: "isError" in funcResult ? Boolean(funcResult.isError) : false
9826
+ });
9827
+ }
9828
+ }
9829
+ }
9830
+ }
9831
+ return results;
9832
+ }
9833
+ async function selectFromSamples(buffer, options, mem, sessionId) {
9834
+ if (!options?.resultPicker || buffer.length <= 1) {
9835
+ return 0;
9836
+ }
9837
+ const resultPicker = options.resultPicker;
9838
+ const hasFunctionCalls = mem ? checkForFunctionCalls(mem, sessionId) : false;
9839
+ if (hasFunctionCalls && mem) {
9840
+ const functionResults = extractFunctionResults(mem, sessionId);
9841
+ const selectedIndex = await resultPicker({
9842
+ type: "function",
9843
+ results: functionResults
9844
+ });
9845
+ if (selectedIndex < 0 || selectedIndex >= functionResults.length) {
9846
+ throw new Error(
9847
+ `Result picker returned invalid index: ${selectedIndex}. Must be between 0 and ${functionResults.length - 1}`
9848
+ );
9849
+ }
9850
+ return selectedIndex;
9851
+ } else {
9852
+ const fieldResults = buffer.map((b, index) => ({
9853
+ index,
9854
+ sample: b.delta
9855
+ }));
9856
+ const selectedIndex = await resultPicker({
9857
+ type: "fields",
9858
+ results: fieldResults
9859
+ });
9860
+ if (selectedIndex < 0 || selectedIndex >= buffer.length) {
9861
+ throw new Error(
9862
+ `Result picker returned invalid index: ${selectedIndex}. Must be between 0 and ${buffer.length - 1}`
9863
+ );
9864
+ }
9865
+ return selectedIndex;
9866
+ }
9867
+ }
9868
+ async function selectFromSamplesInMemory(mem, sessionId, options) {
9869
+ const lastMemory = mem?.getLast(sessionId);
9870
+ if (!lastMemory || lastMemory.role !== "assistant") {
9871
+ return 0;
9872
+ }
9873
+ if (lastMemory.chat.length <= 1) {
9874
+ return 0;
9875
+ }
9876
+ const buffer = lastMemory.chat.map((chat) => ({
9877
+ version: 0,
9878
+ index: chat.index,
9879
+ delta: chat.value
9880
+ }));
9881
+ const selectedIndex = await selectFromSamples(buffer, options, mem, sessionId);
9882
+ return selectedIndex;
9883
+ }
9884
+
9790
9885
  // dsp/validate.ts
9791
9886
  function handleValidationError(mem, errorFields, ai, promptTemplate, sessionId) {
9792
9887
  mem.addRequest(
@@ -9902,7 +9997,10 @@ var AxGen = class extends AxProgramWithSignature {
9902
9997
  thinkingTokenBudget,
9903
9998
  showThoughts
9904
9999
  } = options ?? {};
9905
- const chatPrompt = mem?.history(0, sessionId) ?? [];
10000
+ const selectedIndex = await selectFromSamplesInMemory(mem, sessionId, {
10001
+ resultPicker: options?.resultPicker
10002
+ });
10003
+ const chatPrompt = mem?.history(selectedIndex, sessionId) ?? [];
9906
10004
  if (chatPrompt.length === 0) {
9907
10005
  throw new Error("No chat prompt found");
9908
10006
  }
@@ -9913,7 +10011,8 @@ var AxGen = class extends AxProgramWithSignature {
9913
10011
  }
9914
10012
  const modelConfig = {
9915
10013
  ...options?.modelConfig,
9916
- ...options?.sampleCount ? { n: options.sampleCount } : {}
10014
+ ...options?.sampleCount ? { n: options.sampleCount } : {},
10015
+ ...options?.sampleCount && options?.modelConfig?.temperature == 1 ? { temperature: 0.8 } : {}
9917
10016
  };
9918
10017
  const res = await ai.chat(
9919
10018
  {
@@ -10190,15 +10289,58 @@ var AxGen = class extends AxProgramWithSignature {
10190
10289
  currentVersion = delta.version;
10191
10290
  buffer = mergeDeltas(buffer, delta);
10192
10291
  }
10193
- const result = buffer[0]?.delta ?? {};
10292
+ const selectedIndex = await selectFromSamples(
10293
+ buffer,
10294
+ {
10295
+ resultPicker: options?.resultPicker
10296
+ },
10297
+ // Pass memory to enable function result selection
10298
+ options?.mem,
10299
+ options?.sessionId
10300
+ );
10301
+ const selectedResult = buffer[selectedIndex];
10302
+ const result = selectedResult?.delta ?? {};
10194
10303
  this.trace = { ...values, ...result };
10195
10304
  return result;
10196
10305
  }
10197
10306
  async *streamingForward(ai, values, options) {
10198
- yield* this._forward1(ai, values, {
10307
+ if (!options?.resultPicker) {
10308
+ yield* this._forward1(ai, values, {
10309
+ ...options,
10310
+ stream: true
10311
+ });
10312
+ return;
10313
+ }
10314
+ const generator = this._forward1(ai, values, {
10199
10315
  ...options,
10200
10316
  stream: true
10201
10317
  });
10318
+ let buffer = [];
10319
+ let currentVersion = 0;
10320
+ for await (const delta of generator) {
10321
+ if (delta.version !== currentVersion) {
10322
+ buffer = [];
10323
+ }
10324
+ currentVersion = delta.version;
10325
+ buffer = mergeDeltas(buffer, delta);
10326
+ }
10327
+ const selectedIndex = await selectFromSamples(
10328
+ buffer,
10329
+ {
10330
+ resultPicker: options?.resultPicker
10331
+ },
10332
+ // Pass memory to enable function result selection
10333
+ options?.mem,
10334
+ options?.sessionId
10335
+ );
10336
+ const selectedResult = buffer[selectedIndex];
10337
+ if (selectedResult) {
10338
+ yield {
10339
+ version: currentVersion,
10340
+ index: selectedIndex,
10341
+ delta: selectedResult.delta
10342
+ };
10343
+ }
10202
10344
  }
10203
10345
  setExamples(examples, options) {
10204
10346
  super.setExamples(examples, options);
@@ -11558,13 +11700,6 @@ var AxBaseOptimizer = class {
11558
11700
  if (this.logger) {
11559
11701
  return this.logger;
11560
11702
  }
11561
- try {
11562
- const aiLogger = this.studentAI.getLogger();
11563
- if (aiLogger) {
11564
- return aiLogger;
11565
- }
11566
- } catch {
11567
- }
11568
11703
  return axDefaultOptimizerLogger;
11569
11704
  }
11570
11705
  /**
@@ -13256,6 +13391,11 @@ var AxMiPRO = class extends AxBaseOptimizer {
13256
13391
  bayesianOptimization;
13257
13392
  acquisitionFunction;
13258
13393
  explorationWeight;
13394
+ // Self-consistency / multiple sampling
13395
+ sampleCount;
13396
+ // Surrogate model state for Bayesian optimization
13397
+ miproConfigHistory = [];
13398
+ surrogateModel = /* @__PURE__ */ new Map();
13259
13399
  constructor(args) {
13260
13400
  super(args);
13261
13401
  const options = args.options || {};
@@ -13277,6 +13417,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
13277
13417
  this.bayesianOptimization = options.bayesianOptimization ?? false;
13278
13418
  this.acquisitionFunction = options.acquisitionFunction ?? "expected_improvement";
13279
13419
  this.explorationWeight = options.explorationWeight ?? 0.1;
13420
+ this.sampleCount = options.sampleCount ?? 1;
13280
13421
  this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
13281
13422
  }
13282
13423
  /**
@@ -13321,43 +13462,186 @@ var AxMiPRO = class extends AxBaseOptimizer {
13321
13462
  ];
13322
13463
  }
13323
13464
  /**
13324
- * Generates instruction candidates using the teacher model if available
13465
+ * Generates program summary for context-aware instruction generation
13466
+ */
13467
+ async generateProgramSummary(program, ai) {
13468
+ let signature = "input -> output";
13469
+ if ("getSignature" in program && typeof program.getSignature === "function") {
13470
+ signature = program.getSignature();
13471
+ }
13472
+ const summaryPrompt = `
13473
+ Analyze this language model program and provide a concise summary of its purpose and structure.
13474
+
13475
+ Program Signature: ${signature}
13476
+
13477
+ Provide a 2-3 sentence summary focusing on:
13478
+ 1. The main task or purpose of this program
13479
+ 2. The input-output relationship
13480
+ 3. Any special constraints or requirements
13481
+
13482
+ Summary:`;
13483
+ try {
13484
+ const response = await ai.chat({
13485
+ chatPrompt: [{ role: "user", content: summaryPrompt }]
13486
+ });
13487
+ if ("results" in response) {
13488
+ return response.results[0]?.content?.trim() || "General language model program";
13489
+ }
13490
+ return "General language model program";
13491
+ } catch {
13492
+ return "General language model program";
13493
+ }
13494
+ }
13495
+ /**
13496
+ * Generates dataset summary for context-aware instruction generation
13497
+ */
13498
+ async generateDatasetSummary(examples, ai) {
13499
+ if (examples.length === 0) return "No examples available";
13500
+ const sampleSize = Math.min(this.viewDataBatchSize, examples.length);
13501
+ const sampledExamples = examples.slice(0, sampleSize);
13502
+ const exampleTexts = sampledExamples.map((ex, i) => `Example ${i + 1}: ${JSON.stringify(ex)}`).join("\n");
13503
+ const summaryPrompt = `
13504
+ Analyze this dataset and provide a concise summary of its characteristics.
13505
+
13506
+ Sample Examples:
13507
+ ${exampleTexts}
13508
+
13509
+ Provide a 2-3 sentence summary focusing on:
13510
+ 1. The type of data and domain
13511
+ 2. Common patterns or structures in the examples
13512
+ 3. Key challenges or requirements for processing this data
13513
+
13514
+ Dataset Summary:`;
13515
+ try {
13516
+ const response = await ai.chat({
13517
+ chatPrompt: [{ role: "user", content: summaryPrompt }]
13518
+ });
13519
+ if ("results" in response) {
13520
+ return response.results[0]?.content?.trim() || "General dataset";
13521
+ }
13522
+ return "General dataset";
13523
+ } catch {
13524
+ return "General dataset";
13525
+ }
13526
+ }
13527
+ /**
13528
+ * Enhanced instruction generation using AI with program and data awareness
13529
+ */
13530
+ async generateInstruction({
13531
+ tip,
13532
+ candidateIndex,
13533
+ ai,
13534
+ programSummary,
13535
+ datasetSummary,
13536
+ previousInstructions = []
13537
+ }) {
13538
+ let contextInfo = "";
13539
+ if (this.programAwareProposer && programSummary) {
13540
+ contextInfo += `
13541
+ Program Context: ${programSummary}`;
13542
+ }
13543
+ if (this.dataAwareProposer && datasetSummary) {
13544
+ contextInfo += `
13545
+ Dataset Context: ${datasetSummary}`;
13546
+ }
13547
+ if (this.fewshotAwareProposer && previousInstructions.length > 0) {
13548
+ contextInfo += `
13549
+ Previous Instructions (avoid repeating): ${previousInstructions.slice(-3).join("; ")}`;
13550
+ }
13551
+ const instructionPrompt = `
13552
+ Generate a high-quality instruction for a language model program.
13553
+
13554
+ ${contextInfo}
13555
+
13556
+ ${tip ? `Tip: ${tip}` : ""}
13557
+
13558
+ Requirements:
13559
+ 1. Be specific and actionable
13560
+ 2. Focus on accuracy and clarity
13561
+ 3. Consider the program's purpose and data characteristics
13562
+ 4. Make the instruction distinct from previous ones
13563
+ 5. Keep it concise but comprehensive
13564
+
13565
+ Generate a single, well-crafted instruction:
13566
+ Instruction:`;
13567
+ try {
13568
+ const response = await ai.chat({
13569
+ chatPrompt: [
13570
+ {
13571
+ role: "user",
13572
+ content: instructionPrompt
13573
+ }
13574
+ ]
13575
+ });
13576
+ if ("results" in response) {
13577
+ const instruction2 = response.results[0]?.content?.trim();
13578
+ if (instruction2 && instruction2.length > 10) {
13579
+ return instruction2;
13580
+ }
13581
+ }
13582
+ } catch (error) {
13583
+ if (this.isLoggingEnabled()) {
13584
+ this.getLogger()?.(`Failed to generate AI instruction: ${error}`, {
13585
+ tags: ["optimizer", "warning"]
13586
+ });
13587
+ }
13588
+ }
13589
+ const enhancedTemplates = [
13590
+ "Analyze the input systematically and provide a precise, well-reasoned response.",
13591
+ "Think through this step-by-step, considering all relevant factors before responding.",
13592
+ "Examine the input carefully and generate an accurate, detailed answer.",
13593
+ "Process the information methodically and deliver a clear, comprehensive response.",
13594
+ "Consider the context thoroughly and provide a thoughtful, accurate answer."
13595
+ ];
13596
+ let instruction = enhancedTemplates[candidateIndex % enhancedTemplates.length] || enhancedTemplates[0];
13597
+ if (tip) {
13598
+ instruction = `${instruction} ${tip}`;
13599
+ }
13600
+ return instruction;
13601
+ }
13602
+ /**
13603
+ * Generates instruction candidates using enhanced AI-powered generation
13325
13604
  * @param options Optional compile options that may override teacher AI
13326
13605
  * @returns Array of generated instruction candidates
13327
13606
  */
13328
- async proposeInstructionCandidates(options) {
13607
+ async proposeInstructionCandidates(program, options) {
13329
13608
  const instructions = [];
13330
13609
  const aiToUse = this.getTeacherOrStudentAI(options);
13610
+ let programSummary;
13611
+ let datasetSummary;
13612
+ if (this.programAwareProposer) {
13613
+ programSummary = await this.generateProgramSummary(program, aiToUse);
13614
+ if (this.isLoggingEnabled(options)) {
13615
+ this.getLogger(options)?.(`Program summary: ${programSummary}`, {
13616
+ tags: ["optimizer", "config"]
13617
+ });
13618
+ }
13619
+ }
13620
+ if (this.dataAwareProposer) {
13621
+ datasetSummary = await this.generateDatasetSummary(this.examples, aiToUse);
13622
+ if (this.isLoggingEnabled(options)) {
13623
+ this.getLogger(options)?.(`Dataset summary: ${datasetSummary}`, {
13624
+ tags: ["optimizer", "config"]
13625
+ });
13626
+ }
13627
+ }
13331
13628
  const tips = this.tipAwareProposer ? this.generateTips() : [];
13332
13629
  for (let i = 0; i < this.numCandidates; i++) {
13333
13630
  const tipIndex = tips.length > 0 ? i % tips.length : -1;
13334
- const tipToUse = tipIndex >= 0 ? tips[tipIndex] : "";
13631
+ const tipToUse = tipIndex >= 0 ? tips[tipIndex] : void 0;
13335
13632
  const instruction = await this.generateInstruction({
13336
13633
  tip: tipToUse,
13337
13634
  candidateIndex: i,
13338
- ai: aiToUse
13635
+ ai: aiToUse,
13636
+ programSummary,
13637
+ datasetSummary,
13638
+ previousInstructions: instructions
13639
+ // Pass previous instructions for diversity
13339
13640
  });
13340
13641
  instructions.push(instruction);
13341
13642
  }
13342
13643
  return instructions;
13343
13644
  }
13344
- async generateInstruction({
13345
- tip,
13346
- candidateIndex
13347
- }) {
13348
- const baseInstructions = [
13349
- "Analyze the input carefully and provide a detailed response.",
13350
- "Think step by step and provide a clear answer.",
13351
- "Consider all aspects of the input before responding.",
13352
- "Provide a concise but comprehensive response.",
13353
- "Focus on accuracy and clarity in your response."
13354
- ];
13355
- let instruction = baseInstructions[candidateIndex % baseInstructions.length] || baseInstructions[0];
13356
- if (tip) {
13357
- instruction = `${instruction} ${tip}`;
13358
- }
13359
- return instruction;
13360
- }
13361
13645
  /**
13362
13646
  * Bootstraps few-shot examples for the program
13363
13647
  */
@@ -13402,7 +13686,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
13402
13686
  /**
13403
13687
  * Runs optimization to find the best combination of few-shot examples and instructions
13404
13688
  */
13405
- async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, valset, metricFn, options) {
13689
+ async runOptimization(program, bootstrappedDemos, labeledExamples, instructions, validationExamples, metricFn, options) {
13406
13690
  let bestConfig = {
13407
13691
  instruction: instructions[0] || "",
13408
13692
  bootstrappedDemos: Math.min(1, bootstrappedDemos.length),
@@ -13438,25 +13722,37 @@ var AxMiPRO = class extends AxBaseOptimizer {
13438
13722
  );
13439
13723
  }
13440
13724
  for (let i = startRound; i < this.numTrials; i++) {
13441
- const config = {
13442
- instruction: instructions[i % instructions.length] || instructions[0] || "",
13443
- bootstrappedDemos: Math.min(
13444
- Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
13445
- this.maxBootstrappedDemos
13446
- ),
13447
- labeledExamples: Math.min(
13448
- Math.floor(Math.random() * (labeledExamples.length + 1)),
13449
- this.maxLabeledDemos
13450
- )
13451
- };
13725
+ let config;
13726
+ if (this.bayesianOptimization && this.miproConfigHistory.length > 2) {
13727
+ config = await this.selectConfigurationViaBayesianOptimization(
13728
+ instructions,
13729
+ bootstrappedDemos,
13730
+ labeledExamples
13731
+ );
13732
+ } else {
13733
+ config = {
13734
+ instruction: instructions[i % instructions.length] || instructions[0] || "",
13735
+ bootstrappedDemos: Math.min(
13736
+ Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
13737
+ this.maxBootstrappedDemos
13738
+ ),
13739
+ labeledExamples: Math.min(
13740
+ Math.floor(Math.random() * (labeledExamples.length + 1)),
13741
+ this.maxLabeledDemos
13742
+ )
13743
+ };
13744
+ }
13452
13745
  const score = await this.evaluateConfig(
13453
13746
  program,
13454
13747
  config,
13455
13748
  bootstrappedDemos,
13456
13749
  labeledExamples,
13457
- valset,
13458
- metricFn
13750
+ validationExamples,
13751
+ metricFn,
13752
+ i + 1
13753
+ // Pass current trial number for adaptive evaluation
13459
13754
  );
13755
+ this.updateSurrogateModel(config, score);
13460
13756
  scoreHistory.push(score);
13461
13757
  const improvement = score - bestScore;
13462
13758
  if (improvement > this.minImprovementThreshold) {
@@ -13538,7 +13834,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
13538
13834
  this.stats.convergenceInfo.converged = stagnationRounds < this.earlyStoppingTrials;
13539
13835
  return { bestConfig, bestScore };
13540
13836
  }
13541
- async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, valset, metricFn) {
13837
+ async evaluateConfig(program, config, bootstrappedDemos, labeledExamples, validationExamples, metricFn, currentTrial = 0) {
13542
13838
  const testProgram = { ...program };
13543
13839
  this.applyConfigToProgram(
13544
13840
  testProgram,
@@ -13548,12 +13844,31 @@ var AxMiPRO = class extends AxBaseOptimizer {
13548
13844
  );
13549
13845
  let totalScore = 0;
13550
13846
  let count = 0;
13551
- const evalSet = valset.slice(0, Math.min(5, valset.length));
13847
+ let evalSize;
13848
+ if (this.minibatch) {
13849
+ const baseSize = Math.min(this.minibatchSize, validationExamples.length);
13850
+ const isFullEvalTrial = currentTrial % this.minibatchFullEvalSteps === 0;
13851
+ if (isFullEvalTrial || currentTrial > this.numTrials * 0.8) {
13852
+ evalSize = Math.min(validationExamples.length, baseSize * 2);
13853
+ } else {
13854
+ evalSize = Math.max(3, Math.min(baseSize, validationExamples.length));
13855
+ }
13856
+ } else {
13857
+ evalSize = validationExamples.length;
13858
+ }
13859
+ const evalIndices = this.shuffleArray([
13860
+ ...Array(validationExamples.length).keys()
13861
+ ]).slice(0, evalSize);
13862
+ const evalSet = evalIndices.map((i) => validationExamples[i]);
13552
13863
  for (const example of evalSet) {
13553
13864
  try {
13554
13865
  const prediction = await testProgram.forward(
13555
13866
  this.studentAI,
13556
- example
13867
+ example,
13868
+ this.sampleCount > 1 ? {
13869
+ sampleCount: this.sampleCount,
13870
+ resultPicker: axMajorityVotePicker()
13871
+ } : void 0
13557
13872
  );
13558
13873
  const score = await metricFn({ prediction, example });
13559
13874
  totalScore += score;
@@ -13565,6 +13880,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
13565
13880
  }
13566
13881
  return count > 0 ? totalScore / count : 0;
13567
13882
  }
13883
+ /**
13884
+ * Fisher-Yates shuffle for stochastic evaluation
13885
+ */
13886
+ shuffleArray(array) {
13887
+ const shuffled = [...array];
13888
+ for (let i = shuffled.length - 1; i > 0; i--) {
13889
+ const j = Math.floor(Math.random() * (i + 1));
13890
+ [shuffled[i], shuffled[j]] = [shuffled[j], shuffled[i]];
13891
+ }
13892
+ return shuffled;
13893
+ }
13568
13894
  applyConfigToProgram(program, config, bootstrappedDemos, labeledExamples) {
13569
13895
  if (program.setInstruction) {
13570
13896
  program.setInstruction(config.instruction);
@@ -13586,14 +13912,14 @@ var AxMiPRO = class extends AxBaseOptimizer {
13586
13912
  if (miproOptions?.auto) {
13587
13913
  this.configureAuto(miproOptions.auto);
13588
13914
  }
13589
- const valset = this.getValidationSet(options) || (miproOptions?.valset ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
13915
+ const validationExamples = this.getValidationSet(options) || (miproOptions?.validationExamples ?? this.examples.slice(0, Math.floor(this.examples.length * 0.2)));
13590
13916
  if (this.isLoggingEnabled(options)) {
13591
13917
  this.getLogger(options)?.(
13592
13918
  `Starting MIPROv2 optimization with ${this.numTrials} trials`,
13593
13919
  { tags: ["optimizer", "start"] }
13594
13920
  );
13595
13921
  this.getLogger(options)?.(
13596
- `Using ${this.examples.length} examples for training and ${valset.length} for validation`,
13922
+ `Using ${this.examples.length} examples for training and ${validationExamples.length} for validation`,
13597
13923
  { tags: ["optimizer", "config"] }
13598
13924
  );
13599
13925
  if (this.teacherAI) {
@@ -13623,7 +13949,10 @@ var AxMiPRO = class extends AxBaseOptimizer {
13623
13949
  );
13624
13950
  }
13625
13951
  }
13626
- const instructions = await this.proposeInstructionCandidates(options);
13952
+ const instructions = await this.proposeInstructionCandidates(
13953
+ program,
13954
+ options
13955
+ );
13627
13956
  if (this.isLoggingEnabled(options)) {
13628
13957
  this.getLogger(options)?.(
13629
13958
  `Generated ${instructions.length} instruction candidates`,
@@ -13641,7 +13970,7 @@ var AxMiPRO = class extends AxBaseOptimizer {
13641
13970
  bootstrappedDemos,
13642
13971
  labeledExamples,
13643
13972
  instructions,
13644
- valset,
13973
+ validationExamples,
13645
13974
  metricFn,
13646
13975
  options
13647
13976
  );
@@ -13700,7 +14029,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
13700
14029
  bootstrappedDemos: bestConfig.bootstrappedDemos,
13701
14030
  labeledExamples: bestConfig.labeledExamples,
13702
14031
  numCandidates: this.numCandidates,
13703
- numTrials: this.numTrials
14032
+ numTrials: this.numTrials,
14033
+ sampleCount: this.sampleCount
13704
14034
  }
13705
14035
  };
13706
14036
  }
@@ -13745,7 +14075,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
13745
14075
  minImprovementThreshold: this.minImprovementThreshold,
13746
14076
  bayesianOptimization: this.bayesianOptimization,
13747
14077
  acquisitionFunction: this.acquisitionFunction,
13748
- explorationWeight: this.explorationWeight
14078
+ explorationWeight: this.explorationWeight,
14079
+ sampleCount: this.sampleCount
13749
14080
  };
13750
14081
  }
13751
14082
  /**
@@ -13780,12 +14111,17 @@ var AxMiPRO = class extends AxBaseOptimizer {
13780
14111
  if (config.minImprovementThreshold !== void 0) {
13781
14112
  this.minImprovementThreshold = config.minImprovementThreshold;
13782
14113
  }
14114
+ if (config.sampleCount !== void 0) {
14115
+ this.sampleCount = config.sampleCount;
14116
+ }
13783
14117
  }
13784
14118
  /**
13785
14119
  * Reset optimizer state for reuse with different programs
13786
14120
  */
13787
14121
  reset() {
13788
14122
  super.reset();
14123
+ this.miproConfigHistory = [];
14124
+ this.surrogateModel.clear();
13789
14125
  this.stats.convergenceInfo.convergenceThreshold = this.minImprovementThreshold;
13790
14126
  }
13791
14127
  /**
@@ -13803,8 +14139,8 @@ var AxMiPRO = class extends AxBaseOptimizer {
13803
14139
  "Reduce maxBootstrappedDemos or maxLabeledDemos, or provide more examples"
13804
14140
  );
13805
14141
  }
13806
- const valSetSize = this.getValidationSet().length;
13807
- if (valSetSize < 5) {
14142
+ const validationSetSize = this.getValidationSet().length;
14143
+ if (validationSetSize < 5) {
13808
14144
  result.issues.push(
13809
14145
  "Validation set too small for reliable MiPRO optimization"
13810
14146
  );
@@ -13818,6 +14154,141 @@ var AxMiPRO = class extends AxBaseOptimizer {
13818
14154
  suggestions: result.suggestions
13819
14155
  };
13820
14156
  }
14157
+ /**
14158
+ * Encodes a configuration into a string key for surrogate model lookup
14159
+ */
14160
+ encodeConfiguration(config) {
14161
+ return `${config.instruction.length}_${config.bootstrappedDemos}_${config.labeledExamples}`;
14162
+ }
14163
+ /**
14164
+ * Updates the surrogate model with a new configuration-score pair
14165
+ */
14166
+ updateSurrogateModel(config, score) {
14167
+ this.miproConfigHistory.push({ config: { ...config }, score });
14168
+ const key = this.encodeConfiguration(config);
14169
+ const similarConfigs = this.miproConfigHistory.filter(
14170
+ (entry) => this.encodeConfiguration(entry.config) === key
14171
+ );
14172
+ if (similarConfigs.length > 0) {
14173
+ const scores = similarConfigs.map((entry) => entry.score);
14174
+ const mean = scores.reduce((sum, s2) => sum + s2, 0) / scores.length;
14175
+ const variance = scores.length > 1 ? scores.reduce((sum, s2) => sum + Math.pow(s2 - mean, 2), 0) / (scores.length - 1) : 0.1;
14176
+ this.surrogateModel.set(key, { mean, variance });
14177
+ }
14178
+ }
14179
+ /**
14180
+ * Predicts performance using the surrogate model
14181
+ */
14182
+ predictPerformance(config) {
14183
+ const key = this.encodeConfiguration(config);
14184
+ if (this.surrogateModel.has(key)) {
14185
+ return this.surrogateModel.get(key);
14186
+ }
14187
+ if (this.miproConfigHistory.length > 0) {
14188
+ const similarities = this.miproConfigHistory.map((entry) => {
14189
+ const diff = Math.abs(entry.config.bootstrappedDemos - config.bootstrappedDemos) + Math.abs(entry.config.labeledExamples - config.labeledExamples);
14190
+ return { score: entry.score, similarity: 1 / (1 + diff) };
14191
+ });
14192
+ const totalWeight = similarities.reduce((sum, s2) => sum + s2.similarity, 0);
14193
+ const weightedMean = similarities.reduce((sum, s2) => sum + s2.score * s2.similarity, 0) / totalWeight;
14194
+ return { mean: weightedMean, variance: 0.2 };
14195
+ }
14196
+ return { mean: 0.5, variance: 0.3 };
14197
+ }
14198
+ /**
14199
+ * Calculates acquisition function value for Bayesian optimization
14200
+ */
14201
+ calculateAcquisitionValue(config) {
14202
+ const prediction = this.predictPerformance(config);
14203
+ const { mean, variance } = prediction;
14204
+ const std = Math.sqrt(variance);
14205
+ const bestScore = this.miproConfigHistory.length > 0 ? Math.max(...this.miproConfigHistory.map((entry) => entry.score)) : 0;
14206
+ switch (this.acquisitionFunction) {
14207
+ case "expected_improvement": {
14208
+ const improvement = mean - bestScore;
14209
+ if (std === 0) return Math.max(0, improvement);
14210
+ const z = improvement / std;
14211
+ const phi = 0.5 * (1 + this.erf(z / Math.sqrt(2)));
14212
+ const pdfValue = Math.exp(-0.5 * z * z) / Math.sqrt(2 * Math.PI);
14213
+ return improvement * phi + std * pdfValue;
14214
+ }
14215
+ case "upper_confidence_bound": {
14216
+ return mean + this.explorationWeight * std;
14217
+ }
14218
+ case "probability_improvement": {
14219
+ const improvement = mean - bestScore;
14220
+ if (std === 0) return improvement > 0 ? 1 : 0;
14221
+ const z = improvement / std;
14222
+ return 0.5 * (1 + this.erf(z / Math.sqrt(2)));
14223
+ }
14224
+ default:
14225
+ return mean;
14226
+ }
14227
+ }
14228
+ /**
14229
+ * Error function approximation for acquisition function calculations
14230
+ */
14231
+ erf(x) {
14232
+ const a1 = 0.254829592;
14233
+ const a2 = -0.284496736;
14234
+ const a3 = 1.421413741;
14235
+ const a4 = -1.453152027;
14236
+ const a5 = 1.061405429;
14237
+ const p = 0.3275911;
14238
+ const sign = x >= 0 ? 1 : -1;
14239
+ x = Math.abs(x);
14240
+ const t = 1 / (1 + p * x);
14241
+ const y = 1 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.exp(-x * x);
14242
+ return sign * y;
14243
+ }
14244
+ /**
14245
+ * Selects the next configuration to evaluate using Bayesian optimization
14246
+ */
14247
+ async selectConfigurationViaBayesianOptimization(instructions, bootstrappedDemos, labeledExamples) {
14248
+ const candidates = [];
14249
+ const numCandidates = Math.min(20, instructions.length * 3);
14250
+ for (let i = 0; i < numCandidates; i++) {
14251
+ const config = {
14252
+ instruction: instructions[i % instructions.length] || instructions[0] || "",
14253
+ bootstrappedDemos: Math.min(
14254
+ Math.floor(Math.random() * (bootstrappedDemos.length + 1)),
14255
+ this.maxBootstrappedDemos
14256
+ ),
14257
+ labeledExamples: Math.min(
14258
+ Math.floor(Math.random() * (labeledExamples.length + 1)),
14259
+ this.maxLabeledDemos
14260
+ )
14261
+ };
14262
+ const acquisitionValue = this.calculateAcquisitionValue(config);
14263
+ candidates.push({ config, acquisitionValue });
14264
+ }
14265
+ candidates.sort((a, b) => b.acquisitionValue - a.acquisitionValue);
14266
+ return candidates[0].config;
14267
+ }
14268
+ };
14269
+ var axMajorityVotePicker = () => {
14270
+ return async (data) => {
14271
+ if (data.type === "fields") {
14272
+ const counts = {};
14273
+ for (const { index, sample } of data.results) {
14274
+ const key = JSON.stringify(sample);
14275
+ if (!counts[key]) {
14276
+ counts[key] = { count: 0, index };
14277
+ }
14278
+ counts[key].count += 1;
14279
+ }
14280
+ let bestKey;
14281
+ let bestCount = -1;
14282
+ for (const [k, v] of Object.entries(counts)) {
14283
+ if (v.count > bestCount) {
14284
+ bestCount = v.count;
14285
+ bestKey = k;
14286
+ }
14287
+ }
14288
+ return counts[bestKey]?.index ?? 0;
14289
+ }
14290
+ return data.results[0]?.index ?? 0;
14291
+ };
13821
14292
  };
13822
14293
 
13823
14294
  // ai/mock/api.ts