langchain-timbr 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_timbr/__init__.py +17 -0
- langchain_timbr/config.py +21 -0
- langchain_timbr/langchain/__init__.py +16 -0
- langchain_timbr/langchain/execute_timbr_query_chain.py +307 -0
- langchain_timbr/langchain/generate_answer_chain.py +99 -0
- langchain_timbr/langchain/generate_timbr_sql_chain.py +176 -0
- langchain_timbr/langchain/identify_concept_chain.py +138 -0
- langchain_timbr/langchain/timbr_sql_agent.py +418 -0
- langchain_timbr/langchain/validate_timbr_sql_chain.py +187 -0
- langchain_timbr/langgraph/__init__.py +13 -0
- langchain_timbr/langgraph/execute_timbr_query_node.py +108 -0
- langchain_timbr/langgraph/generate_response_node.py +59 -0
- langchain_timbr/langgraph/generate_timbr_sql_node.py +98 -0
- langchain_timbr/langgraph/identify_concept_node.py +78 -0
- langchain_timbr/langgraph/validate_timbr_query_node.py +100 -0
- langchain_timbr/llm_wrapper/llm_wrapper.py +189 -0
- langchain_timbr/llm_wrapper/timbr_llm_wrapper.py +41 -0
- langchain_timbr/timbr_llm_connector.py +398 -0
- langchain_timbr/utils/general.py +70 -0
- langchain_timbr/utils/prompt_service.py +330 -0
- langchain_timbr/utils/temperature_supported_models.json +62 -0
- langchain_timbr/utils/timbr_llm_utils.py +575 -0
- langchain_timbr/utils/timbr_utils.py +475 -0
- langchain_timbr-1.5.0.dist-info/METADATA +103 -0
- langchain_timbr-1.5.0.dist-info/RECORD +27 -0
- langchain_timbr-1.5.0.dist-info/WHEEL +4 -0
- langchain_timbr-1.5.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,575 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
from langchain.llms.base import LLM
|
|
3
|
+
import base64, hashlib
|
|
4
|
+
from cryptography.fernet import Fernet
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
import concurrent.futures
|
|
7
|
+
import time
|
|
8
|
+
|
|
9
|
+
from .timbr_utils import get_datasources, get_tags, get_concepts, get_concept_properties, validate_sql, get_properties_description, get_relationships_description
|
|
10
|
+
from .prompt_service import (
|
|
11
|
+
get_determine_concept_prompt_template,
|
|
12
|
+
get_generate_sql_prompt_template,
|
|
13
|
+
get_qa_prompt_template
|
|
14
|
+
)
|
|
15
|
+
from ..config import llm_timeout
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _clean_snowflake_prompt(prompt: Any) -> None:
|
|
19
|
+
import re
|
|
20
|
+
|
|
21
|
+
def clean_func(prompt_content: str) -> str:
|
|
22
|
+
raw = prompt_content
|
|
23
|
+
# 1. Normalize Windows/Mac line endings → '\n'
|
|
24
|
+
raw = raw.replace('\r\n', '\n').replace('\r', '\n')
|
|
25
|
+
|
|
26
|
+
# 2. Collapse any multiple blank lines → single '\n'
|
|
27
|
+
raw = re.sub(r'\n{2,}', '\n', raw)
|
|
28
|
+
|
|
29
|
+
# 3. Convert ALL real '\n' → literal backslash-n
|
|
30
|
+
raw = raw.replace('\n', '\\n')
|
|
31
|
+
|
|
32
|
+
# 4. Normalize curly quotes to straight ASCII
|
|
33
|
+
raw = (raw
|
|
34
|
+
.replace('’', "'")
|
|
35
|
+
.replace('‘', "'")
|
|
36
|
+
.replace('“', '"')
|
|
37
|
+
.replace('”', '"'))
|
|
38
|
+
|
|
39
|
+
# 5. Collapse any accidental double-backticks → single backtick
|
|
40
|
+
raw = raw.replace('``', '`')
|
|
41
|
+
|
|
42
|
+
# 6. Escape ALL backslashes so '\\n' survives as two chars
|
|
43
|
+
raw = raw.replace('\\', '\\\\')
|
|
44
|
+
|
|
45
|
+
# 7. Escape single-quotes for SQL string literal
|
|
46
|
+
raw = raw.replace("'", "''")
|
|
47
|
+
|
|
48
|
+
# 8. Escape double-quotes for SQL string literal
|
|
49
|
+
raw = raw.replace('"', '\\"')
|
|
50
|
+
|
|
51
|
+
return raw
|
|
52
|
+
|
|
53
|
+
prompt[0].content = clean_func(prompt[0].content) # System message
|
|
54
|
+
prompt[1].content = clean_func(prompt[1].content) # User message
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def generate_key() -> bytes:
|
|
58
|
+
"""Generate a new Fernet secret key."""
|
|
59
|
+
passcode = b"lucylit2025"
|
|
60
|
+
hlib = hashlib.md5()
|
|
61
|
+
hlib.update(passcode)
|
|
62
|
+
return base64.urlsafe_b64encode(hlib.hexdigest().encode('utf-8'))
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _call_llm_with_timeout(llm: LLM, prompt: Any, timeout: int = 60) -> Any:
|
|
66
|
+
"""
|
|
67
|
+
Call LLM with timeout to prevent hanging.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
llm: The LLM instance
|
|
71
|
+
prompt: The prompt to send
|
|
72
|
+
timeout: Timeout in seconds (default: 60)
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
LLM response
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
TimeoutError: If the call takes longer than timeout seconds
|
|
79
|
+
Exception: Any other exception from the LLM call
|
|
80
|
+
"""
|
|
81
|
+
def _llm_call():
|
|
82
|
+
return llm(prompt)
|
|
83
|
+
|
|
84
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
85
|
+
future = executor.submit(_llm_call)
|
|
86
|
+
try:
|
|
87
|
+
return future.result(timeout=timeout)
|
|
88
|
+
except concurrent.futures.TimeoutError:
|
|
89
|
+
raise TimeoutError(f"LLM call timed out after {timeout} seconds")
|
|
90
|
+
except Exception as e:
|
|
91
|
+
raise e
|
|
92
|
+
|
|
93
|
+
ENCRYPT_KEY = generate_key()
|
|
94
|
+
MEASURES_DESCRIPTION = "The following columns are calculated measures and can only be aggregated with an aggregate function: COUNT/SUM/AVG/MIN/MAX (count distinct is not allowed)"
|
|
95
|
+
TRANSITIVE_RELATIONSHIP_DESCRIPTION = "Transitive relationship columns allow you to access data through multiple relationship hops. These columns follow the pattern `<relationship_name>[<table_name>*<number>].<column_name>` where the number after the asterisk (*) indicates how many relationship levels to traverse. For example, `acquired_by[company*4].company_name` means 'go through up to 4 levels of the acquired_by relationship to get the company name', while columns ending with '_transitivity_level' indicate the actual relationship depth (Cannot be null or 0 - level 1 represents direct relationships, while levels 2, 3, 4, etc. represent indirect relationships through multiple hops. To filter by relationship type, use `_transitivity_level = 1` for direct relationships only, `_transitivity_level > 1` for indirect relationships only."
|
|
96
|
+
|
|
97
|
+
def encrypt_prompt(prompt: Any, key: Optional[bytes] = ENCRYPT_KEY) -> bytes:
|
|
98
|
+
"""Serialize & encrypt the prompt; returns a URL-safe token."""
|
|
99
|
+
# build prompt_text as before…
|
|
100
|
+
if isinstance(prompt, str):
|
|
101
|
+
text = prompt
|
|
102
|
+
elif isinstance(prompt, list):
|
|
103
|
+
parts = []
|
|
104
|
+
for message in prompt:
|
|
105
|
+
if hasattr(message, "content"):
|
|
106
|
+
parts.append(f"{message.type}: {message.content}")
|
|
107
|
+
else:
|
|
108
|
+
parts.append(str(message))
|
|
109
|
+
text = "\n".join(parts)
|
|
110
|
+
else:
|
|
111
|
+
text = str(prompt)
|
|
112
|
+
|
|
113
|
+
f = Fernet(key)
|
|
114
|
+
return f.encrypt(text.encode()).decode('utf-8')
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def decrypt_prompt(token: bytes, key: bytes) -> str:
|
|
118
|
+
"""Decrypt the token and return the original prompt string."""
|
|
119
|
+
f = Fernet(key)
|
|
120
|
+
return f.decrypt(token).decode()
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _prompt_to_string(prompt: Any) -> str:
|
|
124
|
+
prompt_text = ''
|
|
125
|
+
if isinstance(prompt, str):
|
|
126
|
+
prompt_text = prompt
|
|
127
|
+
elif isinstance(prompt, list):
|
|
128
|
+
for message in prompt:
|
|
129
|
+
if hasattr(message, "content"):
|
|
130
|
+
prompt_text += message.type + ": " + message.content + "\n"
|
|
131
|
+
else:
|
|
132
|
+
prompt_text += str(message)
|
|
133
|
+
else:
|
|
134
|
+
prompt_text = str(prompt)
|
|
135
|
+
return prompt_text.strip()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _calculate_token_count(llm: LLM, prompt: str) -> int:
|
|
139
|
+
"""
|
|
140
|
+
Calculate the token count for a given prompt text using the specified LLM.
|
|
141
|
+
Falls back to tiktoken if the LLM doesn't support token counting.
|
|
142
|
+
"""
|
|
143
|
+
token_count = 0
|
|
144
|
+
try:
|
|
145
|
+
if hasattr(llm, "get_num_tokens_from_messages"):
|
|
146
|
+
token_count = llm.get_num_tokens_from_messages(prompt)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
#print(f"Error with primary token counting: {e}")
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
# Use tiktoken as fallback if token_count is still 0
|
|
152
|
+
if token_count == 0:
|
|
153
|
+
try:
|
|
154
|
+
import tiktoken
|
|
155
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
156
|
+
if isinstance(prompt, str):
|
|
157
|
+
token_count = len(encoding.encode(prompt))
|
|
158
|
+
else:
|
|
159
|
+
prompt_text = _prompt_to_string(prompt)
|
|
160
|
+
token_count = len(encoding.encode(prompt_text))
|
|
161
|
+
except Exception as e2:
|
|
162
|
+
#print(f"Error calculating token count with fallback method: {e2}")
|
|
163
|
+
pass
|
|
164
|
+
|
|
165
|
+
return token_count
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def determine_concept(
|
|
169
|
+
question: str,
|
|
170
|
+
llm: LLM,
|
|
171
|
+
conn_params: dict,
|
|
172
|
+
concepts_list: Optional[list] = None,
|
|
173
|
+
views_list: Optional[list] = None,
|
|
174
|
+
include_logic_concepts: Optional[bool] = False,
|
|
175
|
+
include_tags: Optional[str] = None,
|
|
176
|
+
should_validate: Optional[bool] = False,
|
|
177
|
+
retries: Optional[int] = 3,
|
|
178
|
+
note: Optional[str] = '',
|
|
179
|
+
debug: Optional[bool] = False,
|
|
180
|
+
timeout: Optional[int] = None,
|
|
181
|
+
) -> dict[str, any]:
|
|
182
|
+
usage_metadata = {}
|
|
183
|
+
determined_concept_name = None
|
|
184
|
+
schema = 'dtimbr'
|
|
185
|
+
|
|
186
|
+
# Use config default timeout if none provided
|
|
187
|
+
if timeout is None:
|
|
188
|
+
timeout = llm_timeout
|
|
189
|
+
|
|
190
|
+
determine_concept_prompt = get_determine_concept_prompt_template(conn_params["token"], conn_params["is_jwt"], conn_params["jwt_tenant_id"])
|
|
191
|
+
tags = get_tags(conn_params=conn_params, include_tags=include_tags)
|
|
192
|
+
concepts = get_concepts(
|
|
193
|
+
conn_params=conn_params,
|
|
194
|
+
concepts_list=concepts_list,
|
|
195
|
+
views_list=views_list,
|
|
196
|
+
include_logic_concepts=include_logic_concepts,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
if not concepts:
|
|
200
|
+
raise Exception("No relevant concepts found for the query.")
|
|
201
|
+
|
|
202
|
+
concepts_desc_arr = []
|
|
203
|
+
for concept in concepts.values():
|
|
204
|
+
concept_name = concept.get('concept')
|
|
205
|
+
concept_desc = concept.get('description')
|
|
206
|
+
concept_tags = tags.get('concept_tags').get(concept_name) if concept.get('is_view') == 'false' else tags.get('view_tags').get(concept_name)
|
|
207
|
+
|
|
208
|
+
if concept_tags:
|
|
209
|
+
concept_tags = str(concept_tags).replace('{', '').replace('}', '').replace("'", '')
|
|
210
|
+
|
|
211
|
+
concept_verbose = f"`{concept_name}`"
|
|
212
|
+
if concept_desc:
|
|
213
|
+
concept_verbose += f" (description: {concept_desc})"
|
|
214
|
+
if concept_tags:
|
|
215
|
+
concept_verbose += f" [tags: {concept_tags}]"
|
|
216
|
+
concepts[concept_name]['tags'] = f"- Annotations and constraints: {concept_tags}\n"
|
|
217
|
+
|
|
218
|
+
concepts_desc_arr.append(concept_verbose)
|
|
219
|
+
|
|
220
|
+
combined_list = concepts_list + views_list
|
|
221
|
+
|
|
222
|
+
if len(combined_list) == 1 and not (combined_list[0].lower() == 'none' or combined_list[0].lower() == 'null'):
|
|
223
|
+
# If only one concept is provided, return it directly
|
|
224
|
+
determined_concept_name = concepts_list[0] if concepts_list else views_list[0]
|
|
225
|
+
|
|
226
|
+
if determined_concept_name not in concepts:
|
|
227
|
+
raise Exception(f"'{determined_concept_name}' was not found in the ontology.")
|
|
228
|
+
|
|
229
|
+
else:
|
|
230
|
+
# Use LLM to determine the concept based on the question
|
|
231
|
+
iteration = 0
|
|
232
|
+
error = ''
|
|
233
|
+
while determined_concept_name is None and iteration < retries:
|
|
234
|
+
iteration += 1
|
|
235
|
+
err_txt = f"\nLast try got an error: {error}" if error else ""
|
|
236
|
+
prompt = determine_concept_prompt.format_messages(
|
|
237
|
+
question=question.strip(),
|
|
238
|
+
concepts=",".join(concepts_desc_arr),
|
|
239
|
+
note=(note or '') + err_txt,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
apx_token_count = _calculate_token_count(llm, prompt)
|
|
243
|
+
if "snowflake" in llm._llm_type:
|
|
244
|
+
_clean_snowflake_prompt(prompt)
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
response = _call_llm_with_timeout(llm, prompt, timeout=timeout)
|
|
248
|
+
except TimeoutError as e:
|
|
249
|
+
error = f"LLM call timed out: {str(e)}"
|
|
250
|
+
continue
|
|
251
|
+
except Exception as e:
|
|
252
|
+
error = f"LLM call failed: {str(e)}"
|
|
253
|
+
continue
|
|
254
|
+
usage_metadata['determine_concept'] = {
|
|
255
|
+
"approximate": apx_token_count,
|
|
256
|
+
# **(response.usage_metadata or response.usage or {}),
|
|
257
|
+
**(response.usage_metadata or {}),
|
|
258
|
+
}
|
|
259
|
+
if debug:
|
|
260
|
+
usage_metadata['determine_concept']["p_hash"] = encrypt_prompt(prompt)
|
|
261
|
+
|
|
262
|
+
if hasattr(response, "content"):
|
|
263
|
+
response_text = response.content
|
|
264
|
+
elif isinstance(response, str):
|
|
265
|
+
response_text = response
|
|
266
|
+
else:
|
|
267
|
+
raise ValueError("Unexpected response format from LLM.")
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
candidate = response_text.strip()
|
|
271
|
+
if should_validate and candidate not in concepts.keys():
|
|
272
|
+
error = f"Concept '{determined_concept_name}' not found in the list of concepts."
|
|
273
|
+
continue
|
|
274
|
+
|
|
275
|
+
determined_concept_name = candidate
|
|
276
|
+
error = ''
|
|
277
|
+
|
|
278
|
+
if determined_concept_name is None and error != '':
|
|
279
|
+
raise Exception(f"Failed to determine concept: {error}")
|
|
280
|
+
|
|
281
|
+
if determined_concept_name:
|
|
282
|
+
schema = 'vtimbr' if concepts.get(determined_concept_name).get('is_view') == 'true' else 'dtimbr'
|
|
283
|
+
return {
|
|
284
|
+
"concept_metadata": concepts.get(determined_concept_name) if determined_concept_name else None,
|
|
285
|
+
"concept": determined_concept_name,
|
|
286
|
+
"schema": schema,
|
|
287
|
+
"usage_metadata": usage_metadata,
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def _build_columns_str(
|
|
292
|
+
columns: list[dict],
|
|
293
|
+
columns_tags: Optional[dict] = {},
|
|
294
|
+
exclude: Optional[list] = None,
|
|
295
|
+
) -> str:
|
|
296
|
+
columns_desc_arr = []
|
|
297
|
+
for col in columns:
|
|
298
|
+
full_name = col.get('name') or col.get('col_name') # When rel column, it can be `relationship_name[column_name]`
|
|
299
|
+
col_name = col.get('col_name', '')
|
|
300
|
+
|
|
301
|
+
if col_name.startswith("measure."):
|
|
302
|
+
col_name = col_name.replace("measure.", "")
|
|
303
|
+
|
|
304
|
+
if exclude and (col_name in exclude or any(col_name.endswith('.' + exc) for exc in exclude)):
|
|
305
|
+
continue
|
|
306
|
+
|
|
307
|
+
col_tags = str(columns_tags.get(col_name)) if columns_tags.get(col_name) else None
|
|
308
|
+
if col_tags:
|
|
309
|
+
col_tags = col_tags.replace('{', '').replace('}', '').replace("'", '').replace(": ", " - ").replace(",", ". ").strip()
|
|
310
|
+
|
|
311
|
+
description = col.get('description') or col.get('comment', '')
|
|
312
|
+
|
|
313
|
+
data_type = col.get('data_type', 'string').lower() or 'string'
|
|
314
|
+
|
|
315
|
+
col_meta = []
|
|
316
|
+
if data_type:
|
|
317
|
+
col_meta.append(f"type: {data_type}")
|
|
318
|
+
if description:
|
|
319
|
+
col_meta.append(f"description: {description}")
|
|
320
|
+
if col_tags:
|
|
321
|
+
col_meta.append(f"annotations and constraints: {col_tags}")
|
|
322
|
+
|
|
323
|
+
col_meta_str = ', '.join(col_meta) if col_meta else ''
|
|
324
|
+
if col_meta_str:
|
|
325
|
+
col_meta_str = f" ({col_meta_str})"
|
|
326
|
+
|
|
327
|
+
columns_desc_arr.append(f"`{full_name}`{col_meta_str}")
|
|
328
|
+
|
|
329
|
+
return ", ".join(columns_desc_arr) if columns_desc_arr else ''
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _build_rel_columns_str(relationships: list[dict], columns_tags: Optional[dict] = {}, exclude_properties: Optional[list] = None) -> str:
|
|
333
|
+
if not relationships:
|
|
334
|
+
return ''
|
|
335
|
+
rel_str_arr = []
|
|
336
|
+
for rel_name in relationships:
|
|
337
|
+
rel = relationships[rel_name]
|
|
338
|
+
rel_description = rel.get('description', '')
|
|
339
|
+
rel_description = f" which described as \"{rel_description}\"" if rel_description else ""
|
|
340
|
+
rel_columns = rel.get('columns', [])
|
|
341
|
+
rel_measures = rel.get('measures', [])
|
|
342
|
+
|
|
343
|
+
if rel_columns:
|
|
344
|
+
joined_columns_str = _build_columns_str(rel_columns, columns_tags=columns_tags, exclude=exclude_properties)
|
|
345
|
+
rel_str_arr.append(f"- The following columns are part of {rel_name} relationship{rel_description}, and must be used as is wrapped with quotes: {joined_columns_str}")
|
|
346
|
+
if rel_measures:
|
|
347
|
+
joined_measures_str = _build_columns_str(rel_measures, columns_tags=columns_tags, exclude=exclude_properties)
|
|
348
|
+
rel_str_arr.append(f"- {MEASURES_DESCRIPTION}, are part of {rel_name} relationship{rel_description}: {joined_measures_str}")
|
|
349
|
+
|
|
350
|
+
return '.\n'.join(rel_str_arr) if rel_str_arr else ''
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def _parse_sql_from_llm_response(response: Any) -> str:
|
|
354
|
+
if hasattr(response, "content"):
|
|
355
|
+
response_text = response.content
|
|
356
|
+
elif isinstance(response, str):
|
|
357
|
+
response_text = response
|
|
358
|
+
else:
|
|
359
|
+
raise ValueError("Unexpected response format from LLM.")
|
|
360
|
+
|
|
361
|
+
return (response_text
|
|
362
|
+
.replace("```sql", "")
|
|
363
|
+
.replace("```", "")
|
|
364
|
+
.replace('SELECT \n', 'SELECT ')
|
|
365
|
+
.replace(';', '')
|
|
366
|
+
.strip())
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def _get_active_datasource(conn_params: dict) -> dict:
|
|
370
|
+
datasources = get_datasources(conn_params, filter_active=True)
|
|
371
|
+
return datasources[0] if datasources else None
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def generate_sql(
|
|
375
|
+
question: str,
|
|
376
|
+
llm: LLM,
|
|
377
|
+
conn_params: dict,
|
|
378
|
+
concept: str,
|
|
379
|
+
schema: Optional[str] = None,
|
|
380
|
+
concepts_list: Optional[list] = None,
|
|
381
|
+
views_list: Optional[list] = None,
|
|
382
|
+
include_logic_concepts: Optional[bool] = False,
|
|
383
|
+
include_tags: Optional[str] = None,
|
|
384
|
+
exclude_properties: Optional[list] = None,
|
|
385
|
+
should_validate_sql: Optional[bool] = False,
|
|
386
|
+
retries: Optional[int] = 3,
|
|
387
|
+
max_limit: Optional[int] = 500,
|
|
388
|
+
note: Optional[str] = '',
|
|
389
|
+
db_is_case_sensitive: Optional[bool] = False,
|
|
390
|
+
graph_depth: Optional[int] = 1,
|
|
391
|
+
debug: Optional[bool] = False,
|
|
392
|
+
timeout: Optional[int] = None,
|
|
393
|
+
) -> dict[str, str]:
|
|
394
|
+
usage_metadata = {}
|
|
395
|
+
concept_metadata = None
|
|
396
|
+
|
|
397
|
+
# Use config default timeout if none provided
|
|
398
|
+
if timeout is None:
|
|
399
|
+
timeout = llm_timeout
|
|
400
|
+
|
|
401
|
+
generate_sql_prompt = get_generate_sql_prompt_template(conn_params["token"], conn_params["is_jwt"], conn_params["jwt_tenant_id"])
|
|
402
|
+
|
|
403
|
+
if concept and concept != "" and (schema is None or schema != "vtimbr"):
|
|
404
|
+
concepts_list = [concept]
|
|
405
|
+
elif concept and concept != "" and schema == "vtimbr":
|
|
406
|
+
views_list = [concept]
|
|
407
|
+
|
|
408
|
+
determine_concept_res = determine_concept(
|
|
409
|
+
question=question,
|
|
410
|
+
llm=llm,
|
|
411
|
+
conn_params=conn_params,
|
|
412
|
+
concepts_list=concepts_list,
|
|
413
|
+
views_list=views_list,
|
|
414
|
+
include_logic_concepts=include_logic_concepts,
|
|
415
|
+
include_tags=include_tags,
|
|
416
|
+
should_validate=should_validate_sql,
|
|
417
|
+
retries=retries,
|
|
418
|
+
note=note,
|
|
419
|
+
debug=debug,
|
|
420
|
+
timeout=timeout,
|
|
421
|
+
)
|
|
422
|
+
concept, schema, concept_metadata = determine_concept_res.get('concept'), determine_concept_res.get('schema'), determine_concept_res.get('concept_metadata')
|
|
423
|
+
usage_metadata.update(determine_concept_res.get('usage_metadata', {}))
|
|
424
|
+
|
|
425
|
+
if not concept:
|
|
426
|
+
raise Exception("No relevant concept found for the query.")
|
|
427
|
+
|
|
428
|
+
datasource_type = _get_active_datasource(conn_params).get('target_type')
|
|
429
|
+
|
|
430
|
+
properties_desc = get_properties_description(conn_params=conn_params)
|
|
431
|
+
relationships_desc = get_relationships_description(conn_params=conn_params)
|
|
432
|
+
|
|
433
|
+
concept_properties_metadata = get_concept_properties(schema=schema, concept_name=concept, conn_params=conn_params, properties_desc=properties_desc, relationships_desc=relationships_desc, graph_depth=graph_depth)
|
|
434
|
+
columns, measures, relationships = concept_properties_metadata.get('columns', []), concept_properties_metadata.get('measures', []), concept_properties_metadata.get('relationships', {})
|
|
435
|
+
tags = get_tags(conn_params=conn_params, include_tags=include_tags).get('property_tags')
|
|
436
|
+
|
|
437
|
+
columns_str = _build_columns_str(columns, columns_tags=tags, exclude=exclude_properties)
|
|
438
|
+
measures_str = _build_columns_str(measures, tags, exclude=exclude_properties)
|
|
439
|
+
rel_prop_str = _build_rel_columns_str(relationships, columns_tags=tags, exclude_properties=exclude_properties)
|
|
440
|
+
|
|
441
|
+
if rel_prop_str:
|
|
442
|
+
measures_str += f"\n{rel_prop_str}"
|
|
443
|
+
|
|
444
|
+
sql_query = None
|
|
445
|
+
iteration = 0
|
|
446
|
+
is_sql_valid = True
|
|
447
|
+
error = ''
|
|
448
|
+
while sql_query is None or (should_validate_sql and iteration < retries and not is_sql_valid):
|
|
449
|
+
iteration += 1
|
|
450
|
+
err_txt = f"\nThe original SQL (`{sql_query}`) was invalid with error: {error}. Please generate a corrected query." if error and "snowflake" not in llm._llm_type else ""
|
|
451
|
+
|
|
452
|
+
sensitivity_txt = "- Ensure value comparisons are case-insensitive, e.g., use LOWER(column) = 'value'.\n" if db_is_case_sensitive else ""
|
|
453
|
+
|
|
454
|
+
measures_context = f"- {MEASURES_DESCRIPTION}: {measures_str}\n" if measures_str else ""
|
|
455
|
+
has_transitive_relationships = any(
|
|
456
|
+
rel.get('is_transitive')
|
|
457
|
+
for rel in relationships.values()
|
|
458
|
+
) if relationships else False
|
|
459
|
+
transitive_context = f"- {TRANSITIVE_RELATIONSHIP_DESCRIPTION}\n" if has_transitive_relationships else ""
|
|
460
|
+
concept_description = f"- Description: {concept_metadata.get('description')}\n" if concept_metadata and concept_metadata.get('description') else ""
|
|
461
|
+
concept_tags = concept_metadata.get('tags') if concept_metadata and concept_metadata.get('tags') else ""
|
|
462
|
+
cur_date = datetime.now().strftime("%Y-%m-%d")
|
|
463
|
+
prompt = generate_sql_prompt.format_messages(
|
|
464
|
+
current_date=cur_date,
|
|
465
|
+
datasource_type=datasource_type or 'standard sql',
|
|
466
|
+
schema=schema,
|
|
467
|
+
concept=f"`{concept}`",
|
|
468
|
+
description=concept_description or "",
|
|
469
|
+
tags=concept_tags or "",
|
|
470
|
+
question=question,
|
|
471
|
+
columns=columns_str,
|
|
472
|
+
measures_context=measures_context,
|
|
473
|
+
transitive_context=transitive_context,
|
|
474
|
+
sensitivity_context=sensitivity_txt,
|
|
475
|
+
max_limit=max_limit,
|
|
476
|
+
note=note + err_txt,
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
apx_token_count = _calculate_token_count(llm, prompt)
|
|
480
|
+
if "snowflake" in llm._llm_type:
|
|
481
|
+
_clean_snowflake_prompt(prompt)
|
|
482
|
+
|
|
483
|
+
try:
|
|
484
|
+
response = _call_llm_with_timeout(llm, prompt, timeout=timeout)
|
|
485
|
+
except TimeoutError as e:
|
|
486
|
+
error = f"LLM call timed out: {str(e)}"
|
|
487
|
+
if should_validate_sql:
|
|
488
|
+
continue
|
|
489
|
+
else:
|
|
490
|
+
raise Exception(error)
|
|
491
|
+
except Exception as e:
|
|
492
|
+
error = f"LLM call failed: {str(e)}"
|
|
493
|
+
if should_validate_sql:
|
|
494
|
+
continue
|
|
495
|
+
else:
|
|
496
|
+
raise Exception(error)
|
|
497
|
+
|
|
498
|
+
usage_metadata['generate_sql'] = {
|
|
499
|
+
"approximate": apx_token_count,
|
|
500
|
+
# **(response.usage_metadata or response.usage or {}),
|
|
501
|
+
**(response.usage_metadata or {}),
|
|
502
|
+
}
|
|
503
|
+
if debug:
|
|
504
|
+
usage_metadata['generate_sql']["p_hash"] = encrypt_prompt(prompt)
|
|
505
|
+
|
|
506
|
+
sql_query = _parse_sql_from_llm_response(response)
|
|
507
|
+
|
|
508
|
+
if should_validate_sql:
|
|
509
|
+
is_sql_valid, error = validate_sql(sql_query, conn_params)
|
|
510
|
+
|
|
511
|
+
return {
|
|
512
|
+
"sql": sql_query,
|
|
513
|
+
"concept": concept,
|
|
514
|
+
"schema": schema,
|
|
515
|
+
"error": error if not is_sql_valid else None,
|
|
516
|
+
"is_sql_valid": is_sql_valid if should_validate_sql else None,
|
|
517
|
+
"usage_metadata": usage_metadata,
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def answer_question(
|
|
522
|
+
question: str,
|
|
523
|
+
llm: LLM,
|
|
524
|
+
conn_params: dict,
|
|
525
|
+
results: str,
|
|
526
|
+
sql: Optional[str] = None,
|
|
527
|
+
debug: Optional[bool] = False,
|
|
528
|
+
timeout: Optional[int] = None,
|
|
529
|
+
) -> dict[str, Any]:
|
|
530
|
+
# Use config default timeout if none provided
|
|
531
|
+
if timeout is None:
|
|
532
|
+
timeout = llm_timeout
|
|
533
|
+
|
|
534
|
+
qa_prompt = get_qa_prompt_template(conn_params["token"], conn_params["is_jwt"], conn_params["jwt_tenant_id"])
|
|
535
|
+
|
|
536
|
+
prompt = qa_prompt.format_messages(
|
|
537
|
+
question=question,
|
|
538
|
+
formatted_rows=results,
|
|
539
|
+
additional_context=f"SQL QUERY:\n{sql}\n\n" if sql else "",
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
apx_token_count = _calculate_token_count(llm, prompt)
|
|
543
|
+
|
|
544
|
+
if "snowflake" in llm._llm_type:
|
|
545
|
+
_clean_snowflake_prompt(prompt)
|
|
546
|
+
|
|
547
|
+
try:
|
|
548
|
+
response = _call_llm_with_timeout(llm, prompt, timeout=timeout)
|
|
549
|
+
except TimeoutError as e:
|
|
550
|
+
raise TimeoutError(f"LLM call timed out while answering question: {str(e)}")
|
|
551
|
+
except Exception as e:
|
|
552
|
+
raise Exception(f"LLM call failed while answering question: {str(e)}")
|
|
553
|
+
|
|
554
|
+
if hasattr(response, "content"):
|
|
555
|
+
response_text = response.content
|
|
556
|
+
elif isinstance(response, str):
|
|
557
|
+
response_text = response
|
|
558
|
+
else:
|
|
559
|
+
raise ValueError("Unexpected response format from LLM.")
|
|
560
|
+
|
|
561
|
+
usage_metadata = {
|
|
562
|
+
"answer_question": {
|
|
563
|
+
"approximate": apx_token_count,
|
|
564
|
+
# **(response.usage_metadata or response.usage or {}),
|
|
565
|
+
**(response.usage_metadata or {}),
|
|
566
|
+
},
|
|
567
|
+
}
|
|
568
|
+
if debug:
|
|
569
|
+
usage_metadata["answer_question"]["p_hash"] = encrypt_prompt(prompt)
|
|
570
|
+
|
|
571
|
+
return {
|
|
572
|
+
"answer": response_text,
|
|
573
|
+
"usage_metadata": usage_metadata,
|
|
574
|
+
}
|
|
575
|
+
|