lfx-nightly 0.1.13.dev7__py3-none-any.whl → 0.1.13.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lfx-nightly might be problematic. Click here for more details.
- lfx/_assets/component_index.json +1 -1
- lfx/base/agents/agent.py +35 -2
- lfx/base/models/model_input_constants.py +74 -7
- lfx/base/models/ollama_constants.py +3 -0
- lfx/components/agents/agent.py +37 -2
- lfx/components/agents/mcp_component.py +16 -2
- lfx/components/ibm/watsonx.py +25 -21
- lfx/components/ollama/ollama.py +221 -14
- lfx/schema/schema.py +5 -0
- {lfx_nightly-0.1.13.dev7.dist-info → lfx_nightly-0.1.13.dev9.dist-info}/METADATA +1 -1
- {lfx_nightly-0.1.13.dev7.dist-info → lfx_nightly-0.1.13.dev9.dist-info}/RECORD +13 -13
- {lfx_nightly-0.1.13.dev7.dist-info → lfx_nightly-0.1.13.dev9.dist-info}/WHEEL +0 -0
- {lfx_nightly-0.1.13.dev7.dist-info → lfx_nightly-0.1.13.dev9.dist-info}/entry_points.txt +0 -0
lfx/base/agents/agent.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, cast
|
|
|
5
5
|
|
|
6
6
|
from langchain.agents import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent
|
|
7
7
|
from langchain.agents.agent import RunnableAgent
|
|
8
|
+
from langchain.callbacks.base import BaseCallbackHandler
|
|
8
9
|
from langchain_core.messages import HumanMessage
|
|
9
10
|
from langchain_core.runnables import Runnable
|
|
10
11
|
|
|
@@ -75,6 +76,12 @@ class LCAgentComponent(Component):
|
|
|
75
76
|
Output(display_name="Response", name="response", method="message_response"),
|
|
76
77
|
]
|
|
77
78
|
|
|
79
|
+
# Get shared callbacks for tracing and save them to self.shared_callbacks
|
|
80
|
+
def _get_shared_callbacks(self) -> list[BaseCallbackHandler]:
|
|
81
|
+
if not hasattr(self, "shared_callbacks"):
|
|
82
|
+
self.shared_callbacks = self.get_langchain_callbacks()
|
|
83
|
+
return self.shared_callbacks
|
|
84
|
+
|
|
78
85
|
@abstractmethod
|
|
79
86
|
def build_agent(self) -> AgentExecutor:
|
|
80
87
|
"""Create the agent."""
|
|
@@ -209,7 +216,8 @@ class LCAgentComponent(Component):
|
|
|
209
216
|
result = await process_agent_events(
|
|
210
217
|
runnable.astream_events(
|
|
211
218
|
input_dict,
|
|
212
|
-
|
|
219
|
+
# here we use the shared callbacks because the AgentExecutor uses the tools
|
|
220
|
+
config={"callbacks": [AgentAsyncHandler(self.log), *self._get_shared_callbacks()]},
|
|
213
221
|
version="v2",
|
|
214
222
|
),
|
|
215
223
|
agent_message,
|
|
@@ -285,15 +293,40 @@ class LCToolsAgentComponent(LCAgentComponent):
|
|
|
285
293
|
tools_names = ", ".join([tool.name for tool in self.tools])
|
|
286
294
|
return tools_names
|
|
287
295
|
|
|
296
|
+
# Set shared callbacks for tracing
|
|
297
|
+
def set_tools_callbacks(self, tools_list: list[Tool], callbacks_list: list[BaseCallbackHandler]):
|
|
298
|
+
"""Set shared callbacks for tracing to the tools.
|
|
299
|
+
|
|
300
|
+
If we do not pass down the same callbacks to each tool
|
|
301
|
+
used by the agent, then each tool will instantiate a new callback.
|
|
302
|
+
For some tracing services, this will cause
|
|
303
|
+
the callback handler to lose the id of its parent run (Agent)
|
|
304
|
+
and thus throw an error in the tracing service client.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
tools_list: list of tools to set the callbacks for
|
|
308
|
+
callbacks_list: list of callbacks to set for the tools
|
|
309
|
+
Returns:
|
|
310
|
+
None
|
|
311
|
+
"""
|
|
312
|
+
for tool in tools_list or []:
|
|
313
|
+
if hasattr(tool, "callbacks"):
|
|
314
|
+
tool.callbacks = callbacks_list
|
|
315
|
+
|
|
288
316
|
async def _get_tools(self) -> list[Tool]:
|
|
289
317
|
component_toolkit = _get_component_toolkit()
|
|
290
318
|
tools_names = self._build_tools_names()
|
|
291
319
|
agent_description = self.get_tool_description()
|
|
292
320
|
# TODO: Agent Description Depreciated Feature to be removed
|
|
293
321
|
description = f"{agent_description}{tools_names}"
|
|
322
|
+
|
|
294
323
|
tools = component_toolkit(component=self).get_tools(
|
|
295
|
-
tool_name=self.get_tool_name(),
|
|
324
|
+
tool_name=self.get_tool_name(),
|
|
325
|
+
tool_description=description,
|
|
326
|
+
# here we do not use the shared callbacks as we are exposing the agent as a tool
|
|
327
|
+
callbacks=self.get_langchain_callbacks(),
|
|
296
328
|
)
|
|
297
329
|
if hasattr(self, "tools_metadata"):
|
|
298
330
|
tools = component_toolkit(component=self, metadata=self.tools_metadata).update_tools_metadata(tools=tools)
|
|
331
|
+
|
|
299
332
|
return tools
|
|
@@ -14,14 +14,18 @@ class ModelProvidersDict(TypedDict):
|
|
|
14
14
|
is_active: bool
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def get_filtered_inputs(component_class):
|
|
17
|
+
def get_filtered_inputs(component_class, provider_name: str | None = None):
|
|
18
18
|
base_input_names = {field.name for field in LCModelComponent.get_base_inputs()}
|
|
19
19
|
component_instance = component_class()
|
|
20
20
|
|
|
21
|
-
return [
|
|
21
|
+
return [
|
|
22
|
+
process_inputs(input_, provider_name)
|
|
23
|
+
for input_ in component_instance.inputs
|
|
24
|
+
if input_.name not in base_input_names
|
|
25
|
+
]
|
|
22
26
|
|
|
23
27
|
|
|
24
|
-
def process_inputs(component_data: Input):
|
|
28
|
+
def process_inputs(component_data: Input, provider_name: str | None = None):
|
|
25
29
|
"""Processes and modifies an input configuration based on its type or name.
|
|
26
30
|
|
|
27
31
|
Adjusts properties such as value, advanced status, real-time refresh, and additional information for specific
|
|
@@ -29,6 +33,7 @@ def process_inputs(component_data: Input):
|
|
|
29
33
|
|
|
30
34
|
Args:
|
|
31
35
|
component_data: The input configuration to process.
|
|
36
|
+
provider_name: The name of the provider to process the inputs for.
|
|
32
37
|
|
|
33
38
|
Returns:
|
|
34
39
|
The modified input configuration.
|
|
@@ -43,9 +48,11 @@ def process_inputs(component_data: Input):
|
|
|
43
48
|
component_data.advanced = True
|
|
44
49
|
component_data.value = True
|
|
45
50
|
elif component_data.name in {"temperature", "base_url"}:
|
|
46
|
-
|
|
51
|
+
if provider_name not in ["IBM watsonx.ai", "Ollama"]:
|
|
52
|
+
component_data = set_advanced_true(component_data)
|
|
47
53
|
elif component_data.name == "model_name":
|
|
48
|
-
|
|
54
|
+
if provider_name not in ["IBM watsonx.ai"]:
|
|
55
|
+
component_data = set_real_time_refresh_false(component_data)
|
|
49
56
|
component_data = add_combobox_true(component_data)
|
|
50
57
|
component_data = add_info(
|
|
51
58
|
component_data,
|
|
@@ -79,6 +86,28 @@ def create_input_fields_dict(inputs: list[Input], prefix: str) -> dict[str, Inpu
|
|
|
79
86
|
return {f"{prefix}{input_.name}": input_.to_dict() for input_ in inputs}
|
|
80
87
|
|
|
81
88
|
|
|
89
|
+
def _get_ollama_inputs_and_fields():
|
|
90
|
+
try:
|
|
91
|
+
from lfx.components.ollama.ollama import ChatOllamaComponent
|
|
92
|
+
|
|
93
|
+
ollama_inputs = get_filtered_inputs(ChatOllamaComponent, provider_name="Ollama")
|
|
94
|
+
except ImportError as e:
|
|
95
|
+
msg = "Ollama is not installed. Please install it with `pip install langchain-ollama`."
|
|
96
|
+
raise ImportError(msg) from e
|
|
97
|
+
return ollama_inputs, create_input_fields_dict(ollama_inputs, "")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _get_watsonx_inputs_and_fields():
|
|
101
|
+
try:
|
|
102
|
+
from lfx.components.ibm.watsonx import WatsonxAIComponent
|
|
103
|
+
|
|
104
|
+
watsonx_inputs = get_filtered_inputs(WatsonxAIComponent, provider_name="IBM watsonx.ai")
|
|
105
|
+
except ImportError as e:
|
|
106
|
+
msg = "IBM watsonx.ai is not installed. Please install it with `pip install langchain-ibm-watsonx`."
|
|
107
|
+
raise ImportError(msg) from e
|
|
108
|
+
return watsonx_inputs, create_input_fields_dict(watsonx_inputs, "")
|
|
109
|
+
|
|
110
|
+
|
|
82
111
|
def _get_google_generative_ai_inputs_and_fields():
|
|
83
112
|
try:
|
|
84
113
|
from lfx.components.google.google_generative_ai import GoogleGenerativeAIComponent
|
|
@@ -293,6 +322,36 @@ try:
|
|
|
293
322
|
except ImportError:
|
|
294
323
|
pass
|
|
295
324
|
|
|
325
|
+
try:
|
|
326
|
+
from lfx.components.ibm.watsonx import WatsonxAIComponent
|
|
327
|
+
|
|
328
|
+
watsonx_inputs, watsonx_fields = _get_watsonx_inputs_and_fields()
|
|
329
|
+
MODEL_PROVIDERS_DICT["IBM watsonx.ai"] = {
|
|
330
|
+
"fields": watsonx_fields,
|
|
331
|
+
"inputs": watsonx_inputs,
|
|
332
|
+
"prefix": "",
|
|
333
|
+
"component_class": WatsonxAIComponent(),
|
|
334
|
+
"icon": WatsonxAIComponent.icon,
|
|
335
|
+
"is_active": True,
|
|
336
|
+
}
|
|
337
|
+
except ImportError:
|
|
338
|
+
pass
|
|
339
|
+
|
|
340
|
+
try:
|
|
341
|
+
from lfx.components.ollama.ollama import ChatOllamaComponent
|
|
342
|
+
|
|
343
|
+
ollama_inputs, ollama_fields = _get_ollama_inputs_and_fields()
|
|
344
|
+
MODEL_PROVIDERS_DICT["Ollama"] = {
|
|
345
|
+
"fields": ollama_fields,
|
|
346
|
+
"inputs": ollama_inputs,
|
|
347
|
+
"prefix": "",
|
|
348
|
+
"component_class": ChatOllamaComponent(),
|
|
349
|
+
"icon": ChatOllamaComponent.icon,
|
|
350
|
+
"is_active": True,
|
|
351
|
+
}
|
|
352
|
+
except ImportError:
|
|
353
|
+
pass
|
|
354
|
+
|
|
296
355
|
# Expose only active providers ----------------------------------------------
|
|
297
356
|
ACTIVE_MODEL_PROVIDERS_DICT: dict[str, ModelProvidersDict] = {
|
|
298
357
|
name: prov for name, prov in MODEL_PROVIDERS_DICT.items() if prov.get("is_active", True)
|
|
@@ -302,10 +361,18 @@ MODEL_PROVIDERS: list[str] = list(ACTIVE_MODEL_PROVIDERS_DICT.keys())
|
|
|
302
361
|
|
|
303
362
|
ALL_PROVIDER_FIELDS: list[str] = [field for prov in ACTIVE_MODEL_PROVIDERS_DICT.values() for field in prov["fields"]]
|
|
304
363
|
|
|
305
|
-
MODEL_DYNAMIC_UPDATE_FIELDS = [
|
|
364
|
+
MODEL_DYNAMIC_UPDATE_FIELDS = [
|
|
365
|
+
"api_key",
|
|
366
|
+
"model",
|
|
367
|
+
"tool_model_enabled",
|
|
368
|
+
"base_url",
|
|
369
|
+
"model_name",
|
|
370
|
+
"watsonx_endpoint",
|
|
371
|
+
"url",
|
|
372
|
+
]
|
|
306
373
|
|
|
307
374
|
MODELS_METADATA = {name: {"icon": prov["icon"]} for name, prov in ACTIVE_MODEL_PROVIDERS_DICT.items()}
|
|
308
375
|
|
|
309
|
-
MODEL_PROVIDERS_LIST = ["Anthropic", "Google Generative AI", "OpenAI"]
|
|
376
|
+
MODEL_PROVIDERS_LIST = ["Anthropic", "Google Generative AI", "OpenAI", "IBM watsonx.ai", "Ollama"]
|
|
310
377
|
|
|
311
378
|
MODEL_OPTIONS_METADATA = [MODELS_METADATA[key] for key in MODEL_PROVIDERS_LIST if key in MODELS_METADATA]
|
lfx/components/agents/agent.py
CHANGED
|
@@ -20,7 +20,7 @@ from lfx.components.langchain_utilities.tool_calling import ToolCallingAgentComp
|
|
|
20
20
|
from lfx.custom.custom_component.component import get_component_toolkit
|
|
21
21
|
from lfx.custom.utils import update_component_build_config
|
|
22
22
|
from lfx.helpers.base_model import build_model_from_schema
|
|
23
|
-
from lfx.inputs.inputs import BoolInput
|
|
23
|
+
from lfx.inputs.inputs import BoolInput, SecretStrInput, StrInput
|
|
24
24
|
from lfx.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output, TableInput
|
|
25
25
|
from lfx.log.logger import logger
|
|
26
26
|
from lfx.schema.data import Data
|
|
@@ -77,6 +77,32 @@ class AgentComponent(ToolCallingAgentComponent):
|
|
|
77
77
|
},
|
|
78
78
|
},
|
|
79
79
|
),
|
|
80
|
+
SecretStrInput(
|
|
81
|
+
name="api_key",
|
|
82
|
+
display_name="API Key",
|
|
83
|
+
info="The API key to use for the model.",
|
|
84
|
+
required=True,
|
|
85
|
+
),
|
|
86
|
+
StrInput(
|
|
87
|
+
name="base_url",
|
|
88
|
+
display_name="Base URL",
|
|
89
|
+
info="The base URL of the API.",
|
|
90
|
+
required=True,
|
|
91
|
+
show=False,
|
|
92
|
+
),
|
|
93
|
+
StrInput(
|
|
94
|
+
name="project_id",
|
|
95
|
+
display_name="Project ID",
|
|
96
|
+
info="The project ID of the model.",
|
|
97
|
+
required=True,
|
|
98
|
+
show=False,
|
|
99
|
+
),
|
|
100
|
+
IntInput(
|
|
101
|
+
name="max_output_tokens",
|
|
102
|
+
display_name="Max Output Tokens",
|
|
103
|
+
info="The maximum number of tokens to generate.",
|
|
104
|
+
show=False,
|
|
105
|
+
),
|
|
80
106
|
*openai_inputs_filtered,
|
|
81
107
|
MultilineInput(
|
|
82
108
|
name="system_prompt",
|
|
@@ -195,10 +221,15 @@ class AgentComponent(ToolCallingAgentComponent):
|
|
|
195
221
|
if not isinstance(self.tools, list): # type: ignore[has-type]
|
|
196
222
|
self.tools = []
|
|
197
223
|
current_date_tool = (await CurrentDateComponent(**self.get_base_args()).to_toolkit()).pop(0)
|
|
224
|
+
|
|
198
225
|
if not isinstance(current_date_tool, StructuredTool):
|
|
199
226
|
msg = "CurrentDateComponent must be converted to a StructuredTool"
|
|
200
227
|
raise TypeError(msg)
|
|
201
228
|
self.tools.append(current_date_tool)
|
|
229
|
+
|
|
230
|
+
# Set shared callbacks for tracing the tools used by the agent
|
|
231
|
+
self.set_tools_callbacks(self.tools, self._get_shared_callbacks())
|
|
232
|
+
|
|
202
233
|
return llm_model, self.chat_history, self.tools
|
|
203
234
|
|
|
204
235
|
async def message_response(self) -> Message:
|
|
@@ -471,7 +502,8 @@ class AgentComponent(ToolCallingAgentComponent):
|
|
|
471
502
|
def delete_fields(self, build_config: dotdict, fields: dict | list[str]) -> None:
|
|
472
503
|
"""Delete specified fields from build_config."""
|
|
473
504
|
for field in fields:
|
|
474
|
-
build_config
|
|
505
|
+
if build_config is not None and field in build_config:
|
|
506
|
+
build_config.pop(field, None)
|
|
475
507
|
|
|
476
508
|
def update_input_types(self, build_config: dotdict) -> dotdict:
|
|
477
509
|
"""Update input types for all fields in build_config."""
|
|
@@ -599,11 +631,14 @@ class AgentComponent(ToolCallingAgentComponent):
|
|
|
599
631
|
agent_description = self.get_tool_description()
|
|
600
632
|
# TODO: Agent Description Depreciated Feature to be removed
|
|
601
633
|
description = f"{agent_description}{tools_names}"
|
|
634
|
+
|
|
602
635
|
tools = component_toolkit(component=self).get_tools(
|
|
603
636
|
tool_name="Call_Agent",
|
|
604
637
|
tool_description=description,
|
|
638
|
+
# here we do not use the shared callbacks as we are exposing the agent as a tool
|
|
605
639
|
callbacks=self.get_langchain_callbacks(),
|
|
606
640
|
)
|
|
607
641
|
if hasattr(self, "tools_metadata"):
|
|
608
642
|
tools = component_toolkit(component=self, metadata=self.tools_metadata).update_tools_metadata(tools=tools)
|
|
643
|
+
|
|
609
644
|
return tools
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import json
|
|
4
5
|
import uuid
|
|
5
6
|
from typing import Any
|
|
6
7
|
|
|
@@ -520,7 +521,6 @@ class MCPToolsComponent(ComponentWithCache):
|
|
|
520
521
|
if session_context:
|
|
521
522
|
self.stdio_client.set_session_context(session_context)
|
|
522
523
|
self.streamable_http_client.set_session_context(session_context)
|
|
523
|
-
|
|
524
524
|
exec_tool = self._tool_cache[self.tool]
|
|
525
525
|
tool_args = self.get_inputs_for_all_tools(self.tools)[self.tool]
|
|
526
526
|
kwargs = {}
|
|
@@ -535,11 +535,14 @@ class MCPToolsComponent(ComponentWithCache):
|
|
|
535
535
|
unflattened_kwargs = maybe_unflatten_dict(kwargs)
|
|
536
536
|
|
|
537
537
|
output = await exec_tool.coroutine(**unflattened_kwargs)
|
|
538
|
-
|
|
539
538
|
tool_content = []
|
|
540
539
|
for item in output.content:
|
|
541
540
|
item_dict = item.model_dump()
|
|
541
|
+
item_dict = self.process_output_item(item_dict)
|
|
542
542
|
tool_content.append(item_dict)
|
|
543
|
+
|
|
544
|
+
if isinstance(tool_content, list) and all(isinstance(x, dict) for x in tool_content):
|
|
545
|
+
return DataFrame(tool_content)
|
|
543
546
|
return DataFrame(data=tool_content)
|
|
544
547
|
return DataFrame(data=[{"error": "You must select a tool"}])
|
|
545
548
|
except Exception as e:
|
|
@@ -547,6 +550,17 @@ class MCPToolsComponent(ComponentWithCache):
|
|
|
547
550
|
await logger.aexception(msg)
|
|
548
551
|
raise ValueError(msg) from e
|
|
549
552
|
|
|
553
|
+
def process_output_item(self, item_dict):
|
|
554
|
+
"""Process the output of a tool."""
|
|
555
|
+
if item_dict.get("type") == "text":
|
|
556
|
+
text = item_dict.get("text")
|
|
557
|
+
try:
|
|
558
|
+
return json.loads(text)
|
|
559
|
+
# convert it to dict
|
|
560
|
+
except json.JSONDecodeError:
|
|
561
|
+
return item_dict
|
|
562
|
+
return item_dict
|
|
563
|
+
|
|
550
564
|
def _get_session_context(self) -> str | None:
|
|
551
565
|
"""Get the Langflow session ID for MCP session caching."""
|
|
552
566
|
# Try to get session ID from the component's execution context
|
lfx/components/ibm/watsonx.py
CHANGED
|
@@ -21,23 +21,24 @@ class WatsonxAIComponent(LCModelComponent):
|
|
|
21
21
|
beta = False
|
|
22
22
|
|
|
23
23
|
_default_models = ["ibm/granite-3-2b-instruct", "ibm/granite-3-8b-instruct", "ibm/granite-13b-instruct-v2"]
|
|
24
|
-
|
|
24
|
+
_urls = [
|
|
25
|
+
"https://us-south.ml.cloud.ibm.com",
|
|
26
|
+
"https://eu-de.ml.cloud.ibm.com",
|
|
27
|
+
"https://eu-gb.ml.cloud.ibm.com",
|
|
28
|
+
"https://au-syd.ml.cloud.ibm.com",
|
|
29
|
+
"https://jp-tok.ml.cloud.ibm.com",
|
|
30
|
+
"https://ca-tor.ml.cloud.ibm.com",
|
|
31
|
+
]
|
|
25
32
|
inputs = [
|
|
26
33
|
*LCModelComponent.get_base_inputs(),
|
|
27
34
|
DropdownInput(
|
|
28
|
-
name="
|
|
35
|
+
name="base_url",
|
|
29
36
|
display_name="watsonx API Endpoint",
|
|
30
37
|
info="The base URL of the API.",
|
|
31
|
-
value=
|
|
32
|
-
options=
|
|
33
|
-
"https://us-south.ml.cloud.ibm.com",
|
|
34
|
-
"https://eu-de.ml.cloud.ibm.com",
|
|
35
|
-
"https://eu-gb.ml.cloud.ibm.com",
|
|
36
|
-
"https://au-syd.ml.cloud.ibm.com",
|
|
37
|
-
"https://jp-tok.ml.cloud.ibm.com",
|
|
38
|
-
"https://ca-tor.ml.cloud.ibm.com",
|
|
39
|
-
],
|
|
38
|
+
value=[],
|
|
39
|
+
options=_urls,
|
|
40
40
|
real_time_refresh=True,
|
|
41
|
+
required=True,
|
|
41
42
|
),
|
|
42
43
|
StrInput(
|
|
43
44
|
name="project_id",
|
|
@@ -56,8 +57,9 @@ class WatsonxAIComponent(LCModelComponent):
|
|
|
56
57
|
display_name="Model Name",
|
|
57
58
|
options=[],
|
|
58
59
|
value=None,
|
|
59
|
-
|
|
60
|
+
real_time_refresh=True,
|
|
60
61
|
required=True,
|
|
62
|
+
refresh_button=True,
|
|
61
63
|
),
|
|
62
64
|
IntInput(
|
|
63
65
|
name="max_tokens",
|
|
@@ -155,18 +157,20 @@ class WatsonxAIComponent(LCModelComponent):
|
|
|
155
157
|
|
|
156
158
|
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
|
157
159
|
"""Update model options when URL or API key changes."""
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
if field_name == "url" and field_value:
|
|
160
|
+
if field_name == "base_url" and field_value:
|
|
161
161
|
try:
|
|
162
|
-
models = self.fetch_models(base_url=
|
|
163
|
-
build_config
|
|
164
|
-
if build_config
|
|
165
|
-
build_config
|
|
166
|
-
info_message = f"Updated model options: {len(models)} models found in {
|
|
162
|
+
models = self.fetch_models(base_url=field_value)
|
|
163
|
+
build_config["model_name"]["options"] = models
|
|
164
|
+
if build_config["model_name"]["value"]:
|
|
165
|
+
build_config["model_name"]["value"] = models[0]
|
|
166
|
+
info_message = f"Updated model options: {len(models)} models found in {field_value}"
|
|
167
167
|
logger.info(info_message)
|
|
168
168
|
except Exception: # noqa: BLE001
|
|
169
169
|
logger.exception("Error updating model options.")
|
|
170
|
+
if field_name == "model_name" and field_value and field_value in WatsonxAIComponent._urls:
|
|
171
|
+
build_config["model_name"]["options"] = self.fetch_models(base_url=field_value)
|
|
172
|
+
build_config["model_name"]["value"] = ""
|
|
173
|
+
return build_config
|
|
170
174
|
|
|
171
175
|
def build_model(self) -> LanguageModel:
|
|
172
176
|
# Parse logit_bias from JSON string if provided
|
|
@@ -195,7 +199,7 @@ class WatsonxAIComponent(LCModelComponent):
|
|
|
195
199
|
|
|
196
200
|
return ChatWatsonx(
|
|
197
201
|
apikey=SecretStr(self.api_key).get_secret_value(),
|
|
198
|
-
url=self.
|
|
202
|
+
url=self.base_url,
|
|
199
203
|
project_id=self.project_id,
|
|
200
204
|
model_id=self.model_name,
|
|
201
205
|
params=chat_params,
|