langchain-timbr 2.1.5__py3-none-any.whl → 2.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,7 +12,7 @@
12
12
  from .timbr_llm_connector import TimbrLlmConnector
13
13
  from .llm_wrapper.llm_wrapper import LlmWrapper, LlmTypes
14
14
 
15
- from .utils.timbr_llm_utils import (
15
+ from .utils.timbr_utils import (
16
16
  generate_key,
17
17
  encrypt_prompt,
18
18
  decrypt_prompt,
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.1.5'
32
- __version_tuple__ = version_tuple = (2, 1, 5)
31
+ __version__ = version = '2.1.7'
32
+ __version_tuple__ = version_tuple = (2, 1, 7)
33
33
 
34
34
  __commit_id__ = commit_id = None
langchain_timbr/config.py CHANGED
@@ -27,4 +27,8 @@ llm_client_id = os.environ.get('LLM_CLIENT_ID', None)
27
27
  llm_client_secret = os.environ.get('LLM_CLIENT_SECRET', None)
28
28
  llm_endpoint = os.environ.get('LLM_ENDPOINT', None)
29
29
  llm_api_version = os.environ.get('LLM_API_VERSION', None)
30
- llm_scope = os.environ.get('LLM_SCOPE', "https://cognitiveservices.azure.com/.default") # e.g. "api://<your-client-id>/.default"
30
+ llm_scope = os.environ.get('LLM_SCOPE', "https://cognitiveservices.azure.com/.default") # e.g. "api://<your-client-id>/.default"
31
+
32
+ # Whether to enable reasoning during SQL generation
33
+ enable_reasoning = to_boolean(os.environ.get('ENABLE_REASONING', 'false'))
34
+ reasoning_steps = to_integer(os.environ.get('REASONING_STEPS', 2))
@@ -42,6 +42,8 @@ class ExecuteTimbrQueryChain(Chain):
42
42
  is_jwt: Optional[bool] = False,
43
43
  jwt_tenant_id: Optional[str] = None,
44
44
  conn_params: Optional[dict] = None,
45
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
46
+ reasoning_steps: Optional[int] = config.reasoning_steps,
45
47
  debug: Optional[bool] = False,
46
48
  **kwargs,
47
49
  ):
@@ -69,6 +71,8 @@ class ExecuteTimbrQueryChain(Chain):
69
71
  :param is_jwt: Whether to use JWT authentication (default is False).
70
72
  :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
71
73
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
74
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
75
+ :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
72
76
  :param kwargs: Additional arguments to pass to the base
73
77
  :return: A list of rows from the Timbr query
74
78
 
@@ -137,6 +141,8 @@ class ExecuteTimbrQueryChain(Chain):
137
141
  self._jwt_tenant_id = jwt_tenant_id
138
142
  self._debug = to_boolean(debug)
139
143
  self._conn_params = conn_params or {}
144
+ self._enable_reasoning = to_boolean(enable_reasoning)
145
+ self._reasoning_steps = to_integer(reasoning_steps)
140
146
 
141
147
 
142
148
  @property
@@ -209,6 +215,8 @@ class ExecuteTimbrQueryChain(Chain):
209
215
  note=(self._note or '') + err_txt,
210
216
  db_is_case_sensitive=self._db_is_case_sensitive,
211
217
  graph_depth=self._graph_depth,
218
+ enable_reasoning=self._enable_reasoning,
219
+ reasoning_steps=self._reasoning_steps,
212
220
  debug=self._debug,
213
221
  )
214
222
 
@@ -239,6 +247,7 @@ class ExecuteTimbrQueryChain(Chain):
239
247
  concept_name = inputs.get("concept", self._concept)
240
248
  is_sql_valid = True
241
249
  error = None
250
+ reasoning_status = None
242
251
  usage_metadata = {}
243
252
 
244
253
  if sql and self._should_validate_sql:
@@ -255,6 +264,7 @@ class ExecuteTimbrQueryChain(Chain):
255
264
  schema_name = generate_res.get("schema", schema_name)
256
265
  concept_name = generate_res.get("concept", concept_name)
257
266
  is_sql_valid = generate_res.get("is_sql_valid")
267
+ reasoning_status = generate_res.get("reasoning_status")
258
268
  if not is_sql_valid and not self._should_validate_sql:
259
269
  is_sql_valid = True
260
270
 
@@ -293,6 +303,7 @@ class ExecuteTimbrQueryChain(Chain):
293
303
  "schema": schema_name,
294
304
  "concept": concept_name,
295
305
  "error": error if not is_sql_valid else None,
306
+ "reasoning_status": reasoning_status,
296
307
  self.usage_metadata_key: usage_metadata,
297
308
  }
298
309
 
@@ -39,6 +39,8 @@ class GenerateTimbrSqlChain(Chain):
39
39
  is_jwt: Optional[bool] = False,
40
40
  jwt_tenant_id: Optional[str] = None,
41
41
  conn_params: Optional[dict] = None,
42
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
43
+ reasoning_steps: Optional[int] = config.reasoning_steps,
42
44
  debug: Optional[bool] = False,
43
45
  **kwargs,
44
46
  ):
@@ -64,6 +66,9 @@ class GenerateTimbrSqlChain(Chain):
64
66
  :param is_jwt: Whether to use JWT authentication (default is False).
65
67
  :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
66
68
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
69
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
70
+ :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
71
+ :param debug: Whether to enable debug mode for detailed logging
67
72
  :param kwargs: Additional arguments to pass to the base
68
73
 
69
74
  ## Example
@@ -129,6 +134,8 @@ class GenerateTimbrSqlChain(Chain):
129
134
  self._jwt_tenant_id = jwt_tenant_id
130
135
  self._debug = to_boolean(debug)
131
136
  self._conn_params = conn_params or {}
137
+ self._enable_reasoning = to_boolean(enable_reasoning)
138
+ self._reasoning_steps = to_integer(reasoning_steps)
132
139
 
133
140
 
134
141
  @property
@@ -184,6 +191,8 @@ class GenerateTimbrSqlChain(Chain):
184
191
  note=self._note,
185
192
  db_is_case_sensitive=self._db_is_case_sensitive,
186
193
  graph_depth=self._graph_depth,
194
+ enable_reasoning=self._enable_reasoning,
195
+ reasoning_steps=self._reasoning_steps,
187
196
  debug=self._debug,
188
197
  )
189
198
 
@@ -197,5 +206,6 @@ class GenerateTimbrSqlChain(Chain):
197
206
  "concept": concept,
198
207
  "is_sql_valid": generate_res.get("is_sql_valid"),
199
208
  "error": generate_res.get("error"),
209
+ "reasoning_status": generate_res.get("reasoning_status"),
200
210
  self.usage_metadata_key: generate_res.get("usage_metadata"),
201
211
  }
@@ -6,6 +6,7 @@ from langchain.schema import AgentAction, AgentFinish
6
6
  from ..utils.general import parse_list, to_boolean, to_integer
7
7
  from .execute_timbr_query_chain import ExecuteTimbrQueryChain
8
8
  from .generate_answer_chain import GenerateAnswerChain
9
+ from .. import config
9
10
 
10
11
  class TimbrSqlAgent(BaseSingleActionAgent):
11
12
  def __init__(
@@ -34,6 +35,8 @@ class TimbrSqlAgent(BaseSingleActionAgent):
34
35
  is_jwt: Optional[bool] = False,
35
36
  jwt_tenant_id: Optional[str] = None,
36
37
  conn_params: Optional[dict] = None,
38
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
39
+ reasoning_steps: Optional[int] = config.reasoning_steps,
37
40
  debug: Optional[bool] = False
38
41
  ):
39
42
  """
@@ -61,6 +64,8 @@ class TimbrSqlAgent(BaseSingleActionAgent):
61
64
  :param is_jwt: Whether to use JWT authentication (default is False).
62
65
  :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
63
66
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
67
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
68
+ :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
64
69
 
65
70
  ## Example
66
71
  ```
@@ -113,6 +118,8 @@ class TimbrSqlAgent(BaseSingleActionAgent):
113
118
  is_jwt=to_boolean(is_jwt),
114
119
  jwt_tenant_id=jwt_tenant_id,
115
120
  conn_params=conn_params,
121
+ enable_reasoning=to_boolean(enable_reasoning),
122
+ reasoning_steps=to_integer(reasoning_steps),
116
123
  debug=to_boolean(debug),
117
124
  )
118
125
  self._generate_answer = to_boolean(generate_answer)
@@ -173,6 +180,7 @@ class TimbrSqlAgent(BaseSingleActionAgent):
173
180
  "sql": None,
174
181
  "schema": None,
175
182
  "concept": None,
183
+ "reasoning_status": None,
176
184
  "usage_metadata": {},
177
185
  },
178
186
  log="Empty input received"
@@ -200,6 +208,7 @@ class TimbrSqlAgent(BaseSingleActionAgent):
200
208
  "schema": result.get("schema", ""),
201
209
  "concept": result.get("concept", ""),
202
210
  "error": result.get("error", None),
211
+ "reasoning_status": result.get("reasoning_status", None),
203
212
  "usage_metadata": usage_metadata,
204
213
  },
205
214
  log=f"Successfully executed query on concept: {result.get('concept', '')}"
@@ -214,6 +223,7 @@ class TimbrSqlAgent(BaseSingleActionAgent):
214
223
  "sql": None,
215
224
  "schema": None,
216
225
  "concept": None,
226
+ "reasoning_status": None,
217
227
  "usage_metadata": {},
218
228
  },
219
229
  log=error_context
@@ -234,6 +244,7 @@ class TimbrSqlAgent(BaseSingleActionAgent):
234
244
  "sql": None,
235
245
  "schema": None,
236
246
  "concept": None,
247
+ "reasoning_status": None,
237
248
  "usage_metadata": {},
238
249
  },
239
250
  log="Empty or whitespace-only input received"
@@ -274,6 +285,7 @@ class TimbrSqlAgent(BaseSingleActionAgent):
274
285
  "schema": result.get("schema", ""),
275
286
  "concept": result.get("concept", ""),
276
287
  "error": result.get("error", None),
288
+ "reasoning_status": result.get("reasoning_status", None),
277
289
  "usage_metadata": usage_metadata,
278
290
  },
279
291
  log=f"Successfully executed query on concept: {result.get('concept', '')}"
@@ -288,6 +300,7 @@ class TimbrSqlAgent(BaseSingleActionAgent):
288
300
  "sql": None,
289
301
  "schema": None,
290
302
  "concept": None,
303
+ "reasoning_status": None,
291
304
  "usage_metadata": {},
292
305
  },
293
306
  log=error_context
@@ -332,6 +345,8 @@ def create_timbr_sql_agent(
332
345
  is_jwt: Optional[bool] = False,
333
346
  jwt_tenant_id: Optional[str] = None,
334
347
  conn_params: Optional[dict] = None,
348
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
349
+ reasoning_steps: Optional[int] = config.reasoning_steps,
335
350
  debug: Optional[bool] = False
336
351
  ) -> AgentExecutor:
337
352
  """
@@ -361,6 +376,8 @@ def create_timbr_sql_agent(
361
376
  :param is_jwt: Whether to use JWT authentication (default is False).
362
377
  :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
363
378
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
379
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
380
+ :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
364
381
 
365
382
  Returns:
366
383
  AgentExecutor: Configured agent executor ready to use
@@ -427,6 +444,8 @@ def create_timbr_sql_agent(
427
444
  is_jwt=is_jwt,
428
445
  jwt_tenant_id=jwt_tenant_id,
429
446
  conn_params=conn_params,
447
+ enable_reasoning=enable_reasoning,
448
+ reasoning_steps=reasoning_steps,
430
449
  debug=debug,
431
450
  )
432
451
 
@@ -40,6 +40,8 @@ class ValidateTimbrSqlChain(Chain):
40
40
  is_jwt: Optional[bool] = False,
41
41
  jwt_tenant_id: Optional[str] = None,
42
42
  conn_params: Optional[dict] = None,
43
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
44
+ reasoning_steps: Optional[int] = config.reasoning_steps,
43
45
  debug: Optional[bool] = False,
44
46
  **kwargs,
45
47
  ):
@@ -64,6 +66,8 @@ class ValidateTimbrSqlChain(Chain):
64
66
  :param is_jwt: Whether to use JWT authentication (default is False).
65
67
  :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
66
68
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
69
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
70
+ :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
67
71
  :param kwargs: Additional arguments to pass to the base
68
72
 
69
73
  ## Example
@@ -127,8 +131,10 @@ class ValidateTimbrSqlChain(Chain):
127
131
  self._verify_ssl = to_boolean(verify_ssl)
128
132
  self._is_jwt = to_boolean(is_jwt)
129
133
  self._jwt_tenant_id = jwt_tenant_id
130
- self._debug = to_boolean(debug)
131
134
  self._conn_params = conn_params or {}
135
+ self._enable_reasoning = to_boolean(enable_reasoning)
136
+ self._reasoning_steps = to_integer(reasoning_steps)
137
+ self._debug = to_boolean(debug)
132
138
 
133
139
 
134
140
  @property
@@ -193,6 +199,8 @@ class ValidateTimbrSqlChain(Chain):
193
199
  note=f"{prompt_extension}The original SQL query (`{sql}`) was invalid with this error from query {error}. Please take this in consideration while generating the query.",
194
200
  db_is_case_sensitive=self._db_is_case_sensitive,
195
201
  graph_depth=self._graph_depth,
202
+ enable_reasoning=self._enable_reasoning,
203
+ reasoning_steps=self._reasoning_steps,
196
204
  debug=self._debug,
197
205
  )
198
206
  sql = generate_res.get("sql", "")
@@ -200,6 +208,7 @@ class ValidateTimbrSqlChain(Chain):
200
208
  concept = generate_res.get("concept", self._concept)
201
209
  usage_metadata.update(generate_res.get("usage_metadata", {}))
202
210
  is_sql_valid = generate_res.get("is_sql_valid")
211
+ reasoning_status = generate_res.get("reasoning_status")
203
212
  error = generate_res.get("error")
204
213
 
205
214
  return {
@@ -208,5 +217,6 @@ class ValidateTimbrSqlChain(Chain):
208
217
  "concept": concept,
209
218
  "is_sql_valid": is_sql_valid,
210
219
  "error": error,
220
+ "reasoning_status": reasoning_status,
211
221
  self.usage_metadata_key: usage_metadata,
212
222
  }
@@ -3,7 +3,7 @@ from langchain.llms.base import LLM
3
3
  from langgraph.graph import StateGraph
4
4
 
5
5
  from ..langchain.execute_timbr_query_chain import ExecuteTimbrQueryChain
6
-
6
+ from .. import config
7
7
 
8
8
  class ExecuteSemanticQueryNode:
9
9
  """
@@ -36,6 +36,8 @@ class ExecuteSemanticQueryNode:
36
36
  is_jwt: Optional[bool] = False,
37
37
  jwt_tenant_id: Optional[str] = None,
38
38
  conn_params: Optional[dict] = None,
39
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
40
+ reasoning_steps: Optional[int] = config.reasoning_steps,
39
41
  debug: Optional[bool] = False,
40
42
  **kwargs,
41
43
  ):
@@ -63,6 +65,8 @@ class ExecuteSemanticQueryNode:
63
65
  :param is_jwt: Whether to use JWT authentication (default is False).
64
66
  :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
65
67
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
68
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
69
+ :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
66
70
  :return: A list of rows from the Timbr query
67
71
  """
68
72
  self.chain = ExecuteTimbrQueryChain(
@@ -89,6 +93,8 @@ class ExecuteSemanticQueryNode:
89
93
  is_jwt=is_jwt,
90
94
  jwt_tenant_id=jwt_tenant_id,
91
95
  conn_params=conn_params,
96
+ enable_reasoning=enable_reasoning,
97
+ reasoning_steps=reasoning_steps,
92
98
  debug=debug,
93
99
  **kwargs,
94
100
  )
@@ -3,6 +3,7 @@ from langchain.llms.base import LLM
3
3
  from langgraph.graph import StateGraph
4
4
 
5
5
  from ..langchain.generate_timbr_sql_chain import GenerateTimbrSqlChain
6
+ from .. import config
6
7
 
7
8
  class GenerateTimbrSqlNode:
8
9
  """
@@ -32,6 +33,8 @@ class GenerateTimbrSqlNode:
32
33
  is_jwt: Optional[bool] = False,
33
34
  jwt_tenant_id: Optional[str] = None,
34
35
  conn_params: Optional[dict] = None,
36
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
37
+ reasoning_steps: Optional[int] = config.reasoning_steps,
35
38
  debug: Optional[bool] = False,
36
39
  **kwargs,
37
40
  ):
@@ -57,6 +60,8 @@ class GenerateTimbrSqlNode:
57
60
  :param is_jwt: Whether to use JWT authentication (default: False)
58
61
  :param jwt_tenant_id: Tenant ID for JWT authentication when using multi-tenant setup
59
62
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
63
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
64
+ :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
60
65
  """
61
66
  self.chain = GenerateTimbrSqlChain(
62
67
  llm=llm,
@@ -80,6 +85,8 @@ class GenerateTimbrSqlNode:
80
85
  is_jwt=is_jwt,
81
86
  jwt_tenant_id=jwt_tenant_id,
82
87
  conn_params=conn_params,
88
+ enable_reasoning=enable_reasoning,
89
+ reasoning_steps=reasoning_steps,
83
90
  debug=debug,
84
91
  **kwargs,
85
92
  )
@@ -33,6 +33,8 @@ class ValidateSemanticSqlNode:
33
33
  is_jwt: Optional[bool] = False,
34
34
  jwt_tenant_id: Optional[str] = None,
35
35
  conn_params: Optional[dict] = None,
36
+ enable_reasoning: Optional[bool] = False,
37
+ reasoning_steps: Optional[int] = 2,
36
38
  debug: Optional[bool] = False,
37
39
  **kwargs,
38
40
  ):
@@ -57,6 +59,8 @@ class ValidateSemanticSqlNode:
57
59
  :param is_jwt: Whether to use JWT authentication (default: False)
58
60
  :param jwt_tenant_id: Tenant ID for JWT authentication when using multi-tenant setup
59
61
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
62
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
63
+ :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
60
64
  """
61
65
  self.chain = ValidateTimbrSqlChain(
62
66
  llm=llm,
@@ -79,6 +83,8 @@ class ValidateSemanticSqlNode:
79
83
  is_jwt=is_jwt,
80
84
  jwt_tenant_id=jwt_tenant_id,
81
85
  conn_params=conn_params,
86
+ enable_reasoning=enable_reasoning,
87
+ reasoning_steps=reasoning_steps,
82
88
  debug=debug,
83
89
  **kwargs,
84
90
  )
@@ -184,6 +184,16 @@ class PromptService:
184
184
  return self._fetch_template("llm_prompts/generate_sql")
185
185
 
186
186
 
187
+ def get_generate_sql_reasoning_template(self) -> ChatPromptTemplate:
188
+ """
189
+ Get generate SQL reasoning template from API service (cached)
190
+
191
+ Returns:
192
+ ChatPromptTemplate object
193
+ """
194
+ return self._fetch_template("llm_prompts/generate_sql_reasoning")
195
+
196
+
187
197
  def get_generate_answer_template(self) -> ChatPromptTemplate:
188
198
  """
189
199
  Get generate answer template from API service (cached)
@@ -264,6 +274,22 @@ def get_generate_sql_prompt_template(
264
274
  return PromptTemplateWrapper(prompt_service, "get_generate_sql_template")
265
275
 
266
276
 
277
+ def get_generate_sql_reasoning_prompt_template(
278
+ conn_params: Optional[dict] = None
279
+ ) -> PromptTemplateWrapper:
280
+ """
281
+ Get generate SQL reasoning prompt template wrapper
282
+
283
+ Args:
284
+ conn_params: Connection parameters including url, token, is_jwt, and jwt_tenant_id
285
+
286
+ Returns:
287
+ PromptTemplateWrapper for generate SQL reasoning
288
+ """
289
+ prompt_service = PromptService(conn_params=conn_params)
290
+ return PromptTemplateWrapper(prompt_service, "get_generate_sql_reasoning_template")
291
+
292
+
267
293
  def get_qa_prompt_template(
268
294
  conn_params: Optional[dict] = None
269
295
  ) -> PromptTemplateWrapper:
@@ -1,20 +1,18 @@
1
1
  from typing import Any, Optional
2
2
  from langchain.llms.base import LLM
3
- import base64, hashlib
4
- from cryptography.fernet import Fernet
5
3
  from datetime import datetime
6
4
  import concurrent.futures
7
- import time
5
+ import json
8
6
 
9
- from .timbr_utils import get_datasources, get_tags, get_concepts, get_concept_properties, validate_sql, get_properties_description, get_relationships_description
7
+ from .timbr_utils import get_datasources, get_tags, get_concepts, get_concept_properties, validate_sql, get_properties_description, get_relationships_description, cache_with_version_check, encrypt_prompt
10
8
  from .prompt_service import (
11
9
  get_determine_concept_prompt_template,
12
10
  get_generate_sql_prompt_template,
11
+ get_generate_sql_reasoning_prompt_template,
13
12
  get_qa_prompt_template
14
13
  )
15
14
  from ..config import llm_timeout
16
15
 
17
-
18
16
  def _clean_snowflake_prompt(prompt: Any) -> None:
19
17
  import re
20
18
 
@@ -54,14 +52,6 @@ def _clean_snowflake_prompt(prompt: Any) -> None:
54
52
  prompt[1].content = clean_func(prompt[1].content) # User message
55
53
 
56
54
 
57
- def generate_key() -> bytes:
58
- """Generate a new Fernet secret key."""
59
- passcode = b"lucylit2025"
60
- hlib = hashlib.md5()
61
- hlib.update(passcode)
62
- return base64.urlsafe_b64encode(hlib.hexdigest().encode('utf-8'))
63
-
64
-
65
55
  def _call_llm_with_timeout(llm: LLM, prompt: Any, timeout: int = 60) -> Any:
66
56
  """
67
57
  Call LLM with timeout to prevent hanging.
@@ -90,35 +80,9 @@ def _call_llm_with_timeout(llm: LLM, prompt: Any, timeout: int = 60) -> Any:
90
80
  except Exception as e:
91
81
  raise e
92
82
 
93
- ENCRYPT_KEY = generate_key()
94
83
  MEASURES_DESCRIPTION = "The following columns are calculated measures and can only be aggregated with an aggregate function: COUNT/SUM/AVG/MIN/MAX (count distinct is not allowed)"
95
84
  TRANSITIVE_RELATIONSHIP_DESCRIPTION = "Transitive relationship columns allow you to access data through multiple relationship hops. These columns follow the pattern `<relationship_name>[<table_name>*<number>].<column_name>` where the number after the asterisk (*) indicates how many relationship levels to traverse. For example, `acquired_by[company*4].company_name` means 'go through up to 4 levels of the acquired_by relationship to get the company name', while columns ending with '_transitivity_level' indicate the actual relationship depth (Cannot be null or 0 - level 1 represents direct relationships, while levels 2, 3, 4, etc. represent indirect relationships through multiple hops. To filter by relationship type, use `_transitivity_level = 1` for direct relationships only, `_transitivity_level > 1` for indirect relationships only."
96
85
 
97
- def encrypt_prompt(prompt: Any, key: Optional[bytes] = ENCRYPT_KEY) -> bytes:
98
- """Serialize & encrypt the prompt; returns a URL-safe token."""
99
- # build prompt_text as before…
100
- if isinstance(prompt, str):
101
- text = prompt
102
- elif isinstance(prompt, list):
103
- parts = []
104
- for message in prompt:
105
- if hasattr(message, "content"):
106
- parts.append(f"{message.type}: {message.content}")
107
- else:
108
- parts.append(str(message))
109
- text = "\n".join(parts)
110
- else:
111
- text = str(prompt)
112
-
113
- f = Fernet(key)
114
- return f.encrypt(text.encode()).decode('utf-8')
115
-
116
-
117
- def decrypt_prompt(token: bytes, key: bytes) -> str:
118
- """Decrypt the token and return the original prompt string."""
119
- f = Fernet(key)
120
- return f.decrypt(token).decode()
121
-
122
86
 
123
87
  def _prompt_to_string(prompt: Any) -> str:
124
88
  prompt_text = ''
@@ -135,7 +99,7 @@ def _prompt_to_string(prompt: Any) -> str:
135
99
  return prompt_text.strip()
136
100
 
137
101
 
138
- def _calculate_token_count(llm: LLM, prompt: str) -> int:
102
+ def _calculate_token_count(llm: LLM, prompt: str | list[Any]) -> int:
139
103
  """
140
104
  Calculate the token count for a given prompt text using the specified LLM.
141
105
  Falls back to basic if the LLM doesn't support token counting.
@@ -186,24 +150,108 @@ def _get_response_text(response: Any) -> str:
186
150
 
187
151
  return response_text
188
152
 
189
- def _extract_usage_metadata(response: Any) -> dict:
190
- usage_metadata = response.response_metadata
191
-
192
- if usage_metadata and 'usage' in usage_metadata:
193
- usage_metadata = usage_metadata['usage']
194
-
195
- if not usage_metadata and 'usage_metadata' in response:
196
- usage_metadata = response.usage_metadata
197
- if usage_metadata and 'usage' in usage_metadata:
198
- usage_metadata = usage_metadata['usage']
199
-
200
- if not usage_metadata and 'usage' in response:
201
- usage_metadata = response.usage
202
- if usage_metadata and 'usage' in usage_metadata:
203
- usage_metadata = usage_metadata['usage']
204
153
 
154
+ def _extract_usage_metadata(response: Any) -> dict:
155
+ """
156
+ Extract usage metadata from LLM response across different providers.
157
+
158
+ Different providers return usage data in different formats:
159
+ - OpenAI/AzureOpenAI: response.response_metadata['token_usage'] or response.usage_metadata
160
+ - Anthropic: response.response_metadata['usage'] or response.usage_metadata
161
+ - Google/VertexAI: response.usage_metadata
162
+ - Bedrock: response.response_metadata['usage'] or response.response_metadata (direct ResponseMetadata)
163
+ - Snowflake: response.response_metadata['usage']
164
+ - Databricks: response.usage_metadata or response.response_metadata
165
+ """
166
+ usage_metadata = {}
167
+
168
+ # Try to get response_metadata first (most common)
169
+ if hasattr(response, 'response_metadata') and response.response_metadata:
170
+ resp_meta = response.response_metadata
171
+
172
+ # Check for 'usage' key (Anthropic, Bedrock, Snowflake)
173
+ if 'usage' in resp_meta:
174
+ usage_metadata = resp_meta['usage']
175
+ # Check for 'token_usage' key (OpenAI/AzureOpenAI)
176
+ elif 'token_usage' in resp_meta:
177
+ usage_metadata = resp_meta['token_usage']
178
+ # Check for direct token fields in response_metadata (some Bedrock responses)
179
+ elif any(key in resp_meta for key in ['input_tokens', 'output_tokens', 'total_tokens',
180
+ 'prompt_tokens', 'completion_tokens']):
181
+ usage_metadata = {
182
+ k: v for k, v in resp_meta.items()
183
+ if k in ['input_tokens', 'output_tokens', 'total_tokens',
184
+ 'prompt_tokens', 'completion_tokens']
185
+ }
186
+
187
+ # Try usage_metadata attribute (Google, VertexAI, some others)
188
+ if not usage_metadata and hasattr(response, 'usage_metadata') and response.usage_metadata:
189
+ usage_meta = response.usage_metadata
190
+ if isinstance(usage_meta, dict):
191
+ # If it has a nested 'usage' key
192
+ if 'usage' in usage_meta:
193
+ usage_metadata = usage_meta['usage']
194
+ else:
195
+ usage_metadata = usage_meta
196
+ else:
197
+ # Handle case where usage_metadata is an object with attributes
198
+ usage_metadata = {
199
+ k: getattr(usage_meta, k)
200
+ for k in dir(usage_meta)
201
+ if not k.startswith('_') and not callable(getattr(usage_meta, k))
202
+ }
203
+
204
+ # Try direct usage attribute (fallback)
205
+ if not usage_metadata and hasattr(response, 'usage') and response.usage:
206
+ usage = response.usage
207
+ if isinstance(usage, dict):
208
+ if 'usage' in usage:
209
+ usage_metadata = usage['usage']
210
+ else:
211
+ usage_metadata = usage
212
+ else:
213
+ # Handle case where usage is an object with attributes
214
+ usage_metadata = {
215
+ k: getattr(usage, k)
216
+ for k in dir(usage)
217
+ if not k.startswith('_') and not callable(getattr(usage, k))
218
+ }
219
+
220
+ # Normalize token field names to standard format
221
+ # Different providers use different names: input_tokens vs prompt_tokens, etc.
222
+ if usage_metadata:
223
+ normalized = {}
224
+
225
+ # Map various input token field names
226
+ if 'input_tokens' in usage_metadata:
227
+ normalized['input_tokens'] = usage_metadata['input_tokens']
228
+ elif 'prompt_tokens' in usage_metadata:
229
+ normalized['input_tokens'] = usage_metadata['prompt_tokens']
230
+
231
+ # Map various output token field names
232
+ if 'output_tokens' in usage_metadata:
233
+ normalized['output_tokens'] = usage_metadata['output_tokens']
234
+ elif 'completion_tokens' in usage_metadata:
235
+ normalized['output_tokens'] = usage_metadata['completion_tokens']
236
+
237
+ # Map total tokens
238
+ if 'total_tokens' in usage_metadata:
239
+ normalized['total_tokens'] = usage_metadata['total_tokens']
240
+ elif 'input_tokens' in normalized and 'output_tokens' in normalized:
241
+ # Calculate total if not provided
242
+ normalized['total_tokens'] = normalized['input_tokens'] + normalized['output_tokens']
243
+
244
+ # Keep any other metadata fields that don't conflict
245
+ for key, value in usage_metadata.items():
246
+ if key not in ['input_tokens', 'prompt_tokens', 'output_tokens',
247
+ 'completion_tokens', 'total_tokens']:
248
+ normalized[key] = value
249
+
250
+ return normalized if normalized else usage_metadata
251
+
205
252
  return usage_metadata
206
253
 
254
+
207
255
  def determine_concept(
208
256
  question: str,
209
257
  llm: LLM,
@@ -396,6 +444,199 @@ def _get_active_datasource(conn_params: dict) -> dict:
396
444
  return datasources[0] if datasources else None
397
445
 
398
446
 
447
+ def _evaluate_sql_enable_reasoning(
448
+ question: str,
449
+ sql_query: str,
450
+ llm: LLM,
451
+ conn_params: dict,
452
+ timeout: int,
453
+ ) -> dict:
454
+ """
455
+ Evaluate if the generated SQL correctly answers the business question.
456
+
457
+ Returns:
458
+ dict with 'assessment' ('correct'|'partial'|'incorrect') and 'reasoning'
459
+ """
460
+ generate_sql_reasoning_template = get_generate_sql_reasoning_prompt_template(conn_params)
461
+ prompt = generate_sql_reasoning_template.format_messages(
462
+ question=question.strip(),
463
+ sql_query=sql_query.strip(),
464
+ )
465
+
466
+ apx_token_count = _calculate_token_count(llm, prompt)
467
+ if hasattr(llm, "_llm_type") and "snowflake" in llm._llm_type:
468
+ _clean_snowflake_prompt(prompt)
469
+
470
+ response = _call_llm_with_timeout(llm, prompt, timeout=timeout)
471
+
472
+ # Extract JSON from response content (handle markdown code blocks)
473
+ content = response.content.strip()
474
+
475
+ # Remove markdown code block markers if present
476
+ if content.startswith("```json"):
477
+ content = content[7:] # Remove ```json
478
+ elif content.startswith("```"):
479
+ content = content[3:] # Remove ```
480
+
481
+ if content.endswith("```"):
482
+ content = content[:-3] # Remove closing ```
483
+
484
+ content = content.strip()
485
+
486
+ # Parse JSON response
487
+ evaluation = json.loads(content)
488
+
489
+ return {
490
+ "evaluation": evaluation,
491
+ "apx_token_count": apx_token_count,
492
+ "usage_metadata": _extract_usage_metadata(response),
493
+ }
494
+
495
+
496
+ @cache_with_version_check
497
+ def _build_sql_generation_context(
498
+ conn_params: dict,
499
+ schema: str,
500
+ concept: str,
501
+ concept_metadata: dict,
502
+ graph_depth: int,
503
+ include_tags: Optional[str],
504
+ exclude_properties: Optional[list],
505
+ db_is_case_sensitive: bool,
506
+ max_limit: int,
507
+ ) -> dict:
508
+ """
509
+ Prepare the complete SQL generation context by gathering all necessary metadata.
510
+
511
+ This includes:
512
+ - Datasource information
513
+ - Concept properties (columns, measures, relationships)
514
+ - Property tags
515
+ - Building column/measure/relationship descriptions
516
+ - Assembling the final context dictionary
517
+
518
+ Returns:
519
+ dict containing all context needed for SQL generation prompts
520
+ """
521
+ datasource_type = _get_active_datasource(conn_params).get('target_type')
522
+
523
+ properties_desc = get_properties_description(conn_params=conn_params)
524
+ relationships_desc = get_relationships_description(conn_params=conn_params)
525
+
526
+ concept_properties_metadata = get_concept_properties(
527
+ schema=schema,
528
+ concept_name=concept,
529
+ conn_params=conn_params,
530
+ properties_desc=properties_desc,
531
+ relationships_desc=relationships_desc,
532
+ graph_depth=graph_depth
533
+ )
534
+ columns = concept_properties_metadata.get('columns', [])
535
+ measures = concept_properties_metadata.get('measures', [])
536
+ relationships = concept_properties_metadata.get('relationships', {})
537
+ tags = get_tags(conn_params=conn_params, include_tags=include_tags).get('property_tags')
538
+
539
+ columns_str = _build_columns_str(columns, columns_tags=tags, exclude=exclude_properties)
540
+ measures_str = _build_columns_str(measures, tags, exclude=exclude_properties)
541
+ rel_prop_str = _build_rel_columns_str(relationships, columns_tags=tags, exclude_properties=exclude_properties)
542
+
543
+ if rel_prop_str:
544
+ measures_str += f"\n{rel_prop_str}"
545
+
546
+ # Determine if relationships have transitive properties
547
+ has_transitive_relationships = any(
548
+ rel.get('is_transitive')
549
+ for rel in relationships.values()
550
+ ) if relationships else False
551
+
552
+ concept_description = f"- Description: {concept_metadata.get('description')}\n" if concept_metadata and concept_metadata.get('description') else ""
553
+ concept_tags = concept_metadata.get('tags') if concept_metadata and concept_metadata.get('tags') else ""
554
+
555
+ cur_date = datetime.now().strftime("%Y-%m-%d")
556
+
557
+ # Build context descriptions
558
+ sensitivity_txt = "- Ensure value comparisons are case-insensitive, e.g., use LOWER(column) = 'value'.\n" if db_is_case_sensitive else ""
559
+ measures_context = f"- {MEASURES_DESCRIPTION}: {measures_str}\n" if measures_str else ""
560
+ transitive_context = f"- {TRANSITIVE_RELATIONSHIP_DESCRIPTION}\n" if has_transitive_relationships else ""
561
+
562
+ return {
563
+ 'cur_date': cur_date,
564
+ 'datasource_type': datasource_type or 'standard sql',
565
+ 'schema': schema,
566
+ 'concept': concept,
567
+ 'concept_description': concept_description or "",
568
+ 'concept_tags': concept_tags or "",
569
+ 'columns_str': columns_str,
570
+ 'measures_context': measures_context,
571
+ 'transitive_context': transitive_context,
572
+ 'sensitivity_txt': sensitivity_txt,
573
+ 'max_limit': max_limit,
574
+ }
575
+
576
+
577
+ def _generate_sql_with_llm(
578
+ question: str,
579
+ llm: LLM,
580
+ conn_params: dict,
581
+ generate_sql_prompt: Any,
582
+ current_context: dict,
583
+ note: str,
584
+ should_validate_sql: bool,
585
+ timeout: int,
586
+ debug: bool = False,
587
+ ) -> dict:
588
+ """
589
+ Generate SQL using LLM based on the provided context and note.
590
+ This function is used for both initial SQL generation and regeneration with feedback.
591
+
592
+ Args:
593
+ current_context: dict containing datasource_type, schema, concept, concept_description,
594
+ concept_tags, columns_str, measures_context, transitive_context,
595
+ sensitivity_txt, max_limit, cur_date
596
+ note: Additional instructions/feedback to include in the prompt
597
+
598
+ Returns:
599
+ dict with 'sql', 'is_valid', 'error', 'apx_token_count', 'usage_metadata', 'p_hash' (if debug)
600
+ """
601
+ prompt = generate_sql_prompt.format_messages(
602
+ current_date=current_context['cur_date'],
603
+ datasource_type=current_context['datasource_type'],
604
+ schema=current_context['schema'],
605
+ concept=f"`{current_context['concept']}`",
606
+ description=current_context['concept_description'],
607
+ tags=current_context['concept_tags'],
608
+ question=question,
609
+ columns=current_context['columns_str'],
610
+ measures_context=current_context['measures_context'],
611
+ transitive_context=current_context['transitive_context'],
612
+ sensitivity_context=current_context['sensitivity_txt'],
613
+ max_limit=current_context['max_limit'],
614
+ note=note,
615
+ )
616
+
617
+ apx_token_count = _calculate_token_count(llm, prompt)
618
+ if hasattr(llm, "_llm_type") and "snowflake" in llm._llm_type:
619
+ _clean_snowflake_prompt(prompt)
620
+
621
+ response = _call_llm_with_timeout(llm, prompt, timeout=timeout)
622
+
623
+ result = {
624
+ "sql": _parse_sql_from_llm_response(response),
625
+ "apx_token_count": apx_token_count,
626
+ "usage_metadata": _extract_usage_metadata(response),
627
+ "is_valid": True,
628
+ "error": None,
629
+ }
630
+
631
+ if debug:
632
+ result["p_hash"] = encrypt_prompt(prompt)
633
+
634
+ if should_validate_sql:
635
+ result["is_valid"], result["error"] = validate_sql(result["sql"], conn_params)
636
+
637
+ return result
638
+
639
+
399
640
  def generate_sql(
400
641
  question: str,
401
642
  llm: LLM,
@@ -413,11 +654,14 @@ def generate_sql(
413
654
  note: Optional[str] = '',
414
655
  db_is_case_sensitive: Optional[bool] = False,
415
656
  graph_depth: Optional[int] = 1,
657
+ enable_reasoning: Optional[bool] = False,
658
+ reasoning_steps: Optional[int] = 2,
416
659
  debug: Optional[bool] = False,
417
660
  timeout: Optional[int] = None,
418
661
  ) -> dict[str, str]:
419
662
  usage_metadata = {}
420
663
  concept_metadata = None
664
+ reasoning_status = 'correct'
421
665
 
422
666
  # Use config default timeout if none provided
423
667
  if timeout is None:
@@ -450,22 +694,6 @@ def generate_sql(
450
694
  if not concept:
451
695
  raise Exception("No relevant concept found for the query.")
452
696
 
453
- datasource_type = _get_active_datasource(conn_params).get('target_type')
454
-
455
- properties_desc = get_properties_description(conn_params=conn_params)
456
- relationships_desc = get_relationships_description(conn_params=conn_params)
457
-
458
- concept_properties_metadata = get_concept_properties(schema=schema, concept_name=concept, conn_params=conn_params, properties_desc=properties_desc, relationships_desc=relationships_desc, graph_depth=graph_depth)
459
- columns, measures, relationships = concept_properties_metadata.get('columns', []), concept_properties_metadata.get('measures', []), concept_properties_metadata.get('relationships', {})
460
- tags = get_tags(conn_params=conn_params, include_tags=include_tags).get('property_tags')
461
-
462
- columns_str = _build_columns_str(columns, columns_tags=tags, exclude=exclude_properties)
463
- measures_str = _build_columns_str(measures, tags, exclude=exclude_properties)
464
- rel_prop_str = _build_rel_columns_str(relationships, columns_tags=tags, exclude_properties=exclude_properties)
465
-
466
- if rel_prop_str:
467
- measures_str += f"\n{rel_prop_str}"
468
-
469
697
  sql_query = None
470
698
  iteration = 0
471
699
  is_sql_valid = True
@@ -474,39 +702,39 @@ def generate_sql(
474
702
  iteration += 1
475
703
  err_txt = f"\nThe original SQL (`{sql_query}`) was invalid with error: {error}. Please generate a corrected query." if error and "snowflake" not in llm._llm_type else ""
476
704
 
477
- sensitivity_txt = "- Ensure value comparisons are case-insensitive, e.g., use LOWER(column) = 'value'.\n" if db_is_case_sensitive else ""
478
-
479
- measures_context = f"- {MEASURES_DESCRIPTION}: {measures_str}\n" if measures_str else ""
480
- has_transitive_relationships = any(
481
- rel.get('is_transitive')
482
- for rel in relationships.values()
483
- ) if relationships else False
484
- transitive_context = f"- {TRANSITIVE_RELATIONSHIP_DESCRIPTION}\n" if has_transitive_relationships else ""
485
- concept_description = f"- Description: {concept_metadata.get('description')}\n" if concept_metadata and concept_metadata.get('description') else ""
486
- concept_tags = concept_metadata.get('tags') if concept_metadata and concept_metadata.get('tags') else ""
487
- cur_date = datetime.now().strftime("%Y-%m-%d")
488
- prompt = generate_sql_prompt.format_messages(
489
- current_date=cur_date,
490
- datasource_type=datasource_type or 'standard sql',
491
- schema=schema,
492
- concept=f"`{concept}`",
493
- description=concept_description or "",
494
- tags=concept_tags or "",
495
- question=question,
496
- columns=columns_str,
497
- measures_context=measures_context,
498
- transitive_context=transitive_context,
499
- sensitivity_context=sensitivity_txt,
500
- max_limit=max_limit,
501
- note=note + err_txt,
502
- )
503
-
504
- apx_token_count = _calculate_token_count(llm, prompt)
505
- if "snowflake" in llm._llm_type:
506
- _clean_snowflake_prompt(prompt)
507
-
508
705
  try:
509
- response = _call_llm_with_timeout(llm, prompt, timeout=timeout)
706
+ result = _generate_sql_with_llm(
707
+ question=question,
708
+ llm=llm,
709
+ conn_params=conn_params,
710
+ generate_sql_prompt=generate_sql_prompt,
711
+ current_context=_build_sql_generation_context(
712
+ conn_params=conn_params,
713
+ schema=schema,
714
+ concept=concept,
715
+ concept_metadata=concept_metadata,
716
+ graph_depth=graph_depth,
717
+ include_tags=include_tags,
718
+ exclude_properties=exclude_properties,
719
+ db_is_case_sensitive=db_is_case_sensitive,
720
+ max_limit=max_limit),
721
+ note=note + err_txt,
722
+ should_validate_sql=should_validate_sql,
723
+ timeout=timeout,
724
+ debug=debug,
725
+ )
726
+
727
+ usage_metadata['generate_sql'] = {
728
+ "approximate": result['apx_token_count'],
729
+ **result['usage_metadata'],
730
+ }
731
+ if debug and 'p_hash' in result:
732
+ usage_metadata['generate_sql']["p_hash"] = result['p_hash']
733
+
734
+ sql_query = result['sql']
735
+ is_sql_valid = result['is_valid']
736
+ error = result['error']
737
+
510
738
  except TimeoutError as e:
511
739
  error = f"LLM call timed out: {str(e)}"
512
740
  raise Exception(error)
@@ -516,18 +744,73 @@ def generate_sql(
516
744
  continue
517
745
  else:
518
746
  raise Exception(error)
519
-
520
- usage_metadata['generate_sql'] = {
521
- "approximate": apx_token_count,
522
- **_extract_usage_metadata(response),
523
- }
524
- if debug:
525
- usage_metadata['generate_sql']["p_hash"] = encrypt_prompt(prompt)
526
-
527
- sql_query = _parse_sql_from_llm_response(response)
528
-
529
- if should_validate_sql:
530
- is_sql_valid, error = validate_sql(sql_query, conn_params)
747
+
748
+
749
+ if enable_reasoning and sql_query is not None:
750
+ for step in range(reasoning_steps):
751
+ try:
752
+ # Step 1: Evaluate the current SQL
753
+ eval_result = _evaluate_sql_enable_reasoning(
754
+ question=question,
755
+ sql_query=sql_query,
756
+ llm=llm,
757
+ conn_params=conn_params,
758
+ timeout=timeout,
759
+ )
760
+
761
+ usage_metadata[f'sql_reasoning_step_{step + 1}'] = {
762
+ "approximate": eval_result['apx_token_count'],
763
+ **eval_result['usage_metadata'],
764
+ }
765
+
766
+ evaluation = eval_result['evaluation']
767
+ reasoning_status = evaluation.get("assessment", "partial").lower()
768
+
769
+ if reasoning_status == "correct":
770
+ break
771
+
772
+ # Step 2: Regenerate SQL with feedback
773
+ evaluation_note = note + f"\n\nThe previously generated SQL: `{sql_query}` was assessed as '{evaluation.get('assessment')}' because: {evaluation.get('reasoning', '*could not determine cause*')}. Please provide a corrected SQL query that better answers the question: '{question}'."
774
+
775
+ # Increase graph depth for 2nd+ reasoning attempts, up to max of 3
776
+ context_graph_depth = min(3, int(graph_depth) + step) if graph_depth < 3 and step > 0 else graph_depth
777
+ regen_result = _generate_sql_with_llm(
778
+ question=question,
779
+ llm=llm,
780
+ conn_params=conn_params,
781
+ generate_sql_prompt=generate_sql_prompt,
782
+ current_context=_build_sql_generation_context(
783
+ conn_params=conn_params,
784
+ schema=schema,
785
+ concept=concept,
786
+ concept_metadata=concept_metadata,
787
+ graph_depth=context_graph_depth,
788
+ include_tags=include_tags,
789
+ exclude_properties=exclude_properties,
790
+ db_is_case_sensitive=db_is_case_sensitive,
791
+ max_limit=max_limit),
792
+ note=evaluation_note,
793
+ should_validate_sql=should_validate_sql,
794
+ timeout=timeout,
795
+ debug=debug,
796
+ )
797
+
798
+ usage_metadata[f'generate_sql_reasoning_step_{step + 1}'] = {
799
+ "approximate": regen_result['apx_token_count'],
800
+ **regen_result['usage_metadata'],
801
+ }
802
+ if debug and 'p_hash' in regen_result:
803
+ usage_metadata[f'generate_sql_reasoning_step_{step + 1}']['p_hash'] = regen_result['p_hash']
804
+
805
+ sql_query = regen_result['sql']
806
+ is_sql_valid = regen_result['is_valid']
807
+ error = regen_result['error']
808
+
809
+ except TimeoutError as e:
810
+ raise Exception(f"LLM call timed out: {str(e)}")
811
+ except Exception as e:
812
+ print(f"Warning: LLM reasoning failed: {e}")
813
+ break
531
814
 
532
815
  return {
533
816
  "sql": sql_query,
@@ -535,6 +818,7 @@ def generate_sql(
535
818
  "schema": schema,
536
819
  "error": error if not is_sql_valid else None,
537
820
  "is_sql_valid": is_sql_valid if should_validate_sql else None,
821
+ "reasoning_status": reasoning_status,
538
822
  "usage_metadata": usage_metadata,
539
823
  }
540
824
 
@@ -1,7 +1,10 @@
1
1
  from typing import Optional, Any
2
2
  import time
3
+ import base64
4
+ import hashlib
3
5
  from pytimbr_api import timbr_http_connector
4
6
  from functools import wraps
7
+ from cryptography.fernet import Fernet
5
8
 
6
9
  from ..config import cache_timeout, ignore_tags, ignore_tags_prefix
7
10
  from .general import to_boolean
@@ -40,6 +43,42 @@ def _serialize_cache_key(*args, **kwargs):
40
43
  return (tuple(serialize(arg) for arg in args), tuple((k, serialize(v)) for k, v in kwargs.items()))
41
44
 
42
45
 
46
+ def generate_key() -> bytes:
47
+ """Generate a new Fernet secret key."""
48
+ passcode = b"lucylit2025"
49
+ hlib = hashlib.md5()
50
+ hlib.update(passcode)
51
+ return base64.urlsafe_b64encode(hlib.hexdigest().encode('utf-8'))
52
+
53
+
54
+ ENCRYPT_KEY = generate_key()
55
+
56
+
57
+ def encrypt_prompt(prompt: Any, key: Optional[bytes] = ENCRYPT_KEY) -> bytes:
58
+ """Serialize & encrypt the prompt; returns a URL-safe token."""
59
+ if isinstance(prompt, str):
60
+ text = prompt
61
+ elif isinstance(prompt, list):
62
+ parts = []
63
+ for message in prompt:
64
+ if hasattr(message, "content"):
65
+ parts.append(f"{message.type}: {message.content}")
66
+ else:
67
+ parts.append(str(message))
68
+ text = "\n".join(parts)
69
+ else:
70
+ text = str(prompt)
71
+
72
+ f = Fernet(key)
73
+ return f.encrypt(text.encode()).decode('utf-8')
74
+
75
+
76
+ def decrypt_prompt(token: bytes, key: bytes) -> str:
77
+ """Decrypt the token and return the original prompt string."""
78
+ f = Fernet(key)
79
+ return f.decrypt(token).decode()
80
+
81
+
43
82
  def cache_with_version_check(func):
44
83
  """Decorator to cache function results and invalidate if ontology version changes."""
45
84
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langchain-timbr
3
- Version: 2.1.5
3
+ Version: 2.1.7
4
4
  Summary: LangChain & LangGraph extensions that parse LLM prompts into Timbr semantic SQL and execute them.
5
5
  Project-URL: Homepage, https://github.com/WPSemantix/langchain-timbr
6
6
  Project-URL: Documentation, https://docs.timbr.ai/doc/docs/integration/langchain-sdk/
@@ -23,14 +23,13 @@ Requires-Dist: cryptography==45.0.7; python_version >= '3.11'
23
23
  Requires-Dist: cryptography>=44.0.3; python_version == '3.10'
24
24
  Requires-Dist: langchain-community==0.3.30; python_version >= '3.11'
25
25
  Requires-Dist: langchain-community>=0.3.27; python_version == '3.10'
26
- Requires-Dist: langchain-core>=0.3.58; python_version == '3.10'
27
- Requires-Dist: langchain-core>=0.3.80; python_version >= '3.11'
26
+ Requires-Dist: langchain-core>=0.3.80
28
27
  Requires-Dist: langchain==0.3.27; python_version >= '3.11'
29
28
  Requires-Dist: langchain>=0.3.25; python_version == '3.10'
30
29
  Requires-Dist: langgraph==0.6.8; python_version >= '3.11'
31
30
  Requires-Dist: langgraph>=0.3.20; python_version == '3.10'
32
31
  Requires-Dist: pydantic==2.10.4
33
- Requires-Dist: pytimbr-api>=2.0.0; python_version >= '3.11'
32
+ Requires-Dist: pytimbr-api>=2.1.0
34
33
  Requires-Dist: tiktoken==0.8.0
35
34
  Requires-Dist: transformers==4.57.0; python_version >= '3.11'
36
35
  Requires-Dist: transformers>=4.53; python_version == '3.10'
@@ -0,0 +1,28 @@
1
+ langchain_timbr/__init__.py,sha256=qNyk3Rt-8oWr_OGuU_E-6siNZXuCnvVEkj65EIuVbbQ,824
2
+ langchain_timbr/_version.py,sha256=MO-pKnEzeW3zl7_c60hrWWrquNyRCaWfqqY8EMXwxVA,704
3
+ langchain_timbr/config.py,sha256=c3A_HIw1b1Y6tc4EaXHZRJ_OptGLzl5bQGzx4ec_tgM,1605
4
+ langchain_timbr/timbr_llm_connector.py,sha256=mdkWskpvmXZre5AzVFn6KfPnVH5YN5MIwfEoXWBLMgY,13170
5
+ langchain_timbr/langchain/__init__.py,sha256=ejcsZKP9PK0j4WrrCCcvBXpDpP-TeRiVb21OIUJqix8,580
6
+ langchain_timbr/langchain/execute_timbr_query_chain.py,sha256=snfx22QE0hM3pjvoApUBFGvtIqNt2Yy1PFnx4uuRScs,16316
7
+ langchain_timbr/langchain/generate_answer_chain.py,sha256=nteA4QZp9CAOskTBl_CokwaMlqnR2g2GvKz2mLs9WVY,4871
8
+ langchain_timbr/langchain/generate_timbr_sql_chain.py,sha256=G2uEPB4NbZcDIkWtoxVlPNcuaaUcEmUcP39_4U5aYK0,9811
9
+ langchain_timbr/langchain/identify_concept_chain.py,sha256=kuzg0jJQpFGIiaxtNhdQ5K4HXveLVwONFNsoipPCteE,7169
10
+ langchain_timbr/langchain/timbr_sql_agent.py,sha256=fCYAYjvE_9xVtp-xicOGKRbvtjK8gRipXkSl9mKTcio,20810
11
+ langchain_timbr/langchain/validate_timbr_sql_chain.py,sha256=rq3fVmNyt_D1vfRbvn21f_lACSqPiHdVMri4Gofw5nY,10264
12
+ langchain_timbr/langgraph/__init__.py,sha256=mKBFd0x01jWpRujUWe-suX3FFhenPoDxrvzs8I0mum0,457
13
+ langchain_timbr/langgraph/execute_timbr_query_node.py,sha256=Tz7N3tCGJgot7v23SxYZaoG2o0kxXanv0wLcT1pJwyc,5994
14
+ langchain_timbr/langgraph/generate_response_node.py,sha256=opwscNEXabaSyCFLbzGQFkDFEymJurhNU9aAtm1rnOk,2375
15
+ langchain_timbr/langgraph/generate_timbr_sql_node.py,sha256=kd_soT3w3V0fGDI52R_iL0S1jlTApUawWxRJFQ6yo_Q,5370
16
+ langchain_timbr/langgraph/identify_concept_node.py,sha256=aiLDFEcz_vM4zZ_ULe1SvJKmI-e4Fb2SibZQaEPz_eY,3649
17
+ langchain_timbr/langgraph/validate_timbr_query_node.py,sha256=utdy8cj3Pe6cd3OSAUjpCh7gwQI8EaenUW7bmptC7iQ,5191
18
+ langchain_timbr/llm_wrapper/llm_wrapper.py,sha256=j94DqIGECXyfAVayLC7VaNxs_8n1qYFiHY2Qvt2B3Bc,17537
19
+ langchain_timbr/llm_wrapper/timbr_llm_wrapper.py,sha256=sDqDOz0qu8b4WWlagjNceswMVyvEJ8yBWZq2etBh-T0,1362
20
+ langchain_timbr/utils/general.py,sha256=KkehHvIj8GoQ_0KVXLcUVeaYaTtkuzgXmYYx2TXJhI4,10253
21
+ langchain_timbr/utils/prompt_service.py,sha256=QVmfA9cHO2IPVsKG8V5cuMm2gPfvRq2VzLcx04sqT88,12197
22
+ langchain_timbr/utils/temperature_supported_models.json,sha256=d3UmBUpG38zDjjB42IoGpHTUaf0pHMBRSPY99ao1a3g,1832
23
+ langchain_timbr/utils/timbr_llm_utils.py,sha256=7jg5TxTMkdLKZPvGbS49bTSmX6z8gbP84iKb_JALTy8,34646
24
+ langchain_timbr/utils/timbr_utils.py,sha256=x-z46NQn8nhRR8PJ5l23uh0qEpDWbhD8x4r-DyY4IYY,18864
25
+ langchain_timbr-2.1.7.dist-info/METADATA,sha256=aV12qFam0Klb2tGTr5we4lb9PzFRQFEEMsOZ0wcMYgQ,10724
26
+ langchain_timbr-2.1.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
27
+ langchain_timbr-2.1.7.dist-info/licenses/LICENSE,sha256=0ITGFk2alkC7-e--bRGtuzDrv62USIiVyV2Crf3_L_0,1065
28
+ langchain_timbr-2.1.7.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,28 +0,0 @@
1
- langchain_timbr/__init__.py,sha256=gxd6Y6QDmYZtPlYVdXtPIy501hMOZXHjWh2qq4qzt_s,828
2
- langchain_timbr/_version.py,sha256=maW2cYW4nRdLTImWeNoFLjtvs_U5VO5bEhsRV9_0Ivk,704
3
- langchain_timbr/config.py,sha256=PEtvNgvnA9UseZJjKgup_O6xdG-VYk3N11nH8p8W1Kg,1410
4
- langchain_timbr/timbr_llm_connector.py,sha256=mdkWskpvmXZre5AzVFn6KfPnVH5YN5MIwfEoXWBLMgY,13170
5
- langchain_timbr/langchain/__init__.py,sha256=ejcsZKP9PK0j4WrrCCcvBXpDpP-TeRiVb21OIUJqix8,580
6
- langchain_timbr/langchain/execute_timbr_query_chain.py,sha256=pedMajyKDI2ZaoyVp1r64nHX015Wy-r96HoJrRlCh48,15579
7
- langchain_timbr/langchain/generate_answer_chain.py,sha256=nteA4QZp9CAOskTBl_CokwaMlqnR2g2GvKz2mLs9WVY,4871
8
- langchain_timbr/langchain/generate_timbr_sql_chain.py,sha256=3Z0ut78AFCNHKwLwOYH44hzJDIOA-zNF0x8Tjyrvzp4,9098
9
- langchain_timbr/langchain/identify_concept_chain.py,sha256=kuzg0jJQpFGIiaxtNhdQ5K4HXveLVwONFNsoipPCteE,7169
10
- langchain_timbr/langchain/timbr_sql_agent.py,sha256=HntpalzCZ-PlHd7na5V0syCMqrREFUpppGM4eHstaZQ,19574
11
- langchain_timbr/langchain/validate_timbr_sql_chain.py,sha256=OcE_7yfb9xpD-I4OS7RG1bY4-yi1UicjvGegOv_vkQU,9567
12
- langchain_timbr/langgraph/__init__.py,sha256=mKBFd0x01jWpRujUWe-suX3FFhenPoDxrvzs8I0mum0,457
13
- langchain_timbr/langgraph/execute_timbr_query_node.py,sha256=rPx_V3OOh-JTGOwrEopHmOmFuM-ngBZdswkW9oZ43hU,5536
14
- langchain_timbr/langgraph/generate_response_node.py,sha256=opwscNEXabaSyCFLbzGQFkDFEymJurhNU9aAtm1rnOk,2375
15
- langchain_timbr/langgraph/generate_timbr_sql_node.py,sha256=wkau-NajblSVzNIro9IyqawULvz3XaCYSEdYW95vWco,4911
16
- langchain_timbr/langgraph/identify_concept_node.py,sha256=aiLDFEcz_vM4zZ_ULe1SvJKmI-e4Fb2SibZQaEPz_eY,3649
17
- langchain_timbr/langgraph/validate_timbr_query_node.py,sha256=-2fuieCz1hv6ua-17zfonme8LQ_OoPnoOBTdGSXkJgs,4793
18
- langchain_timbr/llm_wrapper/llm_wrapper.py,sha256=j94DqIGECXyfAVayLC7VaNxs_8n1qYFiHY2Qvt2B3Bc,17537
19
- langchain_timbr/llm_wrapper/timbr_llm_wrapper.py,sha256=sDqDOz0qu8b4WWlagjNceswMVyvEJ8yBWZq2etBh-T0,1362
20
- langchain_timbr/utils/general.py,sha256=KkehHvIj8GoQ_0KVXLcUVeaYaTtkuzgXmYYx2TXJhI4,10253
21
- langchain_timbr/utils/prompt_service.py,sha256=QT7kiq72rQno77z1-tvGGD7HlH_wdTQAl_1teSoKEv4,11373
22
- langchain_timbr/utils/temperature_supported_models.json,sha256=d3UmBUpG38zDjjB42IoGpHTUaf0pHMBRSPY99ao1a3g,1832
23
- langchain_timbr/utils/timbr_llm_utils.py,sha256=_4Qz5SX5cXW1Rl_fSBcE9P3uPEaI8DBg3GpXA4uQGoI,23102
24
- langchain_timbr/utils/timbr_utils.py,sha256=SvmQ0wYicODNhmo8c-5_KPDBAfrBVBkUfoO8sPItQhk,17759
25
- langchain_timbr-2.1.5.dist-info/METADATA,sha256=tBsAUFHHGDM8i9tBirHqEmtZ7-ft2um55Rif_vVbtF0,10840
26
- langchain_timbr-2.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
27
- langchain_timbr-2.1.5.dist-info/licenses/LICENSE,sha256=0ITGFk2alkC7-e--bRGtuzDrv62USIiVyV2Crf3_L_0,1065
28
- langchain_timbr-2.1.5.dist-info/RECORD,,