AstrBot 4.5.8__py3-none-any.whl → 4.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/agent/mcp_client.py +152 -26
- astrbot/core/agent/message.py +7 -0
- astrbot/core/config/default.py +31 -1
- astrbot/core/core_lifecycle.py +8 -0
- astrbot/core/db/__init__.py +50 -1
- astrbot/core/db/migration/migra_webchat_session.py +131 -0
- astrbot/core/db/po.py +49 -13
- astrbot/core/db/sqlite.py +102 -3
- astrbot/core/knowledge_base/kb_helper.py +314 -33
- astrbot/core/knowledge_base/kb_mgr.py +45 -1
- astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
- astrbot/core/knowledge_base/prompts.py +65 -0
- astrbot/core/pipeline/process_stage/method/llm_request.py +28 -14
- astrbot/core/pipeline/process_stage/utils.py +60 -16
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +13 -10
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -4
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +0 -4
- astrbot/core/provider/entities.py +22 -9
- astrbot/core/provider/func_tool_manager.py +12 -9
- astrbot/core/provider/manager.py +4 -0
- astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
- astrbot/core/provider/sources/gemini_source.py +25 -8
- astrbot/core/provider/sources/openai_source.py +9 -16
- astrbot/dashboard/routes/chat.py +134 -77
- astrbot/dashboard/routes/knowledge_base.py +172 -0
- {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/METADATA +5 -4
- {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/RECORD +30 -26
- {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/WHEEL +0 -0
- {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/entry_points.txt +0 -0
- {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/licenses/LICENSE +0 -0
astrbot/core/db/po.py
CHANGED
|
@@ -3,13 +3,7 @@ from dataclasses import dataclass, field
|
|
|
3
3
|
from datetime import datetime, timezone
|
|
4
4
|
from typing import TypedDict
|
|
5
5
|
|
|
6
|
-
from sqlmodel import
|
|
7
|
-
JSON,
|
|
8
|
-
Field,
|
|
9
|
-
SQLModel,
|
|
10
|
-
Text,
|
|
11
|
-
UniqueConstraint,
|
|
12
|
-
)
|
|
6
|
+
from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint
|
|
13
7
|
|
|
14
8
|
|
|
15
9
|
class PlatformStat(SQLModel, table=True):
|
|
@@ -18,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
|
|
|
18
12
|
Note: In astrbot v4, we moved `platform` table to here.
|
|
19
13
|
"""
|
|
20
14
|
|
|
21
|
-
__tablename__ = "platform_stats"
|
|
15
|
+
__tablename__ = "platform_stats" # type: ignore
|
|
22
16
|
|
|
23
17
|
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
|
24
18
|
timestamp: datetime = Field(nullable=False)
|
|
@@ -37,7 +31,7 @@ class PlatformStat(SQLModel, table=True):
|
|
|
37
31
|
|
|
38
32
|
|
|
39
33
|
class ConversationV2(SQLModel, table=True):
|
|
40
|
-
__tablename__ = "conversations"
|
|
34
|
+
__tablename__ = "conversations" # type: ignore
|
|
41
35
|
|
|
42
36
|
inner_conversation_id: int = Field(
|
|
43
37
|
primary_key=True,
|
|
@@ -74,7 +68,7 @@ class Persona(SQLModel, table=True):
|
|
|
74
68
|
It can be used to customize the behavior of LLMs.
|
|
75
69
|
"""
|
|
76
70
|
|
|
77
|
-
__tablename__ = "personas"
|
|
71
|
+
__tablename__ = "personas" # type: ignore
|
|
78
72
|
|
|
79
73
|
id: int | None = Field(
|
|
80
74
|
primary_key=True,
|
|
@@ -104,7 +98,7 @@ class Persona(SQLModel, table=True):
|
|
|
104
98
|
class Preference(SQLModel, table=True):
|
|
105
99
|
"""This class represents preferences for bots."""
|
|
106
100
|
|
|
107
|
-
__tablename__ = "preferences"
|
|
101
|
+
__tablename__ = "preferences" # type: ignore
|
|
108
102
|
|
|
109
103
|
id: int | None = Field(
|
|
110
104
|
default=None,
|
|
@@ -140,7 +134,7 @@ class PlatformMessageHistory(SQLModel, table=True):
|
|
|
140
134
|
or platform-specific messages.
|
|
141
135
|
"""
|
|
142
136
|
|
|
143
|
-
__tablename__ = "platform_message_history"
|
|
137
|
+
__tablename__ = "platform_message_history" # type: ignore
|
|
144
138
|
|
|
145
139
|
id: int | None = Field(
|
|
146
140
|
primary_key=True,
|
|
@@ -161,13 +155,55 @@ class PlatformMessageHistory(SQLModel, table=True):
|
|
|
161
155
|
)
|
|
162
156
|
|
|
163
157
|
|
|
158
|
+
class PlatformSession(SQLModel, table=True):
|
|
159
|
+
"""Platform session table for managing user sessions across different platforms.
|
|
160
|
+
|
|
161
|
+
A session represents a chat window for a specific user on a specific platform.
|
|
162
|
+
Each session can have multiple conversations (对话) associated with it.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
__tablename__ = "platform_sessions" # type: ignore
|
|
166
|
+
|
|
167
|
+
inner_id: int | None = Field(
|
|
168
|
+
primary_key=True,
|
|
169
|
+
sa_column_kwargs={"autoincrement": True},
|
|
170
|
+
default=None,
|
|
171
|
+
)
|
|
172
|
+
session_id: str = Field(
|
|
173
|
+
max_length=100,
|
|
174
|
+
nullable=False,
|
|
175
|
+
unique=True,
|
|
176
|
+
default_factory=lambda: f"webchat_{uuid.uuid4()}",
|
|
177
|
+
)
|
|
178
|
+
platform_id: str = Field(default="webchat", nullable=False)
|
|
179
|
+
"""Platform identifier (e.g., 'webchat', 'qq', 'discord')"""
|
|
180
|
+
creator: str = Field(nullable=False)
|
|
181
|
+
"""Username of the session creator"""
|
|
182
|
+
display_name: str | None = Field(default=None, max_length=255)
|
|
183
|
+
"""Display name for the session"""
|
|
184
|
+
is_group: int = Field(default=0, nullable=False)
|
|
185
|
+
"""0 for private chat, 1 for group chat (not implemented yet)"""
|
|
186
|
+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
187
|
+
updated_at: datetime = Field(
|
|
188
|
+
default_factory=lambda: datetime.now(timezone.utc),
|
|
189
|
+
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
__table_args__ = (
|
|
193
|
+
UniqueConstraint(
|
|
194
|
+
"session_id",
|
|
195
|
+
name="uix_platform_session_id",
|
|
196
|
+
),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
164
200
|
class Attachment(SQLModel, table=True):
|
|
165
201
|
"""This class represents attachments for messages in AstrBot.
|
|
166
202
|
|
|
167
203
|
Attachments can be images, files, or other media types.
|
|
168
204
|
"""
|
|
169
205
|
|
|
170
|
-
__tablename__ = "attachments"
|
|
206
|
+
__tablename__ = "attachments" # type: ignore
|
|
171
207
|
|
|
172
208
|
inner_attachment_id: int | None = Field(
|
|
173
209
|
primary_key=True,
|
astrbot/core/db/sqlite.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import threading
|
|
3
3
|
import typing as T
|
|
4
|
-
from datetime import datetime, timedelta
|
|
4
|
+
from datetime import datetime, timedelta, timezone
|
|
5
5
|
|
|
6
6
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
7
|
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
|
@@ -12,6 +12,7 @@ from astrbot.core.db.po import (
|
|
|
12
12
|
ConversationV2,
|
|
13
13
|
Persona,
|
|
14
14
|
PlatformMessageHistory,
|
|
15
|
+
PlatformSession,
|
|
15
16
|
PlatformStat,
|
|
16
17
|
Preference,
|
|
17
18
|
SQLModel,
|
|
@@ -412,7 +413,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
412
413
|
user_id,
|
|
413
414
|
offset_sec=86400,
|
|
414
415
|
):
|
|
415
|
-
"""Delete platform message history records
|
|
416
|
+
"""Delete platform message history records newer than the specified offset."""
|
|
416
417
|
async with self.get_db() as session:
|
|
417
418
|
session: AsyncSession
|
|
418
419
|
async with session.begin():
|
|
@@ -422,7 +423,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
422
423
|
delete(PlatformMessageHistory).where(
|
|
423
424
|
col(PlatformMessageHistory.platform_id) == platform_id,
|
|
424
425
|
col(PlatformMessageHistory.user_id) == user_id,
|
|
425
|
-
col(PlatformMessageHistory.created_at)
|
|
426
|
+
col(PlatformMessageHistory.created_at) >= cutoff_time,
|
|
426
427
|
),
|
|
427
428
|
)
|
|
428
429
|
|
|
@@ -709,3 +710,101 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
709
710
|
t.start()
|
|
710
711
|
t.join()
|
|
711
712
|
return result
|
|
713
|
+
|
|
714
|
+
# ====
|
|
715
|
+
# Platform Session Management
|
|
716
|
+
# ====
|
|
717
|
+
|
|
718
|
+
async def create_platform_session(
|
|
719
|
+
self,
|
|
720
|
+
creator: str,
|
|
721
|
+
platform_id: str = "webchat",
|
|
722
|
+
session_id: str | None = None,
|
|
723
|
+
display_name: str | None = None,
|
|
724
|
+
is_group: int = 0,
|
|
725
|
+
) -> PlatformSession:
|
|
726
|
+
"""Create a new Platform session."""
|
|
727
|
+
kwargs = {}
|
|
728
|
+
if session_id:
|
|
729
|
+
kwargs["session_id"] = session_id
|
|
730
|
+
|
|
731
|
+
async with self.get_db() as session:
|
|
732
|
+
session: AsyncSession
|
|
733
|
+
async with session.begin():
|
|
734
|
+
new_session = PlatformSession(
|
|
735
|
+
creator=creator,
|
|
736
|
+
platform_id=platform_id,
|
|
737
|
+
display_name=display_name,
|
|
738
|
+
is_group=is_group,
|
|
739
|
+
**kwargs,
|
|
740
|
+
)
|
|
741
|
+
session.add(new_session)
|
|
742
|
+
await session.flush()
|
|
743
|
+
await session.refresh(new_session)
|
|
744
|
+
return new_session
|
|
745
|
+
|
|
746
|
+
async def get_platform_session_by_id(
|
|
747
|
+
self, session_id: str
|
|
748
|
+
) -> PlatformSession | None:
|
|
749
|
+
"""Get a Platform session by its ID."""
|
|
750
|
+
async with self.get_db() as session:
|
|
751
|
+
session: AsyncSession
|
|
752
|
+
query = select(PlatformSession).where(
|
|
753
|
+
PlatformSession.session_id == session_id,
|
|
754
|
+
)
|
|
755
|
+
result = await session.execute(query)
|
|
756
|
+
return result.scalar_one_or_none()
|
|
757
|
+
|
|
758
|
+
async def get_platform_sessions_by_creator(
|
|
759
|
+
self,
|
|
760
|
+
creator: str,
|
|
761
|
+
platform_id: str | None = None,
|
|
762
|
+
page: int = 1,
|
|
763
|
+
page_size: int = 20,
|
|
764
|
+
) -> list[PlatformSession]:
|
|
765
|
+
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
|
|
766
|
+
async with self.get_db() as session:
|
|
767
|
+
session: AsyncSession
|
|
768
|
+
offset = (page - 1) * page_size
|
|
769
|
+
query = select(PlatformSession).where(PlatformSession.creator == creator)
|
|
770
|
+
|
|
771
|
+
if platform_id:
|
|
772
|
+
query = query.where(PlatformSession.platform_id == platform_id)
|
|
773
|
+
|
|
774
|
+
query = (
|
|
775
|
+
query.order_by(desc(PlatformSession.updated_at))
|
|
776
|
+
.offset(offset)
|
|
777
|
+
.limit(page_size)
|
|
778
|
+
)
|
|
779
|
+
result = await session.execute(query)
|
|
780
|
+
return list(result.scalars().all())
|
|
781
|
+
|
|
782
|
+
async def update_platform_session(
|
|
783
|
+
self,
|
|
784
|
+
session_id: str,
|
|
785
|
+
display_name: str | None = None,
|
|
786
|
+
) -> None:
|
|
787
|
+
"""Update a Platform session's updated_at timestamp and optionally display_name."""
|
|
788
|
+
async with self.get_db() as session:
|
|
789
|
+
session: AsyncSession
|
|
790
|
+
async with session.begin():
|
|
791
|
+
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
|
|
792
|
+
if display_name is not None:
|
|
793
|
+
values["display_name"] = display_name
|
|
794
|
+
|
|
795
|
+
await session.execute(
|
|
796
|
+
update(PlatformSession)
|
|
797
|
+
.where(col(PlatformSession.session_id == session_id))
|
|
798
|
+
.values(**values),
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
async def delete_platform_session(self, session_id: str) -> None:
|
|
802
|
+
"""Delete a Platform session by its ID."""
|
|
803
|
+
async with self.get_db() as session:
|
|
804
|
+
session: AsyncSession
|
|
805
|
+
async with session.begin():
|
|
806
|
+
await session.execute(
|
|
807
|
+
delete(PlatformSession).where(
|
|
808
|
+
col(PlatformSession.session_id == session_id),
|
|
809
|
+
),
|
|
810
|
+
)
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import json
|
|
3
|
+
import re
|
|
4
|
+
import time
|
|
2
5
|
import uuid
|
|
3
6
|
from pathlib import Path
|
|
4
7
|
|
|
@@ -8,12 +11,98 @@ from astrbot.core import logger
|
|
|
8
11
|
from astrbot.core.db.vec_db.base import BaseVecDB
|
|
9
12
|
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
|
10
13
|
from astrbot.core.provider.manager import ProviderManager
|
|
11
|
-
from astrbot.core.provider.provider import
|
|
14
|
+
from astrbot.core.provider.provider import (
|
|
15
|
+
EmbeddingProvider,
|
|
16
|
+
RerankProvider,
|
|
17
|
+
)
|
|
18
|
+
from astrbot.core.provider.provider import (
|
|
19
|
+
Provider as LLMProvider,
|
|
20
|
+
)
|
|
12
21
|
|
|
13
22
|
from .chunking.base import BaseChunker
|
|
23
|
+
from .chunking.recursive import RecursiveCharacterChunker
|
|
14
24
|
from .kb_db_sqlite import KBSQLiteDatabase
|
|
15
25
|
from .models import KBDocument, KBMedia, KnowledgeBase
|
|
26
|
+
from .parsers.url_parser import extract_text_from_url
|
|
16
27
|
from .parsers.util import select_parser
|
|
28
|
+
from .prompts import TEXT_REPAIR_SYSTEM_PROMPT
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RateLimiter:
|
|
32
|
+
"""一个简单的速率限制器"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, max_rpm: int):
|
|
35
|
+
self.max_per_minute = max_rpm
|
|
36
|
+
self.interval = 60.0 / max_rpm if max_rpm > 0 else 0
|
|
37
|
+
self.last_call_time = 0
|
|
38
|
+
|
|
39
|
+
async def __aenter__(self):
|
|
40
|
+
if self.interval == 0:
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
now = time.monotonic()
|
|
44
|
+
elapsed = now - self.last_call_time
|
|
45
|
+
|
|
46
|
+
if elapsed < self.interval:
|
|
47
|
+
await asyncio.sleep(self.interval - elapsed)
|
|
48
|
+
|
|
49
|
+
self.last_call_time = time.monotonic()
|
|
50
|
+
|
|
51
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
async def _repair_and_translate_chunk_with_retry(
|
|
56
|
+
chunk: str,
|
|
57
|
+
repair_llm_service: LLMProvider,
|
|
58
|
+
rate_limiter: RateLimiter,
|
|
59
|
+
max_retries: int = 2,
|
|
60
|
+
) -> list[str]:
|
|
61
|
+
"""
|
|
62
|
+
Repairs, translates, and optionally re-chunks a single text chunk using the small LLM, with rate limiting.
|
|
63
|
+
"""
|
|
64
|
+
# 为了防止 LLM 上下文污染,在 user_prompt 中也加入明确的指令
|
|
65
|
+
user_prompt = f"""IGNORE ALL PREVIOUS INSTRUCTIONS. Your ONLY task is to process the following text chunk according to the system prompt provided.
|
|
66
|
+
|
|
67
|
+
Text chunk to process:
|
|
68
|
+
---
|
|
69
|
+
{chunk}
|
|
70
|
+
---
|
|
71
|
+
"""
|
|
72
|
+
for attempt in range(max_retries + 1):
|
|
73
|
+
try:
|
|
74
|
+
async with rate_limiter:
|
|
75
|
+
response = await repair_llm_service.text_chat(
|
|
76
|
+
prompt=user_prompt, system_prompt=TEXT_REPAIR_SYSTEM_PROMPT
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
llm_output = response.completion_text
|
|
80
|
+
|
|
81
|
+
if "<discard_chunk />" in llm_output:
|
|
82
|
+
return [] # Signal to discard this chunk
|
|
83
|
+
|
|
84
|
+
# More robust regex to handle potential LLM formatting errors (spaces, newlines in tags)
|
|
85
|
+
matches = re.findall(
|
|
86
|
+
r"<\s*repaired_text\s*>\s*(.*?)\s*<\s*/\s*repaired_text\s*>",
|
|
87
|
+
llm_output,
|
|
88
|
+
re.DOTALL,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if matches:
|
|
92
|
+
# Further cleaning to ensure no empty strings are returned
|
|
93
|
+
return [m.strip() for m in matches if m.strip()]
|
|
94
|
+
else:
|
|
95
|
+
# If no valid tags and not explicitly discarded, discard it to be safe.
|
|
96
|
+
return []
|
|
97
|
+
except Exception as e:
|
|
98
|
+
logger.warning(
|
|
99
|
+
f" - LLM call failed on attempt {attempt + 1}/{max_retries + 1}. Error: {str(e)}"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
logger.error(
|
|
103
|
+
f" - Failed to process chunk after {max_retries + 1} attempts. Using original text."
|
|
104
|
+
)
|
|
105
|
+
return [chunk]
|
|
17
106
|
|
|
18
107
|
|
|
19
108
|
class KBHelper:
|
|
@@ -100,7 +189,7 @@ class KBHelper:
|
|
|
100
189
|
async def upload_document(
|
|
101
190
|
self,
|
|
102
191
|
file_name: str,
|
|
103
|
-
file_content: bytes,
|
|
192
|
+
file_content: bytes | None,
|
|
104
193
|
file_type: str,
|
|
105
194
|
chunk_size: int = 512,
|
|
106
195
|
chunk_overlap: int = 50,
|
|
@@ -108,6 +197,7 @@ class KBHelper:
|
|
|
108
197
|
tasks_limit: int = 3,
|
|
109
198
|
max_retries: int = 3,
|
|
110
199
|
progress_callback=None,
|
|
200
|
+
pre_chunked_text: list[str] | None = None,
|
|
111
201
|
) -> KBDocument:
|
|
112
202
|
"""上传并处理文档(带原子性保证和失败清理)
|
|
113
203
|
|
|
@@ -130,46 +220,63 @@ class KBHelper:
|
|
|
130
220
|
await self._ensure_vec_db()
|
|
131
221
|
doc_id = str(uuid.uuid4())
|
|
132
222
|
media_paths: list[Path] = []
|
|
223
|
+
file_size = 0
|
|
133
224
|
|
|
134
225
|
# file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
|
|
135
226
|
# async with aiofiles.open(file_path, "wb") as f:
|
|
136
227
|
# await f.write(file_content)
|
|
137
228
|
|
|
138
229
|
try:
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
await progress_callback("parsing", 0, 100)
|
|
142
|
-
|
|
143
|
-
parser = await select_parser(f".{file_type}")
|
|
144
|
-
parse_result = await parser.parse(file_content, file_name)
|
|
145
|
-
text_content = parse_result.text
|
|
146
|
-
media_items = parse_result.media
|
|
230
|
+
chunks_text = []
|
|
231
|
+
saved_media = []
|
|
147
232
|
|
|
148
|
-
if
|
|
149
|
-
|
|
233
|
+
if pre_chunked_text is not None:
|
|
234
|
+
# 如果提供了预分块文本,直接使用
|
|
235
|
+
chunks_text = pre_chunked_text
|
|
236
|
+
file_size = sum(len(chunk) for chunk in chunks_text)
|
|
237
|
+
logger.info(f"使用预分块文本进行上传,共 {len(chunks_text)} 个块。")
|
|
238
|
+
else:
|
|
239
|
+
# 否则,执行标准的文件解析和分块流程
|
|
240
|
+
if file_content is None:
|
|
241
|
+
raise ValueError(
|
|
242
|
+
"当未提供 pre_chunked_text 时,file_content 不能为空。"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
file_size = len(file_content)
|
|
246
|
+
|
|
247
|
+
# 阶段1: 解析文档
|
|
248
|
+
if progress_callback:
|
|
249
|
+
await progress_callback("parsing", 0, 100)
|
|
150
250
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
doc_id=doc_id,
|
|
156
|
-
media_type=media_item.media_type,
|
|
157
|
-
file_name=media_item.file_name,
|
|
158
|
-
content=media_item.content,
|
|
159
|
-
mime_type=media_item.mime_type,
|
|
160
|
-
)
|
|
161
|
-
saved_media.append(media)
|
|
162
|
-
media_paths.append(Path(media.file_path))
|
|
251
|
+
parser = await select_parser(f".{file_type}")
|
|
252
|
+
parse_result = await parser.parse(file_content, file_name)
|
|
253
|
+
text_content = parse_result.text
|
|
254
|
+
media_items = parse_result.media
|
|
163
255
|
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
256
|
+
if progress_callback:
|
|
257
|
+
await progress_callback("parsing", 100, 100)
|
|
258
|
+
|
|
259
|
+
# 保存媒体文件
|
|
260
|
+
for media_item in media_items:
|
|
261
|
+
media = await self._save_media(
|
|
262
|
+
doc_id=doc_id,
|
|
263
|
+
media_type=media_item.media_type,
|
|
264
|
+
file_name=media_item.file_name,
|
|
265
|
+
content=media_item.content,
|
|
266
|
+
mime_type=media_item.mime_type,
|
|
267
|
+
)
|
|
268
|
+
saved_media.append(media)
|
|
269
|
+
media_paths.append(Path(media.file_path))
|
|
270
|
+
|
|
271
|
+
# 阶段2: 分块
|
|
272
|
+
if progress_callback:
|
|
273
|
+
await progress_callback("chunking", 0, 100)
|
|
167
274
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
275
|
+
chunks_text = await self.chunker.chunk(
|
|
276
|
+
text_content,
|
|
277
|
+
chunk_size=chunk_size,
|
|
278
|
+
chunk_overlap=chunk_overlap,
|
|
279
|
+
)
|
|
173
280
|
contents = []
|
|
174
281
|
metadatas = []
|
|
175
282
|
for idx, chunk_text in enumerate(chunks_text):
|
|
@@ -205,7 +312,7 @@ class KBHelper:
|
|
|
205
312
|
kb_id=self.kb.kb_id,
|
|
206
313
|
doc_name=file_name,
|
|
207
314
|
file_type=file_type,
|
|
208
|
-
file_size=
|
|
315
|
+
file_size=file_size,
|
|
209
316
|
# file_path=str(file_path),
|
|
210
317
|
file_path="",
|
|
211
318
|
chunk_count=len(chunks_text),
|
|
@@ -359,3 +466,177 @@ class KBHelper:
|
|
|
359
466
|
)
|
|
360
467
|
|
|
361
468
|
return media
|
|
469
|
+
|
|
470
|
+
async def upload_from_url(
|
|
471
|
+
self,
|
|
472
|
+
url: str,
|
|
473
|
+
chunk_size: int = 512,
|
|
474
|
+
chunk_overlap: int = 50,
|
|
475
|
+
batch_size: int = 32,
|
|
476
|
+
tasks_limit: int = 3,
|
|
477
|
+
max_retries: int = 3,
|
|
478
|
+
progress_callback=None,
|
|
479
|
+
enable_cleaning: bool = False,
|
|
480
|
+
cleaning_provider_id: str | None = None,
|
|
481
|
+
) -> KBDocument:
|
|
482
|
+
"""从 URL 上传并处理文档(带原子性保证和失败清理)
|
|
483
|
+
Args:
|
|
484
|
+
url: 要提取内容的网页 URL
|
|
485
|
+
chunk_size: 文本块大小
|
|
486
|
+
chunk_overlap: 文本块重叠大小
|
|
487
|
+
batch_size: 批处理大小
|
|
488
|
+
tasks_limit: 并发任务限制
|
|
489
|
+
max_retries: 最大重试次数
|
|
490
|
+
progress_callback: 进度回调函数,接收参数 (stage, current, total)
|
|
491
|
+
- stage: 当前阶段 ('extracting', 'cleaning', 'parsing', 'chunking', 'embedding')
|
|
492
|
+
- current: 当前进度
|
|
493
|
+
- total: 总数
|
|
494
|
+
Returns:
|
|
495
|
+
KBDocument: 上传的文档对象
|
|
496
|
+
Raises:
|
|
497
|
+
ValueError: 如果 URL 为空或无法提取内容
|
|
498
|
+
IOError: 如果网络请求失败
|
|
499
|
+
"""
|
|
500
|
+
# 获取 Tavily API 密钥
|
|
501
|
+
config = self.prov_mgr.acm.default_conf
|
|
502
|
+
tavily_keys = config.get("provider_settings", {}).get(
|
|
503
|
+
"websearch_tavily_key", []
|
|
504
|
+
)
|
|
505
|
+
if not tavily_keys:
|
|
506
|
+
raise ValueError(
|
|
507
|
+
"Error: Tavily API key is not configured in provider_settings."
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# 阶段1: 从 URL 提取内容
|
|
511
|
+
if progress_callback:
|
|
512
|
+
await progress_callback("extracting", 0, 100)
|
|
513
|
+
|
|
514
|
+
try:
|
|
515
|
+
text_content = await extract_text_from_url(url, tavily_keys)
|
|
516
|
+
except Exception as e:
|
|
517
|
+
logger.error(f"Failed to extract content from URL {url}: {e}")
|
|
518
|
+
raise OSError(f"Failed to extract content from URL {url}: {e}") from e
|
|
519
|
+
|
|
520
|
+
if not text_content:
|
|
521
|
+
raise ValueError(f"No content extracted from URL: {url}")
|
|
522
|
+
|
|
523
|
+
if progress_callback:
|
|
524
|
+
await progress_callback("extracting", 100, 100)
|
|
525
|
+
|
|
526
|
+
# 阶段2: (可选)清洗内容并分块
|
|
527
|
+
final_chunks = await self._clean_and_rechunk_content(
|
|
528
|
+
content=text_content,
|
|
529
|
+
url=url,
|
|
530
|
+
progress_callback=progress_callback,
|
|
531
|
+
enable_cleaning=enable_cleaning,
|
|
532
|
+
cleaning_provider_id=cleaning_provider_id,
|
|
533
|
+
chunk_size=chunk_size,
|
|
534
|
+
chunk_overlap=chunk_overlap,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
if enable_cleaning and not final_chunks:
|
|
538
|
+
raise ValueError(
|
|
539
|
+
"内容清洗后未提取到有效文本。请尝试关闭内容清洗功能,或更换更高性能的LLM模型后重试。"
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
# 创建一个虚拟文件名
|
|
543
|
+
file_name = url.split("/")[-1] or f"document_from_{url}"
|
|
544
|
+
if not Path(file_name).suffix:
|
|
545
|
+
file_name += ".url"
|
|
546
|
+
|
|
547
|
+
# 复用现有的 upload_document 方法,但传入预分块文本
|
|
548
|
+
return await self.upload_document(
|
|
549
|
+
file_name=file_name,
|
|
550
|
+
file_content=None,
|
|
551
|
+
file_type="url", # 使用 'url' 作为特殊文件类型
|
|
552
|
+
chunk_size=chunk_size,
|
|
553
|
+
chunk_overlap=chunk_overlap,
|
|
554
|
+
batch_size=batch_size,
|
|
555
|
+
tasks_limit=tasks_limit,
|
|
556
|
+
max_retries=max_retries,
|
|
557
|
+
progress_callback=progress_callback,
|
|
558
|
+
pre_chunked_text=final_chunks,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
async def _clean_and_rechunk_content(
|
|
562
|
+
self,
|
|
563
|
+
content: str,
|
|
564
|
+
url: str,
|
|
565
|
+
progress_callback=None,
|
|
566
|
+
enable_cleaning: bool = False,
|
|
567
|
+
cleaning_provider_id: str | None = None,
|
|
568
|
+
repair_max_rpm: int = 60,
|
|
569
|
+
chunk_size: int = 512,
|
|
570
|
+
chunk_overlap: int = 50,
|
|
571
|
+
) -> list[str]:
|
|
572
|
+
"""
|
|
573
|
+
对从 URL 获取的内容进行清洗、修复、翻译和重新分块。
|
|
574
|
+
"""
|
|
575
|
+
if not enable_cleaning:
|
|
576
|
+
# 如果不启用清洗,则使用从前端传递的参数进行分块
|
|
577
|
+
logger.info(
|
|
578
|
+
f"内容清洗未启用,使用指定参数进行分块: chunk_size={chunk_size}, chunk_overlap={chunk_overlap}"
|
|
579
|
+
)
|
|
580
|
+
return await self.chunker.chunk(
|
|
581
|
+
content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
if not cleaning_provider_id:
|
|
585
|
+
logger.warning(
|
|
586
|
+
"启用了内容清洗,但未提供 cleaning_provider_id,跳过清洗并使用默认分块。"
|
|
587
|
+
)
|
|
588
|
+
return await self.chunker.chunk(content)
|
|
589
|
+
|
|
590
|
+
if progress_callback:
|
|
591
|
+
await progress_callback("cleaning", 0, 100)
|
|
592
|
+
|
|
593
|
+
try:
|
|
594
|
+
# 获取指定的 LLM Provider
|
|
595
|
+
llm_provider = await self.prov_mgr.get_provider_by_id(cleaning_provider_id)
|
|
596
|
+
if not llm_provider or not isinstance(llm_provider, LLMProvider):
|
|
597
|
+
raise ValueError(
|
|
598
|
+
f"无法找到 ID 为 {cleaning_provider_id} 的 LLM Provider 或类型不正确"
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
# 初步分块
|
|
602
|
+
# 优化分隔符,优先按段落分割,以获得更高质量的文本块
|
|
603
|
+
text_splitter = RecursiveCharacterChunker(
|
|
604
|
+
chunk_size=chunk_size,
|
|
605
|
+
chunk_overlap=chunk_overlap,
|
|
606
|
+
separators=["\n\n", "\n", " "], # 优先使用段落分隔符
|
|
607
|
+
)
|
|
608
|
+
initial_chunks = await text_splitter.chunk(content)
|
|
609
|
+
logger.info(f"初步分块完成,生成 {len(initial_chunks)} 个块用于修复。")
|
|
610
|
+
|
|
611
|
+
# 并发处理所有块
|
|
612
|
+
rate_limiter = RateLimiter(repair_max_rpm)
|
|
613
|
+
tasks = [
|
|
614
|
+
_repair_and_translate_chunk_with_retry(
|
|
615
|
+
chunk, llm_provider, rate_limiter
|
|
616
|
+
)
|
|
617
|
+
for chunk in initial_chunks
|
|
618
|
+
]
|
|
619
|
+
|
|
620
|
+
repaired_results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
621
|
+
|
|
622
|
+
final_chunks = []
|
|
623
|
+
for i, result in enumerate(repaired_results):
|
|
624
|
+
if isinstance(result, Exception):
|
|
625
|
+
logger.warning(f"块 {i} 处理异常: {str(result)}. 回退到原始块。")
|
|
626
|
+
final_chunks.append(initial_chunks[i])
|
|
627
|
+
elif isinstance(result, list):
|
|
628
|
+
final_chunks.extend(result)
|
|
629
|
+
|
|
630
|
+
logger.info(
|
|
631
|
+
f"文本修复完成: {len(initial_chunks)} 个原始块 -> {len(final_chunks)} 个最终块。"
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
if progress_callback:
|
|
635
|
+
await progress_callback("cleaning", 100, 100)
|
|
636
|
+
|
|
637
|
+
return final_chunks
|
|
638
|
+
|
|
639
|
+
except Exception as e:
|
|
640
|
+
logger.error(f"使用 Provider '{cleaning_provider_id}' 清洗内容失败: {e}")
|
|
641
|
+
# 清洗失败,返回默认分块结果,保证流程不中断
|
|
642
|
+
return await self.chunker.chunk(content)
|