langchain 0.0.76 → 0.0.77

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. package/chains/query_constructor/ir.cjs +1 -0
  2. package/chains/query_constructor/ir.d.ts +1 -0
  3. package/chains/query_constructor/ir.js +1 -0
  4. package/chains/query_constructor.cjs +1 -0
  5. package/chains/query_constructor.d.ts +1 -0
  6. package/chains/query_constructor.js +1 -0
  7. package/dist/agents/chat_convo/index.cjs +27 -11
  8. package/dist/agents/chat_convo/index.d.ts +4 -1
  9. package/dist/agents/chat_convo/index.js +28 -12
  10. package/dist/agents/chat_convo/outputParser.cjs +79 -7
  11. package/dist/agents/chat_convo/outputParser.d.ts +25 -13
  12. package/dist/agents/chat_convo/outputParser.js +77 -6
  13. package/dist/agents/chat_convo/prompt.cjs +11 -8
  14. package/dist/agents/chat_convo/prompt.d.ts +2 -2
  15. package/dist/agents/chat_convo/prompt.js +11 -8
  16. package/dist/callbacks/handlers/tracer_langchain.cjs +12 -4
  17. package/dist/callbacks/handlers/tracer_langchain.d.ts +4 -1
  18. package/dist/callbacks/handlers/tracer_langchain.js +12 -4
  19. package/dist/callbacks/manager.cjs +6 -2
  20. package/dist/callbacks/manager.js +6 -2
  21. package/dist/chains/query_constructor/index.cjs +105 -0
  22. package/dist/chains/query_constructor/index.d.ts +37 -0
  23. package/dist/chains/query_constructor/index.js +95 -0
  24. package/dist/chains/query_constructor/ir.cjs +116 -0
  25. package/dist/chains/query_constructor/ir.d.ts +60 -0
  26. package/dist/chains/query_constructor/ir.js +107 -0
  27. package/dist/chains/query_constructor/parser.cjs +103 -0
  28. package/dist/chains/query_constructor/parser.d.ts +12 -0
  29. package/dist/chains/query_constructor/parser.js +99 -0
  30. package/dist/chains/query_constructor/prompt.cjs +127 -0
  31. package/dist/chains/query_constructor/prompt.d.ts +15 -0
  32. package/dist/chains/query_constructor/prompt.js +124 -0
  33. package/dist/chains/sql_db/sql_db_chain.cjs +13 -0
  34. package/dist/chains/sql_db/sql_db_chain.d.ts +2 -0
  35. package/dist/chains/sql_db/sql_db_chain.js +13 -0
  36. package/dist/client/langchainplus.cjs +21 -15
  37. package/dist/client/langchainplus.d.ts +4 -2
  38. package/dist/client/langchainplus.js +21 -15
  39. package/dist/llms/sagemaker_endpoint.cjs +123 -0
  40. package/dist/llms/sagemaker_endpoint.d.ts +82 -0
  41. package/dist/llms/sagemaker_endpoint.js +118 -0
  42. package/dist/memory/buffer_memory.cjs +1 -1
  43. package/dist/memory/buffer_memory.js +1 -1
  44. package/dist/memory/buffer_window_memory.cjs +1 -1
  45. package/dist/memory/buffer_window_memory.js +1 -1
  46. package/dist/output_parsers/expression.cjs +68 -0
  47. package/dist/output_parsers/expression.d.ts +25 -0
  48. package/dist/output_parsers/expression.js +49 -0
  49. package/dist/output_parsers/expression_type_handlers/array_literal_expression_handler.cjs +26 -0
  50. package/dist/output_parsers/expression_type_handlers/array_literal_expression_handler.d.ts +7 -0
  51. package/dist/output_parsers/expression_type_handlers/array_literal_expression_handler.js +22 -0
  52. package/dist/output_parsers/expression_type_handlers/base.cjs +67 -0
  53. package/dist/output_parsers/expression_type_handlers/base.d.ts +23 -0
  54. package/dist/output_parsers/expression_type_handlers/base.js +62 -0
  55. package/dist/output_parsers/expression_type_handlers/boolean_literal_handler.cjs +24 -0
  56. package/dist/output_parsers/expression_type_handlers/boolean_literal_handler.d.ts +7 -0
  57. package/dist/output_parsers/expression_type_handlers/boolean_literal_handler.js +20 -0
  58. package/dist/output_parsers/expression_type_handlers/call_expression_handler.cjs +52 -0
  59. package/dist/output_parsers/expression_type_handlers/call_expression_handler.d.ts +7 -0
  60. package/dist/output_parsers/expression_type_handlers/call_expression_handler.js +48 -0
  61. package/dist/output_parsers/expression_type_handlers/factory.cjs +56 -0
  62. package/dist/output_parsers/expression_type_handlers/factory.d.ts +9 -0
  63. package/dist/output_parsers/expression_type_handlers/factory.js +52 -0
  64. package/dist/output_parsers/expression_type_handlers/identifier_handler.cjs +22 -0
  65. package/dist/output_parsers/expression_type_handlers/identifier_handler.d.ts +7 -0
  66. package/dist/output_parsers/expression_type_handlers/identifier_handler.js +18 -0
  67. package/dist/output_parsers/expression_type_handlers/member_expression_handler.cjs +45 -0
  68. package/dist/output_parsers/expression_type_handlers/member_expression_handler.d.ts +7 -0
  69. package/dist/output_parsers/expression_type_handlers/member_expression_handler.js +41 -0
  70. package/dist/output_parsers/expression_type_handlers/numeric_literal_handler.cjs +24 -0
  71. package/dist/output_parsers/expression_type_handlers/numeric_literal_handler.d.ts +7 -0
  72. package/dist/output_parsers/expression_type_handlers/numeric_literal_handler.js +20 -0
  73. package/dist/output_parsers/expression_type_handlers/object_literal_expression_handler.cjs +29 -0
  74. package/dist/output_parsers/expression_type_handlers/object_literal_expression_handler.d.ts +7 -0
  75. package/dist/output_parsers/expression_type_handlers/object_literal_expression_handler.js +25 -0
  76. package/dist/output_parsers/expression_type_handlers/property_assignment_handler.cjs +36 -0
  77. package/dist/output_parsers/expression_type_handlers/property_assignment_handler.d.ts +7 -0
  78. package/dist/output_parsers/expression_type_handlers/property_assignment_handler.js +32 -0
  79. package/dist/output_parsers/expression_type_handlers/string_literal_handler.cjs +22 -0
  80. package/dist/output_parsers/expression_type_handlers/string_literal_handler.d.ts +7 -0
  81. package/dist/output_parsers/expression_type_handlers/string_literal_handler.js +18 -0
  82. package/dist/output_parsers/expression_type_handlers/types.cjs +2 -0
  83. package/dist/output_parsers/expression_type_handlers/types.d.ts +41 -0
  84. package/dist/output_parsers/expression_type_handlers/types.js +1 -0
  85. package/dist/output_parsers/index.cjs +2 -1
  86. package/dist/output_parsers/index.d.ts +1 -1
  87. package/dist/output_parsers/index.js +1 -1
  88. package/dist/output_parsers/structured.cjs +81 -23
  89. package/dist/output_parsers/structured.d.ts +18 -0
  90. package/dist/output_parsers/structured.js +79 -22
  91. package/dist/retrievers/self_query/index.cjs +79 -0
  92. package/dist/retrievers/self_query/index.d.ts +33 -0
  93. package/dist/retrievers/self_query/index.js +74 -0
  94. package/dist/retrievers/self_query/translator.cjs +72 -0
  95. package/dist/retrievers/self_query/translator.d.ts +14 -0
  96. package/dist/retrievers/self_query/translator.js +67 -0
  97. package/dist/schema/query_constructor.cjs +26 -0
  98. package/dist/schema/query_constructor.d.ts +6 -0
  99. package/dist/schema/query_constructor.js +22 -0
  100. package/dist/tools/json.cjs +3 -1
  101. package/dist/tools/json.js +3 -1
  102. package/dist/util/event-source-parse.cjs +31 -5
  103. package/dist/util/event-source-parse.d.ts +3 -3
  104. package/dist/util/event-source-parse.js +31 -5
  105. package/llms/sagemaker_endpoint.cjs +1 -0
  106. package/llms/sagemaker_endpoint.d.ts +1 -0
  107. package/llms/sagemaker_endpoint.js +1 -0
  108. package/output_parsers/expression.cjs +1 -0
  109. package/output_parsers/expression.d.ts +1 -0
  110. package/output_parsers/expression.js +1 -0
  111. package/package.json +61 -3
  112. package/retrievers/self_query.cjs +1 -0
  113. package/retrievers/self_query.d.ts +1 -0
  114. package/retrievers/self_query.js +1 -0
  115. package/schema/query_constructor.cjs +1 -0
  116. package/schema/query_constructor.d.ts +1 -0
  117. package/schema/query_constructor.js +1 -0
@@ -3,18 +3,20 @@ Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.LangChainPlusClient = exports.isChain = exports.isChatModel = exports.isLLM = void 0;
4
4
  const tracer_langchain_js_1 = require("../callbacks/handlers/tracer_langchain.cjs");
5
5
  const utils_js_1 = require("../stores/message/utils.cjs");
6
+ const async_caller_js_1 = require("../util/async_caller.cjs");
6
7
  // utility functions
7
8
  const isLocalhost = (url) => {
8
9
  const strippedUrl = url.replace("http://", "").replace("https://", "");
9
10
  const hostname = strippedUrl.split("/")[0].split(":")[0];
10
11
  return (hostname === "localhost" || hostname === "127.0.0.1" || hostname === "::1");
11
12
  };
12
- const getSeededTenantId = async (apiUrl, apiKey) => {
13
+ const getSeededTenantId = async (apiUrl, apiKey, callerOptions = undefined) => {
13
14
  // Get the tenant ID from the seeded tenant
15
+ const caller = new async_caller_js_1.AsyncCaller(callerOptions ?? {});
14
16
  const url = `${apiUrl}/tenants`;
15
17
  let response;
16
18
  try {
17
- response = await fetch(url, {
19
+ response = await caller.call(fetch, url, {
18
20
  method: "GET",
19
21
  headers: apiKey ? { authorization: `Bearer ${apiKey}` } : undefined,
20
22
  });
@@ -80,7 +82,7 @@ async function getModelOrFactoryType(llm) {
80
82
  throw new Error("Unknown model or factory type");
81
83
  }
82
84
  class LangChainPlusClient {
83
- constructor(apiUrl, tenantId, apiKey) {
85
+ constructor(apiUrl, tenantId, apiKey, callerOptions) {
84
86
  Object.defineProperty(this, "apiKey", {
85
87
  enumerable: true,
86
88
  configurable: true,
@@ -99,17 +101,21 @@ class LangChainPlusClient {
99
101
  writable: true,
100
102
  value: void 0
101
103
  });
104
+ Object.defineProperty(this, "caller", {
105
+ enumerable: true,
106
+ configurable: true,
107
+ writable: true,
108
+ value: void 0
109
+ });
102
110
  this.apiUrl = apiUrl;
103
111
  this.apiKey = apiKey;
104
112
  this.tenantId = tenantId;
105
113
  this.validateApiKeyIfHosted();
114
+ this.caller = new async_caller_js_1.AsyncCaller(callerOptions ?? {});
106
115
  }
107
- static async create(apiUrl, apiKey = undefined, tenantId = undefined) {
108
- let tenantId_ = tenantId;
109
- if (!tenantId_) {
110
- tenantId_ = await getSeededTenantId(apiUrl, apiKey);
111
- }
112
- return new LangChainPlusClient(apiUrl, tenantId_, apiKey);
116
+ static async create(apiUrl, apiKey = undefined) {
117
+ const tenantId = await getSeededTenantId(apiUrl, apiKey);
118
+ return new LangChainPlusClient(apiUrl, tenantId, apiKey);
113
119
  }
114
120
  validateApiKeyIfHosted() {
115
121
  const isLocal = isLocalhost(this.apiUrl);
@@ -138,7 +144,7 @@ class LangChainPlusClient {
138
144
  }
139
145
  }
140
146
  const url = `${this.apiUrl}${path}${queryString ? `?${queryString}` : ""}`;
141
- const response = await fetch(url, {
147
+ const response = await this.caller.call(fetch, url, {
142
148
  method: "GET",
143
149
  headers: this.headers,
144
150
  });
@@ -155,7 +161,7 @@ class LangChainPlusClient {
155
161
  formData.append("output_keys", outputKeys.join(","));
156
162
  formData.append("description", description);
157
163
  formData.append("tenant_id", this.tenantId);
158
- const response = await fetch(url, {
164
+ const response = await this.caller.call(fetch, url, {
159
165
  method: "POST",
160
166
  headers: this.headers,
161
167
  body: formData,
@@ -171,7 +177,7 @@ class LangChainPlusClient {
171
177
  return result;
172
178
  }
173
179
  async createDataset(name, description) {
174
- const response = await fetch(`${this.apiUrl}/datasets`, {
180
+ const response = await this.caller.call(fetch, `${this.apiUrl}/datasets`, {
175
181
  method: "POST",
176
182
  headers: { ...this.headers, "Content-Type": "application/json" },
177
183
  body: JSON.stringify({
@@ -245,7 +251,7 @@ class LangChainPlusClient {
245
251
  else {
246
252
  throw new Error("Must provide datasetName or datasetId");
247
253
  }
248
- const response = await fetch(this.apiUrl + path, {
254
+ const response = await this.caller.call(fetch, this.apiUrl + path, {
249
255
  method: "DELETE",
250
256
  headers: this.headers,
251
257
  });
@@ -274,7 +280,7 @@ class LangChainPlusClient {
274
280
  outputs,
275
281
  created_at: createdAt_.toISOString(),
276
282
  };
277
- const response = await fetch(`${this.apiUrl}/examples`, {
283
+ const response = await this.caller.call(fetch, `${this.apiUrl}/examples`, {
278
284
  method: "POST",
279
285
  headers: { ...this.headers, "Content-Type": "application/json" },
280
286
  body: JSON.stringify(data),
@@ -314,7 +320,7 @@ class LangChainPlusClient {
314
320
  }
315
321
  async deleteExample(exampleId) {
316
322
  const path = `/examples/${exampleId}`;
317
- const response = await fetch(this.apiUrl + path, {
323
+ const response = await this.caller.call(fetch, this.apiUrl + path, {
318
324
  method: "DELETE",
319
325
  headers: this.headers,
320
326
  });
@@ -5,6 +5,7 @@ import { BaseLanguageModel } from "../base_language/index.js";
5
5
  import { BaseChain } from "../chains/base.js";
6
6
  import { BaseLLM } from "../llms/base.js";
7
7
  import { BaseChatModel } from "../chat_models/base.js";
8
+ import { AsyncCallerParams } from "../util/async_caller.js";
8
9
  export interface RunResult extends BaseRun {
9
10
  name: string;
10
11
  session_id: string;
@@ -43,8 +44,9 @@ export declare class LangChainPlusClient {
43
44
  private apiKey?;
44
45
  private apiUrl;
45
46
  private tenantId;
46
- constructor(apiUrl: string, tenantId: string, apiKey?: string);
47
- static create(apiUrl: string, apiKey?: string | undefined, tenantId?: string | undefined): Promise<LangChainPlusClient>;
47
+ private caller;
48
+ constructor(apiUrl: string, tenantId: string, apiKey?: string, callerOptions?: AsyncCallerParams);
49
+ static create(apiUrl: string, apiKey?: string | undefined): Promise<LangChainPlusClient>;
48
50
  private validateApiKeyIfHosted;
49
51
  private get headers();
50
52
  private get queryParams();
@@ -1,17 +1,19 @@
1
1
  import { LangChainTracer } from "../callbacks/handlers/tracer_langchain.js";
2
2
  import { mapStoredMessagesToChatMessages } from "../stores/message/utils.js";
3
+ import { AsyncCaller } from "../util/async_caller.js";
3
4
  // utility functions
4
5
  const isLocalhost = (url) => {
5
6
  const strippedUrl = url.replace("http://", "").replace("https://", "");
6
7
  const hostname = strippedUrl.split("/")[0].split(":")[0];
7
8
  return (hostname === "localhost" || hostname === "127.0.0.1" || hostname === "::1");
8
9
  };
9
- const getSeededTenantId = async (apiUrl, apiKey) => {
10
+ const getSeededTenantId = async (apiUrl, apiKey, callerOptions = undefined) => {
10
11
  // Get the tenant ID from the seeded tenant
12
+ const caller = new AsyncCaller(callerOptions ?? {});
11
13
  const url = `${apiUrl}/tenants`;
12
14
  let response;
13
15
  try {
14
- response = await fetch(url, {
16
+ response = await caller.call(fetch, url, {
15
17
  method: "GET",
16
18
  headers: apiKey ? { authorization: `Bearer ${apiKey}` } : undefined,
17
19
  });
@@ -74,7 +76,7 @@ async function getModelOrFactoryType(llm) {
74
76
  throw new Error("Unknown model or factory type");
75
77
  }
76
78
  export class LangChainPlusClient {
77
- constructor(apiUrl, tenantId, apiKey) {
79
+ constructor(apiUrl, tenantId, apiKey, callerOptions) {
78
80
  Object.defineProperty(this, "apiKey", {
79
81
  enumerable: true,
80
82
  configurable: true,
@@ -93,17 +95,21 @@ export class LangChainPlusClient {
93
95
  writable: true,
94
96
  value: void 0
95
97
  });
98
+ Object.defineProperty(this, "caller", {
99
+ enumerable: true,
100
+ configurable: true,
101
+ writable: true,
102
+ value: void 0
103
+ });
96
104
  this.apiUrl = apiUrl;
97
105
  this.apiKey = apiKey;
98
106
  this.tenantId = tenantId;
99
107
  this.validateApiKeyIfHosted();
108
+ this.caller = new AsyncCaller(callerOptions ?? {});
100
109
  }
101
- static async create(apiUrl, apiKey = undefined, tenantId = undefined) {
102
- let tenantId_ = tenantId;
103
- if (!tenantId_) {
104
- tenantId_ = await getSeededTenantId(apiUrl, apiKey);
105
- }
106
- return new LangChainPlusClient(apiUrl, tenantId_, apiKey);
110
+ static async create(apiUrl, apiKey = undefined) {
111
+ const tenantId = await getSeededTenantId(apiUrl, apiKey);
112
+ return new LangChainPlusClient(apiUrl, tenantId, apiKey);
107
113
  }
108
114
  validateApiKeyIfHosted() {
109
115
  const isLocal = isLocalhost(this.apiUrl);
@@ -132,7 +138,7 @@ export class LangChainPlusClient {
132
138
  }
133
139
  }
134
140
  const url = `${this.apiUrl}${path}${queryString ? `?${queryString}` : ""}`;
135
- const response = await fetch(url, {
141
+ const response = await this.caller.call(fetch, url, {
136
142
  method: "GET",
137
143
  headers: this.headers,
138
144
  });
@@ -149,7 +155,7 @@ export class LangChainPlusClient {
149
155
  formData.append("output_keys", outputKeys.join(","));
150
156
  formData.append("description", description);
151
157
  formData.append("tenant_id", this.tenantId);
152
- const response = await fetch(url, {
158
+ const response = await this.caller.call(fetch, url, {
153
159
  method: "POST",
154
160
  headers: this.headers,
155
161
  body: formData,
@@ -165,7 +171,7 @@ export class LangChainPlusClient {
165
171
  return result;
166
172
  }
167
173
  async createDataset(name, description) {
168
- const response = await fetch(`${this.apiUrl}/datasets`, {
174
+ const response = await this.caller.call(fetch, `${this.apiUrl}/datasets`, {
169
175
  method: "POST",
170
176
  headers: { ...this.headers, "Content-Type": "application/json" },
171
177
  body: JSON.stringify({
@@ -239,7 +245,7 @@ export class LangChainPlusClient {
239
245
  else {
240
246
  throw new Error("Must provide datasetName or datasetId");
241
247
  }
242
- const response = await fetch(this.apiUrl + path, {
248
+ const response = await this.caller.call(fetch, this.apiUrl + path, {
243
249
  method: "DELETE",
244
250
  headers: this.headers,
245
251
  });
@@ -268,7 +274,7 @@ export class LangChainPlusClient {
268
274
  outputs,
269
275
  created_at: createdAt_.toISOString(),
270
276
  };
271
- const response = await fetch(`${this.apiUrl}/examples`, {
277
+ const response = await this.caller.call(fetch, `${this.apiUrl}/examples`, {
272
278
  method: "POST",
273
279
  headers: { ...this.headers, "Content-Type": "application/json" },
274
280
  body: JSON.stringify(data),
@@ -308,7 +314,7 @@ export class LangChainPlusClient {
308
314
  }
309
315
  async deleteExample(exampleId) {
310
316
  const path = `/examples/${exampleId}`;
311
- const response = await fetch(this.apiUrl + path, {
317
+ const response = await this.caller.call(fetch, this.apiUrl + path, {
312
318
  method: "DELETE",
313
319
  headers: this.headers,
314
320
  });
@@ -0,0 +1,123 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.SageMakerEndpoint = exports.BaseSageMakerContentHandler = void 0;
4
+ const client_sagemaker_runtime_1 = require("@aws-sdk/client-sagemaker-runtime");
5
+ const base_js_1 = require("./base.cjs");
6
+ /**
7
+ * A handler class to transform input from LLM to a format that SageMaker
8
+ * endpoint expects. Similarily, the class also handles transforming output from
9
+ * the SageMaker endpoint to a format that LLM class expects.
10
+ *
11
+ * Example:
12
+ * ```
13
+ * class ContentHandler implements ContentHandlerBase<string, string> {
14
+ * contentType = "application/json"
15
+ * accepts = "application/json"
16
+ *
17
+ * transformInput(prompt: string, modelKwargs: Record<string, unknown>) {
18
+ * const inputString = JSON.stringify({
19
+ * prompt,
20
+ * ...modelKwargs
21
+ * })
22
+ * return Buffer.from(inputString)
23
+ * }
24
+ *
25
+ * transformOutput(output: Uint8Array) {
26
+ * const responseJson = JSON.parse(Buffer.from(output).toString("utf-8"))
27
+ * return responseJson[0].generated_text
28
+ * }
29
+ *
30
+ * }
31
+ * ```
32
+ */
33
+ class BaseSageMakerContentHandler {
34
+ constructor() {
35
+ /** The MIME type of the input data passed to endpoint */
36
+ Object.defineProperty(this, "contentType", {
37
+ enumerable: true,
38
+ configurable: true,
39
+ writable: true,
40
+ value: "text/plain"
41
+ });
42
+ /** The MIME type of the response data returned from endpoint */
43
+ Object.defineProperty(this, "accepts", {
44
+ enumerable: true,
45
+ configurable: true,
46
+ writable: true,
47
+ value: "text/plain"
48
+ });
49
+ }
50
+ }
51
+ exports.BaseSageMakerContentHandler = BaseSageMakerContentHandler;
52
+ class SageMakerEndpoint extends base_js_1.LLM {
53
+ constructor(fields) {
54
+ super(fields ?? {});
55
+ Object.defineProperty(this, "endpointName", {
56
+ enumerable: true,
57
+ configurable: true,
58
+ writable: true,
59
+ value: void 0
60
+ });
61
+ Object.defineProperty(this, "contentHandler", {
62
+ enumerable: true,
63
+ configurable: true,
64
+ writable: true,
65
+ value: void 0
66
+ });
67
+ Object.defineProperty(this, "modelKwargs", {
68
+ enumerable: true,
69
+ configurable: true,
70
+ writable: true,
71
+ value: void 0
72
+ });
73
+ Object.defineProperty(this, "endpointKwargs", {
74
+ enumerable: true,
75
+ configurable: true,
76
+ writable: true,
77
+ value: void 0
78
+ });
79
+ Object.defineProperty(this, "client", {
80
+ enumerable: true,
81
+ configurable: true,
82
+ writable: true,
83
+ value: void 0
84
+ });
85
+ const regionName = fields.clientOptions.region;
86
+ if (!regionName) {
87
+ throw new Error(`Please pass a "clientOptions" object with a "region" field to the constructor`);
88
+ }
89
+ const endpointName = fields?.endpointName;
90
+ if (!endpointName) {
91
+ throw new Error(`Please pass an "endpointName" field to the constructor`);
92
+ }
93
+ const contentHandler = fields?.contentHandler;
94
+ if (!contentHandler) {
95
+ throw new Error(`Please pass a "contentHandler" field to the constructor`);
96
+ }
97
+ this.endpointName = fields.endpointName;
98
+ this.contentHandler = fields.contentHandler;
99
+ this.endpointKwargs = fields.endpointKwargs;
100
+ this.modelKwargs = fields.modelKwargs;
101
+ this.client = new client_sagemaker_runtime_1.SageMakerRuntimeClient(fields.clientOptions);
102
+ }
103
+ _llmType() {
104
+ return "sagemaker_endpoint";
105
+ }
106
+ /** @ignore */
107
+ async _call(prompt, options) {
108
+ const body = await this.contentHandler.transformInput(prompt, this.modelKwargs ?? {});
109
+ const { contentType, accepts } = this.contentHandler;
110
+ const response = await this.caller.call(() => this.client.send(new client_sagemaker_runtime_1.InvokeEndpointCommand({
111
+ EndpointName: this.endpointName,
112
+ Body: body,
113
+ ContentType: contentType,
114
+ Accept: accepts,
115
+ ...this.endpointKwargs,
116
+ }), { abortSignal: options.signal }));
117
+ if (response.Body === undefined) {
118
+ throw new Error("Inference result missing Body");
119
+ }
120
+ return this.contentHandler.transformOutput(response.Body);
121
+ }
122
+ }
123
+ exports.SageMakerEndpoint = SageMakerEndpoint;
@@ -0,0 +1,82 @@
1
+ import { SageMakerRuntimeClient, SageMakerRuntimeClientConfig } from "@aws-sdk/client-sagemaker-runtime";
2
+ import { LLM, BaseLLMParams } from "./base.js";
3
+ /**
4
+ * A handler class to transform input from LLM to a format that SageMaker
5
+ * endpoint expects. Similarily, the class also handles transforming output from
6
+ * the SageMaker endpoint to a format that LLM class expects.
7
+ *
8
+ * Example:
9
+ * ```
10
+ * class ContentHandler implements ContentHandlerBase<string, string> {
11
+ * contentType = "application/json"
12
+ * accepts = "application/json"
13
+ *
14
+ * transformInput(prompt: string, modelKwargs: Record<string, unknown>) {
15
+ * const inputString = JSON.stringify({
16
+ * prompt,
17
+ * ...modelKwargs
18
+ * })
19
+ * return Buffer.from(inputString)
20
+ * }
21
+ *
22
+ * transformOutput(output: Uint8Array) {
23
+ * const responseJson = JSON.parse(Buffer.from(output).toString("utf-8"))
24
+ * return responseJson[0].generated_text
25
+ * }
26
+ *
27
+ * }
28
+ * ```
29
+ */
30
+ export declare abstract class BaseSageMakerContentHandler<InputType, OutputType> {
31
+ /** The MIME type of the input data passed to endpoint */
32
+ contentType: string;
33
+ /** The MIME type of the response data returned from endpoint */
34
+ accepts: string;
35
+ /**
36
+ * Transforms the input to a format that model can accept as the request Body.
37
+ * Should return bytes or seekable file like object in the format specified in
38
+ * the contentType request header.
39
+ */
40
+ abstract transformInput(prompt: InputType, modelKwargs: Record<string, unknown>): Promise<Uint8Array>;
41
+ /**
42
+ * Transforms the output from the model to string that the LLM class expects.
43
+ */
44
+ abstract transformOutput(output: Uint8Array): Promise<OutputType>;
45
+ }
46
+ /** Content handler for LLM class. */
47
+ export type SageMakerLLMContentHandler = BaseSageMakerContentHandler<string, string>;
48
+ export interface SageMakerEndpointInput extends BaseLLMParams {
49
+ /**
50
+ * The name of the endpoint from the deployed SageMaker model. Must be unique
51
+ * within an AWS Region.
52
+ */
53
+ endpointName: string;
54
+ /**
55
+ * Options passed to the SageMaker client.
56
+ */
57
+ clientOptions: SageMakerRuntimeClientConfig;
58
+ /**
59
+ * The content handler class that provides an input and output transform
60
+ * functions to handle formats between LLM and the endpoint.
61
+ */
62
+ contentHandler: SageMakerLLMContentHandler;
63
+ /**
64
+ * Key word arguments to pass to the model.
65
+ */
66
+ modelKwargs?: Record<string, unknown>;
67
+ /**
68
+ * Optional attributes passed to the InvokeEndpointCommand
69
+ */
70
+ endpointKwargs?: Record<string, unknown>;
71
+ }
72
+ export declare class SageMakerEndpoint extends LLM {
73
+ endpointName: string;
74
+ contentHandler: SageMakerLLMContentHandler;
75
+ modelKwargs?: Record<string, unknown>;
76
+ endpointKwargs?: Record<string, unknown>;
77
+ client: SageMakerRuntimeClient;
78
+ constructor(fields: SageMakerEndpointInput);
79
+ _llmType(): string;
80
+ /** @ignore */
81
+ _call(prompt: string, options: this["ParsedCallOptions"]): Promise<string>;
82
+ }
@@ -0,0 +1,118 @@
1
+ import { SageMakerRuntimeClient, InvokeEndpointCommand, } from "@aws-sdk/client-sagemaker-runtime";
2
+ import { LLM } from "./base.js";
3
+ /**
4
+ * A handler class to transform input from LLM to a format that SageMaker
5
+ * endpoint expects. Similarily, the class also handles transforming output from
6
+ * the SageMaker endpoint to a format that LLM class expects.
7
+ *
8
+ * Example:
9
+ * ```
10
+ * class ContentHandler implements ContentHandlerBase<string, string> {
11
+ * contentType = "application/json"
12
+ * accepts = "application/json"
13
+ *
14
+ * transformInput(prompt: string, modelKwargs: Record<string, unknown>) {
15
+ * const inputString = JSON.stringify({
16
+ * prompt,
17
+ * ...modelKwargs
18
+ * })
19
+ * return Buffer.from(inputString)
20
+ * }
21
+ *
22
+ * transformOutput(output: Uint8Array) {
23
+ * const responseJson = JSON.parse(Buffer.from(output).toString("utf-8"))
24
+ * return responseJson[0].generated_text
25
+ * }
26
+ *
27
+ * }
28
+ * ```
29
+ */
30
+ export class BaseSageMakerContentHandler {
31
+ constructor() {
32
+ /** The MIME type of the input data passed to endpoint */
33
+ Object.defineProperty(this, "contentType", {
34
+ enumerable: true,
35
+ configurable: true,
36
+ writable: true,
37
+ value: "text/plain"
38
+ });
39
+ /** The MIME type of the response data returned from endpoint */
40
+ Object.defineProperty(this, "accepts", {
41
+ enumerable: true,
42
+ configurable: true,
43
+ writable: true,
44
+ value: "text/plain"
45
+ });
46
+ }
47
+ }
48
+ export class SageMakerEndpoint extends LLM {
49
+ constructor(fields) {
50
+ super(fields ?? {});
51
+ Object.defineProperty(this, "endpointName", {
52
+ enumerable: true,
53
+ configurable: true,
54
+ writable: true,
55
+ value: void 0
56
+ });
57
+ Object.defineProperty(this, "contentHandler", {
58
+ enumerable: true,
59
+ configurable: true,
60
+ writable: true,
61
+ value: void 0
62
+ });
63
+ Object.defineProperty(this, "modelKwargs", {
64
+ enumerable: true,
65
+ configurable: true,
66
+ writable: true,
67
+ value: void 0
68
+ });
69
+ Object.defineProperty(this, "endpointKwargs", {
70
+ enumerable: true,
71
+ configurable: true,
72
+ writable: true,
73
+ value: void 0
74
+ });
75
+ Object.defineProperty(this, "client", {
76
+ enumerable: true,
77
+ configurable: true,
78
+ writable: true,
79
+ value: void 0
80
+ });
81
+ const regionName = fields.clientOptions.region;
82
+ if (!regionName) {
83
+ throw new Error(`Please pass a "clientOptions" object with a "region" field to the constructor`);
84
+ }
85
+ const endpointName = fields?.endpointName;
86
+ if (!endpointName) {
87
+ throw new Error(`Please pass an "endpointName" field to the constructor`);
88
+ }
89
+ const contentHandler = fields?.contentHandler;
90
+ if (!contentHandler) {
91
+ throw new Error(`Please pass a "contentHandler" field to the constructor`);
92
+ }
93
+ this.endpointName = fields.endpointName;
94
+ this.contentHandler = fields.contentHandler;
95
+ this.endpointKwargs = fields.endpointKwargs;
96
+ this.modelKwargs = fields.modelKwargs;
97
+ this.client = new SageMakerRuntimeClient(fields.clientOptions);
98
+ }
99
+ _llmType() {
100
+ return "sagemaker_endpoint";
101
+ }
102
+ /** @ignore */
103
+ async _call(prompt, options) {
104
+ const body = await this.contentHandler.transformInput(prompt, this.modelKwargs ?? {});
105
+ const { contentType, accepts } = this.contentHandler;
106
+ const response = await this.caller.call(() => this.client.send(new InvokeEndpointCommand({
107
+ EndpointName: this.endpointName,
108
+ Body: body,
109
+ ContentType: contentType,
110
+ Accept: accepts,
111
+ ...this.endpointKwargs,
112
+ }), { abortSignal: options.signal }));
113
+ if (response.Body === undefined) {
114
+ throw new Error("Inference result missing Body");
115
+ }
116
+ return this.contentHandler.transformOutput(response.Body);
117
+ }
118
+ }
@@ -45,7 +45,7 @@ class BufferMemory extends chat_memory_js_1.BaseChatMemory {
45
45
  return result;
46
46
  }
47
47
  const result = {
48
- [this.memoryKey]: (0, base_js_1.getBufferString)(messages),
48
+ [this.memoryKey]: (0, base_js_1.getBufferString)(messages, this.humanPrefix, this.aiPrefix),
49
49
  };
50
50
  return result;
51
51
  }
@@ -42,7 +42,7 @@ export class BufferMemory extends BaseChatMemory {
42
42
  return result;
43
43
  }
44
44
  const result = {
45
- [this.memoryKey]: getBufferString(messages),
45
+ [this.memoryKey]: getBufferString(messages, this.humanPrefix, this.aiPrefix),
46
46
  };
47
47
  return result;
48
48
  }
@@ -52,7 +52,7 @@ class BufferWindowMemory extends chat_memory_js_1.BaseChatMemory {
52
52
  return result;
53
53
  }
54
54
  const result = {
55
- [this.memoryKey]: (0, base_js_1.getBufferString)(messages.slice(-this.k * 2)),
55
+ [this.memoryKey]: (0, base_js_1.getBufferString)(messages.slice(-this.k * 2), this.humanPrefix, this.aiPrefix),
56
56
  };
57
57
  return result;
58
58
  }
@@ -49,7 +49,7 @@ export class BufferWindowMemory extends BaseChatMemory {
49
49
  return result;
50
50
  }
51
51
  const result = {
52
- [this.memoryKey]: getBufferString(messages.slice(-this.k * 2)),
52
+ [this.memoryKey]: getBufferString(messages.slice(-this.k * 2), this.humanPrefix, this.aiPrefix),
53
53
  };
54
54
  return result;
55
55
  }
@@ -0,0 +1,68 @@
1
+ "use strict";
2
+ var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
3
+ if (k2 === undefined) k2 = k;
4
+ var desc = Object.getOwnPropertyDescriptor(m, k);
5
+ if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
6
+ desc = { enumerable: true, get: function() { return m[k]; } };
7
+ }
8
+ Object.defineProperty(o, k2, desc);
9
+ }) : (function(o, m, k, k2) {
10
+ if (k2 === undefined) k2 = k;
11
+ o[k2] = m[k];
12
+ }));
13
+ var __exportStar = (this && this.__exportStar) || function(m, exports) {
14
+ for (var p in m) if (p !== "default" && !Object.prototype.hasOwnProperty.call(exports, p)) __createBinding(exports, m, p);
15
+ };
16
+ Object.defineProperty(exports, "__esModule", { value: true });
17
+ exports.MasterHandler = exports.ExpressionParser = void 0;
18
+ const factory_js_1 = require("./expression_type_handlers/factory.cjs");
19
+ const output_parser_js_1 = require("../schema/output_parser.cjs");
20
+ const base_js_1 = require("./expression_type_handlers/base.cjs");
21
+ /**
22
+ * okay so we need to be able to handle the following cases:
23
+ * ExpressionStatement
24
+ * CallExpression
25
+ * Identifier | MemberExpression
26
+ * ExpressionLiterals: [
27
+ * CallExpression
28
+ * StringLiteral
29
+ * NumericLiteral
30
+ * ArrayLiteralExpression
31
+ * ExpressionLiterals
32
+ * ObjectLiteralExpression
33
+ * PropertyAssignment
34
+ * Identifier
35
+ * ExpressionLiterals
36
+ * ]
37
+ */
38
+ class ExpressionParser extends output_parser_js_1.BaseOutputParser {
39
+ async parse(text) {
40
+ const parse = await base_js_1.ASTParser.importASTParser();
41
+ try {
42
+ const program = parse(text);
43
+ if (program.body.length > 1) {
44
+ throw new Error(`Expected 1 statement, got ${program.body.length}`);
45
+ }
46
+ const [node] = program.body;
47
+ if (!base_js_1.ASTParser.isExpressionStatement(node)) {
48
+ throw new Error(`Expected ExpressionStatement, got ${node.type}`);
49
+ }
50
+ const { expression: expressionStatement } = node;
51
+ if (!base_js_1.ASTParser.isCallExpression(expressionStatement)) {
52
+ throw new Error("Expected CallExpression");
53
+ }
54
+ const masterHandler = factory_js_1.MasterHandler.createMasterHandler();
55
+ return await masterHandler.handle(expressionStatement);
56
+ }
57
+ catch (err) {
58
+ throw new Error(`Error parsing ${err}: ${text}`);
59
+ }
60
+ }
61
+ getFormatInstructions() {
62
+ return "";
63
+ }
64
+ }
65
+ exports.ExpressionParser = ExpressionParser;
66
+ __exportStar(require("./expression_type_handlers/types.cjs"), exports);
67
+ var factory_js_2 = require("./expression_type_handlers/factory.cjs");
68
+ Object.defineProperty(exports, "MasterHandler", { enumerable: true, get: function () { return factory_js_2.MasterHandler; } });