langchain-timbr 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_timbr/__init__.py +17 -0
- langchain_timbr/config.py +21 -0
- langchain_timbr/langchain/__init__.py +16 -0
- langchain_timbr/langchain/execute_timbr_query_chain.py +307 -0
- langchain_timbr/langchain/generate_answer_chain.py +99 -0
- langchain_timbr/langchain/generate_timbr_sql_chain.py +176 -0
- langchain_timbr/langchain/identify_concept_chain.py +138 -0
- langchain_timbr/langchain/timbr_sql_agent.py +418 -0
- langchain_timbr/langchain/validate_timbr_sql_chain.py +187 -0
- langchain_timbr/langgraph/__init__.py +13 -0
- langchain_timbr/langgraph/execute_timbr_query_node.py +108 -0
- langchain_timbr/langgraph/generate_response_node.py +59 -0
- langchain_timbr/langgraph/generate_timbr_sql_node.py +98 -0
- langchain_timbr/langgraph/identify_concept_node.py +78 -0
- langchain_timbr/langgraph/validate_timbr_query_node.py +100 -0
- langchain_timbr/llm_wrapper/llm_wrapper.py +189 -0
- langchain_timbr/llm_wrapper/timbr_llm_wrapper.py +41 -0
- langchain_timbr/timbr_llm_connector.py +398 -0
- langchain_timbr/utils/general.py +70 -0
- langchain_timbr/utils/prompt_service.py +330 -0
- langchain_timbr/utils/temperature_supported_models.json +62 -0
- langchain_timbr/utils/timbr_llm_utils.py +575 -0
- langchain_timbr/utils/timbr_utils.py +475 -0
- langchain_timbr-1.5.0.dist-info/METADATA +103 -0
- langchain_timbr-1.5.0.dist-info/RECORD +27 -0
- langchain_timbr-1.5.0.dist-info/WHEEL +4 -0
- langchain_timbr-1.5.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from langchain.llms.base import LLM
|
|
3
|
+
from pydantic import Field
|
|
4
|
+
|
|
5
|
+
from .timbr_llm_wrapper import TimbrLlmWrapper
|
|
6
|
+
from ..utils.general import is_llm_type, is_support_temperature
|
|
7
|
+
from ..config import llm_temperature
|
|
8
|
+
|
|
9
|
+
class LlmTypes(Enum):
|
|
10
|
+
OpenAI = 'openai-chat'
|
|
11
|
+
Anthropic = 'anthropic-chat'
|
|
12
|
+
Google = 'chat-google-generative-ai'
|
|
13
|
+
AzureOpenAI = 'azure-openai-chat'
|
|
14
|
+
Snowflake = 'snowflake-cortex'
|
|
15
|
+
Timbr = 'timbr'
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LlmWrapper(LLM):
|
|
19
|
+
"""
|
|
20
|
+
LlmWrapper is a unified interface for connecting to various Large Language Model (LLM) providers
|
|
21
|
+
(OpenAI, Anthropic, Google, Azure OpenAI, Snowflake Cortex, etc.) using LangChain. It abstracts
|
|
22
|
+
the initialization and connection logic for each provider, allowing you to switch between them
|
|
23
|
+
with a consistent API.
|
|
24
|
+
"""
|
|
25
|
+
client: LLM = Field(default=None, exclude=True)
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
llm_type: str,
|
|
30
|
+
api_key: str,
|
|
31
|
+
model: str = None,
|
|
32
|
+
**llm_params,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
:param llm_type (str): The type of LLM provider (e.g., 'openai-chat', 'anthropic-chat').
|
|
36
|
+
:param api_key (str): The API key for authenticating with the LLM provider.
|
|
37
|
+
:param model (str): The model name or deployment to use. Defaults to provider-specific values (Optional).
|
|
38
|
+
:param **llm_params: Additional parameters for the LLM (e.g., temperature, endpoint, etc.).
|
|
39
|
+
"""
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.client = self._connect_to_llm(
|
|
42
|
+
llm_type,
|
|
43
|
+
api_key,
|
|
44
|
+
model,
|
|
45
|
+
**llm_params,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def _llm_type(self):
|
|
51
|
+
return self.client._llm_type
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _add_temperature(self, llm_type, llm_model, **llm_params):
|
|
55
|
+
"""
|
|
56
|
+
Add temperature to the LLM parameters if the LLM model supports it.
|
|
57
|
+
"""
|
|
58
|
+
if "temperature" not in llm_params:
|
|
59
|
+
if llm_temperature is not None and is_support_temperature(llm_type, llm_model):
|
|
60
|
+
llm_params["temperature"] = llm_temperature
|
|
61
|
+
return llm_params
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _connect_to_llm(self, llm_type, api_key, model, **llm_params):
|
|
65
|
+
if is_llm_type(llm_type, LlmTypes.OpenAI):
|
|
66
|
+
from langchain_openai import ChatOpenAI as OpenAI
|
|
67
|
+
llm_model = model or "gpt-4o-2024-11-20"
|
|
68
|
+
params = self._add_temperature(LlmTypes.OpenAI.name, llm_model, **llm_params)
|
|
69
|
+
return OpenAI(
|
|
70
|
+
openai_api_key=api_key,
|
|
71
|
+
model_name=llm_model,
|
|
72
|
+
**params,
|
|
73
|
+
)
|
|
74
|
+
elif is_llm_type(llm_type, LlmTypes.Anthropic):
|
|
75
|
+
from langchain_anthropic import ChatAnthropic as Claude
|
|
76
|
+
llm_model = model or "claude-3-5-sonnet-20241022"
|
|
77
|
+
params = self._add_temperature(LlmTypes.Anthropic.name, llm_model, **llm_params)
|
|
78
|
+
return Claude(
|
|
79
|
+
anthropic_api_key=api_key,
|
|
80
|
+
model=llm_model,
|
|
81
|
+
**params,
|
|
82
|
+
)
|
|
83
|
+
elif is_llm_type(llm_type, LlmTypes.Google):
|
|
84
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
85
|
+
llm_model = model or "gemini-2.0-flash-exp"
|
|
86
|
+
params = self._add_temperature(LlmTypes.Google.name, llm_model, **llm_params)
|
|
87
|
+
return ChatGoogleGenerativeAI(
|
|
88
|
+
google_api_key=api_key,
|
|
89
|
+
model=llm_model,
|
|
90
|
+
**params,
|
|
91
|
+
)
|
|
92
|
+
elif is_llm_type(llm_type, LlmTypes.Timbr):
|
|
93
|
+
return TimbrLlmWrapper(
|
|
94
|
+
api_key=api_key,
|
|
95
|
+
**params,
|
|
96
|
+
)
|
|
97
|
+
elif is_llm_type(llm_type, LlmTypes.Snowflake):
|
|
98
|
+
from langchain_community.chat_models import ChatSnowflakeCortex
|
|
99
|
+
llm_model = model or "openai-gpt-4.1"
|
|
100
|
+
params = self._add_temperature(LlmTypes.Snowflake.name, llm_model, **llm_params)
|
|
101
|
+
|
|
102
|
+
return ChatSnowflakeCortex(
|
|
103
|
+
model=llm_model,
|
|
104
|
+
**params,
|
|
105
|
+
)
|
|
106
|
+
elif is_llm_type(llm_type, LlmTypes.AzureOpenAI):
|
|
107
|
+
from langchain_openai import AzureChatOpenAI
|
|
108
|
+
azure_endpoint = params.pop('azure_endpoint', None)
|
|
109
|
+
azure_api_version = params.pop('azure_openai_api_version', None)
|
|
110
|
+
llm_model = model or "gpt-4o-2024-11-20"
|
|
111
|
+
params = self._add_temperature(LlmTypes.AzureOpenAI.name, llm_model, **llm_params)
|
|
112
|
+
return AzureChatOpenAI(
|
|
113
|
+
openai_api_key=api_key,
|
|
114
|
+
azure_deployment=llm_model,
|
|
115
|
+
azure_endpoint=azure_endpoint,
|
|
116
|
+
openai_api_version=azure_api_version,
|
|
117
|
+
**params,
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
raise ValueError(f"Unsupported LLM type: {llm_type}")
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def get_model_list(self) -> list[str]:
|
|
124
|
+
"""Return the list of available models for the LLM."""
|
|
125
|
+
models = []
|
|
126
|
+
try:
|
|
127
|
+
if is_llm_type(self._llm_type, LlmTypes.OpenAI):
|
|
128
|
+
from openai import OpenAI
|
|
129
|
+
client = OpenAI(api_key=self.client.openai_api_key._secret_value)
|
|
130
|
+
models = [model.id for model in client.models.list()]
|
|
131
|
+
elif is_llm_type(self._llm_type, LlmTypes.Anthropic):
|
|
132
|
+
import anthropic
|
|
133
|
+
client = anthropic.Anthropic(api_key=self.client.anthropic_api_key._secret_value)
|
|
134
|
+
models = [model.id for model in client.models.list()]
|
|
135
|
+
elif is_llm_type(self._llm_type, LlmTypes.Google):
|
|
136
|
+
import google.generativeai as genai
|
|
137
|
+
genai.configure(api_key=self.client.google_api_key._secret_value)
|
|
138
|
+
models = [m.name.replace('models/', '') for m in genai.list_models()]
|
|
139
|
+
elif is_llm_type(self._llm_type, LlmTypes.AzureOpenAI):
|
|
140
|
+
from openai import AzureOpenAI
|
|
141
|
+
# Get Azure-specific attributes from the client
|
|
142
|
+
azure_endpoint = getattr(self.client, 'azure_endpoint', None)
|
|
143
|
+
api_version = getattr(self.client, 'openai_api_version', None)
|
|
144
|
+
api_key = self.client.openai_api_key._secret_value
|
|
145
|
+
|
|
146
|
+
if azure_endpoint and api_version and api_key:
|
|
147
|
+
client = AzureOpenAI(
|
|
148
|
+
api_key=api_key,
|
|
149
|
+
azure_endpoint=azure_endpoint,
|
|
150
|
+
api_version=api_version
|
|
151
|
+
)
|
|
152
|
+
# For Azure, get the deployments instead of models
|
|
153
|
+
try:
|
|
154
|
+
models = [model.id for model in client.models.list()]
|
|
155
|
+
except:
|
|
156
|
+
# If listing models fails, provide some common deployment names
|
|
157
|
+
models = ["gpt-4o", "Other (Custom)"]
|
|
158
|
+
elif is_llm_type(self._llm_type, LlmTypes.Snowflake):
|
|
159
|
+
# Snowflake Cortex available models
|
|
160
|
+
models = [
|
|
161
|
+
"openai-gpt-4.1",
|
|
162
|
+
"mistral-large2",
|
|
163
|
+
"llama3.1-70b",
|
|
164
|
+
"llama3.1-405b"
|
|
165
|
+
]
|
|
166
|
+
# elif self._is_llm_type(self._llm_type, LlmTypes.Timbr):
|
|
167
|
+
|
|
168
|
+
except Exception as e:
|
|
169
|
+
models = []
|
|
170
|
+
|
|
171
|
+
return models
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _call(self, prompt, **kwargs):
|
|
175
|
+
return self.client(prompt, **kwargs)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def __call__(self, prompt, **kwargs):
|
|
179
|
+
"""
|
|
180
|
+
Override the default __call__ method to handle input preprocessing.
|
|
181
|
+
I used this in order to override prompt input validation made by pydantic
|
|
182
|
+
and allow sending list of AiMessages instead of string only
|
|
183
|
+
"""
|
|
184
|
+
return self._call(prompt, **kwargs)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def query(self, prompt, **kwargs):
|
|
188
|
+
return self._call(prompt, **kwargs)
|
|
189
|
+
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from langchain.llms.base import LLM
|
|
2
|
+
import requests
|
|
3
|
+
from typing import Optional, List
|
|
4
|
+
|
|
5
|
+
class TimbrLlmWrapper(LLM):
|
|
6
|
+
def __init__(self, url: str, api_key: str, temperature: Optional[float] = 0):
|
|
7
|
+
"""
|
|
8
|
+
***TBD, Not ready yet.***
|
|
9
|
+
|
|
10
|
+
Custom LLM implementation for timbr LLM wrapped with a proxy server.
|
|
11
|
+
|
|
12
|
+
:param url: URL of the proxy server wrapping timbr LLM.
|
|
13
|
+
:param api_key: API key for authentication with the proxy server.
|
|
14
|
+
:param temperature: Sampling temperature for the model.
|
|
15
|
+
"""
|
|
16
|
+
self.url = url
|
|
17
|
+
self.api_key = api_key
|
|
18
|
+
self.temperature = temperature
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def _llm_type(self) -> str:
|
|
22
|
+
return "timbr"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
26
|
+
"""
|
|
27
|
+
Sends the prompt to the proxy server and returns the response.
|
|
28
|
+
"""
|
|
29
|
+
headers = { "Authorization": f"Bearer {self.api_key}" }
|
|
30
|
+
payload = {
|
|
31
|
+
"prompt": prompt,
|
|
32
|
+
"temperature": self.temperature,
|
|
33
|
+
"stop": stop,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
response = requests.post(self.url, json=payload, headers=headers)
|
|
37
|
+
if response.status_code == 200:
|
|
38
|
+
return response.json().get("response", "")
|
|
39
|
+
else:
|
|
40
|
+
raise ValueError(f"Error communicating with timbr proxy: {response.text}")
|
|
41
|
+
|
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
from typing import Optional, Any, Literal
|
|
2
|
+
from typing_extensions import TypedDict
|
|
3
|
+
from langchain.llms.base import LLM
|
|
4
|
+
from langgraph.graph import StateGraph, END
|
|
5
|
+
|
|
6
|
+
from .utils.general import to_boolean, to_integer
|
|
7
|
+
from .llm_wrapper.llm_wrapper import LlmWrapper
|
|
8
|
+
from .utils.timbr_utils import get_ontologies, get_concepts
|
|
9
|
+
from .langchain import IdentifyTimbrConceptChain, GenerateTimbrSqlChain, ValidateTimbrSqlChain, ExecuteTimbrQueryChain, create_timbr_sql_agent
|
|
10
|
+
from .langgraph import GenerateTimbrSqlNode, ValidateSemanticSqlNode, ExecuteSemanticQueryNode, GenerateResponseNode
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
from .config import (
|
|
14
|
+
url as default_url,
|
|
15
|
+
token as default_token,
|
|
16
|
+
ontology as default_ontology,
|
|
17
|
+
llm_type,
|
|
18
|
+
llm_model,
|
|
19
|
+
llm_api_key,
|
|
20
|
+
llm_temperature,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
class TimbrLanggraphState(TypedDict):
|
|
24
|
+
prompt: str
|
|
25
|
+
sql: str
|
|
26
|
+
concept: str
|
|
27
|
+
rows: list
|
|
28
|
+
response: str
|
|
29
|
+
error: str
|
|
30
|
+
is_sql_valid: bool
|
|
31
|
+
usage_metadata: dict[str, Any]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TimbrLlmConnector:
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
llm: LLM,
|
|
38
|
+
url: Optional[str] = default_url,
|
|
39
|
+
token: Optional[str] = default_token,
|
|
40
|
+
ontology: Optional[str] = default_ontology,
|
|
41
|
+
max_limit: Optional[int] = 500,
|
|
42
|
+
verify_ssl: Optional[bool] = True,
|
|
43
|
+
is_jwt: Optional[bool] = False,
|
|
44
|
+
jwt_tenant_id: Optional[str] = None,
|
|
45
|
+
conn_params: Optional[dict] = None,
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
:param url: Timbr server url
|
|
49
|
+
:param token: Timbr password or token value
|
|
50
|
+
:param ontology: The name of the ontology/knowledge graph
|
|
51
|
+
:param llm: An LLM instance or a function that takes a prompt string and returns the LLM’s response
|
|
52
|
+
:param max_limit: Maximum number of rows to return
|
|
53
|
+
:param verify_ssl: Whether to verify SSL certificates (default is True).
|
|
54
|
+
:param is_jwt: Whether to use JWT authentication (default is False).
|
|
55
|
+
:param jwt_tenant_id: Tenant ID for JWT authentication (if applicable).
|
|
56
|
+
:param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
|
|
57
|
+
|
|
58
|
+
## Example
|
|
59
|
+
```
|
|
60
|
+
timbr_llm_wrapper = LlmWrapper(
|
|
61
|
+
llm_type=LlmTypes.OpenAI,
|
|
62
|
+
model="gpt-4o"
|
|
63
|
+
api_key=<openai_api_key>
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
llm_connector = TimbrLlmConnector(
|
|
67
|
+
url=<url>,
|
|
68
|
+
token=<token>,
|
|
69
|
+
llm=timbr_llm_wrapper,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Show ontology list at timbr instance from url connection
|
|
73
|
+
ontologies = llm_connector.get_ontologies()
|
|
74
|
+
|
|
75
|
+
# Find which concept & schema will be queried by the user input
|
|
76
|
+
determine_concept_res = llm_connector.determine_concept(llm_input)
|
|
77
|
+
query_concept, query_schema = determine_concept_res.get("concept"), determine_concept_res.get("schema")
|
|
78
|
+
|
|
79
|
+
# Generate timbr SQL query from user input
|
|
80
|
+
sql_query = llm_connector.generate_sql(llm_input).get("sql")
|
|
81
|
+
|
|
82
|
+
# Run timbr SQL query
|
|
83
|
+
results = llm_connector.run_timbr_query(sql_query).get("rows", [])
|
|
84
|
+
|
|
85
|
+
# Parse & Run LLM question
|
|
86
|
+
results = llm_connector.run_llm_query(llm_input).get("rows", [])
|
|
87
|
+
```
|
|
88
|
+
"""
|
|
89
|
+
self.url = url
|
|
90
|
+
self.token = token
|
|
91
|
+
self.ontology = ontology
|
|
92
|
+
self.max_limit = to_integer(max_limit)
|
|
93
|
+
self.verify_ssl = to_boolean(verify_ssl)
|
|
94
|
+
self.is_jwt = to_boolean(is_jwt)
|
|
95
|
+
self.jwt_tenant_id = jwt_tenant_id
|
|
96
|
+
self.conn_params = conn_params or {}
|
|
97
|
+
|
|
98
|
+
if llm is not None:
|
|
99
|
+
self._llm = llm
|
|
100
|
+
elif llm_type is not None and llm_api_key is not None:
|
|
101
|
+
llm_params = {}
|
|
102
|
+
if llm_temperature is not None:
|
|
103
|
+
llm_params["temperature"] = llm_temperature
|
|
104
|
+
|
|
105
|
+
self._llm = LlmWrapper(
|
|
106
|
+
llm_type=llm_type,
|
|
107
|
+
api_key=llm_api_key,
|
|
108
|
+
model=llm_model,
|
|
109
|
+
**llm_params,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# TODO: Make this function a decorator and use in on relevant methods
|
|
114
|
+
# def _is_ontology_set(self):
|
|
115
|
+
# return self.ontology != 'system_db'
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _get_conn_params(self):
|
|
119
|
+
return {
|
|
120
|
+
"url": self.url,
|
|
121
|
+
"token": self.token,
|
|
122
|
+
"ontology": self.ontology,
|
|
123
|
+
"verify_ssl": self.verify_ssl,
|
|
124
|
+
"is_jwt": self.is_jwt,
|
|
125
|
+
"jwt_tenant_id": self.jwt_tenant_id,
|
|
126
|
+
**self.conn_params,
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def get_ontologies(self) -> list[str]:
|
|
131
|
+
return get_ontologies(conn_params=self._get_conn_params())
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_concepts(self) -> dict:
|
|
135
|
+
"""
|
|
136
|
+
Get the list of concepts from the Timbr server.
|
|
137
|
+
"""
|
|
138
|
+
return get_concepts(
|
|
139
|
+
conn_params=self._get_conn_params(),
|
|
140
|
+
concepts_list="*",
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def get_views(self) -> dict:
|
|
145
|
+
"""
|
|
146
|
+
Get the list of views from the Timbr server.
|
|
147
|
+
"""
|
|
148
|
+
return get_concepts(
|
|
149
|
+
conn_params=self._get_conn_params(),
|
|
150
|
+
views_list="*",
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def set_ontology(self, ontology: str):
|
|
155
|
+
self.ontology = ontology
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def determine_concept(
|
|
159
|
+
self,
|
|
160
|
+
question: str,
|
|
161
|
+
concepts_list: Optional[list] = None,
|
|
162
|
+
views_list: Optional[list] = None,
|
|
163
|
+
include_logic_concepts: Optional[bool] = False,
|
|
164
|
+
include_tags: Optional[str] = None,
|
|
165
|
+
should_validate: Optional[bool] = False,
|
|
166
|
+
retries: Optional[int] = 3,
|
|
167
|
+
note: Optional[str] = '',
|
|
168
|
+
**chain_kwargs: Any,
|
|
169
|
+
) -> dict[str, Any]:
|
|
170
|
+
determine_concept_chain = IdentifyTimbrConceptChain(
|
|
171
|
+
**self._get_conn_params(),
|
|
172
|
+
llm=self._llm,
|
|
173
|
+
concepts_list=concepts_list,
|
|
174
|
+
views_list=views_list,
|
|
175
|
+
include_logic_concepts=include_logic_concepts,
|
|
176
|
+
include_tags=include_tags,
|
|
177
|
+
should_validate=should_validate,
|
|
178
|
+
retries=retries,
|
|
179
|
+
note=note,
|
|
180
|
+
**chain_kwargs,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
return determine_concept_chain.invoke({ "prompt": question })
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def generate_sql(
|
|
187
|
+
self,
|
|
188
|
+
question: str,
|
|
189
|
+
concept_name: Optional[str] = None,
|
|
190
|
+
schema: Optional[str] = None,
|
|
191
|
+
concepts_list: Optional[list] = None,
|
|
192
|
+
views_list: Optional[list] = None,
|
|
193
|
+
include_logic_concepts: Optional[bool] = False,
|
|
194
|
+
include_tags: Optional[str] = None,
|
|
195
|
+
should_validate_sql: Optional[bool] = False,
|
|
196
|
+
retries: Optional[int] = 3,
|
|
197
|
+
note: Optional[str] = '',
|
|
198
|
+
**chain_kwargs: Any,
|
|
199
|
+
) -> dict[str, Any]:
|
|
200
|
+
generate_timbr_llm_chain = GenerateTimbrSqlChain(
|
|
201
|
+
llm=self._llm,
|
|
202
|
+
**self._get_conn_params(),
|
|
203
|
+
schema=schema,
|
|
204
|
+
concept=concept_name,
|
|
205
|
+
concepts_list=concepts_list,
|
|
206
|
+
views_list=views_list,
|
|
207
|
+
include_logic_concepts=include_logic_concepts,
|
|
208
|
+
include_tags=include_tags,
|
|
209
|
+
should_validate_sql=should_validate_sql,
|
|
210
|
+
retries=retries,
|
|
211
|
+
max_limit=self.max_limit,
|
|
212
|
+
note=note,
|
|
213
|
+
**chain_kwargs,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
return generate_timbr_llm_chain.invoke({ "prompt": question })
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def validate_sql(
|
|
220
|
+
self,
|
|
221
|
+
question: str,
|
|
222
|
+
sql_query: str,
|
|
223
|
+
retries: Optional[int] = 3,
|
|
224
|
+
concepts_list: Optional[list] = None,
|
|
225
|
+
views_list: Optional[list] = None,
|
|
226
|
+
include_logic_concepts: Optional[bool] = False,
|
|
227
|
+
include_tags: Optional[str] = None,
|
|
228
|
+
note: Optional[str] = '',
|
|
229
|
+
**chain_kwargs: Any,
|
|
230
|
+
) -> dict[str, Any]:
|
|
231
|
+
validate_timbr_sql_chain = ValidateTimbrSqlChain(
|
|
232
|
+
llm=self._llm,
|
|
233
|
+
**self._get_conn_params(),
|
|
234
|
+
retries=retries,
|
|
235
|
+
concepts_list=concepts_list,
|
|
236
|
+
views_list=views_list,
|
|
237
|
+
include_logic_concepts=include_logic_concepts,
|
|
238
|
+
include_tags=include_tags,
|
|
239
|
+
max_limit=self.max_limit,
|
|
240
|
+
note=note,
|
|
241
|
+
**chain_kwargs,
|
|
242
|
+
)
|
|
243
|
+
return validate_timbr_sql_chain.invoke({ "sql": sql_query, "prompt": question })
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def run_timbr_query(
|
|
247
|
+
self,
|
|
248
|
+
sql_query: str,
|
|
249
|
+
concepts_list: Optional[list] = None,
|
|
250
|
+
views_list: Optional[list] = None,
|
|
251
|
+
include_logic_concepts: Optional[bool] = False,
|
|
252
|
+
include_tags: Optional[str] = None,
|
|
253
|
+
should_validate_sql: Optional[bool] = True,
|
|
254
|
+
retries: Optional[int] = 3,
|
|
255
|
+
note: Optional[str] = '',
|
|
256
|
+
**chain_kwargs: Any,
|
|
257
|
+
) -> dict[str, Any]:
|
|
258
|
+
execute_timbr_query_chain = ExecuteTimbrQueryChain(
|
|
259
|
+
llm=self._llm,
|
|
260
|
+
**self._get_conn_params(),
|
|
261
|
+
concepts_list=concepts_list,
|
|
262
|
+
views_list=views_list,
|
|
263
|
+
include_logic_concepts=include_logic_concepts,
|
|
264
|
+
include_tags=include_tags,
|
|
265
|
+
should_validate_sql=should_validate_sql,
|
|
266
|
+
retries=retries,
|
|
267
|
+
max_limit=self.max_limit,
|
|
268
|
+
note=note,
|
|
269
|
+
**chain_kwargs,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
return execute_timbr_query_chain.invoke({ "sql": sql_query })
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def run_llm_query(
|
|
276
|
+
self,
|
|
277
|
+
question: str,
|
|
278
|
+
concepts_list: Optional[list] = None,
|
|
279
|
+
views_list: Optional[list] = None,
|
|
280
|
+
include_logic_concepts: Optional[bool] = False,
|
|
281
|
+
include_tags: Optional[str] = None,
|
|
282
|
+
should_validate_sql: Optional[bool] = True,
|
|
283
|
+
retries: Optional[int] = 3,
|
|
284
|
+
note: Optional[str] = '',
|
|
285
|
+
**agent_kwargs: Any,
|
|
286
|
+
) -> dict[str, Any]:
|
|
287
|
+
agent = create_timbr_sql_agent(
|
|
288
|
+
llm=self._llm,
|
|
289
|
+
**self._get_conn_params(),
|
|
290
|
+
concept=None,
|
|
291
|
+
concepts_list=concepts_list,
|
|
292
|
+
views_list=views_list,
|
|
293
|
+
include_logic_concepts=include_logic_concepts,
|
|
294
|
+
include_tags=include_tags,
|
|
295
|
+
should_validate_sql=should_validate_sql,
|
|
296
|
+
retries=retries,
|
|
297
|
+
max_limit=self.max_limit,
|
|
298
|
+
note=note,
|
|
299
|
+
**agent_kwargs,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
return agent.invoke(question)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def run_llm_query_graph(
|
|
306
|
+
self,
|
|
307
|
+
question: str,
|
|
308
|
+
concepts_list: Optional[list] = None,
|
|
309
|
+
views_list: Optional[list] = None,
|
|
310
|
+
include_logic_concepts: Optional[bool] = False,
|
|
311
|
+
include_tags: Optional[str] = None,
|
|
312
|
+
should_validate_sql: Optional[bool] = True,
|
|
313
|
+
retries: Optional[int] = 3,
|
|
314
|
+
note: Optional[str] = '',
|
|
315
|
+
**nodes_kwargs: Any,
|
|
316
|
+
) -> dict[str, Any]:
|
|
317
|
+
generate_sql_node = GenerateTimbrSqlNode(
|
|
318
|
+
llm=self._llm,
|
|
319
|
+
**self._get_conn_params(),
|
|
320
|
+
concepts_list=concepts_list,
|
|
321
|
+
views_list=views_list,
|
|
322
|
+
include_logic_concepts=include_logic_concepts,
|
|
323
|
+
include_tags=include_tags,
|
|
324
|
+
max_limit=self.max_limit,
|
|
325
|
+
note=note,
|
|
326
|
+
**nodes_kwargs,
|
|
327
|
+
)
|
|
328
|
+
validate_sql_node = ValidateSemanticSqlNode(
|
|
329
|
+
llm=self._llm,
|
|
330
|
+
**self._get_conn_params(),
|
|
331
|
+
retries=retries,
|
|
332
|
+
concepts_list=concepts_list,
|
|
333
|
+
views_list=views_list,
|
|
334
|
+
include_logic_concepts=include_logic_concepts,
|
|
335
|
+
include_tags=include_tags,
|
|
336
|
+
max_limit=self.max_limit,
|
|
337
|
+
note=note,
|
|
338
|
+
**nodes_kwargs,
|
|
339
|
+
)
|
|
340
|
+
execute_sql_node = ExecuteSemanticQueryNode(
|
|
341
|
+
llm=self._llm,
|
|
342
|
+
**self._get_conn_params(),
|
|
343
|
+
concepts_list=concepts_list,
|
|
344
|
+
views_list=views_list,
|
|
345
|
+
include_logic_concepts=include_logic_concepts,
|
|
346
|
+
include_tags=include_tags,
|
|
347
|
+
should_validate_sql=should_validate_sql,
|
|
348
|
+
retries=retries,
|
|
349
|
+
max_limit=self.max_limit,
|
|
350
|
+
note=note,
|
|
351
|
+
**nodes_kwargs,
|
|
352
|
+
)
|
|
353
|
+
generate_response_node = GenerateResponseNode()
|
|
354
|
+
|
|
355
|
+
graph_builder = StateGraph(TimbrLanggraphState)
|
|
356
|
+
|
|
357
|
+
graph_builder.add_node("generate_sql", generate_sql_node)
|
|
358
|
+
graph_builder.add_node("validate_sql", validate_sql_node)
|
|
359
|
+
graph_builder.add_node("execute_sql", execute_sql_node)
|
|
360
|
+
graph_builder.add_node("generate_response", generate_response_node)
|
|
361
|
+
|
|
362
|
+
graph_builder.add_edge("generate_sql", "validate_sql")
|
|
363
|
+
|
|
364
|
+
def route_validation(state: dict) -> Literal["execute_sql", "end"]:
|
|
365
|
+
# If validation is successful, proceed to execute the query.
|
|
366
|
+
# Otherwise, stop the flow.
|
|
367
|
+
if state.get("is_sql_valid"):
|
|
368
|
+
return "execute_sql"
|
|
369
|
+
else:
|
|
370
|
+
return "end"
|
|
371
|
+
|
|
372
|
+
graph_builder.add_conditional_edges(
|
|
373
|
+
"validate_sql",
|
|
374
|
+
route_validation,
|
|
375
|
+
{
|
|
376
|
+
"execute_sql": "execute_sql",
|
|
377
|
+
"end": END
|
|
378
|
+
}
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
graph_builder.add_edge("execute_sql", "generate_response")
|
|
382
|
+
graph_builder.set_entry_point("generate_sql")
|
|
383
|
+
|
|
384
|
+
compiled_graph = graph_builder.compile()
|
|
385
|
+
|
|
386
|
+
initial_state = {
|
|
387
|
+
"prompt": question,
|
|
388
|
+
"sql": "",
|
|
389
|
+
"concept": "",
|
|
390
|
+
"rows": [],
|
|
391
|
+
"response": "",
|
|
392
|
+
"error": "",
|
|
393
|
+
"is_sql_valid": False,
|
|
394
|
+
"usage_metadata": {}
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
result = compiled_graph.invoke(initial_state)
|
|
398
|
+
return result
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
|
|
4
|
+
### A global helper functions to use across the project
|
|
5
|
+
|
|
6
|
+
def parse_list(input_value, separator=',') -> list[str]:
|
|
7
|
+
try:
|
|
8
|
+
if isinstance(input_value, str):
|
|
9
|
+
return [item.strip() for item in input_value.split(separator) if item.strip()]
|
|
10
|
+
elif isinstance(input_value, list):
|
|
11
|
+
return [item.strip() for item in input_value if item.strip()]
|
|
12
|
+
return []
|
|
13
|
+
except Exception as e:
|
|
14
|
+
raise ValueError(f"Failed to parse list value: {e}")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def to_boolean(value) -> bool:
|
|
18
|
+
try:
|
|
19
|
+
if isinstance(value, str):
|
|
20
|
+
return value.lower() in ['true', '1']
|
|
21
|
+
return bool(value)
|
|
22
|
+
except Exception as e:
|
|
23
|
+
raise ValueError(f"Failed to parse boolean value: {e}")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def to_integer(value) -> int:
|
|
27
|
+
try:
|
|
28
|
+
return int(value)
|
|
29
|
+
except (ValueError, TypeError) as e:
|
|
30
|
+
raise ValueError(f"Failed to parse integer value: {e}")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def is_llm_type(llm_type, enum_value):
|
|
34
|
+
"""Check if llm_type equals the enum value or its name, case-insensitive."""
|
|
35
|
+
if llm_type == enum_value:
|
|
36
|
+
return True
|
|
37
|
+
|
|
38
|
+
if isinstance(llm_type, str):
|
|
39
|
+
llm_type_lower = llm_type.lower()
|
|
40
|
+
enum_name_lower = enum_value.name.lower() if enum_value.name else ""
|
|
41
|
+
enum_value_lower = enum_value.value.lower() if isinstance(enum_value.value, str) else ""
|
|
42
|
+
|
|
43
|
+
return (
|
|
44
|
+
llm_type_lower == enum_name_lower or
|
|
45
|
+
llm_type_lower == enum_value_lower or
|
|
46
|
+
llm_type_lower.startswith(enum_name_lower) or # Usecase for snowflake which its type is the provider name + the model name
|
|
47
|
+
llm_type_lower.startswith(enum_value_lower)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return False
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def is_support_temperature(llm_type: str, llm_model: str) -> bool:
|
|
54
|
+
"""
|
|
55
|
+
Check if the LLM model supports temperature setting.
|
|
56
|
+
"""
|
|
57
|
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
58
|
+
json_file_path = os.path.join(current_dir, 'temperature_supported_models.json')
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
with open(json_file_path, 'r') as f:
|
|
62
|
+
temperature_supported_models = json.load(f)
|
|
63
|
+
|
|
64
|
+
# Check if llm_type exists and llm_model is in its list
|
|
65
|
+
if llm_type in temperature_supported_models:
|
|
66
|
+
return llm_model in temperature_supported_models[llm_type]
|
|
67
|
+
|
|
68
|
+
return False
|
|
69
|
+
except (FileNotFoundError, json.JSONDecodeError, KeyError):
|
|
70
|
+
return False
|