langchain 0.1.9 → 0.1.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/agents/toolkits/sql/index.cjs +4 -1
- package/dist/agents/toolkits/sql/index.d.ts +1 -0
- package/dist/agents/toolkits/sql/index.js +1 -0
- package/dist/chains/combine_documents/base.cjs +1 -1
- package/dist/chains/combine_documents/base.js +1 -1
- package/dist/chains/openai_functions/base.d.ts +3 -3
- package/dist/chains/openai_functions/structured_output.cjs +2 -2
- package/dist/chains/openai_functions/structured_output.js +2 -2
- package/dist/chains/sql_db/index.cjs +3 -1
- package/dist/chains/sql_db/index.d.ts +2 -2
- package/dist/chains/sql_db/index.js +2 -2
- package/dist/chains/sql_db/sql_db_chain.cjs +82 -1
- package/dist/chains/sql_db/sql_db_chain.d.ts +41 -2
- package/dist/chains/sql_db/sql_db_chain.js +81 -1
- package/dist/chains/sql_db/sql_db_prompt.cjs +9 -1
- package/dist/chains/sql_db/sql_db_prompt.d.ts +3 -1
- package/dist/chains/sql_db/sql_db_prompt.js +8 -0
- package/package.json +1 -1
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.createSqlAgent = exports.SqlToolkit = void 0;
|
|
3
|
+
exports.SQL_SUFFIX = exports.SQL_PREFIX = exports.createSqlAgent = exports.SqlToolkit = void 0;
|
|
4
4
|
var sql_js_1 = require("./sql.cjs");
|
|
5
5
|
Object.defineProperty(exports, "SqlToolkit", { enumerable: true, get: function () { return sql_js_1.SqlToolkit; } });
|
|
6
6
|
Object.defineProperty(exports, "createSqlAgent", { enumerable: true, get: function () { return sql_js_1.createSqlAgent; } });
|
|
7
|
+
var prompt_js_1 = require("./prompt.cjs");
|
|
8
|
+
Object.defineProperty(exports, "SQL_PREFIX", { enumerable: true, get: function () { return prompt_js_1.SQL_PREFIX; } });
|
|
9
|
+
Object.defineProperty(exports, "SQL_SUFFIX", { enumerable: true, get: function () { return prompt_js_1.SQL_SUFFIX; } });
|
|
@@ -10,7 +10,7 @@ exports.DEFAULT_DOCUMENT_PROMPT =
|
|
|
10
10
|
async function formatDocuments({ documentPrompt, documentSeparator, documents, config, }) {
|
|
11
11
|
const formattedDocs = await Promise.all(documents.map((document) => documentPrompt
|
|
12
12
|
.withConfig({ runName: "document_formatter" })
|
|
13
|
-
.invoke({ page_content: document.pageContent }, config)));
|
|
13
|
+
.invoke({ ...document.metadata, page_content: document.pageContent }, config)));
|
|
14
14
|
return formattedDocs.join(documentSeparator);
|
|
15
15
|
}
|
|
16
16
|
exports.formatDocuments = formatDocuments;
|
|
@@ -7,6 +7,6 @@ export const DEFAULT_DOCUMENT_PROMPT =
|
|
|
7
7
|
export async function formatDocuments({ documentPrompt, documentSeparator, documents, config, }) {
|
|
8
8
|
const formattedDocs = await Promise.all(documents.map((document) => documentPrompt
|
|
9
9
|
.withConfig({ runName: "document_formatter" })
|
|
10
|
-
.invoke({ page_content: document.pageContent }, config)));
|
|
10
|
+
.invoke({ ...document.metadata, page_content: document.pageContent }, config)));
|
|
11
11
|
return formattedDocs.join(documentSeparator);
|
|
12
12
|
}
|
|
@@ -2,7 +2,7 @@ import type { z } from "zod";
|
|
|
2
2
|
import { JsonSchema7Type } from "zod-to-json-schema";
|
|
3
3
|
import type { BaseOutputParser } from "@langchain/core/output_parsers";
|
|
4
4
|
import type { BasePromptTemplate } from "@langchain/core/prompts";
|
|
5
|
-
import type { RunnableInterface } from "@langchain/core/runnables";
|
|
5
|
+
import type { Runnable, RunnableInterface } from "@langchain/core/runnables";
|
|
6
6
|
import type { BaseFunctionCallOptions, BaseLanguageModelInput, FunctionDefinition } from "@langchain/core/language_models/base";
|
|
7
7
|
import type { InputValues } from "@langchain/core/utils/types";
|
|
8
8
|
import type { BaseMessage } from "@langchain/core/messages";
|
|
@@ -77,7 +77,7 @@ export type CreateOpenAIFnRunnableConfig<RunInput extends Record<string, any>, R
|
|
|
77
77
|
* // { name: 'John Doe', age: 30, fav_food: 'chocolate chip cookies' }
|
|
78
78
|
* ```
|
|
79
79
|
*/
|
|
80
|
-
export declare function createOpenAIFnRunnable<RunInput extends Record<string, any> = Record<string, any>, RunOutput extends Record<string, any> = Record<string, any>>(config: CreateOpenAIFnRunnableConfig<RunInput, RunOutput>):
|
|
80
|
+
export declare function createOpenAIFnRunnable<RunInput extends Record<string, any> = Record<string, any>, RunOutput extends Record<string, any> = Record<string, any>>(config: CreateOpenAIFnRunnableConfig<RunInput, RunOutput>): Runnable<RunInput, RunOutput>;
|
|
81
81
|
/**
|
|
82
82
|
* Configuration params for the createStructuredOutputRunnable method.
|
|
83
83
|
*/
|
|
@@ -150,4 +150,4 @@ export type CreateStructuredOutputRunnableConfig<RunInput extends Record<string,
|
|
|
150
150
|
* // { name: 'John Doe', age: 30, fav_food: 'chocolate chip cookies' }
|
|
151
151
|
* ```
|
|
152
152
|
*/
|
|
153
|
-
export declare function createStructuredOutputRunnable<RunInput extends Record<string, any> = Record<string, any>, RunOutput extends Record<string, any> = Record<string, any>>(config: CreateStructuredOutputRunnableConfig<RunInput, RunOutput>):
|
|
153
|
+
export declare function createStructuredOutputRunnable<RunInput extends Record<string, any> = Record<string, any>, RunOutput extends Record<string, any> = Record<string, any>>(config: CreateStructuredOutputRunnableConfig<RunInput, RunOutput>): Runnable<RunInput, RunOutput>;
|
|
@@ -25,7 +25,7 @@ class FunctionCallStructuredOutputParser extends output_parsers_1.BaseLLMOutputP
|
|
|
25
25
|
fields = fieldsOrSchema;
|
|
26
26
|
}
|
|
27
27
|
if (fields.jsonSchema === undefined && fields.zodSchema === undefined) {
|
|
28
|
-
throw new Error(`Must provide one of "jsonSchema" or "zodSchema".`);
|
|
28
|
+
throw new Error(`Must provide at least one of "jsonSchema" or "zodSchema".`);
|
|
29
29
|
}
|
|
30
30
|
super(fields);
|
|
31
31
|
Object.defineProperty(this, "lc_namespace", {
|
|
@@ -55,7 +55,7 @@ class FunctionCallStructuredOutputParser extends output_parsers_1.BaseLLMOutputP
|
|
|
55
55
|
if (fields.jsonSchema !== undefined) {
|
|
56
56
|
this.jsonSchemaValidator = new json_schema_1.Validator(fields.jsonSchema, "7");
|
|
57
57
|
}
|
|
58
|
-
|
|
58
|
+
if (fields.zodSchema !== undefined) {
|
|
59
59
|
this.zodSchema = fields.zodSchema;
|
|
60
60
|
}
|
|
61
61
|
}
|
|
@@ -22,7 +22,7 @@ export class FunctionCallStructuredOutputParser extends BaseLLMOutputParser {
|
|
|
22
22
|
fields = fieldsOrSchema;
|
|
23
23
|
}
|
|
24
24
|
if (fields.jsonSchema === undefined && fields.zodSchema === undefined) {
|
|
25
|
-
throw new Error(`Must provide one of "jsonSchema" or "zodSchema".`);
|
|
25
|
+
throw new Error(`Must provide at least one of "jsonSchema" or "zodSchema".`);
|
|
26
26
|
}
|
|
27
27
|
super(fields);
|
|
28
28
|
Object.defineProperty(this, "lc_namespace", {
|
|
@@ -52,7 +52,7 @@ export class FunctionCallStructuredOutputParser extends BaseLLMOutputParser {
|
|
|
52
52
|
if (fields.jsonSchema !== undefined) {
|
|
53
53
|
this.jsonSchemaValidator = new Validator(fields.jsonSchema, "7");
|
|
54
54
|
}
|
|
55
|
-
|
|
55
|
+
if (fields.zodSchema !== undefined) {
|
|
56
56
|
this.zodSchema = fields.zodSchema;
|
|
57
57
|
}
|
|
58
58
|
}
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.SQL_SAP_HANA_PROMPT = exports.SQL_MYSQL_PROMPT = exports.SQL_MSSQL_PROMPT = exports.SQL_SQLITE_PROMPT = exports.SQL_POSTGRES_PROMPT = exports.DEFAULT_SQL_DATABASE_PROMPT = exports.SqlDatabaseChain = void 0;
|
|
3
|
+
exports.SQL_PROMPTS_MAP = exports.SQL_SAP_HANA_PROMPT = exports.SQL_MYSQL_PROMPT = exports.SQL_MSSQL_PROMPT = exports.SQL_SQLITE_PROMPT = exports.SQL_POSTGRES_PROMPT = exports.DEFAULT_SQL_DATABASE_PROMPT = exports.createSqlQueryChain = exports.SqlDatabaseChain = void 0;
|
|
4
4
|
var sql_db_chain_js_1 = require("./sql_db_chain.cjs");
|
|
5
5
|
Object.defineProperty(exports, "SqlDatabaseChain", { enumerable: true, get: function () { return sql_db_chain_js_1.SqlDatabaseChain; } });
|
|
6
|
+
Object.defineProperty(exports, "createSqlQueryChain", { enumerable: true, get: function () { return sql_db_chain_js_1.createSqlQueryChain; } });
|
|
6
7
|
var sql_db_prompt_js_1 = require("./sql_db_prompt.cjs");
|
|
7
8
|
Object.defineProperty(exports, "DEFAULT_SQL_DATABASE_PROMPT", { enumerable: true, get: function () { return sql_db_prompt_js_1.DEFAULT_SQL_DATABASE_PROMPT; } });
|
|
8
9
|
Object.defineProperty(exports, "SQL_POSTGRES_PROMPT", { enumerable: true, get: function () { return sql_db_prompt_js_1.SQL_POSTGRES_PROMPT; } });
|
|
@@ -10,3 +11,4 @@ Object.defineProperty(exports, "SQL_SQLITE_PROMPT", { enumerable: true, get: fun
|
|
|
10
11
|
Object.defineProperty(exports, "SQL_MSSQL_PROMPT", { enumerable: true, get: function () { return sql_db_prompt_js_1.SQL_MSSQL_PROMPT; } });
|
|
11
12
|
Object.defineProperty(exports, "SQL_MYSQL_PROMPT", { enumerable: true, get: function () { return sql_db_prompt_js_1.SQL_MYSQL_PROMPT; } });
|
|
12
13
|
Object.defineProperty(exports, "SQL_SAP_HANA_PROMPT", { enumerable: true, get: function () { return sql_db_prompt_js_1.SQL_SAP_HANA_PROMPT; } });
|
|
14
|
+
Object.defineProperty(exports, "SQL_PROMPTS_MAP", { enumerable: true, get: function () { return sql_db_prompt_js_1.SQL_PROMPTS_MAP; } });
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
export { SqlDatabaseChain, SqlDatabaseChainInput } from "./sql_db_chain.js";
|
|
2
|
-
export { DEFAULT_SQL_DATABASE_PROMPT, SQL_POSTGRES_PROMPT, SQL_SQLITE_PROMPT, SQL_MSSQL_PROMPT, SQL_MYSQL_PROMPT, SQL_SAP_HANA_PROMPT, } from "./sql_db_prompt.js";
|
|
1
|
+
export { SqlDatabaseChain, type SqlDatabaseChainInput, type CreateSqlQueryChainFields, createSqlQueryChain, } from "./sql_db_chain.js";
|
|
2
|
+
export { DEFAULT_SQL_DATABASE_PROMPT, SQL_POSTGRES_PROMPT, SQL_SQLITE_PROMPT, SQL_MSSQL_PROMPT, SQL_MYSQL_PROMPT, SQL_SAP_HANA_PROMPT, SQL_PROMPTS_MAP, } from "./sql_db_prompt.js";
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
export { SqlDatabaseChain } from "./sql_db_chain.js";
|
|
2
|
-
export { DEFAULT_SQL_DATABASE_PROMPT, SQL_POSTGRES_PROMPT, SQL_SQLITE_PROMPT, SQL_MSSQL_PROMPT, SQL_MYSQL_PROMPT, SQL_SAP_HANA_PROMPT, } from "./sql_db_prompt.js";
|
|
1
|
+
export { SqlDatabaseChain, createSqlQueryChain, } from "./sql_db_chain.js";
|
|
2
|
+
export { DEFAULT_SQL_DATABASE_PROMPT, SQL_POSTGRES_PROMPT, SQL_SQLITE_PROMPT, SQL_MSSQL_PROMPT, SQL_MYSQL_PROMPT, SQL_SAP_HANA_PROMPT, SQL_PROMPTS_MAP, } from "./sql_db_prompt.js";
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.SqlDatabaseChain = void 0;
|
|
3
|
+
exports.createSqlQueryChain = exports.SqlDatabaseChain = void 0;
|
|
4
4
|
const base_1 = require("@langchain/core/language_models/base");
|
|
5
|
+
const runnables_1 = require("@langchain/core/runnables");
|
|
6
|
+
const output_parsers_1 = require("@langchain/core/output_parsers");
|
|
5
7
|
const sql_db_prompt_js_1 = require("./sql_db_prompt.cjs");
|
|
6
8
|
const base_js_1 = require("../base.cjs");
|
|
7
9
|
const llm_chain_js_1 = require("../llm_chain.cjs");
|
|
@@ -187,3 +189,82 @@ class SqlDatabaseChain extends base_js_1.BaseChain {
|
|
|
187
189
|
}
|
|
188
190
|
}
|
|
189
191
|
exports.SqlDatabaseChain = SqlDatabaseChain;
|
|
192
|
+
const strip = (text) => {
|
|
193
|
+
// Replace escaped quotes with actual quotes
|
|
194
|
+
let newText = text.replace(/\\"/g, '"').trim();
|
|
195
|
+
// Remove wrapping quotes if the entire string is wrapped in quotes
|
|
196
|
+
if (newText.startsWith('"') && newText.endsWith('"')) {
|
|
197
|
+
newText = newText.substring(1, newText.length - 1);
|
|
198
|
+
}
|
|
199
|
+
return newText;
|
|
200
|
+
};
|
|
201
|
+
const difference = (setA, setB) => new Set([...setA].filter((x) => !setB.has(x)));
|
|
202
|
+
/**
|
|
203
|
+
* Create a SQL query chain that can create SQL queries for the given database.
|
|
204
|
+
* Returns a Runnable.
|
|
205
|
+
*
|
|
206
|
+
* @param {BaseLanguageModel} llm The language model to use in the chain.
|
|
207
|
+
* @param {SqlDatabase} db The database to use in the chain.
|
|
208
|
+
* @param {BasePromptTemplate | undefined} prompt The prompt to use in the chain.
|
|
209
|
+
* @param {BaseLanguageModel | undefined} k The amount of docs/results to return. Passed through the prompt input value `top_k`.
|
|
210
|
+
* @param {SqlDialect} dialect The SQL dialect to use in the chain.
|
|
211
|
+
* @returns {Promise<RunnableSequence<Record<string, unknown>, string>>} A runnable sequence representing the chain.
|
|
212
|
+
* @example ```typescript
|
|
213
|
+
* const datasource = new DataSource({
|
|
214
|
+
* type: "sqlite",
|
|
215
|
+
* database: "../../../../Chinook.db",
|
|
216
|
+
* });
|
|
217
|
+
* const db = await SqlDatabase.fromDataSourceParams({
|
|
218
|
+
* appDataSource: datasource,
|
|
219
|
+
* });
|
|
220
|
+
* const llm = new ChatOpenAI({ temperature: 0 });
|
|
221
|
+
* const chain = await createSqlQueryChain({
|
|
222
|
+
* llm,
|
|
223
|
+
* db,
|
|
224
|
+
* dialect: "sqlite",
|
|
225
|
+
* });
|
|
226
|
+
* ```
|
|
227
|
+
*/
|
|
228
|
+
async function createSqlQueryChain({ llm, db, prompt, k = 5, dialect, }) {
|
|
229
|
+
let promptToUse;
|
|
230
|
+
if (prompt) {
|
|
231
|
+
promptToUse = prompt;
|
|
232
|
+
}
|
|
233
|
+
else if (sql_db_prompt_js_1.SQL_PROMPTS_MAP[dialect]) {
|
|
234
|
+
promptToUse = sql_db_prompt_js_1.SQL_PROMPTS_MAP[dialect];
|
|
235
|
+
}
|
|
236
|
+
else {
|
|
237
|
+
promptToUse = sql_db_prompt_js_1.DEFAULT_SQL_DATABASE_PROMPT;
|
|
238
|
+
}
|
|
239
|
+
if (difference(new Set(["input", "top_k", "table_info"]), new Set(promptToUse.inputVariables)).size > 0) {
|
|
240
|
+
throw new Error(`Prompt must have input variables: 'input', 'top_k', 'table_info'. Received prompt with input variables: ` +
|
|
241
|
+
`${promptToUse.inputVariables}. Full prompt:\n\n${promptToUse}`);
|
|
242
|
+
}
|
|
243
|
+
if (promptToUse.inputVariables.includes("dialect")) {
|
|
244
|
+
promptToUse = await promptToUse.partial({ dialect });
|
|
245
|
+
}
|
|
246
|
+
promptToUse = await promptToUse.partial({ top_k: k.toString() });
|
|
247
|
+
const inputs = {
|
|
248
|
+
input: (x) => {
|
|
249
|
+
if ("question" in x) {
|
|
250
|
+
return `${x.question}\nSQLQuery: `;
|
|
251
|
+
}
|
|
252
|
+
throw new Error("Input must include a question property.");
|
|
253
|
+
},
|
|
254
|
+
table_info: async (x) => db.getTableInfo(x.tableNamesToUse),
|
|
255
|
+
};
|
|
256
|
+
return runnables_1.RunnableSequence.from([
|
|
257
|
+
runnables_1.RunnablePassthrough.assign(inputs),
|
|
258
|
+
(x) => {
|
|
259
|
+
const newInputs = { ...x };
|
|
260
|
+
delete newInputs.question;
|
|
261
|
+
delete newInputs.tableNamesToUse;
|
|
262
|
+
return newInputs;
|
|
263
|
+
},
|
|
264
|
+
promptToUse,
|
|
265
|
+
llm.bind({ stop: ["\nSQLResult:"] }),
|
|
266
|
+
new output_parsers_1.StringOutputParser(),
|
|
267
|
+
strip,
|
|
268
|
+
]);
|
|
269
|
+
}
|
|
270
|
+
exports.createSqlQueryChain = createSqlQueryChain;
|
|
@@ -1,7 +1,9 @@
|
|
|
1
|
-
import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base";
|
|
1
|
+
import type { BaseLanguageModel, BaseLanguageModelInterface } from "@langchain/core/language_models/base";
|
|
2
2
|
import { ChainValues } from "@langchain/core/utils/types";
|
|
3
|
-
import { PromptTemplate } from "@langchain/core/prompts";
|
|
3
|
+
import { BasePromptTemplate, PromptTemplate } from "@langchain/core/prompts";
|
|
4
4
|
import { CallbackManagerForChainRun } from "@langchain/core/callbacks/manager";
|
|
5
|
+
import { RunnableSequence } from "@langchain/core/runnables";
|
|
6
|
+
import { SqlDialect } from "./sql_db_prompt.js";
|
|
5
7
|
import { BaseChain, ChainInputs } from "../base.js";
|
|
6
8
|
import type { SqlDatabase } from "../../sql_db.js";
|
|
7
9
|
/**
|
|
@@ -74,3 +76,40 @@ export declare class SqlDatabaseChain extends BaseChain {
|
|
|
74
76
|
*/
|
|
75
77
|
private verifyNumberOfTokens;
|
|
76
78
|
}
|
|
79
|
+
export interface CreateSqlQueryChainFields {
|
|
80
|
+
llm: BaseLanguageModel;
|
|
81
|
+
db: SqlDatabase;
|
|
82
|
+
prompt?: BasePromptTemplate;
|
|
83
|
+
/**
|
|
84
|
+
* @default 5
|
|
85
|
+
*/
|
|
86
|
+
k?: number;
|
|
87
|
+
dialect: SqlDialect;
|
|
88
|
+
}
|
|
89
|
+
/**
|
|
90
|
+
* Create a SQL query chain that can create SQL queries for the given database.
|
|
91
|
+
* Returns a Runnable.
|
|
92
|
+
*
|
|
93
|
+
* @param {BaseLanguageModel} llm The language model to use in the chain.
|
|
94
|
+
* @param {SqlDatabase} db The database to use in the chain.
|
|
95
|
+
* @param {BasePromptTemplate | undefined} prompt The prompt to use in the chain.
|
|
96
|
+
* @param {BaseLanguageModel | undefined} k The amount of docs/results to return. Passed through the prompt input value `top_k`.
|
|
97
|
+
* @param {SqlDialect} dialect The SQL dialect to use in the chain.
|
|
98
|
+
* @returns {Promise<RunnableSequence<Record<string, unknown>, string>>} A runnable sequence representing the chain.
|
|
99
|
+
* @example ```typescript
|
|
100
|
+
* const datasource = new DataSource({
|
|
101
|
+
* type: "sqlite",
|
|
102
|
+
* database: "../../../../Chinook.db",
|
|
103
|
+
* });
|
|
104
|
+
* const db = await SqlDatabase.fromDataSourceParams({
|
|
105
|
+
* appDataSource: datasource,
|
|
106
|
+
* });
|
|
107
|
+
* const llm = new ChatOpenAI({ temperature: 0 });
|
|
108
|
+
* const chain = await createSqlQueryChain({
|
|
109
|
+
* llm,
|
|
110
|
+
* db,
|
|
111
|
+
* dialect: "sqlite",
|
|
112
|
+
* });
|
|
113
|
+
* ```
|
|
114
|
+
*/
|
|
115
|
+
export declare function createSqlQueryChain({ llm, db, prompt, k, dialect, }: CreateSqlQueryChainFields): Promise<RunnableSequence<Record<string, unknown>, string>>;
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import { calculateMaxTokens, getModelContextSize, } from "@langchain/core/language_models/base";
|
|
2
|
-
import {
|
|
2
|
+
import { RunnablePassthrough, RunnableSequence, } from "@langchain/core/runnables";
|
|
3
|
+
import { StringOutputParser } from "@langchain/core/output_parsers";
|
|
4
|
+
import { DEFAULT_SQL_DATABASE_PROMPT, SQL_PROMPTS_MAP, } from "./sql_db_prompt.js";
|
|
3
5
|
import { BaseChain } from "../base.js";
|
|
4
6
|
import { LLMChain } from "../llm_chain.js";
|
|
5
7
|
import { getPromptTemplateFromDataSource } from "../../util/sql_utils.js";
|
|
@@ -183,3 +185,81 @@ export class SqlDatabaseChain extends BaseChain {
|
|
|
183
185
|
}
|
|
184
186
|
}
|
|
185
187
|
}
|
|
188
|
+
const strip = (text) => {
|
|
189
|
+
// Replace escaped quotes with actual quotes
|
|
190
|
+
let newText = text.replace(/\\"/g, '"').trim();
|
|
191
|
+
// Remove wrapping quotes if the entire string is wrapped in quotes
|
|
192
|
+
if (newText.startsWith('"') && newText.endsWith('"')) {
|
|
193
|
+
newText = newText.substring(1, newText.length - 1);
|
|
194
|
+
}
|
|
195
|
+
return newText;
|
|
196
|
+
};
|
|
197
|
+
const difference = (setA, setB) => new Set([...setA].filter((x) => !setB.has(x)));
|
|
198
|
+
/**
|
|
199
|
+
* Create a SQL query chain that can create SQL queries for the given database.
|
|
200
|
+
* Returns a Runnable.
|
|
201
|
+
*
|
|
202
|
+
* @param {BaseLanguageModel} llm The language model to use in the chain.
|
|
203
|
+
* @param {SqlDatabase} db The database to use in the chain.
|
|
204
|
+
* @param {BasePromptTemplate | undefined} prompt The prompt to use in the chain.
|
|
205
|
+
* @param {BaseLanguageModel | undefined} k The amount of docs/results to return. Passed through the prompt input value `top_k`.
|
|
206
|
+
* @param {SqlDialect} dialect The SQL dialect to use in the chain.
|
|
207
|
+
* @returns {Promise<RunnableSequence<Record<string, unknown>, string>>} A runnable sequence representing the chain.
|
|
208
|
+
* @example ```typescript
|
|
209
|
+
* const datasource = new DataSource({
|
|
210
|
+
* type: "sqlite",
|
|
211
|
+
* database: "../../../../Chinook.db",
|
|
212
|
+
* });
|
|
213
|
+
* const db = await SqlDatabase.fromDataSourceParams({
|
|
214
|
+
* appDataSource: datasource,
|
|
215
|
+
* });
|
|
216
|
+
* const llm = new ChatOpenAI({ temperature: 0 });
|
|
217
|
+
* const chain = await createSqlQueryChain({
|
|
218
|
+
* llm,
|
|
219
|
+
* db,
|
|
220
|
+
* dialect: "sqlite",
|
|
221
|
+
* });
|
|
222
|
+
* ```
|
|
223
|
+
*/
|
|
224
|
+
export async function createSqlQueryChain({ llm, db, prompt, k = 5, dialect, }) {
|
|
225
|
+
let promptToUse;
|
|
226
|
+
if (prompt) {
|
|
227
|
+
promptToUse = prompt;
|
|
228
|
+
}
|
|
229
|
+
else if (SQL_PROMPTS_MAP[dialect]) {
|
|
230
|
+
promptToUse = SQL_PROMPTS_MAP[dialect];
|
|
231
|
+
}
|
|
232
|
+
else {
|
|
233
|
+
promptToUse = DEFAULT_SQL_DATABASE_PROMPT;
|
|
234
|
+
}
|
|
235
|
+
if (difference(new Set(["input", "top_k", "table_info"]), new Set(promptToUse.inputVariables)).size > 0) {
|
|
236
|
+
throw new Error(`Prompt must have input variables: 'input', 'top_k', 'table_info'. Received prompt with input variables: ` +
|
|
237
|
+
`${promptToUse.inputVariables}. Full prompt:\n\n${promptToUse}`);
|
|
238
|
+
}
|
|
239
|
+
if (promptToUse.inputVariables.includes("dialect")) {
|
|
240
|
+
promptToUse = await promptToUse.partial({ dialect });
|
|
241
|
+
}
|
|
242
|
+
promptToUse = await promptToUse.partial({ top_k: k.toString() });
|
|
243
|
+
const inputs = {
|
|
244
|
+
input: (x) => {
|
|
245
|
+
if ("question" in x) {
|
|
246
|
+
return `${x.question}\nSQLQuery: `;
|
|
247
|
+
}
|
|
248
|
+
throw new Error("Input must include a question property.");
|
|
249
|
+
},
|
|
250
|
+
table_info: async (x) => db.getTableInfo(x.tableNamesToUse),
|
|
251
|
+
};
|
|
252
|
+
return RunnableSequence.from([
|
|
253
|
+
RunnablePassthrough.assign(inputs),
|
|
254
|
+
(x) => {
|
|
255
|
+
const newInputs = { ...x };
|
|
256
|
+
delete newInputs.question;
|
|
257
|
+
delete newInputs.tableNamesToUse;
|
|
258
|
+
return newInputs;
|
|
259
|
+
},
|
|
260
|
+
promptToUse,
|
|
261
|
+
llm.bind({ stop: ["\nSQLResult:"] }),
|
|
262
|
+
new StringOutputParser(),
|
|
263
|
+
strip,
|
|
264
|
+
]);
|
|
265
|
+
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.SQL_ORACLE_PROMPT = exports.SQL_SAP_HANA_PROMPT = exports.SQL_MSSQL_PROMPT = exports.SQL_MYSQL_PROMPT = exports.SQL_SQLITE_PROMPT = exports.SQL_POSTGRES_PROMPT = exports.DEFAULT_SQL_DATABASE_PROMPT = void 0;
|
|
3
|
+
exports.SQL_PROMPTS_MAP = exports.SQL_ORACLE_PROMPT = exports.SQL_SAP_HANA_PROMPT = exports.SQL_MSSQL_PROMPT = exports.SQL_MYSQL_PROMPT = exports.SQL_SQLITE_PROMPT = exports.SQL_POSTGRES_PROMPT = exports.DEFAULT_SQL_DATABASE_PROMPT = void 0;
|
|
4
4
|
/* eslint-disable spaced-comment */
|
|
5
5
|
const prompts_1 = require("@langchain/core/prompts");
|
|
6
6
|
exports.DEFAULT_SQL_DATABASE_PROMPT = new prompts_1.PromptTemplate({
|
|
@@ -139,3 +139,11 @@ Only use the following tables:
|
|
|
139
139
|
Question: {input}`,
|
|
140
140
|
inputVariables: ["dialect", "table_info", "input", "top_k"],
|
|
141
141
|
});
|
|
142
|
+
exports.SQL_PROMPTS_MAP = {
|
|
143
|
+
oracle: exports.SQL_ORACLE_PROMPT,
|
|
144
|
+
postgres: exports.SQL_POSTGRES_PROMPT,
|
|
145
|
+
sqlite: exports.SQL_SQLITE_PROMPT,
|
|
146
|
+
mysql: exports.SQL_MYSQL_PROMPT,
|
|
147
|
+
mssql: exports.SQL_MSSQL_PROMPT,
|
|
148
|
+
"sap hana": exports.SQL_SAP_HANA_PROMPT,
|
|
149
|
+
};
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { PromptTemplate } from "@langchain/core/prompts";
|
|
1
|
+
import { BasePromptTemplate, PromptTemplate } from "@langchain/core/prompts";
|
|
2
2
|
export declare const DEFAULT_SQL_DATABASE_PROMPT: PromptTemplate<{
|
|
3
3
|
input: any;
|
|
4
4
|
top_k: any;
|
|
@@ -41,3 +41,5 @@ export declare const SQL_ORACLE_PROMPT: PromptTemplate<{
|
|
|
41
41
|
dialect: any;
|
|
42
42
|
table_info: any;
|
|
43
43
|
}, any>;
|
|
44
|
+
export type SqlDialect = "oracle" | "postgres" | "sqlite" | "mysql" | "mssql" | "sap hana";
|
|
45
|
+
export declare const SQL_PROMPTS_MAP: Record<SqlDialect, BasePromptTemplate>;
|
|
@@ -136,3 +136,11 @@ Only use the following tables:
|
|
|
136
136
|
Question: {input}`,
|
|
137
137
|
inputVariables: ["dialect", "table_info", "input", "top_k"],
|
|
138
138
|
});
|
|
139
|
+
export const SQL_PROMPTS_MAP = {
|
|
140
|
+
oracle: SQL_ORACLE_PROMPT,
|
|
141
|
+
postgres: SQL_POSTGRES_PROMPT,
|
|
142
|
+
sqlite: SQL_SQLITE_PROMPT,
|
|
143
|
+
mysql: SQL_MYSQL_PROMPT,
|
|
144
|
+
mssql: SQL_MSSQL_PROMPT,
|
|
145
|
+
"sap hana": SQL_SAP_HANA_PROMPT,
|
|
146
|
+
};
|