langchain-timbr 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_timbr/__init__.py +17 -0
- langchain_timbr/config.py +21 -0
- langchain_timbr/langchain/__init__.py +16 -0
- langchain_timbr/langchain/execute_timbr_query_chain.py +307 -0
- langchain_timbr/langchain/generate_answer_chain.py +99 -0
- langchain_timbr/langchain/generate_timbr_sql_chain.py +176 -0
- langchain_timbr/langchain/identify_concept_chain.py +138 -0
- langchain_timbr/langchain/timbr_sql_agent.py +418 -0
- langchain_timbr/langchain/validate_timbr_sql_chain.py +187 -0
- langchain_timbr/langgraph/__init__.py +13 -0
- langchain_timbr/langgraph/execute_timbr_query_node.py +108 -0
- langchain_timbr/langgraph/generate_response_node.py +59 -0
- langchain_timbr/langgraph/generate_timbr_sql_node.py +98 -0
- langchain_timbr/langgraph/identify_concept_node.py +78 -0
- langchain_timbr/langgraph/validate_timbr_query_node.py +100 -0
- langchain_timbr/llm_wrapper/llm_wrapper.py +189 -0
- langchain_timbr/llm_wrapper/timbr_llm_wrapper.py +41 -0
- langchain_timbr/timbr_llm_connector.py +398 -0
- langchain_timbr/utils/general.py +70 -0
- langchain_timbr/utils/prompt_service.py +330 -0
- langchain_timbr/utils/temperature_supported_models.json +62 -0
- langchain_timbr/utils/timbr_llm_utils.py +575 -0
- langchain_timbr/utils/timbr_utils.py +475 -0
- langchain_timbr-1.5.0.dist-info/METADATA +103 -0
- langchain_timbr-1.5.0.dist-info/RECORD +27 -0
- langchain_timbr-1.5.0.dist-info/WHEEL +4 -0
- langchain_timbr-1.5.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
from typing import Dict, Any, Optional, List, Union
|
|
3
|
+
from langchain.schema import SystemMessage, HumanMessage
|
|
4
|
+
from langchain.prompts.chat import ChatPromptTemplate
|
|
5
|
+
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
from ..config import url, token as default_token, is_jwt, jwt_tenant_id as default_jwt_tenant_id, llm_timeout
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
# Global template cache shared across all PromptService instances
|
|
14
|
+
_global_template_cache = {}
|
|
15
|
+
|
|
16
|
+
class PromptService:
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
base_url: Optional[str] = url,
|
|
20
|
+
token: Optional[str] = default_token,
|
|
21
|
+
is_jwt: Optional[bool] = is_jwt,
|
|
22
|
+
jwt_tenant_id: Optional[str] = default_jwt_tenant_id,
|
|
23
|
+
timeout: Optional[int] = llm_timeout,
|
|
24
|
+
):
|
|
25
|
+
self.base_url = base_url.rstrip('/')
|
|
26
|
+
self.token = token
|
|
27
|
+
self.is_jwt = is_jwt
|
|
28
|
+
self.jwt_tenant_id = jwt_tenant_id
|
|
29
|
+
self.timeout = timeout
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _get_headers(self) -> Dict[str, str]:
|
|
33
|
+
"""Get headers for API requests"""
|
|
34
|
+
headers = {"Content-Type": "application/json"}
|
|
35
|
+
|
|
36
|
+
if self.is_jwt:
|
|
37
|
+
headers["x-jwt-token"] = self.token
|
|
38
|
+
if self.jwt_tenant_id:
|
|
39
|
+
headers["x-jwt-tenant-id"] = self.jwt_tenant_id
|
|
40
|
+
elif self.token:
|
|
41
|
+
headers["x-api-key"] = self.token
|
|
42
|
+
|
|
43
|
+
return headers
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _convert_template_response_to_chat_prompt(self, api_response: Union[List, Dict[str, Any]]) -> ChatPromptTemplate:
|
|
47
|
+
"""
|
|
48
|
+
Convert API response containing template strings to ChatPromptTemplate
|
|
49
|
+
|
|
50
|
+
Expected API response format:
|
|
51
|
+
{
|
|
52
|
+
"status": "success",
|
|
53
|
+
"data": [
|
|
54
|
+
{"type": "SystemMessage", "template": "You are a helpful SQL expert...", "role": "system"},
|
|
55
|
+
{"type": "HumanMessage", "template": "BUSINESS QUESTION: {question}...", "role": "human"}
|
|
56
|
+
]
|
|
57
|
+
}
|
|
58
|
+
"""
|
|
59
|
+
# Handle response with status and data fields
|
|
60
|
+
if isinstance(api_response, dict) and "status" in api_response:
|
|
61
|
+
if api_response["status"] == "error":
|
|
62
|
+
error_msg = api_response.get("data", "Unknown error from prompt service")
|
|
63
|
+
raise Exception(f"Prompt service error: {error_msg}")
|
|
64
|
+
elif api_response["status"] == "success" and "data" in api_response:
|
|
65
|
+
# Process the data array
|
|
66
|
+
data = api_response["data"]
|
|
67
|
+
if isinstance(data, list):
|
|
68
|
+
return self._parse_template_array(data)
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError("Expected 'data' to be an array of template objects")
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError(f"Invalid API response: unexpected status '{api_response['status']}' or missing 'data' field")
|
|
73
|
+
|
|
74
|
+
raise ValueError("Invalid API response format: expected object with 'status' and 'data' fields")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _parse_template_array(self, templates: List[Dict[str, Any]]) -> ChatPromptTemplate:
|
|
78
|
+
"""
|
|
79
|
+
Parse an array of template dictionaries into a ChatPromptTemplate
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
templates: List of template dictionaries
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
ChatPromptTemplate object
|
|
86
|
+
"""
|
|
87
|
+
message_templates = []
|
|
88
|
+
for tmpl in templates:
|
|
89
|
+
if not isinstance(tmpl, dict) or "template" not in tmpl:
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
# Check type field (SystemMessage/HumanMessage) or role field (system/human)
|
|
93
|
+
msg_type = tmpl.get("type", "").lower()
|
|
94
|
+
msg_role = tmpl.get("role", "").lower()
|
|
95
|
+
template_str = tmpl["template"]
|
|
96
|
+
|
|
97
|
+
if msg_type == "systemmessage" or msg_role == "system":
|
|
98
|
+
message_templates.append(SystemMessagePromptTemplate.from_template(template_str))
|
|
99
|
+
elif msg_type == "humanmessage" or msg_role == "human":
|
|
100
|
+
message_templates.append(HumanMessagePromptTemplate.from_template(template_str))
|
|
101
|
+
else:
|
|
102
|
+
# Default to HumanMessage for unknown types
|
|
103
|
+
message_templates.append(HumanMessagePromptTemplate.from_template(template_str))
|
|
104
|
+
|
|
105
|
+
return ChatPromptTemplate.from_messages(message_templates)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _fetch_template(self, endpoint: str) -> ChatPromptTemplate:
|
|
109
|
+
"""
|
|
110
|
+
Fetch template from API service without any data parameters
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
endpoint: The API endpoint to call
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
ChatPromptTemplate object
|
|
117
|
+
"""
|
|
118
|
+
# Check global cache first
|
|
119
|
+
if endpoint in _global_template_cache:
|
|
120
|
+
logger.debug(f"Using cached template for endpoint: {endpoint}")
|
|
121
|
+
return _global_template_cache[endpoint]
|
|
122
|
+
|
|
123
|
+
url = f"{self.base_url}/timbr/api/{endpoint}"
|
|
124
|
+
headers = self._get_headers()
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
response = requests.post(
|
|
128
|
+
url,
|
|
129
|
+
headers=headers,
|
|
130
|
+
timeout=self.timeout
|
|
131
|
+
)
|
|
132
|
+
response.raise_for_status()
|
|
133
|
+
|
|
134
|
+
api_response = response.json()
|
|
135
|
+
chat_prompt = self._convert_template_response_to_chat_prompt(api_response)
|
|
136
|
+
|
|
137
|
+
# Cache the template globally
|
|
138
|
+
_global_template_cache[endpoint] = chat_prompt
|
|
139
|
+
logger.debug(f"Cached template for endpoint: {endpoint}")
|
|
140
|
+
|
|
141
|
+
return chat_prompt
|
|
142
|
+
|
|
143
|
+
except requests.exceptions.RequestException as e:
|
|
144
|
+
logger.error(f"Failed to get template from service {url}: {str(e)}")
|
|
145
|
+
raise Exception(f"Prompt service request failed: {str(e)}")
|
|
146
|
+
except json.JSONDecodeError as e:
|
|
147
|
+
logger.error(f"Invalid JSON response from prompt service: {str(e)}")
|
|
148
|
+
raise Exception(f"Invalid response from prompt service: {str(e)}")
|
|
149
|
+
except Exception as e:
|
|
150
|
+
logger.error(f"Error processing prompt service response: {str(e)}")
|
|
151
|
+
raise Exception(f"Error processing prompt service response: {str(e)}")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def get_identify_concept_template(self) -> ChatPromptTemplate:
|
|
155
|
+
"""
|
|
156
|
+
Get identify concept template from API service (cached)
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
ChatPromptTemplate object
|
|
160
|
+
"""
|
|
161
|
+
return self._fetch_template("llm_prompts/identify_concept")
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def get_generate_sql_template(self) -> ChatPromptTemplate:
|
|
165
|
+
"""
|
|
166
|
+
Get generate SQL template from API service (cached)
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
ChatPromptTemplate object
|
|
170
|
+
"""
|
|
171
|
+
return self._fetch_template("llm_prompts/generate_sql")
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def get_generate_answer_template(self) -> ChatPromptTemplate:
|
|
175
|
+
"""
|
|
176
|
+
Get generate answer template from API service (cached)
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
ChatPromptTemplate object
|
|
180
|
+
"""
|
|
181
|
+
return self._fetch_template("llm_prompts/generate_answer")
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def clear_cache(self):
|
|
185
|
+
"""Clear the global template cache"""
|
|
186
|
+
_global_template_cache.clear()
|
|
187
|
+
logger.info("Global prompt template cache cleared")
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class PromptTemplateWrapper:
|
|
191
|
+
"""
|
|
192
|
+
Wrapper class that mimics the original ChatPromptTemplate behavior
|
|
193
|
+
but uses cached templates from the external API service
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
def __init__(self, prompt_service: PromptService, template_method: str):
|
|
197
|
+
self.prompt_service = prompt_service
|
|
198
|
+
self.template_method = template_method
|
|
199
|
+
self._cached_template = None
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def format_messages(self, **kwargs) -> List:
|
|
203
|
+
"""
|
|
204
|
+
Format messages using the cached template
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
**kwargs: Parameters for the prompt template
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
List of LangChain message objects
|
|
211
|
+
"""
|
|
212
|
+
# Get the cached template
|
|
213
|
+
if self._cached_template is None:
|
|
214
|
+
method = getattr(self.prompt_service, self.template_method)
|
|
215
|
+
self._cached_template = method()
|
|
216
|
+
|
|
217
|
+
# Format the template with the provided kwargs
|
|
218
|
+
return self._cached_template.format_messages(**kwargs)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# Individual prompt template getter functions
|
|
222
|
+
def get_determine_concept_prompt_template(
|
|
223
|
+
token: Optional[str] = None,
|
|
224
|
+
is_jwt: Optional[bool] = None,
|
|
225
|
+
jwt_tenant_id: Optional[str] = None
|
|
226
|
+
) -> PromptTemplateWrapper:
|
|
227
|
+
"""
|
|
228
|
+
Get determine concept prompt template wrapper
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
token: Authentication token
|
|
232
|
+
is_jwt: Whether the token is a JWT
|
|
233
|
+
jwt_tenant_id: JWT tenant ID
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
PromptTemplateWrapper for determine concept
|
|
237
|
+
"""
|
|
238
|
+
prompt_service = PromptService(
|
|
239
|
+
token=token,
|
|
240
|
+
is_jwt=is_jwt,
|
|
241
|
+
jwt_tenant_id=jwt_tenant_id
|
|
242
|
+
)
|
|
243
|
+
return PromptTemplateWrapper(prompt_service, "get_identify_concept_template")
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def get_generate_sql_prompt_template(
|
|
247
|
+
token: Optional[str] = None,
|
|
248
|
+
is_jwt: Optional[bool] = None,
|
|
249
|
+
jwt_tenant_id: Optional[str] = None
|
|
250
|
+
) -> PromptTemplateWrapper:
|
|
251
|
+
"""
|
|
252
|
+
Get generate SQL prompt template wrapper
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
token: Authentication token
|
|
256
|
+
is_jwt: Whether the token is a JWT
|
|
257
|
+
jwt_tenant_id: JWT tenant ID
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
PromptTemplateWrapper for generate SQL
|
|
261
|
+
"""
|
|
262
|
+
prompt_service = PromptService(
|
|
263
|
+
token=token,
|
|
264
|
+
is_jwt=is_jwt,
|
|
265
|
+
jwt_tenant_id=jwt_tenant_id
|
|
266
|
+
)
|
|
267
|
+
return PromptTemplateWrapper(prompt_service, "get_generate_sql_template")
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def get_qa_prompt_template(
|
|
271
|
+
token: Optional[str] = None,
|
|
272
|
+
is_jwt: Optional[bool] = None,
|
|
273
|
+
jwt_tenant_id: Optional[str] = None
|
|
274
|
+
) -> PromptTemplateWrapper:
|
|
275
|
+
"""
|
|
276
|
+
Get QA prompt template wrapper
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
token: Authentication token
|
|
280
|
+
is_jwt: Whether the token is a JWT
|
|
281
|
+
jwt_tenant_id: JWT tenant ID
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
PromptTemplateWrapper for QA
|
|
285
|
+
"""
|
|
286
|
+
prompt_service = PromptService(
|
|
287
|
+
token=token,
|
|
288
|
+
is_jwt=is_jwt,
|
|
289
|
+
jwt_tenant_id=jwt_tenant_id
|
|
290
|
+
)
|
|
291
|
+
return PromptTemplateWrapper(prompt_service, "get_generate_answer_template")
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
# Global prompt service instance (updated signature)
|
|
295
|
+
def get_prompt_service(
|
|
296
|
+
token: str = None,
|
|
297
|
+
is_jwt: bool = None,
|
|
298
|
+
jwt_tenant_id: str = None
|
|
299
|
+
) -> PromptService:
|
|
300
|
+
"""
|
|
301
|
+
Get or create a prompt service instance
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
token: Authentication token (API key or JWT token)
|
|
305
|
+
is_jwt: Whether the token is a JWT
|
|
306
|
+
jwt_tenant_id: JWT tenant ID
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
PromptService instance
|
|
310
|
+
"""
|
|
311
|
+
return PromptService(
|
|
312
|
+
token=token,
|
|
313
|
+
is_jwt=is_jwt,
|
|
314
|
+
jwt_tenant_id=jwt_tenant_id
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
# Global cache management functions
|
|
319
|
+
def clear_global_template_cache():
|
|
320
|
+
"""Clear the global template cache"""
|
|
321
|
+
_global_template_cache.clear()
|
|
322
|
+
logger.info("Global prompt template cache cleared")
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def get_cache_status():
|
|
326
|
+
"""Get information about the global template cache"""
|
|
327
|
+
return {
|
|
328
|
+
"cached_endpoints": list(_global_template_cache.keys()),
|
|
329
|
+
"cache_size": len(_global_template_cache)
|
|
330
|
+
}
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
{
|
|
2
|
+
"OpenAI": [
|
|
3
|
+
"gpt-4",
|
|
4
|
+
"gpt-4-turbo",
|
|
5
|
+
"gpt-4o"
|
|
6
|
+
],
|
|
7
|
+
"Anthropic": [
|
|
8
|
+
"claude-opus-4-20250514",
|
|
9
|
+
"claude-sonnet-4-20250514",
|
|
10
|
+
"claude-3-7-sonnet-20250219",
|
|
11
|
+
"claude-3-5-sonnet-20241022",
|
|
12
|
+
"claude-3-5-haiku-20241022",
|
|
13
|
+
"claude-3-5-sonnet-20240620",
|
|
14
|
+
"claude-3-haiku-20240307",
|
|
15
|
+
"claude-3-opus-20240229",
|
|
16
|
+
"claude-3-sonnet-20240229",
|
|
17
|
+
"claude-2.1",
|
|
18
|
+
"claude-2.0"
|
|
19
|
+
],
|
|
20
|
+
"Google": [
|
|
21
|
+
"gemini-1.5-flash-latest",
|
|
22
|
+
"gemini-1.5-flash",
|
|
23
|
+
"gemini-1.5-flash-002",
|
|
24
|
+
"gemini-1.5-flash-8b",
|
|
25
|
+
"gemini-1.5-flash-8b-001",
|
|
26
|
+
"gemini-1.5-flash-8b-latest",
|
|
27
|
+
"gemini-2.5-flash-preview-04-17",
|
|
28
|
+
"gemini-2.5-flash-preview-05-20",
|
|
29
|
+
"gemini-2.5-flash",
|
|
30
|
+
"gemini-2.5-flash-preview-04-17-thinking",
|
|
31
|
+
"gemini-2.5-flash-lite-preview-06-17",
|
|
32
|
+
"gemini-2.5-pro",
|
|
33
|
+
"gemini-2.0-flash-exp",
|
|
34
|
+
"gemini-2.0-flash",
|
|
35
|
+
"gemini-2.0-flash-001",
|
|
36
|
+
"gemini-2.0-flash-exp-image-generation",
|
|
37
|
+
"gemini-2.0-flash-lite-001",
|
|
38
|
+
"gemini-2.0-flash-lite",
|
|
39
|
+
"gemini-2.0-flash-lite-preview-02-05",
|
|
40
|
+
"gemini-2.0-flash-lite-preview",
|
|
41
|
+
"gemini-2.0-flash-thinking-exp-01-21",
|
|
42
|
+
"gemini-2.0-flash-thinking-exp",
|
|
43
|
+
"gemini-2.0-flash-thinking-exp-1219",
|
|
44
|
+
"learnlm-2.0-flash-experimental",
|
|
45
|
+
"gemma-3-1b-it",
|
|
46
|
+
"gemma-3-4b-it",
|
|
47
|
+
"gemma-3-12b-it",
|
|
48
|
+
"gemma-3-27b-it",
|
|
49
|
+
"gemma-3n-e4b-it",
|
|
50
|
+
"gemma-3n-e2b-it"
|
|
51
|
+
],
|
|
52
|
+
"AzureOpenAI": [
|
|
53
|
+
"gpt-4o"
|
|
54
|
+
],
|
|
55
|
+
"Snowflake": [
|
|
56
|
+
"openai-gpt-4.1",
|
|
57
|
+
"mistral-large2",
|
|
58
|
+
"llama3.1-70b",
|
|
59
|
+
"llama3.1-405b"
|
|
60
|
+
],
|
|
61
|
+
"Timbr": []
|
|
62
|
+
}
|