langchain-timbr 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,189 @@
1
+ from enum import Enum
2
+ from langchain.llms.base import LLM
3
+ from pydantic import Field
4
+
5
+ from .timbr_llm_wrapper import TimbrLlmWrapper
6
+ from ..utils.general import is_llm_type, is_support_temperature
7
+ from ..config import llm_temperature
8
+
9
+ class LlmTypes(Enum):
10
+ OpenAI = 'openai-chat'
11
+ Anthropic = 'anthropic-chat'
12
+ Google = 'chat-google-generative-ai'
13
+ AzureOpenAI = 'azure-openai-chat'
14
+ Snowflake = 'snowflake-cortex'
15
+ Timbr = 'timbr'
16
+
17
+
18
+ class LlmWrapper(LLM):
19
+ """
20
+ LlmWrapper is a unified interface for connecting to various Large Language Model (LLM) providers
21
+ (OpenAI, Anthropic, Google, Azure OpenAI, Snowflake Cortex, etc.) using LangChain. It abstracts
22
+ the initialization and connection logic for each provider, allowing you to switch between them
23
+ with a consistent API.
24
+ """
25
+ client: LLM = Field(default=None, exclude=True)
26
+
27
+ def __init__(
28
+ self,
29
+ llm_type: str,
30
+ api_key: str,
31
+ model: str = None,
32
+ **llm_params,
33
+ ):
34
+ """
35
+ :param llm_type (str): The type of LLM provider (e.g., 'openai-chat', 'anthropic-chat').
36
+ :param api_key (str): The API key for authenticating with the LLM provider.
37
+ :param model (str): The model name or deployment to use. Defaults to provider-specific values (Optional).
38
+ :param **llm_params: Additional parameters for the LLM (e.g., temperature, endpoint, etc.).
39
+ """
40
+ super().__init__()
41
+ self.client = self._connect_to_llm(
42
+ llm_type,
43
+ api_key,
44
+ model,
45
+ **llm_params,
46
+ )
47
+
48
+
49
+ @property
50
+ def _llm_type(self):
51
+ return self.client._llm_type
52
+
53
+
54
+ def _add_temperature(self, llm_type, llm_model, **llm_params):
55
+ """
56
+ Add temperature to the LLM parameters if the LLM model supports it.
57
+ """
58
+ if "temperature" not in llm_params:
59
+ if llm_temperature is not None and is_support_temperature(llm_type, llm_model):
60
+ llm_params["temperature"] = llm_temperature
61
+ return llm_params
62
+
63
+
64
+ def _connect_to_llm(self, llm_type, api_key, model, **llm_params):
65
+ if is_llm_type(llm_type, LlmTypes.OpenAI):
66
+ from langchain_openai import ChatOpenAI as OpenAI
67
+ llm_model = model or "gpt-4o-2024-11-20"
68
+ params = self._add_temperature(LlmTypes.OpenAI.name, llm_model, **llm_params)
69
+ return OpenAI(
70
+ openai_api_key=api_key,
71
+ model_name=llm_model,
72
+ **params,
73
+ )
74
+ elif is_llm_type(llm_type, LlmTypes.Anthropic):
75
+ from langchain_anthropic import ChatAnthropic as Claude
76
+ llm_model = model or "claude-3-5-sonnet-20241022"
77
+ params = self._add_temperature(LlmTypes.Anthropic.name, llm_model, **llm_params)
78
+ return Claude(
79
+ anthropic_api_key=api_key,
80
+ model=llm_model,
81
+ **params,
82
+ )
83
+ elif is_llm_type(llm_type, LlmTypes.Google):
84
+ from langchain_google_genai import ChatGoogleGenerativeAI
85
+ llm_model = model or "gemini-2.0-flash-exp"
86
+ params = self._add_temperature(LlmTypes.Google.name, llm_model, **llm_params)
87
+ return ChatGoogleGenerativeAI(
88
+ google_api_key=api_key,
89
+ model=llm_model,
90
+ **params,
91
+ )
92
+ elif is_llm_type(llm_type, LlmTypes.Timbr):
93
+ return TimbrLlmWrapper(
94
+ api_key=api_key,
95
+ **params,
96
+ )
97
+ elif is_llm_type(llm_type, LlmTypes.Snowflake):
98
+ from langchain_community.chat_models import ChatSnowflakeCortex
99
+ llm_model = model or "openai-gpt-4.1"
100
+ params = self._add_temperature(LlmTypes.Snowflake.name, llm_model, **llm_params)
101
+
102
+ return ChatSnowflakeCortex(
103
+ model=llm_model,
104
+ **params,
105
+ )
106
+ elif is_llm_type(llm_type, LlmTypes.AzureOpenAI):
107
+ from langchain_openai import AzureChatOpenAI
108
+ azure_endpoint = params.pop('azure_endpoint', None)
109
+ azure_api_version = params.pop('azure_openai_api_version', None)
110
+ llm_model = model or "gpt-4o-2024-11-20"
111
+ params = self._add_temperature(LlmTypes.AzureOpenAI.name, llm_model, **llm_params)
112
+ return AzureChatOpenAI(
113
+ openai_api_key=api_key,
114
+ azure_deployment=llm_model,
115
+ azure_endpoint=azure_endpoint,
116
+ openai_api_version=azure_api_version,
117
+ **params,
118
+ )
119
+ else:
120
+ raise ValueError(f"Unsupported LLM type: {llm_type}")
121
+
122
+
123
+ def get_model_list(self) -> list[str]:
124
+ """Return the list of available models for the LLM."""
125
+ models = []
126
+ try:
127
+ if is_llm_type(self._llm_type, LlmTypes.OpenAI):
128
+ from openai import OpenAI
129
+ client = OpenAI(api_key=self.client.openai_api_key._secret_value)
130
+ models = [model.id for model in client.models.list()]
131
+ elif is_llm_type(self._llm_type, LlmTypes.Anthropic):
132
+ import anthropic
133
+ client = anthropic.Anthropic(api_key=self.client.anthropic_api_key._secret_value)
134
+ models = [model.id for model in client.models.list()]
135
+ elif is_llm_type(self._llm_type, LlmTypes.Google):
136
+ import google.generativeai as genai
137
+ genai.configure(api_key=self.client.google_api_key._secret_value)
138
+ models = [m.name.replace('models/', '') for m in genai.list_models()]
139
+ elif is_llm_type(self._llm_type, LlmTypes.AzureOpenAI):
140
+ from openai import AzureOpenAI
141
+ # Get Azure-specific attributes from the client
142
+ azure_endpoint = getattr(self.client, 'azure_endpoint', None)
143
+ api_version = getattr(self.client, 'openai_api_version', None)
144
+ api_key = self.client.openai_api_key._secret_value
145
+
146
+ if azure_endpoint and api_version and api_key:
147
+ client = AzureOpenAI(
148
+ api_key=api_key,
149
+ azure_endpoint=azure_endpoint,
150
+ api_version=api_version
151
+ )
152
+ # For Azure, get the deployments instead of models
153
+ try:
154
+ models = [model.id for model in client.models.list()]
155
+ except:
156
+ # If listing models fails, provide some common deployment names
157
+ models = ["gpt-4o", "Other (Custom)"]
158
+ elif is_llm_type(self._llm_type, LlmTypes.Snowflake):
159
+ # Snowflake Cortex available models
160
+ models = [
161
+ "openai-gpt-4.1",
162
+ "mistral-large2",
163
+ "llama3.1-70b",
164
+ "llama3.1-405b"
165
+ ]
166
+ # elif self._is_llm_type(self._llm_type, LlmTypes.Timbr):
167
+
168
+ except Exception as e:
169
+ models = []
170
+
171
+ return models
172
+
173
+
174
+ def _call(self, prompt, **kwargs):
175
+ return self.client(prompt, **kwargs)
176
+
177
+
178
+ def __call__(self, prompt, **kwargs):
179
+ """
180
+ Override the default __call__ method to handle input preprocessing.
181
+ I used this in order to override prompt input validation made by pydantic
182
+ and allow sending list of AiMessages instead of string only
183
+ """
184
+ return self._call(prompt, **kwargs)
185
+
186
+
187
+ def query(self, prompt, **kwargs):
188
+ return self._call(prompt, **kwargs)
189
+
@@ -0,0 +1,41 @@
1
+ from langchain.llms.base import LLM
2
+ import requests
3
+ from typing import Optional, List
4
+
5
+ class TimbrLlmWrapper(LLM):
6
+ def __init__(self, url: str, api_key: str, temperature: Optional[float] = 0):
7
+ """
8
+ ***TBD, Not ready yet.***
9
+
10
+ Custom LLM implementation for timbr LLM wrapped with a proxy server.
11
+
12
+ :param url: URL of the proxy server wrapping timbr LLM.
13
+ :param api_key: API key for authentication with the proxy server.
14
+ :param temperature: Sampling temperature for the model.
15
+ """
16
+ self.url = url
17
+ self.api_key = api_key
18
+ self.temperature = temperature
19
+
20
+ @property
21
+ def _llm_type(self) -> str:
22
+ return "timbr"
23
+
24
+
25
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
26
+ """
27
+ Sends the prompt to the proxy server and returns the response.
28
+ """
29
+ headers = { "Authorization": f"Bearer {self.api_key}" }
30
+ payload = {
31
+ "prompt": prompt,
32
+ "temperature": self.temperature,
33
+ "stop": stop,
34
+ }
35
+
36
+ response = requests.post(self.url, json=payload, headers=headers)
37
+ if response.status_code == 200:
38
+ return response.json().get("response", "")
39
+ else:
40
+ raise ValueError(f"Error communicating with timbr proxy: {response.text}")
41
+
@@ -0,0 +1,398 @@
1
+ from typing import Optional, Any, Literal
2
+ from typing_extensions import TypedDict
3
+ from langchain.llms.base import LLM
4
+ from langgraph.graph import StateGraph, END
5
+
6
+ from .utils.general import to_boolean, to_integer
7
+ from .llm_wrapper.llm_wrapper import LlmWrapper
8
+ from .utils.timbr_utils import get_ontologies, get_concepts
9
+ from .langchain import IdentifyTimbrConceptChain, GenerateTimbrSqlChain, ValidateTimbrSqlChain, ExecuteTimbrQueryChain, create_timbr_sql_agent
10
+ from .langgraph import GenerateTimbrSqlNode, ValidateSemanticSqlNode, ExecuteSemanticQueryNode, GenerateResponseNode
11
+
12
+
13
+ from .config import (
14
+ url as default_url,
15
+ token as default_token,
16
+ ontology as default_ontology,
17
+ llm_type,
18
+ llm_model,
19
+ llm_api_key,
20
+ llm_temperature,
21
+ )
22
+
23
+ class TimbrLanggraphState(TypedDict):
24
+ prompt: str
25
+ sql: str
26
+ concept: str
27
+ rows: list
28
+ response: str
29
+ error: str
30
+ is_sql_valid: bool
31
+ usage_metadata: dict[str, Any]
32
+
33
+
34
+ class TimbrLlmConnector:
35
+ def __init__(
36
+ self,
37
+ llm: LLM,
38
+ url: Optional[str] = default_url,
39
+ token: Optional[str] = default_token,
40
+ ontology: Optional[str] = default_ontology,
41
+ max_limit: Optional[int] = 500,
42
+ verify_ssl: Optional[bool] = True,
43
+ is_jwt: Optional[bool] = False,
44
+ jwt_tenant_id: Optional[str] = None,
45
+ conn_params: Optional[dict] = None,
46
+ ):
47
+ """
48
+ :param url: Timbr server url
49
+ :param token: Timbr password or token value
50
+ :param ontology: The name of the ontology/knowledge graph
51
+ :param llm: An LLM instance or a function that takes a prompt string and returns the LLM’s response
52
+ :param max_limit: Maximum number of rows to return
53
+ :param verify_ssl: Whether to verify SSL certificates (default is True).
54
+ :param is_jwt: Whether to use JWT authentication (default is False).
55
+ :param jwt_tenant_id: Tenant ID for JWT authentication (if applicable).
56
+ :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
57
+
58
+ ## Example
59
+ ```
60
+ timbr_llm_wrapper = LlmWrapper(
61
+ llm_type=LlmTypes.OpenAI,
62
+ model="gpt-4o"
63
+ api_key=<openai_api_key>
64
+ )
65
+
66
+ llm_connector = TimbrLlmConnector(
67
+ url=<url>,
68
+ token=<token>,
69
+ llm=timbr_llm_wrapper,
70
+ )
71
+
72
+ # Show ontology list at timbr instance from url connection
73
+ ontologies = llm_connector.get_ontologies()
74
+
75
+ # Find which concept & schema will be queried by the user input
76
+ determine_concept_res = llm_connector.determine_concept(llm_input)
77
+ query_concept, query_schema = determine_concept_res.get("concept"), determine_concept_res.get("schema")
78
+
79
+ # Generate timbr SQL query from user input
80
+ sql_query = llm_connector.generate_sql(llm_input).get("sql")
81
+
82
+ # Run timbr SQL query
83
+ results = llm_connector.run_timbr_query(sql_query).get("rows", [])
84
+
85
+ # Parse & Run LLM question
86
+ results = llm_connector.run_llm_query(llm_input).get("rows", [])
87
+ ```
88
+ """
89
+ self.url = url
90
+ self.token = token
91
+ self.ontology = ontology
92
+ self.max_limit = to_integer(max_limit)
93
+ self.verify_ssl = to_boolean(verify_ssl)
94
+ self.is_jwt = to_boolean(is_jwt)
95
+ self.jwt_tenant_id = jwt_tenant_id
96
+ self.conn_params = conn_params or {}
97
+
98
+ if llm is not None:
99
+ self._llm = llm
100
+ elif llm_type is not None and llm_api_key is not None:
101
+ llm_params = {}
102
+ if llm_temperature is not None:
103
+ llm_params["temperature"] = llm_temperature
104
+
105
+ self._llm = LlmWrapper(
106
+ llm_type=llm_type,
107
+ api_key=llm_api_key,
108
+ model=llm_model,
109
+ **llm_params,
110
+ )
111
+
112
+
113
+ # TODO: Make this function a decorator and use in on relevant methods
114
+ # def _is_ontology_set(self):
115
+ # return self.ontology != 'system_db'
116
+
117
+
118
+ def _get_conn_params(self):
119
+ return {
120
+ "url": self.url,
121
+ "token": self.token,
122
+ "ontology": self.ontology,
123
+ "verify_ssl": self.verify_ssl,
124
+ "is_jwt": self.is_jwt,
125
+ "jwt_tenant_id": self.jwt_tenant_id,
126
+ **self.conn_params,
127
+ }
128
+
129
+
130
+ def get_ontologies(self) -> list[str]:
131
+ return get_ontologies(conn_params=self._get_conn_params())
132
+
133
+
134
+ def get_concepts(self) -> dict:
135
+ """
136
+ Get the list of concepts from the Timbr server.
137
+ """
138
+ return get_concepts(
139
+ conn_params=self._get_conn_params(),
140
+ concepts_list="*",
141
+ )
142
+
143
+
144
+ def get_views(self) -> dict:
145
+ """
146
+ Get the list of views from the Timbr server.
147
+ """
148
+ return get_concepts(
149
+ conn_params=self._get_conn_params(),
150
+ views_list="*",
151
+ )
152
+
153
+
154
+ def set_ontology(self, ontology: str):
155
+ self.ontology = ontology
156
+
157
+
158
+ def determine_concept(
159
+ self,
160
+ question: str,
161
+ concepts_list: Optional[list] = None,
162
+ views_list: Optional[list] = None,
163
+ include_logic_concepts: Optional[bool] = False,
164
+ include_tags: Optional[str] = None,
165
+ should_validate: Optional[bool] = False,
166
+ retries: Optional[int] = 3,
167
+ note: Optional[str] = '',
168
+ **chain_kwargs: Any,
169
+ ) -> dict[str, Any]:
170
+ determine_concept_chain = IdentifyTimbrConceptChain(
171
+ **self._get_conn_params(),
172
+ llm=self._llm,
173
+ concepts_list=concepts_list,
174
+ views_list=views_list,
175
+ include_logic_concepts=include_logic_concepts,
176
+ include_tags=include_tags,
177
+ should_validate=should_validate,
178
+ retries=retries,
179
+ note=note,
180
+ **chain_kwargs,
181
+ )
182
+
183
+ return determine_concept_chain.invoke({ "prompt": question })
184
+
185
+
186
+ def generate_sql(
187
+ self,
188
+ question: str,
189
+ concept_name: Optional[str] = None,
190
+ schema: Optional[str] = None,
191
+ concepts_list: Optional[list] = None,
192
+ views_list: Optional[list] = None,
193
+ include_logic_concepts: Optional[bool] = False,
194
+ include_tags: Optional[str] = None,
195
+ should_validate_sql: Optional[bool] = False,
196
+ retries: Optional[int] = 3,
197
+ note: Optional[str] = '',
198
+ **chain_kwargs: Any,
199
+ ) -> dict[str, Any]:
200
+ generate_timbr_llm_chain = GenerateTimbrSqlChain(
201
+ llm=self._llm,
202
+ **self._get_conn_params(),
203
+ schema=schema,
204
+ concept=concept_name,
205
+ concepts_list=concepts_list,
206
+ views_list=views_list,
207
+ include_logic_concepts=include_logic_concepts,
208
+ include_tags=include_tags,
209
+ should_validate_sql=should_validate_sql,
210
+ retries=retries,
211
+ max_limit=self.max_limit,
212
+ note=note,
213
+ **chain_kwargs,
214
+ )
215
+
216
+ return generate_timbr_llm_chain.invoke({ "prompt": question })
217
+
218
+
219
+ def validate_sql(
220
+ self,
221
+ question: str,
222
+ sql_query: str,
223
+ retries: Optional[int] = 3,
224
+ concepts_list: Optional[list] = None,
225
+ views_list: Optional[list] = None,
226
+ include_logic_concepts: Optional[bool] = False,
227
+ include_tags: Optional[str] = None,
228
+ note: Optional[str] = '',
229
+ **chain_kwargs: Any,
230
+ ) -> dict[str, Any]:
231
+ validate_timbr_sql_chain = ValidateTimbrSqlChain(
232
+ llm=self._llm,
233
+ **self._get_conn_params(),
234
+ retries=retries,
235
+ concepts_list=concepts_list,
236
+ views_list=views_list,
237
+ include_logic_concepts=include_logic_concepts,
238
+ include_tags=include_tags,
239
+ max_limit=self.max_limit,
240
+ note=note,
241
+ **chain_kwargs,
242
+ )
243
+ return validate_timbr_sql_chain.invoke({ "sql": sql_query, "prompt": question })
244
+
245
+
246
+ def run_timbr_query(
247
+ self,
248
+ sql_query: str,
249
+ concepts_list: Optional[list] = None,
250
+ views_list: Optional[list] = None,
251
+ include_logic_concepts: Optional[bool] = False,
252
+ include_tags: Optional[str] = None,
253
+ should_validate_sql: Optional[bool] = True,
254
+ retries: Optional[int] = 3,
255
+ note: Optional[str] = '',
256
+ **chain_kwargs: Any,
257
+ ) -> dict[str, Any]:
258
+ execute_timbr_query_chain = ExecuteTimbrQueryChain(
259
+ llm=self._llm,
260
+ **self._get_conn_params(),
261
+ concepts_list=concepts_list,
262
+ views_list=views_list,
263
+ include_logic_concepts=include_logic_concepts,
264
+ include_tags=include_tags,
265
+ should_validate_sql=should_validate_sql,
266
+ retries=retries,
267
+ max_limit=self.max_limit,
268
+ note=note,
269
+ **chain_kwargs,
270
+ )
271
+
272
+ return execute_timbr_query_chain.invoke({ "sql": sql_query })
273
+
274
+
275
+ def run_llm_query(
276
+ self,
277
+ question: str,
278
+ concepts_list: Optional[list] = None,
279
+ views_list: Optional[list] = None,
280
+ include_logic_concepts: Optional[bool] = False,
281
+ include_tags: Optional[str] = None,
282
+ should_validate_sql: Optional[bool] = True,
283
+ retries: Optional[int] = 3,
284
+ note: Optional[str] = '',
285
+ **agent_kwargs: Any,
286
+ ) -> dict[str, Any]:
287
+ agent = create_timbr_sql_agent(
288
+ llm=self._llm,
289
+ **self._get_conn_params(),
290
+ concept=None,
291
+ concepts_list=concepts_list,
292
+ views_list=views_list,
293
+ include_logic_concepts=include_logic_concepts,
294
+ include_tags=include_tags,
295
+ should_validate_sql=should_validate_sql,
296
+ retries=retries,
297
+ max_limit=self.max_limit,
298
+ note=note,
299
+ **agent_kwargs,
300
+ )
301
+
302
+ return agent.invoke(question)
303
+
304
+
305
+ def run_llm_query_graph(
306
+ self,
307
+ question: str,
308
+ concepts_list: Optional[list] = None,
309
+ views_list: Optional[list] = None,
310
+ include_logic_concepts: Optional[bool] = False,
311
+ include_tags: Optional[str] = None,
312
+ should_validate_sql: Optional[bool] = True,
313
+ retries: Optional[int] = 3,
314
+ note: Optional[str] = '',
315
+ **nodes_kwargs: Any,
316
+ ) -> dict[str, Any]:
317
+ generate_sql_node = GenerateTimbrSqlNode(
318
+ llm=self._llm,
319
+ **self._get_conn_params(),
320
+ concepts_list=concepts_list,
321
+ views_list=views_list,
322
+ include_logic_concepts=include_logic_concepts,
323
+ include_tags=include_tags,
324
+ max_limit=self.max_limit,
325
+ note=note,
326
+ **nodes_kwargs,
327
+ )
328
+ validate_sql_node = ValidateSemanticSqlNode(
329
+ llm=self._llm,
330
+ **self._get_conn_params(),
331
+ retries=retries,
332
+ concepts_list=concepts_list,
333
+ views_list=views_list,
334
+ include_logic_concepts=include_logic_concepts,
335
+ include_tags=include_tags,
336
+ max_limit=self.max_limit,
337
+ note=note,
338
+ **nodes_kwargs,
339
+ )
340
+ execute_sql_node = ExecuteSemanticQueryNode(
341
+ llm=self._llm,
342
+ **self._get_conn_params(),
343
+ concepts_list=concepts_list,
344
+ views_list=views_list,
345
+ include_logic_concepts=include_logic_concepts,
346
+ include_tags=include_tags,
347
+ should_validate_sql=should_validate_sql,
348
+ retries=retries,
349
+ max_limit=self.max_limit,
350
+ note=note,
351
+ **nodes_kwargs,
352
+ )
353
+ generate_response_node = GenerateResponseNode()
354
+
355
+ graph_builder = StateGraph(TimbrLanggraphState)
356
+
357
+ graph_builder.add_node("generate_sql", generate_sql_node)
358
+ graph_builder.add_node("validate_sql", validate_sql_node)
359
+ graph_builder.add_node("execute_sql", execute_sql_node)
360
+ graph_builder.add_node("generate_response", generate_response_node)
361
+
362
+ graph_builder.add_edge("generate_sql", "validate_sql")
363
+
364
+ def route_validation(state: dict) -> Literal["execute_sql", "end"]:
365
+ # If validation is successful, proceed to execute the query.
366
+ # Otherwise, stop the flow.
367
+ if state.get("is_sql_valid"):
368
+ return "execute_sql"
369
+ else:
370
+ return "end"
371
+
372
+ graph_builder.add_conditional_edges(
373
+ "validate_sql",
374
+ route_validation,
375
+ {
376
+ "execute_sql": "execute_sql",
377
+ "end": END
378
+ }
379
+ )
380
+
381
+ graph_builder.add_edge("execute_sql", "generate_response")
382
+ graph_builder.set_entry_point("generate_sql")
383
+
384
+ compiled_graph = graph_builder.compile()
385
+
386
+ initial_state = {
387
+ "prompt": question,
388
+ "sql": "",
389
+ "concept": "",
390
+ "rows": [],
391
+ "response": "",
392
+ "error": "",
393
+ "is_sql_valid": False,
394
+ "usage_metadata": {}
395
+ }
396
+
397
+ result = compiled_graph.invoke(initial_state)
398
+ return result
@@ -0,0 +1,70 @@
1
+ import os
2
+ import json
3
+
4
+ ### A global helper functions to use across the project
5
+
6
+ def parse_list(input_value, separator=',') -> list[str]:
7
+ try:
8
+ if isinstance(input_value, str):
9
+ return [item.strip() for item in input_value.split(separator) if item.strip()]
10
+ elif isinstance(input_value, list):
11
+ return [item.strip() for item in input_value if item.strip()]
12
+ return []
13
+ except Exception as e:
14
+ raise ValueError(f"Failed to parse list value: {e}")
15
+
16
+
17
+ def to_boolean(value) -> bool:
18
+ try:
19
+ if isinstance(value, str):
20
+ return value.lower() in ['true', '1']
21
+ return bool(value)
22
+ except Exception as e:
23
+ raise ValueError(f"Failed to parse boolean value: {e}")
24
+
25
+
26
+ def to_integer(value) -> int:
27
+ try:
28
+ return int(value)
29
+ except (ValueError, TypeError) as e:
30
+ raise ValueError(f"Failed to parse integer value: {e}")
31
+
32
+
33
+ def is_llm_type(llm_type, enum_value):
34
+ """Check if llm_type equals the enum value or its name, case-insensitive."""
35
+ if llm_type == enum_value:
36
+ return True
37
+
38
+ if isinstance(llm_type, str):
39
+ llm_type_lower = llm_type.lower()
40
+ enum_name_lower = enum_value.name.lower() if enum_value.name else ""
41
+ enum_value_lower = enum_value.value.lower() if isinstance(enum_value.value, str) else ""
42
+
43
+ return (
44
+ llm_type_lower == enum_name_lower or
45
+ llm_type_lower == enum_value_lower or
46
+ llm_type_lower.startswith(enum_name_lower) or # Usecase for snowflake which its type is the provider name + the model name
47
+ llm_type_lower.startswith(enum_value_lower)
48
+ )
49
+
50
+ return False
51
+
52
+
53
+ def is_support_temperature(llm_type: str, llm_model: str) -> bool:
54
+ """
55
+ Check if the LLM model supports temperature setting.
56
+ """
57
+ current_dir = os.path.dirname(os.path.abspath(__file__))
58
+ json_file_path = os.path.join(current_dir, 'temperature_supported_models.json')
59
+
60
+ try:
61
+ with open(json_file_path, 'r') as f:
62
+ temperature_supported_models = json.load(f)
63
+
64
+ # Check if llm_type exists and llm_model is in its list
65
+ if llm_type in temperature_supported_models:
66
+ return llm_model in temperature_supported_models[llm_type]
67
+
68
+ return False
69
+ except (FileNotFoundError, json.JSONDecodeError, KeyError):
70
+ return False