dhee 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. dhee/__init__.py +64 -0
  2. dhee/api/__init__.py +6 -0
  3. dhee/api/app.py +195 -0
  4. dhee/api/server.py +34 -0
  5. dhee/benchmarks/__init__.py +2 -0
  6. dhee/benchmarks/arc_agi.py +383 -0
  7. dhee/benchmarks/longmemeval.py +2177 -0
  8. dhee/cli.py +424 -0
  9. dhee/cli_config.py +139 -0
  10. dhee/cli_mcp.py +189 -0
  11. dhee/cli_setup.py +75 -0
  12. dhee/configs/__init__.py +19 -0
  13. dhee/configs/active.py +65 -0
  14. dhee/configs/base.py +686 -0
  15. dhee/configs/presets.py +152 -0
  16. dhee/core/__init__.py +22 -0
  17. dhee/core/agi_loop.py +212 -0
  18. dhee/core/alaya.py +330 -0
  19. dhee/core/answer_orchestration.py +942 -0
  20. dhee/core/buddhi.py +934 -0
  21. dhee/core/category.py +796 -0
  22. dhee/core/code_exec_counter.py +266 -0
  23. dhee/core/cognition.py +481 -0
  24. dhee/core/conflict.py +60 -0
  25. dhee/core/consolidation.py +112 -0
  26. dhee/core/decay.py +52 -0
  27. dhee/core/distillation.py +232 -0
  28. dhee/core/echo.py +634 -0
  29. dhee/core/engram.py +431 -0
  30. dhee/core/engram_extractor.py +609 -0
  31. dhee/core/enrichment.py +631 -0
  32. dhee/core/episodic_index.py +979 -0
  33. dhee/core/evolution.py +356 -0
  34. dhee/core/forgetting.py +391 -0
  35. dhee/core/fusion.py +68 -0
  36. dhee/core/graph.py +566 -0
  37. dhee/core/intent.py +93 -0
  38. dhee/core/kernel.py +142 -0
  39. dhee/core/log_parser.py +197 -0
  40. dhee/core/metrics.py +207 -0
  41. dhee/core/profile.py +504 -0
  42. dhee/core/proposition_context.py +144 -0
  43. dhee/core/resolvers.py +949 -0
  44. dhee/core/retrieval.py +171 -0
  45. dhee/core/salience.py +113 -0
  46. dhee/core/samskara.py +510 -0
  47. dhee/core/scene.py +381 -0
  48. dhee/core/traces.py +120 -0
  49. dhee/core/viveka.py +708 -0
  50. dhee/db/__init__.py +0 -0
  51. dhee/db/sqlite.py +2845 -0
  52. dhee/db/sqlite_backup.py +2070 -0
  53. dhee/decay/__init__.py +5 -0
  54. dhee/embeddings/__init__.py +0 -0
  55. dhee/embeddings/base.py +21 -0
  56. dhee/embeddings/gemini.py +83 -0
  57. dhee/embeddings/nvidia.py +116 -0
  58. dhee/embeddings/ollama.py +66 -0
  59. dhee/embeddings/openai.py +47 -0
  60. dhee/embeddings/qwen.py +139 -0
  61. dhee/embeddings/simple.py +65 -0
  62. dhee/exceptions.py +19 -0
  63. dhee/integrations/__init__.py +1 -0
  64. dhee/llms/__init__.py +0 -0
  65. dhee/llms/base.py +56 -0
  66. dhee/llms/dhee.py +295 -0
  67. dhee/llms/gemini.py +60 -0
  68. dhee/llms/mock.py +35 -0
  69. dhee/llms/nvidia.py +136 -0
  70. dhee/llms/ollama.py +58 -0
  71. dhee/llms/openai.py +35 -0
  72. dhee/llms/teacher_logger.py +243 -0
  73. dhee/mcp_server.py +1025 -0
  74. dhee/mcp_slim.py +442 -0
  75. dhee/memory/__init__.py +14 -0
  76. dhee/memory/base.py +23 -0
  77. dhee/memory/core.py +440 -0
  78. dhee/memory/main.py +6103 -0
  79. dhee/memory/parallel.py +60 -0
  80. dhee/memory/projects.py +395 -0
  81. dhee/memory/smart.py +507 -0
  82. dhee/memory/tasks.py +683 -0
  83. dhee/memory/utils.py +173 -0
  84. dhee/observability.py +49 -0
  85. dhee/retrieval/__init__.py +10 -0
  86. dhee/retrieval/reranker.py +252 -0
  87. dhee/simple.py +362 -0
  88. dhee/skills/__init__.py +7 -0
  89. dhee/skills/discovery.py +59 -0
  90. dhee/skills/executor.py +262 -0
  91. dhee/skills/hashing.py +81 -0
  92. dhee/skills/miner.py +374 -0
  93. dhee/skills/outcomes.py +151 -0
  94. dhee/skills/schema.py +241 -0
  95. dhee/skills/store.py +282 -0
  96. dhee/skills/structure.py +498 -0
  97. dhee/skills/trajectory.py +260 -0
  98. dhee/teaching/__init__.py +17 -0
  99. dhee/teaching/concepts.py +307 -0
  100. dhee/teaching/config.py +27 -0
  101. dhee/teaching/student_model.py +372 -0
  102. dhee/teaching/teaching_memory.py +255 -0
  103. dhee/utils/__init__.py +0 -0
  104. dhee/utils/factory.py +169 -0
  105. dhee/utils/math.py +25 -0
  106. dhee/utils/prompts.py +382 -0
  107. dhee/utils/repo_identity.py +72 -0
  108. dhee/vector_stores/__init__.py +0 -0
  109. dhee/vector_stores/base.py +61 -0
  110. dhee/vector_stores/memory.py +106 -0
  111. dhee/vector_stores/sqlite_vec.py +391 -0
  112. dhee/vector_stores/zvec_store.py +402 -0
  113. dhee-1.0.0.dist-info/METADATA +342 -0
  114. dhee-1.0.0.dist-info/RECORD +130 -0
  115. dhee-1.0.0.dist-info/WHEEL +5 -0
  116. dhee-1.0.0.dist-info/entry_points.txt +4 -0
  117. dhee-1.0.0.dist-info/licenses/LICENSE +21 -0
  118. dhee-1.0.0.dist-info/top_level.txt +3 -0
  119. dheeModel/__init__.py +18 -0
  120. dheeModel/client.py +385 -0
  121. dheeModel/model/__init__.py +1 -0
  122. dheeModel/model/dhee_model.py +167 -0
  123. dheeModel/training/__init__.py +1 -0
  124. dheeModel/training/data_formatter.py +155 -0
  125. dheeModel/training/karma.py +272 -0
  126. dheeModel/training/nididhyasana.py +660 -0
  127. dheeModel/training/smrti.py +411 -0
  128. dheeModel/training/train.py +321 -0
  129. dhee_shared/__init__.py +1 -0
  130. dhee_shared/model_paths.py +63 -0
dhee/__init__.py ADDED
@@ -0,0 +1,64 @@
1
+ """dhee — Cognition as a Service. The memory layer that makes ANY agent intelligent.
2
+
3
+ - FadeMem: Dual-layer (SML/LML) with natural decay
4
+ - EchoMem: Multi-modal encoding for stronger retention
5
+ - CategoryMem: Dynamic hierarchical category organization
6
+ - Universal Engram: Structured facts + context anchoring
7
+ - Cognition Engine: Memory-grounded recursive reasoning
8
+ - Prospective Scenes: Memory-driven future anticipation
9
+
10
+ Quick Start (zero-config, no API key):
11
+ from dhee import Memory
12
+ m = Memory()
13
+ m.add("User prefers Python")
14
+ results = m.search("programming preferences")
15
+
16
+ Tiered Memory Classes:
17
+ CoreMemory — lightweight: add/search/delete + decay (no LLM)
18
+ SmartMemory — + echo encoding, categories, knowledge graph (needs LLM)
19
+ FullMemory — + scenes, profiles, tasks, cognition (everything)
20
+ Memory — alias for CoreMemory (lightest default)
21
+ """
22
+
23
+ from dhee.memory.core import CoreMemory
24
+ from dhee.memory.smart import SmartMemory
25
+ from dhee.memory.main import FullMemory
26
+ from dhee.simple import Engram
27
+ from dhee.core.category import CategoryProcessor, Category, CategoryType, CategoryMatch
28
+ from dhee.core.echo import EchoProcessor, EchoDepth, EchoResult
29
+ from dhee.configs.base import MemoryConfig, FadeMemConfig, EchoMemConfig, CategoryMemConfig, ScopeConfig
30
+
31
+ # Default: CoreMemory (lightest, zero-config)
32
+ Memory = CoreMemory
33
+
34
+ __version__ = "1.0.0"
35
+ __all__ = [
36
+ # Tiered memory classes
37
+ "CoreMemory",
38
+ "SmartMemory",
39
+ "FullMemory",
40
+ "Memory",
41
+ # Simplified interface
42
+ "Engram",
43
+ # CategoryMem
44
+ "CategoryProcessor",
45
+ "Category",
46
+ "CategoryType",
47
+ "CategoryMatch",
48
+ # EchoMem
49
+ "EchoProcessor",
50
+ "EchoDepth",
51
+ "EchoResult",
52
+ # Config
53
+ "MemoryConfig",
54
+ "FadeMemConfig",
55
+ "EchoMemConfig",
56
+ "CategoryMemConfig",
57
+ "ScopeConfig",
58
+ ]
59
+
60
+
61
+ def _load_teaching():
62
+ """Lazy-load teaching module to avoid import overhead when not needed."""
63
+ from dhee.teaching import ConceptStore, StudentModel, TeachingMemory, TeachingConfig
64
+ return ConceptStore, StudentModel, TeachingMemory, TeachingConfig
dhee/api/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ """Engram REST API module."""
2
+
3
+ from dhee.api.app import app
4
+ from dhee.api.server import run
5
+
6
+ __all__ = ["app", "run"]
dhee/api/app.py ADDED
@@ -0,0 +1,195 @@
1
+ """Engram core REST API — lightweight handoff endpoints (no auth required).
2
+
3
+ These endpoints mirror the enterprise ``/v1/handoff/*`` routes but delegate
4
+ directly to ``engram.core.kernel`` without session/token enforcement.
5
+ They are intended for local development and for the ``prompt_context.py`` hook
6
+ which fires as a subprocess with no auth context.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ import os
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ from fastapi import FastAPI, Query
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from pydantic import BaseModel, Field
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ app = FastAPI(
22
+ title="Engram Core API",
23
+ version="0.1.0",
24
+ description="Lightweight handoff + memory endpoints.",
25
+ )
26
+
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_methods=["*"],
31
+ allow_headers=["*"],
32
+ )
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Health
37
+ # ---------------------------------------------------------------------------
38
+
39
+ @app.get("/health")
40
+ async def health():
41
+ return {"status": "ok"}
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Request / response schemas
46
+ # ---------------------------------------------------------------------------
47
+
48
+ class CheckpointRequest(BaseModel):
49
+ task_summary: Optional[str] = None
50
+ event_type: str = "hook_checkpoint"
51
+ agent_id: str = "claude-code"
52
+ context_snapshot: Optional[str] = None
53
+ repo_path: Optional[str] = None
54
+ status: Optional[str] = None
55
+ decisions_made: Optional[List[str]] = None
56
+ files_touched: Optional[List[str]] = None
57
+ todos_remaining: Optional[List[str]] = None
58
+ blockers: Optional[List[str]] = None
59
+ key_commands: Optional[List[str]] = None
60
+ test_results: Optional[str] = None
61
+
62
+
63
+ class RecoverRequest(BaseModel):
64
+ repo_path: str
65
+ agent_id: str = "claude-code"
66
+
67
+
68
+ class SessionDigestRequest(BaseModel):
69
+ task_summary: str
70
+ repo: Optional[str] = None
71
+ status: str = "active"
72
+ agent_id: str = "claude-code"
73
+ decisions_made: Optional[List[str]] = None
74
+ files_touched: Optional[List[str]] = None
75
+ todos_remaining: Optional[List[str]] = None
76
+ blockers: Optional[List[str]] = None
77
+ key_commands: Optional[List[str]] = None
78
+ test_results: Optional[str] = None
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # Handoff endpoints
83
+ # ---------------------------------------------------------------------------
84
+
85
+ @app.post("/v1/handoff/checkpoint")
86
+ async def handoff_checkpoint(request: CheckpointRequest):
87
+ """Receive a lightweight checkpoint from the hook or an agent.
88
+
89
+ Creates a dhee-bus session (if needed) and writes a checkpoint snapshot.
90
+ """
91
+ from dhee.core.kernel import _get_bus
92
+
93
+ bus = None
94
+ try:
95
+ bus = _get_bus()
96
+
97
+ # Find or create a session for this agent
98
+ session = bus.get_session(agent_id=request.agent_id)
99
+ if session is None:
100
+ sid = bus.save_session(
101
+ agent_id=request.agent_id,
102
+ repo=request.repo_path,
103
+ status=request.status or "active",
104
+ task_summary=request.task_summary or "",
105
+ )
106
+ else:
107
+ sid = session["id"]
108
+ # Update task_summary if provided
109
+ updates: Dict[str, Any] = {}
110
+ if request.task_summary:
111
+ updates["task_summary"] = request.task_summary
112
+ if request.status:
113
+ updates["status"] = request.status
114
+ if updates:
115
+ bus.update_session(sid, **updates)
116
+
117
+ snapshot = {
118
+ "event_type": request.event_type,
119
+ "task_summary": request.task_summary,
120
+ "context_snapshot": request.context_snapshot,
121
+ "files_touched": request.files_touched or [],
122
+ "key_commands": request.key_commands or [],
123
+ "decisions_made": request.decisions_made or [],
124
+ "todos_remaining": request.todos_remaining or [],
125
+ "blockers": request.blockers or [],
126
+ "test_results": request.test_results,
127
+ }
128
+ cid = bus.checkpoint(sid, request.agent_id, snapshot)
129
+
130
+ return {"status": "ok", "session_id": sid, "checkpoint_id": cid}
131
+
132
+ except Exception as exc:
133
+ logger.exception("Checkpoint failed")
134
+ return {"status": "error", "detail": str(exc)}
135
+ finally:
136
+ if bus is not None:
137
+ try:
138
+ bus.close()
139
+ except Exception:
140
+ pass
141
+
142
+
143
+ @app.get("/v1/handoff/sessions/last")
144
+ async def handoff_last_session(
145
+ agent_id: Optional[str] = Query(default=None),
146
+ repo: Optional[str] = Query(default=None),
147
+ fallback_log_recovery: bool = Query(default=True),
148
+ ):
149
+ """Get the last session, falling back to JSONL log parsing."""
150
+ from dhee.core.kernel import get_last_session
151
+
152
+ session = get_last_session(
153
+ agent_id=agent_id or "mcp-server",
154
+ repo=repo,
155
+ fallback_log_recovery=fallback_log_recovery,
156
+ )
157
+ if session is None:
158
+ return {"status": "no_session", "message": "No previous session found."}
159
+ return session
160
+
161
+
162
+ @app.post("/v1/handoff/recover")
163
+ async def handoff_recover(request: RecoverRequest):
164
+ """Direct log recovery — parse JSONL logs without checking bus first."""
165
+ from dhee.core.log_parser import find_latest_log, parse_conversation_log
166
+
167
+ log_path = find_latest_log(request.repo_path)
168
+ if log_path is None:
169
+ return {"status": "no_logs", "message": "No conversation logs found."}
170
+
171
+ digest = parse_conversation_log(log_path)
172
+ if digest.get("message_count", 0) == 0:
173
+ return {"status": "empty_log", "message": "Log file was empty."}
174
+
175
+ return digest
176
+
177
+
178
+ @app.post("/v1/handoff/sessions/digest")
179
+ async def save_handoff_digest(request: SessionDigestRequest):
180
+ """Save a session digest (lightweight, no auth)."""
181
+ from dhee.core.kernel import save_session_digest
182
+
183
+ result = save_session_digest(
184
+ task_summary=request.task_summary,
185
+ agent_id=request.agent_id,
186
+ repo=request.repo,
187
+ status=request.status,
188
+ decisions_made=request.decisions_made,
189
+ files_touched=request.files_touched,
190
+ todos_remaining=request.todos_remaining,
191
+ blockers=request.blockers,
192
+ key_commands=request.key_commands,
193
+ test_results=request.test_results,
194
+ )
195
+ return result
dhee/api/server.py ADDED
@@ -0,0 +1,34 @@
1
+ """Engram core API server runner.
2
+
3
+ Starts the lightweight FastAPI app from ``engram.api.app`` on the configured
4
+ host/port. This is the standalone server — the enterprise version adds auth,
5
+ governance, and more endpoints on top.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+
11
+ def run():
12
+ """Run the Engram core API server."""
13
+ import argparse
14
+ import uvicorn
15
+
16
+ parser = argparse.ArgumentParser(description="Engram Core API Server")
17
+ parser.add_argument("--host", default="127.0.0.1", help="Host to bind to")
18
+ parser.add_argument("--port", type=int, default=8100, help="Port to listen on")
19
+ parser.add_argument("--reload", action="store_true", help="Enable auto-reload")
20
+ args = parser.parse_args()
21
+
22
+ print(f"Starting Engram Core API on http://{args.host}:{args.port}")
23
+ print(f"Docs at http://{args.host}:{args.port}/docs")
24
+
25
+ uvicorn.run(
26
+ "engram.api.app:app",
27
+ host=args.host,
28
+ port=args.port,
29
+ reload=args.reload,
30
+ )
31
+
32
+
33
+ if __name__ == "__main__":
34
+ run()
@@ -0,0 +1,2 @@
1
+ """Benchmark runners for Engram."""
2
+
@@ -0,0 +1,383 @@
1
+ """ARC-AGI benchmark runner for Engram.
2
+
3
+ Tests abstract reasoning using Engram memory + LLM on Chollet's
4
+ Abstraction and Reasoning Corpus (ARC-AGI).
5
+
6
+ Two modes:
7
+ 1. Direct: LLM sees training examples and predicts test output.
8
+ 2. Memory-augmented: solved patterns stored in Engram memory;
9
+ similar patterns retrieved as extra context for new tasks.
10
+
11
+ Usage:
12
+ python -m engram.benchmarks.arc_agi \
13
+ --data-dir data/arc-agi/evaluation \
14
+ --max-tasks 50 \
15
+ --mode memory
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import json
22
+ import logging
23
+ import os
24
+ import re
25
+ import sys
26
+ import tempfile
27
+ import time
28
+ from pathlib import Path
29
+ from typing import Any, Dict, List, Optional, Tuple
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # ── Grid helpers ──────────────────────────────────────────
34
+
35
+ COLOR_NAMES = {
36
+ 0: "black", 1: "blue", 2: "red", 3: "green", 4: "yellow",
37
+ 5: "grey", 6: "magenta", 7: "orange", 8: "cyan", 9: "maroon",
38
+ }
39
+
40
+
41
+ def grid_to_text(grid: List[List[int]]) -> str:
42
+ """Render a grid as a compact text block with row/col indices."""
43
+ rows = len(grid)
44
+ cols = len(grid[0]) if grid else 0
45
+ header = " " + " ".join(f"{c:>2}" for c in range(cols))
46
+ lines = [f"({rows}x{cols} grid)", header]
47
+ for r, row in enumerate(grid):
48
+ lines.append(f"{r:>2} " + " ".join(f"{v:>2}" for v in row))
49
+ return "\n".join(lines)
50
+
51
+
52
+ def grids_equal(a: List[List[int]], b: List[List[int]]) -> bool:
53
+ if len(a) != len(b):
54
+ return False
55
+ return all(row_a == row_b for row_a, row_b in zip(a, b))
56
+
57
+
58
+ def parse_grid_from_text(text: str, expected_rows: int = 0, expected_cols: int = 0) -> Optional[List[List[int]]]:
59
+ """Best-effort parse a grid from LLM output text."""
60
+ # Strip thinking blocks (e.g. <think>...</think> from reasoning models)
61
+ text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
62
+ # Try JSON array first
63
+ try:
64
+ # Find the outermost [[...]] pattern
65
+ match = re.search(r"\[\s*\[.*?\]\s*\]", text, re.DOTALL)
66
+ if match:
67
+ grid = json.loads(match.group())
68
+ if isinstance(grid, list) and all(isinstance(r, list) for r in grid):
69
+ return grid
70
+ except (json.JSONDecodeError, ValueError):
71
+ pass
72
+
73
+ # Try row-by-row number extraction
74
+ lines = text.strip().split("\n")
75
+ grid = []
76
+ for line in lines:
77
+ nums = re.findall(r"\b(\d)\b", line)
78
+ if nums and len(nums) >= 2:
79
+ grid.append([int(n) for n in nums])
80
+ if grid and len(grid) >= 2:
81
+ # Normalize column counts
82
+ max_cols = max(len(r) for r in grid)
83
+ if all(len(r) == max_cols for r in grid):
84
+ return grid
85
+
86
+ return None
87
+
88
+
89
+ # ── Prompt construction ───────────────────────────────────
90
+
91
+ def format_task_prompt(
92
+ train_pairs: List[Dict[str, Any]],
93
+ test_input: List[List[int]],
94
+ memory_context: str = "",
95
+ ) -> str:
96
+ """Build the LLM prompt for one ARC task."""
97
+ parts = [
98
+ "You are solving an ARC-AGI abstract reasoning puzzle.",
99
+ "Each puzzle has training examples showing input→output grid transformations.",
100
+ "Find the pattern and apply it to the test input.",
101
+ "",
102
+ ]
103
+
104
+ if memory_context:
105
+ parts.extend([
106
+ "Here are similar patterns you solved before:",
107
+ memory_context,
108
+ "",
109
+ ])
110
+
111
+ for i, pair in enumerate(train_pairs):
112
+ parts.append(f"=== Training Example {i+1} ===")
113
+ parts.append("Input:")
114
+ parts.append(grid_to_text(pair["input"]))
115
+ parts.append("Output:")
116
+ parts.append(grid_to_text(pair["output"]))
117
+ parts.append("")
118
+
119
+ parts.append("=== Test ===")
120
+ parts.append("Input:")
121
+ parts.append(grid_to_text(test_input))
122
+ parts.append("")
123
+ parts.append(
124
+ "Analyze the pattern from the training examples. "
125
+ "Then output ONLY the predicted output grid as a JSON 2D array (e.g. [[0,1],[2,3]]). "
126
+ "No explanation, just the JSON array."
127
+ )
128
+
129
+ return "\n".join(parts)
130
+
131
+
132
+ def describe_pattern(
133
+ train_pairs: List[Dict[str, Any]],
134
+ test_input: List[List[int]],
135
+ test_output: List[List[int]],
136
+ ) -> str:
137
+ """Create a textual description of a solved task for memory storage."""
138
+ in_shapes = [f"{len(p['input'])}x{len(p['input'][0])}" for p in train_pairs]
139
+ out_shapes = [f"{len(p['output'])}x{len(p['output'][0])}" for p in train_pairs]
140
+ test_in_shape = f"{len(test_input)}x{len(test_input[0])}"
141
+ test_out_shape = f"{len(test_output)}x{len(test_output[0])}"
142
+
143
+ unique_vals = set()
144
+ for p in train_pairs:
145
+ for row in p["input"]:
146
+ unique_vals.update(row)
147
+ for row in p["output"]:
148
+ unique_vals.update(row)
149
+
150
+ return (
151
+ f"ARC pattern: input shapes {in_shapes}, output shapes {out_shapes}. "
152
+ f"Test: {test_in_shape} → {test_out_shape}. "
153
+ f"Colors used: {sorted(unique_vals)}. "
154
+ f"Training examples: {len(train_pairs)}."
155
+ )
156
+
157
+
158
+ # ── Benchmark runner ──────────────────────────────────────
159
+
160
+ def load_tasks(data_dir: str) -> Dict[str, Dict[str, Any]]:
161
+ tasks = {}
162
+ for path in sorted(Path(data_dir).glob("*.json")):
163
+ with open(path) as f:
164
+ tasks[path.stem] = json.load(f)
165
+ return tasks
166
+
167
+
168
+ def run_arc_benchmark(args: argparse.Namespace) -> Dict[str, Any]:
169
+ # Load .env (check CWD first, then project root)
170
+ env_path = os.path.join(os.getcwd(), ".env")
171
+ if not os.path.exists(env_path):
172
+ env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env")
173
+ if os.path.exists(env_path):
174
+ with open(env_path) as f:
175
+ for line in f:
176
+ line = line.strip()
177
+ if "=" in line and not line.startswith("#"):
178
+ key, _, val = line.partition("=")
179
+ os.environ.setdefault(key.strip(), val.strip().strip('"').strip("'"))
180
+
181
+ tasks = load_tasks(args.data_dir)
182
+ task_ids = sorted(tasks.keys())
183
+ if args.max_tasks > 0:
184
+ task_ids = task_ids[: args.max_tasks]
185
+
186
+ print(f"ARC-AGI Benchmark: {len(task_ids)} tasks from {args.data_dir}")
187
+ print(f"Mode: {args.mode}")
188
+ print(f"LLM: {args.llm_provider}/{args.llm_model}")
189
+ print()
190
+
191
+ # Build LLM
192
+ from dhee.utils.factory import LLMFactory
193
+ llm_config = {
194
+ "model": args.llm_model,
195
+ "temperature": args.temperature,
196
+ "top_p": args.top_p,
197
+ "max_tokens": args.max_tokens,
198
+ "timeout": args.timeout,
199
+ "enable_thinking": args.enable_thinking,
200
+ }
201
+ if args.api_key:
202
+ llm_config["api_key"] = args.api_key
203
+ llm = LLMFactory.create(args.llm_provider, llm_config)
204
+
205
+ # Build memory (only in memory mode)
206
+ memory = None
207
+ if args.mode == "memory":
208
+ from dhee.configs.base import MemoryConfig
209
+ from dhee.memory.main import Memory
210
+
211
+ tmpdir = tempfile.mkdtemp(prefix="arc_bench_")
212
+ config = MemoryConfig(
213
+ vector_store={"provider": "memory", "config": {}},
214
+ llm={"provider": args.llm_provider, "config": {"model": args.llm_model, "temperature": args.temperature, "max_tokens": args.max_tokens, "timeout": args.timeout}},
215
+ embedder={"provider": args.embedder_provider, "config": {"model": args.embedder_model}},
216
+ history_db_path=os.path.join(tmpdir, "arc.db"),
217
+ embedding_model_dims=args.embedding_dims,
218
+ echo={"enable_echo": False},
219
+ category={"enable_categories": False},
220
+ graph={"enable_graph": False},
221
+ scene={"enable_scenes": False},
222
+ profile={"enable_profiles": False},
223
+ )
224
+ memory = Memory(config)
225
+ print(f"Embedder: {args.embedder_provider}/{args.embedder_model}")
226
+ print(f"Memory DB: {tmpdir}")
227
+ print()
228
+
229
+ solved = 0
230
+ failed = 0
231
+ errored = 0
232
+ results = []
233
+ t_start = time.time()
234
+
235
+ for idx, task_id in enumerate(task_ids):
236
+ task = tasks[task_id]
237
+ train_pairs = task["train"]
238
+ test_cases = task["test"]
239
+
240
+ for test_idx, test_case in enumerate(test_cases):
241
+ test_input = test_case["input"]
242
+ expected_output = test_case["output"]
243
+
244
+ # Memory retrieval
245
+ memory_context = ""
246
+ if memory and idx > 0:
247
+ try:
248
+ query = describe_pattern(train_pairs, test_input, [])
249
+ search_results = memory.search(query=query, user_id="arc", limit=3)
250
+ hits = search_results.get("results", [])
251
+ if hits:
252
+ memory_context = "\n".join(
253
+ f"- {h.get('memory', '')[:200]}" for h in hits
254
+ )
255
+ except Exception as e:
256
+ logger.debug("Memory search failed: %s", e)
257
+
258
+ prompt = format_task_prompt(train_pairs, test_input, memory_context)
259
+
260
+ try:
261
+ response = llm.generate(prompt)
262
+ predicted = parse_grid_from_text(response)
263
+
264
+ if predicted and grids_equal(predicted, expected_output):
265
+ solved += 1
266
+ status = "SOLVED"
267
+
268
+ # Store solved pattern in memory
269
+ if memory:
270
+ try:
271
+ pattern_desc = describe_pattern(
272
+ train_pairs, test_input, expected_output
273
+ )
274
+ memory.add(
275
+ messages=[{"role": "user", "content": pattern_desc}],
276
+ user_id="arc",
277
+ infer=False,
278
+ )
279
+ except Exception:
280
+ pass
281
+ else:
282
+ failed += 1
283
+ status = "WRONG"
284
+
285
+ results.append({
286
+ "task_id": task_id,
287
+ "test_idx": test_idx,
288
+ "status": status,
289
+ "predicted": predicted,
290
+ "expected_shape": f"{len(expected_output)}x{len(expected_output[0])}",
291
+ })
292
+
293
+ except Exception as e:
294
+ errored += 1
295
+ results.append({
296
+ "task_id": task_id,
297
+ "test_idx": test_idx,
298
+ "status": "ERROR",
299
+ "error": str(e),
300
+ })
301
+ print(f" [{idx+1}/{len(task_ids)}] {task_id}: ERROR — {e}")
302
+ continue
303
+
304
+ total_attempted = solved + failed + errored
305
+ score_pct = (solved / total_attempted * 100) if total_attempted else 0
306
+
307
+ if (idx + 1) % args.print_every == 0 or status == "SOLVED":
308
+ print(
309
+ f" [{idx+1}/{len(task_ids)}] {task_id}: {status} "
310
+ f"| Running: {solved}/{total_attempted} ({score_pct:.1f}%)"
311
+ )
312
+
313
+ elapsed = time.time() - t_start
314
+ total_attempted = solved + failed + errored
315
+ score = solved / total_attempted if total_attempted else 0
316
+
317
+ print()
318
+ print("=" * 60)
319
+ print("ARC-AGI BENCHMARK RESULTS")
320
+ print("=" * 60)
321
+ print(f" Tasks attempted: {total_attempted}")
322
+ print(f" Solved: {solved}")
323
+ print(f" Wrong: {failed}")
324
+ print(f" Errors: {errored}")
325
+ print(f" Score: {score:.1%} ({solved}/{total_attempted})")
326
+ print(f" Time: {elapsed:.0f}s ({elapsed/max(total_attempted,1):.1f}s/task)")
327
+ print(f" Mode: {args.mode}")
328
+ print(f" LLM: {args.llm_provider}/{args.llm_model}")
329
+ if memory:
330
+ print(f" Embedder: {args.embedder_provider}/{args.embedder_model}")
331
+ print()
332
+
333
+ summary = {
334
+ "score": round(score, 4),
335
+ "solved": solved,
336
+ "failed": failed,
337
+ "errored": errored,
338
+ "total": total_attempted,
339
+ "elapsed_s": round(elapsed, 1),
340
+ "mode": args.mode,
341
+ "llm": f"{args.llm_provider}/{args.llm_model}",
342
+ }
343
+
344
+ if args.output_json:
345
+ out_path = Path(args.output_json)
346
+ out_path.parent.mkdir(parents=True, exist_ok=True)
347
+ with open(out_path, "w") as f:
348
+ json.dump({"summary": summary, "results": results}, f, indent=2)
349
+ print(f" Results saved to: {args.output_json}")
350
+
351
+ return summary
352
+
353
+
354
+ def parse_args() -> argparse.Namespace:
355
+ parser = argparse.ArgumentParser(description="Run ARC-AGI benchmark with Engram memory + LLM.")
356
+ parser.add_argument("--data-dir", default="data/arc-agi/evaluation", help="Directory with ARC task JSON files.")
357
+ parser.add_argument("--max-tasks", type=int, default=50, help="Max tasks to evaluate (-1 = all).")
358
+ parser.add_argument("--mode", choices=["direct", "memory"], default="direct", help="direct: LLM only. memory: Engram memory-augmented.")
359
+ parser.add_argument("--output-json", default=None, help="Path to save detailed results JSON.")
360
+ parser.add_argument("--print-every", type=int, default=5, help="Progress print interval.")
361
+ parser.add_argument("--timeout", type=int, default=120, help="LLM call timeout in seconds.")
362
+
363
+ parser.add_argument("--llm-provider", default="nvidia", choices=["nvidia", "openai", "gemini", "ollama"])
364
+ parser.add_argument("--llm-model", default="deepseek-ai/deepseek-r1-distill-qwen-14b")
365
+ parser.add_argument("--api-key", default=None, help="Override LLM API key.")
366
+ parser.add_argument("--temperature", type=float, default=0.0, help="LLM temperature.")
367
+ parser.add_argument("--top-p", type=float, default=0.7, help="LLM top-p.")
368
+ parser.add_argument("--max-tokens", type=int, default=2048, help="LLM max output tokens.")
369
+ parser.add_argument("--enable-thinking", action="store_true", help="Enable thinking/CoT mode for supported models.")
370
+ parser.add_argument("--embedder-provider", default="nvidia", choices=["nvidia", "openai", "gemini", "simple"])
371
+ parser.add_argument("--embedder-model", default="nvidia/nv-embed-v1")
372
+ parser.add_argument("--embedding-dims", type=int, default=4096)
373
+
374
+ return parser.parse_args()
375
+
376
+
377
+ def main() -> None:
378
+ args = parse_args()
379
+ run_arc_benchmark(args)
380
+
381
+
382
+ if __name__ == "__main__":
383
+ main()