lfx-nightly 0.1.13.dev3__py3-none-any.whl → 0.1.13.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lfx-nightly might be problematic. Click here for more details.
- lfx/_assets/component_index.json +1 -1
- lfx/base/composio/composio_base.py +24 -9
- lfx/base/datastax/astradb_base.py +3 -2
- lfx/base/io/chat.py +5 -4
- lfx/base/mcp/util.py +71 -5
- lfx/base/models/watsonx_constants.py +12 -0
- lfx/cli/commands.py +1 -1
- lfx/components/agents/agent.py +1 -1
- lfx/components/agents/cuga_agent.py +1 -1
- lfx/components/agents/mcp_component.py +16 -0
- lfx/components/amazon/amazon_bedrock_converse.py +1 -1
- lfx/components/apify/apify_actor.py +3 -3
- lfx/components/datastax/astradb_vectorstore.py +1 -1
- lfx/components/mistral/mistral_embeddings.py +1 -1
- lfx/components/models/embedding_model.py +85 -7
- lfx/components/openrouter/openrouter.py +49 -147
- lfx/custom/custom_component/component.py +3 -2
- lfx/graph/graph/base.py +1 -1
- lfx/graph/graph/schema.py +3 -2
- lfx/graph/vertex/vertex_types.py +1 -1
- lfx/io/schema.py +6 -0
- {lfx_nightly-0.1.13.dev3.dist-info → lfx_nightly-0.1.13.dev5.dist-info}/METADATA +1 -1
- {lfx_nightly-0.1.13.dev3.dist-info → lfx_nightly-0.1.13.dev5.dist-info}/RECORD +25 -24
- {lfx_nightly-0.1.13.dev3.dist-info → lfx_nightly-0.1.13.dev5.dist-info}/WHEEL +0 -0
- {lfx_nightly-0.1.13.dev3.dist-info → lfx_nightly-0.1.13.dev5.dist-info}/entry_points.txt +0 -0
|
@@ -284,6 +284,21 @@ class ComposioBaseComponent(Component):
|
|
|
284
284
|
# Track all auth field names discovered across all toolkits
|
|
285
285
|
_all_auth_field_names: set[str] = set()
|
|
286
286
|
|
|
287
|
+
@classmethod
|
|
288
|
+
def get_actions_cache(cls) -> dict[str, dict[str, Any]]:
|
|
289
|
+
"""Get the class-level actions cache."""
|
|
290
|
+
return cls._actions_cache
|
|
291
|
+
|
|
292
|
+
@classmethod
|
|
293
|
+
def get_action_schema_cache(cls) -> dict[str, dict[str, Any]]:
|
|
294
|
+
"""Get the class-level action schema cache."""
|
|
295
|
+
return cls._action_schema_cache
|
|
296
|
+
|
|
297
|
+
@classmethod
|
|
298
|
+
def get_all_auth_field_names(cls) -> set[str]:
|
|
299
|
+
"""Get all auth field names discovered across toolkits."""
|
|
300
|
+
return cls._all_auth_field_names
|
|
301
|
+
|
|
287
302
|
outputs = [
|
|
288
303
|
Output(name="dataFrame", display_name="DataFrame", method="as_dataframe"),
|
|
289
304
|
]
|
|
@@ -403,11 +418,11 @@ class ComposioBaseComponent(Component):
|
|
|
403
418
|
|
|
404
419
|
# Try to load from the class-level cache
|
|
405
420
|
toolkit_slug = self.app_name.lower()
|
|
406
|
-
if toolkit_slug in self.__class__.
|
|
421
|
+
if toolkit_slug in self.__class__.get_actions_cache():
|
|
407
422
|
# Deep-copy so that any mutation on this instance does not affect the
|
|
408
423
|
# cached master copy.
|
|
409
|
-
self._actions_data = copy.deepcopy(self.__class__.
|
|
410
|
-
self._action_schemas = copy.deepcopy(self.__class__.
|
|
424
|
+
self._actions_data = copy.deepcopy(self.__class__.get_actions_cache()[toolkit_slug])
|
|
425
|
+
self._action_schemas = copy.deepcopy(self.__class__.get_action_schema_cache().get(toolkit_slug, {}))
|
|
411
426
|
logger.debug(f"Loaded actions for {toolkit_slug} from in-process cache")
|
|
412
427
|
return
|
|
413
428
|
|
|
@@ -630,8 +645,8 @@ class ComposioBaseComponent(Component):
|
|
|
630
645
|
|
|
631
646
|
# Cache actions for this toolkit so subsequent component instances
|
|
632
647
|
# can reuse them without hitting the Composio API again.
|
|
633
|
-
self.__class__.
|
|
634
|
-
self.__class__.
|
|
648
|
+
self.__class__.get_actions_cache()[toolkit_slug] = copy.deepcopy(self._actions_data)
|
|
649
|
+
self.__class__.get_action_schema_cache()[toolkit_slug] = copy.deepcopy(self._action_schemas)
|
|
635
650
|
|
|
636
651
|
except ValueError as e:
|
|
637
652
|
logger.debug(f"Could not populate Composio actions for {self.app_name}: {e}")
|
|
@@ -1313,7 +1328,7 @@ class ComposioBaseComponent(Component):
|
|
|
1313
1328
|
|
|
1314
1329
|
self._auth_dynamic_fields.add(name)
|
|
1315
1330
|
# Also add to class-level cache for better tracking
|
|
1316
|
-
self.__class__.
|
|
1331
|
+
self.__class__.get_all_auth_field_names().add(name)
|
|
1317
1332
|
|
|
1318
1333
|
def _render_custom_auth_fields(self, build_config: dict, schema: dict[str, Any], mode: str) -> None:
|
|
1319
1334
|
"""Render fields for custom auth based on schema auth_config_details sections."""
|
|
@@ -1378,7 +1393,7 @@ class ComposioBaseComponent(Component):
|
|
|
1378
1393
|
if name:
|
|
1379
1394
|
names.add(name)
|
|
1380
1395
|
# Add to class-level cache for tracking all discovered auth fields
|
|
1381
|
-
self.__class__.
|
|
1396
|
+
self.__class__.get_all_auth_field_names().add(name)
|
|
1382
1397
|
# Only use names discovered from the toolkit schema; do not add aliases
|
|
1383
1398
|
return names
|
|
1384
1399
|
|
|
@@ -1443,7 +1458,7 @@ class ComposioBaseComponent(Component):
|
|
|
1443
1458
|
# Check if we need to populate actions - but also check cache availability
|
|
1444
1459
|
actions_available = bool(self._actions_data)
|
|
1445
1460
|
toolkit_slug = getattr(self, "app_name", "").lower()
|
|
1446
|
-
cached_actions_available = toolkit_slug in self.__class__.
|
|
1461
|
+
cached_actions_available = toolkit_slug in self.__class__.get_actions_cache()
|
|
1447
1462
|
|
|
1448
1463
|
should_populate = False
|
|
1449
1464
|
|
|
@@ -2623,7 +2638,7 @@ class ComposioBaseComponent(Component):
|
|
|
2623
2638
|
# Add all dynamic auth fields to protected set
|
|
2624
2639
|
protected.update(self._auth_dynamic_fields)
|
|
2625
2640
|
# Also protect any auth fields discovered across all instances
|
|
2626
|
-
protected.update(self.__class__.
|
|
2641
|
+
protected.update(self.__class__.get_all_auth_field_names())
|
|
2627
2642
|
|
|
2628
2643
|
for key, cfg in list(build_config.items()):
|
|
2629
2644
|
if key in protected:
|
|
@@ -14,6 +14,7 @@ from lfx.io import (
|
|
|
14
14
|
SecretStrInput,
|
|
15
15
|
StrInput,
|
|
16
16
|
)
|
|
17
|
+
from lfx.log.logger import logger
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class AstraDBBaseComponent(Component):
|
|
@@ -364,8 +365,8 @@ class AstraDBBaseComponent(Component):
|
|
|
364
365
|
"status": db.status if db.status != "ACTIVE" else None,
|
|
365
366
|
"org_id": db.org_id if db.org_id else None,
|
|
366
367
|
}
|
|
367
|
-
except Exception: # noqa: BLE001
|
|
368
|
-
|
|
368
|
+
except Exception as e: # noqa: BLE001
|
|
369
|
+
logger.debug("Failed to get metadata for database %s: %s", db.name, e)
|
|
369
370
|
|
|
370
371
|
return db_info_dict
|
|
371
372
|
|
lfx/base/io/chat.py
CHANGED
|
@@ -6,8 +6,9 @@ class ChatComponent(Component):
|
|
|
6
6
|
description = "Use as base for chat components."
|
|
7
7
|
|
|
8
8
|
def get_properties_from_source_component(self):
|
|
9
|
-
|
|
10
|
-
|
|
9
|
+
vertex = self.get_vertex()
|
|
10
|
+
if vertex and hasattr(vertex, "incoming_edges") and vertex.incoming_edges:
|
|
11
|
+
source_id = vertex.incoming_edges[0].source_id
|
|
11
12
|
source_vertex = self.graph.get_vertex(source_id)
|
|
12
13
|
component = source_vertex.custom_component
|
|
13
14
|
source = component.display_name
|
|
@@ -15,6 +16,6 @@ class ChatComponent(Component):
|
|
|
15
16
|
possible_attributes = ["model_name", "model_id", "model"]
|
|
16
17
|
for attribute in possible_attributes:
|
|
17
18
|
if hasattr(component, attribute) and getattr(component, attribute):
|
|
18
|
-
return getattr(component, attribute), icon, source, component.
|
|
19
|
-
return source, icon, component.display_name, component.
|
|
19
|
+
return getattr(component, attribute), icon, source, component.get_id()
|
|
20
|
+
return source, icon, component.display_name, component.get_id()
|
|
20
21
|
return None, None, None, None
|
lfx/base/mcp/util.py
CHANGED
|
@@ -85,6 +85,46 @@ ALLOWED_HEADERS = {
|
|
|
85
85
|
}
|
|
86
86
|
|
|
87
87
|
|
|
88
|
+
def create_mcp_http_client_with_ssl_option(
|
|
89
|
+
headers: dict[str, str] | None = None,
|
|
90
|
+
timeout: httpx.Timeout | None = None,
|
|
91
|
+
auth: httpx.Auth | None = None,
|
|
92
|
+
*,
|
|
93
|
+
verify_ssl: bool = True,
|
|
94
|
+
) -> httpx.AsyncClient:
|
|
95
|
+
"""Create an httpx AsyncClient with configurable SSL verification.
|
|
96
|
+
|
|
97
|
+
This is a custom factory that extends the standard MCP client factory
|
|
98
|
+
to support disabling SSL verification for self-signed certificates.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
headers: Optional headers to include with all requests.
|
|
102
|
+
timeout: Request timeout as httpx.Timeout object.
|
|
103
|
+
auth: Optional authentication handler.
|
|
104
|
+
verify_ssl: Whether to verify SSL certificates (default: True).
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Configured httpx.AsyncClient instance.
|
|
108
|
+
"""
|
|
109
|
+
kwargs: dict[str, Any] = {
|
|
110
|
+
"follow_redirects": True,
|
|
111
|
+
"verify": verify_ssl,
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
if timeout is None:
|
|
115
|
+
kwargs["timeout"] = httpx.Timeout(30.0)
|
|
116
|
+
else:
|
|
117
|
+
kwargs["timeout"] = timeout
|
|
118
|
+
|
|
119
|
+
if headers is not None:
|
|
120
|
+
kwargs["headers"] = headers
|
|
121
|
+
|
|
122
|
+
if auth is not None:
|
|
123
|
+
kwargs["auth"] = auth
|
|
124
|
+
|
|
125
|
+
return httpx.AsyncClient(**kwargs)
|
|
126
|
+
|
|
127
|
+
|
|
88
128
|
def validate_headers(headers: dict[str, str]) -> dict[str, str]:
|
|
89
129
|
"""Validate and sanitize HTTP headers according to RFC 7230.
|
|
90
130
|
|
|
@@ -695,7 +735,7 @@ class MCPSessionManager:
|
|
|
695
735
|
|
|
696
736
|
Args:
|
|
697
737
|
session_id: Unique identifier for this session
|
|
698
|
-
connection_params: Connection parameters including URL, headers, timeouts
|
|
738
|
+
connection_params: Connection parameters including URL, headers, timeouts, verify_ssl
|
|
699
739
|
preferred_transport: If set to "sse", skip Streamable HTTP and go directly to SSE
|
|
700
740
|
|
|
701
741
|
Returns:
|
|
@@ -711,6 +751,19 @@ class MCPSessionManager:
|
|
|
711
751
|
# Track which transport succeeded
|
|
712
752
|
used_transport: list[str] = []
|
|
713
753
|
|
|
754
|
+
# Get verify_ssl option from connection params, default to True
|
|
755
|
+
verify_ssl = connection_params.get("verify_ssl", True)
|
|
756
|
+
|
|
757
|
+
# Create custom httpx client factory with SSL verification option
|
|
758
|
+
def custom_httpx_factory(
|
|
759
|
+
headers: dict[str, str] | None = None,
|
|
760
|
+
timeout: httpx.Timeout | None = None,
|
|
761
|
+
auth: httpx.Auth | None = None,
|
|
762
|
+
) -> httpx.AsyncClient:
|
|
763
|
+
return create_mcp_http_client_with_ssl_option(
|
|
764
|
+
headers=headers, timeout=timeout, auth=auth, verify_ssl=verify_ssl
|
|
765
|
+
)
|
|
766
|
+
|
|
714
767
|
async def session_task():
|
|
715
768
|
"""Background task that keeps the session alive."""
|
|
716
769
|
streamable_error = None
|
|
@@ -725,6 +778,7 @@ class MCPSessionManager:
|
|
|
725
778
|
url=connection_params["url"],
|
|
726
779
|
headers=connection_params["headers"],
|
|
727
780
|
timeout=connection_params["timeout_seconds"],
|
|
781
|
+
httpx_client_factory=custom_httpx_factory,
|
|
728
782
|
) as (read, write, _):
|
|
729
783
|
session = ClientSession(read, write)
|
|
730
784
|
async with session:
|
|
@@ -765,6 +819,7 @@ class MCPSessionManager:
|
|
|
765
819
|
connection_params["headers"],
|
|
766
820
|
connection_params["timeout_seconds"],
|
|
767
821
|
sse_read_timeout,
|
|
822
|
+
httpx_client_factory=custom_httpx_factory,
|
|
768
823
|
) as (read, write):
|
|
769
824
|
session = ClientSession(read, write)
|
|
770
825
|
async with session:
|
|
@@ -1216,6 +1271,8 @@ class MCPStreamableHttpClient:
|
|
|
1216
1271
|
headers: dict[str, str] | None = None,
|
|
1217
1272
|
timeout_seconds: int = 30,
|
|
1218
1273
|
sse_read_timeout_seconds: int = 30,
|
|
1274
|
+
*,
|
|
1275
|
+
verify_ssl: bool = True,
|
|
1219
1276
|
) -> list[StructuredTool]:
|
|
1220
1277
|
"""Connect to MCP server using Streamable HTTP transport with SSE fallback (SDK style)."""
|
|
1221
1278
|
# Validate and sanitize headers early
|
|
@@ -1233,12 +1290,13 @@ class MCPStreamableHttpClient:
|
|
|
1233
1290
|
msg = f"Invalid Streamable HTTP or SSE URL ({url}): {error_msg}"
|
|
1234
1291
|
raise ValueError(msg)
|
|
1235
1292
|
# Store connection parameters for later use in run_tool
|
|
1236
|
-
# Include SSE read timeout for fallback
|
|
1293
|
+
# Include SSE read timeout for fallback and SSL verification option
|
|
1237
1294
|
self._connection_params = {
|
|
1238
1295
|
"url": url,
|
|
1239
1296
|
"headers": validated_headers,
|
|
1240
1297
|
"timeout_seconds": timeout_seconds,
|
|
1241
1298
|
"sse_read_timeout_seconds": sse_read_timeout_seconds,
|
|
1299
|
+
"verify_ssl": verify_ssl,
|
|
1242
1300
|
}
|
|
1243
1301
|
elif headers:
|
|
1244
1302
|
self._connection_params["headers"] = validated_headers
|
|
@@ -1258,11 +1316,18 @@ class MCPStreamableHttpClient:
|
|
|
1258
1316
|
return response.tools
|
|
1259
1317
|
|
|
1260
1318
|
async def connect_to_server(
|
|
1261
|
-
self,
|
|
1319
|
+
self,
|
|
1320
|
+
url: str,
|
|
1321
|
+
headers: dict[str, str] | None = None,
|
|
1322
|
+
sse_read_timeout_seconds: int = 30,
|
|
1323
|
+
*,
|
|
1324
|
+
verify_ssl: bool = True,
|
|
1262
1325
|
) -> list[StructuredTool]:
|
|
1263
1326
|
"""Connect to MCP server using Streamable HTTP with SSE fallback transport (SDK style)."""
|
|
1264
1327
|
return await asyncio.wait_for(
|
|
1265
|
-
self._connect_to_server(
|
|
1328
|
+
self._connect_to_server(
|
|
1329
|
+
url, headers, sse_read_timeout_seconds=sse_read_timeout_seconds, verify_ssl=verify_ssl
|
|
1330
|
+
),
|
|
1266
1331
|
timeout=get_settings_service().settings.mcp_server_timeout,
|
|
1267
1332
|
)
|
|
1268
1333
|
|
|
@@ -1493,7 +1558,8 @@ async def update_tools(
|
|
|
1493
1558
|
client = mcp_stdio_client
|
|
1494
1559
|
elif mode in ["Streamable_HTTP", "SSE"]:
|
|
1495
1560
|
# Streamable HTTP connection with SSE fallback
|
|
1496
|
-
|
|
1561
|
+
verify_ssl = server_config.get("verify_ssl", True)
|
|
1562
|
+
tools = await mcp_streamable_http_client.connect_to_server(url, headers=headers, verify_ssl=verify_ssl)
|
|
1497
1563
|
client = mcp_streamable_http_client
|
|
1498
1564
|
else:
|
|
1499
1565
|
logger.error(f"Invalid MCP server mode for '{server_name}': {mode}")
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .model_metadata import create_model_metadata
|
|
2
|
+
|
|
3
|
+
# Granite Embedding models
|
|
4
|
+
WATSONX_EMBEDDING_MODELS_DETAILED = [
|
|
5
|
+
create_model_metadata(provider="IBM Watsonx", name="ibm/granite-embedding-125m-english", icon="IBMWatsonx"),
|
|
6
|
+
create_model_metadata(provider="IBM Watsonx", name="ibm/granite-embedding-278m-multilingual", icon="IBMWatsonx"),
|
|
7
|
+
create_model_metadata(provider="IBM Watsonx", name="ibm/granite-embedding-30m-english", icon="IBMWatsonx"),
|
|
8
|
+
create_model_metadata(provider="IBM Watsonx", name="ibm/granite-embedding-107m-multilingual", icon="IBMWatsonx"),
|
|
9
|
+
create_model_metadata(provider="IBM Watsonx", name="ibm/granite-embedding-30m-sparse", icon="IBMWatsonx"),
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
WATSONX_EMBEDDING_MODEL_NAMES = [metadata["name"] for metadata in WATSONX_EMBEDDING_MODELS_DETAILED]
|
lfx/cli/commands.py
CHANGED
|
@@ -43,7 +43,7 @@ def serve_command(
|
|
|
43
43
|
host: str = typer.Option("127.0.0.1", "--host", "-h", help="Host to bind the server to"),
|
|
44
44
|
port: int = typer.Option(8000, "--port", "-p", help="Port to bind the server to"),
|
|
45
45
|
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show diagnostic output and execution details"), # noqa: FBT001, FBT003
|
|
46
|
-
env_file: Path | None = typer.Option(
|
|
46
|
+
env_file: Path | None = typer.Option(
|
|
47
47
|
None,
|
|
48
48
|
"--env-file",
|
|
49
49
|
help="Path to the .env file containing environment variables",
|
lfx/components/agents/agent.py
CHANGED
|
@@ -61,6 +61,7 @@ class MCPToolsComponent(ComponentWithCache):
|
|
|
61
61
|
"mcp_server",
|
|
62
62
|
"tool",
|
|
63
63
|
"use_cache",
|
|
64
|
+
"verify_ssl",
|
|
64
65
|
]
|
|
65
66
|
|
|
66
67
|
display_name = "MCP Tools"
|
|
@@ -86,6 +87,16 @@ class MCPToolsComponent(ComponentWithCache):
|
|
|
86
87
|
value=False,
|
|
87
88
|
advanced=True,
|
|
88
89
|
),
|
|
90
|
+
BoolInput(
|
|
91
|
+
name="verify_ssl",
|
|
92
|
+
display_name="Verify SSL Certificate",
|
|
93
|
+
info=(
|
|
94
|
+
"Enable SSL certificate verification for HTTPS connections. "
|
|
95
|
+
"Disable only for development/testing with self-signed certificates."
|
|
96
|
+
),
|
|
97
|
+
value=True,
|
|
98
|
+
advanced=True,
|
|
99
|
+
),
|
|
89
100
|
DropdownInput(
|
|
90
101
|
name="tool",
|
|
91
102
|
display_name="Tool",
|
|
@@ -210,6 +221,11 @@ class MCPToolsComponent(ComponentWithCache):
|
|
|
210
221
|
self.tools = []
|
|
211
222
|
return [], {"name": server_name, "config": server_config}
|
|
212
223
|
|
|
224
|
+
# Add verify_ssl option to server config if not present
|
|
225
|
+
if "verify_ssl" not in server_config:
|
|
226
|
+
verify_ssl = getattr(self, "verify_ssl", True)
|
|
227
|
+
server_config["verify_ssl"] = verify_ssl
|
|
228
|
+
|
|
213
229
|
_, tool_list, tool_cache = await update_tools(
|
|
214
230
|
server_name=server_name,
|
|
215
231
|
server_config=server_config,
|
|
@@ -92,7 +92,7 @@ class ApifyActorsComponent(Component):
|
|
|
92
92
|
"""Run the Actor and return node output."""
|
|
93
93
|
input_ = json.loads(self.run_input)
|
|
94
94
|
fields = ApifyActorsComponent.parse_dataset_fields(self.dataset_fields) if self.dataset_fields else None
|
|
95
|
-
res = self.
|
|
95
|
+
res = self.run_actor(self.actor_id, input_, fields=fields)
|
|
96
96
|
if self.flatten_dataset:
|
|
97
97
|
res = [ApifyActorsComponent.flatten(item) for item in res]
|
|
98
98
|
data = [Data(data=item) for item in res]
|
|
@@ -159,7 +159,7 @@ class ApifyActorsComponent(Component):
|
|
|
159
159
|
# retrieve if nested, just in case
|
|
160
160
|
input_dict = input_dict.get("run_input", input_dict)
|
|
161
161
|
|
|
162
|
-
res = parent.
|
|
162
|
+
res = parent.run_actor(actor_id, input_dict)
|
|
163
163
|
return "\n\n".join([ApifyActorsComponent.dict_to_json_str(item) for item in res])
|
|
164
164
|
|
|
165
165
|
return ApifyActorRun
|
|
@@ -256,7 +256,7 @@ class ApifyActorsComponent(Component):
|
|
|
256
256
|
valid_chars = string.ascii_letters + string.digits + "_-"
|
|
257
257
|
return "".join(char if char in valid_chars else "_" for char in actor_id)
|
|
258
258
|
|
|
259
|
-
def
|
|
259
|
+
def run_actor(self, actor_id: str, run_input: dict, fields: list[str] | None = None) -> list[dict]:
|
|
260
260
|
"""Run an Apify Actor and return the output dataset.
|
|
261
261
|
|
|
262
262
|
Args:
|
|
@@ -227,7 +227,7 @@ class AstraDBVectorStoreComponent(AstraDBBaseComponent, LCVectorStoreComponent):
|
|
|
227
227
|
for provider in providers.reranking_providers.values()
|
|
228
228
|
for model in provider.models
|
|
229
229
|
]
|
|
230
|
-
except
|
|
230
|
+
except Exception as e: # noqa: BLE001
|
|
231
231
|
self.log(f"Hybrid search not available: {e}")
|
|
232
232
|
return {
|
|
233
233
|
"available": False,
|
|
@@ -3,7 +3,9 @@ from typing import Any
|
|
|
3
3
|
from langchain_openai import OpenAIEmbeddings
|
|
4
4
|
|
|
5
5
|
from lfx.base.embeddings.model import LCEmbeddingsModel
|
|
6
|
+
from lfx.base.models.ollama_constants import OLLAMA_EMBEDDING_MODELS
|
|
6
7
|
from lfx.base.models.openai_constants import OPENAI_EMBEDDING_MODEL_NAMES
|
|
8
|
+
from lfx.base.models.watsonx_constants import WATSONX_EMBEDDING_MODEL_NAMES
|
|
7
9
|
from lfx.field_typing import Embeddings
|
|
8
10
|
from lfx.io import (
|
|
9
11
|
BoolInput,
|
|
@@ -29,11 +31,11 @@ class EmbeddingModelComponent(LCEmbeddingsModel):
|
|
|
29
31
|
DropdownInput(
|
|
30
32
|
name="provider",
|
|
31
33
|
display_name="Model Provider",
|
|
32
|
-
options=["OpenAI"],
|
|
34
|
+
options=["OpenAI", "Ollama", "WatsonX"],
|
|
33
35
|
value="OpenAI",
|
|
34
36
|
info="Select the embedding model provider",
|
|
35
37
|
real_time_refresh=True,
|
|
36
|
-
options_metadata=[{"icon": "OpenAI"}],
|
|
38
|
+
options_metadata=[{"icon": "OpenAI"}, {"icon": "Ollama"}, {"icon": "WatsonxAI"}],
|
|
37
39
|
),
|
|
38
40
|
DropdownInput(
|
|
39
41
|
name="model",
|
|
@@ -56,6 +58,13 @@ class EmbeddingModelComponent(LCEmbeddingsModel):
|
|
|
56
58
|
info="Base URL for the API. Leave empty for default.",
|
|
57
59
|
advanced=True,
|
|
58
60
|
),
|
|
61
|
+
# Watson-specific inputs
|
|
62
|
+
MessageTextInput(
|
|
63
|
+
name="project_id",
|
|
64
|
+
display_name="Project ID",
|
|
65
|
+
info="Watson AI Project ID (required for WatsonX)",
|
|
66
|
+
show=False,
|
|
67
|
+
),
|
|
59
68
|
IntInput(
|
|
60
69
|
name="dimensions",
|
|
61
70
|
display_name="Dimensions",
|
|
@@ -102,13 +111,82 @@ class EmbeddingModelComponent(LCEmbeddingsModel):
|
|
|
102
111
|
show_progress_bar=show_progress_bar,
|
|
103
112
|
model_kwargs=model_kwargs,
|
|
104
113
|
)
|
|
114
|
+
|
|
115
|
+
if provider == "Ollama":
|
|
116
|
+
try:
|
|
117
|
+
from langchain_ollama import OllamaEmbeddings
|
|
118
|
+
except ImportError:
|
|
119
|
+
try:
|
|
120
|
+
from langchain_community.embeddings import OllamaEmbeddings
|
|
121
|
+
except ImportError:
|
|
122
|
+
msg = "Please install langchain-ollama: pip install langchain-ollama"
|
|
123
|
+
raise ImportError(msg) from None
|
|
124
|
+
|
|
125
|
+
return OllamaEmbeddings(
|
|
126
|
+
model=model,
|
|
127
|
+
base_url=api_base or "http://localhost:11434",
|
|
128
|
+
**model_kwargs,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if provider == "WatsonX":
|
|
132
|
+
try:
|
|
133
|
+
from langchain_ibm import WatsonxEmbeddings
|
|
134
|
+
except ImportError:
|
|
135
|
+
msg = "Please install langchain-ibm: pip install langchain-ibm"
|
|
136
|
+
raise ImportError(msg) from None
|
|
137
|
+
|
|
138
|
+
if not api_key:
|
|
139
|
+
msg = "Watson AI API key is required when using WatsonX provider"
|
|
140
|
+
raise ValueError(msg)
|
|
141
|
+
|
|
142
|
+
project_id = self.project_id
|
|
143
|
+
|
|
144
|
+
if not project_id:
|
|
145
|
+
msg = "Project ID is required for WatsonX"
|
|
146
|
+
raise ValueError(msg)
|
|
147
|
+
|
|
148
|
+
params = {
|
|
149
|
+
"model_id": model,
|
|
150
|
+
"url": api_base or "https://us-south.ml.cloud.ibm.com",
|
|
151
|
+
"apikey": api_key,
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
params["project_id"] = project_id
|
|
155
|
+
|
|
156
|
+
return WatsonxEmbeddings(**params)
|
|
157
|
+
|
|
105
158
|
msg = f"Unknown provider: {provider}"
|
|
106
159
|
raise ValueError(msg)
|
|
107
160
|
|
|
108
161
|
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None) -> dotdict:
|
|
109
|
-
if field_name == "provider"
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
162
|
+
if field_name == "provider":
|
|
163
|
+
if field_value == "OpenAI":
|
|
164
|
+
build_config["model"]["options"] = OPENAI_EMBEDDING_MODEL_NAMES
|
|
165
|
+
build_config["model"]["value"] = OPENAI_EMBEDDING_MODEL_NAMES[0]
|
|
166
|
+
build_config["api_key"]["display_name"] = "OpenAI API Key"
|
|
167
|
+
build_config["api_key"]["required"] = True
|
|
168
|
+
build_config["api_key"]["show"] = True
|
|
169
|
+
build_config["api_base"]["display_name"] = "OpenAI API Base URL"
|
|
170
|
+
build_config["project_id"]["show"] = False
|
|
171
|
+
|
|
172
|
+
elif field_value == "Ollama":
|
|
173
|
+
build_config["model"]["options"] = OLLAMA_EMBEDDING_MODELS
|
|
174
|
+
build_config["model"]["value"] = OLLAMA_EMBEDDING_MODELS[0]
|
|
175
|
+
build_config["api_key"]["display_name"] = "API Key (Optional)"
|
|
176
|
+
build_config["api_key"]["required"] = False
|
|
177
|
+
build_config["api_key"]["show"] = False
|
|
178
|
+
build_config["api_base"]["display_name"] = "Ollama Base URL"
|
|
179
|
+
build_config["api_base"]["value"] = "http://localhost:11434"
|
|
180
|
+
build_config["project_id"]["show"] = False
|
|
181
|
+
|
|
182
|
+
elif field_value == "WatsonX":
|
|
183
|
+
build_config["model"]["options"] = WATSONX_EMBEDDING_MODEL_NAMES
|
|
184
|
+
build_config["model"]["value"] = WATSONX_EMBEDDING_MODEL_NAMES[0]
|
|
185
|
+
build_config["api_key"]["display_name"] = "Watson AI API Key"
|
|
186
|
+
build_config["api_key"]["required"] = True
|
|
187
|
+
build_config["api_key"]["show"] = True
|
|
188
|
+
build_config["api_base"]["display_name"] = "Watson AI URL"
|
|
189
|
+
build_config["api_base"]["value"] = "https://us-south.ml.cloud.ibm.com"
|
|
190
|
+
build_config["project_id"]["show"] = True
|
|
191
|
+
|
|
114
192
|
return build_config
|