agmem 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {agmem-0.1.2.dist-info → agmem-0.1.3.dist-info}/METADATA +138 -14
  2. {agmem-0.1.2.dist-info → agmem-0.1.3.dist-info}/RECORD +45 -26
  3. memvcs/cli.py +10 -0
  4. memvcs/commands/add.py +6 -0
  5. memvcs/commands/audit.py +59 -0
  6. memvcs/commands/clone.py +7 -0
  7. memvcs/commands/daemon.py +28 -0
  8. memvcs/commands/distill.py +16 -0
  9. memvcs/commands/federated.py +53 -0
  10. memvcs/commands/fsck.py +31 -0
  11. memvcs/commands/garden.py +14 -0
  12. memvcs/commands/gc.py +51 -0
  13. memvcs/commands/merge.py +55 -1
  14. memvcs/commands/prove.py +66 -0
  15. memvcs/commands/pull.py +27 -0
  16. memvcs/commands/resolve.py +130 -0
  17. memvcs/commands/verify.py +74 -23
  18. memvcs/core/audit.py +124 -0
  19. memvcs/core/consistency.py +9 -9
  20. memvcs/core/crypto_verify.py +280 -0
  21. memvcs/core/distiller.py +25 -25
  22. memvcs/core/encryption.py +169 -0
  23. memvcs/core/federated.py +86 -0
  24. memvcs/core/gardener.py +23 -24
  25. memvcs/core/ipfs_remote.py +39 -0
  26. memvcs/core/knowledge_graph.py +1 -0
  27. memvcs/core/llm/__init__.py +10 -0
  28. memvcs/core/llm/anthropic_provider.py +50 -0
  29. memvcs/core/llm/base.py +27 -0
  30. memvcs/core/llm/factory.py +30 -0
  31. memvcs/core/llm/openai_provider.py +36 -0
  32. memvcs/core/merge.py +36 -23
  33. memvcs/core/objects.py +16 -6
  34. memvcs/core/pack.py +92 -0
  35. memvcs/core/privacy_budget.py +63 -0
  36. memvcs/core/remote.py +38 -0
  37. memvcs/core/repository.py +82 -2
  38. memvcs/core/temporal_index.py +9 -0
  39. memvcs/core/trust.py +103 -0
  40. memvcs/core/vector_store.py +15 -1
  41. memvcs/core/zk_proofs.py +26 -0
  42. {agmem-0.1.2.dist-info → agmem-0.1.3.dist-info}/WHEEL +0 -0
  43. {agmem-0.1.2.dist-info → agmem-0.1.3.dist-info}/entry_points.txt +0 -0
  44. {agmem-0.1.2.dist-info → agmem-0.1.3.dist-info}/licenses/LICENSE +0 -0
  45. {agmem-0.1.2.dist-info → agmem-0.1.3.dist-info}/top_level.txt +0 -0
memvcs/core/audit.py ADDED
@@ -0,0 +1,124 @@
1
+ """
2
+ Tamper-evident audit trail for agmem.
3
+
4
+ Append-only, hash-chained log of significant operations.
5
+ """
6
+
7
+ import datetime
8
+ import hashlib
9
+ import hmac
10
+ import json
11
+ import os
12
+ from pathlib import Path
13
+ from typing import Optional, List, Dict, Any, Tuple
14
+
15
+
16
+ def _audit_dir(mem_dir: Path) -> Path:
17
+ return mem_dir / "audit"
18
+
19
+
20
+ def _log_path(mem_dir: Path) -> Path:
21
+ return _audit_dir(mem_dir) / "log"
22
+
23
+
24
+ def _get_previous_hash(mem_dir: Path) -> str:
25
+ """Read last line of audit log and return its entry hash, or empty for first entry."""
26
+ path = _log_path(mem_dir)
27
+ if not path.exists():
28
+ return ""
29
+ lines = path.read_text().strip().split("\n")
30
+ if not lines:
31
+ return ""
32
+ # Format per line: entry_hash\tpayload_json
33
+ for line in reversed(lines):
34
+ line = line.strip()
35
+ if not line:
36
+ continue
37
+ if "\t" in line:
38
+ return line.split("\t", 1)[0]
39
+ return ""
40
+ return ""
41
+
42
+
43
+ def _hash_entry(prev_hash: str, payload: str) -> str:
44
+ """Compute this entry's hash: SHA-256(prev_hash + payload)."""
45
+ return hashlib.sha256((prev_hash + payload).encode()).hexdigest()
46
+
47
+
48
+ def append_audit(
49
+ mem_dir: Path,
50
+ operation: str,
51
+ details: Optional[Dict[str, Any]] = None,
52
+ ) -> None:
53
+ """
54
+ Append a tamper-evident audit entry. Write synchronously.
55
+ Each entry: entry_hash TAB payload_json (payload has timestamp, operation, details, prev_hash).
56
+ """
57
+ mem_dir = Path(mem_dir)
58
+ _audit_dir(mem_dir).mkdir(parents=True, exist_ok=True)
59
+ path = _log_path(mem_dir)
60
+ prev_hash = _get_previous_hash(mem_dir)
61
+ payload = {
62
+ "timestamp": datetime.datetime.utcnow().isoformat() + "Z",
63
+ "operation": operation,
64
+ "details": details or {},
65
+ "prev_hash": prev_hash,
66
+ }
67
+ payload_str = json.dumps(payload, sort_keys=True)
68
+ entry_hash = _hash_entry(prev_hash, payload_str)
69
+ line = f"{entry_hash}\t{payload_str}\n"
70
+ with open(path, "a", encoding="utf-8") as f:
71
+ f.write(line)
72
+ f.flush()
73
+ try:
74
+ os.fsync(f.fileno())
75
+ except (AttributeError, OSError):
76
+ pass
77
+
78
+
79
+ def read_audit(mem_dir: Path, max_entries: int = 1000) -> List[Dict[str, Any]]:
80
+ """Read audit log entries (newest first). Each entry has entry_hash, prev_hash, timestamp, operation, details."""
81
+ path = _log_path(mem_dir)
82
+ if not path.exists():
83
+ return []
84
+ entries = []
85
+ for line in reversed(path.read_text().strip().split("\n")):
86
+ line = line.strip()
87
+ if not line:
88
+ continue
89
+ if "\t" not in line:
90
+ continue
91
+ entry_hash, payload_str = line.split("\t", 1)
92
+ try:
93
+ payload = json.loads(payload_str)
94
+ except json.JSONDecodeError:
95
+ continue
96
+ payload["entry_hash"] = entry_hash
97
+ entries.append(payload)
98
+ if len(entries) >= max_entries:
99
+ break
100
+ return entries
101
+
102
+
103
+ def verify_audit(mem_dir: Path) -> Tuple[bool, Optional[int]]:
104
+ """
105
+ Verify the audit log chain. Returns (valid, first_bad_index).
106
+ first_bad_index is 0-based index of first entry that fails chain verification.
107
+ """
108
+ path = _log_path(mem_dir)
109
+ if not path.exists():
110
+ return (True, None)
111
+ lines = path.read_text().strip().split("\n")
112
+ prev_hash = ""
113
+ for i, line in enumerate(lines):
114
+ line = line.strip()
115
+ if not line:
116
+ continue
117
+ if "\t" not in line:
118
+ return (False, i)
119
+ entry_hash, payload_str = line.split("\t", 1)
120
+ expected_hash = _hash_entry(prev_hash, payload_str)
121
+ if not hmac.compare_digest(entry_hash, expected_hash):
122
+ return (False, i)
123
+ prev_hash = entry_hash
124
+ return (True, None)
@@ -100,23 +100,23 @@ class ConsistencyChecker:
100
100
  return triples
101
101
 
102
102
  def _extract_triples_llm(self, content: str, source: str) -> List[Triple]:
103
- """Extract triples using LLM."""
103
+ """Extract triples using LLM (multi-provider)."""
104
104
  try:
105
- import openai
105
+ from .llm import get_provider
106
106
 
107
- response = openai.chat.completions.create(
108
- model="gpt-3.5-turbo",
109
- messages=[
107
+ provider = get_provider(provider_name=self.llm_provider)
108
+ if not provider:
109
+ return []
110
+ text = provider.complete(
111
+ [
110
112
  {
111
113
  "role": "system",
112
- "content": "Extract factual statements as (subject, predicate, object) triples. "
113
- "One per line, format: SUBJECT | PREDICATE | OBJECT",
114
+ "content": "Extract factual statements as (subject, predicate, object) triples. One per line, format: SUBJECT | PREDICATE | OBJECT",
114
115
  },
115
116
  {"role": "user", "content": content[:3000]},
116
117
  ],
117
118
  max_tokens=500,
118
119
  )
119
- text = response.choices[0].message.content
120
120
  triples = []
121
121
  for i, line in enumerate(text.splitlines(), 1):
122
122
  if "|" in line:
@@ -138,7 +138,7 @@ class ConsistencyChecker:
138
138
 
139
139
  def extract_triples(self, content: str, source: str, use_llm: bool = False) -> List[Triple]:
140
140
  """Extract triples from content."""
141
- if use_llm and self.llm_provider == "openai":
141
+ if use_llm and self.llm_provider:
142
142
  t = self._extract_triples_llm(content, source)
143
143
  if t:
144
144
  return t
@@ -0,0 +1,280 @@
1
+ """
2
+ Cryptographic commit verification for agmem.
3
+
4
+ Merkle tree over commit blobs, optional Ed25519 signing of Merkle root.
5
+ Verification on checkout, pull, and via verify/fsck.
6
+ """
7
+
8
+ import hashlib
9
+ import hmac
10
+ import json
11
+ import os
12
+ from pathlib import Path
13
+ from typing import Optional, List, Tuple, Any, Dict
14
+
15
+ from .objects import ObjectStore, Tree, Commit
16
+
17
+ # Ed25519 via cryptography (optional)
18
+ try:
19
+ from cryptography.hazmat.primitives.asymmetric.ed25519 import (
20
+ Ed25519PrivateKey,
21
+ Ed25519PublicKey,
22
+ )
23
+ from cryptography.exceptions import InvalidSignature
24
+ from cryptography.hazmat.primitives import serialization
25
+
26
+ ED25519_AVAILABLE = True
27
+ except ImportError:
28
+ ED25519_AVAILABLE = False
29
+
30
+
31
+ def _collect_blob_hashes_from_tree(store: ObjectStore, tree_hash: str) -> List[str]:
32
+ """Recursively collect all blob hashes from a tree. Returns sorted list for deterministic Merkle."""
33
+ tree = Tree.load(store, tree_hash)
34
+ if not tree:
35
+ return []
36
+ blobs: List[str] = []
37
+ for entry in tree.entries:
38
+ if entry.obj_type == "blob":
39
+ blobs.append(entry.hash)
40
+ elif entry.obj_type == "tree":
41
+ blobs.extend(_collect_blob_hashes_from_tree(store, entry.hash))
42
+ return sorted(blobs)
43
+
44
+
45
+ def _merkle_hash(data: bytes) -> str:
46
+ """SHA-256 hash for Merkle tree nodes."""
47
+ return hashlib.sha256(data).hexdigest()
48
+
49
+
50
+ def build_merkle_tree(blob_hashes: List[str]) -> str:
51
+ """
52
+ Build balanced binary Merkle tree from blob hashes.
53
+ Leaves are hashes of blob hashes (as hex strings); internal nodes hash(left_hex || right_hex).
54
+ Returns root hash (hex).
55
+ """
56
+ if not blob_hashes:
57
+ return _merkle_hash(b"empty")
58
+ # Leaves: hash each blob hash string to fixed-size leaf
59
+ layer = [_merkle_hash(h.encode()) for h in blob_hashes]
60
+ while len(layer) > 1:
61
+ next_layer = []
62
+ for i in range(0, len(layer), 2):
63
+ left = layer[i]
64
+ right = layer[i + 1] if i + 1 < len(layer) else layer[i]
65
+ combined = (left + right).encode()
66
+ next_layer.append(_merkle_hash(combined))
67
+ layer = next_layer
68
+ return layer[0]
69
+
70
+
71
+ def build_merkle_root_for_commit(store: ObjectStore, commit_hash: str) -> Optional[str]:
72
+ """Build Merkle root for a commit's tree. Returns None if commit/tree missing."""
73
+ commit = Commit.load(store, commit_hash)
74
+ if not commit:
75
+ return None
76
+ blobs = _collect_blob_hashes_from_tree(store, commit.tree)
77
+ return build_merkle_tree(blobs)
78
+
79
+
80
+ def merkle_proof(blob_hashes: List[str], target_blob_hash: str) -> Optional[List[Tuple[str, str]]]:
81
+ """
82
+ Produce Merkle proof for a blob: list of (sibling_hash, "L"|"R") from leaf to root.
83
+ Returns None if target not in list.
84
+ """
85
+ if target_blob_hash not in blob_hashes:
86
+ return None
87
+ layer = [_merkle_hash(h.encode()) for h in sorted(blob_hashes)]
88
+ leaf_index = sorted(blob_hashes).index(target_blob_hash)
89
+ proof: List[Tuple[str, str]] = []
90
+ idx = leaf_index
91
+ while len(layer) > 1:
92
+ next_layer = []
93
+ for i in range(0, len(layer), 2):
94
+ left = layer[i]
95
+ right = layer[i + 1] if i + 1 < len(layer) else layer[i]
96
+ combined = (left + right).encode()
97
+ parent = _merkle_hash(combined)
98
+ next_layer.append(parent)
99
+ # If current idx is in this pair, record sibling and advance index
100
+ pair_idx = i // 2
101
+ if idx == i:
102
+ proof.append((right, "R"))
103
+ idx = pair_idx
104
+ elif idx == i + 1:
105
+ proof.append((left, "L"))
106
+ idx = pair_idx
107
+ layer = next_layer
108
+ return proof if proof else []
109
+
110
+
111
+ def verify_merkle_proof(blob_hash: str, proof: List[Tuple[str, str]], expected_root: str) -> bool:
112
+ """Verify a Merkle proof for a blob against expected root."""
113
+ current = _merkle_hash(blob_hash.encode())
114
+ for sibling, side in proof:
115
+ if side == "L":
116
+ current = _merkle_hash((sibling + current).encode())
117
+ else:
118
+ current = _merkle_hash((current + sibling).encode())
119
+ return current == expected_root
120
+
121
+
122
+ # --- Signing (Ed25519) ---
123
+
124
+
125
+ def _keys_dir(mem_dir: Path) -> Path:
126
+ return mem_dir / "keys"
127
+
128
+
129
+ def get_signing_key_paths(mem_dir: Path) -> Tuple[Path, Path]:
130
+ """Return (private_key_path, public_key_path). Private may not exist (env-only)."""
131
+ kd = _keys_dir(mem_dir)
132
+ return (kd / "private.pem", kd / "public.pem")
133
+
134
+
135
+ def ensure_keys_dir(mem_dir: Path) -> Path:
136
+ """Ensure .mem/keys exists; return keys dir."""
137
+ kd = _keys_dir(mem_dir)
138
+ kd.mkdir(parents=True, exist_ok=True)
139
+ return kd
140
+
141
+
142
+ def generate_keypair(mem_dir: Path) -> Tuple[bytes, bytes]:
143
+ """Generate Ed25519 keypair. Returns (private_pem, public_pem). Requires cryptography."""
144
+ if not ED25519_AVAILABLE:
145
+ raise RuntimeError(
146
+ "Signing requires 'cryptography'; install with: pip install cryptography"
147
+ )
148
+ private_key = Ed25519PrivateKey.generate()
149
+ public_key = private_key.public_key()
150
+ private_pem = private_key.private_bytes(
151
+ encoding=serialization.Encoding.PEM,
152
+ format=serialization.PrivateFormat.PKCS8,
153
+ encryption_algorithm=serialization.NoEncryption(),
154
+ )
155
+ public_pem = public_key.public_bytes(
156
+ encoding=serialization.Encoding.PEM,
157
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
158
+ )
159
+ return (private_pem, public_pem)
160
+
161
+
162
+ def save_public_key(mem_dir: Path, public_pem: bytes) -> Path:
163
+ """Save public key to .mem/keys/public.pem. Returns path."""
164
+ ensure_keys_dir(mem_dir)
165
+ path = _keys_dir(mem_dir) / "public.pem"
166
+ path.write_bytes(public_pem)
167
+ return path
168
+
169
+
170
+ def load_public_key(mem_dir: Path) -> Optional[bytes]:
171
+ """Load public key PEM from .mem/keys/public.pem or config. Returns None if not found."""
172
+ path = _keys_dir(mem_dir) / "public.pem"
173
+ if path.exists():
174
+ return path.read_bytes()
175
+ config_file = mem_dir / "config.json"
176
+ if config_file.exists():
177
+ try:
178
+ config = json.loads(config_file.read_text())
179
+ return config.get("signing", {}).get("public_key_pem")
180
+ except Exception:
181
+ pass
182
+ return None
183
+
184
+
185
+ def load_private_key_from_env() -> Optional[bytes]:
186
+ """Load private key PEM from env AGMEM_SIGNING_PRIVATE_KEY (or path in AGMEM_SIGNING_PRIVATE_KEY_FILE)."""
187
+ pem = os.environ.get("AGMEM_SIGNING_PRIVATE_KEY")
188
+ if pem:
189
+ return pem.encode() if isinstance(pem, str) else pem
190
+ path = os.environ.get("AGMEM_SIGNING_PRIVATE_KEY_FILE")
191
+ if path and os.path.isfile(path):
192
+ return Path(path).read_bytes()
193
+ return None
194
+
195
+
196
+ def sign_merkle_root(root_hex: str, private_key_pem: bytes) -> str:
197
+ """Sign Merkle root (hex string). Returns signature as hex."""
198
+ if not ED25519_AVAILABLE:
199
+ raise RuntimeError("Signing requires 'cryptography'")
200
+ key = serialization.load_pem_private_key(private_key_pem, password=None)
201
+ if not isinstance(key, Ed25519PrivateKey):
202
+ raise TypeError("Ed25519 private key required")
203
+ sig = key.sign(root_hex.encode())
204
+ return sig.hex()
205
+
206
+
207
+ def verify_signature(root_hex: str, signature_hex: str, public_key_pem: bytes) -> bool:
208
+ """Verify signature of Merkle root. Returns True if valid."""
209
+ if not ED25519_AVAILABLE:
210
+ return False
211
+ try:
212
+ key = serialization.load_pem_public_key(public_key_pem)
213
+ if not isinstance(key, Ed25519PublicKey):
214
+ return False
215
+ key.verify(bytes.fromhex(signature_hex), root_hex.encode())
216
+ return True
217
+ except InvalidSignature:
218
+ return False
219
+ except Exception:
220
+ return False
221
+
222
+
223
+ def verify_commit(
224
+ store: ObjectStore,
225
+ commit_hash: str,
226
+ public_key_pem: Optional[bytes] = None,
227
+ *,
228
+ mem_dir: Optional[Path] = None,
229
+ ) -> Tuple[bool, Optional[str]]:
230
+ """
231
+ Verify commit: rebuild Merkle tree from blobs, compare root to stored, verify signature.
232
+ Returns (verified, error_message). verified=True means OK; False + message means tampered or unverified.
233
+ If public_key_pem is None and mem_dir is set, load from mem_dir.
234
+ """
235
+ commit = Commit.load(store, commit_hash)
236
+ if not commit:
237
+ return (False, "commit not found")
238
+ stored_root = (commit.metadata or {}).get("merkle_root")
239
+ stored_sig = (commit.metadata or {}).get("signature")
240
+ if not stored_root:
241
+ return (False, "commit has no merkle_root (unverified)")
242
+ computed_root = build_merkle_root_for_commit(store, commit_hash)
243
+ if not computed_root:
244
+ return (False, "could not build Merkle tree (missing tree/blobs)")
245
+ if not hmac.compare_digest(computed_root, stored_root):
246
+ return (False, "merkle_root mismatch (commit tampered)")
247
+ if not stored_sig:
248
+ return (True, None) # Root matches; no signature (legacy)
249
+ pub = public_key_pem
250
+ if not pub and mem_dir:
251
+ pub = load_public_key(mem_dir)
252
+ if not pub:
253
+ return (False, "signature present but no public key configured")
254
+ if isinstance(pub, str):
255
+ pub = pub.encode()
256
+ if not verify_signature(stored_root, stored_sig, pub):
257
+ return (False, "signature verification failed")
258
+ return (True, None)
259
+
260
+
261
+ def verify_commit_optional(
262
+ store: ObjectStore,
263
+ commit_hash: str,
264
+ mem_dir: Optional[Path] = None,
265
+ *,
266
+ strict: bool = False,
267
+ ) -> None:
268
+ """
269
+ Verify commit; if strict=True raise on failure. If strict=False, only raise on tamper (root mismatch).
270
+ Unverified (no merkle_root) is OK when not strict.
271
+ """
272
+ ok, err = verify_commit(store, commit_hash, None, mem_dir=mem_dir)
273
+ if ok:
274
+ return
275
+ if not err:
276
+ return
277
+ if "tampered" in err or "mismatch" in err or "signature verification failed" in err:
278
+ raise ValueError(f"Commit verification failed: {err}")
279
+ if strict:
280
+ raise ValueError(f"Commit verification failed: {err}")
memvcs/core/distiller.py CHANGED
@@ -110,9 +110,32 @@ class Distiller:
110
110
  continue
111
111
  combined = "\n---\n".join(contents)
112
112
 
113
- if self.config.llm_provider == "openai" and self.config.llm_model:
113
+ if self.config.llm_provider and self.config.llm_model:
114
114
  try:
115
- return self._extract_with_openai(combined, cluster.topic)
115
+ from .llm import get_provider
116
+
117
+ config = {
118
+ "llm_provider": self.config.llm_provider,
119
+ "llm_model": self.config.llm_model,
120
+ }
121
+ provider = get_provider(config=config)
122
+ if provider:
123
+ text = provider.complete(
124
+ [
125
+ {
126
+ "role": "system",
127
+ "content": "Extract factual statements from the text. Output as bullet points (one fact per line). Focus on: user preferences, learned facts, key decisions.",
128
+ },
129
+ {
130
+ "role": "user",
131
+ "content": f"Topic: {cluster.topic}\n\n{combined[:4000]}",
132
+ },
133
+ ],
134
+ max_tokens=500,
135
+ )
136
+ return [
137
+ line.strip() for line in text.splitlines() if line.strip().startswith("-")
138
+ ][:15]
116
139
  except Exception:
117
140
  pass
118
141
 
@@ -125,29 +148,6 @@ class Distiller:
125
148
  facts.append(f"- {line[:200]}")
126
149
  return facts[:10] if facts else [f"- Learned about {cluster.topic}"]
127
150
 
128
- def _extract_with_openai(self, content: str, topic: str) -> List[str]:
129
- """Extract facts using OpenAI API."""
130
- import openai
131
-
132
- response = openai.chat.completions.create(
133
- model=self.config.llm_model or "gpt-3.5-turbo",
134
- messages=[
135
- {
136
- "role": "system",
137
- "content": "Extract factual statements from the text. "
138
- "Output as bullet points (one fact per line). "
139
- "Focus on: user preferences, learned facts, key decisions.",
140
- },
141
- {
142
- "role": "user",
143
- "content": f"Topic: {topic}\n\n{content[:4000]}",
144
- },
145
- ],
146
- max_tokens=500,
147
- )
148
- text = response.choices[0].message.content
149
- return [line.strip() for line in text.splitlines() if line.strip().startswith("-")][:15]
150
-
151
151
  def write_consolidated(self, cluster: EpisodeCluster, facts: List[str]) -> Path:
152
152
  """Write consolidated semantic file."""
153
153
  self.target_dir.mkdir(parents=True, exist_ok=True)
@@ -0,0 +1,169 @@
1
+ """
2
+ Encryption at rest for agmem object store.
3
+
4
+ AES-256-GCM for object payloads; key derived from passphrase via Argon2id.
5
+ Hash-then-encrypt so content-addressable paths stay based on plaintext hash.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import secrets
11
+ from pathlib import Path
12
+ from typing import Optional, Tuple, Dict, Any, Callable
13
+
14
+ # AES-GCM and Argon2id via cryptography
15
+ try:
16
+ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
17
+ from cryptography.hazmat.primitives.kdf.argon2 import Argon2id
18
+
19
+ ENCRYPTION_AVAILABLE = True
20
+ except ImportError:
21
+ ENCRYPTION_AVAILABLE = False
22
+
23
+ IV_LEN = 12
24
+ TAG_LEN = 16
25
+ KEY_LEN = 32
26
+
27
+
28
+ def _encryption_config_path(mem_dir: Path) -> Path:
29
+ return mem_dir / "encryption.json"
30
+
31
+
32
+ def load_encryption_config(mem_dir: Path) -> Optional[Dict[str, Any]]:
33
+ """Load encryption config (salt, time_cost, memory_cost) from .mem/encryption.json."""
34
+ path = _encryption_config_path(mem_dir)
35
+ if not path.exists():
36
+ return None
37
+ try:
38
+ return json.loads(path.read_text())
39
+ except Exception:
40
+ return None
41
+
42
+
43
+ def save_encryption_config(
44
+ mem_dir: Path,
45
+ salt: bytes,
46
+ time_cost: int = 3,
47
+ memory_cost: int = 65536,
48
+ parallelism: int = 4,
49
+ ) -> Path:
50
+ """Save encryption config; salt stored as hex. Returns config path."""
51
+ mem_dir.mkdir(parents=True, exist_ok=True)
52
+ path = _encryption_config_path(mem_dir)
53
+ path.write_text(
54
+ json.dumps(
55
+ {
56
+ "salt_hex": salt.hex(),
57
+ "time_cost": time_cost,
58
+ "memory_cost": memory_cost,
59
+ "parallelism": parallelism,
60
+ },
61
+ indent=2,
62
+ )
63
+ )
64
+ return path
65
+
66
+
67
+ def derive_key(
68
+ passphrase: bytes,
69
+ salt: bytes,
70
+ time_cost: int = 3,
71
+ memory_cost: int = 65536,
72
+ parallelism: int = 4,
73
+ ) -> bytes:
74
+ """Derive 32-byte key from passphrase using Argon2id."""
75
+ if not ENCRYPTION_AVAILABLE:
76
+ raise RuntimeError("Encryption requires 'cryptography'")
77
+ kdf = Argon2id(
78
+ salt=salt,
79
+ length=KEY_LEN,
80
+ time_cost=time_cost,
81
+ memory_cost=memory_cost,
82
+ parallelism=parallelism,
83
+ )
84
+ return kdf.derive(passphrase)
85
+
86
+
87
+ def encrypt(plaintext: bytes, key: bytes) -> Tuple[bytes, bytes]:
88
+ """Encrypt with AES-256-GCM. Returns (iv, ciphertext_with_tag)."""
89
+ if not ENCRYPTION_AVAILABLE:
90
+ raise RuntimeError("Encryption requires 'cryptography'")
91
+ aes = AESGCM(key)
92
+ iv = secrets.token_bytes(IV_LEN)
93
+ ct = aes.encrypt(iv, plaintext, None) # ct includes 16-byte tag
94
+ return (iv, ct)
95
+
96
+
97
+ def decrypt(iv: bytes, ciphertext_with_tag: bytes, key: bytes) -> bytes:
98
+ """Decrypt AES-256-GCM. Raises on auth failure."""
99
+ if not ENCRYPTION_AVAILABLE:
100
+ raise RuntimeError("Encryption requires 'cryptography'")
101
+ aes = AESGCM(key)
102
+ return aes.decrypt(iv, ciphertext_with_tag, None)
103
+
104
+
105
+ def init_encryption(mem_dir: Path, time_cost: int = 3, memory_cost: int = 65536) -> bytes:
106
+ """Create new encryption config with random salt. Returns salt (caller derives key from passphrase)."""
107
+ salt = secrets.token_bytes(16)
108
+ save_encryption_config(mem_dir, salt, time_cost=time_cost, memory_cost=memory_cost)
109
+ return salt
110
+
111
+
112
+ class ObjectStoreEncryptor:
113
+ """
114
+ Encryptor for object store payloads (compressed bytes).
115
+ Uses AES-256-GCM; IV and tag stored with ciphertext.
116
+ """
117
+
118
+ def __init__(self, get_key: Callable[[], Optional[bytes]]):
119
+ self._get_key = get_key
120
+
121
+ def encrypt_payload(self, plaintext: bytes) -> bytes:
122
+ """Encrypt payload. Returns iv (12) + ciphertext_with_tag."""
123
+ key = self._get_key()
124
+ if not key:
125
+ raise ValueError("Encryption key not available (passphrase required)")
126
+ iv, ct = encrypt(plaintext, key)
127
+ return iv + ct
128
+
129
+ def decrypt_payload(self, raw: bytes) -> bytes:
130
+ """Decrypt payload. raw = iv (12) + ciphertext_with_tag."""
131
+ key = self._get_key()
132
+ if not key:
133
+ raise ValueError("Encryption key not available (passphrase required)")
134
+ if len(raw) < IV_LEN + TAG_LEN:
135
+ raise ValueError("Payload too short for encrypted object")
136
+ iv = raw[:IV_LEN]
137
+ ct = raw[IV_LEN:]
138
+ return decrypt(iv, ct, key)
139
+
140
+
141
+ def get_key_from_env_or_cache(
142
+ mem_dir: Path,
143
+ env_var: str = "AGMEM_ENCRYPTION_PASSPHRASE",
144
+ cache_var: str = "_agmem_encryption_key_cache",
145
+ ) -> Optional[bytes]:
146
+ """Get key from env or process cache. Derives key if passphrase in env and config exists."""
147
+ # Module-level cache for session (same process)
148
+ import sys
149
+
150
+ mod = sys.modules.get("memvcs.core.encryption")
151
+ if mod and getattr(mod, cache_var, None) is not None:
152
+ return getattr(mod, cache_var)
153
+ passphrase = os.environ.get(env_var)
154
+ if not passphrase:
155
+ return None
156
+ cfg = load_encryption_config(mem_dir)
157
+ if not cfg:
158
+ return None
159
+ salt = bytes.fromhex(cfg["salt_hex"])
160
+ key = derive_key(
161
+ passphrase.encode() if isinstance(passphrase, str) else passphrase,
162
+ salt,
163
+ time_cost=cfg.get("time_cost", 3),
164
+ memory_cost=cfg.get("memory_cost", 65536),
165
+ parallelism=cfg.get("parallelism", 4),
166
+ )
167
+ if mod is not None:
168
+ setattr(mod, cache_var, key)
169
+ return key