rdf-starbase 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rdf_starbase/store.py ADDED
@@ -0,0 +1,1049 @@
1
+ """
2
+ Unified TripleStore implementation using FactStore/TermDict.
3
+
4
+ This refactored TripleStore uses the dictionary-encoded integer storage
5
+ internally while maintaining backward compatibility with the existing API.
6
+ The SPARQL executor and AI grounding layer continue to work unchanged.
7
+
8
+ Key design:
9
+ - FactStore holds facts as integer IDs (g, s, p, o columns)
10
+ - TermDict maps RDF terms to/from integer IDs
11
+ - _df property materializes a string-based view for SPARQL executor
12
+ - Reasoner can now work directly with the integer-based storage
13
+ """
14
+
15
+ from datetime import datetime, timezone
16
+ from typing import Optional, Any, Literal
17
+ from uuid import UUID, uuid4
18
+ from pathlib import Path
19
+
20
+ import polars as pl
21
+
22
+ from rdf_starbase.models import Triple, QuotedTriple, Assertion, ProvenanceContext
23
+ from rdf_starbase.storage.terms import TermDict, TermKind, Term, TermId, get_term_kind
24
+ from rdf_starbase.storage.quoted_triples import QtDict
25
+ from rdf_starbase.storage.facts import FactStore, FactFlags, DEFAULT_GRAPH_ID
26
+
27
+
28
+ class TripleStore:
29
+ """
30
+ A high-performance RDF-Star triple store backed by dictionary-encoded Polars DataFrames.
31
+
32
+ Unified architecture:
33
+ - All terms are dictionary-encoded to integer IDs (TermDict)
34
+ - Facts are stored as integer tuples for maximum join performance (FactStore)
35
+ - String-based views are materialized on demand for SPARQL compatibility
36
+ - Reasoner works directly on integer storage for efficient inference
37
+ """
38
+
39
+ def __init__(self):
40
+ """Initialize an empty triple store with unified storage."""
41
+ # Core storage components
42
+ self._term_dict = TermDict()
43
+ self._qt_dict = QtDict(self._term_dict)
44
+ self._fact_store = FactStore(self._term_dict, self._qt_dict)
45
+
46
+ # Cache for the string-based DataFrame view
47
+ self._df_cache: Optional[pl.DataFrame] = None
48
+ self._df_cache_valid = False
49
+
50
+ # Mapping from assertion UUID to (s_id, p_id, o_id, g_id) for deprecation
51
+ self._assertion_map: dict[UUID, tuple[TermId, TermId, TermId, TermId]] = {}
52
+
53
+ # Quoted triple references (for backward compatibility)
54
+ self._quoted_triples: dict[UUID, QuotedTriple] = {}
55
+
56
+ # Pre-intern common predicates and well-known IRIs
57
+ self._init_common_terms()
58
+
59
+ def _init_common_terms(self):
60
+ """Pre-intern commonly used terms for performance."""
61
+ # RDF vocabulary
62
+ self._rdf_type_id = self._term_dict.intern_iri(
63
+ "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
64
+ )
65
+ # RDFS vocabulary
66
+ self._rdfs_label_id = self._term_dict.intern_iri(
67
+ "http://www.w3.org/2000/01/rdf-schema#label"
68
+ )
69
+ self._rdfs_subclass_id = self._term_dict.intern_iri(
70
+ "http://www.w3.org/2000/01/rdf-schema#subClassOf"
71
+ )
72
+
73
+ def _invalidate_cache(self):
74
+ """Invalidate the cached DataFrame view after modifications."""
75
+ self._df_cache_valid = False
76
+
77
+ def _intern_term(self, value: Any, is_uri_hint: bool = False) -> TermId:
78
+ """
79
+ Intern a term value to a TermId.
80
+
81
+ Args:
82
+ value: The term value (string, number, bool, etc.)
83
+ is_uri_hint: If True, treat string as IRI; otherwise infer
84
+
85
+ Returns:
86
+ TermId for the interned term
87
+ """
88
+ if isinstance(value, str):
89
+ # Check if it looks like a URI
90
+ if is_uri_hint or value.startswith(("http://", "https://", "urn:", "file://")):
91
+ return self._term_dict.intern_iri(value)
92
+ elif value.startswith("_:"):
93
+ return self._term_dict.intern_bnode(value[2:])
94
+ else:
95
+ # Parse RDF literal syntax: "value"^^<datatype> or "value"@lang or "value"
96
+ return self._intern_literal_string(value)
97
+ elif isinstance(value, bool):
98
+ return self._term_dict.intern_literal(str(value).lower(),
99
+ datatype="http://www.w3.org/2001/XMLSchema#boolean")
100
+ elif isinstance(value, int):
101
+ return self._term_dict.intern_literal(str(value),
102
+ datatype="http://www.w3.org/2001/XMLSchema#integer")
103
+ elif isinstance(value, float):
104
+ return self._term_dict.intern_literal(str(value),
105
+ datatype="http://www.w3.org/2001/XMLSchema#decimal")
106
+ else:
107
+ return self._term_dict.intern_literal(str(value))
108
+
109
+ def _intern_literal_string(self, value: str) -> TermId:
110
+ """
111
+ Parse and intern a string that may be in RDF literal syntax.
112
+
113
+ Handles:
114
+ - "value"^^<http://...> -> typed literal
115
+ - "value"@en -> language-tagged literal
116
+ - "value" -> plain literal (xsd:string)
117
+ - value -> plain literal (no quotes)
118
+ """
119
+ # Check for typed literal: "value"^^<datatype>
120
+ if value.startswith('"') and '^^<' in value:
121
+ # Find the closing quote before ^^
122
+ caret_pos = value.find('^^<')
123
+ if caret_pos > 0 and value[caret_pos-1] == '"':
124
+ lex = value[1:caret_pos-1] # Extract value between quotes
125
+ datatype = value[caret_pos+3:-1] # Extract datatype IRI (strip < and >)
126
+ return self._term_dict.intern_literal(lex, datatype=datatype)
127
+
128
+ # Check for language-tagged literal: "value"@lang
129
+ if value.startswith('"') and '"@' in value:
130
+ at_pos = value.rfind('"@')
131
+ if at_pos > 0:
132
+ lex = value[1:at_pos] # Extract value between quotes
133
+ lang = value[at_pos+2:] # Extract language tag
134
+ return self._term_dict.intern_literal(lex, lang=lang)
135
+
136
+ # Check for quoted plain literal: "value"
137
+ if value.startswith('"') and value.endswith('"') and len(value) >= 2:
138
+ lex = value[1:-1] # Strip quotes
139
+ return self._term_dict.intern_literal(lex)
140
+
141
+ # Unquoted plain literal
142
+ return self._term_dict.intern_literal(value)
143
+
144
+ def _term_to_string(self, term_id: TermId) -> Optional[str]:
145
+ """Convert a TermId back to its string representation."""
146
+ term = self._term_dict.lookup(term_id)
147
+ if term is None:
148
+ return None
149
+ return term.lex
150
+
151
+ @property
152
+ def _df(self) -> pl.DataFrame:
153
+ """
154
+ Materialize the string-based DataFrame view for SPARQL executor.
155
+
156
+ This is a computed property that builds a string-column DataFrame
157
+ from the integer-based FactStore. Results are cached until invalidated.
158
+ """
159
+ if self._df_cache_valid and self._df_cache is not None:
160
+ return self._df_cache
161
+
162
+ # Get raw facts - include ALL facts (deleted too, for include_deprecated support)
163
+ fact_df = self._fact_store._df
164
+
165
+ if len(fact_df) == 0:
166
+ self._df_cache = self._create_empty_dataframe()
167
+ self._df_cache_valid = True
168
+ return self._df_cache
169
+
170
+ # Build term lookup DataFrame for fast Polars join
171
+ # Include datatype_id for typed value conversion
172
+ term_rows = [
173
+ {"term_id": tid, "lex": term.lex, "kind": int(term.kind),
174
+ "datatype_id": term.datatype_id if term.datatype_id else 0}
175
+ for tid, term in self._term_dict._id_to_term.items()
176
+ ]
177
+
178
+ if term_rows:
179
+ term_df = pl.DataFrame(term_rows).cast({
180
+ "term_id": pl.UInt64,
181
+ "lex": pl.Utf8,
182
+ "kind": pl.UInt8,
183
+ "datatype_id": pl.UInt64,
184
+ })
185
+ else:
186
+ term_df = pl.DataFrame({
187
+ "term_id": pl.Series([], dtype=pl.UInt64),
188
+ "lex": pl.Series([], dtype=pl.Utf8),
189
+ "kind": pl.Series([], dtype=pl.UInt8),
190
+ "datatype_id": pl.Series([], dtype=pl.UInt64),
191
+ })
192
+
193
+ # Rename fact columns to avoid conflicts during joins
194
+ result = fact_df.rename({
195
+ "source": "source_id",
196
+ "process": "process_id"
197
+ })
198
+
199
+ # Subject - join drops right key automatically in Polars 1.x
200
+ result = result.join(
201
+ term_df.select([pl.col("term_id"), pl.col("lex").alias("subject")]),
202
+ left_on="s", right_on="term_id", how="left"
203
+ )
204
+
205
+ # Predicate
206
+ result = result.join(
207
+ term_df.select([pl.col("term_id"), pl.col("lex").alias("predicate")]),
208
+ left_on="p", right_on="term_id", how="left"
209
+ )
210
+
211
+ # Object + object type + typed object value
212
+ # Get XSD numeric datatype IDs for typed value conversion
213
+ xsd_integer_id = self._term_dict.xsd_integer_id
214
+ xsd_decimal_id = self._term_dict.xsd_decimal_id
215
+ xsd_double_id = self._term_dict.xsd_double_id
216
+ xsd_boolean_id = self._term_dict.xsd_boolean_id
217
+
218
+ obj_df = term_df.select([
219
+ pl.col("term_id"),
220
+ pl.col("lex").alias("object"),
221
+ pl.col("datatype_id").alias("obj_datatype_id"),
222
+ pl.when(pl.col("kind") == int(TermKind.IRI)).then(pl.lit("uri"))
223
+ .when(pl.col("kind") == int(TermKind.BNODE)).then(pl.lit("bnode"))
224
+ .otherwise(pl.lit("literal"))
225
+ .alias("object_type")
226
+ ])
227
+ result = result.join(obj_df, left_on="o", right_on="term_id", how="left")
228
+
229
+ # Create typed object_value column for numeric comparisons
230
+ # Cast lex to float for numeric datatypes, null otherwise
231
+ result = result.with_columns([
232
+ pl.when(
233
+ (pl.col("obj_datatype_id") == xsd_integer_id) |
234
+ (pl.col("obj_datatype_id") == xsd_decimal_id) |
235
+ (pl.col("obj_datatype_id") == xsd_double_id)
236
+ ).then(
237
+ pl.col("object").cast(pl.Float64, strict=False)
238
+ ).when(
239
+ pl.col("obj_datatype_id") == xsd_boolean_id
240
+ ).then(
241
+ pl.when(pl.col("object") == "true").then(pl.lit(1.0))
242
+ .when(pl.col("object") == "false").then(pl.lit(0.0))
243
+ .otherwise(pl.lit(None))
244
+ ).otherwise(pl.lit(None).cast(pl.Float64))
245
+ .alias("object_value")
246
+ ])
247
+
248
+ # Graph (handle 0 = default graph)
249
+ result = result.join(
250
+ term_df.select([pl.col("term_id"), pl.col("lex").alias("graph")]),
251
+ left_on="g", right_on="term_id", how="left"
252
+ )
253
+
254
+ # Source
255
+ result = result.join(
256
+ term_df.select([pl.col("term_id"), pl.col("lex").alias("source")]),
257
+ left_on="source_id", right_on="term_id", how="left"
258
+ )
259
+
260
+ # Process
261
+ result = result.join(
262
+ term_df.select([pl.col("term_id"), pl.col("lex").alias("process")]),
263
+ left_on="process_id", right_on="term_id", how="left"
264
+ )
265
+
266
+ # Build final schema
267
+ result = result.select([
268
+ pl.lit(None).cast(pl.Utf8).alias("assertion_id"), # Deferred - generate on iteration
269
+ "subject",
270
+ "predicate",
271
+ "object",
272
+ "object_type",
273
+ "object_value", # Typed numeric value for FILTER comparisons
274
+ "graph",
275
+ pl.lit(None).cast(pl.Utf8).alias("quoted_triple_id"),
276
+ "source",
277
+ # Convert t_added microseconds to datetime
278
+ (pl.col("t_added") * 1000).cast(pl.Datetime("ns", "UTC")).alias("timestamp"),
279
+ "confidence",
280
+ "process",
281
+ pl.lit(None).cast(pl.Utf8).alias("version"),
282
+ pl.lit("{}").alias("metadata"),
283
+ pl.lit(None).cast(pl.Utf8).alias("superseded_by"),
284
+ ((pl.col("flags").cast(pl.Int32) & int(FactFlags.DELETED)) != 0).alias("deprecated"),
285
+ ])
286
+
287
+ # Generate UUIDs using row_nr (fast)
288
+ n = len(result)
289
+ if n > 0:
290
+ # Use sequential IDs for now (much faster than UUID generation)
291
+ result = result.with_columns([
292
+ pl.arange(0, n, eager=True).cast(pl.Utf8).alias("assertion_id")
293
+ ])
294
+
295
+ self._df_cache = result
296
+ self._df_cache_valid = True
297
+ return self._df_cache
298
+
299
+ @_df.setter
300
+ def _df(self, value: pl.DataFrame):
301
+ """
302
+ Allow direct DataFrame assignment for backward compatibility.
303
+
304
+ This is used by some internal operations that modify _df directly.
305
+ We sync changes back to the FactStore.
306
+ """
307
+ # For backward compatibility, accept direct DataFrame assignment
308
+ # This is mainly used during persistence load
309
+ self._df_cache = value
310
+ self._df_cache_valid = True
311
+ # Note: This doesn't sync to FactStore - used only for legacy load
312
+
313
+ @staticmethod
314
+ def _create_empty_dataframe() -> pl.DataFrame:
315
+ """Create the schema for the string-based assertion DataFrame."""
316
+ return pl.DataFrame({
317
+ "assertion_id": pl.Series([], dtype=pl.Utf8),
318
+ "subject": pl.Series([], dtype=pl.Utf8),
319
+ "predicate": pl.Series([], dtype=pl.Utf8),
320
+ "object": pl.Series([], dtype=pl.Utf8),
321
+ "object_type": pl.Series([], dtype=pl.Utf8),
322
+ "object_value": pl.Series([], dtype=pl.Float64), # Typed numeric value
323
+ "graph": pl.Series([], dtype=pl.Utf8),
324
+ "quoted_triple_id": pl.Series([], dtype=pl.Utf8),
325
+ "source": pl.Series([], dtype=pl.Utf8),
326
+ "timestamp": pl.Series([], dtype=pl.Datetime("us", "UTC")),
327
+ "confidence": pl.Series([], dtype=pl.Float64),
328
+ "process": pl.Series([], dtype=pl.Utf8),
329
+ "version": pl.Series([], dtype=pl.Utf8),
330
+ "metadata": pl.Series([], dtype=pl.Utf8),
331
+ "superseded_by": pl.Series([], dtype=pl.Utf8),
332
+ "deprecated": pl.Series([], dtype=pl.Boolean),
333
+ })
334
+
335
+ def add_triple(
336
+ self,
337
+ subject: str,
338
+ predicate: str,
339
+ obj: Any,
340
+ provenance: ProvenanceContext,
341
+ graph: Optional[str] = None,
342
+ ) -> UUID:
343
+ """
344
+ Add a triple with provenance to the store.
345
+
346
+ Args:
347
+ subject: Subject URI or blank node
348
+ predicate: Predicate URI
349
+ obj: Object (URI, literal, or value)
350
+ provenance: Provenance context for this assertion
351
+ graph: Optional named graph
352
+
353
+ Returns:
354
+ UUID of the created assertion
355
+ """
356
+ # Generate assertion ID upfront
357
+ assertion_id = uuid4()
358
+
359
+ # Intern all terms
360
+ s_id = self._intern_term(subject, is_uri_hint=True)
361
+ p_id = self._intern_term(predicate, is_uri_hint=True)
362
+ o_id = self._intern_term(obj)
363
+ g_id = self._term_dict.intern_iri(graph) if graph else DEFAULT_GRAPH_ID
364
+
365
+ # Intern provenance terms
366
+ source_id = self._term_dict.intern_literal(provenance.source) if provenance.source else 0
367
+ process_id = self._term_dict.intern_literal(provenance.process) if provenance.process else 0
368
+
369
+ # Convert provenance timestamp to microseconds if provided
370
+ t_added = None
371
+ if provenance.timestamp:
372
+ t_added = int(provenance.timestamp.timestamp() * 1_000_000)
373
+
374
+ # Add to fact store
375
+ self._fact_store.add_fact(
376
+ s=s_id,
377
+ p=p_id,
378
+ o=o_id,
379
+ g=g_id,
380
+ flags=FactFlags.ASSERTED,
381
+ source=source_id,
382
+ confidence=provenance.confidence,
383
+ process=process_id,
384
+ t_added=t_added,
385
+ )
386
+
387
+ # Store mapping for deprecation
388
+ self._assertion_map[assertion_id] = (s_id, p_id, o_id, g_id)
389
+
390
+ self._invalidate_cache()
391
+ return assertion_id
392
+
393
+ def add_assertion(self, assertion: Assertion) -> UUID:
394
+ """Add a complete assertion object to the store."""
395
+ return self.add_triple(
396
+ subject=assertion.triple.subject,
397
+ predicate=assertion.triple.predicate,
398
+ obj=assertion.triple.object,
399
+ provenance=assertion.provenance,
400
+ graph=assertion.triple.graph,
401
+ )
402
+
403
+ def add_triples_batch(
404
+ self,
405
+ triples: list[dict],
406
+ ) -> int:
407
+ """
408
+ Add multiple triples in a single batch operation.
409
+
410
+ This is MUCH faster than calling add_triple() repeatedly because:
411
+ - Batch term interning
412
+ - Single FactStore batch operation
413
+
414
+ Args:
415
+ triples: List of dicts with keys:
416
+ - subject: str
417
+ - predicate: str
418
+ - object: Any
419
+ - source: str
420
+ - confidence: float (optional, default 1.0)
421
+ - process: str (optional)
422
+ - graph: str (optional)
423
+
424
+ Returns:
425
+ Number of triples added
426
+ """
427
+ if not triples:
428
+ return 0
429
+
430
+ # Prepare batch data
431
+ facts = []
432
+ now = datetime.now(timezone.utc)
433
+
434
+ for t in triples:
435
+ # Intern terms
436
+ s_id = self._intern_term(t["subject"], is_uri_hint=True)
437
+ p_id = self._intern_term(t["predicate"], is_uri_hint=True)
438
+ o_id = self._intern_term(t.get("object", ""))
439
+
440
+ graph = t.get("graph")
441
+ g_id = self._term_dict.intern_iri(graph) if graph else DEFAULT_GRAPH_ID
442
+
443
+ source = t.get("source", "unknown")
444
+ source_id = self._term_dict.intern_literal(source) if source else 0
445
+
446
+ process = t.get("process")
447
+ process_id = self._term_dict.intern_literal(process) if process else 0
448
+
449
+ confidence = t.get("confidence", 1.0)
450
+
451
+ facts.append((g_id, s_id, p_id, o_id, source_id, confidence, process_id))
452
+
453
+ # Batch insert to FactStore
454
+ for g_id, s_id, p_id, o_id, source_id, confidence, process_id in facts:
455
+ self._fact_store.add_fact(
456
+ s=s_id,
457
+ p=p_id,
458
+ o=o_id,
459
+ g=g_id,
460
+ flags=FactFlags.ASSERTED,
461
+ source=source_id,
462
+ confidence=confidence,
463
+ process=process_id,
464
+ )
465
+
466
+ self._invalidate_cache()
467
+ return len(triples)
468
+
469
+ def add_triples_columnar(
470
+ self,
471
+ subjects: list[str],
472
+ predicates: list[str],
473
+ objects: list[Any],
474
+ source: str = "unknown",
475
+ confidence: float = 1.0,
476
+ graph: Optional[str] = None,
477
+ ) -> int:
478
+ """
479
+ Add triples from column lists (TRUE vectorized path).
480
+
481
+ This is the FASTEST ingestion method. Pass pre-built lists
482
+ of subjects, predicates, and objects.
483
+
484
+ Args:
485
+ subjects: List of subject URIs
486
+ predicates: List of predicate URIs
487
+ objects: List of object values
488
+ source: Shared source for provenance
489
+ confidence: Shared confidence score
490
+ graph: Optional graph URI
491
+
492
+ Returns:
493
+ Number of triples added
494
+ """
495
+ n = len(subjects)
496
+ if n == 0:
497
+ return 0
498
+
499
+ # Batch intern terms
500
+ g_id = self._term_dict.intern_iri(graph) if graph else DEFAULT_GRAPH_ID
501
+ source_id = self._term_dict.intern_literal(source)
502
+
503
+ # Intern subjects (all URIs)
504
+ s_col = [self._term_dict.intern_iri(s) for s in subjects]
505
+
506
+ # Intern predicates (all URIs)
507
+ p_col = [self._term_dict.intern_iri(p) for p in predicates]
508
+
509
+ # Intern objects (could be literals or URIs)
510
+ o_col = [self._intern_term(o) for o in objects]
511
+
512
+ # Graph column
513
+ g_col = [g_id] * n
514
+
515
+ # Use columnar insert
516
+ self._fact_store.add_facts_columnar(
517
+ g_col=g_col,
518
+ s_col=s_col,
519
+ p_col=p_col,
520
+ o_col=o_col,
521
+ flags=FactFlags.ASSERTED,
522
+ source=source_id,
523
+ confidence=confidence,
524
+ )
525
+
526
+ self._invalidate_cache()
527
+ return n
528
+
529
+ def get_triples(
530
+ self,
531
+ subject: Optional[str] = None,
532
+ predicate: Optional[str] = None,
533
+ obj: Optional[str] = None,
534
+ graph: Optional[str] = None,
535
+ source: Optional[str] = None,
536
+ min_confidence: float = 0.0,
537
+ include_deprecated: bool = False,
538
+ ) -> pl.DataFrame:
539
+ """
540
+ Query triples with optional filters.
541
+
542
+ Uses the string-based _df view for compatibility with existing code.
543
+ """
544
+ df = self._df.lazy()
545
+
546
+ if subject is not None:
547
+ df = df.filter(pl.col("subject") == subject)
548
+ if predicate is not None:
549
+ df = df.filter(pl.col("predicate") == predicate)
550
+ if obj is not None:
551
+ df = df.filter(pl.col("object") == str(obj))
552
+ if graph is not None:
553
+ df = df.filter(pl.col("graph") == graph)
554
+ if source is not None:
555
+ df = df.filter(pl.col("source") == source)
556
+ if min_confidence is not None:
557
+ df = df.filter(pl.col("confidence") >= min_confidence)
558
+ if not include_deprecated:
559
+ df = df.filter(~pl.col("deprecated"))
560
+
561
+ return df.collect()
562
+
563
+ def get_competing_claims(
564
+ self,
565
+ subject: str,
566
+ predicate: str,
567
+ ) -> pl.DataFrame:
568
+ """Find competing assertions about the same subject-predicate pair."""
569
+ df = self.get_triples(subject=subject, predicate=predicate, include_deprecated=False)
570
+ df = df.sort(["confidence", "timestamp"], descending=[True, True])
571
+ return df
572
+
573
+ def deprecate_assertion(self, assertion_id: UUID, superseded_by: Optional[UUID] = None) -> None:
574
+ """Mark an assertion as deprecated."""
575
+ # Look up the assertion in our mapping
576
+ if assertion_id in self._assertion_map:
577
+ s_id, p_id, o_id, g_id = self._assertion_map[assertion_id]
578
+ self._fact_store.mark_deleted(s=s_id, p=p_id, o=o_id)
579
+ self._invalidate_cache()
580
+ return
581
+
582
+ # Fallback: try to find in cached DataFrame
583
+ if self._df_cache is not None and len(self._df_cache) > 0:
584
+ matching = self._df_cache.filter(pl.col("assertion_id") == str(assertion_id))
585
+ if len(matching) > 0:
586
+ # Mark in cache
587
+ self._df_cache = self._df_cache.with_columns([
588
+ pl.when(pl.col("assertion_id") == str(assertion_id))
589
+ .then(True)
590
+ .otherwise(pl.col("deprecated"))
591
+ .alias("deprecated"),
592
+
593
+ pl.when(pl.col("assertion_id") == str(assertion_id))
594
+ .then(str(superseded_by) if superseded_by else None)
595
+ .otherwise(pl.col("superseded_by"))
596
+ .alias("superseded_by"),
597
+ ])
598
+
599
+ # Also need to mark in FactStore
600
+ subject = matching["subject"][0]
601
+ predicate = matching["predicate"][0]
602
+ obj = matching["object"][0]
603
+
604
+ s_id = self._term_dict.lookup_iri(subject)
605
+ p_id = self._term_dict.lookup_iri(predicate)
606
+ o_id = self._term_dict.lookup_iri(obj)
607
+ if o_id is None:
608
+ o_id = self._term_dict.lookup_literal(obj)
609
+
610
+ if s_id is not None and p_id is not None and o_id is not None:
611
+ self._fact_store.mark_deleted(s=s_id, p=p_id, o=o_id)
612
+ self._invalidate_cache()
613
+
614
+ def get_provenance_timeline(self, subject: str, predicate: str) -> pl.DataFrame:
615
+ """Get the full history of assertions about a subject-predicate pair."""
616
+ df = self.get_triples(
617
+ subject=subject,
618
+ predicate=predicate,
619
+ include_deprecated=True
620
+ )
621
+ df = df.sort("timestamp")
622
+ return df
623
+
624
+ def mark_deleted(
625
+ self,
626
+ s: Optional[str] = None,
627
+ p: Optional[str] = None,
628
+ o: Optional[str] = None
629
+ ) -> int:
630
+ """
631
+ Mark matching triples as deprecated (soft delete).
632
+
633
+ Works on the FactStore level for correctness.
634
+ """
635
+ # Look up term IDs (if they don't exist, no triples to delete)
636
+ s_id = None
637
+ p_id = None
638
+ o_id = None
639
+
640
+ if s is not None:
641
+ s_id = self._term_dict.lookup_iri(s)
642
+ if s_id is None:
643
+ return 0
644
+ if p is not None:
645
+ p_id = self._term_dict.lookup_iri(p)
646
+ if p_id is None:
647
+ return 0
648
+ if o is not None:
649
+ # Try as IRI first, then literal
650
+ o_id = self._term_dict.lookup_iri(o)
651
+ if o_id is None:
652
+ o_id = self._term_dict.lookup_literal(o)
653
+ if o_id is None:
654
+ return 0
655
+
656
+ count = self._fact_store.mark_deleted(s=s_id, p=p_id, o=o_id)
657
+ self._invalidate_cache()
658
+ return count
659
+
660
+ def save(self, path: Path | str) -> None:
661
+ """
662
+ Save the triple store to disk.
663
+
664
+ Saves all components: TermDict, QtDict, FactStore.
665
+ """
666
+ path = Path(path)
667
+ path.parent.mkdir(parents=True, exist_ok=True)
668
+
669
+ # Create a directory for the unified store
670
+ store_dir = path.parent / (path.stem + "_unified")
671
+ store_dir.mkdir(parents=True, exist_ok=True)
672
+
673
+ # Save components
674
+ self._term_dict.save(store_dir / "terms")
675
+ self._qt_dict.save(store_dir / "quoted_triples")
676
+ self._fact_store.save(store_dir / "facts")
677
+
678
+ # Also save the legacy format for backward compatibility
679
+ self._df.write_parquet(path)
680
+
681
+ @classmethod
682
+ def load(cls, path: Path | str) -> "TripleStore":
683
+ """
684
+ Load a triple store from disk.
685
+
686
+ Attempts to load unified format first, falls back to legacy.
687
+ """
688
+ path = Path(path)
689
+ store_dir = path.parent / (path.stem + "_unified")
690
+
691
+ if store_dir.exists():
692
+ # Load unified format
693
+ store = cls()
694
+ store._term_dict = TermDict.load(store_dir / "terms")
695
+ store._qt_dict = QtDict.load(store_dir / "quoted_triples", store._term_dict)
696
+ store._fact_store = FactStore.load(
697
+ store_dir / "facts",
698
+ store._term_dict,
699
+ store._qt_dict
700
+ )
701
+ return store
702
+ else:
703
+ # Load legacy format and convert
704
+ store = cls()
705
+ legacy_df = pl.read_parquet(path)
706
+
707
+ # Import each row
708
+ for row in legacy_df.iter_rows(named=True):
709
+ if not row.get("deprecated", False):
710
+ prov = ProvenanceContext(
711
+ source=row.get("source", "legacy"),
712
+ confidence=row.get("confidence", 1.0),
713
+ process=row.get("process"),
714
+ timestamp=row.get("timestamp", datetime.now(timezone.utc)),
715
+ )
716
+ store.add_triple(
717
+ subject=row["subject"],
718
+ predicate=row["predicate"],
719
+ obj=row["object"],
720
+ provenance=prov,
721
+ graph=row.get("graph"),
722
+ )
723
+
724
+ return store
725
+
726
+ def stats(self) -> dict[str, Any]:
727
+ """Get statistics about the triple store."""
728
+ fact_stats = self._fact_store.stats()
729
+ term_stats = self._term_dict.stats()
730
+
731
+ # Get unique subjects and predicates from actual facts (not deleted)
732
+ active_facts = self._fact_store._df.filter(
733
+ (pl.col("flags").cast(pl.Int32) & int(FactFlags.DELETED)) == 0
734
+ )
735
+ unique_subjects = active_facts.select("s").unique().height
736
+ unique_predicates = active_facts.select("p").unique().height
737
+
738
+ return {
739
+ "total_assertions": fact_stats["total_facts"],
740
+ "active_assertions": fact_stats["active_facts"],
741
+ "deprecated_assertions": fact_stats["total_facts"] - fact_stats["active_facts"],
742
+ "unique_sources": len(set(
743
+ self._term_to_string(sid)
744
+ for sid in self._fact_store._df["source"].unique().to_list()
745
+ if sid and sid != 0
746
+ )),
747
+ "unique_subjects": unique_subjects,
748
+ "unique_predicates": unique_predicates,
749
+ "term_dict": term_stats,
750
+ "fact_store": fact_stats,
751
+ }
752
+
753
+ def __len__(self) -> int:
754
+ """Return the number of active assertions."""
755
+ return self._fact_store.count_active()
756
+
757
+ def __repr__(self) -> str:
758
+ stats = self.stats()
759
+ return (
760
+ f"TripleStore("
761
+ f"assertions={stats['active_assertions']}, "
762
+ f"terms={stats['term_dict']['total_terms']})"
763
+ )
764
+
765
+ # =========================================================================
766
+ # Named Graph Management
767
+ # =========================================================================
768
+
769
+ def list_graphs(self) -> list[str]:
770
+ """List all named graphs in the store."""
771
+ # Get unique graph IDs from FactStore
772
+ graph_ids = self._fact_store._df.filter(
773
+ (pl.col("g") != DEFAULT_GRAPH_ID) &
774
+ ((pl.col("flags").cast(pl.Int32) & int(FactFlags.DELETED)) == 0)
775
+ ).select("g").unique().to_series().to_list()
776
+
777
+ graphs = []
778
+ for gid in graph_ids:
779
+ term = self._term_dict.lookup(gid)
780
+ if term is not None:
781
+ graphs.append(term.lex)
782
+
783
+ return sorted(graphs)
784
+
785
+ def create_graph(self, graph_uri: str) -> bool:
786
+ """Create an empty named graph."""
787
+ g_id = self._term_dict.intern_iri(graph_uri)
788
+ existing = self._fact_store._df.filter(
789
+ (pl.col("g") == g_id) &
790
+ ((pl.col("flags").cast(pl.Int32) & int(FactFlags.DELETED)) == 0)
791
+ ).height
792
+ return existing == 0
793
+
794
+ def drop_graph(self, graph_uri: str, silent: bool = False) -> int:
795
+ """Drop (delete) a named graph and all its triples."""
796
+ g_id = self._term_dict.lookup_iri(graph_uri)
797
+ if g_id is None:
798
+ return 0
799
+
800
+ # Mark all facts in this graph as deleted
801
+ count = 0
802
+ fact_df = self._fact_store._df
803
+ matching = fact_df.filter(
804
+ (pl.col("g") == g_id) &
805
+ ((pl.col("flags").cast(pl.Int32) & int(FactFlags.DELETED)) == 0)
806
+ )
807
+ count = matching.height
808
+
809
+ if count > 0:
810
+ # Update flags
811
+ self._fact_store._df = fact_df.with_columns([
812
+ pl.when(
813
+ (pl.col("g") == g_id) &
814
+ ((pl.col("flags").cast(pl.Int32) & int(FactFlags.DELETED)) == 0)
815
+ )
816
+ .then((pl.col("flags").cast(pl.Int32) | int(FactFlags.DELETED)).cast(pl.UInt16))
817
+ .otherwise(pl.col("flags"))
818
+ .alias("flags")
819
+ ])
820
+ self._invalidate_cache()
821
+
822
+ return count
823
+
824
+ def clear_graph(self, graph_uri: Optional[str] = None, silent: bool = False) -> int:
825
+ """Clear all triples from a graph (or default graph if None)."""
826
+ if graph_uri is None:
827
+ g_id = DEFAULT_GRAPH_ID
828
+ else:
829
+ g_id = self._term_dict.lookup_iri(graph_uri)
830
+ if g_id is None:
831
+ return 0
832
+
833
+ fact_df = self._fact_store._df
834
+ matching = fact_df.filter(
835
+ (pl.col("g") == g_id) &
836
+ ((pl.col("flags").cast(pl.Int32) & int(FactFlags.DELETED)) == 0)
837
+ )
838
+ count = matching.height
839
+
840
+ if count > 0:
841
+ self._fact_store._df = fact_df.with_columns([
842
+ pl.when(
843
+ (pl.col("g") == g_id) &
844
+ ((pl.col("flags").cast(pl.Int32) & int(FactFlags.DELETED)) == 0)
845
+ )
846
+ .then((pl.col("flags").cast(pl.Int32) | int(FactFlags.DELETED)).cast(pl.UInt16))
847
+ .otherwise(pl.col("flags"))
848
+ .alias("flags")
849
+ ])
850
+ self._invalidate_cache()
851
+
852
+ return count
853
+
854
+ def copy_graph(
855
+ self,
856
+ source_graph: Optional[str],
857
+ dest_graph: Optional[str],
858
+ silent: bool = False,
859
+ ) -> int:
860
+ """Copy all triples from source graph to destination graph."""
861
+ # Clear destination first
862
+ self.clear_graph(dest_graph, silent=True)
863
+
864
+ # Get source graph ID
865
+ if source_graph is None:
866
+ src_g_id = DEFAULT_GRAPH_ID
867
+ else:
868
+ src_g_id = self._term_dict.lookup_iri(source_graph)
869
+ if src_g_id is None:
870
+ return 0
871
+
872
+ # Get destination graph ID
873
+ if dest_graph is None:
874
+ dest_g_id = DEFAULT_GRAPH_ID
875
+ else:
876
+ dest_g_id = self._term_dict.intern_iri(dest_graph)
877
+
878
+ # Get source facts
879
+ fact_df = self._fact_store._df
880
+ source_facts = fact_df.filter(
881
+ (pl.col("g") == src_g_id) &
882
+ ((pl.col("flags").cast(pl.Int32) & int(FactFlags.DELETED)) == 0)
883
+ )
884
+
885
+ if source_facts.height == 0:
886
+ return 0
887
+
888
+ # Create copies with new graph and transaction IDs
889
+ new_txn = self._fact_store._allocate_txn()
890
+ t_now = int(datetime.now(timezone.utc).timestamp() * 1_000_000)
891
+
892
+ new_facts = source_facts.with_columns([
893
+ pl.lit(dest_g_id).cast(pl.UInt64).alias("g"),
894
+ pl.lit(new_txn).cast(pl.UInt64).alias("txn"),
895
+ pl.lit(t_now).cast(pl.UInt64).alias("t_added"),
896
+ ])
897
+
898
+ self._fact_store._df = pl.concat([self._fact_store._df, new_facts])
899
+ self._invalidate_cache()
900
+
901
+ return new_facts.height
902
+
903
+ def move_graph(
904
+ self,
905
+ source_graph: Optional[str],
906
+ dest_graph: Optional[str],
907
+ silent: bool = False,
908
+ ) -> int:
909
+ """Move all triples from source graph to destination graph."""
910
+ count = self.copy_graph(source_graph, dest_graph, silent)
911
+
912
+ # Clear source
913
+ if source_graph is None:
914
+ self.clear_graph(None, silent=True)
915
+ else:
916
+ self.clear_graph(source_graph, silent=True)
917
+
918
+ return count
919
+
920
+ def add_graph(
921
+ self,
922
+ source_graph: Optional[str],
923
+ dest_graph: Optional[str],
924
+ silent: bool = False,
925
+ ) -> int:
926
+ """Add all triples from source graph to destination graph."""
927
+ # Get source graph ID
928
+ if source_graph is None:
929
+ src_g_id = DEFAULT_GRAPH_ID
930
+ else:
931
+ src_g_id = self._term_dict.lookup_iri(source_graph)
932
+ if src_g_id is None:
933
+ return 0
934
+
935
+ # Get destination graph ID
936
+ if dest_graph is None:
937
+ dest_g_id = DEFAULT_GRAPH_ID
938
+ else:
939
+ dest_g_id = self._term_dict.intern_iri(dest_graph)
940
+
941
+ # Get source facts
942
+ fact_df = self._fact_store._df
943
+ source_facts = fact_df.filter(
944
+ (pl.col("g") == src_g_id) &
945
+ ((pl.col("flags").cast(pl.Int32) & int(FactFlags.DELETED)) == 0)
946
+ )
947
+
948
+ if source_facts.height == 0:
949
+ return 0
950
+
951
+ # Create copies with new graph
952
+ new_txn = self._fact_store._allocate_txn()
953
+ t_now = int(datetime.now(timezone.utc).timestamp() * 1_000_000)
954
+
955
+ new_facts = source_facts.with_columns([
956
+ pl.lit(dest_g_id).cast(pl.UInt64).alias("g"),
957
+ pl.lit(new_txn).cast(pl.UInt64).alias("txn"),
958
+ pl.lit(t_now).cast(pl.UInt64).alias("t_added"),
959
+ ])
960
+
961
+ self._fact_store._df = pl.concat([self._fact_store._df, new_facts])
962
+ self._invalidate_cache()
963
+
964
+ return new_facts.height
965
+
966
+ def load_graph(
967
+ self,
968
+ source_uri: str,
969
+ graph_uri: Optional[str] = None,
970
+ silent: bool = False,
971
+ ) -> int:
972
+ """Load RDF data from a URI into a graph."""
973
+ from pathlib import Path
974
+ from urllib.parse import urlparse, unquote
975
+
976
+ # Determine file path
977
+ if source_uri.startswith("file://"):
978
+ parsed = urlparse(source_uri)
979
+ file_path_str = unquote(parsed.path)
980
+ if len(file_path_str) > 2 and file_path_str[0] == '/' and file_path_str[2] == ':':
981
+ file_path_str = file_path_str[1:]
982
+ file_path = Path(file_path_str)
983
+ elif source_uri.startswith(("http://", "https://")):
984
+ import tempfile
985
+ import urllib.request
986
+ try:
987
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".ttl") as f:
988
+ urllib.request.urlretrieve(source_uri, f.name)
989
+ file_path = Path(f.name)
990
+ except Exception as e:
991
+ if silent:
992
+ return 0
993
+ raise ValueError(f"Failed to download {source_uri}: {e}")
994
+ else:
995
+ file_path = Path(source_uri)
996
+
997
+ if not file_path.exists():
998
+ if silent:
999
+ return 0
1000
+ raise FileNotFoundError(f"Source file not found: {file_path}")
1001
+
1002
+ # Determine format from extension
1003
+ suffix = file_path.suffix.lower()
1004
+
1005
+ try:
1006
+ if suffix in (".ttl", ".turtle"):
1007
+ from rdf_starbase.formats.turtle import parse_turtle
1008
+ parsed = parse_turtle(file_path.read_text())
1009
+ triples = parsed.triples
1010
+ elif suffix in (".nt", ".ntriples"):
1011
+ from rdf_starbase.formats.ntriples import parse_ntriples
1012
+ parsed = parse_ntriples(file_path.read_text())
1013
+ triples = parsed.triples
1014
+ elif suffix in (".rdf", ".xml"):
1015
+ from rdf_starbase.formats.rdfxml import parse_rdfxml
1016
+ parsed = parse_rdfxml(file_path.read_text())
1017
+ triples = parsed.triples
1018
+ elif suffix in (".jsonld", ".json"):
1019
+ from rdf_starbase.formats.jsonld import parse_jsonld
1020
+ parsed = parse_jsonld(file_path.read_text())
1021
+ triples = parsed.triples
1022
+ else:
1023
+ from rdf_starbase.formats.turtle import parse_turtle
1024
+ parsed = parse_turtle(file_path.read_text())
1025
+ triples = parsed.triples
1026
+ except Exception as e:
1027
+ if silent:
1028
+ return 0
1029
+ raise ValueError(f"Failed to parse {file_path}: {e}")
1030
+
1031
+ # Add triples to the graph
1032
+ prov = ProvenanceContext(
1033
+ source=source_uri,
1034
+ confidence=1.0,
1035
+ process="LOAD",
1036
+ )
1037
+
1038
+ count = 0
1039
+ for triple in triples:
1040
+ self.add_triple(
1041
+ subject=triple.subject,
1042
+ predicate=triple.predicate,
1043
+ obj=triple.object,
1044
+ provenance=prov,
1045
+ graph=graph_uri,
1046
+ )
1047
+ count += 1
1048
+
1049
+ return count