langchain-timbr 2.1.14__tar.gz → 2.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/PKG-INFO +1 -1
  2. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/_version.py +2 -2
  3. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/langchain/execute_timbr_query_chain.py +6 -0
  4. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/langchain/generate_timbr_sql_chain.py +2 -0
  5. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/langchain/timbr_sql_agent.py +12 -0
  6. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/langchain/validate_timbr_sql_chain.py +6 -0
  7. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/utils/timbr_llm_utils.py +333 -166
  8. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/integration/test_langchain_chains.py +7 -0
  9. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/.github/dependabot.yml +0 -0
  10. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/.github/pull_request_template.md +0 -0
  11. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/.github/workflows/_codespell.yml +0 -0
  12. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/.github/workflows/_fossa.yml +0 -0
  13. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/.github/workflows/install-dependencies-and-run-tests.yml +0 -0
  14. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/.github/workflows/publish.yml +0 -0
  15. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/.gitignore +0 -0
  16. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/LICENSE +0 -0
  17. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/README.md +0 -0
  18. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/SECURITY.md +0 -0
  19. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/pyproject.toml +0 -0
  20. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/pytest.ini +0 -0
  21. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/requirements.txt +0 -0
  22. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/requirements310.txt +0 -0
  23. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/requirements311.txt +0 -0
  24. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/__init__.py +0 -0
  25. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/config.py +0 -0
  26. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/langchain/__init__.py +0 -0
  27. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/langchain/generate_answer_chain.py +0 -0
  28. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/langchain/identify_concept_chain.py +0 -0
  29. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/langgraph/__init__.py +0 -0
  30. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/langgraph/execute_timbr_query_node.py +0 -0
  31. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/langgraph/generate_response_node.py +0 -0
  32. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/langgraph/generate_timbr_sql_node.py +0 -0
  33. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/langgraph/identify_concept_node.py +0 -0
  34. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/langgraph/validate_timbr_query_node.py +0 -0
  35. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/llm_wrapper/llm_wrapper.py +0 -0
  36. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/llm_wrapper/timbr_llm_wrapper.py +0 -0
  37. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/timbr_llm_connector.py +0 -0
  38. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/utils/general.py +0 -0
  39. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/utils/prompt_service.py +0 -0
  40. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/utils/temperature_supported_models.json +0 -0
  41. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/src/langchain_timbr/utils/timbr_utils.py +0 -0
  42. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/README.md +0 -0
  43. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/conftest.py +0 -0
  44. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/integration/test_agent_integration.py +0 -0
  45. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/integration/test_azure_databricks_provider.py +0 -0
  46. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/integration/test_azure_openai_model.py +0 -0
  47. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/integration/test_chain_pipeline.py +0 -0
  48. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/integration/test_chain_reasoning.py +0 -0
  49. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/integration/test_jwt_token.py +0 -0
  50. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/integration/test_langgraph_nodes.py +0 -0
  51. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/integration/test_timeout_functionality.py +0 -0
  52. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/standard/conftest.py +0 -0
  53. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/standard/test_chain_documentation.py +0 -0
  54. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/standard/test_connection_validation.py +0 -0
  55. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/standard/test_llm_wrapper_optional_params.py +0 -0
  56. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/standard/test_optional_llm_integration.py +0 -0
  57. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/standard/test_standard_chain_requirements.py +0 -0
  58. {langchain_timbr-2.1.14 → langchain_timbr-2.2.0}/tests/standard/test_unit_tests.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langchain-timbr
3
- Version: 2.1.14
3
+ Version: 2.2.0
4
4
  Summary: LangChain & LangGraph extensions that parse LLM prompts into Timbr semantic SQL and execute them.
5
5
  Project-URL: Homepage, https://github.com/WPSemantix/langchain-timbr
6
6
  Project-URL: Documentation, https://docs.timbr.ai/doc/docs/integration/langchain-sdk/
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.1.14'
32
- __version_tuple__ = version_tuple = (2, 1, 14)
31
+ __version__ = version = '2.2.0'
32
+ __version_tuple__ = version_tuple = (2, 2, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -247,6 +247,8 @@ class ExecuteTimbrQueryChain(Chain):
247
247
  concept_name = inputs.get("concept", self._concept)
248
248
  is_sql_valid = True
249
249
  error = None
250
+ identify_concept_reason = None
251
+ generate_sql_reason = None
250
252
  reasoning_status = None
251
253
  rows = []
252
254
  usage_metadata = {}
@@ -270,6 +272,8 @@ class ExecuteTimbrQueryChain(Chain):
270
272
  is_sql_valid = True
271
273
 
272
274
  error = generate_res.get("error")
275
+ identify_concept_reason = generate_res.get("identify_concept_reason")
276
+ generate_sql_reason = generate_res.get("generate_sql_reason")
273
277
  usage_metadata = self._summarize_usage_metadata(usage_metadata, generate_res.get("usage_metadata", {}))
274
278
 
275
279
  is_sql_not_tried = not any(sql.lower().strip() == gen.lower().strip() for gen in generated)
@@ -305,6 +309,8 @@ class ExecuteTimbrQueryChain(Chain):
305
309
  "concept": concept_name,
306
310
  "error": error if not is_sql_valid else None,
307
311
  "reasoning_status": reasoning_status,
312
+ "identify_concept_reason": identify_concept_reason,
313
+ "generate_sql_reason": generate_sql_reason,
308
314
  self.usage_metadata_key: usage_metadata,
309
315
  }
310
316
 
@@ -206,6 +206,8 @@ class GenerateTimbrSqlChain(Chain):
206
206
  "concept": concept,
207
207
  "is_sql_valid": generate_res.get("is_sql_valid"),
208
208
  "error": generate_res.get("error"),
209
+ "identify_concept_reason": generate_res.get("identify_concept_reason"),
210
+ "generate_sql_reason": generate_res.get("generate_sql_reason"),
209
211
  "reasoning_status": generate_res.get("reasoning_status"),
210
212
  self.usage_metadata_key: generate_res.get("usage_metadata"),
211
213
  }
@@ -181,6 +181,8 @@ class TimbrSqlAgent(BaseSingleActionAgent):
181
181
  "schema": None,
182
182
  "concept": None,
183
183
  "reasoning_status": None,
184
+ "identify_concept_reason": None,
185
+ "generate_sql_reason": None,
184
186
  "usage_metadata": {},
185
187
  },
186
188
  log="Empty input received"
@@ -210,6 +212,8 @@ class TimbrSqlAgent(BaseSingleActionAgent):
210
212
  "error": result.get("error", None),
211
213
  "reasoning_status": result.get("reasoning_status", None),
212
214
  "usage_metadata": usage_metadata,
215
+ "identify_concept_reason": None,
216
+ "generate_sql_reason": None,
213
217
  },
214
218
  log=f"Successfully executed query on concept: {result.get('concept', '')}"
215
219
  )
@@ -224,6 +228,8 @@ class TimbrSqlAgent(BaseSingleActionAgent):
224
228
  "schema": None,
225
229
  "concept": None,
226
230
  "reasoning_status": None,
231
+ "identify_concept_reason": None,
232
+ "generate_sql_reason": None,
227
233
  "usage_metadata": {},
228
234
  },
229
235
  log=error_context
@@ -245,6 +251,8 @@ class TimbrSqlAgent(BaseSingleActionAgent):
245
251
  "schema": None,
246
252
  "concept": None,
247
253
  "reasoning_status": None,
254
+ "identify_concept_reason": None,
255
+ "generate_sql_reason": None,
248
256
  "usage_metadata": {},
249
257
  },
250
258
  log="Empty or whitespace-only input received"
@@ -286,6 +294,8 @@ class TimbrSqlAgent(BaseSingleActionAgent):
286
294
  "concept": result.get("concept", ""),
287
295
  "error": result.get("error", None),
288
296
  "reasoning_status": result.get("reasoning_status", None),
297
+ "identify_concept_reason": result.get("identify_concept_reason", None),
298
+ "generate_sql_reason": result.get("generate_sql_reason", None),
289
299
  "usage_metadata": usage_metadata,
290
300
  },
291
301
  log=f"Successfully executed query on concept: {result.get('concept', '')}"
@@ -301,6 +311,8 @@ class TimbrSqlAgent(BaseSingleActionAgent):
301
311
  "schema": None,
302
312
  "concept": None,
303
313
  "reasoning_status": None,
314
+ "identify_concept_reason": None,
315
+ "generate_sql_reason": None,
304
316
  "usage_metadata": {},
305
317
  },
306
318
  log=error_context
@@ -179,6 +179,8 @@ class ValidateTimbrSqlChain(Chain):
179
179
  schema = self._schema
180
180
  concept = self._concept
181
181
  reasoning_status = None
182
+ identify_concept_reason = None
183
+ generate_sql_reason = None
182
184
 
183
185
  is_sql_valid, error, sql = validate_sql(sql, self._get_conn_params())
184
186
  if not is_sql_valid:
@@ -211,6 +213,8 @@ class ValidateTimbrSqlChain(Chain):
211
213
  is_sql_valid = generate_res.get("is_sql_valid")
212
214
  reasoning_status = generate_res.get("reasoning_status")
213
215
  error = generate_res.get("error")
216
+ identify_concept_reason = generate_res.get("identify_concept_reason")
217
+ generate_sql_reason = generate_res.get("generate_sql_reason")
214
218
 
215
219
  return {
216
220
  "sql": sql,
@@ -219,5 +223,7 @@ class ValidateTimbrSqlChain(Chain):
219
223
  "is_sql_valid": is_sql_valid,
220
224
  "error": error,
221
225
  "reasoning_status": reasoning_status,
226
+ "identify_concept_reason": identify_concept_reason,
227
+ "generate_sql_reason": generate_sql_reason,
222
228
  self.usage_metadata_key: usage_metadata,
223
229
  }
@@ -268,6 +268,7 @@ def determine_concept(
268
268
  ) -> dict[str, Any]:
269
269
  usage_metadata = {}
270
270
  determined_concept_name = None
271
+ identify_concept_reason = None
271
272
  schema = 'dtimbr'
272
273
 
273
274
  # Use config default timeout if none provided
@@ -339,10 +340,21 @@ def determine_concept(
339
340
  if debug:
340
341
  usage_metadata['determine_concept']["p_hash"] = encrypt_prompt(prompt)
341
342
 
342
- response_text = _get_response_text(response)
343
- candidate = response_text.strip()
343
+ # Try to parse as JSON first (with 'result' and 'reason' keys)
344
+ try:
345
+ parsed_response = _parse_json_from_llm_response(response)
346
+ if isinstance(parsed_response, dict) and 'result' in parsed_response:
347
+ candidate = parsed_response.get('result', '').strip()
348
+ identify_concept_reason = parsed_response.get('reason', None)
349
+ else:
350
+ # Fallback to plain text if JSON doesn't have expected structure
351
+ candidate = _get_response_text(response).strip()
352
+ except (json.JSONDecodeError, ValueError):
353
+ # If not JSON, treat as plain text (backwards compatibility)
354
+ candidate = _get_response_text(response).strip()
355
+
344
356
  if should_validate and candidate not in concepts_and_views.keys():
345
- error = f"Concept '{determined_concept_name}' not found in the list of concepts."
357
+ error = f"Concept '{candidate}' not found in the list of concepts."
346
358
  continue
347
359
 
348
360
  determined_concept_name = candidate
@@ -356,6 +368,7 @@ def determine_concept(
356
368
  return {
357
369
  "concept_metadata": concepts_and_views.get(determined_concept_name) if determined_concept_name else None,
358
370
  "concept": determined_concept_name,
371
+ "identify_concept_reason": identify_concept_reason,
359
372
  "schema": schema,
360
373
  "usage_metadata": usage_metadata,
361
374
  }
@@ -423,14 +436,45 @@ def _build_rel_columns_str(relationships: list[dict], columns_tags: Optional[dic
423
436
  return '.\n'.join(rel_str_arr) if rel_str_arr else ''
424
437
 
425
438
 
426
- def _parse_sql_from_llm_response(response: Any) -> str:
439
+ def _parse_sql_and_reason_from_llm_response(response: Any) -> dict:
440
+ """
441
+ Parse SQL & reason from LLM response. Handles both plain SQL strings and JSON format with 'result' and 'reason' keys.
442
+
443
+ Returns:
444
+ dict with 'sql' and 'reason' keys (reason may be None if not provided)
445
+ """
446
+ # Try to parse as JSON first
447
+ try:
448
+ parsed_json = _parse_json_from_llm_response(response)
449
+
450
+ # Extract SQL from 'result' key and reason from 'reason' key
451
+ if isinstance(parsed_json, dict) and 'result' in parsed_json:
452
+ sql = parsed_json.get('result', '')
453
+ reason = parsed_json.get('reason', None)
454
+
455
+ # Clean the SQL
456
+ sql = (sql
457
+ .replace("```sql", "")
458
+ .replace("```", "")
459
+ .replace('SELECT \n', 'SELECT ')
460
+ .replace(';', '')
461
+ .strip())
462
+
463
+ return {'sql': sql, 'reason': reason}
464
+ except (json.JSONDecodeError, ValueError):
465
+ # If not JSON, treat as plain SQL string (backwards compatibility)
466
+ pass
467
+
468
+ # Fallback to plain text parsing
427
469
  response_text = _get_response_text(response)
428
- return (response_text
429
- .replace("```sql", "")
430
- .replace("```", "")
431
- .replace('SELECT \n', 'SELECT ')
432
- .replace(';', '')
433
- .strip())
470
+ sql = (response_text
471
+ .replace("```sql", "")
472
+ .replace("```", "")
473
+ .replace('SELECT \n', 'SELECT ')
474
+ .replace(';', '')
475
+ .strip())
476
+
477
+ return {'sql': sql, 'reason': None}
434
478
 
435
479
 
436
480
  def _get_active_datasource(conn_params: dict) -> dict:
@@ -438,6 +482,38 @@ def _get_active_datasource(conn_params: dict) -> dict:
438
482
  return datasources[0] if datasources else None
439
483
 
440
484
 
485
+ def _parse_json_from_llm_response(response: Any) -> dict:
486
+ """
487
+ Parse JSON from LLM response. Handles markdown code blocks and extracts valid JSON.
488
+
489
+ Args:
490
+ response: LLM response object
491
+
492
+ Returns:
493
+ dict containing parsed JSON
494
+
495
+ Raises:
496
+ json.JSONDecodeError: If response cannot be parsed as JSON
497
+ ValueError: If response format is unexpected
498
+ """
499
+ response_text = _get_response_text(response)
500
+
501
+ # Remove markdown code block markers if present
502
+ content = response_text.strip()
503
+ if content.startswith("```json"):
504
+ content = content[7:] # Remove ```json
505
+ elif content.startswith("```"):
506
+ content = content[3:] # Remove ```
507
+
508
+ if content.endswith("```"):
509
+ content = content[:-3] # Remove closing ```
510
+
511
+ content = content.strip()
512
+
513
+ # Parse and return JSON
514
+ return json.loads(content)
515
+
516
+
441
517
  def _evaluate_sql_enable_reasoning(
442
518
  question: str,
443
519
  sql_query: str,
@@ -463,22 +539,8 @@ def _evaluate_sql_enable_reasoning(
463
539
 
464
540
  response = _call_llm_with_timeout(llm, prompt, timeout=timeout)
465
541
 
466
- # Extract JSON from response content (handle markdown code blocks)
467
- content = response.content.strip()
468
-
469
- # Remove markdown code block markers if present
470
- if content.startswith("```json"):
471
- content = content[7:] # Remove ```json
472
- elif content.startswith("```"):
473
- content = content[3:] # Remove ```
474
-
475
- if content.endswith("```"):
476
- content = content[:-3] # Remove closing ```
477
-
478
- content = content.strip()
479
-
480
542
  # Parse JSON response
481
- evaluation = json.loads(content)
543
+ evaluation = _parse_json_from_llm_response(response)
482
544
 
483
545
  return {
484
546
  "evaluation": evaluation,
@@ -571,11 +633,9 @@ def _build_sql_generation_context(
571
633
  def _generate_sql_with_llm(
572
634
  question: str,
573
635
  llm: LLM,
574
- conn_params: dict,
575
636
  generate_sql_prompt: Any,
576
637
  current_context: dict,
577
638
  note: str,
578
- should_validate_sql: bool,
579
639
  timeout: int,
580
640
  debug: bool = False,
581
641
  ) -> dict:
@@ -614,8 +674,12 @@ def _generate_sql_with_llm(
614
674
 
615
675
  response = _call_llm_with_timeout(llm, prompt, timeout=timeout)
616
676
 
677
+ # Parse response which now includes both SQL and reason
678
+ parsed_response = _parse_sql_and_reason_from_llm_response(response)
679
+
617
680
  result = {
618
- "sql": _parse_sql_from_llm_response(response),
681
+ "sql": parsed_response['sql'],
682
+ "generate_sql_reason": parsed_response['reason'],
619
683
  "apx_token_count": apx_token_count,
620
684
  "usage_metadata": _extract_usage_metadata(response),
621
685
  "is_valid": True,
@@ -625,11 +689,163 @@ def _generate_sql_with_llm(
625
689
  if debug:
626
690
  result["p_hash"] = encrypt_prompt(prompt)
627
691
 
628
- if should_validate_sql:
629
- result["is_valid"], result["error"], result["sql"] = validate_sql(result["sql"], conn_params)
630
692
 
631
693
  return result
632
694
 
695
+ def handle_generate_sql_reasoning(
696
+ sql_query: str,
697
+ question: str,
698
+ llm: LLM,
699
+ conn_params: dict,
700
+ schema: str,
701
+ concept: str,
702
+ concept_metadata: dict,
703
+ include_tags: bool,
704
+ exclude_properties: list,
705
+ db_is_case_sensitive: bool,
706
+ max_limit: int,
707
+ reasoning_steps: int,
708
+ note: str,
709
+ graph_depth: int,
710
+ usage_metadata: dict,
711
+ timeout: int,
712
+ debug: bool,
713
+ ) -> tuple[str, int, str]:
714
+ generate_sql_prompt = get_generate_sql_prompt_template(conn_params)
715
+ context_graph_depth = graph_depth
716
+ reasoned_sql = sql_query
717
+ reasoned_sql_reason = None
718
+ for step in range(reasoning_steps):
719
+ try:
720
+ # Step 1: Evaluate the current SQL
721
+ eval_result = _evaluate_sql_enable_reasoning(
722
+ question=question,
723
+ sql_query=reasoned_sql,
724
+ llm=llm,
725
+ conn_params=conn_params,
726
+ timeout=timeout,
727
+ )
728
+
729
+ usage_metadata[f'sql_reasoning_step_{step + 1}'] = {
730
+ "approximate": eval_result['apx_token_count'],
731
+ **eval_result['usage_metadata'],
732
+ }
733
+
734
+ evaluation = eval_result['evaluation']
735
+ reasoning_status = evaluation.get("assessment", "partial").lower()
736
+
737
+ if reasoning_status == "correct":
738
+ break
739
+
740
+ # Step 2: Regenerate SQL with feedback
741
+ evaluation_note = note + f"\n\nThe previously generated SQL: `{reasoned_sql}` was assessed as '{evaluation.get('assessment')}' because: {evaluation.get('reasoning', '*could not determine cause*')}. Please provide a corrected SQL query that better answers the question: '{question}'.\n\nCRITICAL: Return ONLY the SQL query without any explanation or comments."
742
+
743
+ # Increase graph depth for 2nd+ reasoning attempts, up to max of 3
744
+ context_graph_depth = min(3, int(graph_depth) + step) if graph_depth < 3 and step > 0 else graph_depth
745
+ regen_result = _generate_sql_with_llm(
746
+ question=question,
747
+ llm=llm,
748
+ generate_sql_prompt=generate_sql_prompt,
749
+ current_context=_build_sql_generation_context(
750
+ conn_params=conn_params,
751
+ schema=schema,
752
+ concept=concept,
753
+ concept_metadata=concept_metadata,
754
+ graph_depth=context_graph_depth,
755
+ include_tags=include_tags,
756
+ exclude_properties=exclude_properties,
757
+ db_is_case_sensitive=db_is_case_sensitive,
758
+ max_limit=max_limit),
759
+ note=evaluation_note,
760
+ timeout=timeout,
761
+ debug=debug,
762
+ )
763
+
764
+ reasoned_sql = regen_result['sql']
765
+ reasoned_sql_reason = regen_result['generate_sql_reason']
766
+ error = regen_result['error']
767
+
768
+ step_key = f'generate_sql_reasoning_step_{step + 1}'
769
+ usage_metadata[step_key] = {
770
+ "approximate": regen_result['apx_token_count'],
771
+ **regen_result['usage_metadata'],
772
+ }
773
+ if debug and 'p_hash' in regen_result:
774
+ usage_metadata[step_key]['p_hash'] = regen_result['p_hash']
775
+
776
+ if error:
777
+ raise Exception(error)
778
+
779
+ except TimeoutError as e:
780
+ raise Exception(f"LLM call timed out: {str(e)}")
781
+ except Exception as e:
782
+ print(f"Warning: LLM reasoning failed: {e}")
783
+ break
784
+
785
+ return reasoned_sql, context_graph_depth, reasoned_sql_reason
786
+
787
+ def handle_validate_generate_sql(
788
+ sql_query: str,
789
+ question: str,
790
+ llm: LLM,
791
+ conn_params: dict,
792
+ generate_sql_prompt: Any,
793
+ schema: str,
794
+ concept: str,
795
+ concept_metadata: dict,
796
+ include_tags: bool,
797
+ exclude_properties: list,
798
+ db_is_case_sensitive: bool,
799
+ max_limit: int,
800
+ graph_depth: int,
801
+ retries: int,
802
+ timeout: int,
803
+ debug: bool,
804
+ usage_metadata: dict,
805
+ ) -> tuple[bool, str, str]:
806
+ is_sql_valid, error, sql_query = validate_sql(sql_query, conn_params)
807
+ validation_attempt = 0
808
+
809
+ while validation_attempt < retries and not is_sql_valid:
810
+ validation_attempt += 1
811
+ validation_err_txt = f"\nThe generated SQL (`{sql_query}`) was invalid with error: {error}. Please generate a corrected query that achieves the intended result." if error and "snowflake" not in llm._llm_type else ""
812
+
813
+ regen_result = _generate_sql_with_llm(
814
+ question=question,
815
+ llm=llm,
816
+ generate_sql_prompt=generate_sql_prompt,
817
+ current_context=_build_sql_generation_context(
818
+ conn_params=conn_params,
819
+ schema=schema,
820
+ concept=concept,
821
+ concept_metadata=concept_metadata,
822
+ graph_depth=graph_depth,
823
+ include_tags=include_tags,
824
+ exclude_properties=exclude_properties,
825
+ db_is_case_sensitive=db_is_case_sensitive,
826
+ max_limit=max_limit),
827
+ note=validation_err_txt,
828
+ timeout=timeout,
829
+ debug=debug,
830
+ )
831
+
832
+ regen_error = regen_result['error']
833
+ sql_query = regen_result['sql']
834
+
835
+ validation_key = f'generate_sql_validation_regen_{validation_attempt}'
836
+ usage_metadata[validation_key] = {
837
+ "approximate": regen_result['apx_token_count'],
838
+ **regen_result['usage_metadata'],
839
+ }
840
+ if debug and 'p_hash' in regen_result:
841
+ usage_metadata[validation_key]['p_hash'] = regen_result['p_hash']
842
+
843
+ if regen_error:
844
+ raise Exception(regen_error)
845
+
846
+ is_sql_valid, error, sql_query = validate_sql(sql_query, conn_params)
847
+
848
+ return is_sql_valid, error, sql_query
633
849
 
634
850
  def generate_sql(
635
851
  question: str,
@@ -656,13 +872,11 @@ def generate_sql(
656
872
  usage_metadata = {}
657
873
  concept_metadata = None
658
874
  reasoning_status = 'correct'
659
-
875
+
660
876
  # Use config default timeout if none provided
661
877
  if timeout is None:
662
878
  timeout = config.llm_timeout
663
879
 
664
- generate_sql_prompt = get_generate_sql_prompt_template(conn_params)
665
-
666
880
  if concept and concept != "" and (schema is None or schema != "vtimbr"):
667
881
  concepts_list = [concept]
668
882
  elif concept and concept != "" and schema == "vtimbr":
@@ -682,154 +896,105 @@ def generate_sql(
682
896
  debug=debug,
683
897
  timeout=timeout,
684
898
  )
685
- concept, schema, concept_metadata = determine_concept_res.get('concept'), determine_concept_res.get('schema'), determine_concept_res.get('concept_metadata')
899
+
900
+ concept = determine_concept_res.get('concept')
901
+ identify_concept_reason = determine_concept_res.get('identify_concept_reason', None)
902
+ schema = determine_concept_res.get('schema')
903
+ concept_metadata = determine_concept_res.get('concept_metadata')
686
904
  usage_metadata.update(determine_concept_res.get('usage_metadata', {}))
687
905
 
688
906
  if not concept:
689
907
  raise Exception("No relevant concept found for the query.")
690
908
 
909
+ generate_sql_prompt = get_generate_sql_prompt_template(conn_params)
691
910
  sql_query = None
692
- iteration = 0
693
- is_sql_valid = True
911
+ generate_sql_reason = None
912
+ is_sql_valid = True # Assume valid by default; set to False only if validation fails
694
913
  error = ''
695
- while sql_query is None or (should_validate_sql and iteration < retries and not is_sql_valid):
696
- iteration += 1
697
- err_txt = f"\nThe original SQL (`{sql_query}`) was invalid with error: {error}. Please generate a corrected query." if error and "snowflake" not in llm._llm_type else ""
698
914
 
699
- try:
700
- result = _generate_sql_with_llm(
915
+ try:
916
+ result = _generate_sql_with_llm(
917
+ question=question,
918
+ llm=llm,
919
+ generate_sql_prompt=generate_sql_prompt,
920
+ current_context=_build_sql_generation_context(
921
+ conn_params=conn_params,
922
+ schema=schema,
923
+ concept=concept,
924
+ concept_metadata=concept_metadata,
925
+ graph_depth=graph_depth,
926
+ include_tags=include_tags,
927
+ exclude_properties=exclude_properties,
928
+ db_is_case_sensitive=db_is_case_sensitive,
929
+ max_limit=max_limit),
930
+ note=note,
931
+ timeout=timeout,
932
+ debug=debug,
933
+ )
934
+
935
+ usage_metadata['generate_sql'] = {
936
+ "approximate": result['apx_token_count'],
937
+ **result['usage_metadata'],
938
+ }
939
+ if debug and 'p_hash' in result:
940
+ usage_metadata['generate_sql']["p_hash"] = result['p_hash']
941
+
942
+ sql_query = result['sql']
943
+ generate_sql_reason = result.get('generate_sql_reason', None)
944
+ error = result['error']
945
+
946
+ if error:
947
+ raise Exception(error)
948
+
949
+ if enable_reasoning and sql_query is not None:
950
+ sql_query, graph_depth, generate_sql_reason = handle_generate_sql_reasoning(
951
+ sql_query=sql_query,
952
+ question=question,
953
+ llm=llm,
954
+ conn_params=conn_params,
955
+ schema=schema,
956
+ concept=concept,
957
+ concept_metadata=concept_metadata,
958
+ include_tags=include_tags,
959
+ exclude_properties=exclude_properties,
960
+ db_is_case_sensitive=db_is_case_sensitive,
961
+ max_limit=max_limit,
962
+ reasoning_steps=reasoning_steps,
963
+ note=note,
964
+ graph_depth=graph_depth,
965
+ usage_metadata=usage_metadata,
966
+ timeout=timeout,
967
+ debug=debug,
968
+ )
969
+
970
+ if should_validate_sql or enable_reasoning:
971
+ # Validate & regenerate only once if reasoning enabled and validation is disabled
972
+ validate_retries = 1 if not should_validate_sql else retries
973
+ is_sql_valid, error, sql_query = handle_validate_generate_sql(
974
+ sql_query=sql_query,
701
975
  question=question,
702
976
  llm=llm,
703
977
  conn_params=conn_params,
704
978
  generate_sql_prompt=generate_sql_prompt,
705
- current_context=_build_sql_generation_context(
706
- conn_params=conn_params,
707
- schema=schema,
708
- concept=concept,
709
- concept_metadata=concept_metadata,
710
- graph_depth=graph_depth,
711
- include_tags=include_tags,
712
- exclude_properties=exclude_properties,
713
- db_is_case_sensitive=db_is_case_sensitive,
714
- max_limit=max_limit),
715
- note=note + err_txt,
716
- should_validate_sql=should_validate_sql,
979
+ schema=schema,
980
+ concept=concept,
981
+ concept_metadata=concept_metadata,
982
+ include_tags=include_tags,
983
+ exclude_properties=exclude_properties,
984
+ db_is_case_sensitive=db_is_case_sensitive,
985
+ max_limit=max_limit,
986
+ graph_depth=graph_depth,
987
+ retries=validate_retries,
717
988
  timeout=timeout,
718
989
  debug=debug,
990
+ usage_metadata=usage_metadata,
719
991
  )
720
-
721
- usage_metadata['generate_sql'] = {
722
- "approximate": result['apx_token_count'],
723
- **result['usage_metadata'],
724
- }
725
- if debug and 'p_hash' in result:
726
- usage_metadata['generate_sql']["p_hash"] = result['p_hash']
727
-
728
- sql_query = result['sql']
729
- is_sql_valid = result['is_valid']
730
- error = result['error']
731
-
732
- except TimeoutError as e:
733
- error = f"LLM call timed out: {str(e)}"
734
- raise Exception(error)
735
- except Exception as e:
736
- error = f"LLM call failed: {str(e)}"
737
- if should_validate_sql:
738
- continue
739
- else:
740
- raise Exception(error)
741
-
742
-
743
- if enable_reasoning and sql_query is not None:
744
- for step in range(reasoning_steps):
745
- try:
746
- # Step 1: Evaluate the current SQL
747
- eval_result = _evaluate_sql_enable_reasoning(
748
- question=question,
749
- sql_query=sql_query,
750
- llm=llm,
751
- conn_params=conn_params,
752
- timeout=timeout,
753
- )
754
-
755
- usage_metadata[f'sql_reasoning_step_{step + 1}'] = {
756
- "approximate": eval_result['apx_token_count'],
757
- **eval_result['usage_metadata'],
758
- }
759
-
760
- evaluation = eval_result['evaluation']
761
- reasoning_status = evaluation.get("assessment", "partial").lower()
762
-
763
- if reasoning_status == "correct":
764
- break
765
-
766
- # Step 2: Regenerate SQL with feedback (with validation retries)
767
- evaluation_note = note + f"\n\nThe previously generated SQL: `{sql_query}` was assessed as '{evaluation.get('assessment')}' because: {evaluation.get('reasoning', '*could not determine cause*')}. Please provide a corrected SQL query that better answers the question: '{question}'.\n\nCRITICAL: Return ONLY the SQL query without any explanation or comments."
768
-
769
- # Increase graph depth for 2nd+ reasoning attempts, up to max of 3
770
- context_graph_depth = min(3, int(graph_depth) + step) if graph_depth < 3 and step > 0 else graph_depth
771
-
772
- # Regenerate SQL with validation retries
773
- # Always validate during reasoning to ensure quality, regardless of global should_validate_sql flag
774
- validation_iteration = 0
775
- regen_is_valid = False
776
- regen_error = ''
777
- regen_sql = None
778
-
779
- while validation_iteration < retries and (regen_sql is None or not regen_is_valid):
780
- validation_iteration += 1
781
- validation_err_txt = f"\nThe regenerated SQL (`{regen_sql}`) was invalid with error: {regen_error}. Please generate a corrected query." if regen_error and "snowflake" not in llm._llm_type else ""
782
-
783
- regen_result = _generate_sql_with_llm(
784
- question=question,
785
- llm=llm,
786
- conn_params=conn_params,
787
- generate_sql_prompt=generate_sql_prompt,
788
- current_context=_build_sql_generation_context(
789
- conn_params=conn_params,
790
- schema=schema,
791
- concept=concept,
792
- concept_metadata=concept_metadata,
793
- graph_depth=context_graph_depth,
794
- include_tags=include_tags,
795
- exclude_properties=exclude_properties,
796
- db_is_case_sensitive=db_is_case_sensitive,
797
- max_limit=max_limit),
798
- note=evaluation_note + validation_err_txt,
799
- should_validate_sql=True, # Always validate during reasoning
800
- timeout=timeout,
801
- debug=debug,
802
- )
803
-
804
- regen_sql = regen_result['sql']
805
- regen_is_valid = regen_result['is_valid']
806
- regen_error = regen_result['error']
807
-
808
- # Track token usage for each validation iteration
809
- if validation_iteration == 1:
810
- usage_metadata[f'generate_sql_reasoning_step_{step + 1}'] = {
811
- "approximate": regen_result['apx_token_count'],
812
- **regen_result['usage_metadata'],
813
- }
814
- if debug and 'p_hash' in regen_result:
815
- usage_metadata[f'generate_sql_reasoning_step_{step + 1}']['p_hash'] = regen_result['p_hash']
816
- else:
817
- usage_metadata[f'generate_sql_reasoning_step_{step + 1}_validation_{validation_iteration}'] = {
818
- "approximate": regen_result['apx_token_count'],
819
- **regen_result['usage_metadata'],
820
- }
821
- if debug and 'p_hash' in regen_result:
822
- usage_metadata[f'generate_sql_reasoning_step_{step + 1}_validation_{validation_iteration}']['p_hash'] = regen_result['p_hash']
823
-
824
- sql_query = regen_sql
825
- is_sql_valid = regen_is_valid
826
- error = regen_error
827
-
828
- except TimeoutError as e:
829
- raise Exception(f"LLM call timed out: {str(e)}")
830
- except Exception as e:
831
- print(f"Warning: LLM reasoning failed: {e}")
832
- break
992
+ except TimeoutError as e:
993
+ error = f"LLM call timed out: {str(e)}"
994
+ raise Exception(error)
995
+ except Exception as e:
996
+ error = f"LLM call failed: {str(e)}"
997
+ raise Exception(error)
833
998
 
834
999
  return {
835
1000
  "sql": sql_query,
@@ -837,6 +1002,8 @@ def generate_sql(
837
1002
  "schema": schema,
838
1003
  "error": error if not is_sql_valid else None,
839
1004
  "is_sql_valid": is_sql_valid if should_validate_sql else None,
1005
+ "identify_concept_reason": identify_concept_reason,
1006
+ "generate_sql_reason": generate_sql_reason,
840
1007
  "reasoning_status": reasoning_status,
841
1008
  "usage_metadata": usage_metadata,
842
1009
  }
@@ -27,6 +27,8 @@ class TestIdentifyTimbrConceptChain:
27
27
  print("IdentifyTimbrConceptChain result:", result)
28
28
  assert "concept" in result, "Chain should return a 'concept'"
29
29
  assert result["concept"], "Returned concept should not be empty"
30
+ assert "identify_concept_reason" in result, "Chain should return a 'identify_concept_reason'"
31
+ assert result["identify_concept_reason"], "Returned identify_concept_reason should not be empty"
30
32
  assert chain.usage_metadata_key in result, "Chain should return 'usage_metadata'"
31
33
  assert len(result[chain.usage_metadata_key]) == 1 and 'determine_concept' in result[chain.usage_metadata_key], "Usage metadata should contain only 'determine_concept'"
32
34
 
@@ -63,6 +65,11 @@ class TestGenerateTimbrSqlChain:
63
65
  assert "concept" in result and result["concept"], "Concept name should be returned"
64
66
  assert chain.usage_metadata_key in result, "Chain should return 'usage_metadata'"
65
67
  assert len(result[chain.usage_metadata_key]) == 2 and 'determine_concept' in result[chain.usage_metadata_key] and 'generate_sql' in result[chain.usage_metadata_key], "Usage metadata should contain both 'determine_concept' and 'generate_sql'"
68
+ assert "identify_concept_reason" in result, "Chain should return a 'identify_concept_reason'"
69
+ assert result["identify_concept_reason"], "Returned identify_concept_reason should not be empty"
70
+ assert "generate_sql_reason" in result, "Chain should return a 'generate_sql_reason'"
71
+ assert result["generate_sql_reason"], "Returned generate_sql_reason should not be empty"
72
+
66
73
 
67
74
  def test_generate_timbr_sql_with_limit_chain(self, llm, config):
68
75
  """Test SQL generation with row limit."""