langchain-timbr 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_timbr/__init__.py +17 -0
- langchain_timbr/config.py +21 -0
- langchain_timbr/langchain/__init__.py +16 -0
- langchain_timbr/langchain/execute_timbr_query_chain.py +307 -0
- langchain_timbr/langchain/generate_answer_chain.py +99 -0
- langchain_timbr/langchain/generate_timbr_sql_chain.py +176 -0
- langchain_timbr/langchain/identify_concept_chain.py +138 -0
- langchain_timbr/langchain/timbr_sql_agent.py +418 -0
- langchain_timbr/langchain/validate_timbr_sql_chain.py +187 -0
- langchain_timbr/langgraph/__init__.py +13 -0
- langchain_timbr/langgraph/execute_timbr_query_node.py +108 -0
- langchain_timbr/langgraph/generate_response_node.py +59 -0
- langchain_timbr/langgraph/generate_timbr_sql_node.py +98 -0
- langchain_timbr/langgraph/identify_concept_node.py +78 -0
- langchain_timbr/langgraph/validate_timbr_query_node.py +100 -0
- langchain_timbr/llm_wrapper/llm_wrapper.py +189 -0
- langchain_timbr/llm_wrapper/timbr_llm_wrapper.py +41 -0
- langchain_timbr/timbr_llm_connector.py +398 -0
- langchain_timbr/utils/general.py +70 -0
- langchain_timbr/utils/prompt_service.py +330 -0
- langchain_timbr/utils/temperature_supported_models.json +62 -0
- langchain_timbr/utils/timbr_llm_utils.py +575 -0
- langchain_timbr/utils/timbr_utils.py +475 -0
- langchain_timbr-1.5.0.dist-info/METADATA +103 -0
- langchain_timbr-1.5.0.dist-info/RECORD +27 -0
- langchain_timbr-1.5.0.dist-info/WHEEL +4 -0
- langchain_timbr-1.5.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
#
|
|
2
|
+
# *### ., @%
|
|
3
|
+
# *%## `#// %%%* *@ `` @%
|
|
4
|
+
# #*. * .%%%` @@@@* @@ @@@@,@@@@ @&@@@@ .&@@@*
|
|
5
|
+
# #%%# .. *@ @@ @` @@` ,@ @% #@ @@
|
|
6
|
+
# ,, .,%(##/./%%#, *@ @@ @` @@` ,@ @% #@ @@
|
|
7
|
+
# ,%##% `` `/@@* @@ @` @@` ,@ (/@@@#/ @@
|
|
8
|
+
# ``
|
|
9
|
+
# ``````````````````````````````````````````````````````````````
|
|
10
|
+
# Copyright (C) 2018-2025 timbr.ai
|
|
11
|
+
|
|
12
|
+
__version__ = "1.4.3"
|
|
13
|
+
from .timbr_llm_connector import TimbrLlmConnector
|
|
14
|
+
from .llm_wrapper.llm_wrapper import LlmWrapper, LlmTypes
|
|
15
|
+
|
|
16
|
+
from .langchain import *
|
|
17
|
+
from .langgraph import *
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from .utils.general import to_boolean, to_integer, parse_list
|
|
3
|
+
|
|
4
|
+
# MUST HAVE VARIABLES
|
|
5
|
+
url = os.environ.get('TIMBR_URL')
|
|
6
|
+
token = os.environ.get('TIMBR_TOKEN')
|
|
7
|
+
ontology = os.environ.get('ONTOLOGY', 'system_db')
|
|
8
|
+
|
|
9
|
+
# OPTIONAL VARIABLES
|
|
10
|
+
is_jwt = to_boolean(os.environ.get('IS_JWT', 'false'))
|
|
11
|
+
jwt_tenant_id = os.environ.get('JWT_TENANT_ID', None)
|
|
12
|
+
|
|
13
|
+
cache_timeout = to_integer(os.environ.get('CACHE_TIMEOUT', 120))
|
|
14
|
+
ignore_tags = parse_list(os.environ.get('IGNORE_TAGS', 'icon'))
|
|
15
|
+
ignore_tags_prefix = parse_list(os.environ.get('IGNORE_TAGS_PREFIX', 'mdx.,bli.'))
|
|
16
|
+
|
|
17
|
+
llm_type = os.environ.get('LLM_TYPE')
|
|
18
|
+
llm_model = os.environ.get('LLM_MODEL')
|
|
19
|
+
llm_api_key = os.environ.get('LLM_API_KEY')
|
|
20
|
+
llm_temperature = os.environ.get('LLM_TEMPERATURE', 0.0)
|
|
21
|
+
llm_timeout = to_integer(os.environ.get('LLM_TIMEOUT', 60)) # Default 60 seconds timeout
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from .identify_concept_chain import IdentifyTimbrConceptChain
|
|
2
|
+
from .generate_timbr_sql_chain import GenerateTimbrSqlChain
|
|
3
|
+
from .validate_timbr_sql_chain import ValidateTimbrSqlChain
|
|
4
|
+
from .execute_timbr_query_chain import ExecuteTimbrQueryChain
|
|
5
|
+
from .generate_answer_chain import GenerateAnswerChain
|
|
6
|
+
from .timbr_sql_agent import TimbrSqlAgent, create_timbr_sql_agent
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"IdentifyTimbrConceptChain",
|
|
10
|
+
"GenerateTimbrSqlChain",
|
|
11
|
+
"ValidateTimbrSqlChain",
|
|
12
|
+
"ExecuteTimbrQueryChain",
|
|
13
|
+
"GenerateAnswerChain",
|
|
14
|
+
"TimbrSqlAgent",
|
|
15
|
+
"create_timbr_sql_agent",
|
|
16
|
+
]
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
from typing import Optional, Union, Dict, Any
|
|
2
|
+
from langchain.chains.base import Chain
|
|
3
|
+
from langchain.llms.base import LLM
|
|
4
|
+
|
|
5
|
+
from ..utils.general import parse_list, to_boolean, to_integer
|
|
6
|
+
from ..utils.timbr_utils import run_query, validate_sql
|
|
7
|
+
from ..utils.timbr_llm_utils import generate_sql
|
|
8
|
+
|
|
9
|
+
class ExecuteTimbrQueryChain(Chain):
|
|
10
|
+
"""
|
|
11
|
+
LangChain chain for executing SQL queries against Timbr knowledge graph databases.
|
|
12
|
+
|
|
13
|
+
This chain executes SQL queries on Timbr ontology/knowledge graph databases and
|
|
14
|
+
returns the query results, handling retries and result validation. It uses an LLM
|
|
15
|
+
for query generation and connects to Timbr via URL and token.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
llm: LLM,
|
|
21
|
+
url: str,
|
|
22
|
+
token: str,
|
|
23
|
+
ontology: str,
|
|
24
|
+
schema: Optional[str] = 'dtimbr',
|
|
25
|
+
concept: Optional[str] = None,
|
|
26
|
+
concepts_list: Optional[Union[list[str], str]] = None,
|
|
27
|
+
views_list: Optional[Union[list[str], str]] = None,
|
|
28
|
+
include_logic_concepts: Optional[bool] = False,
|
|
29
|
+
include_tags: Optional[Union[list[str], str]] = None,
|
|
30
|
+
exclude_properties: Optional[Union[list[str], str]] = ['entity_id', 'entity_type', 'entity_label'],
|
|
31
|
+
should_validate_sql: Optional[bool] = False,
|
|
32
|
+
retries: Optional[int] = 3,
|
|
33
|
+
max_limit: Optional[int] = 500,
|
|
34
|
+
retry_if_no_results: Optional[bool] = False,
|
|
35
|
+
no_results_max_retries: Optional[int] = 2,
|
|
36
|
+
note: Optional[str] = '',
|
|
37
|
+
db_is_case_sensitive: Optional[bool] = False,
|
|
38
|
+
graph_depth: Optional[int] = 1,
|
|
39
|
+
verify_ssl: Optional[bool] = True,
|
|
40
|
+
is_jwt: Optional[bool] = False,
|
|
41
|
+
jwt_tenant_id: Optional[str] = None,
|
|
42
|
+
conn_params: Optional[dict] = None,
|
|
43
|
+
debug: Optional[bool] = False,
|
|
44
|
+
**kwargs,
|
|
45
|
+
):
|
|
46
|
+
"""
|
|
47
|
+
:param llm: An LLM instance or a function that takes a prompt string and returns the LLM's response
|
|
48
|
+
:param url: Timbr server url
|
|
49
|
+
:param token: Timbr password or token value
|
|
50
|
+
:param ontology: The name of the ontology/knowledge graph
|
|
51
|
+
:param schema: The name of the schema to query
|
|
52
|
+
:param concept: The name of the concept to query
|
|
53
|
+
:param concepts_list: Optional specific concept options to query
|
|
54
|
+
:param views_list: Optional specific view options to query
|
|
55
|
+
:param include_logic_concepts: Optional boolean to include logic concepts (concepts without unique properties which only inherits from an upper level concept with filter logic) in the query.
|
|
56
|
+
:param include_tags: Optional specific concepts & properties tag options to use in the query (Disabled by default. Use '*' to enable all tags or a string represents a list of tags divided by commas (e.g. 'tag1,tag2')
|
|
57
|
+
:param exclude_properties: Optional specific properties to exclude from the query (entity_id, entity_type & entity_label by default).
|
|
58
|
+
:param should_validate_sql: Whether to validate the SQL before executing it
|
|
59
|
+
:param retries: Number of retry attempts if the generated SQL is invalid
|
|
60
|
+
:param max_limit: Maximum number of rows to return
|
|
61
|
+
:retry_if_no_results: Whether to infer the result value from the SQL query. If the query won't return any rows, it will try to re-generate the SQL query then re-run it.
|
|
62
|
+
:param no_results_max_retries: Number of retry attempts to infer the result value from the SQL query
|
|
63
|
+
:param note: Optional additional note to extend our llm prompt
|
|
64
|
+
:param db_is_case_sensitive: Whether the database is case sensitive (default is False).
|
|
65
|
+
:param graph_depth: Maximum number of relationship hops to traverse from the source concept during schema exploration (default is 1).
|
|
66
|
+
:param verify_ssl: Whether to verify SSL certificates (default is True).
|
|
67
|
+
:param is_jwt: Whether to use JWT authentication (default is False).
|
|
68
|
+
:param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
|
|
69
|
+
:param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
|
|
70
|
+
:param kwargs: Additional arguments to pass to the base
|
|
71
|
+
:return: A list of rows from the Timbr query
|
|
72
|
+
|
|
73
|
+
## Example
|
|
74
|
+
```
|
|
75
|
+
execute_timbr_query_chain = ExecuteTimbrQueryChain(
|
|
76
|
+
url=<url>,
|
|
77
|
+
token=<token>,
|
|
78
|
+
llm=<llm or timbr_llm_wrapper instance>,
|
|
79
|
+
ontology=<ontology_name>,
|
|
80
|
+
schema=<schema_name>,
|
|
81
|
+
concept=<concept_name>,
|
|
82
|
+
concepts_list=<concepts>,
|
|
83
|
+
views_list=<views>,
|
|
84
|
+
should_validate_sql=False,
|
|
85
|
+
note=<note>,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return execute_timbr_query_chain.invoke({ "prompt": question }).get("rows", [])
|
|
89
|
+
```
|
|
90
|
+
"""
|
|
91
|
+
super().__init__(**kwargs)
|
|
92
|
+
self._llm = llm
|
|
93
|
+
self._url = url
|
|
94
|
+
self._token = token
|
|
95
|
+
self._ontology = ontology
|
|
96
|
+
self._schema = schema
|
|
97
|
+
self._concept = concept
|
|
98
|
+
self._concepts_list = parse_list(concepts_list)
|
|
99
|
+
self._views_list = parse_list(views_list)
|
|
100
|
+
self._include_tags = parse_list(include_tags)
|
|
101
|
+
self._include_logic_concepts = to_boolean(include_logic_concepts)
|
|
102
|
+
self._exclude_properties = parse_list(exclude_properties)
|
|
103
|
+
self._should_validate_sql = to_boolean(should_validate_sql)
|
|
104
|
+
self._retries = to_integer(retries)
|
|
105
|
+
self._max_limit = to_integer(max_limit)
|
|
106
|
+
self._retry_if_no_results = to_boolean(retry_if_no_results)
|
|
107
|
+
self._no_results_max_retries = to_integer(no_results_max_retries)
|
|
108
|
+
self._note = note
|
|
109
|
+
self._db_is_case_sensitive = to_boolean(db_is_case_sensitive)
|
|
110
|
+
self._graph_depth = to_integer(graph_depth)
|
|
111
|
+
self._verify_ssl = to_boolean(verify_ssl)
|
|
112
|
+
self._is_jwt = to_boolean(is_jwt)
|
|
113
|
+
self._jwt_tenant_id = jwt_tenant_id
|
|
114
|
+
self._debug = to_boolean(debug)
|
|
115
|
+
self._conn_params = conn_params or {}
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def usage_metadata_key(self) -> str:
|
|
120
|
+
return "execute_timbr_usage_metadata"
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def input_keys(self) -> list:
|
|
125
|
+
return ["prompt"]
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def output_keys(self) -> list:
|
|
130
|
+
return [
|
|
131
|
+
"rows",
|
|
132
|
+
"sql",
|
|
133
|
+
"schema",
|
|
134
|
+
"concept",
|
|
135
|
+
"error",
|
|
136
|
+
self.usage_metadata_key,
|
|
137
|
+
]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _get_conn_params(self) -> dict:
|
|
141
|
+
return {
|
|
142
|
+
"url": self._url,
|
|
143
|
+
"token": self._token,
|
|
144
|
+
"ontology": self._ontology,
|
|
145
|
+
"verify_ssl": self._verify_ssl,
|
|
146
|
+
"is_jwt": self._is_jwt,
|
|
147
|
+
"jwt_tenant_id": self._jwt_tenant_id,
|
|
148
|
+
**self._conn_params,
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
|
153
|
+
if (not inputs.get("sql")) and (not inputs.get("prompt")):
|
|
154
|
+
raise ValueError("Timbr SQL or user prompt is required for executing the chain.")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _generate_sql(
|
|
158
|
+
self,
|
|
159
|
+
prompt: str,
|
|
160
|
+
sql: Optional[str] = None,
|
|
161
|
+
concept_name: Optional[str] = None,
|
|
162
|
+
schema_name: Optional[str] = None,
|
|
163
|
+
error: Optional[str] = None,
|
|
164
|
+
) -> Dict[str, Any]:
|
|
165
|
+
|
|
166
|
+
if not prompt:
|
|
167
|
+
raise ValueError("Timbr SQL or user prompt is required for executing the chain.")
|
|
168
|
+
|
|
169
|
+
err_txt = f"\nThe original SQL (`{sql}`) was invalid with error: {error}. Please generate a corrected query." if error else ""
|
|
170
|
+
generate_res = generate_sql(
|
|
171
|
+
prompt,
|
|
172
|
+
self._llm,
|
|
173
|
+
self._get_conn_params(),
|
|
174
|
+
concept=concept_name,
|
|
175
|
+
schema=schema_name,
|
|
176
|
+
concepts_list=self._concepts_list,
|
|
177
|
+
views_list=self._views_list,
|
|
178
|
+
include_tags=self._include_tags,
|
|
179
|
+
include_logic_concepts=self._include_logic_concepts,
|
|
180
|
+
exclude_properties=self._exclude_properties,
|
|
181
|
+
should_validate_sql=self._should_validate_sql,
|
|
182
|
+
retries=self._retries,
|
|
183
|
+
max_limit=self._max_limit,
|
|
184
|
+
note=self._note + err_txt,
|
|
185
|
+
db_is_case_sensitive=self._db_is_case_sensitive,
|
|
186
|
+
graph_depth=self._graph_depth,
|
|
187
|
+
debug=self._debug,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return generate_res
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _has_no_meaningful_results(self, rows: list) -> bool:
|
|
194
|
+
"""
|
|
195
|
+
Check if the rows returned from the query are empty or do not contain meaningful data.
|
|
196
|
+
This can be customized based on specific criteria for what constitutes "meaningful" results.
|
|
197
|
+
"""
|
|
198
|
+
if not rows:
|
|
199
|
+
return True
|
|
200
|
+
|
|
201
|
+
# Check if all rows have all None values
|
|
202
|
+
for row in rows:
|
|
203
|
+
if any(value is not None for value in row.values()):
|
|
204
|
+
return False
|
|
205
|
+
|
|
206
|
+
return True
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _call(self, inputs: Dict[str, Any], run_manager=None) -> Dict[str, Any]:
|
|
210
|
+
try:
|
|
211
|
+
prompt = inputs.get("prompt")
|
|
212
|
+
sql = inputs.get("sql", None)
|
|
213
|
+
schema_name = inputs.get("schema", self._schema)
|
|
214
|
+
concept_name = inputs.get("concept", self._concept)
|
|
215
|
+
is_sql_valid = True
|
|
216
|
+
error = None
|
|
217
|
+
usage_metadata = {}
|
|
218
|
+
|
|
219
|
+
if sql and self._should_validate_sql:
|
|
220
|
+
is_sql_valid, error = validate_sql(sql, self._get_conn_params())
|
|
221
|
+
|
|
222
|
+
is_infered = False
|
|
223
|
+
iteration = 0
|
|
224
|
+
generated = []
|
|
225
|
+
while not is_infered and iteration <= self._no_results_max_retries:
|
|
226
|
+
if prompt is not None and not sql or not is_sql_valid:
|
|
227
|
+
generate_res = self._generate_sql(prompt, sql, concept_name, schema_name, error)
|
|
228
|
+
|
|
229
|
+
sql = generate_res.get("sql", "")
|
|
230
|
+
schema_name = generate_res.get("schema", schema_name)
|
|
231
|
+
concept_name = generate_res.get("concept", concept_name)
|
|
232
|
+
is_sql_valid = generate_res.get("is_sql_valid")
|
|
233
|
+
if not is_sql_valid and not self._should_validate_sql:
|
|
234
|
+
is_sql_valid = True
|
|
235
|
+
|
|
236
|
+
error = generate_res.get("error")
|
|
237
|
+
usage_metadata = self._summarize_usage_metadata(usage_metadata, generate_res.get("usage_metadata", {}))
|
|
238
|
+
|
|
239
|
+
is_sql_not_tried = not any(sql.lower().strip() == gen.lower().strip() for gen in generated)
|
|
240
|
+
|
|
241
|
+
rows = run_query(
|
|
242
|
+
sql,
|
|
243
|
+
self._get_conn_params(),
|
|
244
|
+
llm_prompt=prompt
|
|
245
|
+
) if is_sql_valid and is_sql_not_tried else []
|
|
246
|
+
|
|
247
|
+
if iteration < self._no_results_max_retries:
|
|
248
|
+
# If no rows are returned and we should infer the result, we will try to re-generate the SQL query
|
|
249
|
+
if prompt is not None and self._retry_if_no_results and self._has_no_meaningful_results(rows):
|
|
250
|
+
if is_sql_not_tried:
|
|
251
|
+
generated.append(sql)
|
|
252
|
+
# If the SQL is valid but no rows are returned, create an error message to be sent to the LLM
|
|
253
|
+
if is_sql_valid:
|
|
254
|
+
error = f"No rows returned. Please revise the SQL considering if the question was ambiguous (e.g., which ID or name to use), try use alternative columns in the WHERE clause part in a way that could match the user's intent, without adding new columns with new filters."
|
|
255
|
+
error += "\nConsider that this queries already generated and returned 0 rows:\n" + "\n".join(generated)
|
|
256
|
+
is_sql_valid = False
|
|
257
|
+
else:
|
|
258
|
+
# Generated twice the same SQL, so we will stop the loop
|
|
259
|
+
is_infered = True
|
|
260
|
+
else:
|
|
261
|
+
is_infered = True
|
|
262
|
+
iteration += 1
|
|
263
|
+
|
|
264
|
+
return {
|
|
265
|
+
"rows": rows,
|
|
266
|
+
"sql": sql,
|
|
267
|
+
"schema": schema_name,
|
|
268
|
+
"concept": concept_name,
|
|
269
|
+
"error": error if not is_sql_valid else None,
|
|
270
|
+
self.usage_metadata_key: usage_metadata,
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
except Exception as e:
|
|
274
|
+
raise RuntimeError(f"Error executing the chain: {str(e)}")
|
|
275
|
+
|
|
276
|
+
def _summarize_usage_metadata(self, current_metadata: dict, new_metadata: dict) -> dict:
|
|
277
|
+
"""
|
|
278
|
+
Summarize usage metadata by aggregating specific numeric keys and overriding others.
|
|
279
|
+
|
|
280
|
+
:param current_metadata: The existing usage metadata dictionary
|
|
281
|
+
:param new_metadata: The new usage metadata to be added
|
|
282
|
+
:return: Updated usage metadata dictionary
|
|
283
|
+
"""
|
|
284
|
+
keys_to_sum = ['approximate', 'input_tokens', 'output_tokens', 'total_tokens']
|
|
285
|
+
|
|
286
|
+
for outer_key, outer_value in new_metadata.items():
|
|
287
|
+
if isinstance(outer_value, dict):
|
|
288
|
+
if outer_key not in current_metadata:
|
|
289
|
+
current_metadata[outer_key] = {}
|
|
290
|
+
|
|
291
|
+
for inner_key, inner_value in outer_value.items():
|
|
292
|
+
if inner_key in keys_to_sum:
|
|
293
|
+
# Sum the numeric values
|
|
294
|
+
current_val = current_metadata[outer_key].get(inner_key, 0)
|
|
295
|
+
if isinstance(inner_value, (int, float)) and isinstance(current_val, (int, float)):
|
|
296
|
+
current_metadata[outer_key][inner_key] = current_val + inner_value
|
|
297
|
+
else:
|
|
298
|
+
current_metadata[outer_key][inner_key] = inner_value
|
|
299
|
+
else:
|
|
300
|
+
# Override other keys
|
|
301
|
+
current_metadata[outer_key][inner_key] = inner_value
|
|
302
|
+
else:
|
|
303
|
+
# If the outer value is not a dict, just override it
|
|
304
|
+
current_metadata[outer_key] = outer_value
|
|
305
|
+
|
|
306
|
+
return current_metadata
|
|
307
|
+
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from typing import Optional, Dict, Any
|
|
2
|
+
from langchain.chains.base import Chain
|
|
3
|
+
from langchain.llms.base import LLM
|
|
4
|
+
|
|
5
|
+
from ..utils.general import to_boolean
|
|
6
|
+
from ..utils.timbr_llm_utils import answer_question
|
|
7
|
+
|
|
8
|
+
class GenerateAnswerChain(Chain):
|
|
9
|
+
"""
|
|
10
|
+
Chain that generates an answer based on a given prompt and rows of data.
|
|
11
|
+
It uses the LLM to build a human-readable answer.
|
|
12
|
+
|
|
13
|
+
This chain connects to a Timbr server via the provided URL and token to generate contextual answers from query results using an LLM.
|
|
14
|
+
"""
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
llm: LLM,
|
|
18
|
+
url: str,
|
|
19
|
+
token: str,
|
|
20
|
+
verify_ssl: Optional[bool] = True,
|
|
21
|
+
is_jwt: Optional[bool] = False,
|
|
22
|
+
jwt_tenant_id: Optional[str] = None,
|
|
23
|
+
conn_params: Optional[dict] = None,
|
|
24
|
+
debug: Optional[bool] = False,
|
|
25
|
+
**kwargs,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
:param llm: An LLM instance or a function that takes a prompt string and returns the LLM’s response
|
|
29
|
+
:param url: Timbr server url
|
|
30
|
+
:param token: Timbr password or token value
|
|
31
|
+
:param verify_ssl: Whether to verify SSL certificates (default is True).
|
|
32
|
+
:param is_jwt: Whether to use JWT authentication (default is False).
|
|
33
|
+
:param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
|
|
34
|
+
:param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
|
|
35
|
+
|
|
36
|
+
## Example
|
|
37
|
+
```
|
|
38
|
+
generate_answer_chain = GenerateAnswerChain(
|
|
39
|
+
llm=<llm or timbr_llm_wrapper instance>
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
return generate_answer_chain.invoke({ "prompt": prompt, "rows": rows }).get("answer", [])
|
|
43
|
+
```
|
|
44
|
+
"""
|
|
45
|
+
super().__init__(**kwargs)
|
|
46
|
+
self._llm = llm
|
|
47
|
+
self._url = url
|
|
48
|
+
self._token = token
|
|
49
|
+
self._verify_ssl = to_boolean(verify_ssl)
|
|
50
|
+
self._is_jwt = to_boolean(is_jwt)
|
|
51
|
+
self._jwt_tenant_id = jwt_tenant_id
|
|
52
|
+
self._debug = to_boolean(debug)
|
|
53
|
+
self._conn_params = conn_params or {}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def usage_metadata_key(self) -> str:
|
|
58
|
+
return "generate_answer_usage_metadata"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def input_keys(self) -> list:
|
|
63
|
+
return ["prompt", "rows"]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def output_keys(self) -> list:
|
|
68
|
+
return ["answer", self.usage_metadata_key]
|
|
69
|
+
|
|
70
|
+
def _get_conn_params(self) -> dict:
|
|
71
|
+
return {
|
|
72
|
+
"url": self._url,
|
|
73
|
+
"token": self._token,
|
|
74
|
+
# "ontology": self._ontology,
|
|
75
|
+
"verify_ssl": self._verify_ssl,
|
|
76
|
+
"is_jwt": self._is_jwt,
|
|
77
|
+
"jwt_tenant_id": self._jwt_tenant_id,
|
|
78
|
+
**self._conn_params,
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _call(self, inputs: Dict[str, Any], run_manager=None) -> Dict[str, str]:
|
|
83
|
+
prompt = inputs["prompt"]
|
|
84
|
+
rows = inputs["rows"]
|
|
85
|
+
sql = inputs['sql'] if 'sql' in inputs else None
|
|
86
|
+
|
|
87
|
+
res = answer_question(
|
|
88
|
+
question=prompt,
|
|
89
|
+
llm=self._llm,
|
|
90
|
+
conn_params=self._get_conn_params(),
|
|
91
|
+
results=rows,
|
|
92
|
+
sql=sql,
|
|
93
|
+
debug=self._debug,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return {
|
|
97
|
+
"answer": res.get("answer", ""),
|
|
98
|
+
self.usage_metadata_key: res.get("usage_metadata", {}),
|
|
99
|
+
}
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from typing import Optional, Union, Dict, Any
|
|
2
|
+
from langchain.chains.base import Chain
|
|
3
|
+
from langchain.llms.base import LLM
|
|
4
|
+
|
|
5
|
+
from ..utils.general import parse_list, to_boolean, to_integer
|
|
6
|
+
from ..utils.timbr_llm_utils import generate_sql
|
|
7
|
+
|
|
8
|
+
class GenerateTimbrSqlChain(Chain):
|
|
9
|
+
"""
|
|
10
|
+
LangChain chain for generating SQL queries from natural language prompts using Timbr knowledge graphs.
|
|
11
|
+
|
|
12
|
+
This chain takes user prompts and generates corresponding SQL queries that can be executed
|
|
13
|
+
against Timbr ontology/knowledge graph databases. It uses an LLM to process prompts and
|
|
14
|
+
connects to Timbr via URL and token for SQL generation.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
llm: LLM,
|
|
20
|
+
url: str,
|
|
21
|
+
token: str,
|
|
22
|
+
ontology: str,
|
|
23
|
+
schema: Optional[str] = 'dtimbr',
|
|
24
|
+
concept: Optional[str] = None,
|
|
25
|
+
concepts_list: Optional[Union[list[str], str]] = None,
|
|
26
|
+
views_list: Optional[Union[list[str], str]] = None,
|
|
27
|
+
include_tags: Optional[Union[list[str], str]] = None,
|
|
28
|
+
include_logic_concepts: Optional[bool] = False,
|
|
29
|
+
exclude_properties: Optional[Union[list[str], str]] = ['entity_id', 'entity_type', 'entity_label'],
|
|
30
|
+
should_validate_sql: Optional[bool] = False,
|
|
31
|
+
retries: Optional[int] = 3,
|
|
32
|
+
max_limit: Optional[int] = 500,
|
|
33
|
+
note: Optional[str] = '',
|
|
34
|
+
db_is_case_sensitive: Optional[bool] = False,
|
|
35
|
+
graph_depth: Optional[int] = 1,
|
|
36
|
+
verify_ssl: Optional[bool] = True,
|
|
37
|
+
is_jwt: Optional[bool] = False,
|
|
38
|
+
jwt_tenant_id: Optional[str] = None,
|
|
39
|
+
conn_params: Optional[dict] = None,
|
|
40
|
+
debug: Optional[bool] = False,
|
|
41
|
+
**kwargs,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
:param llm: An LLM instance or a function that takes a prompt string and returns the LLM’s response
|
|
45
|
+
:param url: Timbr server url
|
|
46
|
+
:param token: Timbr password or token value
|
|
47
|
+
:param ontology: The name of the ontology/knowledge graph
|
|
48
|
+
:param schema: The name of the schema to query
|
|
49
|
+
:param concept: The name of the concept to query
|
|
50
|
+
:param concepts_list: Optional specific concept options to query
|
|
51
|
+
:param views_list: Optional specific view options to query
|
|
52
|
+
:param include_tags: Optional specific concepts & properties tag options to use in the query (Disabled by default. Use '*' to enable all tags or a string represents a list of tags divided by commas (e.g. 'tag1,tag2')
|
|
53
|
+
:param include_logic_concepts: Optional boolean to include logic concepts (concepts without unique properties which only inherits from an upper level concept with filter logic) in the query.
|
|
54
|
+
:param exclude_properties: Optional specific properties to exclude from the query (entity_id, entity_type & entity_label by default).
|
|
55
|
+
:param should_validate_sql: Whether to validate the SQL before executing it
|
|
56
|
+
:param retries: Number of retry attempts if the generated SQL is invalid
|
|
57
|
+
:param max_limit: Maximum number of rows to query
|
|
58
|
+
:param note: Optional additional note to extend our llm prompt
|
|
59
|
+
:param db_is_case_sensitive: Whether the database is case sensitive (default is False).
|
|
60
|
+
:param graph_depth: Maximum number of relationship hops to traverse from the source concept during schema exploration (default is 1).
|
|
61
|
+
:param verify_ssl: Whether to verify SSL certificates (default is True).
|
|
62
|
+
:param is_jwt: Whether to use JWT authentication (default is False).
|
|
63
|
+
:param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
|
|
64
|
+
:param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
|
|
65
|
+
:param kwargs: Additional arguments to pass to the base
|
|
66
|
+
|
|
67
|
+
## Example
|
|
68
|
+
```
|
|
69
|
+
generate_timbr_sql_chain = GenerateTimbrSqlChain(
|
|
70
|
+
url=<url>,
|
|
71
|
+
token=<token>,
|
|
72
|
+
llm=<llm or timbr_llm_wrapper instance>,
|
|
73
|
+
ontology=<ontology_name>,
|
|
74
|
+
schema=<schema_name>,
|
|
75
|
+
concept=<concept_name>,
|
|
76
|
+
concepts_list=<concepts>,
|
|
77
|
+
views_list=<views>,
|
|
78
|
+
include_tags=<tags>,
|
|
79
|
+
note=<note>,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return generate_timbr_sql_chain.invoke({ "prompt": question }).get("sql", [])
|
|
83
|
+
```
|
|
84
|
+
"""
|
|
85
|
+
super().__init__(**kwargs)
|
|
86
|
+
self._llm = llm
|
|
87
|
+
self._url = url
|
|
88
|
+
self._token = token
|
|
89
|
+
self._ontology = ontology
|
|
90
|
+
self._schema = schema
|
|
91
|
+
self._concept = concept
|
|
92
|
+
self._concepts_list = parse_list(concepts_list)
|
|
93
|
+
self._views_list = parse_list(views_list)
|
|
94
|
+
self._include_tags = parse_list(include_tags)
|
|
95
|
+
self._include_logic_concepts = to_boolean(include_logic_concepts)
|
|
96
|
+
self._should_validate_sql = to_boolean(should_validate_sql)
|
|
97
|
+
self._exclude_properties = parse_list(exclude_properties)
|
|
98
|
+
self._retries = to_integer(retries)
|
|
99
|
+
self._max_limit = to_integer(max_limit)
|
|
100
|
+
self._note = note
|
|
101
|
+
self._db_is_case_sensitive = to_boolean(db_is_case_sensitive)
|
|
102
|
+
self._graph_depth = to_integer(graph_depth)
|
|
103
|
+
self._verify_ssl = to_boolean(verify_ssl)
|
|
104
|
+
self._is_jwt = to_boolean(is_jwt)
|
|
105
|
+
self._jwt_tenant_id = jwt_tenant_id
|
|
106
|
+
self._debug = to_boolean(debug)
|
|
107
|
+
self._conn_params = conn_params or {}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def usage_metadata_key(self) -> str:
|
|
112
|
+
return "generate_sql_usage_metadata"
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def input_keys(self) -> list:
|
|
117
|
+
return ["prompt"]
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def output_keys(self) -> list:
|
|
121
|
+
return [
|
|
122
|
+
"sql",
|
|
123
|
+
"schema",
|
|
124
|
+
"concept",
|
|
125
|
+
"is_sql_valid",
|
|
126
|
+
"error",
|
|
127
|
+
self.usage_metadata_key,
|
|
128
|
+
]
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _get_conn_params(self) -> dict:
|
|
132
|
+
return {
|
|
133
|
+
"url": self._url,
|
|
134
|
+
"token": self._token,
|
|
135
|
+
"ontology": self._ontology,
|
|
136
|
+
"verify_ssl": self._verify_ssl,
|
|
137
|
+
"is_jwt": self._is_jwt,
|
|
138
|
+
"jwt_tenant_id": self._jwt_tenant_id,
|
|
139
|
+
**self._conn_params,
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _call(self, inputs: Dict[str, Any], run_manager=None) -> Dict[str, str]:
|
|
144
|
+
prompt = inputs["prompt"]
|
|
145
|
+
generate_res = generate_sql(
|
|
146
|
+
question=prompt,
|
|
147
|
+
llm=self._llm,
|
|
148
|
+
conn_params=self._get_conn_params(),
|
|
149
|
+
schema=self._schema,
|
|
150
|
+
concept=self._concept,
|
|
151
|
+
concepts_list=self._concepts_list,
|
|
152
|
+
views_list=self._views_list,
|
|
153
|
+
include_tags=self._include_tags,
|
|
154
|
+
include_logic_concepts=self._include_logic_concepts,
|
|
155
|
+
exclude_properties=self._exclude_properties,
|
|
156
|
+
should_validate_sql=self._should_validate_sql,
|
|
157
|
+
retries=self._retries,
|
|
158
|
+
max_limit=self._max_limit,
|
|
159
|
+
note=self._note,
|
|
160
|
+
db_is_case_sensitive=self._db_is_case_sensitive,
|
|
161
|
+
graph_depth=self._graph_depth,
|
|
162
|
+
debug=self._debug,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
sql = generate_res.get("sql", "")
|
|
166
|
+
schema = generate_res.get("schema", self._schema)
|
|
167
|
+
concept = generate_res.get("concept", self._concept)
|
|
168
|
+
|
|
169
|
+
return {
|
|
170
|
+
"sql": sql,
|
|
171
|
+
"schema": schema,
|
|
172
|
+
"concept": concept,
|
|
173
|
+
"is_sql_valid": generate_res.get("is_sql_valid"),
|
|
174
|
+
"error": generate_res.get("error"),
|
|
175
|
+
self.usage_metadata_key: generate_res.get("usage_metadata"),
|
|
176
|
+
}
|