agent-runtime-core 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_runtime_core/__init__.py +118 -2
- agent_runtime_core/agentic_loop.py +254 -0
- agent_runtime_core/config.py +54 -4
- agent_runtime_core/config_schema.py +307 -0
- agent_runtime_core/contexts.py +348 -0
- agent_runtime_core/interfaces.py +106 -0
- agent_runtime_core/json_runtime.py +509 -0
- agent_runtime_core/llm/__init__.py +80 -7
- agent_runtime_core/llm/anthropic.py +133 -12
- agent_runtime_core/llm/models_config.py +180 -0
- agent_runtime_core/memory/__init__.py +70 -0
- agent_runtime_core/memory/manager.py +554 -0
- agent_runtime_core/memory/mixin.py +294 -0
- agent_runtime_core/multi_agent.py +569 -0
- agent_runtime_core/persistence/__init__.py +2 -0
- agent_runtime_core/persistence/file.py +277 -0
- agent_runtime_core/rag/__init__.py +65 -0
- agent_runtime_core/rag/chunking.py +224 -0
- agent_runtime_core/rag/indexer.py +253 -0
- agent_runtime_core/rag/retriever.py +261 -0
- agent_runtime_core/runner.py +193 -15
- agent_runtime_core/tool_calling_agent.py +88 -130
- agent_runtime_core/tools.py +179 -0
- agent_runtime_core/vectorstore/__init__.py +193 -0
- agent_runtime_core/vectorstore/base.py +138 -0
- agent_runtime_core/vectorstore/embeddings.py +242 -0
- agent_runtime_core/vectorstore/sqlite_vec.py +328 -0
- agent_runtime_core/vectorstore/vertex.py +295 -0
- {agent_runtime_core-0.6.0.dist-info → agent_runtime_core-0.7.1.dist-info}/METADATA +202 -1
- agent_runtime_core-0.7.1.dist-info/RECORD +57 -0
- agent_runtime_core-0.6.0.dist-info/RECORD +0 -38
- {agent_runtime_core-0.6.0.dist-info → agent_runtime_core-0.7.1.dist-info}/WHEEL +0 -0
- {agent_runtime_core-0.6.0.dist-info → agent_runtime_core-0.7.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -20,6 +20,7 @@ from agent_runtime_core.persistence.base import (
|
|
|
20
20
|
ConversationStore,
|
|
21
21
|
TaskStore,
|
|
22
22
|
PreferencesStore,
|
|
23
|
+
KnowledgeStore,
|
|
23
24
|
Scope,
|
|
24
25
|
Conversation,
|
|
25
26
|
ConversationMessage,
|
|
@@ -28,6 +29,10 @@ from agent_runtime_core.persistence.base import (
|
|
|
28
29
|
TaskList,
|
|
29
30
|
Task,
|
|
30
31
|
TaskState,
|
|
32
|
+
Fact,
|
|
33
|
+
FactType,
|
|
34
|
+
Summary,
|
|
35
|
+
Embedding,
|
|
31
36
|
)
|
|
32
37
|
|
|
33
38
|
|
|
@@ -505,3 +510,275 @@ class FilePreferencesStore(PreferencesStore):
|
|
|
505
510
|
|
|
506
511
|
async def get_all(self, scope: Scope = Scope.GLOBAL) -> dict[str, Any]:
|
|
507
512
|
return await self._load_preferences(scope)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
class FileKnowledgeStore(KnowledgeStore):
|
|
516
|
+
"""
|
|
517
|
+
File-based knowledge store with optional vector store integration.
|
|
518
|
+
|
|
519
|
+
Stores facts and summaries in JSON files:
|
|
520
|
+
- {base_path}/knowledge/facts/{fact_id}.json
|
|
521
|
+
- {base_path}/knowledge/summaries/{summary_id}.json
|
|
522
|
+
|
|
523
|
+
Embeddings are stored via an optional VectorStore backend.
|
|
524
|
+
"""
|
|
525
|
+
|
|
526
|
+
def __init__(
|
|
527
|
+
self,
|
|
528
|
+
project_dir: Optional[Path] = None,
|
|
529
|
+
vector_store: Optional["VectorStore"] = None,
|
|
530
|
+
embedding_client: Optional["EmbeddingClient"] = None,
|
|
531
|
+
):
|
|
532
|
+
"""
|
|
533
|
+
Initialize file-based knowledge store.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
project_dir: Base directory for file storage
|
|
537
|
+
vector_store: Optional VectorStore for embedding storage
|
|
538
|
+
embedding_client: Optional EmbeddingClient for generating embeddings
|
|
539
|
+
"""
|
|
540
|
+
self._project_dir = project_dir
|
|
541
|
+
self._vector_store = vector_store
|
|
542
|
+
self._embedding_client = embedding_client
|
|
543
|
+
|
|
544
|
+
def _get_facts_path(self, scope: Scope) -> Path:
|
|
545
|
+
return _get_base_path(scope, self._project_dir) / "knowledge" / "facts"
|
|
546
|
+
|
|
547
|
+
def _get_summaries_path(self, scope: Scope) -> Path:
|
|
548
|
+
return _get_base_path(scope, self._project_dir) / "knowledge" / "summaries"
|
|
549
|
+
|
|
550
|
+
def _get_fact_path(self, fact_id: UUID, scope: Scope) -> Path:
|
|
551
|
+
return self._get_facts_path(scope) / f"{fact_id}.json"
|
|
552
|
+
|
|
553
|
+
def _get_summary_path(self, summary_id: UUID, scope: Scope) -> Path:
|
|
554
|
+
return self._get_summaries_path(scope) / f"{summary_id}.json"
|
|
555
|
+
|
|
556
|
+
def _serialize_fact(self, fact: Fact) -> dict:
|
|
557
|
+
return {
|
|
558
|
+
"id": str(fact.id),
|
|
559
|
+
"key": fact.key,
|
|
560
|
+
"value": fact.value,
|
|
561
|
+
"fact_type": fact.fact_type.value,
|
|
562
|
+
"confidence": fact.confidence,
|
|
563
|
+
"source": fact.source,
|
|
564
|
+
"created_at": fact.created_at.isoformat(),
|
|
565
|
+
"updated_at": fact.updated_at.isoformat(),
|
|
566
|
+
"expires_at": fact.expires_at.isoformat() if fact.expires_at else None,
|
|
567
|
+
"metadata": fact.metadata,
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
def _deserialize_fact(self, data: dict) -> Fact:
|
|
571
|
+
return Fact(
|
|
572
|
+
id=_parse_uuid(data["id"]),
|
|
573
|
+
key=data["key"],
|
|
574
|
+
value=data["value"],
|
|
575
|
+
fact_type=FactType(data["fact_type"]),
|
|
576
|
+
confidence=data.get("confidence", 1.0),
|
|
577
|
+
source=data.get("source"),
|
|
578
|
+
created_at=_parse_datetime(data["created_at"]),
|
|
579
|
+
updated_at=_parse_datetime(data["updated_at"]),
|
|
580
|
+
expires_at=_parse_datetime(data["expires_at"]) if data.get("expires_at") else None,
|
|
581
|
+
metadata=data.get("metadata", {}),
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
def _serialize_summary(self, summary: Summary) -> dict:
|
|
585
|
+
return {
|
|
586
|
+
"id": str(summary.id),
|
|
587
|
+
"content": summary.content,
|
|
588
|
+
"conversation_id": str(summary.conversation_id) if summary.conversation_id else None,
|
|
589
|
+
"conversation_ids": [str(cid) for cid in summary.conversation_ids],
|
|
590
|
+
"start_time": summary.start_time.isoformat() if summary.start_time else None,
|
|
591
|
+
"end_time": summary.end_time.isoformat() if summary.end_time else None,
|
|
592
|
+
"created_at": summary.created_at.isoformat(),
|
|
593
|
+
"metadata": summary.metadata,
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
def _deserialize_summary(self, data: dict) -> Summary:
|
|
597
|
+
return Summary(
|
|
598
|
+
id=_parse_uuid(data["id"]),
|
|
599
|
+
content=data["content"],
|
|
600
|
+
conversation_id=_parse_uuid(data["conversation_id"]) if data.get("conversation_id") else None,
|
|
601
|
+
conversation_ids=[_parse_uuid(cid) for cid in data.get("conversation_ids", [])],
|
|
602
|
+
start_time=_parse_datetime(data["start_time"]) if data.get("start_time") else None,
|
|
603
|
+
end_time=_parse_datetime(data["end_time"]) if data.get("end_time") else None,
|
|
604
|
+
created_at=_parse_datetime(data["created_at"]),
|
|
605
|
+
metadata=data.get("metadata", {}),
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
# Fact operations
|
|
609
|
+
async def save_fact(self, fact: Fact, scope: Scope = Scope.PROJECT) -> None:
|
|
610
|
+
path = self._get_fact_path(fact.id, scope)
|
|
611
|
+
_ensure_dir(path.parent)
|
|
612
|
+
fact.updated_at = datetime.utcnow()
|
|
613
|
+
with open(path, "w") as f:
|
|
614
|
+
f.write(_json_dumps(self._serialize_fact(fact)))
|
|
615
|
+
|
|
616
|
+
async def get_fact(self, fact_id: UUID, scope: Scope = Scope.PROJECT) -> Optional[Fact]:
|
|
617
|
+
path = self._get_fact_path(fact_id, scope)
|
|
618
|
+
if not path.exists():
|
|
619
|
+
return None
|
|
620
|
+
try:
|
|
621
|
+
with open(path, "r") as f:
|
|
622
|
+
data = json.load(f)
|
|
623
|
+
return self._deserialize_fact(data)
|
|
624
|
+
except (json.JSONDecodeError, IOError):
|
|
625
|
+
return None
|
|
626
|
+
|
|
627
|
+
async def get_fact_by_key(self, key: str, scope: Scope = Scope.PROJECT) -> Optional[Fact]:
|
|
628
|
+
facts_path = self._get_facts_path(scope)
|
|
629
|
+
if not facts_path.exists():
|
|
630
|
+
return None
|
|
631
|
+
for file in facts_path.glob("*.json"):
|
|
632
|
+
try:
|
|
633
|
+
with open(file, "r") as f:
|
|
634
|
+
data = json.load(f)
|
|
635
|
+
if data.get("key") == key:
|
|
636
|
+
return self._deserialize_fact(data)
|
|
637
|
+
except (json.JSONDecodeError, IOError):
|
|
638
|
+
continue
|
|
639
|
+
return None
|
|
640
|
+
|
|
641
|
+
async def list_facts(
|
|
642
|
+
self,
|
|
643
|
+
scope: Scope = Scope.PROJECT,
|
|
644
|
+
fact_type: Optional[FactType] = None,
|
|
645
|
+
limit: int = 100,
|
|
646
|
+
) -> list[Fact]:
|
|
647
|
+
facts_path = self._get_facts_path(scope)
|
|
648
|
+
if not facts_path.exists():
|
|
649
|
+
return []
|
|
650
|
+
facts = []
|
|
651
|
+
for file in facts_path.glob("*.json"):
|
|
652
|
+
try:
|
|
653
|
+
with open(file, "r") as f:
|
|
654
|
+
data = json.load(f)
|
|
655
|
+
fact = self._deserialize_fact(data)
|
|
656
|
+
if fact_type is None or fact.fact_type == fact_type:
|
|
657
|
+
facts.append(fact)
|
|
658
|
+
except (json.JSONDecodeError, IOError):
|
|
659
|
+
continue
|
|
660
|
+
facts.sort(key=lambda f: f.updated_at, reverse=True)
|
|
661
|
+
return facts[:limit]
|
|
662
|
+
|
|
663
|
+
async def delete_fact(self, fact_id: UUID, scope: Scope = Scope.PROJECT) -> bool:
|
|
664
|
+
path = self._get_fact_path(fact_id, scope)
|
|
665
|
+
if path.exists():
|
|
666
|
+
path.unlink()
|
|
667
|
+
return True
|
|
668
|
+
return False
|
|
669
|
+
|
|
670
|
+
# Summary operations
|
|
671
|
+
async def save_summary(self, summary: Summary, scope: Scope = Scope.PROJECT) -> None:
|
|
672
|
+
path = self._get_summary_path(summary.id, scope)
|
|
673
|
+
_ensure_dir(path.parent)
|
|
674
|
+
with open(path, "w") as f:
|
|
675
|
+
f.write(_json_dumps(self._serialize_summary(summary)))
|
|
676
|
+
|
|
677
|
+
async def get_summary(self, summary_id: UUID, scope: Scope = Scope.PROJECT) -> Optional[Summary]:
|
|
678
|
+
path = self._get_summary_path(summary_id, scope)
|
|
679
|
+
if not path.exists():
|
|
680
|
+
return None
|
|
681
|
+
try:
|
|
682
|
+
with open(path, "r") as f:
|
|
683
|
+
data = json.load(f)
|
|
684
|
+
return self._deserialize_summary(data)
|
|
685
|
+
except (json.JSONDecodeError, IOError):
|
|
686
|
+
return None
|
|
687
|
+
|
|
688
|
+
async def get_summaries_for_conversation(
|
|
689
|
+
self,
|
|
690
|
+
conversation_id: UUID,
|
|
691
|
+
scope: Scope = Scope.PROJECT,
|
|
692
|
+
) -> list[Summary]:
|
|
693
|
+
summaries_path = self._get_summaries_path(scope)
|
|
694
|
+
if not summaries_path.exists():
|
|
695
|
+
return []
|
|
696
|
+
summaries = []
|
|
697
|
+
for file in summaries_path.glob("*.json"):
|
|
698
|
+
try:
|
|
699
|
+
with open(file, "r") as f:
|
|
700
|
+
data = json.load(f)
|
|
701
|
+
summary = self._deserialize_summary(data)
|
|
702
|
+
if (
|
|
703
|
+
summary.conversation_id == conversation_id
|
|
704
|
+
or conversation_id in summary.conversation_ids
|
|
705
|
+
):
|
|
706
|
+
summaries.append(summary)
|
|
707
|
+
except (json.JSONDecodeError, IOError):
|
|
708
|
+
continue
|
|
709
|
+
summaries.sort(key=lambda s: s.created_at, reverse=True)
|
|
710
|
+
return summaries
|
|
711
|
+
|
|
712
|
+
async def delete_summary(self, summary_id: UUID, scope: Scope = Scope.PROJECT) -> bool:
|
|
713
|
+
path = self._get_summary_path(summary_id, scope)
|
|
714
|
+
if path.exists():
|
|
715
|
+
path.unlink()
|
|
716
|
+
return True
|
|
717
|
+
return False
|
|
718
|
+
|
|
719
|
+
# Embedding operations (using VectorStore)
|
|
720
|
+
async def save_embedding(self, embedding: Embedding, scope: Scope = Scope.PROJECT) -> None:
|
|
721
|
+
if not self._vector_store:
|
|
722
|
+
raise NotImplementedError("No vector store configured")
|
|
723
|
+
await self._vector_store.add(
|
|
724
|
+
id=str(embedding.id),
|
|
725
|
+
vector=embedding.vector,
|
|
726
|
+
content=embedding.content,
|
|
727
|
+
metadata={
|
|
728
|
+
"content_type": embedding.content_type,
|
|
729
|
+
"source_id": str(embedding.source_id) if embedding.source_id else None,
|
|
730
|
+
"model": embedding.model,
|
|
731
|
+
"dimensions": embedding.dimensions,
|
|
732
|
+
**embedding.metadata,
|
|
733
|
+
},
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
async def search_similar(
|
|
737
|
+
self,
|
|
738
|
+
query_vector: list[float],
|
|
739
|
+
limit: int = 10,
|
|
740
|
+
scope: Scope = Scope.PROJECT,
|
|
741
|
+
content_type: Optional[str] = None,
|
|
742
|
+
) -> list[tuple[Embedding, float]]:
|
|
743
|
+
if not self._vector_store:
|
|
744
|
+
raise NotImplementedError("No vector store configured")
|
|
745
|
+
|
|
746
|
+
filter_dict = {"content_type": content_type} if content_type else None
|
|
747
|
+
results = await self._vector_store.search(query_vector, limit=limit, filter=filter_dict)
|
|
748
|
+
|
|
749
|
+
return [
|
|
750
|
+
(
|
|
751
|
+
Embedding(
|
|
752
|
+
id=_parse_uuid(r.id),
|
|
753
|
+
vector=[], # Don't return full vector
|
|
754
|
+
content=r.content,
|
|
755
|
+
content_type=r.metadata.get("content_type", "text"),
|
|
756
|
+
source_id=_parse_uuid(r.metadata["source_id"]) if r.metadata.get("source_id") else None,
|
|
757
|
+
model=r.metadata.get("model"),
|
|
758
|
+
dimensions=r.metadata.get("dimensions", 0),
|
|
759
|
+
created_at=datetime.utcnow(), # Not stored in vector store
|
|
760
|
+
metadata={k: v for k, v in r.metadata.items()
|
|
761
|
+
if k not in ("content_type", "source_id", "model", "dimensions")},
|
|
762
|
+
),
|
|
763
|
+
r.score,
|
|
764
|
+
)
|
|
765
|
+
for r in results
|
|
766
|
+
]
|
|
767
|
+
|
|
768
|
+
async def delete_embedding(self, embedding_id: UUID, scope: Scope = Scope.PROJECT) -> bool:
|
|
769
|
+
if not self._vector_store:
|
|
770
|
+
raise NotImplementedError("No vector store configured")
|
|
771
|
+
return await self._vector_store.delete(str(embedding_id))
|
|
772
|
+
|
|
773
|
+
async def close(self) -> None:
|
|
774
|
+
if self._vector_store:
|
|
775
|
+
await self._vector_store.close()
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
# Type hints for optional imports
|
|
779
|
+
try:
|
|
780
|
+
from agent_runtime_core.vectorstore.base import VectorStore
|
|
781
|
+
from agent_runtime_core.vectorstore.embeddings import EmbeddingClient
|
|
782
|
+
except ImportError:
|
|
783
|
+
VectorStore = None # type: ignore
|
|
784
|
+
EmbeddingClient = None # type: ignore
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RAG (Retrieval Augmented Generation) module for agent_runtime_core.
|
|
3
|
+
|
|
4
|
+
This module provides portable RAG services that work without Django:
|
|
5
|
+
- KnowledgeIndexer: Service to chunk, embed, and store knowledge in vector stores
|
|
6
|
+
- KnowledgeRetriever: Service to retrieve relevant knowledge at runtime
|
|
7
|
+
- Text chunking utilities
|
|
8
|
+
|
|
9
|
+
Example usage:
|
|
10
|
+
from agent_runtime_core.rag import (
|
|
11
|
+
KnowledgeIndexer,
|
|
12
|
+
KnowledgeRetriever,
|
|
13
|
+
chunk_text,
|
|
14
|
+
ChunkingConfig,
|
|
15
|
+
)
|
|
16
|
+
from agent_runtime_core.vectorstore import get_vector_store, get_embedding_client
|
|
17
|
+
|
|
18
|
+
# Setup
|
|
19
|
+
vector_store = get_vector_store("sqlite_vec", path="./vectors.db")
|
|
20
|
+
embedding_client = get_embedding_client("openai")
|
|
21
|
+
|
|
22
|
+
# Index content
|
|
23
|
+
indexer = KnowledgeIndexer(vector_store, embedding_client)
|
|
24
|
+
await indexer.index_text(
|
|
25
|
+
text="Your knowledge content here...",
|
|
26
|
+
source_id="doc-1",
|
|
27
|
+
metadata={"name": "My Document"},
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Retrieve at runtime
|
|
31
|
+
retriever = KnowledgeRetriever(vector_store, embedding_client)
|
|
32
|
+
results = await retriever.retrieve(
|
|
33
|
+
query="What is the return policy?",
|
|
34
|
+
top_k=5,
|
|
35
|
+
)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
from agent_runtime_core.rag.chunking import (
|
|
39
|
+
chunk_text,
|
|
40
|
+
ChunkingConfig,
|
|
41
|
+
TextChunk,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# Lazy imports to avoid circular dependencies
|
|
46
|
+
def __getattr__(name: str):
|
|
47
|
+
if name == "KnowledgeIndexer":
|
|
48
|
+
from agent_runtime_core.rag.indexer import KnowledgeIndexer
|
|
49
|
+
return KnowledgeIndexer
|
|
50
|
+
elif name == "KnowledgeRetriever":
|
|
51
|
+
from agent_runtime_core.rag.retriever import KnowledgeRetriever
|
|
52
|
+
return KnowledgeRetriever
|
|
53
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
__all__ = [
|
|
57
|
+
# Chunking
|
|
58
|
+
"chunk_text",
|
|
59
|
+
"ChunkingConfig",
|
|
60
|
+
"TextChunk",
|
|
61
|
+
# Services
|
|
62
|
+
"KnowledgeIndexer",
|
|
63
|
+
"KnowledgeRetriever",
|
|
64
|
+
]
|
|
65
|
+
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Text chunking utilities for RAG.
|
|
3
|
+
|
|
4
|
+
Provides functions to split text into chunks suitable for embedding and retrieval.
|
|
5
|
+
This module has no external dependencies and can be used standalone.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Optional
|
|
10
|
+
import re
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class ChunkingConfig:
|
|
15
|
+
"""Configuration for text chunking."""
|
|
16
|
+
|
|
17
|
+
chunk_size: int = 500
|
|
18
|
+
"""Target size of each chunk in characters."""
|
|
19
|
+
|
|
20
|
+
chunk_overlap: int = 50
|
|
21
|
+
"""Number of characters to overlap between chunks."""
|
|
22
|
+
|
|
23
|
+
separator: str = "\n\n"
|
|
24
|
+
"""Primary separator to split on (paragraphs by default)."""
|
|
25
|
+
|
|
26
|
+
fallback_separators: list[str] = None
|
|
27
|
+
"""Fallback separators if primary doesn't work: ["\n", ". ", " "]"""
|
|
28
|
+
|
|
29
|
+
min_chunk_size: int = 100
|
|
30
|
+
"""Minimum chunk size - smaller chunks are merged with neighbors."""
|
|
31
|
+
|
|
32
|
+
def __post_init__(self):
|
|
33
|
+
if self.fallback_separators is None:
|
|
34
|
+
self.fallback_separators = ["\n", ". ", " "]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class TextChunk:
|
|
39
|
+
"""A chunk of text with metadata."""
|
|
40
|
+
|
|
41
|
+
text: str
|
|
42
|
+
"""The chunk text content."""
|
|
43
|
+
|
|
44
|
+
index: int
|
|
45
|
+
"""Index of this chunk (0-based)."""
|
|
46
|
+
|
|
47
|
+
start_char: int
|
|
48
|
+
"""Starting character position in original text."""
|
|
49
|
+
|
|
50
|
+
end_char: int
|
|
51
|
+
"""Ending character position in original text."""
|
|
52
|
+
|
|
53
|
+
metadata: dict = None
|
|
54
|
+
"""Optional metadata for this chunk."""
|
|
55
|
+
|
|
56
|
+
def __post_init__(self):
|
|
57
|
+
if self.metadata is None:
|
|
58
|
+
self.metadata = {}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def chunk_text(
|
|
62
|
+
text: str,
|
|
63
|
+
config: Optional[ChunkingConfig] = None,
|
|
64
|
+
metadata: Optional[dict] = None,
|
|
65
|
+
) -> list[TextChunk]:
|
|
66
|
+
"""
|
|
67
|
+
Split text into chunks suitable for embedding.
|
|
68
|
+
|
|
69
|
+
Uses a recursive approach:
|
|
70
|
+
1. Try to split on primary separator (paragraphs)
|
|
71
|
+
2. If chunks are too large, split on fallback separators
|
|
72
|
+
3. Merge small chunks with neighbors
|
|
73
|
+
4. Add overlap between chunks
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
text: The text to chunk
|
|
77
|
+
config: Chunking configuration (uses defaults if not provided)
|
|
78
|
+
metadata: Optional metadata to include in each chunk
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
List of TextChunk objects
|
|
82
|
+
"""
|
|
83
|
+
if config is None:
|
|
84
|
+
config = ChunkingConfig()
|
|
85
|
+
|
|
86
|
+
if metadata is None:
|
|
87
|
+
metadata = {}
|
|
88
|
+
|
|
89
|
+
if not text or not text.strip():
|
|
90
|
+
return []
|
|
91
|
+
|
|
92
|
+
# Normalize whitespace
|
|
93
|
+
text = text.strip()
|
|
94
|
+
|
|
95
|
+
# Split into initial segments using primary separator
|
|
96
|
+
segments = _split_on_separator(text, config.separator)
|
|
97
|
+
|
|
98
|
+
# Recursively split segments that are too large
|
|
99
|
+
all_separators = [config.separator] + config.fallback_separators
|
|
100
|
+
segments = _recursive_split(segments, all_separators, config.chunk_size)
|
|
101
|
+
|
|
102
|
+
# Merge small segments
|
|
103
|
+
segments = _merge_small_segments(segments, config.min_chunk_size, config.chunk_size)
|
|
104
|
+
|
|
105
|
+
# Create chunks with overlap
|
|
106
|
+
chunks = _create_chunks_with_overlap(
|
|
107
|
+
segments,
|
|
108
|
+
config.chunk_overlap,
|
|
109
|
+
text,
|
|
110
|
+
metadata,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return chunks
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _split_on_separator(text: str, separator: str) -> list[str]:
|
|
117
|
+
"""Split text on a separator, keeping non-empty segments."""
|
|
118
|
+
if separator == ". ":
|
|
119
|
+
# Special handling for sentence splitting - keep the period
|
|
120
|
+
parts = re.split(r'(?<=\.)\s+', text)
|
|
121
|
+
else:
|
|
122
|
+
parts = text.split(separator)
|
|
123
|
+
return [p.strip() for p in parts if p.strip()]
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _recursive_split(
|
|
127
|
+
segments: list[str],
|
|
128
|
+
separators: list[str],
|
|
129
|
+
max_size: int,
|
|
130
|
+
sep_index: int = 0,
|
|
131
|
+
) -> list[str]:
|
|
132
|
+
"""Recursively split segments that are too large."""
|
|
133
|
+
if sep_index >= len(separators):
|
|
134
|
+
# No more separators - just return as is (will be split by character if needed)
|
|
135
|
+
return segments
|
|
136
|
+
|
|
137
|
+
result = []
|
|
138
|
+
separator = separators[sep_index]
|
|
139
|
+
|
|
140
|
+
for segment in segments:
|
|
141
|
+
if len(segment) <= max_size:
|
|
142
|
+
result.append(segment)
|
|
143
|
+
else:
|
|
144
|
+
# Try to split on this separator
|
|
145
|
+
sub_segments = _split_on_separator(segment, separator)
|
|
146
|
+
if len(sub_segments) > 1:
|
|
147
|
+
# Recursively process sub-segments
|
|
148
|
+
result.extend(_recursive_split(sub_segments, separators, max_size, sep_index))
|
|
149
|
+
else:
|
|
150
|
+
# This separator didn't help, try next one
|
|
151
|
+
result.extend(_recursive_split([segment], separators, max_size, sep_index + 1))
|
|
152
|
+
|
|
153
|
+
return result
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _merge_small_segments(
|
|
157
|
+
segments: list[str],
|
|
158
|
+
min_size: int,
|
|
159
|
+
max_size: int,
|
|
160
|
+
) -> list[str]:
|
|
161
|
+
"""Merge segments that are too small with their neighbors."""
|
|
162
|
+
if not segments:
|
|
163
|
+
return []
|
|
164
|
+
|
|
165
|
+
result = []
|
|
166
|
+
current = segments[0]
|
|
167
|
+
|
|
168
|
+
for segment in segments[1:]:
|
|
169
|
+
combined = current + "\n\n" + segment
|
|
170
|
+
if len(current) < min_size and len(combined) <= max_size:
|
|
171
|
+
# Merge with current
|
|
172
|
+
current = combined
|
|
173
|
+
else:
|
|
174
|
+
# Save current and start new
|
|
175
|
+
result.append(current)
|
|
176
|
+
current = segment
|
|
177
|
+
|
|
178
|
+
result.append(current)
|
|
179
|
+
return result
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _create_chunks_with_overlap(
|
|
183
|
+
segments: list[str],
|
|
184
|
+
overlap: int,
|
|
185
|
+
original_text: str,
|
|
186
|
+
base_metadata: dict,
|
|
187
|
+
) -> list[TextChunk]:
|
|
188
|
+
"""Create TextChunk objects with overlap between chunks."""
|
|
189
|
+
chunks = []
|
|
190
|
+
current_pos = 0
|
|
191
|
+
|
|
192
|
+
for i, segment in enumerate(segments):
|
|
193
|
+
# Find the actual position in original text
|
|
194
|
+
start_pos = original_text.find(segment[:50], current_pos)
|
|
195
|
+
if start_pos == -1:
|
|
196
|
+
start_pos = current_pos
|
|
197
|
+
|
|
198
|
+
end_pos = start_pos + len(segment)
|
|
199
|
+
|
|
200
|
+
# Add overlap from previous chunk if not first
|
|
201
|
+
if i > 0 and overlap > 0:
|
|
202
|
+
# Get overlap text from end of previous segment
|
|
203
|
+
prev_segment = segments[i - 1]
|
|
204
|
+
overlap_text = prev_segment[-overlap:] if len(prev_segment) > overlap else prev_segment
|
|
205
|
+
segment_with_overlap = overlap_text + " " + segment
|
|
206
|
+
else:
|
|
207
|
+
segment_with_overlap = segment
|
|
208
|
+
|
|
209
|
+
chunk = TextChunk(
|
|
210
|
+
text=segment_with_overlap,
|
|
211
|
+
index=i,
|
|
212
|
+
start_char=start_pos,
|
|
213
|
+
end_char=end_pos,
|
|
214
|
+
metadata={
|
|
215
|
+
**base_metadata,
|
|
216
|
+
"chunk_index": i,
|
|
217
|
+
"total_chunks": len(segments),
|
|
218
|
+
},
|
|
219
|
+
)
|
|
220
|
+
chunks.append(chunk)
|
|
221
|
+
current_pos = end_pos
|
|
222
|
+
|
|
223
|
+
return chunks
|
|
224
|
+
|