langchain-timbr 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,330 @@
1
+ import requests
2
+ from typing import Dict, Any, Optional, List, Union
3
+ from langchain.schema import SystemMessage, HumanMessage
4
+ from langchain.prompts.chat import ChatPromptTemplate
5
+ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate
6
+ import json
7
+ import logging
8
+
9
+ from ..config import url, token as default_token, is_jwt, jwt_tenant_id as default_jwt_tenant_id, llm_timeout
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Global template cache shared across all PromptService instances
14
+ _global_template_cache = {}
15
+
16
+ class PromptService:
17
+ def __init__(
18
+ self,
19
+ base_url: Optional[str] = url,
20
+ token: Optional[str] = default_token,
21
+ is_jwt: Optional[bool] = is_jwt,
22
+ jwt_tenant_id: Optional[str] = default_jwt_tenant_id,
23
+ timeout: Optional[int] = llm_timeout,
24
+ ):
25
+ self.base_url = base_url.rstrip('/')
26
+ self.token = token
27
+ self.is_jwt = is_jwt
28
+ self.jwt_tenant_id = jwt_tenant_id
29
+ self.timeout = timeout
30
+
31
+
32
+ def _get_headers(self) -> Dict[str, str]:
33
+ """Get headers for API requests"""
34
+ headers = {"Content-Type": "application/json"}
35
+
36
+ if self.is_jwt:
37
+ headers["x-jwt-token"] = self.token
38
+ if self.jwt_tenant_id:
39
+ headers["x-jwt-tenant-id"] = self.jwt_tenant_id
40
+ elif self.token:
41
+ headers["x-api-key"] = self.token
42
+
43
+ return headers
44
+
45
+
46
+ def _convert_template_response_to_chat_prompt(self, api_response: Union[List, Dict[str, Any]]) -> ChatPromptTemplate:
47
+ """
48
+ Convert API response containing template strings to ChatPromptTemplate
49
+
50
+ Expected API response format:
51
+ {
52
+ "status": "success",
53
+ "data": [
54
+ {"type": "SystemMessage", "template": "You are a helpful SQL expert...", "role": "system"},
55
+ {"type": "HumanMessage", "template": "BUSINESS QUESTION: {question}...", "role": "human"}
56
+ ]
57
+ }
58
+ """
59
+ # Handle response with status and data fields
60
+ if isinstance(api_response, dict) and "status" in api_response:
61
+ if api_response["status"] == "error":
62
+ error_msg = api_response.get("data", "Unknown error from prompt service")
63
+ raise Exception(f"Prompt service error: {error_msg}")
64
+ elif api_response["status"] == "success" and "data" in api_response:
65
+ # Process the data array
66
+ data = api_response["data"]
67
+ if isinstance(data, list):
68
+ return self._parse_template_array(data)
69
+ else:
70
+ raise ValueError("Expected 'data' to be an array of template objects")
71
+ else:
72
+ raise ValueError(f"Invalid API response: unexpected status '{api_response['status']}' or missing 'data' field")
73
+
74
+ raise ValueError("Invalid API response format: expected object with 'status' and 'data' fields")
75
+
76
+
77
+ def _parse_template_array(self, templates: List[Dict[str, Any]]) -> ChatPromptTemplate:
78
+ """
79
+ Parse an array of template dictionaries into a ChatPromptTemplate
80
+
81
+ Args:
82
+ templates: List of template dictionaries
83
+
84
+ Returns:
85
+ ChatPromptTemplate object
86
+ """
87
+ message_templates = []
88
+ for tmpl in templates:
89
+ if not isinstance(tmpl, dict) or "template" not in tmpl:
90
+ continue
91
+
92
+ # Check type field (SystemMessage/HumanMessage) or role field (system/human)
93
+ msg_type = tmpl.get("type", "").lower()
94
+ msg_role = tmpl.get("role", "").lower()
95
+ template_str = tmpl["template"]
96
+
97
+ if msg_type == "systemmessage" or msg_role == "system":
98
+ message_templates.append(SystemMessagePromptTemplate.from_template(template_str))
99
+ elif msg_type == "humanmessage" or msg_role == "human":
100
+ message_templates.append(HumanMessagePromptTemplate.from_template(template_str))
101
+ else:
102
+ # Default to HumanMessage for unknown types
103
+ message_templates.append(HumanMessagePromptTemplate.from_template(template_str))
104
+
105
+ return ChatPromptTemplate.from_messages(message_templates)
106
+
107
+
108
+ def _fetch_template(self, endpoint: str) -> ChatPromptTemplate:
109
+ """
110
+ Fetch template from API service without any data parameters
111
+
112
+ Args:
113
+ endpoint: The API endpoint to call
114
+
115
+ Returns:
116
+ ChatPromptTemplate object
117
+ """
118
+ # Check global cache first
119
+ if endpoint in _global_template_cache:
120
+ logger.debug(f"Using cached template for endpoint: {endpoint}")
121
+ return _global_template_cache[endpoint]
122
+
123
+ url = f"{self.base_url}/timbr/api/{endpoint}"
124
+ headers = self._get_headers()
125
+
126
+ try:
127
+ response = requests.post(
128
+ url,
129
+ headers=headers,
130
+ timeout=self.timeout
131
+ )
132
+ response.raise_for_status()
133
+
134
+ api_response = response.json()
135
+ chat_prompt = self._convert_template_response_to_chat_prompt(api_response)
136
+
137
+ # Cache the template globally
138
+ _global_template_cache[endpoint] = chat_prompt
139
+ logger.debug(f"Cached template for endpoint: {endpoint}")
140
+
141
+ return chat_prompt
142
+
143
+ except requests.exceptions.RequestException as e:
144
+ logger.error(f"Failed to get template from service {url}: {str(e)}")
145
+ raise Exception(f"Prompt service request failed: {str(e)}")
146
+ except json.JSONDecodeError as e:
147
+ logger.error(f"Invalid JSON response from prompt service: {str(e)}")
148
+ raise Exception(f"Invalid response from prompt service: {str(e)}")
149
+ except Exception as e:
150
+ logger.error(f"Error processing prompt service response: {str(e)}")
151
+ raise Exception(f"Error processing prompt service response: {str(e)}")
152
+
153
+
154
+ def get_identify_concept_template(self) -> ChatPromptTemplate:
155
+ """
156
+ Get identify concept template from API service (cached)
157
+
158
+ Returns:
159
+ ChatPromptTemplate object
160
+ """
161
+ return self._fetch_template("llm_prompts/identify_concept")
162
+
163
+
164
+ def get_generate_sql_template(self) -> ChatPromptTemplate:
165
+ """
166
+ Get generate SQL template from API service (cached)
167
+
168
+ Returns:
169
+ ChatPromptTemplate object
170
+ """
171
+ return self._fetch_template("llm_prompts/generate_sql")
172
+
173
+
174
+ def get_generate_answer_template(self) -> ChatPromptTemplate:
175
+ """
176
+ Get generate answer template from API service (cached)
177
+
178
+ Returns:
179
+ ChatPromptTemplate object
180
+ """
181
+ return self._fetch_template("llm_prompts/generate_answer")
182
+
183
+
184
+ def clear_cache(self):
185
+ """Clear the global template cache"""
186
+ _global_template_cache.clear()
187
+ logger.info("Global prompt template cache cleared")
188
+
189
+
190
+ class PromptTemplateWrapper:
191
+ """
192
+ Wrapper class that mimics the original ChatPromptTemplate behavior
193
+ but uses cached templates from the external API service
194
+ """
195
+
196
+ def __init__(self, prompt_service: PromptService, template_method: str):
197
+ self.prompt_service = prompt_service
198
+ self.template_method = template_method
199
+ self._cached_template = None
200
+
201
+
202
+ def format_messages(self, **kwargs) -> List:
203
+ """
204
+ Format messages using the cached template
205
+
206
+ Args:
207
+ **kwargs: Parameters for the prompt template
208
+
209
+ Returns:
210
+ List of LangChain message objects
211
+ """
212
+ # Get the cached template
213
+ if self._cached_template is None:
214
+ method = getattr(self.prompt_service, self.template_method)
215
+ self._cached_template = method()
216
+
217
+ # Format the template with the provided kwargs
218
+ return self._cached_template.format_messages(**kwargs)
219
+
220
+
221
+ # Individual prompt template getter functions
222
+ def get_determine_concept_prompt_template(
223
+ token: Optional[str] = None,
224
+ is_jwt: Optional[bool] = None,
225
+ jwt_tenant_id: Optional[str] = None
226
+ ) -> PromptTemplateWrapper:
227
+ """
228
+ Get determine concept prompt template wrapper
229
+
230
+ Args:
231
+ token: Authentication token
232
+ is_jwt: Whether the token is a JWT
233
+ jwt_tenant_id: JWT tenant ID
234
+
235
+ Returns:
236
+ PromptTemplateWrapper for determine concept
237
+ """
238
+ prompt_service = PromptService(
239
+ token=token,
240
+ is_jwt=is_jwt,
241
+ jwt_tenant_id=jwt_tenant_id
242
+ )
243
+ return PromptTemplateWrapper(prompt_service, "get_identify_concept_template")
244
+
245
+
246
+ def get_generate_sql_prompt_template(
247
+ token: Optional[str] = None,
248
+ is_jwt: Optional[bool] = None,
249
+ jwt_tenant_id: Optional[str] = None
250
+ ) -> PromptTemplateWrapper:
251
+ """
252
+ Get generate SQL prompt template wrapper
253
+
254
+ Args:
255
+ token: Authentication token
256
+ is_jwt: Whether the token is a JWT
257
+ jwt_tenant_id: JWT tenant ID
258
+
259
+ Returns:
260
+ PromptTemplateWrapper for generate SQL
261
+ """
262
+ prompt_service = PromptService(
263
+ token=token,
264
+ is_jwt=is_jwt,
265
+ jwt_tenant_id=jwt_tenant_id
266
+ )
267
+ return PromptTemplateWrapper(prompt_service, "get_generate_sql_template")
268
+
269
+
270
+ def get_qa_prompt_template(
271
+ token: Optional[str] = None,
272
+ is_jwt: Optional[bool] = None,
273
+ jwt_tenant_id: Optional[str] = None
274
+ ) -> PromptTemplateWrapper:
275
+ """
276
+ Get QA prompt template wrapper
277
+
278
+ Args:
279
+ token: Authentication token
280
+ is_jwt: Whether the token is a JWT
281
+ jwt_tenant_id: JWT tenant ID
282
+
283
+ Returns:
284
+ PromptTemplateWrapper for QA
285
+ """
286
+ prompt_service = PromptService(
287
+ token=token,
288
+ is_jwt=is_jwt,
289
+ jwt_tenant_id=jwt_tenant_id
290
+ )
291
+ return PromptTemplateWrapper(prompt_service, "get_generate_answer_template")
292
+
293
+
294
+ # Global prompt service instance (updated signature)
295
+ def get_prompt_service(
296
+ token: str = None,
297
+ is_jwt: bool = None,
298
+ jwt_tenant_id: str = None
299
+ ) -> PromptService:
300
+ """
301
+ Get or create a prompt service instance
302
+
303
+ Args:
304
+ token: Authentication token (API key or JWT token)
305
+ is_jwt: Whether the token is a JWT
306
+ jwt_tenant_id: JWT tenant ID
307
+
308
+ Returns:
309
+ PromptService instance
310
+ """
311
+ return PromptService(
312
+ token=token,
313
+ is_jwt=is_jwt,
314
+ jwt_tenant_id=jwt_tenant_id
315
+ )
316
+
317
+
318
+ # Global cache management functions
319
+ def clear_global_template_cache():
320
+ """Clear the global template cache"""
321
+ _global_template_cache.clear()
322
+ logger.info("Global prompt template cache cleared")
323
+
324
+
325
+ def get_cache_status():
326
+ """Get information about the global template cache"""
327
+ return {
328
+ "cached_endpoints": list(_global_template_cache.keys()),
329
+ "cache_size": len(_global_template_cache)
330
+ }
@@ -0,0 +1,62 @@
1
+ {
2
+ "OpenAI": [
3
+ "gpt-4",
4
+ "gpt-4-turbo",
5
+ "gpt-4o"
6
+ ],
7
+ "Anthropic": [
8
+ "claude-opus-4-20250514",
9
+ "claude-sonnet-4-20250514",
10
+ "claude-3-7-sonnet-20250219",
11
+ "claude-3-5-sonnet-20241022",
12
+ "claude-3-5-haiku-20241022",
13
+ "claude-3-5-sonnet-20240620",
14
+ "claude-3-haiku-20240307",
15
+ "claude-3-opus-20240229",
16
+ "claude-3-sonnet-20240229",
17
+ "claude-2.1",
18
+ "claude-2.0"
19
+ ],
20
+ "Google": [
21
+ "gemini-1.5-flash-latest",
22
+ "gemini-1.5-flash",
23
+ "gemini-1.5-flash-002",
24
+ "gemini-1.5-flash-8b",
25
+ "gemini-1.5-flash-8b-001",
26
+ "gemini-1.5-flash-8b-latest",
27
+ "gemini-2.5-flash-preview-04-17",
28
+ "gemini-2.5-flash-preview-05-20",
29
+ "gemini-2.5-flash",
30
+ "gemini-2.5-flash-preview-04-17-thinking",
31
+ "gemini-2.5-flash-lite-preview-06-17",
32
+ "gemini-2.5-pro",
33
+ "gemini-2.0-flash-exp",
34
+ "gemini-2.0-flash",
35
+ "gemini-2.0-flash-001",
36
+ "gemini-2.0-flash-exp-image-generation",
37
+ "gemini-2.0-flash-lite-001",
38
+ "gemini-2.0-flash-lite",
39
+ "gemini-2.0-flash-lite-preview-02-05",
40
+ "gemini-2.0-flash-lite-preview",
41
+ "gemini-2.0-flash-thinking-exp-01-21",
42
+ "gemini-2.0-flash-thinking-exp",
43
+ "gemini-2.0-flash-thinking-exp-1219",
44
+ "learnlm-2.0-flash-experimental",
45
+ "gemma-3-1b-it",
46
+ "gemma-3-4b-it",
47
+ "gemma-3-12b-it",
48
+ "gemma-3-27b-it",
49
+ "gemma-3n-e4b-it",
50
+ "gemma-3n-e2b-it"
51
+ ],
52
+ "AzureOpenAI": [
53
+ "gpt-4o"
54
+ ],
55
+ "Snowflake": [
56
+ "openai-gpt-4.1",
57
+ "mistral-large2",
58
+ "llama3.1-70b",
59
+ "llama3.1-405b"
60
+ ],
61
+ "Timbr": []
62
+ }