okb 1.1.0__py3-none-any.whl → 1.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
okb/llm/consolidate.py DELETED
@@ -1,685 +0,0 @@
1
- """Entity consolidation orchestration - clustering, relationships, and full pipeline."""
2
-
3
- from __future__ import annotations
4
-
5
- import json
6
- from dataclasses import dataclass, field
7
- from datetime import UTC, datetime
8
- from typing import Any
9
-
10
- import psycopg
11
- from psycopg.rows import dict_row
12
-
13
-
14
- @dataclass
15
- class TopicCluster:
16
- """A cluster of related entities and documents."""
17
-
18
- id: str
19
- name: str
20
- description: str | None
21
- member_count: int
22
- entities: list[dict] # [{id, name, type, distance}]
23
- documents: list[dict] # [{id, title, distance}]
24
-
25
-
26
- @dataclass
27
- class EntityRelationship:
28
- """A relationship between two entities."""
29
-
30
- id: str
31
- source_entity: dict # {id, name, type}
32
- target_entity: dict # {id, name, type}
33
- relationship_type: str # works_for, uses, belongs_to, related_to
34
- confidence: float
35
- context: str | None
36
-
37
-
38
- @dataclass
39
- class ConsolidationResult:
40
- """Result from running consolidation pipeline."""
41
-
42
- duplicates_found: int = 0
43
- merges_pending: int = 0
44
- merges_auto_approved: int = 0
45
- cross_doc_candidates: int = 0
46
- clusters_created: int = 0
47
- relationships_found: int = 0
48
- started_at: datetime = field(default_factory=lambda: datetime.now(UTC))
49
- completed_at: datetime | None = None
50
- errors: list[str] = field(default_factory=list)
51
-
52
-
53
- RELATIONSHIP_SYSTEM = """\
54
- You are an expert at identifying relationships between entities.
55
- Given pairs of entities that appear in the same documents, identify any relationships.
56
-
57
- Relationship types:
58
- - works_for: Person works for organization
59
- - uses: Project/person uses technology
60
- - belongs_to: Entity belongs to/is part of another
61
- - related_to: General association (only if no specific type applies)
62
-
63
- Return ONLY valid JSON:
64
- {
65
- "relationships": [
66
- {
67
- "source": "Entity 1",
68
- "target": "Entity 2",
69
- "type": "uses",
70
- "confidence": 0.9,
71
- "reason": "Brief explanation"
72
- }
73
- ]
74
- }
75
-
76
- If no relationships found, return: {"relationships": []}
77
- """
78
-
79
- RELATIONSHIP_USER = """\
80
- Identify relationships between these entity pairs that co-occur in documents:
81
-
82
- {entity_pairs}
83
-
84
- Only include high-confidence relationships.
85
- """
86
-
87
- CLUSTER_NAMING_SYSTEM = """\
88
- You are naming a topic cluster based on its member entities.
89
- Create a short, descriptive name (2-5 words) and brief description.
90
-
91
- Return ONLY valid JSON:
92
- {"name": "Cluster Name", "description": "One sentence description"}
93
- """
94
-
95
- CLUSTER_NAMING_USER = """\
96
- Name this cluster containing these entities:
97
-
98
- {entity_list}
99
-
100
- The cluster should have a name that captures the common theme.
101
- """
102
-
103
-
104
- def run_consolidation(
105
- db_url: str,
106
- detect_duplicates: bool = True,
107
- detect_cross_doc: bool = True,
108
- build_clusters: bool = True,
109
- extract_relationships: bool = True,
110
- auto_merge_threshold: float = 0.95,
111
- dry_run: bool = False,
112
- ) -> ConsolidationResult:
113
- """Run the full entity consolidation pipeline.
114
-
115
- Args:
116
- db_url: Database URL
117
- detect_duplicates: Run duplicate detection
118
- detect_cross_doc: Run cross-document entity detection
119
- build_clusters: Build topic clusters
120
- extract_relationships: Extract entity relationships
121
- auto_merge_threshold: Auto-approve merges above this confidence
122
- dry_run: Don't make changes, just report what would happen
123
-
124
- Returns:
125
- ConsolidationResult with counts and status
126
- """
127
- result = ConsolidationResult()
128
-
129
- # Log the run
130
- run_id = None
131
- if not dry_run:
132
- with psycopg.connect(db_url) as conn:
133
- r = conn.execute(
134
- "INSERT INTO consolidation_runs (run_type) VALUES ('full') RETURNING id"
135
- ).fetchone()
136
- run_id = r[0] if r else None
137
- conn.commit()
138
-
139
- try:
140
- # Phase 1: Duplicate detection
141
- if detect_duplicates:
142
- from .extractors.dedup import (
143
- approve_merge,
144
- create_pending_merge,
145
- find_duplicate_entities,
146
- )
147
-
148
- pairs = find_duplicate_entities(db_url)
149
- result.duplicates_found = len(pairs)
150
-
151
- if not dry_run:
152
- for pair in pairs:
153
- if pair.confidence >= auto_merge_threshold:
154
- # Auto-approve high-confidence merges
155
- merge_id = create_pending_merge(db_url, pair)
156
- if merge_id and approve_merge(db_url, merge_id):
157
- result.merges_auto_approved += 1
158
- else:
159
- # Create pending for review
160
- if create_pending_merge(db_url, pair):
161
- result.merges_pending += 1
162
-
163
- # Phase 2: Cross-document detection
164
- if detect_cross_doc:
165
- from .extractors.cross_doc import (
166
- classify_candidates,
167
- find_cross_document_entities,
168
- store_candidates,
169
- )
170
-
171
- candidates = find_cross_document_entities(db_url)
172
- if candidates:
173
- # Classify with LLM
174
- classify_candidates(candidates, db_url)
175
- result.cross_doc_candidates = len(candidates)
176
-
177
- if not dry_run:
178
- store_candidates(db_url, candidates)
179
-
180
- # Phase 3: Topic clustering
181
- if build_clusters:
182
- clusters = build_topic_clusters(db_url, dry_run=dry_run)
183
- result.clusters_created = len(clusters)
184
-
185
- # Phase 4: Entity relationships
186
- if extract_relationships:
187
- relationships = extract_entity_relationships(db_url, dry_run=dry_run)
188
- result.relationships_found = len(relationships)
189
-
190
- result.completed_at = datetime.now(UTC)
191
-
192
- # Update run record
193
- if run_id and not dry_run:
194
- with psycopg.connect(db_url) as conn:
195
- conn.execute(
196
- """
197
- UPDATE consolidation_runs
198
- SET completed_at = NOW(),
199
- stats = %s
200
- WHERE id = %s
201
- """,
202
- (
203
- psycopg.types.json.Json(
204
- {
205
- "duplicates_found": result.duplicates_found,
206
- "merges_pending": result.merges_pending,
207
- "merges_auto_approved": result.merges_auto_approved,
208
- "cross_doc_candidates": result.cross_doc_candidates,
209
- "clusters_created": result.clusters_created,
210
- "relationships_found": result.relationships_found,
211
- }
212
- ),
213
- run_id,
214
- ),
215
- )
216
- conn.commit()
217
-
218
- except Exception as e:
219
- result.errors.append(str(e))
220
- if run_id and not dry_run:
221
- with psycopg.connect(db_url) as conn:
222
- conn.execute(
223
- "UPDATE consolidation_runs SET error = %s WHERE id = %s",
224
- (str(e), run_id),
225
- )
226
- conn.commit()
227
-
228
- return result
229
-
230
-
231
- def build_topic_clusters(
232
- db_url: str,
233
- n_clusters: int | None = None,
234
- min_cluster_size: int = 3,
235
- dry_run: bool = False,
236
- ) -> list[TopicCluster]:
237
- """Build topic clusters from entity embeddings using k-means.
238
-
239
- Args:
240
- db_url: Database URL
241
- n_clusters: Number of clusters (auto-determined if None)
242
- min_cluster_size: Minimum entities per cluster
243
- dry_run: Don't save to database
244
-
245
- Returns:
246
- List of TopicCluster objects
247
- """
248
- from . import complete
249
-
250
- clusters: list[TopicCluster] = []
251
-
252
- with psycopg.connect(db_url, row_factory=dict_row) as conn:
253
- from pgvector.psycopg import register_vector
254
-
255
- register_vector(conn)
256
-
257
- # Get entity embeddings
258
- entities = conn.execute(
259
- """
260
- SELECT d.id, d.title, d.metadata->>'entity_type' as entity_type,
261
- (SELECT embedding FROM chunks WHERE document_id = d.id LIMIT 1) as embedding
262
- FROM documents d
263
- WHERE d.source_type = 'entity'
264
- AND EXISTS (SELECT 1 FROM chunks WHERE document_id = d.id)
265
- """
266
- ).fetchall()
267
-
268
- if len(entities) < min_cluster_size:
269
- return []
270
-
271
- # Auto-determine cluster count: sqrt(n/2), min 2, max 20
272
- if n_clusters is None:
273
- n_clusters = max(2, min(20, int((len(entities) / 2) ** 0.5)))
274
-
275
- # Simple k-means using PostgreSQL
276
- # Initialize centroids with random entities
277
- import random
278
-
279
- centroid_entities = random.sample(list(entities), min(n_clusters, len(entities)))
280
- centroids = [e["embedding"] for e in centroid_entities if e["embedding"]]
281
-
282
- if len(centroids) < 2:
283
- return []
284
-
285
- # Assign entities to nearest centroid
286
- entity_clusters: dict[int, list[dict]] = {i: [] for i in range(len(centroids))}
287
-
288
- for entity in entities:
289
- if entity["embedding"] is None:
290
- continue
291
-
292
- # Find nearest centroid
293
- best_cluster = 0
294
- best_distance = float("inf")
295
- for i, centroid in enumerate(centroids):
296
- result = conn.execute(
297
- "SELECT %s::vector <=> %s::vector as dist",
298
- (entity["embedding"], centroid),
299
- ).fetchone()
300
- dist = result["dist"]
301
- if dist < best_distance:
302
- best_distance = dist
303
- best_cluster = i
304
-
305
- entity_clusters[best_cluster].append(
306
- {
307
- "id": str(entity["id"]),
308
- "name": entity["title"],
309
- "type": entity["entity_type"],
310
- "distance": best_distance,
311
- "embedding": entity["embedding"],
312
- }
313
- )
314
-
315
- # Filter clusters by size and create
316
- for cluster_idx, members in entity_clusters.items():
317
- if len(members) < min_cluster_size:
318
- continue
319
-
320
- # Calculate centroid as average of member embeddings
321
- if not members:
322
- continue
323
-
324
- # Get cluster name from LLM
325
- entity_list = "\n".join(
326
- f"- {m['name']} ({m['type']})" for m in sorted(members, key=lambda x: x["distance"])
327
- )
328
- prompt = CLUSTER_NAMING_USER.format(entity_list=entity_list)
329
- response = complete(prompt, system=CLUSTER_NAMING_SYSTEM, max_tokens=256, use_cache=True)
330
-
331
- cluster_name = f"Cluster {cluster_idx + 1}"
332
- cluster_desc = None
333
- if response:
334
- try:
335
- content = response.content.strip()
336
- if content.startswith("```"):
337
- lines = content.split("\n")
338
- content = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
339
- data = json.loads(content)
340
- cluster_name = data.get("name", cluster_name)
341
- cluster_desc = data.get("description")
342
- except json.JSONDecodeError:
343
- pass
344
-
345
- # Calculate centroid (average embedding)
346
- # For simplicity, use first member's embedding as proxy
347
- centroid = members[0]["embedding"]
348
-
349
- cluster = TopicCluster(
350
- id="",
351
- name=cluster_name,
352
- description=cluster_desc,
353
- member_count=len(members),
354
- entities=[
355
- {"id": m["id"], "name": m["name"], "type": m["type"], "distance": m["distance"]}
356
- for m in members
357
- ],
358
- documents=[],
359
- )
360
-
361
- if not dry_run:
362
- # Save cluster
363
- result = conn.execute(
364
- """
365
- INSERT INTO topic_clusters (name, description, centroid, member_count)
366
- VALUES (%s, %s, %s, %s)
367
- RETURNING id
368
- """,
369
- (cluster_name, cluster_desc, centroid, len(members)),
370
- ).fetchone()
371
-
372
- if result:
373
- cluster.id = str(result["id"])
374
-
375
- # Add members
376
- for m in members:
377
- conn.execute(
378
- """
379
- INSERT INTO topic_cluster_members
380
- (cluster_id, document_id, distance, is_entity)
381
- VALUES (%s, %s, %s, TRUE)
382
- ON CONFLICT DO NOTHING
383
- """,
384
- (result["id"], m["id"], m["distance"]),
385
- )
386
-
387
- conn.commit()
388
-
389
- clusters.append(cluster)
390
-
391
- return clusters
392
-
393
-
394
- def extract_entity_relationships(
395
- db_url: str,
396
- entity_ids: list[str] | None = None,
397
- dry_run: bool = False,
398
- ) -> list[EntityRelationship]:
399
- """Extract relationships between entities that co-occur in documents.
400
-
401
- Args:
402
- db_url: Database URL
403
- entity_ids: Filter to specific entities (None = all)
404
- dry_run: Don't save to database
405
-
406
- Returns:
407
- List of EntityRelationship objects
408
- """
409
- from . import complete
410
-
411
- relationships: list[EntityRelationship] = []
412
-
413
- with psycopg.connect(db_url, row_factory=dict_row) as conn:
414
- # Find entity pairs that co-occur in documents
415
- sql = """
416
- SELECT
417
- e1.id as e1_id, e1.title as e1_name,
418
- e1.metadata->>'entity_type' as e1_type,
419
- e2.id as e2_id, e2.title as e2_name,
420
- e2.metadata->>'entity_type' as e2_type,
421
- COUNT(DISTINCT r1.document_id) as shared_docs
422
- FROM entity_refs r1
423
- JOIN entity_refs r2 ON r1.document_id = r2.document_id
424
- JOIN documents e1 ON e1.id = r1.entity_id
425
- JOIN documents e2 ON e2.id = r2.entity_id
426
- WHERE r1.entity_id < r2.entity_id -- Avoid duplicates
427
- AND e1.source_type = 'entity'
428
- AND e2.source_type = 'entity'
429
- """
430
- params: list[Any] = []
431
-
432
- if entity_ids:
433
- sql += " AND (r1.entity_id = ANY(%s) OR r2.entity_id = ANY(%s))"
434
- params.extend([entity_ids, entity_ids])
435
-
436
- sql += """
437
- GROUP BY e1.id, e1.title, e1.metadata->>'entity_type',
438
- e2.id, e2.title, e2.metadata->>'entity_type'
439
- HAVING COUNT(DISTINCT r1.document_id) >= 2
440
- ORDER BY shared_docs DESC
441
- LIMIT 50
442
- """
443
-
444
- pairs = conn.execute(sql, params).fetchall()
445
-
446
- if not pairs:
447
- return []
448
-
449
- # Format for LLM
450
- pair_lines = []
451
- pair_map = {}
452
- for p in pairs:
453
- key = f"{p['e1_name']}|{p['e2_name']}"
454
- pair_lines.append(
455
- f"- {p['e1_name']} ({p['e1_type']}) <-> "
456
- f"{p['e2_name']} ({p['e2_type']}) "
457
- f"[{p['shared_docs']} shared docs]"
458
- )
459
- pair_map[key] = p
460
-
461
- prompt = RELATIONSHIP_USER.format(entity_pairs="\n".join(pair_lines))
462
- response = complete(prompt, system=RELATIONSHIP_SYSTEM, max_tokens=2048, use_cache=True)
463
-
464
- if response is None:
465
- return []
466
-
467
- # Parse response
468
- try:
469
- content = response.content.strip()
470
- if content.startswith("```"):
471
- lines = content.split("\n")
472
- content = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
473
- data = json.loads(content)
474
- except json.JSONDecodeError:
475
- return []
476
-
477
- for rel in data.get("relationships", []):
478
- source_name = rel.get("source", "")
479
- target_name = rel.get("target", "")
480
- rel_type = rel.get("type", "related_to")
481
- confidence = rel.get("confidence", 0.5)
482
- reason = rel.get("reason")
483
-
484
- # Find matching pair
485
- key = f"{source_name}|{target_name}"
486
- key_rev = f"{target_name}|{source_name}"
487
- pair = pair_map.get(key) or pair_map.get(key_rev)
488
-
489
- if pair:
490
- # Ensure source/target order matches pair
491
- if key_rev in pair_map:
492
- source_name, target_name = target_name, source_name
493
-
494
- relationship = EntityRelationship(
495
- id="",
496
- source_entity={
497
- "id": str(pair["e1_id"]),
498
- "name": pair["e1_name"],
499
- "type": pair["e1_type"],
500
- },
501
- target_entity={
502
- "id": str(pair["e2_id"]),
503
- "name": pair["e2_name"],
504
- "type": pair["e2_type"],
505
- },
506
- relationship_type=rel_type,
507
- confidence=confidence,
508
- context=reason,
509
- )
510
-
511
- if not dry_run:
512
- result = conn.execute(
513
- """
514
- INSERT INTO entity_relationships
515
- (source_entity_id, target_entity_id, relationship_type,
516
- confidence, context)
517
- VALUES (%s, %s, %s, %s, %s)
518
- ON CONFLICT (source_entity_id, target_entity_id, relationship_type)
519
- DO UPDATE SET confidence = EXCLUDED.confidence
520
- RETURNING id
521
- """,
522
- (pair["e1_id"], pair["e2_id"], rel_type, confidence, reason),
523
- ).fetchone()
524
- if result:
525
- relationship.id = str(result["id"])
526
-
527
- relationships.append(relationship)
528
-
529
- if not dry_run:
530
- conn.commit()
531
-
532
- return relationships
533
-
534
-
535
- def get_topic_clusters(db_url: str, limit: int = 20) -> list[dict]:
536
- """Get topic clusters with their members.
537
-
538
- Returns list of cluster dicts.
539
- """
540
- with psycopg.connect(db_url, row_factory=dict_row) as conn:
541
- clusters = conn.execute(
542
- """
543
- SELECT id, name, description, member_count, created_at
544
- FROM topic_clusters
545
- ORDER BY member_count DESC
546
- LIMIT %s
547
- """,
548
- (limit,),
549
- ).fetchall()
550
-
551
- result = []
552
- for c in clusters:
553
- # Get members
554
- members = conn.execute(
555
- """
556
- SELECT d.id, d.title, d.source_type, m.distance, m.is_entity
557
- FROM topic_cluster_members m
558
- JOIN documents d ON d.id = m.document_id
559
- WHERE m.cluster_id = %s
560
- ORDER BY m.distance
561
- LIMIT 20
562
- """,
563
- (c["id"],),
564
- ).fetchall()
565
-
566
- result.append(
567
- {
568
- "id": str(c["id"]),
569
- "name": c["name"],
570
- "description": c["description"],
571
- "member_count": c["member_count"],
572
- "members": [
573
- {
574
- "id": str(m["id"]),
575
- "title": m["title"],
576
- "type": m["source_type"],
577
- "is_entity": m["is_entity"],
578
- "distance": m["distance"],
579
- }
580
- for m in members
581
- ],
582
- }
583
- )
584
-
585
- return result
586
-
587
-
588
- def get_entity_relationships(
589
- db_url: str,
590
- entity_name: str | None = None,
591
- relationship_type: str | None = None,
592
- limit: int = 50,
593
- ) -> list[dict]:
594
- """Get entity relationships.
595
-
596
- Args:
597
- db_url: Database URL
598
- entity_name: Filter to relationships involving this entity
599
- relationship_type: Filter by relationship type (works_for, uses, belongs_to, related_to)
600
- limit: Maximum results
601
-
602
- Returns:
603
- List of relationship dicts.
604
- """
605
- with psycopg.connect(db_url, row_factory=dict_row) as conn:
606
- sql = """
607
- SELECT
608
- r.id, r.relationship_type, r.confidence, r.context,
609
- s.id as source_id, s.title as source_name,
610
- s.metadata->>'entity_type' as source_type,
611
- t.id as target_id, t.title as target_name,
612
- t.metadata->>'entity_type' as target_type
613
- FROM entity_relationships r
614
- JOIN documents s ON s.id = r.source_entity_id
615
- JOIN documents t ON t.id = r.target_entity_id
616
- WHERE 1=1
617
- """
618
- params: list[Any] = []
619
-
620
- if entity_name:
621
- sql += " AND (LOWER(s.title) = LOWER(%s) OR LOWER(t.title) = LOWER(%s))"
622
- params.extend([entity_name, entity_name])
623
-
624
- if relationship_type:
625
- sql += " AND r.relationship_type = %s"
626
- params.append(relationship_type)
627
-
628
- sql += " ORDER BY r.confidence DESC LIMIT %s"
629
- params.append(limit)
630
-
631
- results = conn.execute(sql, params).fetchall()
632
-
633
- return [
634
- {
635
- "id": str(r["id"]),
636
- "source": {
637
- "id": str(r["source_id"]),
638
- "name": r["source_name"],
639
- "type": r["source_type"],
640
- },
641
- "target": {
642
- "id": str(r["target_id"]),
643
- "name": r["target_name"],
644
- "type": r["target_type"],
645
- },
646
- "type": r["relationship_type"],
647
- "confidence": r["confidence"],
648
- "context": r["context"],
649
- }
650
- for r in results
651
- ]
652
-
653
-
654
- def format_consolidation_result(result: ConsolidationResult) -> str:
655
- """Format consolidation result for display."""
656
- lines = ["## Consolidation Results\n"]
657
-
658
- if result.duplicates_found:
659
- lines.append(f"**Duplicate Detection:** {result.duplicates_found} potential duplicates found")
660
- if result.merges_auto_approved:
661
- lines.append(f" - {result.merges_auto_approved} auto-approved (high confidence)")
662
- if result.merges_pending:
663
- lines.append(f" - {result.merges_pending} pending review")
664
-
665
- if result.cross_doc_candidates:
666
- lines.append(
667
- f"**Cross-Document Entities:** {result.cross_doc_candidates} candidates detected"
668
- )
669
-
670
- if result.clusters_created:
671
- lines.append(f"**Topic Clusters:** {result.clusters_created} clusters created")
672
-
673
- if result.relationships_found:
674
- lines.append(f"**Entity Relationships:** {result.relationships_found} relationships found")
675
-
676
- if result.errors:
677
- lines.append(f"\n**Errors:** {len(result.errors)}")
678
- for err in result.errors[:5]:
679
- lines.append(f" - {err}")
680
-
681
- if result.completed_at:
682
- duration = (result.completed_at - result.started_at).total_seconds()
683
- lines.append(f"\nCompleted in {duration:.1f}s")
684
-
685
- return "\n".join(lines)