chuk-ai-session-manager 0.7.1__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chuk_ai_session_manager/__init__.py +84 -40
- chuk_ai_session_manager/api/__init__.py +1 -1
- chuk_ai_session_manager/api/simple_api.py +53 -59
- chuk_ai_session_manager/exceptions.py +31 -17
- chuk_ai_session_manager/guards/__init__.py +118 -0
- chuk_ai_session_manager/guards/bindings.py +217 -0
- chuk_ai_session_manager/guards/cache.py +163 -0
- chuk_ai_session_manager/guards/manager.py +819 -0
- chuk_ai_session_manager/guards/models.py +498 -0
- chuk_ai_session_manager/guards/ungrounded.py +159 -0
- chuk_ai_session_manager/infinite_conversation.py +86 -79
- chuk_ai_session_manager/memory/__init__.py +247 -0
- chuk_ai_session_manager/memory/artifacts_bridge.py +469 -0
- chuk_ai_session_manager/memory/context_packer.py +347 -0
- chuk_ai_session_manager/memory/fault_handler.py +507 -0
- chuk_ai_session_manager/memory/manifest.py +307 -0
- chuk_ai_session_manager/memory/models.py +1084 -0
- chuk_ai_session_manager/memory/mutation_log.py +186 -0
- chuk_ai_session_manager/memory/pack_cache.py +206 -0
- chuk_ai_session_manager/memory/page_table.py +275 -0
- chuk_ai_session_manager/memory/prefetcher.py +192 -0
- chuk_ai_session_manager/memory/tlb.py +247 -0
- chuk_ai_session_manager/memory/vm_prompts.py +238 -0
- chuk_ai_session_manager/memory/working_set.py +574 -0
- chuk_ai_session_manager/models/__init__.py +21 -9
- chuk_ai_session_manager/models/event_source.py +3 -1
- chuk_ai_session_manager/models/event_type.py +10 -1
- chuk_ai_session_manager/models/session.py +103 -68
- chuk_ai_session_manager/models/session_event.py +69 -68
- chuk_ai_session_manager/models/session_metadata.py +9 -10
- chuk_ai_session_manager/models/session_run.py +21 -22
- chuk_ai_session_manager/models/token_usage.py +76 -76
- chuk_ai_session_manager/procedural_memory/__init__.py +70 -0
- chuk_ai_session_manager/procedural_memory/formatter.py +407 -0
- chuk_ai_session_manager/procedural_memory/manager.py +523 -0
- chuk_ai_session_manager/procedural_memory/models.py +371 -0
- chuk_ai_session_manager/sample_tools.py +79 -46
- chuk_ai_session_manager/session_aware_tool_processor.py +27 -16
- chuk_ai_session_manager/session_manager.py +259 -232
- chuk_ai_session_manager/session_prompt_builder.py +163 -111
- chuk_ai_session_manager/session_storage.py +45 -52
- {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.1.dist-info}/METADATA +80 -4
- chuk_ai_session_manager-0.8.1.dist-info/RECORD +45 -0
- {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.1.dist-info}/WHEEL +1 -1
- chuk_ai_session_manager-0.7.1.dist-info/RECORD +0 -22
- {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,507 @@
|
|
|
1
|
+
# chuk_ai_session_manager/memory/fault_handler.py
|
|
2
|
+
"""
|
|
3
|
+
Page Fault Handler for AI Virtual Memory.
|
|
4
|
+
|
|
5
|
+
The PageFaultHandler resolves requests for pages not currently in L0.
|
|
6
|
+
When the model calls page_fault(page_id, target_level), this handler:
|
|
7
|
+
1. Looks up the page location
|
|
8
|
+
2. Loads from the appropriate tier
|
|
9
|
+
3. Compresses to the requested level
|
|
10
|
+
4. Returns the canonical tool result envelope
|
|
11
|
+
|
|
12
|
+
Design principles:
|
|
13
|
+
- Async-native: All I/O operations are async
|
|
14
|
+
- Pydantic-native: All models are BaseModel subclasses
|
|
15
|
+
- No magic strings: Uses enums for all categorical values
|
|
16
|
+
- Metrics-aware: Tracks fault counts and latencies
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import time
|
|
20
|
+
from typing import Callable, Dict, List, Optional, Protocol
|
|
21
|
+
|
|
22
|
+
from pydantic import BaseModel, Field, PrivateAttr
|
|
23
|
+
|
|
24
|
+
from .models import (
|
|
25
|
+
ALL_COMPRESSION_LEVELS,
|
|
26
|
+
AudioContent,
|
|
27
|
+
CompressionLevel,
|
|
28
|
+
FaultEffects,
|
|
29
|
+
FaultMetrics,
|
|
30
|
+
ImageContent,
|
|
31
|
+
MemoryPage,
|
|
32
|
+
Modality,
|
|
33
|
+
PageData,
|
|
34
|
+
PageMeta,
|
|
35
|
+
SearchResultEntry,
|
|
36
|
+
StorageTier,
|
|
37
|
+
StructuredContent,
|
|
38
|
+
TextContent,
|
|
39
|
+
VideoContent,
|
|
40
|
+
VMMetrics,
|
|
41
|
+
)
|
|
42
|
+
from .page_table import PageTable
|
|
43
|
+
from .tlb import PageTLB
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class PageLoader(Protocol):
|
|
47
|
+
"""Protocol for loading pages from storage tiers."""
|
|
48
|
+
|
|
49
|
+
async def load(
|
|
50
|
+
self,
|
|
51
|
+
page_id: str,
|
|
52
|
+
tier: StorageTier,
|
|
53
|
+
artifact_id: Optional[str] = None,
|
|
54
|
+
) -> Optional[MemoryPage]:
|
|
55
|
+
"""Load a page from storage."""
|
|
56
|
+
...
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class PageCompressor(Protocol):
|
|
60
|
+
"""Protocol for compressing pages to different levels."""
|
|
61
|
+
|
|
62
|
+
async def compress(
|
|
63
|
+
self,
|
|
64
|
+
page: MemoryPage,
|
|
65
|
+
target_level: CompressionLevel,
|
|
66
|
+
) -> MemoryPage:
|
|
67
|
+
"""Compress a page to the target level."""
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class FaultResult(BaseModel):
|
|
72
|
+
"""Result of a page fault resolution."""
|
|
73
|
+
|
|
74
|
+
success: bool
|
|
75
|
+
page: Optional[MemoryPage] = None
|
|
76
|
+
error: Optional[str] = None
|
|
77
|
+
source_tier: Optional[StorageTier] = None
|
|
78
|
+
latency_ms: float = 0.0
|
|
79
|
+
was_compressed: bool = False
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class VMToolResult(BaseModel):
|
|
83
|
+
"""
|
|
84
|
+
Canonical envelope for VM tool results.
|
|
85
|
+
|
|
86
|
+
This is the format returned to the model via role="tool" messages.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
page: PageData
|
|
90
|
+
effects: FaultEffects = Field(default_factory=FaultEffects)
|
|
91
|
+
|
|
92
|
+
def to_json(self) -> str:
|
|
93
|
+
"""Serialize to JSON for tool response."""
|
|
94
|
+
return self.model_dump_json()
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class VMToolError(BaseModel):
|
|
98
|
+
"""Error response for VM tool calls."""
|
|
99
|
+
|
|
100
|
+
error: str
|
|
101
|
+
page_id: Optional[str] = None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class PageFaultHandler(BaseModel):
|
|
105
|
+
"""
|
|
106
|
+
Handles page fault resolution.
|
|
107
|
+
|
|
108
|
+
When the model needs a page that's not in the working set,
|
|
109
|
+
this handler fetches it from the appropriate tier.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
# Dependencies (set after construction)
|
|
113
|
+
page_table: Optional[PageTable] = None
|
|
114
|
+
tlb: Optional[PageTLB] = None
|
|
115
|
+
|
|
116
|
+
# Page storage (maps page_id -> MemoryPage for L2+ pages)
|
|
117
|
+
# In production, this would be replaced by ArtifactsBridge
|
|
118
|
+
page_store: Dict[str, MemoryPage] = Field(default_factory=dict)
|
|
119
|
+
|
|
120
|
+
# Metrics
|
|
121
|
+
metrics: VMMetrics = Field(default_factory=VMMetrics)
|
|
122
|
+
|
|
123
|
+
# Configuration
|
|
124
|
+
max_faults_per_turn: int = Field(
|
|
125
|
+
default=2, description="Maximum faults allowed per turn"
|
|
126
|
+
)
|
|
127
|
+
faults_this_turn: int = Field(default=0, description="Faults issued this turn")
|
|
128
|
+
|
|
129
|
+
# Optional async loader/compressor (private attrs)
|
|
130
|
+
_loader: Optional[PageLoader] = PrivateAttr(default=None)
|
|
131
|
+
_compressor: Optional[PageCompressor] = PrivateAttr(default=None)
|
|
132
|
+
|
|
133
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
134
|
+
|
|
135
|
+
def configure(
|
|
136
|
+
self,
|
|
137
|
+
page_table: PageTable,
|
|
138
|
+
tlb: Optional[PageTLB] = None,
|
|
139
|
+
loader: Optional[PageLoader] = None,
|
|
140
|
+
compressor: Optional[PageCompressor] = None,
|
|
141
|
+
) -> None:
|
|
142
|
+
"""Configure the handler with dependencies."""
|
|
143
|
+
self.page_table = page_table
|
|
144
|
+
self.tlb = tlb
|
|
145
|
+
self._loader = loader
|
|
146
|
+
self._compressor = compressor
|
|
147
|
+
|
|
148
|
+
def new_turn(self) -> None:
|
|
149
|
+
"""Reset per-turn counters."""
|
|
150
|
+
self.faults_this_turn = 0
|
|
151
|
+
self.metrics.new_turn()
|
|
152
|
+
|
|
153
|
+
def can_fault(self) -> bool:
|
|
154
|
+
"""Check if more faults are allowed this turn."""
|
|
155
|
+
return self.faults_this_turn < self.max_faults_per_turn
|
|
156
|
+
|
|
157
|
+
async def handle_fault(
|
|
158
|
+
self,
|
|
159
|
+
page_id: str,
|
|
160
|
+
target_level: int = 2,
|
|
161
|
+
) -> FaultResult:
|
|
162
|
+
"""
|
|
163
|
+
Handle a page fault request.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
page_id: ID of the page to load
|
|
167
|
+
target_level: Compression level (0=full, 1=reduced, 2=abstract, 3=ref)
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
FaultResult with the loaded page or error
|
|
171
|
+
"""
|
|
172
|
+
start_time = time.time()
|
|
173
|
+
|
|
174
|
+
# Check fault limit
|
|
175
|
+
if not self.can_fault():
|
|
176
|
+
return FaultResult(
|
|
177
|
+
success=False,
|
|
178
|
+
error=f"Fault limit exceeded ({self.max_faults_per_turn} per turn)",
|
|
179
|
+
latency_ms=(time.time() - start_time) * 1000,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Ensure we have a page table
|
|
183
|
+
if self.page_table is None:
|
|
184
|
+
return FaultResult(
|
|
185
|
+
success=False,
|
|
186
|
+
error="PageFaultHandler not configured",
|
|
187
|
+
latency_ms=(time.time() - start_time) * 1000,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Look up page entry (TLB first, then page table)
|
|
191
|
+
entry = None
|
|
192
|
+
if self.tlb:
|
|
193
|
+
entry = self.tlb.lookup(page_id)
|
|
194
|
+
if entry:
|
|
195
|
+
self.metrics.record_tlb_hit()
|
|
196
|
+
else:
|
|
197
|
+
self.metrics.record_tlb_miss()
|
|
198
|
+
entry = self.page_table.lookup(page_id)
|
|
199
|
+
if entry:
|
|
200
|
+
self.tlb.insert(entry)
|
|
201
|
+
else:
|
|
202
|
+
entry = self.page_table.lookup(page_id)
|
|
203
|
+
|
|
204
|
+
if not entry:
|
|
205
|
+
return FaultResult(
|
|
206
|
+
success=False,
|
|
207
|
+
error=f"Page not found: {page_id}",
|
|
208
|
+
latency_ms=(time.time() - start_time) * 1000,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Load the page
|
|
212
|
+
page = await self._load_page(page_id, entry.tier, entry.artifact_id)
|
|
213
|
+
if not page:
|
|
214
|
+
return FaultResult(
|
|
215
|
+
success=False,
|
|
216
|
+
error=f"Failed to load page from {entry.tier.value}",
|
|
217
|
+
source_tier=entry.tier,
|
|
218
|
+
latency_ms=(time.time() - start_time) * 1000,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Compress if needed
|
|
222
|
+
target = CompressionLevel(target_level)
|
|
223
|
+
was_compressed = False
|
|
224
|
+
if page.compression_level != target:
|
|
225
|
+
page = await self._compress_page(page, target)
|
|
226
|
+
was_compressed = True
|
|
227
|
+
|
|
228
|
+
# Update metrics
|
|
229
|
+
self.faults_this_turn += 1
|
|
230
|
+
self.metrics.record_fault()
|
|
231
|
+
self.page_table.mark_accessed(page_id)
|
|
232
|
+
|
|
233
|
+
latency = (time.time() - start_time) * 1000
|
|
234
|
+
return FaultResult(
|
|
235
|
+
success=True,
|
|
236
|
+
page=page,
|
|
237
|
+
source_tier=entry.tier,
|
|
238
|
+
latency_ms=latency,
|
|
239
|
+
was_compressed=was_compressed,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
async def _load_page(
|
|
243
|
+
self,
|
|
244
|
+
page_id: str,
|
|
245
|
+
tier: StorageTier,
|
|
246
|
+
artifact_id: Optional[str],
|
|
247
|
+
) -> Optional[MemoryPage]:
|
|
248
|
+
"""Load a page from storage."""
|
|
249
|
+
# Use custom loader if available
|
|
250
|
+
if self._loader:
|
|
251
|
+
return await self._loader.load(page_id, tier, artifact_id)
|
|
252
|
+
|
|
253
|
+
# Default: check in-memory store
|
|
254
|
+
return self.page_store.get(page_id)
|
|
255
|
+
|
|
256
|
+
async def _compress_page(
|
|
257
|
+
self,
|
|
258
|
+
page: MemoryPage,
|
|
259
|
+
target_level: CompressionLevel,
|
|
260
|
+
) -> MemoryPage:
|
|
261
|
+
"""Compress a page to the target level."""
|
|
262
|
+
# Use custom compressor if available
|
|
263
|
+
if self._compressor:
|
|
264
|
+
return await self._compressor.compress(page, target_level)
|
|
265
|
+
|
|
266
|
+
# Default: stub compression (just update level, don't transform)
|
|
267
|
+
page.compression_level = target_level
|
|
268
|
+
return page
|
|
269
|
+
|
|
270
|
+
def store_page(self, page: MemoryPage) -> None:
|
|
271
|
+
"""Store a page in the local page store (for testing/simple usage)."""
|
|
272
|
+
self.page_store[page.page_id] = page
|
|
273
|
+
|
|
274
|
+
def build_tool_result(
|
|
275
|
+
self,
|
|
276
|
+
fault_result: FaultResult,
|
|
277
|
+
evictions: Optional[List[str]] = None,
|
|
278
|
+
) -> VMToolResult:
|
|
279
|
+
"""
|
|
280
|
+
Build the canonical tool result envelope for a fault result.
|
|
281
|
+
|
|
282
|
+
This is what gets returned to the model in the tool response.
|
|
283
|
+
"""
|
|
284
|
+
if not fault_result.success or not fault_result.page:
|
|
285
|
+
# Return error as a minimal PageData
|
|
286
|
+
return VMToolResult(
|
|
287
|
+
page=PageData(
|
|
288
|
+
page_id="error",
|
|
289
|
+
modality=Modality.TEXT.value,
|
|
290
|
+
level=0,
|
|
291
|
+
tier=StorageTier.L0.value,
|
|
292
|
+
content=TextContent(text=fault_result.error or "Unknown error"),
|
|
293
|
+
meta=PageMeta(),
|
|
294
|
+
),
|
|
295
|
+
effects=FaultEffects(promoted_to_working_set=False),
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
page = fault_result.page
|
|
299
|
+
content = self._format_content_for_modality(page)
|
|
300
|
+
meta = self._build_meta(page, fault_result)
|
|
301
|
+
|
|
302
|
+
return VMToolResult(
|
|
303
|
+
page=PageData(
|
|
304
|
+
page_id=page.page_id,
|
|
305
|
+
modality=page.modality.value,
|
|
306
|
+
level=page.compression_level
|
|
307
|
+
if isinstance(page.compression_level, int)
|
|
308
|
+
else page.compression_level.value,
|
|
309
|
+
tier=StorageTier.L1.value, # Promoted to L1 after fault
|
|
310
|
+
content=content,
|
|
311
|
+
meta=meta,
|
|
312
|
+
),
|
|
313
|
+
effects=FaultEffects(
|
|
314
|
+
promoted_to_working_set=True,
|
|
315
|
+
tokens_est=page.size_tokens or page.estimate_tokens(),
|
|
316
|
+
evictions=evictions,
|
|
317
|
+
),
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
def _format_content_for_modality(
|
|
321
|
+
self, page: MemoryPage
|
|
322
|
+
) -> TextContent | ImageContent | AudioContent | VideoContent | StructuredContent:
|
|
323
|
+
"""Format page content based on modality."""
|
|
324
|
+
if page.modality == Modality.TEXT:
|
|
325
|
+
return TextContent(text=page.content or "")
|
|
326
|
+
|
|
327
|
+
elif page.modality == Modality.IMAGE:
|
|
328
|
+
result = ImageContent()
|
|
329
|
+
if page.caption:
|
|
330
|
+
result.caption = page.caption
|
|
331
|
+
if page.content and isinstance(page.content, str):
|
|
332
|
+
if page.content.startswith("http"):
|
|
333
|
+
result.url = page.content
|
|
334
|
+
elif page.content.startswith("data:"):
|
|
335
|
+
result.base64 = page.content
|
|
336
|
+
else:
|
|
337
|
+
result.caption = page.content
|
|
338
|
+
return result
|
|
339
|
+
|
|
340
|
+
elif page.modality == Modality.AUDIO:
|
|
341
|
+
result = AudioContent()
|
|
342
|
+
if page.transcript:
|
|
343
|
+
result.transcript = page.transcript
|
|
344
|
+
elif page.content and isinstance(page.content, str):
|
|
345
|
+
result.transcript = page.content
|
|
346
|
+
if page.duration_seconds:
|
|
347
|
+
result.duration_seconds = page.duration_seconds
|
|
348
|
+
return result
|
|
349
|
+
|
|
350
|
+
elif page.modality == Modality.VIDEO:
|
|
351
|
+
result = VideoContent()
|
|
352
|
+
if page.transcript:
|
|
353
|
+
result.transcript = page.transcript
|
|
354
|
+
if page.duration_seconds:
|
|
355
|
+
result.duration_seconds = page.duration_seconds
|
|
356
|
+
scenes = page.metadata.get("scenes", [])
|
|
357
|
+
if scenes:
|
|
358
|
+
result.scenes = scenes
|
|
359
|
+
return result
|
|
360
|
+
|
|
361
|
+
elif page.modality == Modality.STRUCTURED:
|
|
362
|
+
return StructuredContent(
|
|
363
|
+
data=page.content if isinstance(page.content, dict) else {}
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
else:
|
|
367
|
+
return TextContent(text=str(page.content) if page.content else "")
|
|
368
|
+
|
|
369
|
+
def _build_meta(
|
|
370
|
+
self,
|
|
371
|
+
page: MemoryPage,
|
|
372
|
+
fault_result: FaultResult,
|
|
373
|
+
) -> PageMeta:
|
|
374
|
+
"""Build metadata for tool result."""
|
|
375
|
+
meta = PageMeta(
|
|
376
|
+
source_tier=fault_result.source_tier.value
|
|
377
|
+
if fault_result.source_tier
|
|
378
|
+
else "unknown",
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
if page.mime_type:
|
|
382
|
+
meta.mime_type = page.mime_type
|
|
383
|
+
|
|
384
|
+
if page.size_bytes:
|
|
385
|
+
meta.size_bytes = page.size_bytes
|
|
386
|
+
|
|
387
|
+
if page.dimensions:
|
|
388
|
+
meta.dimensions = list(page.dimensions)
|
|
389
|
+
|
|
390
|
+
if page.duration_seconds:
|
|
391
|
+
meta.duration_seconds = page.duration_seconds
|
|
392
|
+
|
|
393
|
+
if fault_result.latency_ms:
|
|
394
|
+
meta.latency_ms = round(fault_result.latency_ms, 2)
|
|
395
|
+
|
|
396
|
+
return meta
|
|
397
|
+
|
|
398
|
+
def get_metrics(self) -> FaultMetrics:
|
|
399
|
+
"""Get fault handler metrics."""
|
|
400
|
+
return FaultMetrics(
|
|
401
|
+
faults_this_turn=self.faults_this_turn,
|
|
402
|
+
max_faults_per_turn=self.max_faults_per_turn,
|
|
403
|
+
faults_remaining=self.max_faults_per_turn - self.faults_this_turn,
|
|
404
|
+
total_faults=self.metrics.faults_total,
|
|
405
|
+
tlb_hit_rate=self.metrics.tlb_hit_rate,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
class SearchResult(BaseModel):
|
|
410
|
+
"""Result of a page search operation."""
|
|
411
|
+
|
|
412
|
+
results: List[SearchResultEntry] = Field(default_factory=list)
|
|
413
|
+
total_available: int = 0
|
|
414
|
+
|
|
415
|
+
def to_json(self) -> str:
|
|
416
|
+
"""Serialize to JSON for tool response."""
|
|
417
|
+
return self.model_dump_json()
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
class PageSearchHandler(BaseModel):
|
|
421
|
+
"""
|
|
422
|
+
Handles search_pages tool calls.
|
|
423
|
+
|
|
424
|
+
Searches available pages by query and returns metadata (not content).
|
|
425
|
+
"""
|
|
426
|
+
|
|
427
|
+
page_table: Optional[PageTable] = None
|
|
428
|
+
|
|
429
|
+
# Optional search function
|
|
430
|
+
_search_fn: Optional[Callable] = PrivateAttr(default=None)
|
|
431
|
+
|
|
432
|
+
# Page hints (for simple text search)
|
|
433
|
+
page_hints: Dict[str, str] = Field(default_factory=dict)
|
|
434
|
+
|
|
435
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
436
|
+
|
|
437
|
+
def configure(
|
|
438
|
+
self,
|
|
439
|
+
page_table: PageTable,
|
|
440
|
+
search_fn: Optional[Callable] = None,
|
|
441
|
+
) -> None:
|
|
442
|
+
"""Configure the search handler."""
|
|
443
|
+
self.page_table = page_table
|
|
444
|
+
self._search_fn = search_fn
|
|
445
|
+
|
|
446
|
+
def set_hint(self, page_id: str, hint: str) -> None:
|
|
447
|
+
"""Set a search hint for a page."""
|
|
448
|
+
self.page_hints[page_id] = hint
|
|
449
|
+
|
|
450
|
+
async def search(
|
|
451
|
+
self,
|
|
452
|
+
query: str,
|
|
453
|
+
modality: Optional[str] = None,
|
|
454
|
+
limit: int = 5,
|
|
455
|
+
) -> SearchResult:
|
|
456
|
+
"""
|
|
457
|
+
Search for pages matching a query.
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
query: Search query (keyword or semantic)
|
|
461
|
+
modality: Optional filter by modality
|
|
462
|
+
limit: Maximum results
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
SearchResult with matching page metadata
|
|
466
|
+
"""
|
|
467
|
+
if self.page_table is None:
|
|
468
|
+
return SearchResult(results=[], total_available=0)
|
|
469
|
+
|
|
470
|
+
# Use custom search if available
|
|
471
|
+
if self._search_fn:
|
|
472
|
+
return await self._search_fn(query, modality, limit)
|
|
473
|
+
|
|
474
|
+
# Default: simple text matching on hints
|
|
475
|
+
results: List[SearchResultEntry] = []
|
|
476
|
+
query_lower = query.lower()
|
|
477
|
+
|
|
478
|
+
for page_id, entry in self.page_table.entries.items():
|
|
479
|
+
# Filter by modality if specified
|
|
480
|
+
if modality and entry.modality.value != modality:
|
|
481
|
+
continue
|
|
482
|
+
|
|
483
|
+
# Check hint for match
|
|
484
|
+
hint = self.page_hints.get(page_id, "")
|
|
485
|
+
if query_lower in hint.lower() or query_lower in page_id.lower():
|
|
486
|
+
relevance = 1.0 if query_lower in page_id.lower() else 0.8
|
|
487
|
+
results.append(
|
|
488
|
+
SearchResultEntry(
|
|
489
|
+
page_id=page_id,
|
|
490
|
+
modality=entry.modality.value,
|
|
491
|
+
tier=entry.tier.value,
|
|
492
|
+
levels=ALL_COMPRESSION_LEVELS,
|
|
493
|
+
hint=hint,
|
|
494
|
+
relevance=relevance,
|
|
495
|
+
)
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
if len(results) >= limit:
|
|
499
|
+
break
|
|
500
|
+
|
|
501
|
+
# Sort by relevance
|
|
502
|
+
results.sort(key=lambda x: x.relevance, reverse=True)
|
|
503
|
+
|
|
504
|
+
return SearchResult(
|
|
505
|
+
results=results[:limit],
|
|
506
|
+
total_available=len(self.page_table.entries),
|
|
507
|
+
)
|