MindsDB 25.2.2.2__py3-none-any.whl → 25.2.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- {MindsDB-25.2.2.2.dist-info → MindsDB-25.2.4.0.dist-info}/METADATA +209 -228
- {MindsDB-25.2.2.2.dist-info → MindsDB-25.2.4.0.dist-info}/RECORD +52 -50
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +1 -11
- mindsdb/api/executor/datahub/datanodes/system_tables.py +4 -1
- mindsdb/api/http/initialize.py +8 -5
- mindsdb/api/http/namespaces/agents.py +0 -7
- mindsdb/api/http/namespaces/config.py +0 -48
- mindsdb/api/http/namespaces/databases.py +69 -1
- mindsdb/api/http/namespaces/knowledge_bases.py +1 -1
- mindsdb/api/http/namespaces/util.py +0 -28
- mindsdb/integrations/handlers/anyscale_endpoints_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/dspy_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/file_handler/file_handler.py +28 -46
- mindsdb/integrations/handlers/file_handler/tests/test_file_handler.py +8 -11
- mindsdb/integrations/handlers/langchain_embedding_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/langchain_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/llama_index_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/ms_one_drive_handler/ms_one_drive_tables.py +1 -1
- mindsdb/integrations/handlers/openai_handler/constants.py +3 -1
- mindsdb/integrations/handlers/openai_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/rag_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/ray_serve_handler/ray_serve_handler.py +33 -8
- mindsdb/integrations/handlers/timegpt_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/web_handler/urlcrawl_helpers.py +3 -2
- mindsdb/integrations/handlers/web_handler/web_handler.py +42 -33
- mindsdb/integrations/handlers/youtube_handler/__init__.py +2 -0
- mindsdb/integrations/handlers/youtube_handler/connection_args.py +32 -0
- mindsdb/integrations/libs/llm/utils.py +5 -0
- mindsdb/integrations/libs/process_cache.py +2 -2
- mindsdb/integrations/utilities/files/file_reader.py +66 -14
- mindsdb/integrations/utilities/rag/chains/local_context_summarizer_chain.py +227 -0
- mindsdb/interfaces/agents/agents_controller.py +3 -3
- mindsdb/interfaces/agents/callback_handlers.py +52 -5
- mindsdb/interfaces/agents/langchain_agent.py +5 -3
- mindsdb/interfaces/database/database.py +1 -1
- mindsdb/interfaces/database/integrations.py +1 -1
- mindsdb/interfaces/file/file_controller.py +140 -11
- mindsdb/interfaces/jobs/scheduler.py +1 -1
- mindsdb/interfaces/knowledge_base/preprocessing/constants.py +2 -2
- mindsdb/interfaces/skills/skills_controller.py +2 -2
- mindsdb/interfaces/skills/sql_agent.py +6 -1
- mindsdb/interfaces/storage/db.py +1 -12
- mindsdb/migrations/versions/2025-02-09_4943359e354a_file_metadata.py +31 -0
- mindsdb/migrations/versions/2025-02-10_6ab9903fc59a_del_log_table.py +33 -0
- mindsdb/utilities/config.py +1 -0
- mindsdb/utilities/log.py +17 -2
- mindsdb/utilities/ml_task_queue/consumer.py +4 -2
- mindsdb/utilities/render/sqlalchemy_render.py +15 -5
- mindsdb/utilities/log_controller.py +0 -39
- mindsdb/utilities/telemetry.py +0 -44
- {MindsDB-25.2.2.2.dist-info → MindsDB-25.2.4.0.dist-info}/LICENSE +0 -0
- {MindsDB-25.2.2.2.dist-info → MindsDB-25.2.4.0.dist-info}/WHEEL +0 -0
- {MindsDB-25.2.2.2.dist-info → MindsDB-25.2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -4,6 +4,7 @@ import csv
|
|
|
4
4
|
from io import BytesIO, StringIO, IOBase
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
import codecs
|
|
7
|
+
from typing import List
|
|
7
8
|
|
|
8
9
|
import filetype
|
|
9
10
|
import pandas as pd
|
|
@@ -65,6 +66,7 @@ def decode(file_obj: IOBase) -> StringIO:
|
|
|
65
66
|
class FormatDetector:
|
|
66
67
|
|
|
67
68
|
supported_formats = ['parquet', 'csv', 'xlsx', 'pdf', 'json', 'txt']
|
|
69
|
+
multipage_formats = ['xlsx']
|
|
68
70
|
|
|
69
71
|
def __init__(
|
|
70
72
|
self,
|
|
@@ -200,16 +202,62 @@ class FormatDetector:
|
|
|
200
202
|
|
|
201
203
|
class FileReader(FormatDetector):
|
|
202
204
|
|
|
203
|
-
def
|
|
205
|
+
def _get_fnc(self):
|
|
204
206
|
format = self.get_format()
|
|
205
|
-
|
|
206
207
|
func = getattr(self, f'read_{format}', None)
|
|
207
208
|
if func is None:
|
|
208
209
|
raise FileDetectError(f'Unsupported format: {format}')
|
|
210
|
+
return func
|
|
211
|
+
|
|
212
|
+
def get_pages(self, **kwargs) -> List[str]:
|
|
213
|
+
"""
|
|
214
|
+
Get list of tables in file
|
|
215
|
+
"""
|
|
216
|
+
format = self.get_format()
|
|
217
|
+
if format not in self.multipage_formats:
|
|
218
|
+
# only one table
|
|
219
|
+
return ['main']
|
|
220
|
+
|
|
221
|
+
func = self._get_fnc()
|
|
222
|
+
self.file_obj.seek(0)
|
|
209
223
|
|
|
224
|
+
return [
|
|
225
|
+
name for name, _ in
|
|
226
|
+
func(self.file_obj, only_names=True, **kwargs)
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
def get_contents(self, **kwargs):
|
|
230
|
+
"""
|
|
231
|
+
Get all info(pages with content) from file as dict: {tablename, content}
|
|
232
|
+
"""
|
|
233
|
+
func = self._get_fnc()
|
|
210
234
|
self.file_obj.seek(0)
|
|
211
|
-
|
|
212
|
-
|
|
235
|
+
|
|
236
|
+
format = self.get_format()
|
|
237
|
+
if format not in self.multipage_formats:
|
|
238
|
+
# only one table
|
|
239
|
+
return {'main': func(self.file_obj, name=self.name, **kwargs)}
|
|
240
|
+
|
|
241
|
+
return {
|
|
242
|
+
name: df
|
|
243
|
+
for name, df in
|
|
244
|
+
func(self.file_obj, **kwargs)
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
def get_page_content(self, page_name: str = None, **kwargs) -> pd.DataFrame:
|
|
248
|
+
"""
|
|
249
|
+
Get content of a single table
|
|
250
|
+
"""
|
|
251
|
+
func = self._get_fnc()
|
|
252
|
+
self.file_obj.seek(0)
|
|
253
|
+
|
|
254
|
+
format = self.get_format()
|
|
255
|
+
if format not in self.multipage_formats:
|
|
256
|
+
# only one table
|
|
257
|
+
return func(self.file_obj, name=self.name, **kwargs)
|
|
258
|
+
|
|
259
|
+
for _, df in func(self.file_obj, name=self.name, page_name=page_name, **kwargs):
|
|
260
|
+
return df
|
|
213
261
|
|
|
214
262
|
@staticmethod
|
|
215
263
|
def _get_csv_dialect(buffer, delimiter=None) -> csv.Dialect:
|
|
@@ -304,14 +352,18 @@ class FileReader(FormatDetector):
|
|
|
304
352
|
return pd.read_parquet(file_obj)
|
|
305
353
|
|
|
306
354
|
@staticmethod
|
|
307
|
-
def read_xlsx(file_obj: BytesIO,
|
|
308
|
-
|
|
309
|
-
file_obj.seek(0)
|
|
355
|
+
def read_xlsx(file_obj: BytesIO, page_name=None, only_names=False, **kwargs):
|
|
310
356
|
with pd.ExcelFile(file_obj) as xls:
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
357
|
+
|
|
358
|
+
if page_name is not None:
|
|
359
|
+
# return specific page
|
|
360
|
+
yield page_name, pd.read_excel(xls, sheet_name=page_name)
|
|
361
|
+
|
|
362
|
+
for page_name in xls.sheet_names:
|
|
363
|
+
|
|
364
|
+
if only_names:
|
|
365
|
+
# extract only pages names
|
|
366
|
+
df = None
|
|
367
|
+
else:
|
|
368
|
+
df = pd.read_excel(xls, sheet_name=page_name)
|
|
369
|
+
yield page_name, df
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections import namedtuple
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
from mindsdb.interfaces.agents.langchain_agent import create_chat_model
|
|
6
|
+
from langchain.chains.base import Chain
|
|
7
|
+
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
|
8
|
+
from langchain.chains.llm import LLMChain
|
|
9
|
+
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain, ReduceDocumentsChain
|
|
10
|
+
from langchain_core.callbacks import dispatch_custom_event
|
|
11
|
+
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
|
12
|
+
from langchain_core.documents import Document
|
|
13
|
+
from langchain_core.prompts import PromptTemplate
|
|
14
|
+
from pandas import DataFrame
|
|
15
|
+
|
|
16
|
+
from mindsdb.integrations.libs.vectordatabase_handler import VectorStoreHandler
|
|
17
|
+
from mindsdb.integrations.utilities.rag.settings import SummarizationConfig
|
|
18
|
+
from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
|
|
19
|
+
from mindsdb.utilities import log
|
|
20
|
+
|
|
21
|
+
logger = log.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
Summary = namedtuple('Summary', ['source_id', 'content'])
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def create_map_reduce_documents_chain(summarization_config: SummarizationConfig, input: str) -> ReduceDocumentsChain:
|
|
27
|
+
"""Creates a chain that map-reduces documents into a single consolidated summary."""
|
|
28
|
+
summarization_llm = create_chat_model({
|
|
29
|
+
'model_name': summarization_config.llm_config.model_name,
|
|
30
|
+
'provider': summarization_config.llm_config.provider,
|
|
31
|
+
**summarization_config.llm_config.params
|
|
32
|
+
})
|
|
33
|
+
|
|
34
|
+
reduce_prompt_template = summarization_config.reduce_prompt_template
|
|
35
|
+
reduce_prompt = PromptTemplate.from_template(reduce_prompt_template)
|
|
36
|
+
if 'input' in reduce_prompt.input_variables:
|
|
37
|
+
reduce_prompt = reduce_prompt.partial(input=input)
|
|
38
|
+
|
|
39
|
+
reduce_chain = LLMChain(llm=summarization_llm, prompt=reduce_prompt)
|
|
40
|
+
|
|
41
|
+
combine_documents_chain = StuffDocumentsChain(
|
|
42
|
+
llm_chain=reduce_chain,
|
|
43
|
+
document_variable_name='docs'
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
return ReduceDocumentsChain(
|
|
47
|
+
combine_documents_chain=combine_documents_chain,
|
|
48
|
+
collapse_documents_chain=combine_documents_chain,
|
|
49
|
+
token_max=summarization_config.max_summarization_tokens
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class LocalContextSummarizerChain(Chain):
|
|
54
|
+
"""Summarizes M chunks before and after a given chunk in a document."""
|
|
55
|
+
|
|
56
|
+
doc_id_key: str = 'original_row_id'
|
|
57
|
+
chunk_index_key: str = 'chunk_index'
|
|
58
|
+
|
|
59
|
+
vector_store_handler: VectorStoreHandler
|
|
60
|
+
table_name: str = 'embeddings'
|
|
61
|
+
content_column_name: str = 'content'
|
|
62
|
+
metadata_column_name: str = 'metadata'
|
|
63
|
+
|
|
64
|
+
summarization_config: SummarizationConfig
|
|
65
|
+
map_reduce_documents_chain: Optional[ReduceDocumentsChain] = None
|
|
66
|
+
|
|
67
|
+
def _select_chunks_from_vector_store(self, doc_id: str) -> DataFrame:
|
|
68
|
+
condition = FilterCondition(
|
|
69
|
+
f"{self.metadata_column_name}->>'{self.doc_id_key}'",
|
|
70
|
+
FilterOperator.EQUAL,
|
|
71
|
+
doc_id
|
|
72
|
+
)
|
|
73
|
+
return self.vector_store_handler.select(
|
|
74
|
+
self.table_name,
|
|
75
|
+
columns=[self.content_column_name, self.metadata_column_name],
|
|
76
|
+
conditions=[condition]
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
async def _get_all_chunks_for_document(self, doc_id: str) -> List[Document]:
|
|
80
|
+
df = await asyncio.get_event_loop().run_in_executor(
|
|
81
|
+
None, self._select_chunks_from_vector_store, doc_id
|
|
82
|
+
)
|
|
83
|
+
chunks = []
|
|
84
|
+
for _, row in df.iterrows():
|
|
85
|
+
metadata = row.get(self.metadata_column_name, {})
|
|
86
|
+
metadata[self.chunk_index_key] = row.get('chunk_id', 0)
|
|
87
|
+
chunks.append(Document(page_content=row[self.content_column_name], metadata=metadata))
|
|
88
|
+
|
|
89
|
+
return sorted(chunks, key=lambda x: x.metadata.get(self.chunk_index_key, 0))
|
|
90
|
+
|
|
91
|
+
async def summarize_local_context(self, doc_id: str, target_chunk_index: int, M: int) -> Summary:
|
|
92
|
+
"""
|
|
93
|
+
Summarizes M chunks before and after the given chunk.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
doc_id (str): Document ID.
|
|
97
|
+
target_chunk_index (int): Index of the chunk to summarize around.
|
|
98
|
+
M (int): Number of chunks before and after to include.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Summary: Summary object containing source_id and summary content.
|
|
102
|
+
"""
|
|
103
|
+
logger.debug(f"Fetching chunks for document {doc_id}")
|
|
104
|
+
all_chunks = await self._get_all_chunks_for_document(doc_id)
|
|
105
|
+
|
|
106
|
+
if not all_chunks:
|
|
107
|
+
logger.warning(f"No chunks found for document {doc_id}")
|
|
108
|
+
return Summary(source_id=doc_id, content='')
|
|
109
|
+
|
|
110
|
+
# Determine window boundaries
|
|
111
|
+
start_idx = max(0, target_chunk_index - M)
|
|
112
|
+
end_idx = min(len(all_chunks), target_chunk_index + M + 1)
|
|
113
|
+
local_chunks = all_chunks[start_idx:end_idx]
|
|
114
|
+
|
|
115
|
+
logger.debug(f"Summarizing chunks {start_idx} to {end_idx - 1} for document {doc_id}")
|
|
116
|
+
|
|
117
|
+
if not self.map_reduce_documents_chain:
|
|
118
|
+
self.map_reduce_documents_chain = create_map_reduce_documents_chain(
|
|
119
|
+
self.summarization_config, input="Summarize these chunks."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
summary_result = await self.map_reduce_documents_chain.ainvoke(local_chunks)
|
|
123
|
+
summary_text = summary_result.get('output_text', '')
|
|
124
|
+
|
|
125
|
+
logger.debug(f"Generated summary: {summary_text[:100]}...")
|
|
126
|
+
|
|
127
|
+
return Summary(source_id=doc_id, content=summary_text)
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def input_keys(self) -> List[str]:
|
|
131
|
+
return [self.context_key, self.question_key]
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def output_keys(self) -> List[str]:
|
|
135
|
+
return [self.context_key, self.question_key]
|
|
136
|
+
|
|
137
|
+
async def _get_source_summary(self, source_id: str, map_reduce_documents_chain: MapReduceDocumentsChain) -> Summary:
|
|
138
|
+
if not source_id:
|
|
139
|
+
logger.warning("Received empty source_id, returning empty summary")
|
|
140
|
+
return Summary(source_id='', content='')
|
|
141
|
+
|
|
142
|
+
logger.debug(f"Getting summary for source ID: {source_id}")
|
|
143
|
+
source_chunks = await self._get_all_chunks_for_document(source_id)
|
|
144
|
+
|
|
145
|
+
if not source_chunks:
|
|
146
|
+
logger.warning(f"No chunks found for source ID: {source_id}")
|
|
147
|
+
return Summary(source_id=source_id, content='')
|
|
148
|
+
|
|
149
|
+
logger.debug(f"Summarizing {len(source_chunks)} chunks for source ID: {source_id}")
|
|
150
|
+
summary = await map_reduce_documents_chain.ainvoke(source_chunks)
|
|
151
|
+
content = summary.get('output_text', '')
|
|
152
|
+
logger.debug(f"Generated summary for source ID {source_id}: {content[:100]}...")
|
|
153
|
+
|
|
154
|
+
# Stream summarization update.
|
|
155
|
+
dispatch_custom_event('summary', {'source_id': source_id, 'content': content})
|
|
156
|
+
|
|
157
|
+
return Summary(source_id=source_id, content=content)
|
|
158
|
+
|
|
159
|
+
async def _get_source_summaries(self, source_ids: List[str], map_reduce_documents_chain: MapReduceDocumentsChain) -> \
|
|
160
|
+
List[Summary]:
|
|
161
|
+
summaries = await asyncio.gather(
|
|
162
|
+
*[self._get_source_summary(source_id, map_reduce_documents_chain) for source_id in source_ids]
|
|
163
|
+
)
|
|
164
|
+
return summaries
|
|
165
|
+
|
|
166
|
+
def _call(
|
|
167
|
+
self,
|
|
168
|
+
inputs: Dict[str, Any],
|
|
169
|
+
run_manager: Optional[CallbackManagerForChainRun] = None
|
|
170
|
+
) -> Dict[str, Any]:
|
|
171
|
+
# Step 1: Connect to vector store to ensure embeddings are accessible
|
|
172
|
+
self.vector_store_handler.connect()
|
|
173
|
+
|
|
174
|
+
context_chunks: List[Document] = inputs.get(self.context_key, [])
|
|
175
|
+
logger.debug(f"Found {len(context_chunks)} context chunks.")
|
|
176
|
+
|
|
177
|
+
# Step 2: Extract unique document IDs from the provided chunks
|
|
178
|
+
unique_document_ids = self._get_document_ids_from_chunks(context_chunks)
|
|
179
|
+
logger.debug(f"Extracted {len(unique_document_ids)} unique document IDs: {unique_document_ids}")
|
|
180
|
+
|
|
181
|
+
# Step 3: Initialize the summarization chain if not provided
|
|
182
|
+
question = inputs.get(self.question_key, '')
|
|
183
|
+
map_reduce_documents_chain = self.map_reduce_documents_chain or create_map_reduce_documents_chain(
|
|
184
|
+
self.summarization_config, question
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Step 4: Dispatch event to signal summarization start
|
|
188
|
+
if run_manager:
|
|
189
|
+
run_manager.on_text("Starting summarization for documents.", verbose=True)
|
|
190
|
+
|
|
191
|
+
# Step 5: Process each document ID to summarize chunks with local context
|
|
192
|
+
for doc_id in unique_document_ids:
|
|
193
|
+
logger.debug(f"Fetching and summarizing chunks for document ID: {doc_id}")
|
|
194
|
+
|
|
195
|
+
# Fetch all chunks for the document
|
|
196
|
+
chunks = asyncio.get_event_loop().run_until_complete(self._get_all_chunks_for_document(doc_id))
|
|
197
|
+
if not chunks:
|
|
198
|
+
logger.warning(f"No chunks found for document ID: {doc_id}")
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
# Summarize each chunk with M neighboring chunks
|
|
202
|
+
M = self.neighbor_window
|
|
203
|
+
for i, chunk in enumerate(chunks):
|
|
204
|
+
window_chunks = chunks[max(0, i - M): min(len(chunks), i + M + 1)]
|
|
205
|
+
local_summary = asyncio.get_event_loop().run_until_complete(
|
|
206
|
+
map_reduce_documents_chain.ainvoke(window_chunks)
|
|
207
|
+
)
|
|
208
|
+
chunk.metadata['summary'] = local_summary.get('output_text', '')
|
|
209
|
+
logger.debug(f"Chunk {i} summary: {chunk.metadata['summary'][:100]}...")
|
|
210
|
+
|
|
211
|
+
# Step 6: Update the original context chunks with the newly generated summaries
|
|
212
|
+
for chunk in context_chunks:
|
|
213
|
+
doc_id = str(chunk.metadata.get(self.doc_id_key, ''))
|
|
214
|
+
matching_chunk = next((c for c in chunks if c.metadata.get(self.doc_id_key) == doc_id and c.metadata.get(
|
|
215
|
+
'chunk_index') == chunk.metadata.get('chunk_index')), None)
|
|
216
|
+
if matching_chunk:
|
|
217
|
+
chunk.metadata['summary'] = matching_chunk.metadata.get('summary', '')
|
|
218
|
+
else:
|
|
219
|
+
chunk.metadata['summary'] = ''
|
|
220
|
+
logger.warning(f"No matching chunk found for doc_id: {doc_id}")
|
|
221
|
+
|
|
222
|
+
# Step 7: Signal summarization end
|
|
223
|
+
if run_manager:
|
|
224
|
+
run_manager.on_text("Summarization completed.", verbose=True)
|
|
225
|
+
|
|
226
|
+
logger.debug(f"Updated {len(context_chunks)} context chunks with summaries.")
|
|
227
|
+
return inputs
|
|
@@ -256,9 +256,9 @@ class AgentsController:
|
|
|
256
256
|
if (
|
|
257
257
|
is_demo and (
|
|
258
258
|
(name is not None and name != agent_name)
|
|
259
|
-
or (model_name
|
|
260
|
-
or (
|
|
261
|
-
or (isinstance(params, dict) and len(params) >
|
|
259
|
+
or (model_name is not None and existing_agent.model_name != model_name)
|
|
260
|
+
or (provider is not None and existing_agent.provider != provider)
|
|
261
|
+
or (isinstance(params, dict) and len(params) > 0 and 'prompt_template' not in params)
|
|
262
262
|
)
|
|
263
263
|
):
|
|
264
264
|
raise ValueError("It is forbidden to change properties of the demo object")
|
|
@@ -1,9 +1,13 @@
|
|
|
1
|
-
|
|
1
|
+
import io
|
|
2
2
|
import logging
|
|
3
|
+
import contextlib
|
|
4
|
+
from typing import Any, Dict, List, Union, Callable
|
|
5
|
+
|
|
3
6
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
4
7
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
|
5
8
|
from langchain_core.messages.base import BaseMessage
|
|
6
9
|
from langchain_core.outputs import LLMResult
|
|
10
|
+
from langchain_core.callbacks import StdOutCallbackHandler
|
|
7
11
|
|
|
8
12
|
|
|
9
13
|
class ContextCaptureCallback(BaseCallbackHandler):
|
|
@@ -20,14 +24,49 @@ class ContextCaptureCallback(BaseCallbackHandler):
|
|
|
20
24
|
return self.context
|
|
21
25
|
|
|
22
26
|
|
|
27
|
+
class VerboseLogCallbackHandler(StdOutCallbackHandler):
|
|
28
|
+
def __init__(self, logger: logging.Logger, verbose: bool):
|
|
29
|
+
self.logger = logger
|
|
30
|
+
self.verbose = verbose
|
|
31
|
+
super().__init__()
|
|
32
|
+
|
|
33
|
+
def __call(self, method: Callable, *args: List[Any], **kwargs: Any) -> Any:
|
|
34
|
+
if self.verbose is False:
|
|
35
|
+
return
|
|
36
|
+
f = io.StringIO()
|
|
37
|
+
with contextlib.redirect_stdout(f):
|
|
38
|
+
method(*args, **kwargs)
|
|
39
|
+
output = f.getvalue()
|
|
40
|
+
self.logger.info(output)
|
|
41
|
+
|
|
42
|
+
def on_chain_start(self, *args: List[Any], **kwargs: Any) -> None:
|
|
43
|
+
self.__call(super().on_chain_start, *args, **kwargs)
|
|
44
|
+
|
|
45
|
+
def on_chain_end(self, *args: List[Any], **kwargs: Any) -> None:
|
|
46
|
+
self.__call(super().on_chain_end, *args, **kwargs)
|
|
47
|
+
|
|
48
|
+
def on_agent_action(self, *args: List[Any], **kwargs: Any) -> None:
|
|
49
|
+
self.__call(super().on_agent_action, *args, **kwargs)
|
|
50
|
+
|
|
51
|
+
def on_tool_end(self, *args: List[Any], **kwargs: Any) -> None:
|
|
52
|
+
self.__call(super().on_tool_end, *args, **kwargs)
|
|
53
|
+
|
|
54
|
+
def on_text(self, *args: List[Any], **kwargs: Any) -> None:
|
|
55
|
+
self.__call(super().on_text, *args, **kwargs)
|
|
56
|
+
|
|
57
|
+
def on_agent_finish(self, *args: List[Any], **kwargs: Any) -> None:
|
|
58
|
+
self.__call(super().on_agent_finish, *args, **kwargs)
|
|
59
|
+
|
|
60
|
+
|
|
23
61
|
class LogCallbackHandler(BaseCallbackHandler):
|
|
24
62
|
'''Langchain callback handler that logs agent and chain executions.'''
|
|
25
63
|
|
|
26
|
-
def __init__(self, logger: logging.Logger):
|
|
64
|
+
def __init__(self, logger: logging.Logger, verbose: bool = True):
|
|
27
65
|
logger.setLevel('DEBUG')
|
|
28
66
|
self.logger = logger
|
|
29
67
|
self._num_running_chains = 0
|
|
30
68
|
self.generated_sql = None
|
|
69
|
+
self.verbose_log_handler = VerboseLogCallbackHandler(logger, verbose)
|
|
31
70
|
|
|
32
71
|
def on_llm_start(
|
|
33
72
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
@@ -36,6 +75,7 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
36
75
|
self.logger.debug('LLM started with prompts:')
|
|
37
76
|
for prompt in prompts:
|
|
38
77
|
self.logger.debug(prompt[:50])
|
|
78
|
+
self.verbose_log_handler.on_llm_start(serialized, prompts, **kwargs)
|
|
39
79
|
|
|
40
80
|
def on_chat_model_start(
|
|
41
81
|
self,
|
|
@@ -46,7 +86,7 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
46
86
|
self.logger.debug('Chat model started with messages:')
|
|
47
87
|
for message_list in messages:
|
|
48
88
|
for message in message_list:
|
|
49
|
-
self.logger.debug(message.
|
|
89
|
+
self.logger.debug(message.pretty_repr())
|
|
50
90
|
|
|
51
91
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
|
|
52
92
|
'''Run on new LLM token. Only available when streaming is enabled.'''
|
|
@@ -72,6 +112,8 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
72
112
|
self._num_running_chains))
|
|
73
113
|
self.logger.debug('Inputs: {}'.format(inputs))
|
|
74
114
|
|
|
115
|
+
self.verbose_log_handler.on_chain_start(serialized=serialized, inputs=inputs, **kwargs)
|
|
116
|
+
|
|
75
117
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
|
|
76
118
|
'''Run when chain ends running.'''
|
|
77
119
|
self._num_running_chains -= 1
|
|
@@ -79,6 +121,8 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
79
121
|
self._num_running_chains))
|
|
80
122
|
self.logger.debug('Outputs: {}'.format(outputs))
|
|
81
123
|
|
|
124
|
+
self.verbose_log_handler.on_chain_end(outputs=outputs, **kwargs)
|
|
125
|
+
|
|
82
126
|
def on_chain_error(
|
|
83
127
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
84
128
|
) -> Any:
|
|
@@ -96,7 +140,7 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
96
140
|
|
|
97
141
|
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
|
98
142
|
'''Run when tool ends running.'''
|
|
99
|
-
|
|
143
|
+
self.verbose_log_handler.on_tool_end(output=output, **kwargs)
|
|
100
144
|
|
|
101
145
|
def on_tool_error(
|
|
102
146
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
@@ -106,7 +150,7 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
106
150
|
|
|
107
151
|
def on_text(self, text: str, **kwargs: Any) -> Any:
|
|
108
152
|
'''Run on arbitrary text.'''
|
|
109
|
-
|
|
153
|
+
self.verbose_log_handler.on_text(text=text, **kwargs)
|
|
110
154
|
|
|
111
155
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
|
112
156
|
'''Run on agent action.'''
|
|
@@ -124,7 +168,10 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
124
168
|
# fix for mistral
|
|
125
169
|
action.tool = action.tool.replace('\\', '')
|
|
126
170
|
|
|
171
|
+
self.verbose_log_handler.on_agent_action(action=action, **kwargs)
|
|
172
|
+
|
|
127
173
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
|
128
174
|
'''Run on agent end.'''
|
|
129
175
|
self.logger.debug('Agent finished with return values:')
|
|
130
176
|
self.logger.debug(str(finish.return_values))
|
|
177
|
+
self.verbose_log_handler.on_agent_finish(finish=finish, **kwargs)
|
|
@@ -400,7 +400,7 @@ class LangchainAgent:
|
|
|
400
400
|
"max_iterations", args.get("max_iterations", DEFAULT_MAX_ITERATIONS)
|
|
401
401
|
),
|
|
402
402
|
memory=memory,
|
|
403
|
-
verbose=args.get("verbose", args.get("verbose",
|
|
403
|
+
verbose=args.get("verbose", args.get("verbose", False))
|
|
404
404
|
)
|
|
405
405
|
return agent_executor
|
|
406
406
|
|
|
@@ -435,7 +435,7 @@ class LangchainAgent:
|
|
|
435
435
|
all_callbacks = []
|
|
436
436
|
|
|
437
437
|
if self.log_callback_handler is None:
|
|
438
|
-
self.log_callback_handler = LogCallbackHandler(logger)
|
|
438
|
+
self.log_callback_handler = LogCallbackHandler(logger, verbose=args.get("verbose", True))
|
|
439
439
|
|
|
440
440
|
all_callbacks.append(self.log_callback_handler)
|
|
441
441
|
|
|
@@ -599,7 +599,9 @@ AI: {response}"""
|
|
|
599
599
|
agent_executor_finished_event.set()
|
|
600
600
|
|
|
601
601
|
# Enqueue Langchain agent streaming chunks in a separate thread to not block event chunks.
|
|
602
|
-
executor_stream_thread = threading.Thread(
|
|
602
|
+
executor_stream_thread = threading.Thread(
|
|
603
|
+
target=stream_worker, daemon=True, args=(ctx.dump(),), name='LangchainAgent.stream_worker'
|
|
604
|
+
)
|
|
603
605
|
executor_stream_thread.start()
|
|
604
606
|
|
|
605
607
|
while not agent_executor_finished_event.is_set():
|
|
@@ -106,7 +106,7 @@ class DatabaseController:
|
|
|
106
106
|
}
|
|
107
107
|
|
|
108
108
|
def exists(self, db_name: str) -> bool:
|
|
109
|
-
return db_name in self.get_dict()
|
|
109
|
+
return db_name.lower() in self.get_dict()
|
|
110
110
|
|
|
111
111
|
def get_project(self, name: str):
|
|
112
112
|
return self.project_controller.get(name=name)
|
|
@@ -64,7 +64,7 @@ class HandlersCache:
|
|
|
64
64
|
):
|
|
65
65
|
return
|
|
66
66
|
self._stop_event.clear()
|
|
67
|
-
self.cleaner_thread = threading.Thread(target=self._clean)
|
|
67
|
+
self.cleaner_thread = threading.Thread(target=self._clean, name='HandlersCache.clean')
|
|
68
68
|
self.cleaner_thread.daemon = True
|
|
69
69
|
self.cleaner_thread.start()
|
|
70
70
|
|