langchain-timbr 2.1.6__tar.gz → 2.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/PKG-INFO +1 -1
  2. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/__init__.py +1 -1
  3. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/_version.py +2 -2
  4. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/config.py +1 -1
  5. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/langchain/execute_timbr_query_chain.py +4 -4
  6. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/langchain/generate_timbr_sql_chain.py +4 -4
  7. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/langchain/timbr_sql_agent.py +6 -6
  8. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/langchain/validate_timbr_sql_chain.py +11 -1
  9. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/langgraph/execute_timbr_query_node.py +3 -3
  10. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/langgraph/generate_timbr_sql_node.py +3 -3
  11. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/langgraph/validate_timbr_query_node.py +6 -0
  12. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/utils/prompt_service.py +26 -0
  13. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/utils/timbr_llm_utils.py +124 -132
  14. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/utils/timbr_utils.py +39 -0
  15. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/conftest.py +1 -0
  16. langchain_timbr-2.1.7/tests/integration/test_chain_reasoning.py +84 -0
  17. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/integration/test_langchain_chains.py +9 -8
  18. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/.github/dependabot.yml +0 -0
  19. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/.github/pull_request_template.md +0 -0
  20. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/.github/workflows/_codespell.yml +0 -0
  21. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/.github/workflows/_fossa.yml +0 -0
  22. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/.github/workflows/install-dependencies-and-run-tests.yml +0 -0
  23. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/.github/workflows/publish.yml +0 -0
  24. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/.gitignore +0 -0
  25. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/LICENSE +0 -0
  26. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/README.md +0 -0
  27. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/SECURITY.md +0 -0
  28. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/pyproject.toml +0 -0
  29. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/pytest.ini +0 -0
  30. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/requirements.txt +0 -0
  31. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/requirements310.txt +0 -0
  32. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/requirements311.txt +0 -0
  33. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/langchain/__init__.py +0 -0
  34. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/langchain/generate_answer_chain.py +0 -0
  35. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/langchain/identify_concept_chain.py +0 -0
  36. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/langgraph/__init__.py +0 -0
  37. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/langgraph/generate_response_node.py +0 -0
  38. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/langgraph/identify_concept_node.py +0 -0
  39. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/llm_wrapper/llm_wrapper.py +0 -0
  40. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/llm_wrapper/timbr_llm_wrapper.py +0 -0
  41. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/timbr_llm_connector.py +0 -0
  42. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/utils/general.py +0 -0
  43. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/src/langchain_timbr/utils/temperature_supported_models.json +0 -0
  44. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/README.md +0 -0
  45. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/integration/test_agent_integration.py +0 -0
  46. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/integration/test_azure_databricks_provider.py +0 -0
  47. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/integration/test_azure_openai_model.py +0 -0
  48. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/integration/test_chain_pipeline.py +0 -0
  49. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/integration/test_jwt_token.py +0 -0
  50. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/integration/test_langgraph_nodes.py +0 -0
  51. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/integration/test_timeout_functionality.py +0 -0
  52. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/standard/conftest.py +0 -0
  53. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/standard/test_chain_documentation.py +0 -0
  54. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/standard/test_connection_validation.py +0 -0
  55. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/standard/test_llm_wrapper_optional_params.py +0 -0
  56. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/standard/test_optional_llm_integration.py +0 -0
  57. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/standard/test_standard_chain_requirements.py +0 -0
  58. {langchain_timbr-2.1.6 → langchain_timbr-2.1.7}/tests/standard/test_unit_tests.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langchain-timbr
3
- Version: 2.1.6
3
+ Version: 2.1.7
4
4
  Summary: LangChain & LangGraph extensions that parse LLM prompts into Timbr semantic SQL and execute them.
5
5
  Project-URL: Homepage, https://github.com/WPSemantix/langchain-timbr
6
6
  Project-URL: Documentation, https://docs.timbr.ai/doc/docs/integration/langchain-sdk/
@@ -12,7 +12,7 @@
12
12
  from .timbr_llm_connector import TimbrLlmConnector
13
13
  from .llm_wrapper.llm_wrapper import LlmWrapper, LlmTypes
14
14
 
15
- from .utils.timbr_llm_utils import (
15
+ from .utils.timbr_utils import (
16
16
  generate_key,
17
17
  encrypt_prompt,
18
18
  decrypt_prompt,
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.1.6'
32
- __version_tuple__ = version_tuple = (2, 1, 6)
31
+ __version__ = version = '2.1.7'
32
+ __version_tuple__ = version_tuple = (2, 1, 7)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -30,5 +30,5 @@ llm_api_version = os.environ.get('LLM_API_VERSION', None)
30
30
  llm_scope = os.environ.get('LLM_SCOPE', "https://cognitiveservices.azure.com/.default") # e.g. "api://<your-client-id>/.default"
31
31
 
32
32
  # Whether to enable reasoning during SQL generation
33
- with_reasoning = to_boolean(os.environ.get('WITH_REASONING', 'false'))
33
+ enable_reasoning = to_boolean(os.environ.get('ENABLE_REASONING', 'false'))
34
34
  reasoning_steps = to_integer(os.environ.get('REASONING_STEPS', 2))
@@ -42,7 +42,7 @@ class ExecuteTimbrQueryChain(Chain):
42
42
  is_jwt: Optional[bool] = False,
43
43
  jwt_tenant_id: Optional[str] = None,
44
44
  conn_params: Optional[dict] = None,
45
- with_reasoning: Optional[bool] = config.with_reasoning,
45
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
46
46
  reasoning_steps: Optional[int] = config.reasoning_steps,
47
47
  debug: Optional[bool] = False,
48
48
  **kwargs,
@@ -71,7 +71,7 @@ class ExecuteTimbrQueryChain(Chain):
71
71
  :param is_jwt: Whether to use JWT authentication (default is False).
72
72
  :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
73
73
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
74
- :param with_reasoning: Whether to enable reasoning during SQL generation (default is False).
74
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
75
75
  :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
76
76
  :param kwargs: Additional arguments to pass to the base
77
77
  :return: A list of rows from the Timbr query
@@ -141,7 +141,7 @@ class ExecuteTimbrQueryChain(Chain):
141
141
  self._jwt_tenant_id = jwt_tenant_id
142
142
  self._debug = to_boolean(debug)
143
143
  self._conn_params = conn_params or {}
144
- self._with_reasoning = to_boolean(with_reasoning)
144
+ self._enable_reasoning = to_boolean(enable_reasoning)
145
145
  self._reasoning_steps = to_integer(reasoning_steps)
146
146
 
147
147
 
@@ -215,7 +215,7 @@ class ExecuteTimbrQueryChain(Chain):
215
215
  note=(self._note or '') + err_txt,
216
216
  db_is_case_sensitive=self._db_is_case_sensitive,
217
217
  graph_depth=self._graph_depth,
218
- with_reasoning=self._with_reasoning,
218
+ enable_reasoning=self._enable_reasoning,
219
219
  reasoning_steps=self._reasoning_steps,
220
220
  debug=self._debug,
221
221
  )
@@ -39,7 +39,7 @@ class GenerateTimbrSqlChain(Chain):
39
39
  is_jwt: Optional[bool] = False,
40
40
  jwt_tenant_id: Optional[str] = None,
41
41
  conn_params: Optional[dict] = None,
42
- with_reasoning: Optional[bool] = config.with_reasoning,
42
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
43
43
  reasoning_steps: Optional[int] = config.reasoning_steps,
44
44
  debug: Optional[bool] = False,
45
45
  **kwargs,
@@ -66,7 +66,7 @@ class GenerateTimbrSqlChain(Chain):
66
66
  :param is_jwt: Whether to use JWT authentication (default is False).
67
67
  :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
68
68
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
69
- :param with_reasoning: Whether to enable reasoning during SQL generation (default is False).
69
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
70
70
  :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
71
71
  :param debug: Whether to enable debug mode for detailed logging
72
72
  :param kwargs: Additional arguments to pass to the base
@@ -134,7 +134,7 @@ class GenerateTimbrSqlChain(Chain):
134
134
  self._jwt_tenant_id = jwt_tenant_id
135
135
  self._debug = to_boolean(debug)
136
136
  self._conn_params = conn_params or {}
137
- self._with_reasoning = to_boolean(with_reasoning)
137
+ self._enable_reasoning = to_boolean(enable_reasoning)
138
138
  self._reasoning_steps = to_integer(reasoning_steps)
139
139
 
140
140
 
@@ -191,7 +191,7 @@ class GenerateTimbrSqlChain(Chain):
191
191
  note=self._note,
192
192
  db_is_case_sensitive=self._db_is_case_sensitive,
193
193
  graph_depth=self._graph_depth,
194
- with_reasoning=self._with_reasoning,
194
+ enable_reasoning=self._enable_reasoning,
195
195
  reasoning_steps=self._reasoning_steps,
196
196
  debug=self._debug,
197
197
  )
@@ -35,7 +35,7 @@ class TimbrSqlAgent(BaseSingleActionAgent):
35
35
  is_jwt: Optional[bool] = False,
36
36
  jwt_tenant_id: Optional[str] = None,
37
37
  conn_params: Optional[dict] = None,
38
- with_reasoning: Optional[bool] = config.with_reasoning,
38
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
39
39
  reasoning_steps: Optional[int] = config.reasoning_steps,
40
40
  debug: Optional[bool] = False
41
41
  ):
@@ -64,7 +64,7 @@ class TimbrSqlAgent(BaseSingleActionAgent):
64
64
  :param is_jwt: Whether to use JWT authentication (default is False).
65
65
  :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
66
66
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
67
- :param with_reasoning: Whether to enable reasoning during SQL generation (default is False).
67
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
68
68
  :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
69
69
 
70
70
  ## Example
@@ -118,7 +118,7 @@ class TimbrSqlAgent(BaseSingleActionAgent):
118
118
  is_jwt=to_boolean(is_jwt),
119
119
  jwt_tenant_id=jwt_tenant_id,
120
120
  conn_params=conn_params,
121
- with_reasoning=to_boolean(with_reasoning),
121
+ enable_reasoning=to_boolean(enable_reasoning),
122
122
  reasoning_steps=to_integer(reasoning_steps),
123
123
  debug=to_boolean(debug),
124
124
  )
@@ -345,7 +345,7 @@ def create_timbr_sql_agent(
345
345
  is_jwt: Optional[bool] = False,
346
346
  jwt_tenant_id: Optional[str] = None,
347
347
  conn_params: Optional[dict] = None,
348
- with_reasoning: Optional[bool] = config.with_reasoning,
348
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
349
349
  reasoning_steps: Optional[int] = config.reasoning_steps,
350
350
  debug: Optional[bool] = False
351
351
  ) -> AgentExecutor:
@@ -376,7 +376,7 @@ def create_timbr_sql_agent(
376
376
  :param is_jwt: Whether to use JWT authentication (default is False).
377
377
  :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
378
378
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
379
- :param with_reasoning: Whether to enable reasoning during SQL generation (default is False).
379
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
380
380
  :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
381
381
 
382
382
  Returns:
@@ -444,7 +444,7 @@ def create_timbr_sql_agent(
444
444
  is_jwt=is_jwt,
445
445
  jwt_tenant_id=jwt_tenant_id,
446
446
  conn_params=conn_params,
447
- with_reasoning=with_reasoning,
447
+ enable_reasoning=enable_reasoning,
448
448
  reasoning_steps=reasoning_steps,
449
449
  debug=debug,
450
450
  )
@@ -40,6 +40,8 @@ class ValidateTimbrSqlChain(Chain):
40
40
  is_jwt: Optional[bool] = False,
41
41
  jwt_tenant_id: Optional[str] = None,
42
42
  conn_params: Optional[dict] = None,
43
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
44
+ reasoning_steps: Optional[int] = config.reasoning_steps,
43
45
  debug: Optional[bool] = False,
44
46
  **kwargs,
45
47
  ):
@@ -64,6 +66,8 @@ class ValidateTimbrSqlChain(Chain):
64
66
  :param is_jwt: Whether to use JWT authentication (default is False).
65
67
  :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
66
68
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
69
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
70
+ :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
67
71
  :param kwargs: Additional arguments to pass to the base
68
72
 
69
73
  ## Example
@@ -127,8 +131,10 @@ class ValidateTimbrSqlChain(Chain):
127
131
  self._verify_ssl = to_boolean(verify_ssl)
128
132
  self._is_jwt = to_boolean(is_jwt)
129
133
  self._jwt_tenant_id = jwt_tenant_id
130
- self._debug = to_boolean(debug)
131
134
  self._conn_params = conn_params or {}
135
+ self._enable_reasoning = to_boolean(enable_reasoning)
136
+ self._reasoning_steps = to_integer(reasoning_steps)
137
+ self._debug = to_boolean(debug)
132
138
 
133
139
 
134
140
  @property
@@ -193,6 +199,8 @@ class ValidateTimbrSqlChain(Chain):
193
199
  note=f"{prompt_extension}The original SQL query (`{sql}`) was invalid with this error from query {error}. Please take this in consideration while generating the query.",
194
200
  db_is_case_sensitive=self._db_is_case_sensitive,
195
201
  graph_depth=self._graph_depth,
202
+ enable_reasoning=self._enable_reasoning,
203
+ reasoning_steps=self._reasoning_steps,
196
204
  debug=self._debug,
197
205
  )
198
206
  sql = generate_res.get("sql", "")
@@ -200,6 +208,7 @@ class ValidateTimbrSqlChain(Chain):
200
208
  concept = generate_res.get("concept", self._concept)
201
209
  usage_metadata.update(generate_res.get("usage_metadata", {}))
202
210
  is_sql_valid = generate_res.get("is_sql_valid")
211
+ reasoning_status = generate_res.get("reasoning_status")
203
212
  error = generate_res.get("error")
204
213
 
205
214
  return {
@@ -208,5 +217,6 @@ class ValidateTimbrSqlChain(Chain):
208
217
  "concept": concept,
209
218
  "is_sql_valid": is_sql_valid,
210
219
  "error": error,
220
+ "reasoning_status": reasoning_status,
211
221
  self.usage_metadata_key: usage_metadata,
212
222
  }
@@ -36,7 +36,7 @@ class ExecuteSemanticQueryNode:
36
36
  is_jwt: Optional[bool] = False,
37
37
  jwt_tenant_id: Optional[str] = None,
38
38
  conn_params: Optional[dict] = None,
39
- with_reasoning: Optional[bool] = config.with_reasoning,
39
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
40
40
  reasoning_steps: Optional[int] = config.reasoning_steps,
41
41
  debug: Optional[bool] = False,
42
42
  **kwargs,
@@ -65,7 +65,7 @@ class ExecuteSemanticQueryNode:
65
65
  :param is_jwt: Whether to use JWT authentication (default is False).
66
66
  :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
67
67
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
68
- :param with_reasoning: Whether to enable reasoning during SQL generation (default is False).
68
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
69
69
  :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
70
70
  :return: A list of rows from the Timbr query
71
71
  """
@@ -93,7 +93,7 @@ class ExecuteSemanticQueryNode:
93
93
  is_jwt=is_jwt,
94
94
  jwt_tenant_id=jwt_tenant_id,
95
95
  conn_params=conn_params,
96
- with_reasoning=with_reasoning,
96
+ enable_reasoning=enable_reasoning,
97
97
  reasoning_steps=reasoning_steps,
98
98
  debug=debug,
99
99
  **kwargs,
@@ -33,7 +33,7 @@ class GenerateTimbrSqlNode:
33
33
  is_jwt: Optional[bool] = False,
34
34
  jwt_tenant_id: Optional[str] = None,
35
35
  conn_params: Optional[dict] = None,
36
- with_reasoning: Optional[bool] = config.with_reasoning,
36
+ enable_reasoning: Optional[bool] = config.enable_reasoning,
37
37
  reasoning_steps: Optional[int] = config.reasoning_steps,
38
38
  debug: Optional[bool] = False,
39
39
  **kwargs,
@@ -60,7 +60,7 @@ class GenerateTimbrSqlNode:
60
60
  :param is_jwt: Whether to use JWT authentication (default: False)
61
61
  :param jwt_tenant_id: Tenant ID for JWT authentication when using multi-tenant setup
62
62
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
63
- :param with_reasoning: Whether to enable reasoning during SQL generation (default is False).
63
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
64
64
  :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
65
65
  """
66
66
  self.chain = GenerateTimbrSqlChain(
@@ -85,7 +85,7 @@ class GenerateTimbrSqlNode:
85
85
  is_jwt=is_jwt,
86
86
  jwt_tenant_id=jwt_tenant_id,
87
87
  conn_params=conn_params,
88
- with_reasoning=with_reasoning,
88
+ enable_reasoning=enable_reasoning,
89
89
  reasoning_steps=reasoning_steps,
90
90
  debug=debug,
91
91
  **kwargs,
@@ -33,6 +33,8 @@ class ValidateSemanticSqlNode:
33
33
  is_jwt: Optional[bool] = False,
34
34
  jwt_tenant_id: Optional[str] = None,
35
35
  conn_params: Optional[dict] = None,
36
+ enable_reasoning: Optional[bool] = False,
37
+ reasoning_steps: Optional[int] = 2,
36
38
  debug: Optional[bool] = False,
37
39
  **kwargs,
38
40
  ):
@@ -57,6 +59,8 @@ class ValidateSemanticSqlNode:
57
59
  :param is_jwt: Whether to use JWT authentication (default: False)
58
60
  :param jwt_tenant_id: Tenant ID for JWT authentication when using multi-tenant setup
59
61
  :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
62
+ :param enable_reasoning: Whether to enable reasoning during SQL generation (default is False).
63
+ :param reasoning_steps: Number of reasoning steps to perform if reasoning is enabled (default is 2).
60
64
  """
61
65
  self.chain = ValidateTimbrSqlChain(
62
66
  llm=llm,
@@ -79,6 +83,8 @@ class ValidateSemanticSqlNode:
79
83
  is_jwt=is_jwt,
80
84
  jwt_tenant_id=jwt_tenant_id,
81
85
  conn_params=conn_params,
86
+ enable_reasoning=enable_reasoning,
87
+ reasoning_steps=reasoning_steps,
82
88
  debug=debug,
83
89
  **kwargs,
84
90
  )
@@ -184,6 +184,16 @@ class PromptService:
184
184
  return self._fetch_template("llm_prompts/generate_sql")
185
185
 
186
186
 
187
+ def get_generate_sql_reasoning_template(self) -> ChatPromptTemplate:
188
+ """
189
+ Get generate SQL reasoning template from API service (cached)
190
+
191
+ Returns:
192
+ ChatPromptTemplate object
193
+ """
194
+ return self._fetch_template("llm_prompts/generate_sql_reasoning")
195
+
196
+
187
197
  def get_generate_answer_template(self) -> ChatPromptTemplate:
188
198
  """
189
199
  Get generate answer template from API service (cached)
@@ -264,6 +274,22 @@ def get_generate_sql_prompt_template(
264
274
  return PromptTemplateWrapper(prompt_service, "get_generate_sql_template")
265
275
 
266
276
 
277
+ def get_generate_sql_reasoning_prompt_template(
278
+ conn_params: Optional[dict] = None
279
+ ) -> PromptTemplateWrapper:
280
+ """
281
+ Get generate SQL reasoning prompt template wrapper
282
+
283
+ Args:
284
+ conn_params: Connection parameters including url, token, is_jwt, and jwt_tenant_id
285
+
286
+ Returns:
287
+ PromptTemplateWrapper for generate SQL reasoning
288
+ """
289
+ prompt_service = PromptService(conn_params=conn_params)
290
+ return PromptTemplateWrapper(prompt_service, "get_generate_sql_reasoning_template")
291
+
292
+
267
293
  def get_qa_prompt_template(
268
294
  conn_params: Optional[dict] = None
269
295
  ) -> PromptTemplateWrapper:
@@ -1,21 +1,18 @@
1
1
  from typing import Any, Optional
2
2
  from langchain.llms.base import LLM
3
- import base64, hashlib
4
- from cryptography.fernet import Fernet
5
3
  from datetime import datetime
6
4
  import concurrent.futures
7
5
  import json
8
- from langchain_core.messages import HumanMessage, SystemMessage
9
6
 
10
- from .timbr_utils import get_datasources, get_tags, get_concepts, get_concept_properties, validate_sql, get_properties_description, get_relationships_description
7
+ from .timbr_utils import get_datasources, get_tags, get_concepts, get_concept_properties, validate_sql, get_properties_description, get_relationships_description, cache_with_version_check, encrypt_prompt
11
8
  from .prompt_service import (
12
9
  get_determine_concept_prompt_template,
13
10
  get_generate_sql_prompt_template,
11
+ get_generate_sql_reasoning_prompt_template,
14
12
  get_qa_prompt_template
15
13
  )
16
14
  from ..config import llm_timeout
17
15
 
18
-
19
16
  def _clean_snowflake_prompt(prompt: Any) -> None:
20
17
  import re
21
18
 
@@ -55,14 +52,6 @@ def _clean_snowflake_prompt(prompt: Any) -> None:
55
52
  prompt[1].content = clean_func(prompt[1].content) # User message
56
53
 
57
54
 
58
- def generate_key() -> bytes:
59
- """Generate a new Fernet secret key."""
60
- passcode = b"lucylit2025"
61
- hlib = hashlib.md5()
62
- hlib.update(passcode)
63
- return base64.urlsafe_b64encode(hlib.hexdigest().encode('utf-8'))
64
-
65
-
66
55
  def _call_llm_with_timeout(llm: LLM, prompt: Any, timeout: int = 60) -> Any:
67
56
  """
68
57
  Call LLM with timeout to prevent hanging.
@@ -91,35 +80,9 @@ def _call_llm_with_timeout(llm: LLM, prompt: Any, timeout: int = 60) -> Any:
91
80
  except Exception as e:
92
81
  raise e
93
82
 
94
- ENCRYPT_KEY = generate_key()
95
83
  MEASURES_DESCRIPTION = "The following columns are calculated measures and can only be aggregated with an aggregate function: COUNT/SUM/AVG/MIN/MAX (count distinct is not allowed)"
96
84
  TRANSITIVE_RELATIONSHIP_DESCRIPTION = "Transitive relationship columns allow you to access data through multiple relationship hops. These columns follow the pattern `<relationship_name>[<table_name>*<number>].<column_name>` where the number after the asterisk (*) indicates how many relationship levels to traverse. For example, `acquired_by[company*4].company_name` means 'go through up to 4 levels of the acquired_by relationship to get the company name', while columns ending with '_transitivity_level' indicate the actual relationship depth (Cannot be null or 0 - level 1 represents direct relationships, while levels 2, 3, 4, etc. represent indirect relationships through multiple hops. To filter by relationship type, use `_transitivity_level = 1` for direct relationships only, `_transitivity_level > 1` for indirect relationships only."
97
85
 
98
- def encrypt_prompt(prompt: Any, key: Optional[bytes] = ENCRYPT_KEY) -> bytes:
99
- """Serialize & encrypt the prompt; returns a URL-safe token."""
100
- # build prompt_text as before…
101
- if isinstance(prompt, str):
102
- text = prompt
103
- elif isinstance(prompt, list):
104
- parts = []
105
- for message in prompt:
106
- if hasattr(message, "content"):
107
- parts.append(f"{message.type}: {message.content}")
108
- else:
109
- parts.append(str(message))
110
- text = "\n".join(parts)
111
- else:
112
- text = str(prompt)
113
-
114
- f = Fernet(key)
115
- return f.encrypt(text.encode()).decode('utf-8')
116
-
117
-
118
- def decrypt_prompt(token: bytes, key: bytes) -> str:
119
- """Decrypt the token and return the original prompt string."""
120
- f = Fernet(key)
121
- return f.decrypt(token).decode()
122
-
123
86
 
124
87
  def _prompt_to_string(prompt: Any) -> str:
125
88
  prompt_text = ''
@@ -187,6 +150,7 @@ def _get_response_text(response: Any) -> str:
187
150
 
188
151
  return response_text
189
152
 
153
+
190
154
  def _extract_usage_metadata(response: Any) -> dict:
191
155
  """
192
156
  Extract usage metadata from LLM response across different providers.
@@ -287,6 +251,7 @@ def _extract_usage_metadata(response: Any) -> dict:
287
251
 
288
252
  return usage_metadata
289
253
 
254
+
290
255
  def determine_concept(
291
256
  question: str,
292
257
  llm: LLM,
@@ -479,10 +444,11 @@ def _get_active_datasource(conn_params: dict) -> dict:
479
444
  return datasources[0] if datasources else None
480
445
 
481
446
 
482
- def _evaluate_sql_with_reasoning(
447
+ def _evaluate_sql_enable_reasoning(
483
448
  question: str,
484
449
  sql_query: str,
485
450
  llm: LLM,
451
+ conn_params: dict,
486
452
  timeout: int,
487
453
  ) -> dict:
488
454
  """
@@ -491,49 +457,17 @@ def _evaluate_sql_with_reasoning(
491
457
  Returns:
492
458
  dict with 'assessment' ('correct'|'partial'|'incorrect') and 'reasoning'
493
459
  """
494
- system_prompt = """You are an expert SQL and data analysis evaluator for Timbr.ai knowledge graph queries.
495
-
496
- **IMPORTANT CONTEXT:**
497
- - This system uses Timbr.ai, which extends SQL with semantic graph layer, including traversals, measures and more
498
- - Field names may use special Timbr syntax that is NOT standard SQL but is VALID in this system:
499
- * `measure.<measure_name>` - References computed measures (e.g., measure.total_balance_amount)
500
- * `<relationship>[target_table].<property>` - Graph traversal syntax (e.g., has_account[Account].account_name)
501
- * These are translated by Timbr to standard SQL before execution
502
- - DO NOT mark queries as incorrect based on field name syntax - Timbr validates this before execution
503
-
504
- Evaluate whether the generated query correctly addresses the business question:
505
- - **correct**: The query fully and accurately answers the question
506
- - **partial**: The query is partially correct or incomplete
507
- - **incorrect**: The query does not address the question or is wrong
508
-
509
- Return your evaluation as a JSON object with this exact structure:
510
- {
511
- "assessment": "<correct|partial|incorrect>",
512
- "reasoning": "<short but precise sentence explaining your assessment>"
513
- }
514
-
515
- Be concise and objective."""
516
-
517
- user_prompt = f"""**Business Question:**
518
- {question}
519
-
520
- **Generated SQL Query:**
521
- ```sql
522
- {sql_query}
523
- ```
524
-
525
- Please evaluate this result."""
526
-
527
- messages = [
528
- SystemMessage(content=system_prompt),
529
- HumanMessage(content=user_prompt)
530
- ]
460
+ generate_sql_reasoning_template = get_generate_sql_reasoning_prompt_template(conn_params)
461
+ prompt = generate_sql_reasoning_template.format_messages(
462
+ question=question.strip(),
463
+ sql_query=sql_query.strip(),
464
+ )
531
465
 
532
- apx_token_count = _calculate_token_count(llm, messages)
466
+ apx_token_count = _calculate_token_count(llm, prompt)
533
467
  if hasattr(llm, "_llm_type") and "snowflake" in llm._llm_type:
534
- _clean_snowflake_prompt(messages)
468
+ _clean_snowflake_prompt(prompt)
535
469
 
536
- response = _call_llm_with_timeout(llm, messages, timeout=timeout)
470
+ response = _call_llm_with_timeout(llm, prompt, timeout=timeout)
537
471
 
538
472
  # Extract JSON from response content (handle markdown code blocks)
539
473
  content = response.content.strip()
@@ -559,6 +493,87 @@ Please evaluate this result."""
559
493
  }
560
494
 
561
495
 
496
+ @cache_with_version_check
497
+ def _build_sql_generation_context(
498
+ conn_params: dict,
499
+ schema: str,
500
+ concept: str,
501
+ concept_metadata: dict,
502
+ graph_depth: int,
503
+ include_tags: Optional[str],
504
+ exclude_properties: Optional[list],
505
+ db_is_case_sensitive: bool,
506
+ max_limit: int,
507
+ ) -> dict:
508
+ """
509
+ Prepare the complete SQL generation context by gathering all necessary metadata.
510
+
511
+ This includes:
512
+ - Datasource information
513
+ - Concept properties (columns, measures, relationships)
514
+ - Property tags
515
+ - Building column/measure/relationship descriptions
516
+ - Assembling the final context dictionary
517
+
518
+ Returns:
519
+ dict containing all context needed for SQL generation prompts
520
+ """
521
+ datasource_type = _get_active_datasource(conn_params).get('target_type')
522
+
523
+ properties_desc = get_properties_description(conn_params=conn_params)
524
+ relationships_desc = get_relationships_description(conn_params=conn_params)
525
+
526
+ concept_properties_metadata = get_concept_properties(
527
+ schema=schema,
528
+ concept_name=concept,
529
+ conn_params=conn_params,
530
+ properties_desc=properties_desc,
531
+ relationships_desc=relationships_desc,
532
+ graph_depth=graph_depth
533
+ )
534
+ columns = concept_properties_metadata.get('columns', [])
535
+ measures = concept_properties_metadata.get('measures', [])
536
+ relationships = concept_properties_metadata.get('relationships', {})
537
+ tags = get_tags(conn_params=conn_params, include_tags=include_tags).get('property_tags')
538
+
539
+ columns_str = _build_columns_str(columns, columns_tags=tags, exclude=exclude_properties)
540
+ measures_str = _build_columns_str(measures, tags, exclude=exclude_properties)
541
+ rel_prop_str = _build_rel_columns_str(relationships, columns_tags=tags, exclude_properties=exclude_properties)
542
+
543
+ if rel_prop_str:
544
+ measures_str += f"\n{rel_prop_str}"
545
+
546
+ # Determine if relationships have transitive properties
547
+ has_transitive_relationships = any(
548
+ rel.get('is_transitive')
549
+ for rel in relationships.values()
550
+ ) if relationships else False
551
+
552
+ concept_description = f"- Description: {concept_metadata.get('description')}\n" if concept_metadata and concept_metadata.get('description') else ""
553
+ concept_tags = concept_metadata.get('tags') if concept_metadata and concept_metadata.get('tags') else ""
554
+
555
+ cur_date = datetime.now().strftime("%Y-%m-%d")
556
+
557
+ # Build context descriptions
558
+ sensitivity_txt = "- Ensure value comparisons are case-insensitive, e.g., use LOWER(column) = 'value'.\n" if db_is_case_sensitive else ""
559
+ measures_context = f"- {MEASURES_DESCRIPTION}: {measures_str}\n" if measures_str else ""
560
+ transitive_context = f"- {TRANSITIVE_RELATIONSHIP_DESCRIPTION}\n" if has_transitive_relationships else ""
561
+
562
+ return {
563
+ 'cur_date': cur_date,
564
+ 'datasource_type': datasource_type or 'standard sql',
565
+ 'schema': schema,
566
+ 'concept': concept,
567
+ 'concept_description': concept_description or "",
568
+ 'concept_tags': concept_tags or "",
569
+ 'columns_str': columns_str,
570
+ 'measures_context': measures_context,
571
+ 'transitive_context': transitive_context,
572
+ 'sensitivity_txt': sensitivity_txt,
573
+ 'max_limit': max_limit,
574
+ }
575
+
576
+
562
577
  def _generate_sql_with_llm(
563
578
  question: str,
564
579
  llm: LLM,
@@ -639,7 +654,7 @@ def generate_sql(
639
654
  note: Optional[str] = '',
640
655
  db_is_case_sensitive: Optional[bool] = False,
641
656
  graph_depth: Optional[int] = 1,
642
- with_reasoning: Optional[bool] = False,
657
+ enable_reasoning: Optional[bool] = False,
643
658
  reasoning_steps: Optional[int] = 2,
644
659
  debug: Optional[bool] = False,
645
660
  timeout: Optional[int] = None,
@@ -679,50 +694,6 @@ def generate_sql(
679
694
  if not concept:
680
695
  raise Exception("No relevant concept found for the query.")
681
696
 
682
- datasource_type = _get_active_datasource(conn_params).get('target_type')
683
-
684
- properties_desc = get_properties_description(conn_params=conn_params)
685
- relationships_desc = get_relationships_description(conn_params=conn_params)
686
-
687
- concept_properties_metadata = get_concept_properties(schema=schema, concept_name=concept, conn_params=conn_params, properties_desc=properties_desc, relationships_desc=relationships_desc, graph_depth=graph_depth)
688
- columns, measures, relationships = concept_properties_metadata.get('columns', []), concept_properties_metadata.get('measures', []), concept_properties_metadata.get('relationships', {})
689
- tags = get_tags(conn_params=conn_params, include_tags=include_tags).get('property_tags')
690
-
691
- columns_str = _build_columns_str(columns, columns_tags=tags, exclude=exclude_properties)
692
- measures_str = _build_columns_str(measures, tags, exclude=exclude_properties)
693
- rel_prop_str = _build_rel_columns_str(relationships, columns_tags=tags, exclude_properties=exclude_properties)
694
-
695
- if rel_prop_str:
696
- measures_str += f"\n{rel_prop_str}"
697
-
698
-
699
- # Build context descriptions
700
- sensitivity_txt = "- Ensure value comparisons are case-insensitive, e.g., use LOWER(column) = 'value'.\n" if db_is_case_sensitive else ""
701
- measures_context = f"- {MEASURES_DESCRIPTION}: {measures_str}\n" if measures_str else ""
702
- has_transitive_relationships = any(
703
- rel.get('is_transitive')
704
- for rel in relationships.values()
705
- ) if relationships else False
706
- transitive_context = f"- {TRANSITIVE_RELATIONSHIP_DESCRIPTION}\n" if has_transitive_relationships else ""
707
- concept_description = f"- Description: {concept_metadata.get('description')}\n" if concept_metadata and concept_metadata.get('description') else ""
708
- concept_tags = concept_metadata.get('tags') if concept_metadata and concept_metadata.get('tags') else ""
709
- cur_date = datetime.now().strftime("%Y-%m-%d")
710
-
711
- # Build context dict for SQL generation
712
- current_context = {
713
- 'cur_date': cur_date,
714
- 'datasource_type': datasource_type or 'standard sql',
715
- 'schema': schema,
716
- 'concept': concept,
717
- 'concept_description': concept_description or "",
718
- 'concept_tags': concept_tags or "",
719
- 'columns_str': columns_str,
720
- 'measures_context': measures_context,
721
- 'transitive_context': transitive_context,
722
- 'sensitivity_txt': sensitivity_txt,
723
- 'max_limit': max_limit,
724
- }
725
-
726
697
  sql_query = None
727
698
  iteration = 0
728
699
  is_sql_valid = True
@@ -737,7 +708,16 @@ def generate_sql(
737
708
  llm=llm,
738
709
  conn_params=conn_params,
739
710
  generate_sql_prompt=generate_sql_prompt,
740
- current_context=current_context,
711
+ current_context=_build_sql_generation_context(
712
+ conn_params=conn_params,
713
+ schema=schema,
714
+ concept=concept,
715
+ concept_metadata=concept_metadata,
716
+ graph_depth=graph_depth,
717
+ include_tags=include_tags,
718
+ exclude_properties=exclude_properties,
719
+ db_is_case_sensitive=db_is_case_sensitive,
720
+ max_limit=max_limit),
741
721
  note=note + err_txt,
742
722
  should_validate_sql=should_validate_sql,
743
723
  timeout=timeout,
@@ -766,18 +746,19 @@ def generate_sql(
766
746
  raise Exception(error)
767
747
 
768
748
 
769
- if with_reasoning and sql_query is not None:
749
+ if enable_reasoning and sql_query is not None:
770
750
  for step in range(reasoning_steps):
771
751
  try:
772
752
  # Step 1: Evaluate the current SQL
773
- eval_result = _evaluate_sql_with_reasoning(
753
+ eval_result = _evaluate_sql_enable_reasoning(
774
754
  question=question,
775
755
  sql_query=sql_query,
776
756
  llm=llm,
757
+ conn_params=conn_params,
777
758
  timeout=timeout,
778
759
  )
779
760
 
780
- usage_metadata[f'sql_reasoning_step_{step}'] = {
761
+ usage_metadata[f'sql_reasoning_step_{step + 1}'] = {
781
762
  "approximate": eval_result['apx_token_count'],
782
763
  **eval_result['usage_metadata'],
783
764
  }
@@ -791,24 +772,35 @@ def generate_sql(
791
772
  # Step 2: Regenerate SQL with feedback
792
773
  evaluation_note = note + f"\n\nThe previously generated SQL: `{sql_query}` was assessed as '{evaluation.get('assessment')}' because: {evaluation.get('reasoning', '*could not determine cause*')}. Please provide a corrected SQL query that better answers the question: '{question}'."
793
774
 
775
+ # Increase graph depth for 2nd+ reasoning attempts, up to max of 3
776
+ context_graph_depth = min(3, int(graph_depth) + step) if graph_depth < 3 and step > 0 else graph_depth
794
777
  regen_result = _generate_sql_with_llm(
795
778
  question=question,
796
779
  llm=llm,
797
780
  conn_params=conn_params,
798
781
  generate_sql_prompt=generate_sql_prompt,
799
- current_context=current_context,
782
+ current_context=_build_sql_generation_context(
783
+ conn_params=conn_params,
784
+ schema=schema,
785
+ concept=concept,
786
+ concept_metadata=concept_metadata,
787
+ graph_depth=context_graph_depth,
788
+ include_tags=include_tags,
789
+ exclude_properties=exclude_properties,
790
+ db_is_case_sensitive=db_is_case_sensitive,
791
+ max_limit=max_limit),
800
792
  note=evaluation_note,
801
793
  should_validate_sql=should_validate_sql,
802
794
  timeout=timeout,
803
795
  debug=debug,
804
796
  )
805
797
 
806
- usage_metadata[f'generate_sql_reasoning_step_{step}'] = {
798
+ usage_metadata[f'generate_sql_reasoning_step_{step + 1}'] = {
807
799
  "approximate": regen_result['apx_token_count'],
808
800
  **regen_result['usage_metadata'],
809
801
  }
810
802
  if debug and 'p_hash' in regen_result:
811
- usage_metadata[f'generate_sql_reasoning_step_{step}']['p_hash'] = regen_result['p_hash']
803
+ usage_metadata[f'generate_sql_reasoning_step_{step + 1}']['p_hash'] = regen_result['p_hash']
812
804
 
813
805
  sql_query = regen_result['sql']
814
806
  is_sql_valid = regen_result['is_valid']
@@ -1,7 +1,10 @@
1
1
  from typing import Optional, Any
2
2
  import time
3
+ import base64
4
+ import hashlib
3
5
  from pytimbr_api import timbr_http_connector
4
6
  from functools import wraps
7
+ from cryptography.fernet import Fernet
5
8
 
6
9
  from ..config import cache_timeout, ignore_tags, ignore_tags_prefix
7
10
  from .general import to_boolean
@@ -40,6 +43,42 @@ def _serialize_cache_key(*args, **kwargs):
40
43
  return (tuple(serialize(arg) for arg in args), tuple((k, serialize(v)) for k, v in kwargs.items()))
41
44
 
42
45
 
46
+ def generate_key() -> bytes:
47
+ """Generate a new Fernet secret key."""
48
+ passcode = b"lucylit2025"
49
+ hlib = hashlib.md5()
50
+ hlib.update(passcode)
51
+ return base64.urlsafe_b64encode(hlib.hexdigest().encode('utf-8'))
52
+
53
+
54
+ ENCRYPT_KEY = generate_key()
55
+
56
+
57
+ def encrypt_prompt(prompt: Any, key: Optional[bytes] = ENCRYPT_KEY) -> bytes:
58
+ """Serialize & encrypt the prompt; returns a URL-safe token."""
59
+ if isinstance(prompt, str):
60
+ text = prompt
61
+ elif isinstance(prompt, list):
62
+ parts = []
63
+ for message in prompt:
64
+ if hasattr(message, "content"):
65
+ parts.append(f"{message.type}: {message.content}")
66
+ else:
67
+ parts.append(str(message))
68
+ text = "\n".join(parts)
69
+ else:
70
+ text = str(prompt)
71
+
72
+ f = Fernet(key)
73
+ return f.encrypt(text.encode()).decode('utf-8')
74
+
75
+
76
+ def decrypt_prompt(token: bytes, key: bytes) -> str:
77
+ """Decrypt the token and return the original prompt string."""
78
+ f = Fernet(key)
79
+ return f.decrypt(token).decode()
80
+
81
+
43
82
  def cache_with_version_check(func):
44
83
  """Decorator to cache function results and invalidate if ontology version changes."""
45
84
 
@@ -16,6 +16,7 @@ def config():
16
16
  "test_prompt": os.environ.get("TEST_PROMPT", "What are the total sales for consumer customers?"),
17
17
  "test_prompt_2": os.environ.get("TEST_PROMPT_2", "Get all customers"),
18
18
  "test_prompt_3": os.environ.get("TEST_PROMPT_3", "Get all products and materials"),
19
+ "test_reasoning_prompt": os.environ.get("TEST_REASONING_PROMPT", "show me 10 orders in 2021 that contain metal"),
19
20
  "verify_ssl": os.environ.get("VERIFY_SSL", "true"),
20
21
  "jwt_timbr_url": os.environ.get("JWT_TIMBR_URL", "https://staging.timbr.ai:443/"),
21
22
  "jwt_timbr_ontology": os.environ.get("JWT_TIMBR_ONTOLOGY", "supply_metrics"),
@@ -0,0 +1,84 @@
1
+ from langchain_timbr import (
2
+ GenerateTimbrSqlChain,
3
+ ValidateTimbrSqlChain,
4
+ ExecuteTimbrQueryChain,
5
+ create_timbr_sql_agent
6
+ )
7
+
8
+ class TestLangchainChainsReasoningIntegration:
9
+ def _assert_reasoning(self, chain, result, usage_metadata_key=None):
10
+ assert 'reasoning_status' in result
11
+ assert result['reasoning_status'] in ['correct', 'partial', 'incorrect']
12
+
13
+ usage_metadata = result.get(usage_metadata_key or chain.usage_metadata_key, {})
14
+ assert 'sql_reasoning_step_1' in usage_metadata
15
+
16
+ # if first reasoning was incorrect, there must be a re-generating sql & a second reasoning step
17
+ if 'sql_reasoning_step_2' in usage_metadata:
18
+ assert 'generate_sql_reasoning_step_1' in usage_metadata
19
+
20
+ # if the final result was incorrect - there must have two re-generation steps
21
+ if result['reasoning_status'] == 'incorrect':
22
+ assert 'generate_sql_reasoning_step_2' in usage_metadata
23
+
24
+ # SKIP THIS TESTS UNTIL API WILL BE UPDATED
25
+
26
+ def skip_test_generate_timbr_sql_chain(self, llm, config):
27
+ chain = GenerateTimbrSqlChain(
28
+ llm=llm,
29
+ url=config["timbr_url"],
30
+ token=config["timbr_token"],
31
+ ontology=config["timbr_ontology"],
32
+ verify_ssl=config["verify_ssl"],
33
+ enable_reasoning=True,
34
+ )
35
+ result = chain.invoke({ "prompt": config["test_reasoning_prompt"] })
36
+ print("GenerateTimbrSqlChain result:", result)
37
+ self._assert_reasoning(chain, result)
38
+
39
+ def skip_test_validate_timbr_sql_chain(self, llm, config):
40
+ chain = ValidateTimbrSqlChain(
41
+ llm=llm,
42
+ url=config["timbr_url"],
43
+ token=config["timbr_token"],
44
+ ontology=config["timbr_ontology"],
45
+ retries=1, # Use a single retry for test speed
46
+ verify_ssl=config["verify_ssl"],
47
+ enable_reasoning=True,
48
+ )
49
+ inputs = {
50
+ "prompt": config["test_reasoning_prompt"],
51
+ "sql": "SELECT * FROM invalid_table", # Intentionally invalid (or test with a valid one if available)
52
+ }
53
+ result = chain.invoke(inputs)
54
+ print("ValidateTimbrSqlChain result:", result)
55
+ self._assert_reasoning(chain, result)
56
+
57
+ def skip_test_execute_timbr_query_chain(self, llm, config):
58
+ chain = ExecuteTimbrQueryChain(
59
+ llm=llm,
60
+ url=config["timbr_url"],
61
+ token=config["timbr_token"],
62
+ ontology=config["timbr_ontology"],
63
+ verify_ssl=config["verify_ssl"],
64
+ enable_reasoning=True,
65
+ )
66
+ inputs = {
67
+ "prompt": config["test_reasoning_prompt"],
68
+ }
69
+ result = chain.invoke(inputs)
70
+ print("ExecuteTimbrQueryChain result:", result)
71
+ self._assert_reasoning(chain, result)
72
+
73
+ def skip_test_create_timbr_sql_agent(self, llm, config):
74
+ agent = create_timbr_sql_agent(
75
+ llm=llm,
76
+ url=config["timbr_url"],
77
+ token=config["timbr_token"],
78
+ ontology=config["timbr_ontology"],
79
+ verify_ssl=config["verify_ssl"],
80
+ enable_reasoning=True,
81
+ )
82
+ result = agent.invoke(config["test_reasoning_prompt"])
83
+ print("Timbr SQL Agent result:", result)
84
+ self._assert_reasoning(agent, result, usage_metadata_key="usage_metadata")
@@ -5,9 +5,10 @@ from langchain_timbr import (
5
5
  GenerateTimbrSqlChain,
6
6
  ValidateTimbrSqlChain,
7
7
  ExecuteTimbrQueryChain,
8
- GenerateAnswerChain
8
+ GenerateAnswerChain,
9
+ generate_key,
10
+ decrypt_prompt,
9
11
  )
10
- from langchain_timbr.utils import timbr_llm_utils
11
12
 
12
13
 
13
14
  class TestIdentifyTimbrConceptChain:
@@ -147,7 +148,7 @@ class TestGenerateTimbrSqlChain:
147
148
  print("GenerateTimbrSqlChain result:", result)
148
149
  assert "sql" in result and result["sql"], "SQL should be generated"
149
150
  assert "concept" in result and result["concept"] == "customer", "Concept customer should be returned"
150
- prompt = timbr_llm_utils.decrypt_prompt(result["generate_sql_usage_metadata"]["generate_sql"]["p_hash"], timbr_llm_utils.generate_key())
151
+ prompt = decrypt_prompt(result["generate_sql_usage_metadata"]["generate_sql"]["p_hash"], generate_key())
151
152
  assert "customer related info" in prompt, "Customer description should be in prompt"
152
153
  assert "concat of first and last name" in prompt, "Customer name description should be in prompt"
153
154
  assert "continent name" in prompt, "Order market description should be in prompt"
@@ -170,7 +171,7 @@ class TestGenerateTimbrSqlChain:
170
171
  print("GenerateTimbrSqlChain result:", result)
171
172
  assert "sql" in result and result["sql"], "SQL should be generated"
172
173
  assert "concept" in result and result["concept"] == "customer", "Concept customer should be returned"
173
- prompt = timbr_llm_utils.decrypt_prompt(result["generate_sql_usage_metadata"]["generate_sql"]["p_hash"], timbr_llm_utils.generate_key())
174
+ prompt = decrypt_prompt(result["generate_sql_usage_metadata"]["generate_sql"]["p_hash"], generate_key())
174
175
  assert "customer related info" in prompt, "Customer description should be in prompt"
175
176
  assert "concat of first and last name" in prompt, "Customer name description should be in prompt"
176
177
  assert "continent name" in prompt, "Order market description should be in prompt"
@@ -193,7 +194,7 @@ class TestGenerateTimbrSqlChain:
193
194
  print("GenerateTimbrSqlChain result:", result)
194
195
  assert "sql" in result and result["sql"], "SQL should be generated"
195
196
  assert "concept" in result and result["concept"] == "customer_cube", "Concept customer_cube should be returned"
196
- prompt = timbr_llm_utils.decrypt_prompt(result["generate_sql_usage_metadata"]["generate_sql"]["p_hash"], timbr_llm_utils.generate_key())
197
+ prompt = decrypt_prompt(result["generate_sql_usage_metadata"]["generate_sql"]["p_hash"], generate_key())
197
198
  assert "customer cube related info" in prompt, "Customer description should be in prompt"
198
199
  assert "concat of first and last name" in prompt, "Customer name description should be in prompt"
199
200
  assert "continent name" in prompt, "Order market description should be in prompt"
@@ -216,7 +217,7 @@ class TestGenerateTimbrSqlChain:
216
217
  assert "sql" in result and result["sql"], "SQL should be generated"
217
218
  assert "concept" in result and result["concept"] == "product", "Concept product should be returned"
218
219
  assert chain.usage_metadata_key in result, "Chain should return 'usage_metadata'"
219
- prompt = timbr_llm_utils.decrypt_prompt(result["generate_sql_usage_metadata"]["generate_sql"]["p_hash"], timbr_llm_utils.generate_key())
220
+ prompt = decrypt_prompt(result["generate_sql_usage_metadata"]["generate_sql"]["p_hash"], generate_key())
220
221
  assert "commodity" in prompt, "Product tag value of synonym should be in prompt"
221
222
  assert "synonym" in prompt, "Product tag synonym should be in prompt"
222
223
  assert "length" in prompt, "Material tag should be in prompt"
@@ -240,7 +241,7 @@ class TestGenerateTimbrSqlChain:
240
241
  assert "sql" in result and result["sql"], "SQL should be generated"
241
242
  assert "concept" in result and result["concept"] == "product", "Concept product should be returned"
242
243
  assert chain.usage_metadata_key in result, "Chain should return 'usage_metadata'"
243
- prompt = timbr_llm_utils.decrypt_prompt(result["generate_sql_usage_metadata"]["generate_sql"]["p_hash"], timbr_llm_utils.generate_key())
244
+ prompt = decrypt_prompt(result["generate_sql_usage_metadata"]["generate_sql"]["p_hash"], generate_key())
244
245
  assert "commodity" in prompt, "Product tag value of synonym should be in prompt"
245
246
  assert "synonym" in prompt, "Product tag synonym should be in prompt"
246
247
  assert "length" in prompt, "Material tag should be in prompt"
@@ -265,7 +266,7 @@ class TestGenerateTimbrSqlChain:
265
266
  assert "sql" in result and result["sql"], "SQL should be generated"
266
267
  assert "concept" in result and result["concept"] == "product_cube", "Concept product_cube should be returned"
267
268
  assert chain.usage_metadata_key in result, "Chain should return 'usage_metadata'"
268
- prompt = timbr_llm_utils.decrypt_prompt(result["generate_sql_usage_metadata"]["generate_sql"]["p_hash"], timbr_llm_utils.generate_key())
269
+ prompt = decrypt_prompt(result["generate_sql_usage_metadata"]["generate_sql"]["p_hash"], generate_key())
269
270
  assert "commodity cube" in prompt, "Product tag value of synonym should be in prompt"
270
271
  assert "synonym" in prompt, "Product tag synonym should be in prompt"
271
272
  assert "length" in prompt, "Material tag should be in prompt"
File without changes