AstrBot 4.5.7__py3-none-any.whl → 4.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. astrbot/core/agent/mcp_client.py +152 -26
  2. astrbot/core/agent/message.py +8 -1
  3. astrbot/core/config/default.py +8 -1
  4. astrbot/core/core_lifecycle.py +8 -0
  5. astrbot/core/db/__init__.py +50 -1
  6. astrbot/core/db/migration/migra_webchat_session.py +131 -0
  7. astrbot/core/db/po.py +49 -13
  8. astrbot/core/db/sqlite.py +102 -3
  9. astrbot/core/knowledge_base/kb_helper.py +314 -33
  10. astrbot/core/knowledge_base/kb_mgr.py +45 -1
  11. astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
  12. astrbot/core/knowledge_base/prompts.py +65 -0
  13. astrbot/core/pipeline/process_stage/method/llm_request.py +28 -14
  14. astrbot/core/pipeline/process_stage/utils.py +60 -16
  15. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +13 -10
  16. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -4
  17. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +0 -4
  18. astrbot/core/provider/entities.py +22 -9
  19. astrbot/core/provider/func_tool_manager.py +12 -9
  20. astrbot/core/provider/sources/gemini_source.py +25 -8
  21. astrbot/core/provider/sources/openai_source.py +9 -16
  22. astrbot/dashboard/routes/chat.py +134 -77
  23. astrbot/dashboard/routes/knowledge_base.py +172 -0
  24. {astrbot-4.5.7.dist-info → astrbot-4.6.0.dist-info}/METADATA +4 -3
  25. {astrbot-4.5.7.dist-info → astrbot-4.6.0.dist-info}/RECORD +28 -25
  26. {astrbot-4.5.7.dist-info → astrbot-4.6.0.dist-info}/WHEEL +0 -0
  27. {astrbot-4.5.7.dist-info → astrbot-4.6.0.dist-info}/entry_points.txt +0 -0
  28. {astrbot-4.5.7.dist-info → astrbot-4.6.0.dist-info}/licenses/LICENSE +0 -0
astrbot/core/db/sqlite.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
  import threading
3
3
  import typing as T
4
- from datetime import datetime, timedelta
4
+ from datetime import datetime, timedelta, timezone
5
5
 
6
6
  from sqlalchemy.ext.asyncio import AsyncSession
7
7
  from sqlmodel import col, delete, desc, func, or_, select, text, update
@@ -12,6 +12,7 @@ from astrbot.core.db.po import (
12
12
  ConversationV2,
13
13
  Persona,
14
14
  PlatformMessageHistory,
15
+ PlatformSession,
15
16
  PlatformStat,
16
17
  Preference,
17
18
  SQLModel,
@@ -412,7 +413,7 @@ class SQLiteDatabase(BaseDatabase):
412
413
  user_id,
413
414
  offset_sec=86400,
414
415
  ):
415
- """Delete platform message history records older than the specified offset."""
416
+ """Delete platform message history records newer than the specified offset."""
416
417
  async with self.get_db() as session:
417
418
  session: AsyncSession
418
419
  async with session.begin():
@@ -422,7 +423,7 @@ class SQLiteDatabase(BaseDatabase):
422
423
  delete(PlatformMessageHistory).where(
423
424
  col(PlatformMessageHistory.platform_id) == platform_id,
424
425
  col(PlatformMessageHistory.user_id) == user_id,
425
- col(PlatformMessageHistory.created_at) < cutoff_time,
426
+ col(PlatformMessageHistory.created_at) >= cutoff_time,
426
427
  ),
427
428
  )
428
429
 
@@ -709,3 +710,101 @@ class SQLiteDatabase(BaseDatabase):
709
710
  t.start()
710
711
  t.join()
711
712
  return result
713
+
714
+ # ====
715
+ # Platform Session Management
716
+ # ====
717
+
718
+ async def create_platform_session(
719
+ self,
720
+ creator: str,
721
+ platform_id: str = "webchat",
722
+ session_id: str | None = None,
723
+ display_name: str | None = None,
724
+ is_group: int = 0,
725
+ ) -> PlatformSession:
726
+ """Create a new Platform session."""
727
+ kwargs = {}
728
+ if session_id:
729
+ kwargs["session_id"] = session_id
730
+
731
+ async with self.get_db() as session:
732
+ session: AsyncSession
733
+ async with session.begin():
734
+ new_session = PlatformSession(
735
+ creator=creator,
736
+ platform_id=platform_id,
737
+ display_name=display_name,
738
+ is_group=is_group,
739
+ **kwargs,
740
+ )
741
+ session.add(new_session)
742
+ await session.flush()
743
+ await session.refresh(new_session)
744
+ return new_session
745
+
746
+ async def get_platform_session_by_id(
747
+ self, session_id: str
748
+ ) -> PlatformSession | None:
749
+ """Get a Platform session by its ID."""
750
+ async with self.get_db() as session:
751
+ session: AsyncSession
752
+ query = select(PlatformSession).where(
753
+ PlatformSession.session_id == session_id,
754
+ )
755
+ result = await session.execute(query)
756
+ return result.scalar_one_or_none()
757
+
758
+ async def get_platform_sessions_by_creator(
759
+ self,
760
+ creator: str,
761
+ platform_id: str | None = None,
762
+ page: int = 1,
763
+ page_size: int = 20,
764
+ ) -> list[PlatformSession]:
765
+ """Get all Platform sessions for a specific creator (username) and optionally platform."""
766
+ async with self.get_db() as session:
767
+ session: AsyncSession
768
+ offset = (page - 1) * page_size
769
+ query = select(PlatformSession).where(PlatformSession.creator == creator)
770
+
771
+ if platform_id:
772
+ query = query.where(PlatformSession.platform_id == platform_id)
773
+
774
+ query = (
775
+ query.order_by(desc(PlatformSession.updated_at))
776
+ .offset(offset)
777
+ .limit(page_size)
778
+ )
779
+ result = await session.execute(query)
780
+ return list(result.scalars().all())
781
+
782
+ async def update_platform_session(
783
+ self,
784
+ session_id: str,
785
+ display_name: str | None = None,
786
+ ) -> None:
787
+ """Update a Platform session's updated_at timestamp and optionally display_name."""
788
+ async with self.get_db() as session:
789
+ session: AsyncSession
790
+ async with session.begin():
791
+ values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
792
+ if display_name is not None:
793
+ values["display_name"] = display_name
794
+
795
+ await session.execute(
796
+ update(PlatformSession)
797
+ .where(col(PlatformSession.session_id == session_id))
798
+ .values(**values),
799
+ )
800
+
801
+ async def delete_platform_session(self, session_id: str) -> None:
802
+ """Delete a Platform session by its ID."""
803
+ async with self.get_db() as session:
804
+ session: AsyncSession
805
+ async with session.begin():
806
+ await session.execute(
807
+ delete(PlatformSession).where(
808
+ col(PlatformSession.session_id == session_id),
809
+ ),
810
+ )
@@ -1,4 +1,7 @@
1
+ import asyncio
1
2
  import json
3
+ import re
4
+ import time
2
5
  import uuid
3
6
  from pathlib import Path
4
7
 
@@ -8,12 +11,98 @@ from astrbot.core import logger
8
11
  from astrbot.core.db.vec_db.base import BaseVecDB
9
12
  from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
10
13
  from astrbot.core.provider.manager import ProviderManager
11
- from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
14
+ from astrbot.core.provider.provider import (
15
+ EmbeddingProvider,
16
+ RerankProvider,
17
+ )
18
+ from astrbot.core.provider.provider import (
19
+ Provider as LLMProvider,
20
+ )
12
21
 
13
22
  from .chunking.base import BaseChunker
23
+ from .chunking.recursive import RecursiveCharacterChunker
14
24
  from .kb_db_sqlite import KBSQLiteDatabase
15
25
  from .models import KBDocument, KBMedia, KnowledgeBase
26
+ from .parsers.url_parser import extract_text_from_url
16
27
  from .parsers.util import select_parser
28
+ from .prompts import TEXT_REPAIR_SYSTEM_PROMPT
29
+
30
+
31
+ class RateLimiter:
32
+ """一个简单的速率限制器"""
33
+
34
+ def __init__(self, max_rpm: int):
35
+ self.max_per_minute = max_rpm
36
+ self.interval = 60.0 / max_rpm if max_rpm > 0 else 0
37
+ self.last_call_time = 0
38
+
39
+ async def __aenter__(self):
40
+ if self.interval == 0:
41
+ return
42
+
43
+ now = time.monotonic()
44
+ elapsed = now - self.last_call_time
45
+
46
+ if elapsed < self.interval:
47
+ await asyncio.sleep(self.interval - elapsed)
48
+
49
+ self.last_call_time = time.monotonic()
50
+
51
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
52
+ pass
53
+
54
+
55
+ async def _repair_and_translate_chunk_with_retry(
56
+ chunk: str,
57
+ repair_llm_service: LLMProvider,
58
+ rate_limiter: RateLimiter,
59
+ max_retries: int = 2,
60
+ ) -> list[str]:
61
+ """
62
+ Repairs, translates, and optionally re-chunks a single text chunk using the small LLM, with rate limiting.
63
+ """
64
+ # 为了防止 LLM 上下文污染,在 user_prompt 中也加入明确的指令
65
+ user_prompt = f"""IGNORE ALL PREVIOUS INSTRUCTIONS. Your ONLY task is to process the following text chunk according to the system prompt provided.
66
+
67
+ Text chunk to process:
68
+ ---
69
+ {chunk}
70
+ ---
71
+ """
72
+ for attempt in range(max_retries + 1):
73
+ try:
74
+ async with rate_limiter:
75
+ response = await repair_llm_service.text_chat(
76
+ prompt=user_prompt, system_prompt=TEXT_REPAIR_SYSTEM_PROMPT
77
+ )
78
+
79
+ llm_output = response.completion_text
80
+
81
+ if "<discard_chunk />" in llm_output:
82
+ return [] # Signal to discard this chunk
83
+
84
+ # More robust regex to handle potential LLM formatting errors (spaces, newlines in tags)
85
+ matches = re.findall(
86
+ r"<\s*repaired_text\s*>\s*(.*?)\s*<\s*/\s*repaired_text\s*>",
87
+ llm_output,
88
+ re.DOTALL,
89
+ )
90
+
91
+ if matches:
92
+ # Further cleaning to ensure no empty strings are returned
93
+ return [m.strip() for m in matches if m.strip()]
94
+ else:
95
+ # If no valid tags and not explicitly discarded, discard it to be safe.
96
+ return []
97
+ except Exception as e:
98
+ logger.warning(
99
+ f" - LLM call failed on attempt {attempt + 1}/{max_retries + 1}. Error: {str(e)}"
100
+ )
101
+
102
+ logger.error(
103
+ f" - Failed to process chunk after {max_retries + 1} attempts. Using original text."
104
+ )
105
+ return [chunk]
17
106
 
18
107
 
19
108
  class KBHelper:
@@ -100,7 +189,7 @@ class KBHelper:
100
189
  async def upload_document(
101
190
  self,
102
191
  file_name: str,
103
- file_content: bytes,
192
+ file_content: bytes | None,
104
193
  file_type: str,
105
194
  chunk_size: int = 512,
106
195
  chunk_overlap: int = 50,
@@ -108,6 +197,7 @@ class KBHelper:
108
197
  tasks_limit: int = 3,
109
198
  max_retries: int = 3,
110
199
  progress_callback=None,
200
+ pre_chunked_text: list[str] | None = None,
111
201
  ) -> KBDocument:
112
202
  """上传并处理文档(带原子性保证和失败清理)
113
203
 
@@ -130,46 +220,63 @@ class KBHelper:
130
220
  await self._ensure_vec_db()
131
221
  doc_id = str(uuid.uuid4())
132
222
  media_paths: list[Path] = []
223
+ file_size = 0
133
224
 
134
225
  # file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
135
226
  # async with aiofiles.open(file_path, "wb") as f:
136
227
  # await f.write(file_content)
137
228
 
138
229
  try:
139
- # 阶段1: 解析文档
140
- if progress_callback:
141
- await progress_callback("parsing", 0, 100)
142
-
143
- parser = await select_parser(f".{file_type}")
144
- parse_result = await parser.parse(file_content, file_name)
145
- text_content = parse_result.text
146
- media_items = parse_result.media
230
+ chunks_text = []
231
+ saved_media = []
147
232
 
148
- if progress_callback:
149
- await progress_callback("parsing", 100, 100)
233
+ if pre_chunked_text is not None:
234
+ # 如果提供了预分块文本,直接使用
235
+ chunks_text = pre_chunked_text
236
+ file_size = sum(len(chunk) for chunk in chunks_text)
237
+ logger.info(f"使用预分块文本进行上传,共 {len(chunks_text)} 个块。")
238
+ else:
239
+ # 否则,执行标准的文件解析和分块流程
240
+ if file_content is None:
241
+ raise ValueError(
242
+ "当未提供 pre_chunked_text 时,file_content 不能为空。"
243
+ )
244
+
245
+ file_size = len(file_content)
246
+
247
+ # 阶段1: 解析文档
248
+ if progress_callback:
249
+ await progress_callback("parsing", 0, 100)
150
250
 
151
- # 保存媒体文件
152
- saved_media = []
153
- for media_item in media_items:
154
- media = await self._save_media(
155
- doc_id=doc_id,
156
- media_type=media_item.media_type,
157
- file_name=media_item.file_name,
158
- content=media_item.content,
159
- mime_type=media_item.mime_type,
160
- )
161
- saved_media.append(media)
162
- media_paths.append(Path(media.file_path))
251
+ parser = await select_parser(f".{file_type}")
252
+ parse_result = await parser.parse(file_content, file_name)
253
+ text_content = parse_result.text
254
+ media_items = parse_result.media
163
255
 
164
- # 阶段2: 分块
165
- if progress_callback:
166
- await progress_callback("chunking", 0, 100)
256
+ if progress_callback:
257
+ await progress_callback("parsing", 100, 100)
258
+
259
+ # 保存媒体文件
260
+ for media_item in media_items:
261
+ media = await self._save_media(
262
+ doc_id=doc_id,
263
+ media_type=media_item.media_type,
264
+ file_name=media_item.file_name,
265
+ content=media_item.content,
266
+ mime_type=media_item.mime_type,
267
+ )
268
+ saved_media.append(media)
269
+ media_paths.append(Path(media.file_path))
270
+
271
+ # 阶段2: 分块
272
+ if progress_callback:
273
+ await progress_callback("chunking", 0, 100)
167
274
 
168
- chunks_text = await self.chunker.chunk(
169
- text_content,
170
- chunk_size=chunk_size,
171
- chunk_overlap=chunk_overlap,
172
- )
275
+ chunks_text = await self.chunker.chunk(
276
+ text_content,
277
+ chunk_size=chunk_size,
278
+ chunk_overlap=chunk_overlap,
279
+ )
173
280
  contents = []
174
281
  metadatas = []
175
282
  for idx, chunk_text in enumerate(chunks_text):
@@ -205,7 +312,7 @@ class KBHelper:
205
312
  kb_id=self.kb.kb_id,
206
313
  doc_name=file_name,
207
314
  file_type=file_type,
208
- file_size=len(file_content),
315
+ file_size=file_size,
209
316
  # file_path=str(file_path),
210
317
  file_path="",
211
318
  chunk_count=len(chunks_text),
@@ -359,3 +466,177 @@ class KBHelper:
359
466
  )
360
467
 
361
468
  return media
469
+
470
+ async def upload_from_url(
471
+ self,
472
+ url: str,
473
+ chunk_size: int = 512,
474
+ chunk_overlap: int = 50,
475
+ batch_size: int = 32,
476
+ tasks_limit: int = 3,
477
+ max_retries: int = 3,
478
+ progress_callback=None,
479
+ enable_cleaning: bool = False,
480
+ cleaning_provider_id: str | None = None,
481
+ ) -> KBDocument:
482
+ """从 URL 上传并处理文档(带原子性保证和失败清理)
483
+ Args:
484
+ url: 要提取内容的网页 URL
485
+ chunk_size: 文本块大小
486
+ chunk_overlap: 文本块重叠大小
487
+ batch_size: 批处理大小
488
+ tasks_limit: 并发任务限制
489
+ max_retries: 最大重试次数
490
+ progress_callback: 进度回调函数,接收参数 (stage, current, total)
491
+ - stage: 当前阶段 ('extracting', 'cleaning', 'parsing', 'chunking', 'embedding')
492
+ - current: 当前进度
493
+ - total: 总数
494
+ Returns:
495
+ KBDocument: 上传的文档对象
496
+ Raises:
497
+ ValueError: 如果 URL 为空或无法提取内容
498
+ IOError: 如果网络请求失败
499
+ """
500
+ # 获取 Tavily API 密钥
501
+ config = self.prov_mgr.acm.default_conf
502
+ tavily_keys = config.get("provider_settings", {}).get(
503
+ "websearch_tavily_key", []
504
+ )
505
+ if not tavily_keys:
506
+ raise ValueError(
507
+ "Error: Tavily API key is not configured in provider_settings."
508
+ )
509
+
510
+ # 阶段1: 从 URL 提取内容
511
+ if progress_callback:
512
+ await progress_callback("extracting", 0, 100)
513
+
514
+ try:
515
+ text_content = await extract_text_from_url(url, tavily_keys)
516
+ except Exception as e:
517
+ logger.error(f"Failed to extract content from URL {url}: {e}")
518
+ raise OSError(f"Failed to extract content from URL {url}: {e}") from e
519
+
520
+ if not text_content:
521
+ raise ValueError(f"No content extracted from URL: {url}")
522
+
523
+ if progress_callback:
524
+ await progress_callback("extracting", 100, 100)
525
+
526
+ # 阶段2: (可选)清洗内容并分块
527
+ final_chunks = await self._clean_and_rechunk_content(
528
+ content=text_content,
529
+ url=url,
530
+ progress_callback=progress_callback,
531
+ enable_cleaning=enable_cleaning,
532
+ cleaning_provider_id=cleaning_provider_id,
533
+ chunk_size=chunk_size,
534
+ chunk_overlap=chunk_overlap,
535
+ )
536
+
537
+ if enable_cleaning and not final_chunks:
538
+ raise ValueError(
539
+ "内容清洗后未提取到有效文本。请尝试关闭内容清洗功能,或更换更高性能的LLM模型后重试。"
540
+ )
541
+
542
+ # 创建一个虚拟文件名
543
+ file_name = url.split("/")[-1] or f"document_from_{url}"
544
+ if not Path(file_name).suffix:
545
+ file_name += ".url"
546
+
547
+ # 复用现有的 upload_document 方法,但传入预分块文本
548
+ return await self.upload_document(
549
+ file_name=file_name,
550
+ file_content=None,
551
+ file_type="url", # 使用 'url' 作为特殊文件类型
552
+ chunk_size=chunk_size,
553
+ chunk_overlap=chunk_overlap,
554
+ batch_size=batch_size,
555
+ tasks_limit=tasks_limit,
556
+ max_retries=max_retries,
557
+ progress_callback=progress_callback,
558
+ pre_chunked_text=final_chunks,
559
+ )
560
+
561
+ async def _clean_and_rechunk_content(
562
+ self,
563
+ content: str,
564
+ url: str,
565
+ progress_callback=None,
566
+ enable_cleaning: bool = False,
567
+ cleaning_provider_id: str | None = None,
568
+ repair_max_rpm: int = 60,
569
+ chunk_size: int = 512,
570
+ chunk_overlap: int = 50,
571
+ ) -> list[str]:
572
+ """
573
+ 对从 URL 获取的内容进行清洗、修复、翻译和重新分块。
574
+ """
575
+ if not enable_cleaning:
576
+ # 如果不启用清洗,则使用从前端传递的参数进行分块
577
+ logger.info(
578
+ f"内容清洗未启用,使用指定参数进行分块: chunk_size={chunk_size}, chunk_overlap={chunk_overlap}"
579
+ )
580
+ return await self.chunker.chunk(
581
+ content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
582
+ )
583
+
584
+ if not cleaning_provider_id:
585
+ logger.warning(
586
+ "启用了内容清洗,但未提供 cleaning_provider_id,跳过清洗并使用默认分块。"
587
+ )
588
+ return await self.chunker.chunk(content)
589
+
590
+ if progress_callback:
591
+ await progress_callback("cleaning", 0, 100)
592
+
593
+ try:
594
+ # 获取指定的 LLM Provider
595
+ llm_provider = await self.prov_mgr.get_provider_by_id(cleaning_provider_id)
596
+ if not llm_provider or not isinstance(llm_provider, LLMProvider):
597
+ raise ValueError(
598
+ f"无法找到 ID 为 {cleaning_provider_id} 的 LLM Provider 或类型不正确"
599
+ )
600
+
601
+ # 初步分块
602
+ # 优化分隔符,优先按段落分割,以获得更高质量的文本块
603
+ text_splitter = RecursiveCharacterChunker(
604
+ chunk_size=chunk_size,
605
+ chunk_overlap=chunk_overlap,
606
+ separators=["\n\n", "\n", " "], # 优先使用段落分隔符
607
+ )
608
+ initial_chunks = await text_splitter.chunk(content)
609
+ logger.info(f"初步分块完成,生成 {len(initial_chunks)} 个块用于修复。")
610
+
611
+ # 并发处理所有块
612
+ rate_limiter = RateLimiter(repair_max_rpm)
613
+ tasks = [
614
+ _repair_and_translate_chunk_with_retry(
615
+ chunk, llm_provider, rate_limiter
616
+ )
617
+ for chunk in initial_chunks
618
+ ]
619
+
620
+ repaired_results = await asyncio.gather(*tasks, return_exceptions=True)
621
+
622
+ final_chunks = []
623
+ for i, result in enumerate(repaired_results):
624
+ if isinstance(result, Exception):
625
+ logger.warning(f"块 {i} 处理异常: {str(result)}. 回退到原始块。")
626
+ final_chunks.append(initial_chunks[i])
627
+ elif isinstance(result, list):
628
+ final_chunks.extend(result)
629
+
630
+ logger.info(
631
+ f"文本修复完成: {len(initial_chunks)} 个原始块 -> {len(final_chunks)} 个最终块。"
632
+ )
633
+
634
+ if progress_callback:
635
+ await progress_callback("cleaning", 100, 100)
636
+
637
+ return final_chunks
638
+
639
+ except Exception as e:
640
+ logger.error(f"使用 Provider '{cleaning_provider_id}' 清洗内容失败: {e}")
641
+ # 清洗失败,返回默认分块结果,保证流程不中断
642
+ return await self.chunker.chunk(content)
@@ -8,7 +8,7 @@ from astrbot.core.provider.manager import ProviderManager
8
8
  from .chunking.recursive import RecursiveCharacterChunker
9
9
  from .kb_db_sqlite import KBSQLiteDatabase
10
10
  from .kb_helper import KBHelper
11
- from .models import KnowledgeBase
11
+ from .models import KBDocument, KnowledgeBase
12
12
  from .retrieval.manager import RetrievalManager, RetrievalResult
13
13
  from .retrieval.rank_fusion import RankFusion
14
14
  from .retrieval.sparse_retriever import SparseRetriever
@@ -284,3 +284,47 @@ class KnowledgeBaseManager:
284
284
  await self.kb_db.close()
285
285
  except Exception as e:
286
286
  logger.error(f"关闭知识库元数据数据库失败: {e}")
287
+
288
+ async def upload_from_url(
289
+ self,
290
+ kb_id: str,
291
+ url: str,
292
+ chunk_size: int = 512,
293
+ chunk_overlap: int = 50,
294
+ batch_size: int = 32,
295
+ tasks_limit: int = 3,
296
+ max_retries: int = 3,
297
+ progress_callback=None,
298
+ ) -> KBDocument:
299
+ """从 URL 上传文档到指定的知识库
300
+
301
+ Args:
302
+ kb_id: 知识库 ID
303
+ url: 要提取内容的网页 URL
304
+ chunk_size: 文本块大小
305
+ chunk_overlap: 文本块重叠大小
306
+ batch_size: 批处理大小
307
+ tasks_limit: 并发任务限制
308
+ max_retries: 最大重试次数
309
+ progress_callback: 进度回调函数
310
+
311
+ Returns:
312
+ KBDocument: 上传的文档对象
313
+
314
+ Raises:
315
+ ValueError: 如果知识库不存在或 URL 为空
316
+ IOError: 如果网络请求失败
317
+ """
318
+ kb_helper = await self.get_kb(kb_id)
319
+ if not kb_helper:
320
+ raise ValueError(f"Knowledge base with id {kb_id} not found.")
321
+
322
+ return await kb_helper.upload_from_url(
323
+ url=url,
324
+ chunk_size=chunk_size,
325
+ chunk_overlap=chunk_overlap,
326
+ batch_size=batch_size,
327
+ tasks_limit=tasks_limit,
328
+ max_retries=max_retries,
329
+ progress_callback=progress_callback,
330
+ )
@@ -0,0 +1,103 @@
1
+ import asyncio
2
+
3
+ import aiohttp
4
+
5
+
6
+ class URLExtractor:
7
+ """URL 内容提取器,封装了 Tavily API 调用和密钥管理"""
8
+
9
+ def __init__(self, tavily_keys: list[str]):
10
+ """
11
+ 初始化 URL 提取器
12
+
13
+ Args:
14
+ tavily_keys: Tavily API 密钥列表
15
+ """
16
+ if not tavily_keys:
17
+ raise ValueError("Error: Tavily API keys are not configured.")
18
+
19
+ self.tavily_keys = tavily_keys
20
+ self.tavily_key_index = 0
21
+ self.tavily_key_lock = asyncio.Lock()
22
+
23
+ async def _get_tavily_key(self) -> str:
24
+ """并发安全的从列表中获取并轮换Tavily API密钥。"""
25
+ async with self.tavily_key_lock:
26
+ key = self.tavily_keys[self.tavily_key_index]
27
+ self.tavily_key_index = (self.tavily_key_index + 1) % len(self.tavily_keys)
28
+ return key
29
+
30
+ async def extract_text_from_url(self, url: str) -> str:
31
+ """
32
+ 使用 Tavily API 从 URL 提取主要文本内容。
33
+ 这是 web_searcher 插件中 tavily_extract_web_page 方法的简化版本,
34
+ 专门为知识库模块设计,不依赖 AstrMessageEvent。
35
+
36
+ Args:
37
+ url: 要提取内容的网页 URL
38
+
39
+ Returns:
40
+ 提取的文本内容
41
+
42
+ Raises:
43
+ ValueError: 如果 URL 为空或 API 密钥未配置
44
+ IOError: 如果请求失败或返回错误
45
+ """
46
+ if not url:
47
+ raise ValueError("Error: url must be a non-empty string.")
48
+
49
+ tavily_key = await self._get_tavily_key()
50
+ api_url = "https://api.tavily.com/extract"
51
+ headers = {
52
+ "Authorization": f"Bearer {tavily_key}",
53
+ "Content-Type": "application/json",
54
+ }
55
+
56
+ payload = {
57
+ "urls": [url],
58
+ "extract_depth": "basic", # 使用基础提取深度
59
+ }
60
+
61
+ try:
62
+ async with aiohttp.ClientSession(trust_env=True) as session:
63
+ async with session.post(
64
+ api_url,
65
+ json=payload,
66
+ headers=headers,
67
+ timeout=30.0, # 增加超时时间,因为内容提取可能需要更长时间
68
+ ) as response:
69
+ if response.status != 200:
70
+ reason = await response.text()
71
+ raise OSError(
72
+ f"Tavily web extraction failed: {reason}, status: {response.status}"
73
+ )
74
+
75
+ data = await response.json()
76
+ results = data.get("results", [])
77
+
78
+ if not results:
79
+ raise ValueError(f"No content extracted from URL: {url}")
80
+
81
+ # 返回第一个结果的内容
82
+ return results[0].get("raw_content", "")
83
+
84
+ except aiohttp.ClientError as e:
85
+ raise OSError(f"Failed to fetch URL {url}: {e}") from e
86
+ except Exception as e:
87
+ raise OSError(f"Failed to extract content from URL {url}: {e}") from e
88
+
89
+
90
+ # 为了向后兼容,提供一个简单的函数接口
91
+ async def extract_text_from_url(url: str, tavily_keys: list[str]) -> str:
92
+ """
93
+ 简单的函数接口,用于从 URL 提取文本内容
94
+
95
+ Args:
96
+ url: 要提取内容的网页 URL
97
+ tavily_keys: Tavily API 密钥列表
98
+
99
+ Returns:
100
+ 提取的文本内容
101
+ """
102
+ extractor = URLExtractor(tavily_keys)
103
+ return await extractor.extract_text_from_url(url)