local-deep-research 0.5.9__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. local_deep_research/__version__.py +1 -1
  2. local_deep_research/advanced_search_system/candidate_exploration/progressive_explorer.py +11 -1
  3. local_deep_research/advanced_search_system/questions/browsecomp_question.py +32 -6
  4. local_deep_research/advanced_search_system/strategies/focused_iteration_strategy.py +32 -8
  5. local_deep_research/advanced_search_system/strategies/source_based_strategy.py +2 -0
  6. local_deep_research/api/__init__.py +2 -0
  7. local_deep_research/api/research_functions.py +177 -3
  8. local_deep_research/benchmarks/graders.py +150 -5
  9. local_deep_research/benchmarks/models/__init__.py +19 -0
  10. local_deep_research/benchmarks/models/benchmark_models.py +283 -0
  11. local_deep_research/benchmarks/ui/__init__.py +1 -0
  12. local_deep_research/benchmarks/web_api/__init__.py +6 -0
  13. local_deep_research/benchmarks/web_api/benchmark_routes.py +862 -0
  14. local_deep_research/benchmarks/web_api/benchmark_service.py +920 -0
  15. local_deep_research/config/llm_config.py +106 -21
  16. local_deep_research/defaults/default_settings.json +447 -2
  17. local_deep_research/error_handling/report_generator.py +10 -0
  18. local_deep_research/llm/__init__.py +19 -0
  19. local_deep_research/llm/llm_registry.py +155 -0
  20. local_deep_research/metrics/db_models.py +3 -7
  21. local_deep_research/metrics/search_tracker.py +25 -11
  22. local_deep_research/search_system.py +12 -9
  23. local_deep_research/utilities/log_utils.py +23 -10
  24. local_deep_research/utilities/thread_context.py +99 -0
  25. local_deep_research/web/app_factory.py +32 -8
  26. local_deep_research/web/database/benchmark_schema.py +230 -0
  27. local_deep_research/web/database/convert_research_id_to_string.py +161 -0
  28. local_deep_research/web/database/models.py +55 -1
  29. local_deep_research/web/database/schema_upgrade.py +397 -2
  30. local_deep_research/web/database/uuid_migration.py +265 -0
  31. local_deep_research/web/routes/api_routes.py +62 -31
  32. local_deep_research/web/routes/history_routes.py +13 -6
  33. local_deep_research/web/routes/metrics_routes.py +264 -4
  34. local_deep_research/web/routes/research_routes.py +45 -18
  35. local_deep_research/web/routes/route_registry.py +352 -0
  36. local_deep_research/web/routes/settings_routes.py +382 -22
  37. local_deep_research/web/services/research_service.py +22 -29
  38. local_deep_research/web/services/settings_manager.py +53 -0
  39. local_deep_research/web/services/settings_service.py +2 -0
  40. local_deep_research/web/static/css/styles.css +8 -0
  41. local_deep_research/web/static/js/components/detail.js +7 -14
  42. local_deep_research/web/static/js/components/details.js +8 -10
  43. local_deep_research/web/static/js/components/fallback/ui.js +4 -4
  44. local_deep_research/web/static/js/components/history.js +6 -6
  45. local_deep_research/web/static/js/components/logpanel.js +14 -11
  46. local_deep_research/web/static/js/components/progress.js +51 -46
  47. local_deep_research/web/static/js/components/research.js +250 -89
  48. local_deep_research/web/static/js/components/results.js +5 -7
  49. local_deep_research/web/static/js/components/settings.js +32 -26
  50. local_deep_research/web/static/js/components/settings_sync.js +24 -23
  51. local_deep_research/web/static/js/config/urls.js +285 -0
  52. local_deep_research/web/static/js/main.js +8 -8
  53. local_deep_research/web/static/js/research_form.js +267 -12
  54. local_deep_research/web/static/js/services/api.js +18 -18
  55. local_deep_research/web/static/js/services/keyboard.js +8 -8
  56. local_deep_research/web/static/js/services/socket.js +53 -35
  57. local_deep_research/web/static/js/services/ui.js +1 -1
  58. local_deep_research/web/templates/base.html +4 -1
  59. local_deep_research/web/templates/components/custom_dropdown.html +5 -3
  60. local_deep_research/web/templates/components/mobile_nav.html +3 -3
  61. local_deep_research/web/templates/components/sidebar.html +9 -3
  62. local_deep_research/web/templates/pages/benchmark.html +2697 -0
  63. local_deep_research/web/templates/pages/benchmark_results.html +1136 -0
  64. local_deep_research/web/templates/pages/benchmark_simple.html +453 -0
  65. local_deep_research/web/templates/pages/cost_analytics.html +1 -1
  66. local_deep_research/web/templates/pages/metrics.html +212 -39
  67. local_deep_research/web/templates/pages/research.html +8 -6
  68. local_deep_research/web/templates/pages/star_reviews.html +1 -1
  69. local_deep_research/web_search_engines/engines/search_engine_arxiv.py +14 -1
  70. local_deep_research/web_search_engines/engines/search_engine_brave.py +15 -1
  71. local_deep_research/web_search_engines/engines/search_engine_ddg.py +20 -1
  72. local_deep_research/web_search_engines/engines/search_engine_google_pse.py +26 -2
  73. local_deep_research/web_search_engines/engines/search_engine_pubmed.py +15 -1
  74. local_deep_research/web_search_engines/engines/search_engine_retriever.py +192 -0
  75. local_deep_research/web_search_engines/engines/search_engine_tavily.py +307 -0
  76. local_deep_research/web_search_engines/rate_limiting/__init__.py +14 -0
  77. local_deep_research/web_search_engines/rate_limiting/__main__.py +9 -0
  78. local_deep_research/web_search_engines/rate_limiting/cli.py +209 -0
  79. local_deep_research/web_search_engines/rate_limiting/exceptions.py +21 -0
  80. local_deep_research/web_search_engines/rate_limiting/tracker.py +506 -0
  81. local_deep_research/web_search_engines/retriever_registry.py +108 -0
  82. local_deep_research/web_search_engines/search_engine_base.py +161 -43
  83. local_deep_research/web_search_engines/search_engine_factory.py +14 -0
  84. local_deep_research/web_search_engines/search_engines_config.py +20 -0
  85. local_deep_research-0.6.0.dist-info/METADATA +374 -0
  86. {local_deep_research-0.5.9.dist-info → local_deep_research-0.6.0.dist-info}/RECORD +89 -64
  87. local_deep_research-0.5.9.dist-info/METADATA +0 -420
  88. {local_deep_research-0.5.9.dist-info → local_deep_research-0.6.0.dist-info}/WHEEL +0 -0
  89. {local_deep_research-0.5.9.dist-info → local_deep_research-0.6.0.dist-info}/entry_points.txt +0 -0
  90. {local_deep_research-0.5.9.dist-info → local_deep_research-0.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,506 @@
1
+ """
2
+ Adaptive rate limit tracker that learns optimal retry wait times for each search engine.
3
+ """
4
+
5
+ import time
6
+ import random
7
+ import logging
8
+ from collections import deque
9
+ from typing import Dict, Optional, Tuple, List
10
+
11
+
12
+ from ...utilities.db_utils import get_db_session
13
+ from ...web.database.models import RateLimitAttempt, RateLimitEstimate
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class AdaptiveRateLimitTracker:
19
+ """
20
+ Tracks and learns optimal retry wait times for each search engine.
21
+ Persists learned patterns to the main application database using SQLAlchemy.
22
+ """
23
+
24
+ def __init__(self):
25
+ # Load configuration from database settings
26
+ from ...utilities.db_utils import get_db_setting
27
+
28
+ self.memory_window = int(
29
+ get_db_setting("rate_limiting.memory_window", 100)
30
+ )
31
+ self.exploration_rate = float(
32
+ get_db_setting("rate_limiting.exploration_rate", 0.1)
33
+ )
34
+ self.learning_rate = float(
35
+ get_db_setting("rate_limiting.learning_rate", 0.3)
36
+ )
37
+ self.decay_per_day = float(
38
+ get_db_setting("rate_limiting.decay_per_day", 0.95)
39
+ )
40
+ self.enabled = bool(get_db_setting("rate_limiting.enabled", True))
41
+
42
+ # Apply rate limiting profile
43
+ self._apply_profile(get_db_setting("rate_limiting.profile", "balanced"))
44
+
45
+ # In-memory cache for fast access
46
+ self.recent_attempts: Dict[str, deque] = {}
47
+ self.current_estimates: Dict[str, Dict[str, float]] = {}
48
+
49
+ # Load estimates from database
50
+ self._load_estimates()
51
+
52
+ logger.info(
53
+ f"AdaptiveRateLimitTracker initialized: enabled={self.enabled}, profile={get_db_setting('rate_limiting.profile', 'balanced')}"
54
+ )
55
+
56
+ def _apply_profile(self, profile: str) -> None:
57
+ """Apply rate limiting profile settings."""
58
+ if profile == "conservative":
59
+ # More conservative: lower exploration, slower learning
60
+ self.exploration_rate = min(
61
+ self.exploration_rate * 0.5, 0.05
62
+ ) # 5% max exploration
63
+ self.learning_rate = min(
64
+ self.learning_rate * 0.7, 0.2
65
+ ) # Slower learning
66
+ logger.info("Applied conservative rate limiting profile")
67
+ elif profile == "aggressive":
68
+ # More aggressive: higher exploration, faster learning
69
+ self.exploration_rate = min(
70
+ self.exploration_rate * 1.5, 0.2
71
+ ) # Up to 20% exploration
72
+ self.learning_rate = min(
73
+ self.learning_rate * 1.3, 0.5
74
+ ) # Faster learning
75
+ logger.info("Applied aggressive rate limiting profile")
76
+ else: # balanced
77
+ # Use settings as-is
78
+ logger.info("Applied balanced rate limiting profile")
79
+
80
+ def _load_estimates(self) -> None:
81
+ """Load estimates from database into memory."""
82
+ try:
83
+ session = get_db_session()
84
+ estimates = session.query(RateLimitEstimate).all()
85
+
86
+ for estimate in estimates:
87
+ # Apply decay for old estimates
88
+ age_hours = (time.time() - estimate.last_updated) / 3600
89
+ decay = self.decay_per_day ** (age_hours / 24)
90
+
91
+ self.current_estimates[estimate.engine_type] = {
92
+ "base": estimate.base_wait_seconds,
93
+ "min": estimate.min_wait_seconds,
94
+ "max": estimate.max_wait_seconds,
95
+ "confidence": decay,
96
+ }
97
+
98
+ logger.debug(
99
+ f"Loaded estimate for {estimate.engine_type}: base={estimate.base_wait_seconds:.2f}s, confidence={decay:.2f}"
100
+ )
101
+
102
+ except Exception as e:
103
+ logger.warning(f"Could not load rate limit estimates: {e}")
104
+ # Continue with empty estimates - they'll be learned
105
+
106
+ def get_wait_time(self, engine_type: str) -> float:
107
+ """
108
+ Get adaptive wait time for a search engine.
109
+ Includes exploration to discover better rates.
110
+
111
+ Args:
112
+ engine_type: Name of the search engine
113
+
114
+ Returns:
115
+ Wait time in seconds
116
+ """
117
+ # If rate limiting is disabled, return minimal wait time
118
+ if not self.enabled:
119
+ return 0.1
120
+
121
+ if engine_type not in self.current_estimates:
122
+ # First time seeing this engine - start optimistic and learn from real responses
123
+ # Use engine-specific optimistic defaults only for what we know for sure
124
+ optimistic_defaults = {
125
+ "LocalSearchEngine": 0.0, # No network calls
126
+ "SearXNGSearchEngine": 0.1, # Self-hosted default engine
127
+ }
128
+
129
+ wait_time = optimistic_defaults.get(
130
+ engine_type, 0.5
131
+ ) # Default optimistic for others
132
+ logger.info(
133
+ f"No rate limit data for {engine_type}, starting optimistic with {wait_time}s"
134
+ )
135
+ return wait_time
136
+
137
+ estimate = self.current_estimates[engine_type]
138
+ base_wait = estimate["base"]
139
+
140
+ # Exploration vs exploitation
141
+ if random.random() < self.exploration_rate:
142
+ # Explore: try a faster rate to see if API limits have relaxed
143
+ wait_time = base_wait * random.uniform(0.5, 0.9)
144
+ logger.debug(
145
+ f"Exploring faster rate for {engine_type}: {wait_time:.2f}s"
146
+ )
147
+ else:
148
+ # Exploit: use learned estimate with jitter
149
+ wait_time = base_wait * random.uniform(0.9, 1.1)
150
+
151
+ # Enforce bounds
152
+ wait_time = max(estimate["min"], min(wait_time, estimate["max"]))
153
+ return wait_time
154
+
155
+ def record_outcome(
156
+ self,
157
+ engine_type: str,
158
+ wait_time: float,
159
+ success: bool,
160
+ retry_count: int,
161
+ error_type: Optional[str] = None,
162
+ search_result_count: Optional[int] = None,
163
+ ) -> None:
164
+ """
165
+ Record the outcome of a retry attempt.
166
+
167
+ Args:
168
+ engine_type: Name of the search engine
169
+ wait_time: How long we waited before this attempt
170
+ success: Whether the attempt succeeded
171
+ retry_count: Which retry attempt this was (1, 2, 3, etc.)
172
+ error_type: Type of error if failed
173
+ search_result_count: Number of search results returned (for quality monitoring)
174
+ """
175
+ # If rate limiting is disabled, don't record outcomes
176
+ if not self.enabled:
177
+ return
178
+ timestamp = time.time()
179
+
180
+ try:
181
+ # Save to database
182
+ session = get_db_session()
183
+ attempt = RateLimitAttempt(
184
+ engine_type=engine_type,
185
+ timestamp=timestamp,
186
+ wait_time=wait_time,
187
+ retry_count=retry_count,
188
+ success=success,
189
+ error_type=error_type,
190
+ )
191
+ session.add(attempt)
192
+ session.commit()
193
+ except Exception as e:
194
+ logger.error(f"Failed to record rate limit outcome: {e}")
195
+
196
+ # Update in-memory tracking
197
+ if engine_type not in self.recent_attempts:
198
+ # Get current memory window setting
199
+ from ...utilities.db_utils import get_db_setting
200
+
201
+ current_memory_window = int(
202
+ get_db_setting(
203
+ "rate_limiting.memory_window", self.memory_window
204
+ )
205
+ )
206
+ self.recent_attempts[engine_type] = deque(
207
+ maxlen=current_memory_window
208
+ )
209
+
210
+ self.recent_attempts[engine_type].append(
211
+ {
212
+ "wait_time": wait_time,
213
+ "success": success,
214
+ "timestamp": timestamp,
215
+ "retry_count": retry_count,
216
+ "search_result_count": search_result_count,
217
+ }
218
+ )
219
+
220
+ # Update estimates
221
+ self._update_estimate(engine_type)
222
+
223
+ def _update_estimate(self, engine_type: str) -> None:
224
+ """Update wait time estimate based on recent attempts."""
225
+ if (
226
+ engine_type not in self.recent_attempts
227
+ or len(self.recent_attempts[engine_type]) < 3
228
+ ):
229
+ return
230
+
231
+ attempts = list(self.recent_attempts[engine_type])
232
+
233
+ # Calculate success rate and optimal wait time
234
+ successful_waits = [a["wait_time"] for a in attempts if a["success"]]
235
+ failed_waits = [a["wait_time"] for a in attempts if not a["success"]]
236
+
237
+ if not successful_waits:
238
+ # All attempts failed - increase wait time with a cap
239
+ new_base = max(failed_waits) * 1.5 if failed_waits else 10.0
240
+ # Cap the base wait time to prevent runaway growth
241
+ new_base = min(new_base, 10.0) # Max 10 seconds base when all fail
242
+ else:
243
+ # Use 75th percentile of successful waits
244
+ successful_waits.sort()
245
+ percentile_75 = successful_waits[int(len(successful_waits) * 0.75)]
246
+ new_base = percentile_75
247
+
248
+ # Update estimate with learning rate (exponential moving average)
249
+ if engine_type in self.current_estimates:
250
+ old_base = self.current_estimates[engine_type]["base"]
251
+ # Get current learning rate from settings
252
+ from ...utilities.db_utils import get_db_setting
253
+
254
+ current_learning_rate = float(
255
+ get_db_setting(
256
+ "rate_limiting.learning_rate", self.learning_rate
257
+ )
258
+ )
259
+ new_base = (
260
+ 1 - current_learning_rate
261
+ ) * old_base + current_learning_rate * new_base
262
+
263
+ # Apply absolute cap to prevent extreme wait times
264
+ new_base = min(new_base, 10.0) # Cap base at 10 seconds
265
+
266
+ # Calculate bounds with more reasonable limits
267
+ min_wait = max(0.5, new_base * 0.5)
268
+ max_wait = min(10.0, new_base * 3.0) # Max 10 seconds absolute cap
269
+
270
+ # Update in memory
271
+ self.current_estimates[engine_type] = {
272
+ "base": new_base,
273
+ "min": min_wait,
274
+ "max": max_wait,
275
+ "confidence": min(len(attempts) / 20.0, 1.0),
276
+ }
277
+
278
+ # Persist to database
279
+ success_rate = len(successful_waits) / len(attempts) if attempts else 0
280
+
281
+ try:
282
+ session = get_db_session()
283
+
284
+ # Check if estimate exists
285
+ estimate = (
286
+ session.query(RateLimitEstimate)
287
+ .filter_by(engine_type=engine_type)
288
+ .first()
289
+ )
290
+
291
+ if estimate:
292
+ # Update existing estimate
293
+ estimate.base_wait_seconds = new_base
294
+ estimate.min_wait_seconds = min_wait
295
+ estimate.max_wait_seconds = max_wait
296
+ estimate.last_updated = time.time()
297
+ estimate.total_attempts = len(attempts)
298
+ estimate.success_rate = success_rate
299
+ else:
300
+ # Create new estimate
301
+ estimate = RateLimitEstimate(
302
+ engine_type=engine_type,
303
+ base_wait_seconds=new_base,
304
+ min_wait_seconds=min_wait,
305
+ max_wait_seconds=max_wait,
306
+ last_updated=time.time(),
307
+ total_attempts=len(attempts),
308
+ success_rate=success_rate,
309
+ )
310
+ session.add(estimate)
311
+
312
+ session.commit()
313
+
314
+ except Exception as e:
315
+ logger.error(f"Failed to persist rate limit estimate: {e}")
316
+
317
+ logger.info(
318
+ f"Updated rate limit for {engine_type}: {new_base:.2f}s "
319
+ f"(success rate: {success_rate:.1%})"
320
+ )
321
+
322
+ def get_stats(
323
+ self, engine_type: Optional[str] = None
324
+ ) -> List[Tuple[str, float, float, float, float, int, float]]:
325
+ """
326
+ Get statistics for monitoring.
327
+
328
+ Args:
329
+ engine_type: Specific engine to get stats for, or None for all
330
+
331
+ Returns:
332
+ List of tuples with engine statistics
333
+ """
334
+ try:
335
+ session = get_db_session()
336
+
337
+ if engine_type:
338
+ estimates = (
339
+ session.query(RateLimitEstimate)
340
+ .filter_by(engine_type=engine_type)
341
+ .all()
342
+ )
343
+ else:
344
+ estimates = (
345
+ session.query(RateLimitEstimate)
346
+ .order_by(RateLimitEstimate.engine_type)
347
+ .all()
348
+ )
349
+
350
+ return [
351
+ (
352
+ est.engine_type,
353
+ est.base_wait_seconds,
354
+ est.min_wait_seconds,
355
+ est.max_wait_seconds,
356
+ est.last_updated,
357
+ est.total_attempts,
358
+ est.success_rate,
359
+ )
360
+ for est in estimates
361
+ ]
362
+ except Exception as e:
363
+ logger.error(f"Failed to get rate limit stats: {e}")
364
+ return []
365
+
366
+ def reset_engine(self, engine_type: str) -> None:
367
+ """
368
+ Reset learned values for a specific engine.
369
+
370
+ Args:
371
+ engine_type: Engine to reset
372
+ """
373
+ try:
374
+ session = get_db_session()
375
+
376
+ # Delete historical attempts
377
+ session.query(RateLimitAttempt).filter_by(
378
+ engine_type=engine_type
379
+ ).delete()
380
+
381
+ # Delete estimates
382
+ session.query(RateLimitEstimate).filter_by(
383
+ engine_type=engine_type
384
+ ).delete()
385
+
386
+ session.commit()
387
+
388
+ # Clear from memory
389
+ if engine_type in self.recent_attempts:
390
+ del self.recent_attempts[engine_type]
391
+ if engine_type in self.current_estimates:
392
+ del self.current_estimates[engine_type]
393
+
394
+ logger.info(f"Reset rate limit data for {engine_type}")
395
+
396
+ except Exception as e:
397
+ logger.error(
398
+ f"Failed to reset rate limit data for {engine_type}: {e}"
399
+ )
400
+ # Still try to clear from memory even if database operation failed
401
+ if engine_type in self.recent_attempts:
402
+ del self.recent_attempts[engine_type]
403
+ if engine_type in self.current_estimates:
404
+ del self.current_estimates[engine_type]
405
+ # Re-raise the exception so callers know it failed
406
+ raise
407
+
408
+ def get_search_quality_stats(
409
+ self, engine_type: Optional[str] = None
410
+ ) -> List[Dict]:
411
+ """
412
+ Get basic search quality statistics for monitoring.
413
+
414
+ Args:
415
+ engine_type: Specific engine to get stats for, or None for all
416
+
417
+ Returns:
418
+ List of dictionaries with search quality metrics
419
+ """
420
+ stats = []
421
+
422
+ engines_to_check = (
423
+ [engine_type] if engine_type else list(self.recent_attempts.keys())
424
+ )
425
+
426
+ for engine in engines_to_check:
427
+ if engine not in self.recent_attempts:
428
+ continue
429
+
430
+ recent = list(self.recent_attempts[engine])
431
+ search_counts = [
432
+ attempt.get("search_result_count", 0)
433
+ for attempt in recent
434
+ if attempt.get("search_result_count") is not None
435
+ ]
436
+
437
+ if not search_counts:
438
+ continue
439
+
440
+ recent_avg = sum(search_counts) / len(search_counts)
441
+
442
+ stats.append(
443
+ {
444
+ "engine_type": engine,
445
+ "recent_avg_results": recent_avg,
446
+ "min_recent_results": min(search_counts),
447
+ "max_recent_results": max(search_counts),
448
+ "sample_size": len(search_counts),
449
+ "total_attempts": len(recent),
450
+ "status": self._get_quality_status(recent_avg),
451
+ }
452
+ )
453
+
454
+ return stats
455
+
456
+ def _get_quality_status(self, recent_avg: float) -> str:
457
+ """Get quality status string based on average results."""
458
+ if recent_avg < 1:
459
+ return "CRITICAL"
460
+ elif recent_avg < 3:
461
+ return "WARNING"
462
+ elif recent_avg < 5:
463
+ return "CAUTION"
464
+ elif recent_avg >= 10:
465
+ return "EXCELLENT"
466
+ else:
467
+ return "GOOD"
468
+
469
+ def cleanup_old_data(self, days: int = 30) -> None:
470
+ """
471
+ Remove old retry attempt data to prevent database bloat.
472
+
473
+ Args:
474
+ days: Remove data older than this many days
475
+ """
476
+ cutoff_time = time.time() - (days * 24 * 3600)
477
+
478
+ try:
479
+ session = get_db_session()
480
+
481
+ # Count and delete old attempts
482
+ old_attempts = session.query(RateLimitAttempt).filter(
483
+ RateLimitAttempt.timestamp < cutoff_time
484
+ )
485
+ deleted_count = old_attempts.count()
486
+ old_attempts.delete()
487
+
488
+ session.commit()
489
+
490
+ if deleted_count > 0:
491
+ logger.info(f"Cleaned up {deleted_count} old retry attempts")
492
+
493
+ except Exception as e:
494
+ logger.error(f"Failed to cleanup old rate limit data: {e}")
495
+
496
+
497
+ # Create a singleton instance
498
+ _tracker_instance: Optional[AdaptiveRateLimitTracker] = None
499
+
500
+
501
+ def get_tracker() -> AdaptiveRateLimitTracker:
502
+ """Get the global rate limit tracker instance."""
503
+ global _tracker_instance
504
+ if _tracker_instance is None:
505
+ _tracker_instance = AdaptiveRateLimitTracker()
506
+ return _tracker_instance
@@ -0,0 +1,108 @@
1
+ """
2
+ Registry for dynamically registering LangChain retrievers as search engines.
3
+ """
4
+
5
+ from typing import Dict, Optional
6
+ from threading import Lock
7
+ from langchain.schema import BaseRetriever
8
+ from loguru import logger
9
+
10
+
11
+ class RetrieverRegistry:
12
+ """
13
+ Thread-safe registry for LangChain retrievers.
14
+
15
+ This allows users to register retrievers programmatically and use them
16
+ as search engines within LDR.
17
+ """
18
+
19
+ def __init__(self):
20
+ self._retrievers: Dict[str, BaseRetriever] = {}
21
+ self._lock = Lock()
22
+
23
+ def register(self, name: str, retriever: BaseRetriever) -> None:
24
+ """
25
+ Register a retriever with a given name.
26
+
27
+ Args:
28
+ name: Name to register the retriever under
29
+ retriever: LangChain BaseRetriever instance
30
+ """
31
+ with self._lock:
32
+ self._retrievers[name] = retriever
33
+ logger.info(
34
+ f"Registered retriever '{name}' of type {type(retriever).__name__}"
35
+ )
36
+
37
+ def register_multiple(self, retrievers: Dict[str, BaseRetriever]) -> None:
38
+ """
39
+ Register multiple retrievers at once.
40
+
41
+ Args:
42
+ retrievers: Dictionary of {name: retriever} pairs
43
+ """
44
+ with self._lock:
45
+ for name, retriever in retrievers.items():
46
+ self._retrievers[name] = retriever
47
+ logger.info(
48
+ f"Registered retriever '{name}' of type {type(retriever).__name__}"
49
+ )
50
+
51
+ def get(self, name: str) -> Optional[BaseRetriever]:
52
+ """
53
+ Get a registered retriever by name.
54
+
55
+ Args:
56
+ name: Name of the retriever
57
+
58
+ Returns:
59
+ The retriever if found, None otherwise
60
+ """
61
+ with self._lock:
62
+ return self._retrievers.get(name)
63
+
64
+ def unregister(self, name: str) -> None:
65
+ """
66
+ Remove a registered retriever.
67
+
68
+ Args:
69
+ name: Name of the retriever to remove
70
+ """
71
+ with self._lock:
72
+ if name in self._retrievers:
73
+ del self._retrievers[name]
74
+ logger.info(f"Unregistered retriever '{name}'")
75
+
76
+ def clear(self) -> None:
77
+ """Clear all registered retrievers."""
78
+ with self._lock:
79
+ count = len(self._retrievers)
80
+ self._retrievers.clear()
81
+ logger.info(f"Cleared {count} registered retrievers")
82
+
83
+ def is_registered(self, name: str) -> bool:
84
+ """
85
+ Check if a retriever is registered.
86
+
87
+ Args:
88
+ name: Name of the retriever
89
+
90
+ Returns:
91
+ True if registered, False otherwise
92
+ """
93
+ with self._lock:
94
+ return name in self._retrievers
95
+
96
+ def list_registered(self) -> list[str]:
97
+ """
98
+ Get list of all registered retriever names.
99
+
100
+ Returns:
101
+ List of retriever names
102
+ """
103
+ with self._lock:
104
+ return list(self._retrievers.keys())
105
+
106
+
107
+ # Global registry instance
108
+ retriever_registry = RetrieverRegistry()