dhee 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dhee/__init__.py +64 -0
- dhee/api/__init__.py +6 -0
- dhee/api/app.py +195 -0
- dhee/api/server.py +34 -0
- dhee/benchmarks/__init__.py +2 -0
- dhee/benchmarks/arc_agi.py +383 -0
- dhee/benchmarks/longmemeval.py +2177 -0
- dhee/cli.py +424 -0
- dhee/cli_config.py +139 -0
- dhee/cli_mcp.py +189 -0
- dhee/cli_setup.py +75 -0
- dhee/configs/__init__.py +19 -0
- dhee/configs/active.py +65 -0
- dhee/configs/base.py +686 -0
- dhee/configs/presets.py +152 -0
- dhee/core/__init__.py +22 -0
- dhee/core/agi_loop.py +212 -0
- dhee/core/alaya.py +330 -0
- dhee/core/answer_orchestration.py +942 -0
- dhee/core/buddhi.py +934 -0
- dhee/core/category.py +796 -0
- dhee/core/code_exec_counter.py +266 -0
- dhee/core/cognition.py +481 -0
- dhee/core/conflict.py +60 -0
- dhee/core/consolidation.py +112 -0
- dhee/core/decay.py +52 -0
- dhee/core/distillation.py +232 -0
- dhee/core/echo.py +634 -0
- dhee/core/engram.py +431 -0
- dhee/core/engram_extractor.py +609 -0
- dhee/core/enrichment.py +631 -0
- dhee/core/episodic_index.py +979 -0
- dhee/core/evolution.py +356 -0
- dhee/core/forgetting.py +391 -0
- dhee/core/fusion.py +68 -0
- dhee/core/graph.py +566 -0
- dhee/core/intent.py +93 -0
- dhee/core/kernel.py +142 -0
- dhee/core/log_parser.py +197 -0
- dhee/core/metrics.py +207 -0
- dhee/core/profile.py +504 -0
- dhee/core/proposition_context.py +144 -0
- dhee/core/resolvers.py +949 -0
- dhee/core/retrieval.py +171 -0
- dhee/core/salience.py +113 -0
- dhee/core/samskara.py +510 -0
- dhee/core/scene.py +381 -0
- dhee/core/traces.py +120 -0
- dhee/core/viveka.py +708 -0
- dhee/db/__init__.py +0 -0
- dhee/db/sqlite.py +2845 -0
- dhee/db/sqlite_backup.py +2070 -0
- dhee/decay/__init__.py +5 -0
- dhee/embeddings/__init__.py +0 -0
- dhee/embeddings/base.py +21 -0
- dhee/embeddings/gemini.py +83 -0
- dhee/embeddings/nvidia.py +116 -0
- dhee/embeddings/ollama.py +66 -0
- dhee/embeddings/openai.py +47 -0
- dhee/embeddings/qwen.py +139 -0
- dhee/embeddings/simple.py +65 -0
- dhee/exceptions.py +19 -0
- dhee/integrations/__init__.py +1 -0
- dhee/llms/__init__.py +0 -0
- dhee/llms/base.py +56 -0
- dhee/llms/dhee.py +295 -0
- dhee/llms/gemini.py +60 -0
- dhee/llms/mock.py +35 -0
- dhee/llms/nvidia.py +136 -0
- dhee/llms/ollama.py +58 -0
- dhee/llms/openai.py +35 -0
- dhee/llms/teacher_logger.py +243 -0
- dhee/mcp_server.py +1025 -0
- dhee/mcp_slim.py +442 -0
- dhee/memory/__init__.py +14 -0
- dhee/memory/base.py +23 -0
- dhee/memory/core.py +440 -0
- dhee/memory/main.py +6103 -0
- dhee/memory/parallel.py +60 -0
- dhee/memory/projects.py +395 -0
- dhee/memory/smart.py +507 -0
- dhee/memory/tasks.py +683 -0
- dhee/memory/utils.py +173 -0
- dhee/observability.py +49 -0
- dhee/retrieval/__init__.py +10 -0
- dhee/retrieval/reranker.py +252 -0
- dhee/simple.py +362 -0
- dhee/skills/__init__.py +7 -0
- dhee/skills/discovery.py +59 -0
- dhee/skills/executor.py +262 -0
- dhee/skills/hashing.py +81 -0
- dhee/skills/miner.py +374 -0
- dhee/skills/outcomes.py +151 -0
- dhee/skills/schema.py +241 -0
- dhee/skills/store.py +282 -0
- dhee/skills/structure.py +498 -0
- dhee/skills/trajectory.py +260 -0
- dhee/teaching/__init__.py +17 -0
- dhee/teaching/concepts.py +307 -0
- dhee/teaching/config.py +27 -0
- dhee/teaching/student_model.py +372 -0
- dhee/teaching/teaching_memory.py +255 -0
- dhee/utils/__init__.py +0 -0
- dhee/utils/factory.py +169 -0
- dhee/utils/math.py +25 -0
- dhee/utils/prompts.py +382 -0
- dhee/utils/repo_identity.py +72 -0
- dhee/vector_stores/__init__.py +0 -0
- dhee/vector_stores/base.py +61 -0
- dhee/vector_stores/memory.py +106 -0
- dhee/vector_stores/sqlite_vec.py +391 -0
- dhee/vector_stores/zvec_store.py +402 -0
- dhee-1.0.0.dist-info/METADATA +342 -0
- dhee-1.0.0.dist-info/RECORD +130 -0
- dhee-1.0.0.dist-info/WHEEL +5 -0
- dhee-1.0.0.dist-info/entry_points.txt +4 -0
- dhee-1.0.0.dist-info/licenses/LICENSE +21 -0
- dhee-1.0.0.dist-info/top_level.txt +3 -0
- dheeModel/__init__.py +18 -0
- dheeModel/client.py +385 -0
- dheeModel/model/__init__.py +1 -0
- dheeModel/model/dhee_model.py +167 -0
- dheeModel/training/__init__.py +1 -0
- dheeModel/training/data_formatter.py +155 -0
- dheeModel/training/karma.py +272 -0
- dheeModel/training/nididhyasana.py +660 -0
- dheeModel/training/smrti.py +411 -0
- dheeModel/training/train.py +321 -0
- dhee_shared/__init__.py +1 -0
- dhee_shared/model_paths.py +63 -0
dhee/__init__.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""dhee — Cognition as a Service. The memory layer that makes ANY agent intelligent.
|
|
2
|
+
|
|
3
|
+
- FadeMem: Dual-layer (SML/LML) with natural decay
|
|
4
|
+
- EchoMem: Multi-modal encoding for stronger retention
|
|
5
|
+
- CategoryMem: Dynamic hierarchical category organization
|
|
6
|
+
- Universal Engram: Structured facts + context anchoring
|
|
7
|
+
- Cognition Engine: Memory-grounded recursive reasoning
|
|
8
|
+
- Prospective Scenes: Memory-driven future anticipation
|
|
9
|
+
|
|
10
|
+
Quick Start (zero-config, no API key):
|
|
11
|
+
from dhee import Memory
|
|
12
|
+
m = Memory()
|
|
13
|
+
m.add("User prefers Python")
|
|
14
|
+
results = m.search("programming preferences")
|
|
15
|
+
|
|
16
|
+
Tiered Memory Classes:
|
|
17
|
+
CoreMemory — lightweight: add/search/delete + decay (no LLM)
|
|
18
|
+
SmartMemory — + echo encoding, categories, knowledge graph (needs LLM)
|
|
19
|
+
FullMemory — + scenes, profiles, tasks, cognition (everything)
|
|
20
|
+
Memory — alias for CoreMemory (lightest default)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from dhee.memory.core import CoreMemory
|
|
24
|
+
from dhee.memory.smart import SmartMemory
|
|
25
|
+
from dhee.memory.main import FullMemory
|
|
26
|
+
from dhee.simple import Engram
|
|
27
|
+
from dhee.core.category import CategoryProcessor, Category, CategoryType, CategoryMatch
|
|
28
|
+
from dhee.core.echo import EchoProcessor, EchoDepth, EchoResult
|
|
29
|
+
from dhee.configs.base import MemoryConfig, FadeMemConfig, EchoMemConfig, CategoryMemConfig, ScopeConfig
|
|
30
|
+
|
|
31
|
+
# Default: CoreMemory (lightest, zero-config)
|
|
32
|
+
Memory = CoreMemory
|
|
33
|
+
|
|
34
|
+
__version__ = "1.0.0"
|
|
35
|
+
__all__ = [
|
|
36
|
+
# Tiered memory classes
|
|
37
|
+
"CoreMemory",
|
|
38
|
+
"SmartMemory",
|
|
39
|
+
"FullMemory",
|
|
40
|
+
"Memory",
|
|
41
|
+
# Simplified interface
|
|
42
|
+
"Engram",
|
|
43
|
+
# CategoryMem
|
|
44
|
+
"CategoryProcessor",
|
|
45
|
+
"Category",
|
|
46
|
+
"CategoryType",
|
|
47
|
+
"CategoryMatch",
|
|
48
|
+
# EchoMem
|
|
49
|
+
"EchoProcessor",
|
|
50
|
+
"EchoDepth",
|
|
51
|
+
"EchoResult",
|
|
52
|
+
# Config
|
|
53
|
+
"MemoryConfig",
|
|
54
|
+
"FadeMemConfig",
|
|
55
|
+
"EchoMemConfig",
|
|
56
|
+
"CategoryMemConfig",
|
|
57
|
+
"ScopeConfig",
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _load_teaching():
|
|
62
|
+
"""Lazy-load teaching module to avoid import overhead when not needed."""
|
|
63
|
+
from dhee.teaching import ConceptStore, StudentModel, TeachingMemory, TeachingConfig
|
|
64
|
+
return ConceptStore, StudentModel, TeachingMemory, TeachingConfig
|
dhee/api/__init__.py
ADDED
dhee/api/app.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""Engram core REST API — lightweight handoff endpoints (no auth required).
|
|
2
|
+
|
|
3
|
+
These endpoints mirror the enterprise ``/v1/handoff/*`` routes but delegate
|
|
4
|
+
directly to ``engram.core.kernel`` without session/token enforcement.
|
|
5
|
+
They are intended for local development and for the ``prompt_context.py`` hook
|
|
6
|
+
which fires as a subprocess with no auth context.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
from typing import Any, Dict, List, Optional
|
|
14
|
+
|
|
15
|
+
from fastapi import FastAPI, Query
|
|
16
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
17
|
+
from pydantic import BaseModel, Field
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
app = FastAPI(
|
|
22
|
+
title="Engram Core API",
|
|
23
|
+
version="0.1.0",
|
|
24
|
+
description="Lightweight handoff + memory endpoints.",
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
app.add_middleware(
|
|
28
|
+
CORSMiddleware,
|
|
29
|
+
allow_origins=["*"],
|
|
30
|
+
allow_methods=["*"],
|
|
31
|
+
allow_headers=["*"],
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ---------------------------------------------------------------------------
|
|
36
|
+
# Health
|
|
37
|
+
# ---------------------------------------------------------------------------
|
|
38
|
+
|
|
39
|
+
@app.get("/health")
|
|
40
|
+
async def health():
|
|
41
|
+
return {"status": "ok"}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# ---------------------------------------------------------------------------
|
|
45
|
+
# Request / response schemas
|
|
46
|
+
# ---------------------------------------------------------------------------
|
|
47
|
+
|
|
48
|
+
class CheckpointRequest(BaseModel):
|
|
49
|
+
task_summary: Optional[str] = None
|
|
50
|
+
event_type: str = "hook_checkpoint"
|
|
51
|
+
agent_id: str = "claude-code"
|
|
52
|
+
context_snapshot: Optional[str] = None
|
|
53
|
+
repo_path: Optional[str] = None
|
|
54
|
+
status: Optional[str] = None
|
|
55
|
+
decisions_made: Optional[List[str]] = None
|
|
56
|
+
files_touched: Optional[List[str]] = None
|
|
57
|
+
todos_remaining: Optional[List[str]] = None
|
|
58
|
+
blockers: Optional[List[str]] = None
|
|
59
|
+
key_commands: Optional[List[str]] = None
|
|
60
|
+
test_results: Optional[str] = None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class RecoverRequest(BaseModel):
|
|
64
|
+
repo_path: str
|
|
65
|
+
agent_id: str = "claude-code"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SessionDigestRequest(BaseModel):
|
|
69
|
+
task_summary: str
|
|
70
|
+
repo: Optional[str] = None
|
|
71
|
+
status: str = "active"
|
|
72
|
+
agent_id: str = "claude-code"
|
|
73
|
+
decisions_made: Optional[List[str]] = None
|
|
74
|
+
files_touched: Optional[List[str]] = None
|
|
75
|
+
todos_remaining: Optional[List[str]] = None
|
|
76
|
+
blockers: Optional[List[str]] = None
|
|
77
|
+
key_commands: Optional[List[str]] = None
|
|
78
|
+
test_results: Optional[str] = None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# ---------------------------------------------------------------------------
|
|
82
|
+
# Handoff endpoints
|
|
83
|
+
# ---------------------------------------------------------------------------
|
|
84
|
+
|
|
85
|
+
@app.post("/v1/handoff/checkpoint")
|
|
86
|
+
async def handoff_checkpoint(request: CheckpointRequest):
|
|
87
|
+
"""Receive a lightweight checkpoint from the hook or an agent.
|
|
88
|
+
|
|
89
|
+
Creates a dhee-bus session (if needed) and writes a checkpoint snapshot.
|
|
90
|
+
"""
|
|
91
|
+
from dhee.core.kernel import _get_bus
|
|
92
|
+
|
|
93
|
+
bus = None
|
|
94
|
+
try:
|
|
95
|
+
bus = _get_bus()
|
|
96
|
+
|
|
97
|
+
# Find or create a session for this agent
|
|
98
|
+
session = bus.get_session(agent_id=request.agent_id)
|
|
99
|
+
if session is None:
|
|
100
|
+
sid = bus.save_session(
|
|
101
|
+
agent_id=request.agent_id,
|
|
102
|
+
repo=request.repo_path,
|
|
103
|
+
status=request.status or "active",
|
|
104
|
+
task_summary=request.task_summary or "",
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
sid = session["id"]
|
|
108
|
+
# Update task_summary if provided
|
|
109
|
+
updates: Dict[str, Any] = {}
|
|
110
|
+
if request.task_summary:
|
|
111
|
+
updates["task_summary"] = request.task_summary
|
|
112
|
+
if request.status:
|
|
113
|
+
updates["status"] = request.status
|
|
114
|
+
if updates:
|
|
115
|
+
bus.update_session(sid, **updates)
|
|
116
|
+
|
|
117
|
+
snapshot = {
|
|
118
|
+
"event_type": request.event_type,
|
|
119
|
+
"task_summary": request.task_summary,
|
|
120
|
+
"context_snapshot": request.context_snapshot,
|
|
121
|
+
"files_touched": request.files_touched or [],
|
|
122
|
+
"key_commands": request.key_commands or [],
|
|
123
|
+
"decisions_made": request.decisions_made or [],
|
|
124
|
+
"todos_remaining": request.todos_remaining or [],
|
|
125
|
+
"blockers": request.blockers or [],
|
|
126
|
+
"test_results": request.test_results,
|
|
127
|
+
}
|
|
128
|
+
cid = bus.checkpoint(sid, request.agent_id, snapshot)
|
|
129
|
+
|
|
130
|
+
return {"status": "ok", "session_id": sid, "checkpoint_id": cid}
|
|
131
|
+
|
|
132
|
+
except Exception as exc:
|
|
133
|
+
logger.exception("Checkpoint failed")
|
|
134
|
+
return {"status": "error", "detail": str(exc)}
|
|
135
|
+
finally:
|
|
136
|
+
if bus is not None:
|
|
137
|
+
try:
|
|
138
|
+
bus.close()
|
|
139
|
+
except Exception:
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@app.get("/v1/handoff/sessions/last")
|
|
144
|
+
async def handoff_last_session(
|
|
145
|
+
agent_id: Optional[str] = Query(default=None),
|
|
146
|
+
repo: Optional[str] = Query(default=None),
|
|
147
|
+
fallback_log_recovery: bool = Query(default=True),
|
|
148
|
+
):
|
|
149
|
+
"""Get the last session, falling back to JSONL log parsing."""
|
|
150
|
+
from dhee.core.kernel import get_last_session
|
|
151
|
+
|
|
152
|
+
session = get_last_session(
|
|
153
|
+
agent_id=agent_id or "mcp-server",
|
|
154
|
+
repo=repo,
|
|
155
|
+
fallback_log_recovery=fallback_log_recovery,
|
|
156
|
+
)
|
|
157
|
+
if session is None:
|
|
158
|
+
return {"status": "no_session", "message": "No previous session found."}
|
|
159
|
+
return session
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@app.post("/v1/handoff/recover")
|
|
163
|
+
async def handoff_recover(request: RecoverRequest):
|
|
164
|
+
"""Direct log recovery — parse JSONL logs without checking bus first."""
|
|
165
|
+
from dhee.core.log_parser import find_latest_log, parse_conversation_log
|
|
166
|
+
|
|
167
|
+
log_path = find_latest_log(request.repo_path)
|
|
168
|
+
if log_path is None:
|
|
169
|
+
return {"status": "no_logs", "message": "No conversation logs found."}
|
|
170
|
+
|
|
171
|
+
digest = parse_conversation_log(log_path)
|
|
172
|
+
if digest.get("message_count", 0) == 0:
|
|
173
|
+
return {"status": "empty_log", "message": "Log file was empty."}
|
|
174
|
+
|
|
175
|
+
return digest
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@app.post("/v1/handoff/sessions/digest")
|
|
179
|
+
async def save_handoff_digest(request: SessionDigestRequest):
|
|
180
|
+
"""Save a session digest (lightweight, no auth)."""
|
|
181
|
+
from dhee.core.kernel import save_session_digest
|
|
182
|
+
|
|
183
|
+
result = save_session_digest(
|
|
184
|
+
task_summary=request.task_summary,
|
|
185
|
+
agent_id=request.agent_id,
|
|
186
|
+
repo=request.repo,
|
|
187
|
+
status=request.status,
|
|
188
|
+
decisions_made=request.decisions_made,
|
|
189
|
+
files_touched=request.files_touched,
|
|
190
|
+
todos_remaining=request.todos_remaining,
|
|
191
|
+
blockers=request.blockers,
|
|
192
|
+
key_commands=request.key_commands,
|
|
193
|
+
test_results=request.test_results,
|
|
194
|
+
)
|
|
195
|
+
return result
|
dhee/api/server.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Engram core API server runner.
|
|
2
|
+
|
|
3
|
+
Starts the lightweight FastAPI app from ``engram.api.app`` on the configured
|
|
4
|
+
host/port. This is the standalone server — the enterprise version adds auth,
|
|
5
|
+
governance, and more endpoints on top.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def run():
|
|
12
|
+
"""Run the Engram core API server."""
|
|
13
|
+
import argparse
|
|
14
|
+
import uvicorn
|
|
15
|
+
|
|
16
|
+
parser = argparse.ArgumentParser(description="Engram Core API Server")
|
|
17
|
+
parser.add_argument("--host", default="127.0.0.1", help="Host to bind to")
|
|
18
|
+
parser.add_argument("--port", type=int, default=8100, help="Port to listen on")
|
|
19
|
+
parser.add_argument("--reload", action="store_true", help="Enable auto-reload")
|
|
20
|
+
args = parser.parse_args()
|
|
21
|
+
|
|
22
|
+
print(f"Starting Engram Core API on http://{args.host}:{args.port}")
|
|
23
|
+
print(f"Docs at http://{args.host}:{args.port}/docs")
|
|
24
|
+
|
|
25
|
+
uvicorn.run(
|
|
26
|
+
"engram.api.app:app",
|
|
27
|
+
host=args.host,
|
|
28
|
+
port=args.port,
|
|
29
|
+
reload=args.reload,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
if __name__ == "__main__":
|
|
34
|
+
run()
|
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
"""ARC-AGI benchmark runner for Engram.
|
|
2
|
+
|
|
3
|
+
Tests abstract reasoning using Engram memory + LLM on Chollet's
|
|
4
|
+
Abstraction and Reasoning Corpus (ARC-AGI).
|
|
5
|
+
|
|
6
|
+
Two modes:
|
|
7
|
+
1. Direct: LLM sees training examples and predicts test output.
|
|
8
|
+
2. Memory-augmented: solved patterns stored in Engram memory;
|
|
9
|
+
similar patterns retrieved as extra context for new tasks.
|
|
10
|
+
|
|
11
|
+
Usage:
|
|
12
|
+
python -m engram.benchmarks.arc_agi \
|
|
13
|
+
--data-dir data/arc-agi/evaluation \
|
|
14
|
+
--max-tasks 50 \
|
|
15
|
+
--mode memory
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import argparse
|
|
21
|
+
import json
|
|
22
|
+
import logging
|
|
23
|
+
import os
|
|
24
|
+
import re
|
|
25
|
+
import sys
|
|
26
|
+
import tempfile
|
|
27
|
+
import time
|
|
28
|
+
from pathlib import Path
|
|
29
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
# ── Grid helpers ──────────────────────────────────────────
|
|
34
|
+
|
|
35
|
+
COLOR_NAMES = {
|
|
36
|
+
0: "black", 1: "blue", 2: "red", 3: "green", 4: "yellow",
|
|
37
|
+
5: "grey", 6: "magenta", 7: "orange", 8: "cyan", 9: "maroon",
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def grid_to_text(grid: List[List[int]]) -> str:
|
|
42
|
+
"""Render a grid as a compact text block with row/col indices."""
|
|
43
|
+
rows = len(grid)
|
|
44
|
+
cols = len(grid[0]) if grid else 0
|
|
45
|
+
header = " " + " ".join(f"{c:>2}" for c in range(cols))
|
|
46
|
+
lines = [f"({rows}x{cols} grid)", header]
|
|
47
|
+
for r, row in enumerate(grid):
|
|
48
|
+
lines.append(f"{r:>2} " + " ".join(f"{v:>2}" for v in row))
|
|
49
|
+
return "\n".join(lines)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def grids_equal(a: List[List[int]], b: List[List[int]]) -> bool:
|
|
53
|
+
if len(a) != len(b):
|
|
54
|
+
return False
|
|
55
|
+
return all(row_a == row_b for row_a, row_b in zip(a, b))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def parse_grid_from_text(text: str, expected_rows: int = 0, expected_cols: int = 0) -> Optional[List[List[int]]]:
|
|
59
|
+
"""Best-effort parse a grid from LLM output text."""
|
|
60
|
+
# Strip thinking blocks (e.g. <think>...</think> from reasoning models)
|
|
61
|
+
text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
|
|
62
|
+
# Try JSON array first
|
|
63
|
+
try:
|
|
64
|
+
# Find the outermost [[...]] pattern
|
|
65
|
+
match = re.search(r"\[\s*\[.*?\]\s*\]", text, re.DOTALL)
|
|
66
|
+
if match:
|
|
67
|
+
grid = json.loads(match.group())
|
|
68
|
+
if isinstance(grid, list) and all(isinstance(r, list) for r in grid):
|
|
69
|
+
return grid
|
|
70
|
+
except (json.JSONDecodeError, ValueError):
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
# Try row-by-row number extraction
|
|
74
|
+
lines = text.strip().split("\n")
|
|
75
|
+
grid = []
|
|
76
|
+
for line in lines:
|
|
77
|
+
nums = re.findall(r"\b(\d)\b", line)
|
|
78
|
+
if nums and len(nums) >= 2:
|
|
79
|
+
grid.append([int(n) for n in nums])
|
|
80
|
+
if grid and len(grid) >= 2:
|
|
81
|
+
# Normalize column counts
|
|
82
|
+
max_cols = max(len(r) for r in grid)
|
|
83
|
+
if all(len(r) == max_cols for r in grid):
|
|
84
|
+
return grid
|
|
85
|
+
|
|
86
|
+
return None
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# ── Prompt construction ───────────────────────────────────
|
|
90
|
+
|
|
91
|
+
def format_task_prompt(
|
|
92
|
+
train_pairs: List[Dict[str, Any]],
|
|
93
|
+
test_input: List[List[int]],
|
|
94
|
+
memory_context: str = "",
|
|
95
|
+
) -> str:
|
|
96
|
+
"""Build the LLM prompt for one ARC task."""
|
|
97
|
+
parts = [
|
|
98
|
+
"You are solving an ARC-AGI abstract reasoning puzzle.",
|
|
99
|
+
"Each puzzle has training examples showing input→output grid transformations.",
|
|
100
|
+
"Find the pattern and apply it to the test input.",
|
|
101
|
+
"",
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
if memory_context:
|
|
105
|
+
parts.extend([
|
|
106
|
+
"Here are similar patterns you solved before:",
|
|
107
|
+
memory_context,
|
|
108
|
+
"",
|
|
109
|
+
])
|
|
110
|
+
|
|
111
|
+
for i, pair in enumerate(train_pairs):
|
|
112
|
+
parts.append(f"=== Training Example {i+1} ===")
|
|
113
|
+
parts.append("Input:")
|
|
114
|
+
parts.append(grid_to_text(pair["input"]))
|
|
115
|
+
parts.append("Output:")
|
|
116
|
+
parts.append(grid_to_text(pair["output"]))
|
|
117
|
+
parts.append("")
|
|
118
|
+
|
|
119
|
+
parts.append("=== Test ===")
|
|
120
|
+
parts.append("Input:")
|
|
121
|
+
parts.append(grid_to_text(test_input))
|
|
122
|
+
parts.append("")
|
|
123
|
+
parts.append(
|
|
124
|
+
"Analyze the pattern from the training examples. "
|
|
125
|
+
"Then output ONLY the predicted output grid as a JSON 2D array (e.g. [[0,1],[2,3]]). "
|
|
126
|
+
"No explanation, just the JSON array."
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
return "\n".join(parts)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def describe_pattern(
|
|
133
|
+
train_pairs: List[Dict[str, Any]],
|
|
134
|
+
test_input: List[List[int]],
|
|
135
|
+
test_output: List[List[int]],
|
|
136
|
+
) -> str:
|
|
137
|
+
"""Create a textual description of a solved task for memory storage."""
|
|
138
|
+
in_shapes = [f"{len(p['input'])}x{len(p['input'][0])}" for p in train_pairs]
|
|
139
|
+
out_shapes = [f"{len(p['output'])}x{len(p['output'][0])}" for p in train_pairs]
|
|
140
|
+
test_in_shape = f"{len(test_input)}x{len(test_input[0])}"
|
|
141
|
+
test_out_shape = f"{len(test_output)}x{len(test_output[0])}"
|
|
142
|
+
|
|
143
|
+
unique_vals = set()
|
|
144
|
+
for p in train_pairs:
|
|
145
|
+
for row in p["input"]:
|
|
146
|
+
unique_vals.update(row)
|
|
147
|
+
for row in p["output"]:
|
|
148
|
+
unique_vals.update(row)
|
|
149
|
+
|
|
150
|
+
return (
|
|
151
|
+
f"ARC pattern: input shapes {in_shapes}, output shapes {out_shapes}. "
|
|
152
|
+
f"Test: {test_in_shape} → {test_out_shape}. "
|
|
153
|
+
f"Colors used: {sorted(unique_vals)}. "
|
|
154
|
+
f"Training examples: {len(train_pairs)}."
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# ── Benchmark runner ──────────────────────────────────────
|
|
159
|
+
|
|
160
|
+
def load_tasks(data_dir: str) -> Dict[str, Dict[str, Any]]:
|
|
161
|
+
tasks = {}
|
|
162
|
+
for path in sorted(Path(data_dir).glob("*.json")):
|
|
163
|
+
with open(path) as f:
|
|
164
|
+
tasks[path.stem] = json.load(f)
|
|
165
|
+
return tasks
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def run_arc_benchmark(args: argparse.Namespace) -> Dict[str, Any]:
|
|
169
|
+
# Load .env (check CWD first, then project root)
|
|
170
|
+
env_path = os.path.join(os.getcwd(), ".env")
|
|
171
|
+
if not os.path.exists(env_path):
|
|
172
|
+
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env")
|
|
173
|
+
if os.path.exists(env_path):
|
|
174
|
+
with open(env_path) as f:
|
|
175
|
+
for line in f:
|
|
176
|
+
line = line.strip()
|
|
177
|
+
if "=" in line and not line.startswith("#"):
|
|
178
|
+
key, _, val = line.partition("=")
|
|
179
|
+
os.environ.setdefault(key.strip(), val.strip().strip('"').strip("'"))
|
|
180
|
+
|
|
181
|
+
tasks = load_tasks(args.data_dir)
|
|
182
|
+
task_ids = sorted(tasks.keys())
|
|
183
|
+
if args.max_tasks > 0:
|
|
184
|
+
task_ids = task_ids[: args.max_tasks]
|
|
185
|
+
|
|
186
|
+
print(f"ARC-AGI Benchmark: {len(task_ids)} tasks from {args.data_dir}")
|
|
187
|
+
print(f"Mode: {args.mode}")
|
|
188
|
+
print(f"LLM: {args.llm_provider}/{args.llm_model}")
|
|
189
|
+
print()
|
|
190
|
+
|
|
191
|
+
# Build LLM
|
|
192
|
+
from dhee.utils.factory import LLMFactory
|
|
193
|
+
llm_config = {
|
|
194
|
+
"model": args.llm_model,
|
|
195
|
+
"temperature": args.temperature,
|
|
196
|
+
"top_p": args.top_p,
|
|
197
|
+
"max_tokens": args.max_tokens,
|
|
198
|
+
"timeout": args.timeout,
|
|
199
|
+
"enable_thinking": args.enable_thinking,
|
|
200
|
+
}
|
|
201
|
+
if args.api_key:
|
|
202
|
+
llm_config["api_key"] = args.api_key
|
|
203
|
+
llm = LLMFactory.create(args.llm_provider, llm_config)
|
|
204
|
+
|
|
205
|
+
# Build memory (only in memory mode)
|
|
206
|
+
memory = None
|
|
207
|
+
if args.mode == "memory":
|
|
208
|
+
from dhee.configs.base import MemoryConfig
|
|
209
|
+
from dhee.memory.main import Memory
|
|
210
|
+
|
|
211
|
+
tmpdir = tempfile.mkdtemp(prefix="arc_bench_")
|
|
212
|
+
config = MemoryConfig(
|
|
213
|
+
vector_store={"provider": "memory", "config": {}},
|
|
214
|
+
llm={"provider": args.llm_provider, "config": {"model": args.llm_model, "temperature": args.temperature, "max_tokens": args.max_tokens, "timeout": args.timeout}},
|
|
215
|
+
embedder={"provider": args.embedder_provider, "config": {"model": args.embedder_model}},
|
|
216
|
+
history_db_path=os.path.join(tmpdir, "arc.db"),
|
|
217
|
+
embedding_model_dims=args.embedding_dims,
|
|
218
|
+
echo={"enable_echo": False},
|
|
219
|
+
category={"enable_categories": False},
|
|
220
|
+
graph={"enable_graph": False},
|
|
221
|
+
scene={"enable_scenes": False},
|
|
222
|
+
profile={"enable_profiles": False},
|
|
223
|
+
)
|
|
224
|
+
memory = Memory(config)
|
|
225
|
+
print(f"Embedder: {args.embedder_provider}/{args.embedder_model}")
|
|
226
|
+
print(f"Memory DB: {tmpdir}")
|
|
227
|
+
print()
|
|
228
|
+
|
|
229
|
+
solved = 0
|
|
230
|
+
failed = 0
|
|
231
|
+
errored = 0
|
|
232
|
+
results = []
|
|
233
|
+
t_start = time.time()
|
|
234
|
+
|
|
235
|
+
for idx, task_id in enumerate(task_ids):
|
|
236
|
+
task = tasks[task_id]
|
|
237
|
+
train_pairs = task["train"]
|
|
238
|
+
test_cases = task["test"]
|
|
239
|
+
|
|
240
|
+
for test_idx, test_case in enumerate(test_cases):
|
|
241
|
+
test_input = test_case["input"]
|
|
242
|
+
expected_output = test_case["output"]
|
|
243
|
+
|
|
244
|
+
# Memory retrieval
|
|
245
|
+
memory_context = ""
|
|
246
|
+
if memory and idx > 0:
|
|
247
|
+
try:
|
|
248
|
+
query = describe_pattern(train_pairs, test_input, [])
|
|
249
|
+
search_results = memory.search(query=query, user_id="arc", limit=3)
|
|
250
|
+
hits = search_results.get("results", [])
|
|
251
|
+
if hits:
|
|
252
|
+
memory_context = "\n".join(
|
|
253
|
+
f"- {h.get('memory', '')[:200]}" for h in hits
|
|
254
|
+
)
|
|
255
|
+
except Exception as e:
|
|
256
|
+
logger.debug("Memory search failed: %s", e)
|
|
257
|
+
|
|
258
|
+
prompt = format_task_prompt(train_pairs, test_input, memory_context)
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
response = llm.generate(prompt)
|
|
262
|
+
predicted = parse_grid_from_text(response)
|
|
263
|
+
|
|
264
|
+
if predicted and grids_equal(predicted, expected_output):
|
|
265
|
+
solved += 1
|
|
266
|
+
status = "SOLVED"
|
|
267
|
+
|
|
268
|
+
# Store solved pattern in memory
|
|
269
|
+
if memory:
|
|
270
|
+
try:
|
|
271
|
+
pattern_desc = describe_pattern(
|
|
272
|
+
train_pairs, test_input, expected_output
|
|
273
|
+
)
|
|
274
|
+
memory.add(
|
|
275
|
+
messages=[{"role": "user", "content": pattern_desc}],
|
|
276
|
+
user_id="arc",
|
|
277
|
+
infer=False,
|
|
278
|
+
)
|
|
279
|
+
except Exception:
|
|
280
|
+
pass
|
|
281
|
+
else:
|
|
282
|
+
failed += 1
|
|
283
|
+
status = "WRONG"
|
|
284
|
+
|
|
285
|
+
results.append({
|
|
286
|
+
"task_id": task_id,
|
|
287
|
+
"test_idx": test_idx,
|
|
288
|
+
"status": status,
|
|
289
|
+
"predicted": predicted,
|
|
290
|
+
"expected_shape": f"{len(expected_output)}x{len(expected_output[0])}",
|
|
291
|
+
})
|
|
292
|
+
|
|
293
|
+
except Exception as e:
|
|
294
|
+
errored += 1
|
|
295
|
+
results.append({
|
|
296
|
+
"task_id": task_id,
|
|
297
|
+
"test_idx": test_idx,
|
|
298
|
+
"status": "ERROR",
|
|
299
|
+
"error": str(e),
|
|
300
|
+
})
|
|
301
|
+
print(f" [{idx+1}/{len(task_ids)}] {task_id}: ERROR — {e}")
|
|
302
|
+
continue
|
|
303
|
+
|
|
304
|
+
total_attempted = solved + failed + errored
|
|
305
|
+
score_pct = (solved / total_attempted * 100) if total_attempted else 0
|
|
306
|
+
|
|
307
|
+
if (idx + 1) % args.print_every == 0 or status == "SOLVED":
|
|
308
|
+
print(
|
|
309
|
+
f" [{idx+1}/{len(task_ids)}] {task_id}: {status} "
|
|
310
|
+
f"| Running: {solved}/{total_attempted} ({score_pct:.1f}%)"
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
elapsed = time.time() - t_start
|
|
314
|
+
total_attempted = solved + failed + errored
|
|
315
|
+
score = solved / total_attempted if total_attempted else 0
|
|
316
|
+
|
|
317
|
+
print()
|
|
318
|
+
print("=" * 60)
|
|
319
|
+
print("ARC-AGI BENCHMARK RESULTS")
|
|
320
|
+
print("=" * 60)
|
|
321
|
+
print(f" Tasks attempted: {total_attempted}")
|
|
322
|
+
print(f" Solved: {solved}")
|
|
323
|
+
print(f" Wrong: {failed}")
|
|
324
|
+
print(f" Errors: {errored}")
|
|
325
|
+
print(f" Score: {score:.1%} ({solved}/{total_attempted})")
|
|
326
|
+
print(f" Time: {elapsed:.0f}s ({elapsed/max(total_attempted,1):.1f}s/task)")
|
|
327
|
+
print(f" Mode: {args.mode}")
|
|
328
|
+
print(f" LLM: {args.llm_provider}/{args.llm_model}")
|
|
329
|
+
if memory:
|
|
330
|
+
print(f" Embedder: {args.embedder_provider}/{args.embedder_model}")
|
|
331
|
+
print()
|
|
332
|
+
|
|
333
|
+
summary = {
|
|
334
|
+
"score": round(score, 4),
|
|
335
|
+
"solved": solved,
|
|
336
|
+
"failed": failed,
|
|
337
|
+
"errored": errored,
|
|
338
|
+
"total": total_attempted,
|
|
339
|
+
"elapsed_s": round(elapsed, 1),
|
|
340
|
+
"mode": args.mode,
|
|
341
|
+
"llm": f"{args.llm_provider}/{args.llm_model}",
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
if args.output_json:
|
|
345
|
+
out_path = Path(args.output_json)
|
|
346
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
347
|
+
with open(out_path, "w") as f:
|
|
348
|
+
json.dump({"summary": summary, "results": results}, f, indent=2)
|
|
349
|
+
print(f" Results saved to: {args.output_json}")
|
|
350
|
+
|
|
351
|
+
return summary
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def parse_args() -> argparse.Namespace:
|
|
355
|
+
parser = argparse.ArgumentParser(description="Run ARC-AGI benchmark with Engram memory + LLM.")
|
|
356
|
+
parser.add_argument("--data-dir", default="data/arc-agi/evaluation", help="Directory with ARC task JSON files.")
|
|
357
|
+
parser.add_argument("--max-tasks", type=int, default=50, help="Max tasks to evaluate (-1 = all).")
|
|
358
|
+
parser.add_argument("--mode", choices=["direct", "memory"], default="direct", help="direct: LLM only. memory: Engram memory-augmented.")
|
|
359
|
+
parser.add_argument("--output-json", default=None, help="Path to save detailed results JSON.")
|
|
360
|
+
parser.add_argument("--print-every", type=int, default=5, help="Progress print interval.")
|
|
361
|
+
parser.add_argument("--timeout", type=int, default=120, help="LLM call timeout in seconds.")
|
|
362
|
+
|
|
363
|
+
parser.add_argument("--llm-provider", default="nvidia", choices=["nvidia", "openai", "gemini", "ollama"])
|
|
364
|
+
parser.add_argument("--llm-model", default="deepseek-ai/deepseek-r1-distill-qwen-14b")
|
|
365
|
+
parser.add_argument("--api-key", default=None, help="Override LLM API key.")
|
|
366
|
+
parser.add_argument("--temperature", type=float, default=0.0, help="LLM temperature.")
|
|
367
|
+
parser.add_argument("--top-p", type=float, default=0.7, help="LLM top-p.")
|
|
368
|
+
parser.add_argument("--max-tokens", type=int, default=2048, help="LLM max output tokens.")
|
|
369
|
+
parser.add_argument("--enable-thinking", action="store_true", help="Enable thinking/CoT mode for supported models.")
|
|
370
|
+
parser.add_argument("--embedder-provider", default="nvidia", choices=["nvidia", "openai", "gemini", "simple"])
|
|
371
|
+
parser.add_argument("--embedder-model", default="nvidia/nv-embed-v1")
|
|
372
|
+
parser.add_argument("--embedding-dims", type=int, default=4096)
|
|
373
|
+
|
|
374
|
+
return parser.parse_args()
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def main() -> None:
|
|
378
|
+
args = parse_args()
|
|
379
|
+
run_arc_benchmark(args)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
if __name__ == "__main__":
|
|
383
|
+
main()
|