MindsDB 25.2.3.0__py3-none-any.whl → 25.2.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- {MindsDB-25.2.3.0.dist-info → MindsDB-25.2.4.0.dist-info}/METADATA +224 -243
- {MindsDB-25.2.3.0.dist-info → MindsDB-25.2.4.0.dist-info}/RECORD +44 -43
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +1 -11
- mindsdb/api/executor/datahub/datanodes/system_tables.py +4 -1
- mindsdb/api/http/initialize.py +8 -5
- mindsdb/api/http/namespaces/agents.py +0 -7
- mindsdb/api/http/namespaces/config.py +0 -48
- mindsdb/api/http/namespaces/knowledge_bases.py +1 -1
- mindsdb/api/http/namespaces/util.py +0 -28
- mindsdb/integrations/handlers/anyscale_endpoints_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/dspy_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/langchain_embedding_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/langchain_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/llama_index_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/openai_handler/constants.py +3 -1
- mindsdb/integrations/handlers/openai_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/rag_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/ray_serve_handler/ray_serve_handler.py +33 -8
- mindsdb/integrations/handlers/web_handler/urlcrawl_helpers.py +3 -2
- mindsdb/integrations/handlers/web_handler/web_handler.py +42 -33
- mindsdb/integrations/handlers/youtube_handler/__init__.py +2 -0
- mindsdb/integrations/handlers/youtube_handler/connection_args.py +32 -0
- mindsdb/integrations/libs/llm/utils.py +5 -0
- mindsdb/integrations/libs/process_cache.py +2 -2
- mindsdb/integrations/utilities/rag/chains/local_context_summarizer_chain.py +227 -0
- mindsdb/interfaces/agents/agents_controller.py +3 -3
- mindsdb/interfaces/agents/callback_handlers.py +52 -5
- mindsdb/interfaces/agents/langchain_agent.py +5 -3
- mindsdb/interfaces/database/database.py +1 -1
- mindsdb/interfaces/database/integrations.py +1 -1
- mindsdb/interfaces/jobs/scheduler.py +1 -1
- mindsdb/interfaces/knowledge_base/preprocessing/constants.py +2 -2
- mindsdb/interfaces/skills/skills_controller.py +2 -2
- mindsdb/interfaces/skills/sql_agent.py +6 -1
- mindsdb/interfaces/storage/db.py +0 -12
- mindsdb/migrations/versions/2025-02-10_6ab9903fc59a_del_log_table.py +33 -0
- mindsdb/utilities/config.py +1 -0
- mindsdb/utilities/log.py +17 -2
- mindsdb/utilities/ml_task_queue/consumer.py +4 -2
- mindsdb/utilities/render/sqlalchemy_render.py +4 -0
- mindsdb/utilities/log_controller.py +0 -39
- mindsdb/utilities/telemetry.py +0 -44
- {MindsDB-25.2.3.0.dist-info → MindsDB-25.2.4.0.dist-info}/LICENSE +0 -0
- {MindsDB-25.2.3.0.dist-info → MindsDB-25.2.4.0.dist-info}/WHEEL +0 -0
- {MindsDB-25.2.3.0.dist-info → MindsDB-25.2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,11 @@
|
|
|
1
|
+
import io
|
|
1
2
|
import json
|
|
2
3
|
|
|
3
4
|
import requests
|
|
4
5
|
from typing import Dict, Optional
|
|
5
6
|
|
|
6
7
|
import pandas as pd
|
|
8
|
+
import pyarrow.parquet as pq
|
|
7
9
|
|
|
8
10
|
from mindsdb.integrations.libs.base import BaseMLEngine
|
|
9
11
|
|
|
@@ -37,9 +39,17 @@ class RayServeHandler(BaseMLEngine):
|
|
|
37
39
|
args['target'] = target
|
|
38
40
|
self.model_storage.json_set('args', args)
|
|
39
41
|
try:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
42
|
+
if args.get('is_parquet', False):
|
|
43
|
+
buffer = io.BytesIO()
|
|
44
|
+
df.to_parquet(buffer)
|
|
45
|
+
resp = requests.post(args['train_url'],
|
|
46
|
+
files={"df": ("df", buffer.getvalue(), "application/octet-stream")},
|
|
47
|
+
data={"args": json.dumps(args), "target": target},
|
|
48
|
+
)
|
|
49
|
+
else:
|
|
50
|
+
resp = requests.post(args['train_url'],
|
|
51
|
+
json={'df': df.to_json(orient='records'), 'target': target, 'args': args},
|
|
52
|
+
headers={'content-type': 'application/json; format=pandas-records'})
|
|
43
53
|
except requests.exceptions.InvalidSchema:
|
|
44
54
|
raise Exception("Error: The URL provided for the training endpoint is invalid.")
|
|
45
55
|
|
|
@@ -59,14 +69,29 @@ class RayServeHandler(BaseMLEngine):
|
|
|
59
69
|
args = {**(self.model_storage.json_get('args')), **args} # merge incoming args
|
|
60
70
|
pred_args = args.get('predict_params', {})
|
|
61
71
|
args = {**args, **pred_args} # merge pred_args
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
72
|
+
if args.get('is_parquet', False):
|
|
73
|
+
buffer = io.BytesIO()
|
|
74
|
+
df.attrs['pred_args'] = pred_args
|
|
75
|
+
df.to_parquet(buffer)
|
|
76
|
+
resp = requests.post(args['predict_url'],
|
|
77
|
+
files={"df": ("df", buffer.getvalue(), "application/octet-stream")},
|
|
78
|
+
data={"pred_args": json.dumps(pred_args)},
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
resp = requests.post(args['predict_url'],
|
|
82
|
+
json={'df': df.to_json(orient='records'), 'pred_args': pred_args},
|
|
83
|
+
headers={'content-type': 'application/json; format=pandas-records'})
|
|
66
84
|
try:
|
|
67
|
-
|
|
85
|
+
if args.get('is_parquet', False):
|
|
86
|
+
buffer = io.BytesIO(resp.content)
|
|
87
|
+
table = pq.read_table(buffer)
|
|
88
|
+
response = table.to_pandas()
|
|
89
|
+
else:
|
|
90
|
+
response = resp.json()
|
|
68
91
|
except json.JSONDecodeError:
|
|
69
92
|
error = resp.text
|
|
93
|
+
except Exception:
|
|
94
|
+
error = 'Could not decode parquet.'
|
|
70
95
|
else:
|
|
71
96
|
if 'prediction' in response:
|
|
72
97
|
target = args['target']
|
|
@@ -220,8 +220,6 @@ def get_all_website_links_recursively(url, reviewed_urls, limit=None, crawl_dept
|
|
|
220
220
|
if limit is not None:
|
|
221
221
|
if len(reviewed_urls) >= limit:
|
|
222
222
|
return reviewed_urls
|
|
223
|
-
if crawl_depth == current_depth:
|
|
224
|
-
return reviewed_urls
|
|
225
223
|
|
|
226
224
|
if not filters:
|
|
227
225
|
matches_filter = True
|
|
@@ -241,6 +239,9 @@ def get_all_website_links_recursively(url, reviewed_urls, limit=None, crawl_dept
|
|
|
241
239
|
"error": str(error_message),
|
|
242
240
|
}
|
|
243
241
|
|
|
242
|
+
if crawl_depth is not None and crawl_depth == current_depth:
|
|
243
|
+
return reviewed_urls
|
|
244
|
+
|
|
244
245
|
to_rev_url_list = []
|
|
245
246
|
|
|
246
247
|
# create a list of new urls to review that don't exist in the already reviewed ones
|
|
@@ -1,62 +1,71 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
1
3
|
import pandas as pd
|
|
2
4
|
from mindsdb.integrations.libs.response import HandlerStatusResponse
|
|
3
|
-
from
|
|
4
|
-
from mindsdb.integrations.libs.api_handler import APIHandler, APITable
|
|
5
|
-
from mindsdb.utilities.config import Config
|
|
6
|
-
from mindsdb.integrations.utilities.sql_utils import extract_comparison_conditions, project_dataframe
|
|
5
|
+
from mindsdb.utilities.config import config
|
|
7
6
|
from mindsdb.utilities.security import validate_urls
|
|
8
7
|
from .urlcrawl_helpers import get_all_websites
|
|
9
8
|
|
|
9
|
+
from mindsdb.integrations.libs.api_handler import APIResource, APIHandler
|
|
10
|
+
from mindsdb.integrations.utilities.sql_utils import (FilterCondition, FilterOperator)
|
|
10
11
|
|
|
11
|
-
class CrawlerTable(APITable):
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
super().__init__(handler)
|
|
15
|
-
self.config = Config()
|
|
13
|
+
class CrawlerTable(APIResource):
|
|
16
14
|
|
|
17
|
-
def
|
|
15
|
+
def list(
|
|
16
|
+
self,
|
|
17
|
+
conditions: List[FilterCondition] = None,
|
|
18
|
+
limit: int = None,
|
|
19
|
+
**kwargs
|
|
20
|
+
) -> pd.DataFrame:
|
|
18
21
|
"""
|
|
19
22
|
Selects data from the provided websites
|
|
20
23
|
|
|
21
|
-
Args:
|
|
22
|
-
query (ast.Select): Given SQL SELECT query
|
|
23
|
-
|
|
24
24
|
Returns:
|
|
25
25
|
dataframe: Dataframe containing the crawled data
|
|
26
26
|
|
|
27
27
|
Raises:
|
|
28
28
|
NotImplementedError: If the query is not supported
|
|
29
29
|
"""
|
|
30
|
-
conditions = extract_comparison_conditions(query.where)
|
|
31
30
|
urls = []
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
if
|
|
36
|
-
if
|
|
37
|
-
urls =
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
31
|
+
crawl_depth = None
|
|
32
|
+
per_url_limit = None
|
|
33
|
+
for condition in conditions:
|
|
34
|
+
if condition.column == 'url':
|
|
35
|
+
if condition.op == FilterOperator.IN:
|
|
36
|
+
urls = condition.value
|
|
37
|
+
elif condition.op == FilterOperator.EQUAL:
|
|
38
|
+
urls = [condition.value]
|
|
39
|
+
condition.applied = True
|
|
40
|
+
if condition.column == 'crawl_depth' and condition.op == FilterOperator.EQUAL:
|
|
41
|
+
crawl_depth = condition.value
|
|
42
|
+
condition.applied = True
|
|
43
|
+
if condition.column == 'per_url_limit' and condition.op == FilterOperator.EQUAL:
|
|
44
|
+
per_url_limit = condition.value
|
|
45
|
+
condition.applied = True
|
|
41
46
|
|
|
42
47
|
if len(urls) == 0:
|
|
43
48
|
raise NotImplementedError(
|
|
44
|
-
'You must specify what url you want to crawl, for example: SELECT * FROM
|
|
49
|
+
'You must specify what url you want to crawl, for example: SELECT * FROM web.crawler WHERE url = "someurl"')
|
|
45
50
|
|
|
46
|
-
allowed_urls =
|
|
51
|
+
allowed_urls = config.get('web_crawling_allowed_sites', [])
|
|
47
52
|
if allowed_urls and not validate_urls(urls, allowed_urls):
|
|
48
53
|
raise ValueError(f"The provided URL is not allowed for web crawling. Please use any of {', '.join(allowed_urls)}.")
|
|
49
54
|
|
|
50
|
-
if
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
55
|
+
if limit is None and per_url_limit is None and crawl_depth is None:
|
|
56
|
+
per_url_limit = 1
|
|
57
|
+
if per_url_limit is not None:
|
|
58
|
+
# crawl every url separately
|
|
59
|
+
results = []
|
|
60
|
+
for url in urls:
|
|
61
|
+
results.append(get_all_websites([url], per_url_limit, crawl_depth=crawl_depth))
|
|
62
|
+
result = pd.concat(results)
|
|
63
|
+
else:
|
|
64
|
+
result = get_all_websites(urls, limit, crawl_depth=crawl_depth)
|
|
65
|
+
|
|
66
|
+
if limit is not None and len(result) > limit:
|
|
57
67
|
result = result[:limit]
|
|
58
|
-
|
|
59
|
-
result = project_dataframe(result, query.targets, self.get_columns())
|
|
68
|
+
|
|
60
69
|
return result
|
|
61
70
|
|
|
62
71
|
def get_columns(self):
|
|
@@ -5,6 +5,7 @@ from .__about__ import __version__ as version, __description__ as description
|
|
|
5
5
|
|
|
6
6
|
try:
|
|
7
7
|
from .youtube_handler import YoutubeHandler as Handler
|
|
8
|
+
from .connection_args import connection_args
|
|
8
9
|
import_error = None
|
|
9
10
|
except Exception as e:
|
|
10
11
|
Handler = None
|
|
@@ -24,4 +25,5 @@ __all__ = [
|
|
|
24
25
|
"description",
|
|
25
26
|
"import_error",
|
|
26
27
|
"icon_path",
|
|
28
|
+
"connection_args",
|
|
27
29
|
]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from collections import OrderedDict
|
|
2
|
+
|
|
3
|
+
from mindsdb.integrations.libs.const import HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
connection_args = OrderedDict(
|
|
7
|
+
youtube_api_token={
|
|
8
|
+
'type': ARG_TYPE.STR,
|
|
9
|
+
'description': 'Youtube API Token',
|
|
10
|
+
'label': 'Youtube API Token',
|
|
11
|
+
},
|
|
12
|
+
credentials_url={
|
|
13
|
+
'type': ARG_TYPE.STR,
|
|
14
|
+
'description': 'URL to Service Account Keys',
|
|
15
|
+
'label': 'URL to Service Account Keys',
|
|
16
|
+
},
|
|
17
|
+
credentials_file={
|
|
18
|
+
'type': ARG_TYPE.STR,
|
|
19
|
+
'description': 'Location of Service Account Keys',
|
|
20
|
+
'label': 'Path to Service Account Keys',
|
|
21
|
+
},
|
|
22
|
+
credentials={
|
|
23
|
+
'type': ARG_TYPE.PATH,
|
|
24
|
+
'description': 'Service Account Keys',
|
|
25
|
+
'label': 'Upload Service Account Keys',
|
|
26
|
+
},
|
|
27
|
+
code={
|
|
28
|
+
'type': ARG_TYPE.STR,
|
|
29
|
+
'description': 'Code After Authorisation',
|
|
30
|
+
'label': 'Code After Authorisation',
|
|
31
|
+
},
|
|
32
|
+
)
|
|
@@ -115,6 +115,11 @@ def get_llm_config(provider: str, args: Dict) -> BaseLLMConfig:
|
|
|
115
115
|
"""
|
|
116
116
|
temperature = min(1.0, max(0.0, args.get("temperature", 0.0)))
|
|
117
117
|
if provider == "openai":
|
|
118
|
+
|
|
119
|
+
if any(x in args.get("model_name", "") for x in ['o1', 'o3']):
|
|
120
|
+
# for o1 and 03, 'temperature' does not support 0.0 with this model. Only the default (1) value is supported
|
|
121
|
+
temperature = 1
|
|
122
|
+
|
|
118
123
|
return OpenAIConfig(
|
|
119
124
|
model_name=args.get("model_name", DEFAULT_OPENAI_MODEL),
|
|
120
125
|
temperature=temperature,
|
|
@@ -186,7 +186,6 @@ class ProcessCache:
|
|
|
186
186
|
self._keep_alive = {}
|
|
187
187
|
self._stop_event = threading.Event()
|
|
188
188
|
self.cleaner_thread = None
|
|
189
|
-
self._start_clean()
|
|
190
189
|
|
|
191
190
|
def __del__(self):
|
|
192
191
|
self._stop_clean()
|
|
@@ -200,7 +199,7 @@ class ProcessCache:
|
|
|
200
199
|
):
|
|
201
200
|
return
|
|
202
201
|
self._stop_event.clear()
|
|
203
|
-
self.cleaner_thread = threading.Thread(target=self._clean)
|
|
202
|
+
self.cleaner_thread = threading.Thread(target=self._clean, name='ProcessCache.clean')
|
|
204
203
|
self.cleaner_thread.daemon = True
|
|
205
204
|
self.cleaner_thread.start()
|
|
206
205
|
|
|
@@ -258,6 +257,7 @@ class ProcessCache:
|
|
|
258
257
|
Returns:
|
|
259
258
|
Future
|
|
260
259
|
"""
|
|
260
|
+
self._start_clean()
|
|
261
261
|
handler_module_path = payload['handler_meta']['module_path']
|
|
262
262
|
integration_id = payload['handler_meta']['integration_id']
|
|
263
263
|
if task_type in (ML_TASK_TYPE.LEARN, ML_TASK_TYPE.FINETUNE):
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections import namedtuple
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
from mindsdb.interfaces.agents.langchain_agent import create_chat_model
|
|
6
|
+
from langchain.chains.base import Chain
|
|
7
|
+
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
|
8
|
+
from langchain.chains.llm import LLMChain
|
|
9
|
+
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain, ReduceDocumentsChain
|
|
10
|
+
from langchain_core.callbacks import dispatch_custom_event
|
|
11
|
+
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
|
12
|
+
from langchain_core.documents import Document
|
|
13
|
+
from langchain_core.prompts import PromptTemplate
|
|
14
|
+
from pandas import DataFrame
|
|
15
|
+
|
|
16
|
+
from mindsdb.integrations.libs.vectordatabase_handler import VectorStoreHandler
|
|
17
|
+
from mindsdb.integrations.utilities.rag.settings import SummarizationConfig
|
|
18
|
+
from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
|
|
19
|
+
from mindsdb.utilities import log
|
|
20
|
+
|
|
21
|
+
logger = log.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
Summary = namedtuple('Summary', ['source_id', 'content'])
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def create_map_reduce_documents_chain(summarization_config: SummarizationConfig, input: str) -> ReduceDocumentsChain:
|
|
27
|
+
"""Creates a chain that map-reduces documents into a single consolidated summary."""
|
|
28
|
+
summarization_llm = create_chat_model({
|
|
29
|
+
'model_name': summarization_config.llm_config.model_name,
|
|
30
|
+
'provider': summarization_config.llm_config.provider,
|
|
31
|
+
**summarization_config.llm_config.params
|
|
32
|
+
})
|
|
33
|
+
|
|
34
|
+
reduce_prompt_template = summarization_config.reduce_prompt_template
|
|
35
|
+
reduce_prompt = PromptTemplate.from_template(reduce_prompt_template)
|
|
36
|
+
if 'input' in reduce_prompt.input_variables:
|
|
37
|
+
reduce_prompt = reduce_prompt.partial(input=input)
|
|
38
|
+
|
|
39
|
+
reduce_chain = LLMChain(llm=summarization_llm, prompt=reduce_prompt)
|
|
40
|
+
|
|
41
|
+
combine_documents_chain = StuffDocumentsChain(
|
|
42
|
+
llm_chain=reduce_chain,
|
|
43
|
+
document_variable_name='docs'
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
return ReduceDocumentsChain(
|
|
47
|
+
combine_documents_chain=combine_documents_chain,
|
|
48
|
+
collapse_documents_chain=combine_documents_chain,
|
|
49
|
+
token_max=summarization_config.max_summarization_tokens
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class LocalContextSummarizerChain(Chain):
|
|
54
|
+
"""Summarizes M chunks before and after a given chunk in a document."""
|
|
55
|
+
|
|
56
|
+
doc_id_key: str = 'original_row_id'
|
|
57
|
+
chunk_index_key: str = 'chunk_index'
|
|
58
|
+
|
|
59
|
+
vector_store_handler: VectorStoreHandler
|
|
60
|
+
table_name: str = 'embeddings'
|
|
61
|
+
content_column_name: str = 'content'
|
|
62
|
+
metadata_column_name: str = 'metadata'
|
|
63
|
+
|
|
64
|
+
summarization_config: SummarizationConfig
|
|
65
|
+
map_reduce_documents_chain: Optional[ReduceDocumentsChain] = None
|
|
66
|
+
|
|
67
|
+
def _select_chunks_from_vector_store(self, doc_id: str) -> DataFrame:
|
|
68
|
+
condition = FilterCondition(
|
|
69
|
+
f"{self.metadata_column_name}->>'{self.doc_id_key}'",
|
|
70
|
+
FilterOperator.EQUAL,
|
|
71
|
+
doc_id
|
|
72
|
+
)
|
|
73
|
+
return self.vector_store_handler.select(
|
|
74
|
+
self.table_name,
|
|
75
|
+
columns=[self.content_column_name, self.metadata_column_name],
|
|
76
|
+
conditions=[condition]
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
async def _get_all_chunks_for_document(self, doc_id: str) -> List[Document]:
|
|
80
|
+
df = await asyncio.get_event_loop().run_in_executor(
|
|
81
|
+
None, self._select_chunks_from_vector_store, doc_id
|
|
82
|
+
)
|
|
83
|
+
chunks = []
|
|
84
|
+
for _, row in df.iterrows():
|
|
85
|
+
metadata = row.get(self.metadata_column_name, {})
|
|
86
|
+
metadata[self.chunk_index_key] = row.get('chunk_id', 0)
|
|
87
|
+
chunks.append(Document(page_content=row[self.content_column_name], metadata=metadata))
|
|
88
|
+
|
|
89
|
+
return sorted(chunks, key=lambda x: x.metadata.get(self.chunk_index_key, 0))
|
|
90
|
+
|
|
91
|
+
async def summarize_local_context(self, doc_id: str, target_chunk_index: int, M: int) -> Summary:
|
|
92
|
+
"""
|
|
93
|
+
Summarizes M chunks before and after the given chunk.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
doc_id (str): Document ID.
|
|
97
|
+
target_chunk_index (int): Index of the chunk to summarize around.
|
|
98
|
+
M (int): Number of chunks before and after to include.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Summary: Summary object containing source_id and summary content.
|
|
102
|
+
"""
|
|
103
|
+
logger.debug(f"Fetching chunks for document {doc_id}")
|
|
104
|
+
all_chunks = await self._get_all_chunks_for_document(doc_id)
|
|
105
|
+
|
|
106
|
+
if not all_chunks:
|
|
107
|
+
logger.warning(f"No chunks found for document {doc_id}")
|
|
108
|
+
return Summary(source_id=doc_id, content='')
|
|
109
|
+
|
|
110
|
+
# Determine window boundaries
|
|
111
|
+
start_idx = max(0, target_chunk_index - M)
|
|
112
|
+
end_idx = min(len(all_chunks), target_chunk_index + M + 1)
|
|
113
|
+
local_chunks = all_chunks[start_idx:end_idx]
|
|
114
|
+
|
|
115
|
+
logger.debug(f"Summarizing chunks {start_idx} to {end_idx - 1} for document {doc_id}")
|
|
116
|
+
|
|
117
|
+
if not self.map_reduce_documents_chain:
|
|
118
|
+
self.map_reduce_documents_chain = create_map_reduce_documents_chain(
|
|
119
|
+
self.summarization_config, input="Summarize these chunks."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
summary_result = await self.map_reduce_documents_chain.ainvoke(local_chunks)
|
|
123
|
+
summary_text = summary_result.get('output_text', '')
|
|
124
|
+
|
|
125
|
+
logger.debug(f"Generated summary: {summary_text[:100]}...")
|
|
126
|
+
|
|
127
|
+
return Summary(source_id=doc_id, content=summary_text)
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def input_keys(self) -> List[str]:
|
|
131
|
+
return [self.context_key, self.question_key]
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def output_keys(self) -> List[str]:
|
|
135
|
+
return [self.context_key, self.question_key]
|
|
136
|
+
|
|
137
|
+
async def _get_source_summary(self, source_id: str, map_reduce_documents_chain: MapReduceDocumentsChain) -> Summary:
|
|
138
|
+
if not source_id:
|
|
139
|
+
logger.warning("Received empty source_id, returning empty summary")
|
|
140
|
+
return Summary(source_id='', content='')
|
|
141
|
+
|
|
142
|
+
logger.debug(f"Getting summary for source ID: {source_id}")
|
|
143
|
+
source_chunks = await self._get_all_chunks_for_document(source_id)
|
|
144
|
+
|
|
145
|
+
if not source_chunks:
|
|
146
|
+
logger.warning(f"No chunks found for source ID: {source_id}")
|
|
147
|
+
return Summary(source_id=source_id, content='')
|
|
148
|
+
|
|
149
|
+
logger.debug(f"Summarizing {len(source_chunks)} chunks for source ID: {source_id}")
|
|
150
|
+
summary = await map_reduce_documents_chain.ainvoke(source_chunks)
|
|
151
|
+
content = summary.get('output_text', '')
|
|
152
|
+
logger.debug(f"Generated summary for source ID {source_id}: {content[:100]}...")
|
|
153
|
+
|
|
154
|
+
# Stream summarization update.
|
|
155
|
+
dispatch_custom_event('summary', {'source_id': source_id, 'content': content})
|
|
156
|
+
|
|
157
|
+
return Summary(source_id=source_id, content=content)
|
|
158
|
+
|
|
159
|
+
async def _get_source_summaries(self, source_ids: List[str], map_reduce_documents_chain: MapReduceDocumentsChain) -> \
|
|
160
|
+
List[Summary]:
|
|
161
|
+
summaries = await asyncio.gather(
|
|
162
|
+
*[self._get_source_summary(source_id, map_reduce_documents_chain) for source_id in source_ids]
|
|
163
|
+
)
|
|
164
|
+
return summaries
|
|
165
|
+
|
|
166
|
+
def _call(
|
|
167
|
+
self,
|
|
168
|
+
inputs: Dict[str, Any],
|
|
169
|
+
run_manager: Optional[CallbackManagerForChainRun] = None
|
|
170
|
+
) -> Dict[str, Any]:
|
|
171
|
+
# Step 1: Connect to vector store to ensure embeddings are accessible
|
|
172
|
+
self.vector_store_handler.connect()
|
|
173
|
+
|
|
174
|
+
context_chunks: List[Document] = inputs.get(self.context_key, [])
|
|
175
|
+
logger.debug(f"Found {len(context_chunks)} context chunks.")
|
|
176
|
+
|
|
177
|
+
# Step 2: Extract unique document IDs from the provided chunks
|
|
178
|
+
unique_document_ids = self._get_document_ids_from_chunks(context_chunks)
|
|
179
|
+
logger.debug(f"Extracted {len(unique_document_ids)} unique document IDs: {unique_document_ids}")
|
|
180
|
+
|
|
181
|
+
# Step 3: Initialize the summarization chain if not provided
|
|
182
|
+
question = inputs.get(self.question_key, '')
|
|
183
|
+
map_reduce_documents_chain = self.map_reduce_documents_chain or create_map_reduce_documents_chain(
|
|
184
|
+
self.summarization_config, question
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Step 4: Dispatch event to signal summarization start
|
|
188
|
+
if run_manager:
|
|
189
|
+
run_manager.on_text("Starting summarization for documents.", verbose=True)
|
|
190
|
+
|
|
191
|
+
# Step 5: Process each document ID to summarize chunks with local context
|
|
192
|
+
for doc_id in unique_document_ids:
|
|
193
|
+
logger.debug(f"Fetching and summarizing chunks for document ID: {doc_id}")
|
|
194
|
+
|
|
195
|
+
# Fetch all chunks for the document
|
|
196
|
+
chunks = asyncio.get_event_loop().run_until_complete(self._get_all_chunks_for_document(doc_id))
|
|
197
|
+
if not chunks:
|
|
198
|
+
logger.warning(f"No chunks found for document ID: {doc_id}")
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
# Summarize each chunk with M neighboring chunks
|
|
202
|
+
M = self.neighbor_window
|
|
203
|
+
for i, chunk in enumerate(chunks):
|
|
204
|
+
window_chunks = chunks[max(0, i - M): min(len(chunks), i + M + 1)]
|
|
205
|
+
local_summary = asyncio.get_event_loop().run_until_complete(
|
|
206
|
+
map_reduce_documents_chain.ainvoke(window_chunks)
|
|
207
|
+
)
|
|
208
|
+
chunk.metadata['summary'] = local_summary.get('output_text', '')
|
|
209
|
+
logger.debug(f"Chunk {i} summary: {chunk.metadata['summary'][:100]}...")
|
|
210
|
+
|
|
211
|
+
# Step 6: Update the original context chunks with the newly generated summaries
|
|
212
|
+
for chunk in context_chunks:
|
|
213
|
+
doc_id = str(chunk.metadata.get(self.doc_id_key, ''))
|
|
214
|
+
matching_chunk = next((c for c in chunks if c.metadata.get(self.doc_id_key) == doc_id and c.metadata.get(
|
|
215
|
+
'chunk_index') == chunk.metadata.get('chunk_index')), None)
|
|
216
|
+
if matching_chunk:
|
|
217
|
+
chunk.metadata['summary'] = matching_chunk.metadata.get('summary', '')
|
|
218
|
+
else:
|
|
219
|
+
chunk.metadata['summary'] = ''
|
|
220
|
+
logger.warning(f"No matching chunk found for doc_id: {doc_id}")
|
|
221
|
+
|
|
222
|
+
# Step 7: Signal summarization end
|
|
223
|
+
if run_manager:
|
|
224
|
+
run_manager.on_text("Summarization completed.", verbose=True)
|
|
225
|
+
|
|
226
|
+
logger.debug(f"Updated {len(context_chunks)} context chunks with summaries.")
|
|
227
|
+
return inputs
|
|
@@ -256,9 +256,9 @@ class AgentsController:
|
|
|
256
256
|
if (
|
|
257
257
|
is_demo and (
|
|
258
258
|
(name is not None and name != agent_name)
|
|
259
|
-
or (model_name
|
|
260
|
-
or (
|
|
261
|
-
or (isinstance(params, dict) and len(params) >
|
|
259
|
+
or (model_name is not None and existing_agent.model_name != model_name)
|
|
260
|
+
or (provider is not None and existing_agent.provider != provider)
|
|
261
|
+
or (isinstance(params, dict) and len(params) > 0 and 'prompt_template' not in params)
|
|
262
262
|
)
|
|
263
263
|
):
|
|
264
264
|
raise ValueError("It is forbidden to change properties of the demo object")
|
|
@@ -1,9 +1,13 @@
|
|
|
1
|
-
|
|
1
|
+
import io
|
|
2
2
|
import logging
|
|
3
|
+
import contextlib
|
|
4
|
+
from typing import Any, Dict, List, Union, Callable
|
|
5
|
+
|
|
3
6
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
4
7
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
|
5
8
|
from langchain_core.messages.base import BaseMessage
|
|
6
9
|
from langchain_core.outputs import LLMResult
|
|
10
|
+
from langchain_core.callbacks import StdOutCallbackHandler
|
|
7
11
|
|
|
8
12
|
|
|
9
13
|
class ContextCaptureCallback(BaseCallbackHandler):
|
|
@@ -20,14 +24,49 @@ class ContextCaptureCallback(BaseCallbackHandler):
|
|
|
20
24
|
return self.context
|
|
21
25
|
|
|
22
26
|
|
|
27
|
+
class VerboseLogCallbackHandler(StdOutCallbackHandler):
|
|
28
|
+
def __init__(self, logger: logging.Logger, verbose: bool):
|
|
29
|
+
self.logger = logger
|
|
30
|
+
self.verbose = verbose
|
|
31
|
+
super().__init__()
|
|
32
|
+
|
|
33
|
+
def __call(self, method: Callable, *args: List[Any], **kwargs: Any) -> Any:
|
|
34
|
+
if self.verbose is False:
|
|
35
|
+
return
|
|
36
|
+
f = io.StringIO()
|
|
37
|
+
with contextlib.redirect_stdout(f):
|
|
38
|
+
method(*args, **kwargs)
|
|
39
|
+
output = f.getvalue()
|
|
40
|
+
self.logger.info(output)
|
|
41
|
+
|
|
42
|
+
def on_chain_start(self, *args: List[Any], **kwargs: Any) -> None:
|
|
43
|
+
self.__call(super().on_chain_start, *args, **kwargs)
|
|
44
|
+
|
|
45
|
+
def on_chain_end(self, *args: List[Any], **kwargs: Any) -> None:
|
|
46
|
+
self.__call(super().on_chain_end, *args, **kwargs)
|
|
47
|
+
|
|
48
|
+
def on_agent_action(self, *args: List[Any], **kwargs: Any) -> None:
|
|
49
|
+
self.__call(super().on_agent_action, *args, **kwargs)
|
|
50
|
+
|
|
51
|
+
def on_tool_end(self, *args: List[Any], **kwargs: Any) -> None:
|
|
52
|
+
self.__call(super().on_tool_end, *args, **kwargs)
|
|
53
|
+
|
|
54
|
+
def on_text(self, *args: List[Any], **kwargs: Any) -> None:
|
|
55
|
+
self.__call(super().on_text, *args, **kwargs)
|
|
56
|
+
|
|
57
|
+
def on_agent_finish(self, *args: List[Any], **kwargs: Any) -> None:
|
|
58
|
+
self.__call(super().on_agent_finish, *args, **kwargs)
|
|
59
|
+
|
|
60
|
+
|
|
23
61
|
class LogCallbackHandler(BaseCallbackHandler):
|
|
24
62
|
'''Langchain callback handler that logs agent and chain executions.'''
|
|
25
63
|
|
|
26
|
-
def __init__(self, logger: logging.Logger):
|
|
64
|
+
def __init__(self, logger: logging.Logger, verbose: bool = True):
|
|
27
65
|
logger.setLevel('DEBUG')
|
|
28
66
|
self.logger = logger
|
|
29
67
|
self._num_running_chains = 0
|
|
30
68
|
self.generated_sql = None
|
|
69
|
+
self.verbose_log_handler = VerboseLogCallbackHandler(logger, verbose)
|
|
31
70
|
|
|
32
71
|
def on_llm_start(
|
|
33
72
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
@@ -36,6 +75,7 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
36
75
|
self.logger.debug('LLM started with prompts:')
|
|
37
76
|
for prompt in prompts:
|
|
38
77
|
self.logger.debug(prompt[:50])
|
|
78
|
+
self.verbose_log_handler.on_llm_start(serialized, prompts, **kwargs)
|
|
39
79
|
|
|
40
80
|
def on_chat_model_start(
|
|
41
81
|
self,
|
|
@@ -46,7 +86,7 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
46
86
|
self.logger.debug('Chat model started with messages:')
|
|
47
87
|
for message_list in messages:
|
|
48
88
|
for message in message_list:
|
|
49
|
-
self.logger.debug(message.
|
|
89
|
+
self.logger.debug(message.pretty_repr())
|
|
50
90
|
|
|
51
91
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
|
|
52
92
|
'''Run on new LLM token. Only available when streaming is enabled.'''
|
|
@@ -72,6 +112,8 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
72
112
|
self._num_running_chains))
|
|
73
113
|
self.logger.debug('Inputs: {}'.format(inputs))
|
|
74
114
|
|
|
115
|
+
self.verbose_log_handler.on_chain_start(serialized=serialized, inputs=inputs, **kwargs)
|
|
116
|
+
|
|
75
117
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
|
|
76
118
|
'''Run when chain ends running.'''
|
|
77
119
|
self._num_running_chains -= 1
|
|
@@ -79,6 +121,8 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
79
121
|
self._num_running_chains))
|
|
80
122
|
self.logger.debug('Outputs: {}'.format(outputs))
|
|
81
123
|
|
|
124
|
+
self.verbose_log_handler.on_chain_end(outputs=outputs, **kwargs)
|
|
125
|
+
|
|
82
126
|
def on_chain_error(
|
|
83
127
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
84
128
|
) -> Any:
|
|
@@ -96,7 +140,7 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
96
140
|
|
|
97
141
|
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
|
98
142
|
'''Run when tool ends running.'''
|
|
99
|
-
|
|
143
|
+
self.verbose_log_handler.on_tool_end(output=output, **kwargs)
|
|
100
144
|
|
|
101
145
|
def on_tool_error(
|
|
102
146
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
@@ -106,7 +150,7 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
106
150
|
|
|
107
151
|
def on_text(self, text: str, **kwargs: Any) -> Any:
|
|
108
152
|
'''Run on arbitrary text.'''
|
|
109
|
-
|
|
153
|
+
self.verbose_log_handler.on_text(text=text, **kwargs)
|
|
110
154
|
|
|
111
155
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
|
112
156
|
'''Run on agent action.'''
|
|
@@ -124,7 +168,10 @@ class LogCallbackHandler(BaseCallbackHandler):
|
|
|
124
168
|
# fix for mistral
|
|
125
169
|
action.tool = action.tool.replace('\\', '')
|
|
126
170
|
|
|
171
|
+
self.verbose_log_handler.on_agent_action(action=action, **kwargs)
|
|
172
|
+
|
|
127
173
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
|
128
174
|
'''Run on agent end.'''
|
|
129
175
|
self.logger.debug('Agent finished with return values:')
|
|
130
176
|
self.logger.debug(str(finish.return_values))
|
|
177
|
+
self.verbose_log_handler.on_agent_finish(finish=finish, **kwargs)
|