hindsight-api 0.0.20__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/api/__init__.py +2 -4
- hindsight_api/api/http.py +28 -78
- hindsight_api/api/mcp.py +2 -1
- hindsight_api/cli.py +0 -1
- hindsight_api/engine/cross_encoder.py +6 -1
- hindsight_api/engine/embeddings.py +6 -1
- hindsight_api/engine/entity_resolver.py +56 -29
- hindsight_api/engine/llm_wrapper.py +97 -5
- hindsight_api/engine/memory_engine.py +264 -139
- hindsight_api/engine/response_models.py +15 -17
- hindsight_api/engine/retain/bank_utils.py +23 -33
- hindsight_api/engine/retain/entity_processing.py +5 -5
- hindsight_api/engine/retain/fact_extraction.py +85 -23
- hindsight_api/engine/retain/fact_storage.py +1 -1
- hindsight_api/engine/retain/link_creation.py +12 -6
- hindsight_api/engine/retain/link_utils.py +50 -56
- hindsight_api/engine/retain/observation_regeneration.py +264 -0
- hindsight_api/engine/retain/orchestrator.py +31 -44
- hindsight_api/engine/retain/types.py +14 -0
- hindsight_api/engine/search/retrieval.py +2 -2
- hindsight_api/engine/search/think_utils.py +59 -30
- hindsight_api/migrations.py +54 -32
- hindsight_api/models.py +1 -2
- hindsight_api/pg0.py +17 -36
- {hindsight_api-0.0.20.dist-info → hindsight_api-0.1.0.dist-info}/METADATA +2 -3
- hindsight_api-0.1.0.dist-info/RECORD +51 -0
- hindsight_api-0.0.20.dist-info/RECORD +0 -50
- {hindsight_api-0.0.20.dist-info → hindsight_api-0.1.0.dist-info}/WHEEL +0 -0
- {hindsight_api-0.0.20.dist-info → hindsight_api-0.1.0.dist-info}/entry_points.txt +0 -0
hindsight_api/api/__init__.py
CHANGED
|
@@ -17,18 +17,17 @@ def create_app(
|
|
|
17
17
|
http_api_enabled: bool = True,
|
|
18
18
|
mcp_api_enabled: bool = False,
|
|
19
19
|
mcp_mount_path: str = "/mcp",
|
|
20
|
-
run_migrations: bool = True,
|
|
21
20
|
initialize_memory: bool = True
|
|
22
21
|
) -> FastAPI:
|
|
23
22
|
"""
|
|
24
23
|
Create and configure the unified Hindsight API application.
|
|
25
24
|
|
|
26
25
|
Args:
|
|
27
|
-
memory: MemoryEngine instance (already initialized with required parameters)
|
|
26
|
+
memory: MemoryEngine instance (already initialized with required parameters).
|
|
27
|
+
Migrations are controlled by the MemoryEngine's run_migrations parameter.
|
|
28
28
|
http_api_enabled: Whether to enable HTTP REST API endpoints (default: True)
|
|
29
29
|
mcp_api_enabled: Whether to enable MCP server (default: False)
|
|
30
30
|
mcp_mount_path: Path to mount MCP server (default: /mcp)
|
|
31
|
-
run_migrations: Whether to run database migrations on startup (default: True)
|
|
32
31
|
initialize_memory: Whether to initialize memory system on startup (default: True)
|
|
33
32
|
|
|
34
33
|
Returns:
|
|
@@ -50,7 +49,6 @@ def create_app(
|
|
|
50
49
|
from .http import create_app as create_http_app
|
|
51
50
|
app = create_http_app(
|
|
52
51
|
memory=memory,
|
|
53
|
-
run_migrations=run_migrations,
|
|
54
52
|
initialize_memory=initialize_memory
|
|
55
53
|
)
|
|
56
54
|
logger.info("HTTP REST API enabled")
|
hindsight_api/api/http.py
CHANGED
|
@@ -36,27 +36,13 @@ from pydantic import BaseModel, Field, ConfigDict
|
|
|
36
36
|
from hindsight_api import MemoryEngine
|
|
37
37
|
from hindsight_api.engine.memory_engine import Budget
|
|
38
38
|
from hindsight_api.engine.db_utils import acquire_with_retry
|
|
39
|
+
from hindsight_api.engine.response_models import VALID_RECALL_FACT_TYPES
|
|
39
40
|
from hindsight_api.metrics import get_metrics_collector, initialize_metrics, create_metrics_collector
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
logger = logging.getLogger(__name__)
|
|
43
44
|
|
|
44
45
|
|
|
45
|
-
class MetadataFilter(BaseModel):
|
|
46
|
-
"""Filter for metadata fields. Matches records where (key=value) OR (key not set) when match_unset=True."""
|
|
47
|
-
model_config = ConfigDict(json_schema_extra={
|
|
48
|
-
"example": {
|
|
49
|
-
"key": "source",
|
|
50
|
-
"value": "slack",
|
|
51
|
-
"match_unset": True
|
|
52
|
-
}
|
|
53
|
-
})
|
|
54
|
-
|
|
55
|
-
key: str = Field(description="Metadata key to filter on")
|
|
56
|
-
value: Optional[str] = Field(default=None, description="Value to match. If None with match_unset=True, matches any record where key is not set.")
|
|
57
|
-
match_unset: bool = Field(default=True, description="If True, also match records where this metadata key is not set")
|
|
58
|
-
|
|
59
|
-
|
|
60
46
|
class EntityIncludeOptions(BaseModel):
|
|
61
47
|
"""Options for including entity observations in recall results."""
|
|
62
48
|
max_tokens: int = Field(default=500, description="Maximum tokens for entity observations")
|
|
@@ -89,7 +75,6 @@ class RecallRequest(BaseModel):
|
|
|
89
75
|
"max_tokens": 4096,
|
|
90
76
|
"trace": True,
|
|
91
77
|
"query_timestamp": "2023-05-30T23:40:00",
|
|
92
|
-
"filters": [{"key": "source", "value": "slack", "match_unset": True}],
|
|
93
78
|
"include": {
|
|
94
79
|
"entities": {
|
|
95
80
|
"max_tokens": 500
|
|
@@ -104,7 +89,6 @@ class RecallRequest(BaseModel):
|
|
|
104
89
|
max_tokens: int = 4096
|
|
105
90
|
trace: bool = False
|
|
106
91
|
query_timestamp: Optional[str] = Field(default=None, description="ISO format date string (e.g., '2023-05-30T23:40:00')")
|
|
107
|
-
filters: Optional[List[MetadataFilter]] = Field(default=None, description="Filter by metadata. Multiple filters are ANDed together.")
|
|
108
92
|
include: IncludeOptions = Field(default_factory=IncludeOptions, description="Options for including additional data (entities are included by default)")
|
|
109
93
|
|
|
110
94
|
|
|
@@ -362,7 +346,6 @@ class ReflectRequest(BaseModel):
|
|
|
362
346
|
"query": "What do you think about artificial intelligence?",
|
|
363
347
|
"budget": "low",
|
|
364
348
|
"context": "This is for a research paper on AI ethics",
|
|
365
|
-
"filters": [{"key": "source", "value": "slack", "match_unset": True}],
|
|
366
349
|
"include": {
|
|
367
350
|
"facts": {}
|
|
368
351
|
}
|
|
@@ -372,7 +355,6 @@ class ReflectRequest(BaseModel):
|
|
|
372
355
|
query: str
|
|
373
356
|
budget: Budget = Budget.LOW
|
|
374
357
|
context: Optional[str] = None
|
|
375
|
-
filters: Optional[List[MetadataFilter]] = Field(default=None, description="Filter by metadata. Multiple filters are ANDed together.")
|
|
376
358
|
include: ReflectIncludeOptions = Field(default_factory=ReflectIncludeOptions, description="Options for including additional data (disabled by default)")
|
|
377
359
|
|
|
378
360
|
|
|
@@ -439,24 +421,18 @@ class BanksResponse(BaseModel):
|
|
|
439
421
|
|
|
440
422
|
|
|
441
423
|
class DispositionTraits(BaseModel):
|
|
442
|
-
"""Disposition traits
|
|
424
|
+
"""Disposition traits that influence how memories are formed and interpreted."""
|
|
443
425
|
model_config = ConfigDict(json_schema_extra={
|
|
444
426
|
"example": {
|
|
445
|
-
"
|
|
446
|
-
"
|
|
447
|
-
"
|
|
448
|
-
"agreeableness": 0.7,
|
|
449
|
-
"neuroticism": 0.3,
|
|
450
|
-
"bias_strength": 0.7
|
|
427
|
+
"skepticism": 3,
|
|
428
|
+
"literalism": 3,
|
|
429
|
+
"empathy": 3
|
|
451
430
|
}
|
|
452
431
|
})
|
|
453
432
|
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
agreeableness: float = Field(ge=0.0, le=1.0, description="Agreeableness (0-1)")
|
|
458
|
-
neuroticism: float = Field(ge=0.0, le=1.0, description="Neuroticism (0-1)")
|
|
459
|
-
bias_strength: float = Field(ge=0.0, le=1.0, description="How strongly disposition influences opinions (0-1)")
|
|
433
|
+
skepticism: int = Field(ge=1, le=5, description="How skeptical vs trusting (1=trusting, 5=skeptical)")
|
|
434
|
+
literalism: int = Field(ge=1, le=5, description="How literally to interpret information (1=flexible, 5=literal)")
|
|
435
|
+
empathy: int = Field(ge=1, le=5, description="How much to consider emotional context (1=detached, 5=empathetic)")
|
|
460
436
|
|
|
461
437
|
|
|
462
438
|
class BankProfileResponse(BaseModel):
|
|
@@ -466,12 +442,9 @@ class BankProfileResponse(BaseModel):
|
|
|
466
442
|
"bank_id": "user123",
|
|
467
443
|
"name": "Alice",
|
|
468
444
|
"disposition": {
|
|
469
|
-
"
|
|
470
|
-
"
|
|
471
|
-
"
|
|
472
|
-
"agreeableness": 0.7,
|
|
473
|
-
"neuroticism": 0.3,
|
|
474
|
-
"bias_strength": 0.7
|
|
445
|
+
"skepticism": 3,
|
|
446
|
+
"literalism": 3,
|
|
447
|
+
"empathy": 3
|
|
475
448
|
},
|
|
476
449
|
"background": "I am a software engineer with 10 years of experience in startups"
|
|
477
450
|
}
|
|
@@ -500,7 +473,7 @@ class AddBackgroundRequest(BaseModel):
|
|
|
500
473
|
content: str = Field(description="New background information to add or merge")
|
|
501
474
|
update_disposition: bool = Field(
|
|
502
475
|
default=True,
|
|
503
|
-
description="If true, infer
|
|
476
|
+
description="If true, infer disposition traits from the merged background (default: true)"
|
|
504
477
|
)
|
|
505
478
|
|
|
506
479
|
|
|
@@ -510,12 +483,9 @@ class BackgroundResponse(BaseModel):
|
|
|
510
483
|
"example": {
|
|
511
484
|
"background": "I was born in Texas. I am a software engineer with 10 years of experience.",
|
|
512
485
|
"disposition": {
|
|
513
|
-
"
|
|
514
|
-
"
|
|
515
|
-
"
|
|
516
|
-
"agreeableness": 0.8,
|
|
517
|
-
"neuroticism": 0.4,
|
|
518
|
-
"bias_strength": 0.6
|
|
486
|
+
"skepticism": 3,
|
|
487
|
+
"literalism": 3,
|
|
488
|
+
"empathy": 3
|
|
519
489
|
}
|
|
520
490
|
}
|
|
521
491
|
})
|
|
@@ -543,12 +513,9 @@ class BankListResponse(BaseModel):
|
|
|
543
513
|
"bank_id": "user123",
|
|
544
514
|
"name": "Alice",
|
|
545
515
|
"disposition": {
|
|
546
|
-
"
|
|
547
|
-
"
|
|
548
|
-
"
|
|
549
|
-
"agreeableness": 0.5,
|
|
550
|
-
"neuroticism": 0.5,
|
|
551
|
-
"bias_strength": 0.5
|
|
516
|
+
"skepticism": 3,
|
|
517
|
+
"literalism": 3,
|
|
518
|
+
"empathy": 3
|
|
552
519
|
},
|
|
553
520
|
"background": "I am a software engineer",
|
|
554
521
|
"created_at": "2024-01-15T10:30:00Z",
|
|
@@ -567,12 +534,9 @@ class CreateBankRequest(BaseModel):
|
|
|
567
534
|
"example": {
|
|
568
535
|
"name": "Alice",
|
|
569
536
|
"disposition": {
|
|
570
|
-
"
|
|
571
|
-
"
|
|
572
|
-
"
|
|
573
|
-
"agreeableness": 0.7,
|
|
574
|
-
"neuroticism": 0.3,
|
|
575
|
-
"bias_strength": 0.7
|
|
537
|
+
"skepticism": 3,
|
|
538
|
+
"literalism": 3,
|
|
539
|
+
"empathy": 3
|
|
576
540
|
},
|
|
577
541
|
"background": "I am a creative software engineer with 10 years of experience"
|
|
578
542
|
}
|
|
@@ -715,13 +679,13 @@ class DeleteResponse(BaseModel):
|
|
|
715
679
|
success: bool
|
|
716
680
|
|
|
717
681
|
|
|
718
|
-
def create_app(memory: MemoryEngine,
|
|
682
|
+
def create_app(memory: MemoryEngine, initialize_memory: bool = True) -> FastAPI:
|
|
719
683
|
"""
|
|
720
684
|
Create and configure the FastAPI application.
|
|
721
685
|
|
|
722
686
|
Args:
|
|
723
|
-
memory: MemoryEngine instance (already initialized with required parameters)
|
|
724
|
-
|
|
687
|
+
memory: MemoryEngine instance (already initialized with required parameters).
|
|
688
|
+
Migrations are controlled by the MemoryEngine's run_migrations parameter.
|
|
725
689
|
initialize_memory: Whether to initialize memory system on startup (default: True)
|
|
726
690
|
|
|
727
691
|
Returns:
|
|
@@ -752,16 +716,11 @@ def create_app(memory: MemoryEngine, run_migrations: bool = True, initialize_mem
|
|
|
752
716
|
app.state.prometheus_reader = None
|
|
753
717
|
# Metrics collector is already initialized as no-op by default
|
|
754
718
|
|
|
755
|
-
# Startup: Initialize database and memory system
|
|
719
|
+
# Startup: Initialize database and memory system (migrations run inside initialize if enabled)
|
|
756
720
|
if initialize_memory:
|
|
757
721
|
await memory.initialize()
|
|
758
722
|
logging.info("Memory system initialized")
|
|
759
723
|
|
|
760
|
-
if run_migrations:
|
|
761
|
-
from hindsight_api.migrations import run_migrations as do_migrations
|
|
762
|
-
do_migrations(memory.db_url)
|
|
763
|
-
logging.info("Database migrations applied")
|
|
764
|
-
|
|
765
724
|
|
|
766
725
|
|
|
767
726
|
yield
|
|
@@ -913,17 +872,8 @@ def _register_routes(app: FastAPI):
|
|
|
913
872
|
metrics = get_metrics_collector()
|
|
914
873
|
|
|
915
874
|
try:
|
|
916
|
-
# Validate types
|
|
917
|
-
valid_fact_types = ["world", "experience", "opinion"]
|
|
918
|
-
|
|
919
875
|
# Default to world, experience, opinion if not specified (exclude observation by default)
|
|
920
|
-
fact_types = request.types if request.types else
|
|
921
|
-
for ft in fact_types:
|
|
922
|
-
if ft not in valid_fact_types:
|
|
923
|
-
raise HTTPException(
|
|
924
|
-
status_code=400,
|
|
925
|
-
detail=f"Invalid type '{ft}'. Must be one of: {', '.join(valid_fact_types)}"
|
|
926
|
-
)
|
|
876
|
+
fact_types = request.types if request.types else list(VALID_RECALL_FACT_TYPES)
|
|
927
877
|
|
|
928
878
|
# Parse query_timestamp if provided
|
|
929
879
|
question_date = None
|
|
@@ -1605,7 +1555,7 @@ This operation cannot be undone.
|
|
|
1605
1555
|
"/v1/default/banks/{bank_id}/profile",
|
|
1606
1556
|
response_model=BankProfileResponse,
|
|
1607
1557
|
summary="Update memory bank disposition",
|
|
1608
|
-
description="Update bank's
|
|
1558
|
+
description="Update bank's disposition traits (skepticism, literalism, empathy)",
|
|
1609
1559
|
operation_id="update_bank_disposition"
|
|
1610
1560
|
)
|
|
1611
1561
|
async def api_update_bank_disposition(bank_id: str,
|
|
@@ -1852,7 +1802,7 @@ This operation cannot be undone.
|
|
|
1852
1802
|
"/v1/default/banks/{bank_id}/memories",
|
|
1853
1803
|
response_model=DeleteResponse,
|
|
1854
1804
|
summary="Clear memory bank memories",
|
|
1855
|
-
description="Delete memory units for a memory bank. Optionally filter by type (world, experience, opinion) to delete only specific types. This is a destructive operation that cannot be undone. The bank profile (
|
|
1805
|
+
description="Delete memory units for a memory bank. Optionally filter by type (world, experience, opinion) to delete only specific types. This is a destructive operation that cannot be undone. The bank profile (disposition and background) will be preserved.",
|
|
1856
1806
|
operation_id="clear_bank_memories"
|
|
1857
1807
|
)
|
|
1858
1808
|
async def api_clear_bank_memories(bank_id: str,
|
hindsight_api/api/mcp.py
CHANGED
|
@@ -8,6 +8,7 @@ from typing import Optional
|
|
|
8
8
|
|
|
9
9
|
from fastmcp import FastMCP
|
|
10
10
|
from hindsight_api import MemoryEngine
|
|
11
|
+
from hindsight_api.engine.response_models import VALID_RECALL_FACT_TYPES
|
|
11
12
|
|
|
12
13
|
# Configure logging from HINDSIGHT_API_LOG_LEVEL environment variable
|
|
13
14
|
_log_level_str = os.environ.get("HINDSIGHT_API_LOG_LEVEL", "info").lower()
|
|
@@ -90,7 +91,7 @@ def create_mcp_server(memory: MemoryEngine) -> FastMCP:
|
|
|
90
91
|
search_result = await memory.recall_async(
|
|
91
92
|
bank_id=bank_id,
|
|
92
93
|
query=query,
|
|
93
|
-
fact_type=
|
|
94
|
+
fact_type=list(VALID_RECALL_FACT_TYPES),
|
|
94
95
|
budget=Budget.LOW
|
|
95
96
|
)
|
|
96
97
|
|
hindsight_api/cli.py
CHANGED
|
@@ -78,7 +78,12 @@ class SentenceTransformersCrossEncoder(CrossEncoderModel):
|
|
|
78
78
|
)
|
|
79
79
|
|
|
80
80
|
logger.info(f"Loading cross-encoder model: {self.model_name}...")
|
|
81
|
-
|
|
81
|
+
# Disable lazy loading (meta tensors) which causes issues with newer transformers/accelerate
|
|
82
|
+
# Setting low_cpu_mem_usage=False and device_map=None ensures tensors are fully materialized
|
|
83
|
+
self._model = CrossEncoder(
|
|
84
|
+
self.model_name,
|
|
85
|
+
model_kwargs={"low_cpu_mem_usage": False, "device_map": None},
|
|
86
|
+
)
|
|
82
87
|
logger.info("Cross-encoder model loaded")
|
|
83
88
|
|
|
84
89
|
def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
|
|
@@ -84,7 +84,12 @@ class SentenceTransformersEmbeddings(Embeddings):
|
|
|
84
84
|
)
|
|
85
85
|
|
|
86
86
|
logger.info(f"Loading embedding model: {self.model_name}...")
|
|
87
|
-
|
|
87
|
+
# Disable lazy loading (meta tensors) which causes issues with newer transformers/accelerate
|
|
88
|
+
# Setting low_cpu_mem_usage=False and device_map=None ensures tensors are fully materialized
|
|
89
|
+
self._model = SentenceTransformer(
|
|
90
|
+
self.model_name,
|
|
91
|
+
model_kwargs={"low_cpu_mem_usage": False, "device_map": None},
|
|
92
|
+
)
|
|
88
93
|
|
|
89
94
|
# Validate dimension matches database schema
|
|
90
95
|
model_dim = self._model.get_sentence_embedding_dimension()
|
|
@@ -126,18 +126,20 @@ class EntityResolver:
|
|
|
126
126
|
|
|
127
127
|
# Resolve each entity using pre-fetched candidates
|
|
128
128
|
entity_ids = [None] * len(entities_data)
|
|
129
|
-
entities_to_update = [] # (entity_id,
|
|
130
|
-
entities_to_create = [] # (idx, entity_data)
|
|
129
|
+
entities_to_update = [] # (entity_id, event_date)
|
|
130
|
+
entities_to_create = [] # (idx, entity_data, event_date)
|
|
131
131
|
|
|
132
132
|
for idx, entity_data in enumerate(entities_data):
|
|
133
133
|
entity_text = entity_data['text']
|
|
134
134
|
nearby_entities = entity_data.get('nearby_entities', [])
|
|
135
|
+
# Use per-entity date if available, otherwise fall back to batch-level date
|
|
136
|
+
entity_event_date = entity_data.get('event_date', unit_event_date)
|
|
135
137
|
|
|
136
138
|
candidates = all_candidates.get(entity_text, [])
|
|
137
139
|
|
|
138
140
|
if not candidates:
|
|
139
141
|
# Will create new entity
|
|
140
|
-
entities_to_create.append((idx, entity_data))
|
|
142
|
+
entities_to_create.append((idx, entity_data, entity_event_date))
|
|
141
143
|
continue
|
|
142
144
|
|
|
143
145
|
# Score candidates
|
|
@@ -165,9 +167,9 @@ class EntityResolver:
|
|
|
165
167
|
score += co_entity_score * 0.3
|
|
166
168
|
|
|
167
169
|
# 3. Temporal proximity (0-0.2)
|
|
168
|
-
if last_seen:
|
|
170
|
+
if last_seen and entity_event_date:
|
|
169
171
|
# Normalize timezone awareness for comparison
|
|
170
|
-
event_date_utc =
|
|
172
|
+
event_date_utc = entity_event_date if entity_event_date.tzinfo else entity_event_date.replace(tzinfo=timezone.utc)
|
|
171
173
|
last_seen_utc = last_seen if last_seen.tzinfo else last_seen.replace(tzinfo=timezone.utc)
|
|
172
174
|
days_diff = abs((event_date_utc - last_seen_utc).total_seconds() / 86400)
|
|
173
175
|
if days_diff < 7:
|
|
@@ -183,9 +185,9 @@ class EntityResolver:
|
|
|
183
185
|
|
|
184
186
|
if best_score > threshold:
|
|
185
187
|
entity_ids[idx] = best_candidate
|
|
186
|
-
entities_to_update.append((best_candidate,
|
|
188
|
+
entities_to_update.append((best_candidate, entity_event_date))
|
|
187
189
|
else:
|
|
188
|
-
entities_to_create.append((idx, entity_data))
|
|
190
|
+
entities_to_create.append((idx, entity_data, entity_event_date))
|
|
189
191
|
|
|
190
192
|
# Batch update existing entities
|
|
191
193
|
if entities_to_update:
|
|
@@ -199,29 +201,54 @@ class EntityResolver:
|
|
|
199
201
|
entities_to_update
|
|
200
202
|
)
|
|
201
203
|
|
|
202
|
-
#
|
|
203
|
-
# This
|
|
204
|
-
# only one succeeds and the other gets the existing ID
|
|
204
|
+
# Batch create new entities using COPY + INSERT for maximum speed
|
|
205
|
+
# This handles duplicates via ON CONFLICT and returns all IDs
|
|
205
206
|
if entities_to_create:
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
207
|
+
# Group entities by canonical name (lowercase) to handle duplicates within batch
|
|
208
|
+
# For duplicates, we only insert once and reuse the ID
|
|
209
|
+
unique_entities = {} # lowercase_name -> (entity_data, event_date, [indices])
|
|
210
|
+
for idx, entity_data, event_date in entities_to_create:
|
|
211
|
+
name_lower = entity_data['text'].lower()
|
|
212
|
+
if name_lower not in unique_entities:
|
|
213
|
+
unique_entities[name_lower] = (entity_data, event_date, [idx])
|
|
214
|
+
else:
|
|
215
|
+
# Same entity appears multiple times - add index to list
|
|
216
|
+
unique_entities[name_lower][2].append(idx)
|
|
217
|
+
|
|
218
|
+
# Batch insert unique entities and get their IDs
|
|
219
|
+
# Use a single query with unnest for speed
|
|
220
|
+
entity_names = []
|
|
221
|
+
entity_dates = []
|
|
222
|
+
indices_map = [] # Maps result index -> list of original indices
|
|
223
|
+
|
|
224
|
+
for name_lower, (entity_data, event_date, indices) in unique_entities.items():
|
|
225
|
+
entity_names.append(entity_data['text'])
|
|
226
|
+
entity_dates.append(event_date)
|
|
227
|
+
indices_map.append(indices)
|
|
228
|
+
|
|
229
|
+
# Batch INSERT ... ON CONFLICT with RETURNING
|
|
230
|
+
# This is much faster than individual inserts
|
|
231
|
+
rows = await conn.fetch(
|
|
232
|
+
"""
|
|
233
|
+
INSERT INTO entities (bank_id, canonical_name, first_seen, last_seen, mention_count)
|
|
234
|
+
SELECT $1, name, event_date, event_date, 1
|
|
235
|
+
FROM unnest($2::text[], $3::timestamptz[]) AS t(name, event_date)
|
|
236
|
+
ON CONFLICT (bank_id, LOWER(canonical_name))
|
|
237
|
+
DO UPDATE SET
|
|
238
|
+
mention_count = entities.mention_count + 1,
|
|
239
|
+
last_seen = EXCLUDED.last_seen
|
|
240
|
+
RETURNING id
|
|
241
|
+
""",
|
|
242
|
+
bank_id,
|
|
243
|
+
entity_names,
|
|
244
|
+
entity_dates
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Map returned IDs back to original indices
|
|
248
|
+
for result_idx, row in enumerate(rows):
|
|
249
|
+
entity_id = row['id']
|
|
250
|
+
for original_idx in indices_map[result_idx]:
|
|
251
|
+
entity_ids[original_idx] = entity_id
|
|
225
252
|
|
|
226
253
|
return entity_ids
|
|
227
254
|
|
|
@@ -5,12 +5,15 @@ import os
|
|
|
5
5
|
import time
|
|
6
6
|
import asyncio
|
|
7
7
|
from typing import Optional, Any, Dict, List
|
|
8
|
-
from openai import AsyncOpenAI, RateLimitError, APIError, APIStatusError, LengthFinishReasonError
|
|
8
|
+
from openai import AsyncOpenAI, RateLimitError, APIError, APIStatusError, APIConnectionError, LengthFinishReasonError
|
|
9
9
|
from google import genai
|
|
10
10
|
from google.genai import types as genai_types
|
|
11
11
|
from google.genai import errors as genai_errors
|
|
12
12
|
import logging
|
|
13
13
|
|
|
14
|
+
# Seed applied to every Groq request for deterministic behavior.
|
|
15
|
+
DEFAULT_LLM_SEED = 4242
|
|
16
|
+
|
|
14
17
|
logger = logging.getLogger(__name__)
|
|
15
18
|
|
|
16
19
|
# Disable httpx logging
|
|
@@ -40,6 +43,7 @@ class LLMConfig:
|
|
|
40
43
|
api_key: str,
|
|
41
44
|
base_url: str,
|
|
42
45
|
model: str,
|
|
46
|
+
reasoning_effort: str = "low",
|
|
43
47
|
):
|
|
44
48
|
"""
|
|
45
49
|
Initialize LLM configuration.
|
|
@@ -54,6 +58,7 @@ class LLMConfig:
|
|
|
54
58
|
self.api_key = api_key
|
|
55
59
|
self.base_url = base_url
|
|
56
60
|
self.model = model
|
|
61
|
+
self.reasoning_effort = reasoning_effort
|
|
57
62
|
|
|
58
63
|
# Validate provider
|
|
59
64
|
if self.provider not in ["openai", "groq", "ollama", "gemini"]:
|
|
@@ -136,10 +141,14 @@ class LLMConfig:
|
|
|
136
141
|
"messages": messages,
|
|
137
142
|
**kwargs
|
|
138
143
|
}
|
|
144
|
+
|
|
145
|
+
if self.provider == "groq":
|
|
146
|
+
call_params["seed"] = DEFAULT_LLM_SEED
|
|
147
|
+
|
|
139
148
|
if self.provider == "groq":
|
|
140
149
|
call_params["extra_body"] = {
|
|
141
150
|
"service_tier": "auto",
|
|
142
|
-
"reasoning_effort":
|
|
151
|
+
"reasoning_effort": self.reasoning_effort,
|
|
143
152
|
"include_reasoning": False, # Disable hidden reasoning tokens
|
|
144
153
|
}
|
|
145
154
|
|
|
@@ -187,10 +196,15 @@ class LLMConfig:
|
|
|
187
196
|
usage = response.usage
|
|
188
197
|
if duration > 10.0:
|
|
189
198
|
ratio = max(1, usage.completion_tokens) / usage.prompt_tokens
|
|
199
|
+
# Check for cached tokens (OpenAI/Groq may include this)
|
|
200
|
+
cached_tokens = 0
|
|
201
|
+
if hasattr(usage, 'prompt_tokens_details') and usage.prompt_tokens_details:
|
|
202
|
+
cached_tokens = getattr(usage.prompt_tokens_details, 'cached_tokens', 0) or 0
|
|
203
|
+
cache_info = f", cached_tokens={cached_tokens}" if cached_tokens > 0 else ""
|
|
190
204
|
logger.info(
|
|
191
205
|
f"slow llm call: model={self.provider}/{self.model}, "
|
|
192
206
|
f"input_tokens={usage.prompt_tokens}, output_tokens={usage.completion_tokens}, "
|
|
193
|
-
f"total_tokens={usage.total_tokens}, time={duration:.3f}s, ratio out/in={ratio:.2f}"
|
|
207
|
+
f"total_tokens={usage.total_tokens}{cache_info}, time={duration:.3f}s, ratio out/in={ratio:.2f}"
|
|
194
208
|
)
|
|
195
209
|
|
|
196
210
|
return result
|
|
@@ -202,6 +216,18 @@ class LLMConfig:
|
|
|
202
216
|
f"LLM output exceeded token limits. Input may need to be split into smaller chunks."
|
|
203
217
|
) from e
|
|
204
218
|
|
|
219
|
+
except APIConnectionError as e:
|
|
220
|
+
# Handle connection errors (server disconnected, network issues) with retry
|
|
221
|
+
last_exception = e
|
|
222
|
+
if attempt < max_retries:
|
|
223
|
+
logger.warning(f"Connection error, retrying... (attempt {attempt + 1}/{max_retries + 1})")
|
|
224
|
+
backoff = min(initial_backoff * (2 ** attempt), max_backoff)
|
|
225
|
+
await asyncio.sleep(backoff)
|
|
226
|
+
continue
|
|
227
|
+
else:
|
|
228
|
+
logger.error(f"Connection error after {max_retries + 1} attempts: {str(e)}")
|
|
229
|
+
raise
|
|
230
|
+
|
|
205
231
|
except APIStatusError as e:
|
|
206
232
|
last_exception = e
|
|
207
233
|
if attempt < max_retries:
|
|
@@ -238,7 +264,7 @@ class LLMConfig:
|
|
|
238
264
|
skip_validation: bool,
|
|
239
265
|
start_time: float,
|
|
240
266
|
**kwargs
|
|
241
|
-
|
|
267
|
+
) -> Any:
|
|
242
268
|
"""Handle Gemini-specific API calls using google-genai SDK."""
|
|
243
269
|
import json
|
|
244
270
|
|
|
@@ -287,6 +313,8 @@ class LLMConfig:
|
|
|
287
313
|
config_kwargs['max_output_tokens'] = kwargs['max_tokens']
|
|
288
314
|
if response_format is not None:
|
|
289
315
|
config_kwargs['response_mime_type'] = 'application/json'
|
|
316
|
+
# Pass the Pydantic model directly as response_schema for structured output
|
|
317
|
+
config_kwargs['response_schema'] = response_format
|
|
290
318
|
|
|
291
319
|
generation_config = genai_types.GenerateContentConfig(**config_kwargs) if config_kwargs else None
|
|
292
320
|
|
|
@@ -302,6 +330,23 @@ class LLMConfig:
|
|
|
302
330
|
|
|
303
331
|
content = response.text
|
|
304
332
|
|
|
333
|
+
# Handle empty/None response (can happen with content filtering or timeouts)
|
|
334
|
+
if content is None:
|
|
335
|
+
# Check if there's a block reason
|
|
336
|
+
block_reason = None
|
|
337
|
+
if hasattr(response, 'candidates') and response.candidates:
|
|
338
|
+
candidate = response.candidates[0]
|
|
339
|
+
if hasattr(candidate, 'finish_reason'):
|
|
340
|
+
block_reason = candidate.finish_reason
|
|
341
|
+
|
|
342
|
+
if attempt < max_retries:
|
|
343
|
+
logger.warning(f"Gemini returned empty response (reason: {block_reason}), retrying... (attempt {attempt + 1}/{max_retries + 1})")
|
|
344
|
+
backoff = min(initial_backoff * (2 ** attempt), max_backoff)
|
|
345
|
+
await asyncio.sleep(backoff)
|
|
346
|
+
continue
|
|
347
|
+
else:
|
|
348
|
+
raise RuntimeError(f"Gemini returned empty response after {max_retries + 1} attempts (reason: {block_reason})")
|
|
349
|
+
|
|
305
350
|
if response_format is not None:
|
|
306
351
|
# Parse the JSON response
|
|
307
352
|
json_data = json.loads(content)
|
|
@@ -318,14 +363,29 @@ class LLMConfig:
|
|
|
318
363
|
duration = time.time() - start_time
|
|
319
364
|
if duration > 10.0 and hasattr(response, 'usage_metadata') and response.usage_metadata:
|
|
320
365
|
usage = response.usage_metadata
|
|
366
|
+
# Check for cached tokens (Gemini uses cached_content_token_count)
|
|
367
|
+
cached_tokens = getattr(usage, 'cached_content_token_count', 0) or 0
|
|
368
|
+
cache_info = f", cached_tokens={cached_tokens}" if cached_tokens > 0 else ""
|
|
321
369
|
logger.info(
|
|
322
370
|
f"slow llm call: model={self.provider}/{self.model}, "
|
|
323
|
-
f"input_tokens={usage.prompt_token_count}, output_tokens={usage.candidates_token_count}, "
|
|
371
|
+
f"input_tokens={usage.prompt_token_count}, output_tokens={usage.candidates_token_count}{cache_info}, "
|
|
324
372
|
f"time={duration:.3f}s"
|
|
325
373
|
)
|
|
326
374
|
|
|
327
375
|
return result
|
|
328
376
|
|
|
377
|
+
except json.JSONDecodeError as e:
|
|
378
|
+
# Handle truncated JSON responses (often from MAX_TOKENS) with retry
|
|
379
|
+
last_exception = e
|
|
380
|
+
if attempt < max_retries:
|
|
381
|
+
logger.warning(f"Gemini returned invalid JSON (truncated response?), retrying... (attempt {attempt + 1}/{max_retries + 1})")
|
|
382
|
+
backoff = min(initial_backoff * (2 ** attempt), max_backoff)
|
|
383
|
+
await asyncio.sleep(backoff)
|
|
384
|
+
continue
|
|
385
|
+
else:
|
|
386
|
+
logger.error(f"Gemini returned invalid JSON after {max_retries + 1} attempts: {str(e)}")
|
|
387
|
+
raise
|
|
388
|
+
|
|
329
389
|
except genai_errors.APIError as e:
|
|
330
390
|
# Handle rate limits and server errors with retry
|
|
331
391
|
if e.code in (429, 503, 500):
|
|
@@ -372,6 +432,37 @@ class LLMConfig:
|
|
|
372
432
|
api_key=api_key,
|
|
373
433
|
base_url=base_url,
|
|
374
434
|
model=model,
|
|
435
|
+
reasoning_effort="low"
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
@classmethod
|
|
439
|
+
def for_answer_generation(cls) -> "LLMConfig":
|
|
440
|
+
"""
|
|
441
|
+
Create configuration for answer generation operations from environment variables.
|
|
442
|
+
|
|
443
|
+
Falls back to memory LLM config if answer-specific config not set.
|
|
444
|
+
"""
|
|
445
|
+
# Check if answer-specific config exists, otherwise fall back to memory config
|
|
446
|
+
provider = os.getenv("HINDSIGHT_API_ANSWER_LLM_PROVIDER", os.getenv("HINDSIGHT_API_LLM_PROVIDER", "groq"))
|
|
447
|
+
api_key = os.getenv("HINDSIGHT_API_ANSWER_LLM_API_KEY", os.getenv("HINDSIGHT_API_LLM_API_KEY"))
|
|
448
|
+
base_url = os.getenv("HINDSIGHT_API_ANSWER_LLM_BASE_URL", os.getenv("HINDSIGHT_API_LLM_BASE_URL"))
|
|
449
|
+
model = os.getenv("HINDSIGHT_API_ANSWER_LLM_MODEL", os.getenv("HINDSIGHT_API_LLM_MODEL", "openai/gpt-oss-120b"))
|
|
450
|
+
|
|
451
|
+
# Set default base URL if not provided
|
|
452
|
+
if not base_url:
|
|
453
|
+
if provider == "groq":
|
|
454
|
+
base_url = "https://api.groq.com/openai/v1"
|
|
455
|
+
elif provider == "ollama":
|
|
456
|
+
base_url = "http://localhost:11434/v1"
|
|
457
|
+
else:
|
|
458
|
+
base_url = ""
|
|
459
|
+
|
|
460
|
+
return cls(
|
|
461
|
+
provider=provider,
|
|
462
|
+
api_key=api_key,
|
|
463
|
+
base_url=base_url,
|
|
464
|
+
model=model,
|
|
465
|
+
reasoning_effort="high"
|
|
375
466
|
)
|
|
376
467
|
|
|
377
468
|
@classmethod
|
|
@@ -401,4 +492,5 @@ class LLMConfig:
|
|
|
401
492
|
api_key=api_key,
|
|
402
493
|
base_url=base_url,
|
|
403
494
|
model=model,
|
|
495
|
+
reasoning_effort="high"
|
|
404
496
|
)
|