langchain-kinetica 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_kinetica/__init__.py +31 -6
- langchain_kinetica/chat_models.py +537 -0
- langchain_kinetica/document_loaders.py +89 -0
- langchain_kinetica/py.typed +0 -0
- langchain_kinetica/vectorstores.py +934 -0
- langchain_kinetica-1.2.0.dist-info/METADATA +69 -0
- langchain_kinetica-1.2.0.dist-info/RECORD +8 -0
- {langchain_kinetica-1.0.0.dist-info → langchain_kinetica-1.2.0.dist-info}/WHEEL +1 -2
- langchain_kinetica/llm_chat.py +0 -183
- langchain_kinetica/sa_datafile.py +0 -60
- langchain_kinetica/sa_dto.py +0 -111
- langchain_kinetica/sql_output.py +0 -45
- langchain_kinetica-1.0.0.dist-info/LICENSE +0 -21
- langchain_kinetica-1.0.0.dist-info/METADATA +0 -110
- langchain_kinetica-1.0.0.dist-info/RECORD +0 -10
- langchain_kinetica-1.0.0.dist-info/top_level.txt +0 -1
langchain_kinetica/__init__.py
CHANGED
|
@@ -1,8 +1,33 @@
|
|
|
1
|
-
|
|
2
|
-
# Copyright (c) 2024, Chad Juliano, Kinetica DB Inc.
|
|
3
|
-
##
|
|
1
|
+
"""An integration package connecting Kinetica and LangChain."""
|
|
4
2
|
|
|
5
|
-
|
|
3
|
+
from importlib import metadata
|
|
6
4
|
|
|
7
|
-
from .
|
|
8
|
-
|
|
5
|
+
from langchain_kinetica.chat_models import (
|
|
6
|
+
ChatKinetica,
|
|
7
|
+
KineticaSqlOutputParser,
|
|
8
|
+
KineticaSqlResponse,
|
|
9
|
+
)
|
|
10
|
+
from langchain_kinetica.document_loaders import KineticaLoader
|
|
11
|
+
from langchain_kinetica.vectorstores import (
|
|
12
|
+
DistanceStrategy,
|
|
13
|
+
KineticaSettings,
|
|
14
|
+
KineticaVectorstore,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
__version__ = metadata.version(__package__)
|
|
19
|
+
except metadata.PackageNotFoundError:
|
|
20
|
+
# Case where package metadata is not available.
|
|
21
|
+
__version__ = ""
|
|
22
|
+
del metadata # optional, avoids polluting the results of dir(__package__)
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"ChatKinetica",
|
|
26
|
+
"DistanceStrategy",
|
|
27
|
+
"KineticaLoader",
|
|
28
|
+
"KineticaSettings",
|
|
29
|
+
"KineticaSqlOutputParser",
|
|
30
|
+
"KineticaSqlResponse",
|
|
31
|
+
"KineticaVectorstore",
|
|
32
|
+
"__version__",
|
|
33
|
+
]
|
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
##
|
|
2
|
+
# Copyright (c) 2024, Chad Juliano, Kinetica DB Inc.
|
|
3
|
+
##
|
|
4
|
+
"""Kinetica SQL generation LLM API."""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import re
|
|
10
|
+
from importlib.metadata import version
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from re import Pattern
|
|
13
|
+
from typing import Any, cast, override
|
|
14
|
+
|
|
15
|
+
from gpudb import GPUdb
|
|
16
|
+
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
17
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
18
|
+
from langchain_core.messages import (
|
|
19
|
+
AIMessage,
|
|
20
|
+
BaseMessage,
|
|
21
|
+
HumanMessage,
|
|
22
|
+
SystemMessage,
|
|
23
|
+
)
|
|
24
|
+
from langchain_core.output_parsers.transform import BaseOutputParser
|
|
25
|
+
from langchain_core.outputs import ChatGeneration, ChatResult, Generation
|
|
26
|
+
from langchain_core.utils import pre_init
|
|
27
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
28
|
+
|
|
29
|
+
LOG = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
# Kinetica pydantic API datatypes
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class _KdtSuggestContext(BaseModel):
|
|
35
|
+
"""pydantic API request type."""
|
|
36
|
+
|
|
37
|
+
table: str | None = Field(default=None, title="Name of table")
|
|
38
|
+
description: str | None = Field(default=None, title="Table description")
|
|
39
|
+
columns: list[str] = Field(default=[], title="Table columns list")
|
|
40
|
+
rules: list[str] | None = Field(
|
|
41
|
+
default=None, title="Rules that apply to the table."
|
|
42
|
+
)
|
|
43
|
+
samples: dict | None = Field(
|
|
44
|
+
default=None, title="Samples that apply to the entire context."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def to_system_str(self) -> str:
|
|
48
|
+
lines = []
|
|
49
|
+
lines.append(f"CREATE TABLE {self.table} AS")
|
|
50
|
+
lines.append("(")
|
|
51
|
+
|
|
52
|
+
if not self.columns or len(self.columns) == 0:
|
|
53
|
+
msg = "columns list can't be null."
|
|
54
|
+
raise ValueError(msg)
|
|
55
|
+
|
|
56
|
+
columns = []
|
|
57
|
+
for column in self.columns:
|
|
58
|
+
column_new = column.replace('"', "").strip()
|
|
59
|
+
columns.append(f" {column_new}")
|
|
60
|
+
lines.append(",\n".join(columns))
|
|
61
|
+
lines.append(");")
|
|
62
|
+
|
|
63
|
+
if self.description:
|
|
64
|
+
lines.append(f"COMMENT ON TABLE {self.table} IS '{self.description}';")
|
|
65
|
+
|
|
66
|
+
if self.rules and len(self.rules) > 0:
|
|
67
|
+
lines.append(
|
|
68
|
+
f"-- When querying table {self.table} the following rules apply:"
|
|
69
|
+
)
|
|
70
|
+
lines.extend(f"-- * {rule}" for rule in self.rules)
|
|
71
|
+
|
|
72
|
+
return "\n".join(lines)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class _KdtSuggestPayload(BaseModel):
|
|
76
|
+
"""pydantic API request type."""
|
|
77
|
+
|
|
78
|
+
question: str | None = None
|
|
79
|
+
context: list[_KdtSuggestContext]
|
|
80
|
+
|
|
81
|
+
def get_system_str(self) -> str:
|
|
82
|
+
lines = []
|
|
83
|
+
for table_context in self.context:
|
|
84
|
+
if table_context.table is None:
|
|
85
|
+
continue
|
|
86
|
+
context_str = table_context.to_system_str()
|
|
87
|
+
lines.append(context_str)
|
|
88
|
+
return "\n\n".join(lines)
|
|
89
|
+
|
|
90
|
+
def get_messages(self) -> list[dict]:
|
|
91
|
+
messages = []
|
|
92
|
+
for context in self.context:
|
|
93
|
+
if context.samples is None:
|
|
94
|
+
continue
|
|
95
|
+
for question, answer in context.samples.items():
|
|
96
|
+
# unescape double quotes
|
|
97
|
+
answer_new = answer.replace("''", "'")
|
|
98
|
+
|
|
99
|
+
messages.append({"role": "user", "content": question or ""})
|
|
100
|
+
messages.append({"role": "assistant", "content": answer_new})
|
|
101
|
+
return messages
|
|
102
|
+
|
|
103
|
+
def to_completion(self) -> dict:
|
|
104
|
+
messages = []
|
|
105
|
+
messages.append({"role": "system", "content": self.get_system_str()})
|
|
106
|
+
messages.extend(self.get_messages())
|
|
107
|
+
messages.append({"role": "user", "content": self.question or ""})
|
|
108
|
+
return {"messages": messages}
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class _KdtoSuggestRequest(BaseModel):
|
|
112
|
+
"""pydantic API request type."""
|
|
113
|
+
|
|
114
|
+
payload: _KdtSuggestPayload
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class _KdtMessage(BaseModel):
|
|
118
|
+
"""pydantic API response type."""
|
|
119
|
+
|
|
120
|
+
role: str = Field(default="", title="One of [user|assistant|system]")
|
|
121
|
+
content: str
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class _KdtChoice(BaseModel):
|
|
125
|
+
"""pydantic API response type."""
|
|
126
|
+
|
|
127
|
+
index: int
|
|
128
|
+
message: _KdtMessage | None = Field(default=None, title="The generated SQL")
|
|
129
|
+
finish_reason: str
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class _KdtUsage(BaseModel):
|
|
133
|
+
"""pydantic API response type."""
|
|
134
|
+
|
|
135
|
+
prompt_tokens: int
|
|
136
|
+
completion_tokens: int
|
|
137
|
+
total_tokens: int
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class _KdtSqlResponse(BaseModel):
|
|
141
|
+
"""pydantic API response type."""
|
|
142
|
+
|
|
143
|
+
id: str
|
|
144
|
+
object: str
|
|
145
|
+
created: int
|
|
146
|
+
model: str
|
|
147
|
+
choices: list[_KdtChoice]
|
|
148
|
+
usage: _KdtUsage
|
|
149
|
+
prompt: str = Field(default="", title="The input question")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class _KdtCompletionResponse(BaseModel):
|
|
153
|
+
"""pydantic API response type."""
|
|
154
|
+
|
|
155
|
+
status: str
|
|
156
|
+
data: _KdtSqlResponse
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class _KineticaLlmFileContextParser:
|
|
160
|
+
"""Parser for Kinetica LLM context datafiles."""
|
|
161
|
+
|
|
162
|
+
# parse line into a dict containing role and content
|
|
163
|
+
PARSER: Pattern = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)
|
|
164
|
+
|
|
165
|
+
@classmethod
|
|
166
|
+
def _removesuffix(cls, text: str, suffix: str) -> str:
|
|
167
|
+
if suffix and text.endswith(suffix):
|
|
168
|
+
return text[: -len(suffix)]
|
|
169
|
+
return text
|
|
170
|
+
|
|
171
|
+
@classmethod
|
|
172
|
+
def parse_dialogue_file(cls, input_file: os.PathLike) -> dict:
|
|
173
|
+
path = Path(input_file)
|
|
174
|
+
# schema = path.name.removesuffix(".txt") python 3.9
|
|
175
|
+
schema = cls._removesuffix(path.name, ".txt")
|
|
176
|
+
|
|
177
|
+
with Path(input_file).open("r") as fp:
|
|
178
|
+
lines = fp.read()
|
|
179
|
+
|
|
180
|
+
return cls.parse_dialogue(lines, schema)
|
|
181
|
+
|
|
182
|
+
@classmethod
|
|
183
|
+
def parse_dialogue(cls, text: str, schema: str) -> dict:
|
|
184
|
+
messages = []
|
|
185
|
+
system = None
|
|
186
|
+
|
|
187
|
+
lines = text.split("<|end|>")
|
|
188
|
+
user_message = None
|
|
189
|
+
|
|
190
|
+
for line_in in lines:
|
|
191
|
+
line = line_in.strip()
|
|
192
|
+
|
|
193
|
+
if len(line) == 0:
|
|
194
|
+
continue
|
|
195
|
+
|
|
196
|
+
match = cls.PARSER.match(line)
|
|
197
|
+
if match is None:
|
|
198
|
+
msg = f"Could not find starting token in: {line}" # type: ignore[no-redef]
|
|
199
|
+
raise ValueError(msg)
|
|
200
|
+
|
|
201
|
+
groupdict = match.groupdict()
|
|
202
|
+
role = groupdict["role"]
|
|
203
|
+
|
|
204
|
+
if role == "system":
|
|
205
|
+
if system is not None:
|
|
206
|
+
msg = f"Only one system token allowed in: {line}"
|
|
207
|
+
raise ValueError(msg)
|
|
208
|
+
system = groupdict["content"]
|
|
209
|
+
elif role == "user":
|
|
210
|
+
if user_message is not None:
|
|
211
|
+
msg: str = f"Found user token without assistant token: {line}" # type: ignore[no-redef]
|
|
212
|
+
raise ValueError(msg)
|
|
213
|
+
user_message = groupdict
|
|
214
|
+
elif role == "assistant":
|
|
215
|
+
if user_message is None:
|
|
216
|
+
msg = "Found assistant token without user token: {line}"
|
|
217
|
+
raise ValueError(msg)
|
|
218
|
+
messages.append(user_message)
|
|
219
|
+
messages.append(groupdict)
|
|
220
|
+
user_message = None
|
|
221
|
+
else:
|
|
222
|
+
msg = f"Unknown token: {role}"
|
|
223
|
+
raise ValueError(msg)
|
|
224
|
+
|
|
225
|
+
return {"schema": schema, "system": system, "messages": messages}
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class ChatKinetica(BaseChatModel):
|
|
229
|
+
"""Kinetica LLM Chat Model API.
|
|
230
|
+
|
|
231
|
+
Prerequisites for using this API:
|
|
232
|
+
|
|
233
|
+
* The ``gpudb`` and ``typeguard`` packages installed.
|
|
234
|
+
* A Kinetica DB instance.
|
|
235
|
+
* Kinetica host specified in ``KINETICA_URL``
|
|
236
|
+
* Kinetica login specified ``KINETICA_USER``, and ``KINETICA_PASSWD``.
|
|
237
|
+
* An LLM context that specifies the tables and samples to use for inferencing.
|
|
238
|
+
|
|
239
|
+
This API is intended to interact with the Kinetica SqlAssist LLM that supports
|
|
240
|
+
generation of SQL from natural language.
|
|
241
|
+
|
|
242
|
+
In the Kinetica LLM workflow you create an LLM context in the database that provides
|
|
243
|
+
information needed for infefencing that includes tables, annotations, rules, and
|
|
244
|
+
samples. Invoking ``load_messages_from_context()`` will retrieve the contxt
|
|
245
|
+
information from the database so that it can be used to create a chat prompt.
|
|
246
|
+
|
|
247
|
+
The chat prompt consists of a ``SystemMessage`` and pairs of
|
|
248
|
+
``HumanMessage``/``AIMessage`` that contain the samples which are question/SQL
|
|
249
|
+
pairs. You can append pairs samples to this list but it is not intended to
|
|
250
|
+
facilitate a typical natural language conversation.
|
|
251
|
+
|
|
252
|
+
When you create a chain from the chat prompt and execute it, the Kinetica LLM will
|
|
253
|
+
generate SQL from the input. Optionally you can use ``KineticaSqlOutputParser`` to
|
|
254
|
+
execute the SQL and return the result as a dataframe.
|
|
255
|
+
|
|
256
|
+
The following example creates an LLM using the environment variables for the
|
|
257
|
+
Kinetica connection. This will fail if the API is unable to connect to the database.
|
|
258
|
+
|
|
259
|
+
Example:
|
|
260
|
+
.. code-block:: python
|
|
261
|
+
|
|
262
|
+
from langchain_kinetica import ChatKinetica
|
|
263
|
+
|
|
264
|
+
kinetica_llm = ChatKinetica()
|
|
265
|
+
|
|
266
|
+
If you prefer to pass connection information directly then you can create a
|
|
267
|
+
connection using ``GPUdb.get_connection()``.
|
|
268
|
+
|
|
269
|
+
Example:
|
|
270
|
+
.. code-block:: python
|
|
271
|
+
|
|
272
|
+
from langchain_kinetica import ChatKinetica
|
|
273
|
+
|
|
274
|
+
kdbc = GPUdb.get_connection()
|
|
275
|
+
kinetica_llm = ChatKinetica(kdbc=kdbc)
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
kdbc: GPUdb | None = Field(exclude=True)
|
|
279
|
+
""" Kinetica DB connection. """
|
|
280
|
+
|
|
281
|
+
@pre_init
|
|
282
|
+
def validate_environment(cls, values: dict) -> dict: # noqa: N805
|
|
283
|
+
"""Pydantic object validator."""
|
|
284
|
+
kdbc = values.get("kdbc")
|
|
285
|
+
if kdbc is None:
|
|
286
|
+
kdbc = GPUdb.get_connection()
|
|
287
|
+
values["kdbc"] = kdbc
|
|
288
|
+
return values
|
|
289
|
+
|
|
290
|
+
@property
|
|
291
|
+
def _llm_type(self) -> str:
|
|
292
|
+
return "kinetica-sqlassist"
|
|
293
|
+
|
|
294
|
+
def _get_kdbc(self) -> GPUdb:
|
|
295
|
+
if self.kdbc is None:
|
|
296
|
+
msg = "Kinetica DB connection is not initialized."
|
|
297
|
+
raise ValueError(msg)
|
|
298
|
+
return self.kdbc
|
|
299
|
+
|
|
300
|
+
@property
|
|
301
|
+
def _identifying_params(self) -> dict[str, Any]:
|
|
302
|
+
return {
|
|
303
|
+
"kinetica_version": str(self._get_kdbc().server_version),
|
|
304
|
+
"api_version": version("gpudb"),
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
@override
|
|
308
|
+
def _generate(
|
|
309
|
+
self,
|
|
310
|
+
messages: list[BaseMessage],
|
|
311
|
+
stop: list[str] | None = None,
|
|
312
|
+
run_manager: CallbackManagerForLLMRun | None = None,
|
|
313
|
+
**kwargs: Any,
|
|
314
|
+
) -> ChatResult:
|
|
315
|
+
if stop is not None:
|
|
316
|
+
msg = "stop kwargs are not permitted."
|
|
317
|
+
raise ValueError(msg)
|
|
318
|
+
|
|
319
|
+
dict_messages = [self._convert_message_to_dict(m) for m in messages]
|
|
320
|
+
sql_response = self._submit_completion(dict_messages)
|
|
321
|
+
|
|
322
|
+
response_message = cast("_KdtMessage", sql_response.choices[0].message)
|
|
323
|
+
generated_dict = response_message.model_dump()
|
|
324
|
+
|
|
325
|
+
generated_message = self._convert_message_from_dict(generated_dict)
|
|
326
|
+
|
|
327
|
+
llm_output = {
|
|
328
|
+
"input_tokens": sql_response.usage.prompt_tokens,
|
|
329
|
+
"output_tokens": sql_response.usage.completion_tokens,
|
|
330
|
+
"model_name": sql_response.model,
|
|
331
|
+
}
|
|
332
|
+
return ChatResult(
|
|
333
|
+
generations=[ChatGeneration(message=generated_message)],
|
|
334
|
+
llm_output=llm_output,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
def load_messages_from_context(self, context_name: str) -> list:
|
|
338
|
+
"""Load a lanchain prompt from a Kinetica context.
|
|
339
|
+
|
|
340
|
+
A Kinetica Context is an object created with the Kinetica Workbench UI or with
|
|
341
|
+
SQL syntax. This function will convert the data in the context to a list of
|
|
342
|
+
messages that can be used as a prompt. The messages will contain a
|
|
343
|
+
``SystemMessage`` followed by pairs of ``HumanMessage``/``AIMessage`` that
|
|
344
|
+
contain the samples.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
context_name: The name of an LLM context in the database.
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
A list of messages containing the information from the context.
|
|
351
|
+
"""
|
|
352
|
+
# query kinetica for the prompt
|
|
353
|
+
sql = f"GENERATE PROMPT WITH OPTIONS (CONTEXT_NAMES = '{context_name}')"
|
|
354
|
+
|
|
355
|
+
result = self._execute_sql(sql)
|
|
356
|
+
prompt = result["Prompt"]
|
|
357
|
+
prompt_json = json.loads(prompt)
|
|
358
|
+
|
|
359
|
+
# convert the prompt to messages
|
|
360
|
+
request = _KdtoSuggestRequest.model_validate(prompt_json)
|
|
361
|
+
payload = request.payload
|
|
362
|
+
|
|
363
|
+
dict_messages = []
|
|
364
|
+
dict_messages.append({"role": "system", "content": payload.get_system_str()})
|
|
365
|
+
|
|
366
|
+
dict_messages.extend(payload.get_messages())
|
|
367
|
+
return [self._convert_message_from_dict(m) for m in dict_messages]
|
|
368
|
+
|
|
369
|
+
def _submit_completion(self, messages: list[dict]) -> _KdtSqlResponse:
|
|
370
|
+
"""Submit a /chat/completions request to Kinetica."""
|
|
371
|
+
request = {"messages": messages}
|
|
372
|
+
request_json = json.dumps(request)
|
|
373
|
+
response_raw = self._get_kdbc()._GPUdb__submit_request_json( # noqa: SLF001
|
|
374
|
+
"/chat/completions", request_json
|
|
375
|
+
)
|
|
376
|
+
response_json = json.loads(response_raw)
|
|
377
|
+
|
|
378
|
+
status = response_json["status"]
|
|
379
|
+
if status != "OK":
|
|
380
|
+
message = response_json["message"]
|
|
381
|
+
match_resp = re.compile(r"response:({.*})")
|
|
382
|
+
result = match_resp.search(message)
|
|
383
|
+
if result is not None:
|
|
384
|
+
response = result.group(1)
|
|
385
|
+
response_json = json.loads(response)
|
|
386
|
+
message = response_json["message"]
|
|
387
|
+
raise ValueError(message)
|
|
388
|
+
|
|
389
|
+
data = response_json["data"]
|
|
390
|
+
response = _KdtCompletionResponse.model_validate(data)
|
|
391
|
+
if response.status != "OK":
|
|
392
|
+
msg = "SQL Generation failed."
|
|
393
|
+
raise ValueError(msg)
|
|
394
|
+
return response.data
|
|
395
|
+
|
|
396
|
+
def _execute_sql(self, sql: str) -> dict:
|
|
397
|
+
"""Execute an SQL query and return the result."""
|
|
398
|
+
response = self._get_kdbc().execute_sql_and_decode(
|
|
399
|
+
sql, limit=1, get_column_major=False
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
status_info = response["status_info"]
|
|
403
|
+
if status_info["status"] != "OK":
|
|
404
|
+
message = status_info["message"]
|
|
405
|
+
raise ValueError(message)
|
|
406
|
+
|
|
407
|
+
records = response["records"]
|
|
408
|
+
if len(records) != 1:
|
|
409
|
+
msg = "No records returned."
|
|
410
|
+
raise ValueError(msg)
|
|
411
|
+
|
|
412
|
+
record = records[0]
|
|
413
|
+
return dict(record)
|
|
414
|
+
|
|
415
|
+
@classmethod
|
|
416
|
+
def load_messages_from_datafile(cls, sa_datafile: Path) -> list[BaseMessage]:
|
|
417
|
+
"""Load a lanchain prompt from a Kinetica context datafile."""
|
|
418
|
+
datafile_dict = _KineticaLlmFileContextParser.parse_dialogue_file(sa_datafile)
|
|
419
|
+
return cls._convert_dict_to_messages(datafile_dict)
|
|
420
|
+
|
|
421
|
+
@classmethod
|
|
422
|
+
def _convert_message_to_dict(cls, message: BaseMessage) -> dict:
|
|
423
|
+
"""Convert a single message to a BaseMessage."""
|
|
424
|
+
content = cast("str", message.content)
|
|
425
|
+
if isinstance(message, HumanMessage):
|
|
426
|
+
role = "user"
|
|
427
|
+
elif isinstance(message, AIMessage):
|
|
428
|
+
role = "assistant"
|
|
429
|
+
elif isinstance(message, SystemMessage):
|
|
430
|
+
role = "system"
|
|
431
|
+
else:
|
|
432
|
+
msg = f"Got unsupported message type: {message}"
|
|
433
|
+
raise TypeError(msg)
|
|
434
|
+
|
|
435
|
+
return {"role": role, "content": content}
|
|
436
|
+
|
|
437
|
+
@classmethod
|
|
438
|
+
def _convert_message_from_dict(cls, message: dict) -> BaseMessage:
|
|
439
|
+
"""Convert a single message from a BaseMessage."""
|
|
440
|
+
role = message["role"]
|
|
441
|
+
content = message["content"]
|
|
442
|
+
if role == "user":
|
|
443
|
+
return HumanMessage(content=content)
|
|
444
|
+
if role == "assistant":
|
|
445
|
+
return AIMessage(content=content)
|
|
446
|
+
if role == "system":
|
|
447
|
+
return SystemMessage(content=content)
|
|
448
|
+
msg = f"Got unsupported role: {role}"
|
|
449
|
+
raise ValueError(msg)
|
|
450
|
+
|
|
451
|
+
@classmethod
|
|
452
|
+
def _convert_dict_to_messages(cls, sa_data: dict) -> list[BaseMessage]:
|
|
453
|
+
"""Convert a dict to a list of BaseMessages."""
|
|
454
|
+
schema = sa_data["schema"]
|
|
455
|
+
system = sa_data["system"]
|
|
456
|
+
messages = sa_data["messages"]
|
|
457
|
+
LOG.info("Importing prompt for schema: %s", schema)
|
|
458
|
+
|
|
459
|
+
result_list: list[BaseMessage] = []
|
|
460
|
+
result_list.append(SystemMessage(content=system))
|
|
461
|
+
result_list.extend([cls._convert_message_from_dict(m) for m in messages])
|
|
462
|
+
return result_list
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
class KineticaSqlResponse(BaseModel):
|
|
466
|
+
"""Response containing SQL and the fetched data.
|
|
467
|
+
|
|
468
|
+
This object is returned by a chain with ``KineticaSqlOutputParser`` and it contains
|
|
469
|
+
the generated SQL and related Pandas Dataframe fetched from the database.
|
|
470
|
+
"""
|
|
471
|
+
|
|
472
|
+
sql: str = Field(default="")
|
|
473
|
+
"""The generated SQL."""
|
|
474
|
+
|
|
475
|
+
dataframe: Any = Field(default=None)
|
|
476
|
+
"""The Pandas dataframe containing the fetched data."""
|
|
477
|
+
|
|
478
|
+
model_config = ConfigDict(
|
|
479
|
+
arbitrary_types_allowed=True,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
class KineticaSqlOutputParser(BaseOutputParser[KineticaSqlResponse]):
|
|
484
|
+
"""Fetch and return data from the Kinetica LLM.
|
|
485
|
+
|
|
486
|
+
This object is used as the last element of a chain to execute generated SQL and it
|
|
487
|
+
will output a ``KineticaSqlResponse`` containing the SQL and a pandas dataframe with
|
|
488
|
+
the fetched data.
|
|
489
|
+
|
|
490
|
+
Example:
|
|
491
|
+
.. code-block:: python
|
|
492
|
+
|
|
493
|
+
from langchain_kinetica import ChatKinetica, KineticaSqlOutputParser
|
|
494
|
+
|
|
495
|
+
kinetica_llm = ChatKinetica()
|
|
496
|
+
|
|
497
|
+
# create chain
|
|
498
|
+
ctx_messages = kinetica_llm.load_messages_from_context(self.context_name)
|
|
499
|
+
ctx_messages.append(("human", "{input}"))
|
|
500
|
+
prompt_template = ChatPromptTemplate.from_messages(ctx_messages)
|
|
501
|
+
chain = (
|
|
502
|
+
prompt_template
|
|
503
|
+
| kinetica_llm
|
|
504
|
+
| KineticaSqlOutputParser(kdbc=kinetica_llm.kdbc)
|
|
505
|
+
)
|
|
506
|
+
sql_response: KineticaSqlResponse = chain.invoke(
|
|
507
|
+
{"input": "What are the female users ordered by username?"}
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
assert isinstance(sql_response, KineticaSqlResponse)
|
|
511
|
+
LOG.info(f"SQL Response: {sql_response.sql}")
|
|
512
|
+
assert isinstance(sql_response.dataframe, pd.DataFrame)
|
|
513
|
+
"""
|
|
514
|
+
|
|
515
|
+
kdbc: Any = Field(exclude=True)
|
|
516
|
+
""" Kinetica DB connection. """
|
|
517
|
+
|
|
518
|
+
model_config = ConfigDict(
|
|
519
|
+
arbitrary_types_allowed=True,
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
@override
|
|
523
|
+
def parse(self, text: str) -> KineticaSqlResponse:
|
|
524
|
+
"""Parse the LLM output text to fetch data from Kinetica."""
|
|
525
|
+
df = self.kdbc.to_df(text)
|
|
526
|
+
return KineticaSqlResponse(sql=text, dataframe=df)
|
|
527
|
+
|
|
528
|
+
@override
|
|
529
|
+
def parse_result(
|
|
530
|
+
self, result: list[Generation], *, partial: bool = False
|
|
531
|
+
) -> KineticaSqlResponse:
|
|
532
|
+
"""Parse the LLM output result to fetch data from Kinetica."""
|
|
533
|
+
return self.parse(result[0].text)
|
|
534
|
+
|
|
535
|
+
@property
|
|
536
|
+
def _type(self) -> str:
|
|
537
|
+
return "kinetica_sql_output_parser"
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""Kinetica Document Loader API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from gpudb import GPUdb, GPUdbSqlIterator
|
|
8
|
+
from langchain_core.document_loaders.base import BaseLoader
|
|
9
|
+
from langchain_core.documents import Document
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from collections.abc import Iterator
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class KineticaLoader(BaseLoader):
|
|
16
|
+
"""Load from `Kinetica` API.
|
|
17
|
+
|
|
18
|
+
Each document represents one row of the result. The `page_content_columns`
|
|
19
|
+
are written into the `page_content` of the document. The `metadata_columns`
|
|
20
|
+
are written into the `metadata` of the document. By default, all columns
|
|
21
|
+
are written into the `page_content` and none into the `metadata`.
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
query: str,
|
|
28
|
+
kdbc: GPUdb | None = None,
|
|
29
|
+
parameters: dict[str, Any] | None = None,
|
|
30
|
+
page_content_columns: list[str] | None = None,
|
|
31
|
+
metadata_columns: list[str] | None = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Initialize Kinetica document loader.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
query: The query to run in Kinetica.
|
|
37
|
+
kdbc (GPUdb, optional): An optional GPUdb connection instance. If not
|
|
38
|
+
provided, the connection will be established using environment
|
|
39
|
+
variables.
|
|
40
|
+
parameters: Optional. Parameters to pass to the query.
|
|
41
|
+
page_content_columns: Optional. Columns written to Document `page_content`.
|
|
42
|
+
metadata_columns: Optional. Columns written to Document `metadata`.
|
|
43
|
+
"""
|
|
44
|
+
self.query = query
|
|
45
|
+
self.parameters = parameters
|
|
46
|
+
self.page_content_columns = page_content_columns
|
|
47
|
+
self.metadata_columns = metadata_columns if metadata_columns is not None else []
|
|
48
|
+
|
|
49
|
+
if kdbc is None:
|
|
50
|
+
kdbc = GPUdb.get_connection()
|
|
51
|
+
self.kdbc = kdbc
|
|
52
|
+
|
|
53
|
+
def _execute_query(self) -> list[dict[str, Any]]:
|
|
54
|
+
with GPUdbSqlIterator(self.kdbc, self.query) as records:
|
|
55
|
+
column_names = records.type_map.keys()
|
|
56
|
+
return [dict(zip(column_names, record, strict=False)) for record in records]
|
|
57
|
+
|
|
58
|
+
def _get_columns(
|
|
59
|
+
self, query_result: list[dict[str, Any]]
|
|
60
|
+
) -> tuple[list[str], list[str]]:
|
|
61
|
+
page_content_columns = self.page_content_columns
|
|
62
|
+
metadata_columns = self.metadata_columns
|
|
63
|
+
|
|
64
|
+
if page_content_columns is None and query_result:
|
|
65
|
+
page_content_columns = list(query_result[0].keys())
|
|
66
|
+
if metadata_columns is None:
|
|
67
|
+
metadata_columns = []
|
|
68
|
+
return page_content_columns or [], metadata_columns
|
|
69
|
+
|
|
70
|
+
def lazy_load(self) -> Iterator[Document]:
|
|
71
|
+
"""Lazily load data into document objects."""
|
|
72
|
+
query_result = self._execute_query()
|
|
73
|
+
if isinstance(query_result, Exception):
|
|
74
|
+
print(f"An error occurred during the query: {query_result}") # noqa: T201
|
|
75
|
+
return []
|
|
76
|
+
page_content_columns, metadata_columns = self._get_columns(query_result)
|
|
77
|
+
if "*" in page_content_columns:
|
|
78
|
+
page_content_columns = list(query_result[0].keys())
|
|
79
|
+
for row in query_result:
|
|
80
|
+
page_content = "\n".join(
|
|
81
|
+
f"{k}: {v}" for k, v in row.items() if k in page_content_columns
|
|
82
|
+
)
|
|
83
|
+
metadata = {k: v for k, v in row.items() if k in metadata_columns}
|
|
84
|
+
doc = Document(page_content=page_content, metadata=metadata)
|
|
85
|
+
yield doc
|
|
86
|
+
|
|
87
|
+
def load(self) -> list[Document]:
|
|
88
|
+
"""Load data into document objects."""
|
|
89
|
+
return list(self.lazy_load())
|
|
File without changes
|