langchain 0.0.71 → 0.0.73

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. package/dist/agents/index.cjs +2 -1
  2. package/dist/agents/index.d.ts +1 -1
  3. package/dist/agents/index.js +1 -1
  4. package/dist/chains/base.cjs +1 -1
  5. package/dist/chains/base.js +1 -1
  6. package/dist/chains/conversation.cjs +3 -3
  7. package/dist/chains/conversation.d.ts +1 -0
  8. package/dist/chains/conversation.js +2 -2
  9. package/dist/chains/index.cjs +10 -1
  10. package/dist/chains/index.d.ts +4 -0
  11. package/dist/chains/index.js +4 -0
  12. package/dist/chains/retrieval_qa.cjs +3 -1
  13. package/dist/chains/retrieval_qa.d.ts +2 -1
  14. package/dist/chains/retrieval_qa.js +4 -2
  15. package/dist/chains/router/llm_router.cjs +31 -0
  16. package/dist/chains/router/llm_router.d.ts +24 -0
  17. package/dist/chains/router/llm_router.js +27 -0
  18. package/dist/chains/router/multi_prompt.cjs +76 -0
  19. package/dist/chains/router/multi_prompt.d.ts +8 -0
  20. package/dist/chains/router/multi_prompt.js +72 -0
  21. package/dist/chains/router/multi_prompt_prompt.cjs +42 -0
  22. package/dist/chains/router/multi_prompt_prompt.d.ts +2 -0
  23. package/dist/chains/router/multi_prompt_prompt.js +38 -0
  24. package/dist/chains/router/multi_retrieval_prompt.cjs +42 -0
  25. package/dist/chains/router/multi_retrieval_prompt.d.ts +2 -0
  26. package/dist/chains/router/multi_retrieval_prompt.js +38 -0
  27. package/dist/chains/router/multi_retrieval_qa.cjs +89 -0
  28. package/dist/chains/router/multi_retrieval_qa.d.ts +15 -0
  29. package/dist/chains/router/multi_retrieval_qa.js +85 -0
  30. package/dist/chains/router/multi_route.cjs +86 -0
  31. package/dist/chains/router/multi_route.d.ts +38 -0
  32. package/dist/chains/router/multi_route.js +81 -0
  33. package/dist/chains/router/utils.cjs +34 -0
  34. package/dist/chains/router/utils.d.ts +3 -0
  35. package/dist/chains/router/utils.js +30 -0
  36. package/dist/chat_models/openai.cjs +33 -19
  37. package/dist/chat_models/openai.d.ts +1 -1
  38. package/dist/chat_models/openai.js +33 -19
  39. package/dist/embeddings/openai.d.ts +1 -1
  40. package/dist/llms/openai-chat.cjs +31 -19
  41. package/dist/llms/openai-chat.d.ts +1 -1
  42. package/dist/llms/openai-chat.js +31 -19
  43. package/dist/llms/openai.cjs +29 -9
  44. package/dist/llms/openai.d.ts +1 -1
  45. package/dist/llms/openai.js +29 -9
  46. package/dist/output_parsers/index.cjs +6 -1
  47. package/dist/output_parsers/index.d.ts +3 -1
  48. package/dist/output_parsers/index.js +3 -1
  49. package/dist/output_parsers/list.cjs +46 -1
  50. package/dist/output_parsers/list.d.ts +14 -0
  51. package/dist/output_parsers/list.js +44 -0
  52. package/dist/output_parsers/router.cjs +32 -0
  53. package/dist/output_parsers/router.d.ts +11 -0
  54. package/dist/output_parsers/router.js +28 -0
  55. package/dist/output_parsers/structured.cjs +43 -3
  56. package/dist/output_parsers/structured.d.ts +11 -1
  57. package/dist/output_parsers/structured.js +41 -2
  58. package/dist/schema/index.cjs +10 -1
  59. package/dist/schema/index.d.ts +5 -0
  60. package/dist/schema/index.js +8 -0
  61. package/dist/schema/output_parser.d.ts +7 -1
  62. package/dist/stores/message/dynamodb.cjs +126 -0
  63. package/dist/stores/message/dynamodb.d.ts +23 -0
  64. package/dist/stores/message/dynamodb.js +122 -0
  65. package/dist/stores/message/in_memory.cjs +3 -6
  66. package/dist/stores/message/in_memory.d.ts +3 -4
  67. package/dist/stores/message/in_memory.js +4 -7
  68. package/dist/stores/message/utils.cjs +31 -0
  69. package/dist/stores/message/utils.d.ts +8 -0
  70. package/dist/stores/message/utils.js +26 -0
  71. package/dist/types/openai-types.cjs +2 -0
  72. package/dist/types/openai-types.d.ts +101 -0
  73. package/dist/types/openai-types.js +1 -0
  74. package/package.json +14 -1
  75. package/stores/message/dynamodb.cjs +1 -0
  76. package/stores/message/dynamodb.d.ts +1 -0
  77. package/stores/message/dynamodb.js +1 -0
@@ -0,0 +1,89 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.MultiRetrievalQAChain = void 0;
4
+ const zod_1 = require("zod");
5
+ const multi_route_js_1 = require("./multi_route.cjs");
6
+ const template_js_1 = require("../../prompts/template.cjs");
7
+ const prompt_js_1 = require("../../prompts/prompt.cjs");
8
+ const llm_router_js_1 = require("./llm_router.cjs");
9
+ const conversation_js_1 = require("../../chains/conversation.cjs");
10
+ const multi_retrieval_prompt_js_1 = require("./multi_retrieval_prompt.cjs");
11
+ const utils_js_1 = require("./utils.cjs");
12
+ const retrieval_qa_js_1 = require("../../chains/retrieval_qa.cjs");
13
+ const router_js_1 = require("../../output_parsers/router.cjs");
14
+ class MultiRetrievalQAChain extends multi_route_js_1.MultiRouteChain {
15
+ get outputKeys() {
16
+ return ["result"];
17
+ }
18
+ static fromRetrievers(llm, retrieverNames, retrieverDescriptions, retrievers, retrieverPrompts, defaults, options) {
19
+ const { defaultRetriever, defaultPrompt, defaultChain } = defaults ?? {};
20
+ if (defaultPrompt && !defaultRetriever) {
21
+ throw new Error("`default_retriever` must be specified if `default_prompt` is \nprovided. Received only `default_prompt`.");
22
+ }
23
+ const destinations = (0, utils_js_1.zipEntries)(retrieverNames, retrieverDescriptions).map(([name, desc]) => `${name}: ${desc}`);
24
+ const structuredOutputParserSchema = zod_1.z.object({
25
+ destination: zod_1.z
26
+ .string()
27
+ .optional()
28
+ .describe('name of the question answering system to use or "DEFAULT"'),
29
+ next_inputs: zod_1.z
30
+ .object({
31
+ query: zod_1.z
32
+ .string()
33
+ .describe("a potentially modified version of the original input"),
34
+ })
35
+ .describe("input to be fed to the next model"),
36
+ });
37
+ const outputParser = new router_js_1.RouterOutputParser(structuredOutputParserSchema);
38
+ const destinationsStr = destinations.join("\n");
39
+ const routerTemplate = (0, template_js_1.interpolateFString)((0, multi_retrieval_prompt_js_1.STRUCTURED_MULTI_RETRIEVAL_ROUTER_TEMPLATE)(outputParser.getFormatInstructions({ interpolationDepth: 4 })), {
40
+ destinations: destinationsStr,
41
+ });
42
+ const routerPrompt = new prompt_js_1.PromptTemplate({
43
+ template: routerTemplate,
44
+ inputVariables: ["input"],
45
+ outputParser,
46
+ });
47
+ const routerChain = llm_router_js_1.LLMRouterChain.fromLLM(llm, routerPrompt);
48
+ const prompts = retrieverPrompts ?? retrievers.map(() => null);
49
+ const destinationChains = (0, utils_js_1.zipEntries)(retrieverNames, retrievers, prompts).reduce((acc, [name, retriever, prompt]) => {
50
+ let opt;
51
+ if (prompt) {
52
+ opt = { prompt };
53
+ }
54
+ acc[name] = retrieval_qa_js_1.RetrievalQAChain.fromLLM(llm, retriever, opt);
55
+ return acc;
56
+ }, {});
57
+ let _defaultChain;
58
+ if (defaultChain) {
59
+ _defaultChain = defaultChain;
60
+ }
61
+ else if (defaultRetriever) {
62
+ _defaultChain = retrieval_qa_js_1.RetrievalQAChain.fromLLM(llm, defaultRetriever, {
63
+ prompt: defaultPrompt,
64
+ });
65
+ }
66
+ else {
67
+ const promptTemplate = conversation_js_1.DEFAULT_TEMPLATE.replace("input", "query");
68
+ const prompt = new prompt_js_1.PromptTemplate({
69
+ template: promptTemplate,
70
+ inputVariables: ["history", "query"],
71
+ });
72
+ _defaultChain = new conversation_js_1.ConversationChain({
73
+ llm,
74
+ prompt,
75
+ outputKey: "result",
76
+ });
77
+ }
78
+ return new MultiRetrievalQAChain({
79
+ routerChain,
80
+ destinationChains,
81
+ defaultChain: _defaultChain,
82
+ ...options,
83
+ });
84
+ }
85
+ _chainType() {
86
+ return "multi_retrieval_qa_chain";
87
+ }
88
+ }
89
+ exports.MultiRetrievalQAChain = MultiRetrievalQAChain;
@@ -0,0 +1,15 @@
1
+ import { BaseLanguageModel } from "../../base_language/index.js";
2
+ import { MultiRouteChain, MultiRouteChainInput } from "./multi_route.js";
3
+ import { BaseChain } from "../../chains/base.js";
4
+ import { PromptTemplate } from "../../prompts/prompt.js";
5
+ import { BaseRetriever } from "../../schema/index.js";
6
+ export type MultiRetrievalDefaults = {
7
+ defaultRetriever?: BaseRetriever;
8
+ defaultPrompt?: PromptTemplate;
9
+ defaultChain?: BaseChain;
10
+ };
11
+ export declare class MultiRetrievalQAChain extends MultiRouteChain {
12
+ get outputKeys(): string[];
13
+ static fromRetrievers(llm: BaseLanguageModel, retrieverNames: string[], retrieverDescriptions: string[], retrievers: BaseRetriever[], retrieverPrompts?: PromptTemplate[], defaults?: MultiRetrievalDefaults, options?: Omit<MultiRouteChainInput, "defaultChain">): MultiRetrievalQAChain;
14
+ _chainType(): string;
15
+ }
@@ -0,0 +1,85 @@
1
+ import { z } from "zod";
2
+ import { MultiRouteChain } from "./multi_route.js";
3
+ import { interpolateFString } from "../../prompts/template.js";
4
+ import { PromptTemplate } from "../../prompts/prompt.js";
5
+ import { LLMRouterChain } from "./llm_router.js";
6
+ import { ConversationChain, DEFAULT_TEMPLATE, } from "../../chains/conversation.js";
7
+ import { STRUCTURED_MULTI_RETRIEVAL_ROUTER_TEMPLATE } from "./multi_retrieval_prompt.js";
8
+ import { zipEntries } from "./utils.js";
9
+ import { RetrievalQAChain } from "../../chains/retrieval_qa.js";
10
+ import { RouterOutputParser } from "../../output_parsers/router.js";
11
+ export class MultiRetrievalQAChain extends MultiRouteChain {
12
+ get outputKeys() {
13
+ return ["result"];
14
+ }
15
+ static fromRetrievers(llm, retrieverNames, retrieverDescriptions, retrievers, retrieverPrompts, defaults, options) {
16
+ const { defaultRetriever, defaultPrompt, defaultChain } = defaults ?? {};
17
+ if (defaultPrompt && !defaultRetriever) {
18
+ throw new Error("`default_retriever` must be specified if `default_prompt` is \nprovided. Received only `default_prompt`.");
19
+ }
20
+ const destinations = zipEntries(retrieverNames, retrieverDescriptions).map(([name, desc]) => `${name}: ${desc}`);
21
+ const structuredOutputParserSchema = z.object({
22
+ destination: z
23
+ .string()
24
+ .optional()
25
+ .describe('name of the question answering system to use or "DEFAULT"'),
26
+ next_inputs: z
27
+ .object({
28
+ query: z
29
+ .string()
30
+ .describe("a potentially modified version of the original input"),
31
+ })
32
+ .describe("input to be fed to the next model"),
33
+ });
34
+ const outputParser = new RouterOutputParser(structuredOutputParserSchema);
35
+ const destinationsStr = destinations.join("\n");
36
+ const routerTemplate = interpolateFString(STRUCTURED_MULTI_RETRIEVAL_ROUTER_TEMPLATE(outputParser.getFormatInstructions({ interpolationDepth: 4 })), {
37
+ destinations: destinationsStr,
38
+ });
39
+ const routerPrompt = new PromptTemplate({
40
+ template: routerTemplate,
41
+ inputVariables: ["input"],
42
+ outputParser,
43
+ });
44
+ const routerChain = LLMRouterChain.fromLLM(llm, routerPrompt);
45
+ const prompts = retrieverPrompts ?? retrievers.map(() => null);
46
+ const destinationChains = zipEntries(retrieverNames, retrievers, prompts).reduce((acc, [name, retriever, prompt]) => {
47
+ let opt;
48
+ if (prompt) {
49
+ opt = { prompt };
50
+ }
51
+ acc[name] = RetrievalQAChain.fromLLM(llm, retriever, opt);
52
+ return acc;
53
+ }, {});
54
+ let _defaultChain;
55
+ if (defaultChain) {
56
+ _defaultChain = defaultChain;
57
+ }
58
+ else if (defaultRetriever) {
59
+ _defaultChain = RetrievalQAChain.fromLLM(llm, defaultRetriever, {
60
+ prompt: defaultPrompt,
61
+ });
62
+ }
63
+ else {
64
+ const promptTemplate = DEFAULT_TEMPLATE.replace("input", "query");
65
+ const prompt = new PromptTemplate({
66
+ template: promptTemplate,
67
+ inputVariables: ["history", "query"],
68
+ });
69
+ _defaultChain = new ConversationChain({
70
+ llm,
71
+ prompt,
72
+ outputKey: "result",
73
+ });
74
+ }
75
+ return new MultiRetrievalQAChain({
76
+ routerChain,
77
+ destinationChains,
78
+ defaultChain: _defaultChain,
79
+ ...options,
80
+ });
81
+ }
82
+ _chainType() {
83
+ return "multi_retrieval_qa_chain";
84
+ }
85
+ }
@@ -0,0 +1,86 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.MultiRouteChain = exports.RouterChain = void 0;
4
+ const base_js_1 = require("../../chains/base.cjs");
5
+ class RouterChain extends base_js_1.BaseChain {
6
+ get outputKeys() {
7
+ return ["destination", "next_inputs"];
8
+ }
9
+ async route(inputs, callbacks) {
10
+ const result = await this.call(inputs, callbacks);
11
+ return {
12
+ destination: result.destination,
13
+ nextInputs: result.next_inputs,
14
+ };
15
+ }
16
+ }
17
+ exports.RouterChain = RouterChain;
18
+ class MultiRouteChain extends base_js_1.BaseChain {
19
+ constructor(fields) {
20
+ super(fields);
21
+ Object.defineProperty(this, "routerChain", {
22
+ enumerable: true,
23
+ configurable: true,
24
+ writable: true,
25
+ value: void 0
26
+ });
27
+ Object.defineProperty(this, "destinationChains", {
28
+ enumerable: true,
29
+ configurable: true,
30
+ writable: true,
31
+ value: void 0
32
+ });
33
+ Object.defineProperty(this, "defaultChain", {
34
+ enumerable: true,
35
+ configurable: true,
36
+ writable: true,
37
+ value: void 0
38
+ });
39
+ Object.defineProperty(this, "silentErrors", {
40
+ enumerable: true,
41
+ configurable: true,
42
+ writable: true,
43
+ value: false
44
+ });
45
+ this.routerChain = fields.routerChain;
46
+ this.destinationChains = fields.destinationChains;
47
+ this.defaultChain = fields.defaultChain;
48
+ this.silentErrors = fields.silentErrors ?? this.silentErrors;
49
+ }
50
+ get inputKeys() {
51
+ return this.routerChain.inputKeys;
52
+ }
53
+ get outputKeys() {
54
+ return [];
55
+ }
56
+ async _call(values, runManager) {
57
+ const { destination, nextInputs } = await this.routerChain.route(values, runManager?.getChild());
58
+ await runManager?.handleText(`${destination}: ${JSON.stringify(nextInputs)}`);
59
+ if (!destination) {
60
+ return this.defaultChain
61
+ .call(nextInputs, runManager?.getChild())
62
+ .catch((err) => {
63
+ throw new Error(`Error in default chain: ${err}`);
64
+ });
65
+ }
66
+ if (destination in this.destinationChains) {
67
+ return this.destinationChains[destination]
68
+ .call(nextInputs, runManager?.getChild())
69
+ .catch((err) => {
70
+ throw new Error(`Error in ${destination} chain: ${err}`);
71
+ });
72
+ }
73
+ if (this.silentErrors) {
74
+ return this.defaultChain
75
+ .call(nextInputs, runManager?.getChild())
76
+ .catch((err) => {
77
+ throw new Error(`Error in default chain: ${err}`);
78
+ });
79
+ }
80
+ throw new Error(`Destination ${destination} not found in destination chains with keys ${Object.keys(this.destinationChains)}`);
81
+ }
82
+ _chainType() {
83
+ return "multi_route_chain";
84
+ }
85
+ }
86
+ exports.MultiRouteChain = MultiRouteChain;
@@ -0,0 +1,38 @@
1
+ import { CallbackManagerForChainRun, Callbacks } from "../../callbacks/manager.js";
2
+ import { BaseChain, ChainInputs } from "../../chains/base.js";
3
+ import { ChainValues } from "../../schema/index.js";
4
+ type Inputs = {
5
+ [key: string]: Inputs | Inputs[] | string | string[] | number | number[];
6
+ };
7
+ export interface Route {
8
+ destination?: string;
9
+ nextInputs: {
10
+ [key: string]: Inputs;
11
+ };
12
+ }
13
+ export interface MultiRouteChainInput extends ChainInputs {
14
+ routerChain: RouterChain;
15
+ destinationChains: {
16
+ [name: string]: BaseChain;
17
+ };
18
+ defaultChain: BaseChain;
19
+ silentErrors?: boolean;
20
+ }
21
+ export declare abstract class RouterChain extends BaseChain {
22
+ get outputKeys(): string[];
23
+ route(inputs: ChainValues, callbacks?: Callbacks): Promise<Route>;
24
+ }
25
+ export declare class MultiRouteChain extends BaseChain {
26
+ routerChain: RouterChain;
27
+ destinationChains: {
28
+ [name: string]: BaseChain;
29
+ };
30
+ defaultChain: BaseChain;
31
+ silentErrors: boolean;
32
+ constructor(fields: MultiRouteChainInput);
33
+ get inputKeys(): string[];
34
+ get outputKeys(): string[];
35
+ _call(values: ChainValues, runManager?: CallbackManagerForChainRun): Promise<ChainValues>;
36
+ _chainType(): string;
37
+ }
38
+ export {};
@@ -0,0 +1,81 @@
1
+ import { BaseChain } from "../../chains/base.js";
2
+ export class RouterChain extends BaseChain {
3
+ get outputKeys() {
4
+ return ["destination", "next_inputs"];
5
+ }
6
+ async route(inputs, callbacks) {
7
+ const result = await this.call(inputs, callbacks);
8
+ return {
9
+ destination: result.destination,
10
+ nextInputs: result.next_inputs,
11
+ };
12
+ }
13
+ }
14
+ export class MultiRouteChain extends BaseChain {
15
+ constructor(fields) {
16
+ super(fields);
17
+ Object.defineProperty(this, "routerChain", {
18
+ enumerable: true,
19
+ configurable: true,
20
+ writable: true,
21
+ value: void 0
22
+ });
23
+ Object.defineProperty(this, "destinationChains", {
24
+ enumerable: true,
25
+ configurable: true,
26
+ writable: true,
27
+ value: void 0
28
+ });
29
+ Object.defineProperty(this, "defaultChain", {
30
+ enumerable: true,
31
+ configurable: true,
32
+ writable: true,
33
+ value: void 0
34
+ });
35
+ Object.defineProperty(this, "silentErrors", {
36
+ enumerable: true,
37
+ configurable: true,
38
+ writable: true,
39
+ value: false
40
+ });
41
+ this.routerChain = fields.routerChain;
42
+ this.destinationChains = fields.destinationChains;
43
+ this.defaultChain = fields.defaultChain;
44
+ this.silentErrors = fields.silentErrors ?? this.silentErrors;
45
+ }
46
+ get inputKeys() {
47
+ return this.routerChain.inputKeys;
48
+ }
49
+ get outputKeys() {
50
+ return [];
51
+ }
52
+ async _call(values, runManager) {
53
+ const { destination, nextInputs } = await this.routerChain.route(values, runManager?.getChild());
54
+ await runManager?.handleText(`${destination}: ${JSON.stringify(nextInputs)}`);
55
+ if (!destination) {
56
+ return this.defaultChain
57
+ .call(nextInputs, runManager?.getChild())
58
+ .catch((err) => {
59
+ throw new Error(`Error in default chain: ${err}`);
60
+ });
61
+ }
62
+ if (destination in this.destinationChains) {
63
+ return this.destinationChains[destination]
64
+ .call(nextInputs, runManager?.getChild())
65
+ .catch((err) => {
66
+ throw new Error(`Error in ${destination} chain: ${err}`);
67
+ });
68
+ }
69
+ if (this.silentErrors) {
70
+ return this.defaultChain
71
+ .call(nextInputs, runManager?.getChild())
72
+ .catch((err) => {
73
+ throw new Error(`Error in default chain: ${err}`);
74
+ });
75
+ }
76
+ throw new Error(`Destination ${destination} not found in destination chains with keys ${Object.keys(this.destinationChains)}`);
77
+ }
78
+ _chainType() {
79
+ return "multi_route_chain";
80
+ }
81
+ }
@@ -0,0 +1,34 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.zipEntries = void 0;
4
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
5
+ function zipEntries(...arrays) {
6
+ // Check for empty input
7
+ if (arrays.length === 0) {
8
+ return [];
9
+ }
10
+ // Find the length of the first input array
11
+ const firstArrayLength = arrays[0].length;
12
+ // Ensure all input arrays have the same length
13
+ for (const array of arrays) {
14
+ if (array.length !== firstArrayLength) {
15
+ throw new Error("All input arrays must have the same length.");
16
+ }
17
+ }
18
+ // Create an empty array to store the zipped arrays
19
+ const zipped = [];
20
+ // Iterate through each element of the first input array
21
+ for (let i = 0; i < firstArrayLength; i += 1) {
22
+ // Create an array to store the zipped elements at the current index
23
+ const zippedElement = [];
24
+ // Iterate through each input array
25
+ for (const array of arrays) {
26
+ // Add the element at the current index to the zipped element array
27
+ zippedElement.push(array[i]);
28
+ }
29
+ // Add the zipped element array to the zipped array
30
+ zipped.push(zippedElement);
31
+ }
32
+ return zipped;
33
+ }
34
+ exports.zipEntries = zipEntries;
@@ -0,0 +1,3 @@
1
+ export declare function zipEntries<T extends any[]>(...arrays: {
2
+ [P in keyof T]: T[P][];
3
+ }): T[];
@@ -0,0 +1,30 @@
1
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
2
+ export function zipEntries(...arrays) {
3
+ // Check for empty input
4
+ if (arrays.length === 0) {
5
+ return [];
6
+ }
7
+ // Find the length of the first input array
8
+ const firstArrayLength = arrays[0].length;
9
+ // Ensure all input arrays have the same length
10
+ for (const array of arrays) {
11
+ if (array.length !== firstArrayLength) {
12
+ throw new Error("All input arrays must have the same length.");
13
+ }
14
+ }
15
+ // Create an empty array to store the zipped arrays
16
+ const zipped = [];
17
+ // Iterate through each element of the first input array
18
+ for (let i = 0; i < firstArrayLength; i += 1) {
19
+ // Create an array to store the zipped elements at the current index
20
+ const zippedElement = [];
21
+ // Iterate through each input array
22
+ for (const array of arrays) {
23
+ // Add the element at the current index to the zipped element array
24
+ zippedElement.push(array[i]);
25
+ }
26
+ // Add the zipped element array to the zipped array
27
+ zipped.push(zippedElement);
28
+ }
29
+ return zipped;
30
+ }
@@ -281,6 +281,7 @@ class ChatOpenAI extends base_js_1.BaseChatModel {
281
281
  ? await new Promise((resolve, reject) => {
282
282
  let response;
283
283
  let rejected = false;
284
+ let resolved = false;
284
285
  this.completionWithRetry({
285
286
  ...params,
286
287
  messages: messagesMapped,
@@ -290,6 +291,10 @@ class ChatOpenAI extends base_js_1.BaseChatModel {
290
291
  responseType: "stream",
291
292
  onmessage: (event) => {
292
293
  if (event.data?.trim?.() === "[DONE]") {
294
+ if (resolved) {
295
+ return;
296
+ }
297
+ resolved = true;
293
298
  resolve(response);
294
299
  }
295
300
  else {
@@ -305,26 +310,35 @@ class ChatOpenAI extends base_js_1.BaseChatModel {
305
310
  };
306
311
  }
307
312
  // on all messages, update choice
308
- const part = message.choices[0];
309
- if (part != null) {
310
- let choice = response.choices.find((c) => c.index === part.index);
311
- if (!choice) {
312
- choice = {
313
- index: part.index,
314
- finish_reason: part.finish_reason ?? undefined,
315
- };
316
- response.choices.push(choice);
317
- }
318
- if (!choice.message) {
319
- choice.message = {
320
- role: part.delta
321
- ?.role,
322
- content: part.delta?.content ?? "",
323
- };
313
+ for (const part of message.choices) {
314
+ if (part != null) {
315
+ let choice = response.choices.find((c) => c.index === part.index);
316
+ if (!choice) {
317
+ choice = {
318
+ index: part.index,
319
+ finish_reason: part.finish_reason ?? undefined,
320
+ };
321
+ response.choices[part.index] = choice;
322
+ }
323
+ if (!choice.message) {
324
+ choice.message = {
325
+ role: part.delta
326
+ ?.role,
327
+ content: part.delta?.content ?? "",
328
+ };
329
+ }
330
+ choice.message.content += part.delta?.content ?? "";
331
+ // TODO this should pass part.index to the callback
332
+ // when that's supported there
333
+ // eslint-disable-next-line no-void
334
+ void runManager?.handleLLMNewToken(part.delta?.content ?? "");
324
335
  }
325
- choice.message.content += part.delta?.content ?? "";
326
- // eslint-disable-next-line no-void
327
- void runManager?.handleLLMNewToken(part.delta?.content ?? "");
336
+ }
337
+ // when all messages are finished, resolve
338
+ if (!resolved &&
339
+ message.choices.every((c) => c.finish_reason != null)) {
340
+ resolved = true;
341
+ resolve(response);
328
342
  }
329
343
  }
330
344
  },
@@ -1,5 +1,5 @@
1
1
  import { CreateChatCompletionRequest, ConfigurationParameters, CreateChatCompletionResponse } from "openai";
2
- import { AzureOpenAIInput, OpenAICallOptions, OpenAIChatInput } from "../types/open-ai-types.js";
2
+ import { AzureOpenAIInput, OpenAICallOptions, OpenAIChatInput } from "../types/openai-types.js";
3
3
  import type { StreamingAxiosConfiguration } from "../util/axios-types.js";
4
4
  import { BaseChatModel, BaseChatModelParams } from "./base.js";
5
5
  import { BaseChatMessage, ChatResult } from "../schema/index.js";
@@ -275,6 +275,7 @@ export class ChatOpenAI extends BaseChatModel {
275
275
  ? await new Promise((resolve, reject) => {
276
276
  let response;
277
277
  let rejected = false;
278
+ let resolved = false;
278
279
  this.completionWithRetry({
279
280
  ...params,
280
281
  messages: messagesMapped,
@@ -284,6 +285,10 @@ export class ChatOpenAI extends BaseChatModel {
284
285
  responseType: "stream",
285
286
  onmessage: (event) => {
286
287
  if (event.data?.trim?.() === "[DONE]") {
288
+ if (resolved) {
289
+ return;
290
+ }
291
+ resolved = true;
287
292
  resolve(response);
288
293
  }
289
294
  else {
@@ -299,26 +304,35 @@ export class ChatOpenAI extends BaseChatModel {
299
304
  };
300
305
  }
301
306
  // on all messages, update choice
302
- const part = message.choices[0];
303
- if (part != null) {
304
- let choice = response.choices.find((c) => c.index === part.index);
305
- if (!choice) {
306
- choice = {
307
- index: part.index,
308
- finish_reason: part.finish_reason ?? undefined,
309
- };
310
- response.choices.push(choice);
311
- }
312
- if (!choice.message) {
313
- choice.message = {
314
- role: part.delta
315
- ?.role,
316
- content: part.delta?.content ?? "",
317
- };
307
+ for (const part of message.choices) {
308
+ if (part != null) {
309
+ let choice = response.choices.find((c) => c.index === part.index);
310
+ if (!choice) {
311
+ choice = {
312
+ index: part.index,
313
+ finish_reason: part.finish_reason ?? undefined,
314
+ };
315
+ response.choices[part.index] = choice;
316
+ }
317
+ if (!choice.message) {
318
+ choice.message = {
319
+ role: part.delta
320
+ ?.role,
321
+ content: part.delta?.content ?? "",
322
+ };
323
+ }
324
+ choice.message.content += part.delta?.content ?? "";
325
+ // TODO this should pass part.index to the callback
326
+ // when that's supported there
327
+ // eslint-disable-next-line no-void
328
+ void runManager?.handleLLMNewToken(part.delta?.content ?? "");
318
329
  }
319
- choice.message.content += part.delta?.content ?? "";
320
- // eslint-disable-next-line no-void
321
- void runManager?.handleLLMNewToken(part.delta?.content ?? "");
330
+ }
331
+ // when all messages are finished, resolve
332
+ if (!resolved &&
333
+ message.choices.every((c) => c.finish_reason != null)) {
334
+ resolved = true;
335
+ resolve(response);
322
336
  }
323
337
  }
324
338
  },
@@ -1,5 +1,5 @@
1
1
  import { ConfigurationParameters } from "openai";
2
- import { AzureOpenAIInput } from "../types/open-ai-types.js";
2
+ import { AzureOpenAIInput } from "../types/openai-types.js";
3
3
  import { Embeddings, EmbeddingsParams } from "./base.js";
4
4
  export interface OpenAIEmbeddingsParams extends EmbeddingsParams {
5
5
  /** Model name to use */