RouteKitAI 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- routekitai/__init__.py +53 -0
- routekitai/cli/__init__.py +18 -0
- routekitai/cli/main.py +40 -0
- routekitai/cli/replay.py +80 -0
- routekitai/cli/run.py +95 -0
- routekitai/cli/serve.py +966 -0
- routekitai/cli/test_agent.py +178 -0
- routekitai/cli/trace.py +209 -0
- routekitai/cli/trace_analyze.py +120 -0
- routekitai/cli/trace_search.py +126 -0
- routekitai/core/__init__.py +58 -0
- routekitai/core/agent.py +325 -0
- routekitai/core/errors.py +49 -0
- routekitai/core/hooks.py +174 -0
- routekitai/core/memory.py +54 -0
- routekitai/core/message.py +132 -0
- routekitai/core/model.py +91 -0
- routekitai/core/policies.py +373 -0
- routekitai/core/policy.py +85 -0
- routekitai/core/policy_adapter.py +133 -0
- routekitai/core/runtime.py +1403 -0
- routekitai/core/tool.py +148 -0
- routekitai/core/tools.py +180 -0
- routekitai/evals/__init__.py +13 -0
- routekitai/evals/dataset.py +75 -0
- routekitai/evals/metrics.py +101 -0
- routekitai/evals/runner.py +184 -0
- routekitai/graphs/__init__.py +12 -0
- routekitai/graphs/executors.py +457 -0
- routekitai/graphs/graph.py +164 -0
- routekitai/memory/__init__.py +13 -0
- routekitai/memory/episodic.py +242 -0
- routekitai/memory/kv.py +34 -0
- routekitai/memory/retrieval.py +192 -0
- routekitai/memory/vector.py +700 -0
- routekitai/memory/working.py +66 -0
- routekitai/message.py +29 -0
- routekitai/model.py +48 -0
- routekitai/observability/__init__.py +21 -0
- routekitai/observability/analyzer.py +314 -0
- routekitai/observability/exporters/__init__.py +10 -0
- routekitai/observability/exporters/base.py +30 -0
- routekitai/observability/exporters/jsonl.py +81 -0
- routekitai/observability/exporters/otel.py +119 -0
- routekitai/observability/spans.py +111 -0
- routekitai/observability/streaming.py +117 -0
- routekitai/observability/trace.py +144 -0
- routekitai/providers/__init__.py +9 -0
- routekitai/providers/anthropic.py +227 -0
- routekitai/providers/azure_openai.py +243 -0
- routekitai/providers/local.py +196 -0
- routekitai/providers/openai.py +321 -0
- routekitai/py.typed +0 -0
- routekitai/sandbox/__init__.py +12 -0
- routekitai/sandbox/filesystem.py +131 -0
- routekitai/sandbox/network.py +142 -0
- routekitai/sandbox/permissions.py +70 -0
- routekitai/tool.py +33 -0
- routekitai-0.1.0.dist-info/METADATA +328 -0
- routekitai-0.1.0.dist-info/RECORD +64 -0
- routekitai-0.1.0.dist-info/WHEEL +5 -0
- routekitai-0.1.0.dist-info/entry_points.txt +2 -0
- routekitai-0.1.0.dist-info/licenses/LICENSE +21 -0
- routekitai-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
"""Episodic memory with SQLite backend."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import sqlite3
|
|
5
|
+
import time
|
|
6
|
+
import uuid
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from routekitai.core.memory import Memory
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EpisodicMemory(Memory):
|
|
14
|
+
"""Episodic memory with SQLite-backed persistent store.
|
|
15
|
+
|
|
16
|
+
Stores episodes in a SQLite database with:
|
|
17
|
+
- id: Unique episode ID
|
|
18
|
+
- ts: Timestamp
|
|
19
|
+
- content: Episode content (JSON)
|
|
20
|
+
- metadata: Additional metadata (JSON)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, db_path: Path | str | None = None) -> None:
|
|
24
|
+
"""Initialize episodic memory.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
db_path: Path to SQLite database file (defaults to .routekit/episodic.db)
|
|
28
|
+
"""
|
|
29
|
+
if db_path is None:
|
|
30
|
+
db_path = Path(".routekit") / "episodic.db"
|
|
31
|
+
elif isinstance(db_path, str):
|
|
32
|
+
db_path = Path(db_path)
|
|
33
|
+
|
|
34
|
+
self.db_path = db_path
|
|
35
|
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
36
|
+
self._init_db()
|
|
37
|
+
|
|
38
|
+
def _init_db(self) -> None:
|
|
39
|
+
"""Initialize database schema."""
|
|
40
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
41
|
+
# Enable WAL mode for better concurrent access and file locking on Windows
|
|
42
|
+
conn.execute("PRAGMA journal_mode=WAL")
|
|
43
|
+
conn.execute(
|
|
44
|
+
"""
|
|
45
|
+
CREATE TABLE IF NOT EXISTS episodes (
|
|
46
|
+
id TEXT PRIMARY KEY,
|
|
47
|
+
ts REAL NOT NULL,
|
|
48
|
+
content TEXT NOT NULL,
|
|
49
|
+
metadata TEXT NOT NULL
|
|
50
|
+
)
|
|
51
|
+
"""
|
|
52
|
+
)
|
|
53
|
+
conn.execute(
|
|
54
|
+
"""
|
|
55
|
+
CREATE INDEX IF NOT EXISTS idx_ts ON episodes(ts)
|
|
56
|
+
"""
|
|
57
|
+
)
|
|
58
|
+
conn.commit()
|
|
59
|
+
|
|
60
|
+
async def get(self, key: str) -> Any:
|
|
61
|
+
"""Get episode by ID.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
key: Episode ID
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Episode data or None if not found
|
|
68
|
+
"""
|
|
69
|
+
import asyncio
|
|
70
|
+
|
|
71
|
+
# SQLite operations are synchronous, but we need async interface
|
|
72
|
+
loop = asyncio.get_event_loop()
|
|
73
|
+
return await loop.run_in_executor(None, self._get_sync, key)
|
|
74
|
+
|
|
75
|
+
def _get_sync(self, key: str) -> Any:
|
|
76
|
+
"""Synchronous get operation."""
|
|
77
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
78
|
+
conn.row_factory = sqlite3.Row
|
|
79
|
+
cursor = conn.execute(
|
|
80
|
+
"SELECT id, ts, content, metadata FROM episodes WHERE id = ?",
|
|
81
|
+
(key,),
|
|
82
|
+
)
|
|
83
|
+
row = cursor.fetchone()
|
|
84
|
+
if row:
|
|
85
|
+
return {
|
|
86
|
+
"id": row["id"],
|
|
87
|
+
"ts": row["ts"],
|
|
88
|
+
"content": json.loads(row["content"]),
|
|
89
|
+
"metadata": json.loads(row["metadata"]),
|
|
90
|
+
}
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
async def set(self, key: str, value: Any) -> None:
|
|
94
|
+
"""Store episode by ID.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
key: Episode ID
|
|
98
|
+
value: Episode data (dict with content and optional metadata)
|
|
99
|
+
"""
|
|
100
|
+
import asyncio
|
|
101
|
+
|
|
102
|
+
loop = asyncio.get_event_loop()
|
|
103
|
+
await loop.run_in_executor(None, self._set_sync, key, value)
|
|
104
|
+
|
|
105
|
+
def _set_sync(self, key: str, value: Any) -> None:
|
|
106
|
+
"""Synchronous set operation."""
|
|
107
|
+
if not isinstance(value, dict):
|
|
108
|
+
value = {"content": value}
|
|
109
|
+
|
|
110
|
+
content = value.get("content", {})
|
|
111
|
+
metadata = value.get("metadata", {})
|
|
112
|
+
|
|
113
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
114
|
+
conn.execute(
|
|
115
|
+
"""
|
|
116
|
+
INSERT OR REPLACE INTO episodes (id, ts, content, metadata)
|
|
117
|
+
VALUES (?, ?, ?, ?)
|
|
118
|
+
""",
|
|
119
|
+
(key, time.time(), json.dumps(content), json.dumps(metadata)),
|
|
120
|
+
)
|
|
121
|
+
conn.commit()
|
|
122
|
+
|
|
123
|
+
async def append(self, event: dict[str, Any]) -> None:
|
|
124
|
+
"""Append an event as a new episode.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
event: Event dictionary to append
|
|
128
|
+
"""
|
|
129
|
+
import asyncio
|
|
130
|
+
|
|
131
|
+
episode_id = str(uuid.uuid4())
|
|
132
|
+
content = event.get("content", event)
|
|
133
|
+
metadata = event.get("metadata", {})
|
|
134
|
+
|
|
135
|
+
loop = asyncio.get_event_loop()
|
|
136
|
+
await loop.run_in_executor(
|
|
137
|
+
None, self._set_sync, episode_id, {"content": content, "metadata": metadata}
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
async def search(self, query: str, k: int = 5) -> list[dict[str, Any]]:
|
|
141
|
+
"""Search episodes by content (simple substring search for MVP).
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
query: Search query
|
|
145
|
+
k: Number of results to return
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
List of matching episodes
|
|
149
|
+
"""
|
|
150
|
+
import asyncio
|
|
151
|
+
|
|
152
|
+
loop = asyncio.get_event_loop()
|
|
153
|
+
return await loop.run_in_executor(None, self._search_sync, query, k)
|
|
154
|
+
|
|
155
|
+
def _search_sync(self, query: str, k: int) -> list[dict[str, Any]]:
|
|
156
|
+
"""Synchronous search operation."""
|
|
157
|
+
results = []
|
|
158
|
+
query_lower = query.lower()
|
|
159
|
+
|
|
160
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
161
|
+
conn.row_factory = sqlite3.Row
|
|
162
|
+
cursor = conn.execute("SELECT id, ts, content, metadata FROM episodes ORDER BY ts DESC")
|
|
163
|
+
for row in cursor:
|
|
164
|
+
content_str = json.dumps(row["content"]).lower()
|
|
165
|
+
if query_lower in content_str:
|
|
166
|
+
results.append(
|
|
167
|
+
{
|
|
168
|
+
"id": row["id"],
|
|
169
|
+
"ts": row["ts"],
|
|
170
|
+
"content": json.loads(row["content"]),
|
|
171
|
+
"metadata": json.loads(row["metadata"]),
|
|
172
|
+
}
|
|
173
|
+
)
|
|
174
|
+
if len(results) >= k:
|
|
175
|
+
break
|
|
176
|
+
|
|
177
|
+
return results
|
|
178
|
+
|
|
179
|
+
async def get_recent(self, limit: int = 10) -> list[dict[str, Any]]:
|
|
180
|
+
"""Get recent episodes.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
limit: Maximum number of episodes to return
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
List of recent episodes
|
|
187
|
+
"""
|
|
188
|
+
import asyncio
|
|
189
|
+
|
|
190
|
+
loop = asyncio.get_event_loop()
|
|
191
|
+
return await loop.run_in_executor(None, self._get_recent_sync, limit)
|
|
192
|
+
|
|
193
|
+
def _get_recent_sync(self, limit: int) -> list[dict[str, Any]]:
|
|
194
|
+
"""Synchronous get_recent operation."""
|
|
195
|
+
results = []
|
|
196
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
197
|
+
conn.row_factory = sqlite3.Row
|
|
198
|
+
cursor = conn.execute(
|
|
199
|
+
"SELECT id, ts, content, metadata FROM episodes ORDER BY ts DESC LIMIT ?",
|
|
200
|
+
(limit,),
|
|
201
|
+
)
|
|
202
|
+
for row in cursor:
|
|
203
|
+
results.append(
|
|
204
|
+
{
|
|
205
|
+
"id": row["id"],
|
|
206
|
+
"ts": row["ts"],
|
|
207
|
+
"content": json.loads(row["content"]),
|
|
208
|
+
"metadata": json.loads(row["metadata"]),
|
|
209
|
+
}
|
|
210
|
+
)
|
|
211
|
+
return results
|
|
212
|
+
|
|
213
|
+
def close(self) -> None:
|
|
214
|
+
"""Close any open database connections.
|
|
215
|
+
|
|
216
|
+
On Windows, SQLite can hold file locks briefly after connections close.
|
|
217
|
+
This method forces SQLite to release locks by opening and closing a connection.
|
|
218
|
+
"""
|
|
219
|
+
import gc
|
|
220
|
+
import sys
|
|
221
|
+
import time
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
# Force garbage collection to ensure any lingering connections are cleaned up
|
|
225
|
+
gc.collect()
|
|
226
|
+
# Open and immediately close a connection to ensure locks are released
|
|
227
|
+
# Use a short timeout to avoid hanging
|
|
228
|
+
with sqlite3.connect(str(self.db_path), timeout=1.0) as conn:
|
|
229
|
+
# Checkpoint WAL to ensure all data is written and locks are released
|
|
230
|
+
try:
|
|
231
|
+
conn.execute("PRAGMA wal_checkpoint(TRUNCATE)")
|
|
232
|
+
except sqlite3.Error:
|
|
233
|
+
# WAL checkpoint may fail if not in WAL mode, ignore
|
|
234
|
+
pass
|
|
235
|
+
# Execute a simple query to ensure connection is fully established
|
|
236
|
+
conn.execute("SELECT 1")
|
|
237
|
+
# On Windows, SQLite needs a moment to release file locks
|
|
238
|
+
if sys.platform == "win32":
|
|
239
|
+
time.sleep(0.05)
|
|
240
|
+
except (sqlite3.Error, OSError, TimeoutError):
|
|
241
|
+
# Ignore errors during cleanup - file may already be closed or locked
|
|
242
|
+
pass
|
routekitai/memory/kv.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Key-value memory for agents."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class KVMemory(BaseModel):
|
|
9
|
+
"""Key-value memory store for agent state.
|
|
10
|
+
|
|
11
|
+
TODO: Implement persistent key-value storage for agent memory.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
store: dict[str, Any] = Field(default_factory=dict, description="In-memory key-value store")
|
|
15
|
+
|
|
16
|
+
async def get(self, key: str) -> Any:
|
|
17
|
+
"""Get value by key.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
key: Key to retrieve
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Stored value or None if not found
|
|
24
|
+
"""
|
|
25
|
+
return self.store.get(key)
|
|
26
|
+
|
|
27
|
+
async def set(self, key: str, value: Any) -> None:
|
|
28
|
+
"""Set value by key.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
key: Key to set
|
|
32
|
+
value: Value to store
|
|
33
|
+
"""
|
|
34
|
+
self.store[key] = value
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""Retrieval memory with TF-IDF/substring fallback."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from collections import Counter
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from routekitai.core.memory import Memory
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RetrievalMemory(Memory):
|
|
11
|
+
"""Retrieval memory with TF-IDF or substring search fallback.
|
|
12
|
+
|
|
13
|
+
For MVP, provides simple text-based retrieval without vector embeddings.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, use_tfidf: bool = True) -> None:
|
|
17
|
+
"""Initialize retrieval memory.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
use_tfidf: Whether to use TF-IDF (True) or simple substring search (False)
|
|
21
|
+
"""
|
|
22
|
+
self.use_tfidf = use_tfidf
|
|
23
|
+
self._documents: list[dict[str, Any]] = []
|
|
24
|
+
self._idf: dict[str, float] = {}
|
|
25
|
+
|
|
26
|
+
async def get(self, key: str) -> Any:
|
|
27
|
+
"""Get document by ID.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
key: Document ID
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Document data or None if not found
|
|
34
|
+
"""
|
|
35
|
+
for doc in self._documents:
|
|
36
|
+
if doc.get("id") == key:
|
|
37
|
+
return doc
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
async def set(self, key: str, value: Any) -> None:
|
|
41
|
+
"""Store document by ID.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
key: Document ID
|
|
45
|
+
value: Document data
|
|
46
|
+
"""
|
|
47
|
+
if isinstance(value, dict):
|
|
48
|
+
doc = value.copy()
|
|
49
|
+
doc["id"] = key
|
|
50
|
+
else:
|
|
51
|
+
doc = {"id": key, "content": value}
|
|
52
|
+
|
|
53
|
+
# Update or add document
|
|
54
|
+
for i, existing_doc in enumerate(self._documents):
|
|
55
|
+
if existing_doc.get("id") == key:
|
|
56
|
+
self._documents[i] = doc
|
|
57
|
+
self._update_idf()
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
self._documents.append(doc)
|
|
61
|
+
self._update_idf()
|
|
62
|
+
|
|
63
|
+
async def append(self, event: dict[str, Any]) -> None:
|
|
64
|
+
"""Append an event as a new document.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
event: Event dictionary to append
|
|
68
|
+
"""
|
|
69
|
+
import uuid
|
|
70
|
+
|
|
71
|
+
doc_id = str(uuid.uuid4())
|
|
72
|
+
doc = event.copy()
|
|
73
|
+
doc["id"] = doc_id
|
|
74
|
+
# Ensure content field exists for search (extract from event if needed)
|
|
75
|
+
if "content" not in doc:
|
|
76
|
+
# Try to extract content from event
|
|
77
|
+
if isinstance(event, dict) and "content" in event:
|
|
78
|
+
doc["content"] = event["content"]
|
|
79
|
+
else:
|
|
80
|
+
doc["content"] = str(event)
|
|
81
|
+
self._documents.append(doc)
|
|
82
|
+
self._update_idf()
|
|
83
|
+
|
|
84
|
+
async def search(self, query: str, k: int = 5) -> list[dict[str, Any]]:
|
|
85
|
+
"""Search documents using TF-IDF or substring matching.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
query: Search query
|
|
89
|
+
k: Number of results to return
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
List of matching documents with scores
|
|
93
|
+
"""
|
|
94
|
+
if not self._documents:
|
|
95
|
+
return []
|
|
96
|
+
|
|
97
|
+
if self.use_tfidf:
|
|
98
|
+
return await self._search_tfidf(query, k)
|
|
99
|
+
else:
|
|
100
|
+
return await self._search_substring(query, k)
|
|
101
|
+
|
|
102
|
+
async def _search_tfidf(self, query: str, k: int) -> list[dict[str, Any]]:
|
|
103
|
+
"""Search using TF-IDF scoring.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
query: Search query
|
|
107
|
+
k: Number of results
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
List of documents with TF-IDF scores
|
|
111
|
+
"""
|
|
112
|
+
query_terms = self._tokenize(query)
|
|
113
|
+
|
|
114
|
+
scores = []
|
|
115
|
+
for doc in self._documents:
|
|
116
|
+
content = str(doc.get("content", ""))
|
|
117
|
+
doc_terms = self._tokenize(content)
|
|
118
|
+
doc_tf = Counter(doc_terms)
|
|
119
|
+
|
|
120
|
+
score = 0.0
|
|
121
|
+
for term in query_terms:
|
|
122
|
+
if term in doc_tf:
|
|
123
|
+
tf = doc_tf[term] / len(doc_terms) if doc_terms else 0
|
|
124
|
+
idf = self._idf.get(term, 0.0)
|
|
125
|
+
score += tf * idf
|
|
126
|
+
|
|
127
|
+
if score > 0:
|
|
128
|
+
result = doc.copy()
|
|
129
|
+
result["score"] = score
|
|
130
|
+
scores.append(result)
|
|
131
|
+
|
|
132
|
+
# Sort by score descending
|
|
133
|
+
scores.sort(key=lambda x: x.get("score", 0), reverse=True)
|
|
134
|
+
return scores[:k]
|
|
135
|
+
|
|
136
|
+
async def _search_substring(self, query: str, k: int) -> list[dict[str, Any]]:
|
|
137
|
+
"""Search using simple substring matching.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
query: Search query
|
|
141
|
+
k: Number of results
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
List of matching documents
|
|
145
|
+
"""
|
|
146
|
+
query_lower = query.lower()
|
|
147
|
+
results = []
|
|
148
|
+
|
|
149
|
+
for doc in self._documents:
|
|
150
|
+
content = str(doc.get("content", "")).lower()
|
|
151
|
+
if query_lower in content:
|
|
152
|
+
result = doc.copy()
|
|
153
|
+
result["score"] = 1.0 # Simple binary match
|
|
154
|
+
results.append(result)
|
|
155
|
+
if len(results) >= k:
|
|
156
|
+
break
|
|
157
|
+
|
|
158
|
+
return results
|
|
159
|
+
|
|
160
|
+
def _tokenize(self, text: str) -> list[str]:
|
|
161
|
+
"""Tokenize text into words.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
text: Text to tokenize
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
List of lowercase tokens
|
|
168
|
+
"""
|
|
169
|
+
# Simple tokenization: lowercase, split on non-word chars
|
|
170
|
+
tokens = re.findall(r"\b\w+\b", text.lower())
|
|
171
|
+
return tokens
|
|
172
|
+
|
|
173
|
+
def _update_idf(self) -> None:
|
|
174
|
+
"""Update inverse document frequency for all terms."""
|
|
175
|
+
if not self._documents:
|
|
176
|
+
self._idf = {}
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
doc_count = len(self._documents)
|
|
180
|
+
term_doc_count: dict[str, int] = {}
|
|
181
|
+
|
|
182
|
+
for doc in self._documents:
|
|
183
|
+
content = str(doc.get("content", ""))
|
|
184
|
+
terms = set(self._tokenize(content))
|
|
185
|
+
for term in terms:
|
|
186
|
+
term_doc_count[term] = term_doc_count.get(term, 0) + 1
|
|
187
|
+
|
|
188
|
+
# Calculate IDF: log(total_docs / docs_with_term)
|
|
189
|
+
self._idf = {
|
|
190
|
+
term: __import__("math").log(doc_count / count) if count > 0 else 0.0
|
|
191
|
+
for term, count in term_doc_count.items()
|
|
192
|
+
}
|