MindsDB 25.5.4.0__py3-none-any.whl → 25.5.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

@@ -11,10 +11,8 @@ import pandas as pd
11
11
  from langchain.agents import AgentExecutor
12
12
  from langchain.agents.initialize import initialize_agent
13
13
  from langchain.chains.conversation.memory import ConversationSummaryBufferMemory
14
- from langchain_community.chat_models import (
15
- ChatAnyscale,
16
- ChatLiteLLM,
17
- ChatOllama)
14
+ from langchain_community.chat_models import ChatAnyscale, ChatLiteLLM, ChatOllama
15
+ from langchain_writer import ChatWriter
18
16
  from langchain_google_genai import ChatGoogleGenerativeAI
19
17
  from langchain_core.agents import AgentAction, AgentStep
20
18
  from langchain_core.callbacks.base import BaseCallbackHandler
@@ -27,7 +25,9 @@ from langchain_core.tools import Tool
27
25
  from mindsdb.integrations.libs.llm.utils import get_llm_config
28
26
  from mindsdb.integrations.utilities.handler_utils import get_api_key
29
27
  from mindsdb.integrations.utilities.rag.settings import DEFAULT_RAG_PROMPT_TEMPLATE
30
- from mindsdb.interfaces.agents.event_dispatch_callback_handler import EventDispatchCallbackHandler
28
+ from mindsdb.interfaces.agents.event_dispatch_callback_handler import (
29
+ EventDispatchCallbackHandler,
30
+ )
31
31
  from mindsdb.interfaces.agents.constants import AGENT_CHUNK_POLLING_INTERVAL_SECONDS
32
32
  from mindsdb.utilities import log
33
33
  from mindsdb.utilities.context_executor import ContextThreadPoolExecutor
@@ -54,8 +54,10 @@ from mindsdb.interfaces.agents.constants import (
54
54
  NVIDIA_NIM_CHAT_MODELS,
55
55
  USER_COLUMN,
56
56
  ASSISTANT_COLUMN,
57
- CONTEXT_COLUMN, TRACE_ID_COLUMN,
58
- DEFAULT_AGENT_SYSTEM_PROMPT
57
+ CONTEXT_COLUMN,
58
+ TRACE_ID_COLUMN,
59
+ DEFAULT_AGENT_SYSTEM_PROMPT,
60
+ WRITER_CHAT_MODELS,
59
61
  )
60
62
  from mindsdb.interfaces.skills.skill_tool import skill_tool, SkillData
61
63
  from langchain_anthropic import ChatAnthropic
@@ -88,6 +90,9 @@ def get_llm_provider(args: Dict) -> str:
88
90
  return "nvidia_nim"
89
91
  if args["model_name"] in GOOGLE_GEMINI_CHAT_MODELS:
90
92
  return "google"
93
+ # Check for writer models
94
+ if args["model_name"] in WRITER_CHAT_MODELS:
95
+ return "writer"
91
96
 
92
97
  # For vLLM, require explicit provider specification
93
98
  raise ValueError("Invalid model name. Please define a supported llm provider")
@@ -101,21 +106,21 @@ def get_embedding_model_provider(args: Dict) -> str:
101
106
  # Check for explicit embedding model provider
102
107
  if "embedding_model_provider" in args:
103
108
  provider = args["embedding_model_provider"]
104
- if provider == 'vllm':
105
- if not (args.get('openai_api_base') and args.get('model')):
109
+ if provider == "vllm":
110
+ if not (args.get("openai_api_base") and args.get("model")):
106
111
  raise ValueError(
107
112
  "VLLM embeddings configuration error:\n"
108
113
  "- Missing required parameters: 'openai_api_base' and/or 'model'\n"
109
114
  "- Example: openai_api_base='http://localhost:8003/v1', model='your-model-name'"
110
115
  )
111
116
  logger.info("Using custom VLLMEmbeddings class")
112
- return 'vllm'
117
+ return "vllm"
113
118
  return provider
114
119
 
115
120
  # Check if LLM provider is vLLM
116
- llm_provider = args.get('provider', DEFAULT_EMBEDDINGS_MODEL_PROVIDER)
117
- if llm_provider == 'vllm':
118
- if not (args.get('openai_api_base') and args.get('model')):
121
+ llm_provider = args.get("provider", DEFAULT_EMBEDDINGS_MODEL_PROVIDER)
122
+ if llm_provider == "vllm":
123
+ if not (args.get("openai_api_base") and args.get("model")):
119
124
  raise ValueError(
120
125
  "VLLM embeddings configuration error:\n"
121
126
  "- Missing required parameters: 'openai_api_base' and/or 'model'\n"
@@ -123,7 +128,7 @@ def get_embedding_model_provider(args: Dict) -> str:
123
128
  "- Example: openai_api_base='http://localhost:8003/v1', model='your-model-name'"
124
129
  )
125
130
  logger.info("Using custom VLLMEmbeddings class")
126
- return 'vllm'
131
+ return "vllm"
127
132
 
128
133
  # Default to LLM provider
129
134
  return llm_provider
@@ -132,14 +137,15 @@ def get_embedding_model_provider(args: Dict) -> str:
132
137
  def get_chat_model_params(args: Dict) -> Dict:
133
138
  model_config = args.copy()
134
139
  # Include API keys.
135
- model_config["api_keys"] = {
136
- p: get_api_key(p, model_config, None, strict=False) for p in SUPPORTED_PROVIDERS
137
- }
138
- llm_config = get_llm_config(
139
- args.get("provider", get_llm_provider(args)), model_config
140
- )
141
- config_dict = llm_config.model_dump()
140
+ model_config["api_keys"] = {p: get_api_key(p, model_config, None, strict=False) for p in SUPPORTED_PROVIDERS}
141
+ llm_config = get_llm_config(args.get("provider", get_llm_provider(args)), model_config)
142
+ config_dict = llm_config.model_dump(by_alias=True)
142
143
  config_dict = {k: v for k, v in config_dict.items() if v is not None}
144
+
145
+ # If provider is writer, ensure the API key is passed as 'api_key'
146
+ if args.get("provider") == "writer" and "writer_api_key" in config_dict:
147
+ config_dict["api_key"] = config_dict.pop("writer_api_key")
148
+
143
149
  return config_dict
144
150
 
145
151
 
@@ -167,9 +173,11 @@ def create_chat_model(args: Dict):
167
173
  return ChatNVIDIA(**model_kwargs)
168
174
  if args["provider"] == "google":
169
175
  return ChatGoogleGenerativeAI(**model_kwargs)
176
+ if args["provider"] == "writer":
177
+ return ChatWriter(**model_kwargs)
170
178
  if args["provider"] == "mindsdb":
171
179
  return ChatMindsdb(**model_kwargs)
172
- raise ValueError(f'Unknown provider: {args["provider"]}')
180
+ raise ValueError(f"Unknown provider: {args['provider']}")
173
181
 
174
182
 
175
183
  def prepare_prompts(df, base_template, input_variables, user_column=USER_COLUMN):
@@ -178,13 +186,13 @@ def prepare_prompts(df, base_template, input_variables, user_column=USER_COLUMN)
178
186
  # Combine system prompt with user-provided template
179
187
  base_template = f"{DEFAULT_AGENT_SYSTEM_PROMPT}\n\n{base_template}"
180
188
 
181
- base_template = base_template.replace('{{', '{').replace('}}', '}')
189
+ base_template = base_template.replace("{{", "{").replace("}}", "}")
182
190
  prompts = []
183
191
 
184
192
  for i, row in df.iterrows():
185
193
  if i not in empty_prompt_ids:
186
194
  prompt = PromptTemplate(input_variables=input_variables, template=base_template)
187
- kwargs = {col: row[col] if row[col] is not None else '' for col in input_variables}
195
+ kwargs = {col: row[col] if row[col] is not None else "" for col in input_variables}
188
196
  prompts.append(prompt.format(**kwargs))
189
197
  elif row.get(user_column):
190
198
  prompts.append(row[user_column])
@@ -218,9 +226,7 @@ def process_chunk(chunk):
218
226
 
219
227
 
220
228
  class LangchainAgent:
221
-
222
229
  def __init__(self, agent: db.Agents, model: dict = None):
223
-
224
230
  self.agent = agent
225
231
  self.model = model
226
232
 
@@ -243,18 +249,14 @@ class LangchainAgent:
243
249
  args = self.agent.params.copy()
244
250
  args["model_name"] = self.agent.model_name
245
251
  args["provider"] = self.agent.provider
246
- args["embedding_model_provider"] = args.get(
247
- "embedding_model", get_embedding_model_provider(args)
248
- )
252
+ args["embedding_model_provider"] = args.get("embedding_model", get_embedding_model_provider(args))
249
253
 
250
254
  # agent is using current langchain model
251
255
  if self.agent.provider == "mindsdb":
252
256
  args["model_name"] = self.agent.model_name
253
257
 
254
258
  # get prompt
255
- prompt_template = (
256
- self.model["problem_definition"].get("using", {}).get("prompt_template")
257
- )
259
+ prompt_template = self.model["problem_definition"].get("using", {}).get("prompt_template")
258
260
  if prompt_template is not None:
259
261
  # only update prompt_template if it is set on the model
260
262
  args["prompt_template"] = prompt_template
@@ -263,24 +265,23 @@ class LangchainAgent:
263
265
  if args.get("mode") == "retrieval":
264
266
  args["prompt_template"] = DEFAULT_RAG_PROMPT_TEMPLATE
265
267
  else:
266
- raise ValueError(
267
- "Please provide a `prompt_template` or set `mode=retrieval`"
268
- )
268
+ raise ValueError("Please provide a `prompt_template` or set `mode=retrieval`")
269
269
 
270
270
  return args
271
271
 
272
272
  def get_metadata(self) -> Dict:
273
273
  return {
274
- 'provider': self.provider,
275
- 'model_name': self.args["model_name"],
276
- 'embedding_model_provider': self.args.get('embedding_model_provider',
277
- get_embedding_model_provider(self.args)),
278
- 'skills': get_skills(self.agent),
279
- 'user_id': ctx.user_id,
280
- 'session_id': ctx.session_id,
281
- 'company_id': ctx.company_id,
282
- 'user_class': ctx.user_class,
283
- 'email_confirmed': ctx.email_confirmed
274
+ "provider": self.provider,
275
+ "model_name": self.args["model_name"],
276
+ "embedding_model_provider": self.args.get(
277
+ "embedding_model_provider", get_embedding_model_provider(self.args)
278
+ ),
279
+ "skills": get_skills(self.agent),
280
+ "user_id": ctx.user_id,
281
+ "session_id": ctx.session_id,
282
+ "company_id": ctx.company_id,
283
+ "user_class": ctx.user_class,
284
+ "email_confirmed": ctx.email_confirmed,
284
285
  }
285
286
 
286
287
  def get_tags(self) -> List:
@@ -289,14 +290,13 @@ class LangchainAgent:
289
290
  ]
290
291
 
291
292
  def get_completion(self, messages, stream: bool = False):
292
-
293
293
  # Get metadata and tags to be used in the trace
294
294
  metadata = self.get_metadata()
295
295
  tags = self.get_tags()
296
296
 
297
297
  # Set up trace for the API completion in Langfuse
298
298
  self.langfuse_client_wrapper.setup_trace(
299
- name='api-completion',
299
+ name="api-completion",
300
300
  input=messages,
301
301
  tags=tags,
302
302
  metadata=metadata,
@@ -305,9 +305,7 @@ class LangchainAgent:
305
305
  )
306
306
 
307
307
  # Set up trace for the run completion in Langfuse
308
- self.run_completion_span = self.langfuse_client_wrapper.start_span(
309
- name='run-completion',
310
- input=messages)
308
+ self.run_completion_span = self.langfuse_client_wrapper.start_span(name="run-completion", input=messages)
311
309
 
312
310
  if stream:
313
311
  return self._get_completion_stream(messages)
@@ -345,7 +343,7 @@ class LangchainAgent:
345
343
 
346
344
  df = pd.DataFrame(messages)
347
345
 
348
- self.embedding_model_provider = args.get('embedding_model_provider', get_embedding_model_provider(args))
346
+ self.embedding_model_provider = args.get("embedding_model_provider", get_embedding_model_provider(args))
349
347
  # Back compatibility for old models
350
348
  self.provider = args.get("provider", get_llm_provider(args))
351
349
 
@@ -398,7 +396,7 @@ class LangchainAgent:
398
396
  agent=agent_type,
399
397
  # Use custom output parser to handle flaky LLMs that don't ALWAYS conform to output format.
400
398
  agent_kwargs={"output_parser": SafeOutputParser()},
401
- # Calls the agents LLM Chain one final time to generate a final answer based on the previous steps
399
+ # Calls the agent's LLM Chain one final time to generate a final answer based on the previous steps
402
400
  early_stopping_method="generate",
403
401
  handle_parsing_errors=self._handle_parsing_errors,
404
402
  # Timeout per agent invocation.
@@ -406,11 +404,9 @@ class LangchainAgent:
406
404
  "timeout_seconds",
407
405
  args.get("timeout_seconds", DEFAULT_AGENT_TIMEOUT_SECONDS),
408
406
  ),
409
- max_iterations=args.get(
410
- "max_iterations", args.get("max_iterations", DEFAULT_MAX_ITERATIONS)
411
- ),
407
+ max_iterations=args.get("max_iterations", args.get("max_iterations", DEFAULT_MAX_ITERATIONS)),
412
408
  memory=memory,
413
- verbose=args.get("verbose", args.get("verbose", False))
409
+ verbose=args.get("verbose", args.get("verbose", False)),
414
410
  )
415
411
  return agent_executor
416
412
 
@@ -422,7 +418,7 @@ class LangchainAgent:
422
418
  type=rel.skill.type,
423
419
  params=rel.skill.params,
424
420
  project_id=rel.skill.project_id,
425
- agent_tables_list=(rel.parameters or {}).get('tables')
421
+ agent_tables_list=(rel.parameters or {}).get("tables"),
426
422
  )
427
423
  for rel in self.agent.skills_relationships
428
424
  ]
@@ -513,21 +509,22 @@ AI: {response}"""
513
509
  return f"Agent failed with error:\n{str(error)}..."
514
510
 
515
511
  def run_agent(self, df: pd.DataFrame, agent: AgentExecutor, args: Dict) -> pd.DataFrame:
516
- base_template = args.get('prompt_template', args['prompt_template'])
517
- return_context = args.get('return_context', True)
512
+ base_template = args.get("prompt_template", args["prompt_template"])
513
+ return_context = args.get("return_context", True)
518
514
  input_variables = re.findall(r"{{(.*?)}}", base_template)
519
515
 
520
- prompts, empty_prompt_ids = prepare_prompts(df, base_template, input_variables,
521
- args.get('user_column', USER_COLUMN))
516
+ prompts, empty_prompt_ids = prepare_prompts(
517
+ df, base_template, input_variables, args.get("user_column", USER_COLUMN)
518
+ )
522
519
 
523
520
  def _invoke_agent_executor_with_prompt(agent_executor, prompt):
524
521
  if not prompt:
525
522
  return {CONTEXT_COLUMN: [], ASSISTANT_COLUMN: ""}
526
523
  try:
527
524
  callbacks, context_callback = prepare_callbacks(self, args)
528
- result = agent_executor.invoke(prompt, config={'callbacks': callbacks})
525
+ result = agent_executor.invoke(prompt, config={"callbacks": callbacks})
529
526
  captured_context = context_callback.get_contexts()
530
- output = result['output'] if isinstance(result, dict) and 'output' in result else str(result)
527
+ output = result["output"] if isinstance(result, dict) and "output" in result else str(result)
531
528
  return {CONTEXT_COLUMN: captured_context, ASSISTANT_COLUMN: output}
532
529
  except Exception as e:
533
530
  error_message = str(e)
@@ -536,7 +533,10 @@ AI: {response}"""
536
533
  # Format API key error more clearly
537
534
  logger.error(f"API Key Error: {error_message}")
538
535
  error_message = f"API Key Error: {error_message}"
539
- return {CONTEXT_COLUMN: [], ASSISTANT_COLUMN: handle_agent_error(e, error_message)}
536
+ return {
537
+ CONTEXT_COLUMN: [],
538
+ ASSISTANT_COLUMN: handle_agent_error(e, error_message),
539
+ }
540
540
 
541
541
  completions = []
542
542
  contexts = []
@@ -545,10 +545,7 @@ AI: {response}"""
545
545
  agent_timeout_seconds = args.get("timeout", DEFAULT_AGENT_TIMEOUT_SECONDS)
546
546
 
547
547
  with ContextThreadPoolExecutor(max_workers=max_workers) as executor:
548
- futures = [
549
- executor.submit(_invoke_agent_executor_with_prompt, agent, prompt)
550
- for prompt in prompts
551
- ]
548
+ futures = [executor.submit(_invoke_agent_executor_with_prompt, agent, prompt) for prompt in prompts]
552
549
  try:
553
550
  for future in as_completed(futures, timeout=agent_timeout_seconds):
554
551
  result = future.result()
@@ -562,9 +559,7 @@ AI: {response}"""
562
559
  contexts.append(result[CONTEXT_COLUMN])
563
560
  except TimeoutError:
564
561
  timeout_message = "I'm sorry! I couldn't come up with a response in time. Please try again."
565
- logger.warning(
566
- f"Agent execution timed out after {agent_timeout_seconds} seconds"
567
- )
562
+ logger.warning(f"Agent execution timed out after {agent_timeout_seconds} seconds")
568
563
  for _ in range(len(futures) - len(completions)):
569
564
  completions.append(timeout_message)
570
565
  contexts.append([])
@@ -578,10 +573,8 @@ AI: {response}"""
578
573
  pred_df = pd.DataFrame(
579
574
  {
580
575
  ASSISTANT_COLUMN: completions,
581
- CONTEXT_COLUMN: [
582
- json.dumps(ctx) for ctx in contexts
583
- ], # Serialize context to JSON string
584
- TRACE_ID_COLUMN: self.langfuse_client_wrapper.get_trace_id()
576
+ CONTEXT_COLUMN: [json.dumps(ctx) for ctx in contexts], # Serialize context to JSON string
577
+ TRACE_ID_COLUMN: self.langfuse_client_wrapper.get_trace_id(),
585
578
  }
586
579
  )
587
580
 
@@ -591,17 +584,22 @@ AI: {response}"""
591
584
  return pred_df
592
585
 
593
586
  def add_chunk_metadata(self, chunk: Dict) -> Dict:
594
- logger.debug(f'Adding metadata to chunk: {chunk}')
595
- logger.debug(f'Trace ID: {self.langfuse_client_wrapper.get_trace_id()}')
587
+ logger.debug(f"Adding metadata to chunk: {chunk}")
588
+ logger.debug(f"Trace ID: {self.langfuse_client_wrapper.get_trace_id()}")
596
589
  chunk["trace_id"] = self.langfuse_client_wrapper.get_trace_id()
597
590
  return chunk
598
591
 
599
- def _stream_agent_executor(self, agent_executor: AgentExecutor, prompt: str, callbacks: List[BaseCallbackHandler]):
592
+ def _stream_agent_executor(
593
+ self,
594
+ agent_executor: AgentExecutor,
595
+ prompt: str,
596
+ callbacks: List[BaseCallbackHandler],
597
+ ):
600
598
  chunk_queue = queue.Queue()
601
599
  # Add event dispatch callback handler only to streaming completions.
602
600
  event_dispatch_callback_handler = EventDispatchCallbackHandler(chunk_queue)
603
601
  callbacks.append(event_dispatch_callback_handler)
604
- stream_iterator = agent_executor.stream(prompt, config={'callbacks': callbacks})
602
+ stream_iterator = agent_executor.stream(prompt, config={"callbacks": callbacks})
605
603
 
606
604
  agent_executor_finished_event = threading.Event()
607
605
 
@@ -616,7 +614,10 @@ AI: {response}"""
616
614
 
617
615
  # Enqueue Langchain agent streaming chunks in a separate thread to not block event chunks.
618
616
  executor_stream_thread = threading.Thread(
619
- target=stream_worker, daemon=True, args=(ctx.dump(),), name='LangchainAgent.stream_worker'
617
+ target=stream_worker,
618
+ daemon=True,
619
+ args=(ctx.dump(),),
620
+ name="LangchainAgent.stream_worker",
620
621
  )
621
622
  executor_stream_thread.start()
622
623
 
@@ -625,24 +626,24 @@ AI: {response}"""
625
626
  chunk = chunk_queue.get(block=True, timeout=AGENT_CHUNK_POLLING_INTERVAL_SECONDS)
626
627
  except queue.Empty:
627
628
  continue
628
- logger.debug(f'Processing streaming chunk {chunk}')
629
+ logger.debug(f"Processing streaming chunk {chunk}")
629
630
  processed_chunk = self.process_chunk(chunk)
630
- logger.info(f'Processed chunk: {processed_chunk}')
631
+ logger.info(f"Processed chunk: {processed_chunk}")
631
632
  yield self.add_chunk_metadata(processed_chunk)
632
633
  chunk_queue.task_done()
633
634
 
634
635
  def stream_agent(self, df: pd.DataFrame, agent_executor: AgentExecutor, args: Dict) -> Iterable[Dict]:
635
- base_template = args.get('prompt_template', args['prompt_template'])
636
+ base_template = args.get("prompt_template", args["prompt_template"])
636
637
  input_variables = re.findall(r"{{(.*?)}}", base_template)
637
- return_context = args.get('return_context', True)
638
+ return_context = args.get("return_context", True)
638
639
 
639
- prompts, _ = prepare_prompts(df, base_template, input_variables, args.get('user_column', USER_COLUMN))
640
+ prompts, _ = prepare_prompts(df, base_template, input_variables, args.get("user_column", USER_COLUMN))
640
641
 
641
642
  callbacks, context_callback = prepare_callbacks(self, args)
642
643
 
643
644
  yield self.add_chunk_metadata({"type": "start", "prompt": prompts[0]})
644
645
 
645
- if not hasattr(agent_executor, 'stream') or not callable(agent_executor.stream):
646
+ if not hasattr(agent_executor, "stream") or not callable(agent_executor.stream):
646
647
  raise AttributeError("The agent_executor does not have a 'stream' method")
647
648
 
648
649
  stream_iterator = self._stream_agent_executor(agent_executor, prompts[0], callbacks)
@@ -671,21 +672,19 @@ AI: {response}"""
671
672
  if isinstance(chunk, AgentAction):
672
673
  # Format agent actions properly for streaming.
673
674
  return {
674
- 'tool': LangchainAgent.process_chunk(chunk.tool),
675
- 'tool_input': LangchainAgent.process_chunk(chunk.tool_input),
676
- 'log': LangchainAgent.process_chunk(chunk.log)
675
+ "tool": LangchainAgent.process_chunk(chunk.tool),
676
+ "tool_input": LangchainAgent.process_chunk(chunk.tool_input),
677
+ "log": LangchainAgent.process_chunk(chunk.log),
677
678
  }
678
679
  if isinstance(chunk, AgentStep):
679
680
  # Format agent steps properly for streaming.
680
681
  return {
681
- 'action': LangchainAgent.process_chunk(chunk.action),
682
- 'observation': LangchainAgent.process_chunk(chunk.observation) if chunk.observation else ''
682
+ "action": LangchainAgent.process_chunk(chunk.action),
683
+ "observation": LangchainAgent.process_chunk(chunk.observation) if chunk.observation else "",
683
684
  }
684
685
  if issubclass(chunk.__class__, BaseMessage):
685
686
  # Extract content from message subclasses properly for streaming.
686
- return {
687
- 'content': chunk.content
688
- }
687
+ return {"content": chunk.content}
689
688
  if isinstance(chunk, (str, int, float, bool, type(None))):
690
689
  return chunk
691
690
  return str(chunk)