MindsDB 25.5.4.0__py3-none-any.whl → 25.5.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +8 -8
- mindsdb/api/a2a/__main__.py +38 -8
- mindsdb/api/a2a/run_a2a.py +10 -53
- mindsdb/api/a2a/task_manager.py +19 -53
- mindsdb/api/executor/command_executor.py +147 -291
- mindsdb/api/http/namespaces/config.py +61 -86
- mindsdb/integrations/handlers/byom_handler/requirements.txt +1 -2
- mindsdb/integrations/handlers/lancedb_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/litellm_handler/litellm_handler.py +37 -20
- mindsdb/integrations/libs/llm/config.py +13 -0
- mindsdb/integrations/libs/llm/utils.py +37 -65
- mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +230 -227
- mindsdb/interfaces/agents/constants.py +17 -13
- mindsdb/interfaces/agents/langchain_agent.py +93 -94
- mindsdb/interfaces/knowledge_base/controller.py +230 -221
- mindsdb/utilities/config.py +43 -84
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.1.dist-info}/METADATA +261 -259
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.1.dist-info}/RECORD +21 -25
- mindsdb/api/a2a/a2a_client.py +0 -439
- mindsdb/api/a2a/common/client/__init__.py +0 -4
- mindsdb/api/a2a/common/client/card_resolver.py +0 -21
- mindsdb/api/a2a/common/client/client.py +0 -86
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.1.dist-info}/WHEEL +0 -0
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.1.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.1.dist-info}/top_level.txt +0 -0
|
@@ -11,10 +11,8 @@ import pandas as pd
|
|
|
11
11
|
from langchain.agents import AgentExecutor
|
|
12
12
|
from langchain.agents.initialize import initialize_agent
|
|
13
13
|
from langchain.chains.conversation.memory import ConversationSummaryBufferMemory
|
|
14
|
-
from langchain_community.chat_models import
|
|
15
|
-
|
|
16
|
-
ChatLiteLLM,
|
|
17
|
-
ChatOllama)
|
|
14
|
+
from langchain_community.chat_models import ChatAnyscale, ChatLiteLLM, ChatOllama
|
|
15
|
+
from langchain_writer import ChatWriter
|
|
18
16
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
19
17
|
from langchain_core.agents import AgentAction, AgentStep
|
|
20
18
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
|
@@ -27,7 +25,9 @@ from langchain_core.tools import Tool
|
|
|
27
25
|
from mindsdb.integrations.libs.llm.utils import get_llm_config
|
|
28
26
|
from mindsdb.integrations.utilities.handler_utils import get_api_key
|
|
29
27
|
from mindsdb.integrations.utilities.rag.settings import DEFAULT_RAG_PROMPT_TEMPLATE
|
|
30
|
-
from mindsdb.interfaces.agents.event_dispatch_callback_handler import
|
|
28
|
+
from mindsdb.interfaces.agents.event_dispatch_callback_handler import (
|
|
29
|
+
EventDispatchCallbackHandler,
|
|
30
|
+
)
|
|
31
31
|
from mindsdb.interfaces.agents.constants import AGENT_CHUNK_POLLING_INTERVAL_SECONDS
|
|
32
32
|
from mindsdb.utilities import log
|
|
33
33
|
from mindsdb.utilities.context_executor import ContextThreadPoolExecutor
|
|
@@ -54,8 +54,10 @@ from mindsdb.interfaces.agents.constants import (
|
|
|
54
54
|
NVIDIA_NIM_CHAT_MODELS,
|
|
55
55
|
USER_COLUMN,
|
|
56
56
|
ASSISTANT_COLUMN,
|
|
57
|
-
CONTEXT_COLUMN,
|
|
58
|
-
|
|
57
|
+
CONTEXT_COLUMN,
|
|
58
|
+
TRACE_ID_COLUMN,
|
|
59
|
+
DEFAULT_AGENT_SYSTEM_PROMPT,
|
|
60
|
+
WRITER_CHAT_MODELS,
|
|
59
61
|
)
|
|
60
62
|
from mindsdb.interfaces.skills.skill_tool import skill_tool, SkillData
|
|
61
63
|
from langchain_anthropic import ChatAnthropic
|
|
@@ -88,6 +90,9 @@ def get_llm_provider(args: Dict) -> str:
|
|
|
88
90
|
return "nvidia_nim"
|
|
89
91
|
if args["model_name"] in GOOGLE_GEMINI_CHAT_MODELS:
|
|
90
92
|
return "google"
|
|
93
|
+
# Check for writer models
|
|
94
|
+
if args["model_name"] in WRITER_CHAT_MODELS:
|
|
95
|
+
return "writer"
|
|
91
96
|
|
|
92
97
|
# For vLLM, require explicit provider specification
|
|
93
98
|
raise ValueError("Invalid model name. Please define a supported llm provider")
|
|
@@ -101,21 +106,21 @@ def get_embedding_model_provider(args: Dict) -> str:
|
|
|
101
106
|
# Check for explicit embedding model provider
|
|
102
107
|
if "embedding_model_provider" in args:
|
|
103
108
|
provider = args["embedding_model_provider"]
|
|
104
|
-
if provider ==
|
|
105
|
-
if not (args.get(
|
|
109
|
+
if provider == "vllm":
|
|
110
|
+
if not (args.get("openai_api_base") and args.get("model")):
|
|
106
111
|
raise ValueError(
|
|
107
112
|
"VLLM embeddings configuration error:\n"
|
|
108
113
|
"- Missing required parameters: 'openai_api_base' and/or 'model'\n"
|
|
109
114
|
"- Example: openai_api_base='http://localhost:8003/v1', model='your-model-name'"
|
|
110
115
|
)
|
|
111
116
|
logger.info("Using custom VLLMEmbeddings class")
|
|
112
|
-
return
|
|
117
|
+
return "vllm"
|
|
113
118
|
return provider
|
|
114
119
|
|
|
115
120
|
# Check if LLM provider is vLLM
|
|
116
|
-
llm_provider = args.get(
|
|
117
|
-
if llm_provider ==
|
|
118
|
-
if not (args.get(
|
|
121
|
+
llm_provider = args.get("provider", DEFAULT_EMBEDDINGS_MODEL_PROVIDER)
|
|
122
|
+
if llm_provider == "vllm":
|
|
123
|
+
if not (args.get("openai_api_base") and args.get("model")):
|
|
119
124
|
raise ValueError(
|
|
120
125
|
"VLLM embeddings configuration error:\n"
|
|
121
126
|
"- Missing required parameters: 'openai_api_base' and/or 'model'\n"
|
|
@@ -123,7 +128,7 @@ def get_embedding_model_provider(args: Dict) -> str:
|
|
|
123
128
|
"- Example: openai_api_base='http://localhost:8003/v1', model='your-model-name'"
|
|
124
129
|
)
|
|
125
130
|
logger.info("Using custom VLLMEmbeddings class")
|
|
126
|
-
return
|
|
131
|
+
return "vllm"
|
|
127
132
|
|
|
128
133
|
# Default to LLM provider
|
|
129
134
|
return llm_provider
|
|
@@ -132,14 +137,15 @@ def get_embedding_model_provider(args: Dict) -> str:
|
|
|
132
137
|
def get_chat_model_params(args: Dict) -> Dict:
|
|
133
138
|
model_config = args.copy()
|
|
134
139
|
# Include API keys.
|
|
135
|
-
model_config["api_keys"] = {
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
llm_config = get_llm_config(
|
|
139
|
-
args.get("provider", get_llm_provider(args)), model_config
|
|
140
|
-
)
|
|
141
|
-
config_dict = llm_config.model_dump()
|
|
140
|
+
model_config["api_keys"] = {p: get_api_key(p, model_config, None, strict=False) for p in SUPPORTED_PROVIDERS}
|
|
141
|
+
llm_config = get_llm_config(args.get("provider", get_llm_provider(args)), model_config)
|
|
142
|
+
config_dict = llm_config.model_dump(by_alias=True)
|
|
142
143
|
config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
|
144
|
+
|
|
145
|
+
# If provider is writer, ensure the API key is passed as 'api_key'
|
|
146
|
+
if args.get("provider") == "writer" and "writer_api_key" in config_dict:
|
|
147
|
+
config_dict["api_key"] = config_dict.pop("writer_api_key")
|
|
148
|
+
|
|
143
149
|
return config_dict
|
|
144
150
|
|
|
145
151
|
|
|
@@ -167,9 +173,11 @@ def create_chat_model(args: Dict):
|
|
|
167
173
|
return ChatNVIDIA(**model_kwargs)
|
|
168
174
|
if args["provider"] == "google":
|
|
169
175
|
return ChatGoogleGenerativeAI(**model_kwargs)
|
|
176
|
+
if args["provider"] == "writer":
|
|
177
|
+
return ChatWriter(**model_kwargs)
|
|
170
178
|
if args["provider"] == "mindsdb":
|
|
171
179
|
return ChatMindsdb(**model_kwargs)
|
|
172
|
-
raise ValueError(f
|
|
180
|
+
raise ValueError(f"Unknown provider: {args['provider']}")
|
|
173
181
|
|
|
174
182
|
|
|
175
183
|
def prepare_prompts(df, base_template, input_variables, user_column=USER_COLUMN):
|
|
@@ -178,13 +186,13 @@ def prepare_prompts(df, base_template, input_variables, user_column=USER_COLUMN)
|
|
|
178
186
|
# Combine system prompt with user-provided template
|
|
179
187
|
base_template = f"{DEFAULT_AGENT_SYSTEM_PROMPT}\n\n{base_template}"
|
|
180
188
|
|
|
181
|
-
base_template = base_template.replace(
|
|
189
|
+
base_template = base_template.replace("{{", "{").replace("}}", "}")
|
|
182
190
|
prompts = []
|
|
183
191
|
|
|
184
192
|
for i, row in df.iterrows():
|
|
185
193
|
if i not in empty_prompt_ids:
|
|
186
194
|
prompt = PromptTemplate(input_variables=input_variables, template=base_template)
|
|
187
|
-
kwargs = {col: row[col] if row[col] is not None else
|
|
195
|
+
kwargs = {col: row[col] if row[col] is not None else "" for col in input_variables}
|
|
188
196
|
prompts.append(prompt.format(**kwargs))
|
|
189
197
|
elif row.get(user_column):
|
|
190
198
|
prompts.append(row[user_column])
|
|
@@ -218,9 +226,7 @@ def process_chunk(chunk):
|
|
|
218
226
|
|
|
219
227
|
|
|
220
228
|
class LangchainAgent:
|
|
221
|
-
|
|
222
229
|
def __init__(self, agent: db.Agents, model: dict = None):
|
|
223
|
-
|
|
224
230
|
self.agent = agent
|
|
225
231
|
self.model = model
|
|
226
232
|
|
|
@@ -243,18 +249,14 @@ class LangchainAgent:
|
|
|
243
249
|
args = self.agent.params.copy()
|
|
244
250
|
args["model_name"] = self.agent.model_name
|
|
245
251
|
args["provider"] = self.agent.provider
|
|
246
|
-
args["embedding_model_provider"] = args.get(
|
|
247
|
-
"embedding_model", get_embedding_model_provider(args)
|
|
248
|
-
)
|
|
252
|
+
args["embedding_model_provider"] = args.get("embedding_model", get_embedding_model_provider(args))
|
|
249
253
|
|
|
250
254
|
# agent is using current langchain model
|
|
251
255
|
if self.agent.provider == "mindsdb":
|
|
252
256
|
args["model_name"] = self.agent.model_name
|
|
253
257
|
|
|
254
258
|
# get prompt
|
|
255
|
-
prompt_template = (
|
|
256
|
-
self.model["problem_definition"].get("using", {}).get("prompt_template")
|
|
257
|
-
)
|
|
259
|
+
prompt_template = self.model["problem_definition"].get("using", {}).get("prompt_template")
|
|
258
260
|
if prompt_template is not None:
|
|
259
261
|
# only update prompt_template if it is set on the model
|
|
260
262
|
args["prompt_template"] = prompt_template
|
|
@@ -263,24 +265,23 @@ class LangchainAgent:
|
|
|
263
265
|
if args.get("mode") == "retrieval":
|
|
264
266
|
args["prompt_template"] = DEFAULT_RAG_PROMPT_TEMPLATE
|
|
265
267
|
else:
|
|
266
|
-
raise ValueError(
|
|
267
|
-
"Please provide a `prompt_template` or set `mode=retrieval`"
|
|
268
|
-
)
|
|
268
|
+
raise ValueError("Please provide a `prompt_template` or set `mode=retrieval`")
|
|
269
269
|
|
|
270
270
|
return args
|
|
271
271
|
|
|
272
272
|
def get_metadata(self) -> Dict:
|
|
273
273
|
return {
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
274
|
+
"provider": self.provider,
|
|
275
|
+
"model_name": self.args["model_name"],
|
|
276
|
+
"embedding_model_provider": self.args.get(
|
|
277
|
+
"embedding_model_provider", get_embedding_model_provider(self.args)
|
|
278
|
+
),
|
|
279
|
+
"skills": get_skills(self.agent),
|
|
280
|
+
"user_id": ctx.user_id,
|
|
281
|
+
"session_id": ctx.session_id,
|
|
282
|
+
"company_id": ctx.company_id,
|
|
283
|
+
"user_class": ctx.user_class,
|
|
284
|
+
"email_confirmed": ctx.email_confirmed,
|
|
284
285
|
}
|
|
285
286
|
|
|
286
287
|
def get_tags(self) -> List:
|
|
@@ -289,14 +290,13 @@ class LangchainAgent:
|
|
|
289
290
|
]
|
|
290
291
|
|
|
291
292
|
def get_completion(self, messages, stream: bool = False):
|
|
292
|
-
|
|
293
293
|
# Get metadata and tags to be used in the trace
|
|
294
294
|
metadata = self.get_metadata()
|
|
295
295
|
tags = self.get_tags()
|
|
296
296
|
|
|
297
297
|
# Set up trace for the API completion in Langfuse
|
|
298
298
|
self.langfuse_client_wrapper.setup_trace(
|
|
299
|
-
name=
|
|
299
|
+
name="api-completion",
|
|
300
300
|
input=messages,
|
|
301
301
|
tags=tags,
|
|
302
302
|
metadata=metadata,
|
|
@@ -305,9 +305,7 @@ class LangchainAgent:
|
|
|
305
305
|
)
|
|
306
306
|
|
|
307
307
|
# Set up trace for the run completion in Langfuse
|
|
308
|
-
self.run_completion_span = self.langfuse_client_wrapper.start_span(
|
|
309
|
-
name='run-completion',
|
|
310
|
-
input=messages)
|
|
308
|
+
self.run_completion_span = self.langfuse_client_wrapper.start_span(name="run-completion", input=messages)
|
|
311
309
|
|
|
312
310
|
if stream:
|
|
313
311
|
return self._get_completion_stream(messages)
|
|
@@ -345,7 +343,7 @@ class LangchainAgent:
|
|
|
345
343
|
|
|
346
344
|
df = pd.DataFrame(messages)
|
|
347
345
|
|
|
348
|
-
self.embedding_model_provider = args.get(
|
|
346
|
+
self.embedding_model_provider = args.get("embedding_model_provider", get_embedding_model_provider(args))
|
|
349
347
|
# Back compatibility for old models
|
|
350
348
|
self.provider = args.get("provider", get_llm_provider(args))
|
|
351
349
|
|
|
@@ -398,7 +396,7 @@ class LangchainAgent:
|
|
|
398
396
|
agent=agent_type,
|
|
399
397
|
# Use custom output parser to handle flaky LLMs that don't ALWAYS conform to output format.
|
|
400
398
|
agent_kwargs={"output_parser": SafeOutputParser()},
|
|
401
|
-
# Calls the agent
|
|
399
|
+
# Calls the agent's LLM Chain one final time to generate a final answer based on the previous steps
|
|
402
400
|
early_stopping_method="generate",
|
|
403
401
|
handle_parsing_errors=self._handle_parsing_errors,
|
|
404
402
|
# Timeout per agent invocation.
|
|
@@ -406,11 +404,9 @@ class LangchainAgent:
|
|
|
406
404
|
"timeout_seconds",
|
|
407
405
|
args.get("timeout_seconds", DEFAULT_AGENT_TIMEOUT_SECONDS),
|
|
408
406
|
),
|
|
409
|
-
max_iterations=args.get(
|
|
410
|
-
"max_iterations", args.get("max_iterations", DEFAULT_MAX_ITERATIONS)
|
|
411
|
-
),
|
|
407
|
+
max_iterations=args.get("max_iterations", args.get("max_iterations", DEFAULT_MAX_ITERATIONS)),
|
|
412
408
|
memory=memory,
|
|
413
|
-
verbose=args.get("verbose", args.get("verbose", False))
|
|
409
|
+
verbose=args.get("verbose", args.get("verbose", False)),
|
|
414
410
|
)
|
|
415
411
|
return agent_executor
|
|
416
412
|
|
|
@@ -422,7 +418,7 @@ class LangchainAgent:
|
|
|
422
418
|
type=rel.skill.type,
|
|
423
419
|
params=rel.skill.params,
|
|
424
420
|
project_id=rel.skill.project_id,
|
|
425
|
-
agent_tables_list=(rel.parameters or {}).get(
|
|
421
|
+
agent_tables_list=(rel.parameters or {}).get("tables"),
|
|
426
422
|
)
|
|
427
423
|
for rel in self.agent.skills_relationships
|
|
428
424
|
]
|
|
@@ -513,21 +509,22 @@ AI: {response}"""
|
|
|
513
509
|
return f"Agent failed with error:\n{str(error)}..."
|
|
514
510
|
|
|
515
511
|
def run_agent(self, df: pd.DataFrame, agent: AgentExecutor, args: Dict) -> pd.DataFrame:
|
|
516
|
-
base_template = args.get(
|
|
517
|
-
return_context = args.get(
|
|
512
|
+
base_template = args.get("prompt_template", args["prompt_template"])
|
|
513
|
+
return_context = args.get("return_context", True)
|
|
518
514
|
input_variables = re.findall(r"{{(.*?)}}", base_template)
|
|
519
515
|
|
|
520
|
-
prompts, empty_prompt_ids = prepare_prompts(
|
|
521
|
-
|
|
516
|
+
prompts, empty_prompt_ids = prepare_prompts(
|
|
517
|
+
df, base_template, input_variables, args.get("user_column", USER_COLUMN)
|
|
518
|
+
)
|
|
522
519
|
|
|
523
520
|
def _invoke_agent_executor_with_prompt(agent_executor, prompt):
|
|
524
521
|
if not prompt:
|
|
525
522
|
return {CONTEXT_COLUMN: [], ASSISTANT_COLUMN: ""}
|
|
526
523
|
try:
|
|
527
524
|
callbacks, context_callback = prepare_callbacks(self, args)
|
|
528
|
-
result = agent_executor.invoke(prompt, config={
|
|
525
|
+
result = agent_executor.invoke(prompt, config={"callbacks": callbacks})
|
|
529
526
|
captured_context = context_callback.get_contexts()
|
|
530
|
-
output = result[
|
|
527
|
+
output = result["output"] if isinstance(result, dict) and "output" in result else str(result)
|
|
531
528
|
return {CONTEXT_COLUMN: captured_context, ASSISTANT_COLUMN: output}
|
|
532
529
|
except Exception as e:
|
|
533
530
|
error_message = str(e)
|
|
@@ -536,7 +533,10 @@ AI: {response}"""
|
|
|
536
533
|
# Format API key error more clearly
|
|
537
534
|
logger.error(f"API Key Error: {error_message}")
|
|
538
535
|
error_message = f"API Key Error: {error_message}"
|
|
539
|
-
return {
|
|
536
|
+
return {
|
|
537
|
+
CONTEXT_COLUMN: [],
|
|
538
|
+
ASSISTANT_COLUMN: handle_agent_error(e, error_message),
|
|
539
|
+
}
|
|
540
540
|
|
|
541
541
|
completions = []
|
|
542
542
|
contexts = []
|
|
@@ -545,10 +545,7 @@ AI: {response}"""
|
|
|
545
545
|
agent_timeout_seconds = args.get("timeout", DEFAULT_AGENT_TIMEOUT_SECONDS)
|
|
546
546
|
|
|
547
547
|
with ContextThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
548
|
-
futures = [
|
|
549
|
-
executor.submit(_invoke_agent_executor_with_prompt, agent, prompt)
|
|
550
|
-
for prompt in prompts
|
|
551
|
-
]
|
|
548
|
+
futures = [executor.submit(_invoke_agent_executor_with_prompt, agent, prompt) for prompt in prompts]
|
|
552
549
|
try:
|
|
553
550
|
for future in as_completed(futures, timeout=agent_timeout_seconds):
|
|
554
551
|
result = future.result()
|
|
@@ -562,9 +559,7 @@ AI: {response}"""
|
|
|
562
559
|
contexts.append(result[CONTEXT_COLUMN])
|
|
563
560
|
except TimeoutError:
|
|
564
561
|
timeout_message = "I'm sorry! I couldn't come up with a response in time. Please try again."
|
|
565
|
-
logger.warning(
|
|
566
|
-
f"Agent execution timed out after {agent_timeout_seconds} seconds"
|
|
567
|
-
)
|
|
562
|
+
logger.warning(f"Agent execution timed out after {agent_timeout_seconds} seconds")
|
|
568
563
|
for _ in range(len(futures) - len(completions)):
|
|
569
564
|
completions.append(timeout_message)
|
|
570
565
|
contexts.append([])
|
|
@@ -578,10 +573,8 @@ AI: {response}"""
|
|
|
578
573
|
pred_df = pd.DataFrame(
|
|
579
574
|
{
|
|
580
575
|
ASSISTANT_COLUMN: completions,
|
|
581
|
-
CONTEXT_COLUMN: [
|
|
582
|
-
|
|
583
|
-
], # Serialize context to JSON string
|
|
584
|
-
TRACE_ID_COLUMN: self.langfuse_client_wrapper.get_trace_id()
|
|
576
|
+
CONTEXT_COLUMN: [json.dumps(ctx) for ctx in contexts], # Serialize context to JSON string
|
|
577
|
+
TRACE_ID_COLUMN: self.langfuse_client_wrapper.get_trace_id(),
|
|
585
578
|
}
|
|
586
579
|
)
|
|
587
580
|
|
|
@@ -591,17 +584,22 @@ AI: {response}"""
|
|
|
591
584
|
return pred_df
|
|
592
585
|
|
|
593
586
|
def add_chunk_metadata(self, chunk: Dict) -> Dict:
|
|
594
|
-
logger.debug(f
|
|
595
|
-
logger.debug(f
|
|
587
|
+
logger.debug(f"Adding metadata to chunk: {chunk}")
|
|
588
|
+
logger.debug(f"Trace ID: {self.langfuse_client_wrapper.get_trace_id()}")
|
|
596
589
|
chunk["trace_id"] = self.langfuse_client_wrapper.get_trace_id()
|
|
597
590
|
return chunk
|
|
598
591
|
|
|
599
|
-
def _stream_agent_executor(
|
|
592
|
+
def _stream_agent_executor(
|
|
593
|
+
self,
|
|
594
|
+
agent_executor: AgentExecutor,
|
|
595
|
+
prompt: str,
|
|
596
|
+
callbacks: List[BaseCallbackHandler],
|
|
597
|
+
):
|
|
600
598
|
chunk_queue = queue.Queue()
|
|
601
599
|
# Add event dispatch callback handler only to streaming completions.
|
|
602
600
|
event_dispatch_callback_handler = EventDispatchCallbackHandler(chunk_queue)
|
|
603
601
|
callbacks.append(event_dispatch_callback_handler)
|
|
604
|
-
stream_iterator = agent_executor.stream(prompt, config={
|
|
602
|
+
stream_iterator = agent_executor.stream(prompt, config={"callbacks": callbacks})
|
|
605
603
|
|
|
606
604
|
agent_executor_finished_event = threading.Event()
|
|
607
605
|
|
|
@@ -616,7 +614,10 @@ AI: {response}"""
|
|
|
616
614
|
|
|
617
615
|
# Enqueue Langchain agent streaming chunks in a separate thread to not block event chunks.
|
|
618
616
|
executor_stream_thread = threading.Thread(
|
|
619
|
-
target=stream_worker,
|
|
617
|
+
target=stream_worker,
|
|
618
|
+
daemon=True,
|
|
619
|
+
args=(ctx.dump(),),
|
|
620
|
+
name="LangchainAgent.stream_worker",
|
|
620
621
|
)
|
|
621
622
|
executor_stream_thread.start()
|
|
622
623
|
|
|
@@ -625,24 +626,24 @@ AI: {response}"""
|
|
|
625
626
|
chunk = chunk_queue.get(block=True, timeout=AGENT_CHUNK_POLLING_INTERVAL_SECONDS)
|
|
626
627
|
except queue.Empty:
|
|
627
628
|
continue
|
|
628
|
-
logger.debug(f
|
|
629
|
+
logger.debug(f"Processing streaming chunk {chunk}")
|
|
629
630
|
processed_chunk = self.process_chunk(chunk)
|
|
630
|
-
logger.info(f
|
|
631
|
+
logger.info(f"Processed chunk: {processed_chunk}")
|
|
631
632
|
yield self.add_chunk_metadata(processed_chunk)
|
|
632
633
|
chunk_queue.task_done()
|
|
633
634
|
|
|
634
635
|
def stream_agent(self, df: pd.DataFrame, agent_executor: AgentExecutor, args: Dict) -> Iterable[Dict]:
|
|
635
|
-
base_template = args.get(
|
|
636
|
+
base_template = args.get("prompt_template", args["prompt_template"])
|
|
636
637
|
input_variables = re.findall(r"{{(.*?)}}", base_template)
|
|
637
|
-
return_context = args.get(
|
|
638
|
+
return_context = args.get("return_context", True)
|
|
638
639
|
|
|
639
|
-
prompts, _ = prepare_prompts(df, base_template, input_variables, args.get(
|
|
640
|
+
prompts, _ = prepare_prompts(df, base_template, input_variables, args.get("user_column", USER_COLUMN))
|
|
640
641
|
|
|
641
642
|
callbacks, context_callback = prepare_callbacks(self, args)
|
|
642
643
|
|
|
643
644
|
yield self.add_chunk_metadata({"type": "start", "prompt": prompts[0]})
|
|
644
645
|
|
|
645
|
-
if not hasattr(agent_executor,
|
|
646
|
+
if not hasattr(agent_executor, "stream") or not callable(agent_executor.stream):
|
|
646
647
|
raise AttributeError("The agent_executor does not have a 'stream' method")
|
|
647
648
|
|
|
648
649
|
stream_iterator = self._stream_agent_executor(agent_executor, prompts[0], callbacks)
|
|
@@ -671,21 +672,19 @@ AI: {response}"""
|
|
|
671
672
|
if isinstance(chunk, AgentAction):
|
|
672
673
|
# Format agent actions properly for streaming.
|
|
673
674
|
return {
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
675
|
+
"tool": LangchainAgent.process_chunk(chunk.tool),
|
|
676
|
+
"tool_input": LangchainAgent.process_chunk(chunk.tool_input),
|
|
677
|
+
"log": LangchainAgent.process_chunk(chunk.log),
|
|
677
678
|
}
|
|
678
679
|
if isinstance(chunk, AgentStep):
|
|
679
680
|
# Format agent steps properly for streaming.
|
|
680
681
|
return {
|
|
681
|
-
|
|
682
|
-
|
|
682
|
+
"action": LangchainAgent.process_chunk(chunk.action),
|
|
683
|
+
"observation": LangchainAgent.process_chunk(chunk.observation) if chunk.observation else "",
|
|
683
684
|
}
|
|
684
685
|
if issubclass(chunk.__class__, BaseMessage):
|
|
685
686
|
# Extract content from message subclasses properly for streaming.
|
|
686
|
-
return {
|
|
687
|
-
'content': chunk.content
|
|
688
|
-
}
|
|
687
|
+
return {"content": chunk.content}
|
|
689
688
|
if isinstance(chunk, (str, int, float, bool, type(None))):
|
|
690
689
|
return chunk
|
|
691
690
|
return str(chunk)
|