langchain 0.0.76 → 0.0.77
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/chains/query_constructor/ir.cjs +1 -0
- package/chains/query_constructor/ir.d.ts +1 -0
- package/chains/query_constructor/ir.js +1 -0
- package/chains/query_constructor.cjs +1 -0
- package/chains/query_constructor.d.ts +1 -0
- package/chains/query_constructor.js +1 -0
- package/dist/agents/chat_convo/index.cjs +27 -11
- package/dist/agents/chat_convo/index.d.ts +4 -1
- package/dist/agents/chat_convo/index.js +28 -12
- package/dist/agents/chat_convo/outputParser.cjs +79 -7
- package/dist/agents/chat_convo/outputParser.d.ts +25 -13
- package/dist/agents/chat_convo/outputParser.js +77 -6
- package/dist/agents/chat_convo/prompt.cjs +11 -8
- package/dist/agents/chat_convo/prompt.d.ts +2 -2
- package/dist/agents/chat_convo/prompt.js +11 -8
- package/dist/callbacks/handlers/tracer_langchain.cjs +12 -4
- package/dist/callbacks/handlers/tracer_langchain.d.ts +4 -1
- package/dist/callbacks/handlers/tracer_langchain.js +12 -4
- package/dist/callbacks/manager.cjs +6 -2
- package/dist/callbacks/manager.js +6 -2
- package/dist/chains/query_constructor/index.cjs +105 -0
- package/dist/chains/query_constructor/index.d.ts +37 -0
- package/dist/chains/query_constructor/index.js +95 -0
- package/dist/chains/query_constructor/ir.cjs +116 -0
- package/dist/chains/query_constructor/ir.d.ts +60 -0
- package/dist/chains/query_constructor/ir.js +107 -0
- package/dist/chains/query_constructor/parser.cjs +103 -0
- package/dist/chains/query_constructor/parser.d.ts +12 -0
- package/dist/chains/query_constructor/parser.js +99 -0
- package/dist/chains/query_constructor/prompt.cjs +127 -0
- package/dist/chains/query_constructor/prompt.d.ts +15 -0
- package/dist/chains/query_constructor/prompt.js +124 -0
- package/dist/chains/sql_db/sql_db_chain.cjs +13 -0
- package/dist/chains/sql_db/sql_db_chain.d.ts +2 -0
- package/dist/chains/sql_db/sql_db_chain.js +13 -0
- package/dist/client/langchainplus.cjs +21 -15
- package/dist/client/langchainplus.d.ts +4 -2
- package/dist/client/langchainplus.js +21 -15
- package/dist/llms/sagemaker_endpoint.cjs +123 -0
- package/dist/llms/sagemaker_endpoint.d.ts +82 -0
- package/dist/llms/sagemaker_endpoint.js +118 -0
- package/dist/memory/buffer_memory.cjs +1 -1
- package/dist/memory/buffer_memory.js +1 -1
- package/dist/memory/buffer_window_memory.cjs +1 -1
- package/dist/memory/buffer_window_memory.js +1 -1
- package/dist/output_parsers/expression.cjs +68 -0
- package/dist/output_parsers/expression.d.ts +25 -0
- package/dist/output_parsers/expression.js +49 -0
- package/dist/output_parsers/expression_type_handlers/array_literal_expression_handler.cjs +26 -0
- package/dist/output_parsers/expression_type_handlers/array_literal_expression_handler.d.ts +7 -0
- package/dist/output_parsers/expression_type_handlers/array_literal_expression_handler.js +22 -0
- package/dist/output_parsers/expression_type_handlers/base.cjs +67 -0
- package/dist/output_parsers/expression_type_handlers/base.d.ts +23 -0
- package/dist/output_parsers/expression_type_handlers/base.js +62 -0
- package/dist/output_parsers/expression_type_handlers/boolean_literal_handler.cjs +24 -0
- package/dist/output_parsers/expression_type_handlers/boolean_literal_handler.d.ts +7 -0
- package/dist/output_parsers/expression_type_handlers/boolean_literal_handler.js +20 -0
- package/dist/output_parsers/expression_type_handlers/call_expression_handler.cjs +52 -0
- package/dist/output_parsers/expression_type_handlers/call_expression_handler.d.ts +7 -0
- package/dist/output_parsers/expression_type_handlers/call_expression_handler.js +48 -0
- package/dist/output_parsers/expression_type_handlers/factory.cjs +56 -0
- package/dist/output_parsers/expression_type_handlers/factory.d.ts +9 -0
- package/dist/output_parsers/expression_type_handlers/factory.js +52 -0
- package/dist/output_parsers/expression_type_handlers/identifier_handler.cjs +22 -0
- package/dist/output_parsers/expression_type_handlers/identifier_handler.d.ts +7 -0
- package/dist/output_parsers/expression_type_handlers/identifier_handler.js +18 -0
- package/dist/output_parsers/expression_type_handlers/member_expression_handler.cjs +45 -0
- package/dist/output_parsers/expression_type_handlers/member_expression_handler.d.ts +7 -0
- package/dist/output_parsers/expression_type_handlers/member_expression_handler.js +41 -0
- package/dist/output_parsers/expression_type_handlers/numeric_literal_handler.cjs +24 -0
- package/dist/output_parsers/expression_type_handlers/numeric_literal_handler.d.ts +7 -0
- package/dist/output_parsers/expression_type_handlers/numeric_literal_handler.js +20 -0
- package/dist/output_parsers/expression_type_handlers/object_literal_expression_handler.cjs +29 -0
- package/dist/output_parsers/expression_type_handlers/object_literal_expression_handler.d.ts +7 -0
- package/dist/output_parsers/expression_type_handlers/object_literal_expression_handler.js +25 -0
- package/dist/output_parsers/expression_type_handlers/property_assignment_handler.cjs +36 -0
- package/dist/output_parsers/expression_type_handlers/property_assignment_handler.d.ts +7 -0
- package/dist/output_parsers/expression_type_handlers/property_assignment_handler.js +32 -0
- package/dist/output_parsers/expression_type_handlers/string_literal_handler.cjs +22 -0
- package/dist/output_parsers/expression_type_handlers/string_literal_handler.d.ts +7 -0
- package/dist/output_parsers/expression_type_handlers/string_literal_handler.js +18 -0
- package/dist/output_parsers/expression_type_handlers/types.cjs +2 -0
- package/dist/output_parsers/expression_type_handlers/types.d.ts +41 -0
- package/dist/output_parsers/expression_type_handlers/types.js +1 -0
- package/dist/output_parsers/index.cjs +2 -1
- package/dist/output_parsers/index.d.ts +1 -1
- package/dist/output_parsers/index.js +1 -1
- package/dist/output_parsers/structured.cjs +81 -23
- package/dist/output_parsers/structured.d.ts +18 -0
- package/dist/output_parsers/structured.js +79 -22
- package/dist/retrievers/self_query/index.cjs +79 -0
- package/dist/retrievers/self_query/index.d.ts +33 -0
- package/dist/retrievers/self_query/index.js +74 -0
- package/dist/retrievers/self_query/translator.cjs +72 -0
- package/dist/retrievers/self_query/translator.d.ts +14 -0
- package/dist/retrievers/self_query/translator.js +67 -0
- package/dist/schema/query_constructor.cjs +26 -0
- package/dist/schema/query_constructor.d.ts +6 -0
- package/dist/schema/query_constructor.js +22 -0
- package/dist/tools/json.cjs +3 -1
- package/dist/tools/json.js +3 -1
- package/dist/util/event-source-parse.cjs +31 -5
- package/dist/util/event-source-parse.d.ts +3 -3
- package/dist/util/event-source-parse.js +31 -5
- package/llms/sagemaker_endpoint.cjs +1 -0
- package/llms/sagemaker_endpoint.d.ts +1 -0
- package/llms/sagemaker_endpoint.js +1 -0
- package/output_parsers/expression.cjs +1 -0
- package/output_parsers/expression.d.ts +1 -0
- package/output_parsers/expression.js +1 -0
- package/package.json +61 -3
- package/retrievers/self_query.cjs +1 -0
- package/retrievers/self_query.d.ts +1 -0
- package/retrievers/self_query.js +1 -0
- package/schema/query_constructor.cjs +1 -0
- package/schema/query_constructor.d.ts +1 -0
- package/schema/query_constructor.js +1 -0
|
@@ -3,18 +3,20 @@ Object.defineProperty(exports, "__esModule", { value: true });
|
|
|
3
3
|
exports.LangChainPlusClient = exports.isChain = exports.isChatModel = exports.isLLM = void 0;
|
|
4
4
|
const tracer_langchain_js_1 = require("../callbacks/handlers/tracer_langchain.cjs");
|
|
5
5
|
const utils_js_1 = require("../stores/message/utils.cjs");
|
|
6
|
+
const async_caller_js_1 = require("../util/async_caller.cjs");
|
|
6
7
|
// utility functions
|
|
7
8
|
const isLocalhost = (url) => {
|
|
8
9
|
const strippedUrl = url.replace("http://", "").replace("https://", "");
|
|
9
10
|
const hostname = strippedUrl.split("/")[0].split(":")[0];
|
|
10
11
|
return (hostname === "localhost" || hostname === "127.0.0.1" || hostname === "::1");
|
|
11
12
|
};
|
|
12
|
-
const getSeededTenantId = async (apiUrl, apiKey) => {
|
|
13
|
+
const getSeededTenantId = async (apiUrl, apiKey, callerOptions = undefined) => {
|
|
13
14
|
// Get the tenant ID from the seeded tenant
|
|
15
|
+
const caller = new async_caller_js_1.AsyncCaller(callerOptions ?? {});
|
|
14
16
|
const url = `${apiUrl}/tenants`;
|
|
15
17
|
let response;
|
|
16
18
|
try {
|
|
17
|
-
response = await fetch
|
|
19
|
+
response = await caller.call(fetch, url, {
|
|
18
20
|
method: "GET",
|
|
19
21
|
headers: apiKey ? { authorization: `Bearer ${apiKey}` } : undefined,
|
|
20
22
|
});
|
|
@@ -80,7 +82,7 @@ async function getModelOrFactoryType(llm) {
|
|
|
80
82
|
throw new Error("Unknown model or factory type");
|
|
81
83
|
}
|
|
82
84
|
class LangChainPlusClient {
|
|
83
|
-
constructor(apiUrl, tenantId, apiKey) {
|
|
85
|
+
constructor(apiUrl, tenantId, apiKey, callerOptions) {
|
|
84
86
|
Object.defineProperty(this, "apiKey", {
|
|
85
87
|
enumerable: true,
|
|
86
88
|
configurable: true,
|
|
@@ -99,17 +101,21 @@ class LangChainPlusClient {
|
|
|
99
101
|
writable: true,
|
|
100
102
|
value: void 0
|
|
101
103
|
});
|
|
104
|
+
Object.defineProperty(this, "caller", {
|
|
105
|
+
enumerable: true,
|
|
106
|
+
configurable: true,
|
|
107
|
+
writable: true,
|
|
108
|
+
value: void 0
|
|
109
|
+
});
|
|
102
110
|
this.apiUrl = apiUrl;
|
|
103
111
|
this.apiKey = apiKey;
|
|
104
112
|
this.tenantId = tenantId;
|
|
105
113
|
this.validateApiKeyIfHosted();
|
|
114
|
+
this.caller = new async_caller_js_1.AsyncCaller(callerOptions ?? {});
|
|
106
115
|
}
|
|
107
|
-
static async create(apiUrl, apiKey = undefined
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
tenantId_ = await getSeededTenantId(apiUrl, apiKey);
|
|
111
|
-
}
|
|
112
|
-
return new LangChainPlusClient(apiUrl, tenantId_, apiKey);
|
|
116
|
+
static async create(apiUrl, apiKey = undefined) {
|
|
117
|
+
const tenantId = await getSeededTenantId(apiUrl, apiKey);
|
|
118
|
+
return new LangChainPlusClient(apiUrl, tenantId, apiKey);
|
|
113
119
|
}
|
|
114
120
|
validateApiKeyIfHosted() {
|
|
115
121
|
const isLocal = isLocalhost(this.apiUrl);
|
|
@@ -138,7 +144,7 @@ class LangChainPlusClient {
|
|
|
138
144
|
}
|
|
139
145
|
}
|
|
140
146
|
const url = `${this.apiUrl}${path}${queryString ? `?${queryString}` : ""}`;
|
|
141
|
-
const response = await fetch
|
|
147
|
+
const response = await this.caller.call(fetch, url, {
|
|
142
148
|
method: "GET",
|
|
143
149
|
headers: this.headers,
|
|
144
150
|
});
|
|
@@ -155,7 +161,7 @@ class LangChainPlusClient {
|
|
|
155
161
|
formData.append("output_keys", outputKeys.join(","));
|
|
156
162
|
formData.append("description", description);
|
|
157
163
|
formData.append("tenant_id", this.tenantId);
|
|
158
|
-
const response = await fetch
|
|
164
|
+
const response = await this.caller.call(fetch, url, {
|
|
159
165
|
method: "POST",
|
|
160
166
|
headers: this.headers,
|
|
161
167
|
body: formData,
|
|
@@ -171,7 +177,7 @@ class LangChainPlusClient {
|
|
|
171
177
|
return result;
|
|
172
178
|
}
|
|
173
179
|
async createDataset(name, description) {
|
|
174
|
-
const response = await fetch
|
|
180
|
+
const response = await this.caller.call(fetch, `${this.apiUrl}/datasets`, {
|
|
175
181
|
method: "POST",
|
|
176
182
|
headers: { ...this.headers, "Content-Type": "application/json" },
|
|
177
183
|
body: JSON.stringify({
|
|
@@ -245,7 +251,7 @@ class LangChainPlusClient {
|
|
|
245
251
|
else {
|
|
246
252
|
throw new Error("Must provide datasetName or datasetId");
|
|
247
253
|
}
|
|
248
|
-
const response = await fetch
|
|
254
|
+
const response = await this.caller.call(fetch, this.apiUrl + path, {
|
|
249
255
|
method: "DELETE",
|
|
250
256
|
headers: this.headers,
|
|
251
257
|
});
|
|
@@ -274,7 +280,7 @@ class LangChainPlusClient {
|
|
|
274
280
|
outputs,
|
|
275
281
|
created_at: createdAt_.toISOString(),
|
|
276
282
|
};
|
|
277
|
-
const response = await fetch
|
|
283
|
+
const response = await this.caller.call(fetch, `${this.apiUrl}/examples`, {
|
|
278
284
|
method: "POST",
|
|
279
285
|
headers: { ...this.headers, "Content-Type": "application/json" },
|
|
280
286
|
body: JSON.stringify(data),
|
|
@@ -314,7 +320,7 @@ class LangChainPlusClient {
|
|
|
314
320
|
}
|
|
315
321
|
async deleteExample(exampleId) {
|
|
316
322
|
const path = `/examples/${exampleId}`;
|
|
317
|
-
const response = await fetch
|
|
323
|
+
const response = await this.caller.call(fetch, this.apiUrl + path, {
|
|
318
324
|
method: "DELETE",
|
|
319
325
|
headers: this.headers,
|
|
320
326
|
});
|
|
@@ -5,6 +5,7 @@ import { BaseLanguageModel } from "../base_language/index.js";
|
|
|
5
5
|
import { BaseChain } from "../chains/base.js";
|
|
6
6
|
import { BaseLLM } from "../llms/base.js";
|
|
7
7
|
import { BaseChatModel } from "../chat_models/base.js";
|
|
8
|
+
import { AsyncCallerParams } from "../util/async_caller.js";
|
|
8
9
|
export interface RunResult extends BaseRun {
|
|
9
10
|
name: string;
|
|
10
11
|
session_id: string;
|
|
@@ -43,8 +44,9 @@ export declare class LangChainPlusClient {
|
|
|
43
44
|
private apiKey?;
|
|
44
45
|
private apiUrl;
|
|
45
46
|
private tenantId;
|
|
46
|
-
|
|
47
|
-
|
|
47
|
+
private caller;
|
|
48
|
+
constructor(apiUrl: string, tenantId: string, apiKey?: string, callerOptions?: AsyncCallerParams);
|
|
49
|
+
static create(apiUrl: string, apiKey?: string | undefined): Promise<LangChainPlusClient>;
|
|
48
50
|
private validateApiKeyIfHosted;
|
|
49
51
|
private get headers();
|
|
50
52
|
private get queryParams();
|
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
import { LangChainTracer } from "../callbacks/handlers/tracer_langchain.js";
|
|
2
2
|
import { mapStoredMessagesToChatMessages } from "../stores/message/utils.js";
|
|
3
|
+
import { AsyncCaller } from "../util/async_caller.js";
|
|
3
4
|
// utility functions
|
|
4
5
|
const isLocalhost = (url) => {
|
|
5
6
|
const strippedUrl = url.replace("http://", "").replace("https://", "");
|
|
6
7
|
const hostname = strippedUrl.split("/")[0].split(":")[0];
|
|
7
8
|
return (hostname === "localhost" || hostname === "127.0.0.1" || hostname === "::1");
|
|
8
9
|
};
|
|
9
|
-
const getSeededTenantId = async (apiUrl, apiKey) => {
|
|
10
|
+
const getSeededTenantId = async (apiUrl, apiKey, callerOptions = undefined) => {
|
|
10
11
|
// Get the tenant ID from the seeded tenant
|
|
12
|
+
const caller = new AsyncCaller(callerOptions ?? {});
|
|
11
13
|
const url = `${apiUrl}/tenants`;
|
|
12
14
|
let response;
|
|
13
15
|
try {
|
|
14
|
-
response = await fetch
|
|
16
|
+
response = await caller.call(fetch, url, {
|
|
15
17
|
method: "GET",
|
|
16
18
|
headers: apiKey ? { authorization: `Bearer ${apiKey}` } : undefined,
|
|
17
19
|
});
|
|
@@ -74,7 +76,7 @@ async function getModelOrFactoryType(llm) {
|
|
|
74
76
|
throw new Error("Unknown model or factory type");
|
|
75
77
|
}
|
|
76
78
|
export class LangChainPlusClient {
|
|
77
|
-
constructor(apiUrl, tenantId, apiKey) {
|
|
79
|
+
constructor(apiUrl, tenantId, apiKey, callerOptions) {
|
|
78
80
|
Object.defineProperty(this, "apiKey", {
|
|
79
81
|
enumerable: true,
|
|
80
82
|
configurable: true,
|
|
@@ -93,17 +95,21 @@ export class LangChainPlusClient {
|
|
|
93
95
|
writable: true,
|
|
94
96
|
value: void 0
|
|
95
97
|
});
|
|
98
|
+
Object.defineProperty(this, "caller", {
|
|
99
|
+
enumerable: true,
|
|
100
|
+
configurable: true,
|
|
101
|
+
writable: true,
|
|
102
|
+
value: void 0
|
|
103
|
+
});
|
|
96
104
|
this.apiUrl = apiUrl;
|
|
97
105
|
this.apiKey = apiKey;
|
|
98
106
|
this.tenantId = tenantId;
|
|
99
107
|
this.validateApiKeyIfHosted();
|
|
108
|
+
this.caller = new AsyncCaller(callerOptions ?? {});
|
|
100
109
|
}
|
|
101
|
-
static async create(apiUrl, apiKey = undefined
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
tenantId_ = await getSeededTenantId(apiUrl, apiKey);
|
|
105
|
-
}
|
|
106
|
-
return new LangChainPlusClient(apiUrl, tenantId_, apiKey);
|
|
110
|
+
static async create(apiUrl, apiKey = undefined) {
|
|
111
|
+
const tenantId = await getSeededTenantId(apiUrl, apiKey);
|
|
112
|
+
return new LangChainPlusClient(apiUrl, tenantId, apiKey);
|
|
107
113
|
}
|
|
108
114
|
validateApiKeyIfHosted() {
|
|
109
115
|
const isLocal = isLocalhost(this.apiUrl);
|
|
@@ -132,7 +138,7 @@ export class LangChainPlusClient {
|
|
|
132
138
|
}
|
|
133
139
|
}
|
|
134
140
|
const url = `${this.apiUrl}${path}${queryString ? `?${queryString}` : ""}`;
|
|
135
|
-
const response = await fetch
|
|
141
|
+
const response = await this.caller.call(fetch, url, {
|
|
136
142
|
method: "GET",
|
|
137
143
|
headers: this.headers,
|
|
138
144
|
});
|
|
@@ -149,7 +155,7 @@ export class LangChainPlusClient {
|
|
|
149
155
|
formData.append("output_keys", outputKeys.join(","));
|
|
150
156
|
formData.append("description", description);
|
|
151
157
|
formData.append("tenant_id", this.tenantId);
|
|
152
|
-
const response = await fetch
|
|
158
|
+
const response = await this.caller.call(fetch, url, {
|
|
153
159
|
method: "POST",
|
|
154
160
|
headers: this.headers,
|
|
155
161
|
body: formData,
|
|
@@ -165,7 +171,7 @@ export class LangChainPlusClient {
|
|
|
165
171
|
return result;
|
|
166
172
|
}
|
|
167
173
|
async createDataset(name, description) {
|
|
168
|
-
const response = await fetch
|
|
174
|
+
const response = await this.caller.call(fetch, `${this.apiUrl}/datasets`, {
|
|
169
175
|
method: "POST",
|
|
170
176
|
headers: { ...this.headers, "Content-Type": "application/json" },
|
|
171
177
|
body: JSON.stringify({
|
|
@@ -239,7 +245,7 @@ export class LangChainPlusClient {
|
|
|
239
245
|
else {
|
|
240
246
|
throw new Error("Must provide datasetName or datasetId");
|
|
241
247
|
}
|
|
242
|
-
const response = await fetch
|
|
248
|
+
const response = await this.caller.call(fetch, this.apiUrl + path, {
|
|
243
249
|
method: "DELETE",
|
|
244
250
|
headers: this.headers,
|
|
245
251
|
});
|
|
@@ -268,7 +274,7 @@ export class LangChainPlusClient {
|
|
|
268
274
|
outputs,
|
|
269
275
|
created_at: createdAt_.toISOString(),
|
|
270
276
|
};
|
|
271
|
-
const response = await fetch
|
|
277
|
+
const response = await this.caller.call(fetch, `${this.apiUrl}/examples`, {
|
|
272
278
|
method: "POST",
|
|
273
279
|
headers: { ...this.headers, "Content-Type": "application/json" },
|
|
274
280
|
body: JSON.stringify(data),
|
|
@@ -308,7 +314,7 @@ export class LangChainPlusClient {
|
|
|
308
314
|
}
|
|
309
315
|
async deleteExample(exampleId) {
|
|
310
316
|
const path = `/examples/${exampleId}`;
|
|
311
|
-
const response = await fetch
|
|
317
|
+
const response = await this.caller.call(fetch, this.apiUrl + path, {
|
|
312
318
|
method: "DELETE",
|
|
313
319
|
headers: this.headers,
|
|
314
320
|
});
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.SageMakerEndpoint = exports.BaseSageMakerContentHandler = void 0;
|
|
4
|
+
const client_sagemaker_runtime_1 = require("@aws-sdk/client-sagemaker-runtime");
|
|
5
|
+
const base_js_1 = require("./base.cjs");
|
|
6
|
+
/**
|
|
7
|
+
* A handler class to transform input from LLM to a format that SageMaker
|
|
8
|
+
* endpoint expects. Similarily, the class also handles transforming output from
|
|
9
|
+
* the SageMaker endpoint to a format that LLM class expects.
|
|
10
|
+
*
|
|
11
|
+
* Example:
|
|
12
|
+
* ```
|
|
13
|
+
* class ContentHandler implements ContentHandlerBase<string, string> {
|
|
14
|
+
* contentType = "application/json"
|
|
15
|
+
* accepts = "application/json"
|
|
16
|
+
*
|
|
17
|
+
* transformInput(prompt: string, modelKwargs: Record<string, unknown>) {
|
|
18
|
+
* const inputString = JSON.stringify({
|
|
19
|
+
* prompt,
|
|
20
|
+
* ...modelKwargs
|
|
21
|
+
* })
|
|
22
|
+
* return Buffer.from(inputString)
|
|
23
|
+
* }
|
|
24
|
+
*
|
|
25
|
+
* transformOutput(output: Uint8Array) {
|
|
26
|
+
* const responseJson = JSON.parse(Buffer.from(output).toString("utf-8"))
|
|
27
|
+
* return responseJson[0].generated_text
|
|
28
|
+
* }
|
|
29
|
+
*
|
|
30
|
+
* }
|
|
31
|
+
* ```
|
|
32
|
+
*/
|
|
33
|
+
class BaseSageMakerContentHandler {
|
|
34
|
+
constructor() {
|
|
35
|
+
/** The MIME type of the input data passed to endpoint */
|
|
36
|
+
Object.defineProperty(this, "contentType", {
|
|
37
|
+
enumerable: true,
|
|
38
|
+
configurable: true,
|
|
39
|
+
writable: true,
|
|
40
|
+
value: "text/plain"
|
|
41
|
+
});
|
|
42
|
+
/** The MIME type of the response data returned from endpoint */
|
|
43
|
+
Object.defineProperty(this, "accepts", {
|
|
44
|
+
enumerable: true,
|
|
45
|
+
configurable: true,
|
|
46
|
+
writable: true,
|
|
47
|
+
value: "text/plain"
|
|
48
|
+
});
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
exports.BaseSageMakerContentHandler = BaseSageMakerContentHandler;
|
|
52
|
+
class SageMakerEndpoint extends base_js_1.LLM {
|
|
53
|
+
constructor(fields) {
|
|
54
|
+
super(fields ?? {});
|
|
55
|
+
Object.defineProperty(this, "endpointName", {
|
|
56
|
+
enumerable: true,
|
|
57
|
+
configurable: true,
|
|
58
|
+
writable: true,
|
|
59
|
+
value: void 0
|
|
60
|
+
});
|
|
61
|
+
Object.defineProperty(this, "contentHandler", {
|
|
62
|
+
enumerable: true,
|
|
63
|
+
configurable: true,
|
|
64
|
+
writable: true,
|
|
65
|
+
value: void 0
|
|
66
|
+
});
|
|
67
|
+
Object.defineProperty(this, "modelKwargs", {
|
|
68
|
+
enumerable: true,
|
|
69
|
+
configurable: true,
|
|
70
|
+
writable: true,
|
|
71
|
+
value: void 0
|
|
72
|
+
});
|
|
73
|
+
Object.defineProperty(this, "endpointKwargs", {
|
|
74
|
+
enumerable: true,
|
|
75
|
+
configurable: true,
|
|
76
|
+
writable: true,
|
|
77
|
+
value: void 0
|
|
78
|
+
});
|
|
79
|
+
Object.defineProperty(this, "client", {
|
|
80
|
+
enumerable: true,
|
|
81
|
+
configurable: true,
|
|
82
|
+
writable: true,
|
|
83
|
+
value: void 0
|
|
84
|
+
});
|
|
85
|
+
const regionName = fields.clientOptions.region;
|
|
86
|
+
if (!regionName) {
|
|
87
|
+
throw new Error(`Please pass a "clientOptions" object with a "region" field to the constructor`);
|
|
88
|
+
}
|
|
89
|
+
const endpointName = fields?.endpointName;
|
|
90
|
+
if (!endpointName) {
|
|
91
|
+
throw new Error(`Please pass an "endpointName" field to the constructor`);
|
|
92
|
+
}
|
|
93
|
+
const contentHandler = fields?.contentHandler;
|
|
94
|
+
if (!contentHandler) {
|
|
95
|
+
throw new Error(`Please pass a "contentHandler" field to the constructor`);
|
|
96
|
+
}
|
|
97
|
+
this.endpointName = fields.endpointName;
|
|
98
|
+
this.contentHandler = fields.contentHandler;
|
|
99
|
+
this.endpointKwargs = fields.endpointKwargs;
|
|
100
|
+
this.modelKwargs = fields.modelKwargs;
|
|
101
|
+
this.client = new client_sagemaker_runtime_1.SageMakerRuntimeClient(fields.clientOptions);
|
|
102
|
+
}
|
|
103
|
+
_llmType() {
|
|
104
|
+
return "sagemaker_endpoint";
|
|
105
|
+
}
|
|
106
|
+
/** @ignore */
|
|
107
|
+
async _call(prompt, options) {
|
|
108
|
+
const body = await this.contentHandler.transformInput(prompt, this.modelKwargs ?? {});
|
|
109
|
+
const { contentType, accepts } = this.contentHandler;
|
|
110
|
+
const response = await this.caller.call(() => this.client.send(new client_sagemaker_runtime_1.InvokeEndpointCommand({
|
|
111
|
+
EndpointName: this.endpointName,
|
|
112
|
+
Body: body,
|
|
113
|
+
ContentType: contentType,
|
|
114
|
+
Accept: accepts,
|
|
115
|
+
...this.endpointKwargs,
|
|
116
|
+
}), { abortSignal: options.signal }));
|
|
117
|
+
if (response.Body === undefined) {
|
|
118
|
+
throw new Error("Inference result missing Body");
|
|
119
|
+
}
|
|
120
|
+
return this.contentHandler.transformOutput(response.Body);
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
exports.SageMakerEndpoint = SageMakerEndpoint;
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import { SageMakerRuntimeClient, SageMakerRuntimeClientConfig } from "@aws-sdk/client-sagemaker-runtime";
|
|
2
|
+
import { LLM, BaseLLMParams } from "./base.js";
|
|
3
|
+
/**
|
|
4
|
+
* A handler class to transform input from LLM to a format that SageMaker
|
|
5
|
+
* endpoint expects. Similarily, the class also handles transforming output from
|
|
6
|
+
* the SageMaker endpoint to a format that LLM class expects.
|
|
7
|
+
*
|
|
8
|
+
* Example:
|
|
9
|
+
* ```
|
|
10
|
+
* class ContentHandler implements ContentHandlerBase<string, string> {
|
|
11
|
+
* contentType = "application/json"
|
|
12
|
+
* accepts = "application/json"
|
|
13
|
+
*
|
|
14
|
+
* transformInput(prompt: string, modelKwargs: Record<string, unknown>) {
|
|
15
|
+
* const inputString = JSON.stringify({
|
|
16
|
+
* prompt,
|
|
17
|
+
* ...modelKwargs
|
|
18
|
+
* })
|
|
19
|
+
* return Buffer.from(inputString)
|
|
20
|
+
* }
|
|
21
|
+
*
|
|
22
|
+
* transformOutput(output: Uint8Array) {
|
|
23
|
+
* const responseJson = JSON.parse(Buffer.from(output).toString("utf-8"))
|
|
24
|
+
* return responseJson[0].generated_text
|
|
25
|
+
* }
|
|
26
|
+
*
|
|
27
|
+
* }
|
|
28
|
+
* ```
|
|
29
|
+
*/
|
|
30
|
+
export declare abstract class BaseSageMakerContentHandler<InputType, OutputType> {
|
|
31
|
+
/** The MIME type of the input data passed to endpoint */
|
|
32
|
+
contentType: string;
|
|
33
|
+
/** The MIME type of the response data returned from endpoint */
|
|
34
|
+
accepts: string;
|
|
35
|
+
/**
|
|
36
|
+
* Transforms the input to a format that model can accept as the request Body.
|
|
37
|
+
* Should return bytes or seekable file like object in the format specified in
|
|
38
|
+
* the contentType request header.
|
|
39
|
+
*/
|
|
40
|
+
abstract transformInput(prompt: InputType, modelKwargs: Record<string, unknown>): Promise<Uint8Array>;
|
|
41
|
+
/**
|
|
42
|
+
* Transforms the output from the model to string that the LLM class expects.
|
|
43
|
+
*/
|
|
44
|
+
abstract transformOutput(output: Uint8Array): Promise<OutputType>;
|
|
45
|
+
}
|
|
46
|
+
/** Content handler for LLM class. */
|
|
47
|
+
export type SageMakerLLMContentHandler = BaseSageMakerContentHandler<string, string>;
|
|
48
|
+
export interface SageMakerEndpointInput extends BaseLLMParams {
|
|
49
|
+
/**
|
|
50
|
+
* The name of the endpoint from the deployed SageMaker model. Must be unique
|
|
51
|
+
* within an AWS Region.
|
|
52
|
+
*/
|
|
53
|
+
endpointName: string;
|
|
54
|
+
/**
|
|
55
|
+
* Options passed to the SageMaker client.
|
|
56
|
+
*/
|
|
57
|
+
clientOptions: SageMakerRuntimeClientConfig;
|
|
58
|
+
/**
|
|
59
|
+
* The content handler class that provides an input and output transform
|
|
60
|
+
* functions to handle formats between LLM and the endpoint.
|
|
61
|
+
*/
|
|
62
|
+
contentHandler: SageMakerLLMContentHandler;
|
|
63
|
+
/**
|
|
64
|
+
* Key word arguments to pass to the model.
|
|
65
|
+
*/
|
|
66
|
+
modelKwargs?: Record<string, unknown>;
|
|
67
|
+
/**
|
|
68
|
+
* Optional attributes passed to the InvokeEndpointCommand
|
|
69
|
+
*/
|
|
70
|
+
endpointKwargs?: Record<string, unknown>;
|
|
71
|
+
}
|
|
72
|
+
export declare class SageMakerEndpoint extends LLM {
|
|
73
|
+
endpointName: string;
|
|
74
|
+
contentHandler: SageMakerLLMContentHandler;
|
|
75
|
+
modelKwargs?: Record<string, unknown>;
|
|
76
|
+
endpointKwargs?: Record<string, unknown>;
|
|
77
|
+
client: SageMakerRuntimeClient;
|
|
78
|
+
constructor(fields: SageMakerEndpointInput);
|
|
79
|
+
_llmType(): string;
|
|
80
|
+
/** @ignore */
|
|
81
|
+
_call(prompt: string, options: this["ParsedCallOptions"]): Promise<string>;
|
|
82
|
+
}
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import { SageMakerRuntimeClient, InvokeEndpointCommand, } from "@aws-sdk/client-sagemaker-runtime";
|
|
2
|
+
import { LLM } from "./base.js";
|
|
3
|
+
/**
|
|
4
|
+
* A handler class to transform input from LLM to a format that SageMaker
|
|
5
|
+
* endpoint expects. Similarily, the class also handles transforming output from
|
|
6
|
+
* the SageMaker endpoint to a format that LLM class expects.
|
|
7
|
+
*
|
|
8
|
+
* Example:
|
|
9
|
+
* ```
|
|
10
|
+
* class ContentHandler implements ContentHandlerBase<string, string> {
|
|
11
|
+
* contentType = "application/json"
|
|
12
|
+
* accepts = "application/json"
|
|
13
|
+
*
|
|
14
|
+
* transformInput(prompt: string, modelKwargs: Record<string, unknown>) {
|
|
15
|
+
* const inputString = JSON.stringify({
|
|
16
|
+
* prompt,
|
|
17
|
+
* ...modelKwargs
|
|
18
|
+
* })
|
|
19
|
+
* return Buffer.from(inputString)
|
|
20
|
+
* }
|
|
21
|
+
*
|
|
22
|
+
* transformOutput(output: Uint8Array) {
|
|
23
|
+
* const responseJson = JSON.parse(Buffer.from(output).toString("utf-8"))
|
|
24
|
+
* return responseJson[0].generated_text
|
|
25
|
+
* }
|
|
26
|
+
*
|
|
27
|
+
* }
|
|
28
|
+
* ```
|
|
29
|
+
*/
|
|
30
|
+
export class BaseSageMakerContentHandler {
|
|
31
|
+
constructor() {
|
|
32
|
+
/** The MIME type of the input data passed to endpoint */
|
|
33
|
+
Object.defineProperty(this, "contentType", {
|
|
34
|
+
enumerable: true,
|
|
35
|
+
configurable: true,
|
|
36
|
+
writable: true,
|
|
37
|
+
value: "text/plain"
|
|
38
|
+
});
|
|
39
|
+
/** The MIME type of the response data returned from endpoint */
|
|
40
|
+
Object.defineProperty(this, "accepts", {
|
|
41
|
+
enumerable: true,
|
|
42
|
+
configurable: true,
|
|
43
|
+
writable: true,
|
|
44
|
+
value: "text/plain"
|
|
45
|
+
});
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
export class SageMakerEndpoint extends LLM {
|
|
49
|
+
constructor(fields) {
|
|
50
|
+
super(fields ?? {});
|
|
51
|
+
Object.defineProperty(this, "endpointName", {
|
|
52
|
+
enumerable: true,
|
|
53
|
+
configurable: true,
|
|
54
|
+
writable: true,
|
|
55
|
+
value: void 0
|
|
56
|
+
});
|
|
57
|
+
Object.defineProperty(this, "contentHandler", {
|
|
58
|
+
enumerable: true,
|
|
59
|
+
configurable: true,
|
|
60
|
+
writable: true,
|
|
61
|
+
value: void 0
|
|
62
|
+
});
|
|
63
|
+
Object.defineProperty(this, "modelKwargs", {
|
|
64
|
+
enumerable: true,
|
|
65
|
+
configurable: true,
|
|
66
|
+
writable: true,
|
|
67
|
+
value: void 0
|
|
68
|
+
});
|
|
69
|
+
Object.defineProperty(this, "endpointKwargs", {
|
|
70
|
+
enumerable: true,
|
|
71
|
+
configurable: true,
|
|
72
|
+
writable: true,
|
|
73
|
+
value: void 0
|
|
74
|
+
});
|
|
75
|
+
Object.defineProperty(this, "client", {
|
|
76
|
+
enumerable: true,
|
|
77
|
+
configurable: true,
|
|
78
|
+
writable: true,
|
|
79
|
+
value: void 0
|
|
80
|
+
});
|
|
81
|
+
const regionName = fields.clientOptions.region;
|
|
82
|
+
if (!regionName) {
|
|
83
|
+
throw new Error(`Please pass a "clientOptions" object with a "region" field to the constructor`);
|
|
84
|
+
}
|
|
85
|
+
const endpointName = fields?.endpointName;
|
|
86
|
+
if (!endpointName) {
|
|
87
|
+
throw new Error(`Please pass an "endpointName" field to the constructor`);
|
|
88
|
+
}
|
|
89
|
+
const contentHandler = fields?.contentHandler;
|
|
90
|
+
if (!contentHandler) {
|
|
91
|
+
throw new Error(`Please pass a "contentHandler" field to the constructor`);
|
|
92
|
+
}
|
|
93
|
+
this.endpointName = fields.endpointName;
|
|
94
|
+
this.contentHandler = fields.contentHandler;
|
|
95
|
+
this.endpointKwargs = fields.endpointKwargs;
|
|
96
|
+
this.modelKwargs = fields.modelKwargs;
|
|
97
|
+
this.client = new SageMakerRuntimeClient(fields.clientOptions);
|
|
98
|
+
}
|
|
99
|
+
_llmType() {
|
|
100
|
+
return "sagemaker_endpoint";
|
|
101
|
+
}
|
|
102
|
+
/** @ignore */
|
|
103
|
+
async _call(prompt, options) {
|
|
104
|
+
const body = await this.contentHandler.transformInput(prompt, this.modelKwargs ?? {});
|
|
105
|
+
const { contentType, accepts } = this.contentHandler;
|
|
106
|
+
const response = await this.caller.call(() => this.client.send(new InvokeEndpointCommand({
|
|
107
|
+
EndpointName: this.endpointName,
|
|
108
|
+
Body: body,
|
|
109
|
+
ContentType: contentType,
|
|
110
|
+
Accept: accepts,
|
|
111
|
+
...this.endpointKwargs,
|
|
112
|
+
}), { abortSignal: options.signal }));
|
|
113
|
+
if (response.Body === undefined) {
|
|
114
|
+
throw new Error("Inference result missing Body");
|
|
115
|
+
}
|
|
116
|
+
return this.contentHandler.transformOutput(response.Body);
|
|
117
|
+
}
|
|
118
|
+
}
|
|
@@ -45,7 +45,7 @@ class BufferMemory extends chat_memory_js_1.BaseChatMemory {
|
|
|
45
45
|
return result;
|
|
46
46
|
}
|
|
47
47
|
const result = {
|
|
48
|
-
[this.memoryKey]: (0, base_js_1.getBufferString)(messages),
|
|
48
|
+
[this.memoryKey]: (0, base_js_1.getBufferString)(messages, this.humanPrefix, this.aiPrefix),
|
|
49
49
|
};
|
|
50
50
|
return result;
|
|
51
51
|
}
|
|
@@ -52,7 +52,7 @@ class BufferWindowMemory extends chat_memory_js_1.BaseChatMemory {
|
|
|
52
52
|
return result;
|
|
53
53
|
}
|
|
54
54
|
const result = {
|
|
55
|
-
[this.memoryKey]: (0, base_js_1.getBufferString)(messages.slice(-this.k * 2)),
|
|
55
|
+
[this.memoryKey]: (0, base_js_1.getBufferString)(messages.slice(-this.k * 2), this.humanPrefix, this.aiPrefix),
|
|
56
56
|
};
|
|
57
57
|
return result;
|
|
58
58
|
}
|
|
@@ -49,7 +49,7 @@ export class BufferWindowMemory extends BaseChatMemory {
|
|
|
49
49
|
return result;
|
|
50
50
|
}
|
|
51
51
|
const result = {
|
|
52
|
-
[this.memoryKey]: getBufferString(messages.slice(-this.k * 2)),
|
|
52
|
+
[this.memoryKey]: getBufferString(messages.slice(-this.k * 2), this.humanPrefix, this.aiPrefix),
|
|
53
53
|
};
|
|
54
54
|
return result;
|
|
55
55
|
}
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
|
|
3
|
+
if (k2 === undefined) k2 = k;
|
|
4
|
+
var desc = Object.getOwnPropertyDescriptor(m, k);
|
|
5
|
+
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
|
|
6
|
+
desc = { enumerable: true, get: function() { return m[k]; } };
|
|
7
|
+
}
|
|
8
|
+
Object.defineProperty(o, k2, desc);
|
|
9
|
+
}) : (function(o, m, k, k2) {
|
|
10
|
+
if (k2 === undefined) k2 = k;
|
|
11
|
+
o[k2] = m[k];
|
|
12
|
+
}));
|
|
13
|
+
var __exportStar = (this && this.__exportStar) || function(m, exports) {
|
|
14
|
+
for (var p in m) if (p !== "default" && !Object.prototype.hasOwnProperty.call(exports, p)) __createBinding(exports, m, p);
|
|
15
|
+
};
|
|
16
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
17
|
+
exports.MasterHandler = exports.ExpressionParser = void 0;
|
|
18
|
+
const factory_js_1 = require("./expression_type_handlers/factory.cjs");
|
|
19
|
+
const output_parser_js_1 = require("../schema/output_parser.cjs");
|
|
20
|
+
const base_js_1 = require("./expression_type_handlers/base.cjs");
|
|
21
|
+
/**
|
|
22
|
+
* okay so we need to be able to handle the following cases:
|
|
23
|
+
* ExpressionStatement
|
|
24
|
+
* CallExpression
|
|
25
|
+
* Identifier | MemberExpression
|
|
26
|
+
* ExpressionLiterals: [
|
|
27
|
+
* CallExpression
|
|
28
|
+
* StringLiteral
|
|
29
|
+
* NumericLiteral
|
|
30
|
+
* ArrayLiteralExpression
|
|
31
|
+
* ExpressionLiterals
|
|
32
|
+
* ObjectLiteralExpression
|
|
33
|
+
* PropertyAssignment
|
|
34
|
+
* Identifier
|
|
35
|
+
* ExpressionLiterals
|
|
36
|
+
* ]
|
|
37
|
+
*/
|
|
38
|
+
class ExpressionParser extends output_parser_js_1.BaseOutputParser {
|
|
39
|
+
async parse(text) {
|
|
40
|
+
const parse = await base_js_1.ASTParser.importASTParser();
|
|
41
|
+
try {
|
|
42
|
+
const program = parse(text);
|
|
43
|
+
if (program.body.length > 1) {
|
|
44
|
+
throw new Error(`Expected 1 statement, got ${program.body.length}`);
|
|
45
|
+
}
|
|
46
|
+
const [node] = program.body;
|
|
47
|
+
if (!base_js_1.ASTParser.isExpressionStatement(node)) {
|
|
48
|
+
throw new Error(`Expected ExpressionStatement, got ${node.type}`);
|
|
49
|
+
}
|
|
50
|
+
const { expression: expressionStatement } = node;
|
|
51
|
+
if (!base_js_1.ASTParser.isCallExpression(expressionStatement)) {
|
|
52
|
+
throw new Error("Expected CallExpression");
|
|
53
|
+
}
|
|
54
|
+
const masterHandler = factory_js_1.MasterHandler.createMasterHandler();
|
|
55
|
+
return await masterHandler.handle(expressionStatement);
|
|
56
|
+
}
|
|
57
|
+
catch (err) {
|
|
58
|
+
throw new Error(`Error parsing ${err}: ${text}`);
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
getFormatInstructions() {
|
|
62
|
+
return "";
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
exports.ExpressionParser = ExpressionParser;
|
|
66
|
+
__exportStar(require("./expression_type_handlers/types.cjs"), exports);
|
|
67
|
+
var factory_js_2 = require("./expression_type_handlers/factory.cjs");
|
|
68
|
+
Object.defineProperty(exports, "MasterHandler", { enumerable: true, get: function () { return factory_js_2.MasterHandler; } });
|