lfx-nightly 0.1.13.dev8__py3-none-any.whl → 0.1.13.dev10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lfx-nightly might be problematic. Click here for more details.

lfx/base/agents/agent.py CHANGED
@@ -5,12 +5,13 @@ from typing import TYPE_CHECKING, cast
5
5
 
6
6
  from langchain.agents import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent
7
7
  from langchain.agents.agent import RunnableAgent
8
- from langchain_core.messages import HumanMessage
8
+ from langchain.callbacks.base import BaseCallbackHandler
9
+ from langchain_core.messages import BaseMessage, HumanMessage
9
10
  from langchain_core.runnables import Runnable
10
11
 
11
12
  from lfx.base.agents.callback import AgentAsyncHandler
12
13
  from lfx.base.agents.events import ExceptionWithMessageError, process_agent_events
13
- from lfx.base.agents.utils import data_to_messages, get_chat_output_sender_name
14
+ from lfx.base.agents.utils import get_chat_output_sender_name
14
15
  from lfx.custom.custom_component.component import Component, _get_component_toolkit
15
16
  from lfx.field_typing import Tool
16
17
  from lfx.inputs.inputs import InputTypes, MultilineInput
@@ -19,14 +20,13 @@ from lfx.log.logger import logger
19
20
  from lfx.memory import delete_message
20
21
  from lfx.schema.content_block import ContentBlock
21
22
  from lfx.schema.data import Data
23
+ from lfx.schema.log import OnTokenFunctionType
22
24
  from lfx.schema.message import Message
23
25
  from lfx.template.field.base import Output
24
26
  from lfx.utils.constants import MESSAGE_SENDER_AI
25
27
 
26
28
  if TYPE_CHECKING:
27
- from langchain_core.messages import BaseMessage
28
-
29
- from lfx.schema.log import SendMessageFunctionType
29
+ from lfx.schema.log import OnTokenFunctionType, SendMessageFunctionType
30
30
 
31
31
 
32
32
  DEFAULT_TOOLS_DESCRIPTION = "A helpful assistant with access to the following tools:"
@@ -75,6 +75,12 @@ class LCAgentComponent(Component):
75
75
  Output(display_name="Response", name="response", method="message_response"),
76
76
  ]
77
77
 
78
+ # Get shared callbacks for tracing and save them to self.shared_callbacks
79
+ def _get_shared_callbacks(self) -> list[BaseCallbackHandler]:
80
+ if not hasattr(self, "shared_callbacks"):
81
+ self.shared_callbacks = self.get_langchain_callbacks()
82
+ return self.shared_callbacks
83
+
78
84
  @abstractmethod
79
85
  def build_agent(self) -> AgentExecutor:
80
86
  """Create the agent."""
@@ -119,6 +125,24 @@ class LCAgentComponent(Component):
119
125
  # might be overridden in subclasses
120
126
  return None
121
127
 
128
+ def _data_to_messages_skip_empty(self, data: list[Data]) -> list[BaseMessage]:
129
+ """Convert data to messages, filtering only empty text while preserving non-text content.
130
+
131
+ Note: added to fix issue with certain providers failing when given empty text as input.
132
+ """
133
+ messages = []
134
+ for value in data:
135
+ # Only skip if the message has a text attribute that is empty/whitespace
136
+ text = getattr(value, "text", None)
137
+ if isinstance(text, str) and not text.strip():
138
+ # Skip only messages with empty/whitespace-only text strings
139
+ continue
140
+
141
+ lc_message = value.to_lc_message()
142
+ messages.append(lc_message)
143
+
144
+ return messages
145
+
122
146
  async def run_agent(
123
147
  self,
124
148
  agent: Runnable | BaseSingleActionAgent | BaseMultiActionAgent | AgentExecutor,
@@ -138,57 +162,64 @@ class LCAgentComponent(Component):
138
162
  max_iterations=max_iterations,
139
163
  )
140
164
  # Convert input_value to proper format for agent
141
- if hasattr(self.input_value, "to_lc_message") and callable(self.input_value.to_lc_message):
165
+ lc_message = None
166
+ if isinstance(self.input_value, Message):
142
167
  lc_message = self.input_value.to_lc_message()
143
- input_text = lc_message.content if hasattr(lc_message, "content") else str(lc_message)
168
+ input_dict: dict[str, str | list[BaseMessage] | BaseMessage] = {"input": lc_message}
144
169
  else:
145
- lc_message = None
146
- input_text = self.input_value
170
+ input_dict = {"input": self.input_value}
147
171
 
148
- input_dict: dict[str, str | list[BaseMessage]] = {}
149
172
  if hasattr(self, "system_prompt"):
150
173
  input_dict["system_prompt"] = self.system_prompt
151
- if hasattr(self, "chat_history") and self.chat_history:
152
- if (
153
- hasattr(self.chat_history, "to_data")
154
- and callable(self.chat_history.to_data)
155
- and self.chat_history.__class__.__name__ == "Data"
156
- ):
157
- input_dict["chat_history"] = data_to_messages(self.chat_history)
158
- # Handle both lfx.schema.message.Message and langflow.schema.message.Message types
159
- if all(hasattr(m, "to_data") and callable(m.to_data) and "text" in m.data for m in self.chat_history):
160
- input_dict["chat_history"] = data_to_messages(self.chat_history)
161
- if all(isinstance(m, Message) for m in self.chat_history):
162
- input_dict["chat_history"] = data_to_messages([m.to_data() for m in self.chat_history])
163
- if hasattr(lc_message, "content") and isinstance(lc_message.content, list):
164
- # ! Because the input has to be a string, we must pass the images in the chat_history
165
174
 
175
+ if hasattr(self, "chat_history") and self.chat_history:
176
+ if isinstance(self.chat_history, Data):
177
+ input_dict["chat_history"] = self._data_to_messages_skip_empty([self.chat_history])
178
+ elif all(hasattr(m, "to_data") and callable(m.to_data) and "text" in m.data for m in self.chat_history):
179
+ input_dict["chat_history"] = self._data_to_messages_skip_empty(self.chat_history)
180
+ elif all(isinstance(m, Message) for m in self.chat_history):
181
+ input_dict["chat_history"] = self._data_to_messages_skip_empty([m.to_data() for m in self.chat_history])
182
+
183
+ # Handle multimodal input (images + text)
184
+ # Note: Agent input must be a string, so we extract text and move images to chat_history
185
+ if lc_message is not None and hasattr(lc_message, "content") and isinstance(lc_message.content, list):
186
+ # Extract images and text from the text content items
166
187
  image_dicts = [item for item in lc_message.content if item.get("type") == "image"]
167
- text_items = [item for item in lc_message.content if item.get("type") != "image"]
188
+ text_content = [item for item in lc_message.content if item.get("type") != "image"]
168
189
 
169
- # Extract text content from remaining items
170
- if text_items:
171
- # If there are text items, extract their text content
172
- input_text = " ".join(item.get("text", "") for item in text_items if item.get("type") == "text").strip()
190
+ text_strings = [
191
+ item.get("text", "")
192
+ for item in text_content
193
+ if item.get("type") == "text" and item.get("text", "").strip()
194
+ ]
173
195
 
174
- # If input_text is still a list or empty, provide a default
175
- if isinstance(input_text, list) or not input_text:
176
- input_text = "Process the provided images."
196
+ # Set input to concatenated text or empty string
197
+ input_dict["input"] = " ".join(text_strings) if text_strings else ""
198
+
199
+ # If input is still a list or empty, provide a default
200
+ if isinstance(input_dict["input"], list) or not input_dict["input"]:
201
+ input_dict["input"] = "Process the provided images."
177
202
 
178
203
  if "chat_history" not in input_dict:
179
204
  input_dict["chat_history"] = []
205
+
180
206
  if isinstance(input_dict["chat_history"], list):
181
207
  input_dict["chat_history"].extend(HumanMessage(content=[image_dict]) for image_dict in image_dicts)
182
208
  else:
183
209
  input_dict["chat_history"] = [HumanMessage(content=[image_dict]) for image_dict in image_dicts]
184
210
 
185
- # Final safety check: ensure input_text is never empty (prevents Anthropic API errors)
186
- if not input_text or (isinstance(input_text, (list, str)) and not str(input_text).strip()):
187
- input_text = "Continue the conversation."
188
- # Ensure the agent input is a string
189
- if not isinstance(input_text, str):
190
- input_text = " ".join(map(str, input_text)) if isinstance(input_text, list) else str(input_text)
191
- input_dict["input"] = input_text
211
+ # Final safety check: ensure input is never empty (prevents Anthropic API errors)
212
+ current_input = input_dict.get("input", "")
213
+ if isinstance(current_input, list):
214
+ current_input = " ".join(map(str, current_input))
215
+ elif not isinstance(current_input, str):
216
+ current_input = str(current_input)
217
+
218
+ if not current_input.strip():
219
+ input_dict["input"] = "Continue the conversation."
220
+ else:
221
+ input_dict["input"] = current_input
222
+
192
223
  if hasattr(self, "graph"):
193
224
  session_id = self.graph.session_id
194
225
  elif hasattr(self, "_session_id"):
@@ -197,7 +228,6 @@ class LCAgentComponent(Component):
197
228
  session_id = None
198
229
 
199
230
  sender_name = get_chat_output_sender_name(self) or self.display_name or "AI"
200
-
201
231
  agent_message = Message(
202
232
  sender=MESSAGE_SENDER_AI,
203
233
  sender_name=sender_name,
@@ -205,15 +235,24 @@ class LCAgentComponent(Component):
205
235
  content_blocks=[ContentBlock(title="Agent Steps", contents=[])],
206
236
  session_id=session_id or uuid.uuid4(),
207
237
  )
238
+
239
+ # Create token callback if event_manager is available
240
+ # This wraps the event_manager's on_token method to match OnTokenFunctionType Protocol
241
+ on_token_callback: OnTokenFunctionType | None = None
242
+ if self._event_manager:
243
+ on_token_callback = cast("OnTokenFunctionType", self._event_manager.on_token)
244
+
208
245
  try:
209
246
  result = await process_agent_events(
210
247
  runnable.astream_events(
211
248
  input_dict,
212
- config={"callbacks": [AgentAsyncHandler(self.log), *self.get_langchain_callbacks()]},
249
+ # here we use the shared callbacks because the AgentExecutor uses the tools
250
+ config={"callbacks": [AgentAsyncHandler(self.log), *self._get_shared_callbacks()]},
213
251
  version="v2",
214
252
  ),
215
253
  agent_message,
216
254
  cast("SendMessageFunctionType", self.send_message),
255
+ on_token_callback,
217
256
  )
218
257
  except ExceptionWithMessageError as e:
219
258
  if hasattr(e, "agent_message") and hasattr(e.agent_message, "id"):
@@ -285,15 +324,40 @@ class LCToolsAgentComponent(LCAgentComponent):
285
324
  tools_names = ", ".join([tool.name for tool in self.tools])
286
325
  return tools_names
287
326
 
327
+ # Set shared callbacks for tracing
328
+ def set_tools_callbacks(self, tools_list: list[Tool], callbacks_list: list[BaseCallbackHandler]):
329
+ """Set shared callbacks for tracing to the tools.
330
+
331
+ If we do not pass down the same callbacks to each tool
332
+ used by the agent, then each tool will instantiate a new callback.
333
+ For some tracing services, this will cause
334
+ the callback handler to lose the id of its parent run (Agent)
335
+ and thus throw an error in the tracing service client.
336
+
337
+ Args:
338
+ tools_list: list of tools to set the callbacks for
339
+ callbacks_list: list of callbacks to set for the tools
340
+ Returns:
341
+ None
342
+ """
343
+ for tool in tools_list or []:
344
+ if hasattr(tool, "callbacks"):
345
+ tool.callbacks = callbacks_list
346
+
288
347
  async def _get_tools(self) -> list[Tool]:
289
348
  component_toolkit = _get_component_toolkit()
290
349
  tools_names = self._build_tools_names()
291
350
  agent_description = self.get_tool_description()
292
351
  # TODO: Agent Description Depreciated Feature to be removed
293
352
  description = f"{agent_description}{tools_names}"
353
+
294
354
  tools = component_toolkit(component=self).get_tools(
295
- tool_name=self.get_tool_name(), tool_description=description, callbacks=self.get_langchain_callbacks()
355
+ tool_name=self.get_tool_name(),
356
+ tool_description=description,
357
+ # here we do not use the shared callbacks as we are exposing the agent as a tool
358
+ callbacks=self.get_langchain_callbacks(),
296
359
  )
297
360
  if hasattr(self, "tools_metadata"):
298
361
  tools = component_toolkit(component=self, metadata=self.tools_metadata).update_tools_metadata(tools=tools)
362
+
299
363
  return tools
lfx/base/agents/events.py CHANGED
@@ -1,4 +1,5 @@
1
1
  # Add helper functions for each event type
2
+ import asyncio
2
3
  from collections.abc import AsyncIterator
3
4
  from time import perf_counter
4
5
  from typing import Any, Protocol
@@ -9,7 +10,7 @@ from typing_extensions import TypedDict
9
10
 
10
11
  from lfx.schema.content_block import ContentBlock
11
12
  from lfx.schema.content_types import TextContent, ToolContent
12
- from lfx.schema.log import SendMessageFunctionType
13
+ from lfx.schema.log import OnTokenFunctionType, SendMessageFunctionType
13
14
  from lfx.schema.message import Message
14
15
 
15
16
 
@@ -53,7 +54,14 @@ def _calculate_duration(start_time: float) -> int:
53
54
 
54
55
 
55
56
  async def handle_on_chain_start(
56
- event: dict[str, Any], agent_message: Message, send_message_method: SendMessageFunctionType, start_time: float
57
+ event: dict[str, Any],
58
+ agent_message: Message,
59
+ send_message_callback: SendMessageFunctionType,
60
+ send_token_callback: OnTokenFunctionType | None, # noqa: ARG001
61
+ start_time: float,
62
+ *,
63
+ had_streaming: bool = False, # noqa: ARG001
64
+ message_id: str | None = None, # noqa: ARG001
57
65
  ) -> tuple[Message, float]:
58
66
  # Create content blocks if they don't exist
59
67
  if not agent_message.content_blocks:
@@ -80,7 +88,7 @@ async def handle_on_chain_start(
80
88
  header={"title": "Input", "icon": "MessageSquare"},
81
89
  )
82
90
  agent_message.content_blocks[0].contents.append(text_content)
83
- agent_message = await send_message_method(message=agent_message, skip_db_update=True)
91
+ agent_message = await send_message_callback(message=agent_message, skip_db_update=True)
84
92
  start_time = perf_counter()
85
93
  return agent_message, start_time
86
94
 
@@ -101,15 +109,23 @@ def _extract_output_text(output: str | list) -> str:
101
109
  if isinstance(item, dict):
102
110
  if "text" in item:
103
111
  return item["text"] or ""
104
- # If the item's type is "tool_use", return an empty string.
105
- if item.get("type") == "tool_use":
106
- return ""
107
- # Handle items with only 'index' key (from ChatBedrockConverse)
108
- if "index" in item and len(item) == 1:
109
- return ""
110
- # This is a workaround to deal with function calling by Anthropic
111
- if "partial_json" in item:
112
+ if "content" in item:
113
+ return str(item["content"])
114
+ if "message" in item:
115
+ return str(item["message"])
116
+
117
+ # Special case handling for non-text-like dicts
118
+ if (
119
+ item.get("type") == "tool_use" # Handle tool use items
120
+ or ("index" in item and len(item) == 1) # Handle index-only items
121
+ or "partial_json" in item # Handle partial json items
122
+ # Handle index-only items
123
+ or ("index" in item and not any(k in item for k in ("text", "content", "message")))
124
+ # Handle other metadata-only chunks that don't contain meaningful text
125
+ or not any(key in item for key in ["text", "content", "message"])
126
+ ):
112
127
  return ""
128
+
113
129
  # For any other dict format, return empty string
114
130
  return ""
115
131
  # For any other single item type (not str or dict), return empty string
@@ -133,7 +149,14 @@ def _extract_output_text(output: str | list) -> str:
133
149
 
134
150
 
135
151
  async def handle_on_chain_end(
136
- event: dict[str, Any], agent_message: Message, send_message_method: SendMessageFunctionType, start_time: float
152
+ event: dict[str, Any],
153
+ agent_message: Message,
154
+ send_message_callback: SendMessageFunctionType,
155
+ send_token_callback: OnTokenFunctionType | None, # noqa: ARG001
156
+ start_time: float,
157
+ *,
158
+ had_streaming: bool = False,
159
+ message_id: str | None = None, # noqa: ARG001
137
160
  ) -> tuple[Message, float]:
138
161
  data_output = event["data"].get("output")
139
162
  if data_output and isinstance(data_output, AgentFinish) and data_output.return_values.get("output"):
@@ -151,7 +174,11 @@ async def handle_on_chain_end(
151
174
  header={"title": "Output", "icon": "MessageSquare"},
152
175
  )
153
176
  agent_message.content_blocks[0].contents.append(text_content)
154
- agent_message = await send_message_method(message=agent_message, skip_db_update=True)
177
+
178
+ # Only send final message if we didn't have streaming chunks
179
+ # If we had streaming, frontend already accumulated the chunks
180
+ if not had_streaming:
181
+ agent_message = await send_message_callback(message=agent_message)
155
182
  start_time = perf_counter()
156
183
  return agent_message, start_time
157
184
 
@@ -160,7 +187,7 @@ async def handle_on_tool_start(
160
187
  event: dict[str, Any],
161
188
  agent_message: Message,
162
189
  tool_blocks_map: dict[str, ToolContent],
163
- send_message_method: SendMessageFunctionType,
190
+ send_message_callback: SendMessageFunctionType,
164
191
  start_time: float,
165
192
  ) -> tuple[Message, float]:
166
193
  tool_name = event["name"]
@@ -190,7 +217,7 @@ async def handle_on_tool_start(
190
217
  tool_blocks_map[tool_key] = tool_content
191
218
  agent_message.content_blocks[0].contents.append(tool_content)
192
219
 
193
- agent_message = await send_message_method(message=agent_message, skip_db_update=True)
220
+ agent_message = await send_message_callback(message=agent_message, skip_db_update=True)
194
221
  if agent_message.content_blocks and agent_message.content_blocks[0].contents:
195
222
  tool_blocks_map[tool_key] = agent_message.content_blocks[0].contents[-1]
196
223
  return agent_message, new_start_time
@@ -200,7 +227,7 @@ async def handle_on_tool_end(
200
227
  event: dict[str, Any],
201
228
  agent_message: Message,
202
229
  tool_blocks_map: dict[str, ToolContent],
203
- send_message_method: SendMessageFunctionType,
230
+ send_message_callback: SendMessageFunctionType,
204
231
  start_time: float,
205
232
  ) -> tuple[Message, float]:
206
233
  run_id = event.get("run_id", "")
@@ -209,8 +236,8 @@ async def handle_on_tool_end(
209
236
  tool_content = tool_blocks_map.get(tool_key)
210
237
 
211
238
  if tool_content and isinstance(tool_content, ToolContent):
212
- # Call send_message_method first to get the updated message structure
213
- agent_message = await send_message_method(message=agent_message, skip_db_update=True)
239
+ # Call send_message_callback first to get the updated message structure
240
+ agent_message = await send_message_callback(message=agent_message, skip_db_update=True)
214
241
  new_start_time = perf_counter()
215
242
 
216
243
  # Now find and update the tool content in the current message
@@ -246,7 +273,7 @@ async def handle_on_tool_error(
246
273
  event: dict[str, Any],
247
274
  agent_message: Message,
248
275
  tool_blocks_map: dict[str, ToolContent],
249
- send_message_method: SendMessageFunctionType,
276
+ send_message_callback: SendMessageFunctionType,
250
277
  start_time: float,
251
278
  ) -> tuple[Message, float]:
252
279
  run_id = event.get("run_id", "")
@@ -258,7 +285,7 @@ async def handle_on_tool_error(
258
285
  tool_content.error = event["data"].get("error", "Unknown error")
259
286
  tool_content.duration = _calculate_duration(start_time)
260
287
  tool_content.header = {"title": f"Error using **{tool_content.name}**", "icon": "Hammer"}
261
- agent_message = await send_message_method(message=agent_message, skip_db_update=True)
288
+ agent_message = await send_message_callback(message=agent_message, skip_db_update=True)
262
289
  start_time = perf_counter()
263
290
  return agent_message, start_time
264
291
 
@@ -266,8 +293,12 @@ async def handle_on_tool_error(
266
293
  async def handle_on_chain_stream(
267
294
  event: dict[str, Any],
268
295
  agent_message: Message,
269
- send_message_method: SendMessageFunctionType,
296
+ send_message_callback: SendMessageFunctionType, # noqa: ARG001
297
+ send_token_callback: OnTokenFunctionType | None,
270
298
  start_time: float,
299
+ *,
300
+ had_streaming: bool = False, # noqa: ARG001
301
+ message_id: str | None = None,
271
302
  ) -> tuple[Message, float]:
272
303
  data_chunk = event["data"].get("chunk", {})
273
304
  if isinstance(data_chunk, dict) and data_chunk.get("output"):
@@ -275,15 +306,26 @@ async def handle_on_chain_stream(
275
306
  if output and isinstance(output, str | list):
276
307
  agent_message.text = _extract_output_text(output)
277
308
  agent_message.properties.state = "complete"
278
- agent_message = await send_message_method(message=agent_message, skip_db_update=True)
309
+ # Don't call send_message_callback here - we must update in place
310
+ # in order to keep the message id consistent throughout the stream.
311
+ # The final message will be sent after the loop completes
279
312
  start_time = perf_counter()
280
313
  elif isinstance(data_chunk, AIMessageChunk):
281
314
  output_text = _extract_output_text(data_chunk.content)
282
- if output_text and isinstance(agent_message.text, str):
283
- agent_message.text += output_text
284
- agent_message.properties.state = "partial"
285
- agent_message = await send_message_method(message=agent_message, skip_db_update=True)
315
+ # For streaming, send token event if callback is available
316
+ # Note: we should expect the callback, but we keep it optional for backwards compatibility
317
+ # as of v1.6.5
318
+ if output_text and output_text.strip() and send_token_callback and message_id:
319
+ await asyncio.to_thread(
320
+ send_token_callback,
321
+ data={
322
+ "chunk": output_text,
323
+ "id": str(message_id),
324
+ },
325
+ )
326
+
286
327
  if not agent_message.text:
328
+ # Starts the timer when the first message is starting to be generated
287
329
  start_time = perf_counter()
288
330
  return agent_message, start_time
289
331
 
@@ -294,7 +336,7 @@ class ToolEventHandler(Protocol):
294
336
  event: dict[str, Any],
295
337
  agent_message: Message,
296
338
  tool_blocks_map: dict[str, ContentBlock],
297
- send_message_method: SendMessageFunctionType,
339
+ send_message_callback: SendMessageFunctionType,
298
340
  start_time: float,
299
341
  ) -> tuple[Message, float]: ...
300
342
 
@@ -304,8 +346,12 @@ class ChainEventHandler(Protocol):
304
346
  self,
305
347
  event: dict[str, Any],
306
348
  agent_message: Message,
307
- send_message_method: SendMessageFunctionType,
349
+ send_message_callback: SendMessageFunctionType,
350
+ send_token_callback: OnTokenFunctionType | None,
308
351
  start_time: float,
352
+ *,
353
+ had_streaming: bool = False,
354
+ message_id: str | None = None,
309
355
  ) -> tuple[Message, float]: ...
310
356
 
311
357
 
@@ -329,7 +375,8 @@ TOOL_EVENT_HANDLERS: dict[str, ToolEventHandler] = {
329
375
  async def process_agent_events(
330
376
  agent_executor: AsyncIterator[dict[str, Any]],
331
377
  agent_message: Message,
332
- send_message_method: SendMessageFunctionType,
378
+ send_message_callback: SendMessageFunctionType,
379
+ send_token_callback: OnTokenFunctionType | None = None,
333
380
  ) -> Message:
334
381
  """Process agent events and return the final output."""
335
382
  if isinstance(agent_message.properties, dict):
@@ -337,26 +384,46 @@ async def process_agent_events(
337
384
  else:
338
385
  agent_message.properties.icon = "Bot"
339
386
  agent_message.properties.state = "partial"
340
- # Store the initial message
341
- agent_message = await send_message_method(message=agent_message)
387
+ # Store the initial message and capture the message id
388
+ agent_message = await send_message_callback(message=agent_message)
389
+ # Capture the original message id - this must stay consistent throughout if streaming
390
+ initial_message_id = agent_message.id
342
391
  try:
343
392
  # Create a mapping of run_ids to tool contents
344
393
  tool_blocks_map: dict[str, ToolContent] = {}
394
+ had_streaming = False
345
395
  start_time = perf_counter()
396
+
346
397
  async for event in agent_executor:
347
398
  if event["event"] in TOOL_EVENT_HANDLERS:
348
399
  tool_handler = TOOL_EVENT_HANDLERS[event["event"]]
349
400
  # Use skip_db_update=True during streaming to avoid DB round-trips
350
401
  agent_message, start_time = await tool_handler(
351
- event, agent_message, tool_blocks_map, send_message_method, start_time
402
+ event, agent_message, tool_blocks_map, send_message_callback, start_time
352
403
  )
353
404
  elif event["event"] in CHAIN_EVENT_HANDLERS:
354
405
  chain_handler = CHAIN_EVENT_HANDLERS[event["event"]]
355
- # Use skip_db_update=True during streaming to avoid DB round-trips
356
- agent_message, start_time = await chain_handler(event, agent_message, send_message_method, start_time)
406
+
407
+ # Check if this is a streaming event
408
+ if event["event"] in ("on_chain_stream", "on_chat_model_stream"):
409
+ had_streaming = True
410
+ agent_message, start_time = await chain_handler(
411
+ event,
412
+ agent_message,
413
+ send_message_callback,
414
+ send_token_callback,
415
+ start_time,
416
+ had_streaming=had_streaming,
417
+ message_id=initial_message_id,
418
+ )
419
+ else:
420
+ agent_message, start_time = await chain_handler(
421
+ event, agent_message, send_message_callback, None, start_time, had_streaming=had_streaming
422
+ )
423
+
357
424
  agent_message.properties.state = "complete"
358
425
  # Final DB update with the complete message (skip_db_update=False by default)
359
- agent_message = await send_message_method(message=agent_message)
426
+ agent_message = await send_message_callback(message=agent_message)
360
427
  except Exception as e:
361
428
  raise ExceptionWithMessageError(agent_message, str(e)) from e
362
429
  return await Message.create(**agent_message.model_dump())
@@ -14,14 +14,18 @@ class ModelProvidersDict(TypedDict):
14
14
  is_active: bool
15
15
 
16
16
 
17
- def get_filtered_inputs(component_class):
17
+ def get_filtered_inputs(component_class, provider_name: str | None = None):
18
18
  base_input_names = {field.name for field in LCModelComponent.get_base_inputs()}
19
19
  component_instance = component_class()
20
20
 
21
- return [process_inputs(input_) for input_ in component_instance.inputs if input_.name not in base_input_names]
21
+ return [
22
+ process_inputs(input_, provider_name)
23
+ for input_ in component_instance.inputs
24
+ if input_.name not in base_input_names
25
+ ]
22
26
 
23
27
 
24
- def process_inputs(component_data: Input):
28
+ def process_inputs(component_data: Input, provider_name: str | None = None):
25
29
  """Processes and modifies an input configuration based on its type or name.
26
30
 
27
31
  Adjusts properties such as value, advanced status, real-time refresh, and additional information for specific
@@ -29,6 +33,7 @@ def process_inputs(component_data: Input):
29
33
 
30
34
  Args:
31
35
  component_data: The input configuration to process.
36
+ provider_name: The name of the provider to process the inputs for.
32
37
 
33
38
  Returns:
34
39
  The modified input configuration.
@@ -43,9 +48,11 @@ def process_inputs(component_data: Input):
43
48
  component_data.advanced = True
44
49
  component_data.value = True
45
50
  elif component_data.name in {"temperature", "base_url"}:
46
- component_data = set_advanced_true(component_data)
51
+ if provider_name not in ["IBM watsonx.ai", "Ollama"]:
52
+ component_data = set_advanced_true(component_data)
47
53
  elif component_data.name == "model_name":
48
- component_data = set_real_time_refresh_false(component_data)
54
+ if provider_name not in ["IBM watsonx.ai"]:
55
+ component_data = set_real_time_refresh_false(component_data)
49
56
  component_data = add_combobox_true(component_data)
50
57
  component_data = add_info(
51
58
  component_data,
@@ -79,6 +86,28 @@ def create_input_fields_dict(inputs: list[Input], prefix: str) -> dict[str, Inpu
79
86
  return {f"{prefix}{input_.name}": input_.to_dict() for input_ in inputs}
80
87
 
81
88
 
89
+ def _get_ollama_inputs_and_fields():
90
+ try:
91
+ from lfx.components.ollama.ollama import ChatOllamaComponent
92
+
93
+ ollama_inputs = get_filtered_inputs(ChatOllamaComponent, provider_name="Ollama")
94
+ except ImportError as e:
95
+ msg = "Ollama is not installed. Please install it with `pip install langchain-ollama`."
96
+ raise ImportError(msg) from e
97
+ return ollama_inputs, create_input_fields_dict(ollama_inputs, "")
98
+
99
+
100
+ def _get_watsonx_inputs_and_fields():
101
+ try:
102
+ from lfx.components.ibm.watsonx import WatsonxAIComponent
103
+
104
+ watsonx_inputs = get_filtered_inputs(WatsonxAIComponent, provider_name="IBM watsonx.ai")
105
+ except ImportError as e:
106
+ msg = "IBM watsonx.ai is not installed. Please install it with `pip install langchain-ibm-watsonx`."
107
+ raise ImportError(msg) from e
108
+ return watsonx_inputs, create_input_fields_dict(watsonx_inputs, "")
109
+
110
+
82
111
  def _get_google_generative_ai_inputs_and_fields():
83
112
  try:
84
113
  from lfx.components.google.google_generative_ai import GoogleGenerativeAIComponent
@@ -293,6 +322,36 @@ try:
293
322
  except ImportError:
294
323
  pass
295
324
 
325
+ try:
326
+ from lfx.components.ibm.watsonx import WatsonxAIComponent
327
+
328
+ watsonx_inputs, watsonx_fields = _get_watsonx_inputs_and_fields()
329
+ MODEL_PROVIDERS_DICT["IBM watsonx.ai"] = {
330
+ "fields": watsonx_fields,
331
+ "inputs": watsonx_inputs,
332
+ "prefix": "",
333
+ "component_class": WatsonxAIComponent(),
334
+ "icon": WatsonxAIComponent.icon,
335
+ "is_active": True,
336
+ }
337
+ except ImportError:
338
+ pass
339
+
340
+ try:
341
+ from lfx.components.ollama.ollama import ChatOllamaComponent
342
+
343
+ ollama_inputs, ollama_fields = _get_ollama_inputs_and_fields()
344
+ MODEL_PROVIDERS_DICT["Ollama"] = {
345
+ "fields": ollama_fields,
346
+ "inputs": ollama_inputs,
347
+ "prefix": "",
348
+ "component_class": ChatOllamaComponent(),
349
+ "icon": ChatOllamaComponent.icon,
350
+ "is_active": True,
351
+ }
352
+ except ImportError:
353
+ pass
354
+
296
355
  # Expose only active providers ----------------------------------------------
297
356
  ACTIVE_MODEL_PROVIDERS_DICT: dict[str, ModelProvidersDict] = {
298
357
  name: prov for name, prov in MODEL_PROVIDERS_DICT.items() if prov.get("is_active", True)
@@ -302,10 +361,18 @@ MODEL_PROVIDERS: list[str] = list(ACTIVE_MODEL_PROVIDERS_DICT.keys())
302
361
 
303
362
  ALL_PROVIDER_FIELDS: list[str] = [field for prov in ACTIVE_MODEL_PROVIDERS_DICT.values() for field in prov["fields"]]
304
363
 
305
- MODEL_DYNAMIC_UPDATE_FIELDS = ["api_key", "model", "tool_model_enabled", "base_url", "model_name"]
364
+ MODEL_DYNAMIC_UPDATE_FIELDS = [
365
+ "api_key",
366
+ "model",
367
+ "tool_model_enabled",
368
+ "base_url",
369
+ "model_name",
370
+ "watsonx_endpoint",
371
+ "url",
372
+ ]
306
373
 
307
374
  MODELS_METADATA = {name: {"icon": prov["icon"]} for name, prov in ACTIVE_MODEL_PROVIDERS_DICT.items()}
308
375
 
309
- MODEL_PROVIDERS_LIST = ["Anthropic", "Google Generative AI", "OpenAI"]
376
+ MODEL_PROVIDERS_LIST = ["Anthropic", "Google Generative AI", "OpenAI", "IBM watsonx.ai", "Ollama"]
310
377
 
311
378
  MODEL_OPTIONS_METADATA = [MODELS_METADATA[key] for key in MODEL_PROVIDERS_LIST if key in MODELS_METADATA]
@@ -47,3 +47,6 @@ URL_LIST = [
47
47
  "http://127.0.0.1:11434",
48
48
  "http://0.0.0.0:11434",
49
49
  ]
50
+
51
+
52
+ DEFAULT_OLLAMA_API_URL = "https://ollama.com"