langchain-timbr 1.5.2__tar.gz → 1.5.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/.github/workflows/publish.yml +1 -1
  2. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/PKG-INFO +23 -5
  3. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/README.md +9 -1
  4. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/pyproject.toml +13 -4
  5. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/requirements.txt +5 -2
  6. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/_version.py +2 -2
  7. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/llm_wrapper/llm_wrapper.py +23 -3
  8. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/utils/general.py +2 -1
  9. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/utils/prompt_service.py +42 -61
  10. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/utils/temperature_supported_models.json +35 -27
  11. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/utils/timbr_llm_utils.py +43 -25
  12. langchain_timbr-1.5.4/tests/integration/test_azure_databricks_provider.py +42 -0
  13. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/.github/dependabot.yml +0 -0
  14. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/.github/pull_request_template.md +0 -0
  15. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/.github/workflows/_codespell.yml +0 -0
  16. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/.github/workflows/_fossa.yml +0 -0
  17. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/.github/workflows/install-dependencies-and-run-tests.yml +0 -0
  18. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/.gitignore +0 -0
  19. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/LICENSE +0 -0
  20. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/pytest.ini +0 -0
  21. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/__init__.py +0 -0
  22. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/config.py +0 -0
  23. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/langchain/__init__.py +0 -0
  24. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/langchain/execute_timbr_query_chain.py +0 -0
  25. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/langchain/generate_answer_chain.py +0 -0
  26. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/langchain/generate_timbr_sql_chain.py +0 -0
  27. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/langchain/identify_concept_chain.py +0 -0
  28. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/langchain/timbr_sql_agent.py +0 -0
  29. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/langchain/validate_timbr_sql_chain.py +0 -0
  30. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/langgraph/__init__.py +0 -0
  31. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/langgraph/execute_timbr_query_node.py +0 -0
  32. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/langgraph/generate_response_node.py +0 -0
  33. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/langgraph/generate_timbr_sql_node.py +0 -0
  34. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/langgraph/identify_concept_node.py +0 -0
  35. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/langgraph/validate_timbr_query_node.py +0 -0
  36. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/llm_wrapper/timbr_llm_wrapper.py +0 -0
  37. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/timbr_llm_connector.py +0 -0
  38. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/src/langchain_timbr/utils/timbr_utils.py +0 -0
  39. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/tests/README.md +0 -0
  40. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/tests/conftest.py +0 -0
  41. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/tests/integration/test_agent_integration.py +0 -0
  42. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/tests/integration/test_azure_openai_model.py +0 -0
  43. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/tests/integration/test_chain_pipeline.py +0 -0
  44. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/tests/integration/test_jwt_token.py +0 -0
  45. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/tests/integration/test_langchain_chains.py +0 -0
  46. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/tests/integration/test_langgraph_nodes.py +0 -0
  47. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/tests/integration/test_timeout_functionality.py +0 -0
  48. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/tests/standard/conftest.py +0 -0
  49. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/tests/standard/test_chain_documentation.py +0 -0
  50. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/tests/standard/test_standard_chain_requirements.py +0 -0
  51. {langchain_timbr-1.5.2 → langchain_timbr-1.5.4}/tests/standard/test_unit_tests.py +0 -0
@@ -16,7 +16,7 @@ jobs:
16
16
  with: { fetch-depth: 0 } # IMPORTANT: tags available to hatch-vcs
17
17
  - uses: actions/setup-python@v5
18
18
  with: { python-version: "3.12" }
19
- - run: python -m pip install --upgrade build twine
19
+ - run: python -m pip install --upgrade build
20
20
  - run: python -m build # hatch-vcs injects version from tag
21
21
  - name: Publish to PyPI
22
22
  uses: pypa/gh-action-pypi-publish@release/v1
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langchain-timbr
3
- Version: 1.5.2
3
+ Version: 1.5.4
4
4
  Summary: LangChain & LangGraph extensions that parse LLM prompts into Timbr semantic SQL and execute them.
5
5
  Project-URL: Homepage, https://github.com/WPSemantix/langchain-timbr
6
6
  Project-URL: Documentation, https://docs.timbr.ai/doc/docs/integration/langchain-sdk/
@@ -28,15 +28,18 @@ Requires-Dist: langgraph>=0.3.20
28
28
  Requires-Dist: pydantic==2.10.4
29
29
  Requires-Dist: pytimbr-api>=2.0.0
30
30
  Requires-Dist: tiktoken==0.8.0
31
- Requires-Dist: transformers>=4.51.3
31
+ Requires-Dist: transformers>=4.53
32
32
  Provides-Extra: all
33
33
  Requires-Dist: anthropic==0.42.0; extra == 'all'
34
+ Requires-Dist: databricks-langchain==0.3.0; (python_version < '3.10') and extra == 'all'
35
+ Requires-Dist: databricks-langchain==0.7.1; (python_version >= '3.10') and extra == 'all'
36
+ Requires-Dist: databricks-sdk==0.64.0; extra == 'all'
34
37
  Requires-Dist: google-generativeai==0.8.4; extra == 'all'
35
38
  Requires-Dist: langchain-anthropic>=0.3.1; extra == 'all'
36
39
  Requires-Dist: langchain-google-genai>=2.0.9; extra == 'all'
37
40
  Requires-Dist: langchain-openai>=0.3.16; extra == 'all'
38
41
  Requires-Dist: langchain-tests>=0.3.20; extra == 'all'
39
- Requires-Dist: openai==1.77.0; extra == 'all'
42
+ Requires-Dist: openai>=1.77.0; extra == 'all'
40
43
  Requires-Dist: pyarrow<19.0.0; extra == 'all'
41
44
  Requires-Dist: pytest==8.3.4; extra == 'all'
42
45
  Requires-Dist: snowflake-snowpark-python>=1.6.0; extra == 'all'
@@ -45,6 +48,13 @@ Requires-Dist: uvicorn==0.34.0; extra == 'all'
45
48
  Provides-Extra: anthropic
46
49
  Requires-Dist: anthropic==0.42.0; extra == 'anthropic'
47
50
  Requires-Dist: langchain-anthropic>=0.3.1; extra == 'anthropic'
51
+ Provides-Extra: azure-openai
52
+ Requires-Dist: langchain-openai>=0.3.16; extra == 'azure-openai'
53
+ Requires-Dist: openai>=1.77.0; extra == 'azure-openai'
54
+ Provides-Extra: databricks
55
+ Requires-Dist: databricks-langchain==0.3.0; (python_version < '3.10') and extra == 'databricks'
56
+ Requires-Dist: databricks-langchain==0.7.1; (python_version >= '3.10') and extra == 'databricks'
57
+ Requires-Dist: databricks-sdk==0.64.0; extra == 'databricks'
48
58
  Provides-Extra: dev
49
59
  Requires-Dist: langchain-tests>=0.3.20; extra == 'dev'
50
60
  Requires-Dist: pyarrow<19.0.0; extra == 'dev'
@@ -55,7 +65,7 @@ Requires-Dist: google-generativeai==0.8.4; extra == 'google'
55
65
  Requires-Dist: langchain-google-genai>=2.0.9; extra == 'google'
56
66
  Provides-Extra: openai
57
67
  Requires-Dist: langchain-openai>=0.3.16; extra == 'openai'
58
- Requires-Dist: openai==1.77.0; extra == 'openai'
68
+ Requires-Dist: openai>=1.77.0; extra == 'openai'
59
69
  Provides-Extra: snowflake
60
70
  Requires-Dist: snowflake-snowpark-python>=1.6.0; extra == 'snowflake'
61
71
  Requires-Dist: snowflake>=0.8.0; extra == 'snowflake'
@@ -80,15 +90,23 @@ Timbr LangChain LLM SDK is a Python SDK that extends LangChain and LangGraph wit
80
90
 
81
91
  ## Dependencies
82
92
  - Access to a timbr-server
83
- - Python from 3.9.13 or newer
93
+ - Python 3.9.13 or newer
84
94
 
85
95
  ## Installation
86
96
 
87
97
  ### Using pip
98
+
88
99
  ```bash
89
100
  python -m pip install langchain-timbr
90
101
  ```
91
102
 
103
+ ### Install with selected LLM providers
104
+ #### One of: openai, anthropic, google, azure_openai, snowflake, databricks (or 'all')
105
+
106
+ ```bash
107
+ python -m pip install 'langchain-timbr[<your selected providers, separated by comma w/o space]'
108
+ ```
109
+
92
110
  ### Using pip from github
93
111
  ```bash
94
112
  pip install git+https://github.com/WPSemantix/langchain-timbr
@@ -17,15 +17,23 @@ Timbr LangChain LLM SDK is a Python SDK that extends LangChain and LangGraph wit
17
17
 
18
18
  ## Dependencies
19
19
  - Access to a timbr-server
20
- - Python from 3.9.13 or newer
20
+ - Python 3.9.13 or newer
21
21
 
22
22
  ## Installation
23
23
 
24
24
  ### Using pip
25
+
25
26
  ```bash
26
27
  python -m pip install langchain-timbr
27
28
  ```
28
29
 
30
+ ### Install with selected LLM providers
31
+ #### One of: openai, anthropic, google, azure_openai, snowflake, databricks (or 'all')
32
+
33
+ ```bash
34
+ python -m pip install 'langchain-timbr[<your selected providers, separated by comma w/o space]'
35
+ ```
36
+
29
37
  ### Using pip from github
30
38
  ```bash
31
39
  pip install git+https://github.com/WPSemantix/langchain-timbr
@@ -32,15 +32,21 @@ dependencies = [
32
32
  "pydantic==2.10.4",
33
33
  "pytimbr-api>=2.0.0",
34
34
  "tiktoken==0.8.0",
35
- "transformers>=4.51.3"
35
+ "transformers>=4.53"
36
36
  ]
37
37
 
38
38
  [project.optional-dependencies]
39
39
  # LLM providers
40
- openai = ["openai==1.77.0", "langchain-openai>=0.3.16"]
40
+ openai = ["openai>=1.77.0", "langchain-openai>=0.3.16"]
41
+ azure_openai = ["openai>=1.77.0", "langchain-openai>=0.3.16"]
41
42
  anthropic = ["anthropic==0.42.0", "langchain-anthropic>=0.3.1"]
42
43
  google = ["langchain-google-genai>=2.0.9", "google-generativeai==0.8.4"]
43
44
  snowflake = ["snowflake>=0.8.0", "snowflake-snowpark-python>=1.6.0"]
45
+ databricks = [
46
+ "databricks-langchain==0.3.0; python_version < '3.10'",
47
+ "databricks-langchain==0.7.1; python_version >= '3.10'",
48
+ "databricks-sdk==0.64.0"
49
+ ]
44
50
 
45
51
  # Development and testing
46
52
  dev = [
@@ -55,7 +61,7 @@ all = [
55
61
  "anthropic==0.42.0",
56
62
  "google-generativeai==0.8.4",
57
63
  "langchain-anthropic>=0.3.1",
58
- "openai==1.77.0",
64
+ "openai>=1.77.0",
59
65
  "langchain-openai>=0.3.16",
60
66
  "langchain-google-genai>=2.0.9",
61
67
  "snowflake>=0.8.0",
@@ -63,7 +69,10 @@ all = [
63
69
  "pytest==8.3.4",
64
70
  "langchain-tests>=0.3.20",
65
71
  "pyarrow<19.0.0",
66
- "uvicorn==0.34.0"
72
+ "uvicorn==0.34.0",
73
+ "databricks-langchain==0.3.0; python_version < '3.10'",
74
+ "databricks-langchain==0.7.1; python_version >= '3.10'",
75
+ "databricks-sdk==0.64.0"
67
76
  ]
68
77
 
69
78
  [project.urls]
@@ -1,5 +1,8 @@
1
1
  anthropic==0.42.0
2
2
  cryptography>=44.0.3
3
+ databricks-langchain==0.3.0; python_version < '3.10'
4
+ databricks-langchain==0.7.1; python_version >= '3.10'
5
+ databricks-sdk==0.64.0
3
6
  google-generativeai==0.8.4
4
7
  langchain>=0.3.25
5
8
  langchain-anthropic>=0.3.1
@@ -9,7 +12,7 @@ langchain-google-genai>=2.0.9
9
12
  langchain-openai>=0.3.16
10
13
  langchain-tests>=0.3.20
11
14
  langgraph>=0.3.20
12
- openai==1.77.0
15
+ openai>=1.77.0
13
16
  pyarrow<19.0.0
14
17
  pydantic==2.10.4
15
18
  pytest==8.3.4
@@ -17,5 +20,5 @@ pytimbr-api>=2.0.0
17
20
  snowflake>=0.8.0
18
21
  snowflake-snowpark-python>=1.6.0
19
22
  tiktoken==0.8.0
20
- transformers>=4.51.3
23
+ transformers>=4.53
21
24
  uvicorn==0.34.0
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.5.2'
32
- __version_tuple__ = version_tuple = (1, 5, 2)
31
+ __version__ = version = '1.5.4'
32
+ __version_tuple__ = version_tuple = (1, 5, 4)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -12,13 +12,14 @@ class LlmTypes(Enum):
12
12
  Google = 'chat-google-generative-ai'
13
13
  AzureOpenAI = 'azure-openai-chat'
14
14
  Snowflake = 'snowflake-cortex'
15
+ Databricks = 'chat-databricks'
15
16
  Timbr = 'timbr'
16
17
 
17
18
 
18
19
  class LlmWrapper(LLM):
19
20
  """
20
21
  LlmWrapper is a unified interface for connecting to various Large Language Model (LLM) providers
21
- (OpenAI, Anthropic, Google, Azure OpenAI, Snowflake Cortex, etc.) using LangChain. It abstracts
22
+ (OpenAI, Anthropic, Google, Azure OpenAI, Snowflake Cortex, Databricks, etc.) using LangChain. It abstracts
22
23
  the initialization and connection logic for each provider, allowing you to switch between them
23
24
  with a consistent API.
24
25
  """
@@ -95,12 +96,14 @@ class LlmWrapper(LLM):
95
96
  **params,
96
97
  )
97
98
  elif is_llm_type(llm_type, LlmTypes.Snowflake):
98
- from langchain_community.chat_models import ChatSnowflakeCortex
99
+ from langchain_community.chat_models import ChatSnowflakeCortex
99
100
  llm_model = model or "openai-gpt-4.1"
100
101
  params = self._add_temperature(LlmTypes.Snowflake.name, llm_model, **llm_params)
102
+ snowflake_password = params.pop('snowflake_api_key', params.pop('snowflake_password', api_key))
101
103
 
102
104
  return ChatSnowflakeCortex(
103
105
  model=llm_model,
106
+ snowflake_password=snowflake_password,
104
107
  **params,
105
108
  )
106
109
  elif is_llm_type(llm_type, LlmTypes.AzureOpenAI):
@@ -116,6 +119,19 @@ class LlmWrapper(LLM):
116
119
  openai_api_version=azure_api_version,
117
120
  **params,
118
121
  )
122
+ elif is_llm_type(llm_type, LlmTypes.Databricks):
123
+ from databricks.sdk import WorkspaceClient
124
+ from databricks_langchain import ChatDatabricks
125
+ llm_model = model or "databricks-claude-sonnet-4"
126
+ params = self._add_temperature(LlmTypes.Databricks.name, llm_model, **llm_params)
127
+
128
+ host = params.pop('databricks_host', params.pop('host', None))
129
+ w = WorkspaceClient(host=host, token=api_key)
130
+ return ChatDatabricks(
131
+ endpoint=llm_model,
132
+ workspace_client=w, # Using authenticated client
133
+ **params,
134
+ )
119
135
  else:
120
136
  raise ValueError(f"Unsupported LLM type: {llm_type}")
121
137
 
@@ -163,12 +179,16 @@ class LlmWrapper(LLM):
163
179
  "llama3.1-70b",
164
180
  "llama3.1-405b"
165
181
  ]
182
+ elif is_llm_type(self._llm_type, LlmTypes.Databricks):
183
+ w = self.client.workspace_client
184
+ models = [ep.name for ep in w.serving_endpoints.list()]
185
+
166
186
  # elif self._is_llm_type(self._llm_type, LlmTypes.Timbr):
167
187
 
168
188
  except Exception as e:
169
189
  models = []
170
190
 
171
- return models
191
+ return sorted(models)
172
192
 
173
193
 
174
194
  def _call(self, prompt, **kwargs):
@@ -44,7 +44,8 @@ def is_llm_type(llm_type, enum_value):
44
44
  llm_type_lower == enum_name_lower or
45
45
  llm_type_lower == enum_value_lower or
46
46
  llm_type_lower.startswith(enum_name_lower) or # Usecase for snowflake which its type is the provider name + the model name
47
- llm_type_lower.startswith(enum_value_lower)
47
+ llm_type_lower.startswith(enum_value_lower) or
48
+ llm_type_lower in enum_value_lower # Check if the enum value includes the llm type - when providing partial name
48
49
  )
49
50
 
50
51
  return False
@@ -1,32 +1,45 @@
1
1
  import requests
2
2
  from typing import Dict, Any, Optional, List, Union
3
- from langchain.schema import SystemMessage, HumanMessage
4
3
  from langchain.prompts.chat import ChatPromptTemplate
5
4
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate
6
5
  import json
7
6
  import logging
8
-
9
- from ..config import url, token as default_token, is_jwt, jwt_tenant_id as default_jwt_tenant_id, llm_timeout
7
+ from ..config import url as default_url, token as default_token, is_jwt, jwt_tenant_id as default_jwt_tenant_id, llm_timeout
10
8
 
11
9
  logger = logging.getLogger(__name__)
12
10
 
13
11
  # Global template cache shared across all PromptService instances
14
- _global_template_cache = {}
12
+ _global_template_cache: dict[Any, Any] = {}
15
13
 
16
14
  class PromptService:
17
15
  def __init__(
18
16
  self,
19
- base_url: Optional[str] = url,
20
- token: Optional[str] = default_token,
21
- is_jwt: Optional[bool] = is_jwt,
22
- jwt_tenant_id: Optional[str] = default_jwt_tenant_id,
23
- timeout: Optional[int] = llm_timeout,
17
+ conn_params: Optional[Dict[str, Any]] = None,
18
+ **kwargs
24
19
  ):
25
- self.base_url = base_url.rstrip('/')
26
- self.token = token
27
- self.is_jwt = is_jwt
28
- self.jwt_tenant_id = jwt_tenant_id
29
- self.timeout = timeout
20
+ """
21
+ Initialize PromptService with connection parameters.
22
+
23
+ Args:
24
+ conn_params: Dictionary containing connection parameters
25
+ **kwargs: Additional parameters for backward compatibility
26
+ """
27
+ # Extract relevant parameters from conn_params or use defaults
28
+ if conn_params:
29
+ url_value = conn_params.get('url') or default_url
30
+ self.base_url = url_value.rstrip('/') if url_value else ''
31
+ self.token = conn_params.get('token') or default_token
32
+ self.is_jwt = conn_params.get('is_jwt', is_jwt)
33
+ self.jwt_tenant_id = conn_params.get('jwt_tenant_id') or default_jwt_tenant_id
34
+ self.timeout = conn_params.get('timeout') or llm_timeout
35
+ else:
36
+ # Fallback to kwargs for backward compatibility
37
+ url_value = kwargs.get('url') or default_url
38
+ self.base_url = url_value.rstrip('/') if url_value else ''
39
+ self.token = str(kwargs.get('token') or default_token)
40
+ self.is_jwt = kwargs.get('is_jwt', is_jwt)
41
+ self.jwt_tenant_id = kwargs.get('jwt_tenant_id') or default_jwt_tenant_id
42
+ self.timeout = kwargs.get('timeout') or llm_timeout
30
43
 
31
44
 
32
45
  def _get_headers(self) -> Dict[str, str]:
@@ -220,99 +233,67 @@ class PromptTemplateWrapper:
220
233
 
221
234
  # Individual prompt template getter functions
222
235
  def get_determine_concept_prompt_template(
223
- token: Optional[str] = None,
224
- is_jwt: Optional[bool] = None,
225
- jwt_tenant_id: Optional[str] = None
236
+ conn_params: Optional[dict] = None
226
237
  ) -> PromptTemplateWrapper:
227
238
  """
228
239
  Get determine concept prompt template wrapper
229
240
 
230
241
  Args:
231
- token: Authentication token
232
- is_jwt: Whether the token is a JWT
233
- jwt_tenant_id: JWT tenant ID
234
-
242
+ conn_params: Connection parameters including url, token, is_jwt, and jwt_tenant_id
243
+
235
244
  Returns:
236
245
  PromptTemplateWrapper for determine concept
237
246
  """
238
- prompt_service = PromptService(
239
- token=token,
240
- is_jwt=is_jwt,
241
- jwt_tenant_id=jwt_tenant_id
242
- )
247
+ prompt_service = PromptService(conn_params=conn_params)
243
248
  return PromptTemplateWrapper(prompt_service, "get_identify_concept_template")
244
249
 
245
250
 
246
251
  def get_generate_sql_prompt_template(
247
- token: Optional[str] = None,
248
- is_jwt: Optional[bool] = None,
249
- jwt_tenant_id: Optional[str] = None
252
+ conn_params: Optional[dict] = None
250
253
  ) -> PromptTemplateWrapper:
251
254
  """
252
255
  Get generate SQL prompt template wrapper
253
256
 
254
257
  Args:
255
- token: Authentication token
256
- is_jwt: Whether the token is a JWT
257
- jwt_tenant_id: JWT tenant ID
258
+ conn_params: Connection parameters including url, token, is_jwt, and jwt_tenant_id
258
259
 
259
260
  Returns:
260
261
  PromptTemplateWrapper for generate SQL
261
262
  """
262
- prompt_service = PromptService(
263
- token=token,
264
- is_jwt=is_jwt,
265
- jwt_tenant_id=jwt_tenant_id
266
- )
263
+ prompt_service = PromptService(conn_params=conn_params)
267
264
  return PromptTemplateWrapper(prompt_service, "get_generate_sql_template")
268
265
 
269
266
 
270
267
  def get_qa_prompt_template(
271
- token: Optional[str] = None,
272
- is_jwt: Optional[bool] = None,
273
- jwt_tenant_id: Optional[str] = None
268
+ conn_params: Optional[dict] = None
274
269
  ) -> PromptTemplateWrapper:
275
270
  """
276
271
  Get QA prompt template wrapper
277
272
 
278
273
  Args:
279
- token: Authentication token
280
- is_jwt: Whether the token is a JWT
281
- jwt_tenant_id: JWT tenant ID
282
-
274
+ conn_params: Connection parameters including url, token, is_jwt, and jwt_tenant_id
275
+
283
276
  Returns:
284
277
  PromptTemplateWrapper for QA
285
278
  """
286
- prompt_service = PromptService(
287
- token=token,
288
- is_jwt=is_jwt,
289
- jwt_tenant_id=jwt_tenant_id
290
- )
279
+ prompt_service = PromptService(conn_params=conn_params)
291
280
  return PromptTemplateWrapper(prompt_service, "get_generate_answer_template")
292
281
 
293
282
 
294
283
  # Global prompt service instance (updated signature)
295
284
  def get_prompt_service(
296
- token: str = None,
297
- is_jwt: bool = None,
298
- jwt_tenant_id: str = None
285
+ conn_params: Optional[dict] = None
299
286
  ) -> PromptService:
300
287
  """
301
288
  Get or create a prompt service instance
302
289
 
303
290
  Args:
304
- token: Authentication token (API key or JWT token)
305
- is_jwt: Whether the token is a JWT
306
- jwt_tenant_id: JWT tenant ID
307
-
291
+ conn_params: Connection parameters including url, token, is_jwt, and jwt_tenant_id
292
+
308
293
  Returns:
309
294
  PromptService instance
310
295
  """
311
- return PromptService(
312
- token=token,
313
- is_jwt=is_jwt,
314
- jwt_tenant_id=jwt_tenant_id
315
- )
296
+ return PromptService(conn_params=conn_params)
316
297
 
317
298
 
318
299
  # Global cache management functions
@@ -2,61 +2,69 @@
2
2
  "OpenAI": [
3
3
  "gpt-4",
4
4
  "gpt-4-turbo",
5
- "gpt-4o"
5
+ "gpt-4o",
6
+ "gpt-5",
7
+ "gpt-5-chat-latest",
8
+ "gpt-5-mini"
6
9
  ],
7
10
  "Anthropic": [
8
- "claude-opus-4-20250514",
9
- "claude-sonnet-4-20250514",
10
- "claude-3-7-sonnet-20250219",
11
- "claude-3-5-sonnet-20241022",
12
11
  "claude-3-5-haiku-20241022",
13
12
  "claude-3-5-sonnet-20240620",
13
+ "claude-3-5-sonnet-20241022",
14
+ "claude-3-7-sonnet-20250219",
14
15
  "claude-3-haiku-20240307",
15
16
  "claude-3-opus-20240229",
16
- "claude-3-sonnet-20240229",
17
- "claude-2.1",
18
- "claude-2.0"
17
+ "claude-opus-4-20250514",
18
+ "claude-sonnet-4-20250514"
19
19
  ],
20
20
  "Google": [
21
- "gemini-1.5-flash-latest",
22
21
  "gemini-1.5-flash",
23
22
  "gemini-1.5-flash-002",
24
23
  "gemini-1.5-flash-8b",
25
24
  "gemini-1.5-flash-8b-001",
26
25
  "gemini-1.5-flash-8b-latest",
27
- "gemini-2.5-flash-preview-04-17",
28
- "gemini-2.5-flash-preview-05-20",
29
- "gemini-2.5-flash",
30
- "gemini-2.5-flash-preview-04-17-thinking",
31
- "gemini-2.5-flash-lite-preview-06-17",
32
- "gemini-2.5-pro",
33
- "gemini-2.0-flash-exp",
26
+ "gemini-1.5-flash-latest",
34
27
  "gemini-2.0-flash",
35
28
  "gemini-2.0-flash-001",
29
+ "gemini-2.0-flash-exp",
36
30
  "gemini-2.0-flash-exp-image-generation",
37
- "gemini-2.0-flash-lite-001",
38
31
  "gemini-2.0-flash-lite",
39
- "gemini-2.0-flash-lite-preview-02-05",
32
+ "gemini-2.0-flash-lite-001",
40
33
  "gemini-2.0-flash-lite-preview",
41
- "gemini-2.0-flash-thinking-exp-01-21",
34
+ "gemini-2.0-flash-lite-preview-02-05",
42
35
  "gemini-2.0-flash-thinking-exp",
36
+ "gemini-2.0-flash-thinking-exp-01-21",
43
37
  "gemini-2.0-flash-thinking-exp-1219",
44
- "learnlm-2.0-flash-experimental",
45
- "gemma-3-1b-it",
46
- "gemma-3-4b-it",
38
+ "gemini-2.5-flash",
39
+ "gemini-2.5-flash-lite",
40
+ "gemini-2.5-flash-lite-preview-06-17",
41
+ "gemini-2.5-flash-preview-05-20",
47
42
  "gemma-3-12b-it",
43
+ "gemma-3-1b-it",
48
44
  "gemma-3-27b-it",
49
- "gemma-3n-e4b-it",
50
- "gemma-3n-e2b-it"
45
+ "gemma-3-4b-it",
46
+ "gemma-3n-e2b-it",
47
+ "gemma-3n-e4b-it"
51
48
  ],
52
49
  "AzureOpenAI": [
53
50
  "gpt-4o"
54
51
  ],
55
52
  "Snowflake": [
56
- "openai-gpt-4.1",
57
- "mistral-large2",
53
+ "llama3.1-405b",
58
54
  "llama3.1-70b",
59
- "llama3.1-405b"
55
+ "mistral-large2",
56
+ "openai-gpt-4.1"
57
+ ],
58
+ "Databricks": [
59
+ "databricks-claude-3-7-sonnet",
60
+ "databricks-claude-sonnet-4",
61
+ "databricks-gemma-3-12b",
62
+ "databricks-gpt-oss-120b",
63
+ "databricks-gpt-oss-20b",
64
+ "databricks-llama-4-maverick",
65
+ "databricks-meta-llama-3-1-405b-instruct",
66
+ "databricks-meta-llama-3-1-8b-instruct",
67
+ "databricks-meta-llama-3-3-70b-instruct"
60
68
  ],
61
69
  "Timbr": []
62
70
  }
@@ -165,6 +165,40 @@ def _calculate_token_count(llm: LLM, prompt: str) -> int:
165
165
  return token_count
166
166
 
167
167
 
168
+ def _get_response_text(response: Any) -> str:
169
+ if hasattr(response, "content"):
170
+ response_text = response.content
171
+
172
+ # Handle Databricks gpt-oss type of responses (having list of dicts with type + summary for reasoning or type + text for result)
173
+ if isinstance(response_text, list):
174
+ response_text = next(filter(lambda x: x.get('type') == 'text', response.content), None)
175
+ if isinstance(response_text, dict):
176
+ response_text = response_text.get('text', '')
177
+ elif isinstance(response, str):
178
+ response_text = response
179
+ else:
180
+ raise ValueError("Unexpected response format from LLM.")
181
+
182
+ return response_text
183
+
184
+ def _extract_usage_metadata(response: Any) -> dict:
185
+ usage_metadata = response.response_metadata
186
+
187
+ if usage_metadata and 'usage' in usage_metadata:
188
+ usage_metadata = usage_metadata['usage']
189
+
190
+ if not usage_metadata and 'usage_metadata' in response:
191
+ usage_metadata = response.usage_metadata
192
+ if usage_metadata and 'usage' in usage_metadata:
193
+ usage_metadata = usage_metadata['usage']
194
+
195
+ if not usage_metadata and 'usage' in response:
196
+ usage_metadata = response.usage
197
+ if usage_metadata and 'usage' in usage_metadata:
198
+ usage_metadata = usage_metadata['usage']
199
+
200
+ return usage_metadata
201
+
168
202
  def determine_concept(
169
203
  question: str,
170
204
  llm: LLM,
@@ -178,7 +212,7 @@ def determine_concept(
178
212
  note: Optional[str] = '',
179
213
  debug: Optional[bool] = False,
180
214
  timeout: Optional[int] = None,
181
- ) -> dict[str, any]:
215
+ ) -> dict[str, Any]:
182
216
  usage_metadata = {}
183
217
  determined_concept_name = None
184
218
  schema = 'dtimbr'
@@ -187,7 +221,7 @@ def determine_concept(
187
221
  if timeout is None:
188
222
  timeout = llm_timeout
189
223
 
190
- determine_concept_prompt = get_determine_concept_prompt_template(conn_params["token"], conn_params["is_jwt"], conn_params["jwt_tenant_id"])
224
+ determine_concept_prompt = get_determine_concept_prompt_template(conn_params)
191
225
  tags = get_tags(conn_params=conn_params, include_tags=include_tags)
192
226
  concepts = get_concepts(
193
227
  conn_params=conn_params,
@@ -253,20 +287,12 @@ def determine_concept(
253
287
  continue
254
288
  usage_metadata['determine_concept'] = {
255
289
  "approximate": apx_token_count,
256
- # **(response.usage_metadata or response.usage or {}),
257
- **(response.usage_metadata or {}),
290
+ **_extract_usage_metadata(response),
258
291
  }
259
292
  if debug:
260
293
  usage_metadata['determine_concept']["p_hash"] = encrypt_prompt(prompt)
261
294
 
262
- if hasattr(response, "content"):
263
- response_text = response.content
264
- elif isinstance(response, str):
265
- response_text = response
266
- else:
267
- raise ValueError("Unexpected response format from LLM.")
268
-
269
-
295
+ response_text = _get_response_text(response)
270
296
  candidate = response_text.strip()
271
297
  if should_validate and candidate not in concepts.keys():
272
298
  error = f"Concept '{determined_concept_name}' not found in the list of concepts."
@@ -351,13 +377,7 @@ def _build_rel_columns_str(relationships: list[dict], columns_tags: Optional[dic
351
377
 
352
378
 
353
379
  def _parse_sql_from_llm_response(response: Any) -> str:
354
- if hasattr(response, "content"):
355
- response_text = response.content
356
- elif isinstance(response, str):
357
- response_text = response
358
- else:
359
- raise ValueError("Unexpected response format from LLM.")
360
-
380
+ response_text = _get_response_text(response)
361
381
  return (response_text
362
382
  .replace("```sql", "")
363
383
  .replace("```", "")
@@ -398,7 +418,7 @@ def generate_sql(
398
418
  if timeout is None:
399
419
  timeout = llm_timeout
400
420
 
401
- generate_sql_prompt = get_generate_sql_prompt_template(conn_params["token"], conn_params["is_jwt"], conn_params["jwt_tenant_id"])
421
+ generate_sql_prompt = get_generate_sql_prompt_template(conn_params)
402
422
 
403
423
  if concept and concept != "" and (schema is None or schema != "vtimbr"):
404
424
  concepts_list = [concept]
@@ -497,8 +517,7 @@ def generate_sql(
497
517
 
498
518
  usage_metadata['generate_sql'] = {
499
519
  "approximate": apx_token_count,
500
- # **(response.usage_metadata or response.usage or {}),
501
- **(response.usage_metadata or {}),
520
+ **_extract_usage_metadata(response),
502
521
  }
503
522
  if debug:
504
523
  usage_metadata['generate_sql']["p_hash"] = encrypt_prompt(prompt)
@@ -531,7 +550,7 @@ def answer_question(
531
550
  if timeout is None:
532
551
  timeout = llm_timeout
533
552
 
534
- qa_prompt = get_qa_prompt_template(conn_params["token"], conn_params["is_jwt"], conn_params["jwt_tenant_id"])
553
+ qa_prompt = get_qa_prompt_template(conn_params)
535
554
 
536
555
  prompt = qa_prompt.format_messages(
537
556
  question=question,
@@ -561,8 +580,7 @@ def answer_question(
561
580
  usage_metadata = {
562
581
  "answer_question": {
563
582
  "approximate": apx_token_count,
564
- # **(response.usage_metadata or response.usage or {}),
565
- **(response.usage_metadata or {}),
583
+ **_extract_usage_metadata(response),
566
584
  },
567
585
  }
568
586
  if debug:
@@ -0,0 +1,42 @@
1
+ from langchain_timbr import LlmWrapper, LlmTypes, ExecuteTimbrQueryChain
2
+ import os
3
+
4
+ class TestAzureDatabricksProvider:
5
+ """Test suite for Azure Databricks provider integration with Timbr chains."""
6
+
7
+ def skip_test_databricks_connection(self, llm, config):
8
+ DATABRICKS_LLM_HOST = os.getenv('databricks_host')
9
+ AZURE_TOKEN = os.getenv('databricks_token')
10
+
11
+ llm_models_test = [
12
+ "databricks-claude-sonnet-4",
13
+ "databricks-gpt-oss-20b",
14
+ ]
15
+
16
+ for llm_model in llm_models_test:
17
+ llm_instance = LlmWrapper(
18
+ llm_type=LlmTypes.Databricks,
19
+ api_key=AZURE_TOKEN,
20
+ databricks_host=DATABRICKS_LLM_HOST,
21
+ model=llm_model,
22
+ )
23
+
24
+ chain = ExecuteTimbrQueryChain(
25
+ llm=llm_instance,
26
+ url=config["timbr_url"],
27
+ token=config["timbr_token"],
28
+ ontology=config["timbr_ontology"],
29
+ verify_ssl=config["verify_ssl"],
30
+ )
31
+
32
+ inputs = {
33
+ "prompt": config["test_prompt"],
34
+ }
35
+ result = chain.invoke(inputs)
36
+
37
+ print("ExecuteTimbrQueryChain result:", result)
38
+ assert "rows" in result, "Result should contain 'rows'"
39
+ assert isinstance(result["rows"], list), "'rows' should be a list"
40
+ assert result["sql"], "SQL should be present in the result"
41
+
42
+
File without changes