@ax-llm/ax 12.0.12 → 12.0.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/index.cjs CHANGED
@@ -6257,15 +6257,24 @@ var MemoryImpl = class {
6257
6257
  debugRequest(items, this.options?.debugHideSystemPrompt);
6258
6258
  }
6259
6259
  }
6260
+ addFunctionResults(results) {
6261
+ const chat = results.map(({ index, ...value }) => ({
6262
+ index,
6263
+ value: structuredClone(value)
6264
+ }));
6265
+ const lastItem = this.getLast();
6266
+ if (lastItem?.role === "function") {
6267
+ lastItem.chat.push(...chat);
6268
+ } else {
6269
+ this.data.push({ role: "function", chat });
6270
+ }
6271
+ }
6260
6272
  addResponse(results) {
6261
- const chat = results.map((result) => ({
6262
- index: result.index,
6263
- value: structuredClone(result)
6273
+ const chat = results.map(({ index, ...value }) => ({
6274
+ index,
6275
+ value: structuredClone(value)
6264
6276
  }));
6265
- this.data.push({
6266
- role: "assistant",
6267
- chat
6268
- });
6277
+ this.data.push({ role: "assistant", chat });
6269
6278
  if (this.options?.debug) {
6270
6279
  for (const result of results) {
6271
6280
  debugResponse(result);
@@ -6353,9 +6362,20 @@ var MemoryImpl = class {
6353
6362
  history(index) {
6354
6363
  const result = [];
6355
6364
  for (const { role, chat } of this.data) {
6356
- const value = chat.find((v) => v.index === index)?.value;
6357
- if (value) {
6358
- result.push({ role, ...value });
6365
+ let values;
6366
+ if (role === "function") {
6367
+ values = chat.filter((v) => v.index === index).map((v) => v.value);
6368
+ } else {
6369
+ values = chat.find((v) => v.index === index)?.value;
6370
+ }
6371
+ if (Array.isArray(values)) {
6372
+ result.push(
6373
+ ...values.map(
6374
+ (v) => ({ ...v, role })
6375
+ )
6376
+ );
6377
+ } else if (values) {
6378
+ result.push({ ...values, role });
6359
6379
  }
6360
6380
  }
6361
6381
  return result;
@@ -6393,20 +6413,8 @@ var AxMemory = class {
6393
6413
  axValidateChatResponseResult(results);
6394
6414
  this.getMemory(sessionId).addResponse(results);
6395
6415
  }
6396
- addFunctionResult({
6397
- functionId,
6398
- isError,
6399
- index,
6400
- result
6401
- }, sessionId) {
6402
- const functionMessage = {
6403
- role: "function",
6404
- functionId,
6405
- isError,
6406
- result
6407
- };
6408
- axValidateChatRequestMessage(functionMessage);
6409
- this.getMemory(sessionId).addRequest([functionMessage], index);
6416
+ addFunctionResults(results, sessionId) {
6417
+ this.getMemory(sessionId).addFunctionResults(results);
6410
6418
  }
6411
6419
  updateResult(result, sessionId) {
6412
6420
  this.getMemory(sessionId).updateResult(result);
@@ -7643,51 +7651,48 @@ var processFunctions = async ({
7643
7651
  }
7644
7652
  return {
7645
7653
  result: functionResult ?? "",
7654
+ role: "function",
7646
7655
  functionId: func.id,
7647
7656
  index
7648
7657
  };
7649
7658
  }).catch((e) => {
7650
- if (e instanceof FunctionError) {
7651
- const result = e.getFixingInstructions();
7652
- if (span) {
7653
- const errorEventData = {
7654
- name: func.name,
7655
- message: e.toString()
7656
- };
7657
- if (!excludeContentFromTrace) {
7658
- errorEventData.args = func.args;
7659
- errorEventData.fixing_instructions = result;
7660
- }
7661
- span.addEvent("function.error", errorEventData);
7659
+ if (!(e instanceof FunctionError)) {
7660
+ throw e;
7661
+ }
7662
+ const result = e.getFixingInstructions();
7663
+ if (span) {
7664
+ const errorEventData = {
7665
+ name: func.name,
7666
+ message: e.toString()
7667
+ };
7668
+ if (!excludeContentFromTrace) {
7669
+ errorEventData.args = func.args;
7670
+ errorEventData.fixing_instructions = result;
7662
7671
  }
7663
- mem.addFunctionResult(
7664
- {
7665
- functionId: func.id,
7666
- isError: true,
7667
- index,
7668
- result
7669
- },
7670
- sessionId
7671
- );
7672
- mem.addTag("error", sessionId);
7673
- if (ai.getOptions().debug) {
7674
- const logger = ai.getLogger();
7675
- logger(`\u274C Function Error Correction:
7672
+ span.addEvent("function.error", errorEventData);
7673
+ }
7674
+ if (ai.getOptions().debug) {
7675
+ const logger = ai.getLogger();
7676
+ logger(`\u274C Function Error Correction:
7676
7677
  ${result}`, {
7677
- tags: ["error"]
7678
- });
7679
- }
7680
- } else {
7681
- throw e;
7678
+ tags: ["error"]
7679
+ });
7682
7680
  }
7681
+ return {
7682
+ functionId: func.id,
7683
+ isError: true,
7684
+ index,
7685
+ result,
7686
+ role: "function"
7687
+ };
7683
7688
  });
7684
7689
  return promise;
7685
7690
  });
7686
7691
  const results = await Promise.all(promises);
7687
- for (const result of results) {
7688
- if (result) {
7689
- mem.addFunctionResult(result, sessionId);
7690
- }
7692
+ const functionResults = results.filter((result) => result !== void 0);
7693
+ mem.addFunctionResults(functionResults, sessionId);
7694
+ if (functionResults.some((result) => result.isError)) {
7695
+ mem.addTag("error", sessionId);
7691
7696
  }
7692
7697
  return functionsExecuted;
7693
7698
  };
@@ -9961,6 +9966,96 @@ var AxProgram = class {
9961
9966
  }
9962
9967
  };
9963
9968
 
9969
+ // dsp/samples.ts
9970
+ function checkForFunctionCalls(mem, sessionId) {
9971
+ const history = mem.history(0, sessionId);
9972
+ const hasFunctionResults = history.some((msg) => msg.role === "function");
9973
+ const hasFunctionCalls = history.some(
9974
+ (msg) => msg.role === "assistant" && "functionCalls" in msg && Array.isArray(msg.functionCalls) && msg.functionCalls.length > 0
9975
+ );
9976
+ return hasFunctionCalls && hasFunctionResults;
9977
+ }
9978
+ function extractFunctionResults(mem, sessionId) {
9979
+ const history = mem.history(0, sessionId);
9980
+ const results = [];
9981
+ const assistantMessages = history.filter(
9982
+ (msg) => msg.role === "assistant" && "functionCalls" in msg && Array.isArray(msg.functionCalls) && msg.functionCalls.length > 0
9983
+ );
9984
+ const functionMessages = history.filter((msg) => msg.role === "function");
9985
+ for (const assistantMsg of assistantMessages) {
9986
+ if ("functionCalls" in assistantMsg && assistantMsg.functionCalls) {
9987
+ for (const funcCall of assistantMsg.functionCalls) {
9988
+ const funcResult = functionMessages.find(
9989
+ (msg) => "functionId" in msg && msg.functionId === funcCall.id
9990
+ );
9991
+ if (funcResult && "result" in funcResult && "functionId" in funcResult) {
9992
+ results.push({
9993
+ index: results.length,
9994
+ // Use sequential index for function results
9995
+ functionName: funcCall.function.name,
9996
+ functionId: funcCall.id,
9997
+ args: funcCall.function.params || "",
9998
+ result: String(funcResult.result),
9999
+ isError: "isError" in funcResult ? Boolean(funcResult.isError) : false
10000
+ });
10001
+ }
10002
+ }
10003
+ }
10004
+ }
10005
+ return results;
10006
+ }
10007
+ async function selectFromSamples(buffer, options, mem, sessionId) {
10008
+ if (!options?.resultPicker || buffer.length <= 1) {
10009
+ return 0;
10010
+ }
10011
+ const resultPicker = options.resultPicker;
10012
+ const hasFunctionCalls = mem ? checkForFunctionCalls(mem, sessionId) : false;
10013
+ if (hasFunctionCalls && mem) {
10014
+ const functionResults = extractFunctionResults(mem, sessionId);
10015
+ const selectedIndex = await resultPicker({
10016
+ type: "function",
10017
+ results: functionResults
10018
+ });
10019
+ if (selectedIndex < 0 || selectedIndex >= functionResults.length) {
10020
+ throw new Error(
10021
+ `Result picker returned invalid index: ${selectedIndex}. Must be between 0 and ${functionResults.length - 1}`
10022
+ );
10023
+ }
10024
+ return selectedIndex;
10025
+ } else {
10026
+ const fieldResults = buffer.map((b, index) => ({
10027
+ index,
10028
+ sample: b.delta
10029
+ }));
10030
+ const selectedIndex = await resultPicker({
10031
+ type: "fields",
10032
+ results: fieldResults
10033
+ });
10034
+ if (selectedIndex < 0 || selectedIndex >= buffer.length) {
10035
+ throw new Error(
10036
+ `Result picker returned invalid index: ${selectedIndex}. Must be between 0 and ${buffer.length - 1}`
10037
+ );
10038
+ }
10039
+ return selectedIndex;
10040
+ }
10041
+ }
10042
+ async function selectFromSamplesInMemory(mem, sessionId, options) {
10043
+ const lastMemory = mem?.getLast(sessionId);
10044
+ if (!lastMemory || lastMemory.role !== "assistant") {
10045
+ return 0;
10046
+ }
10047
+ if (lastMemory.chat.length <= 1) {
10048
+ return 0;
10049
+ }
10050
+ const buffer = lastMemory.chat.map((chat) => ({
10051
+ version: 0,
10052
+ index: chat.index,
10053
+ delta: chat.value
10054
+ }));
10055
+ const selectedIndex = await selectFromSamples(buffer, options, mem, sessionId);
10056
+ return selectedIndex;
10057
+ }
10058
+
9964
10059
  // dsp/validate.ts
9965
10060
  function handleValidationError(mem, errorFields, ai, promptTemplate, sessionId) {
9966
10061
  mem.addRequest(
@@ -10076,7 +10171,10 @@ var AxGen = class extends AxProgramWithSignature {
10076
10171
  thinkingTokenBudget,
10077
10172
  showThoughts
10078
10173
  } = options ?? {};
10079
- const chatPrompt = mem?.history(0, sessionId) ?? [];
10174
+ const selectedIndex = await selectFromSamplesInMemory(mem, sessionId, {
10175
+ resultPicker: options?.resultPicker
10176
+ });
10177
+ const chatPrompt = mem?.history(selectedIndex, sessionId) ?? [];
10080
10178
  if (chatPrompt.length === 0) {
10081
10179
  throw new Error("No chat prompt found");
10082
10180
  }
@@ -10087,7 +10185,8 @@ var AxGen = class extends AxProgramWithSignature {
10087
10185
  }
10088
10186
  const modelConfig = {
10089
10187
  ...options?.modelConfig,
10090
- ...options?.sampleCount ? { n: options.sampleCount } : {}
10188
+ ...options?.sampleCount ? { n: options.sampleCount } : {},
10189
+ ...options?.sampleCount && options?.modelConfig?.temperature == 1 ? { temperature: 0.8 } : {}
10091
10190
  };
10092
10191
  const res = await ai.chat(
10093
10192
  {
@@ -10364,15 +10463,58 @@ var AxGen = class extends AxProgramWithSignature {
10364
10463
  currentVersion = delta.version;
10365
10464
  buffer = mergeDeltas(buffer, delta);
10366
10465
  }
10367
- const result = buffer[0]?.delta ?? {};
10466
+ const selectedIndex = await selectFromSamples(
10467
+ buffer,
10468
+ {
10469
+ resultPicker: options?.resultPicker
10470
+ },
10471
+ // Pass memory to enable function result selection
10472
+ options?.mem,
10473
+ options?.sessionId
10474
+ );
10475
+ const selectedResult = buffer[selectedIndex];
10476
+ const result = selectedResult?.delta ?? {};
10368
10477
  this.trace = { ...values, ...result };
10369
10478
  return result;
10370
10479
  }
10371
10480
  async *streamingForward(ai, values, options) {
10372
- yield* this._forward1(ai, values, {
10481
+ if (!options?.resultPicker) {
10482
+ yield* this._forward1(ai, values, {
10483
+ ...options,
10484
+ stream: true
10485
+ });
10486
+ return;
10487
+ }
10488
+ const generator = this._forward1(ai, values, {
10373
10489
  ...options,
10374
10490
  stream: true
10375
10491
  });
10492
+ let buffer = [];
10493
+ let currentVersion = 0;
10494
+ for await (const delta of generator) {
10495
+ if (delta.version !== currentVersion) {
10496
+ buffer = [];
10497
+ }
10498
+ currentVersion = delta.version;
10499
+ buffer = mergeDeltas(buffer, delta);
10500
+ }
10501
+ const selectedIndex = await selectFromSamples(
10502
+ buffer,
10503
+ {
10504
+ resultPicker: options?.resultPicker
10505
+ },
10506
+ // Pass memory to enable function result selection
10507
+ options?.mem,
10508
+ options?.sessionId
10509
+ );
10510
+ const selectedResult = buffer[selectedIndex];
10511
+ if (selectedResult) {
10512
+ yield {
10513
+ version: currentVersion,
10514
+ index: selectedIndex,
10515
+ delta: selectedResult.delta
10516
+ };
10517
+ }
10376
10518
  }
10377
10519
  setExamples(examples, options) {
10378
10520
  super.setExamples(examples, options);